Uma implementação da Senet, proposta em redes de aperto e excitação por Jie Hu, Li Shen e Gang Sun, que são os vencedores do concurso de classificação do ILSVRC 2017.
Agora, são implementados se-resnnet (18, 34, 50, 101, 152/20, 32) e SE-Inception-V3.
python cifar.py é executado se-resseNT20 com o conjunto de dados CIFAR10.
python imagenet.py e python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} imagenet.py execute se -resnnet50 com o conjunto de dados ImageNet (2012),
~/.torch/data ou definir uma variável enx da imagem IMAGENET_ROOT=${PATH_TO_YOUR_IMAGENET}CUDA_VISIBLE_DEVICES . (por exemplo, CUDA_VISIBLE_DEVICES=1,2 para usar a GPU 1 e 2)Para SE-CECPEIÇÃO-V3, o tamanho da entrada é necessário para ser 299x299 como o início original.
A base de código é testada na configuração a seguir.
Para executar cifar.py ou imagenet.py , você precisa
pip install git+https://github.com/moskomule/[email protected] Você pode usar algumas segundas se_resnet{20, 56, 50, 101} ) via torch.hub .
import torch . hub
hub_model = torch . hub . load (
'moskomule/senet.pytorch' ,
'se_resnet20' ,
num_classes = 10 )Além disso, um modelo SE-RESNET50 pré-treinamento está disponível.
import torch . hub
hub_model = torch . hub . load (
'moskomule/senet.pytorch' ,
'se_resnet50' ,
pretrained = True ,) python cifar.py [--baseline]
Observe que o conjunto de dados CIFAR-10 deve estar em ~/.torch/data .
| Resnet20 | SE-RESNET20 (Redução 4 ou 8) | |
|---|---|---|
| máx. precisão do teste | 92% | 93% |
python [-m torch.distributed.launch --nproc_per_node=${NUM_GPUS}] imagenet.py
A opção [-m ...] é para treinamento distribuído. Observe que o conjunto de dados ImageNet deve estar em ~/.torch/data ou especificado como IMAGENET_ROOT=${PATH_TO_IMAGENET} .
A taxa de aprendizado inicial e o tamanho do mini-lote são diferentes da versão original devido ao meu recurso computacional .
| Resnet | SE-RESNET | |
|---|---|---|
| máx. precisão do teste (TOP1) | 76,15 %(*) | 77,06% (**) |
(*): Resnet-50 na Torchvision
(**): Ao usar imagenet.py com a configuração --distributed em 8 GPUs. O peso está disponível.
# !wget https://github.com/moskomule/senet.pytorch/releases/download/archive/seresnet50-60a8950a85b2b.pkl
senet = se_resnet50 ( num_classes = 1000 )
senet . load_state_dict ( torch . load ( "seresnet50-60a8950a85b2b.pkl" ))Não posso manter esse repositório ativamente, mas quaisquer contribuições são bem -vindas. Sinta -se à vontade para enviar PRs e problemas.
papel
Implementação de Caffe dos autores