reptile pytorch
1.0.0
Pytorch實施OpenAI的爬行動物算法,以進行監督學習。
目前,它在Omniglot上運行,但尚未在Miniimagenet上運行。
該代碼尚未進行廣泛測試。貢獻和反饋非常歡迎!
火炬設施中已經有一個omniglot數據集類別,但是它似乎比幾乎沒有射擊的學習更適合監督學習。
omniglot.py提供了一種方法,可以從Omniglot中採樣K-Shot N-Way基底任務,以及各種實用程序來分配元訓練集以及基礎任務。
下載Omniglot數據集的兩個部分:
在存儲庫中創建一個omniglot/文件夾,解開拉鍊並合併兩個文件以具有以下文件夾結構:
./train_omniglot.py
...
./omniglot/Alphabet_of_the_Magi/
./omniglot/Angelic/
./omniglot/Anglo-Saxon_Futhorc/
...
./omniglot/ULOG/
現在開始訓練
python train_omniglot.py log --cuda 0 $HYPERPARAMETERS # with CPU
python train_omniglot.py log $HYPERPARAMETERS # with CUDA
$超參數取決於您的任務和超參數。
行為:
log/中找不到檢查點,則將創建一個log/文件夾來存儲張量板信息和檢查點。log/中找到了檢查點,則將從最後一個檢查點恢復。可以隨時使用^C進行訓練,並通過重新運行同一命令從上一個檢查站恢復。
以下一組超參數效果很好。它們是從OpenAI實施中獲取的,但適用於meta-batch=1 。


對於5速5速(紅色曲線):
python train_omniglot.py log/o55 --classes 5 --shots 5 --train-shots 10 --meta-iterations 100000 --iterations 5 --test-iterations 50 --batch 10 --meta-lr 0.2 --lr 0.001對於5向1攝(藍色曲線):
python train_omniglot.py log/o51 --classes 5 --shots 1 --train-shots 12 --meta-iterations 200000 --iterations 12 --test-iterations 86 --batch 10 --meta-lr 0.33 --lr 0.00044