
Introduction | Installation | Commencez | Documentation | ? Reporter des problèmes
Anglais | 简体中文











V0.10.5 a été publié le 2024-9-11.
Points forts:
artifact_location personnalisés dans mlflowvisbackend # 1505exclude_frozen_parameters pour DeepSpeedEngine._zero3_consolidated_16bit_state_dict # 1517Lisez Changelog pour plus de détails.
Mmengine est une bibliothèque fondamentale pour la formation de modèles d'apprentissage en profondeur basés sur Pytorch. Il sert de moteur de formation de toutes les bases de code OpenMMLAB, qui prennent en charge des centaines d'algorithmes dans divers domaines de recherche. De plus, Mmengine est également générique à appliquer à des projets non openmmLab. Ses points forts sont les suivants:
Intégrer les cadres de formation des modèles à grande échelle grand public
Soutient une variété de stratégies de formation
Fournit un système de configuration convivial
Couvre les plateformes de surveillance de la formation traditionnelle
| Mmengine | Pytorch | Python |
|---|---|---|
| principal | > = 1,6 <= 2.1 | > = 3,8, <= 3.11 |
| > = 0,9.0, <= 0,10,4 | > = 1,6 <= 2.1 | > = 3,8, <= 3.11 |
Avant d'installer MMENGINE, veuillez vous assurer que Pytorch a été installé avec succès après le guide officiel.
Installer MMenne
pip install -U openmim
mim install mmengineVérifiez l'installation
python -c ' from mmengine.utils.dl_utils import collect_env;print(collect_env()) ' En prenant la formation d'un modèle RESNET-50 sur l'ensemble de données CIFAR-10 à titre d'exemple, nous utiliserons MMenne pour créer un processus de formation et de validation configurable complet dans moins de 80 lignes de code.
Premièrement, nous devons définir un modèle qui 1) hérite de BaseModel et 2) accepte un mode d'argument supplémentaire dans la méthode forward , en plus des arguments liés à l'ensemble de données.
mode est la "perte" et la méthode forward doit renvoyer un dict contenant la "perte" clé.mode est «prédire» et la méthode avant doit renvoyer les résultats contenant à la fois des prédictions et des étiquettes. import torch . nn . functional as F
import torchvision
from mmengine . model import BaseModel
class MMResNet50 ( BaseModel ):
def __init__ ( self ):
super (). __init__ ()
self . resnet = torchvision . models . resnet50 ()
def forward ( self , imgs , labels , mode ):
x = self . resnet ( imgs )
if mode == 'loss' :
return { 'loss' : F . cross_entropy ( x , labels )}
elif mode == 'predict' :
return x , labelsEnsuite, nous devons créer des ensembles de données et des dataloder pour la formation et la validation. Dans ce cas, nous utilisons simplement des ensembles de données intégrés pris en charge dans TorchVision.
import torchvision . transforms as transforms
from torch . utils . data import DataLoader
norm_cfg = dict ( mean = [ 0.491 , 0.482 , 0.447 ], std = [ 0.202 , 0.199 , 0.201 ])
train_dataloader = DataLoader ( batch_size = 32 ,
shuffle = True ,
dataset = torchvision . datasets . CIFAR10 (
'data/cifar10' ,
train = True ,
download = True ,
transform = transforms . Compose ([
transforms . RandomCrop ( 32 , padding = 4 ),
transforms . RandomHorizontalFlip (),
transforms . ToTensor (),
transforms . Normalize ( ** norm_cfg )
])))
val_dataloader = DataLoader ( batch_size = 32 ,
shuffle = False ,
dataset = torchvision . datasets . CIFAR10 (
'data/cifar10' ,
train = False ,
download = True ,
transform = transforms . Compose ([
transforms . ToTensor (),
transforms . Normalize ( ** norm_cfg )
]))) Pour valider et tester le modèle, nous devons définir une métrique appelée précision pour évaluer le modèle. Cette métrique doit hériter de BaseMetric et met en œuvre les méthodes process et compute_metrics .
from mmengine . evaluator import BaseMetric
class Accuracy ( BaseMetric ):
def process ( self , data_batch , data_samples ):
score , gt = data_samples
# Save the results of a batch to `self.results`
self . results . append ({
'batch_size' : len ( gt ),
'correct' : ( score . argmax ( dim = 1 ) == gt ). sum (). cpu (),
})
def compute_metrics ( self , results ):
total_correct = sum ( item [ 'correct' ] for item in results )
total_size = sum ( item [ 'batch_size' ] for item in results )
# Returns a dictionary with the results of the evaluated metrics,
# where the key is the name of the metric
return dict ( accuracy = 100 * total_correct / total_size ) Enfin, nous pouvons construire un coureur avec Model , DataLoader et Metrics précédemment définis, avec quelques autres configurations, comme indiqué ci-dessous.
from torch . optim import SGD
from mmengine . runner import Runner
runner = Runner (
model = MMResNet50 (),
work_dir = './work_dir' ,
train_dataloader = train_dataloader ,
# a wrapper to execute back propagation and gradient update, etc.
optim_wrapper = dict ( optimizer = dict ( type = SGD , lr = 0.001 , momentum = 0.9 )),
# set some training configs like epochs
train_cfg = dict ( by_epoch = True , max_epochs = 5 , val_interval = 1 ),
val_dataloader = val_dataloader ,
val_cfg = dict (),
val_evaluator = dict ( type = Accuracy ),
) runner . train ()Nous apprécions toutes les contributions pour améliorer Mmen en. Veuillez vous référer à contribution.md pour la directive contributive.
Si vous trouvez ce projet utile dans vos recherches, veuillez envisager citer:
@article{mmengine2022,
title = {{MMEngine}: OpenMMLab Foundational Library for Training Deep Learning Models},
author = {MMEngine Contributors},
howpublished = {url{https://github.com/open-mmlab/mmengine}},
year={2022}
}
Ce projet est publié sous la licence Apache 2.0.