
Nous fournissons des éléments constitutifs facilement personnalisables pour la formation de modèles de langage, y compris les implémentations d' algorithmes de politique , les fonctions de récompense , les métriques , les ensembles de données et les politiques acteurs-critiques basées sur LM
Lien papier: https://arxiv.org/abs/2210.01241
Lien de site Web: https://rl4lms.apps.allenai.org/
Testé entièrement et comparé avec plus de 2000 expériences (Grue Benchmark?) Sur un ensemble complet de:
Tous ces blocs de construction peuvent être personnalisables permettant aux utilisateurs de former des LM basés sur le transformateur pour optimiser toute fonction de récompense arbitraire sur tout ensemble de données de leur choix.
git clone https://github.com/allenai/RL4LMs.git
cd RL4LMs
pip install -e . Nous fournissons également un Dockerfile pour le développement à l'aide de conteneurs Docker contenant toutes les dépendances.
docker build . -t rl4lms Facultativement, les bibliothèques corenlp sont requises pour certains calculs métriques (par exemple, Spice) qui peuvent être téléchargés via cd rl4lms/envs/text_generation/caption_metrics/spice && bash get_stanford_models.sh
Nous fournissons une API de formation simple qui peut être invoquée via le script de train qui permet de former PPO, NLPO ou un modèle supervisé en utilisant un fichier de configuration (YAML).
Par exemple, pour entraîner la base T5 sur la résumé CNN / DM sur PPO en utilisant Rouge-1 comme fonction de récompense, vous pouvez exécuter:
python scripts/training/train_text_generation.py --config_path scripts/training/task_configs/summarization/t5_ppo.ymlLes fichiers de configuration pour toutes les tâches peuvent être trouvés ici.
Le fichier de configuration contient des détails sur les paramètres hyper-paramètres pour les blocs de construction décrits ci-dessous:
Ensemble de données / tâche : ensemble de données contenant des échantillons avec des invites d'entrée et des phrases de référence. Les ensembles de données disponibles se trouvent dans la classe DataPoolRegistry dans le registre. (Voir comment créer votre propre ensemble de données ici)
datapool :
id : cnn_daily_mail
args :
prompt_prefix : " Summarize: "Tokenizer - Un jetons pré-formé qui est utilisé pour (DE) des séquences d'entrée et de sortie de tokenize avec des paramètres pour le rembourrage et la troncature
tokenizer :
model_name : t5-base
padding_side : left
truncation_side : left
pad_token_as_eos_token : False Fonction de récompense : fonction de récompense qui calcule les scores de niveau de jeton à chaque pas de temps du MDP. Les fonctions de récompense disponibles peuvent être trouvées dans la classe RewardFunctionRegistry . (Voir comment créer votre propre fonction de récompense ici)
reward_fn :
id : rouge
args :
rouge_type : " rouge1 " Environnement : configure un environnement de génération de texte de style gym qui simule les épisodes MDP. Les déploiements sont générés à l'aide d'échantillons de train à partir d'un ensemble de données composé de textes d'entrée et de référence. De plus, nous enroulons notre Env avec SubProcVecEnv -procrovecenv à partir de basélines stables qui traite les épisodes n_envs en parallèle en utilisant le multi-processus pour calculer les récompenses étapes.
D'autres paramètres de configuration incluent:
max_episode_length : longueur maximale de l'épisodemax_prompt_length - longueur maximale du texte d'entrée à considérerterminate_on_eos - s'il faut terminer l'épisode dès que l'action EOS est effectuéeprompt_truncation_side - côté troncature pour le texte rapidecontext_start_token - id pour le jeton de contexte (correspond au jeton initial donné au décodeur dans des modèles d'encodeur-décodeur) env :
n_envs : 10
args :
max_prompt_length : 512
max_episode_length : 100
terminate_on_eos : True
prompt_truncation_side : " right "
context_start_token : 0ON-POLICY ALG : Nous fournissons des implémentations de 4 algorithmes de politique: PPO, NLPO, A2C et TRPO adaptés à partir de Basélines stables3 adaptés à des tâches NLP qui peuvent être utilisées hors de la boîte avec une politique causale ou une politique SEQP LM. (Voir comment créer votre propre algorithme ou politique de politique)
Nous fournissons également un entraîneur supervisé à des fins d'analyse comparative. Les modèles de démarrage à chaud supervisé sont déjà téléchargés sur HuggingFace Hub et spécifiés dans les fichiers de configuration respectifs.
Les hyper-paramètres pour l'algorithme peuvent être spécifiés à alg/args .
De plus, tous les algorithmes RL utilisent le contrôleur KL adaptatif pour garder le LM près de LM d'origine en définissant KL initial KL ( alg/kl_div/coeff ) et Target KL ( alg/kl_div/target_kl ).
Nous prenons en charge deux types de politique LM: la politique LM causale (pour les modèles de décodeur uniquement) et la politique SEQ2SEQ LM (pour les modèles d'encodeur-décodeur). De plus pour NLPO, nous en fournissons également des variantes masquées. Les implémentations de stratégie peuvent être trouvées ici et peuvent être jointes aux algorithmes en spécifiant alg/policy/id et alg/policy/args
alg :
id : ppo
args :
n_steps : 512
batch_size : 64
verbose : 1
learning_rate : 0.000002
n_epochs : 5
ent_coef : 0.0
kl_div :
coeff : 0.001
target_kl : 0.2
policy :
id : seq2seq_lm_actor_critic_policy
args :
model_name : t5-base
apply_model_parallel : True
prompt_truncation_side : " right "
generation_kwargs :
do_sample : True
top_k : 50
min_length : 50
max_new_tokens : 100 Configuration du formateur : Nous fournissons un entraîneur sur la politique - un emballage complet de fonctionnalités qui instancie des éléments constitutifs de leurs configurations correspondantes et fournit une boucle de formation extérieure composée de train et d'évaluation train_evaluation/n_iters .
alg/args/n_steps X env/n_envs de l'algorithme choisi.eval_every ITERS, LM est évalué sur la fraction de validation à l'aide de mesures répertoriées dans train_evaluation/metrics avec Génération Kwargs fournies dans train_evaluation/generation_kwargs (cela remplace le déploiement alg/policy/generation_kwargs à des fins d'inférence uniquement) # train and evaluation
train_evaluation :
eval_batch_size : 100
n_iters : 100
eval_every : 10
save_every : 1
metrics :
- id : meteor
args : {}
- id : rouge
- id : bleu
args : {}
- id : bert_score
args :
language : en
- id : diversity
args : {}
generation_kwargs :
do_sample : True
top_k : 0
temperature : 0.7
min_length : 50
max_new_tokens : 100RL4LMS offre une personnalisation complète - en ce qui concerne l'ajout de nouvelles tâches / ensembles de données, des fonctions de récompense, des métriques d'évaluation, des algorithmes de politique et des politiques acteurs-critiques.
Les utilisateurs peuvent créer leurs propres ensembles de données en sous-classe TextGenpool simplement en dépassant prepare(cls, split: str, **args) -> 'TextGenPool': Méthode pour renvoyer une instance de TextGenpool. Un exemple est indiqué ci-dessous:
from rl4lms . data_pools . text_generation_pool import Sample , TextGenPool
class MyDataPool ( TextGenPool ):
@ classmethod
def prepare ( cls , split : str ):
..
samples = []
for ix , item in enumerate (..):
sample = Sample ( id = f" { split } _ { ix } " ,
prompt_or_input_text = item [ "document" ],
references = [ item [ "target" ]]
)
samples . append ( sample )
pool_instance = cls ( samples )
return pool_instance Les finances de récompense personnalisées peuvent être mises en œuvre facilement par la récompense de sous-classe (un appelable) qui prend l'observation (
from rl4lms . envs . text_generation . observation import Observation
from rl4lms . envs . text_generation . reward import RewardFunction
class MyRewardFunction ( RewardFunction ):
def __init__ ( self , * args ) -> None :
super (). __init__ ()
def __call__ ( self , prev_observation : Observation ,
action : int ,
current_observation : Observation ,
done : bool ,
meta_info : Dict [ str , Any ] = None ) -> float :
if done :
reward = ..
return reward
return 0En plus des métriques NLG traditionnelles, pour le prototypage rapide, nous fournissons deux fonctions de récompense synthétique qui forme LMS à générer des nombres dans l'ordre croissant et générer des dates. Ceux-ci peuvent être utilisés pour tester rapidement différents algorithmes et politiques. Les configurations correspondantes peuvent être trouvées ici (nombres, dates)
Les utilisateurs peuvent créer leur propre métrique d'évaluation qui sera ensuite utilisée pour évaluer périodiquement le modèle sur la répartition de la validation de l'ensemble de données. Cela peut être fait par sous-classe BaseMetric qui prend des textes rapides, des textes générés, des textes de référence, Meta_infos, un modèle LM actuel, un nom de division en entrées et renvoie un dict avec le nom de métrique comme clé et valeur composée de tuple de scores au niveau de la phrase et de scores de niveau corpus. Un exemple est le suivant:
from rl4lms . envs . text_generation . metric import BaseMetric
class MyMetric ( BaseMetric ):
def __init__ ( self ) -> None :
super (). __init__ ()
def compute ( self ,
prompt_texts : List [ str ],
generated_texts : List [ str ],
reference_texts : List [ List [ str ]],
meta_infos : List [ Dict [ str , Any ]] = None ,
model : PreTrainedModel = None ,
split_name : str = None ):
metric_dict = {
"custom_metrics/my_metric" : ([ 0.4 , 0.7 , 0.9 ], 0.7 )
}
return metric_dict En plus des algorithmes de politique pris en charge (PPO, NLPO, A2C, TRPO), les utilisateurs peuvent mettre en œuvre leurs propres algorithmes de politique avec le sous-classement de l'OnPolithgorithme de l'OnPolicygorithme de Stable-Baselines3. Étant donné que nous fournissons des emballages pour les algorithmes de politique qui gère les déploiements à l'aide de politiques LM, d'environnement, de récompenses informatiques, etc., les utilisateurs ont juste besoin d'implémenter la méthode train() avec des fonctions de perte personnalisées.
from stable_baselines3 . common . on_policy_algorithm import OnPolicyAlgorithm
class MyOnPolicyAlgorithm ( OnPolicyAlgorithm ):
def __init__ ( ** args ):
super (). __init__ ( ** args )
def train ( self ) -> None :
# train for n_epochs epochs
for epoch in range ( self . n_epochs ):
# Do a complete pass on the rollout buffer
for rollout_data in self . rollout_buffer . get ( self . batch_size ):
# compute loss Nous fournissons des implémentations de politique acteur-critique basées sur LM qui enveloppent la causalité LM et SEQ2SEQ LMS. Ceux-ci peuvent également être étendus (pour EG: Utilisez une architecture critique différente) en dépassant les méthodes appropriées (par exemple evaluate_actions()
Enfin, enregistrez simplement vos composants personnalisés en les ajoutant au registre correspondant, après quoi ils peuvent être utilisés directement à partir de configurations similaires aux composants prédéfinis
Nous avons fourni les modèles de crowdsourcing que nous avons utilisés sur Mechanical Turk, ainsi que des exemples d'entrées dans scripts/crowdworking_templates . Vous pourriez trouver ceux-ci un point de départ utile soit pour évaluer les générations de votre propre modèle, soit pour la collecte de données de formation pour une fonction de récompense apprise.
De plus, nous prenons en charge la journalisation WANDB et le démarrage chaleureux de la formation en stockant des points de contrôle et d'autres artefacts de formation sur un chemin spécifié par l'utilisateur. Ceci est particulièrement utile pour gérer des travaux préemptibles sur de grands grappes programmées.
Les artefacts incluent (1) le fichier JSONL contenant des INFO de déploiement à intervalles spécifiés (2) Fichier JSONL contenant des infos de formation à intervalles spécifiés (3) Fichier JSONL contenant des métriques de validation à des intervalles spécifiés (4) Fichier JSONL contenant des métriques de test avant et après la formation (5) Fichier JSON avec des prévisions de validation à l'entraînement Spécifié (6) Fichier JSON avec des prédictions de test précédemment et après une formation. (8) Config JSON a utilisé pour exécuter l'expérience
L'utilisation complète est la suivante:
WANDB_API_KEY= < YOUR-WANDB-API-KEY-HERE > python scripts/training/train_text_generation.py
--config_path < PATH-TO-CONFIG-FILE >
--experiment_name < EXPERIMENT-NAME >
--base_path_to_store_results < PATH-TO-STORE-RESULTS >
--log_to_wandb @inproceedings { Ramamurthy2022IsRL ,
title = { Is Reinforcement Learning (Not) for Natural Language Processing?: Benchmarks, Baselines, and Building Blocks for Natural Language Policy Optimization } ,
author = { Rajkumar Ramamurthy and Prithviraj Ammanabrolu and Kiant{'e} Brantley and Jack Hessel and Rafet Sifa and Christian Bauckhage and Hannaneh Hajishirzi and Yejin Choi } ,
journal = { arXiv preprint arXiv:2210.01241 } ,
url = { https://arxiv.org/abs/2210.01241 } ,
year = { 2022 }
}Pour la discussion, les questions, l'échange d'idées, rejoignez notre canal Slack