開放式製作的工作,旨在復制基於變壓器的繆斯模型,以快速文本圖像生成。
https://huggingface.co/spaces/openmuse/muse
此存儲庫是為了複製繆斯模型。目的是創建一個簡單且可擴展的回購,以大規模複製有關VQ +變形金剛的Muse並建立知識。我們將使用塗上的LAION-2B + COYO-700M數據集進行培訓。
項目階段:
該項目的所有工件都將在Huggingface Hub上上傳到OpenMuse組織。
首先創建虛擬環境並使用:
git clone https://github.com/huggingface/muse
cd muse
pip install -e " .[extra] "您需要手動安裝PyTorch和torchvision 。我們將torch==1.13.1與CUDA11.7一起進行培訓。
對於分佈式數據並行培訓,我們使用accelerate庫,儘管將來可能會發生變化。對於數據集加載,我們使用webdataset庫。因此,數據集應為webdataset格式。
在媽媽,我們支持以下模型:
MaskGitTransformer紙上的主要變壓器模型。MaskGitVQGAN來自MaskGit Repo的VQGAN模型。VQGANModel來自Taming Transformers Repo的VQGAN模型。這些模型是在muse目錄下實施的。所有模型都實現了熟悉的transformers API。因此,您可以使用from_pretrained和save_pretrained方法加載和保存模型。該模型可以從HuggingFace Hub中保存並加載。
import torch
from torchvision import transforms
from PIL import Image
from muse import MaskGitVQGAN
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# encode and decode images using
encode_transform = = transforms . Compose (
[
transforms . Resize ( 256 , interpolation = transforms . InterpolationMode . BILINEAR ),
transforms . CenterCrop ( 256 ),
transforms . ToTensor (),
]
)
image = Image . open ( "..." ) #
pixel_values = encode_transform ( image ). unsqueeze ( 0 )
image_tokens = vq_model . encode ( pixel_values )
rec_image = vq_model . decode ( image_tokens )
# Convert to PIL images
rec_image = 2.0 * rec_image - 1.0
rec_image = torch . clamp ( rec_image , - 1.0 , 1.0 )
rec_image = ( rec_image + 1.0 ) / 2.0
rec_image *= 255.0
rec_image = rec_image . permute ( 0 , 2 , 3 , 1 ). cpu (). numpy (). astype ( np . uint8 )
pil_images = [ Image . fromarray ( image ) for image in rec_image ] import torch
from muse import MaskGitTransformer , MaskGitVQGAN
from muse . sampling import cosine_schedule
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# Initialize the MaskGitTransformer model
maskgit_model = MaskGitTransformer (
vocab_size = 2025 , #(1024 + 1000 + 1 = 2025 -> Vq_tokens + Imagenet class ids + <mask>)
max_position_embeddings = 257 , # 256 + 1 for class token
hidden_size = 512 ,
num_hidden_layers = 8 ,
num_attention_heads = 8 ,
intermediate_size = 2048 ,
codebook_size = 1024 ,
num_vq_tokens = 256 ,
num_classes = 1000 ,
)
# prepare the input batch
images = torch . randn ( 4 , 3 , 256 , 256 )
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
# encode the images
image_tokens = vq_model . encode ( images )
batch_size , seq_len = image_tokens . shape
# Sample a random timestep for each image
timesteps = torch . rand ( batch_size , device = image_tokens . device )
# Sample a random mask probability for each image using timestep and cosine schedule
mask_prob = cosine_schedule ( timesteps )
mask_prob = mask_prob . clip ( min_masking_rate )
# creat a random mask for each image
num_token_masked = ( seq_len * mask_prob ). round (). clamp ( min = 1 )
batch_randperm = torch . rand ( batch_size , seq_len , device = image_tokens . device ). argsort ( dim = - 1 )
mask = batch_randperm < num_token_masked . unsqueeze ( - 1 )
# mask images and create input and labels
input_ids = torch . where ( mask , mask_id , image_tokens )
labels = torch . where ( mask , image_tokens , - 100 )
# shift the class ids by codebook size
class_ids = class_ids + vq_model . num_embeddings
# prepend the class ids to the image tokens
input_ids = torch . cat ([ class_ids . unsqueeze ( - 1 ), input_ids ], dim = - 1 )
# prepend -100 to the labels as we don't want to predict the class ids
labels = torch . cat ([ - 100 * torch . ones_like ( class_ids ). unsqueeze ( - 1 ), labels ], dim = - 1 )
# forward pass
logits , loss = maskgit_model ( input_ids , labels = labels )
loss . backward ()
# to generate images
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
generated_tokens = maskgit_model . generate ( class_ids = class_ids )
rec_images = vq_model . decode ( generated_tokens )筆記:
maskgits是一個變壓器,它輸出logits,給定一系列VQ和類條件標籤令牌的令牌
Denoising過程的完成方式是用面具令牌ID掩蓋並逐漸denoise
在原始實現中,這是通過首先在最後一個昏暗的和隨機抽樣作為分類分佈上使用軟關係來完成的。這將為每個Maskid提供我們預測的令牌。然後,我們獲得了選擇這些令牌的概率。最後,當添加Gumbel*溫度時,我們獲得了TOPK的最高置信度概率。 Gumbel分佈就像向0的正態分佈轉移,用於建模極端事件。因此,在極端情況下,我們希望從默認一個中選擇一個不同的令牌
對於Lucidrian實施,它首先通過以給定的掩蔽比掩蓋了得分最高(最低的概率)令牌。然後,除了我們獲得的最高10%的邏輯之外,我們將其設置為-Infinity,因此當我們在其上進行牙齦分佈時,它們將被忽略。然後更新輸入ID和分數僅1- logits of logits在預測ID上給出的概率
對於類條件成像網,我們正在使用accelerate進行DDP培訓和webdataset進行數據加載。培訓腳本可在training/train_maskgit_imagenet.py中獲得。
我們將Omegaconf用於配置管理。有關配置模板,請參見configs/template_config.yaml 。下面我們說明配置參數。
wandb :
entity : ???
experiment :
name : ???
project : ???
output_dir : ???
max_train_examples : ???
save_every : 1000
eval_every : 500
generate_every : 1000
log_every : 50
log_grad_norm_every : 100
resume_from_checkpoint : latest
model :
vq_model :
pretrained : " openMUSE/maskgit-vqgan-imagenet-f16-256 "
transformer :
vocab_size : 2048 # (1024 + 1000 + 1 = 2025 -> Vq + Imagenet + <mask>, use 2048 for even division by 8)
max_position_embeddings : 264 # (256 + 1 for class id, use 264 for even division by 8)
hidden_size : 768
num_hidden_layers : 12
num_attention_heads : 12
intermediate_size : 3072
codebook_size : 1024
num_vq_tokens : 256
num_classes : 1000
initializer_range : 0.02
layer_norm_eps : 1e-6
use_bias : False
use_normformer : True
use_encoder_layernorm : True
hidden_dropout : 0.0
attention_dropout : 0.0
gradient_checkpointing : True
enable_xformers_memory_efficient_attention : False
dataset :
params :
train_shards_path_or_url : ???
eval_shards_path_or_url : ???
batch_size : ${training.batch_size}
shuffle_buffer_size : ???
num_workers : ???
resolution : 256
pin_memory : True
persistent_workers : True
preprocessing :
resolution : 256
center_crop : True
random_flip : False
optimizer :
name : adamw # Can be adamw or lion or fused_adamw. Install apex for fused_adamw
params : # default adamw params
learning_rate : ???
scale_lr : False # scale learning rate by total batch size
beta1 : 0.9
beta2 : 0.999
weight_decay : 0.01
epsilon : 1e-8
lr_scheduler :
scheduler : " constant_with_warmup "
params :
learning_rate : ${optimizer.params.learning_rate}
warmup_steps : 500
training :
gradient_accumulation_steps : 1
batch_size : 128
mixed_precision : " no "
enable_tf32 : True
use_ema : False
seed : 42
max_train_steps : ???
overfit_one_batch : False
min_masking_rate : 0.0
label_smoothing : 0.0
max_grad_norm : null與? ? ?需要。
wandb :
wandb.entity :用於記錄的wandb實體。實驗:
experiment.name :實驗的名稱。experiment.project :項目:用於記錄的WANDB項目。experiment.output_dir :保存檢查點的目錄。experiment.max_train_examples :要使用的最大培訓示例數量。experiment.save_every :保存每個save_every步驟的檢查站。experiment.eval_every eval_everyexperiment.generate_every :生成圖像每個generate_every步驟。experiment.log_every :記錄訓練指標,每個log_every步驟。log_grad_norm_every :log梯度規範每個log_grad_norm_every步驟。experiment.resume_from_checkpoint :從中恢復培訓的檢查點。可以是從最新檢查點恢復的latest信息,也可以是保存的檢查點的路徑。如果None或不存在路徑,則訓練從頭開始。模型:
model.vq_model.pretrained 。預先:要使用的驗證的VQ模型。可以是通往保存檢查點或擁抱面模型名稱的途徑。model.transformer :變壓器模型配置。model.gradient_checkpointing :啟用變壓器模型的梯度檢查點。enable_xformers_memory_efficient_attention :為變壓器模型啟用內存有效注意或閃爍注意力。為了閃光,我們需要使用fp16或bf16 。需要安裝Xformers才能正常工作。數據集:
dataset.params.train_shards_path_or_url : webdataset訓練碎片的路徑或URL。dataset.params.eval_shards_path_or_url : webdataset評估碎片的路徑或URL。dataset.params.batch_size :用於培訓的批次大小。dataset.params.shuffle_buffer_size :用於培訓的洗牌緩衝區大小。dataset.params.num_workers :用於數據加載的工人數量。dataset.params.resolution :用於培訓的圖像的分辨率。dataset.params.pin_memory :固定存儲器以進行數據加載。dataset.params.persistent_workers :使用持久工人進行數據加載。dataset.preprocessing.resolution :用於預處理的圖像的分辨率。dataset.preprocessing.center_crop :是否將圖像中心。如果是False ,則將圖像隨機裁剪為resolution 。dataset.preprocessing.random_flip :是否隨機翻轉圖像。如果是False ,則圖像不會翻轉。優化器:
optimizer.name :用於培訓的優化器。optimizer.params :params:優化器參數。lr_scheduler :
lr_scheduler.scheduler :用於培訓的學習率調度程序。lr_scheduler.params :學習率調度程序參數。訓練:
training.gradient_accumulation_steps :用於培訓的梯度積累步驟的數量。training.batch_size :用於培訓的批次大小。training.mixed_precision :用於培訓的混合精度模式。可以no , fp16或bf16 。training.enable_tf32 :啟用TF32用於安培GPU的培訓。training.use_ema :啟用EMA進行培訓。目前不支持。training.seed :用於培訓的種子。training.max_train_steps :訓練步驟的最大數量。training.overfit_one_batch :是否要過度擬合一批調試。training.min_masking_rate :用於培訓的最小掩蔽率。training.label_smoothing :用於培訓的標籤平滑值。max_grad_norm :最大梯度標準。有關培訓和數據集的註釋。 :
我們每次恢復/開始訓練運行時,我們將碎片(用更換)和样本示例隨機重新採樣。這意味著我們的數據加載不是確定性的。我們也不進行基於Epoch的培訓,而只是將其用於書籍保存,並能夠與其他數據集/加載程序重複使用相同的培訓循環。
到目前為止,我們正在單個節點上運行實驗。要在單個節點上啟動訓練運行,請運行以下步驟:
webdataset格式準備數據集。您可以使用scripts/convert_imagenet_to_wds.py腳本將Imagenet數據集轉換為webdataset格式。accelerate config配置培訓env。config.yaml文件。accelerate launch啟動訓練運行。 accelerate launch python -u training/train_maskgit_imagenet.py config=path/to/yaml/config使用Omegaconf,命令行替代以點通用格式完成。例如,如果要覆蓋數據集路徑,則將使用命令python -u train.py config=path/to/config dataset.params.path=path/to/dataset 。
同一命令可用於本地啟動培訓。
├── README.md
├── configs -> All training config files.
│ └── template_config.yaml
├── muse
│ ├── __init__.py
│ ├── data.py -> All data related utils. Can create a data folder if needed.
│ ├── logging.py -> Misc logging utils.
| ├── lr_schedulers.py -> All lr scheduler related utils.
│ ├── modeling_maskgit_vqgan.py -> VQGAN model from maskgit repo.
│ ├── modeling_taming_vqgan.py -> VQGAN model from taming repo.
│ └── modeling_transformer.py -> The main transformer model.
│ ├── modeling_utils.py -> All model related utils, like save_pretrained, from_pretrained from hub etc
│ ├── sampling.py -> Sampling/Generation utils.
│ ├── training_utils.py -> Common training utils.
├── pyproject.toml
├── setup.cfg
├── setup.py
├── test.py
└── training -> All training scripts.
├── __init__.py
├── data.py -> All data related utils. Can create a data folder if needed.
├── optimizer.py -> All optimizer related utils and any new optimizer not available in PT.
├── train_maskgit_imagenet.py
├── train_muse.py
└── train_vqgan.py
該項目基於以下開源存儲庫。感謝所有作者的出色工作。
和obivioulsy到Pytorch團隊為這個驚人的框架❤️