reptile pytorch
1.0.0
監視された学習のためのOpenaiの爬虫類アルゴリズムのPytorch実装。
現在、Omniglotで実行されていますが、まだMiniimagenetでは実行されていません。
コードは広範囲にテストされていません。貢献とフィードバックは大歓迎です!
TorchvisionにはすでにOmniglot Datasetクラスがありますが、少数のショット学習よりも監視された学習に適しているようです。
omniglot.py 、OmniglotのK-Shot N-wayベースタスクとさまざまなユーティリティをサンプリングして、メタトレーニングセットとベースタスクを分割する方法を提供します。
Omniglotデータセットの2つの部分をダウンロードしてください。
リポジトリにomniglot/フォルダーを作成し、2つのファイルを解凍してマージして、次のフォルダー構造を持つようにします。
./train_omniglot.py
...
./omniglot/Alphabet_of_the_Magi/
./omniglot/Angelic/
./omniglot/Anglo-Saxon_Futhorc/
...
./omniglot/ULOG/
でトレーニングを開始します
python train_omniglot.py log --cuda 0 $HYPERPARAMETERS # with CPU
python train_omniglot.py log $HYPERPARAMETERS # with CUDA
$ HyperParametersは、タスクとハイパーパラメーターに依存します。
行動:
log/にチェックポイントが見つからない場合、これにより、テンソルボード情報とチェックポイントを保存するlog/フォルダーが作成されます。log/にチェックポイントが見つかった場合、これは最後のチェックポイントから再開されます。トレーニングは^Cでいつでも中断し、同じコマンドを再実行することにより最後のチェックポイントから再開できます。
次のハイパーパラメーターのセットはきちんと動作します。それらはOpenaiの実装から取られますが、 meta-batch=1にわずかに適合しています。


5ウェイ5ショット(赤い曲線)の場合:
python train_omniglot.py log/o55 --classes 5 --shots 5 --train-shots 10 --meta-iterations 100000 --iterations 5 --test-iterations 50 --batch 10 --meta-lr 0.2 --lr 0.0015ウェイ1ショット(青い曲線)の場合:
python train_omniglot.py log/o51 --classes 5 --shots 1 --train-shots 12 --meta-iterations 200000 --iterations 12 --test-iterations 86 --batch 10 --meta-lr 0.33 --lr 0.00044