Este repositorio contiene la implementación oficial de Pytorch para los modelos probabilísticos de difusión de difusión en forma de estrella en forma de estrella, enfoque para crear modelos de difusión no gaussianos aplicables a varios colectores no euclidianos.
Por Andrey Okhotin*, Dmitry Molchanov*, Vladimir Arkhipkin, Grigory Bartosh, Viktor Ohanesian, Aibek Alanov, Dmitry Vetrov
Asistente: Sergei Kholkin

Los modelos probabilísticos de difusión de Denoising (DDPMS) proporcionan la base para los avances recientes en el modelado generativo. Su estructura de Markovian dificulta la definición de DDPM con distribuciones distintas de las gaussianas o discretas. En este artículo, presentamos DDPM en forma de estrella (SS-DDPM). Su proceso de difusión en forma de estrella nos permite evitar la necesidad de definir las probabilidades de transición o calcular los posteriores. Establecemos la dualidad entre las difusiones de Markovian en forma de estrella y específicas para la familia exponencial de distribuciones, y obtenemos algoritmos eficientes para el entrenamiento y el muestreo de SS-DDPMS. En el caso de las distribuciones gaussianas, SS-DDPM es equivalente a DDPM. Sin embargo, los SS-DDPM proporcionan una receta simple para diseñar modelos de difusión con distribuciones como Beta, Von Mises: Fisher, Dirichlet, Wishart y otros, que pueden ser especialmente útiles cuando los datos se encuentran en un colector limitado. Evaluamos el modelo en diferentes entornos y lo encontramos competitivo incluso en los datos de imágenes, donde Beta SS-DDPM logra resultados comparables a un DDPM gaussiano.
La lógica principal SS-DDPM descrita en el directorio "lib/difusión". Esto puede ser suficiente si quieres
También puede encontrar ejemplos del uso de SS-DDPM en datos geodésicos y sintéticos en "cuadernos" del directorio. Si desea reproducir nuestros resultados, puede encontrar ejemplos de ejecuciones de comandos para experimentos en CIFAR10 y Text8.
Estructura de repo:
Este repositorio probado con antorcha == 1.12.0+CU113 TorchVision == 0.13.0+CU113
git clone https://github.com/andrey-okhotin/star-shaped
cd star-shaped
pip install -r requirements.txt
# only if you don't have pytorch or your pytorch version < 1.11
pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
# only for experiments with synthetic data, otherwise you can just comment all 'import npeet'
git clone https://github.com/gregversteeg/NPEET.git && cd NPEET && pip install . && cd ../ && rm -rf NPEETDescarga de contenido de la carpeta de conjuntos de datos : necesario para todas las tuberías. Este comando puede tomar unos 5 minutos.
pip install py7zr gdown
rm -rf star-shaped/datasets
gdown --fuzzy https://drive.google.com/file/d/1ndXOmbNXR6pwoJ5qs1gVP0eAKU_RAl6E/view ? usp=sharing
py7zr x datasets.7z && rm datasets.7z && mv datasets star-shaped/datasetsDescarga de contenido de la carpeta Pretureny_Models : no es necesario para capacitar tuberías. Este comando puede tomar unos 3 minutos.
pip install py7zr gdown
rm -rf star-shaped/pretrained_models
gdown --fuzzy https://drive.google.com/file/d/1Lebmsti31CwOFg4LYJYlWmlS7rGYQfVi/view ? usp=sharing
py7zr x pretrained_models.7z && rm pretrained_models.7z && mv pretrained_models star-shaped/pretrained_modelsDisponible para correr desde Jupyter-Notebook en el directorio SS_DDPM/Notebooks. Allí puede encontrar ejemplos de capacitación y muestreo para
Disponible para correr desde Bash en el directorio SS_DDPM
Ejecutando el comando:
python lib/run_pipeline -gpu < gpu0_idx > _ < gpu1_idx > _ < gpu2_idx > -pipeline < pipeline_name > -logs_file < name_of_txt_file_to_write_execution_info > -port < available_port_for_processes_sync > . . . " other_pipeline_arguments "Ejemplo de uso corto para ejecutar en 3 GPU en un solo nodo:
python lib/run_pipeline -gpu 0_1_2 -pipeline train_cifar10 -logs_file logs_train_cifar10.txt -port 8890 . . . " other_pipeline_arguments " Entrenamiento Beta SS-DDPM en 4 NVIDIA V100 (NECESITA ~ 32 GB de memoria GPU). Los puntos de control se guardarán en el directorio "Puntos de control/TRAIN_BETA_SS_CIFAR10". Los gráficos de pérdida se guardarán en el directorio "Resultados/trenes_beta_ss_cifar10".
python lib/run_pipeline.py -gpu 0_1_2_3 -port 8900 -pipeline training_cifar10 -diffusion beta_ss -loss KL_rescaled -save_folder train_beta_ss_cifar10 -logs_file logs_training_beta_ss_cifar10.txt
cp checkpoints/training_beta_ss_cifar10/NCSNpp_episode0_epoch1050_model.pt pretrained_models/ncsnpp-cifar10_beta-ss.ptMuestreo beta ss-ddpm en 2 nvidia v100. Los resultados se guardarán en el directorio "Resultados/muestras_beta_ss_cifar10/generado_samples".
python lib/run_pipeline.py -gpu 0_1 -port 8900 -pipeline sampling_cifar10 -diffusion beta_ss -num_sampling_steps 1000 -pretrained_model ncsnpp-cifar10_beta-ss.pt -num_samples 50000 -save_folder sampling_beta_ss_cifar10 -logs_file logs_sampling_beta_ss.txt
python -m pytorch_fid datasets/FID_cifar10_pack50000 results/sampling_beta_ss_cifar10/generated_samplesSi ejecuta exactamente los mismos comandos, obtendrá FID ~ 3.24.
Capacitación SS-DDPM categórica en 4 NVIDIA A100 (NECESITA ~ 150GB MEMORIA DE GPU). Los puntos de control se guardarán en el directorio "Puntos de control/entrenamiento_categorical_ss_text8". Los gráficos de pérdida se guardarán en el directorio "Resultados/entrenamiento_categorical_ss_text8".
python lib/run_pipeline.py -gpu 0_1_2_3 -port 8900 -pipeline training_text8 -diffusion categorical_ss -loss KL -save_folder training_categorical_ss_text8 -logs_file logs_training_categorical_ss.txt
cp checkpoints/training_categorical_ss_text8/T5Encoder_episode0_epoch2016_model.pt pretrained_models/t5base-text8_categorical-ss_fully-trained.ptEstimación de NLL en SS-DDPM categórico en 3 NVIDIA A100. Los resultados se guardarán en el directorio "Resultados/NLL_estimations".
python lib/run_pipeline.py -gpu 0_1_2 -port 8900 -pipeline estimating_nll_text8 -diffusion categorical_ss -pretrained_model t5base-text8_categorical-ss_fully-trained.pt -num_samples -1 -batch_size 1536 -dataset_part test -num_iwae_trajectories 1 -save_folder nll_text8_categorical-ss -logs_file logs_nll_text8_categorical_ss.txtSi ejecuta exactamente los mismos comandos, obtendrá NLL ~ 1.61.
@ inproceedings { okhotin2023star ,
author = { Andrey Okhotin , Dmitry Molchanov , Vladimir Arkhipkin , Grigory Bartosh , Viktor Ohanesian , Aibek Alanov and Dmitry Vetrov },
title = { Star - Shaped Denoising Diffusion Probabilistic Models },
booktitle = { Advances in Neural Information Processing Systems },
volume = { 36 },
year = { 2023 }
}