
Introducción | Instalación | Empiece | Documentación | ? Problemas de informes
Inglés | 简体中文











V0.10.5 se lanzó el 2024-9-11.
Reflejos:
artifact_location personalizado en mlflowvisbackend #1505exclude_frozen_parameters para DeepSpeedEngine._zero3_consolidated_16bit_state_dict #1517Lea ChangeLog para más detalles.
MMEngine es una biblioteca fundamental para capacitar modelos de aprendizaje profundo basados en Pytorch. Sirve como motor de entrenamiento de todas las bases de código OpenMMLab, que admiten cientos de algoritmos en diversas áreas de investigación. Además, MMEngine también es genérico que se aplicará a proyectos no abiertos. Sus aspectos más destacados son los siguientes:
Integrar marcos de capacitación de modelos a gran escala en los principales
Apoya una variedad de estrategias de capacitación
Proporciona un sistema de configuración fácil de usar
Cubre plataformas de monitoreo de capacitación convencional
| Motor mm | Pytorch | Pitón |
|---|---|---|
| principal | > = 1.6 <= 2.1 | > = 3.8, <= 3.11 |
| > = 0.9.0, <= 0.10.4 | > = 1.6 <= 2.1 | > = 3.8, <= 3.11 |
Antes de instalar MMEngine, asegúrese de que Pytorch haya sido instalado con éxito después de la guía oficial.
Instalar mmEngine
pip install -U openmim
mim install mmengineVerificar la instalación
python -c ' from mmengine.utils.dl_utils import collect_env;print(collect_env()) ' Tomando la capacitación de un modelo RESNET-50 en el conjunto de datos CIFAR-10 como ejemplo, utilizaremos MMEngine para construir un proceso completo y de validación configurable en menos de 80 líneas de código.
Primero, necesitamos definir un modelo que 1) hereda de BaseModel y 2) acepta un mode de argumento adicional en el método forward , además de los argumentos relacionados con el conjunto de datos.
mode es "pérdida", y el método forward debería devolver un dict que contenga la "pérdida" clave.mode se "predice", y el método de avance debería devolver los resultados que contienen predicciones y etiquetas. import torch . nn . functional as F
import torchvision
from mmengine . model import BaseModel
class MMResNet50 ( BaseModel ):
def __init__ ( self ):
super (). __init__ ()
self . resnet = torchvision . models . resnet50 ()
def forward ( self , imgs , labels , mode ):
x = self . resnet ( imgs )
if mode == 'loss' :
return { 'loss' : F . cross_entropy ( x , labels )}
elif mode == 'predict' :
return x , labelsA continuación, necesitamos crear conjuntos de datos y dataloader s para capacitación y validación. En este caso, simplemente utilizamos conjuntos de datos integrados admitidos en TorchVision.
import torchvision . transforms as transforms
from torch . utils . data import DataLoader
norm_cfg = dict ( mean = [ 0.491 , 0.482 , 0.447 ], std = [ 0.202 , 0.199 , 0.201 ])
train_dataloader = DataLoader ( batch_size = 32 ,
shuffle = True ,
dataset = torchvision . datasets . CIFAR10 (
'data/cifar10' ,
train = True ,
download = True ,
transform = transforms . Compose ([
transforms . RandomCrop ( 32 , padding = 4 ),
transforms . RandomHorizontalFlip (),
transforms . ToTensor (),
transforms . Normalize ( ** norm_cfg )
])))
val_dataloader = DataLoader ( batch_size = 32 ,
shuffle = False ,
dataset = torchvision . datasets . CIFAR10 (
'data/cifar10' ,
train = False ,
download = True ,
transform = transforms . Compose ([
transforms . ToTensor (),
transforms . Normalize ( ** norm_cfg )
]))) Para validar y probar el modelo, necesitamos definir una métrica llamada precisión para evaluar el modelo. Esta métrica necesita heredar de BaseMetric e implementar los métodos process y compute_metrics .
from mmengine . evaluator import BaseMetric
class Accuracy ( BaseMetric ):
def process ( self , data_batch , data_samples ):
score , gt = data_samples
# Save the results of a batch to `self.results`
self . results . append ({
'batch_size' : len ( gt ),
'correct' : ( score . argmax ( dim = 1 ) == gt ). sum (). cpu (),
})
def compute_metrics ( self , results ):
total_correct = sum ( item [ 'correct' ] for item in results )
total_size = sum ( item [ 'batch_size' ] for item in results )
# Returns a dictionary with the results of the evaluated metrics,
# where the key is the name of the metric
return dict ( accuracy = 100 * total_correct / total_size ) Finalmente, podemos construir un corredor con Model previamente definido, DataLoader y Metrics , con algunas otras configuraciones, como se muestra a continuación.
from torch . optim import SGD
from mmengine . runner import Runner
runner = Runner (
model = MMResNet50 (),
work_dir = './work_dir' ,
train_dataloader = train_dataloader ,
# a wrapper to execute back propagation and gradient update, etc.
optim_wrapper = dict ( optimizer = dict ( type = SGD , lr = 0.001 , momentum = 0.9 )),
# set some training configs like epochs
train_cfg = dict ( by_epoch = True , max_epochs = 5 , val_interval = 1 ),
val_dataloader = val_dataloader ,
val_cfg = dict (),
val_evaluator = dict ( type = Accuracy ),
) runner . train ()Apreciamos todas las contribuciones para mejorar MMEngine. Consulte CONTRIGIARSE.MD para la guía contribuyente.
Si encuentra útil este proyecto en su investigación, considere citar:
@article{mmengine2022,
title = {{MMEngine}: OpenMMLab Foundational Library for Training Deep Learning Models},
author = {MMEngine Contributors},
howpublished = {url{https://github.com/open-mmlab/mmengine}},
year={2022}
}
Este proyecto se publica bajo la licencia Apache 2.0.