TRLX는 제공된 보상 기능 또는 보상 표시 데이터 세트를 사용하여 강화 학습을 통해 대형 언어 모델을 미세 조정하는 데 중점을두기 위해 처음부터 설계된 분산 교육 프레임 워크입니다.
교육 지원? Hugging Face 모델은 Accelerate 지원 트레이너가 제공하여 사용자가 facebook/opt-6.7b , EleutherAI/gpt-neox-20b 및 google/flan-t5-xxl 과 같은 최대 20B 매개 변수의 인과 및 T5 기반 언어 모델을 미세 조정할 수 있습니다. 20B 매개 변수 이외의 모델의 경우 TRLX는 효율적인 병렬 처리 기술을 활용하여 효과적으로 확장하는 NVIDIA NEMO 지원 트레이너를 제공합니다.
다음 RL 알고리즘이 현재 구현되었습니다.
| 연산 | 트레이너를 가속화하십시오 | NEMO 트레이너 |
|---|---|---|
| 근위 정책 최적화 (PPO) | ✅ | ✅ |
| 암시 적 언어 Q- 러닝 (ILQL) | ✅ | ✅ |
선적 서류 비치
? 치즈는 인간의 루프 데이터 수집 라이브러리를 사용하여 RL 애플리케이션을위한 사람의 주석을 수집합니다.
git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install torch --extra-index-url https://download.pytorch.org/whl/cu118
pip install -e . 더 많은 사용법은 예를 참조하십시오. 아래의 Colab 노트북을 사용해 볼 수도 있습니다.
| 설명 | 링크 |
|---|---|
| Simulacra (GPT2, ILQL) | |
| 감정 (GPT2, ILQL) |
예제의 최신 실행은 우리의 웨이트 및 바이어스에 있습니다.
보상 함수 또는 보상 표시 데이터 세트를 사용하여 모델을 교육 할 수 있습니다.
trainer = trlx . train ( 'gpt2' , reward_fn = lambda samples , ** kwargs : [ sample . count ( 'cats' ) for sample in samples ])보상 모델 교육은 Autocrit Library를 참조하십시오.
trainer = trlx . train ( 'EleutherAI/gpt-j-6B' , samples = [ 'dolphins' , 'geese' ], rewards = [ 1.0 , 100.0 ]) trainer = trlx . train ( 'gpt2' , samples = [[ 'Question: 1 + 2 Answer:' , '3' ], [ 'Question: Solve this equation: ∀n>0, s=2, sum(n ** -s). Answer:' , '(pi ** 2)/ 6' ]]) trainer . generate ( ** tokenizer ( 'Q: Who rules the world? A:' , return_tensors = 'pt' ), do_sample = True ) from trlx . data . default_configs import default_ppo_config
config = default_ppo_config ()
config . model . model_path = 'EleutherAI/gpt-neox-20b'
config . tokenizer . tokenizer_path = 'EleutherAI/gpt-neox-20b'
config . train . seq_length = 2048
trainer = trlx . train ( config = config , reward_fn = lambda samples , ** kwargs : [ len ( sample ) for sample in samples ])메모리 사용량을 줄이려면 (CUDA가 메모리 오류에서 발생하는 경우) 먼저 다음과 피파리메이터에서 가장 낮은 설정을 시도하고 결국 증가시킵니다.
# micro batch size per gpu
config . train . batch_size = 1
# freeze all transformer layers
config . model . num_layers_unfrozen = 0
# maximum sample length, prompts or samples longer than that will be truncated
config . train . seq_length = 128
# micro batch size for sampling (specific for PPO)
config . method . chunk_size = 1
# use an additional Q-head (specific for ILQL)
config . method . two_qs = False trainer . save_pretrained ( '/path/to/output/folder/' )accelerate config # choose DeepSpeed option
accelerate launch examples/simulacra.pyNemo Readme의 설정 지침을 따르십시오.
python examples/nemo_ilql_sentiments.py더 많은 사용은 Nemo Readme를 참조하십시오
ray start --head --port=6379
python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml --accelerate_config configs/accelerate/ddp.yaml --num_gpus 4 examples/ppo_sentiments.pymain 브랜치에 대해 TRLX 포크를 벤치마킹하십시오 python -m trlx.reference octocat/trlx-fork:fix-branch TRLX는 표준 파이썬 logging 라이브러리를 사용하여 콘솔에 교육 정보를 기록합니다. 기본 로거가 INFO 수준으로 설정되어 있으므로 INFO , WARNING , ERROR 및 CRITICAL 레벨 메시지가 표준 출력으로 인쇄됩니다.
로그 레벨을 직접 변경하려면 Verbosity Setter를 사용할 수 있습니다. 예를 들어 로그 레벨을 WARNING 사용으로 설정하려면 다음과 같습니다.
import trlx
trlx . logging . set_verbosity ( trlx . logging . WARNING ) 이것은 INFO 수준 메시지를 억제하지만 여전히 WARNING , ERROR 및 CRITICAL 레벨 메시지를 인쇄합니다.
TRLX_VERBOSITY 환경 변수를 표준 로깅 레벨 이름 중 하나로 설정하여 로깅 제어를 제어 할 수도 있습니다.
CRITICAL ( trlx.logging.CRITICAL )ERROR ( trlx.logging.ERROR )WARNING ( trlx.logging.WARNING )INFO ( trlx.logging.INFO )DEBUG ( trlx.logging.DEBUG ) export TRLX_VERBOSITY=WARNING 기본적으로 tqdm 진행 막대는 교육 진행 상황을 표시하는 데 사용됩니다. trlx.logging.disable_progress_bar() trlx.logging.enable_progress_bar() 하여 비활성화 할 수 있습니다.
trlx.logging.enable_explicit_format() 설정하여 메시지를 더 자세히 서식 할 수 있습니다. 이렇게하면 콜 사이트 정보를 각 로그에 주입하여 디버깅에 도움이 될 수 있습니다.
[2023-01-01 05:00:00,000] [INFO] [ppo_orchestrator.py:63:make_experience] [RANK 0] Message...팁 : 로깅 출력량을 줄이려면 TRLX에서 사용하는 타사 라이브러리의 로그 레벨을 변경하는 데 도움이 될 수 있습니다. 예를 들어,
transformersLibrary의 Verbose 메시지를 침묵시키기 위해 TRLX 스크립트 상단에transformers.logging.set_verbosity_error()추가하십시오 (자세한 내용은 로깅 문서 참조).
개발을 위해이 지침을 확인하고 문서도 읽으십시오.
@inproceedings{havrilla-etal-2023-trlx,
title = "trl{X}: A Framework for Large Scale Reinforcement Learning from Human Feedback",
author = "Havrilla, Alexander and
Zhuravinskyi, Maksym and
Phung, Duy and
Tiwari, Aman and
Tow, Jonathan and
Biderman, Stella and
Anthony, Quentin and
Castricato, Louis",
booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing",
month = dec,
year = "2023",
address = "Singapore",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2023.emnlp-main.530",
doi = "10.18653/v1/2023.emnlp-main.530",
pages = "8578--8595",
}
이 저장소에 처음 영감을 준 라이브러리 인 TRL에 기여한 Leandro von Werra에게 감사드립니다.