簡單的替代實現Pytorch中的原型網絡,用於少數拍攝學習(紙,代碼)。
As shown in the reference paper Prototypical Networks are trained to embed samples features in a vectorial space, in particular, at each episode (iteration), a number of samples for a subset of classes are selected and sent through the model, for each subset of class c a number of samples' features ( n_support ) are used to guess the prototype (their barycentre coordinates in the vectorial space) for that class, so then the distances between the remaining n_query樣品及其類Barycentre可以最小化。

訓練後,您可以計算模型生成的特徵的T-SNE(在此存儲庫中不完成,在此處進行更多有關T-SNE的信息),這是本文所示的樣本。

感謝@ludc的貢獻:Pytorch/Vision#46。如果不暗示代碼的大更改,我們將使用官方數據集將其添加到torchvision中。
我們像[匹配一個鏡頭學習的匹配網絡]中實現了Vynials分裂方法。這應該與論文中使用的方法相同(實際上,我從“公開”倉庫中下載了拆分文件)。然後,我們在其中描述了相同的旋轉。通過這種方式,我們應該能夠比較通過運行此代碼與參考文件中描述的結果獲得的結果。
如PYDOC中所述,該類用於生成每批索引的原型訓練算法。
特別是,該對像是通過傳遞數據集的標籤列表來實例化的,採樣器滲透了類,然後為每個類的總數,並為每個類NI的數據集創建一組索引。在每個情節中,採樣器選擇n_classes隨機類,並返回每個選定類的樣本索引的數字( n_support + n_query )。
按照引用的論文計算損失,主要是由其一位作者啟發的。
在prototypical_loss.py中,實施了損失函數和損失類別。
該函數從模型中輸入批處理輸入,示例的地面真相以及用於用作支持樣本的樣本的數字n_suppport 。從目標列表中推斷情節類, n_support樣品為每個類隨機提取,計算其類的barycentres,以及每個剩餘的樣本中每個班級嵌入的距離的距離,以及每個情節類別的每個樣本的概率,都計算出每個情節類別的範圍;然後,從分類問題中的往常中,從錯誤的預測概率(對於查詢樣本)計算出損失。
請注意,培訓代碼僅出於演示目的。
要在此任務上訓練Protonet,請將CD介紹到該倉庫的src根文件夾中並執行:
$ python train.py
該腳本採用以下命令行選項:
dataset_root :存儲數據集的根目錄,默認為'../dataset'
nepochs :訓練的時期數量,默認為100
learning_rate :模型的學習率,默認為0.001
lr_scheduler_step :Steplr學習率調度程序步驟,默認為20
lr_scheduler_gamma :Steplr學習率調度程序伽瑪,默認為0.5
iterations :每個時期的發作數。默認為100
classes_per_it_tr :訓練每集的隨機類數量。默認為60
num_support_tr :每個課程的樣本數量用於培訓。默認為5
num_query_tr :每班樣本的nnumber用作培訓的查詢。默認為5
classes_per_it_val :每集的隨機類數量進行驗證。默認為5
num_support_val :每個類的樣本數量用於驗證。默認為5
num_query_val :每個類的樣本數量用於查詢驗證。默認為15
manual_seed :手動種子初始化的輸入,默認為7
cuda :啟用CUDA(商店True )
在沒有參數的情況下運行命令將使用默認的超參數值訓練模型(上面顯示的結果)。
我們正在嘗試重現參考紙表演,我們將在此處更新我們的最佳結果。
| 模型 | 1-shot(5向Acc。) | 5-shot(5-way Acc。) | 1 -shot(20向上ACC。) | 5-shot(20向上ACC。) |
|---|---|---|---|---|
| 參考文件 | 98.8% | 99.7% | 96.0% | 98.9% |
| 這個存儲庫 | 98.5%** | 99.6%* | 95.1%° | 98.6%° |
*使用默認參數實現(使用--cuda選項)
**實現了運行python train.py --cuda -nsTr 1 -nsVa 1
°實現了運行python train.py --cuda -nsTr 1 -nsVa 1 -cVa 20
°°實現了python train.py --cuda -nsTr 5 -nsVa 5 -cVa 20
引用紙張如下(為您複製從Arxiv複製):
@article{DBLP:journals/corr/SnellSZ17,
author = {Jake Snell and
Kevin Swersky and
Richard S. Zemel},
title = {Prototypical Networks for Few-shot Learning},
journal = {CoRR},
volume = {abs/1703.05175},
year = {2017},
url = {http://arxiv.org/abs/1703.05175},
archivePrefix = {arXiv},
eprint = {1703.05175},
timestamp = {Wed, 07 Jun 2017 14:41:38 +0200},
biburl = {http://dblp.org/rec/bib/journals/corr/SnellSZ17},
bibsource = {dblp computer science bibliography, http://dblp.org}
}
該項目已根據麻省理工學院許可證
版權(C)2018 Daniele E. Ciriello,Orobix SRL(www.orobix.com)。