
Learn2Learn adalah perpustakaan perangkat lunak untuk penelitian meta-pembelajaran.
Learn2Learn dibangun di atas Pytorch untuk mempercepat dua aspek siklus penelitian meta-pembelajaran:
Learn2Learn menyediakan utilitas tingkat rendah dan antarmuka terpadu untuk membuat algoritma dan domain baru, bersama dengan implementasi berkualitas tinggi dari algoritma yang ada dan tolok ukur standar. Ini mempertahankan kompatibilitas dengan TorchVision, Torchaudio, Torchtext, Cherry, dan perpustakaan berbasis Pytorch lainnya yang mungkin Anda gunakan.
Untuk mempelajari lebih lanjut, lihat whitepaper kami: arxiv: 2008.12284
Ringkasan
learn2learn.data : Taskset dan Transforms untuk membuat beberapa tugas dari dataset Pytorch mana pun.learn2learn.vision : Model, Dataset, dan Tolok Ukur untuk Visi Komputer dan Pembelajaran Beberapa-Tempat.learn2learn.gym : Lingkungan dan Utilitas untuk Pembelajaran Meta-Reinforcement.learn2learn.algorithms : Pembungkus tingkat tinggi untuk algoritma meta-learning yang ada.learn2learn.optim : Utilitas dan algoritma untuk optimasi yang dapat dibedakan dan meta-descent.Sumber daya
pip install learn2learnCuplikan berikut memberikan intip pada fungsionalitas Learn2Learn.
Untuk lebih banyak algoritma (Protonets, Anil, Meta-SgD, Reptile, Meta-Curvature, KFO) merujuk ke folder contoh. Sebagian besar dari mereka dapat diimplementasikan dengan pembungkus GBML . (dokumentasi).
maml = l2l . algorithms . MAML ( model , lr = 0.1 )
opt = torch . optim . SGD ( maml . parameters (), lr = 0.001 )
for iteration in range ( 10 ):
opt . zero_grad ()
task_model = maml . clone () # torch.clone() for nn.Modules
adaptation_loss = compute_loss ( task_model )
task_model . adapt ( adaptation_loss ) # computes gradient, update task_model in-place
evaluation_loss = compute_loss ( task_model )
evaluation_loss . backward () # gradients w.r.t. maml.parameters()
opt . step () Pelajari segala jenis algoritma optimasi dengan LearnableOptimizer . (Contoh dan dokumentasi)
linear = nn . Linear ( 784 , 10 )
transform = l2l . optim . ModuleTransform ( l2l . nn . Scale )
metaopt = l2l . optim . LearnableOptimizer ( linear , transform , lr = 0.01 ) # metaopt has .step()
opt = torch . optim . SGD ( metaopt . parameters (), lr = 0.001 ) # metaopt also has .parameters()
metaopt . zero_grad ()
opt . zero_grad ()
error = loss ( linear ( X ), y )
error . backward ()
opt . step () # update metaopt
metaopt . step () # update linear Banyak dataset standar (omniglot, mini-tier-imagenet, fc100, cifar-fs) sudah tersedia di learn2learn.vision.datasets . (dokumentasi)
dataset = l2l . data . MetaDataset ( MyDataset ()) # any PyTorch dataset
transforms = [ # Easy to define your own transform
l2l . data . transforms . NWays ( dataset , n = 5 ),
l2l . data . transforms . KShots ( dataset , k = 1 ),
l2l . data . transforms . LoadData ( dataset ),
]
taskset = Taskset ( dataset , transforms , num_tasks = 20000 )
for task in taskset :
X , y = task
# Meta-train on the task Paralelisasi meta-lingkungan Anda sendiri dengan AsyncVectorEnv , atau gunakan yang terstandarisasi. (dokumentasi)
def make_env ():
env = l2l . gym . HalfCheetahForwardBackwardEnv ()
env = cherry . envs . ActionSpaceScaler ( env )
return env
env = l2l . gym . AsyncVectorEnv ([ make_env for _ in range ( 16 )]) # uses 16 threads
for task_config in env . sample_tasks ( 20 ):
env . set_task ( task ) # all threads receive the same task
state = env . reset () # use standard Gym API
action = my_policy ( env )
env . step ( action )Pelajari dan bedakan melalui pembaruan modul Pytorch. (dokumentasi)
model = MyModel ()
transform = l2l . optim . KroneckerTransform ( l2l . nn . KroneckerLinear )
learned_update = l2l . optim . ParameterUpdate ( # learnable update function
model . parameters (), transform )
clone = l2l . clone_module ( model ) # torch.clone() for nn.Modules
error = loss ( clone ( X ), y )
updates = learned_update ( # similar API as torch.autograd.grad
error ,
clone . parameters (),
create_graph = True ,
)
l2l . update_module ( clone , updates = updates )
loss ( clone ( X ), y ). backward () # Gradients w.r.t model.parameters() and learned_update.parameters() Changelog yang dapat dibaca manusia tersedia di file Changelog.md.
Untuk mengutip repositori learn2learn dalam publikasi akademik Anda, silakan gunakan referensi berikut.
Arnold, Sebastien Mr, Praateek Mahajan, Debajyoti Datta, Ian Bunner, dan Konstantinos Saitas Zarkia. 2020. "Learn2Learn: Perpustakaan untuk Penelitian Meta-pembelajaran." arxiv [cs.lg]. http://arxiv.org/abs/2008.12284.
Anda juga dapat menggunakan entri Bibtex berikut.
@article { Arnold2020-ss ,
title = " learn2learn: A Library for {Meta-Learning} Research " ,
author = " Arnold, S{'e}bastien M R and Mahajan, Praateek and Datta,
Debajyoti and Bunner, Ian and Zarkias, Konstantinos Saitas " ,
month = aug,
year = 2020 ,
url = " http://arxiv.org/abs/2008.12284 " ,
archivePrefix = " arXiv " ,
primaryClass = " cs.LG " ,
eprint = " 2008.12284 "
}
nn.Module menjadi tanpa kewarganegaraan, pelajari2Learn mempertahankan tampilan dan perasaan pytorch yang stateful. Untuk informasi lebih lanjut, lihat makalah ArXIV mereka.