Implementasi referensi untuk autoencoder variasional di TensorFlow dan Pytorch.
Saya merekomendasikan versi Pytorch. Ini termasuk contoh keluarga variasional yang lebih ekspresif, aliran autoregresif terbalik.
Inferensi variasional digunakan agar sesuai dengan model dengan gambar digit tulisan tangan mnist binarized. Jaringan inferensi (encoder) digunakan untuk mengamortisasi inferensi dan berbagi parameter di seluruh titik data. Kemungkinan diparameterisasi oleh jaringan generatif (decoder).
Posting Blog: https://jaan.io/what-is-variational-autoencoder-vae-tutorial/
(Lingkungan Anaconda ada di environment-jax.yml )
Sampling penting digunakan untuk memperkirakan kemungkinan marjinal pada dataset mnist biner Hugo Larochelle. Kemungkinan marjinal akhir pada set tes adalah -97.10 NATS sebanding dengan angka yang diterbitkan.
$ python train_variational_autoencoder_pytorch.py --variational mean-field --use_gpu --data_dir $DAT --max_iterations 30000 --log_interval 10000
Step 0 Train ELBO estimate: -558.027 Validation ELBO estimate: -384.432 Validation log p(x) estimate: -355.430 Speed: 2.72e+06 examples/s
Step 10000 Train ELBO estimate: -111.323 Validation ELBO estimate: -109.048 Validation log p(x) estimate: -103.746 Speed: 2.64e+04 examples/s
Step 20000 Train ELBO estimate: -103.013 Validation ELBO estimate: -107.655 Validation log p(x) estimate: -101.275 Speed: 2.63e+04 examples/s
Step 29999 Test ELBO estimate: -106.642 Test log p(x) estimate: -100.309
Total time: 2.49 minutes
Menggunakan bidang-bidang non-rata-rata, perkiraan posterior variasional yang lebih ekspresif (aliran autoregresif terbalik, https://arxiv.org/abs/1606.04934), tes marginal log-likelihood meningkat menjadi -95.33 NATS:
$ python train_variational_autoencoder_pytorch.py --variational flow
step: 0 train elbo: -578.35
step: 0 valid elbo: -407.06 valid log p(x): -367.88
step: 10000 train elbo: -106.63
step: 10000 valid elbo: -110.12 valid log p(x): -104.00
step: 20000 train elbo: -101.51
step: 20000 valid elbo: -105.02 valid log p(x): -99.11
step: 30000 train elbo: -98.70
step: 30000 valid elbo: -103.76 valid log p(x): -97.71
Menggunakan JAX (Lingkungan Anaconda berada di environment-jax.yml ), untuk mendapatkan speedup 3x di atas Pytorch:
$ python train_variational_autoencoder_jax.py --variational mean-field
Step 0 Train ELBO estimate: -566.059 Validation ELBO estimate: -565.755 Validation log p(x) estimate: -557.914 Speed: 2.56e+11 examples/s
Step 10000 Train ELBO estimate: -98.560 Validation ELBO estimate: -105.725 Validation log p(x) estimate: -98.973 Speed: 7.03e+04 examples/s
Step 20000 Train ELBO estimate: -109.794 Validation ELBO estimate: -105.756 Validation log p(x) estimate: -97.914 Speed: 4.26e+04 examples/s
Step 29999 Test ELBO estimate: -104.867 Test log p(x) estimate: -96.716
Total time: 0.810 minutes
Aliran Autoregresif Terbalik di Jax:
$ python train_variational_autoencoder_jax.py --variational flow
Step 0 Train ELBO estimate: -727.404 Validation ELBO estimate: -726.977 Validation log p(x) estimate: -713.389 Speed: 2.56e+11 examples/s
Step 10000 Train ELBO estimate: -100.093 Validation ELBO estimate: -106.985 Validation log p(x) estimate: -99.565 Speed: 2.57e+04 examples/s
Step 20000 Train ELBO estimate: -113.073 Validation ELBO estimate: -108.057 Validation log p(x) estimate: -98.841 Speed: 3.37e+04 examples/s
Step 29999 Test ELBO estimate: -106.803 Test log p(x) estimate: -97.620
Total time: 2.350 minutes
(Perbedaan antara bidang rata -rata dan aliran autoregresif terbalik mungkin disebabkan oleh beberapa faktor, kepala adalah kurangnya konvolusi dalam implementasi. Blok residu digunakan dalam https://arxiv.org/pdf/1606.04934.pdf untuk mendapatkan ELBO lebih dekat ke -80 NATS.)
python train_variational_autoencoder_tensorflow.pyconvert -delay 20 -loop 0 *.jpg latent-space.gif