Repositori ini berfungsi sebagai basis kode sederhana untuk melakukan eksperimen dengan pendekatan difusi asli (DDIM dan DDPM), dan tidak stabil difusi --- yang dapat dengan mudah ditambahkan. Secara khusus, kode ini mudah dibaca dan cukup fleksibel untuk dimodifikasi untuk satu penggunaan sendiri. Implementasi ini mencakup pengambilan sampel DDIM dan DDPM yang ditulis menggunakan jax primitif untuk kompilasi JIT tidak seperti repositori difusi berbasis JAX yang umum ditemukan.












pip install - r requirements . txtJika Anda mengalami masalah dengan menginstal JAX, silakan merujuk ke dokumentasinya.


python train_unet.py
--loss_type pred_v
--min_snr_gamma 5.0
--timesteps 1000
--sampling_steps 250
--seed 3867
--save_every_k 5
--max_to_keep 5
--epochs 1000
--batch_size 128
--num_workers 0
--gradient_accummulation_steps 1
--pin_memory True
--learning_rate 0.0001
--weight_decay 0.0001
--max_ema_decay 0.9999
--min_ema_decay 0.0
--ema_decay_power 0.66667
--ema_inv_gamma 1
--start_ema_update_after 100
--update_ema_every 10
--result_path ./unet_cifar10
--dataset cifar10
--root_folder ../data
--beta_schedule sigmoid
--dim 64
--dim_mults 1,2,4,8
--resnet_block_groups 8
--learned_variance False
--clear_gpu_cache False
Catatan, pos pemeriksaan digunakan dan dengan demikian, Anda dapat melanjutkan pelatihan dari pos pemeriksaan jika dihentikan secara tiba -tiba.
Untuk informasi lebih lanjut, lihatlah Train.py di SRC/Utils dan Jupyter Notebook yang disediakan di NBS .
import jax
import numpy as np
from src import Unet
# create empty values for initialization --- note the first dimension is being vmapped over.
x = np . ones ([ 1 , 32 , 32 , 3 ])
t = np . ones ([ 1 ])
# seed
key = jax . random . PRNGKey ( 42 )
# initialize model
params = model . init ( key , x , t )[ 'params' ]
# usage
score = model . apply ({ 'params' : params }, x , t )



import equinox as eqx
from src import linear_schedule , ddim_sample , ddpm_sample , get_var_params , get_pred_fn , get_time_pairs
loss_type = "pred_v"
timesteps = 1000
sampling_steps = 250
# variance scheduler
betas = linear_schedule ( timesteps )
var_params = get_var_params ( betas )
# predict function: x0, noise, and v
pred_fn = get_pred_fn ( loss_type )
# compile these functions
ddpm_sample_fn = eqx . filter_jit ( ddpm_sample )
ddim_sample_fn = eqx . filter_jit ( ddim_sample )
# ddpm-sampling
x_ddpm = ddpm_sample_fn (
params ,
model . apply ,
pred_fn ,
x ,
np . arange ( timesteps ),
key ,
var_params ,
timesteps
)
# ddim-sampling
time_pairs = get_time_pairs ( timesteps , sampling_steps )
x_ddim = ddim_sample_fn (
params ,
model . apply ,
pred_fn ,
x ,
time_pairs ,
key ,
var_params ,
sampling_steps ,
0.0
)Fungsi JAX UTIL yang umum adalah Jax.pmap . Untuk kemudahan penggunaan, perpustakaan ini memanfaatkan fungsi utilitas Equinox. Perlu diingat, JIT juga diterapkan saat PMAP digunakan.
from src import shard , unshard
from flax . jax_utils import replicate
# batch_size x image_shape
x = np . random . ones ([ 256 , 32 , 32 , 3 ])
# num_gpus x (256 // num_gpus) x 32 x 32 x 3
x = shard ( x )
# replicate params for each gpu
replicate_params = replicate ( params )
# Specifying None means the variable is static while a number indicates the axis being mapped over.
ddim_sample_fn = eqx . filter_pmap ( ddim_sample , in_axes = ( 0 , None , None , 0 , None , None , None , None , None ))
x_ddim = ddim_sample_fn (
replicate_params ,
model . apply ,
pred_fn ,
x ,
time_pairs ,
key ,
var_params ,
sampling_steps ,
0.0
)
# batch_size x image_shape
x_ddim = unshard ( x_ddim )Pos Pemeriksaan model dan manajemen pos pemeriksaan model dilakukan melalui Orbax.
from src import create_checkpoint_manager , restore_model
FOLDER = "./unet_pred_v/ckpts"
# create manager for reloading ckpt
ckpt_manager = create_checkpoint_manager ( FOLDER )
# specify None for the latest step
ckpt = restore_model ( ckpt_manager , latest_step = None )
# available keys; ckpt is a dict
ckpt [ 'config' ]
ckpt [ 'params' ]
ckpt [ 'ema_params' ]
ckpt [ 'opt_state' ]
# you can also specify a target dict for proper reloading
my_ckpt = {
'params' : params ,
'opt_state' : opt_state ,
'ema_params' : ema_params ,
'config' : config
}
my_ckpt = restore_model ( ckpt_manager , target = my_ckpt , latest_step = None ) import numpy as np
import equinox as eqx
from src import ddim_sample_visual , create_gifs , Unet , get_time_pairs , get_pred_fn
# noise
x = np . random . normal ( 0 , 1 , size = [ 16 , 32 , 32 , 3 ])
# Using the ckpt from above
config = my_ckpt [ 'config' ]
pred_fn = get_pred_fn ( config [ 'loss_type' ])
key = config [ 'key' ]
timesteps = config [ 'timesteps' ]
sampling_steps = config [ 'sampling_steps' ]
# use specify arguments from config
model = Unet (...)
# time_pairs for ddim
time_pairs = get_time_pairs ( timesteps , sampling_steps )
ddim_sample_fn = eqx . filter_jit ( ddim_sample_visual )
# perform generation where the x at different timesteps are returned as well
x_ddim , x_over_time = ddim_sample_fn (
my_ckpt [ 'ema_params' ],
model . apply ,
pred_fn ,
x ,
time_pairs ,
key ,
config [ 'var_params' ],
steps = sampling_steps ,
eta = 0.0 # or config['eta']
)
# create animations: you can upscale the images
frames = 100
create_gifs (
x_over_time ,
duration = 1 / frames ,
folder = "./example_gifs/" ,
image_size = ( 256 , 256 , 3 ),
num_images = 5 ,
) Pembaruan EMA Diadaptasi dari Yiyixuxu dilakukan dengan menggunakan persamaan berikut