

合奏Pytorch
Pytorch輕鬆提高深度學習模型的性能和魯棒性的統一合奏框架。 Ensemble-Pytorch是Pytorch生態系統的一部分,該生態系統要求該項目必須得到很好的維護。
安裝
pip install torchensemble
例子
from torchensemble import VotingClassifier # voting is a classic ensemble strategy
# Load data
train_loader = DataLoader (...)
test_loader = DataLoader (...)
# Define the ensemble
ensemble = VotingClassifier (
estimator = base_estimator , # estimator is your pytorch model
n_estimators = 10 , # number of base estimators
)
# Set the optimizer
ensemble . set_optimizer (
"Adam" , # type of parameter optimizer
lr = learning_rate , # learning rate of parameter optimizer
weight_decay = weight_decay , # weight decay of parameter optimizer
)
# Set the learning rate scheduler
ensemble . set_scheduler (
"CosineAnnealingLR" , # type of learning rate scheduler
T_max = epochs , # additional arguments on the scheduler
)
# Train the ensemble
ensemble . fit (
train_loader ,
epochs = epochs , # number of training epochs
)
# Evaluate the ensemble
acc = ensemble . evaluate ( test_loader ) # testing accuracy支持的合奏
| 合奏名稱 | 類型 | 原始碼 | 問題 |
|---|
| 融合 | 混合 | fusion.py | 分類 /回歸 |
| 投票[1] | 平行線 | 投票 | 分類 /回歸 |
| 神經森林 | 平行線 | 投票 | 分類 /回歸 |
| 袋[2] | 平行線 | 袋裝 | 分類 /回歸 |
| 梯度提升[3] | 順序 | gradient_boosting.py | 分類 /回歸 |
| 快照合奏[4] | 順序 | snapshot_ensemble.py | 分類 /回歸 |
| 對抗訓練[5] | 平行線 | versarial_training.py | 分類 /回歸 |
| 快速幾何合奏[6] | 順序 | fast_ geometric.py | 分類 /回歸 |
| 軟梯度提升[7] | 平行線 | soft_gradient_boosting.py | 分類 /回歸 |
依賴性
- Scikit-Learn> = 0.23.0
- 火炬> = 1.4.0
- 火炬> = 0.2.2
參考
| [1] | 週,Zhi-hua。集合方法:基礎和算法。 CRC出版社,2012年。 |
| [2] | 布雷曼,獅子座。裝袋預測變量。機器學習(1996):123-140。 |
| [3] | 弗里德曼(Jerome H.)貪婪的功能近似:梯度提昇機。統計年鑑(2001):1189-1232。 |
| [4] | Huang,Gao等。快照合奏:火車1,免費獲得M。 ICLR,2017年。 |
| [5] | Lakshminarayanan,Balaji等。簡單可擴展的預測性不確定性使用深層合奏。 NIPS,2017年。 |
| [6] | Garipov,Timur等。 DNNS的損耗表面,模式連接性和快速結合。神經,2018年。 |
| [7] | 馮,吉等。軟梯度提昇機。 Arxiv,2020年。 |
感謝我們所有的貢獻者