このレポは、確率的微分方程式を介したペーパースコアベースの生成モデリングのためのPytorchの実装が含まれています
Yang Song、Jascha Sohl-Dickstein、Diederik P. Kingma、Abhishek Kumar、Stefano Ermon、およびBen Poole
確率的微分方程式(SDE)のレンズを介して、スコアベースの生成モデルに関する以前の作業を一般化および改善する統一されたフレームワークを提案します。特に、SDEによって記述された連続時間の確率プロセスを使用して、データを単純なノイズ分布に変換できます。このSDEは、スコアマッチングで推定できる各中間時間ステップでの限界分布のスコアを知っている場合、サンプル生成のために逆にすることができます。基本的なアイデアは、以下の図に記載されています。

私たちの作業により、既存のアプローチ、新しいサンプリングアルゴリズム、正確な尤度計算、ユニークに識別可能なエンコード、潜在的なコード操作、および新しい条件付き生成能力(クラス条件の生成、包括的生成、インデッティング、色付けを含むが、それに限定されない)をスコアベースの生成モデルの家族にもたらすことができます。
すべてを組み合わせて、CIFAR-10の無条件発電では2.20のFIDとインセプションスコア9.89を達成しました。さらに、均一に不安定なCIFAR-10画像で2.99ビット/DIMの尤度値を取得しました。

私たちの論文のNCSN ++およびDDPM ++モデルとは別に、このコードベースは、データ分布の勾配、トレーニングスコアベースの生成モデルの改善された技術からのNCSNV2 、および拡散性確率モデルからのDDPMのDDPMを推定することにより、生成モデリングからのNCSNを含む、1つの場所で多くの以前のスコアベースのモデルを再刻みます。
新しいモデルのトレーニングをサポートし、サンプルの品質と既存のモデルの可能性を評価します。コードを慎重に設計し、モジュール式であり、新しいSDE、予測子、または補正装置に簡単に拡張できるようにしました。
ほとんどのモデルも現在入手可能ですか?スコアデーブパイプラインを介してディフューザーとアクセス可能。
Diffusersを使用すると、わずか数行のコードでPytorchのSDEベースのモデルをテストできます。
次のようにディフューザーをインストールできます。
pip install diffusers torch accelerate
そして、コードの数行でモデルを試してみてください。
from diffusers import DiffusionPipeline
model_id = "google/ncsnpp-ffhq-1024"
# load model and scheduler
sde_ve = DiffusionPipeline . from_pretrained ( model_id )
# run pipeline in inference (sample random noise and denoise)
image = sde_ve (). images [ 0 ]
# save image
image [ 0 ]. save ( "sde_ve_generated_image.png" )より多くのモデルをハブに直接見つけることができます。
ここでJAXの実装を見つけてください。これは、事前に訓練された分類器を使用したクラス条件の生成をサポートし、先制後に評価プロセスを再開します。
一般に、このPytorchバージョンは消費量が少なくなりますが、Jaxよりも遅くなります。 NCSN ++のトレーニングに関するベンチマークは次のとおりです。 ve sdeを使用したモデル。ハードウェアは4x Nvidia Tesla V100 GPU(32GB)です
| フレームワーク | 時間(ステップごとに秒) | 合計でのメモリ使用(GB) |
|---|---|---|
| Pytorch | 0.56 | 20.6 |
jax( n_jitted_steps=1 ) | 0.30 | 29.7 |
jax( n_jitted_steps=5 ) | 0.20 | 74.8 |
以下を実行して、コードに必要なPythonパッケージのサブセットをインストールします
pip install -r requirements.txtCIFAR-10の統計ファイルを提供します。 cifar10_stats.npzをダウンロードして、 assets/stats/に保存できます。新しいデータセットのこの統計ファイルを計算する方法については、#5をご覧ください。
main.pyを介してモデルをトレーニングおよび評価します。
main.py:
--config: Training configuration.
(default: ' None ' )
--eval_folder: The folder name for storing evaluation results
(default: ' eval ' )
--mode: < train | eval > : Running mode: train or eval
--workdir: Working directory config 、構成ファイルへのパスです。規定された構成ファイルは、 configs/で提供されます。それらはml_collectionsに従ってフォーマットされており、非常に自明である必要があります。
構成ファイルの命名規則:構成ファイルのパスは、次の寸法の組み合わせです。
cifar10の1つ、 celeba 、 celebahq 、 celebahq_256 、 ffhq_256 、 celebahq 、 ffhq 。ncsn 、 ncsnv2 、 ncsnpp 、 ddpm 、 ddpmppの1つ。 workdir 、チェックポイント、サンプル、評価結果など、1つの実験のすべてのアーティファクトを保存するパスです。
eval_folder 、先制予防、画像サンプル、定量的結果のnumpyダンプなど、評価プロセスのすべてのアーティファクトを保存するworkdirのサブフォルダーの名前です。
modeは「トレーニング」または「評価」のいずれかです。 「トレーニング」に設定すると、新しいモデルのトレーニングを開始するか、メタチェックポイント(クラウド環境での先制後の実行を再開するため)がworkdir/checkpoints-metaに存在する場合、古いモデルのトレーニングを再開します。 「評価」に設定すると、次の任意の組み合わせを行うことができます
テスト /検証データセットの損失関数を評価します。
固定数のサンプルを生成し、その開始スコア、FID、またはKIDを計算します。評価の前に、統計ファイルはすでにダウンロード/計算され、 assets/statsに保存されている必要があります。
トレーニングまたはテストデータセットのログ尤度を計算します。
これらの機能は、 ml_collectionsパッケージのコマンドラインサポートを使用して、構成ファイルを使用して、またはより便利に構成できます。たとえば、サンプルを生成してサンプル品質を評価するには、 --config.eval.enable_samplingフラグを供給します。ログリケリを計算するには、 --config.eval.enable_bpdフラグを供給し、 --config.eval.dataset=train/testを指定して、トレーニングまたはテストデータセットの尤度を計算するかどうかを示します。
sde_lib.SDE抽象クラスに固有の抽象的なメソッドを実装します。 discretize()メソッドはオプションであり、デフォルトはEuler-Maruyama離散化です。既存のサンプリング方法と尤度計算は、この新しいSDEに対して自動的に機能します。sampling.Predictor Abstractクラスに固有のもの、 update_fn抽象メソッドを実装し、 @register_predictorにその名前を登録します。新しい予測因子は、 sampling.get_pc_samplerでは、Predictor-Correctorサンプリングのために、およびcontrollable_generation.pyの他のすべての制御可能な生成方法で直接使用できます。sampling.Corrector Abstractクラスに固有のもの、 update_fn抽象メソッドを実装し、 @register_correctorにその名前を登録します。新しい修正器は、 sampling.get_pc_sampler 、およびcontrollable_generation.pyの他のすべての制御可能な生成方法で直接使用できます。 すべてのチェックポイントは、このGoogleドライブで提供されています。
手順:一部のモデルの2つのチェックポイントを見つけることができます。最初のチェックポイント(小さい数字)は、論文の表3でFIDスコアを報告したものです(FIDにも対応し、下の表の列です)。 2番目のチェックポイント(数字が大きい)は、論文の表2(FID(ODE)およびNNL(Bits/DIM)列)にあるブラックボックスオードサンプラーの尤度値とFIDを報告したものです。前者は、トレーニングの過程で最小のFIDに対応しています(50kの反復ごと)。後者は、トレーニング中の最後のチェックポイントです。
Googleのポリシーに従って、元のCelebaとCeleba-HQのチェックポイントをリリースすることはできません。とはいえ、私はFFHQ 1024PX、FFHQ 256PX、および個人リソースを使用してCelleba-HQ 256PXでモデルを再訓練しました。
ここにチェックポイントの詳細なリストとその結果が論文で報告されています。 FID(ODE)は、確率フローODEに適用されるブラックボックスODEソルバーのサンプル品質に対応しています。
| チェックポイントパス | fid | は | fid(ode) | NNL(ビット/薄暗い) |
|---|---|---|---|---|
ve/cifar10_ncsnpp/ | 2.45 | 9.73 | - | - |
ve/cifar10_ncsnpp_continuous/ | 2.38 | 9.83 | - | - |
ve/cifar10_ncsnpp_deep_continuous/ | 2.20 | 9.89 | - | - |
vp/cifar10_ddpm/ | 3.24 | - | 3.37 | 3.28 |
vp/cifar10_ddpm_continuous | - | - | 3.69 | 3.21 |
vp/cifar10_ddpmpp | 2.78 | 9.64 | - | - |
vp/cifar10_ddpmpp_continuous | 2.55 | 9.58 | 3.93 | 3.16 |
vp/cifar10_ddpmpp_deep_continuous | 2.41 | 9.68 | 3.08 | 3.13 |
subvp/cifar10_ddpm_continuous | - | - | 3.56 | 3.05 |
subvp/cifar10_ddpmpp_continuous | 2.61 | 9.56 | 3.16 | 3.02 |
subvp/cifar10_ddpmpp_deep_continuous | 2.41 | 9.57 | 2.92 | 2.99 |
| チェックポイントパス | サンプル |
|---|---|
ve/bedroom_ncsnpp_continuous | ![]() |
ve/church_ncsnpp_continuous | ![]() |
ve/ffhq_1024_ncsnpp_continuous | ![]() |
ve/ffhq_256_ncsnpp_continuous | ![]() |
ve/celebahq_256_ncsnpp_continuous | ![]() |
| リンク | 説明 |
|---|---|
| 事前に守られたチェックポイントをロードし、サンプリング、尤度計算、および制御可能な合成(Jax + Flax)で再生します | |
| 事前に守られたチェックポイントをロードし、サンプリング、尤度計算、および制御可能な合成(Pytorch)で再生します | |
| Jax + Flaxのスコアベースの生成モデルのチュートリアル | |
| Pytorchのスコアベースの生成モデルのチュートリアル |
config.training.n_jitted_stepsを介して設定できます。 CIFAR-10の場合、gpu/tpuに十分なメモリがある場合、 config.training.n_jitted_steps=5を使用することをお勧めします。それ以外の場合はconfig.training.n_jitted_steps=1を使用することをお勧めします。現在の実装では、 config.training.log_freq n_jitted_stepsによってロギングとチェックポイントが正常に機能するために配分可能である必要があります。LangevinCorrectorのsnr (信号対雑音比)パラメーターは、温度パラメーターのように幾分動作します。通常、 snrが大きくなるとサンプルがスムーズになりますが、 snrが小さくなると、より多様だが低品質のサンプルが得られます。 snrの典型的な値は0.05 - 0.2であり、スイートスポットを打つためにチューニングする必要があります。config.model.sigma_maxを選択することをお勧めします。 コードが研究に役立つと思う場合は、引用することを検討してください
@inproceedings {
song2021scorebased,
title = { Score-Based Generative Modeling through Stochastic Differential Equations } ,
author = { Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole } ,
booktitle = { International Conference on Learning Representations } ,
year = { 2021 } ,
url = { https://openreview.net/forum?id=PxTIG12RRHS }
}この作品は、以前のいくつかの論文の上に構築されています。