

合奏Pytorch
Pytorch轻松提高深度学习模型的性能和鲁棒性的统一合奏框架。 Ensemble-Pytorch是Pytorch生态系统的一部分,该生态系统要求该项目必须得到很好的维护。
安装
pip install torchensemble
例子
from torchensemble import VotingClassifier # voting is a classic ensemble strategy
# Load data
train_loader = DataLoader (...)
test_loader = DataLoader (...)
# Define the ensemble
ensemble = VotingClassifier (
estimator = base_estimator , # estimator is your pytorch model
n_estimators = 10 , # number of base estimators
)
# Set the optimizer
ensemble . set_optimizer (
"Adam" , # type of parameter optimizer
lr = learning_rate , # learning rate of parameter optimizer
weight_decay = weight_decay , # weight decay of parameter optimizer
)
# Set the learning rate scheduler
ensemble . set_scheduler (
"CosineAnnealingLR" , # type of learning rate scheduler
T_max = epochs , # additional arguments on the scheduler
)
# Train the ensemble
ensemble . fit (
train_loader ,
epochs = epochs , # number of training epochs
)
# Evaluate the ensemble
acc = ensemble . evaluate ( test_loader ) # testing accuracy支持的合奏
| 合奏名称 | 类型 | 源代码 | 问题 |
|---|
| 融合 | 混合 | fusion.py | 分类 /回归 |
| 投票[1] | 平行线 | 投票 | 分类 /回归 |
| 神经森林 | 平行线 | 投票 | 分类 /回归 |
| 袋[2] | 平行线 | 袋装 | 分类 /回归 |
| 梯度提升[3] | 顺序 | gradient_boosting.py | 分类 /回归 |
| 快照合奏[4] | 顺序 | snapshot_ensemble.py | 分类 /回归 |
| 对抗训练[5] | 平行线 | versarial_training.py | 分类 /回归 |
| 快速几何合奏[6] | 顺序 | fast_ geometric.py | 分类 /回归 |
| 软梯度提升[7] | 平行线 | soft_gradient_boosting.py | 分类 /回归 |
依赖性
- Scikit-Learn> = 0.23.0
- 火炬> = 1.4.0
- 火炬> = 0.2.2
参考
| [1] | 周,Zhi-hua。集合方法:基础和算法。 CRC出版社,2012年。 |
| [2] | 布雷曼,狮子座。装袋预测变量。机器学习(1996):123-140。 |
| [3] | 弗里德曼(Jerome H.)贪婪的功能近似:梯度提升机。统计年鉴(2001):1189-1232。 |
| [4] | Huang,Gao等。快照合奏:火车1,免费获得M。 ICLR,2017年。 |
| [5] | Lakshminarayanan,Balaji等。简单可扩展的预测性不确定性使用深层合奏。 NIPS,2017年。 |
| [6] | Garipov,Timur等。 DNNS的损耗表面,模式连接性和快速结合。神经,2018年。 |
| [7] | 冯,吉等。软梯度提升机。 Arxiv,2020年。 |
感谢我们所有的贡献者