빠른 Text2image 생성을위한 변압기 기반 뮤즈 모델을 재현하기위한 개방형 재생 노력.
https://huggingface.co/spaces/openmuse/muse
이 repo는 뮤즈 모델의 재생산을위한 것입니다. 목표는 단순하고 확장 가능한 리포를 만들어 Muse를 재현하고 VQ + Transformers에 대한 Knowedge를 규모로 구축하는 것입니다. 우리는 교육을 위해 Deduped Laion-2B + Coyo-700m 데이터 세트를 사용합니다.
프로젝트 단계 :
이 프로젝트의 모든 인공물은 Huggingface Hub의 OpenMuse 조직에 업로드됩니다.
먼저 가상 환경을 만들고 다음을 사용하여 저장소를 설치하십시오.
git clone https://github.com/huggingface/muse
cd muse
pip install -e " .[extra] " PyTorch 및 torchvision 수동으로 설치해야합니다. 우리는 torch==1.13.1 사용하여 CUDA11.7 과 함께 훈련을 사용하고 있습니다.
분산 데이터 병렬 교육의 경우 향후에 변경 될 수 있지만 accelerate Library를 사용합니다. 데이터 세트로드의 경우 webdataset 라이브러리를 사용합니다. 따라서 데이터 세트는 webdataset 형식이어야합니다.
Momemnt에서 우리는 다음 모델을 지원합니다.
MaskGitTransformer 용지의 주요 변압기 모델.MaskGitVQGAN MaskGit Repo의 VQGAN 모델.VQGANModel Taming Transformers Repo의 VQGAN 모델. 모델은 muse Directory에서 구현됩니다. 모든 모델은 친숙한 transformers API를 구현합니다. 따라서 from_pretrained 및 save_pretrained 메소드를 사용하여 모델을로드하고 저장할 수 있습니다. Huggingface 허브에서 모델을 저장하고로드 할 수 있습니다.
import torch
from torchvision import transforms
from PIL import Image
from muse import MaskGitVQGAN
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# encode and decode images using
encode_transform = = transforms . Compose (
[
transforms . Resize ( 256 , interpolation = transforms . InterpolationMode . BILINEAR ),
transforms . CenterCrop ( 256 ),
transforms . ToTensor (),
]
)
image = Image . open ( "..." ) #
pixel_values = encode_transform ( image ). unsqueeze ( 0 )
image_tokens = vq_model . encode ( pixel_values )
rec_image = vq_model . decode ( image_tokens )
# Convert to PIL images
rec_image = 2.0 * rec_image - 1.0
rec_image = torch . clamp ( rec_image , - 1.0 , 1.0 )
rec_image = ( rec_image + 1.0 ) / 2.0
rec_image *= 255.0
rec_image = rec_image . permute ( 0 , 2 , 3 , 1 ). cpu (). numpy (). astype ( np . uint8 )
pil_images = [ Image . fromarray ( image ) for image in rec_image ] import torch
from muse import MaskGitTransformer , MaskGitVQGAN
from muse . sampling import cosine_schedule
# Load the pre-trained vq model from the hub
vq_model = MaskGitVQGAN . from_pretrained ( "openMUSE/maskgit-vqgan-imagenet-f16-256" )
# Initialize the MaskGitTransformer model
maskgit_model = MaskGitTransformer (
vocab_size = 2025 , #(1024 + 1000 + 1 = 2025 -> Vq_tokens + Imagenet class ids + <mask>)
max_position_embeddings = 257 , # 256 + 1 for class token
hidden_size = 512 ,
num_hidden_layers = 8 ,
num_attention_heads = 8 ,
intermediate_size = 2048 ,
codebook_size = 1024 ,
num_vq_tokens = 256 ,
num_classes = 1000 ,
)
# prepare the input batch
images = torch . randn ( 4 , 3 , 256 , 256 )
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
# encode the images
image_tokens = vq_model . encode ( images )
batch_size , seq_len = image_tokens . shape
# Sample a random timestep for each image
timesteps = torch . rand ( batch_size , device = image_tokens . device )
# Sample a random mask probability for each image using timestep and cosine schedule
mask_prob = cosine_schedule ( timesteps )
mask_prob = mask_prob . clip ( min_masking_rate )
# creat a random mask for each image
num_token_masked = ( seq_len * mask_prob ). round (). clamp ( min = 1 )
batch_randperm = torch . rand ( batch_size , seq_len , device = image_tokens . device ). argsort ( dim = - 1 )
mask = batch_randperm < num_token_masked . unsqueeze ( - 1 )
# mask images and create input and labels
input_ids = torch . where ( mask , mask_id , image_tokens )
labels = torch . where ( mask , image_tokens , - 100 )
# shift the class ids by codebook size
class_ids = class_ids + vq_model . num_embeddings
# prepend the class ids to the image tokens
input_ids = torch . cat ([ class_ids . unsqueeze ( - 1 ), input_ids ], dim = - 1 )
# prepend -100 to the labels as we don't want to predict the class ids
labels = torch . cat ([ - 100 * torch . ones_like ( class_ids ). unsqueeze ( - 1 ), labels ], dim = - 1 )
# forward pass
logits , loss = maskgit_model ( input_ids , labels = labels )
loss . backward ()
# to generate images
class_ids = torch . randint ( 0 , 1000 , ( 4 ,)) # random class ids
generated_tokens = maskgit_model . generate ( class_ids = class_ids )
rec_images = vq_model . decode ( generated_tokens )메모 :
Maskgits
비난 과정이 수행되는 방식은 마스크 토큰 ID로 마스크를 마스킹하고 점차 거부하는 것입니다.
원래 구현에서 이것은 먼저 마지막 DIM에서 SoftMax를 사용하고 범주 형 분포로 무작위로 샘플링하여 수행됩니다. 이것은 각 마스크에 대한 예측 된 토큰을 제공합니다. 그런 다음 우리는 그 토큰을 선택할 확률을 얻습니다. 마지막으로, Gumbel*온도가 추가 될 때 최고 신뢰 확률을 얻습니다. Gumbel 분포는 극한 이벤트를 모델링하는 데 사용되는 0으로 이동 한 정규 분포와 같습니다. 극단적 인 시나리오에서는 기본값에서 다른 토큰이 선택되는 것을보고 싶습니다.
Lucidrian 구현의 경우 먼저 주어진 마스킹 비율로 마스킹하여 가장 높은 점수 (가장 낮은 확률) 토큰을 제거합니다. 그런 다음, 우리가 얻는 로그 중 가장 높은 10%를 제외하고, 우리는 그것을 -infinity로 설정하므로 Gumbel 분포를 할 때 무시됩니다. 그런 다음 입력 ID와 점수가 1 인 점수를 업데이트하십시오.
Class-Conditional ImageNet의 경우 DDP 교육 및 데이터 로딩 용 webdataset 용 accelerate 사용하고 있습니다. 교육 스크립트는 training/train_maskgit_imagenet.py 에서 사용할 수 있습니다.
우리는 구성 관리에 OMEGACONF를 사용합니다. configuration 템플릿은 configs/template_config.yaml 참조하십시오. 아래에서는 구성 매개 변수를 설명합니다.
wandb :
entity : ???
experiment :
name : ???
project : ???
output_dir : ???
max_train_examples : ???
save_every : 1000
eval_every : 500
generate_every : 1000
log_every : 50
log_grad_norm_every : 100
resume_from_checkpoint : latest
model :
vq_model :
pretrained : " openMUSE/maskgit-vqgan-imagenet-f16-256 "
transformer :
vocab_size : 2048 # (1024 + 1000 + 1 = 2025 -> Vq + Imagenet + <mask>, use 2048 for even division by 8)
max_position_embeddings : 264 # (256 + 1 for class id, use 264 for even division by 8)
hidden_size : 768
num_hidden_layers : 12
num_attention_heads : 12
intermediate_size : 3072
codebook_size : 1024
num_vq_tokens : 256
num_classes : 1000
initializer_range : 0.02
layer_norm_eps : 1e-6
use_bias : False
use_normformer : True
use_encoder_layernorm : True
hidden_dropout : 0.0
attention_dropout : 0.0
gradient_checkpointing : True
enable_xformers_memory_efficient_attention : False
dataset :
params :
train_shards_path_or_url : ???
eval_shards_path_or_url : ???
batch_size : ${training.batch_size}
shuffle_buffer_size : ???
num_workers : ???
resolution : 256
pin_memory : True
persistent_workers : True
preprocessing :
resolution : 256
center_crop : True
random_flip : False
optimizer :
name : adamw # Can be adamw or lion or fused_adamw. Install apex for fused_adamw
params : # default adamw params
learning_rate : ???
scale_lr : False # scale learning rate by total batch size
beta1 : 0.9
beta2 : 0.999
weight_decay : 0.01
epsilon : 1e-8
lr_scheduler :
scheduler : " constant_with_warmup "
params :
learning_rate : ${optimizer.params.learning_rate}
warmup_steps : 500
training :
gradient_accumulation_steps : 1
batch_size : 128
mixed_precision : " no "
enable_tf32 : True
use_ema : False
seed : 42
max_train_steps : ???
overfit_one_batch : False
min_masking_rate : 0.0
label_smoothing : 0.0
max_grad_norm : null논쟁 ??? 필요합니다.
wandb :
wandb.entity : Wandb 엔티티는 로깅에 사용할 수 있습니다.실험 :
experiment.name : 실험의 이름.experiment.project : 로깅에 사용할 WANDB 프로젝트.experiment.output_dir : 검문소를 저장하는 디렉토리.experiment.max_train_examples : 사용할 최대 교육 예제.experiment.save_every : 모든 save_every 단계를 모두 검사 점을 저장하십시오.experiment.eval_every : 모든 eval_every 단계를 평가하십시오.experiment.generate_every : 모든 generate_every 단계를 생성하십시오.experiment.log_every : 모든 log_every 단계마다 훈련 메트릭을 기록하십시오.log_grad_norm_every : 모든 gradient 규범을 로그 log_grad_norm_every 단계.experiment.resume_from_checkpoint : 교육을 재개 할 검문소. 최신 체크 포인트 또는 저장된 체크 포인트로 latest 에서 재개 할 수 있습니다. None 경로가 존재하지 않으면 훈련은 처음부터 시작됩니다.모델 :
model.vq_model.pretrained : 사용할 예정인 VQ 모델. 저장된 체크 포인트 또는 포옹 페이스 모델 이름의 경로 일 수 있습니다.model.transformer : 변압기 모델 구성.model.gradient_checkpointing : 변압기 모델의 기울기 검사 점을 활성화합니다.enable_xformers_memory_efficient_attention : 변압기 모델에 대한 메모리 효율적인주의 또는 플래시주의를 활성화합니다. 플래시주의를 얻으려면 fp16 또는 bf16 사용해야합니다. Xformers를 설치하려면 설치해야합니다.데이터 세트 :
dataset.params.train_shards_path_or_url : webdataset 교육 파편의 경로 또는 URL.dataset.params.eval_shards_path_or_url : webdataset 평가 파편의 경로 또는 URL.dataset.params.batch_size : 훈련에 사용할 배치 크기.dataset.params.shuffle_buffer_size : 훈련에 사용할 셔플 버퍼 크기.dataset.params.num_workers : 데이터로드에 사용할 작업자 수.dataset.params.resolution : 훈련에 사용할 이미지의 해상도.dataset.params.pin_memory : 데이터로드를 위해 메모리를 고정합니다.dataset.params.persistent_workers : 데이터로드에 지속적인 작업자를 사용합니다.dataset.preprocessing.resolution : 전처리에 사용할 이미지의 해상도.dataset.preprocessing.center_crop : 중심을 자르는지 이미지를 자르십시오. False 이면 이미지가 resolution 에 무작위로 자릅니다.dataset.preprocessing.random_flip : 이미지를 무작위로 뒤집을 지 여부. False 이면 이미지가 뒤집어지지 않습니다.최적화 :
optimizer.name : 훈련에 사용할 최적화기.optimizer.params : Optimizer 매개 변수.LR_SCHEDULER :
lr_scheduler.scheduler : 교육에 사용할 학습 속도 스케줄러.lr_scheduler.params : 학습 속도 스케줄러 매개 변수.훈련 :
training.gradient_accumulation_steps : 교육에 사용할 구배 축적 단계 수.training.batch_size : 훈련에 사용할 배치 크기.training.mixed_precision : 교육에 사용할 혼합 정밀 모드. no , fp16 또는 bf16 일 수 있습니다.training.enable_tf32 : Ampere GPU에 대한 교육을 위해 TF32를 활성화합니다.training.use_ema : 교육을위한 EMA를 활성화합니다. 현재 지원되지 않습니다.training.seed : 훈련에 사용할 씨앗.training.max_train_steps : 최대 교육 단계 수.training.overfit_one_batch : 디버깅을 위해 하나의 배치를 과적으로 할 것인지 여부.training.min_masking_rate : 교육에 사용할 최소 마스킹 속도.training.label_smoothing : 교육에 사용할 레이블 스무딩 값.max_grad_norm : Max Gradient Norm.교육 및 데이터 세트에 대한 메모. :
우리는 훈련 실행을 재개/시작할 때마다 훈련을 위해 파편 (교체 포함)과 샘플 예제를 무작위로 샘플링합니다. 이것은 우리의 데이터 로딩이 결정성이 아니라는 것을 의미합니다. 우리는 또한 에포크 기반 교육을하지 않고 책을 유지하고 다른 데이터 세트/로더와 동일한 교육 루프를 재사용 할 수 있습니다.
지금까지 우리는 단일 노드에서 실험을 실행하고 있습니다. 단일 노드에서 훈련 실행을 시작하려면 다음 단계를 실행하십시오.
webdataset 형식으로 준비하십시오. scripts/convert_imagenet_to_wds.py 스크립트를 사용하여 imagenet 데이터 세트를 webdataset 형식으로 변환 할 수 있습니다.accelerate config 사용하여 교육 환경을 구성하십시오.config.yaml 파일을 만듭니다.accelerate launch 사용하여 훈련 실행을 시작하십시오. accelerate launch python -u training/train_maskgit_imagenet.py config=path/to/yaml/config OMEGACONF를 사용하면 명령 라인 오버라이드가 도트 계약 형식으로 수행됩니다. 예 : 데이터 세트 경로를 무시하려면 python -u train.py config=path/to/config dataset.params.path=path/to/dataset 사용합니다.
동일한 명령을 사용하여 로컬로 훈련을 시작할 수 있습니다.
├── README.md
├── configs -> All training config files.
│ └── template_config.yaml
├── muse
│ ├── __init__.py
│ ├── data.py -> All data related utils. Can create a data folder if needed.
│ ├── logging.py -> Misc logging utils.
| ├── lr_schedulers.py -> All lr scheduler related utils.
│ ├── modeling_maskgit_vqgan.py -> VQGAN model from maskgit repo.
│ ├── modeling_taming_vqgan.py -> VQGAN model from taming repo.
│ └── modeling_transformer.py -> The main transformer model.
│ ├── modeling_utils.py -> All model related utils, like save_pretrained, from_pretrained from hub etc
│ ├── sampling.py -> Sampling/Generation utils.
│ ├── training_utils.py -> Common training utils.
├── pyproject.toml
├── setup.cfg
├── setup.py
├── test.py
└── training -> All training scripts.
├── __init__.py
├── data.py -> All data related utils. Can create a data folder if needed.
├── optimizer.py -> All optimizer related utils and any new optimizer not available in PT.
├── train_maskgit_imagenet.py
├── train_muse.py
└── train_vqgan.py
이 프로젝트는 다음과 같은 오픈 소스 저장소를 기반으로합니다. 그들의 놀라운 작품에 대해 모든 저자들에게 감사합니다.
그리고이 놀라운 프레임 워크를 위해 Pytorch 팀에 대한 obivioulsy