ฝึกฝน CIFAR100 โดยใช้ pytorch
นี่คือการทดลองของฉัน eviroument
$ cd pytorch-cifar100ฉันจะใช้ชุดข้อมูล CIFAR100 จาก Torchvision เนื่องจากสะดวกกว่า แต่ฉันยังเก็บรหัสตัวอย่างไว้สำหรับการเขียนโมดูลชุดข้อมูลของคุณเองในโฟลเดอร์ชุดข้อมูลเป็นตัวอย่างสำหรับคนที่ไม่รู้วิธีเขียน
ติดตั้ง Tensorboard
$ pip install tensorboard
$ mkdir runs
Run tensorboard
$ tensorboard --logdir= ' runs ' --port=6006 --host= ' localhost 'คุณต้องระบุเน็ตที่คุณต้องการฝึกอบรมโดยใช้ arg -net
# use gpu to train vgg16
$ python train.py -net vgg16 -gpu บางครั้งคุณอาจต้องการใช้การฝึกอบรมการอุ่นเครื่องโดยตั้ง -warm ที่ 1 หรือ 2 เพื่อป้องกันไม่ให้เครือข่ายแตกต่างกันในช่วงระยะเวลาการฝึกอบรม
net args ที่รองรับคือ:
squeezenet
mobilenet
mobilenetv2
shufflenet
shufflenetv2
vgg11
vgg13
vgg16
vgg19
densenet121
densenet161
densenet201
googlenet
inceptionv3
inceptionv4
inceptionresnetv2
xception
resnet18
resnet34
resnet50
resnet101
resnet152
preactresnet18
preactresnet34
preactresnet50
preactresnet101
preactresnet152
resnext50
resnext101
resnext152
attention56
attention92
seresnet18
seresnet34
seresnet50
seresnet101
seresnet152
nasnet
wideresnet
stochasticdepth18
stochasticdepth34
stochasticdepth50
stochasticdepth101
โดยปกติแล้วไฟล์น้ำหนักที่มีความแม่นยำที่ดีที่สุดจะถูกเขียนลงในดิสก์ด้วยชื่อคำต่อท้าย 'ดีที่สุด' (ค่าเริ่มต้นในโฟลเดอร์จุดตรวจสอบ)
ทดสอบแบบจำลองโดยใช้ test.py
$ python test.py -net vgg16 -weights path_to_vgg16_weights_fileฉันไม่ได้ใช้กลเม็ดการฝึกอบรมใด ๆ เพื่อปรับปรุง Accuray หากคุณต้องการเรียนรู้เพิ่มเติมเกี่ยวกับเทคนิคการฝึกอบรมโปรดดูที่ repo อื่นของฉันมีเทคนิคการฝึกอบรมทั่วไปและการใช้งาน pytorch ของพวกเขา
ฉันติดตามการตั้งค่า hyperparameter ในกระดาษปรับปรุงการทำให้เป็นมาตรฐานของเครือข่ายประสาท convolutional ด้วย cutout ซึ่งเป็น init lr = 0.1 หารด้วย 5 ที่ 60, 120, 160th ages, รถไฟสำหรับ 200 ยุคที่มี batchsize 128 และ decay 5e-4, nesterov momentum 0.9 นอกจากนี้คุณยังสามารถใช้พารามิเตอร์ hyperparameters จากการทำให้เป็นมาตรฐานของเครือข่ายประสาทโดยการลงโทษการแจกแจงเอาท์พุทที่มั่นใจและการเพิ่มการลบข้อมูลแบบสุ่มซึ่งเป็น LR เริ่มต้น = 0.1, LR หารด้วย 10 ตอนที่ 150 และ 225 และการฝึกอบรมสำหรับ 300 Epochs คุณสามารถลด batchsize เป็น 64 หรืออะไรก็ตามที่เหมาะสมกับคุณถ้าคุณไม่มีหน่วยความจำ GPU เพียงพอ
คุณสามารถเลือกได้ว่าจะใช้ Tensorboard เพื่อให้เห็นภาพขั้นตอนการฝึกอบรมของคุณ
ผลลัพธ์ที่ฉันได้รับจากรุ่นที่แน่นอนเนื่องจากฉันใช้พารามิเตอร์ไฮเปอร์พารามิเตอร์เดียวกันเพื่อฝึกอบรมเครือข่ายทั้งหมดเครือข่ายบางแห่งอาจไม่ได้รับผลลัพธ์ที่ดีที่สุดจากพารามิเตอร์ไฮเปอร์พารามิเตอร์เหล่านี้คุณสามารถลองด้วยตัวเองด้วยการปรับพารามิเตอร์ hyperparameters เพื่อให้ได้ผลลัพธ์ที่ดีขึ้น
| ชุดข้อมูล | เครือข่าย | พารามิเตอร์ | top1 err | top5 err | ยุค (LR = 0.1) | ยุค (LR = 0.02) | ยุค (LR = 0.004) | ยุค (LR = 0.0008) | ยุคทั้งหมด |
|---|---|---|---|---|---|---|---|---|---|
| CIFAR100 | Mobilenet | 3.3m | 34.02 | 10.56 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | mobilenetv2 | 2.36m | 31.92 | 09.02 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | บอบบาง | 0.78m | 30.59 | 8.36 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | สลับ | 1.0m | 29.94 | 8.35 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Shufflenetv2 | 1.3m | 30.49 | 8.49 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | vgg11_bn | 28.5m | 31.36 | 11.85 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | vgg13_bn | 28.7m | 28.00 | 9.71 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | vgg16_bn | 34.0m | 27.07 | 8.84 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | vgg19_bn | 39.0m | 27.77 | 8.84 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | resnet18 | 11.2m | 24.39 | 6.95 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | resnet34 | 21.3m | 23.24 | 6.63 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | resnet50 | 23.7m | 22.61 | 6.04 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | resnet101 | 42.7m | 22.22 | 5.61 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Resnet152 | 58.3m | 22.31 | 5.81 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | PreactresNet18 | 11.3m | 27.08 | 8.53 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | preactresNet34 | 21.5m | 24.79 | 7.68 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | preactresNet50 | 23.9m | 25.73 | 8.15 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | PreactresNet101 | 42.9m | 24.84 | 7.83 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | PreactresNet152 | 58.6m | 22.71 | 6.62 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | resnext50 | 14.8m | 22.23 | 6.00 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | resnext101 | 25.3m | 22.22 | 5.99 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | resnext152 | 33.3m | 22.40 | 5.58 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | ความสนใจ 59 | 55.7m | 33.75 | 12.90 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | ความสนใจ 92 | 102.5m | 36.52 | 11.47 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Densenet121 | 7.0m | 22.99 | 6.45 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Densenet161 | 26m | 21.56 | 6.04 | 60 | 60 | 60 | 40 | 200 |
| CIFAR100 | Densenet201 | 18m | 21.46 | 5.9 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Googlenet | 6.2m | 21.97 | 5.94 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | inceptionv3 | 22.3m | 22.81 | 6.39 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | inceptionv4 | 41.3m | 24.14 | 6.90 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | InceptionResNetv2 | 65.4m | 27.51 | 9.11 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Xception | 21.0m | 25.07 | 7.32 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Seresnet18 | 11.4m | 23.56 | 6.68 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | seresnet34 | 21.6m | 22.07 | 6.12 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | seresnet50 | 26.5m | 21.42 | 5.58 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | seresnet101 | 47.7m | 20.98 | 5.41 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | SeresNet152 | 66.2m | 20.66 | 5.19 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | nasnet | 5.2m | 22.71 | 5.91 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | WideresNet-40-10 | 55.9m | 21.25 | 5.77 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | StochasticDepth18 | 11.22m | 31.40 | 8.84 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | StochasticDepth34 | 21.36m | 27.72 | 7.32 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | StochasticDepth50 | 23.71m | 23.35 | 5.76 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | StochasticDepth101 | 42.69m | 21.28 | 5.39 | 60 | 60 | 40 | 40 | 200 |