Dieses Repo enthält die offizielle Pytorch-Implementierung für die papiersternförmigen Denoising-Diffusion-probabilistischen Modelle-Ansatz zur Erstellung nicht gaußischer Diffusionsmodelle, die für verschiedene nicht-euklidische Verteiler anwendbar sind.
Von Andrey Okhotin*, Dmitry Molchanov*, Vladimir Arkhipkin, Grigory Bartosh, Viktor Ohanesian, Aibek Alanov, Dmitry Vetrov
Assistent: Sergei Kholkin

Denoising diffusion probabilistische Modelle (DDPMS) bilden die Grundlage für die jüngsten Durchbrüche bei der generativen Modellierung. Ihre markovsche Struktur macht es schwierig, DDPMs mit anderen Verteilungen als Gaußschen oder diskreten Verteilungen zu definieren. In diesem Artikel stellen wir sternförmige DDPM (SS-DDPM) vor. Der sternförmige Diffusionsprozess ermöglicht es uns, die Notwendigkeit zu umgehen, die Übergangswahrscheinlichkeiten zu definieren oder Posterioren zu berechnen. Wir stellen die Dualität zwischen sternförmigen und spezifischen Markovschen Diffusionen für die exponentielle Verteilungsfamilie fest und leiten effiziente Algorithmen für das Training und die Stichprobe von SS-DDPMs ab. Bei Gaußschen Verteilungen entspricht SS-DDPM DDPM. SS-DDPMs liefern jedoch ein einfaches Rezept für das Entwerfen von Diffusionsmodellen mit Verteilungen wie Beta, von Mises-Fisher, Dirichlet, Wishart und anderen, was besonders nützlich sein kann, wenn Daten auf einer eingeschränkten Mannigfaltigkeit liegen. Wir bewerten das Modell in verschiedenen Einstellungen und finden es auch bei Bilddaten, wobei Beta SS-DDPM Ergebnisse erzielt, die mit einem Gaußschen DDPM vergleichbar sind.
Haupt-SS-DDPM-Logik im Verzeichnis "lib/diffusion". Dies kann genug sein, wenn Sie wollen
Außerdem finden Sie Beispiele für die Verwendung von SS-DDPM für geodätische und synthetische Daten in Verzeichnis-Notizbüchern. Wenn Sie unsere Ergebnisse reproduzieren möchten, finden Sie Beispiele für Befehlsausführungen für Experimente zu CIFAR10 und Text8.
Repo -Struktur:
Dieses Repo wurde mit Torch == 1.12.0+CU113 TORCHVISION == 0.13.0+CU113 getestet
git clone https://github.com/andrey-okhotin/star-shaped
cd star-shaped
pip install -r requirements.txt
# only if you don't have pytorch or your pytorch version < 1.11
pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
# only for experiments with synthetic data, otherwise you can just comment all 'import npeet'
git clone https://github.com/gregversteeg/NPEET.git && cd NPEET && pip install . && cd ../ && rm -rf NPEETHerunterladen des Inhalts des Datasets -Ordners - für alle Pipelines erforderlich. Dieser Befehl kann ungefähr 5 Minuten dauern.
pip install py7zr gdown
rm -rf star-shaped/datasets
gdown --fuzzy https://drive.google.com/file/d/1ndXOmbNXR6pwoJ5qs1gVP0eAKU_RAl6E/view ? usp=sharing
py7zr x datasets.7z && rm datasets.7z && mv datasets star-shaped/datasetsHerunterladen des Inhalts des Ordners voraberhielt_models - Nicht erforderlich für Trainingspipelines. Dieser Befehl kann ungefähr 3 Minuten dauern.
pip install py7zr gdown
rm -rf star-shaped/pretrained_models
gdown --fuzzy https://drive.google.com/file/d/1Lebmsti31CwOFg4LYJYlWmlS7rGYQfVi/view ? usp=sharing
py7zr x pretrained_models.7z && rm pretrained_models.7z && mv pretrained_models star-shaped/pretrained_modelsErhältlich zum Ausführen bei Jupyter-Notebook im Verzeichnis ss_ddpm/Notebooks. Dort finden Sie Beispiele für Training und Probenahme für
Verfügbar für das Laufen bei Bash im Verzeichnis ss_ddpm
Ausführungsbefehl:
python lib/run_pipeline -gpu < gpu0_idx > _ < gpu1_idx > _ < gpu2_idx > -pipeline < pipeline_name > -logs_file < name_of_txt_file_to_write_execution_info > -port < available_port_for_processes_sync > . . . " other_pipeline_arguments "Ein kurzes Nutzungsbeispiel für das Ausführen von 3 GPUs auf einem einzelnen Knoten:
python lib/run_pipeline -gpu 0_1_2 -pipeline train_cifar10 -logs_file logs_train_cifar10.txt -port 8890 . . . " other_pipeline_arguments " Training Beta SS-DDPM auf 4 Nvidia v100 (benötigen ~ 32 GB GPU-Speicher). Kontrollpunkte werden im Verzeichnis "Checkpoints/Train_beta_SS_CIFAR10" gespeichert. Verlustgrafiken werden im Verzeichnis "Ergebnisse/Train_beta_ss_cifar10" gespeichert.
python lib/run_pipeline.py -gpu 0_1_2_3 -port 8900 -pipeline training_cifar10 -diffusion beta_ss -loss KL_rescaled -save_folder train_beta_ss_cifar10 -logs_file logs_training_beta_ss_cifar10.txt
cp checkpoints/training_beta_ss_cifar10/NCSNpp_episode0_epoch1050_model.pt pretrained_models/ncsnpp-cifar10_beta-ss.ptProbenahme Beta SS-DDPM auf 2 Nvidia V100. Die Ergebnisse werden im Verzeichnis "Ergebnisse/Sample_Beta_SS_CIFAR10/generated_samples" gespeichert.
python lib/run_pipeline.py -gpu 0_1 -port 8900 -pipeline sampling_cifar10 -diffusion beta_ss -num_sampling_steps 1000 -pretrained_model ncsnpp-cifar10_beta-ss.pt -num_samples 50000 -save_folder sampling_beta_ss_cifar10 -logs_file logs_sampling_beta_ss.txt
python -m pytorch_fid datasets/FID_cifar10_pack50000 results/sampling_beta_ss_cifar10/generated_samplesWenn Sie genau die gleichen Befehle ausführen, erhalten Sie FID ~ 3.24.
Trainingskategoriale SS-DDPM auf 4 Nvidia A100 (benötigen ~ 150 GB GPU-Speicher). Kontrollpunkte werden im Verzeichnis "Checkpoints/Training_Categorical_SS_TEXT8" gespeichert. Verlustgrafiken werden im Verzeichnis "Ergebnisse/Training_Categorical_SS_Text8" gespeichert.
python lib/run_pipeline.py -gpu 0_1_2_3 -port 8900 -pipeline training_text8 -diffusion categorical_ss -loss KL -save_folder training_categorical_ss_text8 -logs_file logs_training_categorical_ss.txt
cp checkpoints/training_categorical_ss_text8/T5Encoder_episode0_epoch2016_model.pt pretrained_models/t5base-text8_categorical-ss_fully-trained.ptSchätzung der NLL in kategorialen SS-DDPM auf 3 NVIDIA A100. Die Ergebnisse werden im Verzeichnis "Ergebnisse/nll_estimations" gespeichert.
python lib/run_pipeline.py -gpu 0_1_2 -port 8900 -pipeline estimating_nll_text8 -diffusion categorical_ss -pretrained_model t5base-text8_categorical-ss_fully-trained.pt -num_samples -1 -batch_size 1536 -dataset_part test -num_iwae_trajectories 1 -save_folder nll_text8_categorical-ss -logs_file logs_nll_text8_categorical_ss.txtWenn Sie genau die gleichen Befehle ausführen, erhalten Sie NLL ~ 1.61.
@ inproceedings { okhotin2023star ,
author = { Andrey Okhotin , Dmitry Molchanov , Vladimir Arkhipkin , Grigory Bartosh , Viktor Ohanesian , Aibek Alanov and Dmitry Vetrov },
title = { Star - Shaped Denoising Diffusion Probabilistic Models },
booktitle = { Advances in Neural Information Processing Systems },
volume = { 36 },
year = { 2023 }
}