O TRLX é uma estrutura de treinamento distribuída projetada desde o início para se concentrar em modelos de linguagem de grande ajuste com aprendizado de reforço usando uma função de recompensa fornecida ou um conjunto de dados marcado com recompensa.
Suporte de treinamento para? Os modelos de face abraçados são fornecidos por treinadores apoiados por acelerar, permitindo que os usuários ajustem modelos de linguagem causal e baseados em T5 de parâmetros de até 20B, como facebook/opt-6.7b , EleutherAI/gpt-neox-20b e google/flan-t5-xxl . Para modelos além dos parâmetros 20B, o TRLX fornece treinadores de NVIDIA Nemo-Backed que alavancam técnicas de paralelismo eficientes para escalar efetivamente.
Atualmente, os seguintes algoritmos RL são implementados:
| Algoritmo | Acelere o treinador | Treinador Nemo |
|---|---|---|
| Otimização de política proximal (PPO) | ✅ | ✅ |
| Linguagem implícita q-learning (ILQL) | ✅ | ✅ |
Documentação
? O queijo coleta anotações humanas para o seu aplicativo RL com nossa biblioteca de coleta de dados humana no loop.
git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install torch --extra-index-url https://download.pytorch.org/whl/cu118
pip install -e . Para mais uso, consulte exemplos. Você também pode experimentar os notebooks do Colab abaixo:
| Descrição | Link |
|---|---|
| Simulacra (GPT2, ILQL) | |
| Sentimento (GPT2, ILQL) |
As últimas execuções dos exemplos estão em nossos pesos e preconceitos
Você pode treinar um modelo usando uma função de recompensa ou um conjunto de dados marcado com recompensa.
trainer = trlx . train ( 'gpt2' , reward_fn = lambda samples , ** kwargs : [ sample . count ( 'cats' ) for sample in samples ])Para o treinamento do modelo de recompensa , consulte nossa biblioteca de autócritas.
trainer = trlx . train ( 'EleutherAI/gpt-j-6B' , samples = [ 'dolphins' , 'geese' ], rewards = [ 1.0 , 100.0 ]) trainer = trlx . train ( 'gpt2' , samples = [[ 'Question: 1 + 2 Answer:' , '3' ], [ 'Question: Solve this equation: ∀n>0, s=2, sum(n ** -s). Answer:' , '(pi ** 2)/ 6' ]]) trainer . generate ( ** tokenizer ( 'Q: Who rules the world? A:' , return_tensors = 'pt' ), do_sample = True ) from trlx . data . default_configs import default_ppo_config
config = default_ppo_config ()
config . model . model_path = 'EleutherAI/gpt-neox-20b'
config . tokenizer . tokenizer_path = 'EleutherAI/gpt-neox-20b'
config . train . seq_length = 2048
trainer = trlx . train ( config = config , reward_fn = lambda samples , ** kwargs : [ len ( sample ) for sample in samples ])Para reduzir o uso da memória (se você estiver experimentando CUDA fora dos erros de memória), primeiro tente a configuração mais baixa para os seguintes hiperparâmetros e, eventualmente, aumente -os:
# micro batch size per gpu
config . train . batch_size = 1
# freeze all transformer layers
config . model . num_layers_unfrozen = 0
# maximum sample length, prompts or samples longer than that will be truncated
config . train . seq_length = 128
# micro batch size for sampling (specific for PPO)
config . method . chunk_size = 1
# use an additional Q-head (specific for ILQL)
config . method . two_qs = False trainer . save_pretrained ( '/path/to/output/folder/' )accelerate config # choose DeepSpeed option
accelerate launch examples/simulacra.pySiga as instruções de configuração no Nemo ReadMe.
python examples/nemo_ilql_sentiments.pyPara mais uso, consulte o Nemo ReadMe
ray start --head --port=6379
python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml --accelerate_config configs/accelerate/ddp.yaml --num_gpus 4 examples/ppo_sentiments.pymain do TRLX python -m trlx.reference octocat/trlx-fork:fix-branch O TRLX usa a biblioteca logging Python padrão para registrar informações de treinamento no console. O logger padrão está definido para o nível INFO , o que significa que INFO , WARNING , ERROR e mensagens de nível CRITICAL serão impressas na saída padrão.
Para alterar o nível de log diretamente, você pode usar o conjunto de verbosidade. Por exemplo, para definir o nível de log para uso WARNING :
import trlx
trlx . logging . set_verbosity ( trlx . logging . WARNING ) Isso suprimirá mensagens de nível INFO , mas ainda imprimirá WARNING , ERROR e mensagens de nível CRITICAL .
Você também pode controlar a verbosidade do registro definindo a variável de ambiente TRLX_VERBOSITY em um dos nomes de nível de registro padrão:
CRITICAL ( trlx.logging.CRITICAL )ERROR ( trlx.logging.ERROR )WARNING ( trlx.logging.WARNING )INFO ( trlx.logging.INFO )DEBUG ( trlx.logging.DEBUG ) export TRLX_VERBOSITY=WARNING Por padrão, as barras de progresso tqdm são usadas para exibir o progresso do treinamento. Você pode desativá -los chamando trlx.logging.disable_progress_bar() , caso contrário, trlx.logging.enable_progress_bar() para ativar.
As mensagens podem ser formatadas com mais detalhes definindo trlx.logging.enable_explicit_format() . Isso injetará informações do site de chamada em cada log que podem ser úteis para a depuração.
[2023-01-01 05:00:00,000] [INFO] [ppo_orchestrator.py:63:make_experience] [RANK 0] Message...Dica: para reduzir a quantidade de saída de log, você pode achar útil alterar os níveis de log de bibliotecas de terceiros usadas pelo TRLX. Por exemplo, tente adicionar
transformers.logging.set_verbosity_error()à parte superior dos seus scripts TRLX para silenciar mensagens detalhadas da bibliotecatransformers(consulte seus documentos de registro para obter mais detalhes).
Para o desenvolvimento, consulte essas diretrizes e também leia nossos documentos
@inproceedings{havrilla-etal-2023-trlx,
title = "trl{X}: A Framework for Large Scale Reinforcement Learning from Human Feedback",
author = "Havrilla, Alexander and
Zhuravinskyi, Maksym and
Phung, Duy and
Tiwari, Aman and
Tow, Jonathan and
Biderman, Stella and
Anthony, Quentin and
Castricato, Louis",
booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing",
month = dec,
year = "2023",
address = "Singapore",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2023.emnlp-main.530",
doi = "10.18653/v1/2023.emnlp-main.530",
pages = "8578--8595",
}
Muito obrigado a Leandro von Werra por contribuir com a TRL, uma biblioteca que inicialmente inspirou este repositório.