Um esforço de reprodução aberta para reproduzir o modelo de musa baseado em transformadores para geração rápida de text2image.
https://huggingface.co/spaces/openmuse/muse
Este repositório é para reprodução do modelo de musa. O objetivo é criar um repositório simples e escalável, reproduzir o Muse e construir o Knowedge sobre os transformadores VQ + em escala. Usaremos o conjunto de dados Dedused Laion-2b + Coyo-700M para treinamento.
Estágios do projeto:
Todos os artefatos deste projeto serão enviados para a organização OpenMuse no HUGGINGFACE HUB.
Primeiro, crie um ambiente virtual e instale o repo usando:
git clone https://github.com/huggingface/muse
cd muse
pip install -e " .[extra] " Você precisará instalar PyTorch e torchvision manualmente. Estamos usando torch==1.13.1 com CUDA11.7 para treinamento.
Para o treinamento paralelo de dados distribuídos, usamos a biblioteca accelerate , embora isso possa mudar no futuro. Para carga do conjunto de dados, usamos a biblioteca webdataset . Portanto, o conjunto de dados deve estar no formato webdataset .
No Momemnt, apoiamos os seguintes modelos:
MaskGitTransformer - O modelo de transformador principal do papel.MaskGitVQGAN - O modelo VQGAN do repositório MaskGit.VQGANModel - O modelo VQGAN do repo Taming Transformers. Os modelos são implementados no diretório muse . Todos os modelos implementam a API familiar transformers . Assim, você pode usar os métodos dos Métodos from_pretrained e save_pretrained para carregar e salvar os modelos. O modelo pode ser salvo e carregado no hub do Huggingface.
import torch
from torchvision import transforms
from PIL import Image
from muse import MaskGitVQGAN
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# encode and decode images using
encode_transform = = transforms . Compose (
[
transforms . Resize ( 256 , interpolation = transforms . InterpolationMode . BILINEAR ),
transforms . CenterCrop ( 256 ),
transforms . ToTensor (),
]
)
image = Image . open ( "..." ) #
pixel_values = encode_transform ( image ). unsqueeze ( 0 )
image_tokens = vq_model . encode ( pixel_values )
rec_image = vq_model . decode ( image_tokens )
# Convert to PIL images
rec_image = 2.0 * rec_image - 1.0
rec_image = torch . clamp ( rec_image , - 1.0 , 1.0 )
rec_image = ( rec_image + 1.0 ) / 2.0
rec_image *= 255.0
rec_image = rec_image . permute ( 0 , 2 , 3 , 1 ). cpu (). numpy (). astype ( np . uint8 )
pil_images = [ Image . fromarray ( image ) for image in rec_image ] import torch
from muse import MaskGitTransformer , MaskGitVQGAN
from muse . sampling import cosine_schedule
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# Initialize the MaskGitTransformer model
maskgit_model = MaskGitTransformer (
vocab_size = 2025 , #(1024 + 1000 + 1 = 2025 -> Vq_tokens + Imagenet class ids + <mask>)
max_position_embeddings = 257 , # 256 + 1 for class token
hidden_size = 512 ,
num_hidden_layers = 8 ,
num_attention_heads = 8 ,
intermediate_size = 2048 ,
codebook_size = 1024 ,
num_vq_tokens = 256 ,
num_classes = 1000 ,
)
# prepare the input batch
images = torch . randn ( 4 , 3 , 256 , 256 )
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
# encode the images
image_tokens = vq_model . encode ( images )
batch_size , seq_len = image_tokens . shape
# Sample a random timestep for each image
timesteps = torch . rand ( batch_size , device = image_tokens . device )
# Sample a random mask probability for each image using timestep and cosine schedule
mask_prob = cosine_schedule ( timesteps )
mask_prob = mask_prob . clip ( min_masking_rate )
# creat a random mask for each image
num_token_masked = ( seq_len * mask_prob ). round (). clamp ( min = 1 )
batch_randperm = torch . rand ( batch_size , seq_len , device = image_tokens . device ). argsort ( dim = - 1 )
mask = batch_randperm < num_token_masked . unsqueeze ( - 1 )
# mask images and create input and labels
input_ids = torch . where ( mask , mask_id , image_tokens )
labels = torch . where ( mask , image_tokens , - 100 )
# shift the class ids by codebook size
class_ids = class_ids + vq_model . num_embeddings
# prepend the class ids to the image tokens
input_ids = torch . cat ([ class_ids . unsqueeze ( - 1 ), input_ids ], dim = - 1 )
# prepend -100 to the labels as we don't want to predict the class ids
labels = torch . cat ([ - 100 * torch . ones_like ( class_ids ). unsqueeze ( - 1 ), labels ], dim = - 1 )
# forward pass
logits , loss = maskgit_model ( input_ids , labels = labels )
loss . backward ()
# to generate images
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
generated_tokens = maskgit_model . generate ( class_ids = class_ids )
rec_images = vq_model . decode ( generated_tokens )Observação :
Maskgits é um transformador que produz logits, dada uma sequência de tokens de vq e token de etiqueta com classe de classe
A maneira como o processo de denoising é feito é mascarar com máscara IDs de token e gradualmente denoise
Na implementação original, isso é feito primeiro usando um softmax no último escuro e a amostragem aleatoriamente como uma distribuição categórica. Isso dará nossos tokens previstos para cada Maskid. Em seguida, obtemos as probabilidades para que esses tokens sejam escolhidos. Finalmente, obtemos as mais altas probabilidades de confiança do Topk quando a temperatura do Gumbel*é adicionada a ele. A distribuição de góbulos é como uma distribuição normal deslocada em direção a 0, que é usada para modelar eventos extremos. Então, em cenários extremos, gostaremos de ver um token diferente sendo escolhido
Para a implementação da Lucidrian, ele primeiro remove os tokens de maior pontuação (menor probabilidade), mascarando-os com uma determinada taxa de mascaramento. Em seguida, exceto os 10% mais altos dos logits que obtemos, configuramos -o como -infinity, para que, quando fizermos a distribuição de gumbel, eles serão ignorados. Em seguida, atualize os IDs de entrada e as pontuações em que as pontuações são apenas 1-a probabilidade fornecida pelo softmax dos logits nos IDs previstos interessantes
Para o ImageNet de classe Condicional, estamos usando accelerate para treinamento DDP e webdataset para carregamento de dados. O script de treinamento está disponível em training/train_maskgit_imagenet.py .
Usamos o Omegaconf para gerenciamento de configuração. Consulte configs/template_config.yaml para o modelo de configuração. Abaixo, explicamos os parâmetros de configuração.
wandb :
entity : ???
experiment :
name : ???
project : ???
output_dir : ???
max_train_examples : ???
save_every : 1000
eval_every : 500
generate_every : 1000
log_every : 50
log_grad_norm_every : 100
resume_from_checkpoint : latest
model :
vq_model :
pretrained : " openMUSE/maskgit-vqgan-imagenet-f16-256 "
transformer :
vocab_size : 2048 # (1024 + 1000 + 1 = 2025 -> Vq + Imagenet + <mask>, use 2048 for even division by 8)
max_position_embeddings : 264 # (256 + 1 for class id, use 264 for even division by 8)
hidden_size : 768
num_hidden_layers : 12
num_attention_heads : 12
intermediate_size : 3072
codebook_size : 1024
num_vq_tokens : 256
num_classes : 1000
initializer_range : 0.02
layer_norm_eps : 1e-6
use_bias : False
use_normformer : True
use_encoder_layernorm : True
hidden_dropout : 0.0
attention_dropout : 0.0
gradient_checkpointing : True
enable_xformers_memory_efficient_attention : False
dataset :
params :
train_shards_path_or_url : ???
eval_shards_path_or_url : ???
batch_size : ${training.batch_size}
shuffle_buffer_size : ???
num_workers : ???
resolution : 256
pin_memory : True
persistent_workers : True
preprocessing :
resolution : 256
center_crop : True
random_flip : False
optimizer :
name : adamw # Can be adamw or lion or fused_adamw. Install apex for fused_adamw
params : # default adamw params
learning_rate : ???
scale_lr : False # scale learning rate by total batch size
beta1 : 0.9
beta2 : 0.999
weight_decay : 0.01
epsilon : 1e-8
lr_scheduler :
scheduler : " constant_with_warmup "
params :
learning_rate : ${optimizer.params.learning_rate}
warmup_steps : 500
training :
gradient_accumulation_steps : 1
batch_size : 128
mixed_precision : " no "
enable_tf32 : True
use_ema : False
seed : 42
max_train_steps : ???
overfit_one_batch : False
min_masking_rate : 0.0
label_smoothing : 0.0
max_grad_norm : nullArgumentos com ??? são necessários.
wandb :
wandb.entity : a entidade wandb a ser usada no log.experimento :
experiment.name : o nome do experimento.experiment.project : O projeto Wandb a ser usado para registro.experiment.output_dir : o diretório para salvar os pontos de verificação.experiment.max_train_examples : o número máximo de exemplos de treinamento a serem usados.experiment.save_every : salve um ponto de verificação a cada etapas save_every .experiment.eval_every : Avalie o modelo todas as etapas eval_every .experiment.generate_every : Gere imagens a cada generate_every os passos.experiment.log_every : registre as métricas de treinamento todas as etapas log_every .log_grad_norm_every : registre a norma de gradiente cada log_grad_norm_every etapas.experiment.resume_from_checkpoint : o ponto de verificação para retomar o treinamento. Pode ser latest a retomar do ponto de verificação mais recente ou caminho para um ponto de verificação salvo. Se None ou o caminho não existir, o treinamento começará do zero.modelo :
model.vq_model.pretrained : O modelo VQ pré -treinado a ser usado. Pode ser um caminho para um ponto de verificação salvo ou um nome de modelo Huggingface.model.transformer : A configuração do modelo do transformador.model.gradient_checkpointing : Ativar verificação de gradiente para o modelo do transformador.enable_xformers_memory_efficient_attention : ative atenção eficiente da memória ou flash Atenção para o modelo do transformador. Para atenção flash, precisamos usar fp16 ou bf16 . O Xformers precisa ser instalado para que isso funcione.conjunto de dados :
dataset.params.train_shards_path_or_url : o caminho ou URL para os shards de treinamento webdataset .dataset.params.eval_shards_path_or_url webdatasetdataset.params.batch_size : o tamanho do lote a ser usado para treinamento.dataset.params.shuffle_buffer_sizedataset.params.num_workers : o número de trabalhadores a serem usados para carregar dados.dataset.params.resolution : a resolução das imagens a serem usadas para treinamento.dataset.params.pin_memory : fixar a memória para carregamento de dados.dataset.params.persistent_workers : use trabalhadores persistentes para carregar dados.dataset.preprocessing.resolution : A resolução das imagens a serem usadas para pré -processamento.dataset.preprocessing.center_crop : se deve centrar as imagens. Se False , as imagens são cortadas aleatoriamente para a resolution .dataset.preprocessing.random_flip : se deve girar aleatoriamente as imagens. Se False , as imagens não são invertidas.Otimizador :
optimizer.name : o otimizador a ser usado para treinamento.optimizer.params : os parâmetros do otimizador.lr_scheduler :
lr_scheduler.scheduler : o agendador de taxas de aprendizagem usar para treinamento.lr_scheduler.params : os parâmetros do agendador de taxas de aprendizado.treinamento :
training.gradient_accumulation_steps : o número de etapas de acumulação de gradiente a serem usadas para o treinamento.training.batch_size : o tamanho do lote a ser usado para treinamento.training.mixed_precision : o modo de precisão mista a ser usada para treinamento. Pode ser no , fp16 ou bf16 .training.enable_tf32 : Ativar TF32 para treinamento em GPUs Ampere.training.use_ema : Ativar EMA para treinamento. Atualmente não suportado.training.seed : a semente a ser usada para treinamento.training.max_train_steps : o número máximo de etapas de treinamento.training.overfit_one_batch : se deve ao excesso de um lote para depuração.training.min_masking_rate : A taxa de mascaramento mínima a ser usada para treinamento.training.label_smoothing : o valor de suavização da etiqueta a ser usado para treinamento.max_grad_norm : norma de gradiente máximo.Notas sobre treinamento e conjunto de dados. :
Resolemos aleatoriamente os fragmentos (com substituição) e exemplos de amostra em buffer para treinamento toda vez que retomamos/iniciamos a execução de treinamento. Isso significa que nosso carregamento de dados não é determinado. Também não fazemos treinamento baseado em Epoch, mas apenas usando isso para manutenção de livros e ser capaz de reutilizar o mesmo loop de treinamento com outros conjuntos de dados/carregadores.
Até agora, estamos executando experimentos no nó único. Para iniciar uma execução de treinamento em um único nó, execute as seguintes etapas:
webdataset . Você pode usar os scripts/convert_imagenet_to_wds.py para converter o conjunto de dados ImageNet em formato webdataset .accelerate config .config.yaml para o seu experimento.accelerate launch . accelerate launch python -u training/train_maskgit_imagenet.py config=path/to/yaml/config Com o omegaconf, as substituições de linha de comando são feitas no formato de notação de pontos. Por exemplo, se você deseja substituir o caminho do conjunto de dados, você usaria o comando python -u train.py config=path/to/config dataset.params.path=path/to/dataset .
O mesmo comando pode ser usado para iniciar o treinamento localmente.
├── README.md
├── configs -> All training config files.
│ └── template_config.yaml
├── muse
│ ├── __init__.py
│ ├── data.py -> All data related utils. Can create a data folder if needed.
│ ├── logging.py -> Misc logging utils.
| ├── lr_schedulers.py -> All lr scheduler related utils.
│ ├── modeling_maskgit_vqgan.py -> VQGAN model from maskgit repo.
│ ├── modeling_taming_vqgan.py -> VQGAN model from taming repo.
│ └── modeling_transformer.py -> The main transformer model.
│ ├── modeling_utils.py -> All model related utils, like save_pretrained, from_pretrained from hub etc
│ ├── sampling.py -> Sampling/Generation utils.
│ ├── training_utils.py -> Common training utils.
├── pyproject.toml
├── setup.cfg
├── setup.py
├── test.py
└── training -> All training scripts.
├── __init__.py
├── data.py -> All data related utils. Can create a data folder if needed.
├── optimizer.py -> All optimizer related utils and any new optimizer not available in PT.
├── train_maskgit_imagenet.py
├── train_muse.py
└── train_vqgan.py
Este projeto é baseado nos seguintes repositórios de código aberto. Obrigado a todos os autores por seu incrível trabalho.
E obivioulsy à equipe de Pytorch para esta incrível estrutura ❤️