高速Text2imageの生成のためのトランスベースのMuseモデルを再現するためのオープンリプロダクションの取り組み。
https://huggingface.co/spaces/openmuse/muse
このレポは、Museモデルの複製用です。目標は、シンプルでスケーラブルなリポジトリを作成し、ミューズを再現し、VQ + Transformersについての知識を大規模に構築することです。トレーニングには、deduped laion-2b + coyo-700mデータセットを使用します。
プロジェクト段階:
このプロジェクトのすべてのアーティファクトは、HuggingfaceハブのOpenMuse組織にアップロードされます。
最初に仮想環境を作成し、以下を使用してレポをインストールします。
git clone https://github.com/huggingface/muse
cd muse
pip install -e " .[extra] " PyTorchとtorchvision手動でインストールする必要があります。トレーニングには、 CUDA11.7でtorch==1.13.1を使用しています。
分散データ並列トレーニングの場合、 accelerate Libraryを使用しますが、これは将来変化する可能性があります。データセットの読み込みには、 webdatasetライブラリを使用します。したがって、データセットはwebdataset形式である必要があります。
Momemntでは、次のモデルをサポートしています。
MaskGitTransformerペーパーからのメイントランスモデル。MaskGitVQGAN -MaskGit RepoのVQGANモデル。VQGANModel -Taming Transformers RepoのVQGANモデル。モデルはmuse Directoryの下に実装されています。すべてのモデルは、おなじみのtransformers APIを実装します。そのため、モデルをロードおよび保存するために、 from_pretrainedおよびsave_pretrainedメソッドを使用できます。モデルは、Huggingfaceハブから保存してロードできます。
import torch
from torchvision import transforms
from PIL import Image
from muse import MaskGitVQGAN
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# encode and decode images using
encode_transform = = transforms . Compose (
[
transforms . Resize ( 256 , interpolation = transforms . InterpolationMode . BILINEAR ),
transforms . CenterCrop ( 256 ),
transforms . ToTensor (),
]
)
image = Image . open ( "..." ) #
pixel_values = encode_transform ( image ). unsqueeze ( 0 )
image_tokens = vq_model . encode ( pixel_values )
rec_image = vq_model . decode ( image_tokens )
# Convert to PIL images
rec_image = 2.0 * rec_image - 1.0
rec_image = torch . clamp ( rec_image , - 1.0 , 1.0 )
rec_image = ( rec_image + 1.0 ) / 2.0
rec_image *= 255.0
rec_image = rec_image . permute ( 0 , 2 , 3 , 1 ). cpu (). numpy (). astype ( np . uint8 )
pil_images = [ Image . fromarray ( image ) for image in rec_image ] import torch
from muse import MaskGitTransformer , MaskGitVQGAN
from muse . sampling import cosine_schedule
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# Initialize the MaskGitTransformer model
maskgit_model = MaskGitTransformer (
vocab_size = 2025 , #(1024 + 1000 + 1 = 2025 -> Vq_tokens + Imagenet class ids + <mask>)
max_position_embeddings = 257 , # 256 + 1 for class token
hidden_size = 512 ,
num_hidden_layers = 8 ,
num_attention_heads = 8 ,
intermediate_size = 2048 ,
codebook_size = 1024 ,
num_vq_tokens = 256 ,
num_classes = 1000 ,
)
# prepare the input batch
images = torch . randn ( 4 , 3 , 256 , 256 )
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
# encode the images
image_tokens = vq_model . encode ( images )
batch_size , seq_len = image_tokens . shape
# Sample a random timestep for each image
timesteps = torch . rand ( batch_size , device = image_tokens . device )
# Sample a random mask probability for each image using timestep and cosine schedule
mask_prob = cosine_schedule ( timesteps )
mask_prob = mask_prob . clip ( min_masking_rate )
# creat a random mask for each image
num_token_masked = ( seq_len * mask_prob ). round (). clamp ( min = 1 )
batch_randperm = torch . rand ( batch_size , seq_len , device = image_tokens . device ). argsort ( dim = - 1 )
mask = batch_randperm < num_token_masked . unsqueeze ( - 1 )
# mask images and create input and labels
input_ids = torch . where ( mask , mask_id , image_tokens )
labels = torch . where ( mask , image_tokens , - 100 )
# shift the class ids by codebook size
class_ids = class_ids + vq_model . num_embeddings
# prepend the class ids to the image tokens
input_ids = torch . cat ([ class_ids . unsqueeze ( - 1 ), input_ids ], dim = - 1 )
# prepend -100 to the labels as we don't want to predict the class ids
labels = torch . cat ([ - 100 * torch . ones_like ( class_ids ). unsqueeze ( - 1 ), labels ], dim = - 1 )
# forward pass
logits , loss = maskgit_model ( input_ids , labels = labels )
loss . backward ()
# to generate images
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
generated_tokens = maskgit_model . generate ( class_ids = class_ids )
rec_images = vq_model . decode ( generated_tokens )注記:
MaskGitsは、VQとクラス条件付きのラベルトークンの両方のトークンのシーケンスを与えられたロジットを出力するトランスです
除去プロセスが行われる方法は、マスクトークンIDをマスクし、徐々に除去することです
元の実装では、これは最初に最後のDIMでソフトマックスを使用して、カテゴリの分布としてランダムにサンプリングすることによって行われます。これにより、各マスキッドの予測トークンが得られます。次に、それらのトークンが選択される確率を取得します。最後に、Gumbel*Tempが追加されると、Topk最高の信頼確率が得られます。 Gumbel分布は、極端なイベントをモデル化するために使用される0への正規分布のシフトのようなものです。したがって、極端なシナリオでは、デフォルトのトークンから別のトークンが選択されているのを見たいと思います
Lucidrianの実装では、最初に、特定のマスキング比でそれらをマスキングすることにより、最高スコア(最低確率)トークンを削除します。次に、私たちが取得するロジットの最高10%を除いて、それを-infinityに設定しているので、ガンベル分布を行うと、それらは無視されます。次に、入力IDとスコアがわずか1であるスコアを更新します。
クラス条件付きイメージネットの場合、DDPトレーニングにaccelerate 、データ読み込みにwebdataset使用しています。トレーニングスクリプトはtraining/train_maskgit_imagenet.pyで利用できます。
構成管理にはOmegaconfを使用します。構成テンプレートについては、 configs/template_config.yamlを参照してください。以下に、構成パラメーターについて説明します。
wandb :
entity : ???
experiment :
name : ???
project : ???
output_dir : ???
max_train_examples : ???
save_every : 1000
eval_every : 500
generate_every : 1000
log_every : 50
log_grad_norm_every : 100
resume_from_checkpoint : latest
model :
vq_model :
pretrained : " openMUSE/maskgit-vqgan-imagenet-f16-256 "
transformer :
vocab_size : 2048 # (1024 + 1000 + 1 = 2025 -> Vq + Imagenet + <mask>, use 2048 for even division by 8)
max_position_embeddings : 264 # (256 + 1 for class id, use 264 for even division by 8)
hidden_size : 768
num_hidden_layers : 12
num_attention_heads : 12
intermediate_size : 3072
codebook_size : 1024
num_vq_tokens : 256
num_classes : 1000
initializer_range : 0.02
layer_norm_eps : 1e-6
use_bias : False
use_normformer : True
use_encoder_layernorm : True
hidden_dropout : 0.0
attention_dropout : 0.0
gradient_checkpointing : True
enable_xformers_memory_efficient_attention : False
dataset :
params :
train_shards_path_or_url : ???
eval_shards_path_or_url : ???
batch_size : ${training.batch_size}
shuffle_buffer_size : ???
num_workers : ???
resolution : 256
pin_memory : True
persistent_workers : True
preprocessing :
resolution : 256
center_crop : True
random_flip : False
optimizer :
name : adamw # Can be adamw or lion or fused_adamw. Install apex for fused_adamw
params : # default adamw params
learning_rate : ???
scale_lr : False # scale learning rate by total batch size
beta1 : 0.9
beta2 : 0.999
weight_decay : 0.01
epsilon : 1e-8
lr_scheduler :
scheduler : " constant_with_warmup "
params :
learning_rate : ${optimizer.params.learning_rate}
warmup_steps : 500
training :
gradient_accumulation_steps : 1
batch_size : 128
mixed_precision : " no "
enable_tf32 : True
use_ema : False
seed : 42
max_train_steps : ???
overfit_one_batch : False
min_masking_rate : 0.0
label_smoothing : 0.0
max_grad_norm : nullとの議論?必要です。
wandb :
wandb.entity :ロギングに使用するwandbエンティティ。実験:
experiment.name :実験の名前。experiment.project :ロギングに使用するWANDBプロジェクト。experiment.output_dir :チェックポイントを保存するディレクトリ。experiment.max_train_examples :使用するトレーニング例の最大数。experiment.save_every :すべてのsave_everyステップごとにチェックポイントを保存します。experiment.eval_every :すべてのeval_everyステップをすべてモデルを評価します。experiment.generate_every :すべてのgenerate_everyステップを生成します。experiment.log_every : log_everyすべての手順ごとにトレーニングメトリックをログに記録します。log_grad_norm_every :gradient normを記録してlog_grad_norm_every 。experiment.resume_from_checkpoint :トレーニングを再開するチェックポイント。保存されたチェックポイントまでの最新のチェックポイントまたはパスから再開するlatest場合があります。 None場合、またはパスが存在しない場合、トレーニングはゼロから始まります。モデル:
model.vq_model.pretrained :使用する前提条件のVQモデル。保存されたチェックポイントへのパスまたはハギングフェイスモデル名にすることができます。model.transformer :トランスモデルの構成。model.gradient_checkpointing :トランスモデルのグラデーションチェックポイントを有効にします。enable_xformers_memory_efficient_attention :トランスモデルのメモリ効率的な注意またはフラッシュ注意を有効にします。 Flashの注意のために、 fp16またはbf16使用する必要があります。これを機能させるには、Xformersをインストールする必要があります。データセット:
dataset.params.train_shards_path_or_url : webdatasetトレーニングシャードへのパスまたはURL。dataset.params.eval_shards_path_or_url : webdataset評価シャードへのパスまたはURL。dataset.params.batch_size :トレーニングに使用するバッチサイズ。dataset.params.shuffle_buffer_size :トレーニングに使用するシャッフルバッファサイズ。dataset.params.num_workers :データロードに使用する労働者の数。dataset.params.resolution :トレーニングに使用する画像の解像度。dataset.params.pin_memory :データ読み込みのメモリをピン留めします。dataset.params.persistent_workers :データ読み込みには、持続的なワーカーを使用します。dataset.preprocessing.resolution :前処理に使用する画像の解像度。dataset.preprocessing.center_crop :画像のトリミングを中心にするかどうか。 Falseの場合、画像は解像度にresolutionにトリミングされます。dataset.preprocessing.random_flip :画像をランダムにフリップするかどうか。 Falseの場合、画像は反転しません。オプティマイザー:
optimizer.name :トレーニングに使用するオプティマイザー。optimizer.params :オプティマイザーパラメーター。lr_scheduler :
lr_scheduler.scheduler :トレーニングに使用する学習率スケジューラ。lr_scheduler.params :学習率スケジューラパラメーター。トレーニング:
training.gradient_accumulation_steps :トレーニングに使用する勾配蓄積手順の数。training.batch_size :トレーニングに使用するバッチサイズ。training.mixed_precision :トレーニングに使用する混合精度モード。 no 、 fp16またはbf16にすることができます。training.enable_tf32 :アンペアGPUでのトレーニングにTF32を有効にします。training.use_ema :トレーニングにEMAを有効にします。現在サポートされていません。training.seed :トレーニングに使用する種子。training.max_train_steps :トレーニングステップの最大数。training.overfit_one_batch :デバッグのために1つのバッチをオーバーフィットするかどうか。training.min_masking_rate :トレーニングに使用する最小マスキングレート。training.label_smoothing :トレーニングに使用するラベルのスムージング値。max_grad_norm :Max Gradient Norm。トレーニングとデータセットに関するメモ。 :
再開/トレーニングの実行を再開/開始するたびに、トレーニングのためにバッファーのサンプルとサンプルの例をランダムに(交換します)。これは、データの読み込みが決定的ではないことを意味します。また、エポックベースのトレーニングではありませんが、これを帳簿保持に使用して、他のデータセット/ローダーと同じトレーニングループを再利用できるようにします。
これまでのところ、単一ノードで実験を実行しています。単一のノードでトレーニング実行を開始するには、次の手順を実行します。
webdataset形式でデータセットを準備します。 scripts/convert_imagenet_to_wds.pyスクリプトを使用して、imagenetデータセットをwebdataset形式に変換できます。accelerate configを使用してトレーニングEnvを構成します。config.yamlファイルを作成します。accelerate launchを使用してトレーニングランを起動します。 accelerate launch python -u training/train_maskgit_imagenet.py config=path/to/yaml/config Omegaconfを使用すると、コマンドラインオーバーライドはドットノーテーション形式で行われます。たとえば、データセットパスをオーバーライドする場合は、コマンドpython -u train.py config=path/to/config dataset.params.path=path/to/datasetを使用します。
同じコマンドを使用して、ローカルでトレーニングを開始できます。
├── README.md
├── configs -> All training config files.
│ └── template_config.yaml
├── muse
│ ├── __init__.py
│ ├── data.py -> All data related utils. Can create a data folder if needed.
│ ├── logging.py -> Misc logging utils.
| ├── lr_schedulers.py -> All lr scheduler related utils.
│ ├── modeling_maskgit_vqgan.py -> VQGAN model from maskgit repo.
│ ├── modeling_taming_vqgan.py -> VQGAN model from taming repo.
│ └── modeling_transformer.py -> The main transformer model.
│ ├── modeling_utils.py -> All model related utils, like save_pretrained, from_pretrained from hub etc
│ ├── sampling.py -> Sampling/Generation utils.
│ ├── training_utils.py -> Common training utils.
├── pyproject.toml
├── setup.cfg
├── setup.py
├── test.py
└── training -> All training scripts.
├── __init__.py
├── data.py -> All data related utils. Can create a data folder if needed.
├── optimizer.py -> All optimizer related utils and any new optimizer not available in PT.
├── train_maskgit_imagenet.py
├── train_muse.py
└── train_vqgan.py
このプロジェクトは、次のオープンソースリポジトリに激しく基づいています。驚くべき仕事をしてくれたすべての著者に感謝します。
そして、この驚くべきフレームワークのためにPytorchチームにobivioulsyを