Dieses Repository enthält eine offizielle Pytorch -Implementierung für das folgende Papier
Lerne effiziente Faltungsnetzwerke durch Netzwerkschleidigung (ICCV 2017).
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan, Changshui Zhang.
Originalimplementierung: in Taschenlampe abschneiden.
Der Code basiert auf Pytorch-Slimms. Wir fügen Unterstützung für Resnet und Densenet hinzu.
Zitat:
@InProceedings{Liu_2017_ICCV,
author = {Liu, Zhuang and Li, Jianguo and Shen, Zhiqiang and Huang, Gao and Yan, Shoumeng and Zhang, Changshui},
title = {Learning Efficient Convolutional Networks Through Network Slimming},
booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
month = {Oct},
year = {2017}
}
Torch v0.3.1, Torchvision v0.2.0
Wir führen channel selection ein, um das Beschneiden von ResNet und Densenet zu unterstützen. Diese Schicht ist einfach zu implementieren. Es speichert eine indexes , die in einen All-1-Vektor initialisiert wird. Während des Beschneidens werden einige Stellen auf 0 gesetzt, die den beschnittenen Kanälen entsprechen.
Das dataset -Argument gibt an, welcher Datensatz verwendet werden soll: cifar10 oder cifar100 . Das Argument arch gibt die zu verwendende Architektur an: vgg , resnet oder densenet . Die Tiefe wird so ausgewählt, dass sie die im Papier verwendeten Netzwerke sind.
python main.py --dataset cifar10 --arch vgg --depth 19
python main.py --dataset cifar10 --arch resnet --depth 164
python main.py --dataset cifar10 --arch densenet --depth 40python main.py -sr --s 0.0001 --dataset cifar10 --arch vgg --depth 19
python main.py -sr --s 0.00001 --dataset cifar10 --arch resnet --depth 164
python main.py -sr --s 0.00001 --dataset cifar10 --arch densenet --depth 40python vggprune.py --dataset cifar10 --depth 19 --percent 0.7 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT]
python resprune.py --dataset cifar10 --depth 164 --percent 0.4 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT]
python denseprune.py --dataset cifar10 --depth 40 --percent 0.4 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT] Das beschnittene Modell heißt pruned.pth.tar .
python main.py --refine [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch vgg --depth 19 --epochs 160Die Ergebnisse liegen ziemlich nahe am Originalpapier, dessen Ergebnisse durch Taschenlampe erzeugt werden. Beachten Sie, dass aufgrund verschiedener zufälliger Saatgut bis zu ~ 0,5%/1,5% Schwankungen bei CIFAR-10/100-Datensätzen in verschiedenen Läufen nach unseren Erfahrungen in verschiedenen Läufen ausmachen können.
| CIFAR10-VGG | Grundlinie | Sparsity (1E-4) | Prunk (70%) | Fein-Tune-160 (70%) |
|---|---|---|---|---|
| TOP1 -Genauigkeit (%) | 93.77 | 93.30 | 32.54 | 93.78 |
| Parameter | 20.04 m | 20.04 m | 2,25 m | 2,25 m |
| CIFAR10-RESNET-164 | Grundlinie | Sparsity (1E-5) | Pflaumen (40%) | Fein-Tune-160 (40%) | Pflaumen (60%) | Fein-Tune-160 (60%) |
|---|---|---|---|---|---|---|
| TOP1 -Genauigkeit (%) | 94.75 | 94.76 | 94.58 | 95.05 | 47.73 | 93.81 |
| Parameter | 1,71 m | 1,73 m | 1,45 m | 1,45 m | 1,12 m | 1,12 m |
| CIFAR10-Densenet-40 | Grundlinie | Sparsity (1E-5) | Pflaumen (40%) | Fein-Tune-160 (40%) | Pflaumen (60%) | Fein-Tune-160 (60%) |
|---|---|---|---|---|---|---|
| TOP1 -Genauigkeit (%) | 94.11 | 94.17 | 94.16 | 94.32 | 89,46 | 94.22 |
| Parameter | 1,07 m | 1,07 m | 0,69 m | 0,69 m | 0,49 m | 0,49 m |
| CIFAR100-VGG | Grundlinie | Sparsity (1E-4) | Prunk (50%) | Fein-Tune-160 (50%) |
|---|---|---|---|---|
| TOP1 -Genauigkeit (%) | 72.12 | 72.05 | 5.31 | 73,32 |
| Parameter | 20.04 m | 20.04 m | 4,93 m | 4,93 m |
| CIFAR100-RESNET-164 | Grundlinie | Sparsity (1E-5) | Pflaumen (40%) | Fein-Tune-160 (40%) | Pflaumen (60%) | Fein-Tune-160 (60%) |
|---|---|---|---|---|---|---|
| TOP1 -Genauigkeit (%) | 76,79 | 76,87 | 48.0 | 77,36 | --- | --- |
| Parameter | 1,73 m | 1,73 m | 1,49 m | 1,49 m | --- | --- |
Hinweis: Für die Ergebnisse des Beschneidens von 60% der Kanäle für ResNet164-CIFAR100 sind in dieser Implementierung manchmal einige Ebenen beschnitten und es würde Fehler geben. Wir bieten jedoch auch eine Maskenimplementierung, bei der wir eine Maske auf den Skalierungsfaktor in der BN -Schicht anwenden. Bei Mask Implementaion können wir beim Beschneiden von 60% der Kanäle in ResNet164-CIFAR100 auch das beschnittene Netzwerk trainieren.
| CIFAR100-Densenet-40 | Grundlinie | Sparsity (1E-5) | Pflaumen (40%) | Fein-Tune-160 (40%) | Pflaumen (60%) | Fein-Tune-160 (60%) |
|---|---|---|---|---|---|---|
| TOP1 -Genauigkeit (%) | 73,27 | 73,29 | 67,67 | 73,76 | 19.18 | 73.19 |
| Parameter | 1,10 m | 1,10 m | 0,71 m | 0,71 m | 0,50 m | 0,50 m |
Sunmj15 bei gmail.com liuzhuangthu unter gmail.com