Kumpulan ekstensi dan pemuat data untuk pembelajaran beberapa shot & meta-belajar di Pytorch. Torchmeta berisi tolok ukur meta-pembelajaran populer, sepenuhnya kompatibel dengan torchvision dan DataLoader Pytorch.
Module Pytorch, yang disebut MetaModule , yang menyederhanakan penciptaan model pembelajaran meta tertentu (misalnya metode pembelajaran meta berbasis gradien). Lihat contoh MAML untuk contoh menggunakan MetaModule . Anda dapat menginstal Torchmeta baik menggunakan Pypack Manager Pip Python, atau dari Sumber. Untuk menghindari konflik dengan pengaturan Python Anda yang ada, disarankan untuk bekerja di lingkungan virtual dengan virtualenv . Untuk menginstal virtualenv :
pip install --upgrade virtualenv
virtualenv venv
source venv/bin/activateIni adalah cara yang disarankan untuk menginstal Torchmeta:
pip install torchmetaAnda juga dapat menginstal Torchmeta dari Source. Ini disarankan jika Anda ingin berkontribusi pada Torchmeta.
git clone https://github.com/tristandeleu/pytorch-meta.git
cd pytorch-meta
python setup.py installContoh minimal di bawah ini menunjukkan cara membuat dataloader untuk dataset omniglot 5-arah 5-shot dengan TorchMeta. Dataloader memuat batch tugas yang dihasilkan secara acak, dan semua sampel digabungkan menjadi tensor tunggal. Untuk contoh lebih lanjut, periksa folder contoh.
from torchmeta . datasets . helpers import omniglot
from torchmeta . utils . data import BatchMetaDataLoader
dataset = omniglot ( "data" , ways = 5 , shots = 5 , test_shots = 15 , meta_train = True , download = True )
dataloader = BatchMetaDataLoader ( dataset , batch_size = 16 , num_workers = 4 )
for batch in dataloader :
train_inputs , train_targets = batch [ "train" ]
print ( 'Train inputs shape: {0}' . format ( train_inputs . shape )) # (16, 25, 1, 28, 28)
print ( 'Train targets shape: {0}' . format ( train_targets . shape )) # (16, 25)
test_inputs , test_targets = batch [ "test" ]
print ( 'Test inputs shape: {0}' . format ( test_inputs . shape )) # (16, 75, 1, 28, 28)
print ( 'Test targets shape: {0}' . format ( test_targets . shape )) # (16, 75) Fungsi helper hanya tersedia untuk beberapa dataset yang tersedia. Namun, semuanya tersedia melalui antarmuka terpadu yang disediakan oleh Torchmeta. dataset variabel yang didefinisikan di atas setara dengan yang berikut ini
from torchmeta . datasets import Omniglot
from torchmeta . transforms import Categorical , ClassSplitter , Rotation
from torchvision . transforms import Compose , Resize , ToTensor
from torchmeta . utils . data import BatchMetaDataLoader
dataset = Omniglot ( "data" ,
# Number of ways
num_classes_per_task = 5 ,
# Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)
transform = Compose ([ Resize ( 28 ), ToTensor ()]),
# Transform the labels to integers (e.g. ("Glagolitic/character01", "Sanskrit/character14", ...) to (0, 1, ...))
target_transform = Categorical ( num_classes = 5 ),
# Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)
class_augmentations = [ Rotation ([ 90 , 180 , 270 ])],
meta_train = True ,
download = True )
dataset = ClassSplitter ( dataset , shuffle = True , num_train_per_class = 5 , num_test_per_class = 15 )
dataloader = BatchMetaDataLoader ( dataset , batch_size = 16 , num_workers = 4 )Perhatikan bahwa Dataloader, menerima dataset, tetap sama.
Tristan Deleu, Tobias Würfl, Mandana Samiei, Joseph Paul Cohen, dan Yoshua Bengio. Torchmeta: Perpustakaan meta-pembelajaran untuk Pytorch, 2019 [Arxiv]
Jika Anda ingin mengutip Torchmeta, gunakan entri Bibtex berikut:
@misc{deleu2019torchmeta,
title={{Torchmeta: A Meta-Learning library for PyTorch}},
author={Deleu, Tristan and W"urfl, Tobias and Samiei, Mandana and Cohen, Joseph Paul and Bengio, Yoshua},
year={2019},
url={https://arxiv.org/abs/1909.06576},
note={Available at: https://github.com/tristandeleu/pytorch-meta}
}