LibMTL es una biblioteca de código abierto construida en Pytorch para el aprendizaje de tareas múltiples (MTL). Consulte la última documentación para introducciones detalladas e instrucciones de API.
Estratas en Github, ¡nos motiva mucho!
LibMTL proporciona una base de código unificada para implementar y un procedimiento de evaluación consistente que incluye procesamiento de datos, objetivos métricos e hiperparametadores en varios conjuntos de datos de referencia MTL representativos, lo que permite comparaciones cuantitativas, justas y consistentes entre diferentes algoritmos MTL.LibMTL admite muchos métodos MTL de última generación, incluidas 8 arquitecturas y 16 estrategias de optimización. Mientras tanto, LibMTL proporciona una comparación justa de varios conjuntos de datos de referencia que cubren diferentes campos.LibMTL sigue los principios de diseño modular, que permite a los usuarios agregar componentes personalizados de manera flexible y conveniente o hacer modificaciones personalizadas. Por lo tanto, los usuarios pueden desarrollar fácil y rápidamente estrategias y arquitecturas de optimización novedosas o aplicar los algoritmos MTL existentes a nuevos escenarios de aplicación con el soporte de LibMTL . 
Cada módulo se introduce en Docs.
LibMTL actualmente admite los siguientes algoritmos:
| Estrategias de optimización | Lugares | Argumentos |
|---|---|---|
| Ponderación igual (EW) | - | --weighting EW |
| Normalización de gradiente (Gradnorm) | ICML 2018 | --weighting GradNorm |
| Pesos de incertidumbre (UW) | CVPR 2018 | --weighting UW |
| MGDA (código oficial) | Neurips 2018 | --weighting MGDA |
| Promedio de peso dinámico (DWA) (código oficial) | CVPR 2019 | --weighting DWA |
| Estrategia de pérdida geométrica (GLS) | Taller CVPR 2019 | --weighting GLS |
| Proyecto de gradiente conflictivo (PCGRAD) | Neurips 2020 | --weighting PCGrad |
| Descarga de signo de gradiente (Graddrop) | Neurips 2020 | --weighting GradDrop |
| Aprendizaje de tareas múltiples imparciales (IMTL) | ICLR 2021 | --weighting IMTL |
| Vacuna de gradiente (GradVAC) | ICLR 2021 | --weighting GradVac |
| Descendencia de gradiente de contragolpe de conflicto (CAGRAD) (Código oficial) | Neurips 2021 | --weighting CAGrad |
| Nash-mtl (código oficial) | ICML 2022 | --weighting Nash_MTL |
| Ponderación de pérdida aleatoria (RLW) | TMLR 2022 | --weighting RLW |
| Moco | ICLR 2023 | --weighting MoCo |
| Alineado-MTL (código oficial) | CVPR 2023 | --weighting Aligned_MTL |
| STCH (código oficial) | ICML 2024 | --weighting STCH |
| Excessmtl (código oficial) | ICML 2024 | --weighting ExcessMTL |
| Fairgrad (código oficial) | ICML 2024 | --weighting FairGrad |
| Db-mtl | arxiv | --weighting DB_MTL |
| Arquitecturas | Lugares | Argumentos |
|---|---|---|
| Compartir parámetros duros (HPS) | ICML 1993 | --arch HPS |
| Networks de costura cruzada (Cross_stitch) | CVPR 2016 | --arch Cross_stitch |
| Mezcla de múltiples gases de gas (MMOE) | KDD 2018 | --arch MMoE |
| Red de atención de múltiples tareas (MTAN) (código oficial) | CVPR 2019 | --arch MTAN |
| Control de puerta personalizado (CGC), extracción progresiva en capas (PLE) | ACM RECSYS 2020 | --arch CGC , --arch PLE |
| Aprendiendo a ramificarse (LTB) | ICML 2020 | --arch LTB |
| DSelect-K (código oficial) | Neurips 2021 | --arch DSelect_k |
| Conjuntos de datos | Problemas | Número de tareas | Tareas | múltiple | Columna vertebral compatible |
|---|---|---|---|---|---|
| NYUV2 | Comprensión de la escena | 3 | Segmentación semántica+ Estimación de profundidad+ Predicción normal de la superficie | ✘ | Resnet50/ Segnet |
| Paisajes urbanos | Comprensión de la escena | 2 | Segmentación semántica+ Estimación de profundidad | ✘ | Resnet50 |
| Oficina 31 | Reconocimiento de imágenes | 3 | Clasificación | ✓ | Resnet18 |
| Domicilio | Reconocimiento de imágenes | 4 | Clasificación | ✓ | Resnet18 |
| QM9 | Predicción de propiedades moleculares | 11 (predeterminado) | Regresión | ✘ | GNN |
| Patas | Identificación de parafraseo | 4 (predeterminado) | Clasificación | ✓ | Bert |
Crear un entorno virtual
conda create -n libmtl python=3.8
conda activate libmtl
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.htmlClonar el repositorio
git clone https://github.com/median-research-group/LibMTL.git Instalar LibMTL
cd LibMTL
pip install -r requirements.txt
pip install -e . Utilizamos el conjunto de datos NYUV2 como ejemplo para mostrar cómo usar LibMTL .
El conjunto de datos NYUV2 que utilizamos está preprocesado por MTAN. Puede descargar este conjunto de datos aquí.
El código de entrenamiento completo para el conjunto de datos NYUV2 se proporciona en ejemplos/NYU. El archivo main.py es el archivo principal para la capacitación en el conjunto de datos NYUV2.
Puede encontrar los argumentos de la línea de comandos ejecutando el siguiente comando.
python main.py -hPor ejemplo, ejecutar el siguiente comando entrenará un modelo MTL con EW y HPS en el conjunto de datos NYUV2.
python main.py --weighting EW --arch HPS --dataset_path /path/to/nyuv2 --gpu_id 0 --scheduler step --mode train --save_path PATHMás detalles se representan en Docs.
Si encuentra útil LibMTL para su investigación o desarrollo, cite lo siguiente:
@article{lin2023libmtl,
title={{LibMTL}: A {P}ython Library for Multi-Task Learning},
author={Baijiong Lin and Yu Zhang},
journal={Journal of Machine Learning Research},
volume={24},
number={209},
pages={1--7},
year={2023}
} LibMTL es desarrollado y mantenido por Baijiong Lin.
Si tiene alguna pregunta o sugerencia, no dude en contactarnos planteando un problema o enviando un correo electrónico a [email protected] .
Nos gustaría agradecer a los autores que liberan los repositorios públicos (enumerados alfabéticamente): CAGRAD, DSELECT_K_MOE, MultiobjectOptimization, MTAN, MTL, Nash-Mtl, Pytorch_Geometric y XTreme.
LibMTL se publica bajo la licencia MIT.