LibMTL é uma biblioteca de código aberto construído no Pytorch para aprendizado de várias tarefas (MTL). Consulte a documentação mais recente para obter introduções detalhadas e instruções da API.
Estrear -nos no Github - isso nos motiva muito!
LibMTL fornece uma base de código unificada para implementar e um procedimento de avaliação consistente, incluindo processamento de dados, objetivos métricos e hiper-parâmetros em vários conjuntos de dados de referência MTL representativos, que permitem comparações quantitativas, justas e consistentes entre diferentes algoritmos MTL.LibMTL suporta muitos métodos MTL de última geração, incluindo 8 arquiteturas e 16 estratégias de otimização. Enquanto isso, LibMTL fornece uma comparação justa de vários conjuntos de dados de referência que cobrem diferentes campos.LibMTL segue os princípios de design modular, que permitem que os usuários adicionem componentes personalizados de maneira flexível e conveniente ou façam modificações personalizadas. Portanto, os usuários podem desenvolver facilmente e rapidamente novas estratégias e arquiteturas de otimização ou aplicar os algoritmos MTL existentes a novos cenários de aplicativos com o suporte do LibMTL . 
Cada módulo é introduzido nos documentos.
Atualmente, LibMTL suporta os seguintes algoritmos:
| Estratégias de otimização | Locais | Argumentos |
|---|---|---|
| Ponderação igual (ew) | - | --weighting EW |
| Normalização do gradiente (gradnorm) | ICML 2018 | --weighting GradNorm |
| Pesos de incerteza (UW) | CVPR 2018 | --weighting UW |
| MGDA (código oficial) | Neurips 2018 | --weighting MGDA |
| Média de peso dinâmico (DWA) (código oficial) | CVPR 2019 | --weighting DWA |
| Estratégia de perda geométrica (GLS) | Workshop CVPR 2019 | --weighting GLS |
| Projetando gradiente conflitante (PCGrad) | Neurips 2020 | --weighting PCGrad |
| GRADIIO DE SIGNOUT (GRADDROP) | Neurips 2020 | --weighting GradDrop |
| Learning imparcial de várias tarefas (IMTL) | ICLR 2021 | --weighting IMTL |
| Vacina de gradiente (gradvac) | ICLR 2021 | --weighting GradVac |
| Descendência de gradiente avesso a conflito (CAGRAD) (código oficial) | Neurips 2021 | --weighting CAGrad |
| Nash-MTL (código oficial) | ICML 2022 | --weighting Nash_MTL |
| Ponderação de perda aleatória (RLW) | TMLR 2022 | --weighting RLW |
| MOCO | ICLR 2023 | --weighting MoCo |
| Alinhado-mtl (código oficial) | CVPR 2023 | --weighting Aligned_MTL |
| Stch (código oficial) | ICML 2024 | --weighting STCH |
| Excessmtl (código oficial) | ICML 2024 | --weighting ExcessMTL |
| Fairgrad (código oficial) | ICML 2024 | --weighting FairGrad |
| Db-mtl | arxiv | --weighting DB_MTL |
| Arquiteturas | Locais | Argumentos |
|---|---|---|
| Compartilhamento de parâmetros difíceis (HPS) | ICML 1993 | --arch HPS |
| Redes de ponto cruz (Cross_stitch) | CVPR 2016 | --arch Cross_stitch |
| Mistura de Expperts de Multi-Gate (MMOE) | KDD 2018 | --arch MMoE |
| Rede de Atenção de Múltiplas Tarefas (MTAN) (Código Oficial) | CVPR 2019 | --arch MTAN |
| Controle de portão personalizado (CGC), extração progressiva em camadas (PLE) | ACM RECSYS 2020 | --arch CGC , --arch PLE |
| Aprendendo a ramificar (LTB) | ICML 2020 | --arch LTB |
| DSelect-k (código oficial) | Neurips 2021 | --arch DSelect_k |
| Conjuntos de dados | Problemas | Número da tarefa | Tarefas | Multi-entrada | Backbone suportado |
|---|---|---|---|---|---|
| NYUV2 | Entendimento da cena | 3 | Segmentação semântica+ Estimativa de profundidade+ Previsão normal de superfície | ✘ | Resnet50/ Segnet |
| Paisagens da cidade | Entendimento da cena | 2 | Segmentação semântica+ Estimativa de profundidade | ✘ | Resnet50 |
| Office-31 | Reconhecimento de imagem | 3 | Classificação | ✓ | Resnet18 |
| Escritório em casa | Reconhecimento de imagem | 4 | Classificação | ✓ | Resnet18 |
| QM9 | Previsão de propriedades moleculares | 11 (padrão) | Regressão | ✘ | Gnn |
| PAWS-X | Identificação parafraseada | 4 (padrão) | Classificação | ✓ | Bert |
Crie um ambiente virtual
conda create -n libmtl python=3.8
conda activate libmtl
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.htmlClone o repositório
git clone https://github.com/median-research-group/LibMTL.git Instale o LibMTL
cd LibMTL
pip install -r requirements.txt
pip install -e . Usamos o conjunto de dados NYUV2 como exemplo para mostrar como usar LibMTL .
O conjunto de dados NYUV2 que usamos é pré-processado pelo MTAN. Você pode baixar este conjunto de dados aqui.
O código de treinamento completo para o conjunto de dados NYUV2 é fornecido em exemplos/NYU. O arquivo main.py é o arquivo principal para treinamento no conjunto de dados NYUV2.
Você pode encontrar os argumentos da linha de comando executando o seguinte comando.
python main.py -hPor exemplo, a execução do comando a seguir treinará um modelo MTL com EW e HPS no conjunto de dados NYUV2.
python main.py --weighting EW --arch HPS --dataset_path /path/to/nyuv2 --gpu_id 0 --scheduler step --mode train --save_path PATHMais detalhes são representados nos documentos.
Se você achar LibMTL útil para sua pesquisa ou desenvolvimento, cite o seguinte:
@article{lin2023libmtl,
title={{LibMTL}: A {P}ython Library for Multi-Task Learning},
author={Baijiong Lin and Yu Zhang},
journal={Journal of Machine Learning Research},
volume={24},
number={209},
pages={1--7},
year={2023}
} LibMTL é desenvolvido e mantido por Baijiong Lin.
Se você tiver alguma dúvida ou sugestão, não hesite em entrar em contato conosco levantando um problema ou enviando um email para [email protected] .
Gostaríamos de agradecer aos autores que lançam os repositórios públicos (listados em ordem alfabética): CAGRAD, DSELECT_K_MOE, MultioBjectiveOtimization, MTAN, MTL, Nash-MTL, Pytorch_geométrico e Xtreme.
LibMTL é liberado sob a licença do MIT.