Saya senang mengetahui bahwa kode ini telah digunakan dan dikutip dalam makalah berikut:
Domino: Menemukan kesalahan sistematis dengan embeddings lintas-modal oleh Eyuboglu et. al. di ICLR 2022
GSCLIP: Kerangka kerja untuk menjelaskan pergeseran distribusi dalam bahasa alami oleh Zhu et. al. di ICML 2022
UIC-NLP di Semeval-2022 Tugas 5: Menjelajahi pembelajaran kontras untuk deteksi multimodal meme misoginis oleh Cuervo et. al. di Semeval-2022
CDSBERT - Memperluas model bahasa protein dengan kesadaran kodon oleh Hallee ET. al. Dari University of Delaware (Sep 2023)
Enigma-51: Menuju pemahaman yang halus tentang interaksi objek manusia dalam skenario industri oleh Ragusa et. al. (Nov 2023)
Anda dapat menemukan info kutipan di bagian kanan halaman repo GitHub ini bernama: CITE Repositori ini atau gunakan info kutipan di bawah ini.
@software { Shariatnia_Simple_CLIP_2021 ,
author = { Shariatnia, M. Moein } ,
doi = { 10.5281/zenodo.6845731 } ,
month = { 4 } ,
title = { {Simple CLIP} } ,
version = { 1.0.0 } ,
year = { 2021 }
}Pada bulan Januari 2021 Openai mengumumkan dua model baru: Dall-E dan Clip , kedua model multi-modalitas yang menghubungkan teks dan gambar dengan cara tertentu. Dalam artikel ini kita akan menerapkan model klip dari awal di Pytorch . OpenAI telah bersumber terbuka beberapa kode yang berkaitan dengan model klip tetapi saya merasa mengintimidasi dan itu jauh dari sesuatu yang pendek dan sederhana. Saya juga menemukan tutorial yang baik yang diilhami oleh model klip pada contoh kode keras dan saya menerjemahkan beberapa bagiannya ke dalam pytorch untuk membangun tutorial ini sepenuhnya dengan pytorch tercinta kami!
Dalam mempelajari model visual yang dapat ditransfer dari kertas pengawasan bahasa alami, Openai memperkenalkan model baru mereka yang disebut klip , untuk pra-pelatihan gambar-gambar kontras . Singkatnya, model ini mempelajari hubungan antara seluruh kalimat dan gambar yang dijelaskannya; Dalam arti bahwa ketika model dilatih, diberi kalimat input, ia akan dapat mengambil gambar yang paling terkait yang sesuai dengan kalimat itu. Yang penting di sini adalah dilatih pada kalimat penuh alih -alih kelas tunggal seperti mobil, anjing, dll. Intuisinya adalah bahwa ketika dilatih pada seluruh kalimat, model dapat mempelajari lebih banyak hal dan menemukan beberapa pola antara gambar dan teks. Mereka juga menunjukkan bahwa ketika model ini dilatih pada set data besar gambar dan teks yang sesuai, itu juga dapat bertindak sebagai classifier juga. Saya mendorong Anda untuk mempelajari makalah ini untuk mempelajari lebih lanjut tentang model yang menarik ini dan hasil mereka yang mencengangkan pada dataset benchmarking. Untuk menyebutkan hanya satu, model klip yang dilatih dengan strategi ini mengklasifikasikan Imagenet lebih baik daripada model SOTA yang dilatih pada Imagenet itu sendiri dioptimalkan untuk satu -satunya tugas klasifikasi!
Sebagai penggoda (!), Mari kita lihat apa model terakhir yang akan kita bangun dalam artikel ini dari awal mampu: Diberikan kueri (teks mentah) seperti "seorang anak laki -laki melompat dengan skateboard" atau "seorang gadis melompat dari ayunan", model akan mengambil gambar yang paling relevan:

Mari kita lihat beberapa output lagi:

# !pip install timm
# !pip install transformers import os
import cv2
import gc
import numpy as np
import pandas as pd
import itertools
from tqdm . autonotebook import tqdm
import albumentations as A
import torch
from torch import nn
import torch . nn . functional as F
import timm
from transformers import DistilBertModel , DistilBertConfig , DistilBertTokenizer Catatan tentang Config dan CFG: Saya menulis kode dengan skrip Python dan kemudian mengonversinya menjadi buku catatan Jupyter. Jadi, dalam kasus skrip python, config adalah file python normal di mana saya meletakkan semua hyperparameters dan dalam kasus jupyter notebook, ini adalah kelas yang ditentukan di awal buku catatan untuk menyimpan semua hiperparameter.
class CFG :
debug = False
image_path = "C:/Moein/AI/Datasets/Flicker-8k/Images"
captions_path = "C:/Moein/AI/Datasets/Flicker-8k"
batch_size = 32
num_workers = 4
head_lr = 1e-3
image_encoder_lr = 1e-4
text_encoder_lr = 1e-5
weight_decay = 1e-3
patience = 1
factor = 0.8
epochs = 4
device = torch . device ( "cuda" if torch . cuda . is_available () else "cpu" )
model_name = 'resnet50'
image_embedding = 2048
text_encoder_model = "distilbert-base-uncased"
text_embedding = 768
text_tokenizer = "distilbert-base-uncased"
max_length = 200
pretrained = True # for both image encoder and text encoder
trainable = True # for both image encoder and text encoder
temperature = 1.0
# image size
size = 224
# for projection head; used for both image and text encoders
num_projection_layers = 1
projection_dim = 256
dropout = 0.1 class AvgMeter :
def __init__ ( self , name = "Metric" ):
self . name = name
self . reset ()
def reset ( self ):
self . avg , self . sum , self . count = [ 0 ] * 3
def update ( self , val , count = 1 ):
self . count += count
self . sum += val * count
self . avg = self . sum / self . count
def __repr__ ( self ):
text = f" { self . name } : { self . avg :.4f } "
return text
def get_lr ( optimizer ):
for param_group in optimizer . param_groups :
return param_group [ "lr" ]Seperti yang dapat Anda lihat di gambar tittle dari artikel ini, kita perlu menyandikan kedua gambar dan teks yang menggambarkannya. Jadi, dataset perlu mengembalikan gambar dan teks . Tentu saja kita tidak akan memberi makan teks mentah ke encoder teks kita! Kami akan menggunakan model Distilbert (yang lebih kecil dari Bert tetapi berkinerja hampir sama dengan Bert) dari perpustakaan Huggingface sebagai encoder teks kami; Jadi, kita perlu tokenize kalimat (teks) dengan Tokenizer Distilbert dan kemudian memberi makan token ID (input_ids) dan topeng perhatian untuk distilbert. Oleh karena itu, dataset perlu mengurus tokenisasi juga. Di bawah ini Anda dapat melihat kode dataset. Di bawah ini saya akan menjelaskan hal -hal terpenting yang terjadi dalam kode.
Di __init__ kami menerima objek tokenizer yang sebenarnya merupakan tokinzer permukaan pelukan; Tokenizer ini akan dimuat saat menjalankan model. Kami sedang memadukan dan memotong teks ke max_length yang ditentukan. Dalam __getItem__ kita pertama -tama akan memuat judul yang dikodekan yang merupakan kamus dengan kunci input_ids dan attention_mask, membuat tensor keluar dari nilainya dan setelah itu kita akan memuat gambar yang sesuai, mengubah dan menambahnya (jika ada!) Dan kemudian kita menjadikannya tletsor dan memasukkannya ke dalam kamus dengan "gambar sebagai kunci. Akhirnya kami meletakkan teks mentah dari keterangan dengan "keterangan" kunci di kamus hanya untuk tujuan visualisasi.
Saya tidak menggunakan augmentasi data tambahan tetapi Anda dapat menambahkannya jika Anda ingin meningkatkan kinerja model.
class CLIPDataset ( torch . utils . data . Dataset ):
def __init__ ( self , image_filenames , captions , tokenizer , transforms ):
"""
image_filenames and cpations must have the same length; so, if there are
multiple captions for each image, the image_filenames must have repetitive
file names
"""
self . image_filenames = image_filenames
self . captions = list ( captions )
self . encoded_captions = tokenizer (
list ( captions ), padding = True , truncation = True , max_length = CFG . max_length
)
self . transforms = transforms
def __getitem__ ( self , idx ):
item = {
key : torch . tensor ( values [ idx ])
for key , values in self . encoded_captions . items ()
}
image = cv2 . imread ( f" { CFG . image_path } / { self . image_filenames [ idx ] } " )
image = cv2 . cvtColor ( image , cv2 . COLOR_BGR2RGB )
image = self . transforms ( image = image )[ 'image' ]
item [ 'image' ] = torch . tensor ( image ). permute ( 2 , 0 , 1 ). float ()
item [ 'caption' ] = self . captions [ idx ]
return item
def __len__ ( self ):
return len ( self . captions )
def get_transforms ( mode = "train" ):
if mode == "train" :
return A . Compose (
[
A . Resize ( CFG . size , CFG . size , always_apply = True ),
A . Normalize ( max_pixel_value = 255.0 , always_apply = True ),
]
)
else :
return A . Compose (
[
A . Resize ( CFG . size , CFG . size , always_apply = True ),
A . Normalize ( max_pixel_value = 255.0 , always_apply = True ),
]
)Kode encoder gambar lurus ke depan. Saya menggunakan Pytorch Image Model Library (TIMM) di sini yang membuat banyak model gambar berbeda yang tersedia dari resnet ke efisiensi dan banyak lagi. Di sini kita akan menggunakan ResNet50 sebagai encoder gambar kita. Anda dapat dengan mudah menggunakan pustaka TorchVision untuk menggunakan resnet jika Anda tidak ingin menginstal pustaka baru.
Kode mengkodekan setiap gambar ke vektor ukuran tetap dengan ukuran saluran output model (dalam kasus resnet50 ukuran vektor adalah 2048 ). Ini adalah output setelah lapisan NN.AdaptiveAVGPool2D ().
class ImageEncoder ( nn . Module ):
"""
Encode images to a fixed size vector
"""
def __init__ (
self , model_name = CFG . model_name , pretrained = CFG . pretrained , trainable = CFG . trainable
):
super (). __init__ ()
self . model = timm . create_model (
model_name , pretrained , num_classes = 0 , global_pool = "avg"
)
for p in self . model . parameters ():
p . requires_grad = trainable
def forward ( self , x ):
return self . model ( x )Seperti yang saya sebutkan sebelumnya, saya akan menggunakan Distilbert sebagai encoder teks. Seperti kakaknya yang lebih besar, Bert, dua token khusus akan ditambahkan ke token input yang sebenarnya: CLS dan SEP yang menandai awal dan akhir kalimat. Untuk meraih seluruh representasi kalimat (seperti yang ditunjukkan oleh Papers Bert dan Distilbert), kami menggunakan representasi akhir dari token CLS dan kami berharap bahwa representasi ini menangkap makna keseluruhan dari kalimat (Keterangan). Memikirkannya dengan cara ini, mirip dengan apa yang kami lakukan dengan gambar dan mengubahnya menjadi vektor ukuran tetap.
Dalam kasus Distilbert (dan juga Bert) representasi tersembunyi output untuk setiap token adalah vektor dengan ukuran 768 . Jadi, seluruh keterangan akan dikodekan dalam representasi token CLS yang ukurannya adalah 768.
class TextEncoder ( nn . Module ):
def __init__ ( self , model_name = CFG . text_encoder_model , pretrained = CFG . pretrained , trainable = CFG . trainable ):
super (). __init__ ()
if pretrained :
self . model = DistilBertModel . from_pretrained ( model_name )
else :
self . model = DistilBertModel ( config = DistilBertConfig ())
for p in self . model . parameters ():
p . requires_grad = trainable
# we are using the CLS token hidden representation as the sentence's embedding
self . target_token_idx = 0
def forward ( self , input_ids , attention_mask ):
output = self . model ( input_ids = input_ids , attention_mask = attention_mask )
last_hidden_state = output . last_hidden_state
return last_hidden_state [:, self . target_token_idx , :]Saya menggunakan contoh kode keras implementasi proyeksi head untuk menulis yang berikut di pytorch. Sekarang kami telah mengkodekan gambar dan teks kami ke dalam vektor ukuran tetap (2048 untuk gambar dan 768 untuk teks) yang perlu kami bawa (proyek) ke dunia baru (!) Dengan dimensi yang sama untuk gambar dan teks agar dapat membandingkannya dan memisahkan gambar dan teks yang tidak relevan dan menyatukan yang cocok. Jadi, kode berikut akan membawa vektor dimensi 2048 dan 768 ke dunia dimensi 256 (proyection_dim), di mana kita dapat membandingkannya .
"embedding_dim" adalah ukuran vektor input (2048 untuk gambar dan 768 untuk teks) dan "proyection_dim" adalah ukuran vektor output yang akan menjadi 256 untuk kasus kami. Untuk memahami detail bagian ini, Anda dapat merujuk ke kertas klip.
class ProjectionHead ( nn . Module ):
def __init__ (
self ,
embedding_dim ,
projection_dim = CFG . projection_dim ,
dropout = CFG . dropout
):
super (). __init__ ()
self . projection = nn . Linear ( embedding_dim , projection_dim )
self . gelu = nn . GELU ()
self . fc = nn . Linear ( projection_dim , projection_dim )
self . dropout = nn . Dropout ( dropout )
self . layer_norm = nn . LayerNorm ( projection_dim )
def forward ( self , x ):
projected = self . projection ( x )
x = self . gelu ( projected )
x = self . fc ( x )
x = self . dropout ( x )
x = x + projected
x = self . layer_norm ( x )
return x Bagian ini adalah tempat semua kesenangan terjadi! Saya juga akan berbicara tentang fungsi kerugian di sini. Saya menerjemahkan beberapa kode dari contoh kode KERAS ke Pytorch untuk menulis bagian ini. Lihatlah kode dan kemudian baca penjelasan di bawah blok kode ini.
Di sini kami akan menggunakan modul sebelumnya yang kami buat untuk mengimplementasikan model utama. Fungsi __init__ cukup jelas. Dalam fungsi maju, pertama -tama kita mengkodekan gambar dan teks secara terpisah menjadi vektor ukuran tetap (dengan dimensi yang berbeda). Setelah itu, menggunakan modul proyeksi terpisah, kami memproyeksikannya ke dunia bersama (ruang) yang saya bicarakan sebelumnya. Di sini pengkodean akan menjadi bentuk yang sama (256 dalam kasus kami). Setelah itu kami akan menghitung kerugian. Sekali lagi saya merekomendasikan membaca kertas klip untuk mendapatkannya lebih baik tetapi saya akan mencoba yang terbaik untuk menjelaskan bagian ini.
Dalam aljabar linier , satu cara umum untuk mengukur jika dua vektor memiliki karakteristik yang sama (mereka seperti satu sama lain) adalah dengan menghitung produk titik mereka (mengalikan entri yang cocok dan mengambil jumlah mereka); Jika angka terakhirnya besar, mereka sama dan jika kecil mereka tidak (relatif berbicara)!
Oke! Apa yang baru saja saya katakan adalah hal terpenting yang perlu diingat untuk memahami fungsi kehilangan ini. Mari kita lanjutkan. Kami berbicara tentang dua vektor, tapi, apa yang kami miliki di sini? Kami memiliki Image_embeddings, matriks dengan bentuk (batch_size, 256) dan text_embeddings dengan bentuk (batch_size, 256). Cukup mudah! Ini berarti kami memiliki dua kelompok vektor, bukan dua vektor tunggal. Bagaimana kita mengukur seberapa mirip dua kelompok vektor (dua matriks) satu sama lain? Sekali lagi, dengan produk dot (@ operator di pytorch melakukan produk titik atau multiplikasi matriks dalam kasus ini). Untuk dapat melipatgandakan kedua matriks ini bersama -sama, kami mengubah yang kedua. Oke, kami mendapatkan matriks dengan bentuk (batch_size, batch_size) yang akan kami sebut logit. (Suhu sama dengan 1,0 dalam kasus kami, jadi, itu tidak membuat perbedaan. Anda dapat bermain dengannya dan melihat apa bedanya. Lihat juga kertas untuk melihat mengapa itu ada di sini!).
Saya harap Anda masih bersama saya! Jika tidak, tidak apa -apa, cukup tinjau kode dan periksa bentuknya. Sekarang kami memiliki logit kami, kami membutuhkan target. Saya perlu mengatakan bahwa ada cara yang lebih lurus ke depan untuk mendapatkan target tetapi saya harus melakukan ini untuk kasus kami (saya akan berbicara tentang mengapa dalam paragraf berikutnya).
Mari kita pertimbangkan apa yang kami harap model ini belajar: kami ingin belajar "representasi serupa (vektor)" untuk gambar yang diberikan dan keterangan yang menggambarkannya. Artinya kita memberikan gambar atau teks yang menggambarkannya, kami ingin menghasilkan vektor berukuran 256 yang sama untuk keduanya.
class CLIPModel ( nn . Module ):
def __init__ (
self ,
temperature = CFG . temperature ,
image_embedding = CFG . image_embedding ,
text_embedding = CFG . text_embedding ,
):
super (). __init__ ()
self . image_encoder = ImageEncoder ()
self . text_encoder = TextEncoder ()
self . image_projection = ProjectionHead ( embedding_dim = image_embedding )
self . text_projection = ProjectionHead ( embedding_dim = text_embedding )
self . temperature = temperature
def forward ( self , batch ):
# Getting Image and Text Features
image_features = self . image_encoder ( batch [ "image" ])
text_features = self . text_encoder (
input_ids = batch [ "input_ids" ], attention_mask = batch [ "attention_mask" ]
)
# Getting Image and Text Embeddings (with same dimension)
image_embeddings = self . image_projection ( image_features )
text_embeddings = self . text_projection ( text_features )
# Calculating the Loss
logits = ( text_embeddings @ image_embeddings . T ) / self . temperature
images_similarity = image_embeddings @ image_embeddings . T
texts_similarity = text_embeddings @ text_embeddings . T
targets = F . softmax (
( images_similarity + texts_similarity ) / 2 * self . temperature , dim = - 1
)
texts_loss = cross_entropy ( logits , targets , reduction = 'none' )
images_loss = cross_entropy ( logits . T , targets . T , reduction = 'none' )
loss = ( images_loss + texts_loss ) / 2.0 # shape: (batch_size)
return loss . mean ()
def cross_entropy ( preds , targets , reduction = 'none' ):
log_softmax = nn . LogSoftmax ( dim = - 1 )
loss = ( - targets * log_softmax ( preds )). sum ( 1 )
if reduction == "none" :
return loss
elif reduction == "mean" :
return loss . mean ()Jadi, dalam skenario kasus terbaik, Text_embeddings dan matricies Image_embedding harus sama karena mereka menggambarkan hal -hal serupa. Mari kita pikirkan sekarang: jika ini terjadi, seperti apa matriks logit? Mari kita lihat dengan contoh sederhana!
# A simple Example
batch_size = 4
dim = 256
embeddings = torch . randn ( batch_size , dim )
out = embeddings @ embeddings . T
print ( F . softmax ( out , dim = - 1 ))Jadi logit, dalam kasus terbaik, akan menjadi matriks bahwa jika kita mengambil softmax, akan memiliki 1,0 di diagonal (matriks identitas untuk menyebutnya dengan kata -kata mewah!). Karena tugas fungsi kerugian adalah membuat prediksi model yang mirip dengan target (setidaknya dalam kebanyakan kasus!), Kami menginginkan matriks seperti target kami. Itulah alasan mengapa kami menghitung gambar Images_Similarity dan Texts_Similarity di blok kode di atas.
Sekarang kami memiliki matriks target kami, kami akan menggunakan entropi silang sederhana untuk menghitung kerugian aktual. Saya telah menulis bentuk matriks lengkap dari entropi silang sebagai fungsi yang dapat Anda lihat di bagian bawah blok kode. Oke! Kami sudah selesai! Bukankah itu sederhana?! Baiklah, Anda dapat mengabaikan paragraf berikutnya tetapi jika Anda penasaran, ada catatan penting dalam hal itu.
Inilah mengapa saya tidak menggunakan pendekatan yang lebih sederhana : Saya perlu mengakui bahwa ada cara yang lebih sederhana untuk menghitung kerugian ini di Pytorch; Dengan melakukan ini: nn.crossentropyloss () (logit, torch.Arange (batch_size)). Mengapa saya tidak menggunakannya di sini? Karena 2 alasan. 1- Dataset yang kami gunakan memiliki banyak teks untuk satu gambar; Jadi, ada kemungkinan bahwa dua gambar yang identik dengan teks yang sama ada dalam batch (jarang tetapi itu bisa terjadi). Mengambil kerugian dengan metode yang lebih mudah ini akan mengabaikan kemungkinan ini dan model belajar untuk memisahkan dua representasi (menganggapnya berbeda) yang sebenarnya sama. Jelas, kami tidak ingin ini terjadi, jadi saya menghitung seluruh matriks target dengan cara yang menangani kasus -kasus tepi ini. 2- Melakukannya seperti yang saya lakukan, memberi saya pemahaman yang lebih baik tentang apa yang terjadi dalam fungsi kerugian ini; Jadi, saya pikir itu akan memberi Anda intuisi yang lebih baik juga!
Berikut adalah beberapa kesenangan untuk membantu kami memuat kereta dan dataloader yang valid, model kami dan kemudian melatih dan mengevaluasi model kami tentang itu. Tidak banyak yang terjadi di sini; hanya fungsi pelatihan dan fungsi utilitas sederhana
def make_train_valid_dfs ():
dataframe = pd . read_csv ( f" { CFG . captions_path } /captions.csv" )
max_id = dataframe [ "id" ]. max () + 1 if not CFG . debug else 100
image_ids = np . arange ( 0 , max_id )
np . random . seed ( 42 )
valid_ids = np . random . choice (
image_ids , size = int ( 0.2 * len ( image_ids )), replace = False
)
train_ids = [ id_ for id_ in image_ids if id_ not in valid_ids ]
train_dataframe = dataframe [ dataframe [ "id" ]. isin ( train_ids )]. reset_index ( drop = True )
valid_dataframe = dataframe [ dataframe [ "id" ]. isin ( valid_ids )]. reset_index ( drop = True )
return train_dataframe , valid_dataframe
def build_loaders ( dataframe , tokenizer , mode ):
transforms = get_transforms ( mode = mode )
dataset = CLIPDataset (
dataframe [ "image" ]. values ,
dataframe [ "caption" ]. values ,
tokenizer = tokenizer ,
transforms = transforms ,
)
dataloader = torch . utils . data . DataLoader (
dataset ,
batch_size = CFG . batch_size ,
num_workers = CFG . num_workers ,
shuffle = True if mode == "train" else False ,
)
return dataloaderBerikut fungsi praktis untuk melatih model kami. Tidak banyak yang terjadi di sini; Hanya memuat batch, memberi makan mereka ke model dan menginjak pengoptimal dan LR_SCHEDuler.
def train_epoch ( model , train_loader , optimizer , lr_scheduler , step ):
loss_meter = AvgMeter ()
tqdm_object = tqdm ( train_loader , total = len ( train_loader ))
for batch in tqdm_object :
batch = { k : v . to ( CFG . device ) for k , v in batch . items () if k != "caption" }
loss = model ( batch )
optimizer . zero_grad ()
loss . backward ()
optimizer . step ()
if step == "batch" :
lr_scheduler . step ()
count = batch [ "image" ]. size ( 0 )
loss_meter . update ( loss . item (), count )
tqdm_object . set_postfix ( train_loss = loss_meter . avg , lr = get_lr ( optimizer ))
return loss_meter
def valid_epoch ( model , valid_loader ):
loss_meter = AvgMeter ()
tqdm_object = tqdm ( valid_loader , total = len ( valid_loader ))
for batch in tqdm_object :
batch = { k : v . to ( CFG . device ) for k , v in batch . items () if k != "caption" }
loss = model ( batch )
count = batch [ "image" ]. size ( 0 )
loss_meter . update ( loss . item (), count )
tqdm_object . set_postfix ( valid_loss = loss_meter . avg )
return loss_meter
def main ():
train_df , valid_df = make_train_valid_dfs ()
tokenizer = DistilBertTokenizer . from_pretrained ( CFG . text_tokenizer )
train_loader = build_loaders ( train_df , tokenizer , mode = "train" )
valid_loader = build_loaders ( valid_df , tokenizer , mode = "valid" )
model = CLIPModel (). to ( CFG . device )
params = [
{ "params" : model . image_encoder . parameters (), "lr" : CFG . image_encoder_lr },
{ "params" : model . text_encoder . parameters (), "lr" : CFG . text_encoder_lr },
{ "params" : itertools . chain (
model . image_projection . parameters (), model . text_projection . parameters ()
), "lr" : CFG . head_lr , "weight_decay" : CFG . weight_decay }
]
optimizer = torch . optim . AdamW ( params , weight_decay = 0. )
lr_scheduler = torch . optim . lr_scheduler . ReduceLROnPlateau (
optimizer , mode = "min" , patience = CFG . patience , factor = CFG . factor
)
step = "epoch"
best_loss = float ( 'inf' )
for epoch in range ( CFG . epochs ):
print ( f"Epoch: { epoch + 1 } " )
model . train ()
train_loss = train_epoch ( model , train_loader , optimizer , lr_scheduler , step )
model . eval ()
with torch . no_grad ():
valid_loss = valid_epoch ( model , valid_loader )
if valid_loss . avg < best_loss :
best_loss = valid_loss . avg
torch . save ( model . state_dict (), "best.pt" )
print ( "Saved Best Model!" )
lr_scheduler . step ( valid_loss . avg )Menjalankan sel mulai sel berikutnya melatih model. Letakkan kernel pada mode GPU. Setiap zaman harus memakan waktu sekitar 24 menit di GPU (bahkan satu zaman sudah cukup!). Ini bisa memakan waktu satu menit sebelum pelatihan benar -benar dimulai karena kita akan menyandikan semua teks sekali di kereta dan dataset yang valid, jadi tolong jangan hentikan! Segala sesuatu bekerja dengan baik.
main ()Oke! Kami selesai dengan melatih model. Sekarang, kita perlu melakukan inferensi yang dalam kasus kami akan memberikan model teks dan ingin mengambil gambar yang paling relevan dari set validasi (atau tes) yang tidak terlihat.
Dalam fungsi ini, kami memuat model yang kami simpan setelah pelatihan, memberi makan gambar dalam set validasi dan mengembalikan Image_embeddings dengan bentuk (valid_set_size, 256) dan model itu sendiri.
def get_image_embeddings ( valid_df , model_path ):
tokenizer = DistilBertTokenizer . from_pretrained ( CFG . text_tokenizer )
valid_loader = build_loaders ( valid_df , tokenizer , mode = "valid" )
model = CLIPModel (). to ( CFG . device )
model . load_state_dict ( torch . load ( model_path , map_location = CFG . device ))
model . eval ()
valid_image_embeddings = []
with torch . no_grad ():
for batch in tqdm ( valid_loader ):
image_features = model . image_encoder ( batch [ "image" ]. to ( CFG . device ))
image_embeddings = model . image_projection ( image_features )
valid_image_embeddings . append ( image_embeddings )
return model , torch . cat ( valid_image_embeddings ) _ , valid_df = make_train_valid_dfs ()
model , image_embeddings = get_image_embeddings ( valid_df , "best.pt" )Fungsi ini melakukan tugas akhir yang kami harapkan dari model kami mampu: ia mendapatkan model, Image_embeddings, dan kueri teks. Ini akan menampilkan gambar yang paling relevan dari set validasi! Bukankah itu luar biasa? Mari kita lihat bagaimana kinerjanya!
def find_matches ( model , image_embeddings , query , image_filenames , n = 9 ):
tokenizer = DistilBertTokenizer . from_pretrained ( CFG . text_tokenizer )
encoded_query = tokenizer ([ query ])
batch = {
key : torch . tensor ( values ). to ( CFG . device )
for key , values in encoded_query . items ()
}
with torch . no_grad ():
text_features = model . text_encoder (
input_ids = batch [ "input_ids" ], attention_mask = batch [ "attention_mask" ]
)
text_embeddings = model . text_projection ( text_features )
image_embeddings_n = F . normalize ( image_embeddings , p = 2 , dim = - 1 )
text_embeddings_n = F . normalize ( text_embeddings , p = 2 , dim = - 1 )
dot_similarity = text_embeddings_n @ image_embeddings_n . T
values , indices = torch . topk ( dot_similarity . squeeze ( 0 ), n * 5 )
matches = [ image_filenames [ idx ] for idx in indices [:: 5 ]]
_ , axes = plt . subplots ( 3 , 3 , figsize = ( 10 , 10 ))
for match , ax in zip ( matches , axes . flatten ()):
image = cv2 . imread ( f" { CFG . image_path } / { match } " )
image = cv2 . cvtColor ( image , cv2 . COLOR_BGR2RGB )
ax . imshow ( image )
ax . axis ( "off" )
plt . show ()Beginilah cara kami menggunakan fungsi ini. Aaaannnnddd. Hasilnya:
find_matches ( model ,
image_embeddings ,
query = "a group of people dancing in a party" ,
image_filenames = valid_df [ 'image' ]. values ,
n = 9 )
Saya harap Anda menikmati artikel ini. Menerapkan makalah ini adalah pengalaman yang sangat menarik bagi saya. Saya ingin mengucapkan terima kasih kepada Khalid Salama untuk contoh kode keras yang hebat yang ia berikan yang mengilhami saya untuk menulis sesuatu yang serupa di Pytorch.