Este repositorio contiene una implementación de Pytorch para el modelado generativo basado en puntaje en papel a través de ecuaciones diferenciales estocásticas
Por Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon y Ben Poole
Proponemos un marco unificado que generaliza y mejora el trabajo previo en modelos generativos basados en puntaje a través de la lente de ecuaciones diferenciales estocásticas (SDE). En particular, podemos transformar los datos en una distribución de ruido simple con un proceso estocástico de tiempo continuo descrito por un SDE. Este SDE se puede revertir para la generación de muestras si conocemos la puntuación de las distribuciones marginales en cada paso de tiempo intermedio, que se puede estimar con la coincidencia de puntaje. La idea básica se captura en la figura a continuación:

Nuestro trabajo permite una mejor comprensión de los enfoques existentes, los nuevos algoritmos de muestreo, el cálculo de probabilidad exacta, la codificación de identificación única, la manipulación de código latente y trae nuevas habilidades de generación condicional (incluidas, entre otros, la generación condicional de clase, la entrada y la colorización) a la familia de los modelos generativos basados en puntaje.
Todo combinado, logramos un FID de 2.20 y una puntuación de inicio de 9.89 para la generación incondicional en CIFAR-10, así como la generación de alta fidelidad de imágenes de 1024px Celeba-HQ (muestras a continuación). Además, obtuvimos un valor de probabilidad de 2.99 bits/tenues en imágenes CIFAR-10 uniformemente descantadas.

Además de los modelos NCSN ++ y DDPM ++ en nuestro documento, esta base de código también vuelve a implementar muchos modelos de puntaje anteriores en un lugar, incluido NCSN del modelado generativo al estimar los gradientes de la distribución de datos, NCSNV2 a partir de técnicas mejoradas para los modelos generativos basados en puntajes de entrenamiento y DDPM a partir de modelos probabilísticos de difusión de descenso.
Admite capacitar nuevos modelos, evaluando la calidad de la muestra y las probabilidades de los modelos existentes. Diseñamos cuidadosamente el código para que sea modular y fácilmente extensible para nuevos SDE, predictores o correctores.
¿La mayoría de los modelos ahora también están disponibles en? Difusores y accesibles a través de la tubería de puntuación.
Los difusores le permiten probar modelos basados en SDE en Pytorch en solo un par de líneas de código.
Puede instalar difusores de la siguiente manera:
pip install diffusers torch accelerate
Y luego pruebe los modelos con solo un par de líneas de código:
from diffusers import DiffusionPipeline
model_id = "google/ncsnpp-ffhq-1024"
# load model and scheduler
sde_ve = DiffusionPipeline . from_pretrained ( model_id )
# run pipeline in inference (sample random noise and denoise)
image = sde_ve (). images [ 0 ]
# save image
image [ 0 ]. save ( "sde_ve_generated_image.png" )Se pueden encontrar más modelos directamente en el centro.
Encuentre una implementación de JAX aquí, que además admite la generación condicional de clase con un clasificador previamente capacitado y reanude un proceso de evaluación después de la preferencia.
En general, esta versión de Pytorch consume menos memoria pero funciona más lento que Jax. Aquí hay un punto de referencia sobre el entrenamiento de un NCSN ++ cont. Modelo con VE SDE. El hardware es 4x NVIDIA TESLA V100 GPU (32 GB)
| Estructura | Tiempo (segundo por paso) | Uso de la memoria en total (GB) |
|---|---|---|
| Pytorch | 0.56 | 20.6 |
Jax ( n_jitted_steps=1 ) | 0.30 | 29.7 |
Jax ( n_jitted_steps=5 ) | 0.20 | 74.8 |
Ejecute lo siguiente para instalar un subconjunto de los paquetes de Python necesarios para nuestro código
pip install -r requirements.txt Proporcionamos el archivo de estadísticas para CIFAR-10. Puede descargar cifar10_stats.npz y guardarlo en assets/stats/ . Consulte el #5 sobre cómo calcular este archivo de estadísticas para nuevos conjuntos de datos.
Entrena y evalúa nuestros modelos a través de main.py
main.py:
--config: Training configuration.
(default: ' None ' )
--eval_folder: The folder name for storing evaluation results
(default: ' eval ' )
--mode: < train | eval > : Running mode: train or eval
--workdir: Working directory config es la ruta al archivo de configuración. Nuestros archivos de configuración prescritos se proporcionan en configs/ . Están formateados de acuerdo con ml_collections y deberían explicarse bastante por sí mismos.
Convenciones de nombres de archivos de configuración : la ruta de un archivo de configuración es una combinación de las siguientes dimensiones:
cifar10 , celeba , celebahq , celebahq_256 , ffhq_256 , celebahq , ffhq .ncsn , ncsnv2 , ncsnpp , ddpm , ddpmpp . workdir es el camino que almacena todos los artefactos de un experimento, como puntos de control, muestras y resultados de evaluación.
eval_folder es el nombre de una subcarpeta en workdir que almacena todos los artefactos del proceso de evaluación, como los puntos de control meta para la prevención previa a la emisión, las muestras de imágenes y los vertederos de resultados cuantitativos.
mode es "tren" o "eval". Cuando se establece en "entrenar", comienza la capacitación de un nuevo modelo, o reanuda la capacitación de un modelo antiguo si sus meta-checkpoints (para reanudar la ejecución después de la prevención en un entorno de la nube) existen en workdir/checkpoints-meta . Cuando se establece en "eval", puede hacer una combinación arbitraria de lo siguiente
Evalúe la función de pérdida en el conjunto de datos de prueba / validación.
Genere un número fijo de muestras y calcule su puntaje de inicio, FID o Kid. Antes de la evaluación, los archivos de estadísticas ya deben haberse descargado/calculado y almacenado en assets/stats .
Calcule la probabilidad de registro en el conjunto de datos de capacitación o prueba.
Estas funcionalidades se pueden configurar a través de archivos de configuración, o más convenientemente, a través del soporte de línea de comandos del paquete ml_collections . Por ejemplo, para generar muestras y evaluar la calidad de la muestra, suministre el indicador --config.eval.enable_sampling ; Para calcular las versiones de registro log, suministre el indicador --config.eval.enable_bpd y especifique --config.eval.dataset=train/test para indicar si calcula las probabilidades en el conjunto de datos de capacitación o prueba.
sde_lib.SDE e implementan todos los métodos abstractos. El método discretize() es opcional y el valor predeterminado es la discretización de Euler-Maruyama. Los métodos de muestreo existentes y el cálculo de probabilidad funcionarán automáticamente para este nuevo SDE.sampling.Predictor Predictor, implementan el método abstracto update_fn y registre su nombre con @register_predictor . El nuevo predictor se puede utilizar directamente en sampling.get_pc_sampler para el muestreo predictor-corrector, y todos los demás métodos de generación controlable en controllable_generation.py .update_fn @register_corrector sampling.Corrector . El nuevo corrector se puede usar directamente en sampling.get_pc_sampler , y todos los demás métodos de generación controlable en controllable_generation.py . Todos los puntos de control se proporcionan en esta unidad de Google.
Instrucciones : puede encontrar dos puntos de control para algunos modelos. El primer punto de control (con un número más pequeño) es el que informamos los puntajes de FID en la Tabla 3 de nuestro documento (también corresponde al FID y es columnas en la tabla a continuación). El segundo punto de control (con un número más grande) es el que informamos valores de probabilidad y FID de muestras de oda de caja negra en las columnas de la Tabla 2 de nuestro artículo (también FID (ODE) y NNL (bits/dim) en la tabla a continuación). El primero corresponde al FID más pequeño durante el curso del entrenamiento (cada 50k iteraciones). El último es el último punto de control durante el entrenamiento.
Según la política de Google, no podemos lanzar nuestros puntos de control Celeba y Celeba-HQ originales. Dicho esto, he vuelto a capacitar modelos en FFHQ 1024px, FFHQ 256px y Celeba-HQ 256px con recursos personales, y lograron un rendimiento similar a nuestros puntos de control internos.
Aquí hay una lista detallada de puntos de control y sus resultados reportados en el documento. FID (ODE) corresponde a la calidad de muestra del solucionador de oda de caja negra aplicada a la ODE de flujo de probabilidad.
| Ruta de punto de control | DEFENSOR | ES | FID (ODE) | NNL (bits/dim) |
|---|---|---|---|---|
ve/cifar10_ncsnpp/ | 2.45 | 9.73 | - | - |
ve/cifar10_ncsnpp_continuous/ | 2.38 | 9.83 | - | - |
ve/cifar10_ncsnpp_deep_continuous/ | 2.20 | 9.89 | - | - |
vp/cifar10_ddpm/ | 3.24 | - | 3.37 | 3.28 |
vp/cifar10_ddpm_continuous | - | - | 3.69 | 3.21 |
vp/cifar10_ddpmpp | 2.78 | 9.64 | - | - |
vp/cifar10_ddpmpp_continuous | 2.55 | 9.58 | 3.93 | 3.16 |
vp/cifar10_ddpmpp_deep_continuous | 2.41 | 9.68 | 3.08 | 3.13 |
subvp/cifar10_ddpm_continuous | - | - | 3.56 | 3.05 |
subvp/cifar10_ddpmpp_continuous | 2.61 | 9.56 | 3.16 | 3.02 |
subvp/cifar10_ddpmpp_deep_continuous | 2.41 | 9.57 | 2.92 | 2.99 |
| Ruta de punto de control | Muestras |
|---|---|
ve/bedroom_ncsnpp_continuous | ![]() |
ve/church_ncsnpp_continuous | ![]() |
ve/ffhq_1024_ncsnpp_continuous | ![]() |
ve/ffhq_256_ncsnpp_continuous | ![]() |
ve/celebahq_256_ncsnpp_continuous | ![]() |
| Enlace | Descripción |
|---|---|
| Cargue nuestros puntos de control previos al detenido y juegue con muestreo, cálculo de probabilidad y síntesis controlable (Jax + Flax) | |
| Cargue nuestros puntos de control previos a la aparición y juegue con muestreo, computación de probabilidad y síntesis controlable (Pytorch) | |
| Tutorial de modelos generativos basados en puntaje en Jax + Flax | |
| Tutorial de modelos generativos basados en puntaje en Pytorch |
config.training.n_jitted_steps . Para CIFAR-10, recomendamos usar config.training.n_jitted_steps=5 cuando su GPU/TPU tiene suficiente memoria; De lo contrario, recomendamos usar config.training.n_jitted_steps=1 . Nuestra implementación actual requiere que config.training.log_freq sea dividible por n_jitted_steps para registrar y verificar que el punto de control funcione normalmente.snr (relación señal / ruido) del LangevinCorrector se comporta de alguna manera como un parámetro de temperatura. snr más grande generalmente da como resultado muestras más suaves, mientras que snr más pequeña proporciona muestras más diversas pero de menor calidad. Los valores típicos de snr son 0.05 - 0.2 , y requiere ajuste para atacar el punto dulce.config.model.sigma_max para que sea la distancia máxima por pares entre muestras de datos en el conjunto de datos de entrenamiento. Si encuentra el código útil para su investigación, considere citar
@inproceedings {
song2021scorebased,
title = { Score-Based Generative Modeling through Stochastic Differential Equations } ,
author = { Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole } ,
booktitle = { International Conference on Learning Representations } ,
year = { 2021 } ,
url = { https://openreview.net/forum?id=PxTIG12RRHS }
}Este trabajo se basa en algunos documentos anteriores que también podrían interesarle: