นี่คือสนามเด็กเล่นสำหรับผู้เริ่มต้น Pytorch ซึ่งมีโมเดลที่กำหนดไว้ล่วงหน้าในชุดข้อมูลยอดนิยม ขณะนี้เราสนับสนุน
นี่คือตัวอย่างสำหรับชุดข้อมูล MNIST สิ่งนี้จะดาวน์โหลดชุดข้อมูลและรุ่นที่ผ่านการฝึกอบรมมาแล้วโดยอัตโนมัติ
import torch
from torch.autograd import Variable
from utee import selector
model_raw, ds_fetcher, is_imagenet = selector.select('mnist')
ds_val = ds_fetcher(batch_size=10, train=False, val=True)
for idx, (data, target) in enumerate(ds_val):
data = Variable(torch.FloatTensor(data)).cuda()
output = model_raw(data)
นอกจากนี้หากต้องการฝึกอบรมรุ่น MLP ใน MNIST เพียงแค่เรียกใช้ python mnist/train.py
python3 setup.py develop --user
เราให้บริการชุดข้อมูลการตรวจสอบความถูกต้องของ Imagenet ที่คาดการณ์ล่วงหน้าด้วยขนาด 224x224x3 ก่อนอื่นเราจะปรับขนาดภาพที่สั้นลงเป็น 256 จากนั้นเราครอบตัดภาพ 224x224 ที่ตรงกลาง จากนั้นเราเข้ารหัสภาพที่ถูกครอบตัดไปที่สตริง JPG และ Dump เพื่อดอง
cd scriptval224_compressed.pkl (tsinghua / Google ไดรฟ์)python convert.py (ต้องการหน่วยความจำ 48G, ขอบคุณ @jnorwood)นอกจากนี้เรายังให้การสาธิตอย่างง่ายเพื่อหาปริมาณแบบจำลองเหล่านี้เพื่อระบุความกว้างบิตที่ระบุด้วยวิธีการหลายวิธีรวมถึงวิธีการเชิงเส้นวิธี minmax และวิธีที่ไม่ใช่เชิงเส้น
quantize --type cifar10 --quant_method linear --param_bits 8 --fwd_bits 8 --bn_bits 8 --ngpu 1
เราประเมินประสิทธิภาพของชุดข้อมูลยอดนิยมและรุ่นด้วยวิธีเชิงปริมาณเชิงเส้น ความกว้างบิตของค่าเฉลี่ยการทำงานและความแปรปรวนการทำงานใน BN คือ 10 บิตสำหรับผลลัพธ์ทั้งหมด (ยกเว้น 32-float)
| แบบอย่าง | 32-float | 12 บิต | 10 บิต | 8 บิต | 6 บิต |
|---|---|---|---|---|---|
| MNIST | 98.42 | 98.43 | 98.44 | 98.44 | 98.32 |
| svhn | 96.03 | 96.03 | 96.04 | 96.02 | 95.46 |
| CIFAR10 | 93.78 | 93.79 | 93.80 | 93.58 | 90.86 |
| CIFAR100 | 74.27 | 74.21 | 74.19 | 73.70 | 66.32 |
| STL10 | 77.59 | 77.65 | 77.70 | 77.59 | 73.40 |
| Alexnet | 55.70/78.42 | 55.66/78.41 | 55.54/78.39 | 54.17/77.29 | 18.19/36.25 |
| VGG16 | 70.44/89.43 | 70.45/89.43 | 70.44/89.33 | 69.99/89.17 | 53.33/76.32 |
| VGG19 | 71.36/89.94 | 71.35/89.93 | 71.34/89.88 | 70.88/89.62 | 56.00/78.62 |
| resnet18 | 68.63/88.31 | 68.62/88.33 | 68.49/88.25 | 66.80/87.20 | 19.14/36.49 |
| resnet34 | 72.50/90.86 | 72.46/90.82 | 72.45/90.85 | 71.47/90.00 | 32.25/55.71 |
| resnet50 | 74.98/92.17 | 74.94/92.12 | 74.91/92.09 | 72.54/90.44 | 2.43/5.36 |
| resnet101 | 76.69/93.30 | 76.66/93.25 | 76.22/92.90 | 65.69/79.54 | 1.41/1.18 |
| Resnet152 | 77.55/93.59 | 77.51/93.62 | 77.40/93.54 | 74.95/92.46 | 9.29/16.75 |
| Squeezenetv0 | 56.73/79.39 | 56.75/79.40 | 56.70/79.27 | 53.93/77.04 | 14.21/29.74 |
| Squeezenetv1 | 56.52/79.13 | 56.52/79.15 | 56.24/79.03 | 54.56/77.33 | 17.10/32.46 |
| inceptionv3 | 76.41/92.78 | 76.43/92.71 | 76.44/92.73 | 73.67/91.34 | 1.50/4.82 |
หมายเหตุ: ImageNet 32-Float รุ่นโดยตรงจาก Torchvision
ที่นี่เราให้ภาพรวมของอาร์กิวเมนต์ที่เลือกของ quantize.py
| ธง | ค่าเริ่มต้น | คำอธิบายและตัวเลือก |
|---|---|---|
| พิมพ์ | CIFAR10 | MNIST, SVHN, CIFAR10, CIFAR100, STL10, ALEXNET, VGG16, VGG16_BN, VGG19, VGG19_BN, resent18, resent34, resnet50, resnet101, resnet152, squeezenet_v0, squeezenet_v0 |
| quant_method | เป็นเส้นตรง | วิธีการ Quantization: เชิงเส้น, minmax, log, tanh |
| param_bits | 8 | ความกว้างของน้ำหนักและอคติ |
| fwd_bits | 8 | ความกว้างของการเปิดใช้งาน |
| bn_bits | 32 | ความกว้างบิตของการวิ่งเฉลี่ยและรันร่องรอย |
| Overflow_rate | 0.0 | เกณฑ์อัตราการล้นสำหรับวิธีการหาปริมาณเชิงเส้น |
| n_samples | 20 | จำนวนตัวอย่างที่จะสร้างสถิติสำหรับการเปิดใช้งาน |