
12月11日:v2.8.0
11月2日:v2.7.0
請參閱示例文件夾以獲取可以在Google Colab上下載或運行的筆記本。
該庫包含9個模塊,每個模塊都可以在您現有的代碼庫中獨立使用,也可以合併在一起以進行完整的火車/測試工作流程。

讓我們初始化一個普通的TripletMarginloss:
from pytorch_metric_learning import losses
loss_func = losses . TripletMarginLoss ()要計算訓練循環中的損失,請傳遞模型計算的嵌入式以及相應的標籤。嵌入式應具有大小(n,embedding_size),標籤應具有尺寸(n),其中n是批處理大小。
# your training loop
for i , ( data , labels ) in enumerate ( dataloader ):
optimizer . zero_grad ()
embeddings = model ( data )
loss = loss_func ( embeddings , labels )
loss . backward ()
optimizer . step ()TripletMarginloss根據您傳遞的標籤計算批處理內的所有可能的三重態。錨陽性對由共享同一標籤的嵌入形成形成,並且錨點陰性對由具有不同標籤的嵌入形成。
有時它可以幫助添加採礦功能:
from pytorch_metric_learning import miners , losses
miner = miners . MultiSimilarityMiner ()
loss_func = losses . TripletMarginLoss ()
# your training loop
for i , ( data , labels ) in enumerate ( dataloader ):
optimizer . zero_grad ()
embeddings = model ( data )
hard_pairs = miner ( embeddings , labels )
loss = loss_func ( embeddings , labels , hard_pairs )
loss . backward ()
optimizer . step ()在上面的代碼中,礦工發現它認為特別困難的正面和負面對。請注意,即使TripletMarginloss在三胞胎上運行,仍然可以成對通過。這是因為在必要時,庫會自動將對轉換為三胞胎和三重態。
損失功能可以使用距離,還原和正規化器來定制。在下圖中,礦工在批處理中找到了硬對的索引。這些用於將距離矩陣索引到距離對象計算的距離矩陣。對於此圖,損耗函數是基於對的,因此它計算每對損耗。此外,還提供了正規化程序,因此計算批處理中的每個嵌入的正規化損失。人均和每元素損失傳遞給還原器,該還原器(在此圖中)僅將損失保持較高的值。計算高價值對和元素損失的平均值,然後將其添加在一起以獲得最終損失。

現在,這是自定義的TripletMarginloss的示例:
from pytorch_metric_learning . distances import CosineSimilarity
from pytorch_metric_learning . reducers import ThresholdReducer
from pytorch_metric_learning . regularizers import LpRegularizer
from pytorch_metric_learning import losses
loss_func = losses . TripletMarginLoss ( distance = CosineSimilarity (),
reducer = ThresholdReducer ( high = 0.3 ),
embedding_regularizer = LpRegularizer ())該自定義的三重態損失具有以下屬性:
為自我監督學習提供了一個SelfSupervisedLoss包裝器:
from pytorch_metric_learning . losses import SelfSupervisedLoss
loss_func = SelfSupervisedLoss ( TripletMarginLoss ())
# your training for-loop
for i , data in enumerate ( dataloader ):
optimizer . zero_grad ()
embeddings = your_model ( data )
augmented = your_model ( your_augmentation ( data ))
loss = loss_func ( embeddings , augmented )
loss . backward ()
optimizer . step ()如果您對Moco風格的自我設計感興趣,請查看CIFAR10筆記本上的Moco。它使用CrossBatchMemory來實現動量編碼器隊列,這意味著您可以使用任何元組損失,並且任何元組礦工從隊列中提取硬樣品。
如果您沒有時間並想要完整的火車/測試工作流程,請查看Google Colab筆記本示例。
要了解以上所有內容的更多信息,請參見文檔。
pytorch-metric-learning >= v0.9.90需要torch >= 1.6pytorch-metric-learning < v0.9.90沒有版本要求,但用torch >= 1.2其他依賴性: numpy, scikit-learn, tqdm, torchvision
pip install pytorch-metric-learning
要獲得最新的開發版本:
pip install pytorch-metric-learning --pre
在Windows上安裝:
pip install torch===1.6.0 torchvision===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install pytorch-metric-learning
使用評估和記錄功能安裝
(這將安裝FAISS-GPU的非官方PYPI版本,以及記錄保存器和Tensorboard):
pip install pytorch-metric-learning[with-hooks]
使用評估和記錄功能(CPU)安裝
(這將安裝Faiss-CPU的非官方PYPI版本,以及記錄保存器和張板):
pip install pytorch-metric-learning[with-hooks-cpu]
conda install -c conda-forge pytorch-metric-learning
要使用測試模塊,您將需要FAISS,也可以通過Conda安裝。請參閱FAISS的安裝說明。
請參閱功能強大的基準測試器以查看基準結果並使用基準測試工具。
開發部門在dev部門進行:
git checkout dev
可以使用默認的unittest庫進行單元測試:
python -m unittest discover您可以將測試數據類型和測試設備指定為環境變量。例如,在CPU上使用Float32和Float64進行測試:
TEST_DTYPES=float32,float64 TEST_DEVICE=cpu python -m unittest discover要運行一個測試文件,而不是整個測試套件,請指定文件名:
python -m unittest tests/losses/test_angular_loss.py代碼是使用black和isort格式化的:
pip install black isort
./format_code.sh感謝提出拉的貢獻者!
| 貢獻者 | 亮點 |
|---|---|
| Domenicomuscill0 | - 歧管 -p2sgradloss - 直方圖 -DynamicsOftMarginloss - 級別 |
| mlopezantequera | - 使測試人員可以解決查詢和參考集的任何組合 - 使精確度量表與任意標籤比較一起工作 |
| cwkeam | - 自我求婚 - vicregloss - 添加了精確度計算器的平均相互等級精度 - BaselossWrapper |
| IR2718 | - 閾值consistentmarginloss - 數據集模塊 |
| Marijnl | - batcheasyhardminer - 兩座米洛斯 - GlobalTwoStreameMbeddingsPacetester - 使用Trainers.twostreammetricloss示例 |
| chingisooinar | 子centerarcfaceloss |
| Elias-Ramzi | 層次採樣器 |
| fjsj | supconloss |
| Alenubuntu | Circleloss |
| 有趣的祖 | pnploss |
| Wconnell | 學習scrnaseq公制嵌入 |
| Mkmenta | 改進的get_all_triplets_indices (修復了INT_MAX錯誤) |
| Alexschuy | 優化的utils.loss_and_miner_utils.get_random_triplet_indices |
| 約翰吉奧吉 | utils.distribed中的all_gather |
| 悍馬12007 | utils.key_checker |
| Vltanh | 製作InferenceModel.train_indexer接受數據集 |
| btseytlin | get_nearest_neighbors中的推論emodel |
| MLW214 | 添加了return_per_class到精確度計算器 |
| Layumi | Instanceloss |
| 不可思議 | 幫助將ref_emb和ref_labels添加到分佈式包裝器中。 |
| Elisonsherton | 修復了Arcfaceloss中的邊緣外殼。 |
| Stompsjo | 改進了NTXENTLOSS的文檔。 |
| puzer | PNPLOSS的錯誤修復。 |
| Elisim | 開發人員改進了DistributeLossWrapper。 |
| Gaetanlepage | |
| Z1W | |
| Thinline72 | |
| tpanum | |
| 弗拉利克 | |
| joaqo | |
| Jookuma | |
| Gkouros | |
| yutanakamura-tky | |
| Kinglittleq | |
| MARTIN0258 | |
| Michaeldeyzel | |
| hsinger04 | |
| 感冒 | |
| BOT66 |
感謝您在Facebook AI的Ser-Nam Lim和我的研究顧問Serge屬於您的研究。該項目始於我在Facebook AI實習期間,我收到了Ser-Nam的寶貴反饋,他的計算機視覺和機器學習工程師和研究科學家團隊。特別是,感謝Ashish Shah和Austin Reiter在開發的早期階段審查了我的代碼。
該庫包含已從以下大型開源存儲庫進行調整和修改的代碼:
感謝Jeff Musgrave設計徽標。
如果您想在論文中引用Pytorch-metric-Learning,則可以使用此Bibtex:
@article{Musgrave2020PyTorchML,
title={PyTorch Metric Learning},
author={Kevin Musgrave and Serge J. Belongie and Ser-Nam Lim},
journal={ArXiv},
year={2020},
volume={abs/2008.09164}
}