esm2_loras
1.0.0
これは、トークン分類タスクのタンパク質言語モデルESM-2の低ランク適応(LORA)をトレーニングする試みです。特に、RNA結合部位の予測因子を訓練しようとします。まだいくつかの問題があり、フィードバックやアドバイスは大歓迎です。このコードは小さなモデル用であるため、ほぼすべてのGPUで妥当な時間でハイパーパラメーター検索のWANDBスイープを実行する必要があります。必要に応じて、より大きなモデルに簡単に交換できます。
モデル自体
"AmelieSchreiber/esm2_t6_8M_UR50D_lora_rna_binding_sites"ここで顔を抱き締めることができます。
Conda環境をセットアップするには、リポジトリをクローンして実行します。
conda env create -f environment.yml
その後、実行:
conda activate lora_esm_2
モデルの実行をトレーニングするには:
from lora_esm2_script import train_protein_model
train_protein_model ()使用するには、実行してみてください。
from transformers import AutoModelForTokenClassification , AutoTokenizer
from peft import PeftModel
import torch
import numpy as np
import random
# Path to the saved LoRA model
model_path = "esm2_t6_8M-finetuned-lora_2023-08-03_18-32-25"
# ESM2 base model
base_model_path = "facebook/esm2_t6_8M_UR50D"
# Load the model
base_model = AutoModelForTokenClassification . from_pretrained ( base_model_path )
loaded_model = PeftModel . from_pretrained ( base_model , model_path )
# Load the tokenizer
loaded_tokenizer = AutoTokenizer . from_pretrained ( model_path )
# New unseen protein sequence
new_protein_sequence = "FDLNDFLEQKVLVRMEAIINSMTMKERAKPEIIKGSRKRRIAAGSGMQVQDVNRLLKQFDDMQRMMKKM"
# Tokenize the new sequence
inputs = loaded_tokenizer ( new_protein_sequence , truncation = True , padding = 'max_length' , max_length = 512 , return_tensors = "pt" )
# Make predictions
with torch . no_grad ():
outputs = loaded_model ( ** inputs )
logits = outputs . logits
predictions = torch . argmax ( logits , dim = 2 )
# Print logits for debugging
print ( "Logits:" , logits )
# Convert predictions to a list
predicted_labels = predictions . squeeze (). tolist ()
# Get input IDs to identify padding and special tokens
input_ids = inputs [ 'input_ids' ]. squeeze (). tolist ()
# Define a set of token IDs that correspond to special tokens
special_tokens_ids = { loaded_tokenizer . cls_token_id , loaded_tokenizer . pad_token_id , loaded_tokenizer . eos_token_id }
# Filter the predicted labels using the special_tokens_ids to remove predictions for special tokens
binding_sites = [ label for label , token_id in zip ( predicted_labels , input_ids ) if token_id not in special_tokens_ids ]
print ( "Predicted binding sites:" , binding_sites )