Un effort de représentation ouverte pour reproduire le modèle de muse basé sur le transformateur pour la génération rapide de texte en texte.
https://huggingface.co/spaces/openmuse/muse
Ce repo est pour la reproduction du modèle Muse. L'objectif est de créer un repo simple et évolutif, de reproduire Muse et de construire des transformateurs VQ + à grande échelle. Nous utiliserons un ensemble de données LAION-2B + COYOO-700M pour la formation.
Étapes du projet:
Tous les artefacts de ce projet seront téléchargés sur l'organisation OpenMuse sur le HuggingFace Hub.
Créez d'abord un environnement virtuel et installez le dépôt en utilisant:
git clone https://github.com/huggingface/muse
cd muse
pip install -e " .[extra] " Vous devrez installer manuellement PyTorch et torchvision . Nous utilisons torch==1.13.1 avec CUDA11.7 pour la formation.
Pour la formation parallèle de données distribuées, nous utilisons une bibliothèque accelerate , bien que cela puisse changer à l'avenir. Pour le chargement de l'ensemble de données, nous utilisons la bibliothèque webdataset . L'ensemble de données doit donc être au format webdataset .
Au momemnt, nous soutenons les modèles suivants:
MaskGitTransformer - Le modèle de transformateur principal du papier.MaskGitVQGAN - Le modèle VQGAN du Maskgit Repo.VQGANModel - Le modèle VQGAN du repo Taming Transformers. Les modèles sont mis en œuvre dans le répertoire muse . Tous les modèles implémentent l'API transformers familiers. Ainsi, vous pouvez utiliser des méthodes from_pretrained et save_pretrained pour charger et enregistrer les modèles. Le modèle peut être enregistré et chargé à partir du hub étreint.
import torch
from torchvision import transforms
from PIL import Image
from muse import MaskGitVQGAN
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# encode and decode images using
encode_transform = = transforms . Compose (
[
transforms . Resize ( 256 , interpolation = transforms . InterpolationMode . BILINEAR ),
transforms . CenterCrop ( 256 ),
transforms . ToTensor (),
]
)
image = Image . open ( "..." ) #
pixel_values = encode_transform ( image ). unsqueeze ( 0 )
image_tokens = vq_model . encode ( pixel_values )
rec_image = vq_model . decode ( image_tokens )
# Convert to PIL images
rec_image = 2.0 * rec_image - 1.0
rec_image = torch . clamp ( rec_image , - 1.0 , 1.0 )
rec_image = ( rec_image + 1.0 ) / 2.0
rec_image *= 255.0
rec_image = rec_image . permute ( 0 , 2 , 3 , 1 ). cpu (). numpy (). astype ( np . uint8 )
pil_images = [ Image . fromarray ( image ) for image in rec_image ] import torch
from muse import MaskGitTransformer , MaskGitVQGAN
from muse . sampling import cosine_schedule
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# Initialize the MaskGitTransformer model
maskgit_model = MaskGitTransformer (
vocab_size = 2025 , #(1024 + 1000 + 1 = 2025 -> Vq_tokens + Imagenet class ids + <mask>)
max_position_embeddings = 257 , # 256 + 1 for class token
hidden_size = 512 ,
num_hidden_layers = 8 ,
num_attention_heads = 8 ,
intermediate_size = 2048 ,
codebook_size = 1024 ,
num_vq_tokens = 256 ,
num_classes = 1000 ,
)
# prepare the input batch
images = torch . randn ( 4 , 3 , 256 , 256 )
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
# encode the images
image_tokens = vq_model . encode ( images )
batch_size , seq_len = image_tokens . shape
# Sample a random timestep for each image
timesteps = torch . rand ( batch_size , device = image_tokens . device )
# Sample a random mask probability for each image using timestep and cosine schedule
mask_prob = cosine_schedule ( timesteps )
mask_prob = mask_prob . clip ( min_masking_rate )
# creat a random mask for each image
num_token_masked = ( seq_len * mask_prob ). round (). clamp ( min = 1 )
batch_randperm = torch . rand ( batch_size , seq_len , device = image_tokens . device ). argsort ( dim = - 1 )
mask = batch_randperm < num_token_masked . unsqueeze ( - 1 )
# mask images and create input and labels
input_ids = torch . where ( mask , mask_id , image_tokens )
labels = torch . where ( mask , image_tokens , - 100 )
# shift the class ids by codebook size
class_ids = class_ids + vq_model . num_embeddings
# prepend the class ids to the image tokens
input_ids = torch . cat ([ class_ids . unsqueeze ( - 1 ), input_ids ], dim = - 1 )
# prepend -100 to the labels as we don't want to predict the class ids
labels = torch . cat ([ - 100 * torch . ones_like ( class_ids ). unsqueeze ( - 1 ), labels ], dim = - 1 )
# forward pass
logits , loss = maskgit_model ( input_ids , labels = labels )
loss . backward ()
# to generate images
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
generated_tokens = maskgit_model . generate ( class_ids = class_ids )
rec_images = vq_model . decode ( generated_tokens )Note :
Maskgits est un transformateur qui publie des logits compte tenu d'une séquence de jetons de VQ et de jeton d'étiquette conditionné en classe
La façon dont le processus de débarras est de masquer avec des identifiants de jeton de masque et un dénoise progressivement
Dans l'implémentation originale, cela se fait d'abord en utilisant un softmax sur le dernier échantillonnage aléatoire et aléatoire comme distribution catégorique. Cela donnera nos jetons prévus pour chaque masque. Ensuite, nous obtenons les probabilités pour que ces jetons soient choisis. Enfin, nous obtenons les probabilités de confiance les plus élevées de TopK lorsque la température de Gumbel * y est ajoutée. La distribution de Gumbel est comme une distribution normale décalée vers 0 qui est utilisée pour modéliser des événements extrêmes. Donc, dans des scénarios extrêmes, nous aimerons voir un jeton différent en cours de choix parmi le par défaut
Pour la mise en œuvre de lucidrien, il supprime d'abord les jetons les plus scores (probabilité le plus faible) en les masquant avec un rapport de masquage donné. Ensuite, à l'exception des 10% les plus élevés des logits que nous obtenons, nous le définissons sur -Infinity, donc lorsque nous faisons la distribution de Gumbel dessus, ils seront ignorés. Ensuite, mettez à jour les ID d'entrée et les scores où les scores ne sont que 1 - la probabilité donnée par le softmax des logits aux ID prévus de manière intéressante
Pour l'imageNet conditionnel de classe, nous utilisons accelerate pour la formation DDP et webdataset pour le chargement des données. Le script de formation est disponible dans training/train_maskgit_imagenet.py .
Nous utilisons Omegaconf pour la gestion de la configuration. Voir configs/template_config.yaml pour le modèle de configuration. Ci-dessous, nous expliquons les paramètres de configuration.
wandb :
entity : ???
experiment :
name : ???
project : ???
output_dir : ???
max_train_examples : ???
save_every : 1000
eval_every : 500
generate_every : 1000
log_every : 50
log_grad_norm_every : 100
resume_from_checkpoint : latest
model :
vq_model :
pretrained : " openMUSE/maskgit-vqgan-imagenet-f16-256 "
transformer :
vocab_size : 2048 # (1024 + 1000 + 1 = 2025 -> Vq + Imagenet + <mask>, use 2048 for even division by 8)
max_position_embeddings : 264 # (256 + 1 for class id, use 264 for even division by 8)
hidden_size : 768
num_hidden_layers : 12
num_attention_heads : 12
intermediate_size : 3072
codebook_size : 1024
num_vq_tokens : 256
num_classes : 1000
initializer_range : 0.02
layer_norm_eps : 1e-6
use_bias : False
use_normformer : True
use_encoder_layernorm : True
hidden_dropout : 0.0
attention_dropout : 0.0
gradient_checkpointing : True
enable_xformers_memory_efficient_attention : False
dataset :
params :
train_shards_path_or_url : ???
eval_shards_path_or_url : ???
batch_size : ${training.batch_size}
shuffle_buffer_size : ???
num_workers : ???
resolution : 256
pin_memory : True
persistent_workers : True
preprocessing :
resolution : 256
center_crop : True
random_flip : False
optimizer :
name : adamw # Can be adamw or lion or fused_adamw. Install apex for fused_adamw
params : # default adamw params
learning_rate : ???
scale_lr : False # scale learning rate by total batch size
beta1 : 0.9
beta2 : 0.999
weight_decay : 0.01
epsilon : 1e-8
lr_scheduler :
scheduler : " constant_with_warmup "
params :
learning_rate : ${optimizer.params.learning_rate}
warmup_steps : 500
training :
gradient_accumulation_steps : 1
batch_size : 128
mixed_precision : " no "
enable_tf32 : True
use_ema : False
seed : 42
max_train_steps : ???
overfit_one_batch : False
min_masking_rate : 0.0
label_smoothing : 0.0
max_grad_norm : nullArguments avec ??? sont nécessaires.
Wandb :
wandb.entity : l'entité Wandb à utiliser pour la journalisation.expérience :
experiment.name : le nom de l'expérience.experiment.project : Le projet WANDB à utiliser pour l'exploitation forestière.experiment.output_dir : le répertoire pour enregistrer les points de contrôle.experiment.max_train_examples : Le nombre maximum d'exemples de formation à utiliser.experiment.save_every : Enregistrez un point de contrôle à chaque étape save_every .experiment.eval_every : Évaluez le modèle toutes les étapes eval_every .experiment.generate_every : Générez des images toutes les étapes generate_every .experiment.log_every : enregistrez les mesures de formation toutes les étapes log_every .log_grad_norm_every : Loguez la norme de gradient à chaque log_grad_norm_every .experiment.resume_from_checkpoint : le point de contrôle pour reprendre la formation. Peut être latest à reprendre du dernier point de contrôle ou un chemin vers un point de contrôle enregistré. Si None ou le chemin n'existe, l'entraînement commence à partir de zéro.modèle :
model.vq_model.pretrained : le modèle VQ prétrainé à utiliser. Peut être un chemin vers un point de contrôle enregistré ou un nom de modèle HuggingFace.model.transformer : la configuration du modèle de transformateur.model.gradient_checkpointing : Activer le point de contrôle du gradient pour le modèle de transformateur.enable_xformers_memory_efficient_attention : Activer l'attention efficace de la mémoire ou l'attention du flash pour le modèle de transformateur. Pour l'attention du flash, nous devons utiliser fp16 ou bf16 . XFormors doit être installé pour que cela fonctionne.ensemble de données :
dataset.params.train_shards_path_or_url : le chemin d'accès ou l'URL vers les éclats de formation webdataset .dataset.params.eval_shards_path_or_url : le chemin d'accès ou l'URL vers les éclats d'évaluation webdataset .dataset.params.batch_size : la taille du lot à utiliser pour la formation.dataset.params.shuffle_buffer_size : la taille du tampon Shuffle à utiliser pour la formation.dataset.params.num_workers : le nombre de travailleurs à utiliser pour le chargement des données.dataset.params.resolution : la résolution des images à utiliser pour la formation.dataset.params.pin_memory : épinglez la mémoire pour le chargement des données.dataset.params.persistent_workers : Utilisez des travailleurs persistants pour le chargement des données.dataset.preprocessing.resolution : la résolution des images à utiliser pour le prétraitement.dataset.preprocessing.center_crop : s'il faut centrer les images. Si False , les images sont recadrées au hasard à la resolution .dataset.preprocessing.random_flip : s'il faut retourner au hasard les images. Si False , les images ne sont pas retournées.Optimiseur :
optimizer.name : l'optimiseur à utiliser pour la formation.optimizer.params : les paramètres Optimizer.LR_SCHEDULER :
lr_scheduler.scheduler : Le planificateur de taux d'apprentissage à utiliser pour la formation.lr_scheduler.params : Les paramètres du planificateur de taux d'apprentissage.entraînement :
training.gradient_accumulation_steps : Le nombre d'étapes d'accumulation de gradient à utiliser pour la formation.training.batch_size : la taille du lot à utiliser pour la formation.training.mixed_precision : le mode de précision mixte à utiliser pour la formation. Peut être no , fp16 ou bf16 .training.enable_tf32 : Activez TF32 pour la formation sur les GPU AMPERE.training.use_ema : Activer EMA pour la formation. Actuellement non pris en charge.training.seed : La graine à utiliser pour la formation.training.max_train_steps : le nombre maximum d'étapes de formation.training.overfit_one_batch : s'il faut surfiler un lot pour le débogage.training.min_masking_rate : le taux de masquage minimum à utiliser pour la formation.training.label_smoothing : la valeur de lissage de l'étiquette à utiliser pour la formation.max_grad_norm : MAX Gradient Norm.Notes sur la formation et l'ensemble de données. :
Nous rééchapons au hasard les éclats (avec remplacement) et échantillons des exemples dans le tampon pour l'entraînement chaque fois que nous reprenons / commençons la course d'entraînement. Cela signifie que notre chargement de données n'est pas déterminant. Nous ne faisons pas non plus une formation basée sur l'époque, mais nous utilisons simplement cela pour la tenue de livres et la possibilité de réutiliser la même boucle de formation avec d'autres ensembles de données / chargeurs.
Jusqu'à présent, nous exécutons des expériences sur un seul nœud. Pour lancer une course d'entraînement sur un seul nœud, exécutez les étapes suivantes:
webdataset . Vous pouvez utiliser le script scripts/convert_imagenet_to_wds.py pour convertir l'ensemble de données ImageNet au format webdataset .accelerate config .config.yaml pour votre expérience.accelerate launch . accelerate launch python -u training/train_maskgit_imagenet.py config=path/to/yaml/config Avec Omegaconf, les remplacements de ligne de commande sont effectués au format de note de point. Par exemple, si vous souhaitez remplacer le chemin du jeu de données, vous utiliseriez la commande python -u train.py config=path/to/config dataset.params.path=path/to/dataset .
La même commande peut être utilisée pour lancer la formation localement.
├── README.md
├── configs -> All training config files.
│ └── template_config.yaml
├── muse
│ ├── __init__.py
│ ├── data.py -> All data related utils. Can create a data folder if needed.
│ ├── logging.py -> Misc logging utils.
| ├── lr_schedulers.py -> All lr scheduler related utils.
│ ├── modeling_maskgit_vqgan.py -> VQGAN model from maskgit repo.
│ ├── modeling_taming_vqgan.py -> VQGAN model from taming repo.
│ └── modeling_transformer.py -> The main transformer model.
│ ├── modeling_utils.py -> All model related utils, like save_pretrained, from_pretrained from hub etc
│ ├── sampling.py -> Sampling/Generation utils.
│ ├── training_utils.py -> Common training utils.
├── pyproject.toml
├── setup.cfg
├── setup.py
├── test.py
└── training -> All training scripts.
├── __init__.py
├── data.py -> All data related utils. Can create a data folder if needed.
├── optimizer.py -> All optimizer related utils and any new optimizer not available in PT.
├── train_maskgit_imagenet.py
├── train_muse.py
└── train_vqgan.py
Ce projet est basé sur les références open source suivantes. Merci à tous les auteurs pour leur travail incroyable.
Et Obivioulsy à l'équipe Pytorch pour ce cadre incroyable ❤️