이 저장소에는 다음 논문의 공식 Pytorch 구현이 포함되어 있습니다.
네트워크 슬리밍 (ICCV 2017)을 통한 효율적인 컨볼 루션 네트워크 학습.
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan, Changshui Zhang.
원래 구현 : 횃불에서 슬리밍.
코드는 Pytorch-Slimming을 기반으로합니다. RESNET 및 Densenet에 대한 지원이 추가됩니다.
소환:
@InProceedings{Liu_2017_ICCV,
author = {Liu, Zhuang and Li, Jianguo and Shen, Zhiqiang and Huang, Gao and Yan, Shoumeng and Zhang, Changshui},
title = {Learning Efficient Convolutional Networks Through Network Slimming},
booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
month = {Oct},
year = {2017}
}
Torch V0.3.1, Torchvision V0.2.0
RESNET 및 Densenet의 가지 치기를 돕기 위해 channel selection 계층을 소개합니다. 이 레이어는 구현하기 쉽습니다. All-1 벡터에 초기화되는 매개 변수 indexes 저장합니다. 가지 치기 중에는 일부 장소를 가지 치기 채널에 해당하는 일부 장소를 0으로 설정합니다.
dataset 인수는 사용할 데이터 세트를 지정합니다 : cifar10 또는 cifar100 . arch 인수는 vgg , resnet 또는 densenet 사용할 아키텍처를 지정합니다. 깊이는 용지에 사용 된 네트워크와 동일하도록 선택됩니다.
python main.py --dataset cifar10 --arch vgg --depth 19
python main.py --dataset cifar10 --arch resnet --depth 164
python main.py --dataset cifar10 --arch densenet --depth 40python main.py -sr --s 0.0001 --dataset cifar10 --arch vgg --depth 19
python main.py -sr --s 0.00001 --dataset cifar10 --arch resnet --depth 164
python main.py -sr --s 0.00001 --dataset cifar10 --arch densenet --depth 40python vggprune.py --dataset cifar10 --depth 19 --percent 0.7 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT]
python resprune.py --dataset cifar10 --depth 164 --percent 0.4 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT]
python denseprune.py --dataset cifar10 --depth 40 --percent 0.4 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT] 가지 치기 모델의 이름은 pruned.pth.tar 입니다.
python main.py --refine [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch vgg --depth 19 --epochs 160결과는 원래 용지에 상당히 가깝습니다. 그 결과는 Torch에 의해 생성됩니다. 우리의 경험에 따르면 다른 무작위 종자로 인해 CIFAR-10/100 데이터 세트에 대해 최대 ~ 0.5%/1.5% 변동이있을 수 있습니다.
| cifar10-vgg | 기준선 | 희소성 (1E-4) | 자두 (70%) | 미세 조정 -160 (70%) |
|---|---|---|---|---|
| Top1 정확도 (%) | 93.77 | 93.30 | 32.54 | 93.78 |
| 매개 변수 | 20.04m | 20.04m | 2.25m | 2.25m |
| Cifar10-Resnet-164 | 기준선 | 희소성 (1E-5) | 자두 (40%) | 미세 조정 -160 (40%) | 자두 (60%) | 미세 조정 -160 (60%) |
|---|---|---|---|---|---|---|
| Top1 정확도 (%) | 94.75 | 94.76 | 94.58 | 95.05 | 47.73 | 93.81 |
| 매개 변수 | 1.71m | 1.73m | 1.45m | 1.45m | 1.12m | 1.12m |
| Cifar10-densenet-40 | 기준선 | 희소성 (1E-5) | 자두 (40%) | 미세 조정 -160 (40%) | 자두 (60%) | 미세 조정 -160 (60%) |
|---|---|---|---|---|---|---|
| Top1 정확도 (%) | 94.11 | 94.17 | 94.16 | 94.32 | 89.46 | 94.22 |
| 매개 변수 | 1.07m | 1.07m | 0.69m | 0.69m | 0.49m | 0.49m |
| CIFAR100-VGG | 기준선 | 희소성 (1E-4) | 자두 (50%) | 미세 조정 -160 (50%) |
|---|---|---|---|---|
| Top1 정확도 (%) | 72.12 | 72.05 | 5.31 | 73.32 |
| 매개 변수 | 20.04m | 20.04m | 4.93m | 4.93m |
| CIFAR100-RESNET-164 | 기준선 | 희소성 (1E-5) | 자두 (40%) | 미세 조정 -160 (40%) | 자두 (60%) | 미세 조정 -160 (60%) |
|---|---|---|---|---|---|---|
| Top1 정확도 (%) | 76.79 | 76.87 | 48.0 | 77.36 | --- | --- |
| 매개 변수 | 1.73m | 1.73m | 1.49m | 1.49m | --- | --- |
참고 : RESNET164-CIFAR100에 대한 채널의 60%를 가지 치기 결과의 경우이 구현에서는 일부 레이어가 모두 정리되어 오류가 발생합니다. 그러나 우리는 BN 층의 스케일링 계수에 마스크를 적용하는 마스크 구현도 제공합니다. 마스크 구현의 경우, RESNET164-CIFAR100에서 채널의 60%를 가지 치기 할 때 가지 치기 네트워크를 훈련시킬 수도 있습니다.
| Cifar100-densenet-40 | 기준선 | 희소성 (1E-5) | 자두 (40%) | 미세 조정 -160 (40%) | 자두 (60%) | 미세 조정 -160 (60%) |
|---|---|---|---|---|---|---|
| Top1 정확도 (%) | 73.27 | 73.29 | 67.67 | 73.76 | 19.18 | 73.19 |
| 매개 변수 | 1.10m | 1.10m | 0.71m | 0.71m | 0.50m | 0.50m |
gmail.com의 sunmj15 gmail.com의 liuzhuangthu