بقلم ألكساندر كولسنيكوف ، لوكاس باير ، شياووا تشاي ، جوان بويغريبفر ، جيسيكا يونغ ، سيلفان جيلي ، نيل هولسبي
تحديث 18/06/2021: نصدر نماذج جديدة عالية الأداء Bit-R50x1 ، والتي تم تقطيرها من Bit-M-R152x2 ، انظر هذا القسم. مزيد من التفاصيل في ورقتنا "تقطير المعرفة: المعلم الجيد هو صبور ومتسق".
تحديث 08/02/2021: نقوم أيضًا بإصدار جميع نماذج Bit-M التي تم ضبطها على جميع مجموعات بيانات VTAB-1K 19 ، انظر أدناه.
في هذا المستودع ، نقوم بإصدار نماذج متعددة من النقل الكبير (BIT): ورقة تعلم التمثيل البصري العام والتي تم تدريبها مسبقًا على مجموعات بيانات ILSVRC-2012 و IMAMENET-21K. نحن نقدم الرمز لضبط النماذج التي تم إصدارها في أطر التعلم العميق الرئيسية TensorFlow 2 و Pytorch و Jax/Flax.
نأمل أن يستفيد مجتمع رؤية الكمبيوتر من خلال توظيف نماذج ImageNet-21k PretRied أكثر قوة بدلاً من النماذج التقليدية التي تم تدريبها مسبقًا على مجموعة بيانات ILSVRC-2012.
نحن نقدم أيضًا كولابس لاستخدام تفاعلي أكثر استكشافًا: A TensorFlow 2 Colab ، و Pytorch Colab ، و Jax Colab.
تأكد من تثبيت Python>=3.6 على جهازك.
لإعداد TensorFlow 2 ، Pytorch أو Jax ، اتبع الإرشادات الواردة في المستودع المقابل المرتبط هنا.
بالإضافة إلى ذلك ، قم بتثبيت تبعيات Python عن طريق التشغيل (يرجى تحديد tf2 أو pytorch أو jax في الأمر أدناه):
pip install -r bit_{tf2|pytorch|jax}/requirements.txt
أولاً ، قم بتنزيل نموذج بت. نحن نقدم نماذج تم تدريبها مسبقًا على ILSVRC-2012 (BIT-S) أو ImageNet-21K (BIT-M) لـ 5 بنيات مختلفة: RESNET-50X1 ، RESNET-101X1 ، RESNET-50X3 ، RESNET-101X3 ، و RESNET-152X4.
على سبيل المثال ، إذا كنت ترغب في تنزيل Resnet-50x1 تم تدريبه مسبقًا على ImageNet-21K ، قم بتشغيل الأمر التالي:
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1.{npz|h5}
يمكن تنزيل النماذج الأخرى وفقًا لذلك عن طريق توصيل اسم النموذج (Bit-S أو Bit-M) والهندسة المعمارية في الأمر أعلاه. لاحظ أننا نقدم نماذج في تنسيقين: npz (ل Pytorch و Jax) و h5 (ل TF2). بشكل افتراضي ، نتوقع أن يتم تخزين الأوزان النموذجية في المجلد الجذر لهذا المستودع.
بعد ذلك ، يمكنك تشغيل صقل النموذج الذي تم تنزيله على مجموعة البيانات الخاصة بك في أي من الأطر الثلاثة. تشترك جميع الأطر في واجهة سطر الأوامر
python3 -m bit_{pytorch|jax|tf2}.train --name cifar10_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset cifar10
حالياً. ستقوم جميع الأطر تلقائيًا بتنزيل مجموعات بيانات CIFAR-10 و CIFAR-100. يمكن دمج مجموعات البيانات العامة أو المخصصة بسهولة: في TF2 و Jax ، نعتمد على مكتبة مجموعات بيانات TensorFlow القابلة للتمديد. في Pytorch ، نستخدم خط أنابيب إدخال بيانات TorchVision.
لاحظ أن التعليمات البرمجية الخاصة بنا تستخدم جميع وحدات معالجة الرسومات المتاحة للضبط.
نحن ندعم أيضًا التدريب في نظام البيانات المنخفضة: The --examples_per_class <K> سوف يرسم بشكل عشوائي عينات K لكل فصل للتدريب.
للاطلاع على قائمة مفصلة بجميع الأعلام المتاحة ، قم بتشغيل python3 -m bit_{pytorch|jax|tf2}.train --help .
للراحة ، نقدم نماذج Bit-M التي تم ضبطها بالفعل على مجموعة بيانات ILSVRC-2012. يمكن تنزيل النماذج عن طريق إضافة -ILSVRC2012 postfix ، على سبيل المثال
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz
نصدر جميع البنى المذكورة في الورقة ، بحيث يمكنك الاختيار بين الدقة أو السرعة: R50x1 ، R101x1 ، R50x3 ، R101x3 ، R152x4. في المسار أعلاه إلى ملف النموذج ، ما عليك سوى استبدال R50x1 بواسطة بنية الاختيار الخاصة بك.
لقد بحثنا كذلك في المزيد من البنى بعد نشر الورقة ووجدنا R152x2 للحصول على مفاضلة لطيفة بين السرعة والدقة ، وبالتالي ندرج هذا أيضًا في الإصدار ونقدم بعض الأرقام أدناه.
نقوم أيضًا بإصدار النماذج التي تم ضبطها لكل من المهام الـ 19 المدرجة في معيار VTAB-1K. ركضنا كل طراز ثلاث مرات ونطلق كل من هذه الأشواط. هذا يعني أننا ننشر ما مجموعه 5 × 19x3 = 285 نماذج ، ونأمل أن تكون هذه مفيدة في مزيد من التحليل لتعلم النقل.
يمكن تنزيل الملفات عبر النمط التالي:
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-{R50x1,R101x1,R50x3,R101x3,R152x4}-run{0,1,2}-{caltech101,diabetic_retinopathy,dtd,oxford_flowers102,oxford_iiit_pet,resisc45,sun397,cifar100,eurosat,patch_camelyon,smallnorb-elevation,svhn,dsprites-orientation,smallnorb-azimuth,clevr-distance,clevr-count,dmlab,kitti-distance,dsprites-xpos}.npz
لم نقم بتحويل هذه النماذج إلى TF2 (وبالتالي لا يوجد ملف .h5 المقابل) ، ومع ذلك ، قمنا أيضًا بتحميل نماذج TFHUB التي يمكن استخدامها في TF1 و TF2. مثال على تسلسل أوامر لتنزيل نموذج واحد من هذا القبيل هو:
mkdir BiT-M-R50x1-run0-caltech101.tfhub && cd BiT-M-R50x1-run0-caltech101.tfhub
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/{saved_model.pb,tfhub_module.pb}
mkdir variables && cd variables
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/variables/variables.{data@1,index}
بالنسبة للاستنساخ ، يستخدم البرنامج النصي التدريبي الخاص بنا المفردات المفرطة (البتات) التي تم استخدامها في الورقة الأصلية. ومع ذلك ، لاحظ أنه تم تدريب نماذج البتات واستخدامها باستخدام أجهزة Cloud TPU ، لذلك بالنسبة لإعداد GPU النموذجي ، قد يتطلب المعلمات المفرطة الافتراضية لدينا الكثير من الذاكرة أو تؤدي إلى تقدم بطيء للغاية. علاوة على ذلك ، تم تصميم Bit-hyperrule لتعميمها عبر العديد من مجموعات البيانات ، لذلك من الممكن عادةً وضع معلمات عالية الكفاءة في التطبيق. وبالتالي ، فإننا نشجع المستخدم على تجربة المزيد من إعدادات الوزن الخفيف ، لأنها تتطلب موارد أقل بكثير وغالبًا ما تؤدي إلى دقة مماثلة.
على سبيل المثال ، اختبرنا الكود الخاص بنا باستخدام جهاز GPU 8xv100 على مجموعات بيانات CIFAR-10 و CIFAR-100 ، مع تقليل حجم الدُفعة من 512 إلى 128 ومعدل التعلم من 0.003 إلى 0.001. نتج عن هذا الإعداد أداء متطابقًا تقريبًا (انظر النتائج المتوقعة أدناه) مقارنةً بالبيبرول ، على الرغم من كونه أقل تطلبًا حسابية.
أدناه ، نقدم المزيد من الاقتراحات حول كيفية تحسين إعداد ورقةنا.
تم تطوير البت الافتراضي على السحابة على السحابة وهو متعطش للذاكرة. ويرجع ذلك بشكل أساسي إلى حجم الدفعات الكبيرة (512) ودقة الصورة (حتى 480 × 480). فيما يلي بعض النصائح إذا نفدت الذاكرة:
bit_hyperrule.py نحدد دقة الإدخال. من خلال تقليله ، يمكن للمرء أن ينقذ الكثير من الذاكرة والحساب ، على حساب الدقة.--batch_split . على سبيل المثال ، يقلل تشغيل عملية التكييف مع- --batch_split 8 متطلبات الذاكرة بعامل 8. لقد تحققنا من أنه عند استخدام bit-hyperrule ، فإن الكود في هذا المستودع يعيد إنتاج نتائج الورقة.
بالنسبة لهذه المعايير الشائعة ، تؤدي التغييرات المذكورة أعلاه إلى hyperrule ( --batch 128 --base_lr 0.001 ) إلى نتائج مماثلة متشابهة للغاية. يوضح الجدول النتيجة المتوسطة ← الحد الأقصى لما لا يقل عن خمسة أشواط. ملاحظة : هذه ليست مقارنة بين الأطر ، فقط دليل على أن جميع قواعد التعليمات البرمجية يمكن الوثوق بها لإعادة إنتاج النتائج.
| مجموعة البيانات | السابقين/CLS | TF2 | جاكس | Pytorch |
|---|---|---|---|---|
| CIFAR10 | 1 | 52.5 ← 55.8 → 60.2 | 48.7 ← 53.9 → 65.0 | 56.4 ← 56.7 → 73.1 |
| CIFAR10 | 5 | 85.3 ← 87.2 → 89.1 | 80.2 ← 85.8 → 88.6 | 84.8 ← 85.8 → 89.6 |
| CIFAR10 | ممتلىء | 98.5 | 98.4 | 98.5 ← 98.6 → 98.6 |
| CIFAR100 | 1 | 34.8 ← 35.7 → 37.9 | 32.1 ← 35.0 → 37.1 | 31.6 ← 33.8 → 36.9 |
| CIFAR100 | 5 | 68.8 ← 70.4 → 71.4 | 68.6 ← 70.8 → 71.6 | 70.6 ← 71.6 → 71.7 |
| CIFAR100 | ممتلىء | 90.8 | 91.2 | 91.1 ← 91.2 → 91.4 |
| مجموعة البيانات | السابقين/CLS | جاكس | Pytorch |
|---|---|---|---|
| CIFAR10 | 1 | 44.0 ← 56.7 → 65.0 | 50.9 ← 55.5 → 59.5 |
| CIFAR10 | 5 | 85.3 ← 87.0 → 88.2 | 85.3 ← 85.8 → 88.6 |
| CIFAR10 | ممتلىء | 98.5 | 98.5 ← 98.5 → 98.6 |
| CIFAR100 | 1 | 36.4 ← 37.2 → 38.9 | 34.3 ← 36.8 → 39.0 |
| CIFAR100 | 5 | 69.3 ← 70.5 → 72.0 | 70.3 ← 72.0 → 72.3 |
| CIFAR100 | ممتلىء | 91.2 | 91.2 ← 91.3 → 91.4 |
(نماذج TF2 غير متوفرة بعد.)
| مجموعة البيانات | السابقين/CLS | TF2 | جاكس | Pytorch |
|---|---|---|---|---|
| CIFAR10 | 1 | 49.9 ← 54.4 → 60.2 | 48.4 ← 54.1 → 66.1 | 45.8 ← 57.9 → 65.7 |
| CIFAR10 | 5 | 80.8 ← 83.3 → 85.5 | 76.7 ← 82.4 → 85.4 | 80.3 ← 82.3 → 84.9 |
| CIFAR10 | ممتلىء | 97.2 | 97.3 | 97.4 |
| CIFAR100 | 1 | 35.3 ← 37.1 → 38.2 | 32.0 ← 35.2 → 37.8 | 34.6 ← 35.2 → 38.6 |
| CIFAR100 | 5 | 63.8 ← 65.0 → 66.5 | 63.4 ← 64.8 → 66.5 | 64.7 ← 65.5 → 66.0 |
| CIFAR100 | ممتلىء | 86.5 | 86.4 | 86.6 |
تم الحصول على هذه النتائج باستخدام bit-hyperrule. ومع ذلك ، نظرًا لأن هذا ينتج عنه حجم دفعة كبيرة ودقة كبيرة ، يمكن أن تكون الذاكرة مشكلة. يدعم رمز Pytorch تقسيم الدُفعات ، وبالتالي لا يزال بإمكاننا تشغيل الأشياء هناك دون اللجوء إلى Cloud TPUs عن طريق إضافة أمر- --batch_split N حيث N هي قوة اثنين. على سبيل المثال ، ينتج الأمر التالي دقة التحقق من صحة 80.68 على جهاز مع 8 V100 وحدات معالجة الرسومات:
python3 -m bit_pytorch.train --name ilsvrc_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset imagenet2012 --batch_split 4
زيادة إضافية إلى --batch_split 8 عند التشغيل مع 4 V100 وحدات معالجة الرسومات ، إلخ.
وكانت النتائج الكاملة التي تحققت بهذه الطريقة في بعض عمليات الاختبار:
| السابقين/CLS | R50x1 | R152x2 | R101x3 |
|---|---|---|---|
| 1 | 18.36 | 24.5 | 25.55 |
| 5 | 50.64 | 64.5 | 64.18 |
| ممتلىء | 80.68 | 85.15 | WIP |
هذه هي إعادة تشغيل وليس النماذج الورقية الدقيقة. درجات VTAB المتوقعة لاثنين من النماذج هي:
| نموذج | ممتلىء | طبيعي | منظم | متخصص |
|---|---|---|---|---|
| Bit-M-R152x4 | 73.51 | 80.77 | 61.08 | 85.67 |
| Bit-M-R101x3 | 72.65 | 80.29 | 59.40 | 85.75 |
في التذييل G من ورقتنا ، نتحقق مما إذا كان BIT يحسن متانة خارج السياق. للقيام بذلك ، أنشأنا مجموعة بيانات تضم كائنات مقدمة مقابلة لـ 21 فئة ILSVRC-2012 التي تم لصقها على 41 خلفيات متنوعة.
لتنزيل مجموعة البيانات ، قم بتشغيل
wget https://storage.googleapis.com/bit-out-of-context-dataset/bit_out_of_context_dataset.zip
يتم الاحتفاظ بالصور من كل فئة من الفصول الـ 21 في دليل يحمل اسم الفصل.
نقوم بإطلاق نماذج بت مضغوطة من أعلى الأداء من الورقة "تقطير المعرفة: المعلم الجيد هو صبور ومتسق" على تقطير knoweldge. على وجه الخصوص ، نقوم بتقطير نموذج Bit-M-R152x2 (الذي تم تدريبه مسبقًا على صور ImageNet-21k) إلى نماذج Bit-R50x1. نتيجة لذلك ، نحصل على نماذج مضغوطة بأداء تنافسي للغاية.
| نموذج | الرابط تنزيل | دقة | ImageNet Top-1 ACC. (ورق) |
|---|---|---|---|
| Bit-R50x1 | وصلة | 224 | 82.8 |
| Bit-R50x1 | وصلة | 160 | 80.5 |
من أجل الاستنساخ ، نقوم أيضًا بإصدار أوزان من نماذج المعلمين Bit-M-R152x2: PretRained في القرار 224 والقرار 384. راجع الورقة للحصول على تفاصيل حول كيفية استخدام هؤلاء المعلمين.
ليس لدينا أي خطط ملموسة لنشر رمز التقطير ، حيث أن الوصفة بسيطة ونتخيل أن معظم الناس سوف يدمجونه في رمز التدريب الحالي. ومع ذلك ، قام ساياك بول بإعادة تنفيذ إعداد التقطير في Tensorflow بشكل مستقل واستنسخ نتائجنا تقريبًا في العديد من الإعدادات.