Este é o repositório de código para o nosso ACL 2023 Conchtings Paper Regen: Classificação de texto com tiro zero via geração de dados de treinamento com recuperação de densa progressiva.
ATUALIZAÇÃO : check -out como melhorar o Regen usando grandes modelos de idiomas em nossa recente pré -impressão com código!
python 3.8
transformers==4.2.0
pytorch==1.8.0
scikit-learn
faiss-cpu==1.6.4
tqdm>=4.62.2
nltk
O corpus pode ser baixado em:
O conjunto de testes de {Ag News, dbpedia, yahoo, imdb} pode ser facilmente encontrado no hub de dados do HuggingFace. Os conjuntos de testes para outros conjuntos de dados podem ser fundados na pasta test .
O _id significa o ID da classe e text é o conteúdo do documento.
Exemplo (para o conjunto de dados SST-2):
{
{"_id": 0, "text": "It seems to me the film is about the art of ripping people off without ever letting them consciously know you have done so."}
{"_id": 0, "text": "In the end , the movie collapses on its shaky foundation despite the best efforts of director joe carnahan."}
{"_id": 1, "text": "Despite its title , punch-drunk love is never heavy-handed ."}
{"_id": 1, "text": "Though only 60 minutes long , the film is packed with information and impressions."}
...
}
Adaptamos o código do Coco-DR para pré-treinamento. Verifique a implementação original para obter detalhes.
Atualizado em 7 de setembro de 2023 : O modelo pré -treinado foi lançado no Huggingface:
Consulte o código da pasta retrieval , gen_embedding.sh para obter detalhes.
Consulte o código da retrieval/retrieve.py para obter detalhes.
Alguns hiperparâmetros importantes:
args.target : o conjunto de dados de destino usado no experimento.args.model : O modelo de recuperação usado neste estudo.args.corpus_folder/args.corpus_name : a pasta/nome do corpus usado (por exemplo, notícias, wiki) nos experimentos.args.topN : o topn usado na pesquisa de KNN (geralmente definido como 50-100).args.round : As rodadas de recuperação. Defina como 0 para as primeiras rodadas (usando o nome/modelo da etiqueta apenas para recuperação) e 1,2, ... para rodadas posteriores.Nota : Em princípio, nosso modelo é compatível com qualquer retrievers denso (após o treinamento adequado). Se você deseja usar seu próprio modelo denso de recuperação, verifique se o modelo de densidade de recuperação também usa a incorporação do token [CLS] como incorporação de sequência. Caso contrário, pode ser necessário modificar o código na incorporação de peças de geração para garantir que a incorporação gerada esteja correta .
Consulte o código da pasta filter . O comando de exemplo deve ser
train_cmd="CUDA_VISIBLE_DEVICES=0 python3 inference.py --task=${task}
--unlabel_file=${unlabel_file_used_for_filtering}
--data_dir=${folder_for_data}
--cache_dir="${task}/cache" --output_dir=${output_dir} --round=${round}
--load_from_prev=1
--gpu=${gpu} --eval_batch_size=${eval_batch_size}
--max_seq_len=${max_seq_len} --auto_load=0
--model_type=${model_type}"
echo $train_cmd
eval $train_cmd
Aqui
folder_for_data é a pasta dos dados recuperados.unlabel_file_used_for_filtering é o nome do arquivo dos dados recuperados.task é o nome da tarefa.model_type é o PLM usado como discriminador (por exemplo, Roberta). Consulte o código da pasta classification . O comando de exemplo deve ser
train_cmd="CUDA_VISIBLE_DEVICES=0 python3 main.py --do_train --do_eval --task=${task}
--train_file={PATH_FOR_GENERATED_DATASET}
--dev_file={PATH_FOR_GENERATED_VALID_DATASET
--test_file={PATH_FOR_TEST_DATASET
--unlabel_file=unlabeled.json
--data_dir=../datasets/${task}-${label_per_class} --train_seed=${train_seed}
--cache_dir="../datasets/${task}-${label_per_class}/cache"
--output_dir=${output_dir}
--logging_steps=${logging_steps}
--n_gpu=${n_gpu} --num_train_epochs=6
--learning_rate=2e-5 --weight_decay=1e-8
--batch_size=32 --eval_batch_size=128
--max_seq_len=128 --auto_load=1
--model_type=${model_type}"
echo $train_cmd
eval $train_cmd
É alcançado com uma maneira semelhante à etapa de recuperação anterior. Consulte o código da retrieval/retrieve.py novamente para obter detalhes. A única diferença é que você precisa definir a variável args.round como maior que 0 . Você também precisa definir o prev_retrieve_path_name e prev_retrieve_folder no caminho dos documentos para obter os resultados mais recentes da recuperação após a filtragem .
O conjunto de dados gerado pode ser encontrado neste link.
Por favor, cite nosso artigo se achar esse repositório útil para sua pesquisa. Obrigado!
@inproceedings{yu2023zero,
title={ReGen: Zero-Shot Text Classification via Training Data Generation with Progressive Dense Retrieval},
author={Yu, Yue and Zhuang, Yuchen and Zhang, Rongzhi and Meng, Yu and Shen, Jiaming and Zhang, Chao},
booktitle={Findings of ACL},
year={2023}
}