Analizar y mejorar la dinámica de entrenamiento de los modelos de difusión
Tero Karras, Miika Aittala, Jaakko Lehtinen, Janne Hellsten, Timo Aila, Samuli Laine
https://arxiv.org/abs/2312.02696
Contribuciones clave:
conda env create -f environment.yml -n edmconda activate edmPara reproducir los principales resultados de nuestro artículo, simplemente ejecute:
python example.pyEste es un script independiente mínimo que carga el mejor modelo previamente capacitado para cada conjunto de datos y genera una cuadrícula aleatoria de 8x8 de imágenes utilizando la configuración de muestras óptima. Resultados esperados:
| Conjunto de datos | Tiempo de ejecución | Imagen de referencia |
|---|---|---|
| Cifar-10 | ~ 6 segundos | cifar10-32x32.png |
| Ffhq | ~ 28 segundos | ffhq-64x64.png |
| AFHQV2 | ~ 28 segundos | afhqv2-64x64.png |
| Imagenet | ~ 5 min | imagenet-64x64.png |
La forma más fácil de explorar diferentes estrategias de muestreo es modificar example.py directamente. También puede incorporar los modelos previamente capacitados y/o nuestra muestra EDM propuesta en su propio código simplemente copiando los bits relevantes. Tenga en cuenta que las definiciones de clase para los modelos previamente capacitados se almacenan dentro de los encurtidos y se cargan automáticamente durante la falta de torch_utils.persistence . Para usar los modelos en scripts de pitón externos, solo asegúrese de que torch_utils y dnnlib sean accesibles a través de PYTHONPATH .
Docker : puede ejecutar el script de ejemplo usando Docker de la siguiente manera:
# Build the edm:latest image
docker build --tag edm:latest .
# Run the generate.py script using Docker:
docker run --gpus all -it --rm --user $( id -u ) : $( id -g )
-v ` pwd ` :/scratch --workdir /scratch -e HOME=/scratch
edm:latest
python example.py Nota: La imagen Docker requiere la versión del controlador NVIDIA r520 o posterior.
La invocación de docker run puede parecer desalentador, así que desempaquemos su contenido aquí:
--gpus all -it --rm --user $(id -u):$(id -g) : con todas las GPU habilitadas, ejecute una sesión interactiva con UID/GID del usuario actual para evitar que Docker escriba archivos como root.-v `pwd`:/scratch --workdir /scratch : monte la corriente de corriente (por ejemplo, la parte superior de este repositorio de git en su máquina host) a /scratch en el contenedor y úselo como el directorio de trabajo actual.-e HOME=/scratch : especifique dónde almacenar en caché los archivos temporales. Nota: Si desea más control de grano fino, en su lugar puede configurar DNNLIB_CACHE_DIR (para la caché de descarga de modelo previamente capacitado). Desea que estos directivos de caché residan en volúmenes persistentes para que su contenido se retenga en múltiples invocaciones docker run . Proporcionamos modelos previamente capacitados para nuestra configuración de capacitación propuesta (config f), así como la configuración de línea de base (config a):
Para generar un lote de imágenes utilizando un modelo y una muestra determinada, ejecute:
# Generate 64 images and save them as out/*.png
python generate.py --outdir=out --seeds=0-63 --batch=64
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl Generar una gran cantidad de imágenes puede llevar mucho tiempo; La carga de trabajo se puede distribuir a través de múltiples GPU iniciando el comando anterior usando torchrun :
# Generate 1024 images using 2 GPUs
torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-999 --batch=64
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl La configuración de la muestra se puede controlar a través de opciones de línea de comandos; Consulte python generate.py --help para obtener más información. Para obtener los mejores resultados, recomendamos usar la siguiente configuración para cada conjunto de datos:
# For CIFAR-10 at 32x32, use deterministic sampling with 18 steps (NFE = 35)
python generate.py --outdir=out --steps=18
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
# For FFHQ and AFHQv2 at 64x64, use deterministic sampling with 40 steps (NFE = 79)
python generate.py --outdir=out --steps=40
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-ffhq-64x64-uncond-vp.pkl
# For ImageNet at 64x64, use stochastic sampling with 256 steps (NFE = 511)
python generate.py --outdir=out --steps=256 --S_churn=40 --S_min=0.05 --S_max=50 --S_noise=1.003
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-imagenet-64x64-cond-adm.pkl Además de nuestra muestra EDM propuesta, generate.py también se puede usar para reproducir las ablaciones de muestras de la Sección 3 de nuestro documento. Por ejemplo:
# Figure 2a, "Our reimplementation"
python generate.py --outdir=out --steps=512 --solver=euler --disc=vp --schedule=vp --scaling=vp
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/baseline-cifar10-32x32-uncond-vp.pkl
# Figure 2a, "+ Heun & our {t_i}"
python generate.py --outdir=out --steps=128 --solver=heun --disc=edm --schedule=vp --scaling=vp
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/baseline-cifar10-32x32-uncond-vp.pkl
# Figure 2a, "+ Our sigma(t) & s(t)"
python generate.py --outdir=out --steps=18 --solver=heun --disc=edm --schedule=linear --scaling=none
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/baseline-cifar10-32x32-uncond-vp.pkl Para calcular la distancia de inicio de Fréchet (FID) para un modelo y muestreador determinados, primero genere 50,000 imágenes aleatorias y luego compárelas con las estadísticas de referencia del conjunto de datos usando fid.py :
# Generate 50000 images and save them as fid-tmp/*/*.png
torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
# Calculate FID
torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp
--ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz Ambos comandos anteriores pueden ser paralelos en múltiples GPU ajustando --nproc_per_node . El segundo comando generalmente toma 1-3 minutos en la práctica, pero el primero a veces puede tomar varias horas, dependiendo de la configuración. Consulte python fid.py --help para la lista completa de opciones.
Tenga en cuenta que el valor numérico de FIF varía en diferentes semillas aleatorias y es altamente sensible al número de imágenes. Por defecto, fid.py siempre usará 50,000 imágenes generadas; Proporcionar menos imágenes dará como resultado un error, mientras que proporcionar más utilizará un subconjunto aleatorio. Para reducir el efecto de --seeds=50000-99999 variación aleatoria, recomendamos repetir el cálculo varias veces con diferentes semillas, --seeds=0-49999 . --seeds=100000-149999 . En nuestro artículo, calculamos cada FID tres veces e informamos el mínimo.
También tenga en cuenta que es importante comparar las imágenes generadas con el mismo conjunto de datos con el que el modelo fue entrenado originalmente. Para facilitar la evaluación, proporcionamos las estadísticas de referencia exactas que corresponden a nuestros modelos previamente capacitados:
Para ImageNet, proporcionamos dos conjuntos de estadísticas de referencia para habilitar la comparación de manzanas a manzanas: imagenet-64x64.npz debe usarse al evaluar el modelo EDM ( edm-imagenet-64x64-cond-adm.pkl ), por lo cual imagenet-64x64-baseline.npz se debe usar cuando se evalúa el modelo de base (((Basea (((base (((basura (((base ((base ((base ((base ((basura ((basura (((base ((línea baseline-imagenet-64x64-cond-adm.pkl ); Este último fue entrenado originalmente por Dhariwal y Nichol utilizando datos de entrenamiento ligeramente diferentes.
Puede calcular las estadísticas de referencia para sus propios conjuntos de datos de la siguiente manera:
python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz Los conjuntos de datos se almacenan en el mismo formato que en Stylegan: Archivos ZIP sin comprimir que contienen archivos PNG sin comprimir y un dataset.json de archivos de metadatos. Json para etiquetas. Se pueden crear conjuntos de datos personalizados a partir de una carpeta que contiene imágenes; Consulte python dataset_tool.py --help para obtener más información.
CIFAR-10: Descargue la versión CIFAR-10 Python y conviértase en el archivo ZIP:
python dataset_tool.py --source=downloads/cifar10/cifar-10-python.tar.gz
--dest=datasets/cifar10-32x32.zip
python fid.py ref --data=datasets/cifar10-32x32.zip --dest=fid-refs/cifar10-32x32.npzFFHQ: Descargue el conjunto de datos Flickr-Faces-HQ como imágenes 1024x1024 y conviértase en el archivo ZIP a la resolución 64x64:
python dataset_tool.py --source=downloads/ffhq/images1024x1024
--dest=datasets/ffhq-64x64.zip --resolution=64x64
python fid.py ref --data=datasets/ffhq-64x64.zip --dest=fid-refs/ffhq-64x64.npz AFHQV2: Descargue el conjunto de datos de HQ de Animal Faces-HQ actualizado ( afhq-v2-dataset ) y conviértase en el archivo ZIP a la resolución 64x64:
python dataset_tool.py --source=downloads/afhqv2
--dest=datasets/afhqv2-64x64.zip --resolution=64x64
python fid.py ref --data=datasets/afhqv2-64x64.zip --dest=fid-refs/afhqv2-64x64.npzImageNet: descargue el desafío de localización de objetos de ImageNet y convierta en el archivo ZIP en la resolución 64x64:
python dataset_tool.py --source=downloads/imagenet/ILSVRC/Data/CLS-LOC/train
--dest=datasets/imagenet-64x64.zip --resolution=64x64 --transform=center-crop
python fid.py ref --data=datasets/imagenet-64x64.zip --dest=fid-refs/imagenet-64x64.npz Puede entrenar nuevos modelos con train.py . Por ejemplo:
# Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs
torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs
--data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp El ejemplo anterior utiliza el tamaño de lote predeterminado de 512 imágenes (controladas por --batch ) que se divide uniformemente entre 8 GPU (controladas por --nproc_per_node ) para producir 64 imágenes por GPU. Entrenar modelos grandes puede quedarse sin memoria GPU; La mejor manera de evitar esto es limitar el tamaño de lote de por GPU, por ejemplo, --batch-gpu=32 . Esto emplea la acumulación de gradiente para producir los mismos resultados que usar lotes por GPU completos. Consulte python train.py --help para la lista completa de opciones.
Los resultados de cada ejecución de capacitación se guardan en un directorio recién creado, por ejemplo training-runs/00000-cifar10-cond-ddpmpp-edm-gpus8-batch64-fp32 . El bucle de entrenamiento exporta instantáneas de red ( network-snapshot-*.pkl ) y estados de entrenamiento ( training-state-*.pt ) a intervalos regulares (controlados por --snap y --dump ). Las instantáneas de la red se pueden usar para generar imágenes con generate.py , y los estados de entrenamiento se pueden usar para reanudar la capacitación más adelante ( --resume ). Otra información útil se registra en log.txt y stats.jsonl . Para monitorear la convergencia de entrenamiento, recomendamos observar la pérdida de entrenamiento ( "Loss/loss" en stats.jsonl ), así como evaluar periódicamente FID para network-snapshot-*.pkl usando generate.py y fid.py
La siguiente tabla enumera las configuraciones de entrenamiento exactas que utilizamos para obtener nuestros modelos previamente capacitados:
| Modelo | GPU | Tiempo | Opción |
|---|---|---|---|
| CIFAR10‑32X32 -COND -VP | 8xv100 | ~ 2 días | --cond=1 --arch=ddpmpp |
| CIFAR10‑32X32 -COND -VE | 8xv100 | ~ 2 días | --cond=1 --arch=ncsnpp |
| CIFAR10‑32X32 - ACOND - VP | 8xv100 | ~ 2 días | --cond=0 --arch=ddpmpp |
| cifar10‑32x32 -abond -ve | 8xv100 | ~ 2 días | --cond=0 --arch=ncsnpp |
| FFHQ -64X64 - ACOND - VP | 8xv100 | ~ 4 días | --cond=0 --arch=ddpmpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.05 --augment=0.15 |
| ffhq -64x64 -ovond -ve | 8xv100 | ~ 4 días | --cond=0 --arch=ncsnpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.05 --augment=0.15 |
| AFHQV2‑64X64 - ANCOND - VP | 8xv100 | ~ 4 días | --cond=0 --arch=ddpmpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.25 --augment=0.15 |
| AFHQV2‑64X64 - ANCOND -VE | 8xv100 | ~ 4 días | --cond=0 --arch=ncsnpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.25 --augment=0.15 |
| Imagenet -64x64 -acond -adm | 32xa100 | ~ 13 días | --cond=1 --arch=adm --duration=2500 --batch=4096 --lr=1e-4 --ema=50 --dropout=0.10 --augment=0 --fp16=1 --ls=100 --tick=200 |
Para ImageNet-64, ejecutamos el entrenamiento en cuatro nodos NVIDIA DGX A100, cada uno con 8 amperios GPU con 80 GB de memoria. Para reducir los requisitos de memoria de GPU, recomendamos capacitar el modelo con más GPU o limitar el tamaño de lote de por GPU con --batch-gpu . Para configurar la capacitación en múltiples nodos, consulte la documentación de Torchrun.
Copyright © 2022, Nvidia Corporation & Affiliates. Reservados todos los derechos.
Todo el material, incluido el código fuente y los modelos previamente capacitados, tiene licencia bajo la atribución de Creative Commons no comercial-sharealike 4.0 International.
baseline-cifar10-32x32-uncond-vp.pkl y baseline-cifar10-32x32-uncond-ve.pkl se derivan de los modelos previamente capacitados por Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon y Ben Poole. Los modelos se compartieron originalmente bajo la licencia Apache 2.0.
baseline-imagenet-64x64-cond-adm.pkl se deriva del modelo previamente entrenado por Prafulla Dhariwal y Alex Nichol. El modelo se compartió originalmente bajo la licencia MIT.
imagenet-64x64-baseline.npz se deriva de las estadísticas de referencia precomputadas por Prafulla Dhariwal y Alex Nichol. Las estadísticas se compartieron originalmente bajo la licencia MIT.
@inproceedings{Karras2022edm,
author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
title = {Elucidating the Design Space of Diffusion-Based Generative Models},
booktitle = {Proc. NeurIPS},
year = {2022}
}
Esta es una implementación de referencia de investigación y se trata como una caída de código única. Como tal, no aceptamos contribuciones de código externo en forma de solicitudes de extracción.
Agradecemos a Jaakko Lehtinen, Ming-yu Liu, Tuomas Kynkäänniemi, Axel Sauer, Aash Vahdat y Janne Hellsten por sus discusiones y comentarios, y Tero Kuosmanen, Samuel Klenberg y Janne Hellsten por mantener nuestra infraestructura computada.