TRLX es un marco de capacitación distribuido diseñado desde cero para enfocarse en ajustar modelos de idiomas grandes con aprendizaje de refuerzo utilizando una función de recompensa proporcionada o un conjunto de datos marcado con recompensas.
¿Apoyo de capacitación para? Los entrenadores con respaldo de Accelerate proporcionan los modelos de Face, que permiten a los usuarios ajustar modelos de lenguaje causales y basados en T5 de parámetros de hasta 20B, como facebook/opt-6.7b , EleutherAI/gpt-neox-20b y google/flan-t5-xxl . Para los modelos más allá de los parámetros de 20B, TRLX proporciona entrenadores respaldados por NVIDIA NEMO que aprovechan las técnicas de paralelismo eficientes para escalar de manera efectiva.
Actualmente se implementan los siguientes algoritmos RL:
| Algoritmo | Acelerar el entrenador | Entrenador nemo |
|---|---|---|
| Optimización de políticas proximales (PPO) | ✅ | ✅ |
| Lenguaje implícito Q-Learning (ILQL) | ✅ | ✅ |
Documentación
? Cheese recopila anotaciones humanas para su aplicación RL con nuestra biblioteca de recopilación de datos humanos en el bucle.
git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install torch --extra-index-url https://download.pytorch.org/whl/cu118
pip install -e . Para más uso, consulte ejemplos. También puede probar los cuadernos Colab a continuación:
| Descripción | Enlace |
|---|---|
| Simulacra (GPT2, ILQL) | |
| Sentimiento (GPT2, ILQL) |
Las últimas carreras de los ejemplos están en nuestros pesos y prejuicios
Puede entrenar un modelo utilizando una función de recompensa o un conjunto de datos marcado con recompensas.
trainer = trlx . train ( 'gpt2' , reward_fn = lambda samples , ** kwargs : [ sample . count ( 'cats' ) for sample in samples ])Para el entrenamiento modelo de recompensa, consulte nuestra Biblioteca Autocrit.
trainer = trlx . train ( 'EleutherAI/gpt-j-6B' , samples = [ 'dolphins' , 'geese' ], rewards = [ 1.0 , 100.0 ]) trainer = trlx . train ( 'gpt2' , samples = [[ 'Question: 1 + 2 Answer:' , '3' ], [ 'Question: Solve this equation: ∀n>0, s=2, sum(n ** -s). Answer:' , '(pi ** 2)/ 6' ]]) trainer . generate ( ** tokenizer ( 'Q: Who rules the world? A:' , return_tensors = 'pt' ), do_sample = True ) from trlx . data . default_configs import default_ppo_config
config = default_ppo_config ()
config . model . model_path = 'EleutherAI/gpt-neox-20b'
config . tokenizer . tokenizer_path = 'EleutherAI/gpt-neox-20b'
config . train . seq_length = 2048
trainer = trlx . train ( config = config , reward_fn = lambda samples , ** kwargs : [ len ( sample ) for sample in samples ])Para reducir el uso de la memoria (si está experimentando CUDA fuera de los errores de memoria), primero pruebe la configuración más baja para los siguientes hiperparámetros y eventualmente aumente:
# micro batch size per gpu
config . train . batch_size = 1
# freeze all transformer layers
config . model . num_layers_unfrozen = 0
# maximum sample length, prompts or samples longer than that will be truncated
config . train . seq_length = 128
# micro batch size for sampling (specific for PPO)
config . method . chunk_size = 1
# use an additional Q-head (specific for ILQL)
config . method . two_qs = False trainer . save_pretrained ( '/path/to/output/folder/' )accelerate config # choose DeepSpeed option
accelerate launch examples/simulacra.pySiga las instrucciones de configuración en Nemo ReadMe.
python examples/nemo_ilql_sentiments.pyPara más uso, consulte el readme nemo
ray start --head --port=6379
python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml --accelerate_config configs/accelerate/ddp.yaml --num_gpus 4 examples/ppo_sentiments.pymain de TRLX python -m trlx.reference octocat/trlx-fork:fix-branch TRLX utiliza la biblioteca logging de Python estándar para registrar la información de capacitación en la consola. El registrador predeterminado se establece en el nivel INFO , lo que significa que INFO , WARNING , ERROR y los mensajes de nivel CRITICAL se imprimirán en la salida estándar.
Para cambiar el nivel de registro directamente, puede usar el setter de verbosidad. Por ejemplo, para establecer el nivel de registro en el uso WARNING :
import trlx
trlx . logging . set_verbosity ( trlx . logging . WARNING ) Esto suprimirá los mensajes de nivel INFO , pero aún imprimirá WARNING , ERROR y los mensajes de nivel CRITICAL .
También puede controlar la verbosidad de registro estableciendo la variable de entorno TRLX_VERBOSITY en uno de los nombres de nivel de registro estándar:
CRITICAL ( trlx.logging.CRITICAL )ERROR ( trlx.logging.ERROR )WARNING ( trlx.logging.WARNING )INFO ( trlx.logging.INFO )DEBUG ( trlx.logging.DEBUG ) export TRLX_VERBOSITY=WARNING Por defecto, las barras de progreso tqdm se utilizan para mostrar el progreso de la capacitación. Puede deshabilitarlos llamando trlx.logging.disable_progress_bar() , de lo contrario trlx.logging.enable_progress_bar() para habilitar.
Los mensajes se pueden formatear con mayor detalle configurando trlx.logging.enable_explicit_format() . Esto inyectará información del sitio de llamadas en cada registro que puede ser útil para la depuración.
[2023-01-01 05:00:00,000] [INFO] [ppo_orchestrator.py:63:make_experience] [RANK 0] Message...Consejo: Para reducir la cantidad de salida de registro, puede que sea útil cambiar los niveles de registro de las bibliotecas de terceros utilizadas por TRLX. Por ejemplo, intente agregar
transformers.logging.set_verbosity_error()en la parte superior de sus scripts TRLX para silenciar los mensajes verbosos de la bibliotecatransformers(consulte sus documentos de registro para obtener más detalles).
Para el desarrollo, consulte estas pautas y también lea nuestros documentos
@inproceedings{havrilla-etal-2023-trlx,
title = "trl{X}: A Framework for Large Scale Reinforcement Learning from Human Feedback",
author = "Havrilla, Alexander and
Zhuravinskyi, Maksym and
Phung, Duy and
Tiwari, Aman and
Tow, Jonathan and
Biderman, Stella and
Anthony, Quentin and
Castricato, Louis",
booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing",
month = dec,
year = "2023",
address = "Singapore",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2023.emnlp-main.530",
doi = "10.18653/v1/2023.emnlp-main.530",
pages = "8578--8595",
}
Muchas gracias a Leandro von Werra por contribuir con TRL, una biblioteca que inicialmente inspiró este repositorio.