
Einführung | Installation | Fangen Sie an | Dokumentation | Meldungsprobleme
Englisch | 简体中文











V0.10.5 wurde am 2024-9-11 veröffentlicht.
Highlights:
artifact_location in mlflowvisbackend #1505exclude_frozen_parameters für DeepSpeedEngine._zero3_consolidated_16bit_state_dict #1517Lesen Sie ChangeLog für weitere Details.
Mmengine ist eine grundlegende Bibliothek für das Training von Deep -Learning -Modellen, die auf Pytorch basieren. Es dient als Trainingsmotor aller OpenMMLab -Codebasen, die Hunderte von Algorithmen in verschiedenen Forschungsbereichen unterstützen. Darüber hinaus ist Mmengine auch generisch, um auf Nicht-OpenMMMLab-Projekte angewendet zu werden. Die Höhepunkte sind wie folgt:
Integrieren
Unterstützt eine Vielzahl von Schulungsstrategien
Bietet ein benutzerfreundliches Konfigurationssystem
Deckt die Mainstream -Trainingsüberwachungsplattformen ab
| Mmengine | Pytorch | Python |
|---|---|---|
| hauptsächlich | > = 1,6 <= 2,1 | > = 3,8, <= 3,11 |
| > = 0.9.0, <= 0,10,4 | > = 1,6 <= 2,1 | > = 3,8, <= 3,11 |
Stellen Sie vor dem Installieren von Mmengine bitte sicher, dass Pytorch nach dem offiziellen Leitfaden erfolgreich installiert wurde.
Mmengine installieren
pip install -U openmim
mim install mmengineÜberprüfen Sie die Installation
python -c ' from mmengine.utils.dl_utils import collect_env;print(collect_env()) ' Wenn wir das Training eines Resnet-50-Modells auf dem CIFAR-10-Datensatz als Beispiel nutzen, werden wir mit MMengine einen vollständigen, konfigurierbaren Schulungs- und Validierungsprozess in weniger als 80 Codezeilen erstellen.
Zunächst müssen wir ein Modell definieren, das 1) von BaseModel erbt und 2) zusätzlich zu den Argumenten im Zusammenhang mit dem Datensatz einen zusätzlichen mode in der forward akzeptiert.
mode "Verlust", und die forward sollte ein dict zurückgeben, das den Schlüssel "Verlust" enthält.mode "vorherzusagen", und die Vorwärtsmethode sollte die Ergebnisse zurückgeben, die sowohl Vorhersagen als auch Etiketten enthalten. import torch . nn . functional as F
import torchvision
from mmengine . model import BaseModel
class MMResNet50 ( BaseModel ):
def __init__ ( self ):
super (). __init__ ()
self . resnet = torchvision . models . resnet50 ()
def forward ( self , imgs , labels , mode ):
x = self . resnet ( imgs )
if mode == 'loss' :
return { 'loss' : F . cross_entropy ( x , labels )}
elif mode == 'predict' :
return x , labelsAls nächstes müssen wir Datensatz und Dataloader für Schulungen und Validierung erstellen. In diesem Fall verwenden wir einfach integrierte Datensätze, die in Torchvision unterstützt werden.
import torchvision . transforms as transforms
from torch . utils . data import DataLoader
norm_cfg = dict ( mean = [ 0.491 , 0.482 , 0.447 ], std = [ 0.202 , 0.199 , 0.201 ])
train_dataloader = DataLoader ( batch_size = 32 ,
shuffle = True ,
dataset = torchvision . datasets . CIFAR10 (
'data/cifar10' ,
train = True ,
download = True ,
transform = transforms . Compose ([
transforms . RandomCrop ( 32 , padding = 4 ),
transforms . RandomHorizontalFlip (),
transforms . ToTensor (),
transforms . Normalize ( ** norm_cfg )
])))
val_dataloader = DataLoader ( batch_size = 32 ,
shuffle = False ,
dataset = torchvision . datasets . CIFAR10 (
'data/cifar10' ,
train = False ,
download = True ,
transform = transforms . Compose ([
transforms . ToTensor (),
transforms . Normalize ( ** norm_cfg )
]))) Um das Modell zu validieren und zu testen, müssen wir eine Metrik definieren, die als Genauigkeit bezeichnet wird, um das Modell zu bewerten. Diese Metrik muss von BaseMetric Erben erben und implementiert die process und die Methoden compute_metrics .
from mmengine . evaluator import BaseMetric
class Accuracy ( BaseMetric ):
def process ( self , data_batch , data_samples ):
score , gt = data_samples
# Save the results of a batch to `self.results`
self . results . append ({
'batch_size' : len ( gt ),
'correct' : ( score . argmax ( dim = 1 ) == gt ). sum (). cpu (),
})
def compute_metrics ( self , results ):
total_correct = sum ( item [ 'correct' ] for item in results )
total_size = sum ( item [ 'batch_size' ] for item in results )
# Returns a dictionary with the results of the evaluated metrics,
# where the key is the name of the metric
return dict ( accuracy = 100 * total_correct / total_size ) Schließlich können wir einen Läufer mit zuvor definiertem Model , DataLoader und Metrics mit einigen anderen Konfigurationen erstellen, wie unten gezeigt.
from torch . optim import SGD
from mmengine . runner import Runner
runner = Runner (
model = MMResNet50 (),
work_dir = './work_dir' ,
train_dataloader = train_dataloader ,
# a wrapper to execute back propagation and gradient update, etc.
optim_wrapper = dict ( optimizer = dict ( type = SGD , lr = 0.001 , momentum = 0.9 )),
# set some training configs like epochs
train_cfg = dict ( by_epoch = True , max_epochs = 5 , val_interval = 1 ),
val_dataloader = val_dataloader ,
val_cfg = dict (),
val_evaluator = dict ( type = Accuracy ),
) runner . train ()Wir schätzen alle Beiträge zur Verbesserung von Mmengine. Weitere Informationen finden Sie in der beitragenden Richtlinie.
Wenn Sie dieses Projekt in Ihrer Forschung nützlich finden, sollten Sie zitieren:
@article{mmengine2022,
title = {{MMEngine}: OpenMMLab Foundational Library for Training Deep Learning Models},
author = {MMEngine Contributors},
howpublished = {url{https://github.com/open-mmlab/mmengine}},
year={2022}
}
Dieses Projekt wird unter der Apache 2.0 -Lizenz veröffentlicht.