Deep Learning Project Template
1.0.0
一個簡單且設計精良的結構對於任何深度學習項目都是必不可少的,因此經過大量的練習並在Pytorch項目中做出了貢獻,這是一個pytorch項目模板,結合了簡單性,最佳的文件夾結構和良好的OOP設計。主要的想法是,每次啟動Pytorch項目時,您每次都會做很多事情,因此,每次啟動新的Pytorch項目時,包裝所有這些共享的東西都會幫助您更改核心想法。
因此,這是一個簡單的Pytorch模板,可幫助您更快地進入主要項目,並專注於核心(模型架構,培訓流等)
為了減少重複的內容,我們建議使用高級庫。您可以編寫自己的高級庫,也可以使用一些第三部分的庫,例如IGNITE,FASTAI,MMCV…等。這可以幫助您在幾行代碼中編寫緊湊但功能齊全的訓練循環。在這裡,我們使用IGNITE來訓練MNIST。
簡而言之,這是使用此模板的方法,例如,假設您要實現Resnet-18來培訓MNIST,因此您應該執行以下操作:
modeling文件夾創建一個python文件時,名為您喜歡的任何內容,我們在這裡將其命名為example_model.py 。在modeling/__init__.py文件中,您可以構建一個名為build_model的函數以調用您的模型 from . example_model import ResNet18
def build_model ( cfg ):
model = ResNet18 ( cfg . MODEL . NUM_CLASSES )
return modelengine文件夾中創建模型培訓儀功能和推理功能。在Trainer功能中,您需要編寫培訓過程的邏輯,您可以使用一些第三方庫來減少重複的內容。 # trainer
def do_train ( cfg , model , train_loader , val_loader , optimizer , scheduler , loss_fn ):
"""
implement the logic of epoch:
-loop on the number of iterations in the config and call the train step
-add any summaries you want using the summary
"""
pass
# inference
def inference ( cfg , model , val_loader ):
"""
implement the logic of the train step
- run the tensorflow session
- return any metrics you need to summarize
"""
passtools文件夾中,您可以創建train.py 。在此文件中,您需要獲取以下對象的實例“模型”,“ dataloader”,“ optimizer”和config # create instance of the model you want
model = build_model ( cfg )
# create your data generator
train_loader = make_data_loader ( cfg , is_train = True )
val_loader = make_data_loader ( cfg , is_train = False )
# create your model optimizer
optimizer = make_optimizer ( cfg , model )do_train ,然後開始培訓 # here you train your model
do_train ( cfg , model , train_loader , val_loader , optimizer , None , F . cross_entropy )您將在模型和培訓儀文件夾中找到一個模板文件和一個簡單的示例,向您展示如何簡單地嘗試第一個模型。
├── config
│ └── defaults.py - here's the default config file.
│
│
├── configs
│ └── train_mnist_softmax.yml - here's the specific config file for specific model or dataset.
│
│
├── data
│ └── datasets - here's the datasets folder that is responsible for all data handling.
│ └── transforms - here's the data preprocess folder that is responsible for all data augmentation.
│ └── build.py - here's the file to make dataloader.
│ └── collate_batch.py - here's the file that is responsible for merges a list of samples to form a mini-batch.
│
│
├── engine
│ ├── trainer.py - this file contains the train loops.
│ └── inference.py - this file contains the inference process.
│
│
├── layers - this folder contains any customed layers of your project.
│ └── conv_layer.py
│
│
├── modeling - this folder contains any model of your project.
│ └── example_model.py
│
│
├── solver - this folder contains optimizer of your project.
│ └── build.py
│ └── lr_scheduler.py
│
│
├── tools - here's the train/test model of your project.
│ └── train_net.py - here's an example of train model that is responsible for the whole pipeline.
│
│
└── utils
│ ├── logger.py
│ └── any_other_utils_you_need
│
│
└── tests - this foler contains unit test of your project.
├── test_data_sampler.py
歡迎任何一種增強或貢獻。