
はじめに|インストール|始めましょう|ドキュメント| ?報告の問題
英語| 简体中文











V0.10.5は2024-9-11にリリースされました。
ハイライト:
artifact_locationサポートしますDeepSpeedEngine._zero3_consolidated_16bit_state_dict #1517のexclude_frozen_parameters有効にします詳細については、changelogをお読みください。
Mmengineは、Pytorchに基づいた深い学習モデルをトレーニングするための基礎ライブラリです。これは、すべてのOpenMMLabコードベースのトレーニングエンジンとして機能し、さまざまな研究分野の何百ものアルゴリズムをサポートしています。さらに、MMENGINEは、OpenMMLab以外のプロジェクトに適用される一般的なものでもあります。そのハイライトは次のとおりです。
主流の大規模モデルトレーニングフレームワークを統合します
さまざまなトレーニング戦略をサポートします
ユーザーフレンドリーな構成システムを提供します
主流のトレーニング監視プラットフォームをカバーしています
| Mmengine | Pytorch | Python |
|---|---|---|
| 主要 | > = 1.6 <= 2.1 | > = 3.8、<= 3.11 |
| > = 0.9.0、<= 0.10.4 | > = 1.6 <= 2.1 | > = 3.8、<= 3.11 |
Mmengineを設置する前に、Pytorchが公式ガイドに従って正常にインストールされていることを確認してください。
Mmengineをインストールします
pip install -U openmim
mim install mmengineインストールを確認します
python -c ' from mmengine.utils.dl_utils import collect_env;print(collect_env()) ' CIFAR-10データセットでのResNet-50モデルのトレーニングを例にとると、Mmengineを使用して、80行未満のコードで完全で構成可能なトレーニングと検証プロセスを構築します。
まず、データセットに関連するこれらの引数に加えて、1) BaseModelから継承し、2) forwardメソッドで追加の引数modeを受け入れるモデルを定義する必要があります。
modeの値は「損失」であり、 forwardメソッドはキー「損失」を含むdictを返す必要があります。modeの値は「予測」であり、フォワードメソッドは予測とラベルの両方を含む結果を返す必要があります。 import torch . nn . functional as F
import torchvision
from mmengine . model import BaseModel
class MMResNet50 ( BaseModel ):
def __init__ ( self ):
super (). __init__ ()
self . resnet = torchvision . models . resnet50 ()
def forward ( self , imgs , labels , mode ):
x = self . resnet ( imgs )
if mode == 'loss' :
return { 'loss' : F . cross_entropy ( x , labels )}
elif mode == 'predict' :
return x , labels次に、トレーニングと検証のためにデータセットとデータローダーを作成する必要があります。この場合、TorchVisionでサポートされている組み込みデータセットを使用するだけです。
import torchvision . transforms as transforms
from torch . utils . data import DataLoader
norm_cfg = dict ( mean = [ 0.491 , 0.482 , 0.447 ], std = [ 0.202 , 0.199 , 0.201 ])
train_dataloader = DataLoader ( batch_size = 32 ,
shuffle = True ,
dataset = torchvision . datasets . CIFAR10 (
'data/cifar10' ,
train = True ,
download = True ,
transform = transforms . Compose ([
transforms . RandomCrop ( 32 , padding = 4 ),
transforms . RandomHorizontalFlip (),
transforms . ToTensor (),
transforms . Normalize ( ** norm_cfg )
])))
val_dataloader = DataLoader ( batch_size = 32 ,
shuffle = False ,
dataset = torchvision . datasets . CIFAR10 (
'data/cifar10' ,
train = False ,
download = True ,
transform = transforms . Compose ([
transforms . ToTensor (),
transforms . Normalize ( ** norm_cfg )
])))モデルを検証してテストするには、モデルを評価するために精度と呼ばれるメトリックを定義する必要があります。このメトリックは、 BaseMetricから継承し、 processとcompute_metricsメソッドを実装する必要があります。
from mmengine . evaluator import BaseMetric
class Accuracy ( BaseMetric ):
def process ( self , data_batch , data_samples ):
score , gt = data_samples
# Save the results of a batch to `self.results`
self . results . append ({
'batch_size' : len ( gt ),
'correct' : ( score . argmax ( dim = 1 ) == gt ). sum (). cpu (),
})
def compute_metrics ( self , results ):
total_correct = sum ( item [ 'correct' ] for item in results )
total_size = sum ( item [ 'batch_size' ] for item in results )
# Returns a dictionary with the results of the evaluated metrics,
# where the key is the name of the metric
return dict ( accuracy = 100 * total_correct / total_size )最後に、以下に示すように、以前に定義されたModel 、 DataLoader 、およびMetricsを備えたランナーを構築できます。
from torch . optim import SGD
from mmengine . runner import Runner
runner = Runner (
model = MMResNet50 (),
work_dir = './work_dir' ,
train_dataloader = train_dataloader ,
# a wrapper to execute back propagation and gradient update, etc.
optim_wrapper = dict ( optimizer = dict ( type = SGD , lr = 0.001 , momentum = 0.9 )),
# set some training configs like epochs
train_cfg = dict ( by_epoch = True , max_epochs = 5 , val_interval = 1 ),
val_dataloader = val_dataloader ,
val_cfg = dict (),
val_evaluator = dict ( type = Accuracy ),
) runner . train ()Mmengineを改善するためのすべての貢献に感謝しています。貢献ガイドラインについては、converting.mdを参照してください。
このプロジェクトがあなたの研究で役立つと思う場合は、引用を検討してください。
@article{mmengine2022,
title = {{MMEngine}: OpenMMLab Foundational Library for Training Deep Learning Models},
author = {MMEngine Contributors},
howpublished = {url{https://github.com/open-mmlab/mmengine}},
year={2022}
}
このプロジェクトは、Apache 2.0ライセンスの下でリリースされます。