Una implementación de Senet, propuesta en redes Squeeze and-Excitation por Jie Hu, Li Shen y Gang Sun, quienes son los ganadores de la competencia de clasificación ILSVRC 2017.
Ahora se implementan SE-Resnet (18, 34, 50, 101, 152/20, 32) y SE-Inception-V3.
python cifar.py ejecuta SE-Resnet20 con el conjunto de datos CIFAR10.
python imagenet.py y python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} imagenet.py ejecutar se -resnet50 con el conjunto de datos de ImageNet (2012),
~/.torch/data o establecer una variable de entorno IMAGENET_ROOT=${PATH_TO_YOUR_IMAGENET}CUDA_VISIBLE_DEVICES Variable. (por ejemplo, CUDA_VISIBLE_DEVICES=1,2 para usar GPU 1 y 2)Para SE-Inception-V3, se requiere que el tamaño de entrada sea 299x299 como el inicio original.
La base de código se prueba en la siguiente configuración.
Para ejecutar cifar.py o imagenet.py , necesita
pip install git+https://github.com/moskomule/[email protected] Puede usar algunos SE-Resnet ( se_resnet{20, 56, 50, 101} ) a través de torch.hub .
import torch . hub
hub_model = torch . hub . load (
'moskomule/senet.pytorch' ,
'se_resnet20' ,
num_classes = 10 )Además, hay un modelo SE-Resnet50 previamente previamente disponible.
import torch . hub
hub_model = torch . hub . load (
'moskomule/senet.pytorch' ,
'se_resnet50' ,
pretrained = True ,) python cifar.py [--baseline]
Tenga en cuenta que el conjunto de datos CIFAR-10 esperaba estar bajo ~/.torch/data .
| Resnet20 | SE-Resnet20 (reducción 4 u 8) | |
|---|---|---|
| Max. precisión de la prueba | 92% | 93% |
python [-m torch.distributed.launch --nproc_per_node=${NUM_GPUS}] imagenet.py
La opción [-m ...] es para capacitación distribuida. Tenga en cuenta que se espera que el conjunto de datos ImageNet esté en ~/.torch/data o especificado como IMAGENET_ROOT=${PATH_TO_IMAGENET} .
La tasa de aprendizaje inicial y el tamaño de mini lotes son diferentes de la versión original debido a mi recurso computacional .
| Resnet | Reinita SE | |
|---|---|---|
| Max. Prueba de prueba (Top1) | 76.15 %(*) | 77.06% (**) |
(*): Resnet-50 en atorchvision
(**): cuando se usa imagenet.py con la configuración --distributed en 8 GPU. El peso está disponible.
# !wget https://github.com/moskomule/senet.pytorch/releases/download/archive/seresnet50-60a8950a85b2b.pkl
senet = se_resnet50 ( num_classes = 1000 )
senet . load_state_dict ( torch . load ( "seresnet50-60a8950a85b2b.pkl" ))No puedo mantener este repositorio activamente, pero cualquier contribución es bienvenida. No dude en enviar PRS y problemas.
papel
Implementación de cafe de autores