
Métriques d'apprentissage automatique pour les applications Pytorch distribuées et évolutives.
Qu'est-ce que TorchMetrics • Mise en œuvre d'une métrique • Métriques intégrées • Docs • Communauté • Licence

Installation simple de PYPI
pip install torchmetricsInstaller en utilisant conda
conda install -c conda-forge torchmetricsPip de Source
# with git
pip install git+https://github.com/Lightning-AI/torchmetrics.git@release/stablePip des archives
pip install https://github.com/Lightning-AI/torchmetrics/archive/refs/heads/release/stable.zipDépendances supplémentaires pour les mesures spécialisées:
pip install torchmetrics[audio]
pip install torchmetrics[image]
pip install torchmetrics[text]
pip install torchmetrics[all] # install all of the aboveInstallez la dernière version du développeur
pip install https://github.com/Lightning-AI/torchmetrics/archive/master.zipTorchMetrics est une collection de plus de 100 implémentations de métriques Pytorch et une API facile à utiliser pour créer des mesures personnalisées. Il propose:
Vous pouvez utiliser TorchMetrics avec n'importe quel modèle Pytorch ou avec Pytorch Lightning pour profiter de fonctionnalités supplémentaires telles que:
Les métriques basées sur le module contiennent des états métriques internes (similaires aux paramètres du module Pytorch) qui automatisent l'accumulation et la synchronisation entre les appareils!
Cela peut être exécuté sur CPU, GPU unique ou multi-GPU!
Pour le cas GPU / CPU unique:
import torch
# import our library
import torchmetrics
# initialize metric
metric = torchmetrics . classification . Accuracy ( task = "multiclass" , num_classes = 5 )
# move the metric to device you want computations to take place
device = "cuda" if torch . cuda . is_available () else "cpu"
metric . to ( device )
n_batches = 10
for i in range ( n_batches ):
# simulate a classification problem
preds = torch . randn ( 10 , 5 ). softmax ( dim = - 1 ). to ( device )
target = torch . randint ( 5 , ( 10 ,)). to ( device )
# metric on current batch
acc = metric ( preds , target )
print ( f"Accuracy on batch { i } : { acc } " )
# metric on all batches using custom accumulation
acc = metric . compute ()
print ( f"Accuracy on all data: { acc } " )L'utilisation de la métrique du module reste la même lors de l'utilisation de plusieurs GPU ou de plusieurs nœuds.
import os
import torch
import torch . distributed as dist
import torch . multiprocessing as mp
from torch import nn
from torch . nn . parallel import DistributedDataParallel as DDP
import torchmetrics
def metric_ddp ( rank , world_size ):
os . environ [ "MASTER_ADDR" ] = "localhost"
os . environ [ "MASTER_PORT" ] = "12355"
# create default process group
dist . init_process_group ( "gloo" , rank = rank , world_size = world_size )
# initialize model
metric = torchmetrics . classification . Accuracy ( task = "multiclass" , num_classes = 5 )
# define a model and append your metric to it
# this allows metric states to be placed on correct accelerators when
# .to(device) is called on the model
model = nn . Linear ( 10 , 10 )
model . metric = metric
model = model . to ( rank )
# initialize DDP
model = DDP ( model , device_ids = [ rank ])
n_epochs = 5
# this shows iteration over multiple training epochs
for n in range ( n_epochs ):
# this will be replaced by a DataLoader with a DistributedSampler
n_batches = 10
for i in range ( n_batches ):
# simulate a classification problem
preds = torch . randn ( 10 , 5 ). softmax ( dim = - 1 )
target = torch . randint ( 5 , ( 10 ,))
# metric on current batch
acc = metric ( preds , target )
if rank == 0 : # print only for rank 0
print ( f"Accuracy on batch { i } : { acc } " )
# metric on all batches and all accelerators using custom accumulation
# accuracy is same across both accelerators
acc = metric . compute ()
print ( f"Accuracy on all data: { acc } , accelerator rank: { rank } " )
# Resetting internal state such that metric ready for new data
metric . reset ()
# cleanup
dist . destroy_process_group ()
if __name__ == "__main__" :
world_size = 2 # number of gpus to parallelize over
mp . spawn ( metric_ddp , args = ( world_size ,), nprocs = world_size , join = True ) La mise en œuvre de votre propre métrique est aussi simple que la sous-classe d'une torch.nn.Module . Simplement, sous-classe torchmetrics.Metric et implémentez simplement les méthodes update et compute :
import torch
from torchmetrics import Metric
class MyAccuracy ( Metric ):
def __init__ ( self ):
# remember to call super
super (). __init__ ()
# call `self.add_state`for every internal state that is needed for the metrics computations
# dist_reduce_fx indicates the function that should be used to reduce
# state from multiple processes
self . add_state ( "correct" , default = torch . tensor ( 0 ), dist_reduce_fx = "sum" )
self . add_state ( "total" , default = torch . tensor ( 0 ), dist_reduce_fx = "sum" )
def update ( self , preds : torch . Tensor , target : torch . Tensor ) -> None :
# extract predicted class index for computing accuracy
preds = preds . argmax ( dim = - 1 )
assert preds . shape == target . shape
# update metric states
self . correct += torch . sum ( preds == target )
self . total += target . numel ()
def compute ( self ) -> torch . Tensor :
# compute final result
return self . correct . float () / self . total
my_metric = MyAccuracy ()
preds = torch . randn ( 10 , 5 ). softmax ( dim = - 1 )
target = torch . randint ( 5 , ( 10 ,))
print ( my_metric ( preds , target )) Semblable à torch.nn , la plupart des métriques ont à la fois une version basée sur des modules et fonctionnelle. Les versions fonctionnelles sont de simples fonctions Python qui, en tant qu'entrée, prennent Torch.tenseurs et renvoient la métrique correspondante en tant que Torch.tenseur.
import torch
# import our library
import torchmetrics
# simulate a classification problem
preds = torch . randn ( 10 , 5 ). softmax ( dim = - 1 )
target = torch . randint ( 5 , ( 10 ,))
acc = torchmetrics . functional . classification . multiclass_accuracy (
preds , target , num_classes = 5
)Dans le total, les torchmetrics contient plus de plus de mesures, ce qui couvre les domaines suivants:
Chaque domaine peut nécessiter des dépendances supplémentaires qui peuvent être installées avec pip install torchmetrics[audio] , pip install torchmetrics['image'] etc.
La visualisation des mesures peut être importante pour aider à comprendre ce qui se passe avec vos algorithmes d'apprentissage automatique. TorchMetrics a une prise en charge du traçage intégrée (installez les dépendances avec pip install torchmetrics[visual] ) pour presque toutes les mesures modulaires via la méthode .plot . Appelez simplement la méthode pour obtenir une visualisation simple de toute métrique!
import torch
from torchmetrics . classification import MulticlassAccuracy , MulticlassConfusionMatrix
num_classes = 3
# this will generate two distributions that comes more similar as iterations increase
w = torch . randn ( num_classes )
target = lambda it : torch . multinomial (( it * w ). softmax ( dim = - 1 ), 100 , replacement = True )
preds = lambda it : torch . multinomial (( it * w ). softmax ( dim = - 1 ), 100 , replacement = True )
acc = MulticlassAccuracy ( num_classes = num_classes , average = "micro" )
acc_per_class = MulticlassAccuracy ( num_classes = num_classes , average = None )
confmat = MulticlassConfusionMatrix ( num_classes = num_classes )
# plot single value
for i in range ( 5 ):
acc_per_class . update ( preds ( i ), target ( i ))
confmat . update ( preds ( i ), target ( i ))
fig1 , ax1 = acc_per_class . plot ()
fig2 , ax2 = confmat . plot ()
# plot multiple values
values = []
for i in range ( 10 ):
values . append ( acc ( preds ( i ), target ( i )))
fig3 , ax3 = acc . plot ( values )
Pour des exemples de traçage de différentes métriques, essayez d'exécuter cet exemple de fichier.
L'équipe Lightning + TorchMetrics travaille dur pour ajouter encore plus de mesures. Mais nous recherchons des contributeurs incroyables comme vous pour soumettre de nouvelles mesures et améliorer celles existantes!
Rejoignez notre discorde pour obtenir de l'aide pour devenir contributeur!
Pour obtenir de l'aide ou des questions, rejoignez notre énorme communauté sur Discord!
Nous sommes ravis de poursuivre l'héritage solide des logiciels open source et nous avons été inspirés au fil des ans par Caffe, Theano, Keras, Pytorch, Torchbereer, Ignite, Sklearn et Fast.ai.
Si vous souhaitez citer ce framework, n'hésitez pas à utiliser l'option de citation intégrée de GitHub pour générer une citation Bibtex ou de style APA basé sur ce fichier (mais seulement si vous l'avez adoré?).
Veuillez observer la licence Apache 2.0 répertoriée dans ce référentiel. De plus, le framework Lightning est en attente de brevet.