
Este projeto é parcialmente suportado pelo Google TPU Research Cloud. Gostaria de agradecer à equipe do Google Cloud TPU por me fornecer os recursos para treinar os maiores modelos condicionais de texto em configurações distribuídas de vários host.
Nos últimos anos, modelos de várias etapas baseadas em difusão e pontuação revolucionaram o domínio generativo de IA. No entanto, as pesquisas mais recentes nesse campo se tornaram altamente intensivas em matemática, tornando-o desafiador entender como os modelos de difusão de última geração funcionam e gerar imagens tão impressionantes. Replicar esta pesquisa em código pode ser assustador.
Flaxdiff é uma biblioteca de ferramentas (agendadores, amostradores, modelos etc.) projetada e implementada de uma maneira fácil de entender. O foco está na compreensão e na legibilidade sobre o desempenho. Comecei esse projeto como um hobby para me familiarizar com Flax e Jax e aprender sobre difusão e as pesquisas mais recentes em IA generativa.
Inicialmente, iniciei esse projeto em Keras, familiarizado com o TensorFlow 2.0, mas fiz fiz o linho, alimentado por Jax, por seu desempenho e facilidade de uso. Os cadernos e modelos antigos, incluindo meus primeiros modelos de linho, também são fornecidos.
O Notebook Diffusion_flax_linen.ipynb é o meu principal espaço de trabalho para experimentos. Vários pontos de verificação são enviados para a pasta pretrained , juntamente com uma cópia do caderno de trabalho associado a cada ponto de verificação. Pode ser necessário copiar o notebook para a raiz de trabalho para que ela funcione corretamente.
Na pasta example notebooks , você encontrará notebooks abrangentes para várias técnicas de difusão, escritas inteiramente do zero e são independentes da Biblioteca Flaxdiff. Cada caderno inclui explicações detalhadas da matemática e conceitos subjacentes, tornando -os recursos inestimáveis para aprender e entender os modelos de difusão.
Difusão explicada (link nbViewer) (link local)
EDM (elucidando o espaço de design de modelos generativos baseados em difusão)
Esses notebooks visam fornecer um guia muito fácil de entender e passo a passo para os vários modelos e técnicas de difusão. Eles foram projetados para serem adequados para iniciantes e, portanto, embora possam não aderir às formulações e implementações exatas dos artigos originais para torná-los mais compreensíveis e generalizáveis, tentei o meu melhor para mantê-los o mais preciso possível. Se você encontrar algum erro ou tiver alguma sugestão, sinta -se à vontade para abrir um problema ou uma solicitação de tração.
Script de treinamento paralelo de dados de vários hosts em Jax
Utilitários de TPU para facilitar a vida
Eu trabalhei como pesquisador de aprendizado de máquina na Hyperverge de 2019-2021, com foco na visão computacional, especificamente anti-span-spoofing e detecção facial e reconhecimento facial. Desde que mudei para o meu emprego atual em 2021, não envolvi tanto trabalho em pesquisa e desenvolvimento, levando-me a iniciar este projeto de estimação para revisitar e reaprender os fundamentos e me familiarizar com o estado da arte. Minha função atual envolve principalmente a engenharia do sistema de Golang com algum trabalho de ML aplicado. Portanto, o código pode refletir minha jornada de aprendizado. Por favor, perdoe todos os erros e abre um problema para me avisar.
Além disso, poucos do texto podem ser gerados com a ajuda do GitHub Copilot, por isso, desculpe quaisquer erros no texto.
Implementado em flaxdiff.schedulers :
flaxdiff.schedulers.LinearNoiseSchedule ): um agendador discreto beta-parametizado.flaxdiff.schedulers.CosineNoiseSchedule ): um agendador discreto beta-parameterizado.flaxdiff.schedulers.ExpNoiseSchedule ): um agendador discreto beta-parametizado.flaxdiff.schedulers.CosineContinuousNoiseScheduler ): um agendador contínuo.flaxdiff.schedulers.CosineGeneralNoiseScheduler ): um agendador de cosina parametrizado sigma contínuo.flaxdiff.schedulers.KarrasVENoiseScheduler ): um agendador contínuo de Sigma-Parameterized proposto por Karras et al. 2022, mais adequado para inferência.flaxdiff.schedulers.EDMNoiseScheduler ): um agendador contínuo parametorado da Sigma com base no modelo de difusão exponencial (EDM), mais adequado para treinamento com o Karraskarrasvenoscheduler. Implementado em flaxdiff.predictors :
flaxdiff.predictors.EpsilonPredictor ): prediz o ruído nos dados.flaxdiff.predictors.X0Predictor ): prevê os dados originais dos dados ruidosos.flaxdiff.predictors.VPredictor ): Prevendo uma combinação linear dos dados e ruídos, comumente usados no EDM.flaxdiff.predictors.KarrasEDMPredictor ): um preditor generalizado para o EDM, integrando várias parametrizações. Implementado em flaxdiff.samplers :
flaxdiff.samplers.DDPMSampler ): implementa o processo de amostragem do modelo probabilístico de difusão de denoising (DDPM).flaxdiff.samplers.DDIMSampler ): implementa o processo de amostragem de modelo implícito de difusão de denoising (DDIM).flaxdiff.samplers.EulerSampler ): um amostrador de solucionador de ODE usando o método de Euler.flaxdiff.samplers.HeunSampler ): um amostrador de solucionador de Ode usando o método de Heun.flaxdiff.samplers.RK4Sampler ): um amostrador de solucionador de ODE usando o método Runge-Kutta.flaxdiff.samplers.MultiStepDPM ): implementa um método de amostragem em várias etapas inspirado no solucionador de DPM de várias etapas, conforme apresentado aqui: Tonyduan/Difusão) Implementado em flaxdiff.trainer :
flaxdiff.trainer.DiffusionTrainer ): Uma classe projetada para facilitar o treinamento de modelos de difusão. Ele gerencia o loop de treinamento, o cálculo de perdas e as atualizações de modelos. Implementado em flaxdiff.models :
flaxdiff.models.simple_unet.SimpleUNet ): uma arquitetura de amostra para modelos de difusão.flaxdiff.models.simple_unet.Upsample ), downsampling ( flaxdiff.models.simple_unet.Downsample ), time incorpeddings ( flaxdiff.models.simple_unet.FouriedEmbedding ), flaxdingdings (flaxdiff.models.sImiffEnTeLet.FouLETMEddingding), flaxdiff.models.simple_unet.AttentionBlock . e blocos residuais ( flaxdiff.models.simple_unet.ResidualBlock ). Para instalar o Flaxdiff, você precisa ter o Python 3.10 ou superior. Instale as dependências necessárias usando:
pip install -r requirements.txtOs modelos foram treinados e testados com Jax == 0.4.28 e linho == 0,8.4. No entanto, quando atualizei para o mais recente Jax == 0.4.30 e linho == 0.8.5, os modelos pararam de treinar. Parece ter havido uma grande mudança quebrando a dinâmica de treinamento e, portanto, eu recomendaria manter as versões mencionadas nos requisitos.txt
Aqui está um exemplo simplificado para você começar com o treinamento de um modelo de difusão usando o FlaxDiff:
from flaxdiff . schedulers import EDMNoiseScheduler
from flaxdiff . predictors import KarrasPredictionTransform
from flaxdiff . models . simple_unet import SimpleUNet as UNet
from flaxdiff . trainer import DiffusionTrainer
import jax
import optax
from datetime import datetime
BATCH_SIZE = 16
IMAGE_SIZE = 64
# Define noise scheduler
edm_schedule = EDMNoiseScheduler ( 1 , sigma_max = 80 , rho = 7 , sigma_data = 0.5 )
# Define model
unet = UNet ( emb_features = 256 ,
feature_depths = [ 64 , 128 , 256 , 512 ],
attention_configs = [{ "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }],
num_res_blocks = 2 ,
num_middle_res_blocks = 1 )
# Load dataset
data , datalen = get_dataset ( "oxford_flowers102" , batch_size = BATCH_SIZE , image_scale = IMAGE_SIZE )
batches = datalen // BATCH_SIZE
# Define optimizer
solver = optax . adam ( 2e-4 )
# Create trainer
trainer = DiffusionTrainer ( unet , optimizer = solver ,
noise_schedule = edm_schedule ,
rngs = jax . random . PRNGKey ( 4 ),
name = "Diffusion_SDE_VE_" + datetime . now (). strftime ( "%Y-%m-%d_%H:%M:%S" ),
model_output_transform = KarrasPredictionTransform ( sigma_data = edm_schedule . sigma_data ))
# Train the model
final_state = trainer . fit ( data , batches , epochs = 2000 )Aqui está um exemplo simplificado para gerar imagens usando um modelo treinado:
from flaxdiff . samplers import DiffusionSampler
class EulerSampler ( DiffusionSampler ):
def take_next_step ( self , current_samples , reconstructed_samples , pred_noise , current_step , state , next_step = None ):
current_alpha , current_sigma = self . noise_schedule . get_rates ( current_step )
next_alpha , next_sigma = self . noise_schedule . get_rates ( next_step )
dt = next_sigma - current_sigma
x_0_coeff = ( current_alpha * next_sigma - next_alpha * current_sigma ) / dt
dx = ( current_samples - x_0_coeff * reconstructed_samples ) / current_sigma
next_samples = current_samples + dx * dt
return next_samples , state
# Create sampler
sampler = EulerSampler ( trainer . model , trainer . state . ema_params , edm_schedule , model_output_transform = trainer . model_output_transform )
# Generate images
samples = sampler . generate_images ( num_images = 64 , diffusion_steps = 100 , start_step = 1000 , end_step = 0 )
plotImages ( samples , dpi = 300 ) Model trained on Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M on TPU-v4-32: a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden
Parâmetros : Dataset: Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M Batch size: 256 Image Size: 128 Training Epochs: 5 Steps per epoch: 74573 Model Configurations: feature_depths=[128, 256, 512, 1024]
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

Images generated by the following prompts using classifier free guidance with guidance factor = 2: 'water tulip, a water lily, a water lily, a water lily, a photo of a marigold, a water lily, a water lily, a photo of a lotus, a photo of a lotus, a photo of a lotus, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose'
Params : Dataset: oxford_flowers102 Batch size: 16 Image Size: 128 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

Images generated by the following prompts using classifier free guidance with guidance factor = 4: 'water tulip, a water lily, a water lily, a photo of a rose, a photo of a rose, a water lily, a water lily, a photo of a marigold, a photo of a marigold, a photo of a marigold, a water lily, a photo of a sunflower, a photo of a lotus, columbine, columbine, an orchid, an orchid, an orchid, a water lily, a water lily, a water lily, columbine, columbine, a photo of a sunflower, a photo of a sunflower, a photo of a sunflower, a photo of a lotus, a photo of a lotus, a photo of a marigold, a photo of a marigold, a photo of a rose, a photo of a rose, a photo of a rose, orange dahlia, orange dahlia, a lenten rose, a lenten rose, a water lily, a water lily, a water lily, a water lily, an orchid, an orchid, an orchid, hard-leaved pocket orchid, bird of paradise, bird of paradise, a photo of a lovely rose, a photo of a lovely rose, a photo of a globe-flower, a photo of a globe-flower, a photo of a lovely rose, a photo of a lovely rose, a photo of a ruby-lipped cattleya, a photo of a ruby-lipped cattleya, a photo of a lovely rose, a water lily, a osteospermum, a osteospermum, a water lily, a water lily, a water lily, a red rose, a red rose'
Params : Dataset: oxford_flowers102 Batch size: 16 Image Size: 128 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

Params : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: CosineNoiseSchedule Inference Noise Schedule: CosineNoiseSchedule
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

Params : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: CosineNoiseSchedule Inference Noise Schedule: CosineNoiseSchedule
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

Params : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

Sinta -se à vontade para contribuir abrindo questões ou enviando solicitações de tração. Vamos melhorar o Flaxdiff juntos!
Este projeto está licenciado sob a licença do MIT.