本发明实施例涉及人工智能,尤其涉及一种transformer网络模型的训练方法、装置、电子设备及介质。
背景技术:
1、transformer是一种基于注意力机制的序列模型,与传统的循环神经网络(rnn)和卷积神经网络(cnn)不同,transformer仅使用自注意力机制(self-attention)来处理输入序列和输出序列,因此可以并行计算,极大地提高了计算效率。
2、随着transformer网络模型层数越来越深,参数越来越多,存在网络结构复杂,运算量大,速度慢的缺点,很难移植到低算力的终端(如嵌入式设备)中。如何在降低transformer模型参数量和计算复杂度的同时,减小模型的精度损失,是亟待解决的问题。
技术实现思路
1、本发明实施例提供一种transformer网络模型的训练方法、装置、电子设备及介质,用于解决如何在降低transformer模型参数量和计算复杂度的同时,减小模型的精度损失的问题。
2、为了解决上述技术问题,本发明是这样实现的:
3、第一方面,本发明实施例提供了一种transformer网络模型的训练方法,包括:
4、获取原始transformer网络模型;
5、采用第一训练数据集对所述原始transformer网络模型进行训练,得到教师模型;
6、对所述教师模型进行剪枝,得到学生模型;
7、采用所述教师模型和第二训练数据集对所述学生模型进行训练,其中,所述采用所述教师模型和第二训练数据集对所述学生模型进行训练包括:对所述教师模型的中间层和最后输出层的输出进行蒸馏,并根据所述中间层和最后输出层的蒸馏损失,确定训练所述学生模型过程中使用的损失。
8、可选的,所述对所述教师模型进行剪枝,得到学生模型,包括:
9、减少所述教师模型的编码器的层数和/或解码器的层数,得到学生模型。
10、可选的,所述对所述教师模型的中间层和最后输出层的输出进行蒸馏,并根据所述中间层和最后输出层的蒸馏损失,确定训练所述学生模型过程中使用的损失,包括:
11、采用两阶段蒸馏方式对所述教师模型的中间层和最后输出层的输出进行蒸馏,并根据所述中间层和最后输出层的蒸馏损失,确定训练所述学生模型过程中使用的损失;
12、其中,所述两阶段蒸馏方式包括:
13、在第一阶段,对所述教师模型的中间层的输出进行蒸馏,得到所述教师模型的中间层的蒸馏损失,并根据所述教师模型的中间层的蒸馏损失和所述教师模型的硬标签损失,确定所述第一阶段训练所述学生模型过程中使用的损失;
14、在第二阶段,对所述教师模型的最后输出层进行蒸馏,得到所述教师模型的最后输出层的蒸馏损失,将所述教师模型的最后输出层的蒸馏损失作为所述第二阶段练所述学生模型过程中使用的损失。
15、可选的,所述对所述教师模型的中间层和最后输出层的输出进行蒸馏,并根据所述中间层和最后输出层的蒸馏损失,确定训练所述学生模型过程中使用的损失,包括:
16、采用联合学习蒸馏方式对所述教师模型的中间层和最后输出层的输出进行蒸馏,并根据所述中间层和最后输出层的蒸馏损失,确定训练所述学生模型过程中使用的损失;
17、其中,所述联合学习蒸馏方式包括:
18、对所述教师模型的中间层的输出进行蒸馏,得到将所述教师模型的中间层的蒸馏损失;
19、对所述教师模型的最后输出层的输出进行蒸馏,得到所述教师模型的最后输出层的蒸馏损失;
20、将所述教师模型的中间层的蒸馏损失和所述教师模型的最后输出层的蒸馏损失之和,作为训练所述学生模型过程中使用的损失。
21、可选的,所述采用所述教师模型和第二训练数据集对所述学生模型进行训练包括:
22、将所述第二训练数据集输入所述教师模型,得到第一输出,将所述第一输出整除t,并计算整除后的结果的归一化指数函数,得到软目标,其中,t为蒸馏过程中使用的温度;
23、将所述第二训练数据集输入所述学生模型,得到第二输出,将所述第二输出整除t,并计算整除后的结果的归一化指数函数,根据计算后的结果和所述软目标,确定软损失,所述软损失包括所述教师模型的ctc分支软标签的损失和所述教师模型的注意力分支软标签的损失;
24、计算所述第二输出的归一化指数函数,根据计算后的结果和硬目标,确定硬损失,所述硬损失包括所述教师模型的ctc分支硬标签的损失和所述教师模型的注意力分支硬标签的损失;
25、根据所述软损失和所述硬损失,确定所述教师模型的最后输出层的蒸馏损失。
26、可选的,所述教师模型的最后输出层的蒸馏损失由以下至少一项确定:
27、所述教师模型的ctc分支硬标签的损失;
28、所述教师模型的注意力分支硬标签的损失;
29、所述教师模型的ctc分支软标签的损失;
30、所述教师模型的注意力分支软标签的损失。
31、可选的,所述教师模型的最后输出层的蒸馏损失的计算公式如下:
32、losstail=λ*(1-αctc_distill)*lossctc_hard+(1-λ)*(1-αatt_distill)*lossatt_hard+λ*αctc_distill*lossctc_distill+(1-λ)*αatt_distill*lossatt_distill
33、其中,losstail为所述教师模型的最后输出层的蒸馏损失,lossctc_hard为所述教师模型的ctc分支硬标签的损失,lossatt_hard为所述教师模型的注意力分支硬标签的损失,λ为lossctc_hard的权重,lossctc_distill为所述教师模型的ctc分支软标签的损失,lossatt_distill为所述教师模型的注意力分支软标签的损失,αctc_distill为lossctc_distill的权重。
34、可选的,所述对所述教师模型的中间层的输出进行蒸馏,包括:
35、选取所述教师模型的编码器或解码器中的多个目标中间层,所述目标中间层的层数与所述学生模型的编码器或解码器的中间层的层数相等,且一一对应;
36、将所述目标中间层的最后输出和所述目标中间层中的注意力层的输出作为标签,对所述学生模型的编码器或解码器的对应的中间层进行蒸馏训练。
37、可选的,所述选取所述教师模型的编码器或解码器中的多个目标中间层,包括:
38、针对所述教师模型的编码器或解码器,每隔k层选取一层目标中间层;其中,k=m/n,m为所述教师模型的编码器或解码器的中间层的总层数,n为所述学生模型的编码器或解码器的中间层的总层数。
39、可选的,所述教师模型的中间层的蒸馏损失由以下至少一项确定:
40、所述教师模型的编码器或解码器的中间层的总层数;
41、所述中间层中的隐藏层的参数;
42、所述中间层中的注意力层的参数。
43、可选的,所述教师模型的中间层的蒸馏损失的计算公式如下:
44、
45、其中,lossmiddle为所述教师模型的编码器或解码器的中间层的蒸馏损失,m为所述教师模型的编码器或解码器的中间层的总层数,i为所述教师模型的编码器或解码器的中间层的层数序号,ht为所述教师模型的编码器或解码器的中间层中的隐藏层的参数,hs为所述学生模型的编码器或解码器的中间层中的隐藏层的参数,attt为所述教师模型的编码器或解码器的中间层中的注意力层的参数,atts为所述学生模型的编码器或解码器的中间层中的注意力层的参数,k为所述教师模型的编码器或解码器的中间层的总层数,与所述学生模型的编码器或解码器的中间层的总层数的比值,mes为均方损失函数。
46、可选的,所述方法还包括:
47、采用训练后的所述学生模型对待推理数据进行推理,得到推理结果。
48、可选的,所述transformer网络模型为端到端语音识别模型,所述训练后的学生模型应用于以下场景中的至少之一:
49、离线环境;
50、私有化部署环境。
51、第二方面,本发明实施例提供了一种transformer网络模型的训练系统,包括:
52、第一获取模块,用于获取原始transformer网络模型;
53、第一训练模块,用于采用第一训练数据集对所述原始transformer网络模型进行训练,得到教师模型;
54、剪枝模块,用于对所述教师模型进行剪枝,得到学生模型;
55、第二训练模块,用于采用所述教师模型和第二训练数据集对所述学生模型进行训练,其中,所述采用所述教师模型和第二训练数据集对所述学生模型进行训练包括:对所述教师模型的中间层和最后输出层的输出进行蒸馏,并根据所述中间层和最后输出层的蒸馏损失,确定训练所述学生模型过程中使用的损失。
56、第三方面,本发明实施例提供了一种电子设备,包括:处理器、存储器及存储在所述存储器上并可在所述处理器上运行的程序,所述程序被所述处理器执行时实现如上述第一方面所述的transformer网络模型的训练方法的步骤。
57、第四方面,本发明实施例提供了一种计算机可读存储介质,所述计算机可读存储介质上存储有计算机程序,所述计算机程序被处理器执行时实现如上述第一方面所述的transformer网络模型的训练方法的步骤。
58、在本发明实施例中,对transformer网络模型进行训练得到教师模型,并对教师模型进行剪枝和蒸馏,得到一个参数量更小、计算复杂度更低的学生模型,通过蒸馏让学生模型能够更好的学习到教师模型的知识,且学生模型精度与教师模型相当,使得transformer网络模型运行在低算力的终端(如嵌入式设备)上成为可能。
1.一种transformer网络模型的训练方法,其特征在于,包括:
2.如权利要求1所述的方法,其特征在于,所述对所述教师模型进行剪枝,得到学生模型,包括:
3.如权利要求1所述的方法,其特征在于,所述对所述教师模型的中间层和最后输出层的输出进行蒸馏,并根据所述中间层和最后输出层的蒸馏损失,确定训练所述学生模型过程中使用的损失,包括:
4.如权利要求1所述的方法,其特征在于,所述对所述教师模型的中间层和最后输出层的输出进行蒸馏,并根据所述中间层和最后输出层的蒸馏损失,确定训练所述学生模型过程中使用的损失,包括:
5.如权利要求3或4所述的方法,其特征在于,所述采用所述教师模型和第二训练数据集对所述学生模型进行训练包括:
6.如权利要求3或4所述的方法,其特征在于,所述教师模型的最后输出层的蒸馏损失由以下至少一项确定:
7.如权利要求6所述的方法,其特征在于,所述教师模型的最后输出层的蒸馏损失的计算公式如下:
8.如权利要求3或4所述的方法,其特征在于,所述对所述教师模型的中间层的输出进行蒸馏,包括:
9.如权利要求8所述的方法,其特征在于,所述选取所述教师模型的编码器或解码器中的多个目标中间层,包括:
10.如权利要求3或4所述的方法,其特征在于,所述教师模型的中间层的蒸馏损失由以下至少一项确定:
11.如权利要求10所述的方法,其特征在于,所述教师模型的中间层的蒸馏损失的计算公式如下:
12.如权利要求1所述的方法,其特征在于,还包括:
13.如权利要求12所述的方法,其特征在于,所述transformer网络模型为端到端语音识别模型,所述训练后的学生模型应用于以下场景中的至少之一:
14.一种transformer网络模型的训练系统,其特征在于,包括:
15.一种电子设备,其特征在于,包括:处理器、存储器及存储在所述存储器上并可在所述处理器上运行的程序,所述程序被所述处理器执行时实现如权利要求1至13中任一项所述的transformer网络模型的训练方法的步骤。
16.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有计算机程序,所述计算机程序被处理器执行时实现如权利要求1至13中任一项所述的transformer网络模型的训练方法的步骤。
