pytorch_classification
1.0.0
基於torchision實現的pytorch圖像分類功能。
2022.11.05更新
2022.10.29更新,進行代碼重構,基本的功能基本一致。
習慣之前版本的請看v1版本的代碼:V1版本。
主要功能:
利用pytorch實現圖像分類,基於torchision可以擴展使用densenet,resnext,mobilenet,efficientnet,swin transformer等圖像分類網絡
如果有用歡迎star
數據集的組織形式,參考sample_files/imgs/listfile.txt
修改run.sh中的參數,直接運行run.sh即可運行
主要修改的參數:
OUTPUT_PATH 模型保存和log文件的路径
TRAIN_LIST 训练数据集的list文件
VAL_LIST 测试集合的list文件
model_name 默认是resnet50
lr 学习率
epochs 训练总的epoch
batch-size batch的大小
j dataloader的num_workers的大小
num_classes 类别数
代碼存儲在cpp_inference文件夾中。
利用cpp_inference/traced_model/trace_model.py將訓練好的模型導出。
編譯所需的opencv和libtorch代碼到cpp_inference/third_party_library
編譯
sh compile.sh
./bin/imgCls imgpath