
이 프로젝트는 Google TPU Research Cloud에서 부분적으로 지원됩니다. Google Cloud TPU 팀에게 다중 호스트 분산 설정에서 더 큰 텍스트 조건 모델을 훈련시킬 수있는 리소스를 제공해 주셔서 감사합니다.
최근에는 확산 및 점수 기반 다중 단계 모델이 생성 AI 도메인에 혁명을 일으켰습니다. 그러나이 분야의 최신 연구는 수학 집약적이어서 최첨단 확산 모델이 어떻게 작동하고 그러한 인상적인 이미지를 생성하는지 이해하기가 어렵습니다. 코드 에서이 연구를 복제하는 것은 어려울 수 있습니다.
Flaxdiff는 이해하기 쉬운 방식으로 설계 및 구현 된 도구 라이브러리 (스케줄러, 샘플러, 모델 등)입니다. 초점은 성능에 대한 이해 성과 가독성에 있습니다. 나는이 프로젝트를 아마와 Jax에 익숙해지고 확산 및 생성 AI의 최신 연구에 대해 배우는 취미로 시작했습니다.
나는 처음에 Keras 에서이 프로젝트를 시작하여 Tensorflow 2.0에 익숙하지만 성능과 사용 편의성을 위해 Jax에 의해 구동되는 Flax로 전환되었습니다. 첫 번째 아마 모델을 포함한 오래된 노트와 모델도 제공됩니다.
Diffusion_flax_linen.ipynb 노트북은 실험을위한 나의 주요 작업 공간입니다. 각 체크 포인트와 관련된 작업 노트북의 사본과 함께 몇 가지 체크 포인트가 pretrained 폴더에 업로드됩니다. 제대로 작동하기 위해 노트북을 작업 루트에 복사해야 할 수도 있습니다.
example notebooks 폴더에서는 처음부터 처음부터 작성되었으며 Flaxdiff 라이브러리와 독립적 인 다양한 확산 기술에 대한 포괄적 인 노트북을 찾을 수 있습니다. 각 노트북에는 기본 수학 및 개념에 대한 자세한 설명이 포함되어있어 확산 모델을 학습하고 이해하는 데 귀중한 자원이 있습니다.
확산 설명 (NBViewer Link) (로컬 링크)
EDM (확산 기반 생성 모델의 설계 공간을 설명)
이 노트북은 다양한 확산 모델과 기술에 대한 이해하기 쉽고 단계별 가이드를 제공하는 것을 목표로합니다. 그들은 초보자에게 친숙하게 설계되었으므로 원래 논문의 정확한 제형과 구현을 고수하여 더 이해하기 쉽고 일반화 할 수있게 만들지는 않지만 가능한 한 정확하게 유지하기 위해 최선을 다했습니다. 실수가 있거나 제안이 있으면 문제 또는 풀 요청을 자유롭게 열어주십시오.
JAX의 다중 호스트 데이터 병렬 교육 스크립트
삶을 더 편하게 만드는 TPU 유틸리티
저는 2019-2021 년부터 Hyperverge에서 기계 학습 연구원으로 일하면서 컴퓨터 비전, 특히 안면 스푸핑 및 안면 탐지 및 인식에 중점을 둡니다. 2021 년에 현재 직장으로 전환 한 이후로, 나는 많은 R & D 작업에 참여하지 않았으며,이 애완 동물 프로젝트를 시작하여 기초를 다시 방문하고 배우고 최첨단에 익숙해졌습니다. 내 현재의 역할에는 주로 일부 적용된 ML 작업이 방금 뿌려진 Golang 시스템 엔지니어링과 관련이 있습니다. 따라서 코드는 내 학습 여정을 반영 할 수 있습니다. 실수를 용서하고 제게 알려주기 위해 문제를 열어주세요.
또한 Github Copilot의 도움으로 텍스트 중 일부는 생성 될 수 있으므로 텍스트의 실수를 실례합니다.
flaxdiff.schedulers 에서 구현 :
flaxdiff.schedulers.LinearNoiseSchedule ) : 베타 가라 모터 이산 스케줄러.flaxdiff.schedulers.CosineNoiseSchedule ) : 베타 패러 아메리카 화 된 이산 스케줄러.flaxdiff.schedulers.ExpNoiseSchedule ) : 베타 가라 모터 이산 스케줄러.flaxdiff.schedulers.CosineContinuousNoiseScheduler ) : 연속 스케줄러.flaxdiff.schedulers.CosineGeneralNoiseScheduler ) : 연속 Sigma 매개 변수화 된 코사인 스케줄러.flaxdiff.schedulers.KarrasVENoiseScheduler ) : Karras et al. 2022, 추론에 가장 적합합니다.flaxdiff.schedulers.EDMNoiseScheduler ) : 지수 확산 모델 (EDM)을 기반으로 한 시그마-파라미터 연속 스케줄러, Karraskarrasvenoisescheduler와의 훈련에 가장 적합합니다. flaxdiff.predictors 에서 구현 :
flaxdiff.predictors.EpsilonPredictor ) : 데이터의 노이즈를 예측합니다.flaxdiff.predictors.X0Predictor ) : 시끄러운 데이터에서 원래 데이터를 예측합니다.flaxdiff.predictors.VPredictor ) : EDM에 일반적으로 사용되는 데이터와 노이즈의 선형 조합을 예측합니다.flaxdiff.predictors.KarrasEDMPredictor ) : EDM의 일반화 된 예측 변수, 다양한 매개 변수화를 통합합니다. flaxdiff.samplers 에서 구현 :
flaxdiff.samplers.DDPMSampler ) : dedoising 확산 확률 모델 (DDPM) 샘플링 프로세스를 구현합니다.flaxdiff.samplers.DDIMSampler ) : DDIM (Denoising Fiffusion Complicit Model) 샘플링 프로세스를 구현합니다.flaxdiff.samplers.EulerSampler ) : Euler의 방법을 사용하는 ODE 솔버 샘플러.flaxdiff.samplers.HeunSampler ) : Heun 's Method를 사용하는 ODE 솔버 샘플러.flaxdiff.samplers.RK4Sampler ) : runge-kutta 메소드를 사용하는 ODE 솔버 샘플러.flaxdiff.samplers.MultiStepDPM ) : 여기에 제시된대로 MultiStep DPM 솔버에서 영감을 얻은 멀티 단계 샘플링 방법을 구현하십시오 : Tonyduan/Diffusion) flaxdiff.trainer 에서 구현 :
flaxdiff.trainer.DiffusionTrainer ) : 확산 모델의 훈련을 용이하게하도록 설계된 클래스. 교육 루프, 손실 계산 및 모델 업데이트를 관리합니다. flaxdiff.models 에서 구현 :
flaxdiff.models.simple_unet.SimpleUNet ) : 확산 모델을위한 샘플 UNET 아키텍처.flaxdiff.models.simple_unet.Upsample ), 다운 샘플링 ( flaxdiff.models.simple_unet.Downsample ), 시간 내장 ( flaxdiff.models.simple_unet.FouriedEmbedding ),주의 ( flaxdiff.models.simple_unet.AttentionBlock )를 포함한 레이어 라이브러리. 잔여 블록 ( flaxdiff.models.simple_unet.ResidualBlock ). Flaxdiff를 설치하려면 Python 3.10 이상이 필요합니다. 다음을 사용하여 필요한 종속성을 설치하십시오.
pip install -r requirements.txt모델은 Jax == 0.4.28 및 flax == 0.8.4로 훈련 및 테스트되었습니다. 그러나 최신 JAX == 0.4.30 및 Flax == 0.8.5로 업데이트했을 때 모델은 훈련을 중단했습니다. 훈련 역학을 깨는 데 큰 변화가 있었으므로 요구 사항에 언급 된 버전을 고수하는 것이 좋습니다.
다음은 FlaxDiff를 사용하여 확산 모델을 훈련시키는 것을 시작하기위한 단순화 된 예입니다.
from flaxdiff . schedulers import EDMNoiseScheduler
from flaxdiff . predictors import KarrasPredictionTransform
from flaxdiff . models . simple_unet import SimpleUNet as UNet
from flaxdiff . trainer import DiffusionTrainer
import jax
import optax
from datetime import datetime
BATCH_SIZE = 16
IMAGE_SIZE = 64
# Define noise scheduler
edm_schedule = EDMNoiseScheduler ( 1 , sigma_max = 80 , rho = 7 , sigma_data = 0.5 )
# Define model
unet = UNet ( emb_features = 256 ,
feature_depths = [ 64 , 128 , 256 , 512 ],
attention_configs = [{ "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }],
num_res_blocks = 2 ,
num_middle_res_blocks = 1 )
# Load dataset
data , datalen = get_dataset ( "oxford_flowers102" , batch_size = BATCH_SIZE , image_scale = IMAGE_SIZE )
batches = datalen // BATCH_SIZE
# Define optimizer
solver = optax . adam ( 2e-4 )
# Create trainer
trainer = DiffusionTrainer ( unet , optimizer = solver ,
noise_schedule = edm_schedule ,
rngs = jax . random . PRNGKey ( 4 ),
name = "Diffusion_SDE_VE_" + datetime . now (). strftime ( "%Y-%m-%d_%H:%M:%S" ),
model_output_transform = KarrasPredictionTransform ( sigma_data = edm_schedule . sigma_data ))
# Train the model
final_state = trainer . fit ( data , batches , epochs = 2000 )다음은 훈련 된 모델을 사용하여 이미지를 생성하기위한 단순화 된 예입니다.
from flaxdiff . samplers import DiffusionSampler
class EulerSampler ( DiffusionSampler ):
def take_next_step ( self , current_samples , reconstructed_samples , pred_noise , current_step , state , next_step = None ):
current_alpha , current_sigma = self . noise_schedule . get_rates ( current_step )
next_alpha , next_sigma = self . noise_schedule . get_rates ( next_step )
dt = next_sigma - current_sigma
x_0_coeff = ( current_alpha * next_sigma - next_alpha * current_sigma ) / dt
dx = ( current_samples - x_0_coeff * reconstructed_samples ) / current_sigma
next_samples = current_samples + dx * dt
return next_samples , state
# Create sampler
sampler = EulerSampler ( trainer . model , trainer . state . ema_params , edm_schedule , model_output_transform = trainer . model_output_transform )
# Generate images
samples = sampler . generate_images ( num_images = 64 , diffusion_steps = 100 , start_step = 1000 , end_step = 0 )
plotImages ( samples , dpi = 300 ) Laion-Aesthetics 12m + CC12m + MS Coco + 1M 미학적 6+ 6+ 코요 -700m의 코요 -700m의 하위 집합 : a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden
매개 변수 : Dataset: Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M Batch size: 256 Image Size: 128 Training Epochs: 5 Steps per epoch: 74573 Model Configurations: feature_depths=[128, 256, 512, 1024]
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

다음과 같은 프롬프트에 의해 생성 된 이미지가 지침 요소를 사용하여 무료 지침을 사용하여 다음과 같은 프롬프트 = 2 : 'water tulip, a water lily, a water lily, a water lily, a photo of a marigold, a water lily, a water lily, a photo of a lotus, a photo of a lotus, a photo of a lotus, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose'
매개 변수 : Dataset: oxford_flowers102 Batch size: 16 Image Size: 128 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

다음과 같은 프롬프트에 의해 생성 된 이미지가 지침 요소를 사용하여 무료 지침 = 4 : 'water tulip, a water lily, a water lily, a photo of a rose, a photo of a rose, a water lily, a water lily, a photo of a marigold, a photo of a marigold, a photo of a marigold, a water lily, a photo of a sunflower, a photo of a lotus, columbine, columbine, an orchid, an orchid, an orchid, a water lily, a water lily, a water lily, columbine, columbine, a photo of a sunflower, a photo of a sunflower, a photo of a sunflower, a photo of a lotus, a photo of a lotus, a photo of a marigold, a photo of a marigold, a photo of a rose, a photo of a rose, a photo of a rose, orange dahlia, orange dahlia, a lenten rose, a lenten rose, a water lily, a water lily, a water lily, a water lily, an orchid, an orchid, an orchid, hard-leaved pocket orchid, bird of paradise, bird of paradise, a photo of a lovely rose, a photo of a lovely rose, a photo of a globe-flower, a photo of a globe-flower, a photo of a lovely rose, a photo of a lovely rose, a photo of a ruby-lipped cattleya, a photo of a ruby-lipped cattleya, a photo of a lovely rose, a water lily, a osteospermum, a osteospermum, a water lily, a water lily, a water lily, a red rose, a red rose'
매개 변수 : Dataset: oxford_flowers102 Batch size: 16 Image Size: 128 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

매개 변수 : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: CosineNoiseSchedule Inference Noise Schedule: CosineNoiseSchedule
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

매개 변수 : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: CosineNoiseSchedule Inference Noise Schedule: CosineNoiseSchedule
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

매개 변수 : Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

문제를 열거 나 풀 요청을 제출하여 자유롭게 기여하십시오. Flaxdiff를 더 잘 만들어 보자!
이 프로젝트는 MIT 라이센스에 따라 라이센스가 부여됩니다.