LibMTL
1.0.0
LibMTL是建立在用於多任務學習(MTL)的Pytorch上的開源庫。有關詳細的介紹和API說明,請參見最新文檔。
在Github上明星我們 - 它極大地激勵了我們!
LibMTL提供了一個統一的代碼庫,並在幾個代表性MTL基準數據集中進行了一致的評估程序,包括數據處理,指標目標和超參數,該數據集允許定量,公平和一致的MTL算法之間的定量,公平和一致的比較。LibMTL支持許多最新的MTL方法,包括8種架構和16種優化策略。同時, LibMTL對涵蓋不同字段的幾個基準數據集進行了公平的比較。LibMTL遵循模塊化設計原理,該原理允許用戶靈活,方便地添加自定義的組件或進行個性化修改。因此,用戶可以輕鬆,快速開發新穎的優化策略和體系結構,或在LibMTL的支持下將現有的MTL算法應用於新的應用程序方案。 
每個模塊都在文檔中引入。
LibMTL當前支持以下算法:
| 優化策略 | 場地 | 爭論 |
|---|---|---|
| 相等的加權(EW) | - | --weighting EW |
| 梯度歸一化(Gradnorm) | ICML 2018 | --weighting GradNorm |
| 不確定性重量(UW) | CVPR 2018 | --weighting UW |
| MGDA(官方代碼) | Neurips 2018 | --weighting MGDA |
| 動態重量平均值(DWA)(官方代碼) | CVPR 2019 | --weighting DWA |
| 幾何損失策略(GLS) | CVPR 2019研討會 | --weighting GLS |
| 投射衝突梯度(PCGRAD) | 神經2020 | --weighting PCGrad |
| 梯度標誌輟學(GradDrop) | 神經2020 | --weighting GradDrop |
| 公正的多任務學習(IMTL) | ICLR 2021 | --weighting IMTL |
| 梯度疫苗(GRADVAC) | ICLR 2021 | --weighting GradVac |
| 規避衝突的梯度下降(CAGRAD)(官方代碼) | 神經2021 | --weighting CAGrad |
| NASH-MTL(官方代碼) | ICML 2022 | --weighting Nash_MTL |
| 隨機減肥(RLW) | TMLR 2022 | --weighting RLW |
| moco | ICLR 2023 | --weighting MoCo |
| 對齊MTL(官方代碼) | CVPR 2023 | --weighting Aligned_MTL |
| STCH(官方代碼) | ICML 2024 | --weighting STCH |
| 多餘的MTL(官方代碼) | ICML 2024 | --weighting ExcessMTL |
| Fairgrad(官方代碼) | ICML 2024 | --weighting FairGrad |
| DB-MTL | arxiv | --weighting DB_MTL |
| 體系結構 | 場地 | 爭論 |
|---|---|---|
| 硬參數共享(HPS) | ICML 1993 | --arch HPS |
| 交叉縫線網絡(cross_stitch) | CVPR 2016 | --arch Cross_stitch |
| 多門混合物(MMOE) | KDD 2018 | --arch MMoE |
| 多任務注意網絡(MTAN)(官方代碼) | CVPR 2019 | --arch MTAN |
| 定制的門控制(CGC),進行性分層提取(PLE) | ACM Recsys 2020 | --arch CGC , --arch PLE |
| 學習分支(LTB) | ICML 2020 | --arch LTB |
| DSELECT-K(官方代碼) | 神經2021 | --arch DSelect_k |
| 數據集 | 問題 | 任務號 | 任務 | 多輸入 | 支撐的骨幹 |
|---|---|---|---|---|---|
| NYUV2 | 場景的理解 | 3 | 語義分割+ 深度估計+ 表面正常預測 | ✘ | RESNET50/ segnet |
| 城市景觀 | 場景的理解 | 2 | 語義分割+ 深度估計 | ✘ | RESNET50 |
| Office-31 | 圖像識別 | 3 | 分類 | ✓ | RESNET18 |
| 辦公室 | 圖像識別 | 4 | 分類 | ✓ | RESNET18 |
| QM9 | 分子性質預測 | 11(默認) | 回歸 | ✘ | Gnn |
| paws-x | 解釋識別 | 4(默認) | 分類 | ✓ | 伯特 |
創建虛擬環境
conda create -n libmtl python=3.8
conda activate libmtl
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html克隆存儲庫
git clone https://github.com/median-research-group/LibMTL.git安裝LibMTL
cd LibMTL
pip install -r requirements.txt
pip install -e . 我們使用NYUV2數據集作為示例來展示如何使用LibMTL 。
我們使用的NYUV2數據集由MTAN預先處理。您可以在此處下載此數據集。
NYUV2數據集的完整培訓代碼在示例/NYU中提供。文件main.py是在NYUV2數據集上培訓的主要文件。
您可以通過運行以下命令找到命令行參數。
python main.py -h例如,運行以下命令將在NYUV2數據集上使用EW和HPS訓練MTL模型。
python main.py --weighting EW --arch HPS --dataset_path /path/to/nyuv2 --gpu_id 0 --scheduler step --mode train --save_path PATH文檔中有更多詳細信息表示。
如果您發現LibMTL對您的研究或開發有用,請引用以下內容:
@article{lin2023libmtl,
title={{LibMTL}: A {P}ython Library for Multi-Task Learning},
author={Baijiong Lin and Yu Zhang},
journal={Journal of Machine Learning Research},
volume={24},
number={209},
pages={1--7},
year={2023}
}LibMTL由Baijiong Lin開發和維護。
如果您有任何疑問或建議,請隨時通過提出問題或發送電子郵件至[email protected]與我們聯繫。
我們要感謝發佈公共存儲庫的作者(按字母順序列出):Cagrad,dselect_k_moe,MultiObjectiveOptimization,mtan,mtl,mtl,nash-mtl,pytorch_geometric和xtreme。
LibMTL根據MIT許可發布。