Zhengxuan Wu*, Atticus Geiger*, Josh Rozner, Elisa Kreiss, Hanson Lu, Thomas Icard, Christopher Potts, Noah D. Goodman
Das ist eine Implementierung unserer vorgeprinkten kausalen Destillation für Sprachmodelle. Der Standardansatz zur Destillation schult ein Schülermodell gegen zwei Ziele: ein aufgabenspezifisches Ziel (z. B. Sprachmodellierung) und ein Nachahmungsziel, das die verborgenen Zustände des Schülermodells ermutigt, denen des größeren Lehrermodells ähnlich zu sein. In diesem Artikel zeigen wir, dass es vorteilhaft ist, die Destillation mit einem dritten Ziel zu erweitern, das den Schüler dazu ermutigt, den kausalen Berechnungsprozess des Lehrers durch Austauschinterventionstraining (IIT) nachzuahmen. Wir nennen unsere Methode das Destillation Interchange Intervention Training Ziel (Diito) .
Wir finden, dass Diito in einer Umgebung mit niedrigem Ressourcen hilfreich ist. Diito führt eine On-Par-Destillation mit (97%) Standarddestillation durch, trainiert jedoch mit 97% weniger Daten.
Wir haben unsere Haupt -Codebasis von der Destillationsschnittstelle für die Destillation von Huggingface aus.
✅ 12/02/2021 Unser Papier zum Austauschinterventionstraining (IIT) wird veröffentlicht! Lesen Sie dies für eine formalere Definition der Methode.
✅ 12/06/2021 Die kausale Destillations -Codebasis mit dem Präprint veröffentlicht.
✅ 12/06/2021 Freigegebene Bewertungsergebnisse auf destilliertem Tiny-Bert (3 Schichten) mit dem Wiki-Text 103M-Datensatz.
✅ 14.01.2022 veröffentlichte eine neuere Version von Diito und seine Bewertungsergebnisse. Weitere Informationen finden Sie in unserem privat gemeinsam genutzten aktualisierten Vordruck.
✅ 02.02.2022 veröffentlichte die Codebasis für Diito-XXs , die dito auf taskspezifische Modelle in NLP anwendet und sich auf die Unterstützung der Modelldestillation in einer Ressourceneinstellung mit niedrigem Ressourcen konzentriert. Weitere Informationen finden Sie im Repo!
⬜️ Veröffentlichtes Diito (6 Layers) Modell, das mit englischer Wikipedia + bookcorpus trainiert wurde.
Wenn Sie Probleme haben oder Vorschläge haben, kontaktieren Sie mich bitte entweder auf der Seite der Probleme oder unter [email protected].
Hier sind die Ergebnisse der Entwicklersätze von Klebstoff:
| Modell | Anzahl der Trainingstoken | Durchschnittsschrei | Cola | Mnli | MRPC | Qnli | QQP | Rte | SST-2 | STS-B |
|---|---|---|---|---|---|---|---|---|---|---|
| Distilbert (6 Schichten) Devlin et al., 2019 | 3.3b | 79,59 | 51.30 | 82.10 | 87,50 | 89.20 | 88,50 | 59,90 | 91.30 | 86.90 |
| Distilbert (6 Schichten) | 0,1b | 75,80 | 40.43 | 78,95 | 87,45 | 84.76 | 84,96 | 60.10 | 89,38 | 80.40 |
| Diito (6 Schichten) | 0,1b | 77.14 | 45.17 | 79,68 | 88.18 | 85,83 | 85.31 | 60.94 | 90.32 | 81.69 |
| Diito (6 Schichten) | 3.3b | (-) | (-) | (-) | (-) | (-) | (-) | (-) | (-) | (-) |
Wenn Sie dieses Repository verwenden, zitieren Sie bitte die folgenden zwei Papiere: Papier für Interventionsinterventionstraining und Papier für unsere Destillationsmethode.
@article{geiger-etal-2021-iit,
title={Inducing Causal Structure for Interpretable Neural Networks},
author={Geiger, Atticus and Wu, Zhengxuan and Lu, Hanson and Rozner, Josh and Kreiss, Elisa and Icard, Thomas and Goodman, Noah D. and Potts, Christopher},
year={2021},
eprint={2112.00826},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
@article{wu-etal-2021-distill,
title={Causal Distillation for Language Models},
author={Wu, Zhengxuan and Geiger, Atticus and Rozner, Josh and Kreiss, Elisa and Lu, Hanson and Icard, Thomas and Potts, Christopher and Goodman, Noah D.},
year={2021},
eprint={2112.02505},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
Nach der Destillationsschnittstelle für die Destillation von Huggingface müssen wir die Datensätze vor verarbeiten, bevor wir die Destillation durchführen. Sie können sich auf ihr Repo beziehen, um Einzelheiten zu erhalten. Wir passen ihre Vorverarbeitungsskripte an und aktualisieren mit einigen Verbesserungen. Zum Beispiel können wir jetzt Datensätze aus dem Dataset -Hub direkt aus dem Umarmungsface brügen.
# preprocessing from disk
python script/binarized_data.py
--file_path ../../bert-mid-tuning/data-files/wikitext-15M
--split train
--field_name text
--max_parsing_example 1000
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file ./data/binarized_text
# preprocessing from huggingface.
python scripts/binarized_data.py
--dataset_name bookcorpus
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file bookcorpus-dataset/binarized_text
--cache_dir ./distill_cache/
python scripts/binarized_data.py
--dataset_name wikitext
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file wikitext-dataset/binarized_text
--cache_dir ./distill_cache/
python scripts/binarized_data.py
--dataset_name wikitext+bookcorpus
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file wikitext+bookcorpus-dataset/binarized_text
--cache_dir ./distill_cache/
# helper scripts to combine two binarized data files
python scripts/data_combinator.py
--file_path_left ./bookcorpus-dataset/binarized_text.train.bert-base-uncased.pickle
--file_path_right ./wikitext-dataset/binarized_text.train.bert-base-uncased.pickle
--split train
--tokenizer_name bert-base-uncased
--dump_file wikitext+bookcorpus-dataset/binarized_text
# multiprocessing preprocessor.
python scripts/binarized_data.py
--dataset_name bookcorpus
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file bookcorpus-dataset/binarized_text
--cache_dir ./distill_cache/
--fast_process
--preprocessing_num_workers 48Nachdem Sie die Datensätze vorbereitet haben, müssen Sie auch Token -Zählungen generieren.
python scripts/token_counts.py
--data_file data/binarized_text.train.bert-base-uncased.pickle
--token_counts_dump data/binarized_text.train.token_counts.bert-base-uncased.pickle
--vocab_size 30522Vor dem Training empfehlen wir Ihnen, Ihr Schülermodell mit Gewichten aus dem Lehrermodell zu initialisieren.
python scripts/extract_distilbert.py
--model_type bert
--model_name bert-base-uncased
--dump_checkpoint ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth
--num_layers 3Hier ist hier ein Beispiel für Sie, um mit unserem kausalen Destillationsziel oder ohne zu destillieren, ohne.
CUDA_VISIBLE_DEVICES=0,1,2,3 python causal_train.py
--force
--n_gpu 4
--log_interval 10
--student_type distilbert
--student_config ./training_configs/distilbert-base-uncased-large.json
--student_pretrained_weights ./distillation_checkpoints/bert-base-uncased_num_layer_6.pth
--teacher_type bert
--teacher_name bert-base-uncased
--neuron_mapping ./training_configs/single_middle_layer_6.nm
--mlm --alpha_ce 0.25 --alpha_mlm 0.25 --alpha_cos 0.25 --alpha_clm 0.0 --alpha_causal_ce 0.25 --alpha_causal_cos 0.0
--interchange_prop 0.3 --interchange_max_token -1 --interchange_consecutive_only
--freeze_pos_embs
--dump_path ./results/
--data_file ./wikitext-dataset/binarized_text.train.bert-base-uncased.pickle
--token_counts ./wikitext-dataset/binarized_text.train.token_counts.bert-base-uncased.pickle
--seed 42
--n_epoch 3
--gradient_accumulation_steps 6
--batch_size 40 Beachten Sie, dass Sie unser kausales Destillationsziel einfach durch Festlegen der Argumente ein- und ausschalten können. Zum Beispiel fügen wir kürzlich dieses Argument hinzu --alpha_causal_cos um den kausalen Verlust des Cosinus -Verlustdauers zu unterstützen. Beachten Sie, dass die effektive Chargengröße in unserer Einstellung auf 240 festgelegt ist.
Nachdem Sie Ihre destillierten Modelle erhalten haben, müssen Sie sie Feinabstimmung und bewerten sie mit nachgeschalteten Aufgaben. Wir stellen Ihnen alle Skripte zur Verfügung, die Sie ausführen müssen.
CUDA_VISIBLE_DEVICES=0 python run_mlm.py
--model_name_or_path ./path_to_your_model/
--dataset_dir ../path_to_your_data/
--tokenizer_name bert-base-uncased
--do_eval
--output_dir /tmp/test-mlm
--cache_dir ./distill_cache/CUDA_VISIBLE_DEVICES=0,1,2,3 python run_glue.py
--model_name_or_path ./path_to_your_model/
--tokenizer_name bert-base-uncased
--task_name sst2
--do_train
--do_eval
--max_seq_length 128
--per_device_train_batch_size 32
--learning_rate 2e-5
--num_train_epochs 3
--output_dir ./results/
--save_total_limit 1
--cache_dir ./distill_cache/CUDA_VISIBLE_DEVICES=0,1,2,3 python run_ner.py
--model_name_or_path ./path_to_your_model/
--tokenizer_name bert-base-uncased
--dataset_name conll2003
--do_train
--do_eval
--output_dir ./ner_results/
--save_total_limit 1
--cache_dir ./distill_cache/CUDA_VISIBLE_DEVICES=0,1,2,3 python run_qa.py
--model_name_or_path ./path_to_your_model/
--tokenizer_name bert-base-uncased
--dataset_name squad
--do_train
--do_eval
--per_device_train_batch_size 12
--learning_rate 3e-5
--num_train_epochs 2
--max_seq_length 384
--doc_stride 128
--save_total_limit 1
--output_dir ./qa_results/