pytorch playground
1.0.0
이것은 Pytorch 초보자를위한 놀이터로, 인기있는 데이터 세트에 사전 정의 된 모델이 포함되어 있습니다. 현재 우리는 지원합니다
다음은 MNIST 데이터 세트의 예입니다. 이것은 데이터 세트와 미리 훈련 된 모델을 자동으로 다운로드합니다.
import torch
from torch.autograd import Variable
from utee import selector
model_raw, ds_fetcher, is_imagenet = selector.select('mnist')
ds_val = ds_fetcher(batch_size=10, train=False, val=True)
for idx, (data, target) in enumerate(ds_val):
data = Variable(torch.FloatTensor(data)).cuda()
output = model_raw(data)
또한 MNIST에서 MLP 모델을 훈련 시키려면 python mnist/train.py 실행하십시오.
python3 setup.py develop --user
우리는 224x224x3 크기의 사전 계산 된 ImageNet Validation Dataset를 제공합니다. 먼저 이미지의 짧은 크기를 256으로 크기를 조정 한 다음 중앙에서 224x224 이미지를 자릅니다. 그런 다음 자른 이미지를 JPG 문자열로 인코딩하고 덤프하여 피클로 덤프합니다.
cd scriptval224_compressed.pkl (Tsinghua / Google Drive) 다운로드python convert.py (48g 메모리 필요, 감사합니다 @jnorwood)또한 선형 방법, Minmax 방법 및 비선형 방법을 포함한 여러 방법으로 이러한 모델을 지정된 비트 폭으로 정량화하는 간단한 데모를 제공합니다.
quantize --type cifar10 --quant_method linear --param_bits 8 --fwd_bits 8 --bn_bits 8 --ngpu 1
선형 양자화 된 방법으로 인기있는 데이터 세트 및 모델의 성능을 평가합니다. BN에서 실행 평균 및 실행 분산의 비트 폭은 모든 결과에 대해 10 비트입니다. (32 플로트 제외)
| 모델 | 32 플로트 | 12 비트 | 10 비트 | 8 비트 | 6 비트 |
|---|---|---|---|---|---|
| mnist | 98.42 | 98.43 | 98.44 | 98.44 | 98.32 |
| svhn | 96.03 | 96.03 | 96.04 | 96.02 | 95.46 |
| cifar10 | 93.78 | 93.79 | 93.80 | 93.58 | 90.86 |
| cifar100 | 74.27 | 74.21 | 74.19 | 73.70 | 66.32 |
| STL10 | 77.59 | 77.65 | 77.70 | 77.59 | 73.40 |
| Alexnet | 55.70/78.42 | 55.66/78.41 | 55.54/78.39 | 54.17/77.29 | 18.19/36.25 |
| vgg16 | 70.44/89.43 | 70.45/89.43 | 70.44/89.33 | 69.99/89.17 | 53.33/76.32 |
| vgg19 | 71.36/89.94 | 71.35/89.93 | 71.34/89.88 | 70.88/89.62 | 56.00/78.62 |
| RESNET18 | 68.63/88.31 | 68.62/88.33 | 68.49/88.25 | 66.80/87.20 | 19.14/36.49 |
| RESNET34 | 72.50/90.86 | 72.46/90.82 | 72.45/90.85 | 71.47/90.00 | 32.25/55.71 |
| RESNET50 | 74.98/92.17 | 74.94/92.12 | 74.91/92.09 | 72.54/90.44 | 2.43/5.36 |
| RESNET101 | 76.69/93.30 | 76.66/93.25 | 76.22/92.90 | 65.69/79.54 | 1.41/1.18 |
| RESNET152 | 77.55/93.59 | 77.51/93.62 | 77.40/93.54 | 74.95/92.46 | 9.29/16.75 |
| Squeezenetv0 | 56.73/79.39 | 56.75/79.40 | 56.70/79.27 | 53.93/77.04 | 14.21/29.74 |
| Squeezenetv1 | 56.52/79.13 | 56.52/79.15 | 56.24/79.03 | 54.56/77.33 | 17.10/32.46 |
| inceptionv3 | 76.41/92.78 | 76.43/92.71 | 76.44/92.73 | 73.67/91.34 | 1.50/4.82 |
참고 : ImageNet 32 플로트 모델은 Torchvision에서 직접 있습니다
여기서 우리는 quantize.py 의 선택된 인수에 대한 개요를 제공합니다.
| 깃발 | 기본값 | 설명 및 옵션 |
|---|---|---|
| 유형 | cifar10 | MNIST, SVHN, CIFAR10, CIFAR100, STL10, Alexnet, VGG16, VGG16_BN, VGG19, VGG19_BN, FENENT18, FENENT18, RESNET50, RESNET101, RESNET152, SQUEEZENET_V0, SQUEEZENET_V1, Incence_v3 |
| Quant_Method | 선의 | 양자화 방법 : 선형, Minmax, log, tanh |
| param_bits | 8 | 비트의 무게와 편견 |
| fwd_bits | 8 | 활성화의 비트 폭 |
| bn_bits | 32 | 실행중인 비트 폭과 실행 vairance |
| Overflow_rate | 0.0 | 선형 양자화 방법에 대한 오버 플로우 속도 임계 값 |
| n_samples | 20 | 활성화 통계를 만들기위한 샘플 수 |