ATUALIZAÇÃO 22/12/2021: Adicionado suporte para a versão Pytorch Lightning 1.5.6 e limpe o código.
Uma coleção de autoencoders variacionais (VAEs) implementados em Pytorch, com foco na reprodutibilidade. O objetivo deste projeto é fornecer um exemplo de trabalho rápido e simples para muitos dos modelos Cool VAE por aí. Todos os modelos são treinados no conjunto de dados Celeba para consistência e comparação. A arquitetura de todos os modelos é mantida o mais semelhante possível com as mesmas camadas, exceto nos casos em que o artigo original requer uma arquitetura radicalmente diferente (Ex. VQ VAE usa camadas residuais e sem norma em lote, diferentemente de outros modelos). Aqui estão os resultados de cada modelo.
$ git clone https://github.com/AntixK/PyTorch-VAE
$ cd PyTorch-VAE
$ pip install -r requirements.txt
$ cd PyTorch-VAE
$ python run.py -c configs/<config-file-name.yaml>
Modelo de arquivo de configuração
model_params :
name : " <name of VAE model> "
in_channels : 3
latent_dim :
. # Other parameters required by the model
.
.
data_params :
data_path : " <path to the celebA dataset> "
train_batch_size : 64 # Better to have a square number
val_batch_size : 64
patch_size : 64 # Models are designed to work for this size
num_workers : 4
exp_params :
manual_seed : 1265
LR : 0.005
weight_decay :
. # Other arguments required for training, like scheduler etc.
.
.
trainer_params :
gpus : 1
max_epochs : 100
gradient_clip_val : 1.5
.
.
.
logging_params :
save_dir : " logs/ "
name : " <experiment name> "View Tensorboard Logs
$ cd logs/<experiment name>/version_<the version you want>
$ tensorboard --logdir .
Nota: O conjunto de dados padrão é Celeba. No entanto, houve muitos problemas com o download do conjunto de dados do Google Drive (devido a algumas alterações na estrutura de arquivos). Portanto, a recomendação é baixar o arquivo do Google Drive diretamente e extrair para o caminho de sua escolha. O caminho padrão assumido nos arquivos de configuração é `data/celebba/img_align_celeba '. Mas você pode mudá -lo de acordo com sua preferência.
| Modelo | Papel | Reconstrução | Amostras |
|---|---|---|---|
| VAE (código, configuração) | Link | ![]() | ![]() |
| VAE condicional (código, configuração) | Link | ![]() | ![]() |
| Wae - mmd (kernel rbf) (código, configuração) | Link | ![]() | ![]() |
| Wae - mmd (kernel IMQ) (código, configuração) | Link | ![]() | ![]() |
| Beta-vas (código, configuração) | Link | ![]() | ![]() |
| Beta-vasos de beta (código, configuração) | Link | ![]() | ![]() |
| Beta-tc-vrae (código, configuração) | Link | ![]() | ![]() |
| Iwae ( k = 5 ) (código, configuração) | Link | ![]() | ![]() |
| Miwae ( k = 5, m = 3 ) (código, configuração) | Link | ![]() | ![]() |
| Dfcvae (código, configuração) | Link | ![]() | ![]() |
| MSSIM VAE (código, configuração) | Link | ![]() | ![]() |
| VAE categórico (código, configuração) | Link | ![]() | ![]() |
| Vae conjunta (código, configuração) | Link | ![]() | ![]() |
| Info Vae (código, configuração) | Link | ![]() | ![]() |
| LOGCOSH VAE (Código, Config) | Link | ![]() | ![]() |
| Swae (200 projeções) (código, configuração) | Link | ![]() | ![]() |
| VQ-VAE ( k = 512, d = 64 ) (código, configuração) | Link | ![]() | N / D |
| DIP VAE (código, configuração) | Link | ![]() | ![]() |
Se você treinou um modelo melhor, usando essas implementações, ajustando as hiper-paramas no arquivo de configuração, ficaria feliz em incluir seu resultado (junto com o arquivo de configuração) neste repositório, citando seu nome?.
Além disso, se você deseja contribuir com alguns modelos, envie um PR.
Licença Apache 2.0
| Permissões | Limitações | Condições |
|---|---|---|
| Use Uso comercial | Uso da marca registrada | Ⓘ Aviso de licença e direitos autorais |
| ✔️ Modificação | Responsabilidade | Ⓘ Mudanças de estado |
| Distribution Distribuição | Garantia | |
| Use Uso da patente | ||
| Use Uso particular |
@misc{Subramanian2020,
author = {Subramanian, A.K},
title = {PyTorch-VAE},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {url{https://github.com/AntixK/PyTorch-VAE}}
}