TRLX adalah kerangka kerja pelatihan terdistribusi yang dirancang dari bawah ke atas untuk fokus pada model bahasa besar yang menyempurnakan dengan pembelajaran penguatan menggunakan fungsi hadiah yang disediakan atau dataset berlabel hadiah.
Dukungan pelatihan untuk? Model pemeluk wajah disediakan oleh pelatih yang didukung akselerasi, memungkinkan pengguna untuk menyempurnakan model bahasa sebab akibat dan T5 dari parameter hingga 20B, seperti facebook/opt-6.7b , EleutherAI/gpt-neox-20b , dan google/flan-t5-xxl . Untuk model di luar parameter 20B, TRLX menyediakan pelatih yang didukung Nvidia Nemo yang memanfaatkan teknik paralelisme yang efisien untuk skala secara efektif.
Algoritma RL berikut saat ini diimplementasikan:
| Algoritma | Mempercepat pelatih | Pelatih Nemo |
|---|---|---|
| Optimalisasi Kebijakan Proksimal (PPO) | ✅ | ✅ |
| Bahasa implisit Q-Learning (ILQL) | ✅ | ✅ |
Dokumentasi
? Keju kumpulkan anotasi manusia untuk aplikasi RL Anda dengan pustaka pengumpulan data manusia-in-loop kami.
git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install torch --extra-index-url https://download.pytorch.org/whl/cu118
pip install -e . Untuk penggunaan lebih lanjut, lihat contoh. Anda juga dapat mencoba notebook Colab di bawah ini:
| Keterangan | Link |
|---|---|
| Simulacra (GPT2, ILQL) | |
| Sentimen (GPT2, ILQL) |
Contoh terbaru adalah pada bobot & bias kami
Anda dapat melatih model menggunakan fungsi hadiah atau dataset berlabel hadiah.
trainer = trlx . train ( 'gpt2' , reward_fn = lambda samples , ** kwargs : [ sample . count ( 'cats' ) for sample in samples ])Untuk pelatihan model hadiah, lihat Perpustakaan Otokrit kami.
trainer = trlx . train ( 'EleutherAI/gpt-j-6B' , samples = [ 'dolphins' , 'geese' ], rewards = [ 1.0 , 100.0 ]) trainer = trlx . train ( 'gpt2' , samples = [[ 'Question: 1 + 2 Answer:' , '3' ], [ 'Question: Solve this equation: ∀n>0, s=2, sum(n ** -s). Answer:' , '(pi ** 2)/ 6' ]]) trainer . generate ( ** tokenizer ( 'Q: Who rules the world? A:' , return_tensors = 'pt' ), do_sample = True ) from trlx . data . default_configs import default_ppo_config
config = default_ppo_config ()
config . model . model_path = 'EleutherAI/gpt-neox-20b'
config . tokenizer . tokenizer_path = 'EleutherAI/gpt-neox-20b'
config . train . seq_length = 2048
trainer = trlx . train ( config = config , reward_fn = lambda samples , ** kwargs : [ len ( sample ) for sample in samples ])Untuk mengurangi penggunaan memori (jika Anda mengalami CUDA keluar dari kesalahan memori), pertama -tama coba pengaturan terendah untuk hyperparameter berikut dan akhirnya meningkatkannya:
# micro batch size per gpu
config . train . batch_size = 1
# freeze all transformer layers
config . model . num_layers_unfrozen = 0
# maximum sample length, prompts or samples longer than that will be truncated
config . train . seq_length = 128
# micro batch size for sampling (specific for PPO)
config . method . chunk_size = 1
# use an additional Q-head (specific for ILQL)
config . method . two_qs = False trainer . save_pretrained ( '/path/to/output/folder/' )accelerate config # choose DeepSpeed option
accelerate launch examples/simulacra.pyIkuti instruksi pengaturan di Nemo Readme.
python examples/nemo_ilql_sentiments.pyUntuk penggunaan lebih lanjut, lihat Nemo Readme
ray start --head --port=6379
python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml --accelerate_config configs/accelerate/ddp.yaml --num_gpus 4 examples/ppo_sentiments.pymain TRLX python -m trlx.reference octocat/trlx-fork:fix-branch TRLX menggunakan pustaka logging python standar untuk mencatat informasi pelatihan ke konsol. Logger default diatur ke level INFO , yang berarti INFO , WARNING , ERROR , dan pesan tingkat CRITICAL akan dicetak ke output standar.
Untuk mengubah level log secara langsung, Anda dapat menggunakan setter verbositas. Misalnya, untuk mengatur level log ke penggunaan WARNING :
import trlx
trlx . logging . set_verbosity ( trlx . logging . WARNING ) Ini akan menekan pesan tingkat INFO , tetapi tetap mencetak WARNING , ERROR , dan pesan tingkat CRITICAL .
Anda juga dapat mengontrol verbositas logging dengan mengatur variabel lingkungan TRLX_VERBOSITY ke salah satu nama level logging standar:
CRITICAL ( trlx.logging.CRITICAL )ERROR ( trlx.logging.ERROR )WARNING ( trlx.logging.WARNING )INFO ( trlx.logging.INFO )DEBUG ( trlx.logging.DEBUG ) export TRLX_VERBOSITY=WARNING Secara default, bilah kemajuan tqdm digunakan untuk menampilkan kemajuan pelatihan. Anda dapat menonaktifkannya dengan memanggil trlx.logging.disable_progress_bar() , jika tidak trlx.logging.enable_progress_bar() untuk mengaktifkan.
Pesan dapat diformat dengan detail yang lebih besar dengan mengatur trlx.logging.enable_explicit_format() . Ini akan menyuntikkan informasi situs panggilan ke dalam setiap log yang mungkin bermanfaat untuk debugging.
[2023-01-01 05:00:00,000] [INFO] [ppo_orchestrator.py:63:make_experience] [RANK 0] Message...Kiat: Untuk mengurangi jumlah output logging, Anda mungkin merasa terbantu untuk mengubah level log perpustakaan pihak ketiga yang digunakan oleh TRLX. Misalnya, coba tambahkan
transformers.logging.set_verbosity_error()ke bagian atas skrip TRLX Anda untuk membungkam pesan verbose dari pustakatransformers(lihat dokumen logging mereka untuk detail lebih lanjut).
Untuk pengembangan, lihat pedoman ini dan juga baca dokumen kami
@inproceedings{havrilla-etal-2023-trlx,
title = "trl{X}: A Framework for Large Scale Reinforcement Learning from Human Feedback",
author = "Havrilla, Alexander and
Zhuravinskyi, Maksym and
Phung, Duy and
Tiwari, Aman and
Tow, Jonathan and
Biderman, Stella and
Anthony, Quentin and
Castricato, Louis",
booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing",
month = dec,
year = "2023",
address = "Singapore",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2023.emnlp-main.530",
doi = "10.18653/v1/2023.emnlp-main.530",
pages = "8578--8595",
}
Terima kasih banyak kepada Leandro von Werra yang telah berkontribusi dengan TRL, perpustakaan yang awalnya menginspirasi repo ini.