Это игровая площадка для начинающих Pytorch, которая содержит предопределенные модели на популярном наборе данных. В настоящее время мы поддерживаем
Вот пример для набора данных MNIST. Это автоматически загрузит набор данных и предварительно обученную модель.
import torch
from torch.autograd import Variable
from utee import selector
model_raw, ds_fetcher, is_imagenet = selector.select('mnist')
ds_val = ds_fetcher(batch_size=10, train=False, val=True)
for idx, (data, target) in enumerate(ds_val):
data = Variable(torch.FloatTensor(data)).cuda()
output = model_raw(data)
Кроме того, если хотите обучить модель MLP на Mnist, просто запустите python mnist/train.py
python3 setup.py develop --user
Мы предоставляем набор данных проверки ImageNet с размером 224x224x3. Сначала мы изменили размер более короткого размера изображения до 256, затем мы обрезаем изображение 224x224 в центре. Затем мы кодируем обрезанные изображения в строку JPG и сбрасываем в маринованную.
cd scriptval224_compressed.pkl (tsinghua / Google Drive)python convert.py (нуждается в памяти 48 г, спасибо @jnorwood)Мы также предоставляем простую демонстрацию для квантования этих моделей для указанной битовой ширины несколькими методами, включая линейный метод, метод MINMAX и нелинейный метод.
quantize --type cifar10 --quant_method linear --param_bits 8 --fwd_bits 8 --bn_bits 8 --ngpu 1
Мы оцениваем производительность популярного набора данных и моделей с линейным квантовым методом. Битовая ширина бега средней и бегущей дисперсии в BN составляет 10 бит для всех результатов. (За исключением 32-флоата)
| Модель | 32-й | 12-битный | 10-битный | 8-битный | 6-битный |
|---|---|---|---|---|---|
| Мнист | 98.42 | 98.43 | 98.44 | 98.44 | 98.32 |
| Svhn | 96.03 | 96.03 | 96.04 | 96.02 | 95,46 |
| Cifar10 | 93,78 | 93,79 | 93,80 | 93,58 | 90.86 |
| CIFAR100 | 74,27 | 74.21 | 74.19 | 73,70 | 66.32 |
| STL10 | 77.59 | 77.65 | 77.70 | 77.59 | 73,40 |
| Алекснет | 55,70/78,42 | 55,66/78,41 | 55,54/78,39 | 54.17/77.29 | 18.19/36.25 |
| VGG16 | 70,44/89,43 | 70,45/89,43 | 70.44/89,33 | 69,99/89,17 | 53,33/76,32 |
| VGG19 | 71.36/89,94 | 71,35/89,93 | 71.34/89,88 | 70,88/89,62 | 56.00/78.62 |
| Resnet18 | 68,63/88,31 | 68,62/88,33 | 68.49/88.25 | 66.80/87.20 | 19.14/36.49 |
| Resnet34 | 72.50/90,86 | 72,46/90,82 | 72,45/90,85 | 71.47/90.00 | 32,25/55,71 |
| Resnet50 | 74,98/92,17 | 74,94/92,12 | 74,91/92,09 | 72,54/90,44 | 2.43/5.36 |
| Resnet101 | 76.69/93.30 | 76,66/93,25 | 76.22/92,90 | 65,69/79,54 | 1.41/1.18 |
| Resnet152 | 77,55/93,59 | 77,51/93,62 | 77,40/93,54 | 74,95/92,46 | 9.29/16,75 |
| Squeezenetv0 | 56,73/79,39 | 56,75/79,40 | 56,70/79,27 | 53,93/77.04 | 14.21/29,74 |
| Squeezenetv1 | 56,52/79,13 | 56,52/79,15 | 56.24/79.03 | 54,56/77,33 | 17.10/32.46 |
| Началов3 | 76,41/92,78 | 76.43/92,71 | 76.44/92,73 | 73,67/91,34 | 1,50/4,82 |
Примечание: модели ImageNet 32-Float находятся непосредственно из Tourchvision
Здесь мы даем обзор выбранных аргументов quantize.py
| Флаг | Значение по умолчанию | Описание и параметры |
|---|---|---|
| тип | cifar10 | MNIST, SVHN, CIFAR10, CIFAR100, STL10, Alexnet, VGG16, VGG16_BN, VGG19, VGG19_BN, REVENT18, RESENT34, RESNET50, RESNET101, RESNET152, Squeezenet_V0, Squeezenet_v1, unception_v3 |
| QUANT_METHOD | линейный | Метод квантования: линейный, minmax, log, tanh |
| param_bits | 8 | битовая ширина весов и предвзятости |
| fwd_bits | 8 | битовая ширина активации |
| bn_bits | 32 | битовая ширина бега подлую и бегущую Vairance |
| overflow_rate | 0,0 | Порог скорости переполнения для линейного метода квантования |
| n_samples | 20 | количество образцов для создания статистики для активации |