esm2_loras
1.0.0
이것은 토큰 분류 작업에 대한 단백질 언어 모델 ESM-2에 대한 저 순위 적응 (LORA)을 훈련시키기위한 시도입니다. 특히, 우리는 RNA 결합 부위 예측 변수를 훈련 시키려고 시도한다. 여전히 해결해야 할 문제가 있으며 피드백이나 조언에 크게 감사 할 것입니다. 이 코드는 소형 모델을위한 것이므로 거의 모든 GPU에서 합리적인 시간 내에 하이퍼 파라미터 검색을 위해 Wandb 스윕을 수행해야합니다. 원하는 경우 더 큰 모델을 쉽게 교체 할 수 있습니다.
모델 자체
"AmelieSchreiber/esm2_t6_8M_UR50D_lora_rna_binding_sites"얼굴 포옹에서 찾을 수 있습니다.
콘다 환경을 설정하려면 Repo를 복제하고 실행하십시오.
conda env create -f environment.yml
그런 다음 실행 :
conda activate lora_esm_2
모델 실행을 훈련시키기 위해 :
from lora_esm2_script import train_protein_model
train_protein_model ()사용하려면 달리기를 시도하십시오.
from transformers import AutoModelForTokenClassification , AutoTokenizer
from peft import PeftModel
import torch
import numpy as np
import random
# Path to the saved LoRA model
model_path = "esm2_t6_8M-finetuned-lora_2023-08-03_18-32-25"
# ESM2 base model
base_model_path = "facebook/esm2_t6_8M_UR50D"
# Load the model
base_model = AutoModelForTokenClassification . from_pretrained ( base_model_path )
loaded_model = PeftModel . from_pretrained ( base_model , model_path )
# Load the tokenizer
loaded_tokenizer = AutoTokenizer . from_pretrained ( model_path )
# New unseen protein sequence
new_protein_sequence = "FDLNDFLEQKVLVRMEAIINSMTMKERAKPEIIKGSRKRRIAAGSGMQVQDVNRLLKQFDDMQRMMKKM"
# Tokenize the new sequence
inputs = loaded_tokenizer ( new_protein_sequence , truncation = True , padding = 'max_length' , max_length = 512 , return_tensors = "pt" )
# Make predictions
with torch . no_grad ():
outputs = loaded_model ( ** inputs )
logits = outputs . logits
predictions = torch . argmax ( logits , dim = 2 )
# Print logits for debugging
print ( "Logits:" , logits )
# Convert predictions to a list
predicted_labels = predictions . squeeze (). tolist ()
# Get input IDs to identify padding and special tokens
input_ids = inputs [ 'input_ids' ]. squeeze (). tolist ()
# Define a set of token IDs that correspond to special tokens
special_tokens_ids = { loaded_tokenizer . cls_token_id , loaded_tokenizer . pad_token_id , loaded_tokenizer . eos_token_id }
# Filter the predicted labels using the special_tokens_ids to remove predictions for special tokens
binding_sites = [ label for label , token_id in zip ( predicted_labels , input_ids ) if token_id not in special_tokens_ids ]
print ( "Predicted binding sites:" , binding_sites )