Cette base de code réimplémente LORA: adaptation de faible rang des modèles de grands langues (ICLR 2022) et est reconstruit sur la base du Loralib.
Les implémentations de loratorch et loralib sont très différentes. Nous prenons le nn.Linear comme exemple comme suit.
loralib , où
loratorch , loralib calculez loratorch fusionne le poids pré-formé nn.Linear.forward() . Il n'y a pas de différence entre loralib et loratorch dans les couches linéaires. Mais dans certaines couches non linéaires ou complexes, nous ne savons pas si cette couche satisfait loralib . Au contraire, l'idée de fusionner les poids en premier dans loratorch est plus générale et plus extensible. Vous venez d'appeler merge_lora_param() dans loratorch pour fusionner les poids, puis d'appeler forward() dans la couche d'origine pour calculer les résultats. Avec l'aide de loratorch , vous pouvez facilement implémenter Lora sur tout type de couche de torch.nn .
loralib | loratorch | ||
|---|---|---|---|
nn.Linear | ✓ | ✓ | linéaire.ipynb |
nn.Embedding | ✓ | ✓ | embedding.ipynb |
nn.Conv1d | ✓ | ✓ | |
nn.Conv2d | ✓ | ✓ | |
nn.Conv3d | ✓ | ✓ | |
nn.MultiheadAttention | ✘ | ✓ | |
MergedLinear | ✓ (erreur) | ✓ | MergedLinear.ipynb |
| difficile à étendre | facile à étendre |
Nous comparons les résultats de loralib et loratorch dans des exemples pour démontrer l'exactitude de la mise en œuvre dans loratorch .
L'utilisation de loratorch est la même que loralib .
Installez loratorch .
pip install git+https://github.com/Baijiong-Lin/LoRA-Torch
# Alternatively for developers
# git clone https://github.com/Baijiong-Lin/LoRA-Torch
# cd LoRA-Torch
# pip install -e . Remplacez les couches où vous souhaitez utiliser Lora en utilisant loratorch .
# ===== Before =====
# layer = nn.Linear(in_features, out_features)
# ===== After ======
import loratorch as lora
# Add a pair of low-rank adaptation matrices with rank r=16 and alpha=32
layer = lora . Linear ( in_features , out_features , r = 16 , lora_alpha = 32 )Marquez uniquement les paramètres LORA comme entraînant avant la boucle de formation.
model = Model ()
# (!!!) This sets requires_grad to False for all parameters without the string "lora_" in their names
lora . mark_only_lora_as_trainable ( model )
optimizer = torch . optim . SGD ( model . parameters (), lr = 0.1 )
# Training loop
for batch in dataloader :
model . train ()
# forward process
loss = forward_fun ( model , batch )
# backward process
optimizer . zero_grad ()
loss . backward ()
optimizer . step ()
# (!!!) reregister model param to ensure they are in model.state_dict() and model.parameters()
# (!!!) Without this line, the performance does not be affected but you will find that some weights are missing in model.state_dict() and model.parameters()
lora . register_model_param_after_backward ( model )Enregistrer le modèle LORA (seules les matrices LORA seront enregistrées).
# ===== Before =====
# torch.save(model.state_dict(), checkpoint_path)
# ===== After =====
torch . save ( lora . lora_state_dict ( model ), checkpoint_path )Chargez le modèle LORA (besoin de charger d'abord le modèle pré-formé).
# Load the pre-trained checkpoint first
model . load_state_dict ( torch . load ( 'ckpt_pretrained.pt' ), strict = False )
# Then load the LoRA checkpoint
model . load_state_dict ( torch . load ( 'ckpt_lora.pt' ), strict = False ) loratorch est développé et maintenu par Baijiong Lin.
Si vous avez une question ou une suggestion, n'hésitez pas à nous contacter en soulevant un problème ou en envoyant un e-mail à [email protected] .
loratorch est fortement basé sur loralib . Nous remercions ses auteurs pour leur base de code merveilleuse et open source.
Si vous trouvez loratorch utile pour votre recherche ou votre développement, veuillez citer ce qui suit:
@inproceedings { hu2022lora ,
title = { Lo{RA}: Low-Rank Adaptation of Large Language Models } ,
author = { Edward J Hu and Yelong Shen and Phillip Wallis and Zeyuan Allen-Zhu and Yuanzhi Li and Shean Wang and Lu Wang and Weizhu Chen } ,
booktitle = { International Conference on Learning Representations } ,
year = { 2022 } ,
}
@software { lin2023loratorch ,
author = { Baijiong Lin } ,
title = { {LoRA-Torch}: {PyTorch} Reimplementation of {LoRA} } ,
url = { https://github.com/Baijiong-Lin/LoRA-Torch } ,
year = { 2023 }
} loratorch est libéré sous la licence du MIT.