Cet outil est conçu pour calculer la quantité théorique d'opérations multipliées dans les réseaux de neurones. Il peut également calculer le nombre de paramètres et imprimer le coût de calcul par couche d'un réseau donné.
ptflops a deux backends, pytorch et aten . Le backend pytorch est un héritage, il ne considère que nn.Modules uniquement. Cependant, il est toujours utile, car il fournit une meilleure analyse par couche pour CNNS. Dans tous les autres cas, il est recommandé d'utiliser un backend aten , qui considère les opérations Aten, et donc il couvre plus d'architectures de modèle (y compris les transformateurs). Le backend par défaut est aten . S'il vous plaît, n'utilisez pas le backend pytorch pour les architectures de transformateur.
atenverbose=True pour voir les opérations qui n'ont pas été prises en compte pendant le calcul de complexité.nn.Module . Les modules plus profonds au deuxième niveau de nidification ne sont pas indiqués dans les statistiques par couche.ignore_modules oblige ptflops à ignorer les modules répertoriés. Cela peut être utile à des fins de recherche. Par exemple, on peut supprimer toutes les convolutions du processus de comptage spécifiant ignore_modules=[torch.ops.aten.convolution, torch.ops.aten._convolution] . pytorchSupport expérimental:
torch.nn.functional.* Et tensor.* Opérations. Par conséquent, les opérations non soutenues ne contribuent pas à l'estimation finale de la complexité. Voir ptflops/pytorch_ops.py:FUNCTIONAL_MAPPING,TENSOR_OPS_MAPPING pour vérifier les opérations prises en charge. Parfois, les crochets au niveau fonctionnel entrent en conflit avec les crochets pour nn.Module (par exemple, les personnalisés). Dans ce cas, le comptage avec ces OP peut être désactivé en passant backend_specific_config={"count_functional" : False} .ptflops lance un modèle donné sur un tenseur aléatoire et estime la quantité de calculs pendant l'inférence. Les modèles compliqués peuvent avoir plusieurs entrées, certains pourraient être facultatifs. Pour construire une entrée non triviale, on peut utiliser l'argument input_constructor du get_model_complexity_info . input_constructor est une fonction qui prend la résolution spatiale d'entrée en tant que tuple et renvoie un dict avec des arguments d'entrée nommés du modèle. Ensuite, ce dict serait transmis au modèle comme un mot-clé arguments.verbose permet d'obtenir des informations sur les modules qui ne contribuent pas aux nombres finaux.ignore_modules oblige ptflops à ignorer les modules répertoriés. Cela peut être utile à des fins de recherche. Par exemple, on peut supprimer toutes les convolutions du processus de comptage spécifiant ignore_modules=[torch.nn.Conv2d] . Pytorch> = 2.0. Utilisez pip install ptflops==0.7.2.2 pour travailler avec Torch 1.x.
De PYPI:
pip install ptflopsDe ce référentiel:
pip install --upgrade git+https://github.com/sovrasov/flops-counter.pytorch.git import torchvision . models as models
import torch
from ptflops import get_model_complexity_info
with torch . cuda . device ( 0 ):
net = models . densenet161 ()
macs , params = get_model_complexity_info ( net , ( 3 , 224 , 224 ), as_strings = True , backend = 'pytorch'
print_per_layer_stat = True , verbose = True )
print ( '{:<30} {:<8}' . format ( 'Computational complexity: ' , macs ))
print ( '{:<30} {:<8}' . format ( 'Number of parameters: ' , params ))
macs , params = get_model_complexity_info ( net , ( 3 , 224 , 224 ), as_strings = True , backend = 'aten'
print_per_layer_stat = True , verbose = True )
print ( '{:<30} {:<8}' . format ( 'Computational complexity: ' , macs ))
print ( '{:<30} {:<8}' . format ( 'Number of parameters: ' , params ))Si PTFlops était utile pour votre rapport papier ou technique, veuillez me citer:
@online{ptflops,
author = {Vladislav Sovrasov},
title = {ptflops: a flops counting tool for neural networks in pytorch framework},
year = 2018-2024,
url = {https://github.com/sovrasov/flops-counter.pytorch},
}
Merci à @WarmSpringWinds et Horace Il pour la version initiale du script.
| Modèle | Résolution d'entrée | Paramètres (m) | MACS (G) ( pytorch ) | Macs (g) ( aten ) |
|---|---|---|---|---|
| Alexnet | 224x224 | 61.10 | 0,72 | 0,71 |
| convnext_base | 224x224 | 88,59 | 15.43 | 15.38 |
| densenet121 | 224x224 | 7.98 | 2.90 | |
| EfficientNet_B0 | 224x224 | 5.29 | 0,41 | |
| EfficientNet_V2_M | 224x224 | 54.14 | 5.43 | |
| googlenet | 224x224 | 13h00 | 1.51 | |
| Inception_v3 | 224x224 | 27.16 | 5.75 | 5.71 |
| maxvit_t | 224x224 | 30,92 | 5.48 | |
| mnasnet1_0 | 224x224 | 4.38 | 0,33 | |
| mobilenet_v2 | 224x224 | 3,50 | 0,32 | |
| mobilenet_v3_large | 224x224 | 5.48 | 0,23 | |
| regnet_y_1_6gf | 224x224 | 11.20 | 1.65 | |
| resnet18 | 224x224 | 11.69 | 1.83 | 1.81 |
| resnet50 | 224x224 | 25.56 | 4.13 | 4.09 |
| resnext50_32x4d | 224x224 | 25.03 | 4.29 | |
| shufflenet_v2_x1_0 | 224x224 | 2.28 | 0,15 | |
| Squeezenet1_0 | 224x224 | 1.25 | 0,84 | 0,82 |
| VGG16 | 224x224 | 138.36 | 15.52 | 15.48 |
| vit_b_16 | 224x224 | 86.57 | 17.61 (mal) | 16.86 |
| wide_resnet50_2 | 224x224 | 68.88 | 11.45 |
Modèle | Résolution d'entrée | Params (M) | MACS (G)