
Proyek ini sebagian didukung oleh Google TPU Research Cloud. Saya ingin mengucapkan terima kasih kepada tim TPU Google Cloud karena telah memberikan saya sumber daya untuk melatih model-model kondisi teks yang lebih besar dalam pengaturan yang didistribusikan multi-host.
Dalam beberapa tahun terakhir, difusi dan model multi-langkah berbasis skor telah merevolusi domain AI generatif. Namun, penelitian terbaru di bidang ini telah menjadi sangat intensif matematika, membuatnya menantang untuk memahami bagaimana model difusi canggih bekerja dan menghasilkan gambar yang mengesankan. Mereplikasi penelitian ini dalam kode bisa menakutkan.
Flaxdiff adalah perpustakaan alat (penjadwal, sampler, model, dll.) Dirancang dan diimplementasikan dengan cara yang mudah dipahami. Fokusnya adalah pada pemahaman dan keterbacaan atas kinerja. Saya memulai proyek ini sebagai hobi untuk membiasakan diri dengan Flax dan Jax dan untuk belajar tentang difusi dan penelitian terbaru dalam AI generatif.
Saya awalnya memulai proyek ini di Keras, terbiasa dengan TensorFlow 2.0, tetapi beralih ke rami, ditenagai oleh Jax, untuk kinerjanya dan kemudahan penggunaannya. Notebook dan model lama, termasuk model rami pertama saya, juga disediakan.
Notebook Diffusion_flax_linen.ipynb adalah ruang kerja utama saya untuk percobaan. Beberapa pos pemeriksaan diunggah ke folder pretrained bersama dengan salinan notebook kerja yang terkait dengan setiap pos pemeriksaan. Anda mungkin perlu menyalin notebook ke root yang berfungsi agar berfungsi dengan baik.
Dalam example notebooks , Anda akan menemukan buku catatan komprehensif untuk berbagai teknik difusi, ditulis sepenuhnya dari awal dan tidak tergantung pada perpustakaan Flaxdiff. Setiap notebook mencakup penjelasan terperinci tentang matematika dan konsep yang mendasari, menjadikannya sumber daya yang sangat berharga untuk belajar dan memahami model difusi.
Difusi dijelaskan (tautan NBViewer) (tautan lokal)
EDM (menjelaskan ruang desain model generatif berbasis difusi)
Buku catatan ini bertujuan untuk memberikan panduan yang sangat mudah dipahami dan langkah demi langkah untuk berbagai model dan teknik difusi. Mereka dirancang untuk menjadi ramah-pemula, dan dengan demikian meskipun mereka mungkin tidak mematuhi formulasi dan implementasi makalah asli yang tepat untuk membuatnya lebih dapat dimengerti dan digeneralisasikan, saya telah mencoba yang terbaik untuk menjaga mereka seakurat mungkin. Jika Anda menemukan kesalahan atau memiliki saran, jangan ragu untuk membuka masalah atau permintaan tarik.
Data multi-host skrip pelatihan paralel di jax
Utilitas TPU untuk membuat hidup lebih mudah
Saya bekerja sebagai peneliti pembelajaran mesin di Hyperverge dari 2019-2021, dengan fokus pada visi komputer, khususnya wajah anti-spoofing dan deteksi & pengakuan wajah. Sejak beralih ke pekerjaan saya saat ini pada tahun 2021, saya belum melakukan pekerjaan R&D sebanyak mungkin, membuat saya memulai proyek hewan peliharaan ini untuk mengunjungi kembali dan mempelajari kembali fundamental dan menjadi terbiasa dengan canggih. Peran saya saat ini terutama melibatkan rekayasa sistem Golang dengan beberapa pekerjaan ML terapan yang hanya ditaburkan. Oleh karena itu, kode tersebut dapat mencerminkan perjalanan belajar saya. Mohon maafkan kesalahan apa pun dan buka masalah untuk memberi tahu saya.
Juga, beberapa teks dapat dihasilkan dengan bantuan github copilot, jadi mohon maafkan kesalahan dalam teks.
Diimplementasikan di flaxdiff.schedulers :
flaxdiff.schedulers.LinearNoiseSchedule ): Sebuah penjadwal diskrit beta-parameterized.flaxdiff.schedulers.CosineNoiseSchedule ): Sebuah penjadwal diskrit beta-parameterized.flaxdiff.schedulers.ExpNoiseSchedule ): Penjadwal diskrit beta-parameterisasi.flaxdiff.schedulers.CosineContinuousNoiseScheduler ): penjadwal berkelanjutan.flaxdiff.schedulers.CosineGeneralNoiseScheduler ): Sigma kontinu parameter penjadwal kosinus.flaxdiff.schedulers.KarrasVENoiseScheduler ): Penjadwal kontinu yang diparameterisasi sigma yang diusulkan oleh Karras et al. 2022, paling cocok untuk inferensi.flaxdiff.schedulers.EDMNoiseScheduler ): Penjadwal kontinu yang diparameterisasi sigma berdasarkan pada model difusi eksponensial (EDM), paling cocok untuk pelatihan dengan Karraskarrasvenoisescheduler. Diterapkan di flaxdiff.predictors :
flaxdiff.predictors.EpsilonPredictor ): Memprediksi kebisingan dalam data.flaxdiff.predictors.X0Predictor ): Memprediksi data asli dari data bising.flaxdiff.predictors.VPredictor ): Memprediksi kombinasi linier dari data dan kebisingan, yang biasa digunakan dalam EDM.flaxdiff.predictors.KarrasEDMPredictor ): Prediktor umum untuk EDM, mengintegrasikan berbagai parameterisasi. Diimplementasikan di flaxdiff.samplers :
flaxdiff.samplers.DDPMSampler ): Mengimplementasikan proses pengambilan sampel Denoising Difusion Probabilistic Model (DDPM).flaxdiff.samplers.DDIMSampler ): mengimplementasikan proses pengambilan sampel Denoising Difusion Implicit Model (DDIM).flaxdiff.samplers.EulerSampler ): Sampler pemecah ode menggunakan metode Euler.flaxdiff.samplers.HeunSampler ): Sampler pemecah ode menggunakan metode Heun.flaxdiff.samplers.RK4Sampler ): Sampler pemecah ODE menggunakan metode Runge-Kutta.flaxdiff.samplers.MultiStepDPM ): mengimplementasikan metode pengambilan sampel multi-langkah yang terinspirasi oleh pemecah DPM multistep seperti yang disajikan di sini: tonyduan/difusi) Diimplementasikan di flaxdiff.trainer :
flaxdiff.trainer.DiffusionTrainer ): Kelas yang dirancang untuk memfasilitasi pelatihan model difusi. Ini mengelola loop pelatihan, perhitungan kerugian, dan pembaruan model. Diimplementasikan di flaxdiff.models :
flaxdiff.models.simple_unet.SimpleUNet ): Sampel arsitektur unet untuk model difusi.flaxdiff.models.simple_unet.Upsample ), downsampling ( flaxdiff.models.simple_unet.Downsample ), Time embeddings ( flaxdiff.models.simple_unet.FouriedEmbedding ), attention ( flaxdiff.models.simple_unet.AttentionBlock ), dan blok residu ( flaxdiff.models.simple_unet.ResidualBlock ). Untuk memasang Flaxdiff, Anda harus memiliki Python 3.10 atau lebih tinggi. Instal dependensi yang diperlukan menggunakan:
pip install -r requirements.txtModel dilatih dan diuji dengan JAX == 0.4.28 dan Flax == 0.8.4. Namun, ketika saya memperbarui ke JAX terbaru == 0.4.30 dan Flax == 0.8.5, model menghentikan pelatihan. Tampaknya ada beberapa perubahan besar melanggar dinamika pelatihan dan oleh karena itu saya akan merekomendasikan tetap pada versi yang disebutkan dalam persyaratan.txt
Berikut adalah contoh yang disederhanakan untuk memulai dengan melatih model difusi menggunakan Flaxdiff:
from flaxdiff . schedulers import EDMNoiseScheduler
from flaxdiff . predictors import KarrasPredictionTransform
from flaxdiff . models . simple_unet import SimpleUNet as UNet
from flaxdiff . trainer import DiffusionTrainer
import jax
import optax
from datetime import datetime
BATCH_SIZE = 16
IMAGE_SIZE = 64
# Define noise scheduler
edm_schedule = EDMNoiseScheduler ( 1 , sigma_max = 80 , rho = 7 , sigma_data = 0.5 )
# Define model
unet = UNet ( emb_features = 256 ,
feature_depths = [ 64 , 128 , 256 , 512 ],
attention_configs = [{ "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }],
num_res_blocks = 2 ,
num_middle_res_blocks = 1 )
# Load dataset
data , datalen = get_dataset ( "oxford_flowers102" , batch_size = BATCH_SIZE , image_scale = IMAGE_SIZE )
batches = datalen // BATCH_SIZE
# Define optimizer
solver = optax . adam ( 2e-4 )
# Create trainer
trainer = DiffusionTrainer ( unet , optimizer = solver ,
noise_schedule = edm_schedule ,
rngs = jax . random . PRNGKey ( 4 ),
name = "Diffusion_SDE_VE_" + datetime . now (). strftime ( "%Y-%m-%d_%H:%M:%S" ),
model_output_transform = KarrasPredictionTransform ( sigma_data = edm_schedule . sigma_data ))
# Train the model
final_state = trainer . fit ( data , batches , epochs = 2000 )Berikut adalah contoh yang disederhanakan untuk menghasilkan gambar menggunakan model terlatih:
from flaxdiff . samplers import DiffusionSampler
class EulerSampler ( DiffusionSampler ):
def take_next_step ( self , current_samples , reconstructed_samples , pred_noise , current_step , state , next_step = None ):
current_alpha , current_sigma = self . noise_schedule . get_rates ( current_step )
next_alpha , next_sigma = self . noise_schedule . get_rates ( next_step )
dt = next_sigma - current_sigma
x_0_coeff = ( current_alpha * next_sigma - next_alpha * current_sigma ) / dt
dx = ( current_samples - x_0_coeff * reconstructed_samples ) / current_sigma
next_samples = current_samples + dx * dt
return next_samples , state
# Create sampler
sampler = EulerSampler ( trainer . model , trainer . state . ema_params , edm_schedule , model_output_transform = trainer . model_output_transform )
# Generate images
samples = sampler . generate_images ( num_images = 64 , diffusion_steps = 100 , start_step = 1000 , end_step = 0 )
plotImages ( samples , dpi = 300 ) Model trained on Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M on TPU-v4-32: a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden
Params : Dataset: Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M Batch size: 256 Image Size: 128 Training Epochs: 5 Steps per epoch: 74573 Model Configurations: feature_depths=[128, 256, 512, 1024]
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

Images generated by the following prompts using classifier free guidance with guidance factor = 2: 'water tulip, a water lily, a water lily, a water lily, a photo of a marigold, a water lily, a water lily, a photo of a lotus, a photo of a lotus, a photo of a lotus, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose'
Params : Dataset: oxford_flowers102 Batch size: 16 Image Size: 128 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

Images generated by the following prompts using classifier free guidance with guidance factor = 4: 'water tulip, a water lily, a water lily, a photo of a rose, a photo of a rose, a water lily, a water lily, a photo of a marigold, a photo of a marigold, a photo of a marigold, a water lily, a photo of a sunflower, a photo of a lotus, columbine, columbine, an orchid, an orchid, an orchid, a water lily, a water lily, a water lily, columbine, columbine, a photo of a sunflower, a photo of a sunflower, a photo of a sunflower, a photo of a lotus, a photo of a lotus, a photo of a marigold, a photo of a marigold, a photo of a rose, a photo of a rose, a photo of a rose, orange dahlia, orange dahlia, a lenten rose, a lenten rose, a water lily, a water lily, a water lily, a water lily, an orchid, an orchid, an orchid, hard-leaved pocket orchid, bird of paradise, bird of paradise, a photo of a lovely rose, a photo of a lovely rose, a photo of a globe-flower, a photo of a globe-flower, a photo of a lovely rose, a photo of a lovely rose, a photo of a ruby-lipped cattleya, a photo of a ruby-lipped cattleya, a photo of a lovely rose, a water lily, a osteospermum, a osteospermum, a water lily, a water lily, a water lily, a red rose, a red rose'
Params : Dataset: oxford_flowers102 Batch size: 16 Image Size: 128 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

Params : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: CosineNoiseSchedule Inference Noise Schedule: CosineNoiseSchedule
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

Params : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: CosineNoiseSchedule Inference Noise Schedule: CosineNoiseSchedule
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

Params : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

Jangan ragu untuk berkontribusi dengan membuka masalah atau mengirimkan permintaan tarik. Mari kita membuat Flaxdiff lebih baik bersama!
Proyek ini dilisensikan di bawah lisensi MIT.