Este repositorio contiene la implementación oficial del documento un modelo de difusión discreta reparameterizada para la generación de texto.
La base de código se implementa con Fairseq. Para instalar las dependencias, ejecute (recomendado en un entorno virtual) los siguientes comandos:
pip install -r requirements.txt
# install our package of discrete diffusion models
pip install -e discrete_diffusion
# install our fork of fairseq
cd fairseq
python3 setup.py build develop
cd ..Nota El entorno se prueba con Python 3.8.10, Pytorch 1.10.0/1.12.0 y CUDA 11.3. También tenga en cuenta que nuestra bifurcación de FairSeq modifica varios archivos en la base de código original; El uso de versiones más recientes de Fairseq podría conducir a conflictos de dependencia inesperados.
Implementamos modelos de difusión discretos en una biblioteca autónoma discrete_diffusion para uso general. La biblioteca proporciona implementaciones de varios modelos de difusión discretos típicos, que consisten en
(Vanilla/Reparameterized) multinomial diffusion : procesos de difusión que inyectan ruido uniform en la secuencia del token. La implementación de la difusión multinomial de vainilla sigue de cerca la base de código del documento original;(Vanilla/Reparameterized) absorbing diffusion : procesos de difusión donde los tokens dentro de la secuencia podrían ser absorbidos al estado de masking , como se describe en el papel de la D3PM. Estos modelos de difusión comparten el mismo conjunto de interfaces que permiten usos externos. En particular, se definen como subclases de clase DiscreteDiffusion , tomando la siguiente forma:
class DiscreteDiffusion ( nn . Module ):
"""
The parent class for discrete denoising diffusion probabilistic models.
It supports the following methods:
- q_sample()
Sample x_t ~ q(x_t | x_0) to construct noisy Transformer inputs.
- compute_losses()
Compute the loss L_t = KL(q||p) at t-th time step.
- sample_step()
Sample x_t ~ p(x_{t-1} | x_t, x_0) at t-th time step.
"""
def __init__ ( self , num_timesteps ):
super (). __init__ ()
self . num_timesteps = num_timesteps
def q_sample ( self , x_0 , t , ** kwargs ):
"""
Sample from q(x_t | x_0), which is used as the model inputs.
Args:
x_0: token ids with shape [B, N]
t: current time step, tensor with shape [B]
Returns:
return a dict of relevant outputs including x_t.
"""
def compute_losses ( self , inputs , ** kwargs ):
"""
Compute the loss objective KL(q||p) to train our generative process.
Args:
inputs: a dict that contains input types specific to different diffusion processes, containing
- x_t: token ids with shape [B, N]
- t: scalar timesteps, with shape [B]
Returns:
possibly return a dict of relevant outputs, including the loss used for training.
"""
def sample_step ( self , decoder_out , denoising_fn , ** kwargs ):
"""
Given a time step t, start from x_t and sample x_{t-k} from q(x_{t-k} | x_t).
Args:
decoder_out: a namedtuple that contains decoding info, including
- x_t: token ids with shape [B, N]
- t: scalar timesteps
- max_steps: the maximum number of decoding steps
- ...
denoising_fn: a function that takes in x_t and t and returns model logits
kwargs: other arguments that are used to control decoding.
Returns:
return a new decoder_out namedtuple.
""" Se puede instanciar un modelo DiscreteDiffusion configurando lo siguiente:
--num-diffusion-timesteps <int> Especifica el número completo de pasos de tiempo de difusión (predeterminado: 50)--diffusion-type <str> Especifica el tipo de modelo de difusión (opciones: {absorbing, multinomial, reparam-absorbing, reparam-multinomial} )--noise-scheduler-type <str> especifica el programa de ruido solo en difusión multinomial de vainilla/reparam (opciones típicas: {linear, cosine} ; predeterminado: cosine )q_sample() , incluido--q-sample-mode <str> Especifica la estrategia de muestreo (opciones: {default, coupled, multi-step, multi-sample} ; predeterminado: default ). Brindamos varias opciones para el muestreo de default : se dibuja una sola muestra como multi-step : muestra dos pasos de tiempo IID multi-sample : muestra dos muestras de IID coupled : también conocido como entrenamiento condicionado, que se detalla en el Apéndice F del documento. Esto comienza con el muestreo de dos pasos de tiempo IID coupled trae mejoras significativas tanto para la difusión multinomial/absorbente de vainilla, pero la ganancia no es consistentemente sustancial en las variantes reparameterizadas.--not-diffusing-special-sym indica si se debe incluir símbolos especiales durante el proceso de difusión (predeterminado: falso)compute_losses() , incluido--reweighting-type <str> Especifica el esquema de rewe-weighting en nuestra familia reparameterizada (opciones: {linear, reciprocal, none} ; predeterminado: linear )--label-smoothing <float> Especifica la velocidad de suavizado de la etiqueta (predeterminado: 0.1)sample_step() , incluido--argmax-decoding indica si usar ArgMax Decoding para la salida del transformador desocionado --temperature <float> especifica la temperatura --decoding-strategy <str> especifica el uso de vainilla ( default ) / reparameterized ( reparam-<options> ; consulte la estrategia de decodificación de detalles (opciones: {default, reparam-<options>} ; predeterminado: default )--load-ema-weights indica si cargar los pesos del modelo EMA para la generación (predeterminado: falso)--iter-decode-max-iter <int> Especifica el número máximo de times de decodificación (predeterminado: 10)--iter-decode-with-beam <int> Especifica el tamaño del haz para decodificar múltiples secuencias con diferentes longitudes en paralelo (predeterminado: 1)--iter-decode-force-max-iter indica que la decodificación iterativa debe ejecutar el número especificado de iteraciones y no salir. Recomendado para establecer esta bandera en verdad.Vea aquí para obtener una lista más completa de argumentos.
Al pasar --decoding-strategy default , se utiliza el esquema de muestreo de vainilla (específico para cada proceso de difusión discreta).
Se puede invocar un enfoque de decodificación más avanzado pasando --decoding-strategy reparam-<conditioning-of-v>-<topk_mode>-<schedule> . Este enfoque se basa en la reparametrización propuesta en nuestro documento y permite procedimientos de decodificación más efectivos. Las opciones especifican el algoritmo de decodificación a través de
<conditioning-of-v> : uncond o cond ( uncond predeterminados): si se debe generar la variable de enrutamiento <topk_mode> : stochastic<float> o deterministic ( deterministic predeterminado): si se debe usar selección estocástica o determinista Top- $ K $. El valor flotante en stochastic<float> especifica el grado de aleatoriedad en la selección estocástica superior $ k $;<schedule> : linear o cosine ( cosine predeterminado): el cronograma para Consulte la implementación para obtener más detalles sobre las opciones.
Consulte los scripts a continuación para más detalles.
Nota
- Tenga en cuenta que todas las tareas consideradas en este trabajo operan en los datos originales y no adoptan la destilación de conocimiento (KD).
Seguimos el preprocesamiento estándar en FairSeq/Ejemplos para preparar los datos binarizados:
# fetch and preprocess the data to BPE codes
cd examples/translation/
bash prepare-iwslt14.sh
cd ../..
# binarize the data
TEXT=examples/translation/iwslt14.tokenized.de-en
fairseq-preprocess --joined-dictionary --source-lang de --target-lang en
--trainpref $TEXT /train --validpref $TEXT /valid --testpref $TEXT /test
--destdir data-bin/iwslt14.tokenized.de-en
--workers 20Utilizamos los datos publicados en FairSeq/Ejemplos para preparar el conjunto de datos:
wget http://dl.fbaipublicfiles.com/nat/original_dataset.zip
unzip original_dataset.zip
TEXT=wmt14_ende
fairseq-preprocess --joined-dictionary
--source-lang en --target-lang de
--trainpref $TEXT /train.en-de --validpref $TEXT /valid.en-de --testpref $TEXT /test.en-de
--destdir data-bin/wmt14_ende --thresholdtgt 0 --thresholdsrc 0
--workers 20Para este conjunto de datos, utilizamos los datos sin procesar WMT16.tar.gz como procesado en este repositorio.
tar xzvf wmt16.tar.gz
TEXT=wmt16/en-ro
# move train/ dev/ test/ bpe codes into the $TEXT folder
mv $TEXT /train/corpus.bpe.en $TEXT /train.bpe.en
mv $TEXT /train/corpus.bpe.ro $TEXT /train.bpe.ro
mv $TEXT /dev/dev.bpe.en $TEXT /dev.bpe.en
mv $TEXT /dev/dev.bpe.ro $TEXT /dev.bpe.ro
mv $TEXT /test/test.bpe.en $TEXT /test.bpe.en
mv $TEXT /test/test.bpe.ro $TEXT /test.bpe.ro
# binarize the data
fairseq-preprocess --joined-dictionary
--source-lang en --target-lang ro
--trainpref $TEXT /train.bpe --validpref $TEXT /dev.bpe --testpref $TEXT /test.bpe
--destdir data-bin/wmt16_enro --thresholdtgt 0 --thresholdsrc 0
--workers 20 Primero entramos en la carpeta fairseq y luego ejecutamos los siguientes comandos para entrenar los modelos.
# ####### training scripts for IWSLT'14 , WMT'14, and WMT'16
# first cd to fairseq
# we use 1 GPU for IWSLT'14, 4 GPUs for WMT'14 and 2 GPUs for WMT'16 datasets respectively.
CUDA_VISIBLE_DEVICES=0 bash experiments/mt_train.sh -m absorbing -d < iwslt/wmt14/wmt 16> -s default -e True --store-ema --label-smoothing 0.1
CUDA_VISIBLE_DEVICES=1 bash experiments/mt_train.sh -m multinomial -d < iwslt/wmt14/wmt 16> -s default -e True --not-diffusing-special-sym --store-ema --label-smoothing 0.0
CUDA_VISIBLE_DEVICES=2 bash experiments/mt_train.sh -m reparam-absorbing -d < iwslt/wmt14/wmt 16> -s default -e True --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linear
CUDA_VISIBLE_DEVICES=3 bash experiments/mt_train.sh -m reparam-multinomial -d < iwslt/wmt14/wmt 16> -s default -e True --not-diffusing-special-sym --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linearNota
-s <str>se usa para especificar el nombre del experimento.- Podríamos aprobar argumentos personalizados que podrían ser específicos para la capacitación al agregarlos después de
-e True.
La tubería de evaluación se maneja por experiments/mt_generate.sh . El script generará los resultados de la traducción y evaluará la puntuación BLUU.
# ########## IWLS'14, WMT'14, and WMT'16 datasets
# we recommend putting each checkpoint into a separate folder
# since the script will put the decoded results into a file under the same folder of each checkpoint.
CUDA_VISIBLE_DEVICES=0 bash experiments/mt_generate.sh -a false -c < checkpoint_path > -d < iwslt/wmt14/wmt 16> Argumentos:
-a : si promedia múltiples puntos de control-c : indica la ubicación del punto de control. Si -a false (no para promedio de puntos de control), pase la ruta del punto de control; Si -a true , pase el directorio que almacena múltiples puntos de control en diferentes pasos de entrenamiento para promediar.-d : el nombre del conjunto de datosTambién proporcionamos los puntos de control de nuestros modelos capacitados.
| Conjunto de datos | Modelo | Enlace de punto de control |
|---|---|---|
| Iwslt'14 | Multinomial | enlace |
| Iwslt'14 | Absorbente | enlace |
| Iwslt'14 | Reparador-multinomial | enlace |
| Iwslt'14 | Reparam-absorbente | enlace |
| WMT'14 | Multinomial | enlace |
| WMT'14 | Absorbente | enlace |
| WMT'14 | Reparador-multinomial | enlace |
| WMT'14 | Reparam-absorbente | enlace |
| WMT'16 | Multinomial | enlace |
| WMT'16 | Absorbente | enlace |
| WMT'16 | Reparador-multinomial | enlace |
| WMT'16 | Reparam-absorbente | enlace |
Seguimos la configuración experimental en DiffuseQ para la generación de preguntas y las tareas de parafraseo .
Los datos sin procesar de estas dos tareas se pueden obtener del repositorio original de Diffuseq. Luego binarizamos los datos a través del script proporcionado.
# put the raw data in the directory ``diffuseq_data/QG``
# Preprocess the question generation dataset
bash diffusion_mt/scripts/preprocess_diffuseq_datasets.sh QG
# put the raw data in the directory ``diffuseq_data/QQP``
# Preprocess the paraphrasing dataset
bash diffusion_mt/scripts/preprocess_diffuseq_datasets.sh QQP # QQP or QG datasets
# first cd to fairseq
CUDA_VISIBLE_DEVICES=0,1 bash experiments/diffuseq_train.sh -m absorbing -d < qqp/qg > -s default -e True --store-ema --label-smoothing 0.1
CUDA_VISIBLE_DEVICES=2,3 bash experiments/diffuseq_train.sh -m multinomial -d < qqp/qg > -s default -e True --not-diffusing-special-sym --store-ema --label-smoothing 0.0
CUDA_VISIBLE_DEVICES=0,1 bash experiments/diffuseq_train.sh -m reparam-multinomial -d < qqp/qg > -s default -e True --not-diffusing-special-sym --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linear
CUDA_VISIBLE_DEVICES=2,3 bash experiments/diffuseq_train.sh -m reparam-absorbing -d < qqp/qg > -s default -e True --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linear Seguimos de cerca los protocolos de generación y evaluación como en Diffuseq para garantizar una comparación cara a cara. Toda la tubería se vuelve a implementar en fairseq/diffusion_mt/scripts/decode_diffuseq.py y fairseq/diffusion_mt/scripts/eval_diffuseq.py respectivamente para ser compatible con Fairseq. Ejecute los siguientes comandos:
# we recommend putting each checkpoint into a separate folder
# since the script will put the decoded results into a file under the same folder of each checkpoint.
CUDA_VISIBLE_DEVICES=0 bash experiments/diffuseq_generate.sh -a false -b true -c < checkpoint_path > -d < qqp/qg > Argumentos:
-a : si promedia múltiples puntos de control-b : si usar múltiples muestras para la decodificación de MBR-c : indica la ubicación del punto de control. Si -a false (no para promedio de puntos de control), pase la ruta del punto de control; Si -a true , pase el directorio que almacena múltiples puntos de control en diferentes pasos de entrenamiento para promediar.-d : el nombre del conjunto de datosTambién proporcionamos los puntos de control de nuestros modelos capacitados.
| Conjunto de datos | Modelo | Enlace de punto de control |
|---|---|---|
| QG | Multinomial | enlace |
| QG | Absorbente | enlace |
| QG | Reparador-multinomial | enlace |
| QG | Reparam-absorbente | enlace |
| QQP | Multinomial | enlace |
| QQP | Absorbente | enlace |
| QQP | Reparador-multinomial | enlace |
| QQP | Reparam-absorbente | enlace |
@article { zheng2023rdm ,
title = { A Reparameterized Discrete Diffusion Model for Text Generation } ,
author = { Zheng, Lin and Yuan, Jianbo and Yu, Lei and Kong, Lingpeng } ,
journal = { arXiv preprint arXiv:2302.05737 } ,
year = { 2023 }
}