
Kami menyediakan blok bangunan yang mudah disesuaikan untuk model bahasa pelatihan termasuk implementasi algoritma dalam kebijakan , fungsi hadiah , metrik , dataset dan kebijakan aktor-kritik berbasis LM
Tautan kertas: https://arxiv.org/abs/2210.01241
Tautan situs web: https://rl4lms.apps.allenai.org/
Diuji secara menyeluruh dan dibandingkan dengan lebih dari 2000 percobaan (benchmark grue?) Pada seperangkat komprehensif:
Semua blok bangunan ini dapat disesuaikan memungkinkan pengguna untuk melatih LMS berbasis transformator untuk mengoptimalkan fungsi hadiah sewenang-wenang pada dataset pilihan mereka.
git clone https://github.com/allenai/RL4LMs.git
cd RL4LMs
pip install -e . Kami juga menyediakan dockerfile untuk pengembangan menggunakan wadah Docker yang berisi semua dependensi.
docker build . -t rl4lms Secara opsional, perpustakaan Corenlp diperlukan untuk perhitungan metrik tertentu (mis. Spice) yang dapat diunduh melalui cd rl4lms/envs/text_generation/caption_metrics/spice && bash get_stanford_models.sh
Kami menyediakan API pelatihan sederhana yang dapat dipanggil melalui skrip kereta yang memungkinkan untuk melatih PPO, NLPO atau model yang diawasi dengan menggunakan file konfigurasi (YAML).
Misalnya, untuk melatih T5-Base pada ringkasan CNN/DM pada PPO menggunakan rouge-1 sebagai fungsi hadiah, Anda dapat menjalankan:
python scripts/training/train_text_generation.py --config_path scripts/training/task_configs/summarization/t5_ppo.ymlFile konfigurasi untuk semua tugas dapat ditemukan di sini.
File config berisi detail tentang pengaturan hiper-parameter untuk blok bangunan yang dijelaskan di bawah ini:
Dataset/Tugas : Dataset yang berisi sampel dengan prompt input dan kalimat referensi. Dataset yang tersedia ditemukan di kelas DataPoolRegistry di Registry. (Lihat cara membuat dataset Anda sendiri di sini)
datapool :
id : cnn_daily_mail
args :
prompt_prefix : " Summarize: "Tokenizer - Tokenizer terlatih yang digunakan untuk (DE) tokenize input dan urutan output dengan pengaturan untuk bantalan dan pemotongan
tokenizer :
model_name : t5-base
padding_side : left
truncation_side : left
pad_token_as_eos_token : False Fungsi Hadiah : Fungsi Hadiah yang menghitung skor tingkat token pada setiap langkah MDP. Fungsi hadiah yang tersedia dapat ditemukan di kelas RewardFunctionRegistry . (Lihat cara membuat fungsi hadiah Anda sendiri di sini)
reward_fn :
id : rouge
args :
rouge_type : " rouge1 " Lingkungan : Mengkonfigurasi lingkungan pembuatan teks gaya gym yang mensimulasikan episode MDP. Peluncuran dihasilkan menggunakan sampel kereta dari dataset yang terdiri dari input dan teks referensi. Lebih lanjut, kami membungkus env kami dengan SubProcVecEnv dari stabil-baselines yang memproses episode n_envs secara paralel menggunakan multi-proses untuk menghitung imbalan langkah bijaksana.
Pengaturan konfigurasi lebih lanjut meliputi:
max_episode_length : Max Length of the Episodemax_prompt_length - Panjang maksimum teks input untuk dipertimbangkanterminate_on_eos - apakah akan mengakhiri episode segera setelah tindakan EOS dilakukanprompt_truncation_side - sisi pemotongan untuk teks promptcontext_start_token - ID untuk token konteks (sesuai dengan token awal yang diberikan kepada decoder dalam model encoder -decoder) env :
n_envs : 10
args :
max_prompt_length : 512
max_episode_length : 100
terminate_on_eos : True
prompt_truncation_side : " right "
context_start_token : 0ALG On-Policy : Kami menyediakan implementasi 4 algoritma kebijakan: PPO, NLPO, A2C dan TRPO yang diadaptasi dari stabil-baselines3 yang disesuaikan untuk bekerja dengan tugas NLP yang dapat digunakan di luar kotak dengan kebijakan kausal atau kebijakan SEQ2SEQ LM. (Lihat cara membuat algoritma atau kebijakan on-policy Anda sendiri)
Kami juga menyediakan pelatih yang diawasi untuk tujuan pembandingan. Model start hangat yang diawasi sudah diunggah ke hub HuggingFace dan ditentukan dalam file konfigurasi masing -masing.
Hyper-parameter untuk algoritma dapat ditentukan di alg/args .
Lebih lanjut, semua algoritma RL menggunakan pengontrol KL adaptif untuk menjaga LM tetap dekat dengan LM asli dengan mengatur koefisien KL awal ( alg/kl_div/coeff ) dan target KL ( alg/kl_div/target_kl ).
Kami mendukung dua jenis kebijakan LM: kebijakan LM kausal (untuk model decoder saja) dan kebijakan SEQ2SEQ LM (untuk model encoder-decoder). Lebih jauh untuk NLPO, kami juga memberikan varian topeng ini. Implementasi kebijakan dapat ditemukan di sini dan dapat dilampirkan pada algoritma dengan menentukan alg/policy/id dan alg/policy/args
alg :
id : ppo
args :
n_steps : 512
batch_size : 64
verbose : 1
learning_rate : 0.000002
n_epochs : 5
ent_coef : 0.0
kl_div :
coeff : 0.001
target_kl : 0.2
policy :
id : seq2seq_lm_actor_critic_policy
args :
model_name : t5-base
apply_model_parallel : True
prompt_truncation_side : " right "
generation_kwargs :
do_sample : True
top_k : 50
min_length : 50
max_new_tokens : 100 Konfigurasi Pelatih : Kami menyediakan pelatih on-policy-pembungkus fitur-lengkap yang membuat instantiates blok bangunan dari konfigurasi yang sesuai dan menyediakan loop pelatihan luar yang terdiri dari kereta dan evaluasi train_evaluation/n_iters .
alg/args/n_steps X env/n_envs dari algoritma yang dipilih.eval_every iters, LM dievaluasi pada split validasi menggunakan metrik yang tercantum dalam train_evaluation/metrics dengan generasi kwargs yang disediakan dalam train_evaluation/generation_kwargs (ini mengesampingkan peluncuran alg/policy/generation_kwargs hanya untuk tujuan inferensi) # train and evaluation
train_evaluation :
eval_batch_size : 100
n_iters : 100
eval_every : 10
save_every : 1
metrics :
- id : meteor
args : {}
- id : rouge
- id : bleu
args : {}
- id : bert_score
args :
language : en
- id : diversity
args : {}
generation_kwargs :
do_sample : True
top_k : 0
temperature : 0.7
min_length : 50
max_new_tokens : 100RL4LMS menyediakan kemampuan kustomisasi lengkap-sehubungan dengan menambahkan tugas/dataset baru, fungsi hadiah, metrik evaluasi, algoritma on-policy dan kebijakan aktor-kritik.
Pengguna dapat membuat set data mereka sendiri dengan sub -kelas TextGenPool hanya dengan mengesampingkan prepare(cls, split: str, **args) -> 'TextGenPool': metode untuk mengembalikan instance TextGenPool. Contoh ditunjukkan di bawah ini:
from rl4lms . data_pools . text_generation_pool import Sample , TextGenPool
class MyDataPool ( TextGenPool ):
@ classmethod
def prepare ( cls , split : str ):
..
samples = []
for ix , item in enumerate (..):
sample = Sample ( id = f" { split } _ { ix } " ,
prompt_or_input_text = item [ "document" ],
references = [ item [ "target" ]]
)
samples . append ( sample )
pool_instance = cls ( samples )
return pool_instance Kustom Hadiah Kustom dapat diimplementasikan dengan mudah dengan fungsi penghargaan sub-kelas (yang dapat dipanggil) yang mengambil observasi (
from rl4lms . envs . text_generation . observation import Observation
from rl4lms . envs . text_generation . reward import RewardFunction
class MyRewardFunction ( RewardFunction ):
def __init__ ( self , * args ) -> None :
super (). __init__ ()
def __call__ ( self , prev_observation : Observation ,
action : int ,
current_observation : Observation ,
done : bool ,
meta_info : Dict [ str , Any ] = None ) -> float :
if done :
reward = ..
return reward
return 0Selain metrik NLG tradisional, untuk prototipe cepat, kami menyediakan dua fungsi hadiah sintetis yang melatih LMS untuk menghasilkan angka dalam meningkatkan urutan dan menghasilkan tanggal. Ini dapat digunakan untuk dengan cepat menguji algoritma dan kebijakan yang berbeda. Konfigurasi yang sesuai dapat ditemukan di sini (angka, tanggal)
Pengguna dapat membuat metrik evaluasi mereka sendiri yang kemudian akan digunakan untuk mengevaluasi model secara berkala pada pemisahan dataset validasi. Ini dapat dilakukan dengan basemetrik sub-kelas yang mengambil teks cepat, teks yang dihasilkan, teks referensi, meta_infos, model LM saat ini, nama terpisah sebagai input dan mengembalikan dikt dengan nama metrik sebagai kunci dan nilai yang terdiri dari tuple skor tingkat kalimat dan skor tingkat korpus. Contohnya adalah sebagai berikut:
from rl4lms . envs . text_generation . metric import BaseMetric
class MyMetric ( BaseMetric ):
def __init__ ( self ) -> None :
super (). __init__ ()
def compute ( self ,
prompt_texts : List [ str ],
generated_texts : List [ str ],
reference_texts : List [ List [ str ]],
meta_infos : List [ Dict [ str , Any ]] = None ,
model : PreTrainedModel = None ,
split_name : str = None ):
metric_dict = {
"custom_metrics/my_metric" : ([ 0.4 , 0.7 , 0.9 ], 0.7 )
}
return metric_dict Selain algoritma kebijakan yang didukung (PPO, NLPO, A2C, TRPO), pengguna dapat mengimplementasikan algoritma on-policy mereka sendiri dengan mudah dengan sub-kelas-baselines3 onpolicyalgorithm. Karena kami menyediakan pembungkus untuk algoritma on-policy yang menangani peluncuran menggunakan kebijakan LM, lingkungan, penghargaan komputasi dll, pengguna hanya perlu menerapkan metode train() dengan fungsi kerugian khusus.
from stable_baselines3 . common . on_policy_algorithm import OnPolicyAlgorithm
class MyOnPolicyAlgorithm ( OnPolicyAlgorithm ):
def __init__ ( ** args ):
super (). __init__ ( ** args )
def train ( self ) -> None :
# train for n_epochs epochs
for epoch in range ( self . n_epochs ):
# Do a complete pass on the rollout buffer
for rollout_data in self . rollout_buffer . get ( self . batch_size ):
# compute loss Kami menyediakan implementasi kebijakan aktor-kritik berbasis LM yang membungkus LM kausal dan SEQ2SEQ LMS. Ini juga dapat diperluas (misalnya: Gunakan arsitektur kritik yang berbeda) dengan mengesampingkan metode yang tepat (mis. evaluate_actions() )
Akhirnya, cukup daftarkan komponen khusus Anda dengan menambahkannya ke registri yang sesuai, setelah itu dapat digunakan langsung dari konfigurasi yang mirip dengan komponen yang telah ditentukan sebelumnya
Kami telah memberikan templat crowdsourcing yang kami gunakan pada Turki mekanis, bersama dengan contoh input dalam scripts/crowdworking_templates . Anda mungkin menemukan ini titik awal yang bermanfaat baik untuk mengevaluasi generasi model Anda sendiri, atau untuk mengumpulkan data pelatihan untuk fungsi hadiah yang dipelajari.
Selain itu, kami mendukung penebangan wandb dan pelatihan yang hangat dengan menyimpan pos pemeriksaan dan artefak pelatihan lainnya di jalur yang ditentukan pengguna. Ini sangat berguna untuk menjalankan pekerjaan preemptible pada kelompok besar yang dijadwalkan.
Artefak meliputi (1) file jsonl yang berisi info peluncuran pada interval tertentu (2) file jsonl yang berisi info pelatihan pada interval tertentu (3) file jsonl yang berisi metrik validasi pada interval tertentu (4) file jsonL (6) prediksi (6) piring (6) pelatihan uji dengan prediksi json dengan prediksi validasi pada prediksi validasi yang ditentukan (6) sebelum prediksi (6) JSON dengan prediksi validasi pada prediksi 2) dengan prediksi tes (6) dengan prediksi 2) Json biasa menjalankan percobaan
Penggunaan lengkap adalah sebagai berikut:
WANDB_API_KEY= < YOUR-WANDB-API-KEY-HERE > python scripts/training/train_text_generation.py
--config_path < PATH-TO-CONFIG-FILE >
--experiment_name < EXPERIMENT-NAME >
--base_path_to_store_results < PATH-TO-STORE-RESULTS >
--log_to_wandb @inproceedings { Ramamurthy2022IsRL ,
title = { Is Reinforcement Learning (Not) for Natural Language Processing?: Benchmarks, Baselines, and Building Blocks for Natural Language Policy Optimization } ,
author = { Rajkumar Ramamurthy and Prithviraj Ammanabrolu and Kiant{'e} Brantley and Jack Hessel and Rafet Sifa and Christian Bauckhage and Hannaneh Hajishirzi and Yejin Choi } ,
journal = { arXiv preprint arXiv:2210.01241 } ,
url = { https://arxiv.org/abs/2210.01241 } ,
year = { 2022 }
}Untuk diskusi, pertanyaan, pertukaran ide, bergabunglah dengan Slack Channel kami