Zhengxuan Wu*, Atticus Geiger*, Josh Rozner, Elisa Kreiss, Hanson Lu, Thomas Icard, Christopher Potts, Noah D. Goodman
Es una implementación de nuestra destilación causal de preimpresión para los modelos de idiomas. El enfoque estándar para la destilación entrena a un modelo de estudiante con dos objetivos: un objetivo específico de la tarea (por ejemplo, modelado de idiomas) y un objetivo de imitación que fomenta que los estados ocultos del modelo de estudiante sean similares a los del modelo de maestro más grande. En este documento, mostramos que es beneficioso aumentar la destilación con un tercer objetivo que alienta al estudiante a imitar el proceso de cálculo causal del maestro a través de la capacitación de intervención de intercambio (IIT). Nombre nuestro método El objetivo de entrenamiento de intervención de intercambio de destilación (DIITO) .
Encontramos que Diito es útil en una configuración de baja recursos. Diito funciona en PAR con (97%) destilación estándar pero entrenando con un 97% menos de datos.
Bifurcamos nuestra base de código principal de la interfaz de destilación de Huggingface.
✅ 12/02/2021 ¡Se publica nuestro artículo sobre el entrenamiento de intervención de intercambio (IIT)! Lea esto para una definición más formal del método.
✅ 06/12/2021 lanzó la base de código de destilación causal con la preimpresión.
✅ 06/12/2021 Los resultados de la evaluación liberados en Tiny-Bert (3 capas) destilados con el conjunto de datos Wiki-Text 103M.
✅ 14/01/2022 lanzó una versión más nueva de Diito y sus resultados de evaluación. Puede ver nuestra preimpresión actualizada privada para obtener más detalles.
✅ 21/02/2022 lanzó la base de código para DIITO-XX que aplica ídem para destilar modelos específicos de tareas en PNL con un enfoque en la destilación de modelos de soporte en una configuración de bajo contenido de recursos. ¡Mira el repositorio para más información!
⬜️ Lanzado modelo Diito (6 capas) entrenado con Wikipedia en inglés + bookcorpus.
Si experimenta algún problema o tiene sugerencias, contácteme Thourgh la página de problemas o en [email protected].
Aquí están los resultados en los conjuntos de Glue:
| Modelo | # de tokens de entrenamiento | Puntaje promedio | Reajuste salarial | Mnli | MRPC | Qnli | QQP | RTE | SST-2 | STS-B |
|---|---|---|---|---|---|---|---|---|---|---|
| Distilbert (6 capas) Devlin et al., 2019 | 3.3b | 79.59 | 51.30 | 82.10 | 87.50 | 89.20 | 88.50 | 59.90 | 91.30 | 86.90 |
| Distilbert (6 capas) | 0.1b | 75.80 | 40.43 | 78.95 | 87.45 | 84.76 | 84.96 | 60.10 | 89.38 | 80.40 |
| Diito (6 capas) | 0.1b | 77.14 | 45.17 | 79.68 | 88.18 | 85.83 | 85.31 | 60.94 | 90.32 | 81.69 |
| Diito (6 capas) | 3.3b | (-) | (-) | (-) | (-) | (-) | (-) | (-) | (-) | (-) |
Si usa este repositorio, cita los siguientes dos documentos: documento para el entrenamiento de intervención de intercambio y el documento para el método de destilación de nuestro Destilación.
@article{geiger-etal-2021-iit,
title={Inducing Causal Structure for Interpretable Neural Networks},
author={Geiger, Atticus and Wu, Zhengxuan and Lu, Hanson and Rozner, Josh and Kreiss, Elisa and Icard, Thomas and Goodman, Noah D. and Potts, Christopher},
year={2021},
eprint={2112.00826},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
@article{wu-etal-2021-distill,
title={Causal Distillation for Language Models},
author={Wu, Zhengxuan and Geiger, Atticus and Rozner, Josh and Kreiss, Elisa and Lu, Hanson and Icard, Thomas and Potts, Christopher and Goodman, Noah D.},
year={2021},
eprint={2112.02505},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
Después de la interfaz de destilación de Huggingface, necesitamos preprocesar los conjuntos de datos antes de hacer destilación. Puede consultar su repositorio para más detalles. Adaptamos sus scripts de preprocesamiento y actualizamos con algunas mejoras. Por ejemplo, ahora podemos binarizar los conjuntos de datos desde el centro de datos de Huggingface directamente.
# preprocessing from disk
python script/binarized_data.py
--file_path ../../bert-mid-tuning/data-files/wikitext-15M
--split train
--field_name text
--max_parsing_example 1000
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file ./data/binarized_text
# preprocessing from huggingface.
python scripts/binarized_data.py
--dataset_name bookcorpus
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file bookcorpus-dataset/binarized_text
--cache_dir ./distill_cache/
python scripts/binarized_data.py
--dataset_name wikitext
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file wikitext-dataset/binarized_text
--cache_dir ./distill_cache/
python scripts/binarized_data.py
--dataset_name wikitext+bookcorpus
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file wikitext+bookcorpus-dataset/binarized_text
--cache_dir ./distill_cache/
# helper scripts to combine two binarized data files
python scripts/data_combinator.py
--file_path_left ./bookcorpus-dataset/binarized_text.train.bert-base-uncased.pickle
--file_path_right ./wikitext-dataset/binarized_text.train.bert-base-uncased.pickle
--split train
--tokenizer_name bert-base-uncased
--dump_file wikitext+bookcorpus-dataset/binarized_text
# multiprocessing preprocessor.
python scripts/binarized_data.py
--dataset_name bookcorpus
--split train
--field_name text
--tokenizer_type bert
--tokenizer_name bert-base-uncased
--dump_file bookcorpus-dataset/binarized_text
--cache_dir ./distill_cache/
--fast_process
--preprocessing_num_workers 48Después de preparar los conjuntos de datos, también debe generar recuentos de tokens.
python scripts/token_counts.py
--data_file data/binarized_text.train.bert-base-uncased.pickle
--token_counts_dump data/binarized_text.train.token_counts.bert-base-uncased.pickle
--vocab_size 30522Antes de la capacitación, le recomendamos que inicialice su modelo de estudiante con pesos extraídos del modelo de maestro.
python scripts/extract_distilbert.py
--model_type bert
--model_name bert-base-uncased
--dump_checkpoint ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth
--num_layers 3Ahora, aquí hay un ejemplo para que usted se destile con nuestro objetivo de destilación causal o sin,
CUDA_VISIBLE_DEVICES=0,1,2,3 python causal_train.py
--force
--n_gpu 4
--log_interval 10
--student_type distilbert
--student_config ./training_configs/distilbert-base-uncased-large.json
--student_pretrained_weights ./distillation_checkpoints/bert-base-uncased_num_layer_6.pth
--teacher_type bert
--teacher_name bert-base-uncased
--neuron_mapping ./training_configs/single_middle_layer_6.nm
--mlm --alpha_ce 0.25 --alpha_mlm 0.25 --alpha_cos 0.25 --alpha_clm 0.0 --alpha_causal_ce 0.25 --alpha_causal_cos 0.0
--interchange_prop 0.3 --interchange_max_token -1 --interchange_consecutive_only
--freeze_pos_embs
--dump_path ./results/
--data_file ./wikitext-dataset/binarized_text.train.bert-base-uncased.pickle
--token_counts ./wikitext-dataset/binarized_text.train.token_counts.bert-base-uncased.pickle
--seed 42
--n_epoch 3
--gradient_accumulation_steps 6
--batch_size 40 Tenga en cuenta que simplemente puede encender nuestro objetivo de destilación causal a través de la configuración de los argumentos. Por ejemplo, recientemente agregamos este argumento --alpha_causal_cos para respaldar la pérdida causal en el término de pérdida de coseno. Tenga en cuenta que el tamaño de lote efectivo en nuestra configuración se establece en 240.
Después de obtener sus modelos destilados, debe ajustarlos y evaluarlos con tareas aguas abajo. Le proporcionamos todos los scripts que necesita ejecutar.
CUDA_VISIBLE_DEVICES=0 python run_mlm.py
--model_name_or_path ./path_to_your_model/
--dataset_dir ../path_to_your_data/
--tokenizer_name bert-base-uncased
--do_eval
--output_dir /tmp/test-mlm
--cache_dir ./distill_cache/CUDA_VISIBLE_DEVICES=0,1,2,3 python run_glue.py
--model_name_or_path ./path_to_your_model/
--tokenizer_name bert-base-uncased
--task_name sst2
--do_train
--do_eval
--max_seq_length 128
--per_device_train_batch_size 32
--learning_rate 2e-5
--num_train_epochs 3
--output_dir ./results/
--save_total_limit 1
--cache_dir ./distill_cache/CUDA_VISIBLE_DEVICES=0,1,2,3 python run_ner.py
--model_name_or_path ./path_to_your_model/
--tokenizer_name bert-base-uncased
--dataset_name conll2003
--do_train
--do_eval
--output_dir ./ner_results/
--save_total_limit 1
--cache_dir ./distill_cache/CUDA_VISIBLE_DEVICES=0,1,2,3 python run_qa.py
--model_name_or_path ./path_to_your_model/
--tokenizer_name bert-base-uncased
--dataset_name squad
--do_train
--do_eval
--per_device_train_batch_size 12
--learning_rate 3e-5
--num_train_epochs 2
--max_seq_length 384
--doc_stride 128
--save_total_limit 1
--output_dir ./qa_results/