
11 Desember : v2.8.0
2 November : v2.7.0
Lihat folder contoh untuk buku catatan yang dapat Anda unduh atau jalankan di Google Colab.
Perpustakaan ini berisi 9 modul, yang masing -masing dapat digunakan secara independen dalam basis kode Anda yang ada, atau digabungkan bersama untuk alur kerja kereta/pengujian lengkap.

Mari kita inisialisasi tripletmarginloss polos:
from pytorch_metric_learning import losses
loss_func = losses . TripletMarginLoss ()Untuk menghitung kerugian dalam loop pelatihan Anda, lewati embeddings yang dihitung oleh model Anda, dan label yang sesuai. Embeddings harus memiliki ukuran (n, embedding_size), dan label harus memiliki ukuran (n), di mana n adalah ukuran batch.
# your training loop
for i , ( data , labels ) in enumerate ( dataloader ):
optimizer . zero_grad ()
embeddings = model ( data )
loss = loss_func ( embeddings , labels )
loss . backward ()
optimizer . step ()Tripletmarginloss menghitung semua triplet yang mungkin dalam batch, berdasarkan label yang Anda masukkan ke dalamnya. Pasangan anchor-positif dibentuk oleh embeddings yang memiliki label yang sama, dan pasangan jangkar negatif dibentuk oleh embeddings yang memiliki label yang berbeda.
Terkadang dapat membantu menambahkan fungsi penambangan:
from pytorch_metric_learning import miners , losses
miner = miners . MultiSimilarityMiner ()
loss_func = losses . TripletMarginLoss ()
# your training loop
for i , ( data , labels ) in enumerate ( dataloader ):
optimizer . zero_grad ()
embeddings = model ( data )
hard_pairs = miner ( embeddings , labels )
loss = loss_func ( embeddings , labels , hard_pairs )
loss . backward ()
optimizer . step ()Dalam kode di atas, penambang menemukan pasangan positif dan negatif yang menurutnya sangat sulit. Perhatikan bahwa meskipun tripletmarginloss beroperasi pada kembar tiga, masih mungkin untuk lulus berpasangan. Ini karena perpustakaan secara otomatis mengkonversi pasangan menjadi kembar tiga dan kembar tiga menjadi pasangan, bila perlu.
Fungsi kerugian dapat disesuaikan dengan menggunakan jarak, pereduksi, dan regulteran. Dalam diagram di bawah ini, seorang penambang menemukan indeks pasangan keras dalam satu batch. Ini digunakan untuk mengindeks ke dalam matriks jarak, dihitung oleh objek jarak. Untuk diagram ini, fungsi kerugian berbasis pasangan, sehingga menghitung kerugian per pasangan. Selain itu, regularer telah disediakan, sehingga kehilangan regularisasi dihitung untuk setiap penyematan dalam batch. Kerugian per-pasangan dan per-elemen diteruskan ke peredam, yang (dalam diagram ini) hanya menjaga kerugian dengan nilai tinggi. Rata-rata dihitung untuk pasangan bernilai tinggi dan kehilangan elemen, dan kemudian ditambahkan bersama untuk mendapatkan kerugian akhir.

Sekarang inilah contoh tripletmarginloss yang disesuaikan:
from pytorch_metric_learning . distances import CosineSimilarity
from pytorch_metric_learning . reducers import ThresholdReducer
from pytorch_metric_learning . regularizers import LpRegularizer
from pytorch_metric_learning import losses
loss_func = losses . TripletMarginLoss ( distance = CosineSimilarity (),
reducer = ThresholdReducer ( high = 0.3 ),
embedding_regularizer = LpRegularizer ())Kehilangan triplet yang disesuaikan ini memiliki properti berikut:
Pembungkus SelfSupervisedLoss disediakan untuk pembelajaran yang di-swadaya:
from pytorch_metric_learning . losses import SelfSupervisedLoss
loss_func = SelfSupervisedLoss ( TripletMarginLoss ())
# your training for-loop
for i , data in enumerate ( dataloader ):
optimizer . zero_grad ()
embeddings = your_model ( data )
augmented = your_model ( your_augmentation ( data ))
loss = loss_func ( embeddings , augmented )
loss . backward ()
optimizer . step ()Jika Anda tertarik dengan persiapan diri bergaya MOCO, lihatlah MOCO di CIFAR10 Notebook. Ini menggunakan CrossBatchMemory untuk mengimplementasikan momentum encoder antrian, yang berarti Anda dapat menggunakan kehilangan tuple, dan penambang tuple untuk mengekstrak sampel keras dari antrian.
Jika Anda kekurangan waktu dan menginginkan alur kerja kereta/tes lengkap, lihat contoh notebook Google Colab.
Untuk mempelajari lebih lanjut tentang semua hal di atas, lihat dokumentasi.
pytorch-metric-learning >= v0.9.90 membutuhkan torch >= 1.6pytorch-metric-learning < v0.9.90 tidak memiliki persyaratan versi, tetapi diuji dengan torch >= 1.2 Ketergantungan Lainnya: numpy, scikit-learn, tqdm, torchvision
pip install pytorch-metric-learning
Untuk mendapatkan versi dev terbaru :
pip install pytorch-metric-learning --pre
Untuk menginstal di windows :
pip install torch===1.6.0 torchvision===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install pytorch-metric-learning
Untuk menginstal dengan kemampuan evaluasi dan logging
(Ini akan menginstal versi PYPI tidak resmi dari FAISS-GPU, ditambah rekor-penjaga dan tensorboard):
pip install pytorch-metric-learning[with-hooks]
Untuk menginstal dengan kemampuan evaluasi dan logging (CPU)
(Ini akan menginstal versi PYPI tidak resmi dari FAISS-CPU, ditambah pemelihara rekaman dan Tensorboard):
pip install pytorch-metric-learning[with-hooks-cpu]
conda install -c conda-forge pytorch-metric-learning
Untuk menggunakan modul pengujian, Anda akan memerlukan FAISS, yang dapat dipasang melalui Conda juga. Lihat instruksi instalasi untuk FAISS.
Lihat Benchmarker yang kuat untuk melihat hasil benchmark dan menggunakan alat tolok ukur.
Pengembangan dilakukan di cabang dev :
git checkout dev
Tes unit dapat dijalankan dengan perpustakaan unittest default:
python -m unittest discoverAnda dapat menentukan tipe data dan perangkat uji sebagai variabel lingkungan. Misalnya, untuk menguji menggunakan float32 dan float64 di CPU:
TEST_DTYPES=float32,float64 TEST_DEVICE=cpu python -m unittest discoverUntuk menjalankan file uji tunggal alih -alih seluruh suite tes, tentukan nama file:
python -m unittest tests/losses/test_angular_loss.py Kode diformat menggunakan black dan isort :
pip install black isort
./format_code.shTerima kasih kepada para kontributor yang membuat permintaan tarik!
| Penyumbang | Highlight |
|---|---|
| domenicomuscill0 | - Manifoldloss - P2Sgradloss - Histogramloss - DynamicSoftMarginLoss - RANKEDLISTLOSS |
| Mlopezantequera | - Membuat penguji mengerjakan kombinasi kueri dan set referensi apa pun - Made AccuracyCalculator berfungsi dengan perbandingan label sewenang -wenang |
| cwkeam | - SelfSupervisedLoss - Vicregloss - Menambahkan akurasi peringkat timbal balik rata -rata ke AccuracyCalculator - Baselosswrapper |
| IR2718 | - ThresholdConsistentMarginloss - Modul Dataset |
| marijnl | - Batcheasyhardminer - twostreammetricloss - GlobalTwostreamMBedDingsPacetester - Contoh Menggunakan Pelatih.TwostreamMetricloss |
| Chingisooinar | Subcenterarcfaceloss |
| Elias-Ramzi | HierarchicalSampler |
| fjsj | Supconloss |
| Alenubuntu | Circleloss |
| menarikzhuo | Pnploss |
| Wconnell | Mempelajari embedding metrik scrnaseq |
| mkmenta | get_all_triplets_indices yang ditingkatkan (memperbaiki kesalahan INT_MAX ) |
| Alexschuy | utils.loss_and_miner_utils.get_random_triplet_indices |
| Johngiorgi | all_gather di utils.Distributed |
| Hummer12007 | utils.key_checker |
| Vltanh | Dibuat dataset menerima InferenceModel.train_indexer |
| Btseytlin | get_nearest_neighbors di inferencemodel |
| MLW214 | Menambahkan return_per_class ke AccoreCyCalculator |
| Layumi | Instanceloss |
| Notody | Membantu menambahkan ref_emb dan ref_labels ke pembungkus yang didistribusikan. |
| Elisonsherton | Memperbaiki case tepi di Arcfaceloss. |
| STOMPSJO | Dokumentasi yang ditingkatkan untuk ntxentloss. |
| Puzer | Perbaikan bug untuk pnploss. |
| Elisim | Perbaikan pengembang untuk distributedlosswrapper. |
| Gaetanlepage | |
| Z1W | |
| Thinline72 | |
| tpanum | |
| Fralik | |
| Joaqo | |
| Jookuma | |
| Gkouros | |
| yutanakamura-tky | |
| Kinglittleq | |
| Martin0258 | |
| Michaeldeyzel | |
| Hsinger04 | |
| selesma | |
| Bot66 |
Terima kasih kepada Ser-Nam Lim di Facebook AI, dan penasihat penelitian saya, Profesor Serge milik. Proyek ini dimulai selama magang saya di Facebook AI di mana saya menerima umpan balik berharga dari Ser-Nam, dan tim visi komputer dan insinyur pembelajaran mesin serta ilmuwan penelitian. Secara khusus, terima kasih kepada Ashish Shah dan Austin Reiter untuk meninjau kode saya selama tahap awal pengembangan.
Perpustakaan ini berisi kode yang telah diadaptasi dan dimodifikasi dari repo open-source yang hebat berikut:
Terima kasih kepada Jeff Musgrave karena telah merancang logo.
Jika Anda ingin mengutip Pytorch-Metric-Learning di koran Anda, Anda dapat menggunakan Bibtex ini:
@article{Musgrave2020PyTorchML,
title={PyTorch Metric Learning},
author={Kevin Musgrave and Serge J. Belongie and Ser-Nam Lim},
journal={ArXiv},
year={2020},
volume={abs/2008.09164}
}