pytorch playground
1.0.0
这是Pytorch初学者的游乐场,该游乐设施包含流行数据集中的预定义型号。目前,我们支持
这是MNIST数据集的示例。这将自动下载数据集和预训练模型。
import torch
from torch.autograd import Variable
from utee import selector
model_raw, ds_fetcher, is_imagenet = selector.select('mnist')
ds_val = ds_fetcher(batch_size=10, train=False, val=True)
for idx, (data, target) in enumerate(ds_val):
data = Variable(torch.FloatTensor(data)).cuda()
output = model_raw(data)
另外,如果想在MNIST上训练MLP模型,只需运行python mnist/train.py
python3 setup.py develop --user
我们提供具有224x224x3尺寸的预算成像网验证数据集。我们首先将较短的图像大小调整到256,然后在中心裁剪224x224图像。然后,我们将裁剪的图像编码为JPG字符串,然后将其转储到泡菜中。
cd scriptval224_compressed.pkl (tsinghua / google drive)python convert.py (需要48克记忆,谢谢@jnorwood)我们还提供了一个简单的演示,可以用几种方法将这些模型量化为指定的位宽度,包括线性方法,Minmax方法和非线性方法。
quantize --type cifar10 --quant_method linear --param_bits 8 --fwd_bits 8 --bn_bits 8 --ngpu 1
我们使用线性量化方法评估流行数据集和模型的性能。 BN的均值均值和运行差异的位宽度为10位。 (除32条浮动外)
| 模型 | 32流动 | 12位 | 10位 | 8位 | 6位 |
|---|---|---|---|---|---|
| mnist | 98.42 | 98.43 | 98.44 | 98.44 | 98.32 |
| SVHN | 96.03 | 96.03 | 96.04 | 96.02 | 95.46 |
| CIFAR10 | 93.78 | 93.79 | 93.80 | 93.58 | 90.86 |
| CIFAR100 | 74.27 | 74.21 | 74.19 | 73.70 | 66.32 |
| STL10 | 77.59 | 77.65 | 77.70 | 77.59 | 73.40 |
| Alexnet | 55.70/78.42 | 55.66/78.41 | 55.54/78.39 | 54.17/77.29 | 18.19/36.25 |
| VGG16 | 70.44/89.43 | 70.45/89.43 | 70.44/89.33 | 69.99/89.17 | 53.33/76.32 |
| VGG19 | 71.36/89.94 | 71.35/89.93 | 71.34/89.88 | 70.88/89.62 | 56.00/78.62 |
| RESNET18 | 68.63/88.31 | 68.62/88.33 | 68.49/88.25 | 66.80/87.20 | 19.14/36.49 |
| Resnet34 | 72.50/90.86 | 72.46/90.82 | 72.45/90.85 | 71.47/90.00 | 32.25/55.71 |
| RESNET50 | 74.98/92.17 | 74.94/92.12 | 74.91/92.09 | 72.54/90.44 | 2.43/5.36 |
| RESNET101 | 76.69/93.30 | 76.66/93.25 | 76.22/92.90 | 65.69/79.54 | 1.41/1.18 |
| RESNET152 | 77.55/93.59 | 77.51/93.62 | 77.40/93.54 | 74.95/92.46 | 9.29/16.75 |
| SqueezenetV0 | 56.73/79.39 | 56.75/79.40 | 56.70/79.27 | 53.93/77.04 | 14.21/29.74 |
| Squeezenetv1 | 56.52/79.13 | 56.52/79.15 | 56.24/79.03 | 54.56/77.33 | 17.10/32.46 |
| Inceptionv3 | 76.41/92.78 | 76.43/92.71 | 76.44/92.73 | 73.67/91.34 | 1.50/4.82 |
注意:Imagenet 32浮动模型直接来自Torchvision
在这里,我们概述了quantize.py的选定参数。
| 旗帜 | 默认值 | 描述和选项 |
|---|---|---|
| 类型 | CIFAR10 | Mnist,SVHN,CIFAR10,CIFAR100,STL10,ALEXNET,VGG16,VGG16_BN,VGG19,VGG19_BN,RESENT18,RESENT18,RESENT34,RESNET34,RESNET50,RESNET101,RESNET101,RESNET152,RESNET152,SECEEEZEEZEEZENET_V0,SENKEEZENET_V0,SENGEEEEEEZENET___ESPEND___ESPECT______ESTICT__V1 |
| Quant_Method | 线性 | 量化方法:线性,minmax,log,tanh |
| param_bits | 8 | 重量和偏见的位宽度 |
| fwd_bits | 8 | 激活的位宽度 |
| bn_bits | 32 | 跑步的位宽度和跑步vairance |
| Overflow_rate | 0.0 | 线性量化方法的溢流率阈值 |
| n_samples | 20 | 制作激活统计的样本数量 |