Ce référentiel met en œuvre l'augmentation des cas de commutation et la récupération négative dure à partir de l'article "améliorant l'apprentissage contrastif des intégres de phrases avec des positifs et des négatifs récupérés de cas". La combinaison des deux approches avec SIMCSE mène au modèle appelé apprentissage contrastif avec des données augmentées et récupérées pour l'intégration de la phrase (cartes).
Tableau 1. Exemple de phrases d'échantillons à commutation de cas et récupérées.
| Taper | Phrase |
|---|---|
| Original | L'histoire du premier livre se poursuit. |
| À commutation de cas | L'histoire du premier livre se poursuit. |
| Récupéré | L'histoire commence comme une histoire d'amour typique. |
| Aléatoire | Ceci est considéré comme un résultat temporaire. |
Tableau 2. Performance sur les tâches d'intégration des phrases
| Pré-formation | Réglage fin | STS12 | STS13 | STS14 | STS15 | STS16 | STSB | Malade-r | Avg. |
|---|---|---|---|---|---|---|---|---|---|
| base de Roberta | Cartes simcse + | 72.65 | 84.26 | 76,52 | 82.98 | 82.73 | 82.04 | 70,66 | 78.83 |
| plus grand | Cartes simcse + | 74.63 | 86.27 | 79.25 | 85,93 | 83.17 | 83.86 | 72.77 | 80.84 |
Lien de téléchargement: cartes-Roberta-base (téléchargement, 440 Mo), cartes-Roberta-large (téléchargement, 1,23 Go).
Tableau 3. Performance sur les tâches de colle
| Pré-formation | Réglage fin | Mnli-m | QQP | QNLI | SST-2 | Cola | STS-B | MRPC | Rte | Avg. |
|---|---|---|---|---|---|---|---|---|---|---|
| Debertav2-xxlarge | R-drop + switch-case | 92.0 | 93.0 | 96.3 | 97.2 | 75,5 | 93.6 | 93.9 | 94.2 | 91.7 |
Ce repo est construit sur la base des transformateurs HuggingFace et SimCSE. Voir exigences.txt pour les versions de package.
# 1. Download wiki-1m dataset:
# - use wget -P target_folder in data/datasets/download_wiki.sh, and run
bash data/datasets/download_wiki.sh
# - modify train_file in scripts/bert/run_simcse_pretraining_v2.sh
# 2. preprocess wiki-1m dataset for negative retrieval
# - deduplicate the wiki-1m dataset, and (optionally) remove sentences with less than three words
# - modify paths in data/datasets/simcse_utils.py then run it to get model representations for all sentences in dataset
python data/datasets/simcse_utils.py
# 3. Download SentEval evaluation data:
# - use wget -P target_folder in data/datasets/download_senteval.sh, and run
bash data/datasets/download_senteval.shAvant d'exécuter le code, l'utilisateur peut devoir modifier le point de contrôle du modèle par défaut et les chemins d'E / S, y compris:
scripts/bert/run_simcse_grid.sh : ligne 42-50 (Train_file, train_file_dedupl (facultatif), output_dir, tensorboard_dir, sent_rep_cache_file, senteval_data_dir)scripts/bert/run_simcse_pretraining.sh : ligne 17-20 (Train_file, output_dir, tensorboard_dir, Senteval_data_dir), ligne 45 (send_rep_cache_files), ligne 166-213 (Model_Name_Or_Path, Config_Name). # MUST cd to the folder which contains data/, examples/, models/, scripts/, training/ and utils/
cd YOUR_CARDS_WORKING_DIRECTORY
# roberta-base
new_train_file=path_to_wiki1m
sent_rep_cache_file=path_to_sentence_representation_file # generated by data/datasets/simcse_utils.py
# run a model with a single set of hyper-parameters
# when running the model for the very first time, need to add overwrite_cache=True, this will produce a processed training data cache.
bash scripts/bert/run_simcse_grid.sh
model_type=roberta model_size=base
cuda=0,1,2,3 seed=42 learning_rate=4e-5
new_train_file= ${new_train_file} sent_rep_cache_file= ${sent_rep_cache_file}
dyn_knn=65 sample_k=1 knn_metric=cos
switch_case_probability=0.05 switch_case_method=v2
print_only=False
# grid-search on hyper-parameters
bash scripts/bert/run_simcse_grid.sh
model_type=roberta model_size=base
cuda=0,1,2,3 seed=42 learning_rate=1e-5,2e-5,4e-5
new_train_file= ${new_train_file} sent_rep_cache_file= ${sent_rep_cache_file}
dyn_knn=0,9,65 sample_k=1 knn_metric=cos
switch_case_probability=0,0.05,0.1,0.15 switch_case_method=v2
print_only=False
# roberta-large
bash scripts/bert/run_simcse_grid.sh
model_type=roberta model_size=large
cuda=0,1,2,3 seed=42 learning_rate=7.5e-6
new_train_file= ${new_train_file} sent_rep_cache_file= ${sent_rep_cache_file}
dyn_knn=9 sample_k=1 knn_metric=cos
switch_case_probability=0.1 switch_case_method=v1
print_only=False # provide train_file, output_dir, tensorboard_dir if different to the default values
model_name=name_of_saved_mdoel # e.g., roberta_large_bs128x4_lr2e-5_switchcase0.1_v2
bash ./scripts/bert/run_simcse_pretraining.sh
model_name_or_path= ${output_dir} / ${model_name} model_name= ${model_name} config_name= ${output_dir} / ${model_name} /config.json
train_file= ${train_file} output_dir= ${output_dir} /test_only tensorboard_dir= ${tensorboard_dir}
model_type=roberta model_size=base do_train=False
cuda=0 ngpu=1Pour des raisons inconnues, l'ensemble des bons hyper-paramètres de modèle était différent lors de la travail avec les transformateurs HuggingFace V4.11.3 et v4.15.0. Les hyper-paramètres énumérés ci-dessus ont été recherchés sur les transformateurs 4.11.3.
@inproceedings{cards,
title = "Improving Contrastive Learning of Sentence Embeddings with Case-Augmented Positives and Retrieved Negatives",
author = "Wei Wang and Liangzhu Ge and Jingqiao Zhang and Cheng Yang",
booktitle = "The 45th International ACM SIGIR Conference on Research and Development in Information Retrieval (SIGIR)",
year = "2022"
}