Ce référentiel contient le code source de notre document de découvertes EMNLP 2020: le réglage adversaire du domaine en tant que régularisateur efficace.
Dans ce travail, nous proposons un nouveau type de régularisateur pour le processus de réglage fin des modèles de langue pré-entraînés (LMS). Nous identifions la perte de représentations du domaine général des LMS pré-entraînés lors du réglage fin comme une forme d' oubli catastrophique . Le terme contradictoire agit comme un régulariseur qui préserve la plupart des connaissances capturées par le LM pendant la pré-entraînement, empêchant l'oubli catastrophique.
Pour y remédier, nous étendons le processus de réglage fin standard de LMS pré-entraîné avec un objectif contradictoire. Ce terme de perte supplémentaire est lié à un classificateur contradictoire, qui discrimine les représentations de texte dans le domaine et hors du domaine .
Dans le domaine : ensemble de données étiqueté de la tâche ( principale ) à portée de main
Hors du domaine : données non marquées d'un domaine différent ( auxiliaire )
Nous minimisons la perte spécifique à la tâche et maximions en même temps la perte du classificateur de domaine à l'aide d'une couche d'inversion de gradient.
La fonction de perte que nous proposons est la suivante:
L après = l du domaine principal - λl
où L Main est la perte spécifique à la tâche et le domaine L une perte adversaire qui applique l'invariance des représentations de texte dans différents domaines, tout en réglant. λ est un hyperparamètre accordable.

Des expériences sur 4 ensembles de données de colle (COLA, MRPC, SST-2 et RTE) avec deux LMS pré-entraînés différents (Bert et XLNET) montrent des performances améliorées sur un réglage fin standard. Nous montrons empiriquement que le terme contradictoire agit comme un régulariseur qui préserve la plupart des connaissances capturées par le LM pendant la pré-formation, empêchant l'oubli catastrophique.
Créer un environnement (facultatif): Idéalement, vous devez créer un environnement pour le projet.
conda create -n after_env python=3.6
conda activate after_env
Installez Pytorch 1.1.0 avec la version CUDA souhaitée si vous souhaitez utiliser le GPU:
conda install pytorch==1.1.0 torchvision -c pytorch
Clone le projet:
git clone https://github.com/GeorgeVern/AFTERV1.0.git
cd AFTERV1.0
Installez ensuite le reste des exigences:
pip install -r requirements.txt
Pour télécharger les ensembles de données principaux , nous utilisons le script download_glue_data.py à partir d'ici. Vous pouvez choisir les ensembles de données utilisés dans le papier en exécutant la commande suivante:
python download_glue_data.py --data_dir './Datasets' --tasks 'CoLA,SST,RTE,MRPC
Le chemin par défaut pour les ensembles de données est après V1.0 / ensembles de données, mais tout autre chemin peut être utilisé (devrait être d'accord avec le chemin DATA_DIR spécifié dans le script sys_config )
En tant que données auxiliaires , nous utilisons les corpus à partir de divers domaines. Nous fournissons des scripts pour télécharger et prétraiter les corpus utilisés dans nos expériences, tandis que tout autre corporat peut également être utilisé.
Pour exécuter après Bert, vous avez besoin de la commande suivante:
python after_fine-tune.py -i afterBert_finetune_cola_europarl --lambd 0.1
lambd fait référence à Lambda, le poids de la fonction de perte articulaire que nous utilisons.
Dans configs/ , vous pouvez voir une liste de fichiers YAML que nous avons utilisés pour les expériences et peut également modifier leurs hyperparamètres.
Si vous utilisez ce dépôt dans votre recherche, veuillez citer le document:
@inproceedings{vernikos-etal-2020-domain,
title = "{D}omain {A}dversarial {F}ine-{T}uning as an {E}ffective {R}egularizer",
author = "Vernikos, Giorgos and
Margatina, Katerina and
Chronopoulou, Alexandra and
Androutsopoulos, Ion",
booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2020",
year = "2020",
url = "https://www.aclweb.org/anthology/2020.findings-emnlp.278",
doi = "10.18653/v1/2020.findings-emnlp.278",
pages = "3103--3112",
}