LibMTL -это библиотека с открытым исходным кодом, построенную на Pytorch для многозадачного обучения (MTL). Смотрите последнюю документацию для подробных введений и инструкций API.
Сыграйте нас на GitHub - это очень мотивирует нас!
LibMTL предоставляет единую кодовую базу для реализации и последовательную процедуру оценки, включая обработку данных, метрические цели и гиперпараметры в нескольких репрезентативных наборах данных MTL, которые позволяют количественно, справедливо и последовательно сравнить между различными алгоритмами MTL.LibMTL поддерживает многие современные методы MTL, включая 8 архитектур и 16 стратегий оптимизации. Между тем, LibMTL обеспечивает справедливое сравнение нескольких контрольных наборов данных, охватывающих различные области.LibMTL следует принципам модульного проектирования, что позволяет пользователям гибко и удобно добавлять индивидуальные компоненты или вносить персонализированные модификации. Следовательно, пользователи могут легко и быстро разрабатывать новые стратегии и архитектуры оптимизации или применять существующие алгоритмы MTL к новым сценариям приложений при поддержке LibMTL . 
Каждый модуль представлен в документах.
LibMTL в настоящее время поддерживает следующие алгоритмы:
| Стратегии оптимизации | Места | Аргументы |
|---|---|---|
| Равное взвешивание (EW) | - | --weighting EW |
| Нормализация градиента (Gradnorm) | ICML 2018 | --weighting GradNorm |
| Вес неопределенности (UW) | CVPR 2018 | --weighting UW |
| MGDA (официальный код) | Neurips 2018 | --weighting MGDA |
| Динамический средний вес (DWA) (официальный код) | CVPR 2019 | --weighting DWA |
| Стратегия геометрических потерь (GLS) | CVPR 2019 семинар | --weighting GLS |
| Проецируя конфликтующий градиент (PCGrad) | Neurips 2020 | --weighting PCGrad |
| Градиент знак отсека (gradgrop) | Neurips 2020 | --weighting GradDrop |
| Беспристрастное многозадачное обучение (IMTL) | ICLR 2021 | --weighting IMTL |
| Градиентная вакцина (Gradvac) | ICLR 2021 | --weighting GradVac |
| Градиент-конфликтный спуск (Cagrad) (Официальный код) | Neurips 2021 | --weighting CAGrad |
| Nash-Mtl (официальный код) | ICML 2022 | --weighting Nash_MTL |
| Случайное взвешивание потерь (RLW) | TMLR 2022 | --weighting RLW |
| Моко | ICLR 2023 | --weighting MoCo |
| Выровнен-mtl (официальный код) | CVPR 2023 | --weighting Aligned_MTL |
| STCH (официальный код) | ICML 2024 | --weighting STCH |
| Excessmtl (официальный код) | ICML 2024 | --weighting ExcessMTL |
| Fairgrad (официальный код) | ICML 2024 | --weighting FairGrad |
| DB-MTL | arxiv | --weighting DB_MTL |
| Архитектуры | Места | Аргументы |
|---|---|---|
| Обмен жестким параметром (HPS) | ICML 1993 | --arch HPS |
| Сети кросс-сшивания (Cross_stitch) | CVPR 2016 | --arch Cross_stitch |
| Смеси с несколькими воротами (MMOE) | KDD 2018 | --arch MMoE |
| Сеть внимания с несколькими задачами (MTAN) (Официальный код) | CVPR 2019 | --arch MTAN |
| Индивидуальное управление затворами (CGC), прогрессивная слоистая экстракция (PLE) | ACM Recsys 2020 | --arch CGC , --arch PLE |
| Учиться в ветвь (LTB) | ICML 2020 | --arch LTB |
| Dselect-k (официальный код) | Neurips 2021 | --arch DSelect_k |
| Наборы данных | Проблемы | Номер задачи | Задачи | много вход | Поддерживается магистраль |
|---|---|---|---|---|---|
| NYUV2 | Понимание сцены | 3 | Семантическая сегментация+ Оценка глубины+ Поверхностный нормальный прогноз | ✘ | Resnet50/ Segnet |
| Городские пейзажи | Понимание сцены | 2 | Семантическая сегментация+ Оценка глубины | ✘ | Resnet50 |
| Офис-31 | Распознавание изображения | 3 | Классификация | ✓ | Resnet18 |
| Офис-дома | Распознавание изображения | 4 | Классификация | ✓ | Resnet18 |
| QM9 | Прогноз молекулярного свойства | 11 (по умолчанию) | Регрессия | ✘ | Г -н |
| Paws-X | Перефразирование идентификации | 4 (по умолчанию) | Классификация | ✓ | Берт |
Создать виртуальную среду
conda create -n libmtl python=3.8
conda activate libmtl
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.htmlКлонировать репозиторий
git clone https://github.com/median-research-group/LibMTL.git Установите LibMTL
cd LibMTL
pip install -r requirements.txt
pip install -e . Мы используем набор данных NYUV2 в качестве примера, чтобы показать, как использовать LibMTL .
Набор данных NYUV2, который мы использовали, предварительно обрабатывается MTAN. Вы можете скачать этот набор данных здесь.
Полный учебный код для набора данных NYUV2 представлен в примерах/NYU. File main.py является основным файлом для обучения в наборе данных NYUV2.
Вы можете найти аргументы командной строки, выполнив следующую команду.
python main.py -hНапример, запуск следующей команды будет обучать модель MTL с EW и HPS на наборе данных NYUV2.
python main.py --weighting EW --arch HPS --dataset_path /path/to/nyuv2 --gpu_id 0 --scheduler step --mode train --save_path PATHБолее подробная информация представлена в документах.
Если вы считаете LibMTL полезным для исследования или разработки, пожалуйста, укажите следующее:
@article{lin2023libmtl,
title={{LibMTL}: A {P}ython Library for Multi-Task Learning},
author={Baijiong Lin and Yu Zhang},
journal={Journal of Machine Learning Research},
volume={24},
number={209},
pages={1--7},
year={2023}
} LibMTL разрабатывается и поддерживается Baijiong Lin.
Если у вас есть какие -либо вопросы или предложения, пожалуйста, не стесняйтесь обращаться к нам, подняв проблему или отправив электронное письмо на [email protected] .
Мы хотели бы поблагодарить авторов, которые выпускают публичные репозитории (перечисленные в алфавитном порядке): Cagrad, dselect_k_moe, MultiObjectiveOptimization, MTAN, MTL, Nash-Mtl, Pytorch_geometric и Xtreme.
LibMTL выпускается по лицензии MIT.