Este repositorio contiene una implementación oficial de Pytorch para el siguiente documento
Aprender redes convolucionales eficientes a través de la adelgazamiento de la red (ICCV 2017).
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan, Changshui Zhang.
Implementación original: adelgazamiento en la antorcha.
El código se basa en la bola de pytorch. Agregamos soporte para resnet y densenet.
Citación:
@InProceedings{Liu_2017_ICCV,
author = {Liu, Zhuang and Li, Jianguo and Shen, Zhiqiang and Huang, Gao and Yan, Shoumeng and Zhang, Changshui},
title = {Learning Efficient Convolutional Networks Through Network Slimming},
booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
month = {Oct},
year = {2017}
}
Torch V0.3.1, TorchVision v0.2.0
Introducimos la capa channel selection para ayudar a la poda de resnet y densenet. Esta capa es fácil de implementar. Almacena los indexes parámetros que se inicializan en un vector all-1. Durante la poda, establecerá algunos lugares a 0 que correspondan a los canales podados.
El argumento dataset especifica qué conjunto de datos usar: cifar10 o cifar100 . El argumento arch especifica la arquitectura que usará: vgg , resnet o densenet . La profundidad se elige para ser la misma que las redes utilizadas en el papel.
python main.py --dataset cifar10 --arch vgg --depth 19
python main.py --dataset cifar10 --arch resnet --depth 164
python main.py --dataset cifar10 --arch densenet --depth 40python main.py -sr --s 0.0001 --dataset cifar10 --arch vgg --depth 19
python main.py -sr --s 0.00001 --dataset cifar10 --arch resnet --depth 164
python main.py -sr --s 0.00001 --dataset cifar10 --arch densenet --depth 40python vggprune.py --dataset cifar10 --depth 19 --percent 0.7 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT]
python resprune.py --dataset cifar10 --depth 164 --percent 0.4 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT]
python denseprune.py --dataset cifar10 --depth 40 --percent 0.4 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT] El modelo podado se nombrará pruned.pth.tar .
python main.py --refine [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch vgg --depth 19 --epochs 160Los resultados están bastante cerca del artículo original, cuyos resultados son producidos por la antorcha. Tenga en cuenta que debido a diferentes semillas aleatorias, puede haber hasta ~ 0.5%/1.5% de fluctuación en conjuntos de datos CIFAR-10/100 en diferentes ejecuciones, según nuestras experiencias.
| Cifar10-vgg | Base | Escasez (1e-4) | Prune (70%) | FIN-TUNE-160 (70%) |
|---|---|---|---|---|
| Top1 precisión (%) | 93.77 | 93.30 | 32.54 | 93.78 |
| Parámetros | 20.04m | 20.04m | 2.25m | 2.25m |
| CIFAR10-RESNET-164 | Base | Escasez (1e-5) | Poda (40%) | Fina-160 (40%) | Poda (60%) | Fina-160 (60%) |
|---|---|---|---|---|---|---|
| Top1 precisión (%) | 94.75 | 94.76 | 94.58 | 95.05 | 47.73 | 93.81 |
| Parámetros | 1.71m | 1.73m | 1.45m | 1.45m | 1.12m | 1.12m |
| Cifar10-densenet-40 | Base | Escasez (1e-5) | Poda (40%) | Fina-160 (40%) | Poda (60%) | Fina-160 (60%) |
|---|---|---|---|---|---|---|
| Top1 precisión (%) | 94.11 | 94.17 | 94.16 | 94.32 | 89.46 | 94.22 |
| Parámetros | 1.07m | 1.07m | 0.69m | 0.69m | 0.49m | 0.49m |
| Cifar100-vgg | Base | Escasez (1e-4) | Poda (50%) | Fina-160 (50%) |
|---|---|---|---|---|
| Top1 precisión (%) | 72.12 | 72.05 | 5.31 | 73.32 |
| Parámetros | 20.04m | 20.04m | 4.93m | 4.93m |
| CIFAR100-RESNET-164 | Base | Escasez (1e-5) | Poda (40%) | Fina-160 (40%) | Poda (60%) | Fina-160 (60%) |
|---|---|---|---|---|---|---|
| Top1 precisión (%) | 76.79 | 76.87 | 48.0 | 77.36 | --- | --- |
| Parámetros | 1.73m | 1.73m | 1.49m | 1.49m | --- | --- |
Nota: Para los resultados de la poda del 60% de los canales para ResNet164-CIFAR100, en esta implementación, a veces algunas capas están podadas y habría error. Sin embargo, también proporcionamos una implementación de máscara donde aplicamos una máscara al factor de escala en la capa BN. Para la implementación de máscara, al podar el 60% de los canales en Resnet164-CIFAR100, también podemos entrenar la red podada.
| Cifar100-densenet-40 | Base | Escasez (1e-5) | Poda (40%) | Fina-160 (40%) | Poda (60%) | Fina-160 (60%) |
|---|---|---|---|---|---|---|
| Top1 precisión (%) | 73.27 | 73.29 | 67.67 | 73.76 | 19.18 | 73.19 |
| Parámetros | 1.10m | 1.10m | 0.71m | 0.71m | 0.50m | 0.50m |
SunMJ15 en Gmail.com Liuzhuangthu en Gmail.com