esm2_loras
1.0.0
هذه محاولة لتدريب تكيف منخفض (LORA) لنموذج لغة البروتين ESM-2 لمهمة تصنيف الرمز المميز. على وجه الخصوص ، نحاول تدريب تنبؤ موقع ربط الحمض النووي الريبي. لا تزال هناك بعض المشكلات التي يجب العمل عليها وأي ردود فعل أو نصيحة ستكون موضع تقدير كبير. هذا الرمز مخصص لنموذج صغير ، لذا يجب أن يقوم بتجهيزات WANDB للبحث عن ارتفاع الفائقة في فترة زمنية معقولة على أي وحدة معالجة الرسومات تقريبًا. يمكنك بسهولة تبديل نماذج أكبر على الرغم من أنك إذا أردت.
النموذج نفسه
"AmelieSchreiber/esm2_t6_8M_UR50D_lora_rna_binding_sites"يمكن العثور عليها على الوجه المعانقة هنا.
لإعداد بيئة كوندا ، استنساخ الريبو والتشغيل:
conda env create -f environment.yml
ثم قم بالتشغيل:
conda activate lora_esm_2
لتدريب النموذج المدى:
from lora_esm2_script import train_protein_model
train_protein_model ()للاستخدام ، حاول التشغيل:
from transformers import AutoModelForTokenClassification , AutoTokenizer
from peft import PeftModel
import torch
import numpy as np
import random
# Path to the saved LoRA model
model_path = "esm2_t6_8M-finetuned-lora_2023-08-03_18-32-25"
# ESM2 base model
base_model_path = "facebook/esm2_t6_8M_UR50D"
# Load the model
base_model = AutoModelForTokenClassification . from_pretrained ( base_model_path )
loaded_model = PeftModel . from_pretrained ( base_model , model_path )
# Load the tokenizer
loaded_tokenizer = AutoTokenizer . from_pretrained ( model_path )
# New unseen protein sequence
new_protein_sequence = "FDLNDFLEQKVLVRMEAIINSMTMKERAKPEIIKGSRKRRIAAGSGMQVQDVNRLLKQFDDMQRMMKKM"
# Tokenize the new sequence
inputs = loaded_tokenizer ( new_protein_sequence , truncation = True , padding = 'max_length' , max_length = 512 , return_tensors = "pt" )
# Make predictions
with torch . no_grad ():
outputs = loaded_model ( ** inputs )
logits = outputs . logits
predictions = torch . argmax ( logits , dim = 2 )
# Print logits for debugging
print ( "Logits:" , logits )
# Convert predictions to a list
predicted_labels = predictions . squeeze (). tolist ()
# Get input IDs to identify padding and special tokens
input_ids = inputs [ 'input_ids' ]. squeeze (). tolist ()
# Define a set of token IDs that correspond to special tokens
special_tokens_ids = { loaded_tokenizer . cls_token_id , loaded_tokenizer . pad_token_id , loaded_tokenizer . eos_token_id }
# Filter the predicted labels using the special_tokens_ids to remove predictions for special tokens
binding_sites = [ label for label , token_id in zip ( predicted_labels , input_ids ) if token_id not in special_tokens_ids ]
print ( "Predicted binding sites:" , binding_sites )