Это кодовое репо для нашего ACL 2023 выводы Paper Regen: классификация текста с нулевым выстрелом посредством генерации обучающих данных с прогрессивным плотным поиском.
ОБНОВЛЕНИЕ : Проверьте, как улучшить реген с использованием больших языковых моделей в нашем недавнем препринте с кодом!
python 3.8
transformers==4.2.0
pytorch==1.8.0
scikit-learn
faiss-cpu==1.6.4
tqdm>=4.62.2
nltk
Корпус можно загрузить по адресу:
Тестовый набор {Ag News, Dbpedia, Yahoo, IMDB} можно легко найти в Hub Data Hub HuggingFace. Тестовые наборы для других наборов данных могут быть основаны в test папке.
_id означает идентификатор класса, а text - это содержание документа.
Пример (для набора данных SST-2):
{
{"_id": 0, "text": "It seems to me the film is about the art of ripping people off without ever letting them consciously know you have done so."}
{"_id": 0, "text": "In the end , the movie collapses on its shaky foundation despite the best efforts of director joe carnahan."}
{"_id": 1, "text": "Despite its title , punch-drunk love is never heavy-handed ."}
{"_id": 1, "text": "Though only 60 minutes long , the film is packed with information and impressions."}
...
}
Мы адаптируем код из Coco-DR для предварительной подготовки. Пожалуйста, проверьте оригинальную реализацию для получения подробной информации.
Обновлено на 7 сентября 2023 года : предварительная модель была выпущена на Huggingface:
См. Код из папки retrieval , gen_embedding.sh для получения подробной информации.
Смотрите код из retrieval/retrieve.py для получения подробной информации.
Некоторые ключевые гиперпараметры:
args.target : целевой набор данных, используемый в эксперименте.args.model : модель поиска, используемая в этом исследовании.args.corpus_folder/args.corpus_name : папка/имя используемого корпуса (например, viki) в экспериментах.args.topN : topn, используемый в поиске KNN (обычно устанавливается на 50-100).args.round : раунды поиска. Установите на 0 для первых раундов (с использованием названия/шаблона метки только для поиска) и 1,2, ... для последующих раундов.Примечание . В принципе, наша модель совместима с любыми плотными ретриверами (после надлежащего обучения). Если вы хотите использовать свою собственную плотную модель поиска, пожалуйста, убедитесь, что плотная модель поиска также использует встраивание токена [CLS] в качестве последовательности. В противном случае вам может потребоваться изменить код в частях встраивания генерации, чтобы убедиться, что сгенерированное внедрение является правильным .
Смотрите код из папки filter . Пример команда должна быть
train_cmd="CUDA_VISIBLE_DEVICES=0 python3 inference.py --task=${task}
--unlabel_file=${unlabel_file_used_for_filtering}
--data_dir=${folder_for_data}
--cache_dir="${task}/cache" --output_dir=${output_dir} --round=${round}
--load_from_prev=1
--gpu=${gpu} --eval_batch_size=${eval_batch_size}
--max_seq_len=${max_seq_len} --auto_load=0
--model_type=${model_type}"
echo $train_cmd
eval $train_cmd
Здесь
folder_for_data - это папка полученных данных.unlabel_file_used_for_filtering - это имя файла полученных данных.task - это имя задачи.model_type - это PLM, используемый в качестве дискриминатора (например, Роберта). См. Код из папки classification . Пример команда должна быть
train_cmd="CUDA_VISIBLE_DEVICES=0 python3 main.py --do_train --do_eval --task=${task}
--train_file={PATH_FOR_GENERATED_DATASET}
--dev_file={PATH_FOR_GENERATED_VALID_DATASET
--test_file={PATH_FOR_TEST_DATASET
--unlabel_file=unlabeled.json
--data_dir=../datasets/${task}-${label_per_class} --train_seed=${train_seed}
--cache_dir="../datasets/${task}-${label_per_class}/cache"
--output_dir=${output_dir}
--logging_steps=${logging_steps}
--n_gpu=${n_gpu} --num_train_epochs=6
--learning_rate=2e-5 --weight_decay=1e-8
--batch_size=32 --eval_batch_size=128
--max_seq_len=128 --auto_load=1
--model_type=${model_type}"
echo $train_cmd
eval $train_cmd
Это достигается с помощью предыдущего шага поиска. Смотрите код из retrieval/retrieve.py снова для получения подробной информации. Единственное отличие состоит в том, что вам необходимо установить переменную args.round больше, чем 0 . Вам также необходимо установить prev_retrieve_path_name и prev_retrieve_folder на путь документов для последних результатов поиска после фильтрации .
Сгенерированный набор данных можно найти по этой ссылке.
Пожалуйста, процитируйте нашу статью, если вы найдете это репо полезным для вашего исследования. Спасибо!
@inproceedings{yu2023zero,
title={ReGen: Zero-Shot Text Classification via Training Data Generation with Progressive Dense Retrieval},
author={Yu, Yue and Zhuang, Yuchen and Zhang, Rongzhi and Meng, Yu and Shen, Jiaming and Zhang, Chao},
booktitle={Findings of ACL},
year={2023}
}