このリポジトリには、次の論文の公式Pytorchの実装が含まれています
ネットワークスリミングを介した効率的な畳み込みネットワークの学習(ICCV 2017)。
Zhuang Liu、Jianguo Li、Zhiqiang Shen、Gao Huang、Shoumeng Yan、Changshui Zhang。
オリジナルの実装:トーチのスリミング。
このコードは、Pytorch-Slimmingに基づいています。 ResNetとDensenetのサポートを追加します。
引用:
@InProceedings{Liu_2017_ICCV,
author = {Liu, Zhuang and Li, Jianguo and Shen, Zhiqiang and Huang, Gao and Yan, Shoumeng and Zhang, Changshui},
title = {Learning Efficient Convolutional Networks Through Network Slimming},
booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
month = {Oct},
year = {2017}
}
Torch V0.3.1、Torchvision V0.2.0
channel selectionレイヤーを紹介して、ResNetとDensenetの剪定を支援します。このレイヤーは簡単に実装できます。 All-1ベクトルに初期化されたパラメーターindexesを保存します。剪定中、剪定されたチャネルに対応する場所にいくつかの場所を設定します。
datasetセット引数は、使用するデータセットを指定します: cifar10またはcifar100 。 arch引数は、使用するアーキテクチャを指定します: vgg 、 resnet 、またはdensenet 。深さは、論文で使用されているネットワークと同じであるように選択されます。
python main.py --dataset cifar10 --arch vgg --depth 19
python main.py --dataset cifar10 --arch resnet --depth 164
python main.py --dataset cifar10 --arch densenet --depth 40python main.py -sr --s 0.0001 --dataset cifar10 --arch vgg --depth 19
python main.py -sr --s 0.00001 --dataset cifar10 --arch resnet --depth 164
python main.py -sr --s 0.00001 --dataset cifar10 --arch densenet --depth 40python vggprune.py --dataset cifar10 --depth 19 --percent 0.7 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT]
python resprune.py --dataset cifar10 --depth 164 --percent 0.4 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT]
python denseprune.py --dataset cifar10 --depth 40 --percent 0.4 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT]剪定されたモデルは、 pruned.pth.tarという名前です。
python main.py --refine [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch vgg --depth 19 --epochs 160結果は元の論文にかなり近いもので、その結果はトーチによって作成されています。さまざまなランダムシードのため、私たちの経験によると、CIFAR-10/100のデータセットで最大0.5%/1.5%の変動がある可能性があることに注意してください。
| CIFAR10-VGG | ベースライン | スパース(1E-4) | プルーン(70%) | 微調整-160(70%) |
|---|---|---|---|---|
| TOP1精度(%) | 93.77 | 93.30 | 32.54 | 93.78 |
| パラメーター | 20.04m | 20.04m | 2.25m | 2.25m |
| CIFAR10-RESNET-164 | ベースライン | スパース(1E-5) | プルーン(40%) | 微調整-160(40%) | プルーン(60%) | 微調整-160(60%) |
|---|---|---|---|---|---|---|
| TOP1精度(%) | 94.75 | 94.76 | 94.58 | 95.05 | 47.73 | 93.81 |
| パラメーター | 1.71m | 1.73m | 1.45m | 1.45m | 1.12m | 1.12m |
| CIFAR10-DENSENET-40 | ベースライン | スパース(1E-5) | プルーン(40%) | 微調整-160(40%) | プルーン(60%) | 微調整-160(60%) |
|---|---|---|---|---|---|---|
| TOP1精度(%) | 94.11 | 94.17 | 94.16 | 94.32 | 89.46 | 94.22 |
| パラメーター | 1.07m | 1.07m | 0.69m | 0.69m | 0.49m | 0.49m |
| CIFAR100-VGG | ベースライン | スパース(1E-4) | プルーン(50%) | 微調整-160(50%) |
|---|---|---|---|---|
| TOP1精度(%) | 72.12 | 72.05 | 5.31 | 73.32 |
| パラメーター | 20.04m | 20.04m | 4.93m | 4.93m |
| CIFAR100-RESNET-164 | ベースライン | スパース(1E-5) | プルーン(40%) | 微調整-160(40%) | プルーン(60%) | 微調整-160(60%) |
|---|---|---|---|---|---|---|
| TOP1精度(%) | 76.79 | 76.87 | 48.0 | 77.36 | --- | --- |
| パラメーター | 1.73m | 1.73m | 1.49m | 1.49m | --- | --- |
注:ResNet164-Cifar100のチャネルの60%を剪定する結果、この実装では、一部のレイヤーがすべて剪定され、エラーが発生する場合があります。ただし、BNレイヤーのスケーリング係数にマスクを適用するマスクの実装も提供します。 Mask実装の場合、ResNet164-Cifar100のチャネルの60%を剪定する場合、剪定されたネットワークをトレーニングすることもできます。
| CIFAR100-DENSENET-40 | ベースライン | スパース(1E-5) | プルーン(40%) | 微調整-160(40%) | プルーン(60%) | 微調整-160(60%) |
|---|---|---|---|---|---|---|
| TOP1精度(%) | 73.27 | 73.29 | 67.67 | 73.76 | 19.18 | 73.19 |
| パラメーター | 1.10m | 1.10m | 0.71m | 0.71m | 0.50m | 0.50m |
gmail.comのsunmj15 liuzhuangthu at gmail.com