Ini adalah taman bermain untuk pemula Pytorch, yang berisi model yang telah ditentukan pada dataset populer. Saat ini kami mendukung
Berikut adalah contoh untuk dataset MNIST. Ini akan mengunduh dataset dan model pra-terlatih secara otomatis.
import torch
from torch.autograd import Variable
from utee import selector
model_raw, ds_fetcher, is_imagenet = selector.select('mnist')
ds_val = ds_fetcher(batch_size=10, train=False, val=True)
for idx, (data, target) in enumerate(ds_val):
data = Variable(torch.FloatTensor(data)).cuda()
output = model_raw(data)
Juga, jika ingin melatih model MLP di MNIST, cukup jalankan python mnist/train.py
python3 setup.py develop --user
Kami memberikan dataset validasi Imagenet yang sudah diaktifkan dengan ukuran 224x224x3. Kami pertama -tama mengubah ukuran gambar yang lebih pendek menjadi 256, kemudian kami memotong gambar 224x224 di tengah. Kemudian kami menyandikan gambar yang dipotong ke string JPG dan dibuang ke acar.
cd scriptval224_compressed.pkl (Tsinghua / Google Drive)python convert.py (membutuhkan memori 48g, terima kasih @jnorwood)Kami juga memberikan demo sederhana untuk mengukur model ini untuk spesifikasi bit-lebar dengan beberapa metode, termasuk metode linier, metode MinMax dan metode non-linear.
quantize --type cifar10 --quant_method linear --param_bits 8 --fwd_bits 8 --bn_bits 8 --ngpu 1
Kami mengevaluasi kinerja dataset dan model populer dengan metode kuantisasi linier. Bit-lebar menjalankan rata-rata dan menjalankan varian dalam BN adalah 10 bit untuk semua hasil. (Kecuali 32-float)
| Model | 32-float | 12-bit | 10-bit | 8-bit | 6-bit |
|---|---|---|---|---|---|
| Mnist | 98.42 | 98.43 | 98.44 | 98.44 | 98.32 |
| Svhn | 96.03 | 96.03 | 96.04 | 96.02 | 95.46 |
| CIFAR10 | 93.78 | 93.79 | 93.80 | 93.58 | 90.86 |
| CIFAR100 | 74.27 | 74.21 | 74.19 | 73.70 | 66.32 |
| STL10 | 77.59 | 77.65 | 77.70 | 77.59 | 73.40 |
| Alexnet | 55.70/78.42 | 55.66/78.41 | 55.54/78.39 | 54.17/77.29 | 18.19/36.25 |
| VGG16 | 70.44/89.43 | 70.45/89.43 | 70.44/89.33 | 69.99/89.17 | 53.33/76.32 |
| VGG19 | 71.36/89.94 | 71.35/89.93 | 71.34/89.88 | 70.88/89.62 | 56.00/78.62 |
| Resnet18 | 68.63/88.31 | 68.62/88.33 | 68.49/88.25 | 66.80/87.20 | 19.14/36.49 |
| Resnet34 | 72.50/90.86 | 72.46/90.82 | 72.45/90.85 | 71.47/90.00 | 32.25/55.71 |
| Resnet50 | 74.98/92.17 | 74.94/92.12 | 74.91/92.09 | 72.54/90.44 | 2.43/5.36 |
| Resnet101 | 76.69/93.30 | 76.66/93.25 | 76.22/92.90 | 65.69/79.54 | 1.41/1.18 |
| Resnet152 | 77.55/93.59 | 77.51/93.62 | 77.40/93.54 | 74.95/92.46 | 9.29/16.75 |
| Squeezenetv0 | 56.73/79.39 | 56.75/79.40 | 56.70/79.27 | 53.93/77.04 | 14.21/29.74 |
| Squeezenetv1 | 56.52/79.13 | 56.52/79.15 | 56.24/79.03 | 54.56/77.33 | 17.10/32.46 |
| Inceptionv3 | 76.41/92.78 | 76.43/92.71 | 76.44/92.73 | 73.67/91.34 | 1.50/4.82 |
Catatan: Model Imagenet 32-Float langsung dari TorchVision
Di sini kami memberikan gambaran umum argumen terpilih dari quantize.py
| Bendera | Nilai default | Deskripsi & Opsi |
|---|---|---|
| jenis | CIFAR10 | mnist,svhn,cifar10,cifar100,stl10,alexnet,vgg16,vgg16_bn,vgg19,vgg19_bn,resent18,resent34,resnet50,resnet101,resnet152,squeezenet_v0,squeezenet_v1,inception_v3 |
| quant_method | linear | Metode Kuantisasi: Linear, Minmax, Log, Tanh |
| param_bits | 8 | bobot dan bias yang agak lebar |
| fwd_bits | 8 | Bit-lebar aktivasi |
| bn_bits | 32 | Sedikit lebar menjalankan rata-rata dan menjalankan vaAirance |
| overflow_rate | 0,0 | Ambang batas laju overflow untuk metode kuantisasi linier |
| n_samples | 20 | Jumlah sampel untuk membuat statistik untuk aktivasi |