Dieses Repository enthält die offizielle Implementierung von Papier, ein reparametriatisiertes diskretes Diffusionsmodell für die Textgenerierung.
Die Codebasis wird mit Fairseq implementiert. Um die Abhängigkeiten zu installieren, werden die folgenden Befehle ausgeführt (in einer virtuellen Umgebung empfohlen):
pip install -r requirements.txt
# install our package of discrete diffusion models
pip install -e discrete_diffusion
# install our fork of fairseq
cd fairseq
python3 setup.py build develop
cd ..Beachten Sie, dass die Umgebung mit Python 3.8.10, Pytorch 1.10.0/1.12.0 und CUDA 11.3 getestet wird. Beachten Sie auch, dass unsere Gabel von Fairseq mehrere Dateien in der ursprünglichen Codebasis verändert. Die Verwendung neuerer Versionen von Fairseq kann zu unerwarteten Abhängigkeitskonflikten führen.
Wir implementieren diskrete Diffusionsmodelle in einer in sich geschlossenen Bibliothek discrete_diffusion für die allgemeine Verwendung. Die Bibliothek liefert Implementierungen verschiedener typischer diskreter Diffusionsmodelle, die bestehen aus
(Vanilla/Reparameterized) multinomial diffusion : Diffusionsprozesse, die der Token -Sequenz uniform Rauschen injizieren. Die Implementierung der Vanille -multinomialen Diffusion folgt genau der Codebasis des Originalpapiers.(Vanilla/Reparameterized) absorbing diffusion : Diffusionsprozesse, bei denen Token innerhalb der Sequenz in den masking absorbiert werden können, wie im D3pm -Papier beschrieben. Diese Diffusionsmodelle haben den gleichen Satz von Schnittstellen, die externe Verwendungen ermöglichen. Insbesondere werden sie als Unterklassen der DiscreteDiffusion definiert, die das folgende Formular annehmen:
class DiscreteDiffusion ( nn . Module ):
"""
The parent class for discrete denoising diffusion probabilistic models.
It supports the following methods:
- q_sample()
Sample x_t ~ q(x_t | x_0) to construct noisy Transformer inputs.
- compute_losses()
Compute the loss L_t = KL(q||p) at t-th time step.
- sample_step()
Sample x_t ~ p(x_{t-1} | x_t, x_0) at t-th time step.
"""
def __init__ ( self , num_timesteps ):
super (). __init__ ()
self . num_timesteps = num_timesteps
def q_sample ( self , x_0 , t , ** kwargs ):
"""
Sample from q(x_t | x_0), which is used as the model inputs.
Args:
x_0: token ids with shape [B, N]
t: current time step, tensor with shape [B]
Returns:
return a dict of relevant outputs including x_t.
"""
def compute_losses ( self , inputs , ** kwargs ):
"""
Compute the loss objective KL(q||p) to train our generative process.
Args:
inputs: a dict that contains input types specific to different diffusion processes, containing
- x_t: token ids with shape [B, N]
- t: scalar timesteps, with shape [B]
Returns:
possibly return a dict of relevant outputs, including the loss used for training.
"""
def sample_step ( self , decoder_out , denoising_fn , ** kwargs ):
"""
Given a time step t, start from x_t and sample x_{t-k} from q(x_{t-k} | x_t).
Args:
decoder_out: a namedtuple that contains decoding info, including
- x_t: token ids with shape [B, N]
- t: scalar timesteps
- max_steps: the maximum number of decoding steps
- ...
denoising_fn: a function that takes in x_t and t and returns model logits
kwargs: other arguments that are used to control decoding.
Returns:
return a new decoder_out namedtuple.
""" Ein DiscreteDiffusion -Modell kann durch Konfiguration Folgendes instanziiert werden:
--num-diffusion-timesteps <int> Gibt die gesamte Anzahl der Diffusionszeitschritte an (Standard: 50)--diffusion-type <str> Gibt den Diffusionsmodelltyp an (Auswahl: {absorbing, multinomial, reparam-absorbing, reparam-multinomial} )--noise-scheduler-type <str> Gibt den Rauschplan nur in Vanille/Reparam-Multinomialdiffusion an (Typische Auswahlmöglichkeiten: {linear, cosine} ; Standard: cosine )q_sample() , einschließlich--q-sample-mode <str> Gibt die Stichprobenstrategie an (Auswahl: {default, coupled, multi-step, multi-sample} ; Standard: default ). Wir bieten verschiedene Auswahlmöglichkeiten für die Probenahme von default : Eine einzelne Probe wird als gezeichnet als multi-step : Proben Sie zwei Zeitschritte in der IID multi-sample : Proben Sie zwei IID-Proben coupled : auch als konditioniertes Training bezeichnet, das in Anhang F des Papiers detailliert ist. Dies beginnt mit der Probenahme von zwei IID -Zeitschritten coupled Probenahmemodus signifikante Verbesserungen sowohl für Vanille -Multinomial-/absorbierende Diffusionen verleiht, aber die Verstärkung ist in reparameterisierten Varianten nicht konsequent wesentlich.--not-diffusing-special-sym Zeigt an, ob während des Diffusionsprozesses spezielle Symbole einbezogen werden sollen (Standard: Falsch)compute_losses() , einschließlich--reweighting-type <str> Gibt das Wiederbelebungsschema in unserer reparametrisierten Familie an (Auswahl: {linear, reciprocal, none} ; Standard: linear )--label-smoothing <float> Gibt die Rate der Etikettenglättung an (Standardeinstellung: 0.1)sample_step() , einschließlich--argmax-decoding zeigt an --temperature <float> Gibt die Temperatur an --decoding-strategy <str> Gibt die Verwendung von Vanille ( default ) / Reparameterized ( reparam-<options> an; siehe die Details) Decodierungsstrategie (Auswahl: {default, reparam-<options>} ; Standard: Standard: default )--load-ema-weights geben an, ob die EMA-Modellgewichte für die Generation geladen werden sollen (Standard: Falsch)--iter-decode-max-iter <int> Gibt die maximale Anzahl von Zeitschritten für die Dekodierung an (Standardeinstellung: 10)--iter-decode-with-beam <int> Gibt die Strahlgröße für die Dekodierung mehrerer Sequenzen mit unterschiedlichen Längen parallel an (Standardeinstellung: 1)--iter-decode-force-max-iter zeigt an, dass die iterative Dekodierung die angegebene Anzahl von Iterationen ausführen und nicht beendet. Empfohlen, diese Flagge auf True zu setzen.Eine umfassendere Liste von Argumenten finden Sie hier.
Durch das Bestehen --decoding-strategy default wird das Vanille-Stichprobenschema (spezifisch für jeden diskreten Diffusionsprozess) verwendet.
Ein fortschrittlicherer Dekodierungsansatz kann durch Passieren aufgerufen werden --decoding-strategy reparam-<conditioning-of-v>-<topk_mode>-<schedule> . Dieser Ansatz basiert auf der vorgeschlagenen Reparametrisierung in unserem Artikel und ermöglicht effektivere Decodierungsverfahren. Die Optionen geben den Dekodierungsalgorithmus über
<conditioning-of-v> : uncond oder cond (Standard uncond ): Ob die Routing-Variable generiert werden soll <topk_mode> : stochastic<float> oder deterministic ( deterministic ): Ob Sie stochastische oder deterministische Top-k $ -Selection verwenden sollen. Der Float-Wert in stochastic<float> Gibt den Grad der Zufälligkeit in der stochastischen Top-K $ -Selection an.<schedule> : linear oder cosine (Standard cosine ): Der Zeitplan für Weitere Informationen zu den Optionen finden Sie in der Implementierung.
Weitere Informationen finden Sie in den folgenden Skripten.
Notiz
- Beachten Sie, dass alle in dieser Arbeit berücksichtigten Aufgaben auf den Originaldaten arbeiten und keine Wissensdestillation (KD) übernehmen.
Wir folgen der Standardvorverarbeitung in Fairseq/Beispielen zur Erstellung der binärisierten Daten:
# fetch and preprocess the data to BPE codes
cd examples/translation/
bash prepare-iwslt14.sh
cd ../..
# binarize the data
TEXT=examples/translation/iwslt14.tokenized.de-en
fairseq-preprocess --joined-dictionary --source-lang de --target-lang en
--trainpref $TEXT /train --validpref $TEXT /valid --testpref $TEXT /test
--destdir data-bin/iwslt14.tokenized.de-en
--workers 20Wir verwenden die in fairseq/Beispiele veröffentlichten Daten, um den Datensatz vorzubereiten:
wget http://dl.fbaipublicfiles.com/nat/original_dataset.zip
unzip original_dataset.zip
TEXT=wmt14_ende
fairseq-preprocess --joined-dictionary
--source-lang en --target-lang de
--trainpref $TEXT /train.en-de --validpref $TEXT /valid.en-de --testpref $TEXT /test.en-de
--destdir data-bin/wmt14_ende --thresholdtgt 0 --thresholdsrc 0
--workers 20Für diesen Datensatz verwenden wir die Rohdaten wmt16.tar.gz als in diesem Repository vorverarbeiteten.
tar xzvf wmt16.tar.gz
TEXT=wmt16/en-ro
# move train/ dev/ test/ bpe codes into the $TEXT folder
mv $TEXT /train/corpus.bpe.en $TEXT /train.bpe.en
mv $TEXT /train/corpus.bpe.ro $TEXT /train.bpe.ro
mv $TEXT /dev/dev.bpe.en $TEXT /dev.bpe.en
mv $TEXT /dev/dev.bpe.ro $TEXT /dev.bpe.ro
mv $TEXT /test/test.bpe.en $TEXT /test.bpe.en
mv $TEXT /test/test.bpe.ro $TEXT /test.bpe.ro
# binarize the data
fairseq-preprocess --joined-dictionary
--source-lang en --target-lang ro
--trainpref $TEXT /train.bpe --validpref $TEXT /dev.bpe --testpref $TEXT /test.bpe
--destdir data-bin/wmt16_enro --thresholdtgt 0 --thresholdsrc 0
--workers 20 Wir kommen zuerst in den fairseq -Ordner und führen dann die folgenden Befehle aus, um die Modelle zu trainieren.
# ####### training scripts for IWSLT'14 , WMT'14, and WMT'16
# first cd to fairseq
# we use 1 GPU for IWSLT'14, 4 GPUs for WMT'14 and 2 GPUs for WMT'16 datasets respectively.
CUDA_VISIBLE_DEVICES=0 bash experiments/mt_train.sh -m absorbing -d < iwslt/wmt14/wmt 16> -s default -e True --store-ema --label-smoothing 0.1
CUDA_VISIBLE_DEVICES=1 bash experiments/mt_train.sh -m multinomial -d < iwslt/wmt14/wmt 16> -s default -e True --not-diffusing-special-sym --store-ema --label-smoothing 0.0
CUDA_VISIBLE_DEVICES=2 bash experiments/mt_train.sh -m reparam-absorbing -d < iwslt/wmt14/wmt 16> -s default -e True --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linear
CUDA_VISIBLE_DEVICES=3 bash experiments/mt_train.sh -m reparam-multinomial -d < iwslt/wmt14/wmt 16> -s default -e True --not-diffusing-special-sym --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linearNotiz
-s <str>wird verwendet, um den Namen des Experiments anzugeben.- Wir könnten benutzerdefinierte Argumente bestehen, die für das Training spezifisch sein könnten, indem wir sie nach
-e Trueanhängen.
Die Bewertungspipeline wird mit experiments/mt_generate.sh behandelt. Das Skript generiert die Übersetzungsergebnisse und bewertet den BLEU -Wert.
# ########## IWLS'14, WMT'14, and WMT'16 datasets
# we recommend putting each checkpoint into a separate folder
# since the script will put the decoded results into a file under the same folder of each checkpoint.
CUDA_VISIBLE_DEVICES=0 bash experiments/mt_generate.sh -a false -c < checkpoint_path > -d < iwslt/wmt14/wmt 16> Argumente:
-a : Ob Sie durchschnittlich mehrere Checkpoints durchschnittlich sind-c : Zeigt den Ort des Checkpoint an. Wenn -a false (nicht zu durchschnittlichen Kontrollpunkten), übergeben Sie den Checkpoint -Pfad ; Wenn -a true , geben Sie das Verzeichnis über, das mehrere Kontrollpunkte an verschiedenen Trainingsschritten zur Mittelung speichert.-d : Der DatensatznameWir bieten auch die Kontrollpunkte unserer geschulten Modelle an.
| Datensatz | Modell | Checkpoint -Link |
|---|---|---|
| IWSLT'14 | Multinomial | Link |
| IWSLT'14 | Absorbierend | Link |
| IWSLT'14 | Reparam-Multinom | Link |
| IWSLT'14 | Reparam-Absorbing | Link |
| Wmt'14 | Multinomial | Link |
| Wmt'14 | Absorbierend | Link |
| Wmt'14 | Reparam-Multinom | Link |
| Wmt'14 | Reparam-Absorbing | Link |
| Wmt'16 | Multinomial | Link |
| Wmt'16 | Absorbierend | Link |
| Wmt'16 | Reparam-Multinom | Link |
| Wmt'16 | Reparam-Absorbing | Link |
Wir folgen dem experimentellen Setup in Diffuseq zur Erzeugung von Fragen und zur Paraphrasierung von Aufgaben.
Die Rohdaten dieser beiden Aufgaben können aus dem ursprünglichen Diffuseq -Repository abgerufen werden. Anschließend binarisieren wir die Daten über das bereitgestellte Skript.
# put the raw data in the directory ``diffuseq_data/QG``
# Preprocess the question generation dataset
bash diffusion_mt/scripts/preprocess_diffuseq_datasets.sh QG
# put the raw data in the directory ``diffuseq_data/QQP``
# Preprocess the paraphrasing dataset
bash diffusion_mt/scripts/preprocess_diffuseq_datasets.sh QQP # QQP or QG datasets
# first cd to fairseq
CUDA_VISIBLE_DEVICES=0,1 bash experiments/diffuseq_train.sh -m absorbing -d < qqp/qg > -s default -e True --store-ema --label-smoothing 0.1
CUDA_VISIBLE_DEVICES=2,3 bash experiments/diffuseq_train.sh -m multinomial -d < qqp/qg > -s default -e True --not-diffusing-special-sym --store-ema --label-smoothing 0.0
CUDA_VISIBLE_DEVICES=0,1 bash experiments/diffuseq_train.sh -m reparam-multinomial -d < qqp/qg > -s default -e True --not-diffusing-special-sym --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linear
CUDA_VISIBLE_DEVICES=2,3 bash experiments/diffuseq_train.sh -m reparam-absorbing -d < qqp/qg > -s default -e True --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linear Wir folgen den Generations- und Evaluierungsprotokollen wie in Diffuseq genau, um einen Kopf-an-Kopf-Vergleich zu gewährleisten. Die gesamte Pipeline wird in fairseq/diffusion_mt/scripts/decode_diffuseq.py und fairseq/diffusion_mt/scripts/eval_diffuseq.py neu implementiert, um mit FairSeq kompatibel zu sein. Führen Sie die folgenden Befehle aus:
# we recommend putting each checkpoint into a separate folder
# since the script will put the decoded results into a file under the same folder of each checkpoint.
CUDA_VISIBLE_DEVICES=0 bash experiments/diffuseq_generate.sh -a false -b true -c < checkpoint_path > -d < qqp/qg > Argumente:
-a : Ob Sie durchschnittlich mehrere Checkpoints durchschnittlich sind-b : Ob Sie mehrere Proben für die MBR -Dekodierung verwenden sollen-c : Zeigt den Ort des Checkpoint an. Wenn -a false (nicht zu durchschnittlichen Kontrollpunkten), übergeben Sie den Checkpoint -Pfad ; Wenn -a true , geben Sie das Verzeichnis über, das mehrere Kontrollpunkte an verschiedenen Trainingsschritten zur Mittelung speichert.-d : Der DatensatznameWir bieten auch die Kontrollpunkte unserer geschulten Modelle an.
| Datensatz | Modell | Checkpoint -Link |
|---|---|---|
| Qg | Multinomial | Link |
| Qg | Absorbierend | Link |
| Qg | Reparam-Multinom | Link |
| Qg | Reparam-Absorbing | Link |
| QQP | Multinomial | Link |
| QQP | Absorbierend | Link |
| QQP | Reparam-Multinom | Link |
| QQP | Reparam-Absorbing | Link |
@article { zheng2023rdm ,
title = { A Reparameterized Discrete Diffusion Model for Text Generation } ,
author = { Zheng, Lin and Yuan, Jianbo and Yu, Lei and Kong, Lingpeng } ,
journal = { arXiv preprint arXiv:2302.05737 } ,
year = { 2023 }
}