Upaya reproduksi terbuka untuk mereproduksi model MUSE berbasis transformator untuk pembuatan Text2Image yang cepat.
https://huggingface.co/spaces/openmuse/muse
Repo ini untuk reproduksi model Muse. Tujuannya adalah untuk membuat repo yang sederhana dan dapat diskalakan, untuk mereproduksi Muse dan membangun KnowEdge tentang VQ + Transformers pada skala. Kami akan menggunakan dataset Laion-2b + Coyo-700m yang dikurung untuk pelatihan.
Tahap Proyek:
Semua artefak dari proyek ini akan diunggah ke organisasi OpenMuse di hub Huggingface.
Pertama -tama buat lingkungan virtual dan instal repo menggunakan:
git clone https://github.com/huggingface/muse
cd muse
pip install -e " .[extra] " Anda harus menginstal PyTorch dan torchvision secara manual. Kami menggunakan torch==1.13.1 dengan CUDA11.7 untuk pelatihan.
Untuk pelatihan paralel data terdistribusi, kami menggunakan Perpustakaan accelerate , meskipun ini dapat berubah di masa depan. Untuk pemuatan dataset, kami menggunakan Perpustakaan webdataset . Jadi dataset harus dalam format webdataset .
Di momemnt kami mendukung model berikut:
MaskGitTransformer - Model transformator utama dari kertas.MaskGitVQGAN - Model VQGAN dari Repo Maskgit.VQGANModel - Model VQGAN dari Repo Taming Transformers. Model -model tersebut diimplementasikan di bawah direktori muse . Semua model menerapkan API transformers yang akrab. Jadi Anda dapat menggunakan metode from_pretrained dan save_pretrained untuk memuat dan menyimpan model. Model dapat disimpan dan dimuat dari hub huggingface.
import torch
from torchvision import transforms
from PIL import Image
from muse import MaskGitVQGAN
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# encode and decode images using
encode_transform = = transforms . Compose (
[
transforms . Resize ( 256 , interpolation = transforms . InterpolationMode . BILINEAR ),
transforms . CenterCrop ( 256 ),
transforms . ToTensor (),
]
)
image = Image . open ( "..." ) #
pixel_values = encode_transform ( image ). unsqueeze ( 0 )
image_tokens = vq_model . encode ( pixel_values )
rec_image = vq_model . decode ( image_tokens )
# Convert to PIL images
rec_image = 2.0 * rec_image - 1.0
rec_image = torch . clamp ( rec_image , - 1.0 , 1.0 )
rec_image = ( rec_image + 1.0 ) / 2.0
rec_image *= 255.0
rec_image = rec_image . permute ( 0 , 2 , 3 , 1 ). cpu (). numpy (). astype ( np . uint8 )
pil_images = [ Image . fromarray ( image ) for image in rec_image ] import torch
from muse import MaskGitTransformer , MaskGitVQGAN
from muse . sampling import cosine_schedule
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# Initialize the MaskGitTransformer model
maskgit_model = MaskGitTransformer (
vocab_size = 2025 , #(1024 + 1000 + 1 = 2025 -> Vq_tokens + Imagenet class ids + <mask>)
max_position_embeddings = 257 , # 256 + 1 for class token
hidden_size = 512 ,
num_hidden_layers = 8 ,
num_attention_heads = 8 ,
intermediate_size = 2048 ,
codebook_size = 1024 ,
num_vq_tokens = 256 ,
num_classes = 1000 ,
)
# prepare the input batch
images = torch . randn ( 4 , 3 , 256 , 256 )
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
# encode the images
image_tokens = vq_model . encode ( images )
batch_size , seq_len = image_tokens . shape
# Sample a random timestep for each image
timesteps = torch . rand ( batch_size , device = image_tokens . device )
# Sample a random mask probability for each image using timestep and cosine schedule
mask_prob = cosine_schedule ( timesteps )
mask_prob = mask_prob . clip ( min_masking_rate )
# creat a random mask for each image
num_token_masked = ( seq_len * mask_prob ). round (). clamp ( min = 1 )
batch_randperm = torch . rand ( batch_size , seq_len , device = image_tokens . device ). argsort ( dim = - 1 )
mask = batch_randperm < num_token_masked . unsqueeze ( - 1 )
# mask images and create input and labels
input_ids = torch . where ( mask , mask_id , image_tokens )
labels = torch . where ( mask , image_tokens , - 100 )
# shift the class ids by codebook size
class_ids = class_ids + vq_model . num_embeddings
# prepend the class ids to the image tokens
input_ids = torch . cat ([ class_ids . unsqueeze ( - 1 ), input_ids ], dim = - 1 )
# prepend -100 to the labels as we don't want to predict the class ids
labels = torch . cat ([ - 100 * torch . ones_like ( class_ids ). unsqueeze ( - 1 ), labels ], dim = - 1 )
# forward pass
logits , loss = maskgit_model ( input_ids , labels = labels )
loss . backward ()
# to generate images
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
generated_tokens = maskgit_model . generate ( class_ids = class_ids )
rec_images = vq_model . decode ( generated_tokens )Catatan :
Maskgits adalah transformator yang mengeluarkan logit yang diberi urutan token dari kedua VQ dan token label yang dikondisikan kelas
Cara proses denoising dilakukan adalah dengan menutupi dengan ID token topeng dan secara bertahap denoise
Dalam implementasi asli, ini dilakukan dengan terlebih dahulu menggunakan softmax pada redup terakhir dan pengambilan sampel secara acak sebagai distribusi kategori. Ini akan memberikan token kami yang diprediksi untuk setiap maskid. Kemudian kami mendapatkan probabilitas agar token tersebut dipilih. Akhirnya, kami mendapatkan probabilitas kepercayaan tertinggi TOPK ketika gumbel*temp ditambahkan ke dalamnya. Distribusi Gumbel seperti distribusi normal yang bergeser ke 0 yang digunakan untuk memodelkan peristiwa ekstrem. Jadi dalam skenario ekstrem, kami ingin melihat token yang berbeda dipilih dari yang default
Untuk implementasi Lucidrian, pertama-tama menghapus token skor tertinggi (probabilitas terendah) dengan menutupi mereka dengan rasio masking yang diberikan. Kemudian, kecuali 10% tertinggi dari log yang kami dapatkan, kami mengaturnya ke -infinitas sehingga ketika kami melakukan distribusi gumbel di atasnya, mereka akan diabaikan. Kemudian perbarui ID input dan skor di mana skor hanya 1-probabilitas yang diberikan oleh softmax dari log pada ID yang diprediksi menarik
Untuk Imagenet kelas-Class kami menggunakan accelerate untuk pelatihan DDP dan webdataset untuk pemuatan data. Script pelatihan tersedia dalam training/train_maskgit_imagenet.py .
Kami menggunakan omegaconf untuk manajemen konfigurasi. Lihat configs/template_config.yaml untuk templat konfigurasi. Di bawah ini kami menjelaskan parameter konfigurasi.
wandb :
entity : ???
experiment :
name : ???
project : ???
output_dir : ???
max_train_examples : ???
save_every : 1000
eval_every : 500
generate_every : 1000
log_every : 50
log_grad_norm_every : 100
resume_from_checkpoint : latest
model :
vq_model :
pretrained : " openMUSE/maskgit-vqgan-imagenet-f16-256 "
transformer :
vocab_size : 2048 # (1024 + 1000 + 1 = 2025 -> Vq + Imagenet + <mask>, use 2048 for even division by 8)
max_position_embeddings : 264 # (256 + 1 for class id, use 264 for even division by 8)
hidden_size : 768
num_hidden_layers : 12
num_attention_heads : 12
intermediate_size : 3072
codebook_size : 1024
num_vq_tokens : 256
num_classes : 1000
initializer_range : 0.02
layer_norm_eps : 1e-6
use_bias : False
use_normformer : True
use_encoder_layernorm : True
hidden_dropout : 0.0
attention_dropout : 0.0
gradient_checkpointing : True
enable_xformers_memory_efficient_attention : False
dataset :
params :
train_shards_path_or_url : ???
eval_shards_path_or_url : ???
batch_size : ${training.batch_size}
shuffle_buffer_size : ???
num_workers : ???
resolution : 256
pin_memory : True
persistent_workers : True
preprocessing :
resolution : 256
center_crop : True
random_flip : False
optimizer :
name : adamw # Can be adamw or lion or fused_adamw. Install apex for fused_adamw
params : # default adamw params
learning_rate : ???
scale_lr : False # scale learning rate by total batch size
beta1 : 0.9
beta2 : 0.999
weight_decay : 0.01
epsilon : 1e-8
lr_scheduler :
scheduler : " constant_with_warmup "
params :
learning_rate : ${optimizer.params.learning_rate}
warmup_steps : 500
training :
gradient_accumulation_steps : 1
batch_size : 128
mixed_precision : " no "
enable_tf32 : True
use_ema : False
seed : 42
max_train_steps : ???
overfit_one_batch : False
min_masking_rate : 0.0
label_smoothing : 0.0
max_grad_norm : nullArgumen dengan ??? diperlukan.
Wandb :
wandb.entity : entitas wandb untuk digunakan untuk logging.Eksperimen :
experiment.name : Nama percobaan.experiment.project : Proyek WANDB untuk digunakan untuk logging.experiment.output_dir : Direktori untuk menyimpan pos pemeriksaan.experiment.max_train_examples : Jumlah maksimum contoh pelatihan untuk digunakan.experiment.save_every : Simpan pos pemeriksaan setiap langkah save_every .experiment.eval_every : Evaluasi model setiap langkah eval_every .experiment.generate_every : menghasilkan gambar setiap generate_every langkah.experiment.log_every : Catat metrik pelatihan setiap langkah log_every .log_grad_norm_every : Log norma gradien setiap langkah log_grad_norm_every .experiment.resume_from_checkpoint : Pos pemeriksaan untuk melanjutkan pelatihan dari. Dapat menjadi latest untuk dilanjutkan dari pos pemeriksaan atau jalur terbaru ke pos pemeriksaan yang disimpan. Jika None atau jalan yang tidak ada maka pelatihan dimulai dari awal.Model :
model.vq_model.pretrained : Model VQ pretrained untuk digunakan. Dapat menjadi jalur ke pos pemeriksaan yang disimpan atau nama model pelukan.model.transformer : Konfigurasi Model Transformer.model.gradient_checkpointing : Aktifkan pemeriksaan gradien untuk model transformator.enable_xformers_memory_efficient_attention : Mengaktifkan perhatian yang efisien memori atau perhatian flash untuk model transformator. Untuk perhatian kilat kita perlu menggunakan fp16 atau bf16 . Xformers perlu diinstal agar ini berfungsi.Dataset :
dataset.params.train_shards_path_or_url : Jalur atau url ke pecahan pelatihan webdataset .dataset.params.eval_shards_path_or_url : Jalur atau url ke pecahan evaluasi webdataset .dataset.params.batch_size : Ukuran batch untuk digunakan untuk pelatihan.dataset.params.shuffle_buffer_size : Ukuran buffer shuffle untuk digunakan untuk pelatihan.dataset.params.num_workers : Jumlah pekerja yang akan digunakan untuk pemuatan data.dataset.params.resolution : Resolusi gambar yang akan digunakan untuk pelatihan.dataset.params.pin_memory : Pin memori untuk pemuatan data.dataset.params.persistent_workers : Gunakan pekerja persisten untuk pemuatan data.dataset.preprocessing.resolution : Resolusi gambar yang akan digunakan untuk preprocessing.dataset.preprocessing.center_crop : Apakah akan memusatkan tanaman gambar. Jika False maka gambar dipotong secara acak ke resolution .dataset.preprocessing.random_flip : Apakah akan membalikkan gambar secara acak. Jika False maka gambar tidak terbalik.Pengoptimal :
optimizer.name : Pengoptimal untuk digunakan untuk pelatihan.optimizer.params : Parameter pengoptimal.lr_scheduler :
lr_scheduler.scheduler : Penjadwal tingkat pembelajaran untuk digunakan untuk pelatihan.lr_scheduler.params : parameter penjadwal tingkat pembelajaran.pelatihan :
training.gradient_accumulation_steps : Jumlah langkah akumulasi gradien untuk digunakan untuk pelatihan.training.batch_size : Ukuran batch untuk digunakan untuk pelatihan.training.mixed_precision : Mode presisi campuran untuk digunakan untuk pelatihan. no bisa, fp16 atau bf16 .training.enable_tf32 : Aktifkan TF32 untuk pelatihan tentang Ampere GPU.training.use_ema : Aktifkan EMA untuk pelatihan. Saat ini tidak didukung.training.seed Seed: Benih yang akan digunakan untuk pelatihan.training.max_train_steps : Jumlah langkah -langkah pelatihan maksimum.training.overfit_one_batch : apakah akan overfit satu batch untuk debugging.training.min_masking_rate : Tingkat masking minimum untuk digunakan untuk pelatihan.training.label_smoothing : Nilai label smoothing untuk digunakan untuk pelatihan.max_grad_norm : Norma gradien maks.Catatan tentang pelatihan dan dataset. :
Kami secara acak menyusun ulang pecahan (dengan penggantian) dan contoh sampel dalam buffer untuk pelatihan setiap kali kami melanjutkan/memulai menjalankan pelatihan. Ini berarti pemuatan data kami tidak ditentukan. Kami juga tidak melakukan pelatihan berbasis zaman tetapi hanya menggunakan ini untuk menjaga buku dan dapat menggunakan kembali lingkaran pelatihan yang sama dengan set data/loader lainnya.
Sejauh ini kami menjalankan eksperimen pada simpul tunggal. Untuk meluncurkan pelatihan yang dijalankan pada satu node, jalankan langkah -langkah berikut:
webdataset . Anda dapat menggunakan skrip scripts/convert_imagenet_to_wds.py untuk mengonversi dataset imagenet ke format webdataset .accelerate config .config.yaml untuk percobaan Anda.accelerate launch . accelerate launch python -u training/train_maskgit_imagenet.py config=path/to/yaml/config Dengan omegaconf, override Commandline dilakukan dalam format notasi dot. Misalnya jika Anda ingin mengganti jalur dataset, Anda akan menggunakan perintah python -u train.py config=path/to/config dataset.params.path=path/to/dataset .
Perintah yang sama dapat digunakan untuk meluncurkan pelatihan secara lokal.
├── README.md
├── configs -> All training config files.
│ └── template_config.yaml
├── muse
│ ├── __init__.py
│ ├── data.py -> All data related utils. Can create a data folder if needed.
│ ├── logging.py -> Misc logging utils.
| ├── lr_schedulers.py -> All lr scheduler related utils.
│ ├── modeling_maskgit_vqgan.py -> VQGAN model from maskgit repo.
│ ├── modeling_taming_vqgan.py -> VQGAN model from taming repo.
│ └── modeling_transformer.py -> The main transformer model.
│ ├── modeling_utils.py -> All model related utils, like save_pretrained, from_pretrained from hub etc
│ ├── sampling.py -> Sampling/Generation utils.
│ ├── training_utils.py -> Common training utils.
├── pyproject.toml
├── setup.cfg
├── setup.py
├── test.py
└── training -> All training scripts.
├── __init__.py
├── data.py -> All data related utils. Can create a data folder if needed.
├── optimizer.py -> All optimizer related utils and any new optimizer not available in PT.
├── train_maskgit_imagenet.py
├── train_muse.py
└── train_vqgan.py
Proyek ini didasarkan pada repo open-source berikut. Terima kasih kepada semua penulis atas pekerjaan luar biasa mereka.
Dan Obivioulsy to Pytorch Team untuk kerangka kerja yang luar biasa ini ❤️