TRLXは、提供された報酬機能または報酬標識データセットのいずれかを使用して、補強学習を備えた大規模な言語モデルを微調整することに焦点を当てるために設計された分散トレーニングフレームワークです。
のトレーニングサポート?ハギングフェイスモデルは、Accelerate Backed Trainersによって提供され、ユーザーはfacebook/opt-6.7b 、 EleutherAI/gpt-neox-20b 、 google/flan-t5-xxl 。 20Bパラメーターを超えるモデルの場合、TRLXは、効率的な並列性テクニックを活用して効果的にスケーリングするNVIDIAネモ支援トレーナーを提供します。
現在、次のRLアルゴリズムが実装されています。
| アルゴリズム | トレーナーを加速します | Nemoトレーナー |
|---|---|---|
| 近位政策最適化(PPO) | ✅ | ✅ |
| 暗黙の言語Qラーニング(ILQL) | ✅ | ✅ |
ドキュメント
?チーズは、人間のループデータ収集ライブラリを使用して、RLアプリケーションの人間の注釈を収集します。
git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install torch --extra-index-url https://download.pytorch.org/whl/cu118
pip install -e . その他の使用については、例を参照してください。以下のコラブノートブックをお試しください。
| 説明 | リンク |
|---|---|
| Simulacra(GPT2、ILQL) | |
| 感情(GPT2、ILQL) |
例の最新の実行は私たちのウェイトとバイアスにあります
報酬機能または報酬標識データセットを使用してモデルをトレーニングできます。
trainer = trlx . train ( 'gpt2' , reward_fn = lambda samples , ** kwargs : [ sample . count ( 'cats' ) for sample in samples ])報酬モデルトレーニングについては、オートクリットライブラリを参照してください。
trainer = trlx . train ( 'EleutherAI/gpt-j-6B' , samples = [ 'dolphins' , 'geese' ], rewards = [ 1.0 , 100.0 ]) trainer = trlx . train ( 'gpt2' , samples = [[ 'Question: 1 + 2 Answer:' , '3' ], [ 'Question: Solve this equation: ∀n>0, s=2, sum(n ** -s). Answer:' , '(pi ** 2)/ 6' ]]) trainer . generate ( ** tokenizer ( 'Q: Who rules the world? A:' , return_tensors = 'pt' ), do_sample = True ) from trlx . data . default_configs import default_ppo_config
config = default_ppo_config ()
config . model . model_path = 'EleutherAI/gpt-neox-20b'
config . tokenizer . tokenizer_path = 'EleutherAI/gpt-neox-20b'
config . train . seq_length = 2048
trainer = trlx . train ( config = config , reward_fn = lambda samples , ** kwargs : [ len ( sample ) for sample in samples ])メモリの使用量を削減するには(メモリエラーからCUDAが発生している場合)、最初に次のハイパーパラメーターの最低設定を試して、最終的にそれらを増やしてください。
# micro batch size per gpu
config . train . batch_size = 1
# freeze all transformer layers
config . model . num_layers_unfrozen = 0
# maximum sample length, prompts or samples longer than that will be truncated
config . train . seq_length = 128
# micro batch size for sampling (specific for PPO)
config . method . chunk_size = 1
# use an additional Q-head (specific for ILQL)
config . method . two_qs = False trainer . save_pretrained ( '/path/to/output/folder/' )accelerate config # choose DeepSpeed option
accelerate launch examples/simulacra.pyNemo Readmeのセットアップ手順に従ってください。
python examples/nemo_ilql_sentiments.pyより使用するには、Nemo Readmeを参照してください
ray start --head --port=6379
python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml --accelerate_config configs/accelerate/ddp.yaml --num_gpus 4 examples/ppo_sentiments.pymainブランチに対してベンチマークしますpython -m trlx.reference octocat/trlx-fork:fix-branchTRLXは、標準のPython loggingライブラリを使用して、トレーニング情報をコンソールにログに記録します。デフォルトのロガーはINFOレベルに設定されています。つまり、 INFO 、 WARNING 、 ERROR 、およびCRITICALレベルのメッセージが標準出力に印刷されることを意味します。
ログレベルを直接変更するには、冗長セッターを使用できます。たとえば、ログレベルを使用WARNING使用するように設定するには:
import trlx
trlx . logging . set_verbosity ( trlx . logging . WARNING )これにより、 INFOレベルのメッセージが抑制されますが、それでもWARNING 、 ERROR 、 CRITICALレベルのメッセージを印刷します。
TRLX_VERBOSITY環境変数を標準ロギングレベルの名前の1つに設定することにより、ロギングの冗長性を制御することもできます。
CRITICAL ( trlx.logging.CRITICAL )ERROR ( trlx.logging.ERROR )WARNING ( trlx.logging.WARNING )INFO ( trlx.logging.INFO )DEBUG ( trlx.logging.DEBUG ) export TRLX_VERBOSITY=WARNINGデフォルトでは、 tqdm Progressバーを使用してトレーニングの進行状況を表示します。 trlx.logging.disable_progress_bar()を呼び出して、 trlx.logging.enable_progress_bar()有効にすることで無効にできます。
trlx.logging.enable_explicit_format()を設定することにより、メッセージをより詳細にフォーマットできます。これにより、各ログにコールサイト情報が挿入されます。これは、デバッグに役立つ場合があります。
[2023-01-01 05:00:00,000] [INFO] [ppo_orchestrator.py:63:make_experience] [RANK 0] Message...ヒント:ロギング出力の量を減らすために、TRLXが使用するサードパーティライブラリのログレベルを変更すると役立つ場合があります。たとえば、
transformers.logging.set_verbosity_error()をtrlxスクリプトの上部に追加して、transformersライブラリからの冗長なメッセージを黙らせることを試みてください(詳細については、ロギングドキュメントを参照してください)。
開発のためにこれらのガイドラインをチェックして、私たちのドキュメントも読む
@inproceedings{havrilla-etal-2023-trlx,
title = "trl{X}: A Framework for Large Scale Reinforcement Learning from Human Feedback",
author = "Havrilla, Alexander and
Zhuravinskyi, Maksym and
Phung, Duy and
Tiwari, Aman and
Tow, Jonathan and
Biderman, Stella and
Anthony, Quentin and
Castricato, Louis",
booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing",
month = dec,
year = "2023",
address = "Singapore",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2023.emnlp-main.530",
doi = "10.18653/v1/2023.emnlp-main.530",
pages = "8578--8595",
}
Leandro Von Werraは、このレポを最初にインスピレーションを与えたライブラリであるTRLに貢献してくれたことに感謝します。