該回購包含通過隨機微分方程的基於紙得分的生成模型的Pytorch實現
Yang Song,Jascha Sohl-Dickstein,Diederik P. Kingma,Abhishek Kumar,Stefano Ermon和Ben Poole
我們提出了一個統一的框架,該框架通過隨機微分方程(SDE)的鏡頭概括並改善了對基於得分的生成模型的先前工作。特別是,我們可以通過SDE描述的連續時間隨機過程將數據轉換為簡單的噪聲分佈。如果我們知道每個中間時間步驟的邊際分佈的得分,則可以將此SDE逆轉以生成樣本,這可以通過分數匹配來估計。基本思想在以下圖中捕獲:

我們的工作使人們可以更好地了解現有方法,新的採樣算法,確切的可能性計算,獨特的可識別編碼,潛在的代碼操縱以及帶來新的條件生成能力(包括但不限於課堂條件生成,內置和著色)。
總的來說,我們在CIFAR-10上的無條件產生以及高保真生成的1024px Celeba-HQ圖像(下面的樣本)的無條件生成中獲得了2.20的FID和9.89的成立分數。此外,我們在均勻去除的CIFAR-10圖像上獲得了2.99位/昏暗的可能性值。

除我們的論文中的NCSN ++和DDPM ++模型外,該代碼庫還重新實現了一個以前的基於得分的模型,包括通過估算數據分佈的梯度,從培訓基於訓練分數的改進技術中的NCSNV2來估算NCSNV2的NCSN ,用於基於訓練的分數模型,以及來自demo noto ddpm的ddpm ,以及來自demofusion diffusion diffusion probififusivusion probific probilistic模型。
它支持培訓新模型,評估現有模型的樣本質量和可能性。我們仔細設計了代碼是模塊化的,易於新的SDE,預測變量或更正器。
現在大多數型號都可以使用嗎?擴散器和通過分數管道可加入。
擴散器允許您僅使用幾行代碼測試Pytorch中基於SCONS的模型。
您可以按以下方式安裝擴散器:
pip install diffusers torch accelerate
然後僅使用幾行代碼嘗試模型:
from diffusers import DiffusionPipeline
model_id = "google/ncsnpp-ffhq-1024"
# load model and scheduler
sde_ve = DiffusionPipeline . from_pretrained ( model_id )
# run pipeline in inference (sample random noise and denoise)
image = sde_ve (). images [ 0 ]
# save image
image [ 0 ]. save ( "sde_ve_generated_image.png" )可以直接在集線器上找到更多型號。
請在此處找到JAX實施,該實現還支持具有預培訓的分類器的課堂條件生成,並在先發製人的情況下恢復評估過程。
通常,此Pytorch版本消耗的內存較少,但運行速度比JAX慢。這是培訓NCSN ++續的基準。具有VE SDE的模型。硬件是4x nvidia tesla V100 GPU(32GB)
| 框架 | 時間(第二步第二) | 總記憶使用(GB) |
|---|---|---|
| Pytorch | 0.56 | 20.6 |
jax( n_jitted_steps=1 ) | 0.30 | 29.7 |
jax( n_jitted_steps=5 ) | 0.20 | 74.8 |
運行以下內容,為我們的代碼安裝必要的Python軟件包子集
pip install -r requirements.txt我們提供CIFAR-10的統計文件。您可以下載cifar10_stats.npz並將其保存到assets/stats/ 。請查看有關如何計算新數據集此統計文件的#5。
訓練並通過main.py評估我們的模型。
main.py:
--config: Training configuration.
(default: ' None ' )
--eval_folder: The folder name for storing evaluation results
(default: ' eval ' )
--mode: < train | eval > : Running mode: train or eval
--workdir: Working directory config是配置文件的路徑。我們的規定配置文件以configs/ 。它們是根據ml_collections格式化的,應該是相當不言自明的。
配置文件的命名約定:配置文件的路徑是以下維度的組合:
cifar10 , celeba , celebahq , celebahq_256 , ffhq_256 , celebahq , ffhq之一。ncsn , ncsnv2 , ncsnpp , ddpm , ddpmpp之一。 workdir是存儲一個實驗的所有工件的路徑,例如檢查點,樣本和評估結果。
eval_folder是workdir中的子文件夾的名稱,該名稱存儲了評估過程的所有工件,例如用於預先避免的元檢查點,圖像樣本和定量結果的Numpy轉儲。
mode是“火車”或“評估”。當設置“火車”時,它會開始對新模型進行培訓,或者如果在workdir/checkpoints-meta 。設置為“評估”時,它可以任意組合以下
評估測試 /驗證數據集中的損失函數。
生成固定數量的樣品併計算其成立評分,FID或KID。在評估之前,必須已經在assets/stats中下載/計算並存儲統計文件。
計算培訓或測試數據集中的日誌樣品。
可以通過ml_collections軟件包的命令行支持來通過配置文件或更方便地配置這些功能。例如,要生成樣品並評估樣品質量,請提供--config.eval.enable_sampling標誌;要計算log-likelioness,請提供--config.eval.enable_bpd標誌,並指定--config.eval.dataset=train/test以指示是否要計算培訓或測試數據集中的可能性。
sde_lib.SDE抽像類並實現所有抽象方法。 discretize()方法是可選的,默認值是Euler-Maruyama離散化。現有的採樣方法和可能性計算將自動適用於此新的SDE。sampling.Predictor摘要類,實現update_fn摘要方法,然後在@register_predictor中註冊其名稱。新的預測變量可以直接用於sampling.get_pc_sampler用於預測器 - 校准採樣,以及所有其他可控生成方法中的所有其他controllable_generation.py生成方法。sampling.Corrector摘要類,實現update_fn抽象方法,然後在@register_corrector中註冊其名稱。新校正器可以直接用於sampling.get_pc_sampler ,以及所有其他可控的生成方法中的controllable_generation.py中。 所有檢查點均在此Google驅動器中提供。
說明:您可能會為某些型號找到兩個檢查點。第一個檢查點(數字較小)是我們在表3的表3中報告的FID分數(也對應於FID,是下表中的列)。第二個檢查點(具有較大數字)是我們在論文的表2中報告了黑盒ode採樣器的可能性值和FID,下表2(也是FID(ODE)和NNL(bits/dim)列)。前者對應於訓練過程中最小的FID(每50k迭代)。稍後是培訓期間的最後一個檢查站。
根據Google的政策,我們無法發布我們的原始Celeba和Celeba-HQ檢查點。也就是說,我已經在FFHQ 1024PX,FFHQ 256PX和Celeba-HQ 256PX上重新訓練了模型,並具有個人資源,它們的性能與我們的內部檢查站相似。
這是檢查點的詳細列表及其在論文中報告的結果。 FID(ode)對應於應用於概率流ode的黑框求解器的樣品質量。
| 檢查點路徑 | fid | 是 | FID(ode) | NNL(位/昏暗) |
|---|---|---|---|---|
ve/cifar10_ncsnpp/ | 2.45 | 9.73 | - | - |
ve/cifar10_ncsnpp_continuous/ | 2.38 | 9.83 | - | - |
ve/cifar10_ncsnpp_deep_continuous/ | 2.20 | 9.89 | - | - |
vp/cifar10_ddpm/ | 3.24 | - | 3.37 | 3.28 |
vp/cifar10_ddpm_continuous | - | - | 3.69 | 3.21 |
vp/cifar10_ddpmpp | 2.78 | 9.64 | - | - |
vp/cifar10_ddpmpp_continuous | 2.55 | 9.58 | 3.93 | 3.16 |
vp/cifar10_ddpmpp_deep_continuous | 2.41 | 9.68 | 3.08 | 3.13 |
subvp/cifar10_ddpm_continuous | - | - | 3.56 | 3.05 |
subvp/cifar10_ddpmpp_continuous | 2.61 | 9.56 | 3.16 | 3.02 |
subvp/cifar10_ddpmpp_deep_continuous | 2.41 | 9.57 | 2.92 | 2.99 |
| 檢查點路徑 | 樣品 |
|---|---|
ve/bedroom_ncsnpp_continuous | ![]() |
ve/church_ncsnpp_continuous | ![]() |
ve/ffhq_1024_ncsnpp_continuous | ![]() |
ve/ffhq_256_ncsnpp_continuous | ![]() |
ve/celebahq_256_ncsnpp_continuous | ![]() |
| 關聯 | 描述 |
|---|---|
| 加載我們驗證的檢查點,並使用採樣,可能性計算和可控合成(JAX + Flax)進行播放 | |
| 加載我們驗證的檢查點,並使用採樣,可能性計算和可控合成(Pytorch)進行播放 | |
| JAX + Flax中基於得分的生成模型的教程 | |
| Pytorch中基於分數的生成模型的教程 |
config.training.n_jitted_steps設置。對於CIFAR-10,我們建議使用config.training.n_jitted_steps=5當您的GPU/TPU具有足夠的內存時;否則,我們建議使用config.training.n_jitted_steps=1 。我們當前的實現需要config.training.log_freq由n_jitted_steps分開,以進行記錄和檢查點才能正常工作。LangevinCorrector的snr (信噪比)參數在某種程度上像溫度參數。較大的snr通常會產生更平滑的樣品,而較小的snr可提供更多樣化但質量較低的樣本。 snr的典型值為0.05 - 0.2 ,需要調整才能擊中最佳位置。config.model.sigma_max是訓練數據集中數據樣本之間的最大成對距離。 如果您發現該代碼對您的研究有用,請考慮引用
@inproceedings {
song2021scorebased,
title = { Score-Based Generative Modeling through Stochastic Differential Equations } ,
author = { Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole } ,
booktitle = { International Conference on Learning Representations } ,
year = { 2021 } ,
url = { https://openreview.net/forum?id=PxTIG12RRHS }
}這項工作建立在一些以前的論文基礎上,這些論文也可能感興趣: