esm2_loras
1.0.0
這是針對令牌分類任務的蛋白質語言模型ESM-2培訓低級適應(LORA)的嘗試。特別是,我們嘗試訓練RNA結合位點預測指標。仍然有一些問題需要解決,任何反饋或建議都將不勝感激。該代碼適用於一個小型型號,因此它應該在幾乎任何GPU上都在合理的時間內執行WANDB掃描以進行超參數搜索。如果需要的話,您可以輕鬆地將大型型號交換。
模型本身
"AmelieSchreiber/esm2_t6_8M_UR50D_lora_rna_binding_sites"可以在這裡擁抱臉上找到。
要設置Conda環境,請克隆回購併運行:
conda env create -f environment.yml
然後運行:
conda activate lora_esm_2
訓練模型運行:
from lora_esm2_script import train_protein_model
train_protein_model ()要使用,請嘗試運行:
from transformers import AutoModelForTokenClassification , AutoTokenizer
from peft import PeftModel
import torch
import numpy as np
import random
# Path to the saved LoRA model
model_path = "esm2_t6_8M-finetuned-lora_2023-08-03_18-32-25"
# ESM2 base model
base_model_path = "facebook/esm2_t6_8M_UR50D"
# Load the model
base_model = AutoModelForTokenClassification . from_pretrained ( base_model_path )
loaded_model = PeftModel . from_pretrained ( base_model , model_path )
# Load the tokenizer
loaded_tokenizer = AutoTokenizer . from_pretrained ( model_path )
# New unseen protein sequence
new_protein_sequence = "FDLNDFLEQKVLVRMEAIINSMTMKERAKPEIIKGSRKRRIAAGSGMQVQDVNRLLKQFDDMQRMMKKM"
# Tokenize the new sequence
inputs = loaded_tokenizer ( new_protein_sequence , truncation = True , padding = 'max_length' , max_length = 512 , return_tensors = "pt" )
# Make predictions
with torch . no_grad ():
outputs = loaded_model ( ** inputs )
logits = outputs . logits
predictions = torch . argmax ( logits , dim = 2 )
# Print logits for debugging
print ( "Logits:" , logits )
# Convert predictions to a list
predicted_labels = predictions . squeeze (). tolist ()
# Get input IDs to identify padding and special tokens
input_ids = inputs [ 'input_ids' ]. squeeze (). tolist ()
# Define a set of token IDs that correspond to special tokens
special_tokens_ids = { loaded_tokenizer . cls_token_id , loaded_tokenizer . pad_token_id , loaded_tokenizer . eos_token_id }
# Filter the predicted labels using the special_tokens_ids to remove predictions for special tokens
binding_sites = [ label for label , token_id in zip ( predicted_labels , input_ids ) if token_id not in special_tokens_ids ]
print ( "Predicted binding sites:" , binding_sites )