هذا ملعب للمبتدئين في Pytorch ، والذي يحتوي على نماذج محددة مسبقًا على مجموعة البيانات الشهيرة. نحن ندعم حاليا
فيما يلي مثال على مجموعة بيانات MNIST. سيؤدي ذلك إلى تنزيل مجموعة البيانات والنموذج الذي تم تدريبه مسبقًا تلقائيًا.
import torch
from torch.autograd import Variable
from utee import selector
model_raw, ds_fetcher, is_imagenet = selector.select('mnist')
ds_val = ds_fetcher(batch_size=10, train=False, val=True)
for idx, (data, target) in enumerate(ds_val):
data = Variable(torch.FloatTensor(data)).cuda()
output = model_raw(data)
أيضًا ، إذا أردت تدريب طراز MLP على Mnist ، ما عليك سوى تشغيل python mnist/train.py
python3 setup.py develop --user
نحن نقدم مجموعة بيانات التحقق من صحة ImageNet مع حجم 224x224x3. قمنا أولاً بتغيير حجم الحجم الأقصر للصورة إلى 256 ، ثم نقضي صورة 224 × 224 في المركز. ثم نشفير الصور المزروعة على سلسلة JPG ونفوز على المخلل.
cd scriptval224_compressed.pkl (Tsinghua / Google Drive)python convert.py (يحتاج 48G ذاكرة ، شكرًا @jnorwood)نحن نقدم أيضًا عرضًا تجريبيًا بسيطًا لتحديد هذه النماذج إلى عرض البتات المحدد مع عدة طرق ، بما في ذلك الطريقة الخطية وطريقة Minmax والطريقة غير الخطية.
quantize --type cifar10 --quant_method linear --param_bits 8 --fwd_bits 8 --bn_bits 8 --ngpu 1
نقوم بتقييم أداء مجموعة البيانات والنماذج الشائعة ذات الأسلوب الكمي الخطية. عرض البتات في التشغيل المتوسط والتباين الجري في BN هو 10 بت لجميع النتائج. (باستثناء 32 حلو)
| نموذج | 32 حلو | 12 بت | 10 بت | 8 بت | 6 بت |
|---|---|---|---|---|---|
| mnist | 98.42 | 98.43 | 98.44 | 98.44 | 98.32 |
| SVHN | 96.03 | 96.03 | 96.04 | 96.02 | 95.46 |
| CIFAR10 | 93.78 | 93.79 | 93.80 | 93.58 | 90.86 |
| CIFAR100 | 74.27 | 74.21 | 74.19 | 73.70 | 66.32 |
| STL10 | 77.59 | 77.65 | 77.70 | 77.59 | 73.40 |
| اليكسنيت | 55.70/78.42 | 55.66/78.41 | 55.54/78.39 | 54.17/77.29 | 18.19/36.25 |
| VGG16 | 70.44/89.43 | 70.45/89.43 | 70.44/89.33 | 69.99/89.17 | 53.33/76.32 |
| VGG19 | 71.36/89.94 | 71.35/89.93 | 71.34/89.88 | 70.88/89.62 | 56.00/78.62 |
| RESNET18 | 68.63/88.31 | 68.62/88.33 | 68.49/88.25 | 66.80/87.20 | 19.14/36.49 |
| resnet34 | 72.50/90.86 | 72.46/90.82 | 72.45/90.85 | 71.47/90.00 | 32.25/55.71 |
| RESNET50 | 74.98/92.17 | 74.94/92.12 | 74.91/92.09 | 72.54/90.44 | 2.43/5.36 |
| RESNET101 | 76.69/93.30 | 76.66/93.25 | 76.22/92.90 | 65.69/79.54 | 1.41/1.18 |
| RESNET152 | 77.55/93.59 | 77.51/93.62 | 77.40/93.54 | 74.95/92.46 | 9.29/16.75 |
| Squeezenetv0 | 56.73/79.39 | 56.75/79.40 | 56.70/79.27 | 53.93/77.04 | 14.21/29.74 |
| Squeezenetv1 | 56.52/79.13 | 56.52/79.15 | 56.24/79.03 | 54.56/77.33 | 17.10/32.46 |
| InceptionV3 | 76.41/92.78 | 76.43/92.71 | 76.44/92.73 | 73.67/91.34 | 1.50/4.82 |
ملاحظة: نماذج ImageNet 32-float مباشرة من Torchvision
هنا نقدم نظرة عامة على وسيطات محددة من quantize.py
| علَم | القيمة الافتراضية | الوصف والخيارات |
|---|---|---|
| يكتب | CIFAR10 | Mnist ، Svhn ، Cifar10 ، Cifar100 ، STL10 ، Alexnet ، VGG16 ، VGG16_BN ، VGG19 ، VGG19_BN ، RESENT18 ، RESENT34 ، RESNET50 ، RESNET101 ، RESNET152 ، VISEEZENET_V0 ، SIPEEZENET_V1 |
| Quant_Method | خطي | طريقة القياس: الخطية ، minmax ، log ، tanh |
| param_bits | 8 | عرض البت من الأوزان والتحيز |
| FWD_BITS | 8 | عرض بت من التنشيط |
| bn_bits | 32 | عرض البتات من الجري المتوسط والجري فيرانس |
| overflow_rate | 0.0 | عتبة معدل الفائض لطريقة القياس الخطي |
| n_samples | 20 | عدد العينات لتقديم إحصائيات للتنشيط |