
โครงการนี้ได้รับการสนับสนุนบางส่วนโดย Google TPU Research Cloud ฉันขอขอบคุณทีม Google Cloud TPU ที่ให้ทรัพยากรแก่ฉันในการฝึกอบรมโมเดลข้อความที่ใหญ่กว่าในการตั้งค่าแบบกระจายหลายโฮสต์
ในช่วงไม่กี่ปีที่ผ่านมาโมเดลการแพร่กระจายและหลายขั้นตอนได้ปฏิวัติโดเมน AI Generative อย่างไรก็ตามการวิจัยล่าสุดในสาขานี้ได้กลายเป็นความเข้มงวดทางคณิตศาสตร์สูงทำให้มันท้าทายที่จะเข้าใจว่าแบบจำลองการแพร่กระจายที่ล้ำสมัยทำงานอย่างไรและสร้างภาพที่น่าประทับใจเช่นนี้ การทำซ้ำการวิจัยนี้ในรหัสอาจเป็นเรื่องที่น่ากลัว
Flaxdiff เป็นไลบรารีเครื่องมือ (ตารางเวลาตัวอย่างรุ่น ฯลฯ ) ออกแบบและนำไปใช้ในวิธีที่ง่ายต่อการทำความเข้าใจ การมุ่งเน้นไปที่ความเข้าใจและการอ่านได้มากกว่าประสิทธิภาพ ฉันเริ่มโครงการนี้เป็นงานอดิเรกเพื่อทำความคุ้นเคยกับ Flax และ Jax และเรียนรู้เกี่ยวกับการแพร่กระจายและการวิจัยล่าสุดใน AI Generative
ตอนแรกฉันเริ่มโครงการนี้ใน Keras ซึ่งคุ้นเคยกับ Tensorflow 2.0 แต่เปลี่ยนเป็น Flax ขับเคลื่อนโดย Jax เพื่อประสิทธิภาพและความสะดวกในการใช้งาน ยังมีโน้ตบุ๊กและรุ่นเก่ารวมถึงรุ่นผ้าลินินตัวแรกของฉัน
สมุดบันทึก Diffusion_flax_linen.ipynb เป็นพื้นที่ทำงานหลักของฉันสำหรับการทดลอง มีการอัปโหลดจุดตรวจหลายแห่งไปยังโฟลเดอร์ pretrained พร้อมกับสำเนาสมุดบันทึกที่เกี่ยวข้องกับแต่ละจุดตรวจสอบ คุณอาจต้องคัดลอกโน้ตบุ๊กไปยังรูทที่ใช้งานเพื่อให้ทำงานได้อย่างถูกต้อง
ในโฟลเดอร์ example notebooks คุณจะพบสมุดบันทึกที่ครอบคลุมสำหรับเทคนิคการแพร่กระจายที่หลากหลายเขียนตั้งแต่เริ่มต้นและเป็นอิสระจากไลบรารี Flaxdiff โน้ตบุ๊กแต่ละเล่มมีคำอธิบายโดยละเอียดเกี่ยวกับคณิตศาสตร์และแนวคิดพื้นฐานทำให้พวกเขามีทรัพยากรที่มีค่าสำหรับการเรียนรู้และทำความเข้าใจแบบจำลองการแพร่กระจาย
อธิบายการแพร่กระจาย (ลิงก์ NBViewer) (ลิงค์ท้องถิ่น)
EDM (อธิบายพื้นที่การออกแบบของแบบจำลองการแพร่กระจายที่ใช้การแพร่กระจาย)
สมุดบันทึกเหล่านี้มีจุดมุ่งหมายเพื่อให้คำแนะนำที่เข้าใจง่ายและเป็นขั้นตอนเป็นขั้นตอนเกี่ยวกับรูปแบบการแพร่กระจายและเทคนิคต่างๆ พวกเขาได้รับการออกแบบให้เป็นมิตรกับผู้เริ่มต้นและแม้ว่าพวกเขาจะไม่ปฏิบัติตามสูตรที่แน่นอนและการใช้งานของเอกสารต้นฉบับเพื่อให้พวกเขาเข้าใจและสรุปได้ทั่วไปมากขึ้นฉันได้พยายามอย่างดีที่สุดเพื่อให้พวกเขาแม่นยำที่สุดเท่าที่จะทำได้ หากคุณพบข้อผิดพลาดใด ๆ หรือมีข้อเสนอแนะโปรดอย่าลังเลที่จะเปิดปัญหาหรือคำขอดึง
สคริปต์การฝึกอบรมแบบขนานข้อมูลแบบหลายโฮสต์ใน JAX
สาธารณูปโภค TPU เพื่อทำให้ชีวิตง่ายขึ้น
ฉันทำงานเป็นนักวิจัยการเรียนรู้ของเครื่องจักรที่ Hyperverge ตั้งแต่ปี 2562-2564 โดยมุ่งเน้นไปที่การมองเห็นคอมพิวเตอร์โดยเฉพาะการต่อต้านการตบหน้าและการตรวจจับใบหน้าและการรับรู้ ตั้งแต่เปลี่ยนไปใช้งานปัจจุบันของฉันในปี 2021 ฉันไม่ได้ทำงานด้านการวิจัยและพัฒนามากทำให้ฉันเริ่มโครงการสัตว์เลี้ยงนี้เพื่อกลับมาอีกครั้งและเรียนรู้พื้นฐานใหม่และทำความคุ้นเคยกับความทันสมัย บทบาทปัจจุบันของฉันเกี่ยวข้องกับวิศวกรรมระบบ Golang เป็นหลักกับงาน ML ที่ประยุกต์ใช้บางส่วนเพียงแค่โรยเข้ามาดังนั้นรหัสอาจสะท้อนถึงการเรียนรู้ของฉัน โปรดให้อภัยข้อผิดพลาดใด ๆ และเปิดปัญหาเพื่อแจ้งให้เราทราบ
นอกจากนี้ข้อความบางส่วนอาจถูกสร้างขึ้นด้วยความช่วยเหลือของ GitHub Copilot ดังนั้นโปรดแก้ตัวข้อผิดพลาดใด ๆ ในข้อความ
ดำเนินการใน flaxdiff.schedulers :
flaxdiff.schedulers.LinearNoiseSchedule ): ตัวกำหนดตารางเวลาที่ไม่ต่อเนื่องเบต้าพารามิเตอร์flaxdiff.schedulers.CosineNoiseSchedule ): ตัวกำหนดเวลาที่ไม่ต่อเนื่องเบต้าพารามิเตอร์flaxdiff.schedulers.ExpNoiseSchedule ): ตัวกำหนดเวลาที่ไม่ต่อเนื่องเบต้าพารามิเตอร์flaxdiff.schedulers.CosineContinuousNoiseScheduler ): ตัวกำหนดเวลาอย่างต่อเนื่องflaxdiff.schedulers.CosineGeneralNoiseScheduler ): ตัวกำหนดตารางเวลา Sigma แบบต่อเนื่องflaxdiff.schedulers.KarrasVENoiseScheduler ): Sigma-Parameterized Scheduler อย่างต่อเนื่องที่เสนอโดย Karras และคณะ 2022 เหมาะที่สุดสำหรับการอนุมานflaxdiff.schedulers.EDMNoiseScheduler ): signuous scheduler sigma-parameterized ตามรูปแบบการแพร่กระจายแบบเอ็กซ์โปเนนเชียล (EDM) เหมาะที่สุดสำหรับการฝึกอบรมกับ Karraskarrasvenoisescheduler นำไปใช้ใน flaxdiff.predictors :
flaxdiff.predictors.EpsilonPredictor ): ทำนายเสียงในข้อมูลflaxdiff.predictors.X0Predictor ): ทำนายข้อมูลต้นฉบับจากข้อมูลที่มีเสียงดังflaxdiff.predictors.VPredictor ): ทำนายการรวมกันเชิงเส้นของข้อมูลและเสียงรบกวนที่ใช้กันทั่วไปใน EDMflaxdiff.predictors.KarrasEDMPredictor ): ตัวทำนายทั่วไปสำหรับ EDM รวมการรวมพารามิเตอร์ต่างๆ นำไปใช้ใน flaxdiff.samplers :
flaxdiff.samplers.DDPMSampler ): ใช้โมเดลการสุ่มตัวอย่างแบบ denoising diffusion (DDPM)flaxdiff.samplers.DDIMSampler ): ใช้กระบวนการสุ่มตัวอย่างแบบจำลองการแพร่กระจายของ denoising (DDIM)flaxdiff.samplers.EulerSampler ): ตัวอย่าง Solver Ode โดยใช้วิธีการของออยเลอร์flaxdiff.samplers.HeunSampler ): ตัวอย่าง Solver Ode โดยใช้วิธีการของ Heunflaxdiff.samplers.RK4Sampler ): ตัวอย่าง Solver ODE โดยใช้วิธี Runge-Kuttaflaxdiff.samplers.MultiStepDPM ): ใช้วิธีการสุ่มตัวอย่างแบบหลายขั้นตอนที่ได้รับแรงบันดาลใจจากตัวแก้ปัญหา DPM แบบหลายขั้นตอนดังแสดงที่นี่: Tonyduan/Diffusion) นำไปใช้ใน flaxdiff.trainer :
flaxdiff.trainer.DiffusionTrainer ): ชั้นเรียนที่ออกแบบมาเพื่ออำนวยความสะดวกในการฝึกอบรมแบบจำลองการแพร่กระจาย มันจัดการลูปการฝึกอบรมการคำนวณการสูญเสียและการอัปเดตแบบจำลอง นำไปใช้ใน flaxdiff.models :
flaxdiff.models.simple_unet.SimpleUNet ): ตัวอย่างสถาปัตยกรรม UNET สำหรับแบบจำลองการแพร่กระจายflaxdiff.models.simple_unet.Upsample ), downsampling ( flaxdiff.models.simple_unet.Downsample ), การฝังเวลา ( flaxdiff.models.simple_unet.AttentionBlock flaxdiff.models.simple_unet.FouriedEmbedding ) และบล็อกที่เหลือ ( flaxdiff.models.simple_unet.ResidualBlock ) ในการติดตั้ง Flaxdiff คุณต้องมี Python 3.10 หรือสูงกว่า ติดตั้งการพึ่งพาที่ต้องการโดยใช้:
pip install -r requirements.txtแบบจำลองได้รับการฝึกฝนและทดสอบด้วย JAX == 0.4.28 และผ้าลินิน == 0.8.4 อย่างไรก็ตามเมื่อฉันอัปเดตเป็น Jax ล่าสุด == 0.4.30 และ Flax == 0.8.5 รุ่นหยุดการฝึกอบรม ดูเหมือนว่าจะมีการเปลี่ยนแปลงที่สำคัญบางอย่างทำลายพลวัตการฝึกอบรมและดังนั้นฉันขอแนะนำให้ติดกับเวอร์ชันที่กล่าวถึงในข้อกำหนด. txt
นี่คือตัวอย่างที่ง่ายขึ้นเพื่อให้คุณเริ่มต้นด้วยการฝึกอบรมแบบจำลองการแพร่กระจายโดยใช้ Flaxdiff:
from flaxdiff . schedulers import EDMNoiseScheduler
from flaxdiff . predictors import KarrasPredictionTransform
from flaxdiff . models . simple_unet import SimpleUNet as UNet
from flaxdiff . trainer import DiffusionTrainer
import jax
import optax
from datetime import datetime
BATCH_SIZE = 16
IMAGE_SIZE = 64
# Define noise scheduler
edm_schedule = EDMNoiseScheduler ( 1 , sigma_max = 80 , rho = 7 , sigma_data = 0.5 )
# Define model
unet = UNet ( emb_features = 256 ,
feature_depths = [ 64 , 128 , 256 , 512 ],
attention_configs = [{ "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }],
num_res_blocks = 2 ,
num_middle_res_blocks = 1 )
# Load dataset
data , datalen = get_dataset ( "oxford_flowers102" , batch_size = BATCH_SIZE , image_scale = IMAGE_SIZE )
batches = datalen // BATCH_SIZE
# Define optimizer
solver = optax . adam ( 2e-4 )
# Create trainer
trainer = DiffusionTrainer ( unet , optimizer = solver ,
noise_schedule = edm_schedule ,
rngs = jax . random . PRNGKey ( 4 ),
name = "Diffusion_SDE_VE_" + datetime . now (). strftime ( "%Y-%m-%d_%H:%M:%S" ),
model_output_transform = KarrasPredictionTransform ( sigma_data = edm_schedule . sigma_data ))
# Train the model
final_state = trainer . fit ( data , batches , epochs = 2000 )นี่คือตัวอย่างที่ง่ายสำหรับการสร้างภาพโดยใช้โมเดลที่ผ่านการฝึกอบรม:
from flaxdiff . samplers import DiffusionSampler
class EulerSampler ( DiffusionSampler ):
def take_next_step ( self , current_samples , reconstructed_samples , pred_noise , current_step , state , next_step = None ):
current_alpha , current_sigma = self . noise_schedule . get_rates ( current_step )
next_alpha , next_sigma = self . noise_schedule . get_rates ( next_step )
dt = next_sigma - current_sigma
x_0_coeff = ( current_alpha * next_sigma - next_alpha * current_sigma ) / dt
dx = ( current_samples - x_0_coeff * reconstructed_samples ) / current_sigma
next_samples = current_samples + dx * dt
return next_samples , state
# Create sampler
sampler = EulerSampler ( trainer . model , trainer . state . ema_params , edm_schedule , model_output_transform = trainer . model_output_transform )
# Generate images
samples = sampler . generate_images ( num_images = 64 , diffusion_steps = 100 , start_step = 1000 , end_step = 0 )
plotImages ( samples , dpi = 300 ) แบบจำลองที่ได้รับการฝึกฝนเกี่ยวกับ Laion-Aesthetics 12M + CC12M + MS Coco + 1M สุนทรียศาสตร์ 6+ ชุดย่อยของ Coyo-700m บน TPU-V4-32: a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden
Params : Dataset: Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M Batch size: 256 Image Size: 128 Training Epochs: 5 Steps per epoch: 74573 Model Configurations: feature_depths=[128, 256, 512, 1024]
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

ภาพที่สร้างขึ้นโดยการแจ้งต่อไปนี้โดยใช้คำแนะนำฟรีตัวจําแนกพร้อมปัจจัยคำแนะนำ = 2: 'water tulip, a water lily, a water lily, a water lily, a photo of a marigold, a water lily, a water lily, a photo of a lotus, a photo of a lotus, a photo of a lotus, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose'
Params : Dataset: oxford_flowers102 Batch size: 16 Image Size: 128 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

ภาพที่สร้างขึ้นโดยการแจ้งต่อไปนี้โดยใช้คำแนะนำฟรีตัวจําแนกพร้อมปัจจัยคำแนะนำ = 4: 'water tulip, a water lily, a water lily, a photo of a rose, a photo of a rose, a water lily, a water lily, a photo of a marigold, a photo of a marigold, a photo of a marigold, a water lily, a photo of a sunflower, a photo of a lotus, columbine, columbine, an orchid, an orchid, an orchid, a water lily, a water lily, a water lily, columbine, columbine, a photo of a sunflower, a photo of a sunflower, a photo of a sunflower, a photo of a lotus, a photo of a lotus, a photo of a marigold, a photo of a marigold, a photo of a rose, a photo of a rose, a photo of a rose, orange dahlia, orange dahlia, a lenten rose, a lenten rose, a water lily, a water lily, a water lily, a water lily, an orchid, an orchid, an orchid, hard-leaved pocket orchid, bird of paradise, bird of paradise, a photo of a lovely rose, a photo of a lovely rose, a photo of a globe-flower, a photo of a globe-flower, a photo of a lovely rose, a photo of a lovely rose, a photo of a ruby-lipped cattleya, a photo of a ruby-lipped cattleya, a photo of a lovely rose, a water lily, a osteospermum, a osteospermum, a water lily, a water lily, a water lily, a red rose, a red rose'
Params : Dataset: oxford_flowers102 Batch size: 16 Image Size: 128 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

Params : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: CosineNoiseSchedule Inference Noise Schedule: CosineNoiseSchedule
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

Params : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: CosineNoiseSchedule Inference Noise Schedule: CosineNoiseSchedule
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

Params : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

อย่าลังเลที่จะมีส่วนร่วมโดยการเปิดปัญหาหรือส่งคำขอดึง มาทำให้ Flaxdiff ดีขึ้นด้วยกัน!
โครงการนี้ได้รับใบอนุญาตภายใต้ใบอนุญาต MIT