The pytorch image classification function based on torchision implemented.
Updated 2022.11.05
Updated in 2022.10.29, code refactoring, basic functions are basically the same.
If you are used to the previous version, please see the code of the v1 version: V1 version.
Main functions:
Using pytorch to realize image classification, based on torchision, it can extend image classification networks such as densitynet, resnext, mobilenet, efficientnet, swin transformer, etc.
If useful, welcome to star
The organization form of the dataset, refer to sample_files/imgs/listfile.txt
Modify the parameters in run.sh , and run run.sh directly.
Main modified parameters:
OUTPUT_PATH 模型保存和log文件的路径
TRAIN_LIST 训练数据集的list文件
VAL_LIST 测试集合的list文件
model_name 默认是resnet50
lr 学习率
epochs 训练总的epoch
batch-size batch的大小
j dataloader的num_workers的大小
num_classes 类别数
The code is stored in the cpp_inference folder.
Use cpp_inference/traced_model/trace_model.py to export the trained model.
Compile the required opencv and libtorch code to cpp_inference/third_party_library
Compilation
sh compile.sh
./bin/imgCls imgpath