該存儲庫的目的是包含清潔,可讀和測試的代碼,以復制幾乎沒有射擊的學習研究。
該項目用Python 3.6和Pytorch編寫,並假設您有GPU。
有關更多信息,請參閱這些中等文章
列表在requirements.txt中。使用pip install -r requirements.txt安裝virtualenv中。
config.py中的DATA_PATH變量編輯為存儲Omniglot和Miniimagenet數據集的位置。
獲取數據並運行設置腳本後,您的文件夾結構應該看起來像
DATA_PATH/
Omniglot/
images_background/
images_evaluation/
miniImageNet/
images_background/
images_evaluation/
Omniglot數據集。從https://github.com/brendenlake/omniglot/tree/master/python下載,將提取的文件放在DATA_PATH/Omniglot_Raw中,然後運行scripts/prepare_omniglot.py
Miniimagenet數據集。 scripts/prepare_mini_imagenet.py https://drive.google.com/file/d/0b3irx3uqnobmq1flnxjszudywee/view下載data/miniImageNet/images
添加數據集後,在根目錄中運行pytest以運行所有測試。
文件experiments/experiments.txt包含我用來獲取以下結果的超參數。

運行experiments/proto_nets.py來重現原始網絡的結果,以進行幾次學習(Snell等)。
爭論
| 全能 | ||||
|---|---|---|---|---|
| k-way | 5 | 5 | 20 | 20 |
| N-shot | 1 | 5 | 1 | 5 |
| 出版 | 98.8 | 99.7 | 96.0 | 98.9 |
| 這個存儲庫 | 98.2 | 99.4 | 95.8 | 98.6 |
| 迷你膠原 | ||
|---|---|---|
| k-way | 5 | 5 |
| N-shot | 1 | 5 |
| 出版 | 49.4 | 68.2 |
| 這個存儲庫 | 48.0 | 66.2 |
一個可區分的最近的鄰居分類器。

運行experiments/matching_nets.py來複製一個鏡頭學習的匹配網絡的結果(Vinyals等人)。
爭論
我使用餘弦距離度量重現本文的結果很難重現,因為我發現收斂速度緩慢,最終性能取決於隨機初始化。但是,我能夠使用L2距離度量來複製本文的結果(並略微超過)。
| 全能 | ||||
|---|---|---|---|---|
| k-way | 5 | 5 | 20 | 20 |
| N-shot | 1 | 5 | 1 | 5 |
| 出版(餘弦) | 98.1 | 98.9 | 93.8 | 98.5 |
| 這個倉庫(餘弦) | 92.0 | 93.2 | 75.6 | 77.8 |
| 這個倉庫(L2) | 98.3 | 99.8 | 92.8 | 97.8 |
| 迷你膠原 | ||
|---|---|---|
| k-way | 5 | 5 |
| N-shot | 1 | 5 |
| 出版(餘弦,fce) | 44.2 | 57.0 |
| 這個倉庫(餘弦,fce) | 42.8 | 53.6 |
| 這個倉庫(L2) | 46.0 | 58.4 |

我使用Max Pooling而不是基礎的捲積,以便與其他論文保持一致。使用2階MAML的迷你imagenet實驗使我一整天都可以運行。
運行experiments/maml.py以復制模型敏銳的元學習的結果(Finn等人)。
爭論
NB:對於MAML N,K和Q,在火車和測試之間固定。您可能需要調整元批量大小以適合您的GPU。第二階MAML使用更多內存。
| 全能 | ||||
|---|---|---|---|---|
| k-way | 5 | 5 | 20 | 20 |
| N-shot | 1 | 5 | 1 | 5 |
| 出版 | 98.7 | 99.9 | 95.8 | 98.9 |
| 這個倉庫(1) | 95.5 | 99.5 | 92.2 | 97.7 |
| 這個倉庫(2) | 98.1 | 99.8 | 91.6 | 95.9 |
| 迷你膠原 | ||
|---|---|---|
| k-way | 5 | 5 |
| N-shot | 1 | 5 |
| 出版 | 48.1 | 63.2 |
| 這個倉庫(1) | 46.4 | 63.3 |
| 這個倉庫(2) | 47.5 | 64.7 |
括號中的數字表示一階或第二階MAML。