
Learn2Learn est une bibliothèque de logiciels pour la recherche sur la méta-apprentissage.
Learn2Learn s'appuie sur Pytorch pour accélérer deux aspects du cycle de recherche de méta-apprentissage:
Learn2Learn fournit des utilitaires de bas niveau et une interface unifiée pour créer de nouveaux algorithmes et domaines, ainsi que des implémentations de haute qualité des algorithmes existants et des références standardisées. Il conserve la compatibilité avec TorchVision, Torchaudio, TorchText, Cherry et toute autre bibliothèque basée sur Pytorch que vous pourriez utiliser.
Pour en savoir plus, voir notre livre blanc: Arxiv: 2008.12284
Aperçu
learn2learn.data : Taskset et se transforme en tâches à quelques tirs à partir de tout ensemble de données Pytorch.learn2learn.vision : modèles, ensembles de données et repères pour la vision par ordinateur et l'apprentissage à quelques coups.learn2learn.gym : Environnement et services publics pour l'apprentissage des méta-renforts.learn2learn.algorithms : Wrappers de haut niveau pour les algorithmes de méta-apprentissage existants.learn2learn.optim : Utilitaires et algorithmes pour l'optimisation différenciable et le méta-décembre.Ressources
pip install learn2learnLes extraits suivants fournissent un aperçu des fonctionnalités de Learn2Learn.
Pour plus d'algorithmes (Protonets, ANIL, Meta-SGD, Reptile, Meta-Curvature, KFO), référez-vous au dossier Exemples. La plupart d'entre eux peuvent être mis en œuvre avec l'emballage GBML . (documentation).
maml = l2l . algorithms . MAML ( model , lr = 0.1 )
opt = torch . optim . SGD ( maml . parameters (), lr = 0.001 )
for iteration in range ( 10 ):
opt . zero_grad ()
task_model = maml . clone () # torch.clone() for nn.Modules
adaptation_loss = compute_loss ( task_model )
task_model . adapt ( adaptation_loss ) # computes gradient, update task_model in-place
evaluation_loss = compute_loss ( task_model )
evaluation_loss . backward () # gradients w.r.t. maml.parameters()
opt . step () Apprenez tout type d'algorithme d'optimisation avec l' LearnableOptimizer . (exemple et documentation)
linear = nn . Linear ( 784 , 10 )
transform = l2l . optim . ModuleTransform ( l2l . nn . Scale )
metaopt = l2l . optim . LearnableOptimizer ( linear , transform , lr = 0.01 ) # metaopt has .step()
opt = torch . optim . SGD ( metaopt . parameters (), lr = 0.001 ) # metaopt also has .parameters()
metaopt . zero_grad ()
opt . zero_grad ()
error = loss ( linear ( X ), y )
error . backward ()
opt . step () # update metaopt
metaopt . step () # update linear De nombreux ensembles de données standardisés (omniglot, mini-imagenet de niveau / kilométrique, FC100, CIFAR-FS) sont facilement disponibles dans learn2learn.vision.datasets . (documentation)
dataset = l2l . data . MetaDataset ( MyDataset ()) # any PyTorch dataset
transforms = [ # Easy to define your own transform
l2l . data . transforms . NWays ( dataset , n = 5 ),
l2l . data . transforms . KShots ( dataset , k = 1 ),
l2l . data . transforms . LoadData ( dataset ),
]
taskset = Taskset ( dataset , transforms , num_tasks = 20000 )
for task in taskset :
X , y = task
# Meta-train on the task Parallilisez vos propres méta-environnements avec AsyncVectorEnv , ou utilisez les standardisés. (documentation)
def make_env ():
env = l2l . gym . HalfCheetahForwardBackwardEnv ()
env = cherry . envs . ActionSpaceScaler ( env )
return env
env = l2l . gym . AsyncVectorEnv ([ make_env for _ in range ( 16 )]) # uses 16 threads
for task_config in env . sample_tasks ( 20 ):
env . set_task ( task ) # all threads receive the same task
state = env . reset () # use standard Gym API
action = my_policy ( env )
env . step ( action )Apprenez et différenciez les mises à jour des modules Pytorch. (documentation)
model = MyModel ()
transform = l2l . optim . KroneckerTransform ( l2l . nn . KroneckerLinear )
learned_update = l2l . optim . ParameterUpdate ( # learnable update function
model . parameters (), transform )
clone = l2l . clone_module ( model ) # torch.clone() for nn.Modules
error = loss ( clone ( X ), y )
updates = learned_update ( # similar API as torch.autograd.grad
error ,
clone . parameters (),
create_graph = True ,
)
l2l . update_module ( clone , updates = updates )
loss ( clone ( X ), y ). backward () # Gradients w.r.t model.parameters() and learned_update.parameters() Un ChangeLog lisible par l'homme est disponible dans le fichier Changelog.md.
Pour citer le référentiel learn2learn dans vos publications académiques, veuillez utiliser la référence suivante.
Arnold, Sébastien MR, Praateek Mahajan, Debajyoti Datta, Ian Bunner et Konstantinos Saitas Zarkias. 2020. «Learn2Learn: une bibliothèque pour la recherche de méta-apprentissage.» arXiv [cs.lg]. http://arxiv.org/abs/2008.12284.
Vous pouvez également utiliser l'entrée Bibtex suivante.
@article { Arnold2020-ss ,
title = " learn2learn: A Library for {Meta-Learning} Research " ,
author = " Arnold, S{'e}bastien M R and Mahajan, Praateek and Datta,
Debajyoti and Bunner, Ian and Zarkias, Konstantinos Saitas " ,
month = aug,
year = 2020 ,
url = " http://arxiv.org/abs/2008.12284 " ,
archivePrefix = " arXiv " ,
primaryClass = " cs.LG " ,
eprint = " 2008.12284 "
}
nn.Module pour être apatrides, Learn2Learn conserve le look et la feel du Pytorch. Pour plus d'informations, reportez-vous à leur article Arxiv.