LibMTL
1.0.0
LibMTL是建立在用于多任务学习(MTL)的Pytorch上的开源库。有关详细的介绍和API说明,请参见最新文档。
在Github上明星我们 - 它极大地激励了我们!
LibMTL提供了一个统一的代码库,并在几个代表性MTL基准数据集中进行了一致的评估程序,包括数据处理,指标目标和超参数,该数据集允许定量,公平和一致的MTL算法之间的定量,公平和一致的比较。LibMTL支持许多最新的MTL方法,包括8种架构和16种优化策略。同时, LibMTL对涵盖不同字段的几个基准数据集进行了公平的比较。LibMTL遵循模块化设计原理,该原理允许用户灵活,方便地添加自定义的组件或进行个性化修改。因此,用户可以轻松,快速开发新颖的优化策略和体系结构,或在LibMTL的支持下将现有的MTL算法应用于新的应用程序方案。 
每个模块都在文档中引入。
LibMTL当前支持以下算法:
| 优化策略 | 场地 | 争论 |
|---|---|---|
| 相等的加权(EW) | - | --weighting EW |
| 梯度归一化(Gradnorm) | ICML 2018 | --weighting GradNorm |
| 不确定性重量(UW) | CVPR 2018 | --weighting UW |
| MGDA(官方代码) | Neurips 2018 | --weighting MGDA |
| 动态重量平均值(DWA)(官方代码) | CVPR 2019 | --weighting DWA |
| 几何损失策略(GLS) | CVPR 2019研讨会 | --weighting GLS |
| 投射冲突梯度(PCGRAD) | 神经2020 | --weighting PCGrad |
| 梯度标志辍学(GradDrop) | 神经2020 | --weighting GradDrop |
| 公正的多任务学习(IMTL) | ICLR 2021 | --weighting IMTL |
| 梯度疫苗(GRADVAC) | ICLR 2021 | --weighting GradVac |
| 规避冲突的梯度下降(CAGRAD)(官方代码) | 神经2021 | --weighting CAGrad |
| NASH-MTL(官方代码) | ICML 2022 | --weighting Nash_MTL |
| 随机减肥(RLW) | TMLR 2022 | --weighting RLW |
| moco | ICLR 2023 | --weighting MoCo |
| 对齐MTL(官方代码) | CVPR 2023 | --weighting Aligned_MTL |
| STCH(官方代码) | ICML 2024 | --weighting STCH |
| 多余的MTL(官方代码) | ICML 2024 | --weighting ExcessMTL |
| Fairgrad(官方代码) | ICML 2024 | --weighting FairGrad |
| DB-MTL | arxiv | --weighting DB_MTL |
| 体系结构 | 场地 | 争论 |
|---|---|---|
| 硬参数共享(HPS) | ICML 1993 | --arch HPS |
| 交叉缝线网络(cross_stitch) | CVPR 2016 | --arch Cross_stitch |
| 多门混合物(MMOE) | KDD 2018 | --arch MMoE |
| 多任务注意网络(MTAN)(官方代码) | CVPR 2019 | --arch MTAN |
| 定制的门控制(CGC),进行性分层提取(PLE) | ACM Recsys 2020 | --arch CGC , --arch PLE |
| 学习分支(LTB) | ICML 2020 | --arch LTB |
| DSELECT-K(官方代码) | 神经2021 | --arch DSelect_k |
| 数据集 | 问题 | 任务号 | 任务 | 多输入 | 支撑的骨干 |
|---|---|---|---|---|---|
| NYUV2 | 场景的理解 | 3 | 语义分割+ 深度估计+ 表面正常预测 | ✘ | RESNET50/ segnet |
| 城市景观 | 场景的理解 | 2 | 语义分割+ 深度估计 | ✘ | RESNET50 |
| Office-31 | 图像识别 | 3 | 分类 | ✓ | RESNET18 |
| 办公室 | 图像识别 | 4 | 分类 | ✓ | RESNET18 |
| QM9 | 分子性质预测 | 11(默认) | 回归 | ✘ | Gnn |
| paws-x | 解释识别 | 4(默认) | 分类 | ✓ | 伯特 |
创建虚拟环境
conda create -n libmtl python=3.8
conda activate libmtl
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html克隆存储库
git clone https://github.com/median-research-group/LibMTL.git安装LibMTL
cd LibMTL
pip install -r requirements.txt
pip install -e . 我们使用NYUV2数据集作为示例来展示如何使用LibMTL 。
我们使用的NYUV2数据集由MTAN预先处理。您可以在此处下载此数据集。
NYUV2数据集的完整培训代码在示例/NYU中提供。文件main.py是在NYUV2数据集上培训的主要文件。
您可以通过运行以下命令找到命令行参数。
python main.py -h例如,运行以下命令将在NYUV2数据集上使用EW和HPS训练MTL模型。
python main.py --weighting EW --arch HPS --dataset_path /path/to/nyuv2 --gpu_id 0 --scheduler step --mode train --save_path PATH文档中有更多详细信息表示。
如果您发现LibMTL对您的研究或开发有用,请引用以下内容:
@article{lin2023libmtl,
title={{LibMTL}: A {P}ython Library for Multi-Task Learning},
author={Baijiong Lin and Yu Zhang},
journal={Journal of Machine Learning Research},
volume={24},
number={209},
pages={1--7},
year={2023}
}LibMTL由Baijiong Lin开发和维护。
如果您有任何疑问或建议,请随时通过提出问题或发送电子邮件至[email protected]与我们联系。
我们要感谢发布公共存储库的作者(按字母顺序列出):Cagrad,dselect_k_moe,MultiObjectiveOptimization,mtan,mtl,mtl,nash-mtl,pytorch_geometric和xtreme。
LibMTL根据MIT许可发布。