Implementasi Senet, yang diusulkan dalam jaringan pemerasan-dan-eksitasi oleh Jie Hu, Li Shen dan Gang Sun, yang merupakan pemenang Kompetisi Klasifikasi ILSVRC 2017.
Sekarang SE-RESNET (18, 34, 50, 101, 152/20, 32) dan SE-Inception-V3 diimplementasikan.
python cifar.py menjalankan SE-RESNET20 dengan dataset CIFAR10.
python imagenet.py dan python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} imagenet.py run se -resnet50 dengan imagenet (2012) dataset,
~/.torch/data atau mengatur variabel lingkungan IMAGENET_ROOT=${PATH_TO_YOUR_IMAGENET}CUDA_VISIBLE_DEVICES . (misalnya CUDA_VISIBLE_DEVICES=1,2 untuk menggunakan GPU 1 dan 2)Untuk SE-INCEPTION-V3, ukuran input harus 299x299 sebagai awal asli.
Basis kode diuji pada pengaturan berikut.
Untuk menjalankan cifar.py atau imagenet.py , Anda perlu
pip install git+https://github.com/moskomule/[email protected] Anda dapat menggunakan beberapa SE-RESNET ( se_resnet{20, 56, 50, 101} ) melalui torch.hub .
import torch . hub
hub_model = torch . hub . load (
'moskomule/senet.pytorch' ,
'se_resnet20' ,
num_classes = 10 )Juga, model SE-RESNET50 pretrained tersedia.
import torch . hub
hub_model = torch . hub . load (
'moskomule/senet.pytorch' ,
'se_resnet50' ,
pretrained = True ,) python cifar.py [--baseline]
Perhatikan bahwa dataset CIFAR-10 diharapkan berada di bawah ~/.torch/data .
| Resnet20 | SE-RESNET20 (Pengurangan 4 atau 8) | |
|---|---|---|
| Max. akurasi uji | 92% | 93% |
python [-m torch.distributed.launch --nproc_per_node=${NUM_GPUS}] imagenet.py
Opsi [-m ...] adalah untuk pelatihan terdistribusi. Perhatikan bahwa dataset ImageNet diharapkan berada di bawah ~/.torch/data atau ditentukan sebagai IMAGENET_ROOT=${PATH_TO_IMAGENET} .
Tingkat pembelajaran awal dan ukuran mini-batch berbeda dari versi asli karena sumber daya komputasi saya .
| Resnet | SE-RESNET | |
|---|---|---|
| Max. Akurasi Tes (Top1) | 76.15 %(*) | 77,06% (**) |
(*): Resnet-50 di obor
(**): Saat menggunakan imagenet.py dengan pengaturan --distributed pada 8 GPU. Beratnya tersedia.
# !wget https://github.com/moskomule/senet.pytorch/releases/download/archive/seresnet50-60a8950a85b2b.pkl
senet = se_resnet50 ( num_classes = 1000 )
senet . load_state_dict ( torch . load ( "seresnet50-60a8950a85b2b.pkl" ))Saya tidak dapat mempertahankan repositori ini secara aktif, tetapi kontribusi apa pun dipersilakan. Jangan ragu untuk mengirim PR dan masalah.
kertas
Implementasi Caffe Penulis