staged training
1.0.0
在我们针对变压器语言模型的论文分阶段培训中,我们提出了一个分阶段的培训设置,该设置始于小型模型,并通过应用“增长操作员”来增加模型深度和宽度,从而增加了用于培训的计算量。通过使用上一个阶段的每个阶段初始化每个阶段,训练过程可以有效地从前阶段重新使用计算,并变得更有效。
我们在此处发布了增长操作员和评估脚本的可再现代码。
此存储库中的脚本需要Python 3.7或更新。拥有合适的Python环境后,首先根据官方说明首先安装Pytorch v1.9.0。然后运行
pip install -r requirements.txt
我们的增长操作员(宽度/深度)每个都将整个培训状态(包括模型参数,优化器状态,学习率计划等)作为输入,并输出一种培训的新培训状态。
请参阅scripts/cheatsheet.txt以获取有关如何使用相应脚本的更多示例。
例如,您可以使用以下方式应用宽度操作员
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py
--save_prefix final_gpt2_large_div2_width_check_bs512_lr0.0020_warmup3k_seqlen1024_debug
--gpu_count -1
--model gpt2
--tokenizer gpt2
--batch_size 4
--grad_accum 32
--lr 0.002006911598778545
--warmup_steps 3000
--train_steps 250000
--val_every 50
--val_batches 50
--fp16
--seqlen 1024
--log_rate 10
--num_workers 4
--size GPT2_large_div2_width
--random
--resume final_runs/final_gpt2_large_div2_width_check_bs512_lr0.0021_warmup3k_seqlen1024_debug/checkpoint-xxx.ckpt
--doubling weights
或深度操作员:
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py
--save_prefix final_gpt2_large_div2_depthx2_check_bs512_lr0.0020_warmup3k_seqlen1024_debug
--gpu_count -1
--model gpt2
--tokenizer gpt2
--batch_size 4
--grad_accum 32
--lr 0.002006911598778545
--warmup_steps 3000
--train_steps 250000
--val_every 50
--val_batches 50
--fp16
--seqlen 1024
--log_rate 10
--num_workers 4
--size GPT2_large_div2_depth
--random
--resume final_runs/final_gpt2_large_div2_depth_check_bs512_lr0.0020_warmup3k_seqlen1024_debug/checkpoint-epoch=0-step=6499.ckpt
--doubling layers
使用evaluation/eval_wikitext.py或evaluation/eval_lambada.py来评估一个受支持的数据集上的gpt-2。例如:
python evaluation/eval_wikitext.py或使用Docker:
docker build -t evaluation:latest .
docker run --rm --gpus all evaluation:latest evaluation/eval_wikitext.py如果您在研究中使用分阶段的培训或希望参考此处发布的基线结果,请使用以下Bibtex条目。
@misc{shen2022staged,
title={Staged Training for Transformer Language Models},
author={Sheng Shen and Pete Walsh and Kurt Keutzer and Jesse Dodge and Matthew Peters and Iz Beltagy},
year={2022},
eprint={2203.06211},
archivePrefix={arXiv},
primaryClass={cs.CL}
}