يحتوي هذا الريبو على تطبيق Pytorch للنمذجة التوليدية القائمة على النقاط من خلال المعادلات التفاضلية العشوائية
بقلم Yang Song ، Jascha Sohl-Dickstein ، Diverik P. Kingma ، Abhishek Kumar ، Stefano Ermon ، and Ben Poole
نقترح إطار عمل موحد يعتمد ويحسن العمل السابق على النماذج التوليدية القائمة على النقاط من خلال عدسة المعادلات التفاضلية العشوائية (SDEs). على وجه الخصوص ، يمكننا تحويل البيانات إلى توزيع ضوضاء بسيط من خلال عملية ستوكاستيك في الوقت المستمر الموصوفة بواسطة SDE. يمكن عكس هذا SDE لتوليد العينات إذا عرفنا درجة التوزيعات الهامشية في كل خطوة زمنية وسيطة ، والتي يمكن تقديرها بمطابقة الدرجات. يتم التقاط الفكرة الأساسية في الشكل أدناه:

يتيح عملنا فهمًا أفضل للنهج الحالية ، وخوارزميات أخذ العينات الجديدة ، وحساب الاحتمال الدقيق ، والترميز الفريد ، والتلاعب بالدولة الكامنة ، ويجلب قدرات توليد مشروطة جديدة (بما في ذلك على سبيل المثال لا الحصر ، التوليد الفاصلة ، والتعبير والتلوين) إلى عائلة النماذج التوليدية القائمة على النتيجة.
جميعها مجتمعة ، حققنا FID قدره 2.20 ودرجة البداية من 9.89 لتوليد غير مشروط على CIFAR-10 ، بالإضافة إلى توليد عالي الدقة من صور 1024 بكسل CELEBA-HQ (عينات أدناه). بالإضافة إلى ذلك ، حصلنا على قيمة احتمال قدرها 2.99 بت/خافت على صور CIFAR-10 موحدة.

بصرف النظر عن نماذج NCSN ++ و DDPM ++ في ورقتنا ، فإن قاعدة البيانات هذه تعيد أيضًا تنفيذ العديد من النماذج القائمة على الدرجات السابقة في مكان واحد ، بما في ذلك NCSN من النمذجة التوليدية عن طريق تقدير التدرجات لتوزيع البيانات ، و NCSNV2 من تقنيات محسّنة لنماذج التدريب المستندة إلى نقاط الدرجات ، و DDPM من تقويم نماذج النشر.
وهو يدعم تدريب نماذج جديدة ، وتقييم جودة عينة واحتمال النماذج الحالية. لقد صممنا الرمز بعناية ليكون وحدات وسهلة التوسيع إلى SDEs الجديدة أو المتنبئين أو المصححات.
معظم النماذج متوفرة الآن أيضًا؟ الناشرون و accesible عبر خط أنابيب الدرج.
يتيح لك الناشرون اختبار النماذج المستندة إلى SDE في Pytorch في خطين من الكود فقط.
يمكنك تثبيت الناشرين على النحو التالي:
pip install diffusers torch accelerate
ثم جرب النماذج مع خطين فقط من الكود:
from diffusers import DiffusionPipeline
model_id = "google/ncsnpp-ffhq-1024"
# load model and scheduler
sde_ve = DiffusionPipeline . from_pretrained ( model_id )
# run pipeline in inference (sample random noise and denoise)
image = sde_ve (). images [ 0 ]
# save image
image [ 0 ]. save ( "sde_ve_generated_image.png" )يمكن العثور على المزيد من النماذج مباشرة على المحور.
يرجى العثور على تطبيق Jax هنا ، والذي يدعم بالإضافة إلى ذلك الجيل الشرطي مع مصنف تم تدريبه مسبقًا ، واستئناف عملية التقييم بعد الاستباق.
بشكل عام ، يستهلك إصدار Pytorch ذاكرة أقل ولكنه يعمل أبطأ من Jax. فيما يلي معيار لتدريب NCSN ++ CONT. نموذج مع VE SDE. الأجهزة هي 4x nvidia tesla v100 وحدات معالجة الرسومات (32 جيجابايت)
| نطاق | الوقت (الثاني لكل خطوة) | استخدام الذاكرة في المجموع (GB) |
|---|---|---|
| Pytorch | 0.56 | 20.6 |
jax ( n_jitted_steps=1 ) | 0.30 | 29.7 |
jax ( n_jitted_steps=5 ) | 0.20 | 74.8 |
قم بتشغيل ما يلي لتثبيت مجموعة فرعية من حزم Python اللازمة للرمز الخاص بنا
pip install -r requirements.txt نحن نقدم ملف الإحصائيات لـ CIFAR-10. يمكنك تنزيل cifar10_stats.npz وحفظه على assets/stats/ . تحقق من #5 حول كيفية حساب ملف الإحصائيات لمجموعات البيانات الجديدة.
تدريب وتقييم نماذجنا من خلال main.py
main.py:
--config: Training configuration.
(default: ' None ' )
--eval_folder: The folder name for storing evaluation results
(default: ' eval ' )
--mode: < train | eval > : Running mode: train or eval
--workdir: Working directory config هو المسار إلى ملف التكوين. يتم توفير ملفات التكوين المقررة لدينا في configs/ . يتم تنسيقها وفقًا لـ ml_collections وينبغي أن تكون محسوسة تمامًا.
اتفاقيات تسمية ملفات التكوين : مسار ملف التكوين هو مزيج من الأبعاد التالية:
cifar10 ، celeba ، celebahq ، celebahq_256 ، ffhq_256 ، celebahq ، ffhq .ncsn ، ncsnv2 ، ncsnpp ، ddpm ، ddpmpp . workdir هو المسار الذي يخزن جميع القطع الأثرية لتجربة واحدة ، مثل نقاط التفتيش والعينات ونتائج التقييم.
eval_folder هو اسم مقلع فرعي في workdir الذي يخزن جميع القطع الأثرية لعملية التقييم ، مثل نقاط التفتيش الوصفية للوقاية من ما بعد التعبئة ، وعينات الصور ، ومكبرات من النتائج الكمية.
mode هو إما "قطار" أو "تقييم". عند تعيينه على "Train" ، يبدأ تدريب نموذج جديد ، أو يستأنف تدريب نموذج قديم في حالة وجود نقاط الفحص الوصفية (لاستئناف التشغيل بعد الاستبعاد في بيئة سحابة) في workdir/checkpoints-meta . عند ضبطه على "eval" ، يمكنه القيام بمزيج تعسفي لما يلي
تقييم وظيفة الخسارة على مجموعة بيانات الاختبار / التحقق من الصحة.
قم بإنشاء عدد ثابت من العينات وحساب درجة بدءها أو FID أو KID. قبل التقييم ، يجب بالفعل تنزيل/حساب ملفات الإحصائيات وتخزينها في assets/stats .
حساب احتمالية سجل على مجموعة بيانات التدريب أو الاختبار.
يمكن تكوين هذه الوظائف من خلال ملفات التكوين ، أو بشكل أكثر ملاءمة ، من خلال دعم سطر الأوامر لحزمة ml_collections . على سبيل المثال ، لإنشاء عينات وتقييم جودة العينة ، قم بتزويد العلم --config.eval.enable_sampling ؛ لحساب سجلات السجل ، قم بتزويد العلم --config.eval.enable_bpd ، وحدد --config.eval.dataset=train/test للإشارة إلى ما إذا كان سيتم حساب الاحتمالات على بيانات التدريب أو الاختبار.
sde_lib.SDE وتنفيذ جميع الأساليب التجريدية. طريقة discretize() اختياري والافتراضي هو تقديري Euler-Maruyama. ستعمل طرق أخذ العينات الحالية وحساب الاحتمالية تلقائيًا لهذا SDE الجديد.@register_predictor جديدة : متأصلة في update_fn sampling.Predictor . يمكن استخدام المتنبئ الجديد مباشرة في sampling.get_pc_sampler controllable_generation.pysampling.Corrector Abstract ، قم بتنفيذ طريقة update_fn الملخص ، وتسجيل اسمها باستخدام @register_corrector . يمكن استخدام المصحح الجديد مباشرة في sampling.get_pc_sampler ، وجميع طرق التوليد الأخرى التي يمكن التحكم فيها في controllable_generation.py . يتم توفير جميع نقاط التفتيش في محرك Google هذا.
التعليمات : قد تجد نقطتين تفتيش لبعض النماذج. نقطة التفتيش الأولى (برقم أصغر) هي تلك التي أبلغنا عنها في درجات FID في الجدول 3 لورقةنا (المقابلة أيضًا لـ FID وهي أعمدة في الجدول أدناه). نقطة التفتيش الثانية (برقم أكبر) هي تلك التي أبلغنا عنها قيم الاحتمالية و FIDs لأعينة قصيدة الصندوق الأسود في الجدول 2 من ورقنا (أيضًا FID (ODE) و NNL (بت/قاتمة) في الجدول أدناه). السابق يتوافق مع أصغر FID أثناء التدريب (كل 50 ألف تكرار). وفي وقت لاحق هو نقطة التفتيش الأخيرة أثناء التدريب.
وفقًا لسياسة Google ، لا يمكننا إصدار نقاط تفتيش Celeba و Celeba-HQ الأصلية. ومع ذلك ، قمت بإعادة تدريب نماذج على FFHQ 1024px و FFHQ 256PX و Celeba-HQ 256px مع الموارد الشخصية ، وحققت أداءً مشابهًا لنقاط التفتيش الداخلية لدينا.
فيما يلي قائمة مفصلة بنقاط التفتيش ونتائجها المبلغ عنها في الورقة. FID (ODE) يتوافق مع جودة عينة من حلقة ODE BLACK المطبقة على قصيدة تدفق الاحتمال.
| مسار نقطة التفتيش | fid | يكون | FID (قصيدة) | NNL (بت/قاتمة) |
|---|---|---|---|---|
ve/cifar10_ncsnpp/ | 2.45 | 9.73 | - | - |
ve/cifar10_ncsnpp_continuous/ | 2.38 | 9.83 | - | - |
ve/cifar10_ncsnpp_deep_continuous/ | 2.20 | 9.89 | - | - |
vp/cifar10_ddpm/ | 3.24 | - | 3.37 | 3.28 |
vp/cifar10_ddpm_continuous | - | - | 3.69 | 3.21 |
vp/cifar10_ddpmpp | 2.78 | 9.64 | - | - |
vp/cifar10_ddpmpp_continuous | 2.55 | 9.58 | 3.93 | 3.16 |
vp/cifar10_ddpmpp_deep_continuous | 2.41 | 9.68 | 3.08 | 3.13 |
subvp/cifar10_ddpm_continuous | - | - | 3.56 | 3.05 |
subvp/cifar10_ddpmpp_continuous | 2.61 | 9.56 | 3.16 | 3.02 |
subvp/cifar10_ddpmpp_deep_continuous | 2.41 | 9.57 | 2.92 | 2.99 |
| مسار نقطة التفتيش | عينات |
|---|---|
ve/bedroom_ncsnpp_continuous | ![]() |
ve/church_ncsnpp_continuous | ![]() |
ve/ffhq_1024_ncsnpp_continuous | ![]() |
ve/ffhq_256_ncsnpp_continuous | ![]() |
ve/celebahq_256_ncsnpp_continuous | ![]() |
| وصلة | وصف |
|---|---|
| قم بتحميل نقاط التفتيش الخاصة بنا واللعب مع أخذ العينات وحساب الاحتمالية والتوليف القابل للتحكم (Jax + Flax) | |
| قم بتحميل نقاط التفتيش المسبقة لدينا واللعب مع أخذ العينات ، وحساب الاحتمالية ، والتوليف القابل للتحكم (Pytorch) | |
| البرنامج التعليمي للنماذج التوليدية القائمة على الدرجات في Jax + Flax | |
| البرنامج التعليمي للنماذج التوليدية القائمة على الدرجات في Pytorch |
config.training.n_jitted_steps . بالنسبة لـ CIFAR-10 ، نوصي باستخدام config.training.n_jitted_steps=5 عندما يكون لدى GPU/TPU ذاكرة كافية ؛ وإلا فإننا نوصي باستخدام config.training.n_jitted_steps=1 . يتطلب تطبيقنا الحالي config.training.log_freq أن يكون قابلاً للقسمة بواسطة n_jitted_steps لتسجيلها والتحقق من العمل بشكل طبيعي.snr (نسبة الإشارة إلى الضوضاء) من LangevinCorrector إلى حد ما مثل معلمة درجة الحرارة. عادةً ما ينتج عن snr أكبر عينات أكثر سلاسة ، في حين أن snr الأصغر يعطي عينات أكثر تنوعًا ولكن أقل جودة. القيم النموذجية لـ snr هي 0.05 - 0.2 ، وتتطلب ضبط لضرب البقعة الحلوة.config.model.sigma_max ليكون أقصى مسافة زوجية بين عينات البيانات في مجموعة بيانات التدريب. إذا وجدت الرمز مفيدًا لبحثك ، فيرجى التفكير في الإشارة إلى
@inproceedings {
song2021scorebased,
title = { Score-Based Generative Modeling through Stochastic Differential Equations } ,
author = { Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole } ,
booktitle = { International Conference on Learning Representations } ,
year = { 2021 } ,
url = { https://openreview.net/forum?id=PxTIG12RRHS }
}هذا العمل مبني على بعض الأوراق السابقة التي قد تهمك أيضًا: