Il s'agit d'un analyseur de réseau neuronal léger basé sur Pytorch. Il est conçu pour rendre la construction de vos réseaux rapidement et facile, avec la possibilité de les déboguer. Remarque : ce référentiel est actuellement en cours de développement. Par conséquent, certaines API pourraient être modifiées.
Ces outils peuvent montrer
Il y a deux façons d'installer TorchStat dans votre environnement.
$ pip install torchstat$ python3 setup.py installSi vous souhaitez exécuter le TorchStat dès que possible, vous pouvez l'appeler comme un outil CLI si votre réseau existe dans un script. Sinon, vous devez importer TorchStat en tant que module.
$ torchstat masato$ torchstat -f example.py -m Net
[MAdd]: Dropout2d is not supported !
[Flops]: Dropout2d is not supported !
[Memory]: Dropout2d is not supported !
module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B)
0 conv1 3 224 224 10 220 220 760.0 1.85 72,600,000.0 36,784,000.0 605152.0 1936000.0 57.49% 2541152.0
1 conv2 10 110 110 20 106 106 5020.0 0.86 112,360,000.0 56,404,720.0 504080.0 898880.0 26.62% 1402960.0
2 conv2_drop 20 106 106 20 106 106 0.0 0.86 0.0 0.0 0.0 0.0 4.09% 0.0
3 fc1 56180 50 2809050.0 0.00 5,617,950.0 2,809,000.0 11460920.0 200.0 11.58% 11461120.0
4 fc2 50 10 510.0 0.00 990.0 500.0 2240.0 40.0 0.22% 2280.0
total 2815340.0 3.56 190,578,940.0 95,998,220.0 2240.0 40.0 100.00% 15407512.0
===============================================================================================================================================
Total params: 2,815,340
-----------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 3.56MB
Total MAdd: 190.58MMAdd
Total Flops: 96.0MFlops
Total MemR+W: 14.69MBSi vous ne savez pas comment utiliser une commande spécifique, exécutez la commande avec les commutateurs -h ou –Help. Vous verrez des informations d'utilisation et une liste d'options que vous pouvez utiliser avec la commande.
from torchstat import stat
import torchvision . models as models
model = models . resnet18 ()
stat ( model , ( 3 , 224 , 224 ))Remarque : ces fonctionnalités ne fonctionnent que nn.module. Les modules dans TORCH.nn.functional ne sont pas encore pris en charge.
Pour les calques pris en charge, consultez les détails.
Merci à @Sovrasov pour la version initiale de Flops Computation, @CeyKMC pour l'épine dorsale des scripts.