
Proporcionamos bloques de construcción fácilmente personalizables para modelos de lenguaje de capacitación que incluyen implementaciones de algoritmos en política , funciones de recompensa , métricas , conjuntos de datos y políticas críticas de actor basadas en LM
Enlace en papel: https://arxiv.org/abs/2210.01241
Enlace del sitio web: https://rl4lms.apps.allenai.org/
Completamente probado y comparado con más de 2000 experimentos (Grue Benchmill?) En un conjunto completo de:
Todos estos bloques de construcción pueden ser personalizables, lo que permite a los usuarios capacitar a LMS basados en transformadores para optimizar cualquier función de recompensa arbitraria en cualquier conjunto de datos de su elección.
git clone https://github.com/allenai/RL4LMs.git
cd RL4LMs
pip install -e . También proporcionamos un DockerFile para el desarrollo utilizando contenedores Docker que contienen todas las dependencias.
docker build . -t rl4lms Opcionalmente, las bibliotecas de CorenLP son necesarias para ciertas cálculos métricos (por ejemplo, Spice) que se pueden descargar a través de cd rl4lms/envs/text_generation/caption_metrics/spice && bash get_stanford_models.sh
Proporcionamos una API de capacitación simple que se puede invocar a través del script de trenes que permite entrenar PPO, NLPO o un modelo supervisado mediante el uso de un archivo de configuración (YAML).
Por ejemplo, para entrenar la base T5 en la resumen de CNN/DM en PPO usando Rouge-1 como función de recompensa, puede ejecutar:
python scripts/training/train_text_generation.py --config_path scripts/training/task_configs/summarization/t5_ppo.ymlLos archivos de configuración para todas las tareas se pueden encontrar aquí.
El archivo de configuración contiene detalles sobre la configuración del hiper-parámetro para los bloques de construcción que se describen a continuación:
Conjunto de datos/tarea : conjunto de datos que contiene muestras con indicaciones de entrada y oraciones de referencia. Los conjuntos de datos disponibles se encuentran en la clase DataPoolRegistry en el registro. (Vea cómo crear su propio conjunto de datos aquí)
datapool :
id : cnn_daily_mail
args :
prompt_prefix : " Summarize: "Tokenizer : un tokenizador previamente capacitado que se utiliza para (DE) tokenizar secuencias de entrada y salida con configuraciones para relleno y truncamiento
tokenizer :
model_name : t5-base
padding_side : left
truncation_side : left
pad_token_as_eos_token : False Función de recompensa : función de recompensa que calcula los puntajes de nivel de token en cada paso de tiempo de MDP. Las funciones de recompensa disponibles se pueden encontrar en la clase RewardFunctionRegistry . (Vea cómo crear su propia función de recompensa aquí)
reward_fn :
id : rouge
args :
rouge_type : " rouge1 " Entorno : configura un entorno de generación de texto estilo gimnasio que simula episodios de MDP. Los despliegue se generan utilizando muestras de tren del conjunto de datos que consiste en textos de entrada y referencia. Además, envolvemos nuestro ENV con SubProcVecEnv a partir de baselines estables que procesa episodios n_envs en paralelo usando el procesamiento múltiple para calcular recompensas por paso.
Los ajustes de configuración adicionales incluyen:
max_episode_length : longitud máxima del episodiomax_prompt_length - longitud máxima del texto de entrada a considerarterminate_on_eos : si terminar el episodio tan pronto como se realiza la acción de EOSprompt_truncation_side - Lado de truncamiento para el texto de inmediatocontext_start_token - ID para token de contexto (corresponde al token inicial dado al decodificador en modelos de codificador de codificador) env :
n_envs : 10
args :
max_prompt_length : 512
max_episode_length : 100
terminate_on_eos : True
prompt_truncation_side : " right "
context_start_token : 0ALG en política : proporcionamos implementaciones de 4 algoritmos en política: PPO, NLPO, A2C y TRPO adaptados de Baselines3 estables a medida para trabajar con tareas de PNL que se pueden usar fuera de la caja con una política causal o una política SEQ2SQE LM. (Consulte cómo crear su propio algoritmo o política en política)
También proporcionamos un entrenador supervisado para fines de evaluación comparativa. Los modelos de inicio caluroso supervisados ya están cargados en Huggingface Hub y se especifican en los archivos de configuración respectivos.
Los hiper-parametros para el algoritmo se pueden especificar en alg/args .
Además, todos los algoritmos RL usan el controlador KL adaptativo para mantener el LM cerca del LM original configurando el coeficiente KL inicial ( alg/kl_div/coeff ) y el objetivo KL ( alg/kl_div/target_kl ).
Apoyamos dos tipos de política LM: política Causal LM (para modelos de decodificadores) y política SEQ2SEQ LM (para modelos de codificadores codificadores). Además para NLPO, también proporcionamos variantes enmascarables de estos. Las implementaciones de políticas se pueden encontrar aquí y se pueden adjuntar a algoritmos especificando alg/policy/id y alg/policy/args
alg :
id : ppo
args :
n_steps : 512
batch_size : 64
verbose : 1
learning_rate : 0.000002
n_epochs : 5
ent_coef : 0.0
kl_div :
coeff : 0.001
target_kl : 0.2
policy :
id : seq2seq_lm_actor_critic_policy
args :
model_name : t5-base
apply_model_parallel : True
prompt_truncation_side : " right "
generation_kwargs :
do_sample : True
top_k : 50
min_length : 50
max_new_tokens : 100 Configuración del entrenador : proporcionamos un entrenador en la política, un envoltorio de características que entiende instancias de bloques de construcción de sus configuraciones correspondientes y proporciona un bucle de entrenamiento externo que consiste en Train y Evalations train_evaluation/n_iters .
alg/args/n_steps x env/n_envs del algoritmo elegido.eval_every , LM se evalúa en la división de validación utilizando métricas enumeradas en train_evaluation/metrics con la generación de Kwargs proporcionados en train_evaluation/generation_kwargs (esto anule el despliegue de alg/policy/generation_kwargs solo para fines de inferencia) # train and evaluation
train_evaluation :
eval_batch_size : 100
n_iters : 100
eval_every : 10
save_every : 1
metrics :
- id : meteor
args : {}
- id : rouge
- id : bleu
args : {}
- id : bert_score
args :
language : en
- id : diversity
args : {}
generation_kwargs :
do_sample : True
top_k : 0
temperature : 0.7
min_length : 50
max_new_tokens : 100RL4LMS proporciona personalización completa, con respecto a agregar nuevas tareas/conjuntos de datos, funciones de recompensa, métrica de evaluación, algoritmos en política y políticas críticas de actor.
Los usuarios pueden crear sus propios conjuntos de datos subclaseando TextGenpool simplemente anulando prepare(cls, split: str, **args) -> 'TextGenPool': Método para devolver una instancia de TextGenpool. A continuación se muestra un ejemplo:
from rl4lms . data_pools . text_generation_pool import Sample , TextGenPool
class MyDataPool ( TextGenPool ):
@ classmethod
def prepare ( cls , split : str ):
..
samples = []
for ix , item in enumerate (..):
sample = Sample ( id = f" { split } _ { ix } " ,
prompt_or_input_text = item [ "document" ],
references = [ item [ "target" ]]
)
samples . append ( sample )
pool_instance = cls ( samples )
return pool_instance Las funciones de recompensas personalizadas se pueden implementar fácilmente mediante la clasificación de la función de recompensa (un llamado) que toma la observación (
from rl4lms . envs . text_generation . observation import Observation
from rl4lms . envs . text_generation . reward import RewardFunction
class MyRewardFunction ( RewardFunction ):
def __init__ ( self , * args ) -> None :
super (). __init__ ()
def __call__ ( self , prev_observation : Observation ,
action : int ,
current_observation : Observation ,
done : bool ,
meta_info : Dict [ str , Any ] = None ) -> float :
if done :
reward = ..
return reward
return 0Además de las métricas tradicionales de NLG, para la creación de prototipos rápidos, proporcionamos dos funciones de recompensa sintética que entrena a LMS para generar números en orden creciente y generar fechas. Estos se pueden usar para probar rápidamente diferentes algoritmos y políticas. Las configuraciones correspondientes se pueden encontrar aquí (números, fechas)
Los usuarios pueden crear su propia métrica de evaluación que luego se utilizará para evaluar periódicamente el modelo en la división de validación del conjunto de datos. Esto se puede hacer mediante subclase basemétrico que toma textos rápidos, textos generados, textos de referencia, meta_infos, modelo LM actual, nombre dividido como entradas y devuelve un dict con el nombre métrico como clave y valor que consiste en la tupla de puntajes a nivel de oración y puntajes de nivel de Corpus. Un ejemplo es el siguiente:
from rl4lms . envs . text_generation . metric import BaseMetric
class MyMetric ( BaseMetric ):
def __init__ ( self ) -> None :
super (). __init__ ()
def compute ( self ,
prompt_texts : List [ str ],
generated_texts : List [ str ],
reference_texts : List [ List [ str ]],
meta_infos : List [ Dict [ str , Any ]] = None ,
model : PreTrainedModel = None ,
split_name : str = None ):
metric_dict = {
"custom_metrics/my_metric" : ([ 0.4 , 0.7 , 0.9 ], 0.7 )
}
return metric_dict Además de los algoritmos admitidos sobre la política (PPO, NLPO, A2C, TRPO), los usuarios pueden implementar sus propios algoritmos en política con facilidad al sub-clasificando el Policyalgoritmo Onpolicyalgoritmo de Onpolicyalins3. Dado que proporcionamos envoltorios para algoritmos en política que manejan los despliegos utilizando políticas de LM, entorno, recompensas de computación, etc., los usuarios solo necesitan implementar el método train() con funciones de pérdida personalizadas.
from stable_baselines3 . common . on_policy_algorithm import OnPolicyAlgorithm
class MyOnPolicyAlgorithm ( OnPolicyAlgorithm ):
def __init__ ( ** args ):
super (). __init__ ( ** args )
def train ( self ) -> None :
# train for n_epochs epochs
for epoch in range ( self . n_epochs ):
# Do a complete pass on the rollout buffer
for rollout_data in self . rollout_buffer . get ( self . batch_size ):
# compute loss Proporcionamos implementaciones de políticas críticas basadas en LM que envuelven LM causales y SEQ2SEQ LMS. Estos también se pueden extender (por ejemplo: usar una arquitectura crítica diferente) anulando los métodos apropiados (por ejemplo, evaluate_actions() ))
Finalmente, simplemente registre sus componentes personalizados agregándolos al registro correspondiente, después de lo cual se pueden usar directamente desde configuraciones similares a los componentes predefinidos
Hemos proporcionado las plantillas de crowdsourcing que utilizamos en Mechanical Turk, junto con entradas de ejemplo en scripts/crowdworking_templates . Es posible que estos sean un punto de partida útil, ya sea para evaluar las generaciones de su propio modelo o para recopilar datos de capacitación para una función de recompensa aprendida.
Además, apoyamos el registro WANDB y el inicio cálido de la capacitación almacenando puntos de control y otros artefactos de capacitación en una ruta especificada por el usuario. Esto es especialmente útil para ejecutar trabajos preventibles en grupos grandes y programados.
Artifacts include (1) jsonl file containing rollout infos at specified intervals (2) jsonl file containing training infos at specified intervals (3) jsonl file containing validation metrics at specified intervals (4) jsonl file containing test metrics before and after training (5) json file with validation predictions at specified intervals (6) json file with test predictions before and after training (7) trained LM model (8) config json usado para ejecutar el experimento
El uso completo es el siguiente:
WANDB_API_KEY= < YOUR-WANDB-API-KEY-HERE > python scripts/training/train_text_generation.py
--config_path < PATH-TO-CONFIG-FILE >
--experiment_name < EXPERIMENT-NAME >
--base_path_to_store_results < PATH-TO-STORE-RESULTS >
--log_to_wandb @inproceedings { Ramamurthy2022IsRL ,
title = { Is Reinforcement Learning (Not) for Natural Language Processing?: Benchmarks, Baselines, and Building Blocks for Natural Language Policy Optimization } ,
author = { Rajkumar Ramamurthy and Prithviraj Ammanabrolu and Kiant{'e} Brantley and Jack Hessel and Rafet Sifa and Christian Bauckhage and Hannaneh Hajishirzi and Yejin Choi } ,
journal = { arXiv preprint arXiv:2210.01241 } ,
url = { https://arxiv.org/abs/2210.01241 } ,
year = { 2022 }
}Para discusión, preguntas, intercambio de ideas, únase a nuestro canal Slack