Eine Open-Reproduktion-Anstrengung zur Reproduktion des Transformator-Muse-Modells für schnelle Text2Image-Generation.
https://huggingface.co/spaces/openmuse/muse
Dieses Repo dient zur Reproduktion des Muse -Modells. Ziel ist es, ein einfaches und skalierbares Repo zu erstellen, Muse zu reproduzieren und wissen, dass VQ + Transformers im Maßstab wissen. Wir werden Deduped Laion-2B + Coyo-700m-Datensatz für das Training verwenden.
Projektphasen:
Alle Artefakte dieses Projekts werden in der OpenMuse -Organisation im Huggingface -Hub hochgeladen.
Erstellen Sie zunächst eine virtuelle Umgebung und installieren Sie das Repo mit:
git clone https://github.com/huggingface/muse
cd muse
pip install -e " .[extra] " Sie müssen PyTorch und torchvision manuell installieren. Wir verwenden torch==1.13.1 mit CUDA11.7 für das Training.
Für verteilte Daten parallele Schulungen verwenden wir accelerate , obwohl sich dies in Zukunft ändern kann. Für das Laden von Datensatz verwenden wir webdataset -Bibliothek. Der Datensatz sollte also im webdataset -Format enthalten sein.
Bei der Momemnt unterstützen wir die folgenden Modelle:
MaskGitTransformer - Das Haupttransformatormodell aus dem Papier.MaskGitVQGAN - Das Vqgan -Modell aus dem Maskgit Repo.VQGANModel - Das VQGAN -Modell aus dem Taming Transformers Repo. Die Modelle werden im Rahmen muse -Verzeichnisses implementiert. Alle Modelle implementieren die bekannte transformers -API. Sie können also von den Methoden from_pretrained und save_pretrained -Methoden verwenden, um die Modelle zu laden und zu speichern. Das Modell kann gespeichert und aus dem Hubface -Hub geladen werden.
import torch
from torchvision import transforms
from PIL import Image
from muse import MaskGitVQGAN
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# encode and decode images using
encode_transform = = transforms . Compose (
[
transforms . Resize ( 256 , interpolation = transforms . InterpolationMode . BILINEAR ),
transforms . CenterCrop ( 256 ),
transforms . ToTensor (),
]
)
image = Image . open ( "..." ) #
pixel_values = encode_transform ( image ). unsqueeze ( 0 )
image_tokens = vq_model . encode ( pixel_values )
rec_image = vq_model . decode ( image_tokens )
# Convert to PIL images
rec_image = 2.0 * rec_image - 1.0
rec_image = torch . clamp ( rec_image , - 1.0 , 1.0 )
rec_image = ( rec_image + 1.0 ) / 2.0
rec_image *= 255.0
rec_image = rec_image . permute ( 0 , 2 , 3 , 1 ). cpu (). numpy (). astype ( np . uint8 )
pil_images = [ Image . fromarray ( image ) for image in rec_image ] import torch
from muse import MaskGitTransformer , MaskGitVQGAN
from muse . sampling import cosine_schedule
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# Initialize the MaskGitTransformer model
maskgit_model = MaskGitTransformer (
vocab_size = 2025 , #(1024 + 1000 + 1 = 2025 -> Vq_tokens + Imagenet class ids + <mask>)
max_position_embeddings = 257 , # 256 + 1 for class token
hidden_size = 512 ,
num_hidden_layers = 8 ,
num_attention_heads = 8 ,
intermediate_size = 2048 ,
codebook_size = 1024 ,
num_vq_tokens = 256 ,
num_classes = 1000 ,
)
# prepare the input batch
images = torch . randn ( 4 , 3 , 256 , 256 )
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
# encode the images
image_tokens = vq_model . encode ( images )
batch_size , seq_len = image_tokens . shape
# Sample a random timestep for each image
timesteps = torch . rand ( batch_size , device = image_tokens . device )
# Sample a random mask probability for each image using timestep and cosine schedule
mask_prob = cosine_schedule ( timesteps )
mask_prob = mask_prob . clip ( min_masking_rate )
# creat a random mask for each image
num_token_masked = ( seq_len * mask_prob ). round (). clamp ( min = 1 )
batch_randperm = torch . rand ( batch_size , seq_len , device = image_tokens . device ). argsort ( dim = - 1 )
mask = batch_randperm < num_token_masked . unsqueeze ( - 1 )
# mask images and create input and labels
input_ids = torch . where ( mask , mask_id , image_tokens )
labels = torch . where ( mask , image_tokens , - 100 )
# shift the class ids by codebook size
class_ids = class_ids + vq_model . num_embeddings
# prepend the class ids to the image tokens
input_ids = torch . cat ([ class_ids . unsqueeze ( - 1 ), input_ids ], dim = - 1 )
# prepend -100 to the labels as we don't want to predict the class ids
labels = torch . cat ([ - 100 * torch . ones_like ( class_ids ). unsqueeze ( - 1 ), labels ], dim = - 1 )
# forward pass
logits , loss = maskgit_model ( input_ids , labels = labels )
loss . backward ()
# to generate images
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
generated_tokens = maskgit_model . generate ( class_ids = class_ids )
rec_images = vq_model . decode ( generated_tokens )Notiz :
Maskgits ist ein Transformator, der Logits ausgibt, die eine Abfolge von Token von VQ- und Klassenkonditions-Label-Token erhalten
Die Art und Weise, wie der Denoising -Prozess durchgeführt wird
In der ursprünglichen Implementierung erfolgt dies, indem zunächst ein Softmax für die letzte schwache und zufällige Stichprobe als kategoriale Verteilung verwendet wird. Dies gibt unseren vorhergesagten Token für jede Maskid. Dann erhalten wir die Wahrscheinlichkeiten, dass diese Token ausgewählt werden. Schließlich erhalten wir die höchsten Vertrauenswahrscheinlichkeiten, wenn Gumbel*Temperatur hinzugefügt wird. Die Gumbel -Verteilung ist wie eine verschobene Normalverteilung in Richtung 0, die zum Modellieren von extremen Ereignissen verwendet wird. In extremen Szenarien möchten wir also sehen, dass ein anderes Token aus dem Standard ausgewählt wird
Für die lucidrische Implementierung entfernt es zunächst die Token mit der höchsten Punktzahl (niedrigste Wahrscheinlichkeit), indem sie sie mit einem gegebenen Maskierungsverhältnis maskiert. Mit Ausnahme der höchsten 10% der Logits, die wir erhalten, setzen wir es auf eine Infinity. Wenn wir die Gumbel -Verteilung darauf ausführen, werden sie ignoriert. Aktualisieren Sie dann die Eingabe-IDs und die Bewertungen, bei denen die Bewertungen nur 1-die Wahrscheinlichkeit sind
Für klassenkonditionelle ImagEnet verwenden wir accelerate für DDP-Training und webdataset für das Laden von Daten. Das Trainingsskript ist in training/train_maskgit_imagenet.py verfügbar.
Wir verwenden Omegaconf für die Konfigurationsverwaltung. Siehe configs/template_config.yaml für die Konfigurationsvorlage. Nachfolgend erläutern wir die Konfigurationsparameter.
wandb :
entity : ???
experiment :
name : ???
project : ???
output_dir : ???
max_train_examples : ???
save_every : 1000
eval_every : 500
generate_every : 1000
log_every : 50
log_grad_norm_every : 100
resume_from_checkpoint : latest
model :
vq_model :
pretrained : " openMUSE/maskgit-vqgan-imagenet-f16-256 "
transformer :
vocab_size : 2048 # (1024 + 1000 + 1 = 2025 -> Vq + Imagenet + <mask>, use 2048 for even division by 8)
max_position_embeddings : 264 # (256 + 1 for class id, use 264 for even division by 8)
hidden_size : 768
num_hidden_layers : 12
num_attention_heads : 12
intermediate_size : 3072
codebook_size : 1024
num_vq_tokens : 256
num_classes : 1000
initializer_range : 0.02
layer_norm_eps : 1e-6
use_bias : False
use_normformer : True
use_encoder_layernorm : True
hidden_dropout : 0.0
attention_dropout : 0.0
gradient_checkpointing : True
enable_xformers_memory_efficient_attention : False
dataset :
params :
train_shards_path_or_url : ???
eval_shards_path_or_url : ???
batch_size : ${training.batch_size}
shuffle_buffer_size : ???
num_workers : ???
resolution : 256
pin_memory : True
persistent_workers : True
preprocessing :
resolution : 256
center_crop : True
random_flip : False
optimizer :
name : adamw # Can be adamw or lion or fused_adamw. Install apex for fused_adamw
params : # default adamw params
learning_rate : ???
scale_lr : False # scale learning rate by total batch size
beta1 : 0.9
beta2 : 0.999
weight_decay : 0.01
epsilon : 1e-8
lr_scheduler :
scheduler : " constant_with_warmup "
params :
learning_rate : ${optimizer.params.learning_rate}
warmup_steps : 500
training :
gradient_accumulation_steps : 1
batch_size : 128
mixed_precision : " no "
enable_tf32 : True
use_ema : False
seed : 42
max_train_steps : ???
overfit_one_batch : False
min_masking_rate : 0.0
label_smoothing : 0.0
max_grad_norm : nullArgumente mit ??? sind erforderlich.
Wandb :
wandb.entity : Die WANDB -Einheit, die zur Protokollierung verwendet werden muss.Experiment :
experiment.name : Der Name des Experiments.experiment.project : Das WANDB -Projekt zur Protokollierung.experiment.output_dir : Das Verzeichnis zum Speichern der Kontrollpunkte.experiment.max_train_examples : Die maximale Anzahl der zu verwendenden Trainingsbeispiele.experiment.save_every : Speichern Sie einen Checkpoint bei jeder save_every -Schritten.experiment.eval_every : Bewerten Sie das Modell bei jeder eval_every -Schritten.experiment.generate_every : Generieren Sie Bilder jede generate_every -Schritte.experiment.log_every : Protokollieren Sie die Trainingsmetriken bei jeder Schritten log_every .log_grad_norm_every : log die gradientennorm jeder log_grad_norm_every -Schritte an.experiment.resume_from_checkpoint : Der Checkpoint zur Wiederaufnahme von Training von. Kann von dem neuesten Checkpoint oder Pfad zu einem gespeicherten Checkpoint latest sein. Wenn None oder der Pfad nicht vorhanden ist, beginnt das Training von vorne.Modell :
model.vq_model.pretrained : Das zu verwendende VQ -Modell. Kann ein Pfad zu einem gespeicherten Checkpoint oder einem Modellnamen mit dem Umarmungsface sein.model.transformer : Die Transformatormodellkonfiguration.model.gradient_checkpointing : Aktivieren Sie Gradientenprüfungen für das Transformatormodell.enable_xformers_memory_efficient_attention : Aktivieren Sie die Aufmerksamkeit oder Flash -Aufmerksamkeit für das Transformatormodell. Für die Aufmerksamkeit von Flash müssen wir fp16 oder bf16 verwenden. Xformers müssen dafür installiert werden, dass dies funktioniert.Datensatz :
dataset.params.train_shards_path_or_url : Der Pfad oder die URL zu den webdataset -Trainings -Scherben.dataset.params.eval_shards_path_or_url : Der Pfad oder die URL zu den webdataset -Bewertungsschards.dataset.params.batch_size : Die Stapelgröße für das Training.dataset.params.shuffle_buffer_size : Die zum Training verwendete Shuffle -Puffergröße.dataset.params.num_workers : Die Anzahl der Arbeitnehmer, die für das Laden von Daten verwendet werden sollen.dataset.params.resolution : Die Auflösung der für das Training zu verwendenden Bilder.dataset.params.pin_memory : Pin den Speicher für das Laden von Daten.dataset.params.persistent_workers : Verwenden Sie persistente Arbeitnehmer für das Laden von Daten.dataset.preprocessing.resolution : Die Auflösung der für die Vorverarbeitung verwendeten Bilder.dataset.preprocessing.center_crop : Ob Sie die Bilder zentrieren sollen. Wenn False , werden die Bilder zufällig zur resolution beschnitten.dataset.preprocessing.random_flip : Ob Sie die Bilder zufällig umdrehen. Wenn False , werden die Bilder nicht umgedreht.Optimierer :
optimizer.name : Der Optimierer für das Training.optimizer.params : Die Optimierer -Parameter.lr_scheduler :
lr_scheduler.scheduler : Der Lernrate -Scheduler für das Training.lr_scheduler.params : Die Parameter der Lernrate Scheduler.Ausbildung :
training.gradient_accumulation_steps : Die Anzahl der für das Training zu verwendenden Gradientenakkumulationsschritte.training.batch_size : Die Stapelgröße für das Training.training.mixed_precision : Der gemischte Präzisionsmodus für das Training. Kann no , fp16 oder bf16 sein.training.enable_tf32 .training.use_ema : EMA für das Training aktivieren. Derzeit nicht unterstützt.training.seed : Der Saatgut für das Training.training.max_train_steps : Die maximale Anzahl von Trainingsschritten.training.overfit_one_batch : Ob Sie eine Stapel zum Debuggen überwinden sollen.training.min_masking_rate : Die minimale Maskierungsrate für das Training.training.label_smoothing .max_grad_norm : Max -Gradientennorm.Anmerkungen zu Training und Datensatz. :
Wir haben die Scherben (mit Ersatz) und Beispiele in Puffer für das Training bei jedem Schulungsaufnahme nach dem Training nach dem Zufallsprinzip erneut versehen. Dies bedeutet, dass unsere Datenbelastung nicht bestimmt ist. Wir machen auch kein epoch -basiertes Training, sondern verwenden dies nur für die Aufbewahrung von Buch und können dieselbe Trainingsschleife mit anderen Datensätzen/Ladern wiederverwenden können.
Bisher führen wir Experimente am einzelnen Knoten durch. Führen Sie die folgenden Schritte aus:
webdataset -Format vor. Sie können das scripts/convert_imagenet_to_wds.py -Skript verwenden, um das ImageNet -Datensatz in das webdataset -Format zu konvertieren.accelerate config .config.yaml -Datei für Ihr Experiment.accelerate launch . accelerate launch python -u training/train_maskgit_imagenet.py config=path/to/yaml/config Mit Omegaconf werden Befehlszeilenüberschüsse im Dot-NOTATION-Format durchgeführt. Wenn Sie den Dataset -Pfad überschreiben möchten, würden Sie den Befehl python -u train.py config=path/to/config dataset.params.path=path/to/dataset verwenden.
Der gleiche Befehl kann verwendet werden, um das Training lokal zu starten.
├── README.md
├── configs -> All training config files.
│ └── template_config.yaml
├── muse
│ ├── __init__.py
│ ├── data.py -> All data related utils. Can create a data folder if needed.
│ ├── logging.py -> Misc logging utils.
| ├── lr_schedulers.py -> All lr scheduler related utils.
│ ├── modeling_maskgit_vqgan.py -> VQGAN model from maskgit repo.
│ ├── modeling_taming_vqgan.py -> VQGAN model from taming repo.
│ └── modeling_transformer.py -> The main transformer model.
│ ├── modeling_utils.py -> All model related utils, like save_pretrained, from_pretrained from hub etc
│ ├── sampling.py -> Sampling/Generation utils.
│ ├── training_utils.py -> Common training utils.
├── pyproject.toml
├── setup.cfg
├── setup.py
├── test.py
└── training -> All training scripts.
├── __init__.py
├── data.py -> All data related utils. Can create a data folder if needed.
├── optimizer.py -> All optimizer related utils and any new optimizer not available in PT.
├── train_maskgit_imagenet.py
├── train_muse.py
└── train_vqgan.py
Dieses Projekt basiert auf den folgenden Open-Source-Repos. Vielen Dank an alle Autoren für ihre erstaunliche Arbeit.
Und obivioulsy zum Pytorch -Team für diesen erstaunlichen Rahmen ❤️ ❤️