Este es un patio de recreo para los principiantes de Pytorch, que contiene modelos predefinidos en el conjunto de datos popular. Actualmente apoyamos
Aquí hay un ejemplo para el conjunto de datos MNIST. Esto descargará el modelo de datos y el modelo previamente capacitado automáticamente.
import torch
from torch.autograd import Variable
from utee import selector
model_raw, ds_fetcher, is_imagenet = selector.select('mnist')
ds_val = ds_fetcher(batch_size=10, train=False, val=True)
for idx, (data, target) in enumerate(ds_val):
data = Variable(torch.FloatTensor(data)).cuda()
output = model_raw(data)
Además, si desea entrenar el modelo MLP en MNIST, simplemente ejecute python mnist/train.py
python3 setup.py develop --user
Proporcionamos un conjunto de datos de validación de Imagenet precomputado con un tamaño 224x224x3. Primero cambia el tamaño del tamaño más corto de la imagen a 256, luego recortamos la imagen 224x224 en el centro. Luego codificamos las imágenes recortadas en la cadena JPG y volcamos en Pickle.
cd scriptval224_compressed.pkl (Tsinghua / Google Drive)python convert.py (necesita memoria 48G, gracias @jnorwood)También proporcionamos una demostración simple para cuantificar estos modelos al ancho de bits especificado con varios métodos, incluido el método lineal, el método Minmax y el método no lineal.
quantize --type cifar10 --quant_method linear --param_bits 8 --fwd_bits 8 --bn_bits 8 --ngpu 1
Evaluamos el rendimiento del conjunto de datos y los modelos populares con método cuantificado lineal. El ancho de bit de la media de ejecución y la varianza en ejecución en BN son 10 bits para todos los resultados. (Excepto por 32 pisos)
| Modelo | 32 pisos | De 12 bits | De 10 bits | De 8 bits | De 6 bits |
|---|---|---|---|---|---|
| Mnista | 98.42 | 98.43 | 98.44 | 98.44 | 98.32 |
| Svhn | 96.03 | 96.03 | 96.04 | 96.02 | 95.46 |
| Cifar10 | 93.78 | 93.79 | 93.80 | 93.58 | 90.86 |
| Cifar100 | 74.27 | 74.21 | 74.19 | 73.70 | 66.32 |
| Stl10 | 77.59 | 77.65 | 77.70 | 77.59 | 73.40 |
| Alexnet | 55.70/78.42 | 55.66/78.41 | 55.54/78.39 | 54.17/77.29 | 18.19/36.25 |
| VGG16 | 70.44/89.43 | 70.45/89.43 | 70.44/89.33 | 69.99/89.17 | 53.33/76.32 |
| VGG19 | 71.36/89.94 | 71.35/89.93 | 71.34/89.88 | 70.88/89.62 | 56.00/78.62 |
| Resnet18 | 68.63/88.31 | 68.62/88.33 | 68.49/88.25 | 66.80/87.20 | 19.14/36.49 |
| Resnet34 | 72.50/90.86 | 72.46/90.82 | 72.45/90.85 | 71.47/90.00 | 32.25/55.71 |
| Resnet50 | 74.98/92.17 | 74.94/92.12 | 74.91/92.09 | 72.54/90.44 | 2.43/5.36 |
| Resnet101 | 76.69/93.30 | 76.66/93.25 | 76.22/92.90 | 65.69/79.54 | 1.41/1.18 |
| Resnet152 | 77.55/93.59 | 77.51/93.62 | 77.40/93.54 | 74.95/92.46 | 9.29/16.75 |
| Squeezenetv0 | 56.73/79.39 | 56.75/79.40 | 56.70/79.27 | 53.93/77.04 | 14.21/29.74 |
| Squeezenetv1 | 56.52/79.13 | 56.52/79.15 | 56.24/79.03 | 54.56/77.33 | 17.10/32.46 |
| Inceptionv3 | 76.41/92.78 | 76.43/92.71 | 76.44/92.73 | 73.67/91.34 | 1.50/4.82 |
Nota: Los modelos de 32 flotos de Imagenet son directamente de TorchVision
Aquí damos una visión general de los argumentos seleccionados de quantize.py
| Bandera | Valor predeterminado | Descripción y opciones |
|---|---|---|
| tipo | cifar10 | Mnist, SVHN, CIFAR10, CIFAR100, STL10, Alexnet, VGG16, VGG16_BN, VGG19, VGG19_BN, Resent18, Resent34, ResNet50, ResNet101, ResNet152, Sideezenet_V0, Speezenet_V1, Inception_V3 |
| cuant_method | lineal | Método de cuantización: Lineal, Minmax, Log, Tanh |
| param_bits | 8 | ancho de bit de pesas y sesgo |
| FWD_BITS | 8 | ancho de bits de la activación |
| bn_bits | 32 | Bit-width de la vanguardia en funcionamiento |
| desbordamiento | 0.0 | Umbral de velocidad de desbordamiento para el método de cuantización lineal |
| n_samples | 20 | Número de muestras para hacer estadísticas para la activación |