Ce référentiel contient la mise en œuvre officielle du papier un modèle de diffusion discrète réparatérisé pour la génération de texte.
La base de code est implémentée avec Fairseq. Pour installer les dépendances, exécutez (recommandé dans un environnement virtuel) les commandes suivantes:
pip install -r requirements.txt
# install our package of discrete diffusion models
pip install -e discrete_diffusion
# install our fork of fairseq
cd fairseq
python3 setup.py build develop
cd ..Remarque L'environnement est testé avec Python 3.8.10, Pytorch 1.10.0 / 1.12.0 et CUDA 11.3. Notez également que notre fourche de Fairseq modifie plusieurs fichiers dans la base de code d'origine; L'utilisation de versions plus récentes de Fairseq pourrait conduire à des conflits de dépendance inattendus.
Nous mettons en œuvre des modèles de diffusion discrets dans une bibliothèque autonome discrete_diffusion pour une utilisation générale. La bibliothèque fournit des implémentations de divers modèles de diffusion discrètes typiques, composé de
(Vanilla/Reparameterized) multinomial diffusion : processus de diffusion qui injectent un bruit uniform à la séquence de jeton. La mise en œuvre de la diffusion multinomiale de vanille suit de près la base de code de l'article d'origine;(Vanilla/Reparameterized) absorbing diffusion : processus de diffusion où les jetons dans la séquence peuvent être absorbés à l'état masking , comme décrit dans le papier D3pm. Ces modèles de diffusion partagent le même ensemble d'interfaces permettant des utilisations externes. En particulier, ils sont définis comme des sous-classes de la classe DiscreteDiffusion , prenant la forme suivante:
class DiscreteDiffusion ( nn . Module ):
"""
The parent class for discrete denoising diffusion probabilistic models.
It supports the following methods:
- q_sample()
Sample x_t ~ q(x_t | x_0) to construct noisy Transformer inputs.
- compute_losses()
Compute the loss L_t = KL(q||p) at t-th time step.
- sample_step()
Sample x_t ~ p(x_{t-1} | x_t, x_0) at t-th time step.
"""
def __init__ ( self , num_timesteps ):
super (). __init__ ()
self . num_timesteps = num_timesteps
def q_sample ( self , x_0 , t , ** kwargs ):
"""
Sample from q(x_t | x_0), which is used as the model inputs.
Args:
x_0: token ids with shape [B, N]
t: current time step, tensor with shape [B]
Returns:
return a dict of relevant outputs including x_t.
"""
def compute_losses ( self , inputs , ** kwargs ):
"""
Compute the loss objective KL(q||p) to train our generative process.
Args:
inputs: a dict that contains input types specific to different diffusion processes, containing
- x_t: token ids with shape [B, N]
- t: scalar timesteps, with shape [B]
Returns:
possibly return a dict of relevant outputs, including the loss used for training.
"""
def sample_step ( self , decoder_out , denoising_fn , ** kwargs ):
"""
Given a time step t, start from x_t and sample x_{t-k} from q(x_{t-k} | x_t).
Args:
decoder_out: a namedtuple that contains decoding info, including
- x_t: token ids with shape [B, N]
- t: scalar timesteps
- max_steps: the maximum number of decoding steps
- ...
denoising_fn: a function that takes in x_t and t and returns model logits
kwargs: other arguments that are used to control decoding.
Returns:
return a new decoder_out namedtuple.
""" Un modèle DiscreteDiffusion peut être instancié en configurant ce qui suit:
--num-diffusion-timesteps <int> spécifie le nombre entier de pas de temps de diffusion (par défaut: 50)--diffusion-type <str> Spécifie le type de modèle de diffusion (choix: {absorbing, multinomial, reparam-absorbing, reparam-multinomial} )--noise-scheduler-type <str> Spécifie le calendrier de bruit uniquement dans la diffusion multinomiale vanille / réparam (choix typiques: {linear, cosine} ; par défaut: cosine )q_sample() , y compris--q-sample-mode <str> Spécifie la stratégie d'échantillonnage (choix: {default, coupled, multi-step, multi-sample} ; par défaut: default ). Nous fournissons divers choix pour l'échantillonnage de default : un seul échantillon est tiré comme multi-step : exemple de deux pas de temps IID multi-sample : échantillonnez deux échantillons IID coupled : également connu sous le nom de formation conditionnée, qui est détaillée à l'annexe F du document. Cela commence par l'échantillonnage de deux pas de temps coupled apporte des améliorations significatives pour la diffusion multinomiale / absorbante de vanille, mais le gain n'est pas systématiquement substantiel dans les variantes réparamétrées.--not-diffusing-special-sym indique s'il faut inclure des symboles spéciaux pendant le processus de diffusion (par défaut: false)compute_losses() , y compris--reweighting-type <str> Spécifie le schéma de réhabilitation dans notre famille ré-paramétrée (choix: {linear, reciprocal, none} ; par défaut: linear )--label-smoothing <float> Spécifie le taux de lissage de l'étiquette (par défaut: 0,1)sample_step() , y compris--argmax-decoding indique s'il faut utiliser le décodage Argmax pour la sortie du transformateur débrouillé --temperature <float> Spécifie la température --decoding-strategy <str> Spécifie l'utilisation de la vanille ( default ) / réparamètre ( reparam-<options> ; voir les détails) stratégie de décodage (choix: {default, reparam-<options>} ; par défaut: default )--load-ema-weights indique s'il faut charger les poids du modèle EMA pour la génération (par défaut: false)--iter-decode-max-iter <int> spécifie le nombre maximal de temps pour le décodage (par défaut: 10)--iter-decode-with-beam <int> spécifie la taille du faisceau pour le décodage de plusieurs séquences avec différentes longueurs en parallèle (par défaut: 1)--iter-decode-force-max-iter indique que le décodage itératif doit exécuter le nombre spécifié d'itérations et ne pas sortir. Recommandé de définir ce drapeau sur true.Voir ici pour une liste plus complète d'arguments.
En passant par --decoding-strategy default , le schéma d'échantillonnage de vanille (spécifique à chaque processus de diffusion discret) est utilisé.
Une approche de décodage plus avancée peut être invoquée en passant --decoding-strategy reparam-<conditioning-of-v>-<topk_mode>-<schedule> . Cette approche est basée sur la réparamétrisation proposée dans notre article et permet des procédures de décodage plus efficaces. Les options spécifient l'algorithme de décodage via
<conditioning-of-v> : uncond ou cond (par défaut uncond <topk_mode> : stochastic<float> ou deterministic ( deterministic par défaut): Que ce soit pour utiliser la sélection stochastique ou déterministe. La valeur flottante dans stochastic<float> spécifie le degré d'aléatoire dans la sélection stochastique top- $ k $;<schedule> : linear ou cosine ( cosine par défaut): le calendrier pour Voir l'implémentation pour plus de détails sur les options.
Veuillez consulter les scripts ci-dessous pour plus de détails.
Note
- Notez que toutes les tâches considérées dans ce travail fonctionnent sur les données d'origine et n'adoptent pas de distillation de connaissances (KD).
Nous suivons le prétraitement standard à Fairseq / Exemples pour préparer les données binarisées:
# fetch and preprocess the data to BPE codes
cd examples/translation/
bash prepare-iwslt14.sh
cd ../..
# binarize the data
TEXT=examples/translation/iwslt14.tokenized.de-en
fairseq-preprocess --joined-dictionary --source-lang de --target-lang en
--trainpref $TEXT /train --validpref $TEXT /valid --testpref $TEXT /test
--destdir data-bin/iwslt14.tokenized.de-en
--workers 20Nous utilisons les données publiées dans Fairseq / Exemples pour préparer l'ensemble de données:
wget http://dl.fbaipublicfiles.com/nat/original_dataset.zip
unzip original_dataset.zip
TEXT=wmt14_ende
fairseq-preprocess --joined-dictionary
--source-lang en --target-lang de
--trainpref $TEXT /train.en-de --validpref $TEXT /valid.en-de --testpref $TEXT /test.en-de
--destdir data-bin/wmt14_ende --thresholdtgt 0 --thresholdsrc 0
--workers 20Pour cet ensemble de données, nous utilisons les données brutes wmt16.tar.gz comme prétraitée dans ce référentiel.
tar xzvf wmt16.tar.gz
TEXT=wmt16/en-ro
# move train/ dev/ test/ bpe codes into the $TEXT folder
mv $TEXT /train/corpus.bpe.en $TEXT /train.bpe.en
mv $TEXT /train/corpus.bpe.ro $TEXT /train.bpe.ro
mv $TEXT /dev/dev.bpe.en $TEXT /dev.bpe.en
mv $TEXT /dev/dev.bpe.ro $TEXT /dev.bpe.ro
mv $TEXT /test/test.bpe.en $TEXT /test.bpe.en
mv $TEXT /test/test.bpe.ro $TEXT /test.bpe.ro
# binarize the data
fairseq-preprocess --joined-dictionary
--source-lang en --target-lang ro
--trainpref $TEXT /train.bpe --validpref $TEXT /dev.bpe --testpref $TEXT /test.bpe
--destdir data-bin/wmt16_enro --thresholdtgt 0 --thresholdsrc 0
--workers 20 Nous entrons d'abord dans le dossier fairseq , puis exécutons les commandes suivantes pour former les modèles.
# ####### training scripts for IWSLT'14 , WMT'14, and WMT'16
# first cd to fairseq
# we use 1 GPU for IWSLT'14, 4 GPUs for WMT'14 and 2 GPUs for WMT'16 datasets respectively.
CUDA_VISIBLE_DEVICES=0 bash experiments/mt_train.sh -m absorbing -d < iwslt/wmt14/wmt 16> -s default -e True --store-ema --label-smoothing 0.1
CUDA_VISIBLE_DEVICES=1 bash experiments/mt_train.sh -m multinomial -d < iwslt/wmt14/wmt 16> -s default -e True --not-diffusing-special-sym --store-ema --label-smoothing 0.0
CUDA_VISIBLE_DEVICES=2 bash experiments/mt_train.sh -m reparam-absorbing -d < iwslt/wmt14/wmt 16> -s default -e True --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linear
CUDA_VISIBLE_DEVICES=3 bash experiments/mt_train.sh -m reparam-multinomial -d < iwslt/wmt14/wmt 16> -s default -e True --not-diffusing-special-sym --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linearNote
-s <str>est utilisé pour spécifier le nom de l'expérience.- Nous pourrions transmettre des arguments personnalisés qui pourraient être spécifiques à la formation en les ajoutant après
-e True.
Le pipeline d'évaluation est géré par experiments/mt_generate.sh . Le script générera les résultats de traduction et évaluera le score BLEU.
# ########## IWLS'14, WMT'14, and WMT'16 datasets
# we recommend putting each checkpoint into a separate folder
# since the script will put the decoded results into a file under the same folder of each checkpoint.
CUDA_VISIBLE_DEVICES=0 bash experiments/mt_generate.sh -a false -c < checkpoint_path > -d < iwslt/wmt14/wmt 16> Arguments:
-a : s'il faut en moyenne plusieurs points de contrôle-c : indique l'emplacement du point de contrôle. Si -a false (pas aux points de contrôle moyens), passez le chemin de contrôle; Si -a true , passez le répertoire qui stocke plusieurs points de contrôle à différentes étapes de formation pour la moyenne.-d : le nom de l'ensemble de donnéesNous fournissons également les points de contrôle de nos modèles qualifiés.
| Ensemble de données | Modèle | Lien de point de contrôle |
|---|---|---|
| Iwslt'14 | Multination | lien |
| Iwslt'14 | Absorbant | lien |
| Iwslt'14 | Reparam-multinomial | lien |
| Iwslt'14 | Reparam absorbant | lien |
| WMT'14 | Multination | lien |
| WMT'14 | Absorbant | lien |
| WMT'14 | Reparam-multinomial | lien |
| WMT'14 | Reparam absorbant | lien |
| WMT'16 | Multination | lien |
| WMT'16 | Absorbant | lien |
| WMT'16 | Reparam-multinomial | lien |
| WMT'16 | Reparam absorbant | lien |
Nous suivons la configuration expérimentale dans Diffuseq pour la génération de questions et les tâches de paraphrase .
Les données brutes de ces deux tâches peuvent être récupérées à partir du référentiel diffuseq d'origine. Nous binarisons ensuite les données via le script fourni.
# put the raw data in the directory ``diffuseq_data/QG``
# Preprocess the question generation dataset
bash diffusion_mt/scripts/preprocess_diffuseq_datasets.sh QG
# put the raw data in the directory ``diffuseq_data/QQP``
# Preprocess the paraphrasing dataset
bash diffusion_mt/scripts/preprocess_diffuseq_datasets.sh QQP # QQP or QG datasets
# first cd to fairseq
CUDA_VISIBLE_DEVICES=0,1 bash experiments/diffuseq_train.sh -m absorbing -d < qqp/qg > -s default -e True --store-ema --label-smoothing 0.1
CUDA_VISIBLE_DEVICES=2,3 bash experiments/diffuseq_train.sh -m multinomial -d < qqp/qg > -s default -e True --not-diffusing-special-sym --store-ema --label-smoothing 0.0
CUDA_VISIBLE_DEVICES=0,1 bash experiments/diffuseq_train.sh -m reparam-multinomial -d < qqp/qg > -s default -e True --not-diffusing-special-sym --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linear
CUDA_VISIBLE_DEVICES=2,3 bash experiments/diffuseq_train.sh -m reparam-absorbing -d < qqp/qg > -s default -e True --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linear Nous suivons étroitement les protocoles de génération et d'évaluation comme dans DiFfuseq pour assurer une comparaison principale. L'ensemble du pipeline est réimplémenté dans fairseq/diffusion_mt/scripts/decode_diffuseq.py et fairseq/diffusion_mt/scripts/eval_diffuseq.py respectivement pour être compatible avec Fairseq. Exécutez les commandes suivantes:
# we recommend putting each checkpoint into a separate folder
# since the script will put the decoded results into a file under the same folder of each checkpoint.
CUDA_VISIBLE_DEVICES=0 bash experiments/diffuseq_generate.sh -a false -b true -c < checkpoint_path > -d < qqp/qg > Arguments:
-a : s'il faut en moyenne plusieurs points de contrôle-b : s'il faut utiliser plusieurs échantillons pour le décodage MBR-c : indique l'emplacement du point de contrôle. Si -a false (pas aux points de contrôle moyens), passez le chemin de contrôle; Si -a true , passez le répertoire qui stocke plusieurs points de contrôle à différentes étapes de formation pour la moyenne.-d : le nom de l'ensemble de donnéesNous fournissons également les points de contrôle de nos modèles qualifiés.
| Ensemble de données | Modèle | Lien de point de contrôle |
|---|---|---|
| Fm | Multination | lien |
| Fm | Absorbant | lien |
| Fm | Reparam-multinomial | lien |
| Fm | Reparam absorbant | lien |
| QQP | Multination | lien |
| QQP | Absorbant | lien |
| QQP | Reparam-multinomial | lien |
| QQP | Reparam absorbant | lien |
@article { zheng2023rdm ,
title = { A Reparameterized Discrete Diffusion Model for Text Generation } ,
author = { Zheng, Lin and Yuan, Jianbo and Yu, Lei and Kong, Lingpeng } ,
journal = { arXiv preprint arXiv:2302.05737 } ,
year = { 2023 }
}