LLM-RLHF-Tuning
This project implements RLHF three-stage training from scratch and writes the implementation details in detail in the document. Everyone is welcome to communicate and discuss WeChat
Main content:
- Supports instruction fine-tuning Alpaca model
- Support training of Reward models
- Support PPO algorithm to train RL models
- Supports adapters based on two basic models and two lora, and loads four models: RM, SFT, Actor, and Critic at the same time, and supports accelerate distributed training (PPO algorithm implementation details)
- Supports two lora adapters based on one base model, and loads four models: RM, SFT, Actor, and Critic at the same time, supporting accelerate and deepspeed training.
- Supports a base model based on a base model and a lora adapter, Actor and Critic share base model, and implements four model functions: RM, SFT, Actor and Critic, and supports accelerate and deepspeed training.
- Support DPO algorithm training model
renew
- [23/8/23] Support LLaMA2 model training; support DPO training; support PPO training based on one base model, select one or two lora adapters, support accelerate, deepspeed training
- [23/8/13] Support LLaMA model training; support PPO training based on two base models and two lora adapters; support accelerate distributed training
Function
Comparison with the functions of the open source RLHF training framework
| frame | SFT Train | RM Train | PPO Train | DPO Train |
|---|
| Our | ✅ | ✅ | ✅ | ✅ |
| Deepspeed-chat | ✅ | ✅ | ✅ | |
| trl | ✅ | ✅ | ✅ | ✅ |
| MOSS-RLHF | | | ✅ | |
PPO Train
| frame | Accelerate | Deepspeed | Multi LORA | Minimum model parameter quantity (7B as an example) |
|---|
| Our | ✅ | ✅ | ✅ | single model size ~ 7B |
| Deepspeed-chat | | ✅ | | sft+rm+actor+critic ~ 28B |
| trl | ✅ | | | single model size (not use ref model) ~ 7B |
| MOSS-RLHF | actor model, critical model | sft model, rm model | | sft+rm+actor+critic ~ 28B |
Guidelines for use
Environment construction
accelerate==0.21.0
datasets==2.13.1
scikit-learn==1.3.0
sentencepiece==0.1.99
tqdm==4.65.0
transformers==4.31.0
wandb==0.15.8
peft==0.4.0
torch==2.0.1
trl==0.5.0
deepspeed==0.10.0
Support model
Support training methods
Training details
Instruction fine-tuning model
Training reward model
PPO training
DPO training
TODO
Welcome to join the group to discuss WeChat