LibMTL 、マルチタスク学習(MTL)のためにPytorchに基づいて構築されたオープンソースライブラリです。詳細な紹介とAPIの指示については、最新のドキュメントを参照してください。
Githubで私たちを主演してください - それは私たちをたくさんやる気にさせます!
LibMTL 、実装する統一コードベースと、いくつかの代表的なMTLベンチマークデータセットのデータ処理、メトリック目標、およびハイパーパラメーターを含む一貫した評価手順を提供します。LibMTL 、8つのアーキテクチャと16の最適化戦略を含む多くの最先端のMTLメソッドをサポートしています。一方、 LibMTL 、異なるフィールドをカバーするいくつかのベンチマークデータセットの公正な比較を提供します。LibMTLモジュラー設計原則に従います。これにより、ユーザーはカスタマイズされたコンポーネントを柔軟かつ便利に追加したり、パーソナライズされた変更を加えたりできます。したがって、ユーザーは、新しい最適化戦略とアーキテクチャを簡単かつ高速に開発したり、既存のMTLアルゴリズムをLibMTLのサポートを受けて新しいアプリケーションシナリオに適用できます。 
各モジュールはドキュメントで紹介されます。
LibMTL現在、次のアルゴリズムをサポートしています。
| 最適化戦略 | 会場 | 議論 |
|---|---|---|
| 等しい重み付け(EW) | - | --weighting EW |
| グラジエント正規化(GradNorm) | ICML 2018 | --weighting GradNorm |
| 不確実性の重み(UW) | CVPR 2018 | --weighting UW |
| MGDA(公式コード) | Neurips 2018 | --weighting MGDA |
| 動的重量平均(DWA)(公式コード) | CVPR 2019 | --weighting DWA |
| 幾何学的損失戦略(GLS) | CVPR 2019ワークショップ | --weighting GLS |
| 対立する勾配を投影(PCGRAD) | ニューリップ2020 | --weighting PCGrad |
| グラデーションサインドロップアウト(グラデロップ) | ニューリップ2020 | --weighting GradDrop |
| 公平なマルチタスク学習(IMTL) | ICLR 2021 | --weighting IMTL |
| グラジエントワクチン(GradVac) | ICLR 2021 | --weighting GradVac |
| 紛争回避勾配降下(Cagrad)(公式コード) | ニューリップ2021 | --weighting CAGrad |
| nash-mtl(公式コード) | ICML 2022 | --weighting Nash_MTL |
| ランダム損失の重み付け(RLW) | TMLR 2022 | --weighting RLW |
| モコ | ICLR 2023 | --weighting MoCo |
| aligned-mtl(公式コード) | CVPR 2023 | --weighting Aligned_MTL |
| stch(公式コード) | ICML 2024 | --weighting STCH |
| sprecemtl(公式コード) | ICML 2024 | --weighting ExcessMTL |
| フェアグラード(公式コード) | ICML 2024 | --weighting FairGrad |
| db-mtl | arxiv | --weighting DB_MTL |
| アーキテクチャ | 会場 | 議論 |
|---|---|---|
| ハードパラメーター共有(HPS) | ICML 1993 | --arch HPS |
| クロスステッチネットワーク(Cross_stitch) | CVPR 2016 | --arch Cross_stitch |
| マルチゲートの混合物(MMOE) | KDD 2018 | --arch MMoE |
| マルチタスク注意ネットワーク(MTAN)(公式コード) | CVPR 2019 | --arch MTAN |
| カスタマイズされたゲートコントロール(CGC)、プログレッシブ層抽出(PLE) | ACM Recsys 2020 | --arch CGC 、 --arch PLE |
| 分岐する学習(LTB) | ICML 2020 | --arch LTB |
| DSELECT-K(公式コード) | ニューリップ2021 | --arch DSelect_k |
| データセット | 問題 | タスク番号 | タスク | マルチ入力 | サポートされたバックボーン |
|---|---|---|---|---|---|
| NYUV2 | シーンの理解 | 3 | セマンティックセグメンテーション+ 深さ推定+ 表面正常予測 | ✘ | resnet50/ segnet |
| 街並み | シーンの理解 | 2 | セマンティックセグメンテーション+ 深度推定 | ✘ | Resnet50 |
| オフィス-31 | 画像認識 | 3 | 分類 | ✓✓ | resnet18 |
| Office-Home | 画像認識 | 4 | 分類 | ✓✓ | resnet18 |
| QM9 | 分子特性予測 | 11(デフォルト) | 回帰 | ✘ | GNN |
| PAWS-X | 言い換え識別 | 4(デフォルト) | 分類 | ✓✓ | バート |
仮想環境を作成します
conda create -n libmtl python=3.8
conda activate libmtl
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.htmlリポジトリをクローンします
git clone https://github.com/median-research-group/LibMTL.git LibMTLをインストールします
cd LibMTL
pip install -r requirements.txt
pip install -e . NYUV2データセットを例として使用して、 LibMTL使用方法を示しています。
使用したNYUV2データセットは、MTANによって前処理されます。このデータセットはこちらからダウンロードできます。
NYUV2データセットの完全なトレーニングコードは、例/NYUで提供されています。ファイルMain.pyは、NYUV2データセットのトレーニングのメインファイルです。
次のコマンドを実行することにより、コマンドライン引数を見つけることができます。
python main.py -hたとえば、次のコマンドを実行すると、NYUV2データセットでEWとHPSを使用してMTLモデルをトレーニングします。
python main.py --weighting EW --arch HPS --dataset_path /path/to/nyuv2 --gpu_id 0 --scheduler step --mode train --save_path PATH詳細については、ドキュメントに記載されています。
研究や開発に役立つLibMTLある場合は、以下を引用してください。
@article{lin2023libmtl,
title={{LibMTL}: A {P}ython Library for Multi-Task Learning},
author={Baijiong Lin and Yu Zhang},
journal={Journal of Machine Learning Research},
volume={24},
number={209},
pages={1--7},
year={2023}
}LibMTLは、Baijiong Linによって開発および維持されています。
質問や提案がある場合は、問題を提起するか、 [email protected]にメールを送信して、お気軽にお問い合わせください。
公開リポジトリ(アルファベット順にリスト)をリリースした著者に感謝します:Cagrad、dselect_k_moe、multiopumsiveoptimization、mtan、mtl、nash-mtl、pytorch_geometric、xtreme。
LibMTLはMITライセンスの下でリリースされます。