
Ce projet est partiellement pris en charge par Google TPU Research Cloud. Je tiens à remercier l'équipe de Google Cloud TPU de m'avoir fourni les ressources pour former les modèles plus importants en texte en texte dans des paramètres distribués à plusieurs hôtes.
Ces dernières années, les modèles multi-étapes basés sur la diffusion et les scores ont révolutionné le domaine AI génératif. Cependant, les dernières recherches dans ce domaine sont devenues très à forte intensité de mathématiques, ce qui rend difficile de comprendre comment les modèles de diffusion de pointe fonctionnent et génèrent des images aussi impressionnantes. La réplication de cette recherche en code peut être intimidante.
FlaxDiff est une bibliothèque d'outils (planificateurs, échantillonneurs, modèles, etc.) conçus et implémentés de manière facile à comprendre. L'accent est mis sur la compréhension et la lisibilité à la performance. J'ai commencé ce projet comme un passe-temps pour me familiariser avec le lin et le jax et en savoir plus sur la diffusion et les dernières recherches en AI génératrices.
J'ai d'abord commencé ce projet dans Keras, familiarisé avec TensorFlow 2.0, mais je suis passé au lin, propulsé par Jax, pour ses performances et sa facilité d'utilisation. Les anciens cahiers et modèles, y compris mes premiers modèles de lin, sont également fournis.
Le cahier Diffusion_flax_linen.ipynb est mon espace de travail principal pour les expériences. Plusieurs points de contrôle sont téléchargés dans le dossier pretrained avec une copie du cahier de travail associé à chaque point de contrôle. Vous devrez peut-être copier le cahier sur la racine de travail pour qu'il fonctionne correctement.
Dans le dossier example notebooks , vous trouverez des ordinateurs portables complets pour diverses techniques de diffusion, entièrement écrits à partir de zéro et sont indépendants de la bibliothèque FlaxDiff. Chaque cahier comprend des explications détaillées des mathématiques et des concepts sous-jacents, ce qui en fait des ressources inestimables pour apprendre et comprendre les modèles de diffusion.
Diffusion expliqué (lien nbViewer) (lien local)
EDM (élucider l'espace de conception des modèles génératifs basés sur la diffusion)
Ces cahiers visent à fournir un guide très facile à comprendre et étape par étape des différents modèles et techniques de diffusion. Ils sont conçus pour être adaptés aux débutants, et donc bien qu'ils puissent ne pas adhérer aux formulations et implémentations exactes des articles originaux pour les rendre plus compréhensibles et généralisables, j'ai fait de mon mieux pour les garder aussi précis que possible. Si vous trouvez des erreurs ou avez des suggestions, n'hésitez pas à ouvrir un problème ou une demande de traction.
Script de formation parallèle de données multi-hôtes à Jax
Utilitaires TPU pour faciliter la vie
J'ai travaillé comme chercheur à l'apprentissage automatique chez Hyperverge de 2019-2021, en me concentrant sur la vision par ordinateur, en particulier l'anticadrance faciale et la détection et la reconnaissance faciale. Depuis le passage à mon emploi actuel en 2021, je n'ai pas participé à autant de travaux de R&D, ce qui m'a amené à démarrer ce projet de TEP pour revoir et réapprendre les principes fondamentaux et me familiariser avec l'état de la technologie. Mon rôle actuel implique principalement l'ingénierie du système Golang avec un travail de ML appliqué qui vient de saupoudrer. Par conséquent, le code peut refléter mon parcours d'apprentissage. Veuillez pardonner toutes les erreurs et ouvrir un problème pour me le faire savoir.
De plus, peu de texte peut être généré à l'aide de Github Copilot, alors veuillez excuser toutes les erreurs dans le texte.
Implémenté dans flaxdiff.schedulers :
flaxdiff.schedulers.LinearNoiseSchedule ): un planificateur discret par-paramétré bêta.flaxdiff.schedulers.CosineNoiseSchedule ): un planificateur discret par-paramétré bêta.flaxdiff.schedulers.ExpNoiseSchedule ): un planificateur discret par-paramétré bêta.flaxdiff.schedulers.CosineContinuousNoiseScheduler ): un planificateur continu.flaxdiff.schedulers.CosineGeneralNoiseScheduler ): un planificateur de cosinus paramétré de sigma continu.flaxdiff.schedulers.KarrasVENoiseScheduler ): un planificateur continu paramétré sigma proposé par Karras et al. 2022, mieux adapté à l'inférence.flaxdiff.schedulers.EDMNoiseScheduler ): un planificateur continu paramétré sigma basé sur le modèle de diffusion exponentielle (EDM), le mieux adapté à la formation avec le KarraskarRasvenOiseScheler. Mise en œuvre dans flaxdiff.predictors :
flaxdiff.predictors.EpsilonPredictor ): prédit le bruit dans les données.flaxdiff.predictors.X0Predictor ): prédit les données d'origine des données bruyantes.flaxdiff.predictors.VPredictor ): prédit une combinaison linéaire des données et du bruit, couramment utilisées dans l'EDM.flaxdiff.predictors.KarrasEDMPredictor ): un prédicteur généralisé pour l'EDM, intégrant divers paramétrisations. Implémenté dans flaxdiff.samplers :
flaxdiff.samplers.DDPMSampler ): implémente le processus d'échantillonnage du modèle probabiliste de diffusion (DDPM) de la diffusion (DDPM).flaxdiff.samplers.DDIMSampler ): implémente le processus d'échantillonnage du modèle implicite de diffusion de diffusion (DDIM).flaxdiff.samplers.EulerSampler ): un échantillonneur ODE Solver utilisant la méthode d'Euler.flaxdiff.samplers.HeunSampler ): un échantillonneur de solveur ODE utilisant la méthode de Heun.flaxdiff.samplers.RK4Sampler ): un échantillonneur ODE Solver utilisant la méthode Runge-Kutta.flaxdiff.samplers.MultiStepDPM ): Implements a multi-step sampling method inspired by the Multistep DPM solver as presented here: tonyduan/diffusion) Implémenté dans flaxdiff.trainer :
flaxdiff.trainer.DiffusionTrainer ): une classe conçue pour faciliter la formation des modèles de diffusion. Il gère la boucle de formation, le calcul des pertes et les mises à jour du modèle. Implémenté dans flaxdiff.models :
flaxdiff.models.simple_unet.SimpleUNet ): un exemple d'architecture UNET pour les modèles de diffusion.flaxdiff.models.simple_unet.Upsample ), les downsampling ( flaxdiff.models.simple_unet.FouriedEmbedding ), flaxdiff.models.simple_unet.Downsample flaxdiff.models.simple_unet.AttentionBlock Blocs résiduels ( flaxdiff.models.simple_unet.ResidualBlock ). Pour installer FlaxDiff, vous devez avoir Python 3.10 ou plus. Installez les dépendances requises en utilisant:
pip install -r requirements.txtLes modèles ont été entraînés et testés avec JAX == 0,4,28 et lin == 0,8,4. Cependant, lorsque j'ai mis à jour le dernier JAX == 0.4.30 et le lin == 0,8,5, les modèles ont arrêté la formation. Il semble qu'il y ait eu un changement majeur pour briser la dynamique de la formation et donc je recommanderais de m'en tenir aux versions mentionnées dans les exigences.txt
Voici un exemple simplifié pour vous faire démarrer avec la formation d'un modèle de diffusion à l'aide de FlaxDiff:
from flaxdiff . schedulers import EDMNoiseScheduler
from flaxdiff . predictors import KarrasPredictionTransform
from flaxdiff . models . simple_unet import SimpleUNet as UNet
from flaxdiff . trainer import DiffusionTrainer
import jax
import optax
from datetime import datetime
BATCH_SIZE = 16
IMAGE_SIZE = 64
# Define noise scheduler
edm_schedule = EDMNoiseScheduler ( 1 , sigma_max = 80 , rho = 7 , sigma_data = 0.5 )
# Define model
unet = UNet ( emb_features = 256 ,
feature_depths = [ 64 , 128 , 256 , 512 ],
attention_configs = [{ "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }],
num_res_blocks = 2 ,
num_middle_res_blocks = 1 )
# Load dataset
data , datalen = get_dataset ( "oxford_flowers102" , batch_size = BATCH_SIZE , image_scale = IMAGE_SIZE )
batches = datalen // BATCH_SIZE
# Define optimizer
solver = optax . adam ( 2e-4 )
# Create trainer
trainer = DiffusionTrainer ( unet , optimizer = solver ,
noise_schedule = edm_schedule ,
rngs = jax . random . PRNGKey ( 4 ),
name = "Diffusion_SDE_VE_" + datetime . now (). strftime ( "%Y-%m-%d_%H:%M:%S" ),
model_output_transform = KarrasPredictionTransform ( sigma_data = edm_schedule . sigma_data ))
# Train the model
final_state = trainer . fit ( data , batches , epochs = 2000 )Voici un exemple simplifié pour générer des images à l'aide d'un modèle formé:
from flaxdiff . samplers import DiffusionSampler
class EulerSampler ( DiffusionSampler ):
def take_next_step ( self , current_samples , reconstructed_samples , pred_noise , current_step , state , next_step = None ):
current_alpha , current_sigma = self . noise_schedule . get_rates ( current_step )
next_alpha , next_sigma = self . noise_schedule . get_rates ( next_step )
dt = next_sigma - current_sigma
x_0_coeff = ( current_alpha * next_sigma - next_alpha * current_sigma ) / dt
dx = ( current_samples - x_0_coeff * reconstructed_samples ) / current_sigma
next_samples = current_samples + dx * dt
return next_samples , state
# Create sampler
sampler = EulerSampler ( trainer . model , trainer . state . ema_params , edm_schedule , model_output_transform = trainer . model_output_transform )
# Generate images
samples = sampler . generate_images ( num_images = 64 , diffusion_steps = 100 , start_step = 1000 , end_step = 0 )
plotImages ( samples , dpi = 300 ) Modèle formé sur le laion-aesthésique 12m + CC12M + MS Coco + 1M Aesthetic 6+ Sous-ensemble de coyo-700m sur TPU-V4-32: a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden
Params : Dataset: Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M Batch size: 256 Image Size: 128 Training Epochs: 5 Steps per epoch: 74573 Model Configurations: feature_depths=[128, 256, 512, 1024]
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

Images générées par les invites suivantes en utilisant des conseils gratuits du classificateur avec un facteur de guidage = 2: 'water tulip, a water lily, a water lily, a water lily, a photo of a marigold, a water lily, a water lily, a photo of a lotus, a photo of a lotus, a photo of a lotus, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose'
Paramètres : Dataset: oxford_flowers102 Batch size: 16 Image Size: 128 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

Images générées par les invites suivantes en utilisant des conseils gratuits au classificateur avec un facteur de guidage = 4: 'water tulip, a water lily, a water lily, a photo of a rose, a photo of a rose, a water lily, a water lily, a photo of a marigold, a photo of a marigold, a photo of a marigold, a water lily, a photo of a sunflower, a photo of a lotus, columbine, columbine, an orchid, an orchid, an orchid, a water lily, a water lily, a water lily, columbine, columbine, a photo of a sunflower, a photo of a sunflower, a photo of a sunflower, a photo of a lotus, a photo of a lotus, a photo of a marigold, a photo of a marigold, a photo of a rose, a photo of a rose, a photo of a rose, orange dahlia, orange dahlia, a lenten rose, a lenten rose, a water lily, a water lily, a water lily, a water lily, an orchid, an orchid, an orchid, hard-leaved pocket orchid, bird of paradise, bird of paradise, a photo of a lovely rose, a photo of a lovely rose, a photo of a globe-flower, a photo of a globe-flower, a photo of a lovely rose, a photo of a lovely rose, a photo of a ruby-lipped cattleya, a photo of a ruby-lipped cattleya, a photo of a lovely rose, a water lily, a osteospermum, a osteospermum, a water lily, a water lily, a water lily, a red rose, a red rose'
Paramètres : Dataset: oxford_flowers102 Batch size: 16 Image Size: 128 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

Paramètres : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: CosineNoiseSchedule Inference Noise Schedule: CosineNoiseSchedule
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

Paramètres : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: CosineNoiseSchedule Inference Noise Schedule: CosineNoiseSchedule
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

Paramètres : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

N'hésitez pas à contribuer en ouvrant des problèmes ou en soumettant des demandes de traction. Rendons Flaxdiff mieux ensemble!
Ce projet est autorisé sous la licence du MIT.