Este repositório contém uma implementação oficial do Pytorch para o seguinte artigo
Aprendendo redes convolucionais eficientes através do Slimming de Rede (ICCV 2017).
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan, Changshui Zhang.
Implementação original: Slimming in Torch.
O código é baseado em pytorch que lança. Adicionamos suporte para Resnet e Densenet.
Citação:
@InProceedings{Liu_2017_ICCV,
author = {Liu, Zhuang and Li, Jianguo and Shen, Zhiqiang and Huang, Gao and Yan, Shoumeng and Zhang, Changshui},
title = {Learning Efficient Convolutional Networks Through Network Slimming},
booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
month = {Oct},
year = {2017}
}
Torch v0.3.1, Torchvision v0.2.0
Introduzimos a camada channel selection para ajudar a poda do Resnet e Densenet. Essa camada é fácil de implementar. Ele armazena um indexes parâmetros que é inicializado em um vetor All-1. Durante a poda, ele definirá alguns lugares para 0 que correspondem aos canais podados.
O argumento dataset especifica qual conjunto de dados usar: cifar10 ou cifar100 . O argumento arch especifica a arquitetura a ser usada: vgg , resnet ou densenet . A profundidade é escolhida para ser a mesma das redes usadas no papel.
python main.py --dataset cifar10 --arch vgg --depth 19
python main.py --dataset cifar10 --arch resnet --depth 164
python main.py --dataset cifar10 --arch densenet --depth 40python main.py -sr --s 0.0001 --dataset cifar10 --arch vgg --depth 19
python main.py -sr --s 0.00001 --dataset cifar10 --arch resnet --depth 164
python main.py -sr --s 0.00001 --dataset cifar10 --arch densenet --depth 40python vggprune.py --dataset cifar10 --depth 19 --percent 0.7 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT]
python resprune.py --dataset cifar10 --depth 164 --percent 0.4 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT]
python denseprune.py --dataset cifar10 --depth 40 --percent 0.4 --model [PATH TO THE MODEL] --save [DIRECTORY TO STORE RESULT] O modelo podado será nomeado pruned.pth.tar .
python main.py --refine [PATH TO THE PRUNED MODEL] --dataset cifar10 --arch vgg --depth 19 --epochs 160Os resultados estão bastante próximos do artigo original, cujos resultados são produzidos pela tocha. Observe que, devido a diferentes sementes aleatórias, pode haver uma flutuação de até 0,5%/1,5% nos conjuntos de dados CIFAR-10/100 em diferentes execuções, de acordo com nossas experiências.
| Cifar10-VGG | Linha de base | Sparsidade (1E-4) | Poda (70%) | Fine-tune-160 (70%) |
|---|---|---|---|---|
| TOP1 Precisão (%) | 93.77 | 93.30 | 32.54 | 93.78 |
| Parâmetros | 20.04M | 20.04M | 2,25m | 2,25m |
| CIFAR10-RESNET-164 | Linha de base | Sparsidade (1E-5) | Poda (40%) | Fine-tune-160 (40%) | Poda (60%) | Tune-160 fino (60%) |
|---|---|---|---|---|---|---|
| TOP1 Precisão (%) | 94.75 | 94.76 | 94.58 | 95.05 | 47.73 | 93.81 |
| Parâmetros | 1,71m | 1,73m | 1,45m | 1,45m | 1.12m | 1.12m |
| Cifar10-densenet-40 | Linha de base | Sparsidade (1E-5) | Poda (40%) | Fine-tune-160 (40%) | Poda (60%) | Tune-160 fino (60%) |
|---|---|---|---|---|---|---|
| TOP1 Precisão (%) | 94.11 | 94.17 | 94.16 | 94.32 | 89.46 | 94.22 |
| Parâmetros | 1,07m | 1,07m | 0,69m | 0,69m | 0,49m | 0,49m |
| CIFAR100-VGG | Linha de base | Sparsidade (1E-4) | Poda (50%) | Fine-tune-160 (50%) |
|---|---|---|---|---|
| TOP1 Precisão (%) | 72.12 | 72.05 | 5.31 | 73.32 |
| Parâmetros | 20.04M | 20.04M | 4,93m | 4,93m |
| CIFAR100-RESNET-164 | Linha de base | Sparsidade (1E-5) | Poda (40%) | Fine-tune-160 (40%) | Poda (60%) | Tune-160 fino (60%) |
|---|---|---|---|---|---|---|
| TOP1 Precisão (%) | 76.79 | 76.87 | 48.0 | 77.36 | ---- | ---- |
| Parâmetros | 1,73m | 1,73m | 1,49m | 1,49m | ---- | ---- |
NOTA: Para obter resultados de poda 60% dos canais para resnet164-cifar100, nesta implementação, às vezes algumas camadas são podadas e haveria erro. No entanto, também fornecemos uma implementação de máscara na qual aplicamos uma máscara ao fator de escala na camada BN. Para a implementação de máscara, ao podar 60% dos canais no Resnet164-CIFAR100, também podemos treinar a rede podada.
| CIFAR100-DENSENET-40 | Linha de base | Sparsidade (1E-5) | Poda (40%) | Fine-tune-160 (40%) | Poda (60%) | Tune-160 fino (60%) |
|---|---|---|---|---|---|---|
| TOP1 Precisão (%) | 73.27 | 73.29 | 67.67 | 73.76 | 19.18 | 73.19 |
| Parâmetros | 1.10m | 1.10m | 0,71m | 0,71m | 0,50m | 0,50m |
Sunmj15 em gmail.com Liuzhuangthu em gmail.com