OML ist ein in Pytorch basierender Rahmen, um die Modelle zu trainieren und zu validieren, die hochwertige Einbettungen erzeugen.
ㅤㅤ
Es gibt eine Reihe von Menschen von Oxford und HSE -Universitäten, die OML in ihren Thesen verwendet haben. [1] [2] [3]
Das Update konzentriert sich auf mehrere Komponenten:
Wir haben "offizielle" Texteunterstützung und die entsprechenden Python -Beispiele hinzugefügt. (Beachten Sie, dass die Support von Texten in Pipelines noch nicht unterstützt wird.)
Wir haben die RetrievalResults ( RR ) -Klasse vorgestellt - einen Container, der für bestimmte Abfragen abgerufen wurde. RR bietet eine einheitliche Möglichkeit, Vorhersagen zu visualisieren und Metriken zu berechnen (wenn die Grundwahrheiten bekannt sind). Es vereinfacht auch die Nachbearbeitung, wobei ein RR Objekt als Eingabe genommen wird und ein weiteres RR_upd als Ausgabe erzeugt wird. Wenn diese beiden Objekte visuell oder durch Metriken vergleichbarer Abrufergebnisse ermöglichen. Darüber hinaus können Sie problemlos eine Kette solcher Postprozessoren erstellen.
RR ist Speicher optimiert, weil sie Stapel verwenden: Mit anderen Worten, es speichert keine vollständige Matrix der Abfragegalleriestrecken. (Es macht die Suche jedoch ungefähr ungefähr). Wir haben Model und Dataset zu den einzigen Klassen gemacht, die für die Verarbeitung modalitätsspezifischer Logik verantwortlich sind. Model ist für die Interpretation seiner Eingangsdimensionen verantwortlich: Beispielsweise BxCxHxW für Bilder oder BxLxD für Sequenzen wie Texte. Dataset ist für die Vorbereitung eines Elements verantwortlich: Es kann Transforms für Bilder oder Tokenizer für Texte verwenden. Funktionen berechnen Metriken wie calc_retrieval_metrics_rr , RetrievalResults , PairwiseReranker und andere Klassen und Funktionen sind einheitlich, um mit jeder Modalität zu arbeiten.
IVisualizableDataset mit Methode .visaulize() hinzugefügt, die ein einzelnes Element anzeigt. Bei der Implementierung können RetrievalResults das Layout von abgerufenen Ergebnissen zeigen. Der einfachste Weg, Änderungen nachzuholen, besteht darin, die Beispiele erneut zu lesen!
Die empfohlene Validierungsmethode besteht darin, RetrievalResults und Funktionen wie calc_retrieval_metrics_rr , calc_fnmr_at_fmr_rr und andere zu verwenden. Die EmbeddingMetrics wird zur Verwendung mit Pytorch -Blitz und Innenleitungen aufbewahrt. Beachten Sie, dass die Signaturen der EmbeddingMetrics -Methoden leicht verändert wurden, siehe Blitzbeispiele dafür.
Da die modalitätsspezifische Logik auf Dataset beschränkt ist, gibt es nicht mehr PATHS_KEY , X1_KEY , X2_KEY , Y1_KEY und Y2_KEY aus. Tasten, die nicht modalitätsspezifisch sind wie LABELS_KEY , IS_GALLERY , IS_QUERY_KEY , CATEGORIES_KEY sind noch verwendet.
inference_on_images ist jetzt inference und funktioniert mit jeder Modalität.
Geringfügig veränderte Schnittstellen von Datasets. Zum Beispiel haben wir IQueryGalleryDataset und IQueryGalleryLabeledDataset -Schnittstellen. Der erste muss für die Inferenz verwendet werden, die zweite für die Validierung. Auch fügte IVisualizableDataset -Schnittstelle hinzu.
Einige Internale wie IMetricDDP , EmbeddingMetricsDDP , calc_distance_matrix , calc_gt_mask , calc_mask_to_ignore , apply_mask_to_ignore entfernt. Diese Änderungen sollten Sie nicht beeinflussen. Auch der Code, der sich auf eine Pipeline mit vorkundigten Tripletts bezieht.
Feature -Extraktion: Keine Änderungen, außer dass ein optionales Argument hinzugefügt wird - mode_for_checkpointing = (min | max) . Es kann nützlich sein, zwischen den unteren, desto besseren und größeren, desto besseren Metriken zu wechseln.
Paarweise postprozessierende Pipeline: Veränderte den Namen und die Argumente der postprocessor Subkonfiguration leicht- pairwise_images ist jetzt pairwise_reranker und benötigt keine Transformationen.
Sie mögen denken , "Wenn ich Bildeinbettungen brauche, kann ich einfach einen Vanilleklassifikator trainieren und seine vorletzte Schicht nehmen" . Nun, es macht Sinn als Ausgangspunkt. Es gibt jedoch mehrere mögliche Nachteile:
Wenn Sie Emetten verwenden möchten, um die Suche durchzuführen, müssen Sie einen gewissen Abstand zwischen ihnen berechnen (z. B. Cosinus oder L2). Normalerweise optimieren Sie diese Entfernungen während des Trainings im Klassifizierungsaufbau nicht direkt . Sie können also nur hoffen, dass endgültige Einbettungen die gewünschten Eigenschaften haben.
Das zweite Problem ist der Validierungsprozess . In der Such-Setup kümmern Sie sich normalerweise darum, wie mit Ihren Top-N-Ausgängen mit der Abfrage zusammenhängt. Die natürliche Möglichkeit, das Modell zu bewerten, besteht darin, die Suchanforderungen in den Referenzsatz zu simulieren und eine der Abrufmetriken anzuwenden. Es gibt also keine Garantie dafür, dass die Klassifizierungsgenauigkeit mit diesen Metriken korreliert.
Schließlich möchten Sie möglicherweise selbst eine metrische Lernpipeline implementieren. Es gibt eine Menge Arbeit : Um Triplett -Verlust zu verwenden, müssen Sie Stapel auf eine bestimmte Weise bilden, verschiedene Arten von Tripletts -Mining, Verfolgung von Entfernungen usw. implementieren. Für die Validierung müssen Sie auch Abrufmetriken implementieren, einschließlich einer effektiven Einbettungsakkumulation während der Epoche, die Deckung von Eckfällen usw. Es ist noch schwer, wenn Sie mehrere GPUs verwenden und DDP verwenden. Möglicherweise möchten Sie auch Ihre Suchanfragen visualisieren, indem Sie gute und schlechte Suchergebnisse hervorheben. Anstatt es selbst zu tun, können Sie OML einfach für Ihre Zwecke verwenden.
PML ist die beliebte Bibliothek für metrisches Lernen und enthält eine reichhaltige Sammlung von Verlusten, Bergleuten, Entfernungen und Reduzierern. Deshalb liefern wir einfache Beispiele für die Verwendung von OML. Anfangs haben wir versucht, PML zu verwenden, aber am Ende haben wir unsere Bibliothek entwickelt, die mehr Pipeline / Rezepte ausgerichtet ist. So unterscheidet sich OML von PML:
OML verfügt über Pipelines, mit denen Trainingsmodelle durch die Erstellung einer Konfiguration und Ihre Daten im erforderlichen Format vorbereitet werden können (es ist wie das Konvertieren von Daten in das CoCo -Format, um einen Detektor aus mmdetektion zu schulen).
OML konzentriert sich auf End-to-End-Pipelines und praktische Anwendungsfälle. Es verfügt über konfigurierte Beispiele zu beliebten Benchmarks in der Nähe des wirklichen Lebens (wie Fotos von Produkten von Tausend -IDs). Wir fanden einige gute Kombinationen von Hyperparametern in diesen Datensätzen, geschulten und veröffentlichten Modellen und deren Konfigurationen. Daher macht OML mehr Rezepte orientiert als PML, und sein Autor bestätigt, dass seine Bibliothek eine Reihe von Tools ist, sondern die Rezepte. Darüber hinaus sind die Beispiele in PML hauptsächlich für CIFAR- und MNIST -Datensätze gelten.
OML hat den Zoo von vorbereiteten Modellen, auf die leicht aus dem Code zugegriffen werden kann wie in torchvision (wenn Sie resnet50(pretrained=True) eingeben).
OML ist in Pytorch Lightning integriert, sodass wir die Kraft seines Trainers nutzen können. Dies ist besonders hilfreich, wenn wir mit DDP arbeiten. Sie vergleichen also unser DDP -Beispiel und das PMLS -Beispiel. Übrigens hat PML auch Trainer, aber es wird in den Beispielen nicht weit verbreitet, und stattdessen werden benutzerdefinierte train / test verwendet.
Wir glauben, dass Pipelines, lakonische Beispiele und Zoo von vorbereiteten Modellen den Eingangsschwellenwert auf einen sehr niedrigen Wert setzen.
Metrisches Lernproblem (auch als extremes Klassifizierungsproblem bezeichnet) bedeutet eine Situation, in der wir Tausende von IDs einiger Entitäten haben, aber nur wenige Stichproben für jede Entität. Oft gehen wir davon aus, dass wir während der Teststufe (oder Produktion) mit unsichtbaren Unternehmen zu tun haben, was es unmöglich macht, die Vanilleklassifizierungspipeline direkt anzuwenden. In vielen Fällen werden Einbettungen verwendet, um Such- oder Übereinstimmungsverfahren darüber durchzuführen.
Hier sind einige Beispiele für solche Aufgaben aus der Computer Vision Sphere:
embedding - Modellausgabe (auch als features vector oder descriptor bezeichnet).query - Eine Probe, die als Anforderung im Abrufverfahren verwendet wird.gallery set - Der Satz von Entitäten, um Elemente ähnlich wie query zu durchsuchen (auch als reference oder index bezeichnet).Sampler - Ein Argument für DataLoader , mit dem Chargen gebildet werdenMiner - Das Objekt zur Bildung von Paaren oder Drillingen, nachdem die Charge durch Sampler gebildet wurde. Es ist nicht notwendig, die Kombinationen von Proben nur innerhalb der aktuellen Charge zu bilden, daher kann die Speicherbank Teil des Miner sein.Samples / Labels / Instances - In einem Beispiel betrachten wir den Deepfashion -Datensatz. Es enthält Tausende von Modeartikel -IDs (wir nennen sie labels ) und mehrere Fotos für jede Element -ID (wir nennen das individuelle Foto als instance oder sample ). Alle IDs von Fashion Items haben ihre Gruppen wie "Röcke", "Jacken", "Shorts" und so weiter (wir nennen sie categories ). Beachten Sie, dass wir es vermeiden, den Begriff class zu verwenden, um Missverständnisse zu vermeiden.training epoch - Stapel -Sampler, die wir für kombinationsbasierte Verluste verwenden, haben normalerweise eine Länge [number of labels in training dataset] / [numbers of labels in one batch] . Dies bedeutet, dass wir nicht alle verfügbaren Trainingsproben in einer Epoche beobachten (im Gegensatz zur Vanilleklassifizierung), stattdessen beobachten wir alle verfügbaren Etiketten.Es kann mit den aktuellen (2022-Jahres-) SOTA-Methoden vergleichbar sein, z. B. Hyp-vit. (Nur wenige Wörter zu diesem Ansatz: Es handelt sich um eine VIT-Architektur, die mit kontrastivem Verlust trainiert wurde, aber die Einbettungen wurden in einen hyperbolischen Raum projiziert. Wie die Autoren behaupteten, kann ein solcher Raum die verschachtelte Struktur von Daten realer Welt beschreiben. Das Papier erfordert eine starke Mathematik, um die üblichen Operationen für den hyperbolischen Raum anzupassen.))
Wir haben dieselbe Architektur mit Triplettverlust geschult und den Rest der Parameter repariert: Trainings- und Testtransformationen, Bildgröße und Optimierer. Siehe Konfigurationen im Modelle Zoo. Der Trick war die Heuristik in unserem Bergmann und Sampler:
Die Kategorie -Balance -Stichprobentler bildet die Chargen, die die Anzahl der darin enthaltenen Kategorien einschränken. Zum Beispiel, wenn C = 1 nur Jacken in eine Charge und nur Jeans in eine andere steckt (nur ein Beispiel). Es macht die negativen Paare automatisch schwieriger: Es ist für ein Modell aussagekräftiger zu erkennen, warum zwei Jacken anders sind, als dasselbe über eine Jacke und ein T-Shirt zu verstehen.
Hard Tripletts Miner macht die Aufgabe noch schwieriger, nur die schwierigsten Tripletts (mit maximalen positiven und minimalen negativen Entfernungen) zu halten.
Hier sind CMC@1 -Ergebnisse für 2 beliebte Benchmarks. SOP-Datensatz: Hyp-vit-85,9, unsere-86,6. DeepFashion-Datensatz: Hyp-vit-92,5, unsere-92,1. Durch die Verwendung einfacher Heuristiken und der Vermeidung schwerer Mathematik können wir auf SOTA -Ebene durchführen.
Neuere Forschungen in SSL erzielten definitiv großartige Ergebnisse. Das Problem ist, dass diese Ansätze eine enorme Computermenge benötigten, um das Modell zu trainieren. In unserem Framework betrachten wir jedoch den häufigsten Fall, wenn der durchschnittliche Benutzer nicht mehr als ein paar GPUs hat.
Gleichzeitig wäre es unklug, den Erfolg in diesem Bereich zu ignorieren, also nutzen wir ihn immer noch auf zwei Arten:
Nein, du nicht. OML ist ein Framework-Agnostic. Obwohl wir Pytorch Lightning als Loop -Läufer für die Experimente verwenden, behalten wir auch die Möglichkeit, alles auf reinem Pytorch zu betreiben. Daher ist nur der winzige Teil von OML blitzspezifisch und wir halten diese Logik getrennt von einem anderen Code (siehe oml.lightning ). Selbst wenn Sie Blitz verwenden, müssen Sie es nicht wissen, da wir bereit sind, Pipelines zu verwenden.
Die Möglichkeit, reine Pytorch und modulare Struktur des Codes zu verwenden, lässt nach der Implementierung der notwendigen Wrapper einen Raum für die Verwendung von OML mit Ihrem bevorzugten Framework.
Ja. Um das Experiment mit Pipelines durchzuführen, müssen Sie nur einen Konverter in unser Format schreiben (es bedeutet, die .csv -Tabelle mit einigen vordefinierten Spalten vorzubereiten). Das war's!
Wahrscheinlich haben wir bereits ein geeignetes vorgebildetes Modell für Ihre Domäne in unserem Models Zoo . In diesem Fall müssen Sie es nicht einmal trainieren.
Derzeit unterstützen wir den Exportieren von Modellen nach ONNX nicht direkt. Sie können jedoch die integrierten Pytorch-Funktionen verwenden, um dies zu erreichen. Weitere Informationen finden Sie in diesem Problem.
DOKUMENTATION
Tutorial zu Beginn mit: Englisch | Russisch | chinesisch
Die Demo für unsere Arbeit aufträgt: Siamese -Transformatoren für die Nachbearbeitung im Bildabruf
Treffen Sie OpenMetricLearning (OML) auf MarktechPost
Der Bericht für das Berliner Meetup: "Computer Vision in der Produktion". November 2022. Link
pip install -U open-metric-learning ; # minimum dependencies
pip install -U open-metric-learning[nlp]
pip install -U open-metric-learning[audio]docker pull omlteam/oml:gpu
docker pull omlteam/oml:cpu Verluste | Bergleute miner = AllTripletsMiner ()
miner = NHardTripletsMiner ()
miner = MinerWithBank ()
...
criterion = TripletLossWithMiner ( 0.1 , miner )
criterion = ArcFaceLoss ()
criterion = SurrogatePrecision () | Sampler labels = train . get_labels ()
l2c = train . get_label2category ()
sampler = BalanceSampler ( labels )
sampler = CategoryBalanceSampler ( labels , l2c )
sampler = DistinctCategoryBalanceSampler ( labels , l2c ) |
Konfiguration Support max_epochs : 10
sampler :
name : balance
args :
n_labels : 2
n_instances : 2 | Vorausgebildete Modelle model_hf = AutoModel . from_pretrained ( "roberta-base" )
tokenizer = AutoTokenizer . from_pretrained ( "roberta-base" )
extractor_txt = HFWrapper ( model_hf )
extractor_img = ViTExtractor . from_pretrained ( "vits16_dino" )
transforms , _ = get_transforms_for_pretrained ( "vits16_dino" ) |
Nachbearbeitung emb = inference ( extractor , dataset )
rr = RetrievalResults . from_embeddings ( emb , dataset )
postprocessor = AdaptiveThresholding ()
rr_upd = postprocessor . process ( rr , dataset ) | Nachbearbeitung durch NN | Papier embeddings = inference ( extractor , dataset )
rr = RetrievalResults . from_embeddings ( embeddings , dataset )
postprocessor = PairwiseReranker ( ConcatSiamese (), top_n = 3 )
rr_upd = postprocessor . process ( rr , dataset ) |
Protokollierung logger = TensorBoardPipelineLogger ()
logger = NeptunePipelineLogger ()
logger = WandBPipelineLogger ()
logger = MLFlowPipelineLogger ()
logger = ClearMLPipelineLogger () | PML from pytorch_metric_learning import losses
criterion = losses . TripletMarginLoss ( 0.2 , "all" )
pred = ViTExtractor ()( data )
criterion ( pred , gts ) |
Kategorienunterstützung # train
loader = DataLoader ( CategoryBalanceSampler ())
# validation
rr = RetrievalResults . from_embeddings ()
m . calc_retrieval_metrics_rr ( rr , query_categories ) | Metriken embeddigs = inference ( model , dataset )
rr = RetrievalResults . from_embeddings ( embeddings , dataset )
m . calc_retrieval_metrics_rr ( rr , precision_top_k = ( 5 ,))
m . calc_fnmr_at_fmr_rr ( rr , fmr_vals = ( 0.1 ,))
m . calc_topological_metrics ( embeddings , pcf_variance = ( 0.5 ,)) |
Blitz import pytorch_lightning as pl
model = ViTExtractor . from_pretrained ( "vits16_dino" )
clb = MetricValCallback ( EmbeddingMetrics ( dataset ))
module = ExtractorModule ( model , criterion , optimizer )
trainer = pl . Trainer ( max_epochs = 3 , callbacks = [ clb ])
trainer . fit ( module , train_loader , val_loader ) | Lightning DDP clb = MetricValCallback ( EmbeddingMetrics ( val ))
module = ExtractorModuleDDP (
model , criterion , optimizer , train , val
)
ddp = { "devices" : 2 , "strategy" : DDPStrategy ()}
trainer = pl . Trainer ( max_epochs = 3 , callbacks = [ clb ], ** ddp )
trainer . fit ( module ) |
Hier ist ein Beispiel dafür, wie Sie das Modell in einem winzigen Datensatz mit Bildern oder Texten trainieren, validieren und nachbearbeiten können. Weitere Details zum Datensatzformat finden Sie.
| Bilder | Texte |
from torch . optim import Adam
from torch . utils . data import DataLoader
from oml import datasets as d
from oml . inference import inference
from oml . losses import TripletLossWithMiner
from oml . metrics import calc_retrieval_metrics_rr
from oml . miners import AllTripletsMiner
from oml . models import ViTExtractor
from oml . registry import get_transforms_for_pretrained
from oml . retrieval import RetrievalResults , AdaptiveThresholding
from oml . samplers import BalanceSampler
from oml . utils import get_mock_images_dataset
model = ViTExtractor . from_pretrained ( "vits16_dino" ). to ( "cpu" ). train ()
transform , _ = get_transforms_for_pretrained ( "vits16_dino" )
df_train , df_val = get_mock_images_dataset ( global_paths = True )
train = d . ImageLabeledDataset ( df_train , transform = transform )
val = d . ImageQueryGalleryLabeledDataset ( df_val , transform = transform )
optimizer = Adam ( model . parameters (), lr = 1e-4 )
criterion = TripletLossWithMiner ( 0.1 , AllTripletsMiner (), need_logs = True )
sampler = BalanceSampler ( train . get_labels (), n_labels = 2 , n_instances = 2 )
def training ():
for batch in DataLoader ( train , batch_sampler = sampler ):
embeddings = model ( batch [ "input_tensors" ])
loss = criterion ( embeddings , batch [ "labels" ])
loss . backward ()
optimizer . step ()
optimizer . zero_grad ()
print ( criterion . last_logs )
def validation ():
embeddings = inference ( model , val , batch_size = 4 , num_workers = 0 )
rr = RetrievalResults . from_embeddings ( embeddings , val , n_items = 3 )
rr = AdaptiveThresholding ( n_std = 2 ). process ( rr )
rr . visualize ( query_ids = [ 2 , 1 ], dataset = val , show = True )
print ( calc_retrieval_metrics_rr ( rr , map_top_k = ( 3 ,), cmc_top_k = ( 1 ,)))
training ()
validation () | from torch . optim import Adam
from torch . utils . data import DataLoader
from transformers import AutoModel , AutoTokenizer
from oml import datasets as d
from oml . inference import inference
from oml . losses import TripletLossWithMiner
from oml . metrics import calc_retrieval_metrics_rr
from oml . miners import AllTripletsMiner
from oml . models import HFWrapper
from oml . retrieval import RetrievalResults , AdaptiveThresholding
from oml . samplers import BalanceSampler
from oml . utils import get_mock_texts_dataset
model = HFWrapper ( AutoModel . from_pretrained ( "bert-base-uncased" ), 768 ). to ( "cpu" ). train ()
tokenizer = AutoTokenizer . from_pretrained ( "bert-base-uncased" )
df_train , df_val = get_mock_texts_dataset ()
train = d . TextLabeledDataset ( df_train , tokenizer = tokenizer )
val = d . TextQueryGalleryLabeledDataset ( df_val , tokenizer = tokenizer )
optimizer = Adam ( model . parameters (), lr = 1e-4 )
criterion = TripletLossWithMiner ( 0.1 , AllTripletsMiner (), need_logs = True )
sampler = BalanceSampler ( train . get_labels (), n_labels = 2 , n_instances = 2 )
def training ():
for batch in DataLoader ( train , batch_sampler = sampler ):
embeddings = model ( batch [ "input_tensors" ])
loss = criterion ( embeddings , batch [ "labels" ])
loss . backward ()
optimizer . step ()
optimizer . zero_grad ()
print ( criterion . last_logs )
def validation ():
embeddings = inference ( model , val , batch_size = 4 , num_workers = 0 )
rr = RetrievalResults . from_embeddings ( embeddings , val , n_items = 3 )
rr = AdaptiveThresholding ( n_std = 2 ). process ( rr )
rr . visualize ( query_ids = [ 2 , 1 ], dataset = val , show = True )
print ( calc_retrieval_metrics_rr ( rr , map_top_k = ( 3 ,), cmc_top_k = ( 1 ,)))
training ()
validation () |
Ausgabe{ 'active_tri' : 0.125 , 'pos_dist' : 82.5 , 'neg_dist' : 100.5 } # batch 1
{ 'active_tri' : 0.0 , 'pos_dist' : 36.3 , 'neg_dist' : 56.9 } # batch 2
{ 'cmc' : { 1 : 0.75 }, 'precision' : { 5 : 0.75 }, 'map' : { 3 : 0.8 }} | Ausgabe{ 'active_tri' : 0.0 , 'pos_dist' : 8.5 , 'neg_dist' : 11.0 } # batch 1
{ 'active_tri' : 0.25 , 'pos_dist' : 8.9 , 'neg_dist' : 9.8 } # batch 2
{ 'cmc' : { 1 : 0.8 }, 'precision' : { 5 : 0.7 }, 'map' : { 3 : 0.9 }} |
Zusätzliche Abbildungen, Erklärungen und Tipps für den obigen Code.
Hier ist ein Beispiel für Inferenzzeit (mit anderen Worten, Abrufen am Testsatz). Der folgende Code funktioniert sowohl für Texte als auch für Bilder.
from oml . datasets import ImageQueryGalleryDataset
from oml . inference import inference
from oml . models import ViTExtractor
from oml . registry import get_transforms_for_pretrained
from oml . utils import get_mock_images_dataset
from oml . retrieval import RetrievalResults , AdaptiveThresholding
_ , df_test = get_mock_images_dataset ( global_paths = True )
del df_test [ "label" ] # we don't need gt labels for doing predictions
extractor = ViTExtractor . from_pretrained ( "vits16_dino" ). to ( "cpu" )
transform , _ = get_transforms_for_pretrained ( "vits16_dino" )
dataset = ImageQueryGalleryDataset ( df_test , transform = transform )
embeddings = inference ( extractor , dataset , batch_size = 4 , num_workers = 0 )
rr = RetrievalResults . from_embeddings ( embeddings , dataset , n_items = 5 )
rr = AdaptiveThresholding ( n_std = 3.5 ). process ( rr )
rr . visualize ( query_ids = [ 0 , 1 ], dataset = dataset , show = True )
# you get the ids of retrieved items and the corresponding distances
print ( rr )Hier ist ein Beispiel, in dem Abfragen und Galerien getrennt verarbeitet wurden.
import pandas as pd
from oml . datasets import ImageBaseDataset
from oml . inference import inference
from oml . models import ViTExtractor
from oml . registry import get_transforms_for_pretrained
from oml . retrieval import RetrievalResults , ConstantThresholding
from oml . utils import get_mock_images_dataset
extractor = ViTExtractor . from_pretrained ( "vits16_dino" ). to ( "cpu" )
transform , _ = get_transforms_for_pretrained ( "vits16_dino" )
paths = pd . concat ( get_mock_images_dataset ( global_paths = True ))[ "path" ]
galleries , queries1 , queries2 = paths [: 20 ], paths [ 20 : 22 ], paths [ 22 : 24 ]
# gallery is huge and fixed, so we only process it once
dataset_gallery = ImageBaseDataset ( galleries , transform = transform )
embeddings_gallery = inference ( extractor , dataset_gallery , batch_size = 4 , num_workers = 0 )
# queries come "online" in stream
for queries in [ queries1 , queries2 ]:
dataset_query = ImageBaseDataset ( queries , transform = transform )
embeddings_query = inference ( extractor , dataset_query , batch_size = 4 , num_workers = 0 )
# for the operation below we are going to provide integrations with vector search DB like QDrant or Faiss
rr = RetrievalResults . from_embeddings_qg (
embeddings_query = embeddings_query , embeddings_gallery = embeddings_gallery ,
dataset_query = dataset_query , dataset_gallery = dataset_gallery
)
rr = ConstantThresholding ( th = 80 ). process ( rr )
rr . visualize_qg ([ 0 , 1 ], dataset_query = dataset_query , dataset_gallery = dataset_gallery , show = True )
print ( rr )Pipelines bieten eine Möglichkeit, metrische Lernexperimente durch Ändern der Konfigurationsdatei auszuführen. Alles, was Sie brauchen, ist, Ihren Datensatz in einem erforderlichen Format vorzubereiten.
Weitere Informationen finden Sie unter den Ordner Pipelines -Ordner:
Hier finden Sie eine leichte Integration in Modelle von Huggingface -Transformatoren. Sie können es durch andere willkürliche Modelle ersetzen, die von Iextractor geerbt wurden.
Beachten Sie, dass wir momentan keinen eigenen Textmodelle Zoo haben.
pip install open-metric-learning[nlp] from transformers import AutoModel , AutoTokenizer
from oml . models import HFWrapper
model = AutoModel . from_pretrained ( 'bert-base-uncased' ). eval ()
tokenizer = AutoTokenizer . from_pretrained ( 'bert-base-uncased' )
extractor = HFWrapper ( model = model , feat_dim = 768 )
inp = tokenizer ( text = "Hello world" , return_tensors = "pt" , add_special_tokens = True )
embeddings = extractor ( inp )Sie können ein Bildmodell aus unserem Zoo verwenden oder andere willkürliche Modelle verwenden, nachdem Sie es von Iextractor geerbt haben.
from oml . const import CKPT_SAVE_ROOT as CKPT_DIR , MOCK_DATASET_PATH as DATA_DIR
from oml . models import ViTExtractor
from oml . registry import get_transforms_for_pretrained
model = ViTExtractor . from_pretrained ( "vits16_dino" ). eval ()
transforms , im_reader = get_transforms_for_pretrained ( "vits16_dino" )
img = im_reader ( DATA_DIR / "images" / "circle_1.jpg" ) # put path to your image here
img_tensor = transforms ( img )
# img_tensor = transforms(image=img)["image"] # for transforms from Albumentations
features = model ( img_tensor . unsqueeze ( 0 ))
# Check other available models:
print ( list ( ViTExtractor . pretrained_models . keys ()))
# Load checkpoint saved on a disk:
model_ = ViTExtractor ( weights = CKPT_DIR / "vits16_dino.ckpt" , arch = "vits16" , normalise_features = False )Models, von uns ausgebildet. Die folgenden Metriken sind für 224 x 224 Bilder:
| Modell | CMC1 | Datensatz | Gewichte | Experiment |
|---|---|---|---|---|
ViTExtractor.from_pretrained("vits16_inshop") | 0,921 | Deepfashion Inshop | Link | Link |
ViTExtractor.from_pretrained("vits16_sop") | 0,866 | Stanford Online -Produkte | Link | Link |
ViTExtractor.from_pretrained("vits16_cars") | 0,907 | Autos 196 | Link | Link |
ViTExtractor.from_pretrained("vits16_cub") | 0,837 | Cub 200 2011 | Link | Link |
Modelle, geschult von anderen Forschern. Beachten Sie, dass einige Metriken zu bestimmten Benchmarks so hoch sind, weil sie Teil des Trainingsdatensatzes waren (z. B. unicom ). Die folgenden Metriken sind für 224 x 224 Bilder:
| Modell | Stanford Online -Produkte | Deepfashion Inshop | Cub 200 2011 | Autos 196 |
|---|---|---|---|---|
ViTUnicomExtractor.from_pretrained("vitb16_unicom") | 0,700 | 0,734 | 0,847 | 0,916 |
ViTUnicomExtractor.from_pretrained("vitb32_unicom") | 0,690 | 0,722 | 0,796 | 0,893 |
ViTUnicomExtractor.from_pretrained("vitl14_unicom") | 0,726 | 0,790 | 0,868 | 0,922 |
ViTUnicomExtractor.from_pretrained("vitl14_336px_unicom") | 0,745 | 0,810 | 0,875 | 0,924 |
ViTCLIPExtractor.from_pretrained("sber_vitb32_224") | 0,547 | 0,514 | 0,448 | 0,618 |
ViTCLIPExtractor.from_pretrained("sber_vitb16_224") | 0,565 | 0,565 | 0,524 | 0,648 |
ViTCLIPExtractor.from_pretrained("sber_vitl14_224") | 0,512 | 0,555 | 0,606 | 0,707 |
ViTCLIPExtractor.from_pretrained("openai_vitb32_224") | 0,612 | 0,491 | 0,560 | 0,693 |
ViTCLIPExtractor.from_pretrained("openai_vitb16_224") | 0,648 | 0,606 | 0,665 | 0,767 |
ViTCLIPExtractor.from_pretrained("openai_vitl14_224") | 0,670 | 0,675 | 0,745 | 0,844 |
ViTExtractor.from_pretrained("vits16_dino") | 0,648 | 0,509 | 0,627 | 0,265 |
ViTExtractor.from_pretrained("vits8_dino") | 0,651 | 0,524 | 0,661 | 0,315 |
ViTExtractor.from_pretrained("vitb16_dino") | 0,658 | 0,514 | 0,541 | 0,288 |
ViTExtractor.from_pretrained("vitb8_dino") | 0,689 | 0,599 | 0,506 | 0,313 |
ViTExtractor.from_pretrained("vits14_dinov2") | 0,566 | 0,334 | 0,797 | 0,503 |
ViTExtractor.from_pretrained("vits14_reg_dinov2") | 0,566 | 0,332 | 0,795 | 0,740 |
ViTExtractor.from_pretrained("vitb14_dinov2") | 0,565 | 0,342 | 0,842 | 0,644 |
ViTExtractor.from_pretrained("vitb14_reg_dinov2") | 0,557 | 0,324 | 0,833 | 0,828 |
ViTExtractor.from_pretrained("vitl14_dinov2") | 0,576 | 0,352 | 0,844 | 0,692 |
ViTExtractor.from_pretrained("vitl14_reg_dinov2") | 0,571 | 0,340 | 0,840 | 0,871 |
ResnetExtractor.from_pretrained("resnet50_moco_v2") | 0,493 | 0,267 | 0,264 | 0,149 |
ResnetExtractor.from_pretrained("resnet50_imagenet1k_v1") | 0,515 | 0,284 | 0,455 | 0,247 |
Die Metriken können sich von den von Papieren gemeldeten unterscheiden, da sich die Version von Zug/VAL -Split und Verwendung von Begrenzungsboxen unterscheiden kann.
Wir begrüßen neue Mitwirkende! Bitte sehen Sie unsere:
Das Projekt wurde 2020 als Modul für die Catalyst Library gestartet. Ich möchte Leuten danken, die mit mir an diesem Modul gearbeitet haben: Julia Shenhina, Nikita Balagansky, Sergey Kolesnikov und andere.
Ich möchte mich bei Leuten bedanken, die weiterhin an dieser Pipeline arbeiten, als sie ein separates Projekt wurde: Julia Shenhina, Misha Kindulov, Aron Dik, Aleksei Tarasov und Verkhovtsev Leonid.
Ich möchte auch Newyorker danken, da der Teil der Funktionalität von seinem Computer Vision -Team entwickelt (und verwendet) wurde.