
Learn2Learn은 메타 학습 연구를위한 소프트웨어 라이브러리입니다.
Learn2Learn은 메타 학습 연구주기의 두 가지 측면을 가속화하기 위해 Pytorch 위에 구축됩니다.
Learn2Learn은 기존 알고리즘 및 표준화 된 벤치 마크의 고품질 구현과 함께 새로운 알고리즘 및 도메인을 만들기 위해 저수준 유틸리티 및 통합 인터페이스를 제공합니다. Torchvision, Torchaudio, Torchtext, Cherry 및 다른 Pytorch 기반 라이브러리와의 호환성을 유지합니다.
자세한 내용은 백서 : ARXIV : 2008.12284를 참조하십시오
개요
learn2learn.data : Taskset 및 변환 모든 Pytorch 데이터 세트에서 몇 가지 샷 작업을 생성합니다.learn2learn.vision : 컴퓨터 비전 및 소수의 학습을위한 모델, 데이터 세트 및 벤치 마크.learn2learn.gym : 메타 반복 학습을위한 환경 및 유틸리티.learn2learn.algorithms : 기존 메타 학습 알고리즘에 대한 고급 랩퍼.learn2learn.optim : 차별화 가능한 최적화 및 메타 디스크를위한 유틸리티 및 알고리즘.자원
pip install learn2learn다음 스 니펫은 Learn2Learn의 기능을 엿볼 수 있습니다.
더 많은 알고리즘 (Protonets, Anil, Meta-SGD, Reptile, Meta-Curvature, KFO)에 대해서는 예제 폴더를 참조하십시오. 대부분은 GBML 래퍼와 함께 구현할 수 있습니다. (선적 서류 비치).
maml = l2l . algorithms . MAML ( model , lr = 0.1 )
opt = torch . optim . SGD ( maml . parameters (), lr = 0.001 )
for iteration in range ( 10 ):
opt . zero_grad ()
task_model = maml . clone () # torch.clone() for nn.Modules
adaptation_loss = compute_loss ( task_model )
task_model . adapt ( adaptation_loss ) # computes gradient, update task_model in-place
evaluation_loss = compute_loss ( task_model )
evaluation_loss . backward () # gradients w.r.t. maml.parameters()
opt . step () LearnableOptimizer 사용하여 모든 종류의 최적화 알고리즘을 배우십시오. (예제 및 문서)
linear = nn . Linear ( 784 , 10 )
transform = l2l . optim . ModuleTransform ( l2l . nn . Scale )
metaopt = l2l . optim . LearnableOptimizer ( linear , transform , lr = 0.01 ) # metaopt has .step()
opt = torch . optim . SGD ( metaopt . parameters (), lr = 0.001 ) # metaopt also has .parameters()
metaopt . zero_grad ()
opt . zero_grad ()
error = loss ( linear ( X ), y )
error . backward ()
opt . step () # update metaopt
metaopt . step () # update linear 많은 표준화 된 데이터 세트 (Omniglot, Mini-/Tiered-Imagenet, FC100, CIFAR-FS)는 learn2learn.vision.datasets 에서 쉽게 사용할 수 있습니다. (선적 서류 비치)
dataset = l2l . data . MetaDataset ( MyDataset ()) # any PyTorch dataset
transforms = [ # Easy to define your own transform
l2l . data . transforms . NWays ( dataset , n = 5 ),
l2l . data . transforms . KShots ( dataset , k = 1 ),
l2l . data . transforms . LoadData ( dataset ),
]
taskset = Taskset ( dataset , transforms , num_tasks = 20000 )
for task in taskset :
X , y = task
# Meta-train on the task 자신의 메타 환경을 AsyncVectorEnv 와 병렬화하거나 표준화 된 것들을 사용하십시오. (선적 서류 비치)
def make_env ():
env = l2l . gym . HalfCheetahForwardBackwardEnv ()
env = cherry . envs . ActionSpaceScaler ( env )
return env
env = l2l . gym . AsyncVectorEnv ([ make_env for _ in range ( 16 )]) # uses 16 threads
for task_config in env . sample_tasks ( 20 ):
env . set_task ( task ) # all threads receive the same task
state = env . reset () # use standard Gym API
action = my_policy ( env )
env . step ( action )Pytorch 모듈의 업데이트를 통해 배우고 차별화하십시오. (선적 서류 비치)
model = MyModel ()
transform = l2l . optim . KroneckerTransform ( l2l . nn . KroneckerLinear )
learned_update = l2l . optim . ParameterUpdate ( # learnable update function
model . parameters (), transform )
clone = l2l . clone_module ( model ) # torch.clone() for nn.Modules
error = loss ( clone ( X ), y )
updates = learned_update ( # similar API as torch.autograd.grad
error ,
clone . parameters (),
create_graph = True ,
)
l2l . update_module ( clone , updates = updates )
loss ( clone ( X ), y ). backward () # Gradients w.r.t model.parameters() and learned_update.parameters() changelog.md 파일에서 휴먼 읽기 가능한 changelog를 사용할 수 있습니다.
학술 간행물에서 learn2learn 저장소를 인용하려면 다음 참조를 사용하십시오.
Arnold, Sebastien MR, Praateek Mahajan, Debajyoti Datta, Ian Bunner 및 Konstantinos Saitas Zarkias. 2020.“Learn2Learn : 메타 학습 연구를위한 도서관.” arxiv [cs.lg]. http://arxiv.org/abs/2008.12284.
다음 Bibtex 항목을 사용할 수도 있습니다.
@article { Arnold2020-ss ,
title = " learn2learn: A Library for {Meta-Learning} Research " ,
author = " Arnold, S{'e}bastien M R and Mahajan, Praateek and Datta,
Debajyoti and Bunner, Ian and Zarkias, Konstantinos Saitas " ,
month = aug,
year = 2020 ,
url = " http://arxiv.org/abs/2008.12284 " ,
archivePrefix = " arXiv " ,
primaryClass = " cs.LG " ,
eprint = " 2008.12284 "
}
nn.Module 이 무국적으로 유지되는 동안, Learn2Learn은 상태의 Pytorch 모양과 느낌을 유지합니다. 자세한 내용은 ARXIV 용지를 참조하십시오.