LibMTL ist eine Open-Source-Bibliothek, die auf Pytorch für Multi-Task Learning (MTL) basiert. In der neuesten Dokumentation finden Sie detaillierte Einführungen und API -Anweisungen.
Sterne uns auf Github - es motiviert uns sehr!
LibMTL bietet eine einheitliche Code-Basis für die Implementierung und ein konsistentes Bewertungsverfahren, einschließlich Datenverarbeitung, metrischen Ziele und Hyperparametern bei mehreren repräsentativen MTL-Benchmark-Datensätzen, die quantitative, faire und konsistente Vergleiche zwischen verschiedenen MTL-Algorithmen ermöglichen.LibMTL unterstützt viele hochmoderne MTL-Methoden, einschließlich 8 Architekturen und 16 Optimierungsstrategien. In der Zwischenzeit bietet LibMTL einen fairen Vergleich mehrerer Benchmark -Datensätze, die verschiedene Felder abdecken.LibMTL folgt den modularen Designprinzipien, mit denen Benutzer kundenspezifische Komponenten flexibel und bequem hinzufügen oder personalisierte Änderungen vornehmen können. Daher können Benutzer mit Unterstützung von LibMTL einfach neue Optimierungsstrategien und Architekturen entwickeln oder die vorhandenen MTL -Algorithmen auf neue Anwendungsszenarien anwenden. 
Jedes Modul wird in DOCS eingeführt.
LibMTL unterstützt derzeit die folgenden Algorithmen:
| Optimierungsstrategien | Veranstaltungsorte | Argumente |
|---|---|---|
| Gleiche Gewichtung (EW) | - - | --weighting EW |
| Gradientennormalisierung (GradNorm) | ICML 2018 | --weighting GradNorm |
| Unsicherheitsgewichte (UW) | CVPR 2018 | --weighting UW |
| MGDA (offizieller Code) | Neurips 2018 | --weighting MGDA |
| Dynamic Gewicht Durchschnitt (DWA) (offizieller Code) | CVPR 2019 | --weighting DWA |
| Geometrische Verluststrategie (GLS) | CVPR 2019 Workshop | --weighting GLS |
| Projizierter widersprüchlicher Gradienten (PCgrad) | Neurips 2020 | --weighting PCGrad |
| Gradientenzeichen Dropout (Graddrop) | Neurips 2020 | --weighting GradDrop |
| Unparteiisches Multitasking-Lernen (IMTL) | ICLR 2021 | --weighting IMTL |
| Gradientenimpfstoff (Gradvac) | ICLR 2021 | --weighting GradVac |
| Konfliktaverse-Gradientenabstieg (CAGRAD) (Offizieller Kodex) | Neurips 2021 | --weighting CAGrad |
| Nash-Mtl (offizieller Code) | ICML 2022 | --weighting Nash_MTL |
| Zufällige Verlustgewichtung (RLW) | TMLR 2022 | --weighting RLW |
| Moco | ICLR 2023 | --weighting MoCo |
| Ausgerichtetem MTL (offizieller Code) | CVPR 2023 | --weighting Aligned_MTL |
| STCH (offizieller Code) | ICML 2024 | --weighting STCH |
| ExessMTL (offizieller Code) | ICML 2024 | --weighting ExcessMTL |
| Fairgrad (offizieller Code) | ICML 2024 | --weighting FairGrad |
| Db-mtl | Arxiv | --weighting DB_MTL |
| Architekturen | Veranstaltungsorte | Argumente |
|---|---|---|
| Hard Parameter Sharing (HPS) | ICML 1993 | --arch HPS |
| Cross-Stitch-Netzwerke (Cross_stitch) | CVPR 2016 | --arch Cross_stitch |
| Multi-Gate-Mischung aus Experten (Mmoe) | KDD 2018 | --arch MMoE |
| Multi-Task-Aufmerksamkeitsnetzwerk (MTAN) (offizieller Code) | CVPR 2019 | --arch MTAN |
| Customized Gate Control (CGC), Progressive Layered Extraction (PLE) | ACM Recsys 2020 | --arch CGC , --arch PLE |
| Lernen zum Zweig (LTB) | ICML 2020 | --arch LTB |
| Dselect-k (offizieller Code) | Neurips 2021 | --arch DSelect_k |
| Datensätze | Probleme | Aufgabenummer | Aufgaben | Mehreingang | Unterstütztes Rückgrat |
|---|---|---|---|---|---|
| NYUV2 | Szenenverständnis | 3 | Semantische Segmentierung+ Tiefenschätzung+ Oberflächennormale Vorhersage | ✘ | Resnet50/ Segnet |
| Stadtlandschaften | Szenenverständnis | 2 | Semantische Segmentierung+ Tiefenschätzung | ✘ | Resnet50 |
| Office-31 | Bilderkennung | 3 | Einstufung | ✓ | Resnet18 |
| Bürohaus | Bilderkennung | 4 | Einstufung | ✓ | Resnet18 |
| QM9 | Vorhersage der molekularen Eigenschaft | 11 (Standard) | Regression | ✘ | Gnn |
| Pfoten-x | Paraphrase Identifikation | 4 (Standard) | Einstufung | ✓ | Bert |
Erstellen Sie eine virtuelle Umgebung
conda create -n libmtl python=3.8
conda activate libmtl
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.htmlKlonen Sie das Repository
git clone https://github.com/median-research-group/LibMTL.git Installieren Sie LibMTL
cd LibMTL
pip install -r requirements.txt
pip install -e . Wir verwenden den NYUV2 -Datensatz als Beispiel, um zu zeigen, wie LibMTL verwendet wird.
Der von uns verwendete NYUV2-Datensatz wird von MTAN vorverarbeitet. Sie können diesen Datensatz hier herunterladen.
Der vollständige Trainingscode für den NYUV2 -Datensatz ist in Beispielen/NYU bereitgestellt. Die Datei main.py ist die Hauptdatei für das Training im NYUV2 -Datensatz.
Sie können die Befehlszeilenargumente finden, indem Sie den folgenden Befehl ausführen.
python main.py -hWenn Sie beispielsweise den folgenden Befehl ausführen, trainiert ein MTL -Modell mit EW und HPS im NYUV2 -Datensatz.
python main.py --weighting EW --arch HPS --dataset_path /path/to/nyuv2 --gpu_id 0 --scheduler step --mode train --save_path PATHWeitere Details sind in Dokumenten dargestellt.
Wenn Sie LibMTL für Ihre Forschung oder Entwicklung nützlich finden, geben Sie Folgendes an:
@article{lin2023libmtl,
title={{LibMTL}: A {P}ython Library for Multi-Task Learning},
author={Baijiong Lin and Yu Zhang},
journal={Journal of Machine Learning Research},
volume={24},
number={209},
pages={1--7},
year={2023}
} LibMTL wird von baijiong lin entwickelt und aufrechterhalten.
Wenn Sie Fragen oder Vorschläge haben, kontaktieren Sie uns bitte, indem Sie ein Problem ansprechen oder eine E -Mail an [email protected] senden.
Wir möchten den Autoren danken, die die öffentlichen Repositorys veröffentlichen (alphabetisch gelistet): Cagrad, dselect_k_moe, MultiObjectiveOptimization, MTAN, MTL, Nash-MTL, Pytorch_Geometric und Xtreme.
LibMTL wird unter der MIT -Lizenz veröffentlicht.