reptile pytorch
1.0.0
Pytorch实施OpenAI的爬行动物算法,以进行监督学习。
目前,它在Omniglot上运行,但尚未在Miniimagenet上运行。
该代码尚未进行广泛测试。贡献和反馈非常欢迎!
火炬设施中已经有一个omniglot数据集类别,但是它似乎比几乎没有射击的学习更适合监督学习。
omniglot.py提供了一种方法,可以从Omniglot中采样K-Shot N-Way基底任务,以及各种实用程序来分配元训练集以及基础任务。
下载Omniglot数据集的两个部分:
在存储库中创建一个omniglot/文件夹,解开拉链并合并两个文件以具有以下文件夹结构:
./train_omniglot.py
...
./omniglot/Alphabet_of_the_Magi/
./omniglot/Angelic/
./omniglot/Anglo-Saxon_Futhorc/
...
./omniglot/ULOG/
现在开始训练
python train_omniglot.py log --cuda 0 $HYPERPARAMETERS # with CPU
python train_omniglot.py log $HYPERPARAMETERS # with CUDA
$超参数取决于您的任务和超参数。
行为:
log/中找不到检查点,则将创建一个log/文件夹来存储张量板信息和检查点。log/中找到了检查点,则将从最后一个检查点恢复。可以随时使用^C进行训练,并通过重新运行同一命令从上一个检查站恢复。
以下一组超参数效果很好。它们是从OpenAI实施中获取的,但适用于meta-batch=1 。


对于5速5速(红色曲线):
python train_omniglot.py log/o55 --classes 5 --shots 5 --train-shots 10 --meta-iterations 100000 --iterations 5 --test-iterations 50 --batch 10 --meta-lr 0.2 --lr 0.001对于5向1摄(蓝色曲线):
python train_omniglot.py log/o51 --classes 5 --shots 1 --train-shots 12 --meta-iterations 200000 --iterations 12 --test-iterations 86 --batch 10 --meta-lr 0.33 --lr 0.00044