
Este proyecto está parcialmente compatible con Google TPU Research Cloud. Me gustaría agradecer al equipo de Google Cloud TPU por proporcionarme los recursos para capacitar a los modelos condicionales de texto más grandes en la configuración distribuida de múltiples host.
En los últimos años, la difusión y los modelos de múltiples pasos basados en puntajes han revolucionado el dominio generativo de IA. Sin embargo, la última investigación en este campo se ha vuelto muy intensiva en matemáticas, lo que hace que sea difícil comprender cómo funcionan los modelos de difusión de vanguardia y generan imágenes tan impresionantes. Replicar esta investigación en código puede ser desalentadora.
FlaxDiff es una biblioteca de herramientas (programadores, muestreadores, modelos, etc.) diseñada e implementada de una manera fácil de entender. La atención se centra en la comprensión y la legibilidad sobre el rendimiento. Comencé este proyecto como un pasatiempo para familiarizarme con Flax y Jax y aprender sobre la difusión y las últimas investigaciones en IA generativa.
Inicialmente comencé este proyecto en Keras, familiarizado con Tensorflow 2.0, pero hacía la transición al lino, impulsado por Jax, por su rendimiento y facilidad de uso. También se proporcionan los viejos cuadernos y modelos, incluidos mis primeros modelos de lino.
El cuaderno Diffusion_flax_linen.ipynb es mi espacio de trabajo principal para experimentos. Se cargan varios puntos de control en la carpeta pretrained junto con una copia del cuaderno de trabajo asociado con cada punto de control. Es posible que deba copiar el cuaderno a la raíz de trabajo para que funcione correctamente.
En la carpeta de example notebooks , encontrará cuadernos completos para varias técnicas de difusión, escritas completamente desde cero y son independientes de la biblioteca FlaxDiff. Cada cuaderno incluye explicaciones detalladas de las matemáticas y conceptos subyacentes, haciéndolos recursos invaluables para aprender y comprender los modelos de difusión.
Difusión explicada (enlace NBViewer) (enlace local)
EDM (dilucidar el espacio de diseño de los modelos generativos basados en difusión)
Estos cuadernos tienen como objetivo proporcionar una guía muy fácil de entender y paso a paso para los diversos modelos y técnicas de difusión. Están diseñados para ser amigables para principiantes y, por lo tanto, aunque no se adhieran a las formulaciones e implementaciones exactas de los documentos originales para que sean más comprensibles y generalizables, he hecho todo lo posible para mantenerlos lo más precisos posible. Si encuentra algún error o tiene alguna sugerencia, no dude en abrir un problema o una solicitud de extracción.
Script de entrenamiento paralelo de datos multi-host en Jax
Utilidades de TPU para facilitar la vida
Trabajé como investigador de aprendizaje automático en Hyperverge de 2019-2021, centrándome en la visión por computadora, específicamente anti-depósito facial y detección y reconocimiento facial. Desde que cambié a mi trabajo actual en 2021, no he participado en tanto trabajo de I + D, llevándome a comenzar este proyecto de mascotas para volver a visitar y volver a aprender los fundamentos y familiarizarme con el estado del arte. Mi rol actual implica principalmente ingeniería de sistemas de Golang con algunos trabajos de ML aplicados acaba de esparcirse. Por lo tanto, el código puede reflejar mi viaje de aprendizaje. Perdona cualquier error y abra un problema para avisarme.
Además, pocos de los textos pueden generarse con la ayuda de GitHub Copilot, así que disculpe cualquier error en el texto.
Implementado en flaxdiff.schedulers :
flaxdiff.schedulers.LinearNoiseSchedule ): un programador discreto beta-parameterizado.flaxdiff.schedulers.CosineNoiseSchedule ): un programador discreto beta-parametrizado.flaxdiff.schedulers.ExpNoiseSchedule ): un programador discreto beta-parameterizado.flaxdiff.schedulers.CosineContinuousNoiseScheduler ): un programador continuo.flaxdiff.schedulers.CosineGeneralNoiseScheduler ): un programador de coseno parametrizado continuo de Sigma.flaxdiff.schedulers.KarrasVENoiseScheduler ): un programador continuo parametrizado por sigma propuesto por Karras et al. 2022, más adecuado para la inferencia.flaxdiff.schedulers.EDMNoiseScheduler ): un programador continuo parametrizado por sigma basado en el Modelo de Difusión Exponencial (EDM), más adecuado para entrenar con el Karraskarrasoisescheduler. Implementado en flaxdiff.predictors :
flaxdiff.predictors.EpsilonPredictor ): predice el ruido en los datos.flaxdiff.predictors.X0Predictor ): predice los datos originales de los datos ruidosos.flaxdiff.predictors.VPredictor ): predice una combinación lineal de los datos y el ruido, comúnmente utilizado en el EDM.flaxdiff.predictors.KarrasEDMPredictor ): un predictor generalizado para el EDM, que integra varias parametrizaciones. Implementado en flaxdiff.samplers :
flaxdiff.samplers.DDPMSampler ): implementa el proceso de muestreo del modelo de difusión de difusión (DDPM).flaxdiff.samplers.DDIMSampler ): implementa el proceso de muestreo del modelo implícito de difusión (DDIM).flaxdiff.samplers.EulerSampler ): un muestreador de solucionadores ODE usando el método de Euler.flaxdiff.samplers.HeunSampler ): un muestreador de solucionadores ODE usando el método de Heun.flaxdiff.samplers.RK4Sampler ): un muestreador de solucionadores ODE usando el método Runge-Kutta.flaxdiff.samplers.MultiStepDPM ): implementa un método de muestreo de múltiples pasos inspirado en el solucionador DPM de varios artes, como se presenta aquí: Tonyduan/difusión) Implementado en flaxdiff.trainer :
flaxdiff.trainer.DiffusionTrainer ): una clase diseñada para facilitar el entrenamiento de los modelos de difusión. Gestiona el bucle de entrenamiento, el cálculo de pérdidas y las actualizaciones del modelo. Implementado en flaxdiff.models :
flaxdiff.models.simple_unet.SimpleUNet ): una muestra de arquitectura de unlo para modelos de difusión.flaxdiff.models.simple_unet.Upsample ), downsampling ( flaxdiff.models.simple_unet.Downsample ), Time embeddings ( flaxdiff.models.simple_unet.FouriedEmbedding ), attention ( flaxdiff.models.simple_unet.AttentionBlock ), and Bloques residuales ( flaxdiff.models.simple_unet.ResidualBlock ). Para instalar FlaxDiff, debe tener Python 3.10 o más. Instale las dependencias requeridas usando:
pip install -r requirements.txtLos modelos fueron entrenados y probados con Jax == 0.4.28 y FLax == 0.8.4. Sin embargo, cuando actualicé al último Jax == 0.4.30 y FLax == 0.8.5, los modelos dejaron de entrenamiento. Parece que ha habido un cambio importante que rompa la dinámica de entrenamiento y, por lo tanto, recomendaría seguir con las versiones mencionadas en los requisitos.
Aquí hay un ejemplo simplificado para comenzar con el entrenamiento de un modelo de difusión usando Flaxdiff:
from flaxdiff . schedulers import EDMNoiseScheduler
from flaxdiff . predictors import KarrasPredictionTransform
from flaxdiff . models . simple_unet import SimpleUNet as UNet
from flaxdiff . trainer import DiffusionTrainer
import jax
import optax
from datetime import datetime
BATCH_SIZE = 16
IMAGE_SIZE = 64
# Define noise scheduler
edm_schedule = EDMNoiseScheduler ( 1 , sigma_max = 80 , rho = 7 , sigma_data = 0.5 )
# Define model
unet = UNet ( emb_features = 256 ,
feature_depths = [ 64 , 128 , 256 , 512 ],
attention_configs = [{ "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }],
num_res_blocks = 2 ,
num_middle_res_blocks = 1 )
# Load dataset
data , datalen = get_dataset ( "oxford_flowers102" , batch_size = BATCH_SIZE , image_scale = IMAGE_SIZE )
batches = datalen // BATCH_SIZE
# Define optimizer
solver = optax . adam ( 2e-4 )
# Create trainer
trainer = DiffusionTrainer ( unet , optimizer = solver ,
noise_schedule = edm_schedule ,
rngs = jax . random . PRNGKey ( 4 ),
name = "Diffusion_SDE_VE_" + datetime . now (). strftime ( "%Y-%m-%d_%H:%M:%S" ),
model_output_transform = KarrasPredictionTransform ( sigma_data = edm_schedule . sigma_data ))
# Train the model
final_state = trainer . fit ( data , batches , epochs = 2000 )Aquí hay un ejemplo simplificado para generar imágenes utilizando un modelo entrenado:
from flaxdiff . samplers import DiffusionSampler
class EulerSampler ( DiffusionSampler ):
def take_next_step ( self , current_samples , reconstructed_samples , pred_noise , current_step , state , next_step = None ):
current_alpha , current_sigma = self . noise_schedule . get_rates ( current_step )
next_alpha , next_sigma = self . noise_schedule . get_rates ( next_step )
dt = next_sigma - current_sigma
x_0_coeff = ( current_alpha * next_sigma - next_alpha * current_sigma ) / dt
dx = ( current_samples - x_0_coeff * reconstructed_samples ) / current_sigma
next_samples = current_samples + dx * dt
return next_samples , state
# Create sampler
sampler = EulerSampler ( trainer . model , trainer . state . ema_params , edm_schedule , model_output_transform = trainer . model_output_transform )
# Generate images
samples = sampler . generate_images ( num_images = 64 , diffusion_steps = 100 , start_step = 1000 , end_step = 0 )
plotImages ( samples , dpi = 300 ) Modelo entrenado en laion-Aesthetics 12m + CC12m + Ms Coco + 1M Subesthetic 6+ subconjunto de Coyo-700m en TPU-V4-32: a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden
Parámetros : Dataset: Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M Batch size: 256 Image Size: 128 Training Epochs: 5 Steps per epoch: 74573 Model Configurations: feature_depths=[128, 256, 512, 1024]
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

Images generated by the following prompts using classifier free guidance with guidance factor = 2: 'water tulip, a water lily, a water lily, a water lily, a photo of a marigold, a water lily, a water lily, a photo of a lotus, a photo of a lotus, a photo of a lotus, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose'
Parámetros : Dataset: oxford_flowers102 Batch size: 16 Image Size: 128 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

Images generated by the following prompts using classifier free guidance with guidance factor = 4: 'water tulip, a water lily, a water lily, a photo of a rose, a photo of a rose, a water lily, a water lily, a photo of a marigold, a photo of a marigold, a photo of a marigold, a water lily, a photo of a sunflower, a photo of a lotus, columbine, columbine, an orchid, an orchid, an orchid, a water lily, a water lily, a water lily, columbine, columbine, a photo of a sunflower, a photo of a sunflower, a photo of a sunflower, a photo of a lotus, a photo of a lotus, a photo of a marigold, a photo of a marigold, a photo of a rose, a photo of a rose, a photo of a rose, orange dahlia, orange dahlia, a lenten rose, a lenten rose, a water lily, a water lily, a water lily, a water lily, an orchid, an orchid, an orchid, hard-leaved pocket orchid, bird of paradise, bird of paradise, a photo of a lovely rose, a photo of a lovely rose, a photo of a globe-flower, a photo of a globe-flower, a photo of a lovely rose, a photo of a lovely rose, a photo of a ruby-lipped cattleya, a photo of a ruby-lipped cattleya, a photo of a lovely rose, a water lily, a osteospermum, a osteospermum, a water lily, a water lily, a water lily, a red rose, a red rose'
Parámetros : Dataset: oxford_flowers102 Batch size: 16 Image Size: 128 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

Parámetros : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: CosineNoiseSchedule Inference Noise Schedule: CosineNoiseSchedule
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

Parámetros : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: CosineNoiseSchedule Inference Noise Schedule: CosineNoiseSchedule
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

Parámetros : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

Siéntase libre de contribuir abriendo problemas o enviando solicitudes de extracción. ¡Hagamos mejor a Flaxdiff juntos!
Este proyecto tiene licencia bajo la licencia MIT.