LibMTL est une bibliothèque open source construite sur Pytorch pour l'apprentissage multi-tâches (MTL). Voir la dernière documentation pour les introductions détaillées et les instructions de l'API.
Star nous sur Github - cela nous motive beaucoup!
LibMTL fournit une base de code unifiée à implémenter et une procédure d'évaluation cohérente comprenant le traitement des données, les objectifs métriques et les hyper-paramètres sur plusieurs ensembles de données de référence MTL représentatifs, ce qui permet des comparaisons quantitatives, équitables et cohérentes entre les différents algorithmes MTL.LibMTL prend en charge de nombreuses méthodes MTL de pointe, y compris 8 architectures et 16 stratégies d'optimisation. Pendant ce temps, LibMTL fournit une comparaison équitable de plusieurs ensembles de données de référence couvrant différents champs.LibMTL suit les principes de conception modulaires, qui permet aux utilisateurs d'ajouter des composants personnalisés avec flexion et commodément ou apporter des modifications personnalisées. Par conséquent, les utilisateurs peuvent facilement et rapidement développer de nouvelles stratégies et architectures d'optimisation ou appliquer les algorithmes MTL existants à de nouveaux scénarios d'application avec le support de LibMTL . 
Chaque module est introduit dans DOCS.
LibMTL prend actuellement en charge les algorithmes suivants:
| Stratégies d'optimisation | Salles | Arguments |
|---|---|---|
| Pondération égale (EW) | - | --weighting EW |
| Normalisation du gradient (diplômé) | ICML 2018 | --weighting GradNorm |
| Poids d'incertitude (UW) | CVPR 2018 | --weighting UW |
| MGDA (code officiel) | Neirips 2018 | --weighting MGDA |
| Moyenne de poids dynamique (DWA) (code officiel) | CVPR 2019 | --weighting DWA |
| Stratégie de perte géométrique (GLS) | Atelier CVPR 2019 | --weighting GLS |
| Projection de gradient contradictoire (pcgrad) | Neirips 2020 | --weighting PCGrad |
| Déprochage des panneaux de gradient (Graddrop) | Neirips 2020 | --weighting GradDrop |
| Apprentissage impartial multi-tâches (IMTL) | ICLR 2021 | --weighting IMTL |
| Vaccin à gradient (GradVAC) | ICLR 2021 | --weighting GradVac |
| Descendance à un gradient aux conflits (CAGRAD) (code officiel) | Neirips 2021 | --weighting CAGrad |
| Nash-MTL (code officiel) | ICML 2022 | --weighting Nash_MTL |
| Pondération de perte aléatoire (RLW) | TMLR 2022 | --weighting RLW |
| Moco | ICLR 2023 | --weighting MoCo |
| MTL aligné (code officiel) | CVPR 2023 | --weighting Aligned_MTL |
| STCH (code officiel) | ICML 2024 | --weighting STCH |
| Excès (code officiel) | ICML 2024 | --weighting ExcessMTL |
| Fairgrad (code officiel) | ICML 2024 | --weighting FairGrad |
| Db-mtl | arxiv | --weighting DB_MTL |
| Architectures | Salles | Arguments |
|---|---|---|
| Partage des paramètres durs (HPS) | ICML 1993 | --arch HPS |
| Réseaux de crost-cratt (Cross_Stitch) | CVPR 2016 | --arch Cross_stitch |
| Mélange de gats multiples (MMOE) | KDD 2018 | --arch MMoE |
| Réseau d'attention multi-tâches (MTAN) (code officiel) | CVPR 2019 | --arch MTAN |
| Contrôle des portes personnalisé (CGC), extraction en couches progressive (PLE) | ACM Recsys 2020 | --arch CGC , --arch PLE |
| Apprendre à se ramifier (LTB) | ICML 2020 | --arch LTB |
| DSelect-K (code officiel) | Neirips 2021 | --arch DSelect_k |
| Ensembles de données | Problèmes | Numéro de tâche | Tâches | multi-entrées | Épine dorsale prise en charge |
|---|---|---|---|---|---|
| Nyuv2 | Compréhension de la scène | 3 | Segmentation sémantique + Estimation de la profondeur + Prédiction normale de surface | ✘ | Resnet50 / Borne |
| Paysages urbains | Compréhension de la scène | 2 | Segmentation sémantique + Estimation de la profondeur | ✘ | Resnet50 |
| Bureau-31 | Reconnaissance d'image | 3 | Classification | ✓ | Resnet18 |
| Maison de bureau | Reconnaissance d'image | 4 | Classification | ✓ | Resnet18 |
| QM9 | Prédiction des propriétés moléculaires | 11 (par défaut) | Régression | ✘ | Gnn |
| PAWS-X | Identification paraphrase | 4 (par défaut) | Classification | ✓ | Bert |
Créer un environnement virtuel
conda create -n libmtl python=3.8
conda activate libmtl
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.htmlCloner le référentiel
git clone https://github.com/median-research-group/LibMTL.git Installer LibMTL
cd LibMTL
pip install -r requirements.txt
pip install -e . Nous utilisons l'ensemble de données NYUV2 comme exemple pour montrer comment utiliser LibMTL .
L'ensemble de données NYUV2 que nous avons utilisé est prétraité par MTAN. Vous pouvez télécharger cet ensemble de données ici.
Le code de formation complet de l'ensemble de données NYUV2 est fourni dans des exemples / NYU. Le fichier main.py est le fichier principal de formation sur l'ensemble de données NYUV2.
Vous pouvez trouver les arguments en ligne de commande en exécutant la commande suivante.
python main.py -hPar exemple, l'exécution de la commande suivante entraînera un modèle MTL avec EW et HPS sur l'ensemble de données NYUV2.
python main.py --weighting EW --arch HPS --dataset_path /path/to/nyuv2 --gpu_id 0 --scheduler step --mode train --save_path PATHPlus de détails sont représentés dans les documents.
Si vous trouvez LibMTL utile pour vos recherches ou développement, veuillez citer ce qui suit:
@article{lin2023libmtl,
title={{LibMTL}: A {P}ython Library for Multi-Task Learning},
author={Baijiong Lin and Yu Zhang},
journal={Journal of Machine Learning Research},
volume={24},
number={209},
pages={1--7},
year={2023}
} LibMTL est développé et maintenu par Baijiong Lin.
Si vous avez une question ou une suggestion, n'hésitez pas à nous contacter en soulevant un problème ou en envoyant un e-mail à [email protected] .
Nous tenons à remercier les auteurs qui publient les référentiels publics (répertoriés par ordre alphabétique): Cagrad, DSELECT_K_MOE, MultiobjectiveOptimization, MTAN, MTL, NASH-MTL, PYTORCH_GEOMETRIQUE et XTREME.
LibMTL est libéré sous la licence du MIT.