Berlatih di CIFAR100 Menggunakan Pytorch
Ini adalah eksperimen saya menghindari
$ cd pytorch-cifar100Saya akan menggunakan dataset CIFAR100 dari TorchVision karena lebih nyaman, tetapi saya juga menyimpan kode sampel untuk menulis modul dataset Anda sendiri di folder dataset, sebagai contoh bagi orang tidak tahu cara menulisnya.
Instal Tensorboard
$ pip install tensorboard
$ mkdir runs
Run tensorboard
$ tensorboard --logdir= ' runs ' --port=6006 --host= ' localhost 'Anda perlu menentukan jaring yang ingin Anda latih menggunakan arg -net
# use gpu to train vgg16
$ python train.py -net vgg16 -gpu Kadang -kadang, Anda mungkin ingin menggunakan pelatihan pemanasan dengan mengatur -warm ke 1 atau 2, untuk mencegah jaringan yang berbeda selama fase pelatihan awal.
Arg Net yang didukung adalah:
squeezenet
mobilenet
mobilenetv2
shufflenet
shufflenetv2
vgg11
vgg13
vgg16
vgg19
densenet121
densenet161
densenet201
googlenet
inceptionv3
inceptionv4
inceptionresnetv2
xception
resnet18
resnet34
resnet50
resnet101
resnet152
preactresnet18
preactresnet34
preactresnet50
preactresnet101
preactresnet152
resnext50
resnext101
resnext152
attention56
attention92
seresnet18
seresnet34
seresnet50
seresnet101
seresnet152
nasnet
wideresnet
stochasticdepth18
stochasticdepth34
stochasticdepth50
stochasticdepth101
Biasanya, file bobot dengan akurasi terbaik akan ditulis ke disk dengan akhiran nama 'terbaik' (folder default di folder pos pemeriksaan).
Uji model menggunakan test.py
$ python test.py -net vgg16 -weights path_to_vgg16_weights_fileSaya tidak menggunakan trik pelatihan apa pun untuk meningkatkan Accuray, jika Anda ingin mempelajari lebih lanjut tentang trik pelatihan, silakan merujuk ke repo lain saya, berisi berbagai trik pelatihan umum dan implementasi Pytorch mereka.
Saya mengikuti pengaturan hyperparameter dalam kertas meningkatkan regularisasi jaringan saraf konvolusional dengan potongan, yang init lr = 0,1 membagi dengan 5 pada 60, 120, zaman ke-160, melatih untuk 200 zaman dengan kumpulan 128 dan berat badan 5E-4, momentum nesterov 0,9. Anda juga dapat menggunakan hyperparameter dari kertas yang mengatur jaringan saraf dengan menghukum distribusi output yang percaya diri dan augmentasi penghapusan data acak, yang merupakan LR awal = 0,1, LR divied dengan 10 pada zaman ke -150 dan ke -225, dan pelatihan untuk 300 zaman dengan Batchsize 128, ini lebih umum digunakan. Anda dapat mengurangi ukuran batch ke 64 atau apa pun yang cocok untuk Anda, jika Anda tidak memiliki cukup memori GPU.
Anda dapat memilih apakah akan menggunakan Tensorboard untuk memvisualisasikan prosedur pelatihan Anda
Hasil yang bisa saya dapatkan dari model tertentu, karena saya menggunakan hyperparameters yang sama untuk melatih semua jaringan, beberapa jaringan mungkin tidak mendapatkan hasil terbaik dari hyperparameters ini, Anda bisa mencoba diri Anda sendiri dengan finetuning hyperparameters untuk mendapatkan hasil yang lebih baik.
| dataset | jaringan | params | Top1 err | Top5 err | Epoch (LR = 0,1) | Epoch (LR = 0,02) | Epoch (LR = 0,004) | Epoch (LR = 0,0008) | Total zaman |
|---|---|---|---|---|---|---|---|---|---|
| CIFAR100 | Mobilenet | 3.3m | 34.02 | 10.56 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | MobileNetv2 | 2.36m | 31.92 | 09.02 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Squeezenet | 0.78m | 30.59 | 8.36 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Shufflenet | 1.0m | 29.94 | 8.35 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | shufflenetv2 | 1.3m | 30.49 | 8.49 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | vgg11_bn | 28.5m | 31.36 | 11.85 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | VGG13_BN | 28.7m | 28.00 | 9.71 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | VGG16_BN | 34.0m | 27.07 | 8.84 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | VGG19_BN | 39.0m | 27.77 | 8.84 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | resnet18 | 11.2m | 24.39 | 6.95 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | resnet34 | 21.3m | 23.24 | 6.63 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | resnet50 | 23.7m | 22.61 | 6.04 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | resnet101 | 42.7m | 22.22 | 5.61 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | resnet152 | 58.3m | 22.31 | 5.81 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | preactresnet18 | 11.3m | 27.08 | 8.53 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Preactresnet34 | 21.5m | 24.79 | 7.68 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Preactresnet50 | 23.9m | 25.73 | 8.15 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | preactresnet101 | 42.9m | 24.84 | 7.83 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | preactresnet152 | 58.6m | 22.71 | 6.62 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Resnext50 | 14.8m | 22.23 | 6.00 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | resnext101 | 25.3m | 22.22 | 5.99 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | resnext152 | 33.3m | 22.40 | 5.58 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | perhatian59 | 55.7m | 33.75 | 12.90 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | perhatian92 | 102.5m | 36.52 | 11.47 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Densenet121 | 7.0m | 22.99 | 6.45 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Densenet161 | 26m | 21.56 | 6.04 | 60 | 60 | 60 | 40 | 200 |
| CIFAR100 | Densenet201 | 18m | 21.46 | 5.9 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Googlenet | 6.2m | 21.97 | 5.94 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | inceptionv3 | 22.3m | 22.81 | 6.39 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | inceptionv4 | 41.3m | 24.14 | 6.90 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | InceptionResNetv2 | 65.4m | 27.51 | 9.11 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | xception | 21.0m | 25.07 | 7.32 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Seresnet18 | 11.4m | 23.56 | 6.68 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Seresnet34 | 21.6m | 22.07 | 6.12 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Seresnet50 | 26.5m | 21.42 | 5.58 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Seresnet101 | 47.7m | 20.98 | 5.41 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Seresnet152 | 66.2m | 20.66 | 5.19 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | nasnet | 5.2m | 22.71 | 5.91 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | Wideresnet-40-10 | 55.9m | 21.25 | 5.77 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | stochasticdepth18 | 11.22m | 31.40 | 8.84 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | stochasticdepth34 | 21.36m | 27.72 | 7.32 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | StochasticDepth50 | 23.71m | 23.35 | 5.76 | 60 | 60 | 40 | 40 | 200 |
| CIFAR100 | stochasticdepth101 | 42.69m | 21.28 | 5.39 | 60 | 60 | 40 | 40 | 200 |