

Ensemble Pytorch
Uma estrutura unificada para Pytorch melhorar facilmente o desempenho e a robustez do seu modelo de aprendizado profundo. O Ensemble-Pytorch faz parte do ecossistema Pytorch, que exige que o projeto seja bem mantido.
Instalação
pip install torchensemble
Exemplo
from torchensemble import VotingClassifier # voting is a classic ensemble strategy
# Load data
train_loader = DataLoader (...)
test_loader = DataLoader (...)
# Define the ensemble
ensemble = VotingClassifier (
estimator = base_estimator , # estimator is your pytorch model
n_estimators = 10 , # number of base estimators
)
# Set the optimizer
ensemble . set_optimizer (
"Adam" , # type of parameter optimizer
lr = learning_rate , # learning rate of parameter optimizer
weight_decay = weight_decay , # weight decay of parameter optimizer
)
# Set the learning rate scheduler
ensemble . set_scheduler (
"CosineAnnealingLR" , # type of learning rate scheduler
T_max = epochs , # additional arguments on the scheduler
)
# Train the ensemble
ensemble . fit (
train_loader ,
epochs = epochs , # number of training epochs
)
# Evaluate the ensemble
acc = ensemble . evaluate ( test_loader ) # testing accuracy Ensemble suportado
| Nome do conjunto | Tipo | Código -fonte | Problema |
|---|
| Fusão | Misturado | fusion.py | Classificação / Regressão |
| Votação [1] | Paralelo | votação.py | Classificação / Regressão |
| Floresta neural | Paralelo | votação.py | Classificação / Regressão |
| Equipamento [2] | Paralelo | Bagging.py | Classificação / Regressão |
| Aumentado de gradiente [3] | Sequencial | gradiente_boosting.py | Classificação / Regressão |
| Ensemble Snapshot [4] | Sequencial | snapshot_ensemble.py | Classificação / Regressão |
| Treinamento adversário [5] | Paralelo | adversas_training.py | Classificação / Regressão |
| Ensemble geométrico rápido [6] | Sequencial | fast_geometric.py | Classificação / Regressão |
| Aumentado de gradiente suave [7] | Paralelo | soft_gradient_boosting.py | Classificação / Regressão |
Dependências
- Scikit-Learn> = 0,23.0
- tocha> = 1.4.0
- Torchvision> = 0.2.2
Referência
| [1] | Zhou, Zhi-hua. Métodos de conjunto: fundações e algoritmos. CRC Press, 2012. |
| [2] | Breiman, Leo. Preditores de ensacamento. Machine Learning (1996): 123-140. |
| [3] | Friedman, Jerome H. Aproximação da função gananciosa: uma máquina de reforço de gradiente. Annals of Statistics (2001): 1189-1232. |
| [4] | Huang, Gao, et al. Conjuntos instantâneos: Trem 1, obtenha M de graça. ICLR, 2017. |
| [5] | Lakshminarayanan, Balaji, et al. Estimativa de incerteza preditiva simples e escalável usando conjuntos profundos. NIPS, 2017. |
| [6] | Garipov, Timur, et al. Superfícies de perda, conectividade de modo e conjunto rápido de DNNs. Neurips, 2018. |
| [7] | Feng, Ji, et al. Máquina de reforço de gradiente suave. Arxiv, 2020. |
Obrigado a todos os nossos colaboradores