Este repo contém uma implementação de Pytorch para a modelagem generativa baseada em pontuação em papel por meio de equações diferenciais estocásticas
Por Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon e Ben Poole
Propomos uma estrutura unificada que generaliza e melhora o trabalho anterior em modelos generativos baseados em pontuação através das lentes de equações diferenciais estocásticas (SDEs). Em particular, podemos transformar dados em uma simples distribuição de ruído com um processo estocástico de tempo contínuo descrito por um SDE. Este SDE pode ser revertido para geração de amostras se soubermos a pontuação das distribuições marginais em cada etapa intermediária de tempo, que pode ser estimada com a correspondência de pontuação. A ideia básica é capturada na figura abaixo:

Nosso trabalho permite uma melhor compreensão das abordagens existentes, novos algoritmos de amostragem, computação de probabilidade exata, codificação exclusivamente identificável, manipulação de código latente e traz novas habilidades de geração condicional (incluindo, mas não limitadas à geração condicional de classe, insignificante e colorização) para a família de modelos generativos baseados em pontuação.
Tudo combinado, alcançamos um FID de 2,20 e uma pontuação inicial de 9,89 para geração incondicional no CIFAR-10, bem como geração de alta fidelidade de imagens de 1024px celebrba-HQ (amostras abaixo). Além disso, obtivemos um valor de probabilidade de 2,99 bits/dim em imagens CIFAR-10 uniformemente desquantalizadas.

Além dos modelos NCSN ++ e DDPM ++ em nosso artigo, esta base de código também reimplementa muitos modelos anteriores baseados em pontuação em um só lugar, incluindo o NCSN da modelagem generativa estimando gradientes da distribuição de dados, NCSNV2 de modelos de probabilidade melhorados para treinamento de lentidão em treinamento e DDPM de denoising difusão de difusão.
Ele suporta o treinamento de novos modelos, avaliando a qualidade da amostra e as probabilidades dos modelos existentes. Projetamos cuidadosamente o código para ser modular e facilmente extensível para novos SDEs, preditores ou corretores.
A maioria dos modelos agora também está disponível? Difusores e acessíveis através do pipeline de pontuação.
Os difusores permitem testar modelos baseados em SDE em Pytorch em apenas algumas linhas de código.
Você pode instalar difusores da seguinte maneira:
pip install diffusers torch accelerate
E então experimente os modelos com apenas algumas linhas de código:
from diffusers import DiffusionPipeline
model_id = "google/ncsnpp-ffhq-1024"
# load model and scheduler
sde_ve = DiffusionPipeline . from_pretrained ( model_id )
# run pipeline in inference (sample random noise and denoise)
image = sde_ve (). images [ 0 ]
# save image
image [ 0 ]. save ( "sde_ve_generated_image.png" )Mais modelos podem ser encontrados diretamente no hub.
Encontre uma implementação do JAX aqui, que também suporta a geração condicional de classe com um classificador pré-treinado e retomando um processo de avaliação após a pré-emissão.
Em geral, esta versão pytorch consome menos memória, mas funciona mais devagar que o JAX. Aqui está uma referência no treinamento de um ncsn ++ cont. Modelo com ve SDE. Hardware é 4x nvidia tesla v100 gpus (32 GB)
| Estrutura | Tempo (segundo por etapa) | Uso de memória no total (GB) |
|---|---|---|
| Pytorch | 0,56 | 20.6 |
Jax ( n_jitted_steps=1 ) | 0,30 | 29.7 |
JAX ( n_jitted_steps=5 ) | 0,20 | 74.8 |
Execute o seguinte para instalar um subconjunto de pacotes python necessários para o nosso código
pip install -r requirements.txt Fornecemos o arquivo de estatísticas para o CIFAR-10. Você pode baixar cifar10_stats.npz e salvá -lo em assets/stats/ . Confira o #5 sobre como calcular este arquivo estatísticas para novos conjuntos de dados.
Treine e avalie nossos modelos através main.py
main.py:
--config: Training configuration.
(default: ' None ' )
--eval_folder: The folder name for storing evaluation results
(default: ' eval ' )
--mode: < train | eval > : Running mode: train or eval
--workdir: Working directory config é o caminho para o arquivo de configuração. Nossos arquivos de configuração prescritos são fornecidos em configs/ . Eles são formatados de acordo com ml_collections e devem ser bastante auto-explicativos.
Convenções de nomeação de arquivos de configuração : O caminho de um arquivo de configuração é uma combinação das seguintes dimensões:
cifar10 , celeba , celebahq , celebahq_256 , ffhq_256 , celebahq , ffhq .ncsn , ncsnv2 , ncsnpp , ddpm , ddpmpp . workdir é o caminho que armazena todos os artefatos de um experimento, como pontos de verificação, amostras e resultados de avaliação.
eval_folder é o nome de uma subpasta no workdir que armazena todos os artefatos do processo de avaliação, como meta-controle para prevenção de pré-emenda, amostras de imagem e despejos de resultados quantitativos.
mode é "trem" ou "avaliação". Quando definido como "treinar", ele inicia o treinamento de um novo modelo ou retoma o treinamento de um modelo antigo se seus meta-verificações (para retomar a execução após a preferência em um ambiente em nuvem) existirem no workdir/checkpoints-meta . Quando definido como "avaliar", ele pode fazer uma combinação arbitrária do seguinte
Avalie a função de perda no conjunto de dados de teste / validação.
Gere um número fixo de amostras e calcule sua pontuação de início, fid ou criança. Antes da avaliação, os arquivos de estatísticas já devem ter sido baixados/calculados e armazenados em assets/stats .
Calcule a probabilidade de log no conjunto de dados de treinamento ou teste.
Essas funcionalidades podem ser configuradas através de arquivos de configuração, ou mais convenientemente, através do suporte da linha de comando do pacote ml_collections . Por exemplo, para gerar amostras e avaliar a qualidade da amostra, forneça o sinalizador --config.eval.enable_sampling ; Para calcular as probabilidades de log, forneça o sinalizador --config.eval.enable_bpd e especifique --config.eval.dataset=train/test para indicar se deve calcular as probabilidades no conjunto de dados de treinamento ou teste.
sde_lib.SDE Resumo e implemente todos os métodos abstratos. O método discretize() é opcional e o padrão é a discretização de Euler-Maruyama. Os métodos de amostragem existentes e a computação de probabilidade funcionarão automaticamente para este novo SDE.@register_predictor preditores : update_fn the sampling.Predictor . O novo preditor pode ser usado diretamente no sampling.get_pc_sampler para amostragem preditora-corretor e todos os outros métodos de geração controlável em controllable_generation.py .sampling.Corrector Abstract, implemente o método de abstração update_fn e registre seu nome com @register_corrector . O novo corretor pode ser usado diretamente no sampling.get_pc_sampler e todos os outros métodos de geração controlável em controllable_generation.py . Todos os pontos de verificação são fornecidos nesta unidade do Google.
Instruções : Você pode encontrar dois pontos de verificação para alguns modelos. O primeiro ponto de verificação (com um número menor) é o que relatamos as pontuações do FID na Tabela 3 do nosso papel (também correspondendo ao FID e é colunas na tabela abaixo). O segundo ponto de verificação (com um número maior) é o que relatamos valores de probabilidade e fids de amostradores de Ode de Black Box na Tabela 2 do nosso artigo (também FID (ODE) e NNL (BITS/DIM) na tabela abaixo). O primeiro corresponde ao menor FID durante o curso do treinamento (a cada 50k iterações). O posterior é o último ponto de verificação durante o treinamento.
De acordo com a política do Google, não podemos lançar nossos pontos de verificação originais de Celeba e Celeba-HQ. Dito isto, eu treinei os modelos no FFHQ 1024PX, FFHQ 256PX e CELEBA-HQ 256PX com recursos pessoais, e eles alcançaram desempenho semelhante aos nossos pontos de verificação internos.
Aqui está uma lista detalhada de pontos de verificação e seus resultados relatados no artigo. O FID (ODE) corresponde à qualidade da amostra do solucionador de Ode de Black Box aplicado à Ode de Fluxo de Probabilidade.
| Caminho do ponto de verificação | Fid | É | Fid (ODE) | Nnl (bits/dim) |
|---|---|---|---|---|
ve/cifar10_ncsnpp/ | 2.45 | 9.73 | - | - |
ve/cifar10_ncsnpp_continuous/ | 2.38 | 9.83 | - | - |
ve/cifar10_ncsnpp_deep_continuous/ | 2.20 | 9.89 | - | - |
vp/cifar10_ddpm/ | 3.24 | - | 3.37 | 3.28 |
vp/cifar10_ddpm_continuous | - | - | 3.69 | 3.21 |
vp/cifar10_ddpmpp | 2.78 | 9.64 | - | - |
vp/cifar10_ddpmpp_continuous | 2.55 | 9.58 | 3.93 | 3.16 |
vp/cifar10_ddpmpp_deep_continuous | 2.41 | 9.68 | 3.08 | 3.13 |
subvp/cifar10_ddpm_continuous | - | - | 3.56 | 3.05 |
subvp/cifar10_ddpmpp_continuous | 2.61 | 9.56 | 3.16 | 3.02 |
subvp/cifar10_ddpmpp_deep_continuous | 2.41 | 9.57 | 2.92 | 2.99 |
| Caminho do ponto de verificação | Amostras |
|---|---|
ve/bedroom_ncsnpp_continuous | ![]() |
ve/church_ncsnpp_continuous | ![]() |
ve/ffhq_1024_ncsnpp_continuous | ![]() |
ve/ffhq_256_ncsnpp_continuous | ![]() |
ve/celebahq_256_ncsnpp_continuous | ![]() |
| Link | Descrição |
|---|---|
| Carregue nossos pontos de verificação pré -ridicularizados e brinque com amostragem, computação de probabilidade e síntese controlável (Jax + Flax) | |
| Carregue nossos pontos de verificação pré -treinamento e brinque com amostragem, computação de probabilidade e síntese controlável (Pytorch) | |
| Tutorial de modelos generativos baseados em pontuação em Jax + linho | |
| Tutorial de modelos generativos baseados em pontuação em Pytorch |
config.training.n_jitted_steps . Para o CIFAR-10, recomendamos o uso de config.training.n_jitted_steps=5 quando o seu GPU/TPU tem memória suficiente; Caso contrário, recomendamos o uso de config.training.n_jitted_steps=1 . Nossa implementação atual exige que config.training.log_freq seja dividido por n_jitted_steps para registro e check -inging para funcionar normalmente.snr (relação sinal-ruído) do LangevinCorrector se comporta um pouco como um parâmetro de temperatura. snr maior geralmente resulta em amostras mais suaves, enquanto snr menor fornece amostras de qualidade mais diversas, mas de menor qualidade. Os valores típicos do snr são de 0.05 - 0.2 , e requer ajuste para atingir o ponto ideal.config.model.sigma_max para ser a distância máxima em pares entre as amostras de dados no conjunto de dados de treinamento. Se você achar o código útil para sua pesquisa, considere citar
@inproceedings {
song2021scorebased,
title = { Score-Based Generative Modeling through Stochastic Differential Equations } ,
author = { Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole } ,
booktitle = { International Conference on Learning Representations } ,
year = { 2021 } ,
url = { https://openreview.net/forum?id=PxTIG12RRHS }
}Este trabalho é construído sobre alguns documentos anteriores que também podem lhe interessar: