Este es el repositorio de código para nuestro Regen de documento de hallazgos de ACL 2023: clasificación de texto de disparo cero a través de la generación de datos de capacitación con recuperación progresiva densa.
ACTUALIZACIÓN : ¡Consulte cómo mejorar la regen utilizando modelos de idiomas grandes en nuestra preimpresión reciente con código!
python 3.8
transformers==4.2.0
pytorch==1.8.0
scikit-learn
faiss-cpu==1.6.4
tqdm>=4.62.2
nltk
El corpus se puede descargar en:
El conjunto de pruebas de {AG News, Dbpedia, Yahoo, IMDB} se puede encontrar fácilmente en Huggingface Data Hub. Los conjuntos de pruebas para otros conjuntos de datos se pueden fundar en la carpeta test .
El _id significa la ID de clase, y text es el contenido del documento.
Ejemplo (para el conjunto de datos SST-2):
{
{"_id": 0, "text": "It seems to me the film is about the art of ripping people off without ever letting them consciously know you have done so."}
{"_id": 0, "text": "In the end , the movie collapses on its shaky foundation despite the best efforts of director joe carnahan."}
{"_id": 1, "text": "Despite its title , punch-drunk love is never heavy-handed ."}
{"_id": 1, "text": "Though only 60 minutes long , the film is packed with information and impressions."}
...
}
Adaptamos el código de COCO-DR para el pretrénmente. Consulte la implementación original para obtener más detalles.
Actualizado el 7 de septiembre de 2023 : el modelo previamente se ha lanzado en la cara de Hugging:
Vea el código de la carpeta retrieval , gen_embedding.sh para más detalles.
Consulte el código de retrieval/retrieve.py para más detalles.
Algunos hiperparámetros clave:
args.target : el conjunto de datos de destino utilizado en el experimento.args.model : El modelo de recuperación utilizado en este estudio.args.corpus_folder/args.corpus_name : la carpeta/nombre del corpus utilizado (por ejemplo, noticias, wiki) en los experimentos.args.topN : El TOPN utilizado en KNN Search (generalmente establecido en 50-100).args.round : Las rondas de recuperación. Establecer en 0 para las primeras rondas (usando el nombre/plantilla de la etiqueta solo para recuperación) y 1,2, ... para rondas posteriores.NOTA : En principio, nuestro modelo es compatible con cualquier retriever denso (después de un entrenamiento adecuado). Si desea utilizar su propio modelo de recuperación densa, asegúrese de que el modelo de recuperación densa también use la incrustación del token [CLS] como incrustaciones de secuencia. De lo contrario, es posible que deba modificar el código en la incrustación de piezas de generación para asegurarse de que la incrustación generada sea correcta .
Vea el código desde la carpeta filter . El comando de ejemplo debe ser
train_cmd="CUDA_VISIBLE_DEVICES=0 python3 inference.py --task=${task}
--unlabel_file=${unlabel_file_used_for_filtering}
--data_dir=${folder_for_data}
--cache_dir="${task}/cache" --output_dir=${output_dir} --round=${round}
--load_from_prev=1
--gpu=${gpu} --eval_batch_size=${eval_batch_size}
--max_seq_len=${max_seq_len} --auto_load=0
--model_type=${model_type}"
echo $train_cmd
eval $train_cmd
Aquí
folder_for_data es la carpeta de los datos recuperados.unlabel_file_used_for_filtering es el nombre de archivo de los datos recuperados.task es el nombre de la tarea.model_type es el PLM utilizado como discriminador (por ejemplo, Roberta). Vea el código de la carpeta classification . El comando de ejemplo debe ser
train_cmd="CUDA_VISIBLE_DEVICES=0 python3 main.py --do_train --do_eval --task=${task}
--train_file={PATH_FOR_GENERATED_DATASET}
--dev_file={PATH_FOR_GENERATED_VALID_DATASET
--test_file={PATH_FOR_TEST_DATASET
--unlabel_file=unlabeled.json
--data_dir=../datasets/${task}-${label_per_class} --train_seed=${train_seed}
--cache_dir="../datasets/${task}-${label_per_class}/cache"
--output_dir=${output_dir}
--logging_steps=${logging_steps}
--n_gpu=${n_gpu} --num_train_epochs=6
--learning_rate=2e-5 --weight_decay=1e-8
--batch_size=32 --eval_batch_size=128
--max_seq_len=128 --auto_load=1
--model_type=${model_type}"
echo $train_cmd
eval $train_cmd
Se logra de manera similar al paso de recuperación anterior. Consulte el código de retrieval/retrieve.py nuevamente para más detalles. La única diferencia es que necesita establecer la variable args.round en mayor que 0 . También debe establecer el prev_retrieve_path_name y prev_retrieve_folder en la ruta de los documentos para los últimos resultados de recuperación después del filtrado .
El conjunto de datos generado se puede encontrar en este enlace.
Por favor, cita nuestro documento si encuentra este repositorio útil para su investigación. ¡Gracias!
@inproceedings{yu2023zero,
title={ReGen: Zero-Shot Text Classification via Training Data Generation with Progressive Dense Retrieval},
author={Yu, Yue and Zhuang, Yuchen and Zhang, Rongzhi and Meng, Yu and Shen, Jiaming and Zhang, Chao},
booktitle={Findings of ACL},
year={2023}
}