Deep Learning Project Template
1.0.0
一个简单且设计精良的结构对于任何深度学习项目都是必不可少的,因此经过大量的练习并在Pytorch项目中做出了贡献,这是一个pytorch项目模板,结合了简单性,最佳的文件夹结构和良好的OOP设计。主要的想法是,每次启动Pytorch项目时,您每次都会做很多事情,因此,每次启动新的Pytorch项目时,包装所有这些共享的东西都会帮助您更改核心想法。
因此,这是一个简单的Pytorch模板,可帮助您更快地进入主要项目,并专注于核心(模型架构,培训流等)
为了减少重复的内容,我们建议使用高级库。您可以编写自己的高级库,也可以使用一些第三部分的库,例如IGNITE,FASTAI,MMCV…等。这可以帮助您在几行代码中编写紧凑但功能齐全的训练循环。在这里,我们使用IGNITE来训练MNIST。
简而言之,这是使用此模板的方法,例如,假设您要实现Resnet-18来培训MNIST,因此您应该执行以下操作:
modeling文件夹创建一个python文件时,名为您喜欢的任何内容,我们在这里将其命名为example_model.py 。在modeling/__init__.py文件中,您可以构建一个名为build_model的函数以调用您的模型 from . example_model import ResNet18
def build_model ( cfg ):
model = ResNet18 ( cfg . MODEL . NUM_CLASSES )
return modelengine文件夹中创建模型培训仪功能和推理功能。在Trainer功能中,您需要编写培训过程的逻辑,您可以使用一些第三方库来减少重复的内容。 # trainer
def do_train ( cfg , model , train_loader , val_loader , optimizer , scheduler , loss_fn ):
"""
implement the logic of epoch:
-loop on the number of iterations in the config and call the train step
-add any summaries you want using the summary
"""
pass
# inference
def inference ( cfg , model , val_loader ):
"""
implement the logic of the train step
- run the tensorflow session
- return any metrics you need to summarize
"""
passtools文件夹中,您可以创建train.py 。在此文件中,您需要获取以下对象的实例“模型”,“ dataloader”,“ optimizer”和config # create instance of the model you want
model = build_model ( cfg )
# create your data generator
train_loader = make_data_loader ( cfg , is_train = True )
val_loader = make_data_loader ( cfg , is_train = False )
# create your model optimizer
optimizer = make_optimizer ( cfg , model )do_train ,然后开始培训 # here you train your model
do_train ( cfg , model , train_loader , val_loader , optimizer , None , F . cross_entropy )您将在模型和培训仪文件夹中找到一个模板文件和一个简单的示例,向您展示如何简单地尝试第一个模型。
├── config
│ └── defaults.py - here's the default config file.
│
│
├── configs
│ └── train_mnist_softmax.yml - here's the specific config file for specific model or dataset.
│
│
├── data
│ └── datasets - here's the datasets folder that is responsible for all data handling.
│ └── transforms - here's the data preprocess folder that is responsible for all data augmentation.
│ └── build.py - here's the file to make dataloader.
│ └── collate_batch.py - here's the file that is responsible for merges a list of samples to form a mini-batch.
│
│
├── engine
│ ├── trainer.py - this file contains the train loops.
│ └── inference.py - this file contains the inference process.
│
│
├── layers - this folder contains any customed layers of your project.
│ └── conv_layer.py
│
│
├── modeling - this folder contains any model of your project.
│ └── example_model.py
│
│
├── solver - this folder contains optimizer of your project.
│ └── build.py
│ └── lr_scheduler.py
│
│
├── tools - here's the train/test model of your project.
│ └── train_net.py - here's an example of train model that is responsible for the whole pipeline.
│
│
└── utils
│ ├── logger.py
│ └── any_other_utils_you_need
│
│
└── tests - this foler contains unit test of your project.
├── test_data_sampler.py
欢迎任何一种增强或贡献。