
分散型のスケーラブルなPytorchアプリケーションの機械学習メトリック。
トーチメトリックとは何ですか•メトリックの実装•内蔵メトリック•ドキュメント•コミュニティ•ライセンス

Pypiからの簡単なインストール
pip install torchmetricsCondaを使用してインストールします
conda install -c conda-forge torchmetricsソースからのピップ
# with git
pip install git+https://github.com/Lightning-AI/torchmetrics.git@release/stableアーカイブからのピップ
pip install https://github.com/Lightning-AI/torchmetrics/archive/refs/heads/release/stable.zip特殊なメトリックの追加依存関係:
pip install torchmetrics[audio]
pip install torchmetrics[image]
pip install torchmetrics[text]
pip install torchmetrics[all] # install all of the above最新の開発者バージョンをインストールします
pip install https://github.com/Lightning-AI/torchmetrics/archive/master.zipTorchmetricsは、100以上のPytorchメトリックの実装と使いやすいAPIのコレクションであり、カスタムメトリックを作成します。それは提供します:
PytorchモデルまたはPytorch LightningでTorchmetricsを使用して、次のような追加機能を楽しむことができます。
モジュールベースのメトリックには、デバイス全体で蓄積と同期を自動化する内部メトリック状態(Pytorchモジュールのパラメーターと同様)が含まれています!
これは、CPU、シングルGPU、またはマルチGPUで実行できます!
単一のGPU/CPUケースの場合:
import torch
# import our library
import torchmetrics
# initialize metric
metric = torchmetrics . classification . Accuracy ( task = "multiclass" , num_classes = 5 )
# move the metric to device you want computations to take place
device = "cuda" if torch . cuda . is_available () else "cpu"
metric . to ( device )
n_batches = 10
for i in range ( n_batches ):
# simulate a classification problem
preds = torch . randn ( 10 , 5 ). softmax ( dim = - 1 ). to ( device )
target = torch . randint ( 5 , ( 10 ,)). to ( device )
# metric on current batch
acc = metric ( preds , target )
print ( f"Accuracy on batch { i } : { acc } " )
# metric on all batches using custom accumulation
acc = metric . compute ()
print ( f"Accuracy on all data: { acc } " )複数のGPUまたは複数のノードを使用する場合、モジュールメトリック使用量は同じままです。
import os
import torch
import torch . distributed as dist
import torch . multiprocessing as mp
from torch import nn
from torch . nn . parallel import DistributedDataParallel as DDP
import torchmetrics
def metric_ddp ( rank , world_size ):
os . environ [ "MASTER_ADDR" ] = "localhost"
os . environ [ "MASTER_PORT" ] = "12355"
# create default process group
dist . init_process_group ( "gloo" , rank = rank , world_size = world_size )
# initialize model
metric = torchmetrics . classification . Accuracy ( task = "multiclass" , num_classes = 5 )
# define a model and append your metric to it
# this allows metric states to be placed on correct accelerators when
# .to(device) is called on the model
model = nn . Linear ( 10 , 10 )
model . metric = metric
model = model . to ( rank )
# initialize DDP
model = DDP ( model , device_ids = [ rank ])
n_epochs = 5
# this shows iteration over multiple training epochs
for n in range ( n_epochs ):
# this will be replaced by a DataLoader with a DistributedSampler
n_batches = 10
for i in range ( n_batches ):
# simulate a classification problem
preds = torch . randn ( 10 , 5 ). softmax ( dim = - 1 )
target = torch . randint ( 5 , ( 10 ,))
# metric on current batch
acc = metric ( preds , target )
if rank == 0 : # print only for rank 0
print ( f"Accuracy on batch { i } : { acc } " )
# metric on all batches and all accelerators using custom accumulation
# accuracy is same across both accelerators
acc = metric . compute ()
print ( f"Accuracy on all data: { acc } , accelerator rank: { rank } " )
# Resetting internal state such that metric ready for new data
metric . reset ()
# cleanup
dist . destroy_process_group ()
if __name__ == "__main__" :
world_size = 2 # number of gpus to parallelize over
mp . spawn ( metric_ddp , args = ( world_size ,), nprocs = world_size , join = True )独自のメトリックを実装することは、 torch.nn.Moduleをサブクラス化するのと同じくらい簡単です。簡単には、サブクラスtorchmetrics.Metricで、 updateおよびcompute方法を実装するだけです。
import torch
from torchmetrics import Metric
class MyAccuracy ( Metric ):
def __init__ ( self ):
# remember to call super
super (). __init__ ()
# call `self.add_state`for every internal state that is needed for the metrics computations
# dist_reduce_fx indicates the function that should be used to reduce
# state from multiple processes
self . add_state ( "correct" , default = torch . tensor ( 0 ), dist_reduce_fx = "sum" )
self . add_state ( "total" , default = torch . tensor ( 0 ), dist_reduce_fx = "sum" )
def update ( self , preds : torch . Tensor , target : torch . Tensor ) -> None :
# extract predicted class index for computing accuracy
preds = preds . argmax ( dim = - 1 )
assert preds . shape == target . shape
# update metric states
self . correct += torch . sum ( preds == target )
self . total += target . numel ()
def compute ( self ) -> torch . Tensor :
# compute final result
return self . correct . float () / self . total
my_metric = MyAccuracy ()
preds = torch . randn ( 10 , 5 ). softmax ( dim = - 1 )
target = torch . randint ( 5 , ( 10 ,))
print ( my_metric ( preds , target ))torch.nnと同様に、ほとんどのメトリックにはモジュールベースと機能バージョンの両方があります。機能バージョンは、入力としてtorch.tensorを使用し、対応するメトリックをTorch.tensorとして返す単純なPython関数です。
import torch
# import our library
import torchmetrics
# simulate a classification problem
preds = torch . randn ( 10 , 5 ). softmax ( dim = - 1 )
target = torch . randint ( 5 , ( 10 ,))
acc = torchmetrics . functional . classification . multiclass_accuracy (
preds , target , num_classes = 5
)総トーチメトリクスには、次のドメインをカバーする100以上のメトリックが含まれています。
各ドメインにはpip install torchmetrics[audio] 、 pip install torchmetrics['image']などでインストールできる追加の依存関係が必要になる場合があります。
メトリックの視覚化は、機械学習アルゴリズムで何が起こっているかを理解するのに役立つことが重要です。 Torchmetricsには、 .plotメソッドを介してほぼすべてのモジュラーメトリックに対して、プロットサポートが組み込まれています( pip install torchmetrics[visual] )。方法を呼び出して、任意のメトリックの簡単な視覚化を取得してください!
import torch
from torchmetrics . classification import MulticlassAccuracy , MulticlassConfusionMatrix
num_classes = 3
# this will generate two distributions that comes more similar as iterations increase
w = torch . randn ( num_classes )
target = lambda it : torch . multinomial (( it * w ). softmax ( dim = - 1 ), 100 , replacement = True )
preds = lambda it : torch . multinomial (( it * w ). softmax ( dim = - 1 ), 100 , replacement = True )
acc = MulticlassAccuracy ( num_classes = num_classes , average = "micro" )
acc_per_class = MulticlassAccuracy ( num_classes = num_classes , average = None )
confmat = MulticlassConfusionMatrix ( num_classes = num_classes )
# plot single value
for i in range ( 5 ):
acc_per_class . update ( preds ( i ), target ( i ))
confmat . update ( preds ( i ), target ( i ))
fig1 , ax1 = acc_per_class . plot ()
fig2 , ax2 = confmat . plot ()
# plot multiple values
values = []
for i in range ( 10 ):
values . append ( acc ( preds ( i ), target ( i )))
fig3 , ax3 = acc . plot ( values )
さまざまなメトリックをプロットする例については、この例ファイルを実行してみてください。
Lightning + Torchmetricsチームは、さらに多くのメトリックを追加するのに苦労しています。しかし、私たちはあなたのような信じられないほどの貢献者を探して、新しいメトリックを提出し、既存のメトリックを改善しています!
私たちの不一致に参加して、貢献者になることを助けてください!
助けや質問については、Discordで巨大なコミュニティに参加してください!
オープンソースソフトウェアの強力な遺産を継続できることを楽しみにしており、Caffe、Theano、Keras、Pytorch、Torchbearer、Ignite、Sklearn、Fast.aiに長年にわたってインスピレーションを受けてきました。
このフレームワークを引用したい場合は、Githubの組み込みの引用オプションを使用して、このファイルに基づいてBibtexまたはAPAスタイルの引用を生成してください(ただし、気に入った場合のみですか?)。
このリポジトリにリストされているApache 2.0ライセンスを遵守してください。さらに、Lightningフレームワークは保留中の特許です。