التنفيذ البديل البديل للشبكات النموذجية لعدد قليل من التعلم اللقطة (ورقة ، رمز) في Pytorch.
كما هو موضح في الشبكات النموذجية الورقية المرجعية ، يتم تدريب على تضمين ميزات العينات في مساحة متجهية ، c وجه الخصوص ، في كل حلقة ( n_support ) يمكن تقليل عينات n_query و Class Barycentre.

بعد التدريب ، يمكنك حساب T-SNE للميزات التي تم إنشاؤها بواسطة النموذج (لا يتم في هذا الريبو ، والمزيد من Infos حول T-SNE هنا) ، هذه عينة كما هو موضح في الورقة.

مجد إلى ludc لمساهمته: Pytorch/Vision#46. سنستخدم مجموعة البيانات الرسمية عندما تتم إضافتها إلى TorchVision إذا لم يكن ذلك يعني تغييرات كبيرة على الرمز.
قمنا بتنفيذ طريقة تقسيم Vynials كما في [شبكات مطابقة لتعلم لقطة واحدة]. سيكون هذا هو نفس الطريقة المستخدمة في الورقة (في الواقع أقوم بتنزيل ملفات الانقسام من الريبو "غير الرسمي"). ثم نطبق نفس الدورات الموضحة هناك. وبهذه الطريقة ، يجب أن نكون قادرين على مقارنة النتائج التي تم الحصول عليها عن طريق تشغيل هذا الرمز مع النتائج الموضحة في الورقة المرجعية.
كما هو موضح في PYDOC ، يتم استخدام هذه الفئة لإنشاء فهارس كل دفعة لخوارزمية التدريب النموذجية.
على وجه الخصوص ، يتم إنشاء مثيل للكائن من خلال تمرير قائمة الملصقات الخاصة بمجموعة البيانات ، ثم يقوم العينة بإصدار العدد الإجمالي للفئات وإنشاء مجموعة من الفهارس لكل فئة NI مجموعة البيانات. في كل حلقة ، يحدد Sampler فئات عشوائية n_classes ويعيد رقمًا ( n_support + n_query ) من فهارس العينات لكل فئة من الفئات المحددة.
حساب الخسارة كما في الورقة المذكورة ، مستوحاة من هذا الرمز من قبل أحد مؤلفيها.
في prototypical_loss.py يتم تنفيذ كل من وظيفة الخسارة وفئة الخسارة à la pytorch.
تأخذ الوظيفة إدخال مدخلات الدُفعات من النموذج ، وحقائق الأرض ، ورقم n_suppport للعينات لاستخدامها كعينات دعم. يتم استنتاج فصول الحلقة من القائمة المستهدفة ، يتم استخراج عينات n_support بشكل عشوائي لكل فئة ، ويتم حساب فئة BaryCentres الخاصة بهم ، وكذلك مسافات كل عينات متبقية من كل فئة Barycentre واحتمال كل عينة من الانتماء إلى كل فئة من فئة حلقة يتم حسابها ؛ ثم يتم حساب الخسارة من احتمالات التنبؤات الخاطئة (لعينات الاستعلام) كالمعتاد في مشاكل التصنيف.
يرجى ملاحظة أن رمز التدريب موجود هنا فقط لأغراض العرض التوضيحي.
لتدريب البروتونيت على هذه المهمة ، قرص مضغوط في مجلد الجذر src لهذا الريبو والتنفيذ:
$ python train.py
يأخذ البرنامج النصي خيارات سطر الأوامر التالية:
dataset_root : دليل الجذر حيث يتم تخزين مجموعة بيانات THA ، افتراضيًا إلى '../dataset'
nepochs : عدد الحدث للتدريب ، الافتراضي إلى 100
learning_rate : معدل التعلم للنموذج ، الافتراضي إلى 0.001
lr_scheduler_step : خطوة جدولة معدل التعلم Steplr ، الافتراضي إلى 20
lr_scheduler_gamma : جدولة معدل التعلم Steplr Gamma ، افتراضي إلى 0.5
iterations : عدد الحلقات لكل فترة. الافتراضي إلى 100
classes_per_it_tr : عدد الفئات العشوائية لكل حلقة للتدريب. الافتراضي إلى 60
num_support_tr : عدد العينات لكل فصل لاستخدامها كدعم للتدريب. الافتراضي إلى 5
num_query_tr : nnumber من العينات لكل فصل لاستخدامها كاستعلام للتدريب. الافتراضي إلى 5
classes_per_it_val : عدد الفئات العشوائية لكل حلقة للتحقق من الصحة. الافتراضي إلى 5
num_support_val : عدد العينات لكل فئة لاستخدامها كدعم للتحقق من الصحة. الافتراضي إلى 5
num_query_val : عدد العينات لكل فئة لاستخدامها كاستعلام للتحقق من الصحة. الافتراضي إلى 15
manual_seed : إدخال لتهيئة البذور اليدوية ، الافتراضي إلى 7
cuda : تمكين CUDA (متجر True )
سيقوم تشغيل الأمر بدون وسيط بتدريب النماذج بقيم Hyperparamters الافتراضية (إنتاج نتائج موضحة أعلاه).
نحن نحاول إعادة إنتاج الأداء الورقي المرجعي ، سنقوم بتحديث أفضل نتائجنا هنا.
| نموذج | 1 طلقة (5 اتجاهات acc.) | 5 طلقة (5 اتجاهات acc.) | 1 -shot (20 اتجاه acc.) | 5-Shot (20 اتجاه ACC.) |
|---|---|---|---|---|
| ورقة مرجعية | 98.8 ٪ | 99.7 ٪ | 96.0 ٪ | 98.9 ٪ |
| هذا الريبو | 98.5 ٪ ** | 99.6 ٪* | 95.1 ٪ ° | 98.6 ٪ ° |
* تم تحقيقه باستخدام المعلمات الافتراضية (باستخدام --cuda )
** حقق تشغيل python train.py --cuda -nsTr 1 -nsVa 1
حقق ° تشغيل python train.py --cuda -nsTr 1 -nsVa 1 -cVa 20
حقق ° ° تشغيل python train.py --cuda -nsTr 5 -nsVa 5 -cVa 20
استشهد بالورقة على النحو التالي (تم نسخها من Arxiv بالنسبة لك):
@article{DBLP:journals/corr/SnellSZ17,
author = {Jake Snell and
Kevin Swersky and
Richard S. Zemel},
title = {Prototypical Networks for Few-shot Learning},
journal = {CoRR},
volume = {abs/1703.05175},
year = {2017},
url = {http://arxiv.org/abs/1703.05175},
archivePrefix = {arXiv},
eprint = {1703.05175},
timestamp = {Wed, 07 Jun 2017 14:41:38 +0200},
biburl = {http://dblp.org/rec/bib/journals/corr/SnellSZ17},
bibsource = {dblp computer science bibliography, http://dblp.org}
}
هذا المشروع مرخص بموجب ترخيص معهد ماساتشوستس للتكنولوجيا
حقوق الطبع والنشر (C) 2018 Daniele E. Ciriello ، Orobix SRL (www.orobix.com).