
該項目得到了Google TPU研究雲的部分支持。我要感謝Google Cloud TPU團隊為我提供了在多主機分佈式設置中培訓更大文本條件模型的資源。
近年來,擴散和基於得分的多步模型已徹底改變了生成的AI領域。但是,該領域的最新研究已變得高度數學密集型,使了解最先進的擴散模型如何工作並產生令人印象深刻的圖像變得具有挑戰性。在代碼中復制這項研究可能令人生畏。
FlaxDiff是以易於理解的方式設計和實施的工具(調度程序,採樣器,模型等)的庫。重點是對性能的可理解性和可讀性。我開始了這個項目,是一種愛好,以熟悉亞麻和jax,並了解擴散和生成AI的最新研究。
我最初在Keras啟動了這個項目,熟悉Tensorflow 2.0,但由於其性能和易用性而過渡到由JAX提供動力的亞麻。還提供了舊的筆記本電腦和型號,包括我的第一個亞麻模型。
Diffusion_flax_linen.ipynb筆記本是我實驗的主要工作區。將幾個檢查點上傳到pretrained文件夾,以及與每個檢查點關聯的工作筆記本的副本。您可能需要將筆記本複製到工作根,以使其正常運行。
在example notebooks文件夾中,您將找到有關各種擴散技術的全面筆記本,這些筆記本完全是從頭開始編寫的,並且獨立於Flaxdiff庫。每個筆記本都包括對基本數學和概念的詳細說明,使其成為學習和理解擴散模型的寶貴資源。
擴散解釋(NBViewer鏈接)(本地鏈接)
EDM(闡明基於擴散的生成模型的設計空間)
這些筆記本旨在為各種擴散模型和技術提供非常易於理解和逐步指南。它們被設計為對初學者的友好型,因此儘管它們可能不遵守原始論文的確切表述和實現,以使其更容易理解和推廣,但我還是盡力使它們盡可能準確。如果您發現任何錯誤或有任何建議,請隨時打開問題或提取請求。
JAX中的多主宿主數據並行培訓腳本
TPU公用事業可使生活更輕鬆
從2019 - 2021年開始,我曾在Hyperverge擔任機器學習研究人員,重點關注計算機視覺,特別是面部反欺騙和麵部檢測和識別。自從2021年改用我目前的工作以來,我從未從事過太多的研發工作,導致我開始了這個寵物項目,以重新訪問和重新學習基礎知識,並熟悉最先進的工作。我目前的角色主要涉及Golang System Engineering,其中一些應用的ML工作剛剛湧入。因此,代碼可能反映了我的學習旅程。請原諒任何錯誤,並打開一個讓我知道的問題。
另外,在Github Copilot的幫助下,很少有文本可以生成,因此請原諒文本中的任何錯誤。
在flaxdiff.schedulers中實施:
flaxdiff.schedulers.LinearNoiseSchedule ):β參數化的離散調度程序。flaxdiff.schedulers.CosineNoiseSchedule ):β參數化的離散調度程序。flaxdiff.schedulers.ExpNoiseSchedule ):β參數化的離散調度程序。flaxdiff.schedulers.CosineContinuousNoiseScheduler ):連續的調度程序。flaxdiff.schedulers.CosineGeneralNoiseScheduler ):連續的Sigma參數化餘弦調度程序。flaxdiff.schedulers.KarrasVENoiseScheduler ):Sigma參數化的連續調度程序,由Karras等人提出。 2022年,最適合推斷。flaxdiff.schedulers.EDMNoiseScheduler ):基於指數擴散模型(EDM)的Sigma參數連續調度程序,最適合與KarraskAraskArlasvenoisscheduler進行培訓。在flaxdiff.predictors中實施。預告:
flaxdiff.predictors.EpsilonPredictor ):預測數據中的噪聲。flaxdiff.predictors.X0Predictor ):從嘈雜數據中預測原始數據。flaxdiff.predictors.VPredictor ):預測EDM中常用的數據和噪聲的線性組合。flaxdiff.predictors.KarrasEDMPredictor ):EDM的廣義預測指標,集成了各種參數化。在flaxdiff.samplers中實施。採樣器:
flaxdiff.samplers.DDPMSampler ):實現denoising擴散概率模型(DDPM)採樣過程。flaxdiff.samplers.DDIMSampler ):實現denoising擴散隱式模型(DDIM)採樣過程。flaxdiff.samplers.EulerSampler ):使用Euler方法的ode求解器採樣器。flaxdiff.samplers.HeunSampler ):使用heun方法的ode求解器採樣器。flaxdiff.samplers.RK4Sampler ):使用Runge-Kutta方法的ODE求解器採樣器。flaxdiff.samplers.MultiStepDPM ):實現了一種由多步dpm求解器啟發的多步驟採樣方法,如下所示:Tonyduan/tonyduan/diffusion)在flaxdiff.trainer中實施:
flaxdiff.trainer.DiffusionTrainer ):旨在促進擴散模型訓練的類。它管理培訓循環,損失計算和模型更新。在flaxdiff.models中實現:
flaxdiff.models.simple_unet.SimpleUNet ):擴散模型的樣本UNET體系結構。flaxdiff.models.simple_unet.Upsample )的層庫,下採樣( flaxdiff.models.simple_unet.Downsample ),時間嵌入式, flaxdiff.models.simple_unet.FouriedEmbedding tllection glaxsdiff.modiff and and and and and and and and and simt simt simt simt simt simt simt simt simt sign norly simnsemssemne norly flaxdiff.models.simple_unet.ResidualBlock flaxdiff.models.simple_unet.AttentionBlock 。 要安裝FlaxDiff,您需要擁有Python 3.10或更高。使用以下方式安裝所需的依賴項:
pip install -r requirements.txt對模型進行了訓練並用JAX == 0.4.28和亞麻== 0.8.4進行了測試。但是,當我更新到最新的jax == 0.4.30和亞麻== 0.8.5時,這些模型停止了訓練。似乎已經有一些重大變化打破了訓練動態,因此我建議堅持要求中提到的版本.txt
這是一個簡化的示例,可以讓您開始使用FlaxDiff培訓擴散模型:
from flaxdiff . schedulers import EDMNoiseScheduler
from flaxdiff . predictors import KarrasPredictionTransform
from flaxdiff . models . simple_unet import SimpleUNet as UNet
from flaxdiff . trainer import DiffusionTrainer
import jax
import optax
from datetime import datetime
BATCH_SIZE = 16
IMAGE_SIZE = 64
# Define noise scheduler
edm_schedule = EDMNoiseScheduler ( 1 , sigma_max = 80 , rho = 7 , sigma_data = 0.5 )
# Define model
unet = UNet ( emb_features = 256 ,
feature_depths = [ 64 , 128 , 256 , 512 ],
attention_configs = [{ "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }],
num_res_blocks = 2 ,
num_middle_res_blocks = 1 )
# Load dataset
data , datalen = get_dataset ( "oxford_flowers102" , batch_size = BATCH_SIZE , image_scale = IMAGE_SIZE )
batches = datalen // BATCH_SIZE
# Define optimizer
solver = optax . adam ( 2e-4 )
# Create trainer
trainer = DiffusionTrainer ( unet , optimizer = solver ,
noise_schedule = edm_schedule ,
rngs = jax . random . PRNGKey ( 4 ),
name = "Diffusion_SDE_VE_" + datetime . now (). strftime ( "%Y-%m-%d_%H:%M:%S" ),
model_output_transform = KarrasPredictionTransform ( sigma_data = edm_schedule . sigma_data ))
# Train the model
final_state = trainer . fit ( data , batches , epochs = 2000 )這是一個簡化的示例,用於使用訓練有素的模型生成圖像:
from flaxdiff . samplers import DiffusionSampler
class EulerSampler ( DiffusionSampler ):
def take_next_step ( self , current_samples , reconstructed_samples , pred_noise , current_step , state , next_step = None ):
current_alpha , current_sigma = self . noise_schedule . get_rates ( current_step )
next_alpha , next_sigma = self . noise_schedule . get_rates ( next_step )
dt = next_sigma - current_sigma
x_0_coeff = ( current_alpha * next_sigma - next_alpha * current_sigma ) / dt
dx = ( current_samples - x_0_coeff * reconstructed_samples ) / current_sigma
next_samples = current_samples + dx * dt
return next_samples , state
# Create sampler
sampler = EulerSampler ( trainer . model , trainer . state . ema_params , edm_schedule , model_output_transform = trainer . model_output_transform )
# Generate images
samples = sampler . generate_images ( num_images = 64 , diffusion_steps = 100 , start_step = 1000 , end_step = 0 )
plotImages ( samples , dpi = 300 )在TPU-V4-32上接受了Laion-aesthetics 12m + CC12M + CC12M + MS Coco + 1M審美6+子集:coyo-700m的子集: a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden
參數: Dataset: Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M Batch size: 256 Image Size: 128 Training Epochs: 5 Steps per epoch: 74573 Model Configurations: feature_depths=[128, 256, 512, 1024]
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

由以下提示使用無分類器引起的指導因素= 2的圖像= 2: 'water tulip, a water lily, a water lily, a water lily, a photo of a marigold, a water lily, a water lily, a photo of a lotus, a photo of a lotus, a photo of a lotus, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose'
參數: Dataset: oxford_flowers102 Batch size: 16 Image Size: 128 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

由以下提示產生的圖像使用指導因素免費指導= 4: 'water tulip, a water lily, a water lily, a photo of a rose, a photo of a rose, a water lily, a water lily, a photo of a marigold, a photo of a marigold, a photo of a marigold, a water lily, a photo of a sunflower, a photo of a lotus, columbine, columbine, an orchid, an orchid, an orchid, a water lily, a water lily, a water lily, columbine, columbine, a photo of a sunflower, a photo of a sunflower, a photo of a sunflower, a photo of a lotus, a photo of a lotus, a photo of a marigold, a photo of a marigold, a photo of a rose, a photo of a rose, a photo of a rose, orange dahlia, orange dahlia, a lenten rose, a lenten rose, a water lily, a water lily, a water lily, a water lily, an orchid, an orchid, an orchid, hard-leaved pocket orchid, bird of paradise, bird of paradise, a photo of a lovely rose, a photo of a lovely rose, a photo of a globe-flower, a photo of a globe-flower, a photo of a lovely rose, a photo of a lovely rose, a photo of a ruby-lipped cattleya, a photo of a ruby-lipped cattleya, a photo of a lovely rose, a water lily, a osteospermum, a osteospermum, a water lily, a water lily, a water lily, a red rose, a red rose'
參數: Dataset: oxford_flowers102 Batch size: 16 Image Size: 128 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

參數: Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: CosineNoiseSchedule Inference Noise Schedule: CosineNoiseSchedule
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

參數: Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: CosineNoiseSchedule Inference Noise Schedule: CosineNoiseSchedule
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

參數: Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

可以隨意通過開頭問題或提交拉力請求來做出貢獻。讓我們一起使Flaxdiff更好!
該項目已根據MIT許可獲得許可。