جهود إعادة الإنتاج المفتوحة لإعادة إنتاج نموذج Muse المستند إلى المحول لتوليد Text2Image السريع.
https://huggingface.co/spaces/openmuse/muse
هذا الريبو هو استنساخ نموذج موسى. الهدف من ذلك هو إنشاء ريبو بسيط وقابل للتطوير ، لإعادة إنتاج Muse وبناء knowedge حول محولات VQ + على نطاق واسع. سوف نستخدم مجموعة بيانات DEDUPED LAION-2B + COYO-700M للتدريب.
مراحل المشروع:
سيتم تحميل جميع القطع الأثرية لهذا المشروع إلى منظمة OpenMuse على مركز Huggingface.
قم أولاً بإنشاء بيئة افتراضية وقم بتثبيت الريبو باستخدام:
git clone https://github.com/huggingface/muse
cd muse
pip install -e " .[extra] " ستحتاج إلى تثبيت PyTorch و torchvision يدويًا. نحن نستخدم torch==1.13.1 مع CUDA11.7 للتدريب.
بالنسبة للبيانات الموزعة ، فإن التدريب المتوازي ، نستخدم مكتبة accelerate ، على الرغم من أن هذا قد يتغير في المستقبل. لتحميل مجموعة البيانات ، نستخدم مكتبة webdataset . لذلك يجب أن تكون مجموعة البيانات في تنسيق webdataset .
في Momemnt ندعم النماذج التالية:
MaskGitTransformer - نموذج المحول الرئيسي من الورقة.MaskGitVQGAN - نموذج VQGAN من Maskgit Repo.VQGANModel - نموذج VQGAN من Taming Transformers repo. يتم تنفيذ النماذج تحت دليل muse . جميع النماذج تنفذ واجهة برمجة تطبيقات transformers المألوفة. حتى تتمكن من استخدام أساليب from_pretrained و save_pretrained لتحميل النماذج وحفظها. يمكن حفظ النموذج وتحميله من مركز Huggingface.
import torch
from torchvision import transforms
from PIL import Image
from muse import MaskGitVQGAN
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# encode and decode images using
encode_transform = = transforms . Compose (
[
transforms . Resize ( 256 , interpolation = transforms . InterpolationMode . BILINEAR ),
transforms . CenterCrop ( 256 ),
transforms . ToTensor (),
]
)
image = Image . open ( "..." ) #
pixel_values = encode_transform ( image ). unsqueeze ( 0 )
image_tokens = vq_model . encode ( pixel_values )
rec_image = vq_model . decode ( image_tokens )
# Convert to PIL images
rec_image = 2.0 * rec_image - 1.0
rec_image = torch . clamp ( rec_image , - 1.0 , 1.0 )
rec_image = ( rec_image + 1.0 ) / 2.0
rec_image *= 255.0
rec_image = rec_image . permute ( 0 , 2 , 3 , 1 ). cpu (). numpy (). astype ( np . uint8 )
pil_images = [ Image . fromarray ( image ) for image in rec_image ] import torch
from muse import MaskGitTransformer , MaskGitVQGAN
from muse . sampling import cosine_schedule
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# Initialize the MaskGitTransformer model
maskgit_model = MaskGitTransformer (
vocab_size = 2025 , #(1024 + 1000 + 1 = 2025 -> Vq_tokens + Imagenet class ids + <mask>)
max_position_embeddings = 257 , # 256 + 1 for class token
hidden_size = 512 ,
num_hidden_layers = 8 ,
num_attention_heads = 8 ,
intermediate_size = 2048 ,
codebook_size = 1024 ,
num_vq_tokens = 256 ,
num_classes = 1000 ,
)
# prepare the input batch
images = torch . randn ( 4 , 3 , 256 , 256 )
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
# encode the images
image_tokens = vq_model . encode ( images )
batch_size , seq_len = image_tokens . shape
# Sample a random timestep for each image
timesteps = torch . rand ( batch_size , device = image_tokens . device )
# Sample a random mask probability for each image using timestep and cosine schedule
mask_prob = cosine_schedule ( timesteps )
mask_prob = mask_prob . clip ( min_masking_rate )
# creat a random mask for each image
num_token_masked = ( seq_len * mask_prob ). round (). clamp ( min = 1 )
batch_randperm = torch . rand ( batch_size , seq_len , device = image_tokens . device ). argsort ( dim = - 1 )
mask = batch_randperm < num_token_masked . unsqueeze ( - 1 )
# mask images and create input and labels
input_ids = torch . where ( mask , mask_id , image_tokens )
labels = torch . where ( mask , image_tokens , - 100 )
# shift the class ids by codebook size
class_ids = class_ids + vq_model . num_embeddings
# prepend the class ids to the image tokens
input_ids = torch . cat ([ class_ids . unsqueeze ( - 1 ), input_ids ], dim = - 1 )
# prepend -100 to the labels as we don't want to predict the class ids
labels = torch . cat ([ - 100 * torch . ones_like ( class_ids ). unsqueeze ( - 1 ), labels ], dim = - 1 )
# forward pass
logits , loss = maskgit_model ( input_ids , labels = labels )
loss . backward ()
# to generate images
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
generated_tokens = maskgit_model . generate ( class_ids = class_ids )
rec_images = vq_model . decode ( generated_tokens )ملحوظة :
Maskgits هو محول يخرج سجلات أعطى تسلسلًا من الرموز المميزة لكل من VQ ورمز الملصق المكيف
الطريقة التي تتم بها عملية تقليل الإلغاء هي الإخفاء مع معرفات رمز القناع و denoise تدريجيا
في التنفيذ الأصلي ، يتم ذلك عن طريق أولاً باستخدام softmax على آخر قاتمة وأخذ عينات عشوائية كتوزيع فئوي. هذا سوف يعطي الرموز المتوقعة لدينا لكل قناع. ثم نحصل على احتمالات اختيار هذه الرموز. أخيرًا ، نحصل على أعلى احتمالات ثقة TopK عند إضافة Temp Gumbel*. يشبه توزيع Gumbel التوزيع الطبيعي المتغير نحو 0 والذي يتم استخدامه لنمذجة الأحداث المتطرفة. لذلك في السيناريوهات المتطرفة ، نود أن نرى رمزًا مختلفًا يتم اختياره من واحد افتراضي
بالنسبة للتنفيذ لوسيدوريان ، فإنه يزيل أولاً الرموز المميزة (أدنى احتمال) عن طريق إخفاءها بنسبة تقنيع معينة. بعد ذلك ، باستثناء أعلى 10 ٪ من السجلات التي نحصل عليها ، قمنا بتعيينها على -infinity ، لذلك عندما نقوم بتوزيع Gumbel عليها ، سيتم تجاهلها. ثم قم بتحديث معرفات الإدخال والنتائج التي تكون فيها الدرجات فقط 1-الاحتمال الذي تمنحه softmax من السجلات في المعرفات المتوقعة بشكل مثير للاهتمام
بالنسبة إلى التصوير الفاصولي ، فإننا نستخدم accelerate لتدريب DDP و webdataset لتحميل البيانات. البرنامج النصي التدريبي متاح في training/train_maskgit_imagenet.py .
نستخدم أوميجاكونف لإدارة التكوين. راجع configs/template_config.yaml للحصول على قالب التكوين. نوضح أدناه معلمات التكوين.
wandb :
entity : ???
experiment :
name : ???
project : ???
output_dir : ???
max_train_examples : ???
save_every : 1000
eval_every : 500
generate_every : 1000
log_every : 50
log_grad_norm_every : 100
resume_from_checkpoint : latest
model :
vq_model :
pretrained : " openMUSE/maskgit-vqgan-imagenet-f16-256 "
transformer :
vocab_size : 2048 # (1024 + 1000 + 1 = 2025 -> Vq + Imagenet + <mask>, use 2048 for even division by 8)
max_position_embeddings : 264 # (256 + 1 for class id, use 264 for even division by 8)
hidden_size : 768
num_hidden_layers : 12
num_attention_heads : 12
intermediate_size : 3072
codebook_size : 1024
num_vq_tokens : 256
num_classes : 1000
initializer_range : 0.02
layer_norm_eps : 1e-6
use_bias : False
use_normformer : True
use_encoder_layernorm : True
hidden_dropout : 0.0
attention_dropout : 0.0
gradient_checkpointing : True
enable_xformers_memory_efficient_attention : False
dataset :
params :
train_shards_path_or_url : ???
eval_shards_path_or_url : ???
batch_size : ${training.batch_size}
shuffle_buffer_size : ???
num_workers : ???
resolution : 256
pin_memory : True
persistent_workers : True
preprocessing :
resolution : 256
center_crop : True
random_flip : False
optimizer :
name : adamw # Can be adamw or lion or fused_adamw. Install apex for fused_adamw
params : # default adamw params
learning_rate : ???
scale_lr : False # scale learning rate by total batch size
beta1 : 0.9
beta2 : 0.999
weight_decay : 0.01
epsilon : 1e-8
lr_scheduler :
scheduler : " constant_with_warmup "
params :
learning_rate : ${optimizer.params.learning_rate}
warmup_steps : 500
training :
gradient_accumulation_steps : 1
batch_size : 128
mixed_precision : " no "
enable_tf32 : True
use_ema : False
seed : 42
max_train_steps : ???
overfit_one_batch : False
min_masking_rate : 0.0
label_smoothing : 0.0
max_grad_norm : nullالحجج مع ؟؟؟ مطلوب.
واند :
wandb.entity : كيان WANDB لاستخدامه في التسجيل.تجربة :
experiment.name : اسم التجربة.experiment.project .experiment.output_dir : الدليل لحفظ نقاط التفتيش.experiment.max_train_examples : الحد الأقصى لعدد أمثلة التدريب للاستخدام.experiment.save_every save_everyexperiment.eval_every eval_everyexperiment.generate_every generate_everyexperiment.log_every log_everylog_grad_norm_every : قم بتسجيل قاعدة التدرج كل خطوات log_grad_norm_every .experiment.resume_from_checkpoint : نقطة التفتيش لاستئناف التدريب من. يمكن أن يكون latest لاستئناف من أحدث نقطة تفتيش أو مسار إلى نقطة تفتيش محفوظة. إذا لم يكن هناك None أو مسار ، فسيبدأ التدريب من الصفر.نموذج :
model.vq_model.pretrained : نموذج VQ pretRained للاستخدام. يمكن أن يكون طريقًا إلى نقطة تفتيش محفوظة أو اسم طراز Huggingface.model.transformer : تكوين نموذج المحولات.model.gradient_checkpointing : تمكين تحديد التدرج لطراز المحول.enable_xformers_memory_efficient_attention : تمكين الاهتمام الفعال للذاكرة أو اهتمام وميض لنموذج المحول. للحصول على اهتمام فلاش ، نحتاج إلى استخدام fp16 أو bf16 . يجب تثبيت Xformers لهذا العمل.مجموعة البيانات :
dataset.params.train_shards_path_or_url : المسار أو عنوان URL إلى شظايا تدريب webdataset .dataset.params.eval_shards_path_or_url : المسار أو عنوان URL إلى شظايا تقييم webdataset .dataset.params.batch_size : حجم الدُفعة لاستخدامه في التدريب.dataset.params.shuffle_buffer_size : حجم المخزن المؤقت للاختلاق لاستخدامه في التدريب.dataset.params.num_workers : عدد العمال الذين يجب استخدامهم لتحميل البيانات.dataset.params.resolution : دقة الصور لاستخدامها في التدريب.dataset.params.pin_memory : قم بتثبيت الذاكرة لتحميل البيانات.dataset.params.persistent_workers : استخدم العمال المستمرين لتحميل البيانات.dataset.preprocessing.resolution : دقة الصور لاستخدامها في المعالجة المسبقة.dataset.preprocessing.center_crop : ما إذا كان لتوسيط محصول الصور. إذا كان False ، فسيتم اقتصاص الصور بشكل عشوائي إلى resolution .dataset.preprocessing.random_flip : ما إذا كان يجب قلب الصور بشكل عشوائي. إذا False ، فلن يتم قلب الصور.المحسن :
optimizer.name : المحسن لاستخدامه في التدريب.optimizer.params : معلمات المحسن.lr_scheduler :
lr_scheduler.scheduler : جدولة معدل التعلم لاستخدامها في التدريب.lr_scheduler.params : معلمات جدولة معدل التعلم.تمرين :
training.gradient_accumulation_stepstraining.batch_size : حجم الدُفعة لاستخدامه في التدريب.training.mixed_precision mixed_precision: وضع الدقة المختلطة لاستخدامه في التدريب. يمكن أن يكون no ، fp16 أو bf16 .training.enable_tf32 .training.use_ema : تمكين EMA للتدريب. حاليا غير مدعوم.training.seed .training.max_train_steps max_train_steps: الحد الأقصى لعدد خطوات التدريب.training.overfit_one_batch .training.min_masking_rate .training.label_smoothing : قيمة تجانس التسمية لاستخدامها في التدريب.max_grad_norm : Max Lradient Norm.ملاحظات حول التدريب ومجموعة البيانات. :
نقوم بتشكيل شظايا بشكل عشوائي (مع استبدال) وأمثلة عينة في المخزن المؤقت للتدريب في كل مرة نستأنف فيها/بدء تشغيل التدريب. هذا يعني أن تحميل البيانات لدينا ليس محددًا. نحن أيضًا لا نقوم بالتدريب القائم على الحقبة ولكن فقط نستخدم هذا للحفاظ على الكتب والقدرة على إعادة استخدام حلقة التدريب نفسها مع مجموعات البيانات/اللوادر الأخرى.
حتى الآن نقوم بإجراء تجارب على عقدة واحدة. لإطلاق تشغيل تدريب على عقدة واحدة ، قم بتشغيل الخطوات التالية:
webdataset . يمكنك استخدام scripts/convert_imagenet_to_wds.py البرنامج النصي لتحويل مجموعة بيانات ImageNet إلى تنسيق webdataset .accelerate config .config.yaml لتجربتك.accelerate launch . accelerate launch python -u training/train_maskgit_imagenet.py config=path/to/yaml/config مع أوميغاكونف ، تتم تجاوزات خط الأوامر بتنسيق نقاط النقطة. على سبيل المثال ، إذا كنت ترغب في تجاوز مسار مجموعة البيانات ، فستستخدم الأمر python -u train.py config=path/to/config dataset.params.path=path/to/dataset .
يمكن استخدام نفس الأمر لإطلاق التدريب محليًا.
├── README.md
├── configs -> All training config files.
│ └── template_config.yaml
├── muse
│ ├── __init__.py
│ ├── data.py -> All data related utils. Can create a data folder if needed.
│ ├── logging.py -> Misc logging utils.
| ├── lr_schedulers.py -> All lr scheduler related utils.
│ ├── modeling_maskgit_vqgan.py -> VQGAN model from maskgit repo.
│ ├── modeling_taming_vqgan.py -> VQGAN model from taming repo.
│ └── modeling_transformer.py -> The main transformer model.
│ ├── modeling_utils.py -> All model related utils, like save_pretrained, from_pretrained from hub etc
│ ├── sampling.py -> Sampling/Generation utils.
│ ├── training_utils.py -> Common training utils.
├── pyproject.toml
├── setup.cfg
├── setup.py
├── test.py
└── training -> All training scripts.
├── __init__.py
├── data.py -> All data related utils. Can create a data folder if needed.
├── optimizer.py -> All optimizer related utils and any new optimizer not available in PT.
├── train_maskgit_imagenet.py
├── train_muse.py
└── train_vqgan.py
هذا المشروع يعتمد على hevily على repos المفتوح المصدر التالية. شكرا لجميع المؤلفين على عملهم المذهل.
و Obivioulsy لفريق Pytorch لهذا الإطار المذهل ❤