このコードベースは、LORAを再実装します:大規模な言語モデルの低ランク適応(ICLR 2022)は、ロラリブに基づいて再構築されます。
loratorchとloralibの実装は非常に異なっています。次のように、 nn.Linearを例として取ります。
loralibのために、 どこ
loratorchのために、 loralib Computes loratorchが事前に訓練された重量をマージしますnn.Linear.forward()を使用するだけで結果を計算します。線形層にloralibとloratorchに違いはありません。しかし、いくつかの非線形または複雑な層では、このレイヤーが満たされるかどうかはわかりませんloralibを使用して、ロラを複雑な層に拡張することは困難です。それどころか、 loratorchで最初に重みを融合するという考えは、より一般的で拡張可能です。 loratorchのmerge_lora_param()を呼び出して重みをマージし、元のレイヤーでforward()を呼び出して結果を計算します。 loratorchの助けを借りて、 torch.nnのあらゆる種類の層にLoraを簡単に実装できます。
loralib | loratorch | ||
|---|---|---|---|
nn.Linear | ✓✓ | ✓✓ | linear.ipynb |
nn.Embedding | ✓✓ | ✓✓ | Embedding.ipynb |
nn.Conv1d | ✓✓ | ✓✓ | |
nn.Conv2d | ✓✓ | ✓✓ | |
nn.Conv3d | ✓✓ | ✓✓ | |
nn.MultiheadAttention | ✘ | ✓✓ | |
MergedLinear | ✓(エラー) | ✓✓ | mergedLinear.ipynb |
| 拡張が難しい | 簡単に拡張できます |
loralibとloratorchの結果を例で比較して、 loratorchの実装の正確性を示します。
loratorchの使用はloralibと同じです。
loratorchをインストールします。
pip install git+https://github.com/Baijiong-Lin/LoRA-Torch
# Alternatively for developers
# git clone https://github.com/Baijiong-Lin/LoRA-Torch
# cd LoRA-Torch
# pip install -e . loratorchを使用してLoraを使用したいレイヤーを交換します。
# ===== Before =====
# layer = nn.Linear(in_features, out_features)
# ===== After ======
import loratorch as lora
# Add a pair of low-rank adaptation matrices with rank r=16 and alpha=32
layer = lora . Linear ( in_features , out_features , r = 16 , lora_alpha = 32 )トレーニングループの前にトレーニング可能なLORAパラメーターのみをマークします。
model = Model ()
# (!!!) This sets requires_grad to False for all parameters without the string "lora_" in their names
lora . mark_only_lora_as_trainable ( model )
optimizer = torch . optim . SGD ( model . parameters (), lr = 0.1 )
# Training loop
for batch in dataloader :
model . train ()
# forward process
loss = forward_fun ( model , batch )
# backward process
optimizer . zero_grad ()
loss . backward ()
optimizer . step ()
# (!!!) reregister model param to ensure they are in model.state_dict() and model.parameters()
# (!!!) Without this line, the performance does not be affected but you will find that some weights are missing in model.state_dict() and model.parameters()
lora . register_model_param_after_backward ( model )LORAモデルを保存します(LORAマトリックスのみが保存されます)。
# ===== Before =====
# torch.save(model.state_dict(), checkpoint_path)
# ===== After =====
torch . save ( lora . lora_state_dict ( model ), checkpoint_path )LORAモデルをロードします(最初に事前に訓練されたモデルをロードする必要があります)。
# Load the pre-trained checkpoint first
model . load_state_dict ( torch . load ( 'ckpt_pretrained.pt' ), strict = False )
# Then load the LoRA checkpoint
model . load_state_dict ( torch . load ( 'ckpt_lora.pt' ), strict = False )loratorchは、Baijiong Linによって開発および維持されています。
質問や提案がある場合は、問題を提起するか、 [email protected]にメールを送信して、お気軽にお問い合わせください。
loratorchはloralibに大きく基づいています。著者は、素晴らしいオープンソースコードベースに感謝します。
loratorchが研究や開発に役立つと思われる場合は、以下を引用してください。
@inproceedings { hu2022lora ,
title = { Lo{RA}: Low-Rank Adaptation of Large Language Models } ,
author = { Edward J Hu and Yelong Shen and Phillip Wallis and Zeyuan Allen-Zhu and Yuanzhi Li and Shean Wang and Lu Wang and Weizhu Chen } ,
booktitle = { International Conference on Learning Representations } ,
year = { 2022 } ,
}
@software { lin2023loratorch ,
author = { Baijiong Lin } ,
title = { {LoRA-Torch}: {PyTorch} Reimplementation of {LoRA} } ,
url = { https://github.com/Baijiong-Lin/LoRA-Torch } ,
year = { 2023 }
}loratorchはMITライセンスの下でリリースされます。