ReGen
1.0.0
這是我們ACL 2023調查結果紙的代碼存儲庫:通過培訓數據生成,零彈性文本分類,並進行漸進式密集檢索。
更新:結帳如何使用大型語言模型在我們最近使用代碼的預印本中使用大型語言模型來改善重新恢復!
python 3.8
transformers==4.2.0
pytorch==1.8.0
scikit-learn
faiss-cpu==1.6.4
tqdm>=4.62.2
nltk
語料庫可以下載:
{Ag News,Dbpedia,Yahoo,IMDB}的測試集可以在HuggingFace Data Hub上輕鬆找到。其他數據集的測試集可以在test文件夾中建立。
_id代表類ID, text是文檔的內容。
示例(對於SST-2數據集):
{
{"_id": 0, "text": "It seems to me the film is about the art of ripping people off without ever letting them consciously know you have done so."}
{"_id": 0, "text": "In the end , the movie collapses on its shaky foundation despite the best efforts of director joe carnahan."}
{"_id": 1, "text": "Despite its title , punch-drunk love is never heavy-handed ."}
{"_id": 1, "text": "Though only 60 minutes long , the film is packed with information and impressions."}
...
}
我們將Coco-DR的代碼適應預處理。請檢查原始實施以獲取詳細信息。
在2023年9月7日更新:審慎的模型已在擁抱面上發布:
有關詳細信息,請參見retrieval文件gen_embedding.sh中的代碼。
有關詳細信息,請參見retrieval/retrieve.py的代碼。
一些關鍵的超參數:
args.target :實驗中使用的目標數據集。args.model :本研究中使用的檢索模型。args.corpus_folder/args.corpus_name :實驗中使用的語料庫的文件夾/名稱(例如新聞,wiki)。args.topN :KNN搜索中使用的頂部(通常設置為50-100)。args.round :檢索回合。將第一輪設置為0(僅使用標籤名稱/模板用於檢索)和1,2,...以後的回合。注意:原則上,我們的模型與任何密集的檢索器兼容(經過適當的培訓)。如果您想使用自己的密集檢索模型,請確保密集的檢索模型還將[Cls]令牌的嵌入作為序列嵌入。否則,您可能需要在嵌入生成零件中修改代碼,以確保生成的嵌入是正確的。
從filter ”文件夾中查看代碼。示例命令應該是
train_cmd="CUDA_VISIBLE_DEVICES=0 python3 inference.py --task=${task}
--unlabel_file=${unlabel_file_used_for_filtering}
--data_dir=${folder_for_data}
--cache_dir="${task}/cache" --output_dir=${output_dir} --round=${round}
--load_from_prev=1
--gpu=${gpu} --eval_batch_size=${eval_batch_size}
--max_seq_len=${max_seq_len} --auto_load=0
--model_type=${model_type}"
echo $train_cmd
eval $train_cmd
這裡
folder_for_data是檢索到的數據的文件夾。unlabel_file_used_for_filtering是檢索到的數據的文件名。task是任務的名稱。model_type是用作鑑別器(例如Roberta)的PLM。從classification文件夾中查看代碼。示例命令應該是
train_cmd="CUDA_VISIBLE_DEVICES=0 python3 main.py --do_train --do_eval --task=${task}
--train_file={PATH_FOR_GENERATED_DATASET}
--dev_file={PATH_FOR_GENERATED_VALID_DATASET
--test_file={PATH_FOR_TEST_DATASET
--unlabel_file=unlabeled.json
--data_dir=../datasets/${task}-${label_per_class} --train_seed=${train_seed}
--cache_dir="../datasets/${task}-${label_per_class}/cache"
--output_dir=${output_dir}
--logging_steps=${logging_steps}
--n_gpu=${n_gpu} --num_train_epochs=6
--learning_rate=2e-5 --weight_decay=1e-8
--batch_size=32 --eval_batch_size=128
--max_seq_len=128 --auto_load=1
--model_type=${model_type}"
echo $train_cmd
eval $train_cmd
它的實現方式與以前的檢索步驟相似。有關詳細信息,請參閱retrieval/retrieve.py的代碼。唯一的區別是,您需要將變量args.round設置為大於0 。您還需要將prev_retrieve_path_name和prev_retrieve_folder設置為文檔的路徑,以進行過濾後的最新檢索結果。
生成的數據集可以在此鏈接上找到。
如果您發現此回購對您的研究有用,請請我們的論文。謝謝!
@inproceedings{yu2023zero,
title={ReGen: Zero-Shot Text Classification via Training Data Generation with Progressive Dense Retrieval},
author={Yu, Yue and Zhuang, Yuchen and Zhang, Rongzhi and Meng, Yu and Shen, Jiaming and Zhang, Chao},
booktitle={Findings of ACL},
year={2023}
}