repo นี้มีการใช้งาน pytorch สำหรับการสร้างแบบจำลองการสร้างคะแนนกระดาษผ่านสมการเชิงอนุพันธ์สุ่ม
โดย Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon และ Ben Poole
เราเสนอเฟรมเวิร์กแบบครบวงจรที่สรุปและปรับปรุงงานก่อนหน้านี้เกี่ยวกับแบบจำลองการกำเนิดที่ใช้คะแนนผ่านเลนส์ของสมการเชิงอนุพันธ์สุ่ม (SDEs) โดยเฉพาะอย่างยิ่งเราสามารถแปลงข้อมูลเป็นการกระจายเสียงรบกวนอย่างง่ายด้วยกระบวนการสุ่มเวลาต่อเนื่องที่อธิบายโดย SDE SDE นี้สามารถย้อนกลับสำหรับการสร้างตัวอย่างได้หากเรารู้คะแนนของการแจกแจงส่วนเพิ่มในแต่ละขั้นตอนเวลากลางซึ่งสามารถประเมินได้ด้วยการจับคู่คะแนน แนวคิดพื้นฐานถูกจับในรูปด้านล่าง:

งานของเราช่วยให้เข้าใจวิธีการที่มีอยู่อัลกอริธึมการสุ่มตัวอย่างใหม่การคำนวณความน่าจะเป็นที่แน่นอนการเข้ารหัสที่สามารถระบุตัวตนได้โดยเฉพาะการจัดการรหัสแฝงและนำความสามารถในการสร้างแบบมีเงื่อนไขใหม่ (รวมถึง แต่ไม่ จำกัด เฉพาะ
ทั้งหมดรวมกันเราได้รับ FID 2.20 และคะแนนเริ่มต้นที่ 9.89 สำหรับการสร้างที่ไม่มีเงื่อนไขใน CIFAR-10 รวมถึงการสร้างความเที่ยงตรงสูงของภาพ 1024px Celeba-HQ (ตัวอย่างด้านล่าง) นอกจากนี้เรายังได้รับค่าความน่าจะเป็น 2.99 บิต/สลัวในภาพ CIFAR-10 ที่ลดลงอย่างสม่ำเสมอ

นอกเหนือจากโมเดล NCSN ++ และ DDPM ++ ในกระดาษของเราแล้ว codebase นี้ยังนำเสนอโมเดลที่ใช้คะแนนก่อนหน้านี้อีกครั้ง ใน ที่เดียวรวมถึง NCSN จากการสร้างแบบจำลองการกำเนิดโดยการประเมินการไล่ระดับสีของการกระจายข้อมูล NCSNV2 จากเทคนิคที่ได้รับการปรับปรุง
รองรับการฝึกอบรมโมเดลใหม่ประเมินคุณภาพตัวอย่างและโอกาสในการใช้งานของโมเดลที่มีอยู่ เราออกแบบรหัสให้เป็นแบบแยกส่วนและขยายได้อย่างง่ายดายสำหรับ SDE ใหม่ตัวทำนายหรือตัวแก้ไข
ตอนนี้รุ่นส่วนใหญ่มีอยู่ในตอนนี้? diffusers และ accesible ผ่านไปป์ไลน์ Scoresdeve
diffusers ช่วยให้คุณทดสอบโมเดลที่ใช้คะแนน SDE ใน Pytorch ในรหัสเพียงไม่กี่บรรทัด
คุณสามารถติดตั้ง diffusers ดังนี้:
pip install diffusers torch accelerate
จากนั้นลองใช้โมเดลที่มีรหัสเพียงไม่กี่บรรทัด:
from diffusers import DiffusionPipeline
model_id = "google/ncsnpp-ffhq-1024"
# load model and scheduler
sde_ve = DiffusionPipeline . from_pretrained ( model_id )
# run pipeline in inference (sample random noise and denoise)
image = sde_ve (). images [ 0 ]
# save image
image [ 0 ]. save ( "sde_ve_generated_image.png" )รุ่นเพิ่มเติมสามารถพบได้โดยตรงบนฮับ
โปรดค้นหาการใช้งาน JAX ที่นี่ซึ่งรองรับการสร้างชั้นเรียนแบบคลาสด้วยตัวจําแนกที่ผ่านการฝึกอบรมมาก่อนและกลับมาดำเนินการประเมินผลหลังจากการจองล่วงหน้า
โดยทั่วไปรุ่น pytorch นี้ใช้หน่วยความจำน้อยกว่า แต่ทำงานช้ากว่า Jax นี่คือเกณฑ์มาตรฐานในการฝึกอบรม NCSN ++ ต่อ โมเดลกับ VE SDE ฮาร์ดแวร์คือ 4x Nvidia Tesla V100 GPU (32GB)
| กรอบ | เวลา (ที่สองต่อขั้นตอน) | การใช้หน่วยความจำทั้งหมด (GB) |
|---|---|---|
| pytorch | 0.56 | 20.6 |
jax ( n_jitted_steps=1 ) | 0.30 | 29.7 |
jax ( n_jitted_steps=5 ) | 0.20 | 74.8 |
รันต่อไปนี้เพื่อติดตั้งชุดย่อยของแพ็คเกจ Python ที่จำเป็นสำหรับรหัสของเรา
pip install -r requirements.txt เราให้ไฟล์สถิติสำหรับ CIFAR-10 คุณสามารถดาวน์โหลด cifar10_stats.npz และบันทึกลงใน assets/stats/ ตรวจสอบ #5 เกี่ยวกับวิธีการคำนวณไฟล์สถิตินี้สำหรับชุดข้อมูลใหม่
ฝึกอบรมและประเมินแบบจำลองของเราผ่าน main.py
main.py:
--config: Training configuration.
(default: ' None ' )
--eval_folder: The folder name for storing evaluation results
(default: ' eval ' )
--mode: < train | eval > : Running mode: train or eval
--workdir: Working directory config เป็นพา ธ ไปยังไฟล์ config ไฟล์กำหนดค่าที่กำหนดของเรามีให้ใน configs/ พวกเขาจะถูกจัดรูปแบบตาม ml_collections และควรอธิบายตนเอง
การตั้งชื่อการประชุมของไฟล์กำหนดค่า : พา ธ ของไฟล์กำหนดค่าคือการรวมกันของขนาดต่อไปนี้:
cifar10 , celeba , celebahq , celebahq_256 , ffhq_256 , celebahq , ffhqncsn , ncsnv2 , ncsnpp , ddpm , ddpmpp workdir เป็นเส้นทางที่เก็บสิ่งประดิษฐ์ทั้งหมดของการทดลองเดียวเช่นจุดตรวจตัวอย่างและผลการประเมินผล
eval_folder เป็นชื่อของโฟลเดอร์ย่อยใน workdir ที่เก็บสิ่งประดิษฐ์ทั้งหมดของกระบวนการประเมินเช่นจุดตรวจ Meta สำหรับการป้องกันการปล่อยล่วงหน้าตัวอย่างภาพและการทิ้งผลเชิงปริมาณ
mode คือ "รถไฟ" หรือ "ประเมิน" เมื่อตั้งค่าเป็น "รถไฟ" มันจะเริ่มการฝึกอบรมรุ่นใหม่หรือดำเนินการฝึกอบรมแบบจำลองเก่าหากมีการตรวจสอบเมตา (สำหรับการทำงานต่อหลังจากการชำระเงินล่วงหน้าในสภาพแวดล้อมคลาวด์) มีอยู่ใน workdir/checkpoints-meta เมื่อตั้งค่าเป็น "eval" มันสามารถทำการรวมกันโดยพลการของสิ่งต่อไปนี้
ประเมินฟังก์ชั่นการสูญเสียในชุดข้อมูลการทดสอบ / การตรวจสอบ
สร้างตัวอย่างจำนวนคงที่และคำนวณคะแนนเริ่มต้น FID หรือ KID ก่อนการประเมินผลไฟล์สถิติจะต้องได้รับการดาวน์โหลด/คำนวณและเก็บไว้ใน assets/stats แล้ว
คำนวณความน่าจะเป็นบันทึกการฝึกซ้อมหรือชุดข้อมูลทดสอบ
ฟังก์ชันเหล่านี้สามารถกำหนดค่าผ่านไฟล์ config หรือสะดวกกว่าผ่านการสนับสนุนบรรทัดคำสั่งของแพ็คเกจ ml_collections ตัวอย่างเช่นในการสร้างตัวอย่างและประเมินคุณภาพตัวอย่างให้จัดหา --config.eval.enable_sampling ธง ในการคำนวณบันทึกการบันทึกการจัดหา --config.eval.enable_bpd Flag และระบุ --config.eval.dataset=train/test เพื่อระบุว่าจะคำนวณความเป็นไปได้ในชุดข้อมูลการฝึกอบรมหรือการทดสอบ
sde_lib.SDE และใช้วิธีนามธรรมทั้งหมด วิธี discretize() เป็นทางเลือกและค่าเริ่มต้นคือการแยกส่วนออยเลอร์-มารุยามา วิธีการสุ่มตัวอย่างที่มีอยู่และการคำนวณความน่าจะเป็นจะทำงานได้โดยอัตโนมัติสำหรับ SDE ใหม่นี้sampling.Predictor คลาสบทคัดย่อ predictor ใช้วิธี update_fn บทคัดย่อและลงทะเบียนชื่อด้วย @register_predictor ตัวทำนายใหม่สามารถใช้โดยตรงใน sampling.get_pc_sampler สำหรับการสุ่มตัวอย่างแบบทำนายและวิธีการสร้างอื่น ๆ ทั้งหมดใน controllable_generation.pysampling.Corrector คลาสบทคัดย่อของคลาสใช้วิธี update_fn บทคัดย่อและลงทะเบียนชื่อด้วย @register_corrector ตัวแก้ไขใหม่สามารถใช้โดยตรงใน sampling.get_pc_sampler และวิธีการสร้างอื่น ๆ ที่ควบคุมได้ทั้งหมดใน controllable_generation.py จุดตรวจทั้งหมดมีให้ใน Google Drive นี้
คำแนะนำ : คุณอาจพบจุดตรวจสองจุดสำหรับบางรุ่น จุดตรวจสอบแรก (ที่มีจำนวนน้อยกว่า) เป็นจุดที่เรารายงานคะแนน FID ในตารางที่ 3 ของกระดาษของเรา (ยังสอดคล้องกับ FID และคอลัมน์ในตารางด้านล่าง) จุดตรวจสอบที่สอง (ที่มีจำนวนมากขึ้น) คือจุดที่เรารายงานค่าความน่าจะเป็นและ fids ของตัวอย่าง ODE กล่องดำในตารางที่ 2 (เช่น FID (ODE) และคอลัมน์ NNL (BITS/DIM) ในตารางด้านล่าง) อดีตสอดคล้องกับ FID ที่เล็กที่สุดในระหว่างการฝึกอบรม (ทุกการทำซ้ำ 50K) ในภายหลังเป็นจุดตรวจสุดท้ายระหว่างการฝึกอบรม
ตามนโยบายของ Google เราไม่สามารถปล่อยจุดตรวจ Celeba และ Celeba-HQ ดั้งเดิมของเราได้ ที่กล่าวว่าฉันได้ฝึกอบรมแบบจำลองใหม่ใน FFHQ 1024PX, FFHQ 256PX และ CELEBA-HQ 256PX ด้วยทรัพยากรส่วนบุคคลและพวกเขาได้รับประสิทธิภาพที่คล้ายกันกับจุดตรวจสอบภายในของเรา
นี่คือรายละเอียดของจุดตรวจสอบและผลลัพธ์ของพวกเขาที่รายงานไว้ในกระดาษ FID (ODE) สอดคล้องกับคุณภาพตัวอย่างของตัวแก้ปัญหา ODE กล่องดำที่ใช้กับ ODE การไหลของความน่าจะเป็น
| เส้นทางจุดตรวจ | บด | เป็น | FID (ODE) | NNL (บิต/สลัว) |
|---|---|---|---|---|
ve/cifar10_ncsnpp/ | 2.45 | 9.73 | - | - |
ve/cifar10_ncsnpp_continuous/ | 2.38 | 9.83 | - | - |
ve/cifar10_ncsnpp_deep_continuous/ | 2.20 | 9.89 | - | - |
vp/cifar10_ddpm/ | 3.24 | - | 3.37 | 3.28 |
vp/cifar10_ddpm_continuous | - | - | 3.69 | 3.21 |
vp/cifar10_ddpmpp | 2.78 | 9.64 | - | - |
vp/cifar10_ddpmpp_continuous | 2.55 | 9.58 | 3.93 | 3.16 |
vp/cifar10_ddpmpp_deep_continuous | 2.41 | 9.68 | 3.08 | 3.13 |
subvp/cifar10_ddpm_continuous | - | - | 3.56 | 3.05 |
subvp/cifar10_ddpmpp_continuous | 2.61 | 9.56 | 3.16 | 3.02 |
subvp/cifar10_ddpmpp_deep_continuous | 2.41 | 9.57 | 2.92 | 2.99 |
| เส้นทางจุดตรวจ | ตัวอย่าง |
|---|---|
ve/bedroom_ncsnpp_continuous | ![]() |
ve/church_ncsnpp_continuous | ![]() |
ve/ffhq_1024_ncsnpp_continuous | ![]() |
ve/ffhq_256_ncsnpp_continuous | ![]() |
ve/celebahq_256_ncsnpp_continuous | ![]() |
| การเชื่อมโยง | คำอธิบาย |
|---|---|
| โหลดจุดตรวจสอบที่ได้รับการฝึกฝนและเล่นด้วยการสุ่มตัวอย่างการคำนวณความน่าจะเป็นและการสังเคราะห์ที่ควบคุมได้ (Jax + Flax) | |
| โหลดจุดตรวจสอบที่ได้รับการฝึกฝนและเล่นด้วยการสุ่มตัวอย่างการคำนวณความน่าจะเป็นและการสังเคราะห์ที่ควบคุมได้ (Pytorch) | |
| การสอนโมเดลการกำเนิดที่ใช้คะแนนใน Jax + Flax | |
| การสอนแบบจำลองการกำเนิดที่ใช้คะแนนใน Pytorch |
config.training.n_jitted_steps สำหรับ CIFAR-10 เราขอแนะนำให้ใช้ config.training.n_jitted_steps=5 เมื่อ GPU/TPU ของคุณมีหน่วยความจำเพียงพอ มิฉะนั้นเราแนะนำให้ใช้ config.training.n_jitted_steps=1 การใช้งานปัจจุบันของเราต้องใช้ config.training.log_freq เพื่อแบ่งออกเป็น n_jitted_steps สำหรับการเข้าสู่ระบบและการตรวจสอบเพื่อทำงานตามปกติsnr (อัตราส่วนสัญญาณต่อเสียงรบกวน) ของ LangevinCorrector ค่อนข้างทำงานเหมือนพารามิเตอร์อุณหภูมิ โดยทั่วไปแล้ว snr ที่ใหญ่กว่าจะส่งผลให้ตัวอย่างที่ราบรื่นขึ้นในขณะที่ snr ที่เล็กกว่าจะให้ตัวอย่างที่มีความหลากหลาย แต่มีคุณภาพต่ำกว่า ค่าทั่วไปของ snr คือ 0.05 - 0.2 และต้องมีการปรับจูนเพื่อโจมตีจุดหวานconfig.model.sigma_max ให้เป็นระยะทางคู่สูงสุดระหว่างตัวอย่างข้อมูลในชุดข้อมูลการฝึกอบรม หากคุณพบรหัสที่เป็นประโยชน์สำหรับการวิจัยของคุณโปรดพิจารณาอ้างถึง
@inproceedings {
song2021scorebased,
title = { Score-Based Generative Modeling through Stochastic Differential Equations } ,
author = { Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole } ,
booktitle = { International Conference on Learning Representations } ,
year = { 2021 } ,
url = { https://openreview.net/forum?id=PxTIG12RRHS }
}งานนี้สร้างขึ้นจากเอกสารก่อนหน้านี้ซึ่งอาจสนใจคุณ: