您现在的位置是:主页 > 语言相关 >
动手做个DialoGPT:基于LM的生成式多轮对话模型
2021-01-05 17:50:57语言相关 65614人已围观
前段时间刷Arixv的时候,发现清华大学开源了一个大规模的中文闲聊语料库LCCC(论文链接,项目地址),从开源的文件上来看,这可能是目前开源的数量最大、质量最好的闲聊语料库了,而且还包含了部分多轮对话聊天,总的来说可玩性还是蛮强的。笔者也被它吸引到了,尝试着用它来训练了一个闲聊对话模型,结果看上去还是不错的,在此分享一下自己的经验。
语料简介 #
这里简单介绍一下LCCC这个数据集(Large-scale Cleaned Chinese Conversation),具体细节大家可以去Github上看,下载链接也在上面。LCCC分base和large两个版本,base主要是来源于微博对话,large则是在base的基础上融合了其他开源对话语料,按照作者的说法,LCCC经过了严格的清洗过程,所以整体质量看上去还是很不错的。
为了简化任务,所有样本都被处理成双人对话。下面是一些样本示例:
A: 等过年咱们回去买点兔头好好吃顿火锅
B: 太原就没看见有好吃的兔头
A: 我从虹桥给你带个回去那天瞅到一正宗的
B: 最爱你了
A: 那是必须A: 嗯嗯,我再等等!你现在在上海吧?上海风好像比南京还大呢,少出门吧
B: 对啊,我在家,没事儿。一定要小心啊!A: 我去年也去转了一圈,还碰见以前的体育老师了,合了个影
B: 哈哈我还去找高一时侯的英语老师没找到她刚好有事情没在学校~
A: 你也是真心找回忆了哦
B: 哈哈毕业了没去过想去看看啊
模型设计 #
知道了数据长什么样之后,我们接下来就要去设计模型了。显然,我们需要做的就是训练一个模型,预测下一个该回复什么。既然语料里包含了多轮对话,那么我们还要求这个模型支持多轮对话。考虑对话历史的最简单的方式,就是把直到当前句的所有历史对话都拼接成单句文本,来作为模型的输入信息了。
给定一些输入,预测一个输出,从形式上来看我们应该用Seq2Seq模型。直接用Seq2Seq其实问题也不大,但标准的Seq2Seq一般用于形式比较固定的输入输出,比如输入的文本长度应该是集中在某个范围内,不宜变化太大,但考虑多轮对话的话,理论上我们也不知道前面有多少轮对话,因此原则上输入文本长度是无限制的。用Seq2Seq的话还有训练效率低的问题,就是我们每轮对话每次我们只能训练一句回复,如果一个多轮对话有nn句回复,那么那么就要拆分为nn个样本来训练了。
因此,我们需要一个长度能相当自由地变化的、同时能预测整一个多轮对话的模型,实现这个需求的比较适当的选择就是单向语言模型(LM、GPT),做法如下图:
如图所示,我们选择当前主流的Transformer模型,按照BERT的常规输入格式,将每句对话用[SEP]拼接起来,然后就训练一个从左往右的单向语言模型。为了区分不同的说话角色,我们对不同的说话者用不同的Segment Id区分。此外,考虑到BERT和GPT都是用了绝对位置编码,可处理的文本长度存在一个上限,而对话轮数理论上是无限的,所以这里我们采用了相对位置编码的NEZHA作为基本结构,并使用NEZHA的预训练权重作为模型的初始化权重。
说白了,就是往NEZHA里边加入了下三角形式的Attention Mask,使其变为一个语言模型,相关介绍请参考《从语言模型到Seq2Seq:Transformer如戏,全靠Mask》。
训练细节 #
首先,给出笔者的参考实现以及训练好的模型:
代码基于bert4keras实现,基本结构是NEZHA Base。为了使得训练更加“平民化”,笔者只用了一张卡(22G的RTX)配合梯度累积来训练,显存不够大的读者可以缩小batch_size。开源的模型是用2×10?52×10?5的学习率、0.010.01的权重衰减率的Adam优化器进行训练的,训练了1000万步(20多天),单步batch_size为16,累积梯度步数为16,所以相当于以256的batch_size训练了62.5万步,最终的loss大概是2.83(平均每token的交叉熵)。
我们来测试一下我们的训练结果。
>>> print(chatbot.response([u"我最近天天吃火锅"]))
火锅好吃啊
>>> print(chatbot.response([u"我最近天天吃火锅"]))
我也是
>>> print(chatbot.response([u"我最近天天吃火锅"]))
哈哈哈哈哈哈我最近一周吃了两顿火锅
>>> print(chatbot.response([u"我最近天天吃火锅"]))
我也是,每次都吃的撑
>>> print(chatbot.response([u"我最近天天吃火锅"]))
吃火锅不长胖
>>> print(chatbot.response([u"我最近天天吃火锅"]))
火锅好啊
>>> print(chatbot.response([u"我最近天天吃火锅"]))
你们那边有这种火锅?
>>> print(chatbot.response([u"我最近天天吃火锅"]))
火锅好吃吗?我也想吃
>>> print(chatbot.response([u"我最近天天吃火锅"]))
吃吧,吃了就开心了
>>> print(chatbot.response([u"我最近天天吃火锅"]))
我也想天天吃火锅!
>>> print(chatbot.response([u"我最近天天吃火锅"]))
火锅吃不了多久的。>>> print(chatbot.response([u"你那天气好吗"]))
还好,不热,有点冷
>>> print(chatbot.response([u"你那天气好吗", u"还好,不热,有点冷", u"那你得穿多厚的衣服"]))
我穿了两件
对比分析 #
CDial-GPT也开源了自己训练的预训练模型,笔者也将它转换为bert4keras能加载的格式了(CDial-GPT-tf),读者也可以测试比对一下。从训练上来看,CDial-GPT使用pytorch实现的模型,基本结构是GPT Base,使用了4张2080Ti,总batch_size为32,累积梯度64步,论文说训练了30个epoch,总步数约2100万步(笔者的两倍),因此大概相当于batch_size为2048训练了33万步。
在输入设计上,CDial-GPT也有所不同,如下图:
如图所示,CDial-GPT跟我们前述设计的主要不同是多轮对话之间的拼接方式,我们之前是直接用[SEP]连接,它是用[speaker1]、[speaker2](图中简记为S1、S2)这样的角色标记来连接,最后才用一个[SEP]表示回复结束。这样一来,由于预测部分的格式跟历史的格式不一样,因此每次只能训练一句回复,多轮对话要拆分为多个样本来训练,理论上是增加了训练复杂性的(要训练多步才能把一个多轮对话样本训练完)。
至于效果上,个人测试的感觉是两者没什么明显差别。有兴趣的读者也可以自行比较测试。
文章总结 #
本文主要分享了一次对话模型实践,基于CDial-GPT开源的LCCC闲聊语料库,利用语言模型(GPT)对多轮对话进行生成式建模,得到了一个相对通用的闲聊对话模型,最后将本文的思路与CDial-GPT本身开源的模型进行了比较。
下一篇:Mysql优化学习笔记
随机图文
-
游戏智能对战领域让电脑智能化地对战玩家或其他电脑
I. 简介 游戏智能对战是指让电脑智能化地对战玩家或其他电脑的应用。随着人工智能技术的不断发展,游戏智能对战已经成为了游戏领域的一个热门话题。本文将从游戏智能对战的定义和背景以及应用场景入手,深入探讨游戏智能对战的技术原理、应用案例和未来发展。 II. 游戏智能对战的技术原理 游戏智能对战的基本原理是通过人工智能技术,让电脑能够像人一样思考和决策,从而实现与玩家或其他电脑的对战。游戏智能对战的技 -
文生视频领域下智能字幕生成-实现自动字幕生成、智能调整等功能
I. 简介 随着文生视频领域的不断发展,智能字幕生成技术的应用越来越广泛。智能字幕生成技术可以帮助视频制作者快速生成字幕,提高视频的可读性和可搜索性,同时也可以为听力障碍者提供更好的观看体验。本文将介绍智能字幕生成技术的实现原理、应用场景以及未来发展趋势。 II. 自动字幕生成的实现 自动字幕生成技术是指通过计算机程序自动将视频中的语音转换成文字,并将文字显示在视频中。自动字幕生成技术的实现 -
短视频的彷徨与退让:积蓄力量,正待春暖花开
内容加密 -
AI换脸领域下智能换脸-提供自然语言交互,实现换脸功能
I. 概述 AI换脸技术是一种基于人工智能的图像处理技术,它可以将一个人的脸部特征转移到另一个人的脸上,从而实现换脸效果。随着人工智能技术的不断发展,AI换脸技术也越来越成熟,应用场景也越来越广泛。自然语言交互技术在AI换脸中的应用也越来越受到关注,它可以为用户提供更加便捷、自然的交互方式,实现更加智能化的换脸体验。 II. 自然语言交互 自然语言交互是一种基于自然语言的人机交互方式,它可以
猜你喜欢
站点信息
- 文章统计: 442 篇文章
- 微信公众号:扫描二维码,关注我们