OML es un marco basado en Pytorch para entrenar y validar los modelos que producen incrustaciones de alta calidad.
ㅤㅤ
Hay varias personas de las universidades de Oxford y HSE que han usado OML en sus tesis. [1] [2] [3]
La actualización se centra en varios componentes:
Agregamos el soporte de textos "oficiales" y los ejemplos de Python correspondientes. (Tenga en cuenta que el soporte de textos en tuberías aún no está compatible).
Introdujimos la clase RetrievalResults ( RR ), un contenedor para almacenar artículos de la galería recuperados para consultas dadas. RR proporciona una forma unificada de visualizar las predicciones y calcular métricas (si se conocen las verdades terrestres). También simplifica el postprocesamiento, donde un objeto RR se toma como entrada y otro RR_upd se produce como salida. Tener estos dos objetos permite los resultados de recuperación de comparación visualmente o por métricas. Además, puede crear fácilmente una cadena de tales postprocesadores.
RR está optimizado para la memoria debido al uso de lotes: en otras palabras, no almacena una matriz completa de distancias de consulta. (Sin embargo, no hace que la búsqueda sea aproximada). Hicimos Model y Dataset las únicas clases responsables del procesamiento de la lógica específica de modalidad. Model es responsable de interpretar sus dimensiones de entrada: por ejemplo, BxCxHxW para imágenes o BxLxD para secuencias como textos. Dataset es responsable de preparar un elemento: puede usar Transforms para imágenes o Tokenizer para textos. Funciones que calculan las métricas como calc_retrieval_metrics_rr , RetrievalResults , PairwiseReranker y otras clases y funciones están unificadas para trabajar con cualquier modalidad.
IVisualizableDataset con método .visaulize() que muestra un solo elemento. Si se implementa, RetrievalResults puede mostrar el diseño de los resultados recuperados. ¡La forma más fácil de ponerse al día con los cambios es volver a leer los ejemplos!
La forma de validación recomendada es utilizar RetrievalResults y funciones de recuperación como calc_retrieval_metrics_rr , calc_fnmr_at_fmr_rr y otros. La clase EmbeddingMetrics se mantiene para su uso con rayos de Pytorch y tuberías internas. Tenga en cuenta que las firmas de los métodos EmbeddingMetrics se han cambiado ligeramente, ver ejemplos de rayos para eso.
Dado que la lógica específica de modalidad se limita al Dataset , ya no emite PATHS_KEY , X1_KEY , X2_KEY , Y1_KEY y Y2_KEY . Claves que no son específicas de la modalidad como LABELS_KEY , IS_GALLERY , IS_QUERY_KEY , CATEGORIES_KEY todavía están en uso.
inference_on_images ahora es inference y funciona con cualquier modalidad.
Interfaces ligeramente cambiadas de Datasets. Por ejemplo, tenemos interfaces IQueryGalleryDataset e IQueryGalleryLabeledDataset . El primero tiene que usarse para la inferencia, el segundo para la validación. También se agregó la interfaz IVisualizableDataset .
Eliminaron algunas partes internas como IMetricDDP , EmbeddingMetricsDDP , calc_distance_matrix , calc_gt_mask , calc_mask_to_ignore , apply_mask_to_ignore . Estos cambios no deberían afectarte. También se eliminó el código relacionado con una tubería con trillizos precomputados.
Extracción de características: no hay cambios, excepto para agregar un argumento opcional - mode_for_checkpointing = (min | max) . Puede ser útil cambiar entre el tipo más bajo, mejor y mayor, mejor tipo de métricas.
Tubería de poste-poste de pares: cambió ligeramente el nombre y los argumentos del postprocessor Sub Config: pairwise_images ahora es pairwise_reranker y no necesita transformaciones.
Puede pensar "si necesito incrustaciones de imágenes, simplemente puedo entrenar un clasificador de vainilla y tomar su penúltima capa" . Bueno, tiene sentido como punto de partida. Pero hay varios inconvenientes posibles:
Si desea usar incrustaciones para realizar la búsqueda, necesita calcular cierta distancia entre ellos (por ejemplo, coseno o L2). Por lo general, no optimiza directamente estas distancias durante la capacitación en la configuración de clasificación. Por lo tanto, solo puede esperar que las incrustaciones finales tengan las propiedades deseadas.
El segundo problema es el proceso de validación . En la configuración de búsqueda, generalmente le importa qué tan relacionadas estén sus resultados TOP-N con la consulta. La forma natural de evaluar el modelo es simular las solicitudes de búsqueda al conjunto de referencias y aplicar una de las métricas de recuperación. Por lo tanto, no hay garantía de que la precisión de clasificación se correlacione con estas métricas.
Finalmente, es posible que desee implementar una tubería de aprendizaje métrico usted mismo. Hay mucho trabajo : para usar la pérdida de triplete, debe formar lotes de una manera específica, implementar diferentes tipos de tripletes de minería, rastreo de distancias, etc. Para la validación, también debe implementar métricas de recuperación, que incluyen acumulación efectiva de integración durante la época, cubriendo los casos de esquina, etc. Es aún más resistente si tiene varios GPU y usa DDP. También es posible que desee visualizar sus solicitudes de búsqueda destacando los resultados de búsqueda buenos y malos. En lugar de hacerlo usted mismo, simplemente puede usar OML para sus propósitos.
PML es la biblioteca popular para el aprendizaje métrico, e incluye una rica colección de pérdidas, mineros, distancias y reductores; Es por eso que proporcionamos ejemplos directos de usarlos con OML. Inicialmente, tratamos de usar PML, pero al final, se nos ocurrió nuestra biblioteca, que está más orientada a la tubería / recetas. Así es como OML difiere de PML:
OML tiene tuberías que permiten modelos de entrenamiento preparando una configuración y sus datos en el formato requerido (es como convertir datos en formato de Coco para entrenar un detector de MMDetection).
OML se centra en las tuberías de extremo a extremo y los casos de uso práctico. Tiene ejemplos basados en configuración en puntos de referencia populares cerca de la vida real (como fotos de productos de miles de ID). Encontramos algunas buenas combinaciones de hiperparámetros en estos conjuntos de datos, modelos capacitados y publicados y sus configuraciones. Por lo tanto, hace que OML esté más orientada a las recetas que PML, y su autor confirma que esto dice que su biblioteca es un conjunto de herramientas más bien las recetas, además, los ejemplos en PML son principalmente para conjuntos de datos CIFAR y MNIST.
OML tiene el zoológico de modelos previos a la aparición a los que se puede acceder fácilmente desde el código de la misma manera que en torchvision (cuando escribe resnet50(pretrained=True) ).
OML está integrado con Pytorch Lightning, por lo que podemos usar el poder de su entrenador. Esto es especialmente útil cuando trabajamos con DDP, por lo tanto, compara nuestro ejemplo DDP y el PMLS. Por cierto, PML también tiene entrenadores, pero no se usa ampliamente en los ejemplos y se utilizan funciones train / test personalizadas.
Creemos que tener tuberías, ejemplos lacónicos y zoológico de modelos previos a los pretrados establece el umbral de entrada a un valor realmente bajo.
El problema de aprendizaje métrico (también conocido como problema de clasificación extrema ) significa una situación en la que tenemos miles de identificaciones de algunas entidades, pero solo unas pocas muestras para cada entidad. A menudo suponemos que durante la etapa de prueba (o producción) trataremos con entidades invisibles, lo que hace que sea imposible aplicar la tubería de clasificación de vainilla directamente. En muchos casos, los incrustaciones obtenidas se utilizan para realizar procedimientos de búsqueda o correspondencia sobre ellos.
Aquí hay algunos ejemplos de tales tareas de la esfera de la visión de la computadora:
embedding : salida del modelo (también conocida como features vector o descriptor de características).query : una muestra que se utiliza como solicitud en el procedimiento de recuperación.gallery set : el conjunto de entidades para buscar elementos similares a query (también conocido como reference o index ).Sampler : un argumento para DataLoader que se utiliza para formar lotesMiner : el objeto para formar pares o trillizos después de que el lote fue formado por Sampler . No es necesario formar las combinaciones de muestras solo dentro del lote actual, por lo tanto, el banco de memoria puede ser parte del Miner .Samples / Labels / Instances : como ejemplo, consideremos el conjunto de datos de la moda profunda. Incluye miles de identificaciones de elementos de moda (las nombramos labels ) y varias fotos para cada identificación del elemento (nombramos la foto individual como instance o sample ). Todas las identificaciones de artículos de moda tienen sus grupos como "faldas", "chaquetas", "cortos", etc. (los nombramos categories ). Tenga en cuenta que evitamos usar el término class para evitar malentendidos.training epoch : las muestras de lotes que utilizamos para las pérdidas basadas en combinaciones generalmente tienen una longitud igual a [number of labels in training dataset] / [numbers of labels in one batch] . Significa que no observamos todas las muestras de entrenamiento disponibles en una época (a diferencia de la clasificación de vainilla), en cambio, observamos todas las etiquetas disponibles.Puede ser comparable con los métodos SOTA actuales (2022 años), por ejemplo, HyP-VIT. (Pocas palabras sobre este enfoque: es una arquitectura VIT entrenada con pérdida contrastante, pero las incrustaciones se proyectaron en algún espacio hiperbólico. Como afirmaron los autores, dicho espacio puede describir la estructura anidada de los datos del mundo real. Entonces, el documento requiere algunas matemáticas pesadas para adaptar las operaciones habituales para el espacio hiperbolético).
Entrenamos la misma arquitectura con pérdida de triplete, fijando el resto de los parámetros: transformaciones de entrenamiento y prueba, tamaño de imagen y optimizador. Consulte Configuras en Models Zoo. El truco fue en heurística en nuestro minero y muestra:
Category Balance Sampler Forma los lotes que limitan el número de categorías C en él. Por ejemplo, cuando C = 1 solo pone chaquetas en un lote y solo jeans en otro (solo un ejemplo). Automáticamente hace que los pares negativos sean más difíciles: es más significativo que un modelo se dé cuenta de por qué dos chaquetas son diferentes a comprender lo mismo de una chaqueta y una camiseta.
Triplors duros minero hace que la tarea sea aún más difícil mantener solo los trillizos más duros (con distancias negativas positivas y mínimas máximas).
Aquí hay puntajes CMC@1 para 2 puntos de referencia populares. SOP DataSet: HyP-VIT-85.9, el nuestro-86.6. Conjunto de datos de DeepFashion: HyP-VIT-92.5, el nuestro-92.1. Por lo tanto, utilizando heurísticas simples y evitar matemáticas pesadas, podemos realizar a nivel SOTA.
Investigaciones recientes en SSL definitivamente obtuvieron excelentes resultados. El problema es que estos enfoques requirieron una enorme cantidad de computación para entrenar el modelo. Pero en nuestro marco, consideramos el caso más común cuando el usuario promedio no tiene más de unas pocas GPU.
Al mismo tiempo, sería imprudente ignorar el éxito en esta esfera, por lo que todavía lo explotamos de dos maneras:
No, no lo haces. OML es un marco-agnóstico. A pesar de que usamos Pytorch Lightning como corredor de bucle para los experimentos, también tenemos la posibilidad de ejecutar todo en Pytorch puro. Por lo tanto, solo la pequeña parte de OML es específica del rayo y mantenemos esta lógica por separado de otro código (ver oml.lightning ). Incluso cuando usa Lightning, no necesita saberlo, ya que proporcionamos tuberías listas para usar.
La posibilidad de usar Pytorch puro y estructura modular del código deja un espacio para utilizar OML con su marco favorito después de la implementación de los envoltorios necesarios.
Sí. Para ejecutar el experimento con tuberías, solo necesita escribir un convertidor en nuestro formato (significa preparar la tabla .csv con algunas columnas predefinidas). ¡Eso es todo!
Probablemente ya tenemos un modelo previamente capacitado adecuado para su dominio en nuestro zoológico de modelos . En este caso, ni siquiera necesita entrenarlo.
Actualmente, no admitimos exportar modelos a ONNX directamente. Sin embargo, puede usar las capacidades de Pytorch incorporadas para lograr esto. Para obtener más información, consulte este problema.
DOCUMENTACIÓN
Tutorial para comenzar con: inglés | Ruso | Chino
La demostración de nuestro documento revuelta: transformadores siameses para el posprocesamiento de recuperación de imágenes
Conozca OpenMetriclearning (OML) en MarkTechPost
El informe para la reunión con sede en Berlín: "Visión por computadora en producción". Noviembre de 2022. Enlace
pip install -U open-metric-learning ; # minimum dependencies
pip install -U open-metric-learning[nlp]
pip install -U open-metric-learning[audio]docker pull omlteam/oml:gpu
docker pull omlteam/oml:cpu Pérdidas | Mineros miner = AllTripletsMiner ()
miner = NHardTripletsMiner ()
miner = MinerWithBank ()
...
criterion = TripletLossWithMiner ( 0.1 , miner )
criterion = ArcFaceLoss ()
criterion = SurrogatePrecision () | Muestras labels = train . get_labels ()
l2c = train . get_label2category ()
sampler = BalanceSampler ( labels )
sampler = CategoryBalanceSampler ( labels , l2c )
sampler = DistinctCategoryBalanceSampler ( labels , l2c ) |
Configuración de soporte max_epochs : 10
sampler :
name : balance
args :
n_labels : 2
n_instances : 2 | Modelos previamente capacitados model_hf = AutoModel . from_pretrained ( "roberta-base" )
tokenizer = AutoTokenizer . from_pretrained ( "roberta-base" )
extractor_txt = HFWrapper ( model_hf )
extractor_img = ViTExtractor . from_pretrained ( "vits16_dino" )
transforms , _ = get_transforms_for_pretrained ( "vits16_dino" ) |
Postprocesamiento emb = inference ( extractor , dataset )
rr = RetrievalResults . from_embeddings ( emb , dataset )
postprocessor = AdaptiveThresholding ()
rr_upd = postprocessor . process ( rr , dataset ) | Postprocesamiento por NN | Papel embeddings = inference ( extractor , dataset )
rr = RetrievalResults . from_embeddings ( embeddings , dataset )
postprocessor = PairwiseReranker ( ConcatSiamese (), top_n = 3 )
rr_upd = postprocessor . process ( rr , dataset ) |
Explotación florestal logger = TensorBoardPipelineLogger ()
logger = NeptunePipelineLogger ()
logger = WandBPipelineLogger ()
logger = MLFlowPipelineLogger ()
logger = ClearMLPipelineLogger () | PML from pytorch_metric_learning import losses
criterion = losses . TripletMarginLoss ( 0.2 , "all" )
pred = ViTExtractor ()( data )
criterion ( pred , gts ) |
Las categorías apoyan # train
loader = DataLoader ( CategoryBalanceSampler ())
# validation
rr = RetrievalResults . from_embeddings ()
m . calc_retrieval_metrics_rr ( rr , query_categories ) | Métricas misceláneas embeddigs = inference ( model , dataset )
rr = RetrievalResults . from_embeddings ( embeddings , dataset )
m . calc_retrieval_metrics_rr ( rr , precision_top_k = ( 5 ,))
m . calc_fnmr_at_fmr_rr ( rr , fmr_vals = ( 0.1 ,))
m . calc_topological_metrics ( embeddings , pcf_variance = ( 0.5 ,)) |
Iluminación import pytorch_lightning as pl
model = ViTExtractor . from_pretrained ( "vits16_dino" )
clb = MetricValCallback ( EmbeddingMetrics ( dataset ))
module = ExtractorModule ( model , criterion , optimizer )
trainer = pl . Trainer ( max_epochs = 3 , callbacks = [ clb ])
trainer . fit ( module , train_loader , val_loader ) | Lightning DDP clb = MetricValCallback ( EmbeddingMetrics ( val ))
module = ExtractorModuleDDP (
model , criterion , optimizer , train , val
)
ddp = { "devices" : 2 , "strategy" : DDPStrategy ()}
trainer = pl . Trainer ( max_epochs = 3 , callbacks = [ clb ], ** ddp )
trainer . fit ( module ) |
Aquí hay un ejemplo de cómo entrenar, validar y procesar el modelo en un pequeño conjunto de datos de imágenes o textos. Vea más detalles sobre el formato de conjunto de datos.
| Imágenes | TEXTOS |
from torch . optim import Adam
from torch . utils . data import DataLoader
from oml import datasets as d
from oml . inference import inference
from oml . losses import TripletLossWithMiner
from oml . metrics import calc_retrieval_metrics_rr
from oml . miners import AllTripletsMiner
from oml . models import ViTExtractor
from oml . registry import get_transforms_for_pretrained
from oml . retrieval import RetrievalResults , AdaptiveThresholding
from oml . samplers import BalanceSampler
from oml . utils import get_mock_images_dataset
model = ViTExtractor . from_pretrained ( "vits16_dino" ). to ( "cpu" ). train ()
transform , _ = get_transforms_for_pretrained ( "vits16_dino" )
df_train , df_val = get_mock_images_dataset ( global_paths = True )
train = d . ImageLabeledDataset ( df_train , transform = transform )
val = d . ImageQueryGalleryLabeledDataset ( df_val , transform = transform )
optimizer = Adam ( model . parameters (), lr = 1e-4 )
criterion = TripletLossWithMiner ( 0.1 , AllTripletsMiner (), need_logs = True )
sampler = BalanceSampler ( train . get_labels (), n_labels = 2 , n_instances = 2 )
def training ():
for batch in DataLoader ( train , batch_sampler = sampler ):
embeddings = model ( batch [ "input_tensors" ])
loss = criterion ( embeddings , batch [ "labels" ])
loss . backward ()
optimizer . step ()
optimizer . zero_grad ()
print ( criterion . last_logs )
def validation ():
embeddings = inference ( model , val , batch_size = 4 , num_workers = 0 )
rr = RetrievalResults . from_embeddings ( embeddings , val , n_items = 3 )
rr = AdaptiveThresholding ( n_std = 2 ). process ( rr )
rr . visualize ( query_ids = [ 2 , 1 ], dataset = val , show = True )
print ( calc_retrieval_metrics_rr ( rr , map_top_k = ( 3 ,), cmc_top_k = ( 1 ,)))
training ()
validation () | from torch . optim import Adam
from torch . utils . data import DataLoader
from transformers import AutoModel , AutoTokenizer
from oml import datasets as d
from oml . inference import inference
from oml . losses import TripletLossWithMiner
from oml . metrics import calc_retrieval_metrics_rr
from oml . miners import AllTripletsMiner
from oml . models import HFWrapper
from oml . retrieval import RetrievalResults , AdaptiveThresholding
from oml . samplers import BalanceSampler
from oml . utils import get_mock_texts_dataset
model = HFWrapper ( AutoModel . from_pretrained ( "bert-base-uncased" ), 768 ). to ( "cpu" ). train ()
tokenizer = AutoTokenizer . from_pretrained ( "bert-base-uncased" )
df_train , df_val = get_mock_texts_dataset ()
train = d . TextLabeledDataset ( df_train , tokenizer = tokenizer )
val = d . TextQueryGalleryLabeledDataset ( df_val , tokenizer = tokenizer )
optimizer = Adam ( model . parameters (), lr = 1e-4 )
criterion = TripletLossWithMiner ( 0.1 , AllTripletsMiner (), need_logs = True )
sampler = BalanceSampler ( train . get_labels (), n_labels = 2 , n_instances = 2 )
def training ():
for batch in DataLoader ( train , batch_sampler = sampler ):
embeddings = model ( batch [ "input_tensors" ])
loss = criterion ( embeddings , batch [ "labels" ])
loss . backward ()
optimizer . step ()
optimizer . zero_grad ()
print ( criterion . last_logs )
def validation ():
embeddings = inference ( model , val , batch_size = 4 , num_workers = 0 )
rr = RetrievalResults . from_embeddings ( embeddings , val , n_items = 3 )
rr = AdaptiveThresholding ( n_std = 2 ). process ( rr )
rr . visualize ( query_ids = [ 2 , 1 ], dataset = val , show = True )
print ( calc_retrieval_metrics_rr ( rr , map_top_k = ( 3 ,), cmc_top_k = ( 1 ,)))
training ()
validation () |
Producción{ 'active_tri' : 0.125 , 'pos_dist' : 82.5 , 'neg_dist' : 100.5 } # batch 1
{ 'active_tri' : 0.0 , 'pos_dist' : 36.3 , 'neg_dist' : 56.9 } # batch 2
{ 'cmc' : { 1 : 0.75 }, 'precision' : { 5 : 0.75 }, 'map' : { 3 : 0.8 }} | Producción{ 'active_tri' : 0.0 , 'pos_dist' : 8.5 , 'neg_dist' : 11.0 } # batch 1
{ 'active_tri' : 0.25 , 'pos_dist' : 8.9 , 'neg_dist' : 9.8 } # batch 2
{ 'cmc' : { 1 : 0.8 }, 'precision' : { 5 : 0.7 }, 'map' : { 3 : 0.9 }} |
Ilustraciones, explicaciones y consejos adicionales para el código anterior.
Aquí hay un ejemplo de tiempo de inferencia (en otras palabras, recuperación en el conjunto de pruebas). El siguiente código funciona tanto para textos como para imágenes.
from oml . datasets import ImageQueryGalleryDataset
from oml . inference import inference
from oml . models import ViTExtractor
from oml . registry import get_transforms_for_pretrained
from oml . utils import get_mock_images_dataset
from oml . retrieval import RetrievalResults , AdaptiveThresholding
_ , df_test = get_mock_images_dataset ( global_paths = True )
del df_test [ "label" ] # we don't need gt labels for doing predictions
extractor = ViTExtractor . from_pretrained ( "vits16_dino" ). to ( "cpu" )
transform , _ = get_transforms_for_pretrained ( "vits16_dino" )
dataset = ImageQueryGalleryDataset ( df_test , transform = transform )
embeddings = inference ( extractor , dataset , batch_size = 4 , num_workers = 0 )
rr = RetrievalResults . from_embeddings ( embeddings , dataset , n_items = 5 )
rr = AdaptiveThresholding ( n_std = 3.5 ). process ( rr )
rr . visualize ( query_ids = [ 0 , 1 ], dataset = dataset , show = True )
# you get the ids of retrieved items and the corresponding distances
print ( rr )Aquí hay un ejemplo en el que las consultas y las galerías se procesaron por separado.
import pandas as pd
from oml . datasets import ImageBaseDataset
from oml . inference import inference
from oml . models import ViTExtractor
from oml . registry import get_transforms_for_pretrained
from oml . retrieval import RetrievalResults , ConstantThresholding
from oml . utils import get_mock_images_dataset
extractor = ViTExtractor . from_pretrained ( "vits16_dino" ). to ( "cpu" )
transform , _ = get_transforms_for_pretrained ( "vits16_dino" )
paths = pd . concat ( get_mock_images_dataset ( global_paths = True ))[ "path" ]
galleries , queries1 , queries2 = paths [: 20 ], paths [ 20 : 22 ], paths [ 22 : 24 ]
# gallery is huge and fixed, so we only process it once
dataset_gallery = ImageBaseDataset ( galleries , transform = transform )
embeddings_gallery = inference ( extractor , dataset_gallery , batch_size = 4 , num_workers = 0 )
# queries come "online" in stream
for queries in [ queries1 , queries2 ]:
dataset_query = ImageBaseDataset ( queries , transform = transform )
embeddings_query = inference ( extractor , dataset_query , batch_size = 4 , num_workers = 0 )
# for the operation below we are going to provide integrations with vector search DB like QDrant or Faiss
rr = RetrievalResults . from_embeddings_qg (
embeddings_query = embeddings_query , embeddings_gallery = embeddings_gallery ,
dataset_query = dataset_query , dataset_gallery = dataset_gallery
)
rr = ConstantThresholding ( th = 80 ). process ( rr )
rr . visualize_qg ([ 0 , 1 ], dataset_query = dataset_query , dataset_gallery = dataset_gallery , show = True )
print ( rr )Las tuberías proporcionan una forma de ejecutar experimentos de aprendizaje métrico al cambiar solo el archivo de configuración. Todo lo que necesita es preparar su conjunto de datos en un formato requerido.
Consulte la carpeta de tuberías para obtener más detalles:
Aquí hay una integración liviana con los modelos de Transformers Huggingface. Puede reemplazarlo con otros modelos arbitrarios heredados de IEXtractor.
Tenga en cuenta que no tenemos nuestro propio zoológico de modelos de texto en este momento.
pip install open-metric-learning[nlp] from transformers import AutoModel , AutoTokenizer
from oml . models import HFWrapper
model = AutoModel . from_pretrained ( 'bert-base-uncased' ). eval ()
tokenizer = AutoTokenizer . from_pretrained ( 'bert-base-uncased' )
extractor = HFWrapper ( model = model , feat_dim = 768 )
inp = tokenizer ( text = "Hello world" , return_tensors = "pt" , add_special_tokens = True )
embeddings = extractor ( inp )Puede usar un modelo de imagen de nuestro zoológico o usar otros modelos arbitrarios después de heredarlo de IEXtractor.
from oml . const import CKPT_SAVE_ROOT as CKPT_DIR , MOCK_DATASET_PATH as DATA_DIR
from oml . models import ViTExtractor
from oml . registry import get_transforms_for_pretrained
model = ViTExtractor . from_pretrained ( "vits16_dino" ). eval ()
transforms , im_reader = get_transforms_for_pretrained ( "vits16_dino" )
img = im_reader ( DATA_DIR / "images" / "circle_1.jpg" ) # put path to your image here
img_tensor = transforms ( img )
# img_tensor = transforms(image=img)["image"] # for transforms from Albumentations
features = model ( img_tensor . unsqueeze ( 0 ))
# Check other available models:
print ( list ( ViTExtractor . pretrained_models . keys ()))
# Load checkpoint saved on a disk:
model_ = ViTExtractor ( weights = CKPT_DIR / "vits16_dino.ckpt" , arch = "vits16" , normalise_features = False )Modelos, entrenados por nosotros. Las métricas a continuación son para 224 x 224 imágenes:
| modelo | CMC1 | conjunto de datos | pesas | experimento |
|---|---|---|---|---|
ViTExtractor.from_pretrained("vits16_inshop") | 0.921 | Inshop de moda profunda | enlace | enlace |
ViTExtractor.from_pretrained("vits16_sop") | 0.866 | Productos en línea de Stanford | enlace | enlace |
ViTExtractor.from_pretrained("vits16_cars") | 0.907 | Autos 196 | enlace | enlace |
ViTExtractor.from_pretrained("vits16_cub") | 0.837 | Cub 200 2011 | enlace | enlace |
Modelos, entrenados por otros investigadores. Tenga en cuenta que algunas métricas en puntos de referencia particulares son tan altas porque eran parte del conjunto de datos de entrenamiento (por ejemplo, unicom ). Las métricas a continuación son para 224 x 224 imágenes:
| modelo | Productos en línea de Stanford | Inshop de moda profunda | Cub 200 2011 | Autos 196 |
|---|---|---|---|---|
ViTUnicomExtractor.from_pretrained("vitb16_unicom") | 0.700 | 0.734 | 0.847 | 0.916 |
ViTUnicomExtractor.from_pretrained("vitb32_unicom") | 0.690 | 0.722 | 0.796 | 0.893 |
ViTUnicomExtractor.from_pretrained("vitl14_unicom") | 0.726 | 0.790 | 0.868 | 0.922 |
ViTUnicomExtractor.from_pretrained("vitl14_336px_unicom") | 0.745 | 0.810 | 0.875 | 0.924 |
ViTCLIPExtractor.from_pretrained("sber_vitb32_224") | 0.547 | 0.514 | 0.448 | 0.618 |
ViTCLIPExtractor.from_pretrained("sber_vitb16_224") | 0.565 | 0.565 | 0.524 | 0.648 |
ViTCLIPExtractor.from_pretrained("sber_vitl14_224") | 0.512 | 0.555 | 0.606 | 0.707 |
ViTCLIPExtractor.from_pretrained("openai_vitb32_224") | 0.612 | 0.491 | 0.560 | 0.693 |
ViTCLIPExtractor.from_pretrained("openai_vitb16_224") | 0.648 | 0.606 | 0.665 | 0.767 |
ViTCLIPExtractor.from_pretrained("openai_vitl14_224") | 0.670 | 0.675 | 0.745 | 0.844 |
ViTExtractor.from_pretrained("vits16_dino") | 0.648 | 0.509 | 0.627 | 0.265 |
ViTExtractor.from_pretrained("vits8_dino") | 0.651 | 0.524 | 0.661 | 0.315 |
ViTExtractor.from_pretrained("vitb16_dino") | 0.658 | 0.514 | 0.541 | 0.288 |
ViTExtractor.from_pretrained("vitb8_dino") | 0.689 | 0.599 | 0.506 | 0.313 |
ViTExtractor.from_pretrained("vits14_dinov2") | 0.566 | 0.334 | 0.797 | 0.503 |
ViTExtractor.from_pretrained("vits14_reg_dinov2") | 0.566 | 0.332 | 0.795 | 0.740 |
ViTExtractor.from_pretrained("vitb14_dinov2") | 0.565 | 0.342 | 0.842 | 0.644 |
ViTExtractor.from_pretrained("vitb14_reg_dinov2") | 0.557 | 0.324 | 0.833 | 0.828 |
ViTExtractor.from_pretrained("vitl14_dinov2") | 0.576 | 0.352 | 0.844 | 0.692 |
ViTExtractor.from_pretrained("vitl14_reg_dinov2") | 0.571 | 0.340 | 0.840 | 0.871 |
ResnetExtractor.from_pretrained("resnet50_moco_v2") | 0.493 | 0.267 | 0.264 | 0.149 |
ResnetExtractor.from_pretrained("resnet50_imagenet1k_v1") | 0.515 | 0.284 | 0.455 | 0.247 |
Las métricas pueden ser diferentes de las informadas por los documentos, porque la versión de la división de trenes/val y el uso de las cajas limitadas puede diferir.
¡Agradecemos a los nuevos contribuyentes! Por favor, vea nuestro:
El proyecto se inició en 2020 como un módulo para Catalyst Library. Quiero agradecer a las personas que trabajaron conmigo en ese módulo: Julia Shenshina, Nikita Balagansky, Sergey Kolesnikov y otros.
Me gustaría agradecer a las personas que continúan trabajando en esta tubería cuando se convirtió en un proyecto separado: Julia Shenshina, Misha Kindulov, Aron Dik, Aleksei Tarasov y Verkhovtsev Leonid.
También quiero agradecer a Newyorker, ya que la parte de la funcionalidad fue desarrollada (y utilizada) por su equipo de visión por computadora dirigida por mí.