Zhengxuan Wu*,Atticus Geiger*,Josh Rozner,Elisa Kreiss,Hanson Lu,Thomas Icard,Christopher Potts,Noah D. Goodman
这是我们对语言模型的预印因蒸馏的实现。蒸馏的标准方法训练学生模型针对两个目标:特定于任务的目标(例如,语言建模)和鼓励学生模型的隐藏状态的模仿目标与较大的教师模型相似。在本文中,我们表明,以第三个目标来增加蒸馏是有益的,该目标鼓励学生通过互换干预培训(IIT)模仿教师的因果计算过程。我们将我们的方法命名为蒸馏互换干预培训目标(Diito) 。
我们发现Diito在低资源环境中很有帮助。 Diito用(97%)标准蒸馏进行了PAR,但数据少97%。
我们从拥抱面蒸馏界面拨出主代码库。
✅2012年12月2日,我们的有关交换干预培训(IIT)的论文已发布!阅读此方法以进行该方法的更正式的定义。
✅12/06/2021用预印本发布了因果蒸馏代码库。
✅12/06/2021与Wiki-Text 103M数据集发布了蒸馏式Tiny-Bert(3层)的评估结果。
✅01/14/2022发布了Diito的新版本及其评估结果。您可以查看我们私人共享的更新预印本以获取更多详细信息。
✅02/21/2022发布了Diito-XXS的代码库,该代码库将同上用于NLP中的特定于任务的模型,重点是在低资源设置中支持模型蒸馏。查看存储库以获取更多信息!
⬜️发行了Diito(6层)模型,该模型接受了英语Wikipedia + BookCorpus训练。
如果您遇到任何问题或有建议,请与我联系,请与我联系,或通过[email protected]与我联系。
这是胶水开发集的结果:
| 模型 | 训练令牌的# | 平均得分 | 可乐 | mnli | MRPC | Qnli | QQP | rte | SST-2 | STS-B |
|---|---|---|---|---|---|---|---|---|---|---|
| Distilbert(6层)Devlin等人,2019年 | 3.3b | 79.59 | 51.30 | 82.10 | 87.50 | 89.20 | 88.50 | 59.90 | 91.30 | 86.90 |
| Distilbert(6层) | 0.1b | 75.80 | 40.43 | 78.95 | 87.45 | 84.76 | 84.96 | 60.10 | 89.38 | 80.40 |
| Diito(6层) | 0.1b | 77.14 | 45.17 | 79.68 | 88.18 | 85.83 | 85.31 | 60.94 | 90.32 | 81.69 |
| Diito(6层) | 3.3b | ( - ) | ( - ) | ( - ) | ( - ) | ( - ) | ( - ) | ( - ) | ( - ) | ( - ) |
如果使用此存储库,请引用以下两篇论文:用于交换干预培训的纸张,以及我们的蒸馏方法的纸张。
@article{geiger-etal-2021-iit,
title={Inducing Causal Structure for Interpretable Neural Networks},
author={Geiger, Atticus and Wu, Zhengxuan and Lu, Hanson and Rozner, Josh and Kreiss, Elisa and Icard, Thomas and Goodman, Noah D. and Potts, Christopher},
year={2021},
eprint={2112.00826},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
@article{wu-etal-2021-distill,
title={Causal Distillation for Language Models},
author={Wu, Zhengxuan and Geiger, Atticus and Rozner, Josh and Kreiss, Elisa and Lu, Hanson and Icard, Thomas and Potts, Christopher and Goodman, Noah D.},
year={2021},
eprint={2112.02505},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
在拥抱面蒸馏接口之后,我们需要在进行蒸馏之前对数据集进行预处理。您可以参考他们的仓库以获取详细信息。我们调整他们的预处理脚本,并进行一些改进。例如,我们现在可以直接从数据集集线器中直接从数据集集线器进行二进制数据集。
# preprocessing from disk
python script/binarized_data.py
--file_path ../../bert-mid-tuning/data-files/wikitext-15M
--split train
--field_name text
--max_parsing_example 1000
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file ./data/binarized_text
# preprocessing from huggingface.
python scripts/binarized_data.py
--dataset_name bookcorpus
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file bookcorpus-dataset/binarized_text
--cache_dir ./distill_cache/
python scripts/binarized_data.py
--dataset_name wikitext
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file wikitext-dataset/binarized_text
--cache_dir ./distill_cache/
python scripts/binarized_data.py
--dataset_name wikitext+bookcorpus
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file wikitext+bookcorpus-dataset/binarized_text
--cache_dir ./distill_cache/
# helper scripts to combine two binarized data files
python scripts/data_combinator.py
--file_path_left ./bookcorpus-dataset/binarized_text.train.bert-base-uncased.pickle
--file_path_right ./wikitext-dataset/binarized_text.train.bert-base-uncased.pickle
--split train
--tokenizer_name bert-base-uncased
--dump_file wikitext+bookcorpus-dataset/binarized_text
# multiprocessing preprocessor.
python scripts/binarized_data.py
--dataset_name bookcorpus
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file bookcorpus-dataset/binarized_text
--cache_dir ./distill_cache/
--fast_process
--preprocessing_num_workers 48准备好数据集后,您也需要生成令牌计数。
python scripts/token_counts.py
--data_file data/binarized_text.train.bert-base-uncased.pickle
--token_counts_dump data/binarized_text.train.token_counts.bert-base-uncased.pickle
--vocab_size 30522在培训之前,我们建议您使用从教师模型中提取的权重初始化学生模型。
python scripts/extract_distilbert.py
--model_type bert
--model_name bert-base-uncased
--dump_checkpoint ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth
--num_layers 3现在,这是您以我们的因果蒸馏目标或没有的,
CUDA_VISIBLE_DEVICES=0,1,2,3 python causal_train.py
--force
--n_gpu 4
--log_interval 10
--student_type distilbert
--student_config ./training_configs/distilbert-base-uncased-large.json
--student_pretrained_weights ./distillation_checkpoints/bert-base-uncased_num_layer_6.pth
--teacher_type bert
--teacher_name bert-base-uncased
--neuron_mapping ./training_configs/single_middle_layer_6.nm
--mlm --alpha_ce 0.25 --alpha_mlm 0.25 --alpha_cos 0.25 --alpha_clm 0.0 --alpha_causal_ce 0.25 --alpha_causal_cos 0.0
--interchange_prop 0.3 --interchange_max_token -1 --interchange_consecutive_only
--freeze_pos_embs
--dump_path ./results/
--data_file ./wikitext-dataset/binarized_text.train.bert-base-uncased.pickle
--token_counts ./wikitext-dataset/binarized_text.train.token_counts.bert-base-uncased.pickle
--seed 42
--n_epoch 3
--gradient_accumulation_steps 6
--batch_size 40请注意,您可以简单地通过设置参数来打开/关闭我们的因果蒸馏目标。例如,我们最近添加了此参数--alpha_causal_cos ,以支持余弦损失项的因果损失。请注意,我们的设置中的有效批量尺寸设置为240。
获得蒸馏型型号后,您需要对它们进行微调并通过下游任务对其进行评估。我们为您提供您需要运行的所有脚本。
CUDA_VISIBLE_DEVICES=0 python run_mlm.py
--model_name_or_path ./path_to_your_model/
--dataset_dir ../path_to_your_data/
--tokenizer_name bert-base-uncased
--do_eval
--output_dir /tmp/test-mlm
--cache_dir ./distill_cache/CUDA_VISIBLE_DEVICES=0,1,2,3 python run_glue.py
--model_name_or_path ./path_to_your_model/
--tokenizer_name bert-base-uncased
--task_name sst2
--do_train
--do_eval
--max_seq_length 128
--per_device_train_batch_size 32
--learning_rate 2e-5
--num_train_epochs 3
--output_dir ./results/
--save_total_limit 1
--cache_dir ./distill_cache/CUDA_VISIBLE_DEVICES=0,1,2,3 python run_ner.py
--model_name_or_path ./path_to_your_model/
--tokenizer_name bert-base-uncased
--dataset_name conll2003
--do_train
--do_eval
--output_dir ./ner_results/
--save_total_limit 1
--cache_dir ./distill_cache/CUDA_VISIBLE_DEVICES=0,1,2,3 python run_qa.py
--model_name_or_path ./path_to_your_model/
--tokenizer_name bert-base-uncased
--dataset_name squad
--do_train
--do_eval
--per_device_train_batch_size 12
--learning_rate 3e-5
--num_train_epochs 2
--max_seq_length 384
--doc_stride 128
--save_total_limit 1
--output_dir ./qa_results/