Открытые усилия по воспроизведению модели Muse на основе трансформатора для быстрого генерации Text2image.
https://huggingface.co/spaces/openmuse/muse
Это репо предназначено для воспроизведения модели Muse. Цель состоит в том, чтобы создать простое и масштабируемое репо, воспроизвести Muse и построить знание о трансформаторах VQ + в масштабе. Мы будем использовать набор данных DEDUPED LAION-2B + COYO-700M для обучения.
Стадии проекта:
Все артефакты этого проекта будут загружены в организацию OpenMuse в центре Huggingface.
Сначала создайте виртуальную среду и установите репо, используя:
git clone https://github.com/huggingface/muse
cd muse
pip install -e " .[extra] " Вам нужно будет установить PyTorch и torchvision вручную. Мы используем torch==1.13.1 с CUDA11.7 для обучения.
Для распределенной параллельной обучения мы используем библиотеку accelerate , хотя это может измениться в будущем. Для загрузки набора данных мы используем библиотеку webdataset . Таким образом, набор данных должен быть в формате webdataset .
В Momemnt мы поддерживаем следующие модели:
MaskGitTransformer - основная модель трансформатора из бумаги.MaskGitVQGAN - модель VQGAN из Maskgit Repo.VQGANModel - модель VQGAN из Taming Transformers Repo. Модели реализованы в каталоге muse . Все модели реализуют знакомые transformers API. Таким образом, вы можете использовать методы from_pretrained и save_pretrained для загрузки и сохранения моделей. Модель может быть сохранена и загружена из центра Huggingface.
import torch
from torchvision import transforms
from PIL import Image
from muse import MaskGitVQGAN
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# encode and decode images using
encode_transform = = transforms . Compose (
[
transforms . Resize ( 256 , interpolation = transforms . InterpolationMode . BILINEAR ),
transforms . CenterCrop ( 256 ),
transforms . ToTensor (),
]
)
image = Image . open ( "..." ) #
pixel_values = encode_transform ( image ). unsqueeze ( 0 )
image_tokens = vq_model . encode ( pixel_values )
rec_image = vq_model . decode ( image_tokens )
# Convert to PIL images
rec_image = 2.0 * rec_image - 1.0
rec_image = torch . clamp ( rec_image , - 1.0 , 1.0 )
rec_image = ( rec_image + 1.0 ) / 2.0
rec_image *= 255.0
rec_image = rec_image . permute ( 0 , 2 , 3 , 1 ). cpu (). numpy (). astype ( np . uint8 )
pil_images = [ Image . fromarray ( image ) for image in rec_image ] import torch
from muse import MaskGitTransformer , MaskGitVQGAN
from muse . sampling import cosine_schedule
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# Initialize the MaskGitTransformer model
maskgit_model = MaskGitTransformer (
vocab_size = 2025 , #(1024 + 1000 + 1 = 2025 -> Vq_tokens + Imagenet class ids + <mask>)
max_position_embeddings = 257 , # 256 + 1 for class token
hidden_size = 512 ,
num_hidden_layers = 8 ,
num_attention_heads = 8 ,
intermediate_size = 2048 ,
codebook_size = 1024 ,
num_vq_tokens = 256 ,
num_classes = 1000 ,
)
# prepare the input batch
images = torch . randn ( 4 , 3 , 256 , 256 )
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
# encode the images
image_tokens = vq_model . encode ( images )
batch_size , seq_len = image_tokens . shape
# Sample a random timestep for each image
timesteps = torch . rand ( batch_size , device = image_tokens . device )
# Sample a random mask probability for each image using timestep and cosine schedule
mask_prob = cosine_schedule ( timesteps )
mask_prob = mask_prob . clip ( min_masking_rate )
# creat a random mask for each image
num_token_masked = ( seq_len * mask_prob ). round (). clamp ( min = 1 )
batch_randperm = torch . rand ( batch_size , seq_len , device = image_tokens . device ). argsort ( dim = - 1 )
mask = batch_randperm < num_token_masked . unsqueeze ( - 1 )
# mask images and create input and labels
input_ids = torch . where ( mask , mask_id , image_tokens )
labels = torch . where ( mask , image_tokens , - 100 )
# shift the class ids by codebook size
class_ids = class_ids + vq_model . num_embeddings
# prepend the class ids to the image tokens
input_ids = torch . cat ([ class_ids . unsqueeze ( - 1 ), input_ids ], dim = - 1 )
# prepend -100 to the labels as we don't want to predict the class ids
labels = torch . cat ([ - 100 * torch . ones_like ( class_ids ). unsqueeze ( - 1 ), labels ], dim = - 1 )
# forward pass
logits , loss = maskgit_model ( input_ids , labels = labels )
loss . backward ()
# to generate images
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
generated_tokens = maskgit_model . generate ( class_ids = class_ids )
rec_images = vq_model . decode ( generated_tokens )Примечание :
Maskgits-это трансформатор, который выводит Logits с учетом последовательности токенов как VQ, так и в токене, подготовленном в классе,
То, как осуществляется процесс, - это замаскировать с помощью идентификаторов токенов маски и постепенно денузировать
В исходной реализации это делается сначала с использованием Softmax на последней DIM и случайном отборе отбора проб в качестве категорического распределения. Это даст наши прогнозируемые токены для каждой маскид. Затем мы получаем вероятности, чтобы эти токены были выбраны. Наконец, мы получаем самые высокие вероятности уверенности в Topk, когда к нему добавляется температура Gumbel*. Распределение Gumbel подобно изменяемому нормальному распределению по отношению к 0, которое используется для моделирования экстремальных событий. Таким образом, в экстремальных сценариях мы хотели бы видеть другой токен, выбранное из по умолчанию.
Для реализации Lucidrian она сначала удаляет самые высокие токены (самая низкая вероятность), маскируя их с данным коэффициентом маскировки. Затем, за исключением самых высоких 10% логитов, которые мы получаем, мы устанавливаем его на -инфинтность, поэтому, когда мы делаем распределение Gumbel на нем, их игнорируют. Затем обновите входные идентификаторы и оценки, где оценки составляют всего 1-вероятность, определяемая Softmax of Logits на прогнозируемых идентификаторах.
Для Class-Conditional ImageNet мы используем accelerate для обучения DDP и webdataset для загрузки данных. Сценарий обучения доступен в training/train_maskgit_imagenet.py .
Мы используем Omegaconf для управления конфигурацией. См. configs/template_config.yaml для шаблона конфигурации. Ниже мы объясняем параметры конфигурации.
wandb :
entity : ???
experiment :
name : ???
project : ???
output_dir : ???
max_train_examples : ???
save_every : 1000
eval_every : 500
generate_every : 1000
log_every : 50
log_grad_norm_every : 100
resume_from_checkpoint : latest
model :
vq_model :
pretrained : " openMUSE/maskgit-vqgan-imagenet-f16-256 "
transformer :
vocab_size : 2048 # (1024 + 1000 + 1 = 2025 -> Vq + Imagenet + <mask>, use 2048 for even division by 8)
max_position_embeddings : 264 # (256 + 1 for class id, use 264 for even division by 8)
hidden_size : 768
num_hidden_layers : 12
num_attention_heads : 12
intermediate_size : 3072
codebook_size : 1024
num_vq_tokens : 256
num_classes : 1000
initializer_range : 0.02
layer_norm_eps : 1e-6
use_bias : False
use_normformer : True
use_encoder_layernorm : True
hidden_dropout : 0.0
attention_dropout : 0.0
gradient_checkpointing : True
enable_xformers_memory_efficient_attention : False
dataset :
params :
train_shards_path_or_url : ???
eval_shards_path_or_url : ???
batch_size : ${training.batch_size}
shuffle_buffer_size : ???
num_workers : ???
resolution : 256
pin_memory : True
persistent_workers : True
preprocessing :
resolution : 256
center_crop : True
random_flip : False
optimizer :
name : adamw # Can be adamw or lion or fused_adamw. Install apex for fused_adamw
params : # default adamw params
learning_rate : ???
scale_lr : False # scale learning rate by total batch size
beta1 : 0.9
beta2 : 0.999
weight_decay : 0.01
epsilon : 1e-8
lr_scheduler :
scheduler : " constant_with_warmup "
params :
learning_rate : ${optimizer.params.learning_rate}
warmup_steps : 500
training :
gradient_accumulation_steps : 1
batch_size : 128
mixed_precision : " no "
enable_tf32 : True
use_ema : False
seed : 42
max_train_steps : ???
overfit_one_batch : False
min_masking_rate : 0.0
label_smoothing : 0.0
max_grad_norm : nullАргументы с ??? требуются.
Wandb :
wandb.entity : сущность Wandb для использования для регистрации.Эксперимент :
experiment.name : имя эксперимента.experiment.project : проект WANDB для использования для регистрации.experiment.output_dir : каталог для сохранения контрольных точек.experiment.max_train_examples : максимальное количество обучающих примеров для использования.experiment.save_every : Сохраните контрольную точку каждые шаги save_every .experiment.eval_every . Eval_every: Оцените модель каждые шаги eval_every .experiment.generate_every : генерируйте изображения Каждое generate_every Steps.experiment.log_every log_everylog_grad_norm_every : войдите в log_grad_norm_every .experiment.resume_from_checkpoint : контрольная точка для возобновления обучения. Может быть latest , чтобы возобновить с последней контрольной точки или пути к сохраненной контрольной точке. Если None или путь не существует, обучение начинается с нуля.Модель :
model.vq_model.pretrained : предварительно подготовленная модель VQ для использования. Может быть пути к сохраненной контрольной точке или имени модели HuggingFice.model.transformer : конфигурация модели трансформатора.model.gradient_checkpointing : включить градиент контрольно -пропускной пункт для модели трансформатора.enable_xformers_memory_efficient_attention : включить эффективное внимание памяти или внимания к модели трансформатора. Для вспышки мы должны использовать fp16 или bf16 . Xformers необходимо установить для этого для работы.Набор данных :
dataset.params.train_shards_path_or_url : путь или URL -адрес для обучения webdataset .dataset.params.eval_shards_path_or_url : путь или URL в осколки оценки webdataset .dataset.params.batch_size : размер партии для обучения.dataset.params.shuffle_buffer_size : размер буфера для перерыва для обучения.dataset.params.num_workers : количество работников, которые можно использовать для загрузки данных.dataset.params.resolution : разрешение изображений для использования для обучения.dataset.params.pin_memory : Прижмите память для загрузки данных.dataset.params.persistent_workers : Используйте постоянных работников для загрузки данных.dataset.preprocessing.resolution : разрешение изображений для использования для предварительной обработки.dataset.preprocessing.center_crop : Центр обрезка изображений. Если False , то изображения случайно обрезаны до resolution .dataset.preprocessing.random_flip : следует ли случайным образом перевернуть изображения. Если False , то изображения не перевернуты.оптимизатор :
optimizer.name : оптимизатор для использования для обучения.optimizer.params : параметры оптимизатора.lr_scheduler :
lr_scheduler.scheduler : планировщик обучения для обучения.lr_scheduler.params : параметры планировщика обучения.обучение :
training.gradient_accumulation_steps : количество этапов накопления градиента для обучения.training.batch_size : размер партии для использования для обучения.training.mixed_precision : смешанный режим точности для использования для обучения. Может быть no , fp16 или bf16 .training.enable_tf32 : включить TF32 для обучения на графических процессорах Ampere.training.use_ema : включить EMA для обучения. В настоящее время не поддерживается.training.seed : семена для использования для обучения.training.max_train_steps : максимальное количество этапов обучения.training.overfit_one_batch : переполнять одну партию для отладки.training.min_masking_rate : минимальная скорость маскировки для использования для обучения.training.label_smoothing : значение сглаживания метки для использования для обучения.max_grad_norm : MAX GRADIENT NORM.Примечания об обучении и наборе данных. :
Мы случайным образом повторно воспринимаем осколки (с заменой) и примеры примеров в буфере для обучения каждый раз, когда возобновим/начинаем тренировочный запуск. Это означает, что наша загрузка данных не является определяющей. Мы также не проводим обучение на основе эпохи, а просто используем это для хранения книг и возможности повторно использовать ту же цикл обучения с другими наборами данных/погрузчиками.
До сих пор мы проводим эксперименты на одном узле. Чтобы запустить тренировочный запуск на одном узле, выполните следующие шаги:
webdataset . Вы можете использовать скрипт scripts/convert_imagenet_to_wds.py , чтобы преобразовать набор данных ImageNet в формат webdataset .accelerate config .config.yaml для вашего эксперимента.accelerate launch . accelerate launch python -u training/train_maskgit_imagenet.py config=path/to/yaml/config С Omegaconf переопределения командной линии выполняются в формате точечного нота. Например, если вы хотите переопределить путь набора данных, вы бы использовали команду python -u train.py config=path/to/config dataset.params.path=path/to/dataset .
Та же команда может быть использована для запуска обучения на местном уровне.
├── README.md
├── configs -> All training config files.
│ └── template_config.yaml
├── muse
│ ├── __init__.py
│ ├── data.py -> All data related utils. Can create a data folder if needed.
│ ├── logging.py -> Misc logging utils.
| ├── lr_schedulers.py -> All lr scheduler related utils.
│ ├── modeling_maskgit_vqgan.py -> VQGAN model from maskgit repo.
│ ├── modeling_taming_vqgan.py -> VQGAN model from taming repo.
│ └── modeling_transformer.py -> The main transformer model.
│ ├── modeling_utils.py -> All model related utils, like save_pretrained, from_pretrained from hub etc
│ ├── sampling.py -> Sampling/Generation utils.
│ ├── training_utils.py -> Common training utils.
├── pyproject.toml
├── setup.cfg
├── setup.py
├── test.py
└── training -> All training scripts.
├── __init__.py
├── data.py -> All data related utils. Can create a data folder if needed.
├── optimizer.py -> All optimizer related utils and any new optimizer not available in PT.
├── train_maskgit_imagenet.py
├── train_muse.py
└── train_vqgan.py
Этот проект неоднократно основан на следующих репо с открытым исходным кодом. Спасибо всем авторам за их удивительную работу.
И Obivioulsy в команду Pytorch для этой удивительной рамки ❤