
该项目得到了Google TPU研究云的部分支持。我要感谢Google Cloud TPU团队为我提供了在多主机分布式设置中培训更大文本条件模型的资源。
近年来,扩散和基于得分的多步模型已彻底改变了生成的AI领域。但是,该领域的最新研究已变得高度数学密集型,使了解最先进的扩散模型如何工作并产生令人印象深刻的图像变得具有挑战性。在代码中复制这项研究可能令人生畏。
FlaxDiff是以易于理解的方式设计和实施的工具(调度程序,采样器,模型等)的库。重点是对性能的可理解性和可读性。我开始了这个项目,是一种爱好,以熟悉亚麻和jax,并了解扩散和生成AI的最新研究。
我最初在Keras启动了这个项目,熟悉Tensorflow 2.0,但由于其性能和易用性而过渡到由JAX提供动力的亚麻。还提供了旧的笔记本电脑和型号,包括我的第一个亚麻模型。
Diffusion_flax_linen.ipynb笔记本是我实验的主要工作区。将几个检查点上传到pretrained文件夹,以及与每个检查点关联的工作笔记本的副本。您可能需要将笔记本复制到工作根,以使其正常运行。
在example notebooks文件夹中,您将找到有关各种扩散技术的全面笔记本,这些笔记本完全是从头开始编写的,并且独立于Flaxdiff库。每个笔记本都包括对基本数学和概念的详细说明,使其成为学习和理解扩散模型的宝贵资源。
扩散解释(NBViewer链接)(本地链接)
EDM(阐明基于扩散的生成模型的设计空间)
这些笔记本旨在为各种扩散模型和技术提供非常易于理解和逐步指南。它们被设计为对初学者的友好型,因此尽管它们可能不遵守原始论文的确切表述和实现,以使其更容易理解和推广,但我还是尽力使它们尽可能准确。如果您发现任何错误或有任何建议,请随时打开问题或提取请求。
JAX中的多主宿主数据并行培训脚本
TPU公用事业可使生活更轻松
从2019 - 2021年开始,我曾在Hyperverge担任机器学习研究人员,重点关注计算机视觉,特别是面部反欺骗和面部检测和识别。自从2021年改用我目前的工作以来,我从未从事过太多的研发工作,导致我开始了这个宠物项目,以重新访问和重新学习基础知识,并熟悉最先进的工作。我目前的角色主要涉及Golang System Engineering,其中一些应用的ML工作刚刚涌入。因此,代码可能反映了我的学习旅程。请原谅任何错误,并打开一个让我知道的问题。
另外,在Github Copilot的帮助下,很少有文本可以生成,因此请原谅文本中的任何错误。
在flaxdiff.schedulers中实施:
flaxdiff.schedulers.LinearNoiseSchedule ):β参数化的离散调度程序。flaxdiff.schedulers.CosineNoiseSchedule ):β参数化的离散调度程序。flaxdiff.schedulers.ExpNoiseSchedule ):β参数化的离散调度程序。flaxdiff.schedulers.CosineContinuousNoiseScheduler ):连续的调度程序。flaxdiff.schedulers.CosineGeneralNoiseScheduler ):连续的Sigma参数化余弦调度程序。flaxdiff.schedulers.KarrasVENoiseScheduler ):Sigma参数化的连续调度程序,由Karras等人提出。 2022年,最适合推断。flaxdiff.schedulers.EDMNoiseScheduler ):基于指数扩散模型(EDM)的Sigma参数连续调度程序,最适合与KarraskAraskArlasvenoisscheduler进行培训。在flaxdiff.predictors中实施。预告:
flaxdiff.predictors.EpsilonPredictor ):预测数据中的噪声。flaxdiff.predictors.X0Predictor ):从嘈杂数据中预测原始数据。flaxdiff.predictors.VPredictor ):预测EDM中常用的数据和噪声的线性组合。flaxdiff.predictors.KarrasEDMPredictor ):EDM的广义预测指标,集成了各种参数化。在flaxdiff.samplers中实施。采样器:
flaxdiff.samplers.DDPMSampler ):实现denoising扩散概率模型(DDPM)采样过程。flaxdiff.samplers.DDIMSampler ):实现denoising扩散隐式模型(DDIM)采样过程。flaxdiff.samplers.EulerSampler ):使用Euler方法的ode求解器采样器。flaxdiff.samplers.HeunSampler ):使用heun方法的ode求解器采样器。flaxdiff.samplers.RK4Sampler ):使用Runge-Kutta方法的ODE求解器采样器。flaxdiff.samplers.MultiStepDPM ):实现了一种由多步dpm求解器启发的多步骤采样方法,如下所示:Tonyduan/tonyduan/diffusion)在flaxdiff.trainer中实施:
flaxdiff.trainer.DiffusionTrainer ):旨在促进扩散模型训练的类。它管理培训循环,损失计算和模型更新。在flaxdiff.models中实现:
flaxdiff.models.simple_unet.SimpleUNet ):扩散模型的样本UNET体系结构。flaxdiff.models.simple_unet.Upsample ),下采样( flaxdiff.models.simple_unet.Downsample ),时间嵌入( flaxdiff.models.simple_unet.FouriedEmbedding ), flaxdiff.models.simple_unet.AttentionBlock ,残留块( flaxdiff.models.simple_unet.ResidualBlock )。 要安装FlaxDiff,您需要拥有Python 3.10或更高。使用以下方式安装所需的依赖项:
pip install -r requirements.txt对模型进行了训练并用JAX == 0.4.28和亚麻== 0.8.4进行了测试。但是,当我更新到最新的jax == 0.4.30和亚麻== 0.8.5时,这些模型停止了训练。似乎已经有一些重大变化打破了训练动态,因此我建议坚持要求中提到的版本.txt
这是一个简化的示例,可以让您开始使用FlaxDiff培训扩散模型:
from flaxdiff . schedulers import EDMNoiseScheduler
from flaxdiff . predictors import KarrasPredictionTransform
from flaxdiff . models . simple_unet import SimpleUNet as UNet
from flaxdiff . trainer import DiffusionTrainer
import jax
import optax
from datetime import datetime
BATCH_SIZE = 16
IMAGE_SIZE = 64
# Define noise scheduler
edm_schedule = EDMNoiseScheduler ( 1 , sigma_max = 80 , rho = 7 , sigma_data = 0.5 )
# Define model
unet = UNet ( emb_features = 256 ,
feature_depths = [ 64 , 128 , 256 , 512 ],
attention_configs = [{ "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }],
num_res_blocks = 2 ,
num_middle_res_blocks = 1 )
# Load dataset
data , datalen = get_dataset ( "oxford_flowers102" , batch_size = BATCH_SIZE , image_scale = IMAGE_SIZE )
batches = datalen // BATCH_SIZE
# Define optimizer
solver = optax . adam ( 2e-4 )
# Create trainer
trainer = DiffusionTrainer ( unet , optimizer = solver ,
noise_schedule = edm_schedule ,
rngs = jax . random . PRNGKey ( 4 ),
name = "Diffusion_SDE_VE_" + datetime . now (). strftime ( "%Y-%m-%d_%H:%M:%S" ),
model_output_transform = KarrasPredictionTransform ( sigma_data = edm_schedule . sigma_data ))
# Train the model
final_state = trainer . fit ( data , batches , epochs = 2000 )这是一个简化的示例,用于使用训练有素的模型生成图像:
from flaxdiff . samplers import DiffusionSampler
class EulerSampler ( DiffusionSampler ):
def take_next_step ( self , current_samples , reconstructed_samples , pred_noise , current_step , state , next_step = None ):
current_alpha , current_sigma = self . noise_schedule . get_rates ( current_step )
next_alpha , next_sigma = self . noise_schedule . get_rates ( next_step )
dt = next_sigma - current_sigma
x_0_coeff = ( current_alpha * next_sigma - next_alpha * current_sigma ) / dt
dx = ( current_samples - x_0_coeff * reconstructed_samples ) / current_sigma
next_samples = current_samples + dx * dt
return next_samples , state
# Create sampler
sampler = EulerSampler ( trainer . model , trainer . state . ema_params , edm_schedule , model_output_transform = trainer . model_output_transform )
# Generate images
samples = sampler . generate_images ( num_images = 64 , diffusion_steps = 100 , start_step = 1000 , end_step = 0 )
plotImages ( samples , dpi = 300 )在TPU-V4-32上接受了Laion-Asesthits训练的模型12m + CC12M + CC12M + MS Coco + MS Coco + 1M美学6+ Coyo-700m的子集:tpu-V4-32: a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden
参数: Dataset: Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M Batch size: 256 Image Size: 128 Training Epochs: 5 Steps per epoch: 74573 Model Configurations: feature_depths=[128, 256, 512, 1024]
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

由以下提示使用无分类器引起的指导因素= 2的图像= 2: 'water tulip, a water lily, a water lily, a water lily, a photo of a marigold, a water lily, a water lily, a photo of a lotus, a photo of a lotus, a photo of a lotus, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose'
参数: Dataset: oxford_flowers102 Batch size: 16 Image Size: 128 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

由以下提示与指导因子使用免费指导产生的图像= 4: 'water tulip, a water lily, a water lily, a photo of a rose, a photo of a rose, a water lily, a water lily, a photo of a marigold, a photo of a marigold, a photo of a marigold, a water lily, a photo of a sunflower, a photo of a lotus, columbine, columbine, an orchid, an orchid, an orchid, a water lily, a water lily, a water lily, columbine, columbine, a photo of a sunflower, a photo of a sunflower, a photo of a sunflower, a photo of a lotus, a photo of a lotus, a photo of a marigold, a photo of a marigold, a photo of a rose, a photo of a rose, a photo of a rose, orange dahlia, orange dahlia, a lenten rose, a lenten rose, a water lily, a water lily, a water lily, a water lily, an orchid, an orchid, an orchid, hard-leaved pocket orchid, bird of paradise, bird of paradise, a photo of a lovely rose, a photo of a lovely rose, a photo of a globe-flower, a photo of a globe-flower, a photo of a lovely rose, a photo of a lovely rose, a photo of a ruby-lipped cattleya, a photo of a ruby-lipped cattleya, a photo of a lovely rose, a water lily, a osteospermum, a osteospermum, a water lily, a water lily, a water lily, a red rose, a red rose'
参数: Dataset: oxford_flowers102 Batch size: 16 Image Size: 128 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

参数: Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: CosineNoiseSchedule Inference Noise Schedule: CosineNoiseSchedule
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

参数: Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: CosineNoiseSchedule Inference Noise Schedule: CosineNoiseSchedule
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

参数: Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

可以随意通过开头问题或提交拉力请求来做出贡献。让我们一起使Flaxdiff更好!
该项目已根据MIT许可获得许可。