La fonction de classification de l'image Pytorch basée sur la torchision implémentée.
Mise à jour 2022.11.05
Mis à jour en 2022.10.29, refactoring de code, les fonctions de base sont fondamentalement les mêmes.
Si vous êtes habitué à la version précédente, veuillez consulter le code de la version V1: V1.
Fonctions principales:
En utilisant Pytorch pour réaliser la classification des images, en fonction de la torchision, il peut étendre les réseaux de classification d'images tels que DensityNet, Resnext, MobileNet, EfficientNet, Swin Transformer, etc.
Si utile, bienvenue à Star
La forme d'organisation de l'ensemble de données, reportez-vous à Sample_Files / IMGS / ListFile.txt
Modifiez les paramètres dans run.sh et run run.sh directement.
Paramètres principaux modifiés:
OUTPUT_PATH 模型保存和log文件的路径
TRAIN_LIST 训练数据集的list文件
VAL_LIST 测试集合的list文件
model_name 默认是resnet50
lr 学习率
epochs 训练总的epoch
batch-size batch的大小
j dataloader的num_workers的大小
num_classes 类别数
Le code est stocké dans le dossier cpp_inference .
Utilisez cpp_inference / traced_model / trace_model.py pour exporter le modèle formé.
Compilez le code OpenCV et libtorch requis à cpp_inference/third_party_library
Compilation
sh compile.sh
./bin/imgCls imgpath