该存储库包含源代码和经过训练的模型,用于大规模的对话响应生成模型。人类评估结果表明,在单弯对话测试下,对话产生的响应与人类反应质量相当。
该存储库基于HuggingFace Pytorch-Transformer和OpenAI GPT-2,其中包含数据提取脚本,型号训练代码以及预处理的小型(117m)培养基(345m)和大型(762m)模型检查点。
该模型经过Reddit讨论线程的1.47亿多转向对话的培训。最大的型号可以在几个小时内使用8 V100机器进行培训(但这不是必需的),并具有分布式培训和FP16选项。
包含脚本可用于重现DSTC-7接地对话生成挑战的结果,以及从Reddit数据创建的6K多参考数据集。
项目网页:https://www.microsoft.com/en-us/research/project/large-scale-pretraining-for-response-generation/
arxiv纸:https://arxiv.org/abs/1911.00536
(更新07/09/2022)更改files.pushshift.io/reddit服务器导致我们的数据生成管道破裂。这些问题现在已经解决,以下数据准备小节中解释的步骤应再次起作用。数据在大约10个小时内使用8个过程( -j 8 )生成,并且需要800GB的临时磁盘空间。
(更新06/23/2021)我们发布了Dialogpt(retgen)的检索效果/接地版本,请查看Retgen Repo和Retgen Paper
(更新05/20/2021)Prakhar Mishra在YouTube上进行对话的很棒的视频演练
(更新03/31/2021)AK391使用Gradio Web演示试验的第三方演示
(更新09/15/2020)已发布了一组大型对话框排名模型!
通过与我们的最新对话排名模型集成,对话的生成得到了改善
(更新07/08/2020)已发布6K多REF测试集!
为了生成数据,请求运行demo.py并将数据选项设置为“完整”,生成的6k多REF测试集将位于
./data/test.refs.txt
(更新03/10/2020)HuggingFace Transformers中可用的型号卡!
请在HuggingFace Transformers存储库中查看我们的型号卡。有了几行代码,与对话进行交互式播放应该非常直截了当。
小型型号:https://huggingface.co/microsoft/dialogpt-small
中型型号:https://huggingface.co/microsoft/dialogpt-medium
大型型号:https://huggingface.co/microsoft/dialogpt-large
(新)排名模型:https://huggingface.co/microsoft/dialogrpt-updown
(更新01/06/2020)一些第三方解码脚本实现:
Dialogpt完全是在Ubuntu 16.04上开发的,并且 - 取决于我们的可用性 - 如果您在同一配置上运行代码的困难,我们会尝试提供支持。但是,我们无法为其他分布或操作系统提供支持。代码的部分可能会在其他UNIX口味(MACOS,Windows子系统,linux,Cygwin等)上运行,但建议将Ubuntu用于主培训代码。
训练代码可以在CPU上运行,但可能会很慢。我们建议使用GPU训练和Finetune所有型号。 GPU的数量没有最小的限制。但是,如果使用分布式列车进行多种GPU配置,则加速与GPU的数量大致是亚线性的。要在使用较少的GPU时模拟相同的批次化,请在模型培训中使用较大的gradient_accumulation_steps 。
1,17m和345m的模型可以加载在具有12G内存的单个GPU中。 762m的模型将需要一个大于16G内存的GPU才能进行有效的训练。具有50m培训实例和V100 GPU的基准数据的训练速度:
| n_gpu | 时期时间(h) | 令牌/秒 |
|---|---|---|
| 1 | 118 | 10847 |
| 2 | 62 | 20645 |
| 4 | 34 | 37647 |
| 8 | 18 | 71356 |
我们在新数据集上进行的验证模型进行微调通常需要1-2个时代。
我们创建了一个演示脚本demo.py ,以减轻该系统部署的困难。 demo.py包含一个模型下载,数据提取,数据预处理和模型培训的管道,对一个命令行中的虚拟数据集进行了培训。
请使用以下命令来克隆,安装要求并加载Conda环境(请注意,需要NVIDIA CUDA 10.0开发人员工具包):
sudo apt-get install -y make wget gzip bzip2 xz-utils zstd sedgit clone https://github.com/microsoft/DialoGPT.git
cd DialoGPT
conda env create -f LSP-linux.yml -n LSP
conda activate LSP如果您在Linux以外的其他架构上运行此操作,请使用LSP-generic.yml代替LSP-linux.yml ,但请注意,通用型号并未在所有平台中测试,因此无法稳定性。要使用FP16培训,请通过以下命令安装APEX
conda activate LSP
git clone https://github.com/NVIDIA/apex
cd apex
git reset --hard 3d01e4a0a188cc8df54bc6e44cf5eb40ff6b4cc5
pip install -v --no-cache-dir --global-option= " --cpp_ext " --global-option= " --cuda_ext " .
python3.6 demo.py首先,首先从他们的官方存储店中安装Docker和Nvidia-Docker。运行代码的图像环境可以如下加载:
Nvidia-Docker V2。 *
$ docker run --gpus all --ipc=host --rm -it -v $PWD :/workspace --network=host icaruszyz/large-scale-training:dialogpt bashNvidia-Docker V1。 *
$ nvidia-docker --rm -it -v $PWD :/workspace --network=host icaruszyz/large-scale-training:dialogpt bash在Docker容器中,运行
python demo.py本节解释了demo.py中的所有组件。
在运行demo.py之前,您可以在demo.py中设置data_folder (默认值./models ),作为要下载所有数据和预审计/微调模型的位置。然后简单地运行
python demo.py到
请注意,默认情况下, demo.py将使用虚拟数据,请使用选项--data指定REDDIT培训数据。提供三个选项: dummy , small而full 。
python demo.py --data small
python demo.py --data full小的REDDIT数据约为140MB,完整的REDDIT数据大于27GB。使用完整的Reddit数据处理时,您可以准备一杯咖啡,因为这需要很长时间!
为了生成6K多REF测试集数据,请求运行demo.py并将数据选项设置为“完整”,该一代将位于
./data/test.refs.txt
预处理和微调的模型可在Azure Blobstorage上使用。有关如何下载/使用这些型号的更多详细信息,请运行/查看demo.py或者,您可以使用demo_utils.py中的链接直接下载。
首先,使用prepare4db.sh将TSV数据文件转换为以下脚本可以识别的正确格式。然后需要将Trainig数据处理到数据库文件中,并使用下面的命令行:
python prepro.py --corpus $DATA_PATH 该培训脚本可用于单个GPU或多个GPU设置(单个节点中的多个GPU的分布式培训):
python ./LSP_train.py # Single GPU training
python -m torch.distributed.launch --nproc_per_node=8 ./LSP_train.py # Training on 8 GPUs培训脚本接受几个论点来调整培训:
| 争论 | 类型 | 默认值 | 描述 |
|---|---|---|---|
| max_seq_length | int | 128 | 每个培训实例的最大令牌数量。 |
| train_input_file | str | "" | 以.db格式的培训数据集的路径 |
| eval_input_file | str | "" | 以TSV格式设置验证的路径 |
| 继续_从 | int | 0 | 在指定数量的步骤后恢复培训 |
| FP16 | boolean | True | 是否使用16位浮点进行模型训练。 |
| train_batch_size | int | 4 | 培训的批量尺寸 |
| 有效_batch_size | int | 4 | 验证批量尺寸 |
| gradient_accumulation_steps | int | 2 | 在几个步骤上积累梯度 |
| Learning_rate | float | 1e-5 | 学习率 |
| lr_schedule | str | noam | 可以从[ noam , noamwd , BERT , None ]中选择学习率的时间表 |
| num_optim_steps | int | 1000000 | 训练优化步骤数量 |
| no_token_id | boolean | True | 如果设置为true,则使用全零令牌式嵌入。 |
在培训期间,将更新两个日志文件。 train_log.txt和eval_log.txt包含培训和开发设置的模型丢失,困惑和训练速度(令牌/秒)统计。
可以在./models/output_model中找到日志文件和保存的模型检查点
我们注意到,即使使用正确过滤的Reddit数据集,有时我们的模型仍然可以产生中等有毒/不适当的响应。由于这个原因,我们目前无法提供解码脚本(实时演示和解码脚本访问仅受邀请才)。目前,我们仍在研究一种受控的解码方法,以防止该系统脱离有毒产生。请继续关注。
有关第三方解码方法的一些讨论,请参见第3期和Reddit讨论。
有关一些第三方解码方法,请参见下文:
我们发布了6个微调模型,可以在低资源用户注定的数据集中进行进一步的微调。与OpenAI GPT-2模型大小一致,这些模型中的总参数范围为1.17亿至762m。
| 模型 | 从GPT-2进行微调 | 从头开始训练 |
|---|---|---|
| Dialogpt 762M型号 | [链接] [Huggingface型号卡] | [关联] |
| 对话345m型号 | [链接] [Huggingface型号卡] | [关联] |
| Dialogpt 1.17亿型号 | [链接] [Huggingface型号卡] | [关联] |
| 对话345m型号(反向,用于MMI) | 关联 | - |
| 拨号(新排名模型) | 关联 | - |
可以将模型文件完全加载为GPT-2模型检查点,从HuggingFace的变压器中加载。您可以在Dialogpt的repo in ./configs/*中找到相应的配置文件( merges.txt , config.json , vocab.json )。
反向模型正在从目标中预测源。该模型用于MMI重新疗程。
我们最近提出的排名模型用于预测响应的人类反馈(upvots,答复)。这些模型可用于提高对话生成质量(有关详细信息,请参见我们的EMNLP论文)。
重新训练完整模型的第一步是生成上述27GB REDDIT数据集。这涉及从https://files.pushshift.io/reddit下载完整的reddit提交和注释转储并创建中间文件,总体上需要700GB的本地磁盘空间。下载和处理完整数据需要大约1-2天,具体取决于您的(CPU)计算(例如,在最近的计算机上使用8个内核〜24小时)。假设您运行了上述设置和安装步骤(Conda激活LSP等),则可以通过运行以下来创建完整的数据集:
python demo.py --data full
或者
cd reddit_extractor; SIZE=full make -j 8; cd ..
前命令称之为后者,因此两种方法是等效的。我们建议前者,因为后者遇到任何问题或想自定义任何参数(例如, make Command允许您仅构建数据的一个子集),因此最有用。请注意,下载阶段可能是错误的,例如基于您的地理位置(防火墙等)。如果以上命令无法生成data/train.tsv ,或者该文件不接近27GB,则意味着出现问题。在这种情况下,您可能需要检查reddit_extractor/wget-log和reddit_extractor/logs/*.log是否有任何明显的错误(例如,wget无法从pusphshift.io下载)。如果错误消息对您没有意义,请随时与我们联系。如果是这样,请确保包括从这些日志文件收集的任何错误消息。
培训数据统计信息:生成的培训TSV文件应大约26.8 GB未压缩,具有146.80万的培训实例,3.87b源代币和2.14B目标令牌(包括Tusterance-level 0/1重量)。最终的train.tsv文件应包含146,846,215行。
我们建议使用demo.py --data full生成上述数据,因为它(1)生成数据,(2)将其转换为DB格式,并且(3)使用python LSP_train.py训练模型。如果要自定义任何超参数,请直接编辑demo.py
我们的模型实现了DSTC-7挑战响应生成任务的最新结果。
| 实验 | NIST2 | NIST4 | bleu2 | bleu4 | 流星 | ENT-4 | DIST-1 | DIST-2 | avg。伦 |
|---|---|---|---|---|---|---|---|---|---|
| 人类反应 | 2.62 | 2.65 | 12.35% | 3.13% | 8.31% | 10.45 | 16.66% | 67.01% | 18.8 |
| DSTC-7获奖者 | 2.51 | 2.52 | 14.35% | 1.83% | 8.07% | 9.03 | 10.89% | 32.49% | 15.1 |
| 对话345m | 2.80 | 2.82 | 14.16% | 2.31% | 8.51% | 10.08 | 9.13% | 39.73% | 16.9 |
| 对话345m(BS) | 2.92 | 2.97 | 19.18% | 6.05% | 9.29% | 9.57 | 15.73% | 51.03% | 14.2 |
ent表示熵分数,而DIST表示独特的分数。对于除平均长度以外的所有指标,更大的时间更好。
请注意,与人类反应相比,上级自动评估并不是必需的,这意味着我们的模型实现了人类的平价。请查看我们的论文以进行更多详细的分析。
要微调DSTC-7挑战数据上的345M对话框模型,该数据具有8 V100 GPU的服务器,请运行以下命令行(可以在DSTC-7 Repo上找到DSTC数据):
python3 -m torch.distributed.launch --nproc_per_node=8 train_LSP.py --init_checkpoint ./models/medium/medium_ft.pkl --train_input_file ./data/DSTC_train.db --eval_input_file ./data/DSTC_valid.tsv --model_name_or_path ./model/medium/ --learning_rate 1e-4 --train_batch_size 64 --eval_batch_size 64 --no_token_id训练有素的模型可以在DSTC培养基模型上找到
请下载以下第三方软件包,然后保存到空文件夹3rdparty :
cpan install ):XML:twig,stort:sort:自然和字符串:util。请关注DSTC-7官方存储库以提取数据,并将data-official-test/test.refs.txt放入./dstc/data/文件夹中。
在下面运行提取脚本以产生人类响应假设文件human.resp.txt :
python extract_human.py最后,为了重现DSTC数据集上人类假设的结果,请在repo文件夹下运行以下命令:
python batch_eval.py评估结果将在文件夹中生成./dstc/eval/
我们从Reddit上测试了6K多REF数据集。结果总结在下面
| 实验 | NIST2 | NIST4 | bleu2 | bleu4 | 流星 | ENT-4 | DIST-1 | DIST-2 | avg。伦 |
|---|---|---|---|---|---|---|---|---|---|
| 人类反应 | 3.41 | 4.25 | 17.90% | 7.48% | 10.64% | 11 | 14.50% | 63.00% | 13.1 |
| 对话1.17亿 | 2.39 | 2.41 | 10.54% | 1.55% | 7.53% | 10.78 | 8.60% | 39.90% | 12.8 |
| 对话345m | 3 | 3.06 | 16.96% | 4.56% | 9.81% | 9.13 | 6.80% | 26.30% | 12.2 |
| 对话762m | 2.84 | 2.9 | 18.66% | 5.25% | 9.66% | 9.72 | 7.76% | 29.93% | 11.2 |
| 对话345m(BS) | 3.4 | 3.5 | 21.76% | 7.92% | 10.74% | 10.48 | 12.38% | 48.74% | 11.3 |
| 对话345m(w/mmi) | 3.28 | 3.33 | 15.68% | 3.94% | 11.23% | 11.25 | 9.39% | 45.55% | 17.2 |
我们进一步进行人类评估(每种方法的6K示例,每个示例均由3名人类法官评估)。结果表明,在这种非相互作用的图灵测试下,我们的一代质量是要接近实际人类反应的质量:
相关性:A和B,这与源提示更相关。
| 系统 | 胜利(%) | 领带(%) | B赢(%) | 系统b |
|---|---|---|---|---|
| 对话345m | 2671(45%) | 513(9%) | 2816(47%) | 人类的反应 |
| 对话345m | 3281(72%) | 394(9%) | 882(19%) | 人格chat |
| 对话345m w/ mmi | 2871 (48%) | 522(9%) | 2607(43%) | 人类的反应 |
信息性:A和B,这更满足和信息丰富。
| 系统 | 胜利(%) | 领带(%) | B赢(%) | 系统b |
|---|---|---|---|---|
| 对话345m | 2722(45%) | 234(4%) | 3044(51%) | 人类的反应 |
| 对话345m | 3490(77%) | 206(5%) | 861(19%) | 人格chat |
| 对话345m w/ mmi | 3011 (50%) | 234(4%) | 2755(46%) | 人类的反应 |
类似人:A和B,您认为这更可能是由人类产生的。
| 系统 | 胜利(%) | 领带(%) | B赢(%) | 系统b |
|---|---|---|---|---|
| 对话345m | 2716(45%) | 263(4%) | 3021(50%) | 人类的反应 |
| 对话345m | 3462(76%) | 196(4%) | 899(20%) | 人格chat |
| 对话345m w/ mmi | 2978 (50%) | 241(4%) | 2781(46%) | 人类的反应 |
请在我们的Arxiv论文中查看全部详细信息。
| 角色 | 回复 |
|---|---|
| 用户 | 金钱会买幸福吗? |
| 机器人 | 取决于您花多少钱。 |
| 用户 | 购买幸福的最佳方法是什么? |
| 机器人 | 您只需要在20多岁之前成为百万富翁,然后您就可以开心。 |
| 用户 | 这是如此困难! |
| 机器人 | 您不知道成为百万富翁和快乐有多困难。有钱有很多钱是有原因的 |
| 角色 | 回复 |
|---|---|
| 用户 | 生活的意义是什么? |
| 机器人 | 意思是与您的家人和朋友在一起。 |
| 机器人 | 我要猜测:这意味着您的父母和朋友爱您。我相信这就是意思。 |
| 机器人 | 我认为意思是确保您死后还活着。因此,您将与家人,朋友和您自己在一起。 |
| 机器人 | 不是,但感谢您的帮助。 |
| 用户 | 机器人 |
|---|---|
| 谁是美国第一任总统 | 乔治华盛顿 |
| 水的沸点是什么? | 我认为大约是212F。 |
| 哪一个更大,太阳还是月亮? | 太阳。 |
| 哪种动物有黑白条纹? | 黑白条纹猫。 |
| 用户 | 机器人 |
|---|---|
| 中国与美国之间的贸易战争仍在发生。 | 我确定中国政府和U。 s。只要他们将手臂放在桌子上和钱流动,就可以了。 |
| 谁在2018年赢得了世界杯? | 我相信德国。 |
| NVIDIA的Titan RTX真的很好。 | 我有这张卡,我喜欢它。没有投诉。 |
| 保时捷可以用新的Taycan EV击败特斯拉吗? | 我不这么认为。特斯拉比保时捷更强大。 |
| 用户 | 机器人 |
|---|---|
| 美好生活的意义是什么? | 我认为这是我们所有人都与我们的前世联系,生活的意义就是按照您想要的方式和想要的方式生活。 |
| 哪个更重要 - 成功还是快乐? | 成功是无尽的。快乐是唯一重要的事情。 |
| 和平是制止战争的唯一方法吗? | 如果世界不愿意与您合作。和平是最终目标,但和平不仅需要这一目标。 |
| 人们应该有多少自由? | 只要您不伤害人或财产,适当的自由就是随意做的自由。 |
| 用户 | 机器人 |
|---|---|
| 您如何看待对话回购? | 很好,但是很容易找到错误。 |
如果您发现任何问题,请开始问题:)
实时演示和解码脚本访问才受邀请。请关注完整版本。
retgen:https://github.com/dreasysnail/retgen。检索仪式/接地对话及其他。 RETGEN是一个联合培训框架,同时以端到端的方式优化了密集的通道检索器和知识接地的文本生成器。
Microsoft Icecaps:https://github.com/microsoft/icecaps。
作为该项目的正交存储库,Microsoft Icecaps是用于构建神经对话系统的开源工具包(在TensorFlow中)。 ICECAPS在灵活的范式中提供了最新的对话建模和一般NLP文献的一系列工具,该范式可以实现复杂的多任务学习设置。
预处理的Unilm:https://github.com/microsoft/unilm
MT-DNN:https://github.com/namisan/mt-dnn
中国对话的扬jianxin1。 https://github.com/yangjianxin1/gpt2-chitchat。我们很高兴看到我们在Dialogpt中使用的MMI策略也改善了该项目的性能!
如果您有任何疑问/建议,请联系[email protected]。但是,响应将是零星的。请期待延迟。
该项目欢迎贡献和建议。大多数捐款要求您同意撰写贡献者许可协议(CLA),宣布您有权并实际上授予我们使用您的贡献的权利。有关详细信息,请访问https://cla.opensource.microsoft.com。
当您提交拉动请求时,CLA机器人将自动确定您是否需要提供CLA并适当装饰PR(例如状态检查,评论)。只需按照机器人提供的说明即可。您只需要使用我们的CLA在所有存储库中进行一次。
该项目采用了Microsoft开源的行为代码。有关更多信息,请参见《行为守则常见问题守则》或与其他问题或评论联系[email protected]。
该存储库旨在促进对会话数据进行大规模预处理的研究。该工具包仅包含在运行对话框中实际生成模型权重文件所需的建模机械。该模型本身仅提供有关各种文本跨度的权重的信息;为了使研究人员实际使用它,他们将需要自行携带对话数据并从验证的系统中解码响应生成。微软对验证系统的第三方利用率中的任何一代都不承担任何责任。
如果您在研究中使用此代码,则可以引用我们的Arxiv论文:
@inproceedings{zhang2019dialogpt,
title={DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation},
author={Yizhe Zhang and Siqi Sun and Michel Galley and Yen-Chun Chen and Chris Brockett and Xiang Gao and Jianfeng Gao and Jingjing Liu and Bill Dolan},
year={2020},
booktitle={ACL, system demonstration}
}