การดำเนินการของ Senet ที่เสนอใน เครือข่ายการบีบและการกระตุ้น โดย Jie Hu, Li Shen และ Gang Sun ซึ่งเป็นผู้ชนะของการแข่งขันการจำแนกประเภท ILSVRC 2017
ตอนนี้ SE-RESNET (18, 34, 50, 101, 152/20, 32) และ SE-Inception-V3 ถูกนำมาใช้
python cifar.py เรียกใช้ SE-RESNET20 พร้อมชุดข้อมูล CIFAR10
python imagenet.py และ python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} imagenet.py run se -resnet50 พร้อม Imagenet (2012) ชุดข้อมูล
~/.torch/data หรือตั้งค่าตัวแปรสภาพแวดล้อม IMAGENET_ROOT=${PATH_TO_YOUR_IMAGENET}CUDA_VISIBLE_DEVICES (เช่น CUDA_VISIBLE_DEVICES=1,2 เพื่อใช้ GPU 1 และ 2)สำหรับ SE-Inception-V3 จำเป็นต้องมีขนาดอินพุตเป็น 299x299 เป็นจุดเริ่มต้นดั้งเดิม
codebase ถูกทดสอบในการตั้งค่าต่อไปนี้
ในการเรียกใช้ cifar.py หรือ imagenet.py คุณต้องการ
pip install git+https://github.com/moskomule/[email protected] คุณสามารถใช้ SE-RESNET ( se_resnet{20, 56, 50, 101} ) ผ่าน torch.hub
import torch . hub
hub_model = torch . hub . load (
'moskomule/senet.pytorch' ,
'se_resnet20' ,
num_classes = 10 )นอกจากนี้ยังมีรุ่น SE-RESNET50 ที่ได้รับการฝึกฝนไว้
import torch . hub
hub_model = torch . hub . load (
'moskomule/senet.pytorch' ,
'se_resnet50' ,
pretrained = True ,) python cifar.py [--baseline]
โปรดทราบว่าชุดข้อมูล CIFAR-10 คาดว่าจะอยู่ภายใต้ ~/.torch/data
| resnet20 | SE-RESNET20 (ลด 4 หรือ 8) | |
|---|---|---|
| สูงสุด ทดสอบความแม่นยำ | 92% | 93% |
python [-m torch.distributed.launch --nproc_per_node=${NUM_GPUS}] imagenet.py
ตัวเลือก [-m ... ] ใช้สำหรับการฝึกอบรมแบบกระจาย โปรดทราบว่าชุดข้อมูล Imagenet คาดว่าจะอยู่ภายใต้ ~/.torch/data หรือระบุเป็น IMAGENET_ROOT=${PATH_TO_IMAGENET}
อัตราการเรียนรู้เริ่มต้นและขนาดมินิแบทช์นั้นแตกต่างจากรุ่นดั้งเดิมเนื่องจากทรัพยากรการคำนวณของฉัน
| resnet | Se-Resnet | |
|---|---|---|
| สูงสุด ความแม่นยำในการทดสอบ (Top1) | 76.15 %(*) | 77.06% (**) |
(*): resnet-50 ใน Torchvision
(**): เมื่อใช้ imagenet.py กับการตั้งค่า --distributed ใน 8 GPU น้ำหนักมีอยู่
# !wget https://github.com/moskomule/senet.pytorch/releases/download/archive/seresnet50-60a8950a85b2b.pkl
senet = se_resnet50 ( num_classes = 1000 )
senet . load_state_dict ( torch . load ( "seresnet50-60a8950a85b2b.pkl" ))ฉันไม่สามารถรักษาพื้นที่เก็บข้อมูลนี้ได้อย่างแข็งขัน แต่ยินดีต้อนรับการมีส่วนร่วมใด ๆ อย่าลังเลที่จะส่ง PRS และปัญหา
กระดาษ
การใช้งานคาเฟอีนของผู้เขียน