
Metrik Pembelajaran Mesin untuk aplikasi Pytorch yang didistribusikan dan dapat diskalakan.
Apa itu Torchmetrics • Menerapkan metrik • Metrik bawaan • Dokumen • Komunitas • Lisensi

Instalasi sederhana dari PYPI
pip install torchmetricsInstal menggunakan conda
conda install -c conda-forge torchmetricsPip dari Sumber
# with git
pip install git+https://github.com/Lightning-AI/torchmetrics.git@release/stablePip dari arsip
pip install https://github.com/Lightning-AI/torchmetrics/archive/refs/heads/release/stable.zipKetergantungan tambahan untuk metrik khusus:
pip install torchmetrics[audio]
pip install torchmetrics[image]
pip install torchmetrics[text]
pip install torchmetrics[all] # install all of the aboveInstal versi pengembang terbaru
pip install https://github.com/Lightning-AI/torchmetrics/archive/master.zipTorchmetrics adalah kumpulan 100+ implementasi metrik Pytorch dan API yang mudah digunakan untuk membuat metrik khusus. Itu menawarkan:
Anda dapat menggunakan Torchmetrics dengan model Pytorch apa pun atau dengan Pytorch Lightning untuk menikmati fitur tambahan seperti:
Metrik berbasis modul berisi status metrik internal (mirip dengan parameter modul pytorch) yang mengotomatiskan akumulasi dan sinkronisasi di seluruh perangkat!
Ini dapat dijalankan pada CPU, GPU tunggal atau multi-GPU!
Untuk kasus GPU/CPU tunggal:
import torch
# import our library
import torchmetrics
# initialize metric
metric = torchmetrics . classification . Accuracy ( task = "multiclass" , num_classes = 5 )
# move the metric to device you want computations to take place
device = "cuda" if torch . cuda . is_available () else "cpu"
metric . to ( device )
n_batches = 10
for i in range ( n_batches ):
# simulate a classification problem
preds = torch . randn ( 10 , 5 ). softmax ( dim = - 1 ). to ( device )
target = torch . randint ( 5 , ( 10 ,)). to ( device )
# metric on current batch
acc = metric ( preds , target )
print ( f"Accuracy on batch { i } : { acc } " )
# metric on all batches using custom accumulation
acc = metric . compute ()
print ( f"Accuracy on all data: { acc } " )Penggunaan metrik modul tetap sama saat menggunakan beberapa GPU atau beberapa node.
import os
import torch
import torch . distributed as dist
import torch . multiprocessing as mp
from torch import nn
from torch . nn . parallel import DistributedDataParallel as DDP
import torchmetrics
def metric_ddp ( rank , world_size ):
os . environ [ "MASTER_ADDR" ] = "localhost"
os . environ [ "MASTER_PORT" ] = "12355"
# create default process group
dist . init_process_group ( "gloo" , rank = rank , world_size = world_size )
# initialize model
metric = torchmetrics . classification . Accuracy ( task = "multiclass" , num_classes = 5 )
# define a model and append your metric to it
# this allows metric states to be placed on correct accelerators when
# .to(device) is called on the model
model = nn . Linear ( 10 , 10 )
model . metric = metric
model = model . to ( rank )
# initialize DDP
model = DDP ( model , device_ids = [ rank ])
n_epochs = 5
# this shows iteration over multiple training epochs
for n in range ( n_epochs ):
# this will be replaced by a DataLoader with a DistributedSampler
n_batches = 10
for i in range ( n_batches ):
# simulate a classification problem
preds = torch . randn ( 10 , 5 ). softmax ( dim = - 1 )
target = torch . randint ( 5 , ( 10 ,))
# metric on current batch
acc = metric ( preds , target )
if rank == 0 : # print only for rank 0
print ( f"Accuracy on batch { i } : { acc } " )
# metric on all batches and all accelerators using custom accumulation
# accuracy is same across both accelerators
acc = metric . compute ()
print ( f"Accuracy on all data: { acc } , accelerator rank: { rank } " )
# Resetting internal state such that metric ready for new data
metric . reset ()
# cleanup
dist . destroy_process_group ()
if __name__ == "__main__" :
world_size = 2 # number of gpus to parallelize over
mp . spawn ( metric_ddp , args = ( world_size ,), nprocs = world_size , join = True ) Menerapkan metrik Anda sendiri semudah subklassing torch.nn.Module . Sederhananya, torchmetrics.Metric subclass.metrik dan cukup terapkan metode update dan compute :
import torch
from torchmetrics import Metric
class MyAccuracy ( Metric ):
def __init__ ( self ):
# remember to call super
super (). __init__ ()
# call `self.add_state`for every internal state that is needed for the metrics computations
# dist_reduce_fx indicates the function that should be used to reduce
# state from multiple processes
self . add_state ( "correct" , default = torch . tensor ( 0 ), dist_reduce_fx = "sum" )
self . add_state ( "total" , default = torch . tensor ( 0 ), dist_reduce_fx = "sum" )
def update ( self , preds : torch . Tensor , target : torch . Tensor ) -> None :
# extract predicted class index for computing accuracy
preds = preds . argmax ( dim = - 1 )
assert preds . shape == target . shape
# update metric states
self . correct += torch . sum ( preds == target )
self . total += target . numel ()
def compute ( self ) -> torch . Tensor :
# compute final result
return self . correct . float () / self . total
my_metric = MyAccuracy ()
preds = torch . randn ( 10 , 5 ). softmax ( dim = - 1 )
target = torch . randint ( 5 , ( 10 ,))
print ( my_metric ( preds , target )) Mirip dengan torch.nn , sebagian besar metrik memiliki versi berbasis modul dan fungsional. Versi fungsional adalah fungsi Python sederhana yang sebagai input mengambil torch.tensor dan mengembalikan metrik yang sesuai sebagai obor.tensor.
import torch
# import our library
import torchmetrics
# simulate a classification problem
preds = torch . randn ( 10 , 5 ). softmax ( dim = - 1 )
target = torch . randint ( 5 , ( 10 ,))
acc = torchmetrics . functional . classification . multiclass_accuracy (
preds , target , num_classes = 5
)Secara total obor berisi 100+ metrik, yang mencakup domain berikut:
Setiap domain mungkin memerlukan beberapa dependensi tambahan yang dapat diinstal dengan pip install torchmetrics[audio] , pip install torchmetrics['image'] dll.
Visualisasi metrik dapat menjadi penting untuk membantu memahami apa yang terjadi dengan algoritma pembelajaran mesin Anda. Torchmetrics memiliki dukungan plot built-in (instal dependensi dengan pip install torchmetrics[visual] ) untuk hampir semua metrik modular melalui metode .plot . Cukup panggil metode untuk mendapatkan visualisasi sederhana dari metrik apa pun!
import torch
from torchmetrics . classification import MulticlassAccuracy , MulticlassConfusionMatrix
num_classes = 3
# this will generate two distributions that comes more similar as iterations increase
w = torch . randn ( num_classes )
target = lambda it : torch . multinomial (( it * w ). softmax ( dim = - 1 ), 100 , replacement = True )
preds = lambda it : torch . multinomial (( it * w ). softmax ( dim = - 1 ), 100 , replacement = True )
acc = MulticlassAccuracy ( num_classes = num_classes , average = "micro" )
acc_per_class = MulticlassAccuracy ( num_classes = num_classes , average = None )
confmat = MulticlassConfusionMatrix ( num_classes = num_classes )
# plot single value
for i in range ( 5 ):
acc_per_class . update ( preds ( i ), target ( i ))
confmat . update ( preds ( i ), target ( i ))
fig1 , ax1 = acc_per_class . plot ()
fig2 , ax2 = confmat . plot ()
# plot multiple values
values = []
for i in range ( 10 ):
values . append ( acc ( preds ( i ), target ( i )))
fig3 , ax3 = acc . plot ( values )
Untuk contoh merencanakan berbagai metrik, coba jalankan file contoh ini.
Tim Lightning + Torchmetrics bekerja keras menambahkan lebih banyak metrik. Tapi kami mencari kontributor luar biasa seperti Anda untuk mengirimkan metrik baru dan meningkatkan yang sudah ada!
Bergabunglah dengan perselisihan kami untuk mendapatkan bantuan untuk menjadi kontributor!
Untuk bantuan atau pertanyaan, bergabunglah dengan komunitas besar kami di Discord!
Kami senang untuk melanjutkan warisan kuat perangkat lunak open source dan telah terinspirasi selama bertahun -tahun oleh Caffe, Theano, Kera, Pytorch, Torchbearer, Ignite, Sklearn dan Fast.ai.
Jika Anda ingin mengutip kerangka kerja ini, jangan ragu untuk menggunakan opsi kutipan bawaan GitHub untuk menghasilkan kutipan bergaya Bibtex atau APA berdasarkan file ini (tetapi hanya jika Anda menyukainya?).
Harap amati lisensi Apache 2.0 yang tercantum dalam repositori ini. Selain itu, kerangka kerja petir adalah paten yang tertunda.