
このプロジェクトは、Google TPU Research Cloudによって部分的にサポートされています。 Google Cloud TPUチームに、マルチホスト分散設定でより大きなテキストコンディショナルモデルをトレーニングするためのリソースを提供してくれたことに感謝します。
近年、拡散とスコアベースのマルチステップモデルが生成AIドメインに革命をもたらしました。しかし、この分野での最新の研究は非常に数学集約型であり、最先端の拡散モデルがどのように機能し、そのような印象的な画像を生成するかを理解することは困難です。この研究をコードで複製するのは気が遠くなる可能性があります。
FlaxDiffは、わかりやすい方法で設計および実装されたツール(スケジューラー、サンプラー、モデルなど)のライブラリです。焦点は、パフォーマンスに対する理解可能性と読みやすさに焦点を当てています。私はこのプロジェクトを、亜麻とJaxに精通し、拡散と生成AIの最新の研究について学ぶための趣味として始めました。
私は最初にこのプロジェクトをKerasで開始しました。Tensorflow2.0に精通していましたが、そのパフォーマンスと使いやすさのためにJaxを搭載したFlaxに移行しました。私の最初の亜麻モデルを含む古いノートブックとモデルも提供されています。
Diffusion_flax_linen.ipynbノートブックは、実験用の私の主なワークスペースです。いくつかのチェックポイントは、各チェックポイントに関連付けられている作業ノートブックのコピーとともに、 pretrainedフォルダーにアップロードされます。ノートブックを正常に機能させるために作業ルートにコピーする必要がある場合があります。
example notebooksには、完全にゼロから書かれ、FlaxDiffライブラリから独立しているさまざまな拡散技術の包括的なノートブックがあります。各ノートブックには、基礎となる数学と概念の詳細な説明が含まれており、拡散モデルを学習および理解するための非常に貴重なリソースになります。
拡散説明(nbviewerリンク)(ローカルリンク)
EDM(拡散ベースの生成モデルの設計空間の解明)
これらのノートブックは、さまざまな拡散モデルとテクニックを理解しやすく、段階的なガイドを提供することを目的としています。彼らは初心者に優しいように設計されているため、元の論文の正確な定式化と実装を遵守して、より理解しやすく一般化できるようにすることはできませんが、可能な限り正確に保つように最善を尽くしました。間違いを見つけたり、提案がある場合は、お気軽に問題やプルリクエストを開いてください。
JAXのマルチホストデータ並列トレーニングスクリプト
ライフを楽にするためのTPUユーティリティ
私は、2019-2021のHypervergeで機械学習研究者として働いており、コンピュータービジョン、特に顔のアンチスポーフィングと顔の検出と認識に焦点を当てていました。 2021年に現在の仕事に切り替えて以来、私はそれほど多くのR&Dの仕事に従事していないので、このペットプロジェクトを開始して基本を再訪して再学習し、最先端に精通しています。私の現在の役割には、主にGolangシステムエンジニアリングが含まれ、いくつかの適用されたML作業が散らばっています。したがって、コードは私の学習の旅を反映している可能性があります。間違いを許して、問題を開いて私に知らせてください。
また、Github Copilotの助けを借りて生成されるテキストはほとんどないので、テキストの間違いを弁解してください。
flaxdiff.schedulersで実装:
flaxdiff.schedulers.LinearNoiseSchedule ):ベータパラメータ化された離散スケジューラ。flaxdiff.schedulers.CosineNoiseSchedule ):ベータパラメータ化された離散スケジューラ。flaxdiff.schedulers.ExpNoiseSchedule ):ベータパラメータ化された離散スケジューラ。flaxdiff.schedulers.CosineContinuousNoiseScheduler ):連続スケジューラー。flaxdiff.schedulers.CosineGeneralNoiseScheduler ):連続したSigmaパラメーター化されたコサインスケジューラ。flaxdiff.schedulers.KarrasVENoiseScheduler ):Karras et al。 2022年、推論に最適です。flaxdiff.schedulers.EDMNoiseScheduler ):Karraskarrasvenoiseschedulerでのトレーニングに最適な指数関数拡散モデル(EDM)に基づいたSigmaパラメータ化された連続スケジューラー。flaxdiff.predictorsで実装:
flaxdiff.predictors.EpsilonPredictor ):データのノイズを予測します。flaxdiff.predictors.X0Predictor ):騒々しいデータから元のデータを予測します。flaxdiff.predictors.VPredictor ):EDMで一般的に使用されるデータとノイズの線形組み合わせを予測します。flaxdiff.predictors.KarrasEDMPredictor ):EDMの一般化された予測因子、さまざまなパラメーター化を統合します。flaxdiff.samplersで実装:
flaxdiff.samplers.DDPMSampler ):拡散拡散確率モデル(DDPM)サンプリングプロセスを除去します。flaxdiff.samplers.DDIMSampler ):拡散拡散暗黙モデル(DDIM)サンプリングプロセスを実装します。flaxdiff.samplers.EulerSampler ):eulerの方法を使用したODEソルバーサンプラー。flaxdiff.samplers.HeunSampler ):Heunの方法を使用したODEソルバーサンプラー。flaxdiff.samplers.RK4Sampler ):Runge-Kuttaメソッドを使用したODEソルバーサンプラー。flaxdiff.samplers.MultiStepDPM ):ここに示されている多段階DPMソルバーに触発されたマルチステップサンプリング方法を実装:Tonyduan/diffusion)flaxdiff.trainerで実装:
flaxdiff.trainer.DiffusionTrainer ):拡散モデルのトレーニングを促進するように設計されたクラス。トレーニングループ、損失計算、およびモデルの更新を管理します。flaxdiff.modelsで実装:
flaxdiff.models.simple_unet.SimpleUNet ):拡散モデル用のサンプルUNETアーキテクチャ。flaxdiff.models.simple_unet.Upsample ), downsampling ( flaxdiff.models.simple_unet.Downsample ), Time embeddings ( flaxdiff.models.simple_unet.FouriedEmbedding ), attention ( flaxdiff.models.simple_unet.AttentionBlock ), and残留ブロック( flaxdiff.models.simple_unet.ResidualBlock )。 FlaxDiffをインストールするには、Python 3.10以降が必要です。以下を使用して、必要な依存関係をインストールします。
pip install -r requirements.txtモデルを訓練し、JAX == 0.4.28およびFlax == 0.8.4でテストしました。ただし、最新のjax == 0.4.30およびFlax == 0.8.5に更新したとき、モデルはトレーニングを停止しました。トレーニングのダイナミクスを破るいくつかの大きな変化があったようです。したがって、要件に記載されているバージョンに固執することをお勧めします。
FlaxDiffを使用して拡散モデルのトレーニングを開始するための簡略化された例を次に示します。
from flaxdiff . schedulers import EDMNoiseScheduler
from flaxdiff . predictors import KarrasPredictionTransform
from flaxdiff . models . simple_unet import SimpleUNet as UNet
from flaxdiff . trainer import DiffusionTrainer
import jax
import optax
from datetime import datetime
BATCH_SIZE = 16
IMAGE_SIZE = 64
# Define noise scheduler
edm_schedule = EDMNoiseScheduler ( 1 , sigma_max = 80 , rho = 7 , sigma_data = 0.5 )
# Define model
unet = UNet ( emb_features = 256 ,
feature_depths = [ 64 , 128 , 256 , 512 ],
attention_configs = [{ "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }, { "heads" : 4 }],
num_res_blocks = 2 ,
num_middle_res_blocks = 1 )
# Load dataset
data , datalen = get_dataset ( "oxford_flowers102" , batch_size = BATCH_SIZE , image_scale = IMAGE_SIZE )
batches = datalen // BATCH_SIZE
# Define optimizer
solver = optax . adam ( 2e-4 )
# Create trainer
trainer = DiffusionTrainer ( unet , optimizer = solver ,
noise_schedule = edm_schedule ,
rngs = jax . random . PRNGKey ( 4 ),
name = "Diffusion_SDE_VE_" + datetime . now (). strftime ( "%Y-%m-%d_%H:%M:%S" ),
model_output_transform = KarrasPredictionTransform ( sigma_data = edm_schedule . sigma_data ))
# Train the model
final_state = trainer . fit ( data , batches , epochs = 2000 )訓練されたモデルを使用して画像を生成するための簡略化された例を次に示します。
from flaxdiff . samplers import DiffusionSampler
class EulerSampler ( DiffusionSampler ):
def take_next_step ( self , current_samples , reconstructed_samples , pred_noise , current_step , state , next_step = None ):
current_alpha , current_sigma = self . noise_schedule . get_rates ( current_step )
next_alpha , next_sigma = self . noise_schedule . get_rates ( next_step )
dt = next_sigma - current_sigma
x_0_coeff = ( current_alpha * next_sigma - next_alpha * current_sigma ) / dt
dx = ( current_samples - x_0_coeff * reconstructed_samples ) / current_sigma
next_samples = current_samples + dx * dt
return next_samples , state
# Create sampler
sampler = EulerSampler ( trainer . model , trainer . state . ema_params , edm_schedule , model_output_transform = trainer . model_output_transform )
# Generate images
samples = sampler . generate_images ( num_images = 64 , diffusion_steps = 100 , start_step = 1000 , end_step = 0 )
plotImages ( samples , dpi = 300 ) Laion-Aesthetics 12m + CC12M + MS Coco + 1Mの美学6+サブセットでトレーニングされたモデルはa beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden -V4-32でコヨ-700mのコヨ-700mのサブセットです。 a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden
PARAMS : Dataset: Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M Batch size: 256 Image Size: 128 Training Epochs: 5 Steps per epoch: 74573 Model Configurations: feature_depths=[128, 256, 512, 1024]
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

ガイダンスファクターを使用した分類器フリーガイダンスを使用して次のプロンプトによって生成された画像= 2: 'water tulip, a water lily, a water lily, a water lily, a photo of a marigold, a water lily, a water lily, a photo of a lotus, a photo of a lotus, a photo of a lotus, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose'
パラメージ: Dataset: oxford_flowers102 Batch size: 16 Image Size: 128 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

ガイダンスファクターを使用した分類器フリーガイダンスを使用して次のプロンプトによって生成された画像= 4: 'water tulip, a water lily, a water lily, a photo of a rose, a photo of a rose, a water lily, a water lily, a photo of a marigold, a photo of a marigold, a photo of a marigold, a water lily, a photo of a sunflower, a photo of a lotus, columbine, columbine, an orchid, an orchid, an orchid, a water lily, a water lily, a water lily, columbine, columbine, a photo of a sunflower, a photo of a sunflower, a photo of a sunflower, a photo of a lotus, a photo of a lotus, a photo of a marigold, a photo of a marigold, a photo of a rose, a photo of a rose, a photo of a rose, orange dahlia, orange dahlia, a lenten rose, a lenten rose, a water lily, a water lily, a water lily, a water lily, an orchid, an orchid, an orchid, hard-leaved pocket orchid, bird of paradise, bird of paradise, a photo of a lovely rose, a photo of a lovely rose, a photo of a globe-flower, a photo of a globe-flower, a photo of a lovely rose, a photo of a lovely rose, a photo of a ruby-lipped cattleya, a photo of a ruby-lipped cattleya, a photo of a lovely rose, a water lily, a osteospermum, a osteospermum, a water lily, a water lily, a water lily, a red rose, a red rose'
パラメージ: Dataset: oxford_flowers102 Batch size: 16 Image Size: 128 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor

パラメージ: Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: CosineNoiseSchedule Inference Noise Schedule: CosineNoiseSchedule
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

パラメージ: Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: CosineNoiseSchedule Inference Noise Schedule: CosineNoiseSchedule
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

パラメージ: Dataset: oxford_flowers102 Batch size: 16 Image Size: 64 Training Epochs: 1000 Steps per epoch: 511
Training Noise Schedule: EDMNoiseScheduler Inference Noise Schedule: KarrasEDMPredictor
Model: UNet(emb_features=256, feature_depths=[64, 128, 256, 512], attention_configs=[{"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}, {"heads":4}], num_res_blocks=2, num_middle_res_blocks=1)

問題を開始したり、プルリクエストを提出したりすることで、お気軽に貢献してください。 flaxdiffをより良くしましょう!
このプロジェクトは、MITライセンスの下でライセンスされています。