Une collection d'extensions et de chargeurs de données pour l'apprentissage et la méta-apprentissage à quelques coups à Pytorch. Torchmeta contient des repères de méta-d'apprentissage populaires, entièrement compatibles avec torchvision et DataLoader de Pytorch.
Module de Pytorch, appelé MetaModule , qui simplifie la création de certains modèles de méta-apprentissage (par exemple, les méthodes de méta-d'apprentissage basées sur le gradient). Voir l'exemple MAML pour un exemple en utilisant MetaModule . Vous pouvez installer Torchmeta à l'aide de Python's Package Manager PIP, soit à partir de Source. Pour éviter tout conflit avec votre configuration Python existante, il est suggéré de fonctionner dans un environnement virtuel avec virtualenv . Pour installer virtualenv :
pip install --upgrade virtualenv
virtualenv venv
source venv/bin/activateC'est le moyen recommandé d'installer Torchmeta:
pip install torchmetaVous pouvez également installer Torchmeta à partir de la source. Ceci est recommandé si vous souhaitez contribuer à Torchmeta.
git clone https://github.com/tristandeleu/pytorch-meta.git
cd pytorch-meta
python setup.py installCet exemple minimal ci-dessous montre comment créer un dataloader pour l'ensemble de données omniglot à 5 coups avec Torchmeta. Le dataloader charge un lot de tâches générées de manière aléatoire, et tous les échantillons sont concaténés en un seul tenseur. Pour plus d'exemples, consultez le dossier Exemples.
from torchmeta . datasets . helpers import omniglot
from torchmeta . utils . data import BatchMetaDataLoader
dataset = omniglot ( "data" , ways = 5 , shots = 5 , test_shots = 15 , meta_train = True , download = True )
dataloader = BatchMetaDataLoader ( dataset , batch_size = 16 , num_workers = 4 )
for batch in dataloader :
train_inputs , train_targets = batch [ "train" ]
print ( 'Train inputs shape: {0}' . format ( train_inputs . shape )) # (16, 25, 1, 28, 28)
print ( 'Train targets shape: {0}' . format ( train_targets . shape )) # (16, 25)
test_inputs , test_targets = batch [ "test" ]
print ( 'Test inputs shape: {0}' . format ( test_inputs . shape )) # (16, 75, 1, 28, 28)
print ( 'Test targets shape: {0}' . format ( test_targets . shape )) # (16, 75) Les fonctions d'assistance ne sont disponibles que pour certains des ensembles de données disponibles. Cependant, tous sont disponibles via l'interface unifiée fournie par Torchmeta. L' dataset variable défini ci-dessus est équivalent à ce qui suit
from torchmeta . datasets import Omniglot
from torchmeta . transforms import Categorical , ClassSplitter , Rotation
from torchvision . transforms import Compose , Resize , ToTensor
from torchmeta . utils . data import BatchMetaDataLoader
dataset = Omniglot ( "data" ,
# Number of ways
num_classes_per_task = 5 ,
# Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)
transform = Compose ([ Resize ( 28 ), ToTensor ()]),
# Transform the labels to integers (e.g. ("Glagolitic/character01", "Sanskrit/character14", ...) to (0, 1, ...))
target_transform = Categorical ( num_classes = 5 ),
# Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)
class_augmentations = [ Rotation ([ 90 , 180 , 270 ])],
meta_train = True ,
download = True )
dataset = ClassSplitter ( dataset , shuffle = True , num_train_per_class = 5 , num_test_per_class = 15 )
dataloader = BatchMetaDataLoader ( dataset , batch_size = 16 , num_workers = 4 )Notez que le dataloader, recevant l'ensemble de données, reste le même.
Tristan Deleu, Tobias Würfl, Mandana Samiei, Joseph Paul Cohen et Yoshua Bengio. Torchmeta: A Meta-Learning Library for Pytorch, 2019 [Arxiv]
Si vous souhaitez citer Torchmeta, utilisez l'entrée Bibtex suivante:
@misc{deleu2019torchmeta,
title={{Torchmeta: A Meta-Learning library for PyTorch}},
author={Deleu, Tristan and W"urfl, Tobias and Samiei, Mandana and Cohen, Joseph Paul and Bengio, Yoshua},
year={2019},
url={https://arxiv.org/abs/1909.06576},
note={Available at: https://github.com/tristandeleu/pytorch-meta}
}