Este é um playground para o Pytorch Beginners, que contém modelos predefinidos no conjunto de dados populares. Atualmente apoiamos
Aqui está um exemplo para o conjunto de dados MNIST. Isso baixará o conjunto de dados e o modelo pré-treinado automaticamente.
import torch
from torch.autograd import Variable
from utee import selector
model_raw, ds_fetcher, is_imagenet = selector.select('mnist')
ds_val = ds_fetcher(batch_size=10, train=False, val=True)
for idx, (data, target) in enumerate(ds_val):
data = Variable(torch.FloatTensor(data)).cuda()
output = model_raw(data)
Além disso, se quiser treinar o modelo MLP no MNIST, basta executar python mnist/train.py
python3 setup.py develop --user
Fornecemos um conjunto de dados de validação do ImageNet pré -computado com tamanho 224x224x3. Primeiro, redimensionamos o tamanho mais curto da imagem para 256 e depois cortamos a imagem 224x224 no centro. Em seguida, codificamos as imagens cortadas na corda JPG e no despejo para picarble.
cd scriptval224_compressed.pkl (Tsinghua / Google Drive)python convert.py (precisa de 48g de memória, obrigado @jnorwood)Também fornecemos uma demonstração simples para quantizar esses modelos para largura de bits especificada com vários métodos, incluindo método linear, método Minmax e método não linear.
quantize --type cifar10 --quant_method linear --param_bits 8 --fwd_bits 8 --bn_bits 8 --ngpu 1
Avaliamos o desempenho do conjunto de dados populares e dos modelos com método quantizado linear. A largura de bits da média de execução e a variação em execução no BN são 10 bits para todos os resultados. (exceto 32-float)
| Modelo | 32-FLOAT | 12 bits | 10 bits | 8 bits | 6 bits |
|---|---|---|---|---|---|
| Mnist | 98.42 | 98.43 | 98.44 | 98.44 | 98.32 |
| Svhn | 96.03 | 96.03 | 96.04 | 96.02 | 95.46 |
| Cifar10 | 93.78 | 93.79 | 93.80 | 93.58 | 90.86 |
| Cifar100 | 74.27 | 74.21 | 74.19 | 73.70 | 66.32 |
| STL10 | 77.59 | 77.65 | 77.70 | 77.59 | 73.40 |
| Alexnet | 55.70/78.42 | 55.66/78.41 | 55.54/78.39 | 54.17/77.29 | 18.19/36.25 |
| VGG16 | 70.44/89.43 | 70.45/89.43 | 70.44/89.33 | 69.99/89.17 | 53.33/76.32 |
| VGG19 | 71.36/89.94 | 71.35/89.93 | 71.34/89.88 | 70.88/89.62 | 56.00/78.62 |
| Resnet18 | 68.63/88.31 | 68.62/88.33 | 68.49/88.25 | 66.80/87.20 | 19.14/36.49 |
| Resnet34 | 72.50/90.86 | 72.46/90.82 | 72.45/90.85 | 71.47/90.00 | 32.25/55.71 |
| Resnet50 | 74.98/92.17 | 74.94/92.12 | 74.91/92.09 | 72.54/90.44 | 2.43/5.36 |
| Resnet101 | 76.69/93.30 | 76.66/93.25 | 76.22/92.90 | 65.69/79.54 | 1.41/1.18 |
| Resnet152 | 77.55/93.59 | 77.51/93.62 | 77.40/93.54 | 74.95/92.46 | 9.29/16.75 |
| SqueeZeNetv0 | 56,73/79.39 | 56,75/79.40 | 56,70/79.27 | 53.93/77.04 | 14.21/29.74 |
| SqueeZeNetv1 | 56.52/79.13 | 56.52/79.15 | 56.24/79.03 | 54.56/77.33 | 17.10/32.46 |
| EMCCONTROV3 | 76.41/92.78 | 76.43/92.71 | 76.44/92.73 | 73.67/91.34 | 1.50/4.82 |
Nota: Os modelos ImageNet 32-Float são diretamente da Torchvision
Aqui damos uma visão geral dos argumentos selecionados de quantize.py
| Bandeira | Valor padrão | Descrição e opções |
|---|---|---|
| tipo | Cifar10 | mnist, svhn, cifar10, cifar100, stl10, alexnet, vgg16, vgg16_bn, vgg19, vgg19_bn, ressent18, ressent34, resmet50, resmnet101, resset152, squeezenet_v0, squeezenet5, squeezenT1, squeezenT1, squeezenT1, squienT1, venet152, ressent34, ressenT15, ressent152, ressent34, ressenT1, squeezenT1, squeezenT1, squeezen |
| quant_method | linear | Método de quantização: linear, minmax, log, Tanh |
| param_bits | 8 | largura de pesos e preconceitos |
| fwd_bits | 8 | largura de bit de ativação |
| bn_bits | 32 | largura de bits de corrida média e vairraça |
| Overflow_rate | 0,0 | limiar de taxa de transbordamento para o método de quantização linear |
| n_samples | 20 | Número de amostras para fazer estatísticas para ativação |