ReGen
1.0.0
这是我们ACL 2023调查结果纸的代码存储库:通过培训数据生成,零弹性文本分类,并进行渐进式密集检索。
更新:结帐如何使用大型语言模型在我们最近使用代码的预印本中使用大型语言模型来改善重新恢复!
python 3.8
transformers==4.2.0
pytorch==1.8.0
scikit-learn
faiss-cpu==1.6.4
tqdm>=4.62.2
nltk
语料库可以下载:
{Ag News,Dbpedia,Yahoo,IMDB}的测试集可以在HuggingFace Data Hub上轻松找到。其他数据集的测试集可以在test文件夹中建立。
_id代表类ID, text是文档的内容。
示例(对于SST-2数据集):
{
{"_id": 0, "text": "It seems to me the film is about the art of ripping people off without ever letting them consciously know you have done so."}
{"_id": 0, "text": "In the end , the movie collapses on its shaky foundation despite the best efforts of director joe carnahan."}
{"_id": 1, "text": "Despite its title , punch-drunk love is never heavy-handed ."}
{"_id": 1, "text": "Though only 60 minutes long , the film is packed with information and impressions."}
...
}
我们将Coco-DR的代码适应预处理。请检查原始实施以获取详细信息。
在2023年9月7日更新:审慎的模型已在拥抱面上发布:
有关详细信息,请参见retrieval文件gen_embedding.sh中的代码。
有关详细信息,请参见retrieval/retrieve.py的代码。
一些关键的超参数:
args.target :实验中使用的目标数据集。args.model :本研究中使用的检索模型。args.corpus_folder/args.corpus_name :实验中使用的语料库的文件夹/名称(例如新闻,wiki)。args.topN :KNN搜索中使用的顶部(通常设置为50-100)。args.round :检索回合。将第一轮设置为0(仅使用标签名称/模板用于检索)和1,2,...以后的回合。注意:原则上,我们的模型与任何密集的检索器兼容(经过适当的培训)。如果您想使用自己的密集检索模型,请确保密集的检索模型还将[Cls]令牌的嵌入作为序列嵌入。否则,您可能需要在嵌入生成零件中修改代码,以确保生成的嵌入是正确的。
从filter ”文件夹中查看代码。示例命令应该是
train_cmd="CUDA_VISIBLE_DEVICES=0 python3 inference.py --task=${task}
--unlabel_file=${unlabel_file_used_for_filtering}
--data_dir=${folder_for_data}
--cache_dir="${task}/cache" --output_dir=${output_dir} --round=${round}
--load_from_prev=1
--gpu=${gpu} --eval_batch_size=${eval_batch_size}
--max_seq_len=${max_seq_len} --auto_load=0
--model_type=${model_type}"
echo $train_cmd
eval $train_cmd
这里
folder_for_data是检索到的数据的文件夹。unlabel_file_used_for_filtering是检索到的数据的文件名。task是任务的名称。model_type是用作鉴别器(例如Roberta)的PLM。从classification文件夹中查看代码。示例命令应该是
train_cmd="CUDA_VISIBLE_DEVICES=0 python3 main.py --do_train --do_eval --task=${task}
--train_file={PATH_FOR_GENERATED_DATASET}
--dev_file={PATH_FOR_GENERATED_VALID_DATASET
--test_file={PATH_FOR_TEST_DATASET
--unlabel_file=unlabeled.json
--data_dir=../datasets/${task}-${label_per_class} --train_seed=${train_seed}
--cache_dir="../datasets/${task}-${label_per_class}/cache"
--output_dir=${output_dir}
--logging_steps=${logging_steps}
--n_gpu=${n_gpu} --num_train_epochs=6
--learning_rate=2e-5 --weight_decay=1e-8
--batch_size=32 --eval_batch_size=128
--max_seq_len=128 --auto_load=1
--model_type=${model_type}"
echo $train_cmd
eval $train_cmd
它的实现方式与以前的检索步骤相似。有关详细信息,请参阅retrieval/retrieve.py的代码。唯一的区别是,您需要将变量args.round设置为大于0 。您还需要将prev_retrieve_path_name和prev_retrieve_folder设置为文档的路径,以进行过滤后的最新检索结果。
生成的数据集可以在此链接上找到。
如果您发现此回购对您的研究有用,请请我们的论文。谢谢!
@inproceedings{yu2023zero,
title={ReGen: Zero-Shot Text Classification via Training Data Generation with Progressive Dense Retrieval},
author={Yu, Yue and Zhuang, Yuchen and Zhang, Rongzhi and Meng, Yu and Shen, Jiaming and Zhang, Chao},
booktitle={Findings of ACL},
year={2023}
}