knowledge distillation pytorch
1.0.0
克隆仓库
git clone https://github.com/peterliht/knowledge-distillation-pytorch.git
安装依赖项(包括pytorch)
pip install -r requirements.txt
注意:所有超参数都可以在“ params.json”下找到并修改
- 训练5层CNN,其知识从预先训练的RESNET-18模型中蒸馏出来
python train.py --model_dir experiments/cnn_distill
- 培训RESNET-18模型,其知识从预先训练的Resnext-29老师中蒸馏出来
python train.py --model_dir experiments/resnet18_distill/resnext_teacher
- 超参数搜索指定的实验('parent_dir/paramses.json')
python search_hyperparams.py --parent_dir experiments/cnn_distill_alpha_temp
- 最近的HyperSearch实验的结果
python synthesize_results.py --parent_dir experiments/cnn_distill_alpha_temp
快速外卖(要添加更多详细信息):
-从RESNET-18到5层CNN的知识蒸馏
| 模型 | 辍学= 0.5 | 没有辍学 |
|---|---|---|
| 5层CNN | 83.51% | 84.74% |
| 5层CNN W/ resnet18 | 84.49% | 85.69% |
-从更深的模型到RESNET-18的知识蒸馏
| 模型 | 测试准确性 |
|---|---|
| 基线RESNET-18 | 94.175% |
| + KD WIDERESNET-28-10 | 94.333% |
| + KD Preesnet-11 | 94.531% |
| + KD Densenet-100 | 94.729% |
| + KD Resnext-29-8 | 94.788% |
H. Li,“探索深神网的知识蒸馏以进行高效硬件解决方案”,CS230报告,2018年
Hinton,Geoffrey,Oriol Vinyals和Jeff Dean。 “在神经网络中提取知识。” ARXIV预印型ARXIV:1503.02531(2015)。
Romero,A.,Ballas,N.,Kahou,SE,Chassang,A. fitnets:薄网的提示。 ARXIV预印型ARXIV:1412.6550。
https://github.com/cs230-stanford/cs230-stanford.github.io
https://github.com/bearpaw/pytorch-classification