
소개 | 설치 | 시작 | 문서 | ?보고 문제
영어 | 简体中文











V0.10.5는 2024-9-11에 풀려났다.
하이라이트:
artifact_location 지원합니다DeepSpeedEngine._zero3_consolidated_16bit_state_dict #1517 용으로 exclude_frozen_parameters 활성화하십시오자세한 내용은 ChangElog를 읽으십시오.
Mmengine은 Pytorch를 기반으로 한 딥 러닝 모델을 교육하기위한 기초 도서관입니다. 다양한 연구 분야에서 수백 개의 알고리즘을 지원하는 모든 OpenMMLAB 코드베이스의 교육 엔진 역할을합니다. 또한, Mmengine은 또한 비 Openmmlab 프로젝트에 적용되기에 일반적입니다. 하이라이트는 다음과 같습니다.
주류 대규모 모델 교육 프레임 워크를 통합합니다
다양한 교육 전략을 지원합니다
사용자 친화적 인 구성 시스템을 제공합니다
주류 교육 모니터링 플랫폼을 다룹니다
| mmengine | Pytorch | 파이썬 |
|---|---|---|
| 기본 | > = 1.6 <= 2.1 | > = 3.8, <= 3.11 |
| > = 0.9.0, <= 0.10.4 | > = 1.6 <= 2.1 | > = 3.8, <= 3.11 |
Mmengine을 설치하기 전에 공식 안내서에 따라 Pytorch가 성공적으로 설치되었는지 확인하십시오.
mmengine을 설치하십시오
pip install -U openmim
mim install mmengine설치를 확인하십시오
python -c ' from mmengine.utils.dl_utils import collect_env;print(collect_env()) ' CIFAR-10 데이터 세트에서 RESNET-50 모델의 교육을 받으면 Mmengine을 사용하여 80 줄 미만의 코드로 완전하고 구성 가능한 교육 및 유효성 검사 프로세스를 구축합니다.
먼저, 우리는 1) BaseModel 에서 상속하는 모델을 정의하고 2) 데이터 세트와 관련된 인수 외에도 forward 방법에서 추가 인수 mode 수락해야합니다.
mode 의 값은 "손실"이며, forward 방법은 "손실"키가 포함 된 dict 반환해야합니다.mode 값은 "예측"이며, 전진 방법은 예측과 레이블이 모두 포함 된 결과를 반환해야합니다. import torch . nn . functional as F
import torchvision
from mmengine . model import BaseModel
class MMResNet50 ( BaseModel ):
def __init__ ( self ):
super (). __init__ ()
self . resnet = torchvision . models . resnet50 ()
def forward ( self , imgs , labels , mode ):
x = self . resnet ( imgs )
if mode == 'loss' :
return { 'loss' : F . cross_entropy ( x , labels )}
elif mode == 'predict' :
return x , labels다음으로 교육 및 검증을 위해 데이터 세트 및 Dataloader 를 만들어야합니다. 이 경우, 우리는 단순히 Torchvision에서 지원되는 내장 데이터 세트를 사용합니다.
import torchvision . transforms as transforms
from torch . utils . data import DataLoader
norm_cfg = dict ( mean = [ 0.491 , 0.482 , 0.447 ], std = [ 0.202 , 0.199 , 0.201 ])
train_dataloader = DataLoader ( batch_size = 32 ,
shuffle = True ,
dataset = torchvision . datasets . CIFAR10 (
'data/cifar10' ,
train = True ,
download = True ,
transform = transforms . Compose ([
transforms . RandomCrop ( 32 , padding = 4 ),
transforms . RandomHorizontalFlip (),
transforms . ToTensor (),
transforms . Normalize ( ** norm_cfg )
])))
val_dataloader = DataLoader ( batch_size = 32 ,
shuffle = False ,
dataset = torchvision . datasets . CIFAR10 (
'data/cifar10' ,
train = False ,
download = True ,
transform = transforms . Compose ([
transforms . ToTensor (),
transforms . Normalize ( ** norm_cfg )
]))) 모델을 검증하고 테스트하려면 모델을 평가하기 위해 정확도라는 메트릭을 정의해야합니다. 이 메트릭은 BaseMetric 에서 상속해야하며 process 및 compute_metrics 방법을 구현해야합니다.
from mmengine . evaluator import BaseMetric
class Accuracy ( BaseMetric ):
def process ( self , data_batch , data_samples ):
score , gt = data_samples
# Save the results of a batch to `self.results`
self . results . append ({
'batch_size' : len ( gt ),
'correct' : ( score . argmax ( dim = 1 ) == gt ). sum (). cpu (),
})
def compute_metrics ( self , results ):
total_correct = sum ( item [ 'correct' ] for item in results )
total_size = sum ( item [ 'batch_size' ] for item in results )
# Returns a dictionary with the results of the evaluated metrics,
# where the key is the name of the metric
return dict ( accuracy = 100 * total_correct / total_size ) 마지막으로, 아래에 표시된 것처럼 이전에 정의 된 Model , DataLoader 및 Metrics 사용하여 러너를 구성 할 수 있습니다.
from torch . optim import SGD
from mmengine . runner import Runner
runner = Runner (
model = MMResNet50 (),
work_dir = './work_dir' ,
train_dataloader = train_dataloader ,
# a wrapper to execute back propagation and gradient update, etc.
optim_wrapper = dict ( optimizer = dict ( type = SGD , lr = 0.001 , momentum = 0.9 )),
# set some training configs like epochs
train_cfg = dict ( by_epoch = True , max_epochs = 5 , val_interval = 1 ),
val_dataloader = val_dataloader ,
val_cfg = dict (),
val_evaluator = dict ( type = Accuracy ),
) runner . train ()Mmengine을 개선하기위한 모든 기여에 감사드립니다. 기고 가이드 라인은 Contributing.md를 참조하십시오.
이 프로젝트가 연구에 유용하다고 생각되면 인용을 고려하십시오.
@article{mmengine2022,
title = {{MMEngine}: OpenMMLab Foundational Library for Training Deep Learning Models},
author = {MMEngine Contributors},
howpublished = {url{https://github.com/open-mmlab/mmengine}},
year={2022}
}
이 프로젝트는 Apache 2.0 라이센스에 따라 릴리스됩니다.