network slimming
1.0.0
該存儲庫包含以下論文的官方Pytorch實施
通過網絡減少學習有效的捲積網絡(ICCV 2017)。
Zhuang Liu,Jianguo Li,Zhiqiang Shen,Gao Huang,Shoumeng Yan,Changshui Zhang。
原始實施:在火炬中減肥。
該代碼基於Pytorch-Smlimming。我們增加了對Resnet和Densenet的支持。
引用:
@InProceedings{Liu_2017_ICCV,
author = {Liu, Zhuang and Li, Jianguo and Shen, Zhiqiang and Huang, Gao and Yan, Shoumeng and Zhang, Changshui},
title = {Learning Efficient Convolutional Networks Through Network Slimming},
booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
month = {Oct},
year = {2017}
}
Torch v0.3.1,Torchvision V0.2.0
我們引入了channel selection層,以幫助修剪Resnet和Densenet。該層易於實現。它存儲一個參數indexes該參數索引初始化為All-1向量。在修剪過程中,它將設置一些與修剪通道相對應的位置。
dataset集參數指定要使用的數據集: cifar10或cifar100 。 arch參數指定要使用的體系結構: vgg , resnet或densenet 。選擇深度與論文中使用的網絡相同。
python main.py --dataset cifar10 --arch vgg --depth 19
python main.py --dataset cifar10 --arch resnet --depth 164
python main.py --dataset cifar10 --arch densenet --depth 40python main.py -sr --s 0.0001 --dataset cifar10 --arch vgg --depth 19
python main.py -sr --s 0.00001 --dataset cifar10 --arch resnet --depth 164
python main.py -sr --s 0.00001 --dataset cifar10 --arch densenet --depth 40python vggprune.py --dataset cifar10 --depth 19 --percent 0.7 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT]
python resprune.py --dataset cifar10 --depth 164 --percent 0.4 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT]
python denseprune.py --dataset cifar10 --depth 40 --percent 0.4 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT]修剪模型將被命名為pruned.pth.tar 。
python main.py --refine [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch vgg --depth 19 --epochs 160結果非常接近原始論文,其結果是由火炬產生的。請注意,根據我們的經驗,由於不同的隨機種子,CIFAR-10/100數據集的波動可能高達約0.5%/1.5%。
| CIFAR10-VGG | 基線 | 稀疏(1E-4) | 修剪(70%) | 微調160(70%) |
|---|---|---|---|---|
| TOP1精度(%) | 93.77 | 93.30 | 32.54 | 93.78 |
| 參數 | 20.04m | 20.04m | 225萬 | 225萬 |
| CIFAR10-RESNET-164 | 基線 | 稀疏(1E-5) | 修剪(40%) | 微調160(40%) | 修剪(60%) | 微調160(60%) |
|---|---|---|---|---|---|---|
| TOP1精度(%) | 94.75 | 94.76 | 94.58 | 95.05 | 47.73 | 93.81 |
| 參數 | 171m | 173萬 | 145m | 145m | 11.12m | 11.12m |
| CIFAR10-DENSENET-40 | 基線 | 稀疏(1E-5) | 修剪(40%) | 微調160(40%) | 修剪(60%) | 微調160(60%) |
|---|---|---|---|---|---|---|
| TOP1精度(%) | 94.11 | 94.17 | 94.16 | 94.32 | 89.46 | 94.22 |
| 參數 | 1.07m | 1.07m | 0.69m | 0.69m | 4.49m | 4.49m |
| cifar100-vgg | 基線 | 稀疏(1E-4) | 修剪(50%) | 微調160(50%) |
|---|---|---|---|---|
| TOP1精度(%) | 72.12 | 72.05 | 5.31 | 73.32 |
| 參數 | 20.04m | 20.04m | 4.93m | 4.93m |
| CIFAR100-RESNET-164 | 基線 | 稀疏(1E-5) | 修剪(40%) | 微調160(40%) | 修剪(60%) | 微調160(60%) |
|---|---|---|---|---|---|---|
| TOP1精度(%) | 76.79 | 76.87 | 48.0 | 77.36 | --- | --- |
| 參數 | 173萬 | 173萬 | 149m | 149m | --- | --- |
注意:對於將60%的RESNET164-CIFAR100通道修剪的結果,在此實現中,有時有些層都被修剪,並且會出現錯誤。但是,我們還提供了掩模實現,在其中將掩碼應用於BN層中的縮放係數。對於Mask enimallaion,在Resnet164-Cifar100中修剪60%的頻道時,我們還可以訓練修剪的網絡。
| CIFAR100-DENSENET-40 | 基線 | 稀疏(1E-5) | 修剪(40%) | 微調160(40%) | 修剪(60%) | 微調160(60%) |
|---|---|---|---|---|---|---|
| TOP1精度(%) | 73.27 | 73.29 | 67.67 | 73.76 | 19.18 | 73.19 |
| 參數 | 1.10m | 1.10m | 0.71m | 0.71m | 0.50m | 0.50m |
sunmj15在gmail.com上liuzhuangthu at gmail.com