piepline
1.0.0
基於Pytorch的神經網絡培訓管道。旨在標準化訓練過程和加速實驗。
|穩定:|最新:|
import torch
from neural_pipeline . builtin . monitors . tensorboard import TensorboardMonitor
from neural_pipeline . monitoring import LogMonitor
from neural_pipeline import DataProducer , TrainConfig , TrainStage ,
ValidationStage , Trainer , FileStructManager
from somethig import MyNet , MyDataset
fsm = FileStructManager ( base_dir = 'data' , is_continue = False )
model = MyNet (). cuda ()
train_dataset = DataProducer ([ MyDataset ()], batch_size = 4 , num_workers = 2 )
validation_dataset = DataProducer ([ MyDataset ()], batch_size = 4 , num_workers = 2 )
train_config = TrainConfig ( model , [ TrainStage ( train_dataset ),
ValidationStage ( validation_dataset )], torch . nn . NLLLoss (),
torch . optim . SGD ( model . parameters (), lr = 1e-4 , momentum = 0.5 ))
trainer = Trainer ( train_config , fsm , torch . device ( 'cuda:0' )). set_epoch_num ( 50 )
trainer . monitor_hub . add_monitor ( TensorboardMonitor ( fsm , is_continue = False ))
. add_monitor ( LogMonitor ( fsm ))
trainer . train ()這個在mydataset上訓練mynet的示例,並在張力流中進行了質量化,並以指標記錄進行進一步的實驗比較。
pip install piepline
builtin模塊: pip install tensorboardX matplotlib
pip install -U git+https://github.com/PiePline/piepline