Ceci est le référentiel de code pour notre régénération du papier ACL 2023: Classification de texte à tirs zéro via la génération de données de formation avec une récupération dense progressive.
MISE À JOUR : Vérifiez comment améliorer Regen en utilisant de grands modèles de langage dans notre récent préimprimée avec le code!
python 3.8
transformers==4.2.0
pytorch==1.8.0
scikit-learn
faiss-cpu==1.6.4
tqdm>=4.62.2
nltk
Le corpus peut être téléchargé à:
L'ensemble de test de {Ag News, Dbpedia, Yahoo, IMDB} peut être facilement trouvé sur HuggingFace Data Hub. Les ensembles de tests pour d'autres ensembles de données peuvent être fondés dans le dossier test .
Le _id représente l'ID de classe et text est le contenu du document.
Exemple (pour l'ensemble de données SST-2):
{
{"_id": 0, "text": "It seems to me the film is about the art of ripping people off without ever letting them consciously know you have done so."}
{"_id": 0, "text": "In the end , the movie collapses on its shaky foundation despite the best efforts of director joe carnahan."}
{"_id": 1, "text": "Despite its title , punch-drunk love is never heavy-handed ."}
{"_id": 1, "text": "Though only 60 minutes long , the film is packed with information and impressions."}
...
}
Nous adaptons le code de Coco-Dr pour pré-formation. Veuillez vérifier la mise en œuvre d'origine pour plus de détails.
Mis à jour le 7 septembre 2023 : Le modèle pré-entraîné a été publié sur le Huggingface:
Voir le code du dossier retrieval , gen_embedding.sh pour plus de détails.
Voir le code de retrieval/retrieve.py pour plus de détails.
Quelques hyperparamètres clés:
args.target : l'ensemble de données cible utilisé dans l'expérience.args.model : Le modèle de récupération utilisé dans cette étude.args.corpus_folder/args.corpus_name : le dossier / nom du corpus utilisé (par exemple, News, wiki) dans les expériences.args.topN : le topn utilisé dans la recherche KNN (généralement réglé sur 50-100).args.round : les tournées de récupération. Réglé sur 0 pour les premiers tours (en utilisant le nom / modèle d'étiquette pour la récupération uniquement) et 1,2, ... pour les tours ultérieurs.Remarque : En principe, notre modèle est compatible avec tous les retrievers denses (après une formation correctement). Si vous souhaitez utiliser votre propre modèle de récupération dense, assurez-vous que le modèle de récupération dense utilise également l'incorporation de jeton [CLS] comme incorporation de séquence. Sinon, vous devrez peut-être modifier le code en pièces de génération d'intégration pour vous assurer que l'intégration générée est correcte .
Voir le code du dossier filter . L'exemple de commande doit être
train_cmd="CUDA_VISIBLE_DEVICES=0 python3 inference.py --task=${task}
--unlabel_file=${unlabel_file_used_for_filtering}
--data_dir=${folder_for_data}
--cache_dir="${task}/cache" --output_dir=${output_dir} --round=${round}
--load_from_prev=1
--gpu=${gpu} --eval_batch_size=${eval_batch_size}
--max_seq_len=${max_seq_len} --auto_load=0
--model_type=${model_type}"
echo $train_cmd
eval $train_cmd
Ici
folder_for_data est le dossier des données récupérées.unlabel_file_used_for_filtering est le nom de fichier des données récupérées.task est le nom de la tâche.model_type est le PLM utilisé comme discriminateur (par exemple Roberta). Voir le code du dossier classification . L'exemple de commande doit être
train_cmd="CUDA_VISIBLE_DEVICES=0 python3 main.py --do_train --do_eval --task=${task}
--train_file={PATH_FOR_GENERATED_DATASET}
--dev_file={PATH_FOR_GENERATED_VALID_DATASET
--test_file={PATH_FOR_TEST_DATASET
--unlabel_file=unlabeled.json
--data_dir=../datasets/${task}-${label_per_class} --train_seed=${train_seed}
--cache_dir="../datasets/${task}-${label_per_class}/cache"
--output_dir=${output_dir}
--logging_steps=${logging_steps}
--n_gpu=${n_gpu} --num_train_epochs=6
--learning_rate=2e-5 --weight_decay=1e-8
--batch_size=32 --eval_batch_size=128
--max_seq_len=128 --auto_load=1
--model_type=${model_type}"
echo $train_cmd
eval $train_cmd
Il est réalisé avec un moyen similaire à l'étape de récupération précédente. Voir le code de retrieval/retrieve.py à nouveau pour plus de détails. La seule différence est que vous devez définir la variable args.round à plus de 0 . Vous devez également définir le prev_retrieve_path_name et prev_retrieve_folder sur le chemin des documents pour les derniers résultats de récupération après le filtrage .
L'ensemble de données généré peut être trouvé sur ce lien.
Veuillez citer notre article si vous trouvez ce repo utile pour vos recherches. Merci!
@inproceedings{yu2023zero,
title={ReGen: Zero-Shot Text Classification via Training Data Generation with Progressive Dense Retrieval},
author={Yu, Yue and Zhuang, Yuchen and Zhang, Rongzhi and Meng, Yu and Shen, Jiaming and Zhang, Chao},
booktitle={Findings of ACL},
year={2023}
}