
Introdução | Instalação | Comece | Documentação | ? Problemas de relatório
Inglês | 简体中文











V0.10.5 foi lançado em 2024-9-11.
Destaques:
artifact_location personalizado em MLFlowVisbackend #1505exclude_frozen_parameters for DeepSpeedEngine._zero3_consolidated_16bit_state_dict #1517Leia Changelog para mais detalhes.
O MMEngine é uma biblioteca fundamental para o treinamento de modelos de aprendizado profundo baseados em Pytorch. Serve como mecanismo de treinamento de todas as bases de código OpenMmlab, que suportam centenas de algoritmos em várias áreas de pesquisa. Além disso, o MMEngine também é genérico para ser aplicado a projetos não openmmlab. Seus destaques são os seguintes:
Integrar estruturas de treinamento modelo em larga escala
Suporta uma variedade de estratégias de treinamento
Fornece um sistema de configuração fácil de usar
Cobre plataformas de monitoramento de treinamento convencional
| Mmengine | Pytorch | Python |
|---|---|---|
| principal | > = 1.6 <= 2.1 | > = 3.8, <= 3.11 |
| > = 0.9.0, <= 0,10.4 | > = 1.6 <= 2.1 | > = 3.8, <= 3.11 |
Antes de instalar o MMEngine, verifique se o Pytorch foi instalado com sucesso seguindo o guia oficial.
Instale o mmengine
pip install -U openmim
mim install mmengineVerifique a instalação
python -c ' from mmengine.utils.dl_utils import collect_env;print(collect_env()) ' Tomando o treinamento de um modelo RESNET-50 no conjunto de dados CIFAR-10 como exemplo, usaremos o MMEngine para criar um processo completo e configurável de treinamento e validação em menos de 80 linhas de código.
Primeiro, precisamos definir um modelo que 1) herda do BaseModel e 2) aceita um mode de argumento adicional no método forward , além dos argumentos relacionados ao conjunto de dados.
mode é "perda" e o método forward deve retornar um dict que contém a chave "perda".mode é "previsto" e o método avançado deve retornar os resultados contendo previsões e rótulos. import torch . nn . functional as F
import torchvision
from mmengine . model import BaseModel
class MMResNet50 ( BaseModel ):
def __init__ ( self ):
super (). __init__ ()
self . resnet = torchvision . models . resnet50 ()
def forward ( self , imgs , labels , mode ):
x = self . resnet ( imgs )
if mode == 'loss' :
return { 'loss' : F . cross_entropy ( x , labels )}
elif mode == 'predict' :
return x , labelsEm seguida, precisamos criar conjuntos de dados e dados para treinamento e validação. Nesse caso, simplesmente usamos conjuntos de dados embutidos suportados no Torchvision.
import torchvision . transforms as transforms
from torch . utils . data import DataLoader
norm_cfg = dict ( mean = [ 0.491 , 0.482 , 0.447 ], std = [ 0.202 , 0.199 , 0.201 ])
train_dataloader = DataLoader ( batch_size = 32 ,
shuffle = True ,
dataset = torchvision . datasets . CIFAR10 (
'data/cifar10' ,
train = True ,
download = True ,
transform = transforms . Compose ([
transforms . RandomCrop ( 32 , padding = 4 ),
transforms . RandomHorizontalFlip (),
transforms . ToTensor (),
transforms . Normalize ( ** norm_cfg )
])))
val_dataloader = DataLoader ( batch_size = 32 ,
shuffle = False ,
dataset = torchvision . datasets . CIFAR10 (
'data/cifar10' ,
train = False ,
download = True ,
transform = transforms . Compose ([
transforms . ToTensor (),
transforms . Normalize ( ** norm_cfg )
]))) Para validar e testar o modelo, precisamos definir uma métrica chamada precisão para avaliar o modelo. Essa métrica precisa herdar a partir de BaseMetric e implementar o process e os métodos compute_metrics .
from mmengine . evaluator import BaseMetric
class Accuracy ( BaseMetric ):
def process ( self , data_batch , data_samples ):
score , gt = data_samples
# Save the results of a batch to `self.results`
self . results . append ({
'batch_size' : len ( gt ),
'correct' : ( score . argmax ( dim = 1 ) == gt ). sum (). cpu (),
})
def compute_metrics ( self , results ):
total_correct = sum ( item [ 'correct' ] for item in results )
total_size = sum ( item [ 'batch_size' ] for item in results )
# Returns a dictionary with the results of the evaluated metrics,
# where the key is the name of the metric
return dict ( accuracy = 100 * total_correct / total_size ) Por fim, podemos construir um corredor com Model , DataLoader e Metrics previamente definidos, com algumas outras configurações, como mostrado abaixo.
from torch . optim import SGD
from mmengine . runner import Runner
runner = Runner (
model = MMResNet50 (),
work_dir = './work_dir' ,
train_dataloader = train_dataloader ,
# a wrapper to execute back propagation and gradient update, etc.
optim_wrapper = dict ( optimizer = dict ( type = SGD , lr = 0.001 , momentum = 0.9 )),
# set some training configs like epochs
train_cfg = dict ( by_epoch = True , max_epochs = 5 , val_interval = 1 ),
val_dataloader = val_dataloader ,
val_cfg = dict (),
val_evaluator = dict ( type = Accuracy ),
) runner . train ()Agradecemos todas as contribuições para melhorar o mmengine. Consulte Contribuindo.md para obter a diretriz contribuinte.
Se você achar este projeto útil em sua pesquisa, considere citar:
@article{mmengine2022,
title = {{MMEngine}: OpenMMLab Foundational Library for Training Deep Learning Models},
author = {MMEngine Contributors},
howpublished = {url{https://github.com/open-mmlab/mmengine}},
year={2022}
}
Este projeto é lançado sob a licença Apache 2.0.