Descrição chinesa | Inglês
Modelscope | Demo | Papel | Blog
Este projeto é uma versão chinesa do modelo de clipe e usa dados chineses em larga escala para treinamento (~ 200 milhões de pares gráficos e de texto), com o objetivo de ajudar os usuários a realizar rapidamente tarefas como características gráficas e de texto e cálculo de similaridade, recuperação cruzada e classificação de imagem de amostra zero no campo chinês. O código deste projeto é baseado no projeto Open_Clip e é otimizado para dados de campo chinês e para obter melhores resultados nos dados chineses. Este projeto fornece API, código de treinamento e código de teste, e os detalhes serão descritos em detalhes abaixo.
Atualmente, o Chinese-Clip é de origem aberta em 5 escalas diferentes, e suas informações de modelo e métodos de download são mostrados na tabela a seguir:
| Tamanho do modelo | Baixar link | Quantidade de parâmetro | Esqueleto lateral visual | Quantidade de parâmetro lateral visual | Esqueleto do lado do texto | Quantidade do parâmetro do lado do texto | Resolução |
|---|---|---|---|---|---|---|---|
| CN-CLIP RN50 | Download | 77m | Resnet50 | 38m | RBT3 | 39m | 224 |
| CN-Clip Vit-B/16 | Download | 188m | Vit-B/16 | 86m | Roberta-Wwm-Base | 102m | 224 |
| CN-Clip Vit-L/14 | Download | 406m | Vit-L/14 | 304m | Roberta-Wwm-Base | 102m | 224 |
| CN-Clip Vit-L/14@336px | Download | 407m | Vit-L/14 | 304m | Roberta-Wwm-Base | 102m | 336 |
| CN-Clip Vit-H/14 | Download | 958m | Vit-H/14 | 632m | Roberta-Wwm-Large | 326m | 224 |
Para a tarefa de recuperação gráfica e de texto, realizamos experimentos zero e finetune sobre recuperação de Muge, flickr30k-cn e coco-cn. Para classificação de amostra zero da imagem, realizamos experimentos em 10 conjuntos de dados de elevador. Os resultados experimentais são mostrados na tabela abaixo. Devido a limitações de espaço, fornecemos aqui os resultados ideais do modelo de escala do modelo de linha de base e do clipe chinês. Para obter indicadores de resultados detalhados de cada escala de clipe chinês, consulte Results.md para obter detalhes.
MUGE Recuperação de texto para imagem (conjunto de validação oficial) :
| Configurar | Zero-shot | Afinar | ||||||
|---|---|---|---|---|---|---|---|---|
| Métrica | R@1 | R@5 | R@10 | SENHOR | R@1 | R@5 | R@10 | SENHOR |
| Wukong | 42.7 | 69.0 | 78.0 | 63.2 | 52.7 | 77.9 | 85.6 | 72.1 |
| R2d2 | 49.5 | 75.7 | 83.2 | 69.5 | 60.1 | 82.9 | 89.4 | 77.5 |
| CN-clip | 63.0 | 84.1 | 89.2 | 78.8 | 68.9 | 88.7 | 93.1 | 83.6 |
Flickr30K-CN Recuperação (conjunto de testes oficiais) :
| Tarefa | Texto para imagem | Imagem para texto | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Configurar | Zero-shot | Afinar | Zero-shot | Afinar | ||||||||
| Métrica | R@1 | R@5 | R@10 | R@1 | R@5 | R@10 | R@1 | R@5 | R@10 | R@1 | R@5 | R@10 |
| Wukong | 51.7 | 78.9 | 86.3 | 77.4 | 94.5 | 97.0 | 76.1 | 94.8 | 97.5 | 92.7 | 99.1 | 99.6 |
| Taiyi | 60.8 | 85.0 | 91.0 | - | - | - | - | - | - | - | - | - |
| R2d2 | 60.9 | 86.8 | 92.7 | 84.4 | 96.7 | 98.4 | 77.6 | 96.7 | 98.9 | 95.6 | 99.8 | 100.0 |
| CN-clip | 71.2 | 91.4 | 95.5 | 83.8 | 96.9 | 98.6 | 81.6 | 97.5 | 98.8 | 95.3 | 99.7 | 100.0 |
Recuperação Coco-CN (conjunto de testes oficiais) :
| Tarefa | Texto para imagem | Imagem para texto | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Configurar | Zero-shot | Afinar | Zero-shot | Afinar | ||||||||
| Métrica | R@1 | R@5 | R@10 | R@1 | R@5 | R@10 | R@1 | R@5 | R@10 | R@1 | R@5 | R@10 |
| Wukong | 53.4 | 80.2 | 90.1 | 74.0 | 94.4 | 98.1 | 55.2 | 81.0 | 90.6 | 73.3 | 94.0 | 98.0 |
| Taiyi | 60.0 | 84.0 | 93.3 | - | - | - | - | - | - | - | - | - |
| R2d2 | 56.4 | 85.0 | 93.1 | 79.1 | 96.5 | 98.9 | 63.3 | 89.3 | 95.7 | 79.3 | 97.1 | 98.7 |
| CN-clip | 69.2 | 89.9 | 96.1 | 81.5 | 96.9 | 99.1 | 63.0 | 86.6 | 92.9 | 83.5 | 97.3 | 99.2 |
Classificação de imagem com tiro zero :
| Tarefa | Cifar10 | Cifar100 | Dtd | Eurosat | Fer | FGVC | Kitti | Mnist | PC | Voc |
|---|---|---|---|---|---|---|---|---|---|---|
| Git | 88.5 | 61.1 | 42.9 | 43.4 | 41.4 | 6.7 | 22.1 | 68.9 | 50.0 | 80.2 |
| ALINHAR | 94.9 | 76.8 | 66.1 | 52.1 | 50.8 | 25.0 | 41.2 | 74.0 | 55.2 | 83.0 |
| GRAMPO | 94.9 | 77.0 | 56.0 | 63.0 | 48.3 | 33.3 | 11.5 | 79.0 | 62.3 | 84.0 |
| Wukong | 95.4 | 77.1 | 40.9 | 50.3 | - | - | - | - | - | - |
| CN-clip | 96.0 | 79.7 | 51.2 | 52.0 | 55.1 | 26.2 | 49.9 | 79.4 | 63.5 | 84.9 |
Antes de iniciar este projeto, você precisa verificar se os seguintes requisitos de configuração ambiental são atendidos:
Execute o comando a seguir para instalar as bibliotecas de três partes necessárias para este projeto.
pip install -r requirements.txtAqui está um exemplo simples de código para ilustrar como usar a API de clipe chinês. Antes de começar a usar, instale CN_CLIP:
# 通过pip安装
pip install cn_clip
# 或者从源代码安装
cd Chinese-CLIP
pip install -e .Após o sucesso da instalação, você pode chamar a API facilmente através dos métodos a seguir, passar na imagem especificada (exemplo) e texto, extrair o vetor de recurso gráfico e calcular a similaridade:
import torch
from PIL import Image
import cn_clip . clip as clip
from cn_clip . clip import load_from_name , available_models
print ( "Available models:" , available_models ())
# Available models: ['ViT-B-16', 'ViT-L-14', 'ViT-L-14-336', 'ViT-H-14', 'RN50']
device = "cuda" if torch . cuda . is_available () else "cpu"
model , preprocess = load_from_name ( "ViT-B-16" , device = device , download_root = './' )
model . eval ()
image = preprocess ( Image . open ( "examples/pokemon.jpeg" )). unsqueeze ( 0 ). to ( device )
text = clip . tokenize ([ "杰尼龟" , "妙蛙种子" , "小火龙" , "皮卡丘" ]). to ( device )
with torch . no_grad ():
image_features = model . encode_image ( image )
text_features = model . encode_text ( text )
# 对特征进行归一化,请使用归一化后的图文特征用于下游任务
image_features /= image_features . norm ( dim = - 1 , keepdim = True )
text_features /= text_features . norm ( dim = - 1 , keepdim = True )
logits_per_image , logits_per_text = model . get_similarity ( image , text )
probs = logits_per_image . softmax ( dim = - 1 ). cpu (). numpy ()
print ( "Label probs:" , probs ) # [[1.268734e-03 5.436878e-02 6.795761e-04 9.436829e-01]]Também preparamos suporte relevante para implantar modelos ONNX e Tensorrt. Para detalhes, consulte o implantação.md.
Se você não estiver satisfeito em usar a API, continue lendo este documento para aprender a usar nosso projeto para treinamento e teste de modelos de clipes.
O seguinte incluirá tutoriais de recuperação cruzada-modal (incluindo Finetune e Inferência, Cálculo de KNN, etc.) e tutoriais de classificação de imagem de amostra zero.
Depois de baixar este projeto, crie uma nova pasta ${DATAPATH} para armazenar o conjunto de dados, o CKPT pré-treinado e o modelo de registro e CKPT gerado pelo FineTune. A estrutura do diretório da área de trabalho recomendada é a seguinte:
Chinese-CLIP/
├── run_scripts/
│ ├── muge_finetune_vit-b-16_rbt-base.sh
│ ├── flickr30k_finetune_vit-b-16_rbt-base.sh
│ └── ... # 更多finetune或评测脚本...
└── cn_clip/
├── clip/
├── eval/
├── preprocess/
└── training/
${DATAPATH}
├── pretrained_weights/
├── experiments/
├── deploy/ # 用于存放ONNX & TensorRT部署模型
└── datasets/
├── MUGE/
├── Flickr30k-CN/
└── .../ # 更多自定义数据集...
Aqui, fornecemos o método de download de parâmetros de modelo pré-treinado, bem como o processo de pré-processamento de dados antes do FineTune.
Consulte a seção anterior da escala do modelo e do download do link para baixar o modelo correspondente CKPT. Recomenda -se armazenar o arquivo CKPT baixado no ${DATAPATH}/pretrained_weights/ diretório.
Para se adaptar ao código-clipe chinês e garantir a eficiência do processamento e leitura de dados, recomendamos que os conjuntos de dados gráficos e de texto usados para treinamento e avaliação sejam organizados nos seguintes métodos:
${DATAPATH}
└── datasets/
└── ${dataset_name}/
├── train_imgs.tsv # 图片id & 图片内容
├── train_texts.jsonl # 文本id & 文本内容,连同匹配的图片id列表
├── valid_imgs.tsv
├── valid_texts.jsonl
├── test_imgs.tsv
└── test_texts.jsonl
Onde ${dataset_name} refere -se ao nome do conjunto de dados (como Muge)
Para garantir a eficiência do processamento de arquivos, não armazenamos imagens em grandes quantidades de arquivos pequenos, mas armazenamos imagens de treinamento/verificação/teste na base64 no arquivo ${split}_imgs.tsv , respectivamente. Cada linha do arquivo representa uma imagem, incluindo o ID da imagem (Int Type) e a imagem base64, separada por guia, e o formato é o seguinte:
1000002 /9j/4AAQSkZJ...YQj7314oA//2Q==
A maneira de converter o arquivo de imagem original em base64 é muito simples, execute o seguinte código Python:
from PIL import Image
from io import BytesIO
import base64
img = Image . open ( file_name ) # 访问图片路径
img_buffer = BytesIO ()
img . save ( img_buffer , format = img . format )
byte_data = img_buffer . getvalue ()
base64_str = base64 . b64encode ( byte_data ) # bytes
base64_str = base64_str . decode ( "utf-8" ) # str As informações de texto e a relação correspondente entre os pares gráficos e de texto são salvos no arquivo ${split}_texts.jsonl . Cada linha do arquivo é uma linha de JSON, o formato é o seguinte:
{"text_id": 8428, "text": "高级感托特包斜挎", "image_ids": [1076345, 517602]}
Para o conjunto de testes, existe apenas texto e não conheço a relação correspondente entre os pares de figuras e texto, o campo image_ids de cada linha pode ser processado como uma lista vazia, ou seja, "image_ids": [] .
Por fim, também precisamos serializar os arquivos TSV e JSONL juntos e convertê-los em arquivos de banco de dados LMDB indexados pela memória para facilitar a leitura aleatória durante o treinamento.
python cn_clip/preprocess/build_lmdb_dataset.py
--data_dir ${DATAPATH}/datasets/${dataset_name}
--splits train,valid,test
Por exemplo, para o MUGE DATASET, ${dataset_name} está definido como MUGE e --splits Especifica a divisão do conjunto de dados que precisa ser convertida, separada por vírgulas sem espaços. Após a conversão, os seguintes arquivos serializados LMDB serão adicionados à pasta do conjunto de dados.
${DATAPATH}
└── datasets/
└── ${dataset_name}/
└── lmdb/
├── train
│ ├── imgs
│ └── pairs
├── valid
└── test
Para reduzir a dificuldade de começar, também fornecemos o pacote de dados do MUGE (Download Link) e FlickR30K-CN (Link para download) pré-processados de acordo com as etapas acima. Basta baixar e descompactar e colocá -lo no ${DATAPATH}/datasets/ Directory. Se os dados do Coco-CN forem necessários, entre em contato conosco por e-mail depois de solicitar permissão do autor original.
Aqui, apresentamos as etapas de treinamento para facilitar outros usuários para entender os detalhes do modelo e usar o modelo pré-treinado do clipe chinês que fornecemos para o Finetune. Com base nos dois conjuntos de dados de pesquisa a jusante de Muge e Flickr30K-CN, fornecemos scripts de amostra de treinamento run_scripts/muge_finetune_vit-b-16_rbt-base.sh e run_scripts/flickr30k_finetune_vit-b-16_rbt-base.sh . A execução de scripts suporta o treinamento distribuído de uma máquina única (cartões únicos ou múltiplos) e de várias máquinas. Antes de executar, preencha as configurações relacionadas distribuídas de acordo com as diretrizes e comentários no início do script e, em seguida, execute os seguintes comandos para iniciar o treinamento (execute os comandos em cada máquina para treinamento em várias máquinas). Para memória de vídeo insuficiente, você pode considerar ativar a estratégia de recálculo no item de configuração. Os arquivos CKPT de log e modelo gerados pelo treinamento serão salvos automaticamente no diretório especificado pelo usuário:
cd Chinese-CLIP/
bash run_scripts/muge_finetune_vit-b-16_rbt-base.sh ${DATAPATH}Os itens de configuração de treinamento relacionados incluem:
WORKER_CNT : o número de máquinas treinadasGPUS_PER_NODE : número de GPUs em cada máquinatrain-data : o diretório LMDB de dados de treinamento, veja acima para o processo de pré-processamento para preparar arquivos de dados LMDB.val-data : Verifique o diretório LMDB de dados. Quando especificado como nenhuma, a verificação durante o treinamento não será realizada.num-workers : o número de processos no conjunto de dados do conjunto de treinamento (Dataloader), o padrão é 4.valid-num-workers : O número de processos para o conjunto de dados do conjunto de verificação (Dataloader) (se a validação for executada), o padrão é 1.vision-model : especifique o backbone visual, selecione FROM ["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"] .text-model : especifique o backbone do texto, selecione ["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese", "RBT3-chinese"] .context-length : comprimento da sequência de entrada de texto.warmup : etapas de aquecimento.batch-size : tamanho de lote de cartão único durante o treinamento. (Certifique-se de que训练样本总数> batch-size * GPU数, que atende a pelo menos 1 lote de treinamento)lr : Taxa de aprendizado.wd : Decaimento do peso.max-steps : O número de etapas de treinamento e o número de rodadas de treinamento também podem ser especificadas através de max-epochs .freeze-vision : se deve congelar o backbone visual.use-augment : se deve usar o AutoAgment para aprimorar os dados da imagem.valid-batch-size : tamanho em lote independente durante a verificação. (Certifique-se de que验证集样本总数> batch-size * GPU数, satisfazendo pelo menos 1 lote de verificação)valid-step-interval e valid-epoch-interval : verifique a frequência da etapa/epocê. Se especificado como -1, a verificação não será realizada durante o treinamento.grad-checkpointing : Use a estratégia de recálculo para não salvar os resultados intermediários durante o processo avançado, em troca de menos sobrecarga de memória no tempo de treinamento, o que é adequado para memória insuficiente. (Parâmetro store_true , basta adicionar --grad-checkpointing ao script, atualmente é necessário pytorch> 1.8.0)mask-ratio : Referindo -se à estratégia FLIP, uma máscara aleatória pode ser especificada para uma certa proporção de patches de imagem durante o Finetune para reduzir a sobrecarga da memória e acelerar o treinamento. O padrão é 0,0, o que significa que essa política não está ativada.use-flash-attention : O uso do Flashattion pode acelerar significativamente o processo Finetune do clipe chinês e reduzir o uso da memória sem afetar o efeito. (Parâmetro store_true . Depois de configurar o ambiente, adicione --use-flash-attention ao script. Por favor, consulte Flash_Attion.md para obter detalhes)accum-freq : A frequência de acumulação de gradiente é 1 por padrão. Quando especificado como um número inteiro maior que 1, o acúmulo de gradiente de aprendizado comparativo é permitido simular um tamanho de lote maior. Se o tamanho do lote de cartão único for m , o tamanho total do lote será accum_freq * m * GPU数.gather-with-grad : Se você deve executar o GRAY GRAY com gradientes completos durante o treinamento distribuído, ele é desligado por padrão.name : Especifique o caminho de saída. O log HyperParameter, o log de treinamento e a saída CKPT serão armazenados em ${DATAPATH}/experiments/${name}/ .save-step-frequency e save-epoch-frequency : o intervalo entre as etapas ou rodadas do CKPT.report-training-batch-acc : se o log relata a precisão do gráfico de treinamento para textos e texto em lote de gráficos.resume : o caminho a ser lido por peso. O script de amostra especifica como um caminho CKPT pré-treinado, ou pode ser especificado como o caminho CKPT do Finetune do usuário para treinamento contínuo.reset-data-offset : Se você deve continuar sendo executado do ponto de interrupção dos dados anteriores. Se o tamanho do lote ou o número da placa GPU exceder o parâmetro, é recomendável ativar esta opção.reset-optimizer : se deve usar o estado do otimizador. Após o treinamento, o log existirá automaticamente ${DATAPATH}/experiments/${name}/out_${timestamp}.log . O formato de log de treinamento é o seguinte:
2022-12-11,20:40:34 | INFO | Rank 0 | Global Steps: 1/735 | Train Epoch: 1 [1024/250880 (0%)] | Loss: 2.371020 | Image2Text Acc: 49.90 | Text2Image Acc: 48.73 | Data Time: 1.039s | Batch Time: 3.625s | LR: 0.000000 | logit_scale: 4.605 | Global Batch Size: 1024
O formato de log de verificação é o seguinte:
2022-12-11,20:42:47 | INFO | Rank 0 | Validation Result (epoch 1 @ 150 steps) | Valid Loss: 0.502810 | Image2Text Acc: 84.95 | Text2Image Acc: 84.26 | logit_scale: 4.605 | Valid Batch Size: 128
Nota : A comparação da convergência do treinamento e a estabilidade do aprendizado estão correlacionadas com o tamanho total do lote. Se você usar um tamanho de lote menor (em comparação com a configuração padrão de 128 por GPU * 8 GPU), é recomendável usar uma taxa de aprendizado menor. Recomendamos o uso de mais GPUs e maior tamanho em lote para obter melhores resultados.
Fornecemos o processo de extração de recursos e avaliação de tarefas de recuperação gráfica, que é a seguinte:
Atualmente, este código suporta o uso da placa única da GPU para extração de recursos gráficos, consulte o comando a seguir. Também fornecemos suporte para implantar modelos ONNX e Tensorrt para acelerar a inferência de recursos, consulte o implantação.md para obter detalhes.
cd Chinese-CLIP/
export CUDA_VISIBLE_DEVICES=0
export PYTHONPATH= ${PYTHONPATH} : ` pwd ` /cn_clip
split=valid # 指定计算valid或test集特征
resume= ${DATAPATH} /pretrained_weights/clip_cn_vit-b-16.pt
python -u cn_clip/eval/extract_features.py
--extract-image-feats
--extract-text-feats
--image-data= " ${DATAPATH} /datasets/ ${dataset_name} /lmdb/ ${split} /imgs "
--text-data= " ${DATAPATH} /datasets/ ${dataset_name} / ${split} _texts.jsonl "
--img-batch-size=32
--text-batch-size=32
--context-length=52
--resume= ${resume}
--vision-model=ViT-B-16
--text-model=RoBERTa-wwm-ext-base-chinese Os recursos gráficos de saída serão salvos no diretório ${DATAPATH}/datasets/${dataset_name} por padrão, e os recursos da imagem são salvos no arquivo ${split}_imgs.img_feat.jsonl . Cada linha armazena os recursos de uma imagem em JSON, e o formato é o seguinte:
{"image_id": 1000002, "feature": [0.0198, ..., -0.017, 0.0248]}
Os recursos de texto são salvos em ${split}_texts.txt_feat.jsonl , com o formato da seguinte forma:
{"text_id": 248816, "feature": [0.1314, ..., 0.0018, -0.0002]}
Para conjuntos de dados de pesquisa acadêmica em pequena escala, fornecemos uma simples implementação de pesquisa de KNN para facilitar o cálculo dos resultados da recordação de Top-K para pesquisa de texto em graphic e gráfico para texto (dicas: se você deseja criar uma demonstração de pesquisa no projeto, recomenda-se que a estrutura de clipes de clipes e combina a estrutura de mecanismo de origem.
Para pesquisa de texto para imagem (imagens relacionadas ao recall de texto), execute o seguinte comando:
cd Chinese-CLIP/
split=valid # 指定计算valid或test集特征
python -u cn_clip/eval/make_topk_predictions.py
--image-feats= " ${DATAPATH} /datasets/ ${dataset_name} / ${split} _imgs.img_feat.jsonl "
--text-feats= " ${DATAPATH} /datasets/ ${dataset_name} / ${split} _texts.txt_feat.jsonl "
--top-k=10
--eval-batch-size=32768
--output= " ${DATAPATH} /datasets/ ${dataset_name} / ${split} _predictions.jsonl "O resultado é salvo no arquivo JSONL especificado. Cada linha representa o ID da imagem mais importante de um recall de texto, e o formato é o seguinte:
{ "text_id" : 153915 , "image_ids" : [ 5791244 , 1009692167 , 7454547004 , 3564007203 , 38130571 , 2525270674 , 2195419145 , 2503091968 , 4966265765 , 3690431163 ]}Para pesquisa de imagem para texto (texto relacionado à recuperação da imagem), da mesma forma, execute o seguinte comando:
split=valid # 指定计算valid或test集特征
python -u cn_clip/eval/make_topk_predictions_tr.py
--image-feats= " ${DATAPATH} /datasets/ ${dataset_name} / ${split} _imgs.img_feat.jsonl "
--text-feats= " ${DATAPATH} /datasets/ ${dataset_name} / ${split} _texts.txt_feat.jsonl "
--top-k=10
--eval-batch-size=32768
--output= " ${DATAPATH} /datasets/ ${dataset_name} / ${split} _tr_predictions.jsonl "Cada linha dos resultados da saída representa o ID de texto superior de um recall de imagem, e o formato é o seguinte:
{ "image_id" : 977856234 , "text_ids" : [ 156914 , 157914 , 158914 , 155914 , 156179 , 158907 , 157179 , 154179 , 154914 , 154723 ]}Fornecemos o script de avaliação para calcular o recall@1/5/10 da tarefa de pesquisa e damos o recall médio (a média de recall@1/5/10). Execute o seguinte comando para obter a pontuação:
Para pesquisa de texto para foto, execute o comando:
split=valid # 指定计算valid或test集特征
python cn_clip/eval/evaluation.py
${DATAPATH} /datasets/ ${dataset_name} / ${split} _texts.jsonl
${DATAPATH} /datasets/ ${dataset_name} / ${split} _predictions.jsonl
output.json
cat output.jsonPara pesquisa de imagem para texto, execute o seguinte comando primeiro para converter o arquivo jsonl marcado com o formato de imagem para texto em imagem para texto:
python cn_clip/eval/transform_ir_annotation_to_tr.py
--input ${DATAPATH} /datasets/ ${dataset_name} / ${split} _texts.jsonlApós a conclusão, execute o comando:
split=valid # 指定计算valid或test集特征
python cn_clip/eval/evaluation_tr.py
${DATAPATH} /datasets/ ${dataset_name} / ${split} _texts.tr.jsonl
${DATAPATH} /datasets/ ${dataset_name} / ${split} _tr_predictions.jsonl
output.json
cat output.jsonO formato do resultado impresso será o seguinte:
{ "success" : true , "score" : 85.67 , "scoreJson" : { "score" : 85.67 , "mean_recall" : 85.67 , "r1" : 71.2 , "r5" : 90.5 , "r10" : 95.3 }}Em relação ao processo de treinamento e teste de recuperação cruzada, tomamos o conjunto de dados de pesquisa MuGE (gráficos de comércio eletrônico multimodais e desafio de texto) como exemplo, e também fornece um notebook Jupyter (link para download) que inclui todos os processos acima e pode ser executado. Todos são bem -vindos para praticá -lo.
Esta seção apresenta como usar o clipe chinês para implementar a classificação de imagem de amostra zero, levando o conjunto de dados no Benchmark Elevater como exemplo. Elevater é um conjunto de avaliação composto por vários conjuntos de dados classificados conhecidos (incluindo CIFAR-10, CIFAR-100, MNIST, etc.) para avaliar o efeito de amostra zero do modelo nesses conjuntos de dados. Em nosso experimento, preparamos uma versão chinesa do Propt, etiquetas de categoria e imagens originais para cada conjunto de dados. Consulte o documento de dados para obter detalhes para testar o modelo de clipe chinês. Para mais detalhes sobre este benchmark, clique no link. Você também pode consultar o processo que fornecemos para preparar dados e testá -los em seu próprio conjunto de dados de classificação chinesa.
Primeiro prepare os dados no seguinte formato. Como a classificação de imagem de amostra zero requer apenas testes, você só precisa preparar o conjunto de testes e os parâmetros pré-treinados do modelo e armazená-los sob o ${DATAPATH} de acordo com a seguinte estrutura de diretório:
${DATAPATH}
├── pretrained_weights/
└── datasets/
└── ${dataset_name}/
├── label_cn.txt
└── test/
├── 000/ # label id,如label个数大于10,则将其向左补零到3位数保证字典序
│ ├── image_0003.jpg # 图片样本,命名无特殊要求
│ ├── image_0005.jpg
│ └── ...
├── 001/
│ ├── image_0001.jpg
│ ├── image_0002.jpg
│ └── ...
└── 002/
├── image_0003.jpg
├── image_0005.jpg
└── ...
...
O conjunto de testes garante que os dados na pasta de teste sejam divididos de acordo com o ID correspondente ao rótulo e garante que o ID esteja em ordem de dicionário (vários dígitos acima de 10 devem ser suplementados com zeros ao label.zfill(3) , como 001, 002, etc.). label_cn.txt é um rótulo de dados com um nome de etiqueta por linha, como mostrado abaixo:
手风琴
飞机
锚
...
O ID do rótulo correspondente ao rótulo de cada linha é行号-1 , como o ID do rótulo da primeira linha, é 0, e o ID do rótulo da segunda linha é 1. Se o número total de tags for maior que 10, então o zero a três dígitos será adicionado à esquerda, por exemplo, o número de tags é 100 e o ID da TAG será 000-099 . O usuário precisa gerar a pasta correspondente para cada ID da etiqueta e colocar a amostra marcada com a etiqueta nela. Tomamos o conjunto de dados CIFAR-100 em Elevater como exemplo, clique no link para baixar os dados processados. Se você deseja tentar testar o clipe chinês em outros conjuntos de dados contidos em Elevater, consulte nossa documentação de dados.
Preparamos o script de previsão, confira run_scripts/zeroshot_eval.sh . Um exemplo de execução de um comando é o seguinte:
bash run_scripts/zeroshot_eval.sh 0
${DATAPATH} ${dataset_name}
${vision_model} ${text_model}
${ckpt_path} ${index_file}Os significados de cada parâmetro são:
0 é o ID da GPUDATAPATH Consulte a seção de preparação acima, insira o caminho correspondente de acordo com o local real.dataset_name Consulte a seção de preparação acima e digite o nome do diretório do conjunto de dados para avaliação, como cifar-100vision_model é o tipo de modelo especificado, e as opções incluem ["ViT-B-32", "ViT-B-16", "ViT-L-14", "ViT-L-14-336", "RN50", "ViT-H-14"]text_model inclui ["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese", "RBT3-chinese"]ckpt_path é o caminho completo do CKPT pré-treinadoindex_file (Opcional, apenas a avaliação do site oficial da Elevater precisa ser especificada), consulte o documento de dados Por exemplo, se você usar o modelo pré-treinado em escala Vit-B/16 para avaliar o CIFAR-100, execute ( ${DATAPATH} precisará ser substituído de acordo com as condições reais):
bash run_scripts/zeroshot_eval.sh 0
${DATAPATH} cifar-100
ViT-B-16 RoBERTa-wwm-ext-base-chinese
${DATAPATH} /pretrained_weights/clip_cn_vit-b-16.ptRetornar o resultado imprimirá a precisão do TOP-1.
Result:
zeroshot-top1: 0.6444
No CIFAR-100, deve-se esperar que o clipe chinês da escala Vit-B/16 atinja 64,4%. Para detalhes, consulte Results.md para obter detalhes sobre nossos resultados de classificação de amostra zero para outras escalas e outros conjuntos de dados.
Ao mesmo tempo, o programa também salvará um arquivo JSON para enviar o alívio oficial. O conteúdo do arquivo JSON é o seguinte:
{ "model_name" : " CN-CLIP-ViT-B-16 " , "dataset_name" : " cifar-100 " , "num_trainable_params" : 0 , "num_params" : 188262913 , "num_visual_params" : 86192640 , "num_backbone_params" : 188262913 , "n_shot" : 0 , "rnd_seeds" : [ 123 ], "predictions" : " prediction probability tensor [size: (1, 10000, 100)] " } Isso inclui as meta informações do nome do modelo model_name , DataSet Name dataset_name , Quantidade de parâmetro total num_params , quantidade visual de parâmetros de torre num_visual_params e outros modelos, bem como o resultado da saída do modelo, ou seja, o número de labels do modelo e o tamanho é [1, 样本数, 标签个数] .
Com base em nossa API de extração de recursos integrada aos Transformers Huggingface, fornecemos uma demonstração (API de inferência hospedada) que pode simplesmente tentar a classificação de imagem zero-amostra on-line no hub de modelo Huggingface. Veja os links de demonstração para cada escala de modelo abaixo. Bem -vindo para experimentar!
Se você acha que este projeto é útil, espero que você possa nos dar uma estrela e compartilhá -lo com os usuários ao seu redor. Bem -vindo à citação de trabalho relevante, obrigado pelo seu apoio!
@article{chinese-clip,
title={Chinese CLIP: Contrastive Vision-Language Pretraining in Chinese},
author={Yang, An and Pan, Junshu and Lin, Junyang and Men, Rui and Zhang, Yichang and Zhou, Jingren and Zhou, Chang},
journal={arXiv preprint arXiv:2211.01335},
year={2022}
}