Dieses Repo enthält eine Pytorch-Implementierung für die papierbewertungsbasierte generative Modellierung durch stochastische Differentialgleichungen
von Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon und Ben Poole
Wir schlagen einen einheitlichen Rahmen vor, der frühere Arbeiten an bewertungsbasierten Generativmodellen durch die Linse stochastischer Differentialgleichungen (SDEs) verallgemeinert und verbessert. Insbesondere können wir Daten in eine einfache Rauschverteilung mit einem stochastischen Prozess mit kontinuierlicher Zeit umwandeln, der von einem SDE beschrieben wird. Diese SDE kann für die Stichprobenerzeugung umgekehrt werden, wenn wir die Punktzahl der Grenzverteilungen bei jedem Zwischenzeitschritt kennen, der mit einer Score -Matching geschätzt werden kann. Die Grundidee wird in der folgenden Abbildung erfasst:

Unsere Arbeit ermöglicht ein besseres Verständnis bestehender Ansätze, neuen Stichprobenalgorithmen, exakte Wahrscheinlichkeitsberechnung, einzigartig identifizierbarer Codierung, latenter Code-Manipulation und bringt neue Fähigkeiten zur bedingten Erzeugung (einschließlich, aber nicht beschränkt auf die Erzeugung der Klassenkondition, Inpackung und Farbkolorisierung).
Alle kombiniert haben wir eine FID von 2,20 und einen Inception-Score von 9,89 für die bedingungslose Generation auf CIFAR-10 sowie eine hohe Erzeugung von 1024px Celeba-HQ-Bildern (Beispiele unten) erreicht. Zusätzlich haben wir einen Wahrscheinlichkeitswert von 2,99 Bit/Dim auf gleichmäßig dequantisierte CIFAR-10-Bildern erhalten.

Abgesehen von den NCSN ++- und DDPM ++- Modellen in unserem Artikel wird diese Codebasis auch viele frühere Score-basierte Modelle an einem Ort erneut implementiert, einschließlich NCSN aus generativen Modellierung, indem sie Gradienten der Datenverteilung, NCSNV2 aus verbesserten Techniken für Trainingsbewertungsmodelle, und DDPM aus demosing diffusions-probabilitätsmodelle schätzen.
Es unterstützt Schulungen neue Modelle und bewertet die Stichprobenqualität und die Wahrscheinlichkeit vorhandener Modelle. Wir haben den Code sorgfältig so konzipiert, dass sie für neue SDEs, Prädiktoren oder Korrektoren modular und leicht erweiterbar sind.
Die meisten Modelle sind jetzt auch erhältlich? Diffusoren und über die Scoresdeve -Pipeline assozibar.
Mit Diffusoren können Sie SDE -basierte Modelle in Pytorch in nur wenigen Codezeilen testen.
Sie können Diffusoren wie folgt installieren:
pip install diffusers torch accelerate
Probieren Sie die Modelle mit nur ein paar Codezeilen aus:
from diffusers import DiffusionPipeline
model_id = "google/ncsnpp-ffhq-1024"
# load model and scheduler
sde_ve = DiffusionPipeline . from_pretrained ( model_id )
# run pipeline in inference (sample random noise and denoise)
image = sde_ve (). images [ 0 ]
# save image
image [ 0 ]. save ( "sde_ve_generated_image.png" )Weitere Modelle finden Sie direkt am Hub.
Hier finden Sie eine JAX-Implementierung, die zusätzlich die klassenkonditionelle Erzeugung mit einem vorgeborenen Klassifikator unterstützt und nach der Vormeldung einen Bewertungsprozess wieder aufnimmt.
Im Allgemeinen verbraucht diese Pytorch -Version weniger Speicher, läuft jedoch langsamer als Jax. Hier ist ein Maßstab zum Training eines NCSN ++ Cont. Modell mit VE SDE. Hardware ist 4x Nvidia Tesla V100 GPUs (32 GB)
| Rahmen | Zeit (zweiter pro Schritt) | Speicherverbrauch insgesamt (GB) |
|---|---|---|
| Pytorch | 0,56 | 20.6 |
JAX ( n_jitted_steps=1 ) | 0,30 | 29.7 |
JAX ( n_jitted_steps=5 ) | 0,20 | 74,8 |
Führen Sie Folgendes aus, um eine Teilmenge der erforderlichen Python -Pakete für unseren Code zu installieren
pip install -r requirements.txt Wir stellen die Statistikdatei für CIFAR-10 an. Sie können cifar10_stats.npz herunterladen und auf assets/stats/ speichern. Schauen Sie sich #5, wie Sie diese Statistikdatei für neue Datensätze berechnen.
Trainieren und bewerten Sie unsere Modelle über main.py
main.py:
--config: Training configuration.
(default: ' None ' )
--eval_folder: The folder name for storing evaluation results
(default: ' eval ' )
--mode: < train | eval > : Running mode: train or eval
--workdir: Working directory config ist der Pfad zur Konfigurationsdatei. Unsere vorgeschriebenen Konfigurationsdateien finden Sie in configs/ . Sie sind gemäß ml_collections formatiert und sollten ziemlich selbsterklärend sein.
Benennung von Konventionen von Konfigurationsdateien : Der Pfad einer Konfigurationsdatei ist eine Kombination der folgenden Abmessungen:
cifar10 , celeba , celebahq , celebahq_256 , ffhq_256 , celebahq , ffhq .ncsn , ncsnv2 , ncsnpp , ddpm , ddpmpp . workdir ist der Weg, der alle Artefakte eines Experiments wie Kontrollpunkte, Stichproben und Bewertungsergebnisse speichert.
eval_folder ist der Name eines Unterordners in workdir , der alle Artefakte des Bewertungsprozesses speichert, wie Meta-Checkpoints für die Prävention vor der Emption, Bildproben und Numpy-Dumps quantitativer Ergebnisse.
mode ist entweder "Zug" oder "eval". Wenn es auf "Train" eingestellt ist, startet es das Training eines neuen Modells oder fördert das Training eines alten Modells, wenn es in workdir/checkpoints-meta wieder aufgenommen wird (um nach der Voraussetzung in einer Cloud-Umgebung wieder zu laufen). Wenn Sie auf "eval" eingestellt sind, kann es eine willkürliche Kombination aus den folgenden Durchführungen durchführen
Bewerten Sie die Verlustfunktion im Test- / Validierungsdatensatz.
Generieren Sie eine feste Anzahl von Proben und berechnen Sie seinen Aufnahmestunden, FID oder Kind. Vor der Bewertung müssen Statistikdateien bereits heruntergeladen/berechnet und in assets/stats gespeichert worden sein.
Berechnen Sie die Protokoll-Likelihood im Trainings- oder Testdatensatz.
Diese Funktionen können über Konfigurationsdateien oder bequemer über die Befehlszeilenunterstützung des ml_collections Pakets konfiguriert werden. Um beispielsweise Proben zu generieren und die Stichprobenqualität zu bewerten, liefern Sie die Flagge --config.eval.enable_sampling ; Um die Protokoll-Likelihoods zu berechnen, liefern Sie das Flag --config.eval.enable_bpd -Flag und geben Sie --config.eval.dataset=train/test um anzugeben, ob die Wahrscheinlichkeiten für das Training oder das Testen berechnet werden sollen.
sde_lib.SDE und implementieren alle abstrakten Methoden. Die methode discretize() ist optional und die Standardeinstellung ist die Diskretisierung von Euler-Maruyama. Bestehende Stichprobenmethoden und Wahrscheinlichkeitsberechnung funktionieren automatisch für diese neue SDE.sampling.Predictor Abstract -Klasse, implementieren Sie die abstrakte Methode update_fn und registrieren Sie ihren Namen mit @register_predictor . Der neue Prädiktor kann direkt in sampling.get_pc_sampler für die Prädiktor-Corrector-Sampling und alle anderen steuerbaren Generierungsmethoden in controllable_generation.py verwendet werden.sampling.Corrector Abstract -Klasse, implementieren Sie die Abstract -Methode update_fn und registrieren Sie ihren Namen mit @register_corrector . Der neue Korrektor kann direkt in sampling.get_pc_sampler und allen anderen steuerbaren Erzeugungsmethoden in controllable_generation.py verwendet werden. Alle Kontrollpunkte finden Sie in diesem Google Drive.
Anweisungen : Möglicherweise finden Sie zwei Kontrollpunkte für einige Modelle. Der erste Checkpoint (mit einer kleineren Zahl) ist derjenige, den wir in der Tabelle 3 (auch der FID entsprechen und Spalten in der folgenden Tabelle iS -Spalten) gemeldet haben. Der zweite Checkpoint (mit einer größeren Zahl) ist derjenige, den wir in den Spalten Tabelle 2 (auch FID (ODE) und NNL (Bits/Dim) in der folgenden Tabelle Wahrscheinlichkeitswerte und Fids von Schwarzbox-ODE-Sampler in der Tabelle 2 (auch FID (ODE) und NNL (Bits/Dim) gemeldet haben. Ersteres entspricht der kleinsten FID im Verlauf des Trainings (alle 50.000 Iterationen). Der spätere ist der letzte Kontrollpunkt während des Trainings.
Gemäß den Richtlinien von Google können wir unsere ursprünglichen Celeba- und Celeba-HQ-Kontrollpunkte nicht veröffentlichen. Trotzdem habe ich Modelle auf FFHQ 1024PX, FFHQ 256PX und Celeba-HQ 256PX mit persönlichen Ressourcen neu ausgestattet und sie haben eine ähnliche Leistung wie unsere internen Kontrollpunkte erzielt.
Hier finden Sie eine detaillierte Liste der Kontrollpunkte und deren Ergebnisse, die im Papier angegeben sind. FID (ODE) entspricht der Stichprobenqualität von Black-Box-ODE-Solver, die auf die Wahrscheinlichkeitsfluss-ODE angewendet wird.
| Checkpoint -Pfad | Fid | IST | FID (ODE) | NNL (Bits/Dim) |
|---|---|---|---|---|
ve/cifar10_ncsnpp/ | 2.45 | 9.73 | - - | - - |
ve/cifar10_ncsnpp_continuous/ | 2.38 | 9.83 | - - | - - |
ve/cifar10_ncsnpp_deep_continuous/ | 2.20 | 9.89 | - - | - - |
vp/cifar10_ddpm/ | 3.24 | - - | 3.37 | 3.28 |
vp/cifar10_ddpm_continuous | - - | - - | 3.69 | 3.21 |
vp/cifar10_ddpmpp | 2.78 | 9.64 | - - | - - |
vp/cifar10_ddpmpp_continuous | 2.55 | 9.58 | 3.93 | 3.16 |
vp/cifar10_ddpmpp_deep_continuous | 2.41 | 9.68 | 3.08 | 3.13 |
subvp/cifar10_ddpm_continuous | - - | - - | 3.56 | 3.05 |
subvp/cifar10_ddpmpp_continuous | 2.61 | 9.56 | 3.16 | 3.02 |
subvp/cifar10_ddpmpp_deep_continuous | 2.41 | 9.57 | 2.92 | 2.99 |
| Checkpoint -Pfad | Proben |
|---|---|
ve/bedroom_ncsnpp_continuous | ![]() |
ve/church_ncsnpp_continuous | ![]() |
ve/ffhq_1024_ncsnpp_continuous | ![]() |
ve/ffhq_256_ncsnpp_continuous | ![]() |
ve/celebahq_256_ncsnpp_continuous | ![]() |
| Link | Beschreibung |
|---|---|
| Laden Sie unsere vorbereiteten Checkpoints und spielen Sie mit Stichproben, Wahrscheinlichkeitsberechnung und kontrollierbarer Synthese (JAX + Flachs) | |
| Laden Sie unsere vorbereiteten Kontrollpunkte und spielen Sie mit Stichproben, Wahrscheinlichkeitsberechnung und steuerbarer Synthese (Pytorch) | |
| Tutorial für Score-basierte generative Modelle in Jax + Flachs | |
| Tutorial für Score-basierte generative Modelle in Pytorch |
config.training.n_jitted_steps festgelegt werden. Für CIFAR-10 empfehlen wir die Verwendung config.training.n_jitted_steps=5 Wenn Ihre GPU/TPU über ausreichendem Speicher verfügt. Andernfalls empfehlen wir die Verwendung von config.training.n_jitted_steps=1 . Unsere aktuelle Implementierung erfordert, dass config.training.log_freq durch n_jitted_steps zum Abmelden und Überprüfen der normalen Arbeiten teilnahmeberechtigt sein kann.snr (Signal-Rausch-Verhältnis) von LangevinCorrector verhält sich etwas wie ein Temperaturparameter. Größerer snr führt typischerweise zu glatteren Proben, während kleinerer snr vielfältigere, aber niedrigere Proben von niedrigerer Qualität verleiht. Die typischen Werte von snr beträgt 0.05 - 0.2 und es erfordert eine Stimmung, um den Sweet Spot zu treffen.config.model.sigma_max als maximale paarweise Abstand zwischen Datenproben im Trainingsdatensatz zu wählen. Wenn Sie den Code für Ihre Forschung nützlich finden, sollten Sie sich angeben
@inproceedings {
song2021scorebased,
title = { Score-Based Generative Modeling through Stochastic Differential Equations } ,
author = { Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole } ,
booktitle = { International Conference on Learning Representations } ,
year = { 2021 } ,
url = { https://openreview.net/forum?id=PxTIG12RRHS }
}Diese Arbeit basiert auf einigen früheren Papieren, die Sie auch interessieren könnten: