Un esfuerzo de reproducción abierta para reproducir el modelo MUSE basado en transformadores para la generación rápida de text2Image.
https://huggingface.co/spaces/openmuse/muse
Este repositorio es para la reproducción del modelo Muse. El objetivo es crear un repositorio simple y escalable, reproducir a Muse y construir Knowedge sobre VQ + Transformers a escala. Usaremos el conjunto de datos LAion-2B + COYO-700M dedupado para capacitación.
Etapas del proyecto:
Todos los artefactos de este proyecto se cargarán en la organización OpenMuse en el Hub Huggingface.
Primero cree un entorno virtual e instale el repositorio usando:
git clone https://github.com/huggingface/muse
cd muse
pip install -e " .[extra] " Deberá instalar PyTorch y torchvision manualmente. Estamos usando torch==1.13.1 con CUDA11.7 para el entrenamiento.
Para la capacitación paralela de datos distribuidos utilizamos la biblioteca accelerate , aunque esto puede cambiar en el futuro. Para la carga del conjunto de datos, utilizamos la biblioteca webdataset . Entonces, el conjunto de datos debe estar en el formato webdataset .
En el Momemnt apoyamos los siguientes modelos:
MaskGitTransformer : el modelo de transformador principal del papel.MaskGitVQGAN : el modelo VQGAN del repositorio de Maskgit.VQGANModel - El modelo VQGAN del repositorio de transformadores de domesticación. Los modelos se implementan en el directorio muse . Todos los modelos implementan la API transformers familiar. Por lo tanto, puede usar los métodos from_pretrained y save_pretrained para cargar y guardar los modelos. El modelo se puede guardar y cargar desde el Hub Huggingface.
import torch
from torchvision import transforms
from PIL import Image
from muse import MaskGitVQGAN
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# encode and decode images using
encode_transform = = transforms . Compose (
[
transforms . Resize ( 256 , interpolation = transforms . InterpolationMode . BILINEAR ),
transforms . CenterCrop ( 256 ),
transforms . ToTensor (),
]
)
image = Image . open ( "..." ) #
pixel_values = encode_transform ( image ). unsqueeze ( 0 )
image_tokens = vq_model . encode ( pixel_values )
rec_image = vq_model . decode ( image_tokens )
# Convert to PIL images
rec_image = 2.0 * rec_image - 1.0
rec_image = torch . clamp ( rec_image , - 1.0 , 1.0 )
rec_image = ( rec_image + 1.0 ) / 2.0
rec_image *= 255.0
rec_image = rec_image . permute ( 0 , 2 , 3 , 1 ). cpu (). numpy (). astype ( np . uint8 )
pil_images = [ Image . fromarray ( image ) for image in rec_image ] import torch
from muse import MaskGitTransformer , MaskGitVQGAN
from muse . sampling import cosine_schedule
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# Initialize the MaskGitTransformer model
maskgit_model = MaskGitTransformer (
vocab_size = 2025 , #(1024 + 1000 + 1 = 2025 -> Vq_tokens + Imagenet class ids + <mask>)
max_position_embeddings = 257 , # 256 + 1 for class token
hidden_size = 512 ,
num_hidden_layers = 8 ,
num_attention_heads = 8 ,
intermediate_size = 2048 ,
codebook_size = 1024 ,
num_vq_tokens = 256 ,
num_classes = 1000 ,
)
# prepare the input batch
images = torch . randn ( 4 , 3 , 256 , 256 )
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
# encode the images
image_tokens = vq_model . encode ( images )
batch_size , seq_len = image_tokens . shape
# Sample a random timestep for each image
timesteps = torch . rand ( batch_size , device = image_tokens . device )
# Sample a random mask probability for each image using timestep and cosine schedule
mask_prob = cosine_schedule ( timesteps )
mask_prob = mask_prob . clip ( min_masking_rate )
# creat a random mask for each image
num_token_masked = ( seq_len * mask_prob ). round (). clamp ( min = 1 )
batch_randperm = torch . rand ( batch_size , seq_len , device = image_tokens . device ). argsort ( dim = - 1 )
mask = batch_randperm < num_token_masked . unsqueeze ( - 1 )
# mask images and create input and labels
input_ids = torch . where ( mask , mask_id , image_tokens )
labels = torch . where ( mask , image_tokens , - 100 )
# shift the class ids by codebook size
class_ids = class_ids + vq_model . num_embeddings
# prepend the class ids to the image tokens
input_ids = torch . cat ([ class_ids . unsqueeze ( - 1 ), input_ids ], dim = - 1 )
# prepend -100 to the labels as we don't want to predict the class ids
labels = torch . cat ([ - 100 * torch . ones_like ( class_ids ). unsqueeze ( - 1 ), labels ], dim = - 1 )
# forward pass
logits , loss = maskgit_model ( input_ids , labels = labels )
loss . backward ()
# to generate images
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
generated_tokens = maskgit_model . generate ( class_ids = class_ids )
rec_images = vq_model . decode ( generated_tokens )Nota :
Maskgits es un transformador que genera registros dada una secuencia de tokens de vQ y token de etiqueta acondicionado de clase
La forma en que se realiza el proceso de renovación es enmascarar con ID de token de máscara y gradualmente DENOISE
En la implementación original, esto se hace primero utilizando un Softmax en el último Dim y el muestreo aleatorio como una distribución categórica. Esto dará nuestras fichas predichas para cada Maskid. Luego obtenemos las probabilidades de elegir esas fichas. Finalmente, obtenemos las probabilidades de confianza más alta de TOPK cuando se agrega la temperatura de Gumbel*. La distribución de Gumbel es como una distribución normal cambiada hacia 0 que se utiliza para modelar eventos extremos. Entonces, en escenarios extremos, nos gustaría ver que se elige un token diferente de la predeterminada
Para la implementación lucidriana, primero elimina los tokens de mayor puntuación (probabilidad más baja) enmascarándolos con una relación de enmascaramiento dada. Luego, a excepción del 10% más alto de los logits que obtenemos, lo establecemos en Infinity, por lo que cuando hagamos la distribución de Gumbel en él, serán ignorados. Luego, actualice los ID de entrada y los puntajes donde los puntajes son solo 1-la probabilidad dada por el Softmax de los logits en las ID predichas curiosamente
Para el Imagenet condicional de clase estamos utilizando accelerate para la capacitación DDP y webdataset para la carga de datos. El script de entrenamiento está disponible en training/train_maskgit_imagenet.py .
Utilizamos omegaconf para la gestión de la configuración. Consulte configs/template_config.yaml para la plantilla de configuración. A continuación explicamos los parámetros de configuración.
wandb :
entity : ???
experiment :
name : ???
project : ???
output_dir : ???
max_train_examples : ???
save_every : 1000
eval_every : 500
generate_every : 1000
log_every : 50
log_grad_norm_every : 100
resume_from_checkpoint : latest
model :
vq_model :
pretrained : " openMUSE/maskgit-vqgan-imagenet-f16-256 "
transformer :
vocab_size : 2048 # (1024 + 1000 + 1 = 2025 -> Vq + Imagenet + <mask>, use 2048 for even division by 8)
max_position_embeddings : 264 # (256 + 1 for class id, use 264 for even division by 8)
hidden_size : 768
num_hidden_layers : 12
num_attention_heads : 12
intermediate_size : 3072
codebook_size : 1024
num_vq_tokens : 256
num_classes : 1000
initializer_range : 0.02
layer_norm_eps : 1e-6
use_bias : False
use_normformer : True
use_encoder_layernorm : True
hidden_dropout : 0.0
attention_dropout : 0.0
gradient_checkpointing : True
enable_xformers_memory_efficient_attention : False
dataset :
params :
train_shards_path_or_url : ???
eval_shards_path_or_url : ???
batch_size : ${training.batch_size}
shuffle_buffer_size : ???
num_workers : ???
resolution : 256
pin_memory : True
persistent_workers : True
preprocessing :
resolution : 256
center_crop : True
random_flip : False
optimizer :
name : adamw # Can be adamw or lion or fused_adamw. Install apex for fused_adamw
params : # default adamw params
learning_rate : ???
scale_lr : False # scale learning rate by total batch size
beta1 : 0.9
beta2 : 0.999
weight_decay : 0.01
epsilon : 1e-8
lr_scheduler :
scheduler : " constant_with_warmup "
params :
learning_rate : ${optimizer.params.learning_rate}
warmup_steps : 500
training :
gradient_accumulation_steps : 1
batch_size : 128
mixed_precision : " no "
enable_tf32 : True
use_ema : False
seed : 42
max_train_steps : ???
overfit_one_batch : False
min_masking_rate : 0.0
label_smoothing : 0.0
max_grad_norm : nullArgumentos con ??? son necesarios.
Wandb :
wandb.entity : la entidad Wandb para usar para el registro.experimento :
experiment.name . Nombre: el nombre del experimento.experiment.project : el proyecto WandB para usar para el registro.experiment.output_dir : el directorio para guardar los puntos de control.experiment.max_train_examples : el número máximo de ejemplos de entrenamiento a usar.experiment.save_every : guarde un punto de control cada save_every pasos.experiment.eval_every : evalúe el modelo cada paso eval_every .experiment.generate_every : Genere imágenes cada uno generate_every los pasos.experiment.log_every : registre las métricas de entrenamiento en todos los pasos log_every .log_grad_norm_every : registre la norma de gradiente cada log_grad_norm_every pasos.experiment.resume_from_checkpoint : el punto de control para reanudar el entrenamiento. Puede ser latest para reanudar desde el último punto de control o ruta a un punto de control guardado. Si None o el camino no existe, entonces el entrenamiento comienza desde cero.modelo :
model.vq_model.pretrained : el modelo VQ previsto para usar. Puede ser una ruta hacia un punto de control guardado o un nombre del modelo de Huggingface.model.transformer : la configuración del modelo de transformador.model.gradient_checkpointing : habilitar el punto de control de gradiente para el modelo de transformador.enable_xformers_memory_efficient_attention : habilite la atención eficiente de memoria o flash de atención para el modelo Transformer. Para la atención de flash necesitamos usar fp16 o bf16 . Se necesita instalar Xformers para que esto funcione.conjunto de datos :
dataset.params.train_shards_path_or_url : la ruta o URL a los fragmentos de entrenamiento webdataset .dataset.params.eval_shards_path_or_url : la ruta o URL a los fragmentos de evaluación webdataset .dataset.params.batch_size : el tamaño de lotes para usar para el entrenamiento.dataset.params.shuffle_buffer_size : el tamaño del búfer Shuffle para usar para el entrenamiento.dataset.params.num_workers : el número de trabajadores a usar para la carga de datos.dataset.params.resolution : La resolución de las imágenes a usar para la capacitación.dataset.params.pin_memory : fija la memoria para la carga de datos.dataset.params.persistent_workers : use trabajadores persistentes para la carga de datos.dataset.preprocessing.resolution : La resolución de las imágenes a usar para el preprocesamiento.dataset.preprocessing.center_crop : si se debe centrar las imágenes. Si False , las imágenes se recortan al azar a la resolution .dataset.preprocessing.random_flip : si debe voltear al azar las imágenes. Si False , las imágenes no se voltean.optimizador :
optimizer.name : el optimizador a usar para el entrenamiento.optimizer.params : los parámetros de optimizador.LR_SCHEDULER :
lr_scheduler.scheduler : El planificador de tarifas de aprendizaje a usar para la capacitación.lr_scheduler.params : los parámetros del programador de tasa de aprendizaje.capacitación :
training.gradient_accumulation_steps : el número de pasos de acumulación de gradiente para usar para el entrenamiento.training.batch_size : el tamaño de lotes para usar para el entrenamiento.training.mixed_precision : el modo de precisión mixto para usar para el entrenamiento. Puede ser no , fp16 o bf16 .training.enable_tf32 : Habilite TF32 para el entrenamiento en GPU de amperios.training.use_ema : habilitar EMA para el entrenamiento. Actualmente no es compatible.training.seed : La semilla para usar para el entrenamiento.training.max_train_steps : el número máximo de pasos de entrenamiento.training.overfit_one_batch : si se debe sobrevisar un lote para la depuración.training.min_masking_rate : la tasa de enmascaramiento mínima para usar para el entrenamiento.training.label_smoothing : el valor de suavizado de la etiqueta para usar para el entrenamiento.max_grad_norm : Norma de gradiente máximo.Notas sobre capacitación y conjunto de datos. :
Reamitamos aleatoriamente los fragmentos (con reemplazo) y ejemplos de muestras en buffer para entrenamiento cada vez que reanudamos/iniciamos la ejecución de entrenamiento. Esto significa que nuestra carga de datos no es determinada. Tampoco hacemos capacitación basada en una época, pero solo usamos esto para el mantenimiento de libros y poder reutilizar el mismo bucle de entrenamiento con otros conjuntos de datos/cargadores.
Hasta ahora estamos ejecutando experimentos en un solo nodo. Para lanzar una ejecución de entrenamiento en un solo nodo, ejecute los siguientes pasos:
webdataset . Puede usar el scripts/convert_imagenet_to_wds.py script para convertir el conjunto de datos de ImageNet en formato webdataset .accelerate config .config.yaml para su experimento.accelerate launch . accelerate launch python -u training/train_maskgit_imagenet.py config=path/to/yaml/config Con Omegaconf, las anulaciones de línea de comandos se realizan en formato de notación de puntos. Por ejemplo, si desea anular la ruta del conjunto de datos, utilizaría el comando python -u train.py config=path/to/config dataset.params.path=path/to/dataset .
El mismo comando se puede usar para iniciar la capacitación localmente.
├── README.md
├── configs -> All training config files.
│ └── template_config.yaml
├── muse
│ ├── __init__.py
│ ├── data.py -> All data related utils. Can create a data folder if needed.
│ ├── logging.py -> Misc logging utils.
| ├── lr_schedulers.py -> All lr scheduler related utils.
│ ├── modeling_maskgit_vqgan.py -> VQGAN model from maskgit repo.
│ ├── modeling_taming_vqgan.py -> VQGAN model from taming repo.
│ └── modeling_transformer.py -> The main transformer model.
│ ├── modeling_utils.py -> All model related utils, like save_pretrained, from_pretrained from hub etc
│ ├── sampling.py -> Sampling/Generation utils.
│ ├── training_utils.py -> Common training utils.
├── pyproject.toml
├── setup.cfg
├── setup.py
├── test.py
└── training -> All training scripts.
├── __init__.py
├── data.py -> All data related utils. Can create a data folder if needed.
├── optimizer.py -> All optimizer related utils and any new optimizer not available in PT.
├── train_maskgit_imagenet.py
├── train_muse.py
└── train_vqgan.py
Este proyecto se basa en los siguientes reposadores de código abierto. Gracias a todos los autores por su increíble trabajo.
Y obivioulsy al equipo de Pytorch para este increíble marco ❤️