ความพยายามในการเปิดตัวเพื่อทำซ้ำโมเดล Muse ตามหม้อแปลงสำหรับการสร้าง Text2Image ที่รวดเร็ว
https://huggingface.co/spaces/openmuse/muse
repo นี้มีไว้สำหรับการทำซ้ำโมเดล Muse เป้าหมายคือการสร้าง repo ที่เรียบง่ายและปรับขนาดได้เพื่อสร้าง Muse และสร้างความรู้เกี่ยวกับ VQ + Transformers ในระดับ เราจะใช้ชุดข้อมูล Deduped LAION-2B + Coyo-700M สำหรับการฝึกอบรม
ขั้นตอนโครงการ:
สิ่งประดิษฐ์ทั้งหมดของโครงการนี้จะถูกอัปโหลดไปยังองค์กร OpenMuse บน HuggingFace Hub
ก่อนอื่นสร้างสภาพแวดล้อมเสมือนจริงและติดตั้ง repo โดยใช้:
git clone https://github.com/huggingface/muse
cd muse
pip install -e " .[extra] " คุณจะต้องติดตั้ง PyTorch และ torchvision ด้วยตนเอง เรากำลังใช้ torch==1.13.1 กับ CUDA11.7 สำหรับการฝึกอบรม
สำหรับการฝึกอบรมข้อมูลแบบกระจายแบบคู่ขนานเราใช้ accelerate Library แม้ว่าสิ่งนี้อาจเปลี่ยนแปลงได้ในอนาคต สำหรับการโหลดชุดข้อมูลเราใช้ webdataset Library ดังนั้นชุดข้อมูลควรอยู่ในรูปแบบ webdataset
ที่ momemnt เราสนับสนุนโมเดลต่อไปนี้:
MaskGitTransformer - รุ่นหม้อแปลงหลักจากกระดาษMaskGitVQGAN - โมเดล vqgan จาก repo maskgitVQGANModel - โมเดล VQGAN จาก Repo Taming Transformers แบบจำลองนี้ถูกนำไปใช้ภายใต้ไดเรกทอรี muse ทุกรุ่นใช้ API transformers ที่คุ้นเคย ดังนั้นคุณสามารถใช้วิธีการ from_pretrained และ save_pretrained เพื่อโหลดและบันทึกโมเดล โมเดลสามารถบันทึกและโหลดได้จากฮับ HuggingFace
import torch
from torchvision import transforms
from PIL import Image
from muse import MaskGitVQGAN
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# encode and decode images using
encode_transform = = transforms . Compose (
[
transforms . Resize ( 256 , interpolation = transforms . InterpolationMode . BILINEAR ),
transforms . CenterCrop ( 256 ),
transforms . ToTensor (),
]
)
image = Image . open ( "..." ) #
pixel_values = encode_transform ( image ). unsqueeze ( 0 )
image_tokens = vq_model . encode ( pixel_values )
rec_image = vq_model . decode ( image_tokens )
# Convert to PIL images
rec_image = 2.0 * rec_image - 1.0
rec_image = torch . clamp ( rec_image , - 1.0 , 1.0 )
rec_image = ( rec_image + 1.0 ) / 2.0
rec_image *= 255.0
rec_image = rec_image . permute ( 0 , 2 , 3 , 1 ). cpu (). numpy (). astype ( np . uint8 )
pil_images = [ Image . fromarray ( image ) for image in rec_image ] import torch
from muse import MaskGitTransformer , MaskGitVQGAN
from muse . sampling import cosine_schedule
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# Initialize the MaskGitTransformer model
maskgit_model = MaskGitTransformer (
vocab_size = 2025 , #(1024 + 1000 + 1 = 2025 -> Vq_tokens + Imagenet class ids + <mask>)
max_position_embeddings = 257 , # 256 + 1 for class token
hidden_size = 512 ,
num_hidden_layers = 8 ,
num_attention_heads = 8 ,
intermediate_size = 2048 ,
codebook_size = 1024 ,
num_vq_tokens = 256 ,
num_classes = 1000 ,
)
# prepare the input batch
images = torch . randn ( 4 , 3 , 256 , 256 )
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
# encode the images
image_tokens = vq_model . encode ( images )
batch_size , seq_len = image_tokens . shape
# Sample a random timestep for each image
timesteps = torch . rand ( batch_size , device = image_tokens . device )
# Sample a random mask probability for each image using timestep and cosine schedule
mask_prob = cosine_schedule ( timesteps )
mask_prob = mask_prob . clip ( min_masking_rate )
# creat a random mask for each image
num_token_masked = ( seq_len * mask_prob ). round (). clamp ( min = 1 )
batch_randperm = torch . rand ( batch_size , seq_len , device = image_tokens . device ). argsort ( dim = - 1 )
mask = batch_randperm < num_token_masked . unsqueeze ( - 1 )
# mask images and create input and labels
input_ids = torch . where ( mask , mask_id , image_tokens )
labels = torch . where ( mask , image_tokens , - 100 )
# shift the class ids by codebook size
class_ids = class_ids + vq_model . num_embeddings
# prepend the class ids to the image tokens
input_ids = torch . cat ([ class_ids . unsqueeze ( - 1 ), input_ids ], dim = - 1 )
# prepend -100 to the labels as we don't want to predict the class ids
labels = torch . cat ([ - 100 * torch . ones_like ( class_ids ). unsqueeze ( - 1 ), labels ], dim = - 1 )
# forward pass
logits , loss = maskgit_model ( input_ids , labels = labels )
loss . backward ()
# to generate images
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
generated_tokens = maskgit_model . generate ( class_ids = class_ids )
rec_images = vq_model . decode ( generated_tokens )บันทึก :
Maskgits เป็นหม้อแปลงที่เอาต์พุตบันทึกตามลำดับของโทเค็นของโทเค็นฉลากทั้ง VQ และชั้นเรียน
วิธีที่กระบวนการ denoising เสร็จสิ้นคือการปิดบังด้วย Mask Token ID และค่อยๆ denoise
ในการใช้งานดั้งเดิมสิ่งนี้จะทำโดยครั้งแรกโดยใช้ softmax ในการสุ่มตัวอย่างสลัวครั้งสุดท้ายและสุ่มเป็นการแจกแจงแบบจัดหมวดหมู่ สิ่งนี้จะให้โทเค็นที่คาดการณ์ของเราสำหรับแต่ละ maskid จากนั้นเราจะได้รับความน่าจะเป็นสำหรับโทเค็นเหล่านั้นที่จะเลือก ในที่สุดเราจะได้รับความมั่นใจสูงสุดของ Topk เมื่อเพิ่มอุณหภูมิ การแจกแจง Gumbel เป็นเหมือนการกระจายปกติที่เปลี่ยนไปสู่ 0 ซึ่งใช้ในการจำลองเหตุการณ์ที่รุนแรง ดังนั้นในสถานการณ์ที่รุนแรงเราต้องการเห็นโทเค็นอื่นที่ถูกเลือกจากค่าเริ่มต้น
สำหรับการใช้งาน Lucidrian ก่อนอื่นจะลบโทเค็นที่ให้คะแนนสูงสุด (ความน่าจะเป็นต่ำสุด) โดยการปิดบังพวกเขาด้วยอัตราส่วนการปิดบังที่กำหนด จากนั้นยกเว้น 10% สูงสุดของ logits ที่เราได้รับเราตั้งค่าเป็น -infinity ดังนั้นเมื่อเราทำการแจกจ่าย Gumbel บนมันพวกเขาจะถูกเพิกเฉย จากนั้นอัปเดต ID อินพุตและคะแนนที่คะแนนเป็นเพียง 1- ความน่าจะเป็นที่ได้รับจาก softmax ของ logits ที่ ID ที่คาดการณ์ไว้น่าสนใจ
สำหรับ ImageNet ระดับชั้นเรียนเรากำลังใช้ accelerate สำหรับการฝึกอบรม DDP และ webdataset สำหรับการโหลดข้อมูล สคริปต์การฝึกอบรมมีอยู่ใน training/train_maskgit_imagenet.py
เราใช้ Omegaconf สำหรับการจัดการการกำหนดค่า ดู configs/template_config.yaml สำหรับเทมเพลตการกำหนดค่า ด้านล่างเราอธิบายพารามิเตอร์การกำหนดค่า
wandb :
entity : ???
experiment :
name : ???
project : ???
output_dir : ???
max_train_examples : ???
save_every : 1000
eval_every : 500
generate_every : 1000
log_every : 50
log_grad_norm_every : 100
resume_from_checkpoint : latest
model :
vq_model :
pretrained : " openMUSE/maskgit-vqgan-imagenet-f16-256 "
transformer :
vocab_size : 2048 # (1024 + 1000 + 1 = 2025 -> Vq + Imagenet + <mask>, use 2048 for even division by 8)
max_position_embeddings : 264 # (256 + 1 for class id, use 264 for even division by 8)
hidden_size : 768
num_hidden_layers : 12
num_attention_heads : 12
intermediate_size : 3072
codebook_size : 1024
num_vq_tokens : 256
num_classes : 1000
initializer_range : 0.02
layer_norm_eps : 1e-6
use_bias : False
use_normformer : True
use_encoder_layernorm : True
hidden_dropout : 0.0
attention_dropout : 0.0
gradient_checkpointing : True
enable_xformers_memory_efficient_attention : False
dataset :
params :
train_shards_path_or_url : ???
eval_shards_path_or_url : ???
batch_size : ${training.batch_size}
shuffle_buffer_size : ???
num_workers : ???
resolution : 256
pin_memory : True
persistent_workers : True
preprocessing :
resolution : 256
center_crop : True
random_flip : False
optimizer :
name : adamw # Can be adamw or lion or fused_adamw. Install apex for fused_adamw
params : # default adamw params
learning_rate : ???
scale_lr : False # scale learning rate by total batch size
beta1 : 0.9
beta2 : 0.999
weight_decay : 0.01
epsilon : 1e-8
lr_scheduler :
scheduler : " constant_with_warmup "
params :
learning_rate : ${optimizer.params.learning_rate}
warmup_steps : 500
training :
gradient_accumulation_steps : 1
batch_size : 128
mixed_precision : " no "
enable_tf32 : True
use_ema : False
seed : 42
max_train_steps : ???
overfit_one_batch : False
min_masking_rate : 0.0
label_smoothing : 0.0
max_grad_norm : nullข้อโต้แย้งกับ ??? จำเป็นต้องมี
Wandb :
wandb.entity : เอนทิตีของ Wandb ที่จะใช้สำหรับการบันทึกการทดลอง :
experiment.name : ชื่อของการทดลองexperiment.project : โครงการ Wandb ที่จะใช้สำหรับการบันทึกexperiment.output_dir : ไดเรกทอรีเพื่อบันทึกจุดตรวจexperiment.max_train_examples : จำนวนตัวอย่างการฝึกอบรมสูงสุดที่จะใช้experiment.save_every : บันทึกจุดตรวจสอบทุกขั้นตอน save_everyexperiment.eval_every : ประเมินแบบจำลองทุกขั้นตอน eval_everyexperiment.generate_every : สร้างภาพทุกขั้นตอน generate_everyexperiment.log_every : บันทึกตัวชี้วัดการฝึกอบรมทุกขั้นตอน log_everylog_grad_norm_every : บันทึกบรรทัดฐานการไล่ระดับสีทุกขั้นตอน log_grad_norm_everyexperiment.resume_from_checkpoint : จุดตรวจสอบเพื่อกลับมาฝึกอบรมต่อจาก สามารถเป็น latest ที่จะกลับมาทำงานจากจุดตรวจสอบล่าสุดหรือเส้นทางไปยังจุดตรวจสอบที่บันทึกไว้ หาก None หรือเส้นทางไม่มีการฝึกอบรมเริ่มต้นจากศูนย์แบบอย่าง :
model.vq_model.pretrained : โมเดล VQ ที่ผ่านการฝึกอบรมที่จะใช้ สามารถเป็นเส้นทางไปยังจุดตรวจสอบที่บันทึกไว้หรือชื่อโมเดล HuggingFacemodel.transformer : การกำหนดค่าโมเดลหม้อแปลงmodel.gradient_checkpointing : เปิดใช้งานการไล่ระดับสีการไล่ระดับสีสำหรับรุ่นหม้อแปลงenable_xformers_memory_efficient_attention : เปิดใช้งานหน่วยความจำที่มีประสิทธิภาพหรือให้ความสนใจกับโมเดลหม้อแปลง สำหรับความสนใจของแฟลชเราต้องใช้ fp16 หรือ bf16 Xformers จำเป็นต้องติดตั้งเพื่อให้ทำงานนี้ชุดข้อมูล :
dataset.params.train_shards_path_or_url : เส้นทางหรือ URL ไปยังการฝึกอบรม webdatasetdataset.params.eval_shards_path_or_url : เส้นทางหรือ URL ไปยังแผ่นประเมิน webdatasetdataset.params.batch_size : ขนาดแบทช์ที่จะใช้สำหรับการฝึกอบรมdataset.params.shuffle_buffer_size : ขนาดบัฟเฟอร์ Shuffle ที่จะใช้สำหรับการฝึกอบรมdataset.params.num_workers : จำนวนคนงานที่ใช้สำหรับการโหลดข้อมูลdataset.params.resolution : ความละเอียดของภาพที่จะใช้สำหรับการฝึกอบรมdataset.params.pin_memory : PIN หน่วยความจำสำหรับการโหลดข้อมูลdataset.params.persistent_workers : ใช้คนงานถาวรสำหรับการโหลดข้อมูลdataset.preprocessing.resolution : ความละเอียดของภาพที่จะใช้สำหรับการประมวลผลล่วงหน้าdataset.preprocessing.center_crop : ไม่ว่าจะเป็นศูนย์กลางของการครอบตัดภาพหรือไม่ หาก False ภาพจะถูกตัดแบบสุ่มไปที่ resolutiondataset.preprocessing.random_flip : ไม่ว่าจะแบบสุ่มพลิกรูปภาพหรือไม่ ถ้า False แสดงว่าภาพจะไม่พลิกเครื่องมือเพิ่มประสิทธิภาพ :
optimizer.name : เครื่องมือเพิ่มประสิทธิภาพที่จะใช้สำหรับการฝึกอบรมoptimizer.params : พารามิเตอร์ OptimizerLR_SCHEDULER :
lr_scheduler.scheduler : ตัวกำหนดค่าอัตราการเรียนรู้ที่จะใช้สำหรับการฝึกอบรมlr_scheduler.params : พารามิเตอร์กำหนดค่าอัตราการเรียนรู้การฝึกอบรม :
training.gradient_accumulation_steps : จำนวนขั้นตอนการสะสมการไล่ระดับสีเพื่อใช้สำหรับการฝึกอบรมtraining.batch_size : ขนาดแบทช์ที่จะใช้สำหรับการฝึกอบรมtraining.mixed_precision : โหมดความแม่นยำผสมที่จะใช้สำหรับการฝึกอบรม no สามารถเป็น fp16 หรือ bf16training.enable_tf32 : เปิดใช้งาน TF32 สำหรับการฝึกอบรมเกี่ยวกับ AMPERE GPUtraining.use_ema : เปิดใช้งาน EMA สำหรับการฝึกอบรม ปัจจุบันไม่รองรับtraining.seed : เมล็ดพันธุ์ที่จะใช้สำหรับการฝึกอบรมtraining.max_train_steps : จำนวนขั้นตอนการฝึกอบรมสูงสุดtraining.overfit_one_batch : ไม่ว่าจะเกินค่าหนึ่งชุดสำหรับการดีบักหรือไม่training.min_masking_rate : อัตราการปิดบังขั้นต่ำที่จะใช้สำหรับการฝึกอบรมtraining.label_smoothing : ค่าการปรับให้เรียบของฉลากเพื่อใช้สำหรับการฝึกอบรมmax_grad_norm : บรรทัดฐานการไล่ระดับสีสูงสุดหมายเหตุเกี่ยวกับการฝึกอบรมและชุดข้อมูล -
เราสุ่มตัวอย่างเศษ (พร้อมแทนที่) และตัวอย่างตัวอย่างในบัฟเฟอร์สำหรับการฝึกอบรมทุกครั้งที่เรากลับมา/เริ่มการฝึกอบรม ซึ่งหมายความว่าการโหลดข้อมูลของเราไม่ได้กำหนด นอกจากนี้เรายังไม่ได้ทำการฝึกอบรมตามยุค แต่เพียงแค่ใช้สิ่งนี้เพื่อการเก็บหนังสือและความสามารถในการใช้ลูปการฝึกอบรมแบบเดียวกันกับชุดข้อมูล/โหลดอื่น ๆ
จนถึงตอนนี้เรากำลังทำการทดลองบนโหนดเดียว หากต้องการเปิดการฝึกอบรมบนโหนดเดียวให้เรียกใช้ขั้นตอนต่อไปนี้:
webdataset คุณสามารถใช้ scripts/convert_imagenet_to_wds.py เพื่อแปลงชุดข้อมูล Imagenet เป็นรูปแบบ webdatasetaccelerate configconfig.yaml สำหรับการทดสอบของคุณaccelerate launch accelerate launch python -u training/train_maskgit_imagenet.py config=path/to/yaml/config ด้วย Omegaconf การแทนที่คำสั่งจะทำในรูปแบบ dot-notation เช่นหากคุณต้องการแทนที่เส้นทางชุดข้อมูลคุณจะใช้คำสั่ง python -u train.py config=path/to/config dataset.params.path=path/to/dataset
คำสั่งเดียวกันสามารถใช้เพื่อเปิดการฝึกอบรมในพื้นที่
├── README.md
├── configs -> All training config files.
│ └── template_config.yaml
├── muse
│ ├── __init__.py
│ ├── data.py -> All data related utils. Can create a data folder if needed.
│ ├── logging.py -> Misc logging utils.
| ├── lr_schedulers.py -> All lr scheduler related utils.
│ ├── modeling_maskgit_vqgan.py -> VQGAN model from maskgit repo.
│ ├── modeling_taming_vqgan.py -> VQGAN model from taming repo.
│ └── modeling_transformer.py -> The main transformer model.
│ ├── modeling_utils.py -> All model related utils, like save_pretrained, from_pretrained from hub etc
│ ├── sampling.py -> Sampling/Generation utils.
│ ├── training_utils.py -> Common training utils.
├── pyproject.toml
├── setup.cfg
├── setup.py
├── test.py
└── training -> All training scripts.
├── __init__.py
├── data.py -> All data related utils. Can create a data folder if needed.
├── optimizer.py -> All optimizer related utils and any new optimizer not available in PT.
├── train_maskgit_imagenet.py
├── train_muse.py
└── train_vqgan.py
โครงการนี้ขึ้นอยู่กับ repos โอเพนซอร์ซต่อไปนี้ ขอบคุณผู้เขียนทุกคนสำหรับงานที่น่าทึ่งของพวกเขา
และ Obivioulsy ถึงทีม Pytorch สำหรับกรอบที่น่าทึ่งนี้❤