开放式制作的工作,旨在复制基于变压器的缪斯模型,以快速文本图像生成。
https://huggingface.co/spaces/openmuse/muse
此存储库是为了复制缪斯模型。目的是创建一个简单且可扩展的回购,以大规模复制有关VQ +变形金刚的Muse并建立知识。我们将使用涂上的LAION-2B + COYO-700M数据集进行培训。
项目阶段:
该项目的所有工件都将在Huggingface Hub上上传到OpenMuse组织。
首先创建虚拟环境并使用:
git clone https://github.com/huggingface/muse
cd muse
pip install -e " .[extra] "您需要手动安装PyTorch和torchvision 。我们将torch==1.13.1与CUDA11.7一起进行培训。
对于分布式数据并行培训,我们使用accelerate库,尽管将来可能会发生变化。对于数据集加载,我们使用webdataset库。因此,数据集应为webdataset格式。
在妈妈,我们支持以下模型:
MaskGitTransformer纸上的主要变压器模型。MaskGitVQGAN来自MaskGit Repo的VQGAN模型。VQGANModel来自Taming Transformers Repo的VQGAN模型。这些模型是在muse目录下实施的。所有模型都实现了熟悉的transformers API。因此,您可以使用from_pretrained和save_pretrained方法加载和保存模型。该模型可以从HuggingFace Hub中保存并加载。
import torch
from torchvision import transforms
from PIL import Image
from muse import MaskGitVQGAN
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# encode and decode images using
encode_transform = = transforms . Compose (
[
transforms . Resize ( 256 , interpolation = transforms . InterpolationMode . BILINEAR ),
transforms . CenterCrop ( 256 ),
transforms . ToTensor (),
]
)
image = Image . open ( "..." ) #
pixel_values = encode_transform ( image ). unsqueeze ( 0 )
image_tokens = vq_model . encode ( pixel_values )
rec_image = vq_model . decode ( image_tokens )
# Convert to PIL images
rec_image = 2.0 * rec_image - 1.0
rec_image = torch . clamp ( rec_image , - 1.0 , 1.0 )
rec_image = ( rec_image + 1.0 ) / 2.0
rec_image *= 255.0
rec_image = rec_image . permute ( 0 , 2 , 3 , 1 ). cpu (). numpy (). astype ( np . uint8 )
pil_images = [ Image . fromarray ( image ) for image in rec_image ] import torch
from muse import MaskGitTransformer , MaskGitVQGAN
from muse . sampling import cosine_schedule
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# Initialize the MaskGitTransformer model
maskgit_model = MaskGitTransformer (
vocab_size = 2025 , #(1024 + 1000 + 1 = 2025 -> Vq_tokens + Imagenet class ids + <mask>)
max_position_embeddings = 257 , # 256 + 1 for class token
hidden_size = 512 ,
num_hidden_layers = 8 ,
num_attention_heads = 8 ,
intermediate_size = 2048 ,
codebook_size = 1024 ,
num_vq_tokens = 256 ,
num_classes = 1000 ,
)
# prepare the input batch
images = torch . randn ( 4 , 3 , 256 , 256 )
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
# encode the images
image_tokens = vq_model . encode ( images )
batch_size , seq_len = image_tokens . shape
# Sample a random timestep for each image
timesteps = torch . rand ( batch_size , device = image_tokens . device )
# Sample a random mask probability for each image using timestep and cosine schedule
mask_prob = cosine_schedule ( timesteps )
mask_prob = mask_prob . clip ( min_masking_rate )
# creat a random mask for each image
num_token_masked = ( seq_len * mask_prob ). round (). clamp ( min = 1 )
batch_randperm = torch . rand ( batch_size , seq_len , device = image_tokens . device ). argsort ( dim = - 1 )
mask = batch_randperm < num_token_masked . unsqueeze ( - 1 )
# mask images and create input and labels
input_ids = torch . where ( mask , mask_id , image_tokens )
labels = torch . where ( mask , image_tokens , - 100 )
# shift the class ids by codebook size
class_ids = class_ids + vq_model . num_embeddings
# prepend the class ids to the image tokens
input_ids = torch . cat ([ class_ids . unsqueeze ( - 1 ), input_ids ], dim = - 1 )
# prepend -100 to the labels as we don't want to predict the class ids
labels = torch . cat ([ - 100 * torch . ones_like ( class_ids ). unsqueeze ( - 1 ), labels ], dim = - 1 )
# forward pass
logits , loss = maskgit_model ( input_ids , labels = labels )
loss . backward ()
# to generate images
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
generated_tokens = maskgit_model . generate ( class_ids = class_ids )
rec_images = vq_model . decode ( generated_tokens )笔记:
maskgits是一个变压器,它输出logits,给定一系列VQ和类条件标签令牌的令牌
Denoising过程的完成方式是用面具令牌ID掩盖并逐渐denoise
在原始实现中,这是通过首先在最后一个昏暗的和随机抽样作为分类分布上使用软关系来完成的。这将为每个Maskid提供我们预测的令牌。然后,我们获得了选择这些令牌的概率。最后,当添加Gumbel*温度时,我们获得了TOPK的最高置信度概率。 Gumbel分布就像向0的正态分布转移,用于建模极端事件。因此,在极端情况下,我们希望从默认一个中选择一个不同的令牌
对于Lucidrian实施,它首先通过以给定的掩蔽比掩盖了得分最高(最低的概率)令牌。然后,除了我们获得的最高10%的逻辑之外,我们将其设置为-Infinity,因此当我们在其上进行牙龈分布时,它们将被忽略。然后更新输入ID和分数仅1- logits of logits在预测ID上给出的概率
对于类条件成像网,我们正在使用accelerate进行DDP培训和webdataset进行数据加载。培训脚本可在training/train_maskgit_imagenet.py中获得。
我们将Omegaconf用于配置管理。有关配置模板,请参见configs/template_config.yaml 。下面我们说明配置参数。
wandb :
entity : ???
experiment :
name : ???
project : ???
output_dir : ???
max_train_examples : ???
save_every : 1000
eval_every : 500
generate_every : 1000
log_every : 50
log_grad_norm_every : 100
resume_from_checkpoint : latest
model :
vq_model :
pretrained : " openMUSE/maskgit-vqgan-imagenet-f16-256 "
transformer :
vocab_size : 2048 # (1024 + 1000 + 1 = 2025 -> Vq + Imagenet + <mask>, use 2048 for even division by 8)
max_position_embeddings : 264 # (256 + 1 for class id, use 264 for even division by 8)
hidden_size : 768
num_hidden_layers : 12
num_attention_heads : 12
intermediate_size : 3072
codebook_size : 1024
num_vq_tokens : 256
num_classes : 1000
initializer_range : 0.02
layer_norm_eps : 1e-6
use_bias : False
use_normformer : True
use_encoder_layernorm : True
hidden_dropout : 0.0
attention_dropout : 0.0
gradient_checkpointing : True
enable_xformers_memory_efficient_attention : False
dataset :
params :
train_shards_path_or_url : ???
eval_shards_path_or_url : ???
batch_size : ${training.batch_size}
shuffle_buffer_size : ???
num_workers : ???
resolution : 256
pin_memory : True
persistent_workers : True
preprocessing :
resolution : 256
center_crop : True
random_flip : False
optimizer :
name : adamw # Can be adamw or lion or fused_adamw. Install apex for fused_adamw
params : # default adamw params
learning_rate : ???
scale_lr : False # scale learning rate by total batch size
beta1 : 0.9
beta2 : 0.999
weight_decay : 0.01
epsilon : 1e-8
lr_scheduler :
scheduler : " constant_with_warmup "
params :
learning_rate : ${optimizer.params.learning_rate}
warmup_steps : 500
training :
gradient_accumulation_steps : 1
batch_size : 128
mixed_precision : " no "
enable_tf32 : True
use_ema : False
seed : 42
max_train_steps : ???
overfit_one_batch : False
min_masking_rate : 0.0
label_smoothing : 0.0
max_grad_norm : null与???需要。
wandb :
wandb.entity :用于记录的wandb实体。实验:
experiment.name :实验的名称。experiment.project :项目:用于记录的WANDB项目。experiment.output_dir :保存检查点的目录。experiment.max_train_examples :要使用的最大培训示例数量。experiment.save_every :保存每个save_every步骤的检查站。experiment.eval_every eval_everyexperiment.generate_every :生成图像每个generate_every步骤。experiment.log_every :记录训练指标,每个log_every步骤。log_grad_norm_every :log梯度规范每个log_grad_norm_every步骤。experiment.resume_from_checkpoint :从中恢复培训的检查点。可以是从最新检查点恢复的latest信息,也可以是保存的检查点的路径。如果None或不存在路径,则训练从头开始。模型:
model.vq_model.pretrained 。预先:要使用的验证的VQ模型。可以是通往保存检查点或拥抱面模型名称的途径。model.transformer :变压器模型配置。model.gradient_checkpointing :启用变压器模型的梯度检查点。enable_xformers_memory_efficient_attention :为变压器模型启用内存有效注意或闪烁注意力。为了闪光,我们需要使用fp16或bf16 。需要安装Xformers才能正常工作。数据集:
dataset.params.train_shards_path_or_url : webdataset训练碎片的路径或URL。dataset.params.eval_shards_path_or_url : webdataset评估碎片的路径或URL。dataset.params.batch_size :用于培训的批次大小。dataset.params.shuffle_buffer_size :用于培训的洗牌缓冲区大小。dataset.params.num_workers :用于数据加载的工人数量。dataset.params.resolution :用于培训的图像的分辨率。dataset.params.pin_memory :固定存储器以进行数据加载。dataset.params.persistent_workers :使用持久工人进行数据加载。dataset.preprocessing.resolution :用于预处理的图像的分辨率。dataset.preprocessing.center_crop :是否将图像中心。如果是False ,则将图像随机裁剪为resolution 。dataset.preprocessing.random_flip :是否随机翻转图像。如果是False ,则图像不会翻转。优化器:
optimizer.name :用于培训的优化器。optimizer.params :params:优化器参数。lr_scheduler :
lr_scheduler.scheduler :用于培训的学习率调度程序。lr_scheduler.params :学习率调度程序参数。训练:
training.gradient_accumulation_steps :用于培训的梯度积累步骤的数量。training.batch_size :用于培训的批次大小。training.mixed_precision :用于培训的混合精度模式。可以no , fp16或bf16 。training.enable_tf32 :启用TF32用于安培GPU的培训。training.use_ema :启用EMA进行培训。目前不支持。training.seed :用于培训的种子。training.max_train_steps :训练步骤的最大数量。training.overfit_one_batch :是否要过度拟合一批调试。training.min_masking_rate :用于培训的最小掩蔽率。training.label_smoothing :用于培训的标签平滑值。max_grad_norm :最大梯度标准。有关培训和数据集的注释。 :
我们每次恢复/开始训练运行时,我们将碎片(用更换)和样本示例随机重新采样。这意味着我们的数据加载不是确定性的。我们也不进行基于Epoch的培训,而只是将其用于书籍保存,并能够与其他数据集/加载程序重复使用相同的培训循环。
到目前为止,我们正在单个节点上运行实验。要在单个节点上启动训练运行,请运行以下步骤:
webdataset格式准备数据集。您可以使用scripts/convert_imagenet_to_wds.py脚本将Imagenet数据集转换为webdataset格式。accelerate config配置培训env。config.yaml文件。accelerate launch启动训练运行。 accelerate launch python -u training/train_maskgit_imagenet.py config=path/to/yaml/config使用Omegaconf,命令行替代以点通用格式完成。例如,如果要覆盖数据集路径,则将使用命令python -u train.py config=path/to/config dataset.params.path=path/to/dataset 。
同一命令可用于本地启动培训。
├── README.md
├── configs -> All training config files.
│ └── template_config.yaml
├── muse
│ ├── __init__.py
│ ├── data.py -> All data related utils. Can create a data folder if needed.
│ ├── logging.py -> Misc logging utils.
| ├── lr_schedulers.py -> All lr scheduler related utils.
│ ├── modeling_maskgit_vqgan.py -> VQGAN model from maskgit repo.
│ ├── modeling_taming_vqgan.py -> VQGAN model from taming repo.
│ └── modeling_transformer.py -> The main transformer model.
│ ├── modeling_utils.py -> All model related utils, like save_pretrained, from_pretrained from hub etc
│ ├── sampling.py -> Sampling/Generation utils.
│ ├── training_utils.py -> Common training utils.
├── pyproject.toml
├── setup.cfg
├── setup.py
├── test.py
└── training -> All training scripts.
├── __init__.py
├── data.py -> All data related utils. Can create a data folder if needed.
├── optimizer.py -> All optimizer related utils and any new optimizer not available in PT.
├── train_maskgit_imagenet.py
├── train_muse.py
└── train_vqgan.py
该项目基于以下开源存储库。感谢所有作者的出色工作。
和obivioulsy到Pytorch团队为这个惊人的框架❤️