LibMTL adalah perpustakaan open-source yang dibangun di atas Pytorch untuk pembelajaran multi-tugas (MTL). Lihat dokumentasi terbaru untuk perkenalan terperinci dan instruksi API.
Bintang kami di GitHub - ini banyak memotivasi kami!
LibMTL menyediakan basis kode terpadu untuk mengimplementasikan dan prosedur evaluasi yang konsisten termasuk pemrosesan data, tujuan metrik, dan hiper-parameter pada beberapa dataset benchmark MTL yang representatif, yang memungkinkan perbandingan kuantitatif, adil, dan konsisten antara berbagai algoritma MTL.LibMTL mendukung banyak metode MTL canggih termasuk 8 arsitektur dan 16 strategi optimisasi. Sementara itu, LibMTL memberikan perbandingan yang adil dari beberapa dataset benchmark yang mencakup berbagai bidang.LibMTL mengikuti prinsip -prinsip desain modular, yang memungkinkan pengguna untuk secara fleksibel dan mudah menambahkan komponen khusus atau membuat modifikasi yang dipersonalisasi. Oleh karena itu, pengguna dapat dengan mudah dan cepat mengembangkan strategi dan arsitektur optimasi baru atau menerapkan algoritma MTL yang ada ke skenario aplikasi baru dengan dukungan LibMTL . 
Setiap modul diperkenalkan dalam dokumen.
LibMTL saat ini mendukung algoritma berikut:
| Strategi optimasi | Tempat | Argumen |
|---|---|---|
| Bobot yang sama (ew) | - | --weighting EW |
| Normalisasi Gradien (Gradnorm) | ICML 2018 | --weighting GradNorm |
| Bobot ketidakpastian (UW) | CVPR 2018 | --weighting UW |
| MGDA (kode resmi) | Neurips 2018 | --weighting MGDA |
| Dynamic Weight Average (DWA) (Kode Resmi) | CVPR 2019 | --weighting DWA |
| Strategi Kehilangan Geometris (GLS) | Lokakarya CVPR 2019 | --weighting GLS |
| Proyeksi gradien yang bertentangan (pcgrad) | Neurips 2020 | --weighting PCGrad |
| Gradient Sign Dropout (Graddrop) | Neurips 2020 | --weighting GradDrop |
| Pembelajaran multi-tugas yang tidak memihak (IMTL) | ICLR 2021 | --weighting IMTL |
| Vaksin Gradien (GradVAC) | ICLR 2021 | --weighting GradVac |
| Descent Gradient (CAGRAD) (Kode Resmi) | Neurips 2021 | --weighting CAGrad |
| Nash-mtl (kode resmi) | ICML 2022 | --weighting Nash_MTL |
| Bobot kerugian acak (RLW) | TMLR 2022 | --weighting RLW |
| Moco | ICLR 2023 | --weighting MoCo |
| Aligned-MTL (kode resmi) | CVPR 2023 | --weighting Aligned_MTL |
| STCH (kode resmi) | ICML 2024 | --weighting STCH |
| EXTESSMTL (Kode Resmi) | ICML 2024 | --weighting ExcessMTL |
| Fairgrad (kode resmi) | ICML 2024 | --weighting FairGrad |
| Db-mtl | arxiv | --weighting DB_MTL |
| Arsitektur | Tempat | Argumen |
|---|---|---|
| Berbagi Parameter Hard (HPS) | ICML 1993 | --arch HPS |
| Jaringan Cross-Stitch (Cross_Stitch) | CVPR 2016 | --arch Cross_stitch |
| Campuran multi-gate-of-experts (MMOE) | KDD 2018 | --arch MMoE |
| Jaringan Perhatian Multi-Task (MTAN) (Kode Resmi) | CVPR 2019 | --arch MTAN |
| Kontrol gerbang khusus (CGC), ekstraksi berlapis progresif (PLE) | ACM Recsys 2020 | --arch CGC , --arch PLE |
| Belajar ke cabang (LTB) | ICML 2020 | --arch LTB |
| DSelect-K (kode resmi) | Neurips 2021 | --arch DSelect_k |
| Kumpulan data | Masalah | Nomor tugas | Tugas | multi-input | Backbone yang didukung |
|---|---|---|---|---|---|
| NYUV2 | Pemahaman adegan | 3 | Segmentasi semantik+ Estimasi kedalaman+ Permukaan prediksi normal | ✘ | Resnet50/ Segnet |
| Cityscapes | Pemahaman adegan | 2 | Segmentasi semantik+ Estimasi Kedalaman | ✘ | Resnet50 |
| Office-31 | Pengenalan gambar | 3 | Klasifikasi | ✓ | Resnet18 |
| Rumah kantor | Pengenalan gambar | 4 | Klasifikasi | ✓ | Resnet18 |
| Qm9 | Prediksi properti molekuler | 11 (default) | Regresi | ✘ | GNN |
| Paws-X | Identifikasi parafrase | 4 (default) | Klasifikasi | ✓ | Bert |
Buat lingkungan virtual
conda create -n libmtl python=3.8
conda activate libmtl
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.htmlKlon Repositori
git clone https://github.com/median-research-group/LibMTL.git Instal LibMTL
cd LibMTL
pip install -r requirements.txt
pip install -e . Kami menggunakan dataset NYUV2 sebagai contoh untuk menunjukkan cara menggunakan LibMTL .
Dataset NYUV2 yang kami gunakan telah diproses oleh MTAN. Anda dapat mengunduh dataset ini di sini.
Kode pelatihan lengkap untuk dataset NYUV2 disediakan dalam contoh/NYU. File Main.py adalah file utama untuk pelatihan pada dataset NYUV2.
Anda dapat menemukan argumen baris perintah dengan menjalankan perintah berikut.
python main.py -hMisalnya, menjalankan perintah berikut akan melatih model MTL dengan EW dan HPS pada dataset NYUV2.
python main.py --weighting EW --arch HPS --dataset_path /path/to/nyuv2 --gpu_id 0 --scheduler step --mode train --save_path PATHRincian lebih lanjut diwakili dalam dokumen.
Jika Anda menemukan LibMTL berguna untuk penelitian atau pengembangan Anda, silakan kutip yang berikut:
@article{lin2023libmtl,
title={{LibMTL}: A {P}ython Library for Multi-Task Learning},
author={Baijiong Lin and Yu Zhang},
journal={Journal of Machine Learning Research},
volume={24},
number={209},
pages={1--7},
year={2023}
} LibMTL dikembangkan dan dikelola oleh Baijiong Lin.
Jika Anda memiliki pertanyaan atau saran, jangan ragu untuk menghubungi kami dengan mengangkat masalah atau mengirim email ke [email protected] .
Kami ingin mengucapkan terima kasih kepada penulis yang merilis repositori publik (terdaftar secara abjad): CAGRAD, DSELECT_K_MOE, MultiobjectiveOptimization, Mtan, MTL, Nash-MTL, PyTorch_Geometric, dan Xtreme.
LibMTL dirilis di bawah lisensi MIT.