Ce repo contient une implémentation pytorch pour la modélisation générative basée sur le score papier à travers des équations différentielles stochastiques
par Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon et Ben Poole
Nous proposons un cadre unifié qui généralise et améliore les travaux antérieurs sur les modèles génératifs basés sur les scores à travers la lentille des équations différentielles stochastiques (SDE). En particulier, nous pouvons transformer les données en une distribution de bruit simple avec un processus stochastique à temps continu décrit par un SDE. Ce SDE peut être inversé pour la génération d'échantillons si nous connaissons le score des distributions marginales à chaque pas de temps intermédiaire, qui peut être estimée avec la correspondance des scores. L'idée de base est capturée dans la figure ci-dessous:

Notre travail permet une meilleure compréhension des approches existantes, de nouveaux algorithmes d'échantillonnage, du calcul exact de vraisemblance, du codage uniquement identifiable, de la manipulation du code latent et apporte de nouvelles capacités de génération conditionnelle (y compris, mais sans s'y limiter, la génération de la classe, l'instruction et la colorisation) à la famille de modèles génératifs basés sur les scores.
Tous combinés, nous avons réalisé un FID de 2,20 et un score de création de 9,89 pour la génération inconditionnelle sur CIFAR-10, ainsi qu'une génération à haute fidélité d'images 1024px Celeba-HQ (échantillons ci-dessous). De plus, nous avons obtenu une valeur de vraisemblance de 2,99 bits / DIM sur des images CIFAR-10 désactivées uniformément.

Outre les modèles NCSN ++ et DDPM ++ dans notre article, cette base de code réimplique également de nombreux modèles basés sur les scores précédents en un seul endroit, y compris NCSN de la modélisation générative en estimant les gradients de la distribution des données, NCSNV2 à partir de techniques améliorées pour les modèles génératifs de score de formation et DDPM à partir de modèles probabilistes de diffusion.
Il soutient la formation de nouveaux modèles, évaluant la qualité de l'échantillon et la probabilité des modèles existants. Nous avons soigneusement conçu le code pour être modulaire et facilement extensible aux nouveaux SDE, prédicteurs ou correcteurs.
La plupart des modèles sont maintenant également disponibles? Diffuseurs et accessibles via le pipeline ScoresDeve.
Les diffuseurs vous permettent de tester les modèles basés sur le score SDE dans Pytorch en quelques lignes de code.
Vous pouvez installer des diffuseurs comme suit:
pip install diffusers torch accelerate
Puis essayez les modèles avec seulement quelques lignes de code:
from diffusers import DiffusionPipeline
model_id = "google/ncsnpp-ffhq-1024"
# load model and scheduler
sde_ve = DiffusionPipeline . from_pretrained ( model_id )
# run pipeline in inference (sample random noise and denoise)
image = sde_ve (). images [ 0 ]
# save image
image [ 0 ]. save ( "sde_ve_generated_image.png" )Plus de modèles peuvent être trouvés directement sur le moyeu.
Veuillez trouver une implémentation JAX ici, qui prend en plus prend en charge la génération conditionnelle de classe avec un classificateur pré-formé, et reprendre un processus d'évaluation après la préemption.
En général, cette version Pytorch consomme moins de mémoire mais fonctionne plus lentement que Jax. Voici une référence sur la formation d'un NCSN ++ Cont. modèle avec VE SDE. Le matériel est 4x Nvidia Tesla V100 GPU (32 Go)
| Cadre | Temps (deuxième par étape) | Utilisation de la mémoire au total (GB) |
|---|---|---|
| Pytorch | 0,56 | 20.6 |
Jax ( n_jitted_steps=1 ) | 0,30 | 29.7 |
Jax ( n_jitted_steps=5 ) | 0.20 | 74.8 |
Exécutez ce qui suit pour installer un sous-ensemble de packages Python nécessaires pour notre code
pip install -r requirements.txt Nous fournissons le fichier de statistiques pour CIFAR-10. Vous pouvez télécharger cifar10_stats.npz et l'enregistrer sur assets/stats/ . Consultez # 5 sur la façon de calculer ce fichier de statistiques pour de nouveaux ensembles de données.
Former et évaluer nos modèles via main.py
main.py:
--config: Training configuration.
(default: ' None ' )
--eval_folder: The folder name for storing evaluation results
(default: ' eval ' )
--mode: < train | eval > : Running mode: train or eval
--workdir: Working directory config est le chemin d'accès au fichier de configuration. Nos fichiers de configuration prescrits sont fournis en configs/ . Ils sont formatés selon ml_collections et devraient être assez explicites.
CONVENTIONS DE NOMMANDE DES FILEURS DE CONFIG : Le chemin d'accès d'un fichier de configuration est une combinaison des dimensions suivantes:
cifar10 , celeba , celebahq , celebahq_256 , ffhq_256 , celebahq , ffhq .ncsn , ncsnv2 , ncsnpp , ddpm , ddpmpp . workdir est le chemin qui stocke tous les artefacts d'une seule expérience, comme les points de contrôle, les échantillons et les résultats d'évaluation.
eval_folder est le nom d'un sous-dossier dans workdir qui stocke tous les artefacts du processus d'évaluation, comme les points de contrôle Meta pour la prévention de la préemption, les échantillons d'image et les décharges numpy de résultats quantitatifs.
mode est "Train" ou "EVAL". Lorsqu'il est réglé pour «s'entraîner», il démarre la formation d'un nouveau modèle ou reprend la formation d'un ancien modèle si ses points de méta-chère (pour reprendre la course après la préemption dans un environnement cloud) existent dans workdir/checkpoints-meta . Lorsqu'il est défini sur "EVAL", il peut faire une combinaison arbitraire de ce qui suit
Évaluez la fonction de perte sur l'ensemble de données de test / validation.
Générez un nombre fixe d'échantillons et calculez son score de création, FID ou Kid. Avant l'évaluation, les fichiers de statistiques doivent avoir déjà été téléchargés / calculés et stockés en assets/stats .
Calculez le journal logarithmique sur l'ensemble de données de formation ou de test.
Ces fonctionnalités peuvent être configurées via des fichiers de configuration, ou plus commodément, via la prise en charge de la ligne de commande du package ml_collections . Par exemple, pour générer des échantillons et évaluer la qualité de l'échantillon, fournissez le drapeau --config.eval.enable_sampling ; Pour calculer le journal de log, alimentez l'indicateur --config.eval.enable_bpd et spécifiez --config.eval.dataset=train/test pour indiquer s'il faut calculer les probabilités sur l'ensemble de données de formation ou de test.
sde_lib.SDE et implémentez toutes les méthodes abstraites. La méthode discretize() est facultative et la valeur par défaut est la discrétisation Euler-Maruyama. Les méthodes d'échantillonnage existantes et le calcul de vraisemblance fonctionneront automatiquement pour ce nouveau SDE.sampling.Predictor , implémentez la méthode Résumé update_fn et enregistrez son nom avec @register_predictor . Le nouveau prédicteur peut être directement utilisé dans sampling.get_pc_sampler pour l'échantillonnage de prédicteur-corrécteur, et toutes les autres méthodes de génération contrôlable dans controllable_generation.py .sampling.Corrector , implémentez la méthode Résumé update_fn et enregistrez son nom avec @register_corrector . Le nouveau correcteur peut être directement utilisé dans sampling.get_pc_sampler et toutes les autres méthodes de génération contrôlable dans controllable_generation.py . Tous les points de contrôle sont fournis dans ce Google Drive.
Instructions : Vous pouvez trouver deux points de contrôle pour certains modèles. Le premier point de contrôle (avec un nombre plus petit) est celui que nous avons signalé des scores FID dans le tableau 3 de notre article (correspondant également au FID et est des colonnes du tableau ci-dessous). Le deuxième point de contrôle (avec un nombre plus grand) est celui que nous avons signalé des valeurs de vraisemblance et des fidés d'échantillonneurs ODE Black-Box dans les colonnes du tableau 2 de notre article (FID (ODE) et NNL (bits / dim) dans le tableau ci-dessous). Le premier correspond au plus petit FID au cours de la formation (toutes les 50 000 itérations). Le dernier est le dernier point de contrôle pendant la formation.
Selon la politique de Google, nous ne pouvons pas publier nos points de contrôle CELEBA et CELEBA-HQ d'origine. Cela dit, j'ai recommandé des modèles sur FFHQ 1024PX, FFHQ 256PX et CELEBA-HQ 256PX avec des ressources personnelles, et ils ont atteint des performances similaires à nos points de contrôle internes.
Voici une liste détaillée des points de contrôle et leurs résultats rapportés dans le document. FID (ODE) correspond à la qualité de l'échantillon du solveur ODE à boîte noire appliquée à l'ODE du flux de probabilité.
| Chemin de point de contrôle | Fid | EST | FID (ODE) | NNL (bits / dim) |
|---|---|---|---|---|
ve/cifar10_ncsnpp/ | 2.45 | 9.73 | - | - |
ve/cifar10_ncsnpp_continuous/ | 2.38 | 9.83 | - | - |
ve/cifar10_ncsnpp_deep_continuous/ | 2.20 | 9.89 | - | - |
vp/cifar10_ddpm/ | 3.24 | - | 3.37 | 3.28 |
vp/cifar10_ddpm_continuous | - | - | 3.69 | 3.21 |
vp/cifar10_ddpmpp | 2.78 | 9.64 | - | - |
vp/cifar10_ddpmpp_continuous | 2.55 | 9.58 | 3.93 | 3.16 |
vp/cifar10_ddpmpp_deep_continuous | 2.41 | 9.68 | 3.08 | 3.13 |
subvp/cifar10_ddpm_continuous | - | - | 3.56 | 3.05 |
subvp/cifar10_ddpmpp_continuous | 2.61 | 9.56 | 3.16 | 3.02 |
subvp/cifar10_ddpmpp_deep_continuous | 2.41 | 9.57 | 2.92 | 2.99 |
| Chemin de point de contrôle | Échantillons |
|---|---|
ve/bedroom_ncsnpp_continuous | ![]() |
ve/church_ncsnpp_continuous | ![]() |
ve/ffhq_1024_ncsnpp_continuous | ![]() |
ve/ffhq_256_ncsnpp_continuous | ![]() |
ve/celebahq_256_ncsnpp_continuous | ![]() |
| Lien | Description |
|---|---|
| Chargez nos points de contrôle pré-entraînés et jouez avec l'échantillonnage, le calcul de vraisemblance et la synthèse contrôlable (Jax + Flax) | |
| Chargez nos points de contrôle pré-entraînés et jouez avec l'échantillonnage, le calcul de vraisemblance et la synthèse contrôlable (pytorch) | |
| Tutoriel de modèles génératifs basés sur les scores dans Jax + Flax | |
| Tutoriel des modèles génératifs basés sur les scores à Pytorch |
config.training.n_jitted_steps . Pour CIFAR-10, nous vous recommandons d'utiliser config.training.n_jitted_steps=5 Lorsque votre GPU / TPU a une mémoire suffisante; Sinon, nous vous recommandons d'utiliser config.training.n_jitted_steps=1 . Notre implémentation actuelle nécessite config.training.log_freq pour être dividable par n_jitted_steps pour la journalisation et le point de contrôle pour fonctionner normalement.snr (rapport signal / bruit) de LangevinCorrector se comporte un peu comme un paramètre de température. snr plus important se traduit généralement par des échantillons plus lisses, tandis que snr plus petit donne des échantillons de qualité plus diversifiés mais plus faibles. Les valeurs typiques du snr sont 0.05 - 0.2 , et elle nécessite un accord pour frapper le point idéal.config.model.sigma_max pour être la distance maximale par paire entre les échantillons de données dans l'ensemble de données de formation. Si vous trouvez le code utile pour vos recherches, veuillez envisager de citer
@inproceedings {
song2021scorebased,
title = { Score-Based Generative Modeling through Stochastic Differential Equations } ,
author = { Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole } ,
booktitle = { International Conference on Learning Representations } ,
year = { 2021 } ,
url = { https://openreview.net/forum?id=PxTIG12RRHS }
}Ce travail est construit sur certains articles précédents qui pourraient également vous intéresser: