knowledge distillation pytorch
1.0.0
レポをクローンします
git clone https://github.com/peterliht/knowledge-distillation-pytorch.git
依存関係をインストールする(pytorchを含む)
pip install -r requirements.txt
注:すべてのハイパーパラメーターは、「model_dir」の下の「params.json」で見つけることができます。
- 事前に訓練されたResNet-18モデルから蒸留された知識を備えた5層CNNを訓練する
python train.py --model_dir experiments/cnn_distill
- 事前に訓練されたresnext-29教師から蒸留された知識を備えたResnet-18モデルをトレーニングする
python train.py --model_dir experiments/resnet18_distill/resnext_teacher
- 指定された実験を検索するハイパーパラメーター( 'parent_dir/params.json')
python search_hyperparams.py --parent_dir experiments/cnn_distill_alpha_temp
- 最近のHyperSearch実験の結果
python synthesize_results.py --parent_dir experiments/cnn_distill_alpha_temp
クイックテイクアウト(追加する詳細):
-ResNet-18から5層CNNへの知識の蒸留
| モデル | ドロップアウト= 0.5 | ドロップアウトはありません |
|---|---|---|
| 5層CNN | 83.51% | 84.74% |
| 5層CNN w/ resnet18 | 84.49% | 85.69% |
-より深いモデルからResnet-18への知識の蒸留
| モデル | テスト精度 |
|---|---|
| ベースラインResNet-18 | 94.175% |
| + kd wideresnet-28-10 | 94.333% |
| + kd preresnet-110 | 94.531% |
| + kd densenet-100 | 94.729% |
| + kd resnext-29-8 | 94.788% |
H. Li、「効率的なハードウェアソリューションのための深い神経ネットの知識の蒸留の探求」、CS230 Report、2018
ヒントン、ジェフリー、オリオールヴィンヴァルズ、ジェフディーン。 「ニューラルネットワークで知識を蒸留します。」 Arxiv Preprint arxiv:1503.02531(2015)。
Romero、A.、Ballas、N.、Kahou、SE、Chassang、A.、Gatta、C。、&Bengio、Y。(2014)。 Fitlets:薄いディープネットのヒント。 arxiv preprint arxiv:1412.6550。
https://github.com/cs230-stanford/cs230-stanford.github.io
https://github.com/bearpaw/pytorch-classification