
เราให้การสร้างบล็อกแบบปรับแต่งได้อย่างง่ายดายสำหรับแบบจำลองภาษาการฝึกอบรมรวมถึงการใช้งาน อัลกอริทึมตามนโยบาย , ฟังก์ชั่นรางวัล , ตัวชี้วัด , ชุดข้อมูล และ นโยบายนักแสดงตาม LM
ลิงค์กระดาษ: https://arxiv.org/abs/2210.01241
ลิงค์เว็บไซต์: https://rl4lms.apps.allenai.org/
ทดสอบ และ เปรียบเทียบ อย่างละเอียดด้วย การทดลองมากกว่า 2,000 ครั้ง (GRUE Benchmark?) ในชุดที่ครอบคลุม:
การสร้างบล็อกทั้งหมดเหล่านี้สามารถปรับแต่งได้ช่วยให้ผู้ใช้สามารถฝึกอบรม LMS ที่ใช้หม้อแปลงเพื่อเพิ่มประสิทธิภาพฟังก์ชั่นการให้รางวัลโดยพลการใด ๆ ในชุดข้อมูลใด ๆ ที่เลือก
git clone https://github.com/allenai/RL4LMs.git
cd RL4LMs
pip install -e . นอกจากนี้เรายังจัดให้มี Dockerfile สำหรับการพัฒนาโดยใช้คอนเทนเนอร์ Docker ที่มีการอ้างอิงทั้งหมด
docker build . -t rl4lms เป็นทางเลือกจำเป็นต้องใช้ไลบรารี CorenLP สำหรับการคำนวณตัวชี้วัดบางอย่าง (เช่นเครื่องเทศ) ซึ่งสามารถดาวน์โหลดได้ผ่าน cd rl4lms/envs/text_generation/caption_metrics/spice && bash get_stanford_models.sh
เราให้บริการ API การฝึกอบรมอย่างง่ายที่สามารถเรียกใช้ผ่านสคริปต์รถไฟที่อนุญาตให้ฝึกอบรม PPO, NLPO หรือแบบจำลองภายใต้การดูแลโดยใช้ไฟล์ config (YAML)
ตัวอย่างเช่นในการฝึกอบรม T5-base ในการสรุป CNN/DM บน PPO โดยใช้ Rouge-1 เป็นฟังก์ชั่นรางวัลคุณสามารถเรียกใช้:
python scripts/training/train_text_generation.py --config_path scripts/training/task_configs/summarization/t5_ppo.ymlไฟล์กำหนดค่าสำหรับงานทั้งหมดสามารถพบได้ที่นี่
ไฟล์กำหนดค่ามีรายละเอียดเกี่ยวกับการตั้งค่าพารามิเตอร์ไฮเปอร์สำหรับการสร้างบล็อกซึ่งอธิบายไว้ด้านล่าง:
ชุดข้อมูล/งาน : ชุดข้อมูลที่มีตัวอย่างที่มีพรอมต์อินพุตและประโยคอ้างอิง ชุดข้อมูลที่มีอยู่จะพบได้ในคลาส DataPoolRegistry ในรีจิสทรี (ดูวิธีสร้างชุดข้อมูลของคุณเองที่นี่)
datapool :
id : cnn_daily_mail
args :
prompt_prefix : " Summarize: "Tokenizer - tokenizer ที่ผ่านการฝึกอบรมมาก่อนที่ใช้ (de) tokenize ลำดับอินพุตและเอาต์พุตด้วยการตั้งค่าสำหรับการขยายและการตัดทอน
tokenizer :
model_name : t5-base
padding_side : left
truncation_side : left
pad_token_as_eos_token : False ฟังก์ชั่นรางวัล : ฟังก์ชั่นรางวัลที่คำนวณคะแนนระดับโทเค็นในแต่ละขั้นตอนของ MDP ฟังก์ชั่นการให้รางวัลที่มีอยู่สามารถพบได้ในชั้นเรียน RewardFunctionRegistry (ดูวิธีการสร้างฟังก์ชั่นรางวัลของคุณเองที่นี่)
reward_fn :
id : rouge
args :
rouge_type : " rouge1 " สภาพแวดล้อม : กำหนดค่าสภาพแวดล้อมการสร้างข้อความสไตล์ยิมซึ่งจำลองตอน MDP การเปิดตัวถูกสร้างขึ้นโดยใช้ตัวอย่างรถไฟจากชุดข้อมูลซึ่งประกอบด้วยข้อความอินพุตและอ้างอิง นอกจากนี้เราห่อ Env ของเราด้วย SubProcVecEnv จากเส้นเสถียรที่ประมวลผลตอน n_envs แบบคู่ขนานโดยใช้การประมวลผลแบบหลายครั้งเพื่อคำนวณรางวัลที่ชาญฉลาด
การตั้งค่าการกำหนดค่าเพิ่มเติม ได้แก่ :
max_episode_length : ความยาวสูงสุดของตอนmax_prompt_length - ความยาวสูงสุดของข้อความอินพุตที่จะพิจารณาterminate_on_eos - จะยุติตอนทันทีที่ดำเนินการ EOSprompt_truncation_side - ด้านการตัดทอนสำหรับข้อความพรอมต์context_start_token - id สำหรับโทเค็นบริบท (สอดคล้องกับโทเค็นเริ่มต้นที่มอบให้กับตัวถอดรหัสในโมเดลตัวเข้ารหัสตัวพิมพ์ใหญ่) env :
n_envs : 10
args :
max_prompt_length : 512
max_episode_length : 100
terminate_on_eos : True
prompt_truncation_side : " right "
context_start_token : 0ON-POCICY ALG : เราให้การใช้งานอัลกอริธึมตามนโยบาย 4: PPO, NLPO, A2C และ TRPO ที่ปรับให้เข้ามาจากการทำงานที่เสถียร 3 ซึ่งปรับแต่งให้ทำงานกับงาน NLP ซึ่งสามารถใช้งานนอกกรอบที่มีนโยบายเชิงสาเหตุหรือนโยบาย SEQ2SEQ LM (ดูวิธีสร้างอัลกอริทึมหรือนโยบายตามนโยบายของคุณเอง)
นอกจากนี้เรายังจัดให้มีผู้ฝึกสอนดูแลเพื่อวัตถุประสงค์ในการเปรียบเทียบ แบบจำลองการเริ่มต้นที่อบอุ่นภายใต้การดูแลได้อัปโหลดไปยัง HuggingFace Hub แล้วและระบุไว้ในไฟล์ config ที่เกี่ยวข้อง
พารามิเตอร์ไฮเปอร์สำหรับอัลกอริทึมสามารถระบุได้ที่ alg/args
นอกจากนี้อัลกอริทึม RL ทั้งหมดใช้ตัวควบคุม KL Adaptive เพื่อให้ LM อยู่ใกล้กับ LM ดั้งเดิมโดยการตั้งค่าเริ่มต้น KL Co-Efficient ( alg/kl_div/coeff ) และเป้าหมาย KL ( alg/kl_div/target_kl )
เราสนับสนุนนโยบาย LM สองประเภท: นโยบาย LM เชิงสาเหตุ (สำหรับตัวถอดรหัสเฉพาะโมเดล) และ นโยบาย SEQ2SEQ LM (สำหรับโมเดลตัวเข้ารหัส Decoder) นอกจากนี้สำหรับ NLPO เรายังมีตัวแปรที่สวมหน้ากากเหล่านี้ การใช้นโยบายสามารถพบได้ที่นี่และสามารถแนบกับอัลกอริทึมโดยการระบุ alg/policy/id และ alg/policy/args
alg :
id : ppo
args :
n_steps : 512
batch_size : 64
verbose : 1
learning_rate : 0.000002
n_epochs : 5
ent_coef : 0.0
kl_div :
coeff : 0.001
target_kl : 0.2
policy :
id : seq2seq_lm_actor_critic_policy
args :
model_name : t5-base
apply_model_parallel : True
prompt_truncation_side : " right "
generation_kwargs :
do_sample : True
top_k : 50
min_length : 50
max_new_tokens : 100 การกำหนดค่าเทรนเนอร์ : เราให้บริการเทรนเนอร์ตามนโยบาย-wrapper-feature-complete ที่มีการสร้างบล็อกแบบอินสแตนซ์จากการกำหนดค่าที่สอดคล้องกันและให้การฝึกอบรมด้านนอกซึ่งประกอบด้วย รถไฟ และการทำซ้ำการทำ ซ้ำ train_evaluation/n_iters
alg/args/n_steps X env/n_envs ของอัลกอริทึมที่เลือกeval_every iters, LM ได้รับการประเมินในการแยกการตรวจสอบความถูกต้องโดยใช้ตัวชี้วัดที่ระบุไว้ใน train_evaluation/metrics ที่มีการสร้าง Kwargs ที่ให้ไว้ใน train_evaluation/generation_kwargs alg/policy/generation_kwargs # train and evaluation
train_evaluation :
eval_batch_size : 100
n_iters : 100
eval_every : 10
save_every : 1
metrics :
- id : meteor
args : {}
- id : rouge
- id : bleu
args : {}
- id : bert_score
args :
language : en
- id : diversity
args : {}
generation_kwargs :
do_sample : True
top_k : 0
temperature : 0.7
min_length : 50
max_new_tokens : 100RL4LMS ให้การปรับแต่งได้อย่างสมบูรณ์-เกี่ยวกับการเพิ่มงาน/ชุดข้อมูลใหม่ฟังก์ชั่นรางวัลการประเมินผลการประเมินอัลกอริทึมตามนโยบายและนโยบายนักแสดง-นักวิจารณ์
ผู้ใช้สามารถสร้างชุดข้อมูลของตัวเองโดยการจัดประเภท TextGenPool โดยการแทนที่ prepare(cls, split: str, **args) -> 'TextGenPool': วิธีการส่งคืนอินสแตนซ์ของ TextGenPool ตัวอย่างแสดงด้านล่าง:
from rl4lms . data_pools . text_generation_pool import Sample , TextGenPool
class MyDataPool ( TextGenPool ):
@ classmethod
def prepare ( cls , split : str ):
..
samples = []
for ix , item in enumerate (..):
sample = Sample ( id = f" { split } _ { ix } " ,
prompt_or_input_text = item [ "document" ],
references = [ item [ "target" ]]
)
samples . append ( sample )
pool_instance = cls ( samples )
return pool_instance การให้รางวัลที่กำหนดเองสามารถนำไปใช้ได้อย่างง่ายดายโดยการจัดอันดับของรางวัล (เรียกได้) ซึ่งใช้การสังเกต (
from rl4lms . envs . text_generation . observation import Observation
from rl4lms . envs . text_generation . reward import RewardFunction
class MyRewardFunction ( RewardFunction ):
def __init__ ( self , * args ) -> None :
super (). __init__ ()
def __call__ ( self , prev_observation : Observation ,
action : int ,
current_observation : Observation ,
done : bool ,
meta_info : Dict [ str , Any ] = None ) -> float :
if done :
reward = ..
return reward
return 0นอกเหนือจากตัวชี้วัด NLG แบบดั้งเดิมสำหรับการสร้างต้นแบบอย่างรวดเร็วเรายังมีฟังก์ชั่นรางวัลสังเคราะห์สองฟังก์ชั่นที่ฝึก LMS เพื่อสร้างตัวเลขในการเพิ่มลำดับและสร้างวันที่ สิ่งเหล่านี้สามารถใช้เพื่อทดสอบอัลกอริทึมและนโยบายที่แตกต่างกันได้อย่างรวดเร็ว การกำหนดค่าที่สอดคล้องกันสามารถพบได้ที่นี่ (ตัวเลขวันที่)
ผู้ใช้สามารถสร้างตัวชี้วัดการประเมินของตนเองซึ่งจะใช้ในการประเมินแบบจำลองเป็นระยะในการแยกการตรวจสอบความถูกต้องของชุดข้อมูล สิ่งนี้สามารถทำได้โดย basemetric sub-classing ซึ่งใช้ข้อความแจ้งข้อความที่สร้างข้อความข้อความอ้างอิง meta_infos รุ่น LM ปัจจุบันชื่อแยกเป็นอินพุตและส่งคืนคำสั่งที่มีชื่อเมตริกเป็นคีย์และค่าที่ประกอบด้วยคะแนนระดับประโยคและคะแนนระดับคลังข้อมูล ตัวอย่างมีดังนี้:
from rl4lms . envs . text_generation . metric import BaseMetric
class MyMetric ( BaseMetric ):
def __init__ ( self ) -> None :
super (). __init__ ()
def compute ( self ,
prompt_texts : List [ str ],
generated_texts : List [ str ],
reference_texts : List [ List [ str ]],
meta_infos : List [ Dict [ str , Any ]] = None ,
model : PreTrainedModel = None ,
split_name : str = None ):
metric_dict = {
"custom_metrics/my_metric" : ([ 0.4 , 0.7 , 0.9 ], 0.7 )
}
return metric_dict นอกเหนือจากอัลกอริทึมที่รองรับนโยบาย (PPO, NLPO, A2C, TRPO) ผู้ใช้สามารถใช้อัลกอริทึมตามนโยบายของตนเองได้อย่างง่ายดายโดยการจัดประเภท sub-baselines3 เนื่องจากเราจัดเตรียม wrappers สำหรับอัลกอริทึมตามนโยบายที่จัดการกับการเปิดตัวโดยใช้นโยบาย LM สภาพแวดล้อมการคำนวณรางวัล ฯลฯ ผู้ใช้จึงจำเป็นต้องใช้วิธี train() ด้วยฟังก์ชั่นการสูญเสียที่กำหนดเอง
from stable_baselines3 . common . on_policy_algorithm import OnPolicyAlgorithm
class MyOnPolicyAlgorithm ( OnPolicyAlgorithm ):
def __init__ ( ** args ):
super (). __init__ ( ** args )
def train ( self ) -> None :
# train for n_epochs epochs
for epoch in range ( self . n_epochs ):
# Do a complete pass on the rollout buffer
for rollout_data in self . rollout_buffer . get ( self . batch_size ):
# compute loss เราให้บริการนโยบายนักแสดงนักแสดงที่ใช้ LM ซึ่งห่อหุ้ม LM และ SEQ2SEQ LMS สิ่งเหล่านี้สามารถขยายได้ (สำหรับเช่น: ใช้สถาปัตยกรรมการวิจารณ์ที่แตกต่างกัน) โดยการเอาชนะวิธีการที่เหมาะสม (เช่น evaluate_actions() )
ในที่สุดเพียงลงทะเบียนส่วนประกอบที่กำหนดเองของคุณโดยเพิ่มลงในรีจิสทรีที่สอดคล้องกันหลังจากนั้นพวกเขาสามารถใช้โดยตรงจากการกำหนดค่าคล้ายกับส่วนประกอบที่กำหนดไว้ล่วงหน้า
เราได้จัดเตรียมเทมเพลต crowdsourcing ที่เราใช้กับ Mechanical Turk พร้อมกับตัวอย่างอินพุตใน scripts/crowdworking_templates คุณอาจพบว่าจุดเริ่มต้นที่เป็นประโยชน์เหล่านี้สำหรับการประเมินรุ่นรุ่นของคุณเองหรือเพื่อรวบรวมข้อมูลการฝึกอบรมสำหรับฟังก์ชั่นรางวัลที่เรียนรู้
นอกจากนี้เรายังสนับสนุนการบันทึก Wandb และการเริ่มต้นที่อบอุ่นโดยการจัดเก็บจุดตรวจและสิ่งประดิษฐ์การฝึกอบรมอื่น ๆ ในเส้นทางที่ผู้ใช้ระบุ สิ่งนี้มีประโยชน์อย่างยิ่งสำหรับการทำงานที่ได้รับการจารึกไว้ในกลุ่มขนาดใหญ่ที่กำหนด
สิ่งประดิษฐ์รวมถึง (1) ไฟล์ JSONL ที่มีการเปิดตัว infos ตามช่วงเวลาที่กำหนด (2) ไฟล์ JSONL ที่มีการฝึกอบรม infos ตามช่วงเวลาที่กำหนด (3) ไฟล์ JSONL ที่มีการตรวจสอบความถูกต้องในช่วงเวลาที่ระบุ (4) ไฟล์การทดสอบ JSON (8) config json ใช้เรียกใช้การทดสอบ
การใช้งานให้เสร็จสมบูรณ์มีดังนี้:
WANDB_API_KEY= < YOUR-WANDB-API-KEY-HERE > python scripts/training/train_text_generation.py
--config_path < PATH-TO-CONFIG-FILE >
--experiment_name < EXPERIMENT-NAME >
--base_path_to_store_results < PATH-TO-STORE-RESULTS >
--log_to_wandb @inproceedings { Ramamurthy2022IsRL ,
title = { Is Reinforcement Learning (Not) for Natural Language Processing?: Benchmarks, Baselines, and Building Blocks for Natural Language Policy Optimization } ,
author = { Rajkumar Ramamurthy and Prithviraj Ammanabrolu and Kiant{'e} Brantley and Jack Hessel and Rafet Sifa and Christian Bauckhage and Hannaneh Hajishirzi and Yejin Choi } ,
journal = { arXiv preprint arXiv:2210.01241 } ,
url = { https://arxiv.org/abs/2210.01241 } ,
year = { 2022 }
}สำหรับการอภิปรายคำถามการแลกเปลี่ยนความคิดเข้าร่วมช่อง Slack ของเรา