
Learn2Learn es una biblioteca de software para la investigación del meta-learning.
Learn2Learn se basa en la parte superior de Pytorch para acelerar dos aspectos del ciclo de investigación del meta-aprendizaje:
Learn2Learn proporciona utilidades de bajo nivel e interfaz unificada para crear nuevos algoritmos y dominios, junto con implementaciones de alta calidad de algoritmos existentes y puntos de referencia estandarizados. Conserva la compatibilidad con TorchVision, Torchaudio, TorchText, Cherry y cualquier otra biblioteca basada en Pytorch que pueda estar utilizando.
Para obtener más información, vea nuestro documento técnico: ARXIV: 2008.12284
Descripción general
learn2learn.data : Taskset y se transforma para crear tareas de pocos disparos a partir de cualquier conjunto de datos de Pytorch.learn2learn.vision : modelos, conjuntos de datos y puntos de referencia para la visión por computadora y el aprendizaje de pocos disparos.learn2learn.gym : medio ambiente y servicios públicos para el aprendizaje de meta-refuerzo.learn2learn.algorithms : envoltorios de alto nivel para algoritmos existentes de meta-aprendizaje.learn2learn.optim : utilidades y algoritmos para optimización diferenciable y metadescente.Recursos
pip install learn2learnLos siguientes fragmentos proporcionan un adelanto de las funcionalidades de Learn2Learn.
Para obtener más algoritmos (protonetos, ANIL, meta-SGD, reptil, meta-curvatura, KFO), consulte la carpeta de ejemplos. La mayoría de ellos se pueden implementar con el envoltorio GBML . (documentación).
maml = l2l . algorithms . MAML ( model , lr = 0.1 )
opt = torch . optim . SGD ( maml . parameters (), lr = 0.001 )
for iteration in range ( 10 ):
opt . zero_grad ()
task_model = maml . clone () # torch.clone() for nn.Modules
adaptation_loss = compute_loss ( task_model )
task_model . adapt ( adaptation_loss ) # computes gradient, update task_model in-place
evaluation_loss = compute_loss ( task_model )
evaluation_loss . backward () # gradients w.r.t. maml.parameters()
opt . step () Aprenda cualquier tipo de algoritmo de optimización con LearnableOptimizer . (ejemplo y documentación)
linear = nn . Linear ( 784 , 10 )
transform = l2l . optim . ModuleTransform ( l2l . nn . Scale )
metaopt = l2l . optim . LearnableOptimizer ( linear , transform , lr = 0.01 ) # metaopt has .step()
opt = torch . optim . SGD ( metaopt . parameters (), lr = 0.001 ) # metaopt also has .parameters()
metaopt . zero_grad ()
opt . zero_grad ()
error = loss ( linear ( X ), y )
error . backward ()
opt . step () # update metaopt
metaopt . step () # update linear Muchos conjuntos de datos estandarizados (Omniglot, mini/nivelado-Imagenet, FC100, CIFAR-FS) están fácilmente disponibles en learn2learn.vision.datasets . (documentación)
dataset = l2l . data . MetaDataset ( MyDataset ()) # any PyTorch dataset
transforms = [ # Easy to define your own transform
l2l . data . transforms . NWays ( dataset , n = 5 ),
l2l . data . transforms . KShots ( dataset , k = 1 ),
l2l . data . transforms . LoadData ( dataset ),
]
taskset = Taskset ( dataset , transforms , num_tasks = 20000 )
for task in taskset :
X , y = task
# Meta-train on the task Paralelice sus propios meta-ambes con AsyncVectorEnv , o use los estandarizados. (documentación)
def make_env ():
env = l2l . gym . HalfCheetahForwardBackwardEnv ()
env = cherry . envs . ActionSpaceScaler ( env )
return env
env = l2l . gym . AsyncVectorEnv ([ make_env for _ in range ( 16 )]) # uses 16 threads
for task_config in env . sample_tasks ( 20 ):
env . set_task ( task ) # all threads receive the same task
state = env . reset () # use standard Gym API
action = my_policy ( env )
env . step ( action )Aprenda y diferencie a través de actualizaciones de módulos Pytorch. (documentación)
model = MyModel ()
transform = l2l . optim . KroneckerTransform ( l2l . nn . KroneckerLinear )
learned_update = l2l . optim . ParameterUpdate ( # learnable update function
model . parameters (), transform )
clone = l2l . clone_module ( model ) # torch.clone() for nn.Modules
error = loss ( clone ( X ), y )
updates = learned_update ( # similar API as torch.autograd.grad
error ,
clone . parameters (),
create_graph = True ,
)
l2l . update_module ( clone , updates = updates )
loss ( clone ( X ), y ). backward () # Gradients w.r.t model.parameters() and learned_update.parameters() Un cambio de cambios legible por humanos está disponible en el archivo ChangeLog.md.
Para citar el repositorio learn2learn en sus publicaciones académicas, utilice la siguiente referencia.
Arnold, Sebastien MR, Praateek Mahajan, Debajyoti Datta, Ian Bunner y Konstantinos Saitas Zarkias. 2020. "Learn2Learn: una biblioteca para la investigación del meta-learning". arxiv [cs.lg]. http://arxiv.org/abs/2008.12284.
También puede usar la siguiente entrada de Bibtex.
@article { Arnold2020-ss ,
title = " learn2learn: A Library for {Meta-Learning} Research " ,
author = " Arnold, S{'e}bastien M R and Mahajan, Praateek and Datta,
Debajyoti and Bunner, Ian and Zarkias, Konstantinos Saitas " ,
month = aug,
year = 2020 ,
url = " http://arxiv.org/abs/2008.12284 " ,
archivePrefix = " arXiv " ,
primaryClass = " cs.LG " ,
eprint = " 2008.12284 "
}
nn.Module estarán apátrate, Learn2Learn conserva el estado de apariencia de Pytorch. Para obtener más información, consulte su documento ARXIV.