このリポジトリには、テキスト生成のための紙の公式実装が修復された離散拡散モデルが含まれています。
コードベースはFairSeqで実装されています。依存関係をインストールするには、次のコマンドを実行します(仮想環境で推奨)。
pip install -r requirements.txt
# install our package of discrete diffusion models
pip install -e discrete_diffusion
# install our fork of fairseq
cd fairseq
python3 setup.py build develop
cd ..注環境は、Python 3.8.10、Pytorch 1.10.0/1.12.0、およびCuda 11.3でテストされています。また、FAIRSEQのフォークは、元のコードベースのいくつかのファイルを変更することに注意してください。 FairSeqのより最近のバージョンを使用すると、予期しない依存関係の競合につながる可能性があります。
一般的な使用のために、自己完結型ライブラリdiscrete_diffusionに離散拡散モデルを実装します。ライブラリは、で構成されるさまざまな典型的な離散拡散モデルの実装を提供します
(Vanilla/Reparameterized) multinomial diffusion :トークンシーケンスにuniformノイズを注入する拡散プロセス。バニラの多項拡散の実装は、元の論文のコードベースに密接に従います。(Vanilla/Reparameterized) absorbing diffusion :D3PMペーパーで説明されているように、シーケンス内のトークンがmasking状態に吸収される可能性がある拡散プロセス。これらの拡散モデルは、外部用途を可能にする同じインターフェイスのセットを共有します。特に、それらはDiscreteDiffusionクラスのサブクラスとして定義され、次の形式を取得します。
class DiscreteDiffusion ( nn . Module ):
"""
The parent class for discrete denoising diffusion probabilistic models.
It supports the following methods:
- q_sample()
Sample x_t ~ q(x_t | x_0) to construct noisy Transformer inputs.
- compute_losses()
Compute the loss L_t = KL(q||p) at t-th time step.
- sample_step()
Sample x_t ~ p(x_{t-1} | x_t, x_0) at t-th time step.
"""
def __init__ ( self , num_timesteps ):
super (). __init__ ()
self . num_timesteps = num_timesteps
def q_sample ( self , x_0 , t , ** kwargs ):
"""
Sample from q(x_t | x_0), which is used as the model inputs.
Args:
x_0: token ids with shape [B, N]
t: current time step, tensor with shape [B]
Returns:
return a dict of relevant outputs including x_t.
"""
def compute_losses ( self , inputs , ** kwargs ):
"""
Compute the loss objective KL(q||p) to train our generative process.
Args:
inputs: a dict that contains input types specific to different diffusion processes, containing
- x_t: token ids with shape [B, N]
- t: scalar timesteps, with shape [B]
Returns:
possibly return a dict of relevant outputs, including the loss used for training.
"""
def sample_step ( self , decoder_out , denoising_fn , ** kwargs ):
"""
Given a time step t, start from x_t and sample x_{t-k} from q(x_{t-k} | x_t).
Args:
decoder_out: a namedtuple that contains decoding info, including
- x_t: token ids with shape [B, N]
- t: scalar timesteps
- max_steps: the maximum number of decoding steps
- ...
denoising_fn: a function that takes in x_t and t and returns model logits
kwargs: other arguments that are used to control decoding.
Returns:
return a new decoder_out namedtuple.
""" DiscreteDiffusionモデルは、以下を構成することでインスタンス化できます。
--num-diffusion-timesteps <int>拡散時間ステップの総数を指定します(デフォルト:50)--diffusion-type <str>拡散モデルタイプを指定します(選択: {absorbing, multinomial, reparam-absorbing, reparam-multinomial} ))--noise-scheduler-type <str>ノイズスケジュールをバニラ/レパラムの多項拡散でのみ指定します(典型的な選択: {linear, cosine} ;デフォルト: cosine )q_sample()のフォワードサンプリングルーチンに固有の重要な引数。--q-sample-mode <str>サンプリング戦略を指定します(choices: {default, coupled, multi-step, multi-sample} ;デフォルト: default )。サンプリングのためのさまざまな選択肢を提供しますdefault :単一のサンプルが描画されますmulti-step :2つのIIDタイムステップをサンプリングしますmulti-sample :2つのIIDサンプルをサンプリングしますcoupled :条件付きトレーニングとも呼ばれます。これは、論文の付録Fに詳述されています。これは、2つのIIDタイムステップをサンプリングすることから始まりますcoupledサンプリングモードは、バニラ多項/吸収拡散の両方に大幅な改善をもたらすことがわかりましたが、レパラメーター化されたバリアントではゲインは一貫して実質的ではありません。--not-diffusing-special-sym (デフォルト:false)compute_losses()の損失客観的計算に固有の重要な引数--reweighting-type <str>私たちの再送信ファミリーの再重み付けスキームを指定します(選択: {linear, reciprocal, none} ;デフォルト: linear )--label-smoothing <float>ラベルスムージングの速度を指定します(デフォルト:0.1)sample_step()のデコードルーチンに固有の重要な引数--argmax-decoding除去された変圧器出力にargmaxデコードを使用するかどうかを示します--temperature <float>温度を指定します--decoding-strategy <str>バニラ( default ) / reparameterized( {default, reparam-<options>} reparam-<options> ;詳細を参照)の使用を指定しますdefault--load-ema-weights生成にEMAモデルの重みをロードするかどうかを示します(デフォルト:false)--iter-decode-max-iter <int>デコード用のタイムステップの最大数を指定します(デフォルト:10)--iter-decode-with-beam <int>並列の異なる長さの複数のシーケンスを解読するためのビームサイズを指定します(デフォルト:1)--iter-decode-force-max-iter反復デコードが指定された数の反復数を実行し、終了しないことを示します。このフラグをtrueに設定することをお勧めします。より包括的な議論リストについては、こちらをご覧ください。
--decoding-strategy defaultを渡すことにより、バニラサンプリングスキーム(各離散拡散プロセスに固有)が使用されます。
より高度なデコードアプローチは--decoding-strategy reparam-<conditioning-of-v>-<topk_mode>-<schedule>を通過することで呼び出すことができます。このアプローチは、私たちの論文で提案されている修復に基づいており、より効果的なデコード手順を可能にします。オプションは、介してデコードアルゴリズムを指定します
<conditioning-of-v> : uncondまたはcond (デフォルトのuncond ):ルーティング変数を生成するかどうか<topk_mode> : stochastic<float>またはdeterministic (デフォルトのdeterministic ):確率的または決定論的なTop $ k $の選択を使用するかどうか。 stochastic<float>のフロート値は、確率上のトップ$ k $選択のランダム性の程度を指定します。<schedule> : linearまたはcosine (デフォルトのcosine ):のスケジュールオプションの詳細については、実装を参照してください。
詳細については、以下のスクリプトをご覧ください。
注記
- この作業で考慮されるすべてのタスクは、元のデータで動作し、知識蒸留(KD)を採用しないことに注意してください。
FairSeq/Examplesの標準前処理に従って、二等式のデータを準備します。
# fetch and preprocess the data to BPE codes
cd examples/translation/
bash prepare-iwslt14.sh
cd ../..
# binarize the data
TEXT=examples/translation/iwslt14.tokenized.de-en
fairseq-preprocess --joined-dictionary --source-lang de --target-lang en
--trainpref $TEXT /train --validpref $TEXT /valid --testpref $TEXT /test
--destdir data-bin/iwslt14.tokenized.de-en
--workers 20FairSeq/Examplesでリリースされたデータを使用して、データセットを準備します。
wget http://dl.fbaipublicfiles.com/nat/original_dataset.zip
unzip original_dataset.zip
TEXT=wmt14_ende
fairseq-preprocess --joined-dictionary
--source-lang en --target-lang de
--trainpref $TEXT /train.en-de --validpref $TEXT /valid.en-de --testpref $TEXT /test.en-de
--destdir data-bin/wmt14_ende --thresholdtgt 0 --thresholdsrc 0
--workers 20このデータセットでは、このリポジトリで前処理されたRAWデータWMT16.tar.gzを使用します。
tar xzvf wmt16.tar.gz
TEXT=wmt16/en-ro
# move train/ dev/ test/ bpe codes into the $TEXT folder
mv $TEXT /train/corpus.bpe.en $TEXT /train.bpe.en
mv $TEXT /train/corpus.bpe.ro $TEXT /train.bpe.ro
mv $TEXT /dev/dev.bpe.en $TEXT /dev.bpe.en
mv $TEXT /dev/dev.bpe.ro $TEXT /dev.bpe.ro
mv $TEXT /test/test.bpe.en $TEXT /test.bpe.en
mv $TEXT /test/test.bpe.ro $TEXT /test.bpe.ro
# binarize the data
fairseq-preprocess --joined-dictionary
--source-lang en --target-lang ro
--trainpref $TEXT /train.bpe --validpref $TEXT /dev.bpe --testpref $TEXT /test.bpe
--destdir data-bin/wmt16_enro --thresholdtgt 0 --thresholdsrc 0
--workers 20最初にfairseqフォルダーに入り、次に次のコマンドを実行してモデルをトレーニングします。
# ####### training scripts for IWSLT'14 , WMT'14, and WMT'16
# first cd to fairseq
# we use 1 GPU for IWSLT'14, 4 GPUs for WMT'14 and 2 GPUs for WMT'16 datasets respectively.
CUDA_VISIBLE_DEVICES=0 bash experiments/mt_train.sh -m absorbing -d < iwslt/wmt14/wmt 16> -s default -e True --store-ema --label-smoothing 0.1
CUDA_VISIBLE_DEVICES=1 bash experiments/mt_train.sh -m multinomial -d < iwslt/wmt14/wmt 16> -s default -e True --not-diffusing-special-sym --store-ema --label-smoothing 0.0
CUDA_VISIBLE_DEVICES=2 bash experiments/mt_train.sh -m reparam-absorbing -d < iwslt/wmt14/wmt 16> -s default -e True --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linear
CUDA_VISIBLE_DEVICES=3 bash experiments/mt_train.sh -m reparam-multinomial -d < iwslt/wmt14/wmt 16> -s default -e True --not-diffusing-special-sym --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linear注記
-s <str>、実験の名前を指定するために使用されます。-e Trueの後にそれらを追加することにより、トレーニングに固有のカスタム引数を渡すことができます。
評価パイプラインはexperiments/mt_generate.shによって処理されます。スクリプトは翻訳結果を生成し、BLEUスコアを評価します。
# ########## IWLS'14, WMT'14, and WMT'16 datasets
# we recommend putting each checkpoint into a separate folder
# since the script will put the decoded results into a file under the same folder of each checkpoint.
CUDA_VISIBLE_DEVICES=0 bash experiments/mt_generate.sh -a false -c < checkpoint_path > -d < iwslt/wmt14/wmt 16> 議論:
-a :複数のチェックポイントを平均するかどうか-c :チェックポイントの位置を示します。 -a false (平均チェックポイントではない)の場合、チェックポイントパスを渡します。 -a trueの場合、平均化のために異なるトレーニング手順で複数のチェックポイントを保存するディレクトリを渡します。-d :データセット名また、訓練されたモデルのチェックポイントも提供します。
| データセット | モデル | チェックポイントリンク |
|---|---|---|
| IWSLT'14 | 多項 | リンク |
| IWSLT'14 | 吸収 | リンク |
| IWSLT'14 | Reparam-Multinomial | リンク |
| IWSLT'14 | レパラム吸収 | リンク |
| WMT'14 | 多項 | リンク |
| WMT'14 | 吸収 | リンク |
| WMT'14 | Reparam-Multinomial | リンク |
| WMT'14 | レパラム吸収 | リンク |
| WMT'16 | 多項 | リンク |
| WMT'16 | 吸収 | リンク |
| WMT'16 | Reparam-Multinomial | リンク |
| WMT'16 | レパラム吸収 | リンク |
質問生成と言い換えタスクのために、diffuseqの実験セットアップに従います。
これらの2つのタスクの生データは、元のDiffuseQリポジトリから取得できます。次に、提供されたスクリプトを介してデータを双方向させます。
# put the raw data in the directory ``diffuseq_data/QG``
# Preprocess the question generation dataset
bash diffusion_mt/scripts/preprocess_diffuseq_datasets.sh QG
# put the raw data in the directory ``diffuseq_data/QQP``
# Preprocess the paraphrasing dataset
bash diffusion_mt/scripts/preprocess_diffuseq_datasets.sh QQP # QQP or QG datasets
# first cd to fairseq
CUDA_VISIBLE_DEVICES=0,1 bash experiments/diffuseq_train.sh -m absorbing -d < qqp/qg > -s default -e True --store-ema --label-smoothing 0.1
CUDA_VISIBLE_DEVICES=2,3 bash experiments/diffuseq_train.sh -m multinomial -d < qqp/qg > -s default -e True --not-diffusing-special-sym --store-ema --label-smoothing 0.0
CUDA_VISIBLE_DEVICES=0,1 bash experiments/diffuseq_train.sh -m reparam-multinomial -d < qqp/qg > -s default -e True --not-diffusing-special-sym --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linear
CUDA_VISIBLE_DEVICES=2,3 bash experiments/diffuseq_train.sh -m reparam-absorbing -d < qqp/qg > -s default -e True --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linear diffuseqのように、生成および評価プロトコルに密接に従って、直接比較を確保します。パイプライン全体は、 fairseq/diffusion_mt/scripts/decode_diffuseq.pyおよびfairseq/diffusion_mt/scripts/eval_diffuseq.pyで再実装されています。次のコマンドを実行します。
# we recommend putting each checkpoint into a separate folder
# since the script will put the decoded results into a file under the same folder of each checkpoint.
CUDA_VISIBLE_DEVICES=0 bash experiments/diffuseq_generate.sh -a false -b true -c < checkpoint_path > -d < qqp/qg > 議論:
-a :複数のチェックポイントを平均するかどうか-b :MBRデコードに複数のサンプルを使用するかどうか-c :チェックポイントの位置を示します。 -a false (平均チェックポイントではない)の場合、チェックポイントパスを渡します。 -a trueの場合、平均化のために異なるトレーニング手順で複数のチェックポイントを保存するディレクトリを渡します。-d :データセット名また、訓練されたモデルのチェックポイントも提供します。
| データセット | モデル | チェックポイントリンク |
|---|---|---|
| QG | 多項 | リンク |
| QG | 吸収 | リンク |
| QG | Reparam-Multinomial | リンク |
| QG | レパラム吸収 | リンク |
| QQP | 多項 | リンク |
| QQP | 吸収 | リンク |
| QQP | Reparam-Multinomial | リンク |
| QQP | レパラム吸収 | リンク |
@article { zheng2023rdm ,
title = { A Reparameterized Discrete Diffusion Model for Text Generation } ,
author = { Zheng, Lin and Yuan, Jianbo and Yu, Lei and Kong, Lingpeng } ,
journal = { arXiv preprint arXiv:2302.05737 } ,
year = { 2023 }
}