Это реализация Pytorch квантового вариационного автоэкодора (https://arxiv.org/abs/1711.00937).
Вы можете найти оригинальную реализацию автора в Tensorflow здесь с примером, который вы можете запустить в ноутбуке Jupyter.
Чтобы установить зависимости, создайте Conda или виртуальную среду с Python 3, а затем запустите pip install -r requirements.txt .
Чтобы запустить vq-vae, просто запустите python3 main.py Обязательно включите флаг -save , если вы хотите сохранить свою модель. Вы также можете добавить параметры в командной строке. Значения по умолчанию указаны ниже:
parser . add_argument ( "--batch_size" , type = int , default = 32 )
parser . add_argument ( "--n_updates" , type = int , default = 5000 )
parser . add_argument ( "--n_hiddens" , type = int , default = 128 )
parser . add_argument ( "--n_residual_hiddens" , type = int , default = 32 )
parser . add_argument ( "--n_residual_layers" , type = int , default = 2 )
parser . add_argument ( "--embedding_dim" , type = int , default = 64 )
parser . add_argument ( "--n_embeddings" , type = int , default = 512 )
parser . add_argument ( "--beta" , type = float , default = .25 )
parser . add_argument ( "--learning_rate" , type = float , default = 3e-4 )
parser . add_argument ( "--log_interval" , type = int , default = 50 )VQ VAE имеет следующие фундаментальные компоненты модели:
Encoder , который определяет карту x -> z_eVectorQuantizer , который преобразует вывод энкодера в дискретный вектор с одним горячим, который является индексом ближайшего вектора встраивания z_e -> z_qDecoder , который определяет карту z_q -> x_hat и реконструирует исходное изображение Классы энкодера / декодера представляют собой сверточные и обратные сверточные стеки, которые включают остаточные блоки в их архитектуре, см. Resnet Paper. Остатовые модели определяются классами ResidualLayer и ResidualStack .
Эти компоненты организованы в следующей структуре папок:
models/
- decoder.py -> Decoder
- encoder.py -> Encoder
- quantizer.py -> VectorQuantizer
- residual.py -> ResidualLayer, ResidualStack
- vqvae.py -> VQVAE
Чтобы попробовать из скрытого пространства, мы устанавливаем Pixelcnn по скрытым значениям пикселя z_ij . Хитрость здесь заключается в том, что VQ VQ VAE отображает изображение в скрытое пространство, которое имеет ту же структуру, что и изображение 1 канала. Например, если вы запустите параметры VQ VQ VQ VQ VQ VAE, вы будете карты RGB изображения формы (32,32,3) в скрытое пространство с формой (8,8,1) , что эквивалентно изображению серого масштаба 8x8. Следовательно, вы можете использовать Pixelcnn, чтобы соответствовать распределению по значениям «пикселя» 1-канального скрытого пространства 8x8.
Чтобы тренировать Pixelcnn на скрытые представления, вам сначала нужно выполнить эти шаги:
np.save API. В quantizer.py это переменная min_encoding_indices .utils.load_latent_block FUNCTION.Чтобы запустить Pixelcnn, просто введите
python pixelcnn/gated_pixelcnn.py
а также любые параметры (см. Заявления Argparse). Набор данных по умолчанию - LATENT_BLOCK , который будет работать только в том случае, если вы обучили свой vq vae и сохранили скрытые представления.