Zhengxuan Wu*、Atticus Geiger*、Josh Rozner、Elisa Kreiss、Hanson Lu、Thomas Icard、Christopher Potts、Noah D. Goodman
これは、言語モデルのプリプリント因果蒸留の実装です。蒸留への標準的なアプローチは、学生モデルを2つの目的に対して訓練します。タスク固有の目的(言語モデリングなど)と、学生モデルの隠された状態がより大きな教師モデルの状態と類似することを奨励する模倣目標です。この論文では、インターチェンジ介入トレーニング(IIT)を通じて教師の因果計算プロセスを模倣することを生徒が奨励する第三の目的で蒸留を増強することが有益であることを示します。蒸留インターチェンジ介入トレーニング目標(DIITO)と名付けます。
Diitoは、リソースの低い設定で役立つことがわかります。 Diitoは、(97%)の標準蒸留でParを実行しますが、97%のデータでトレーニングします。
Huggingface Distillation Interfaceからメインコードベースをフォークします。
✅12/02/2021インターチェンジ介入トレーニング(IIT)に関する私たちの論文がリリースされました!この方法のより正式な定義については、これを読んでください。
✅12/06/2021は、プリプリントを使用して因果蒸留コードベースをリリースしました。
✅12/06/2021 Wiki-Text 103Mデータセットを使用して、蒸留型のTiny-Bert(3層)で評価結果をリリースしました。
✅01/14/2022 Diitoの新しいバージョンとその評価結果をリリースしました。詳細については、個人的に共有された更新前のプレリントを表示できます。
✅02/21/2022は、低リソース設定でのモデル蒸留をサポートすることに焦点を当てて、NLPのタスク固有のモデルを蒸留するためにDittoを適用するDiito-XXSのコードベースをリリースしました。詳細については、リポジトリをご覧ください!
⬜イギリスのウィキペディア + bookcorpusで訓練されたDiito(6層)モデルをリリースしました。
問題が発生したり、提案がある場合は、問題ページまたは[email protected]のThourghに連絡してください。
接着剤の開発セットの結果は次のとおりです。
| モデル | トレーニングトークンの# | 平均スコア | コーラ | mnli | MRPC | Qnli | QQP | rte | SST-2 | sts-b |
|---|---|---|---|---|---|---|---|---|---|---|
| Distilbert(6層)Devlin et al。、2019 | 3.3b | 79.59 | 51.30 | 82.10 | 87.50 | 89.20 | 88.50 | 59.90 | 91.30 | 86.90 |
| Distilbert(6層) | 0.1b | 75.80 | 40.43 | 78.95 | 87.45 | 84.76 | 84.96 | 60.10 | 89.38 | 80.40 |
| diito(6層) | 0.1b | 77.14 | 45.17 | 79.68 | 88.18 | 85.83 | 85.31 | 60.94 | 90.32 | 81.69 |
| diito(6層) | 3.3b | ( - ) | ( - ) | ( - ) | ( - ) | ( - ) | ( - ) | ( - ) | ( - ) | ( - ) |
このリポジトリを使用する場合は、次の2つの論文を引用してください。インターチェンジ介入トレーニングのための論文、蒸留方法の論文。
@article{geiger-etal-2021-iit,
title={Inducing Causal Structure for Interpretable Neural Networks},
author={Geiger, Atticus and Wu, Zhengxuan and Lu, Hanson and Rozner, Josh and Kreiss, Elisa and Icard, Thomas and Goodman, Noah D. and Potts, Christopher},
year={2021},
eprint={2112.00826},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
@article{wu-etal-2021-distill,
title={Causal Distillation for Language Models},
author={Wu, Zhengxuan and Geiger, Atticus and Rozner, Josh and Kreiss, Elisa and Lu, Hanson and Icard, Thomas and Potts, Christopher and Goodman, Noah D.},
year={2021},
eprint={2112.02505},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
Huggingface Distillation Interfaceに従って、蒸留を行う前にデータセットを前処理する必要があります。詳細については、リポジトリを参照できます。前処理スクリプトを適応させ、いくつかの改善で更新します。たとえば、Huggingfaceのデータセットハブからのデータセットを直接バイナリングできるようになりました。
# preprocessing from disk
python script/binarized_data.py
--file_path ../../bert-mid-tuning/data-files/wikitext-15M
--split train
--field_name text
--max_parsing_example 1000
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file ./data/binarized_text
# preprocessing from huggingface.
python scripts/binarized_data.py
--dataset_name bookcorpus
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file bookcorpus-dataset/binarized_text
--cache_dir ./distill_cache/
python scripts/binarized_data.py
--dataset_name wikitext
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file wikitext-dataset/binarized_text
--cache_dir ./distill_cache/
python scripts/binarized_data.py
--dataset_name wikitext+bookcorpus
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file wikitext+bookcorpus-dataset/binarized_text
--cache_dir ./distill_cache/
# helper scripts to combine two binarized data files
python scripts/data_combinator.py
--file_path_left ./bookcorpus-dataset/binarized_text.train.bert-base-uncased.pickle
--file_path_right ./wikitext-dataset/binarized_text.train.bert-base-uncased.pickle
--split train
--tokenizer_name bert-base-uncased
--dump_file wikitext+bookcorpus-dataset/binarized_text
# multiprocessing preprocessor.
python scripts/binarized_data.py
--dataset_name bookcorpus
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file bookcorpus-dataset/binarized_text
--cache_dir ./distill_cache/
--fast_process
--preprocessing_num_workers 48データセットを準備したら、トークンカウントも生成する必要があります。
python scripts/token_counts.py
--data_file data/binarized_text.train.bert-base-uncased.pickle
--token_counts_dump data/binarized_text.train.token_counts.bert-base-uncased.pickle
--vocab_size 30522トレーニングの前に、教師モデルから抽出されたウェイトで生徒モデルを初期化することをお勧めします。
python scripts/extract_distilbert.py
--model_type bert
--model_name bert-base-uncased
--dump_checkpoint ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth
--num_layers 3さて、ここにあなたが私たちの因果的蒸留目標で蒸留する例がありますか、それなしです。
CUDA_VISIBLE_DEVICES=0,1,2,3 python causal_train.py
--force
--n_gpu 4
--log_interval 10
--student_type distilbert
--student_config ./training_configs/distilbert-base-uncased-large.json
--student_pretrained_weights ./distillation_checkpoints/bert-base-uncased_num_layer_6.pth
--teacher_type bert
--teacher_name bert-base-uncased
--neuron_mapping ./training_configs/single_middle_layer_6.nm
--mlm --alpha_ce 0.25 --alpha_mlm 0.25 --alpha_cos 0.25 --alpha_clm 0.0 --alpha_causal_ce 0.25 --alpha_causal_cos 0.0
--interchange_prop 0.3 --interchange_max_token -1 --interchange_consecutive_only
--freeze_pos_embs
--dump_path ./results/
--data_file ./wikitext-dataset/binarized_text.train.bert-base-uncased.pickle
--token_counts ./wikitext-dataset/binarized_text.train.token_counts.bert-base-uncased.pickle
--seed 42
--n_epoch 3
--gradient_accumulation_steps 6
--batch_size 40引数を設定して、因果蒸留目的をオン/オフにするだけで簡単に変えることができることに注意してください。たとえば、最近、この引数を追加して、 --alpha_causal_cosを追加して、コサインの損失項での因果損失をサポートします。設定の有効なバッチサイズは240に設定されていることに注意してください。
蒸留モデルを取得した後、それらを微調整し、ダウンストリームタスクで評価する必要があります。実行する必要があるすべてのスクリプトを提供します。
CUDA_VISIBLE_DEVICES=0 python run_mlm.py
--model_name_or_path ./path_to_your_model/
--dataset_dir ../path_to_your_data/
--tokenizer_name bert-base-uncased
--do_eval
--output_dir /tmp/test-mlm
--cache_dir ./distill_cache/CUDA_VISIBLE_DEVICES=0,1,2,3 python run_glue.py
--model_name_or_path ./path_to_your_model/
--tokenizer_name bert-base-uncased
--task_name sst2
--do_train
--do_eval
--max_seq_length 128
--per_device_train_batch_size 32
--learning_rate 2e-5
--num_train_epochs 3
--output_dir ./results/
--save_total_limit 1
--cache_dir ./distill_cache/CUDA_VISIBLE_DEVICES=0,1,2,3 python run_ner.py
--model_name_or_path ./path_to_your_model/
--tokenizer_name bert-base-uncased
--dataset_name conll2003
--do_train
--do_eval
--output_dir ./ner_results/
--save_total_limit 1
--cache_dir ./distill_cache/CUDA_VISIBLE_DEVICES=0,1,2,3 python run_qa.py
--model_name_or_path ./path_to_your_model/
--tokenizer_name bert-base-uncased
--dataset_name squad
--do_train
--do_eval
--per_device_train_batch_size 12
--learning_rate 3e-5
--num_train_epochs 2
--max_seq_length 384
--doc_stride 128
--save_total_limit 1
--output_dir ./qa_results/