이 repo에는 확률 적 미분 방정식을 통한 종이 점수 기반 생성 모델링에 대한 Pytorch 구현이 포함되어 있습니다.
양 노래, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon 및 Ben Poole
우리는 확률 론적 미분 방정식 (SDE)의 렌즈를 통해 점수 기반 생성 모델에 대한 이전 작업을 일반화하고 개선하는 통합 프레임 워크를 제안합니다. 특히, 우리는 SDE가 설명하는 연속 시간 확률 프로세스로 데이터를 간단한 노이즈 분포로 변환 할 수 있습니다. 이 SDE는 각 중간 시간 단계에서 한계 분포의 점수를 알면 샘플 생성을 위해 반전 될 수 있으며, 이는 점수 일치로 추정 될 수 있습니다. 기본 아이디어는 아래 그림에서 캡처됩니다.

우리의 작업은 기존 접근법, 새로운 샘플링 알고리즘, 정확한 가능성 계산, 고유하게 식별 가능한 인코딩, 잠재적 코드 조작 및 새로운 조건부 생성 능력 (클래스 조건 생성, 수입 및 채색에 국한되지 않음)을 점수 기반 생성 모델의 패밀리에 대한 이해를 높일 수 있습니다.
모두 CIFAR-10에서 무조건 생성에 대해 2.20 의 FID와 Inception 점수 9.89 , 1024px eleba-HQ 이미지의 고 충실도 생성 (아래 샘플)을 달성했습니다. 또한, 우리는 균일하게 해제 된 CIFAR-10 이미지에서 2.99 비트/딤의 우도 값을 얻었습니다.

이 논문의 NCSN ++ 및 DDPM ++ 모델 외에도,이 코드베이스는 또한 데이터 배포의 그라디언트를 추정하여 NCSNV2를 비난하여 DDPM 에서 NCSNV2 를 추정하여 생성 모델링에서 NCSN을 포함하여 많은 이전 점수 기반 모델을 다시 이식하고, 확산 확산 모델을 비난합니다.
기존 모델의 샘플 품질과 가능성을 평가하는 새로운 모델 교육을 지원합니다. 우리는 새로운 SDE, 예측 변수 또는 교정기가 모듈화되고 쉽게 확장 가능하도록 코드를 신중하게 설계했습니다.
대부분의 모델도 사용할 수 있습니까? Scoresdeve 파이프 라인을 통해 디퓨저 및 적합성.
Diffusers를 사용하면 Pytorch의 SDE 기반 모델을 몇 줄의 코드로 테스트 할 수 있습니다.
다음과 같이 디퓨저를 설치할 수 있습니다.
pip install diffusers torch accelerate
그런 다음 몇 줄의 코드로 모델을 사용해보십시오.
from diffusers import DiffusionPipeline
model_id = "google/ncsnpp-ffhq-1024"
# load model and scheduler
sde_ve = DiffusionPipeline . from_pretrained ( model_id )
# run pipeline in inference (sample random noise and denoise)
image = sde_ve (). images [ 0 ]
# save image
image [ 0 ]. save ( "sde_ve_generated_image.png" )더 많은 모델이 허브에서 직접 찾을 수 있습니다.
여기에서 JAX 구현을 찾으십시오. 여기에서 예비 훈련 된 분류기를 사용한 클래스 조건 생성을 추가로 지원하고 선점 후 평가 프로세스를 재개하십시오.
일반적 으로이 Pytorch 버전은 메모리를 적게 소비하지만 JAX보다 느리게 실행됩니다. 다음은 NCSN ++ CONT를 훈련시키는 벤치 마크입니다. VE SDE를 가진 모델. 하드웨어는 4X NVIDIA TESLA V100 GPUS (32GB)입니다.
| 뼈대 | 시간 (단계당 두 번째) | 총계의 메모리 사용 (GB) |
|---|---|---|
| Pytorch | 0.56 | 20.6 |
jax ( n_jitted_steps=1 ) | 0.30 | 29.7 |
jax ( n_jitted_steps=5 ) | 0.20 | 74.8 |
코드에 필요한 Python 패키지의 서브 세트를 설치하려면 다음을 실행하십시오.
pip install -r requirements.txt CIFAR-10 용 통계 파일을 제공합니다. cifar10_stats.npz 다운로드하여 assets/stats/ 에 저장할 수 있습니다. 새 데이터 세트에 대한이 통계 파일을 계산하는 방법에 대해서는 #5를 확인하십시오.
main.py 통해 모델을 훈련시키고 평가하십시오.
main.py:
--config: Training configuration.
(default: ' None ' )
--eval_folder: The folder name for storing evaluation results
(default: ' eval ' )
--mode: < train | eval > : Running mode: train or eval
--workdir: Working directory config 구성 파일의 경로입니다. 처방 된 구성 파일은 configs/ 에서 제공됩니다. 그것들은 ml_collections 에 따라 형식화되며 상당히 자기 설명이어야합니다.
구성 파일의 명명 : 구성 파일의 경로는 다음 차원의 조합입니다.
cifar10 , celeba , celebahq , celebahq_256 , ffhq_256 , celebahq , ffhq 중 하나입니다.ncsn , ncsnv2 , ncsnpp , ddpm , ddpmpp 중 하나입니다. workdir 체크 포인트, 샘플 및 평가 결과와 같은 한 실험의 모든 인공물을 저장하는 경로입니다.
eval_folder 선점 예방, 이미지 샘플 및 양적 결과의 Numpy 덤프를위한 메타 검문소와 같은 평가 프로세스의 모든 인공물을 저장하는 workdir 의 하위 폴더의 이름입니다.
mode 는 "Train"또는 "Eval"입니다. "Train"으로 설정되면 새로운 모델의 교육을 시작하거나 메타 체크 포인트 (클라우드 환경에서 선점 후 재개)가 workdir/checkpoints-meta 에 존재하는 경우 이전 모델의 교육을 재개합니다. "평가"로 설정하면 다음을 임의의 조합을 수행 할 수 있습니다.
테스트 / 검증 데이터 세트에서 손실 함수를 평가하십시오.
고정 된 수의 샘플을 생성하고 시작 점수, FID 또는 KID를 계산합니다. 평가 전에 STATS 파일은 이미 assets/stats 에 다운로드/계산 및 저장되어야합니다.
훈련 또는 테스트 데이터 세트에서 로그 유효성을 계산하십시오.
이러한 기능은 ml_collections 패키지의 명령 줄 지원을 통해 구성 파일을 통해 구성 할 수 있습니다. 예를 들어, 샘플을 생성하고 샘플 품질을 평가하려면 --config.eval.enable_sampling 플래그를 공급하십시오. 로그 likelihoods를 계산하려면 --config.eval.enable_bpd 플래그를 제공하고 --config.eval.dataset=train/test 지정하여 교육 또는 테스트 데이터 세트의 가능성을 계산할지 여부를 지정하십시오.
sde_lib.SDE 초록 클래스 내재 및 모든 추상 방법을 구현합니다. discretize() 메소드는 선택 사항이고 기본값은 Euler-Maruyama 이산화입니다. 기존 샘플링 방법과 가능성 계산은이 새로운 SDE에 자동으로 작동합니다.sampling.Predictor 내재 포기 요약 클래스, update_fn 초록 메소드를 구현하고 @register_predictor 에 이름을 등록하십시오. 새로운 예측 변수는 sampling.get_pc_sampler 에 직접 사용할 수 있으며, 예측 변수-조정기 샘플링을위한 샘플링 .get_pc_sampler 및 controllable_generation.py 의 기타 모든 제어 가능한 생성 방법.sampling.Corrector 내재 된 Corrector Abstract 클래스, update_fn Abstract 메소드를 구현하고 @register_corrector 에 이름을 등록하십시오. 새로운 교정기는 sampling.get_pc_sampler 에 직접 사용할 수 있으며 controllable_generation.py 에서 기타 모든 제어 가능한 생성 방법. 모든 체크 포인트는이 Google 드라이브에서 제공됩니다.
지침 : 일부 모델에 대해 두 개의 체크 포인트를 찾을 수 있습니다. 첫 번째 체크 포인트 (더 적은 숫자)는 논문의 표 3에서 FID 점수를보고 한 것입니다 (FID에도 해당하고 아래 표에있는 열). 두 번째 체크 포인트 (더 큰 숫자)는 논문의 표 2 (또한 아래 표에있는 FID (ODE) 및 NNL (비트/딤) 열에서 블랙 박스 ODE 샘플러의 가능성 값과 FID를보고 한 것입니다. 전자는 훈련 과정에서 가장 작은 FID에 해당합니다 (50k 반복마다). 이후는 훈련 중 마지막 체크 포인트입니다.
Google의 정책에 따라 원래 Celeba 및 Celeba-HQ 체크 포인트를 공개 할 수 없습니다. 즉, 개인 리소스를 갖춘 FFHQ 1024PX, FFHQ 256PX 및 Celeba-HQ 256PX에서 모델을 재 훈련 한 상태에서 내부 체크 포인트와 유사한 성능을 달성했습니다.
다음은 검문소에 대한 자세한 목록과 논문에보고 된 결과입니다. FID (ODE)는 확률 흐름에 적용되는 블랙 박스 ODE 솔버의 샘플 품질에 해당합니다.
| 체크 포인트 경로 | 버팀대 | 이다 | FID (ODE) | NNL (비트/딤) |
|---|---|---|---|---|
ve/cifar10_ncsnpp/ | 2.45 | 9.73 | - | - |
ve/cifar10_ncsnpp_continuous/ | 2.38 | 9.83 | - | - |
ve/cifar10_ncsnpp_deep_continuous/ | 2.20 | 9.89 | - | - |
vp/cifar10_ddpm/ | 3.24 | - | 3.37 | 3.28 |
vp/cifar10_ddpm_continuous | - | - | 3.69 | 3.21 |
vp/cifar10_ddpmpp | 2.78 | 9.64 | - | - |
vp/cifar10_ddpmpp_continuous | 2.55 | 9.58 | 3.93 | 3.16 |
vp/cifar10_ddpmpp_deep_continuous | 2.41 | 9.68 | 3.08 | 3.13 |
subvp/cifar10_ddpm_continuous | - | - | 3.56 | 3.05 |
subvp/cifar10_ddpmpp_continuous | 2.61 | 9.56 | 3.16 | 3.02 |
subvp/cifar10_ddpmpp_deep_continuous | 2.41 | 9.57 | 2.92 | 2.99 |
| 체크 포인트 경로 | 샘플 |
|---|---|
ve/bedroom_ncsnpp_continuous | ![]() |
ve/church_ncsnpp_continuous | ![]() |
ve/ffhq_1024_ncsnpp_continuous | ![]() |
ve/ffhq_256_ncsnpp_continuous | ![]() |
ve/celebahq_256_ncsnpp_continuous | ![]() |
| 링크 | 설명 |
|---|---|
| 사전 간 체크 포인트를로드하고 샘플링, 가능성 계산 및 제어 가능한 합성 (JAX + Flax)으로 재생하십시오. | |
| 사전에 미리 검사 점을로드하고 샘플링, 가능성 계산 및 제어 가능한 합성 (Pytorch)으로 재생하십시오. | |
| Jax + Flax의 점수 기반 생성 모델 튜토리얼 | |
| Pytorch의 점수 기반 생성 모델 튜토리얼 |
config.training.n_jitted_steps 통해 설정할 수 있습니다. CIFAR-10의 경우 GPU/TPU가 충분한 메모리를 가지고있을 때 config.training.n_jitted_steps=5 사용하는 것이 좋습니다. 그렇지 않으면 config.training.n_jitted_steps=1 사용하는 것이 좋습니다. 현재 구현에는 config.training.log_freq 정상적으로 작동하도록 n_jitted_steps 에 의해 나눌 수 있어야합니다.LangevinCorrector 의 snr (신호 대 잡음비) 매개 변수는 온도 매개 변수처럼 다소 동작합니다. 더 큰 snr 일반적으로 더 부드러운 샘플을 초래하는 반면, 더 작은 snr 더 다양하지만 품질이 낮은 샘플을 제공합니다. snr 의 일반적인 값은 0.05 - 0.2 이며 달콤한 지점을 치기 위해 튜닝이 필요합니다.config.model.sigma_max 교육 데이터 세트의 데이터 샘플 사이의 최대 쌍별 거리로 선택하는 것이 좋습니다. 연구에 유용한 코드를 찾으면 인용을 고려하십시오.
@inproceedings {
song2021scorebased,
title = { Score-Based Generative Modeling through Stochastic Differential Equations } ,
author = { Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole } ,
booktitle = { International Conference on Learning Representations } ,
year = { 2021 } ,
url = { https://openreview.net/forum?id=PxTIG12RRHS }
}이 작업은 귀하에게 관심을 가질 수있는 이전 논문에 기반을두고 있습니다.