Codebase for RetroMAE and beyond.
We have uploaded some checkpoints to Huggingface Hub.
| Model | Description | Link |
|---|---|---|
| RetroMAE | Pre-trianed on the wikipedia and bookcorpus | Shitao/RetroMAE |
| RetroMAE_MSMARCO | Pre-trianed on the MSMARCO passage | Shitao/RetroMAE_MSMARCO |
| RetroMAE_MSMARCO_finetune | Finetune the RetroMAE_MSMARCO on the MSMARCO passage data | Shitao/RetroMAE_MSMARCO_finetune |
| RetroMAE_MSMARCO_distill | Finetune the RetroMAE_MSMARCO on the MSMARCO passage data by minimizing the KL-divergence with the cross-encoder | Shitao/RetroMAE_MSMARCO_distill |
| RetroMAE_BEIR | Finetune the RetroMAE on the MSMARCO passage data for BEIR (use the official negatives provided by BEIR) | Shitao/RetroMAE_BEIR |
You can load them easily using the identifier strings. For example:
from transformers import AutoModel
model = AutoModel.from_pretrained('Shitao/RetroMAE')RetroMAE can provide a strong initialization of dense retriever; after fine-tuned with in-domain data, it gives rise to a high-quality supervised retrieval performance in the corresponding scenario. Besides, It substantially improves the pre-trained model's transferability, which helps to result in superior zero-shot performances on out-of-domain datasets.
| Model | MRR@10 | Recall@1000 |
|---|---|---|
| Bert | 0.346 | 0.964 |
| RetroMAE | 0.382 | 0.981 |
| Model | MRR@10 | Recall@1000 |
|---|---|---|
| coCondenser | 0.382 | 0.984 |
| RetroMAE | 0.393 | 0.985 |
| RetroMAE(distillation) | 0.416 | 0.988 |
| Model | Avg NDCG@10 (18 datasets) |
|---|---|
| Bert | 0.371 |
| Condenser | 0.407 |
| RetroMAE | 0.452 |
| RetroMAE v2 | 0.491 |
git clone https://github.com/staoxiao/RetroMAE.git
cd RetroMAE
pip install .
For development, install as editable:
pip install -e .
This repo includes two functions: pre-train and finetune. Firstly, train the RetroMAE on general dataset (or downstream dataset) with mask language modeling loss. Then finetune the RetroMAE on downstream dataset with contrastive loss. To achieve a better performance, you also can finetune the RetroMAE by distillation the scores provided by cross-encoder. Detailed workflow please refer to our examples.
torchrun --nproc_per_node 8
-m pretrain.run
--output_dir {path to save ckpt}
--data_dir {your data}
--do_train True
--model_name_or_path bert-base-uncased
--pretrain_method {retromae or dupmae}
torchrun --nproc_per_node 8
-m bi_encoder.run
--output_dir {path to save ckpt}
--model_name_or_path Shitao/RetroMAE
--do_train
--corpus_file ./data/BertTokenizer_data/corpus
--train_query_file ./data/BertTokenizer_data/train_query
--train_qrels ./data/BertTokenizer_data/train_qrels.txt
--neg_file ./data/train_negs.tsv
If you find our work helpful, please consider citing us:
@inproceedings{RetroMAE,
title={RetroMAE: Pre-Training Retrieval-oriented Language Models Via Masked Auto-Encoder},
author={Shitao Xiao, Zheng Liu, Yingxia Shao, Zhao Cao},
url={https://arxiv.org/abs/2205.12035},
booktitle ={EMNLP},
year={2022},
}