Dieses Tool wurde entwickelt, um die theoretische Menge an Multiply-ADD-Operationen in neuronalen Netzwerken zu berechnen. Es kann auch die Anzahl der Parameter berechnen und die Rechenkosten pro Layer-Rechenkosten eines bestimmten Netzwerks drucken.
ptflops hat zwei Backends, pytorch und aten . pytorch Backend ist ein Vermächtnis, das nur nn.Modules betrachtet. Es ist jedoch immer noch nützlich, da es eine bessere Par-Schicht-Analyse für CNNs bietet. In allen anderen Fällen wird empfohlen, aten -Backend zu verwenden, das AM -Operationen berücksichtigt, und deckt daher mehr Modellarchitekturen (einschließlich Transformatoren) ab. Das Standard -Backend wird aten . Bitte verwenden Sie kein pytorch -Backend für Transformer -Architekturen.
aten Backendverbose=True um die Operationen zu sehen, die während der Komplexitätsberechnung nicht berücksichtigt wurden.nn.Module verschachtelt sind. Tiefere Module auf der zweiten Nistebene sind in der Statistik pro Schicht nicht gezeigt.ignore_modules Option erzwingt ptflops , die aufgeführten Module zu ignorieren. Dies kann für Forschungszwecke nützlich sein. Beispielsweise kann man alle Konvolutionen aus dem Zählprozess abgeben, in ignore_modules=[torch.ops.aten.convolution, torch.ops.aten._convolution] angegeben wird. pytorch -BackendExperimentelle Unterstützung:
torch.nn.functional.* Und tensor.* Operationen. Daher tragen nicht unterstützte Operationen nicht zur endgültigen Komplexitätsschätzung bei. Siehe ptflops/pytorch_ops.py:FUNCTIONAL_MAPPING,TENSOR_OPS_MAPPING um unterstützte OPs zu überprüfen. Manchmal Konflikten mit Funktionsebene mit Hooks für nn.Module (z. B. benutzerdefinierte). In diesem Fall kann das Zählen mit diesen OPs deaktiviert werden, indem er backend_specific_config={"count_functional" : False} übergeben wird.ptflops startet ein bestimmtes Modell auf einen zufälligen Tensor und schätzt die Menge an Berechnungen während der Inferenz. Komplizierte Modelle können mehrere Eingänge haben, einige von ihnen könnten optional sein. Um nicht-triviale Eingaben zu konstruieren, kann man das Argument input_constructor des get_model_complexity_info verwenden. input_constructor ist eine Funktion, die die räumliche Auflösung der Eingabe als Tupel nimmt und ein DICT mit benannten Eingabemargumenten des Modells zurückgibt. Als nächstes würde dieses Diktat als Keyword -Argumente an das Modell übergeben.verbose Parameter ermöglicht es, Informationen über Module zu erhalten, die nicht zu den endgültigen Zahlen beitragen.ignore_modules Option erzwingt ptflops , die aufgeführten Module zu ignorieren. Dies kann für Forschungszwecke nützlich sein. Beispielsweise kann man alle Konvolutionen aus dem Zählprozess abgeben, in dem ignore_modules=[torch.nn.Conv2d] angegeben wird. Pytorch> = 2.0. Verwenden Sie pip install ptflops==0.7.2.2 um mit Torch 1.x zu arbeiten.
Von pypi:
pip install ptflopsAus diesem Repository:
pip install --upgrade git+https://github.com/sovrasov/flops-counter.pytorch.git import torchvision . models as models
import torch
from ptflops import get_model_complexity_info
with torch . cuda . device ( 0 ):
net = models . densenet161 ()
macs , params = get_model_complexity_info ( net , ( 3 , 224 , 224 ), as_strings = True , backend = 'pytorch'
print_per_layer_stat = True , verbose = True )
print ( '{:<30} {:<8}' . format ( 'Computational complexity: ' , macs ))
print ( '{:<30} {:<8}' . format ( 'Number of parameters: ' , params ))
macs , params = get_model_complexity_info ( net , ( 3 , 224 , 224 ), as_strings = True , backend = 'aten'
print_per_layer_stat = True , verbose = True )
print ( '{:<30} {:<8}' . format ( 'Computational complexity: ' , macs ))
print ( '{:<30} {:<8}' . format ( 'Number of parameters: ' , params ))Wenn PTFlops für Ihren Papier- oder Tech -Bericht nützlich war, zitieren Sie mich bitte:
@online{ptflops,
author = {Vladislav Sovrasov},
title = {ptflops: a flops counting tool for neural networks in pytorch framework},
year = 2018-2024,
url = {https://github.com/sovrasov/flops-counter.pytorch},
}
Vielen Dank an @WarmspringWinds und Horace HE für die erste Version des Skripts.
| Modell | Eingabeauflösung | Parameter (m) | Macs (g) ( pytorch ) | Macs (g) ( aten ) |
|---|---|---|---|---|
| Alexnet | 224x224 | 61.10 | 0,72 | 0,71 |
| überrevnext_base | 224x224 | 88.59 | 15.43 | 15.38 |
| Densenet121 | 224x224 | 7.98 | 2.90 | |
| effizientnetnetz_b0 | 224x224 | 5.29 | 0,41 | |
| effizientnet_v2_m | 224x224 | 54.14 | 5.43 | |
| Googlenet | 224x224 | 13.00 | 1,51 | |
| Inception_v3 | 224x224 | 27.16 | 5.75 | 5.71 |
| maxvit_t | 224x224 | 30.92 | 5.48 | |
| mnasnet1_0 | 224x224 | 4.38 | 0,33 | |
| Mobilenet_v2 | 224x224 | 3.50 | 0,32 | |
| mobilenet_v3_large | 224x224 | 5.48 | 0,23 | |
| Regnet_y_1_6gf | 224x224 | 11.20 | 1.65 | |
| resnet18 | 224x224 | 11.69 | 1.83 | 1.81 |
| resnet50 | 224x224 | 25.56 | 4.13 | 4.09 |
| Resnext50_32x4d | 224x224 | 25.03 | 4.29 | |
| SHUFFLENET_V2_X1_0 | 224x224 | 2.28 | 0,15 | |
| Squeezenet1_0 | 224x224 | 1.25 | 0,84 | 0,82 |
| VGG16 | 224x224 | 138.36 | 15.52 | 15.48 |
| Vit_b_16 | 224x224 | 86,57 | 17.61 (falsch) | 16.86 |
| Wide_resnet50_2 | 224x224 | 68,88 | 11.45 |
Modell | Eingabeauflösung | Parameter (m) | Macs (g)