学習埋め込みのためのシャムおよび三重項ネットワークのPytorch実装。
シャムおよびトリプレットネットワークは、距離が類似性の尺度に対応するコンパクトなユークリッド空間へのマッピングを学習するのに役立ちます[2]。このような方法でトレーニングされた埋め込みは、分類または少数のショット学習タスクのための機能ベクトルとして使用できます。
TorchVision 0.2.1を備えたPytorch 0.4が必要です
Pytorch 0.3互換性チェックアウトタグTorch-0.3.1の場合
Mnistデータセットの埋め込みをトレーニングします。実験はJupyterノートブックで実行されました。
MNISTデータセット上のさまざまな損失関数を使用して、監視された機能の埋め込みを学習します。これは視覚化のためだけのものであるため、実際には最良の選択ではない2次元の埋め込みを使用します。
すべての実験では、同じ埋め込みネットワークが使用されます(32 Conv 5x5-> Preu-> Maxpool 2x2-> 64 Conv 5x5-> Preu-> Maxpool 2x2-> dense 256-> Prelu-> dense 256-> Prelu-> Dense 2)。
クラスの数を含む完全に接続されたレイヤーを追加し、ソフトマックスとクロスエントロピーを使用して分類のためにネットワークをトレーニングします。ネットワークは、約99%の精度をトレーニングします。最後から2つの寸法の埋め込みを抽出します。
電車セット:

テストセット:

埋め込みは分離可能に見えますが(これは私たちがそれらを訓練したものです)、それらは良いメトリック特性を持っていません。それらは、新しいクラスの記述子としての最良の選択ではないかもしれません。
ここで、画像をペアにして埋め込みを訓練するシャムネットワークをトレーニングして、同じクラスからの距離が最小化され、異なるクラスを表す場合はマージン値よりも大きいようになります。対照的な損失関数を最小限に抑えます[1]:

Siamesemnistクラスは、ランダムな正と負のペアをサンプルし、シャムネットワークに供給されます。
ここに20個のトレーニングの後、トレーニングセットのために得られる埋め込みがあります。

テストセット:

学習した埋め込みは、クラス内ではるかによくクラスター化されています。
トリプレットネットワークをトレーニングします。これは、アンカー、(アンカーと同じクラスの)ポジティブな(アンカーとは異なるクラスの)否定的な例を採用します。目的は、アンカーが何らかのマージン値によって否定的な例よりも肯定的な例に近づくように、埋め込みを学習することです。
出典: Schroff、Florian、Dmitry Kalenichenko、James Philbin。 FaceNet:顔認識とクラスタリングのための統一された埋め込み。 CVPR 2015。
トリプレット損失: 
tripletmnistクラスは、可能なすべてのアンカーに対して正と否定的な例をサンプリングします。
ここに20個のトレーニングの後、トレーニングセットのために得られる埋め込みがあります。

テストセット:

学習した埋め込みは、シャムネットワークの場合ほどクラス内で互いに近いものではありませんが、それは私たちが最適化したものではありません。私たちは、埋め込みが他のクラスよりも同じクラスからの他の埋め込みに近づくことを望んでいました。それがトレーニングが行く場所であることがわかります。
シャムおよびトリプレットネットワークにはいくつかの問題があります。
これらの問題に効率的に対処するために、分類のために行ったように、標準のミニバッチでネットワークを供給します。損失関数は、ミニバッチ内のハードペアとトリプレットの選択を担当します。 10クラスごとに16の画像でネットワークをフィードすると、最大159*160/2 = 12720ペアと10*16*15/2*(9*16)= 172800トリプレットを処理できます。
通常、ミニバッチ内のすべての可能なペアまたはトリプレットを処理することは最良のアイデアではありません。 [2]および[3]でトリプレットを選択する方法に関するいくつかの戦略を見つけることができます。
分類ネットワークで行ったように、ネットワークにミニバッチでフィードします。今回は、各クラス内のN_CLASSESとN_SAMPLESをサンプリングする特別なバッチサンプラーを使用して、サイズN_CLASSES*N_SAMPLESのミニバッチを作成します。
各ミニバッチの正のペアとネガティブペアは、提供されたラベルを使用して選択されます。
Mnistはかなり簡単なデータセットであり、ランダムに選択されたペアからの埋め込みはすでに非常に良かったので、ここではあまり改善されていません。
埋め込みをトレーニング:

埋め込みをテスト:

オンラインペアの選択と同様に、ネットワークにミニバッチを使用します。ラベルと予測された埋め込みを与えられたトリプレット選択に使用できる戦略はいくつかあります。
トリプレット選択の戦略は慎重に選択する必要があります。悪い戦略は、非効率的なトレーニングにつながる可能性があります。さらに悪いことに、崩壊をモデル化することになります(すべての埋め込みは同じ値を持つことになります)。
ポジティブペアごとにランダムなハードネガで得たものは次のとおりです。
トレーニングセット:

テストセット:

FashionMnist Datasetについても同様の実験が行われ、オンラインのネガティブマイニングの利点がわずかに目立つようになりました。 2次元の埋め込みのみを備えたまったく同じネットワークアーキテクチャが使用されましたが、これはおそらく、良い埋め込みを学習するのに十分なほど複雑ではありません。より高い数のクラスを持つより複雑なデータセットは、オンラインマイニングからさらに利益を得る必要があります。

ランダムに選択されたペアを持つシャムネットワーク

マイナスマイニングによるオンライン対照損失

ランダムなトリプレットを備えたトリプレットネットワーク

マイナスマイニングによるオンライントリプレット損失

[1] Raia Hadsell、Sumit Chopra、Yann Lecun、Invariantマッピングの学習による次元削減、CVPR 2006
[2] Schroff、Florian、Dmitry Kalenichenko、およびJames Philbin。 FaceNet:顔認識とクラスタリングのための統一された埋め込み。 CVPR 2015
[3] Alexander Hermans、Lucas Beyer、Bastian Leibe、2017
[4] Brandon Amos、Bartosz Ludwiczuk、Mahadev Satyanarayanan、Openface:モバイルアプリケーションを備えた汎用フェイス認識ライブラリ、2016年
[5] Yi Sun、Xiaogang Wang、Xiaoou Tang、共同識別検証による深い学習顔表現、NIPS 2014