Zhengxuan Wu*,Atticus Geiger*,Josh Rozner,Elisa Kreiss,Hanson Lu,Thomas Icard,Christopher Potts,Noah D. Goodman
這是我們對語言模型的預印因蒸餾的實現。蒸餾的標準方法訓練學生模型針對兩個目標:特定於任務的目標(例如,語言建模)和鼓勵學生模型的隱藏狀態的模仿目標與較大的教師模型相似。在本文中,我們表明,以第三個目標來增加蒸餾是有益的,該目標鼓勵學生通過互換干預培訓(IIT)模仿教師的因果計算過程。我們將我們的方法命名為蒸餾互換干預培訓目標(Diito) 。
我們發現Diito在低資源環境中很有幫助。 Diito用(97%)標準蒸餾進行了PAR,但數據少97%。
我們從擁抱面蒸餾界面撥出主代碼庫。
✅2012年12月2日,我們的有關交換干預培訓(IIT)的論文已發布!閱讀此方法以進行該方法的更正式的定義。
✅12/06/2021用預印本發布了因果蒸餾代碼庫。
✅12/06/2021與Wiki-Text 103M數據集發布了蒸餾式Tiny-Bert(3層)的評估結果。
✅01/14/2022發布了Diito的新版本及其評估結果。您可以查看我們私人共享的更新預印本以獲取更多詳細信息。
✅02/21/2022發布了Diito-XXS的代碼庫,該代碼庫將同上用於NLP中的特定於任務的模型,重點是在低資源設置中支持模型蒸餾。查看存儲庫以獲取更多信息!
⬜️發行了Diito(6層)模型,該模型接受了英語Wikipedia + BookCorpus訓練。
如果您遇到任何問題或有建議,請與我聯繫,請與我聯繫,或通過[email protected]與我聯繫。
這是膠水開發集的結果:
| 模型 | 訓練令牌的# | 平均得分 | 可樂 | mnli | MRPC | Qnli | QQP | rte | SST-2 | STS-B |
|---|---|---|---|---|---|---|---|---|---|---|
| Distilbert(6層)Devlin等人,2019年 | 3.3b | 79.59 | 51.30 | 82.10 | 87.50 | 89.20 | 88.50 | 59.90 | 91.30 | 86.90 |
| Distilbert(6層) | 0.1b | 75.80 | 40.43 | 78.95 | 87.45 | 84.76 | 84.96 | 60.10 | 89.38 | 80.40 |
| Diito(6層) | 0.1b | 77.14 | 45.17 | 79.68 | 88.18 | 85.83 | 85.31 | 60.94 | 90.32 | 81.69 |
| Diito(6層) | 3.3b | ( - ) | ( - ) | ( - ) | ( - ) | ( - ) | ( - ) | ( - ) | ( - ) | ( - ) |
如果使用此存儲庫,請引用以下兩篇論文:用於交換干預培訓的紙張,以及我們的蒸餾方法的紙張。
@article{geiger-etal-2021-iit,
title={Inducing Causal Structure for Interpretable Neural Networks},
author={Geiger, Atticus and Wu, Zhengxuan and Lu, Hanson and Rozner, Josh and Kreiss, Elisa and Icard, Thomas and Goodman, Noah D. and Potts, Christopher},
year={2021},
eprint={2112.00826},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
@article{wu-etal-2021-distill,
title={Causal Distillation for Language Models},
author={Wu, Zhengxuan and Geiger, Atticus and Rozner, Josh and Kreiss, Elisa and Lu, Hanson and Icard, Thomas and Potts, Christopher and Goodman, Noah D.},
year={2021},
eprint={2112.02505},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
在擁抱面蒸餾接口之後,我們需要在進行蒸餾之前對數據集進行預處理。您可以參考他們的倉庫以獲取詳細信息。我們調整他們的預處理腳本,並進行一些改進。例如,我們現在可以直接從數據集集線器中直接從數據集集線器進行二進制數據集。
# preprocessing from disk
python script/binarized_data.py
--file_path ../../bert-mid-tuning/data-files/wikitext-15M
--split train
--field_name text
--max_parsing_example 1000
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file ./data/binarized_text
# preprocessing from huggingface.
python scripts/binarized_data.py
--dataset_name bookcorpus
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file bookcorpus-dataset/binarized_text
--cache_dir ./distill_cache/
python scripts/binarized_data.py
--dataset_name wikitext
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file wikitext-dataset/binarized_text
--cache_dir ./distill_cache/
python scripts/binarized_data.py
--dataset_name wikitext+bookcorpus
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file wikitext+bookcorpus-dataset/binarized_text
--cache_dir ./distill_cache/
# helper scripts to combine two binarized data files
python scripts/data_combinator.py
--file_path_left ./bookcorpus-dataset/binarized_text.train.bert-base-uncased.pickle
--file_path_right ./wikitext-dataset/binarized_text.train.bert-base-uncased.pickle
--split train
--tokenizer_name bert-base-uncased
--dump_file wikitext+bookcorpus-dataset/binarized_text
# multiprocessing preprocessor.
python scripts/binarized_data.py
--dataset_name bookcorpus
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file bookcorpus-dataset/binarized_text
--cache_dir ./distill_cache/
--fast_process
--preprocessing_num_workers 48準備好數據集後,您也需要生成令牌計數。
python scripts/token_counts.py
--data_file data/binarized_text.train.bert-base-uncased.pickle
--token_counts_dump data/binarized_text.train.token_counts.bert-base-uncased.pickle
--vocab_size 30522在培訓之前,我們建議您使用從教師模型中提取的權重初始化學生模型。
python scripts/extract_distilbert.py
--model_type bert
--model_name bert-base-uncased
--dump_checkpoint ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth
--num_layers 3現在,這是您以我們的因果蒸餾目標或沒有的,
CUDA_VISIBLE_DEVICES=0,1,2,3 python causal_train.py
--force
--n_gpu 4
--log_interval 10
--student_type distilbert
--student_config ./training_configs/distilbert-base-uncased-large.json
--student_pretrained_weights ./distillation_checkpoints/bert-base-uncased_num_layer_6.pth
--teacher_type bert
--teacher_name bert-base-uncased
--neuron_mapping ./training_configs/single_middle_layer_6.nm
--mlm --alpha_ce 0.25 --alpha_mlm 0.25 --alpha_cos 0.25 --alpha_clm 0.0 --alpha_causal_ce 0.25 --alpha_causal_cos 0.0
--interchange_prop 0.3 --interchange_max_token -1 --interchange_consecutive_only
--freeze_pos_embs
--dump_path ./results/
--data_file ./wikitext-dataset/binarized_text.train.bert-base-uncased.pickle
--token_counts ./wikitext-dataset/binarized_text.train.token_counts.bert-base-uncased.pickle
--seed 42
--n_epoch 3
--gradient_accumulation_steps 6
--batch_size 40請注意,您可以簡單地通過設置參數來打開/關閉我們的因果蒸餾目標。例如,我們最近添加了此參數--alpha_causal_cos ,以支持餘弦損失項的因果損失。請注意,我們的設置中的有效批量尺寸設置為240。
獲得蒸餾型型號後,您需要對它們進行微調並通過下游任務對其進行評估。我們為您提供您需要運行的所有腳本。
CUDA_VISIBLE_DEVICES=0 python run_mlm.py
--model_name_or_path ./path_to_your_model/
--dataset_dir ../path_to_your_data/
--tokenizer_name bert-base-uncased
--do_eval
--output_dir /tmp/test-mlm
--cache_dir ./distill_cache/CUDA_VISIBLE_DEVICES=0,1,2,3 python run_glue.py
--model_name_or_path ./path_to_your_model/
--tokenizer_name bert-base-uncased
--task_name sst2
--do_train
--do_eval
--max_seq_length 128
--per_device_train_batch_size 32
--learning_rate 2e-5
--num_train_epochs 3
--output_dir ./results/
--save_total_limit 1
--cache_dir ./distill_cache/CUDA_VISIBLE_DEVICES=0,1,2,3 python run_ner.py
--model_name_or_path ./path_to_your_model/
--tokenizer_name bert-base-uncased
--dataset_name conll2003
--do_train
--do_eval
--output_dir ./ner_results/
--save_total_limit 1
--cache_dir ./distill_cache/CUDA_VISIBLE_DEVICES=0,1,2,3 python run_qa.py
--model_name_or_path ./path_to_your_model/
--tokenizer_name bert-base-uncased
--dataset_name squad
--do_train
--do_eval
--per_device_train_batch_size 12
--learning_rate 3e-5
--num_train_epochs 2
--max_seq_length 384
--doc_stride 128
--save_total_limit 1
--output_dir ./qa_results/