Zhengxuan Wu *, Atticus Geiger *, Josh Rozner, Elisa Kreiss, Hanson Lu, Thomas Icard, Christopher Potts, Noah D. Goodman
Il s'agit d'une mise en œuvre de notre distillation causale préalable pour les modèles de langue. L'approche standard de la distillation forme un modèle étudiant contre deux objectifs: un objectif spécifique à la tâche (par exemple, la modélisation de la langue) et un objectif d'imitation qui encourage les états cachés du modèle étudiant comme similaires à ceux du modèle d'enseignant plus large. Dans cet article, nous montrons qu'il est avantageux d'augmenter la distillation avec un troisième objectif qui encourage l'élève à imiter le processus de calcul causal de l'enseignant par l'intervention d'échange (IIT). Nous nommons notre méthode l'objectif de formation d'intervention de l'échange de distillation (DIITO) .
Nous constatons que DIITO est utile dans un cadre à faible ressource. DIITO effectue des parts sur la pAR avec (97%) de distillation standard mais une formation avec 97% de données moins.
Nous débordons notre base de code principale à partir de l'interface de distillation Huggingface.
✅ 12/02/2021 Notre article sur la formation à l'intervention d'échange (IIT) est publié! Lisez ceci pour une définition plus formelle de la méthode.
✅ 12/06/2021 a libéré la base de code de distillation causale avec la préimpression.
✅ 12/06/2021 Résultats de l'évaluation publiés sur Distily Tiny-Bert (3 couches) avec l'ensemble de données Wiki-Text 103M.
✅ 14/01/2022 a publié une nouvelle version de DIITO et ses résultats d'évaluation. Vous pouvez afficher notre préimprimée mise à jour partagée en privé pour plus de détails.
✅ 21/02/2022 a publié la base de code pour les diito-xx qui applique un idem pour distiller les modèles spécifiques à la tâche dans NLP en mettant l'accent sur la support de distillation du modèle dans un cadre à faible ressource. Consultez le dépôt pour plus d'informations!
⬜️ Modèle de diito (6 couches) a été formé avec l'anglais wikipedia + bookcorpus.
Si vous rencontrez des problèmes ou avez des suggestions, veuillez me contacter soit la page des problèmes ou à [email protected].
Voici les résultats sur les développeurs de colle:
| Modèle | # de jetons de formation | Score moyen | Cola | MNLI | MRPC | QNLI | QQP | Rte | SST-2 | STS-B |
|---|---|---|---|---|---|---|---|---|---|---|
| Distilbert (6 couches) Devlin et al., 2019 | 3.3b | 79.59 | 51.30 | 82.10 | 87.50 | 89.20 | 88.50 | 59,90 | 91.30 | 86.90 |
| Distilbert (6 couches) | 0.1b | 75.80 | 40.43 | 78,95 | 87.45 | 84.76 | 84.96 | 60.10 | 89.38 | 80.40 |
| Diito (6 couches) | 0.1b | 77.14 | 45.17 | 79.68 | 88.18 | 85,83 | 85.31 | 60,94 | 90.32 | 81,69 |
| Diito (6 couches) | 3.3b | (-) | (-) | (-) | (-) | (-) | (-) | (-) | (-) | (-) |
Si vous utilisez ce référentiel, veuillez citer les deux articles suivants: papier pour la formation d'intervention d'échange et papier pour la méthode de notre distillation.
@article{geiger-etal-2021-iit,
title={Inducing Causal Structure for Interpretable Neural Networks},
author={Geiger, Atticus and Wu, Zhengxuan and Lu, Hanson and Rozner, Josh and Kreiss, Elisa and Icard, Thomas and Goodman, Noah D. and Potts, Christopher},
year={2021},
eprint={2112.00826},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
@article{wu-etal-2021-distill,
title={Causal Distillation for Language Models},
author={Wu, Zhengxuan and Geiger, Atticus and Rozner, Josh and Kreiss, Elisa and Lu, Hanson and Icard, Thomas and Potts, Christopher and Goodman, Noah D.},
year={2021},
eprint={2112.02505},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
Après l'interface de distillation HuggingFace, nous devons prétraiter les ensembles de données avant de faire la distillation. Vous pouvez vous référer à leur dépôt pour plus de détails. Nous adaptons leurs scripts de prétraitement et mettons à jour avec quelques améliorations. Par exemple, nous pouvons désormais binariser des ensembles de données à partir du centre de jeu de données à partir de HuggingFace directement.
# preprocessing from disk
python script/binarized_data.py
--file_path ../../bert-mid-tuning/data-files/wikitext-15M
--split train
--field_name text
--max_parsing_example 1000
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file ./data/binarized_text
# preprocessing from huggingface.
python scripts/binarized_data.py
--dataset_name bookcorpus
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file bookcorpus-dataset/binarized_text
--cache_dir ./distill_cache/
python scripts/binarized_data.py
--dataset_name wikitext
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file wikitext-dataset/binarized_text
--cache_dir ./distill_cache/
python scripts/binarized_data.py
--dataset_name wikitext+bookcorpus
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file wikitext+bookcorpus-dataset/binarized_text
--cache_dir ./distill_cache/
# helper scripts to combine two binarized data files
python scripts/data_combinator.py
--file_path_left ./bookcorpus-dataset/binarized_text.train.bert-base-uncased.pickle
--file_path_right ./wikitext-dataset/binarized_text.train.bert-base-uncased.pickle
--split train
--tokenizer_name bert-base-uncased
--dump_file wikitext+bookcorpus-dataset/binarized_text
# multiprocessing preprocessor.
python scripts/binarized_data.py
--dataset_name bookcorpus
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file bookcorpus-dataset/binarized_text
--cache_dir ./distill_cache/
--fast_process
--preprocessing_num_workers 48Après avoir préparé les ensembles de données, vous devez également générer des comptages de jetons.
python scripts/token_counts.py
--data_file data/binarized_text.train.bert-base-uncased.pickle
--token_counts_dump data/binarized_text.train.token_counts.bert-base-uncased.pickle
--vocab_size 30522Avant la formation, nous vous recommandons d'initialiser votre modèle d'élève avec des poids extraits du modèle enseignant.
python scripts/extract_distilbert.py
--model_type bert
--model_name bert-base-uncased
--dump_checkpoint ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth
--num_layers 3Maintenant, voici un exemple pour vous distiller avec notre objectif de distillation causale ou sans,
CUDA_VISIBLE_DEVICES=0,1,2,3 python causal_train.py
--force
--n_gpu 4
--log_interval 10
--student_type distilbert
--student_config ./training_configs/distilbert-base-uncased-large.json
--student_pretrained_weights ./distillation_checkpoints/bert-base-uncased_num_layer_6.pth
--teacher_type bert
--teacher_name bert-base-uncased
--neuron_mapping ./training_configs/single_middle_layer_6.nm
--mlm --alpha_ce 0.25 --alpha_mlm 0.25 --alpha_cos 0.25 --alpha_clm 0.0 --alpha_causal_ce 0.25 --alpha_causal_cos 0.0
--interchange_prop 0.3 --interchange_max_token -1 --interchange_consecutive_only
--freeze_pos_embs
--dump_path ./results/
--data_file ./wikitext-dataset/binarized_text.train.bert-base-uncased.pickle
--token_counts ./wikitext-dataset/binarized_text.train.token_counts.bert-base-uncased.pickle
--seed 42
--n_epoch 3
--gradient_accumulation_steps 6
--batch_size 40 Notez que vous pouvez simplement tourner notre objectif de distillation causal activé / désactivé en définissant les arguments. Par exemple, nous ajoutons récemment cet argument --alpha_causal_cos pour soutenir la perte de causalité sur le terme de perte de cosinus. Notez que la taille efficace du lot dans notre paramètre est définie sur 240.
Après avoir obtenu vos modèles distillés, vous devez les affiner et les évaluer avec des tâches en aval. Nous vous fournissons tous les scripts que vous devez exécuter.
CUDA_VISIBLE_DEVICES=0 python run_mlm.py
--model_name_or_path ./path_to_your_model/
--dataset_dir ../path_to_your_data/
--tokenizer_name bert-base-uncased
--do_eval
--output_dir /tmp/test-mlm
--cache_dir ./distill_cache/CUDA_VISIBLE_DEVICES=0,1,2,3 python run_glue.py
--model_name_or_path ./path_to_your_model/
--tokenizer_name bert-base-uncased
--task_name sst2
--do_train
--do_eval
--max_seq_length 128
--per_device_train_batch_size 32
--learning_rate 2e-5
--num_train_epochs 3
--output_dir ./results/
--save_total_limit 1
--cache_dir ./distill_cache/CUDA_VISIBLE_DEVICES=0,1,2,3 python run_ner.py
--model_name_or_path ./path_to_your_model/
--tokenizer_name bert-base-uncased
--dataset_name conll2003
--do_train
--do_eval
--output_dir ./ner_results/
--save_total_limit 1
--cache_dir ./distill_cache/CUDA_VISIBLE_DEVICES=0,1,2,3 python run_qa.py
--model_name_or_path ./path_to_your_model/
--tokenizer_name bert-base-uncased
--dataset_name squad
--do_train
--do_eval
--per_device_train_batch_size 12
--learning_rate 3e-5
--num_train_epochs 2
--max_seq_length 384
--doc_stride 128
--save_total_limit 1
--output_dir ./qa_results/