knowledge distillation pytorch
1.0.0
克隆倉庫
git clone https://github.com/peterliht/knowledge-distillation-pytorch.git
安裝依賴項(包括pytorch)
pip install -r requirements.txt
注意:所有超參數都可以在“ params.json”下找到並修改
- 訓練5層CNN,其知識從預先訓練的RESNET-18模型中蒸餾出來
python train.py --model_dir experiments/cnn_distill
- 培訓RESNET-18模型,其知識從預先訓練的Resnext-29老師中蒸餾出來
python train.py --model_dir experiments/resnet18_distill/resnext_teacher
- 超參數搜索指定的實驗('parent_dir/paramses.json')
python search_hyperparams.py --parent_dir experiments/cnn_distill_alpha_temp
- 最近的HyperSearch實驗的結果
python synthesize_results.py --parent_dir experiments/cnn_distill_alpha_temp
快速外賣(要添加更多詳細信息):
-從RESNET-18到5層CNN的知識蒸餾
| 模型 | 輟學= 0.5 | 沒有輟學 |
|---|---|---|
| 5層CNN | 83.51% | 84.74% |
| 5層CNN W/ resnet18 | 84.49% | 85.69% |
-從更深的模型到RESNET-18的知識蒸餾
| 模型 | 測試準確性 |
|---|---|
| 基線RESNET-18 | 94.175% |
| + KD WIDERESNET-28-10 | 94.333% |
| + KD Preesnet-11 | 94.531% |
| + KD Densenet-100 | 94.729% |
| + KD Resnext-29-8 | 94.788% |
H. Li,“探索深神網的知識蒸餾以進行高效硬件解決方案”,CS230報告,2018年
Hinton,Geoffrey,Oriol Vinyals和Jeff Dean。 “在神經網絡中提取知識。” ARXIV預印型ARXIV:1503.02531(2015)。
Romero,A.,Ballas,N.,Kahou,SE,Chassang,A. fitnets:薄網的提示。 ARXIV預印型ARXIV:1412.6550。
https://github.com/cs230-stanford/cs230-stanford.github.io
https://github.com/bearpaw/pytorch-classification