Implémentation de référence pour un autoencoder variationnel dans TensorFlow et Pytorch.
Je recommande la version Pytorch. Il comprend un exemple d'une famille variationnelle plus expressive, le flux autorégressif inverse.
L'inférence variationnelle est utilisée pour ajuster le modèle aux images binarisées des chiffres manuscrits MNIST. Un réseau d'inférence (encodeur) est utilisé pour amortir les paramètres d'inférence et de partage sur les points de données. La probabilité est paramétrée par un réseau génératif (décodeur).
Post de blog: https://jaan.io/what-is-variational-autoencoder-vae-titorial/
(L'environnement Anaconda est dans environment-jax.yml )
L'échantillonnage d'importance est utilisé pour estimer la probabilité marginale de l'ensemble de données binaire MNIST de Hugo LaRochelle. La probabilité marginale finale sur l'ensemble de tests était -97.10 NATS est comparable aux numéros publiés.
$ python train_variational_autoencoder_pytorch.py --variational mean-field --use_gpu --data_dir $DAT --max_iterations 30000 --log_interval 10000
Step 0 Train ELBO estimate: -558.027 Validation ELBO estimate: -384.432 Validation log p(x) estimate: -355.430 Speed: 2.72e+06 examples/s
Step 10000 Train ELBO estimate: -111.323 Validation ELBO estimate: -109.048 Validation log p(x) estimate: -103.746 Speed: 2.64e+04 examples/s
Step 20000 Train ELBO estimate: -103.013 Validation ELBO estimate: -107.655 Validation log p(x) estimate: -101.275 Speed: 2.63e+04 examples/s
Step 29999 Test ELBO estimate: -106.642 Test log p(x) estimate: -100.309
Total time: 2.49 minutes
En utilisant une approximation postérieure variationnelle non expressive et plus expressive (flux autorégressif inverse, https://arxiv.org/abs/1606.04934), le log-likelihain marginal de test s'améliore à -95.33 NATS:
$ python train_variational_autoencoder_pytorch.py --variational flow
step: 0 train elbo: -578.35
step: 0 valid elbo: -407.06 valid log p(x): -367.88
step: 10000 train elbo: -106.63
step: 10000 valid elbo: -110.12 valid log p(x): -104.00
step: 20000 train elbo: -101.51
step: 20000 valid elbo: -105.02 valid log p(x): -99.11
step: 30000 train elbo: -98.70
step: 30000 valid elbo: -103.76 valid log p(x): -97.71
L'utilisation de Jax (ANAConda Environment est dans environment-jax.yml ), pour obtenir une accélération 3X sur Pytorch:
$ python train_variational_autoencoder_jax.py --variational mean-field
Step 0 Train ELBO estimate: -566.059 Validation ELBO estimate: -565.755 Validation log p(x) estimate: -557.914 Speed: 2.56e+11 examples/s
Step 10000 Train ELBO estimate: -98.560 Validation ELBO estimate: -105.725 Validation log p(x) estimate: -98.973 Speed: 7.03e+04 examples/s
Step 20000 Train ELBO estimate: -109.794 Validation ELBO estimate: -105.756 Validation log p(x) estimate: -97.914 Speed: 4.26e+04 examples/s
Step 29999 Test ELBO estimate: -104.867 Test log p(x) estimate: -96.716
Total time: 0.810 minutes
Débit autorégressif inverse dans Jax:
$ python train_variational_autoencoder_jax.py --variational flow
Step 0 Train ELBO estimate: -727.404 Validation ELBO estimate: -726.977 Validation log p(x) estimate: -713.389 Speed: 2.56e+11 examples/s
Step 10000 Train ELBO estimate: -100.093 Validation ELBO estimate: -106.985 Validation log p(x) estimate: -99.565 Speed: 2.57e+04 examples/s
Step 20000 Train ELBO estimate: -113.073 Validation ELBO estimate: -108.057 Validation log p(x) estimate: -98.841 Speed: 3.37e+04 examples/s
Step 29999 Test ELBO estimate: -106.803 Test log p(x) estimate: -97.620
Total time: 2.350 minutes
(La différence entre un champ moyen et un flux autorégressif inverse peut être due à plusieurs facteurs, le chef étant le manque de convolutions dans la mise en œuvre. Les blocs résiduels sont utilisés dans https://arxiv.org/pdf/1606.04934.pdf pour rapprocher l'ELBO plus proche de -80 NATS.)
python train_variational_autoencoder_tensorflow.pyconvert -delay 20 -loop 0 *.jpg latent-space.gif