esm2_loras
1.0.0
这是针对令牌分类任务的蛋白质语言模型ESM-2培训低级适应(LORA)的尝试。特别是,我们尝试训练RNA结合位点预测指标。仍然有一些问题需要解决,任何反馈或建议都将不胜感激。该代码适用于一个小型型号,因此它应该在几乎任何GPU上都在合理的时间内执行WANDB扫描以进行超参数搜索。如果需要的话,您可以轻松地将大型型号交换。
模型本身
"AmelieSchreiber/esm2_t6_8M_UR50D_lora_rna_binding_sites"可以在这里拥抱脸上找到。
要设置Conda环境,请克隆回购并运行:
conda env create -f environment.yml
然后运行:
conda activate lora_esm_2
训练模型运行:
from lora_esm2_script import train_protein_model
train_protein_model ()要使用,请尝试运行:
from transformers import AutoModelForTokenClassification , AutoTokenizer
from peft import PeftModel
import torch
import numpy as np
import random
# Path to the saved LoRA model
model_path = "esm2_t6_8M-finetuned-lora_2023-08-03_18-32-25"
# ESM2 base model
base_model_path = "facebook/esm2_t6_8M_UR50D"
# Load the model
base_model = AutoModelForTokenClassification . from_pretrained ( base_model_path )
loaded_model = PeftModel . from_pretrained ( base_model , model_path )
# Load the tokenizer
loaded_tokenizer = AutoTokenizer . from_pretrained ( model_path )
# New unseen protein sequence
new_protein_sequence = "FDLNDFLEQKVLVRMEAIINSMTMKERAKPEIIKGSRKRRIAAGSGMQVQDVNRLLKQFDDMQRMMKKM"
# Tokenize the new sequence
inputs = loaded_tokenizer ( new_protein_sequence , truncation = True , padding = 'max_length' , max_length = 512 , return_tensors = "pt" )
# Make predictions
with torch . no_grad ():
outputs = loaded_model ( ** inputs )
logits = outputs . logits
predictions = torch . argmax ( logits , dim = 2 )
# Print logits for debugging
print ( "Logits:" , logits )
# Convert predictions to a list
predicted_labels = predictions . squeeze (). tolist ()
# Get input IDs to identify padding and special tokens
input_ids = inputs [ 'input_ids' ]. squeeze (). tolist ()
# Define a set of token IDs that correspond to special tokens
special_tokens_ids = { loaded_tokenizer . cls_token_id , loaded_tokenizer . pad_token_id , loaded_tokenizer . eos_token_id }
# Filter the predicted labels using the special_tokens_ids to remove predictions for special tokens
binding_sites = [ label for label , token_id in zip ( predicted_labels , input_ids ) if token_id not in special_tokens_ids ]
print ( "Predicted binding sites:" , binding_sites )