
Этот проект частично поддерживается Google TPU Research Cloud. Я хотел бы поблагодарить команду Google Cloud TPU за предоставление мне ресурсов для обучения более крупных моделей текстовых кондиционированных моделей в распределенных настройках с несколькими хостами.
В последние годы многоэтапные модели диффузии и баллов произвели революцию в генеративной области ИИ. Тем не менее, последнее исследование в этой области стало высокой математической интенсивностью, что затрудняет понимание того, как работают современные модели диффузии и генерируют такие впечатляющие образы. Репликация этого исследования в коде может быть пугающим.
Flaxdiff-это библиотека инструментов (планировщики, пробоотборники, модели и т. Д.), Разработанная и внедренная и реализованная способом. Основное внимание уделяется понятности и читабельности по сравнению с производительностью. Я начал этот проект как хобби, чтобы ознакомиться с леном и JAX и узнать о диффузии и последних исследованиях в области генеративного искусственного интеллекта.
Первоначально я начал этот проект в Керас, знакомый с Tensorflow 2.0, но перешел на льнь, работающий от JAX, за его производительность и простоту использования. Также предоставляются старые ноутбуки и модели, в том числе мои первые модели льна.
Ноутбук Diffusion_flax_linen.ipynb - это мое основное рабочее пространство для экспериментов. Несколько контрольных точек загружаются в pretrained папку вместе с копией рабочей ноутбука, связанной с каждой контрольной точкой. Возможно, вам придется скопировать ноутбук в рабочий корень, чтобы он функционировал должным образом.
В example notebooks вы найдете всеобъемлющие ноутбуки для различных методов диффузии, написанных полностью с нуля и не зависят от библиотеки льна. Каждый блокнот включает в себя подробные объяснения базовой математики и концепций, что делает их бесценными ресурсами для обучения и понимания моделей распространения.
Диффузия объяснена (ссылка NBViewer) (локальная ссылка)
EDM (выяснение пространства дизайна генеративных моделей на основе диффузии)
Эти записные книжки направлены на то, чтобы обеспечить очень простое для понимания и пошаговое руководство по различным диффузионным моделям и методам. Они предназначены для того, чтобы быть для начинающих, и, таким образом, хотя они могут не придерживаться точных составов и реализаций оригинальных работ, чтобы сделать их более понятными и обобщаемыми, я старался изо всех сил, чтобы сохранить их как можно более точными. Если вы найдете какие -либо ошибки или у вас есть какие -либо предложения, пожалуйста, не стесняйтесь открыть проблему или запрос на тягу.
Параллельный учебный скрипт данных с несколькими хостами в JAX
Утилиты TPU для облегчения жизни
Я работал исследователем машинного обучения в Hyperverge в 2019-2021 годах, сосредоточившись на компьютерном зрении, в частности, анти-спорофы для лица и обнаружения и распознавания лица. С момента перехода на мою текущую работу в 2021 году я не занимался таким большим количеством исследований и разработок, заставляя меня начать этот Pet-проект, чтобы вернуться и заново изучать основы и познакомиться с самым современным. Моя нынешняя роль включает в себя в первую очередь Golang System Engineering с некоторой прикладной работой ML, только что посыпанной. Следовательно, код может отражать мое учебное путешествие. Пожалуйста, прощайте любые ошибки и откройте проблему, чтобы сообщить мне.
Кроме того, немногие из текстов могут быть сгенерированы с помощью GitHub Copilot, поэтому, пожалуйста, извините за любые ошибки в тексте.
Внедрены в flaxdiff.schedulers
flaxdiff.schedulers.LinearNoiseSchedule ): бета-параметрированный дискретный планировщик.flaxdiff.schedulers.CosineNoiseSchedule ): бета-параметрированный дискретный планировщик.flaxdiff.schedulers.ExpNoiseSchedule ): бета-параметрированный дискретный планировщик.flaxdiff.schedulers.CosineContinuousNoiseScheduler ): непрерывный планировщик.flaxdiff.schedulers.CosineGeneralNoiseScheduler ): непрерывный Sigma, параметризованный косинус.flaxdiff.schedulers.KarrasVENoiseScheduler ): сигма-параметрированный непрерывный планировщик, предложенный Karras et al. 2022, лучше всего подходит для вывода.flaxdiff.schedulers.EDMNoiseScheduler ): непрерывный планирующий Sigma-параметированный на основе модели экспоненциальной диффузии (EDM), который лучше всего подходит для обучения с Karraskarrasvenoiseisheduler. Реализовано в flaxdiff.predictors
flaxdiff.predictors.EpsilonPredictor ): прогнозирует шум в данных.flaxdiff.predictors.X0Predictor ): прогнозирует исходные данные из шумных данных.flaxdiff.predictors.VPredictor ): прогнозирует линейную комбинацию данных и шума, обычно используемой в EDM.flaxdiff.predictors.KarrasEDMPredictor ): обобщенный предиктор EDM, интегрирующий различные параметризации. Реализовано в flaxdiff.samplers
flaxdiff.samplers.DDPMSampler ): реализует процесс отбора проб диффузии диффузии (DDPM).flaxdiff.samplers.DDIMSampler ): реализует процесс отбора проб неявной модели (DDIM).flaxdiff.samplers.EulerSampler ): пробоотборщик решателя ODE с использованием метода Эйлера.flaxdiff.samplers.HeunSampler ): пробоотборник Ode Solver с использованием метода Heun's.flaxdiff.samplers.RK4Sampler ): пробоотборник решателя ODE с использованием метода Runge-Kutta.flaxdiff.samplers.MultiStepDPM ): реализует многоэтапный метод отбора проб, вдохновленный многоэтапным решателем DPM, как представлено здесь: Tonyduan/Diffusion) Реализовано в flaxdiff.trainer :
flaxdiff.trainer.DiffusionTrainer ): класс, предназначенный для облегчения обучения диффузионных моделей. Он управляет петлей обучения, расчетом потерь и обновлениями моделей. Реализовано в flaxdiff.models :
flaxdiff.models.simple_unet.SimpleUNet ): образец Unet Architecture для диффузионных моделей.flaxdiff.models.simple_unet.Upsample ), Downsampling ( flaxdiff.models.simple_unet.Downsample ), Вторжение времени ( flaxdiff.models.simple_unet.FouriedEmbedding ), articeDiff.models.simple_unet.fouriedembedding), artity.models.simple_Unet.FouredEmbedding), artity.models.simple_unet. и flaxdiff.models.simple_unet.AttentionBlock ) Остаточные блоки ( flaxdiff.models.simple_unet.ResidualBlock ). Чтобы установить Flaxdiff, вам нужно иметь Python 3.10 или выше. Установите требуемые зависимости, используя:
pip install -r requirements.txtМодели были обучены и протестированы с JAX == 0,4,28 и льном == 0,8,4. Однако, когда я обновился до последнего JAX == 0.4.30 и льна == 0,8,5, модели прекратили тренировку. Похоже, было какое -то серьезные изменения, нарушающие динамику обучения, и поэтому я бы порекомендовал придерживаться версий, упомянутых в требованиях.txt.
Вот упрощенный пример, чтобы вы начали с обучения диффузионной модели, используя Flaxdiff:
from flaxdiff . schedulers import EDMNoiseScheduler
from flaxdiff . predictors import KarrasPredictionTransform
from flaxdiff . models . simple_unet import SimpleUNet as UNet
from flaxdiff . trainer import DiffusionTrainer
import jax
import optax
from datetime import datetime
BATCH_SIZE = 16
IMAGE_SIZE = 64
# Define noise scheduler
edm_schedule = EDMNoiseScheduler ( 1 , sigma_max = 80 , rho = 7 , sigma_data = 0.5 )
# Define model
unet = UNet ( emb_features = 256 ,
feature_depths = [ 64 , 128 , 256 , 512 ],
attention_configs = [{ "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }],
num_res_blocks = 2 ,
num_middle_res_blocks = 1 )
# Load dataset
data , datalen = get_dataset ( "oxford_flowers102" , batch_size = BATCH_SIZE , image_scale = IMAGE_SIZE )
batches = datalen // BATCH_SIZE
# Define optimizer
solver = optax . adam ( 2e-4 )
# Create trainer
trainer = DiffusionTrainer ( unet , optimizer = solver ,
noise_schedule = edm_schedule ,
rngs = jax . random . PRNGKey ( 4 ),
name = "Diffusion_SDE_VE_" + datetime . now (). strftime ( "%Y-%m-%d_%H:%M:%S" ),
model_output_transform = KarrasPredictionTransform ( sigma_data = edm_schedule . sigma_data ))
# Train the model
final_state = trainer . fit ( data , batches , epochs = 2000 )Вот упрощенный пример для генерации изображений с использованием обученной модели:
from flaxdiff . samplers import DiffusionSampler
class EulerSampler ( DiffusionSampler ):
def take_next_step ( self , current_samples , reconstructed_samples , pred_noise , current_step , state , next_step = None ):
current_alpha , current_sigma = self . noise_schedule . get_rates ( current_step )
next_alpha , next_sigma = self . noise_schedule . get_rates ( next_step )
dt = next_sigma - current_sigma
x_0_coeff = ( current_alpha * next_sigma - next_alpha * current_sigma ) / dt
dx = ( current_samples - x_0_coeff * reconstructed_samples ) / current_sigma
next_samples = current_samples + dx * dt
return next_samples , state
# Create sampler
sampler = EulerSampler ( trainer . model , trainer . state . ema_params , edm_schedule , model_output_transform = trainer . model_output_transform )
# Generate images
samples = sampler . generate_images ( num_images = 64 , diffusion_steps = 100 , start_step = 1000 , end_step = 0 )
plotImages ( samples , dpi = 300 ) Model trained on Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M on TPU-v4-32: a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden
Параметры : Dataset: Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M Batch size: 256 Image Size: 128 Training Epochs: 5 Steps per epoch: 74573 Model Configurations: feature_depths=[128, 256, 512, 1024]
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

Images generated by the following prompts using classifier free guidance with guidance factor = 2: 'water tulip, a water lily, a water lily, a water lily, a photo of a marigold, a water lily, a water lily, a photo of a lotus, a photo of a lotus, a photo of a lotus, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose'
Парамы : Dataset: oxford_flowers102 Batch size: 16 Image Size: 128 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

Images generated by the following prompts using classifier free guidance with guidance factor = 4: 'water tulip, a water lily, a water lily, a photo of a rose, a photo of a rose, a water lily, a water lily, a photo of a marigold, a photo of a marigold, a photo of a marigold, a water lily, a photo of a sunflower, a photo of a lotus, columbine, columbine, an orchid, an orchid, an orchid, a water lily, a water lily, a water lily, columbine, columbine, a photo of a sunflower, a photo of a sunflower, a photo of a sunflower, a photo of a lotus, a photo of a lotus, a photo of a marigold, a photo of a marigold, a photo of a rose, a photo of a rose, a photo of a rose, orange dahlia, orange dahlia, a lenten rose, a lenten rose, a water lily, a water lily, a water lily, a water lily, an orchid, an orchid, an orchid, hard-leaved pocket orchid, bird of paradise, bird of paradise, a photo of a lovely rose, a photo of a lovely rose, a photo of a globe-flower, a photo of a globe-flower, a photo of a lovely rose, a photo of a lovely rose, a photo of a ruby-lipped cattleya, a photo of a ruby-lipped cattleya, a photo of a lovely rose, a water lily, a osteospermum, a osteospermum, a water lily, a water lily, a water lily, a red rose, a red rose'
Парамы : Dataset: oxford_flowers102 Batch size: 16 Image Size: 128 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

PARAMS : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: CosineNoiseSchedule Inference Noise Schedule: CosineNoiseSchedule
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

PARAMS : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: CosineNoiseSchedule Inference Noise Schedule: CosineNoiseSchedule
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

PARAMS : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

Не стесняйтесь вносить свой вклад, открывая вопросы или отправляя запросы на привлечение. Давайте сделаем Flaxdiff лучше вместе!
Этот проект лицензирован по лицензии MIT.