Flexibler Training und Abruf für späte Interaktionsmodelle

Pylate ist eine Bibliothek, die auf Satztransformatoren aufgebaut ist und die Feinabstimmung, Inferenz und Abruf mit hochmodernen Colbert-Modellen vereinfacht und optimiert. Es ermöglicht eine einfache Feinabstimmung sowohl für einzelne als auch bei mehreren GPUs und bietet Flexibilität für verschiedene Hardware-Setups. Pylate optimiert auch das Abrufen von Dokumenten und ermöglicht es Ihnen, eine breite Palette von Modellen zu laden, sodass Sie Colbert-Modelle aus den meisten vorgebliebenen Sprachmodellen konstruieren können.
Sie können Pylat mit PIP installieren:
pip install pylateVerwenden Sie für Bewertungsabhängigkeiten:
pip install " pylate[eval] " Die vollständige Dokumentation finden Sie hier, einschließlich ausführlicher Führer, Beispiele und API-Referenzen.
Hier ist ein einfaches Beispiel für das Training eines Colbert -Modells im MS Marco -Datensatz -Triplet -Datensatz mit Pylat. Dieses Skript demonstriert ein Training mit kontrastivem Verlust und Bewertung des Modells an einem festgehaltenen Eval-Set:
import torch
from datasets import load_dataset
from sentence_transformers import (
SentenceTransformerTrainer ,
SentenceTransformerTrainingArguments ,
)
from pylate import evaluation , losses , models , utils
# Define model parameters for contrastive training
model_name = "bert-base-uncased" # Choose the pre-trained model you want to use as base
batch_size = 32 # Larger batch size often improves results, but requires more memory
num_train_epochs = 1 # Adjust based on your requirements
# Set the run name for logging and output directory
run_name = "contrastive-bert-base-uncased"
output_dir = f"output/ { run_name } "
# 1. Here we define our ColBERT model. If not a ColBERT model, will add a linear layer to the base encoder.
model = models . ColBERT ( model_name_or_path = model_name )
# Compiling the model makes the training faster
model = torch . compile ( model )
# Load dataset
dataset = load_dataset ( "sentence-transformers/msmarco-bm25" , "triplet" , split = "train" )
# Split the dataset (this dataset does not have a validation set, so we split the training set)
splits = dataset . train_test_split ( test_size = 0.01 )
train_dataset = splits [ "train" ]
eval_dataset = splits [ "test" ]
# Define the loss function
train_loss = losses . Contrastive ( model = model )
# Initialize the evaluator
dev_evaluator = evaluation . ColBERTTripletEvaluator (
anchors = eval_dataset [ "query" ],
positives = eval_dataset [ "positive" ],
negatives = eval_dataset [ "negative" ],
)
# Configure the training arguments (e.g., batch size, evaluation strategy, logging steps)
args = SentenceTransformerTrainingArguments (
output_dir = output_dir ,
num_train_epochs = num_train_epochs ,
per_device_train_batch_size = batch_size ,
per_device_eval_batch_size = batch_size ,
fp16 = True , # Set to False if you get an error that your GPU can't run on FP16
bf16 = False , # Set to True if you have a GPU that supports BF16
run_name = run_name , # Will be used in W&B if `wandb` is installed
learning_rate = 3e-6 ,
)
# Initialize the trainer for the contrastive training
trainer = SentenceTransformerTrainer (
model = model ,
args = args ,
train_dataset = train_dataset ,
eval_dataset = eval_dataset ,
loss = train_loss ,
evaluator = dev_evaluator ,
data_collator = utils . ColBERTCollator ( model . tokenize ),
)
# Start the training process
trainer . train ()Nach dem Training kann das Modell mit dem Ausgangsverzeichnispfad geladen werden:
from pylate import models
model = models . ColBERT ( model_name_or_path = "contrastive-bert-base-uncased" )Um die beste Leistung beim Training eines Colbert -Modells zu erzielen, sollten Sie die Wissensdestillation verwenden, um das Modell mit den Punktzahlen eines starken Lehrermodells zu trainieren. Hier ist ein einfaches Beispiel dafür, wie ein Modell mit Wissensdestillation in Pylat auf MS Marco trainiert wird:
import torch
from datasets import load_dataset
from sentence_transformers import (
SentenceTransformerTrainer ,
SentenceTransformerTrainingArguments ,
)
from pylate import losses , models , utils
# Load the datasets required for knowledge distillation (train, queries, documents)
train = load_dataset (
path = "lightonai/ms-marco-en-bge" ,
name = "train" ,
)
queries = load_dataset (
path = "lightonai/ms-marco-en-bge" ,
name = "queries" ,
)
documents = load_dataset (
path = "lightonai/ms-marco-en-bge" ,
name = "documents" ,
)
# Set the transformation to load the documents/queries texts using the corresponding ids on the fly
train . set_transform (
utils . KDProcessing ( queries = queries , documents = documents ). transform ,
)
# Define the base model, training parameters, and output directory
model_name = "bert-base-uncased" # Choose the pre-trained model you want to use as base
batch_size = 16
num_train_epochs = 1
# Set the run name for logging and output directory
run_name = "knowledge-distillation-bert-base"
output_dir = f"output/ { run_name } "
# Initialize the ColBERT model from the base model
model = models . ColBERT ( model_name_or_path = model_name )
# Compiling the model to make the training faster
model = torch . compile ( model )
# Configure the training arguments (e.g., epochs, batch size, learning rate)
args = SentenceTransformerTrainingArguments (
output_dir = output_dir ,
num_train_epochs = num_train_epochs ,
per_device_train_batch_size = batch_size ,
fp16 = True , # Set to False if you get an error that your GPU can't run on FP16
bf16 = False , # Set to True if you have a GPU that supports BF16
run_name = run_name ,
learning_rate = 1e-5 ,
)
# Use the Distillation loss function for training
train_loss = losses . Distillation ( model = model )
# Initialize the trainer
trainer = SentenceTransformerTrainer (
model = model ,
args = args ,
train_dataset = train ,
loss = train_loss ,
data_collator = utils . ColBERTCollator ( tokenize_fn = model . tokenize ),
)
# Start the training process
trainer . train ()Pylate unterstützt die Umarmung von Gesichtsdatensätzen und ermöglicht ein nahtloses Triplett- / Wissensdestillationsbasis. Für kontrastives Training können Sie einen der vorhandenen Triplet -Datensätze für Satztransformatoren verwenden. Im Folgenden finden Sie ein Beispiel für das Erstellen eines benutzerdefinierten Triplet -Datensatzes für das Training:
from datasets import Dataset
dataset = [
{
"query" : "example query 1" ,
"positive" : "example positive document 1" ,
"negative" : "example negative document 1" ,
},
{
"query" : "example query 2" ,
"positive" : "example positive document 2" ,
"negative" : "example negative document 2" ,
},
{
"query" : "example query 3" ,
"positive" : "example positive document 3" ,
"negative" : "example negative document 3" ,
},
]
dataset = Dataset . from_list ( mapping = dataset )
train_dataset , test_dataset = dataset . train_test_split ( test_size = 0.3 )Um einen Wissensdestillationsdatensatz zu erstellen, können Sie den folgenden Snippet verwenden:
from datasets import Dataset
dataset = [
{
"query_id" : 54528 ,
"document_ids" : [
6862419 ,
335116 ,
339186 ,
],
"scores" : [
0.4546215673141326 ,
0.6575686537173476 ,
0.26825184192900203 ,
],
},
{
"query_id" : 749480 ,
"document_ids" : [
6862419 ,
335116 ,
339186 ,
],
"scores" : [
0.2546215673141326 ,
0.7575686537173476 ,
0.96825184192900203 ,
],
},
]
dataset = Dataset . from_list ( mapping = dataset )
documents = [
{ "document_id" : 6862419 , "text" : "example doc 1" },
{ "document_id" : 335116 , "text" : "example doc 2" },
{ "document_id" : 339186 , "text" : "example doc 3" },
]
queries = [
{ "query_id" : 749480 , "text" : "example query" },
]
documents = Dataset . from_list ( mapping = documents )
queries = Dataset . from_list ( mapping = queries )Pylate ermöglicht das einfache Abrufen von Top -Dokumenten für eine bestimmte Abfrage mit dem ausgebildeten Colbert -Modell und dem Voyager -Index, laden Sie einfach das Modell und initieren Sie den Index:
from pylate import indexes , models , retrieve
model = models . ColBERT (
model_name_or_path = "lightonai/colbertv2.0" ,
)
index = indexes . Voyager (
index_folder = "pylate-index" ,
index_name = "index" ,
override = True ,
)
retriever = retrieve . ColBERT ( index = index )Sobald das Modell und der Index eingerichtet sind, können wir den Index mit ihren Einbettungen und entsprechenden IDs Dokumente hinzufügen:
documents_ids = [ "1" , "2" , "3" ]
documents = [
"document 1 text" , "document 2 text" , "document 3 text"
]
# Encode the documents
documents_embeddings = model . encode (
documents ,
batch_size = 32 ,
is_query = False , # Encoding documents
show_progress_bar = True ,
)
# Add the documents ids and embeddings to the Voyager index
index . add_documents (
documents_ids = documents_ids ,
documents_embeddings = documents_embeddings ,
)Anschließend können wir die Top-K-Dokumente für einen bestimmten Satz von Abfragen abrufen:
queries_embeddings = model . encode (
[ "query for document 3" , "query for document 1" ],
batch_size = 32 ,
is_query = True , # Encoding queries
show_progress_bar = True ,
)
scores = retriever . retrieve (
queries_embeddings = queries_embeddings ,
k = 10 ,
)
print ( scores )Beispielausgabe:
[
[
{ "id" : "3" , "score" : 11.266985893249512 },
{ "id" : "1" , "score" : 10.303335189819336 },
{ "id" : "2" , "score" : 9.502392768859863 },
],
[
{ "id" : "1" , "score" : 10.88800048828125 },
{ "id" : "3" , "score" : 9.950843811035156 },
{ "id" : "2" , "score" : 9.602447509765625 },
],
]Wenn Sie nur das Colbert-Modell verwenden möchten, um eine Überholung über Ihre Pipeline der ersten Stufe durchzuführen, ohne einen Index zu erstellen, können Sie einfach die Rangfunktion verwenden und die Abfragen und Dokumente zur Umarbeitung übergeben:
from pylate import rank
queries = [
"query A" ,
"query B" ,
]
documents = [
[ "document A" , "document B" ],
[ "document 1" , "document C" , "document B" ],
]
documents_ids = [
[ 1 , 2 ],
[ 1 , 3 , 2 ],
]
queries_embeddings = model . encode (
queries ,
is_query = True ,
)
documents_embeddings = model . encode (
documents ,
is_query = False ,
)
reranked_documents = rank . rerank (
documents_ids = documents_ids ,
queries_embeddings = queries_embeddings ,
documents_embeddings = documents_embeddings ,
)Wir begrüßen Beiträge! Um loszulegen:
pip install " pylate[dev] "make testmake ruffmake livedocSie können mit diesem Bibtex auf die Bibliothek verweisen:
@misc { PyLate ,
title = { PyLate: Flexible Training and Retrieval for Late Interaction Models } ,
author = { Chaffin, Antoine and Sourty, Raphaël } ,
url = { https://github.com/lightonai/pylate } ,
year = { 2024 }
}