pytorch_classification
1.0.0
実装されたトーチに基づくPytorch画像分類関数。
更新2022.11.05
2022.10.29に更新されたコードリファクタリング、基本的な機能は基本的に同じです。
前のバージョンに慣れている場合は、V1バージョンのコード:V1バージョンをご覧ください。
主な機能:
Pytorchを使用して、トーチに基づいて画像分類を実現すると、DensityNet、ResNext、MobileNet、EfficientNet、Swin Transformerなどの画像分類ネットワークを拡張できます。
便利なら、スターにようこそ
データセットの組織形式は、sample_files/imgs/listfile.txtを参照してください
run.shのパラメーターを変更し、run.shを直接実行します。
主な修正パラメーター:
OUTPUT_PATH 模型保存和log文件的路径
TRAIN_LIST 训练数据集的list文件
VAL_LIST 测试集合的list文件
model_name 默认是resnet50
lr 学习率
epochs 训练总的epoch
batch-size batch的大小
j dataloader的num_workers的大小
num_classes 类别数
コードはcpp_inferenceフォルダーに保存されます。
cpp_inference/traced_model/trace_model.pyを使用して、訓練されたモデルをエクスポートします。
必要なopencvおよびlibtorchコードをcpp_inference/third_party_libraryにコンパイルします
編集
sh compile.sh
./bin/imgCls imgpath