一种基于网络深度压缩的模型蒸馏方法、装置和介质与流程

专利查询2023-7-13  128



1.本发明涉及模型压缩领域,尤其涉及一种基于网络深度压缩的模型蒸馏方法、装置和介质。


背景技术:

2.迄今为止,深度学习已经作为机器学习的主流分支广泛应用于各行各业。但是,大多数深度模型的高复杂度、昂贵计算成本,使其不仅难以在移动端或嵌入设备运行,且难以应用于实时任务中。如何有效压缩模型,使其减少时间计算成本的同时,保证模型精度损失最小化,在近年来得到了广泛关注。
3.知识蒸馏作为模型压缩的重要技术之一,自2015年hinton首次提出迄今,已获得广泛研究,同时已广泛应用于工业界。知识蒸馏的主要思想是先训练一个复杂网络模型,利用复杂网络引导小网络训练,使小网络的模型效果接近复杂网络。知识蒸馏方法主要可以划分为基于类别(logits)蒸馏、基于特征蒸馏蒸馏两大类,分别表示类间关系的知识蒸馏、网络中间特征的知识蒸馏。这些方法主要针对老师模型和学生模型的相近深度的网络层实现信息蒸馏,大多未在蒸馏过程中考虑网络深度对模型精度的促进,未从如何让网络浅层学习更深层次信息入手。


技术实现要素:

4.为了解决现有技术中存在的深度模型时间计算成本大、网络复杂度高,难以运行于移动端或嵌入设备的问题,且目前常用的知识蒸馏方法大都未考虑网络深度对模型精度的提升,主要对老师模型和学生模型相近深度的特征或logits进行知识蒸馏,对此,本发明提出了一种基于网络深度压缩的模型蒸馏方法、装置和介质,其具体技术方案如下:一种基于网络深度压缩的模型蒸馏方法,包括以下步骤:步骤s1:构建一类分类数据集,采用交叉熵损失函数预训练复杂教师模型,对该分类数据集的分类数据进行分类。
5.骤s2:计算分类数据经过复杂教师模型后获得的最后一层的特征距离,构建简单学生模型,计算分类数据经过简单学生模型后获得的不同网络层的特征距离;步骤s3:对经过复杂教师模型获得的最后一层的特征距离和经过简单学生模型获得的不同网络层的特征距离,利用斯皮尔曼公式匹配特征相关性,并以此构建特征关联损失函数;步骤s4:特征关联损失函数叠加交叉熵损失函数,训练简单学生模型。
6.进一步的,所述步骤s1,具体为:给定具有不同类别的样本分类数据,获取复杂教师模型网络映射函数和网络参数,输入不同类别的样本分类数据至复杂教师模型,利用交叉熵损失函数训练复杂教师模型网络,对样本分类数据进行分类。
7.进一步的,所述交叉熵损失函数训练复杂教师模型网络,表达式为:
其中为的独热编码形式,为样本,为属性分类标签,复杂教师模型网络测试结果为,其中为复杂教师模型的网络参数,为复杂教师模型网络映射函数。
8.进一步的,所述步骤s2,具体为:将样本分类数据即样本图像传入复杂教师模型网络,利用复杂教师模型网络的特征映射函数和网络参数,得到复杂教师模型网络每一层的特征结果;将样本图像传入简单学生模型网络,利用简单学生模型网络的特征映射函数和网络参数,得到简单学生模型网络每一层的特征结果;然后对多个样本图像的网络中间层特征,利用余弦相似度计算不同样本之间的特征距离,即可得到复杂教师模型网络最后一层的特征距离和简单学生模型网络不同网络层的特征距离。
9.进一步的,所述步骤s3,具体为:基于不同样本图像之间特征距离大小的排名,结合斯皮尔曼相关公式构建特征关联损失函数,表达式为其中表示特征关联损失函数,表示不同样本间特征距离从小到大的排名,
ꢀ‑
1表示网络最后一层的层系数,表示复杂教师模型网络最后一层样本特征距离,表示简单学生模型网络第层特征距离,表示输入网络样本数。
10.进一步的,所述步骤s4,具体为:关联损失函数和交叉熵损失函数,训练简单学生模型:其中表示训练简单学生模型网络所用的损失函数,其为交叉熵损失函数和特征关联损失函数的和,表示权重参数。
11.一种基于网络深度压缩的模型蒸馏装置,包括一个或多个处理器,用于实现所述的基于网络深度压缩的模型蒸馏方法。
12.一种计算机可读存储介质,其上存储有程序,该程序被处理器执行时,实现所述的基于网络深度压缩的模型蒸馏方法。
13.本发明的优点:相较于大部分已有知识蒸馏方法主要应用于复杂教师模型网络和简单学生模型网络相近深度的网络层,本发明方法考虑深度对网络提升的影响,直接进行浅层向深层的蒸馏学习,且本发明方法实现方法简便,效果提升显著,并可以与已有知识蒸馏方法同时使
用提升效果。
附图说明
14.图1为本发明方法整体流程示意图;图2为本发明提供的一种基于网络深度压缩的模型蒸馏装置的结构框图。
具体实施方式
15.为了使本发明的目的、技术方案和技术效果更加清楚明白,以下结合说明书附图和实施例,对本发明作进一步详细说明。
16.本发明的一种基于深度压缩的模型蒸馏方法,使用余弦距离计算不同数据经过模型后的特征关系,基于斯皮尔曼相关公式构建损失函数,将简单模型不同层特征关系与复杂模型最后一层的特征关系进行匹配,并以此构建损失函数,引导简单模型不同层的数据特征关系向复杂模型最深层的数据特征关系靠近,使简单模型的浅层学习到更深层特征信息,从而实现网络深度的压缩。
17.具体的,包括如下步骤:步骤s1:构建一类分类数据集,采用交叉熵损失函数预训练复杂教师模型,对该分类数据集的分类数据进行分类。
18.具体的,给定具有不同类别的样本分类数据,获取复杂教师模型网络映射函数和网络参数,输入不同类别的样本分类数据至复杂教师模型,利用交叉熵损失函数训练复杂教师模型网络,对样本分类数据进行分类。
19.例如,构建鸟类分类数据集共1.2万张,采用交叉熵损失函数预训练复杂教师模型,对200个类别的鸟类数据进行分类。
20.详细的:给定个样本和个类别的鸟类数据,为样本,为属性分类标签,复杂教师模型网络测试结果为,其中为复杂教师模型的网络参数,为复杂教师模型网络映射函数;利用如下交叉熵损失函数训练教师模型:其中为的独热编码形式。
21.骤s2:计算分类数据经过复杂教师模型后获得的最后一层的特征距离,构建简单学生模型,计算分类数据经过简单学生模型后获得的不同网络层的特征距离。
22.具体的,将样本分类数据即样本图像传入复杂教师模型网络,利用复杂教师模型网络的特征映射函数和网络参数,得到复杂教师模型网络每一层的特征结果;将样本图像传入简单学生模型网络,利用简单学生模型网络的特征映射函数和网络参数,得到简单学生模型网络每一层的特征结果;然后对多个样本图像的网络中间层特征,利用余弦相似度计算不同样本之间的特征距离,即可得到复杂教师模型网络最后一层的特征距离和简单学生模型网络不同网络层
的特征距离。
23.详细的:单个鸟类图像传入复杂教师模型网络,每一层的特征结果为,其中为复杂教师模型网络前层的特征映射函数,为复杂教师模型网络前层网络参数;简单学生模型网络每一层特征结果为,其中为简单学生模型网络前层的特征映射函数,为简单学生模型网络前层参数;对多个样本图像的网络中间层特征,利用余弦相似度计算不同样本图像之间的特征距离:其中表示批数据传入复杂教师模型网络或简单学生模型网络第层的特征矩阵的拉伸结果,表示复杂教师模型网络或简单学生模型网络第层的样本间特征距离;假定输入数据数目为个,则为行列,其中表示特征像素数,的维度为行列,描述个样本相互之间的特征距离。
24.步骤s3:如图1所示,对经过复杂教师模型获得的最后一层的特征距离和经过简单学生模型获得的不同网络层的特征距离,利用斯皮尔曼公式匹配特征相关性,并以此构建特征关联损失函数。
25.详细的:基于不同鸟类样本之间特征距离大小的排名,结合斯皮尔曼相关公式构建特征关联损失函数:其中表示特征关联损失函数,表示不同样本间特征距离从小到大的排名,
ꢀ‑
1表示网络最后一层的层系数,表示复杂教师模型网络最后一层样本特征距离,表示简单学生模型网络第层特征距离,表示输入网络样本数。
26.步骤s4:特征关联损失函数叠加交叉熵损失函数,训练简单学生模型,引导简单学生模型不同层的数据特征关系向复杂教师模型最深层的数据特征关系靠近,使简单学生模型的浅层学习到更深层特征信息,从而实现网络深度的压缩。
27.详细的:关联损失函数和交叉熵损失函数,训练简单学生模型:其中表示训练简单学生模型网络所用的损失函数,其为交叉熵损失函数和特
征关联损失函数的和,表示权重参数。
28.本发明的方法使用pytorch框架进行实验,为更明显地展示本发明方法的效果,特选择复杂老师模型和简单学生模型精度差异大的分类问题进行蒸馏操作:cub鸟类数据上选择5类1次(5-way 1-shot)的少样本分类问题,复杂老师模型和简单学生模型的基础结构分别采用conv6和conv4结构,分类方式采用prototypical networks中的特征匹配模式。使用初始学习率为0.001的adam优化器,图像尺寸为84*84,训练600次epoch,训练过程中支持集(support set)每类1个(5-way),查询集(query set)每类16个(5-way)。本发明所有实验均从零开始训练。
29.具体的,如下表1所示,各方法在cub鸟类数据集上5-way 1-shot训练的测试结果:本发明方法分别展示conv4训练、conv6训练、conv4使用kd蒸馏训练(conv4+kd)、conv4第一层与conv6最后一层特征关系匹配(conv4+fr1)、conv4第二层与conv6最后一层特征关系匹配(conv4+fr2)、conv4最后一层与conv6最后一层特征关系匹配(conv4+fr-1)、conv4所有层与conv6最后一层特征关系匹配(conv4+frall)、conv4所有层与conv6最后一层特征关系匹配结合kd蒸馏(conv4+frall+kd)的训练结果。
30.表1:与前述基于深度压缩的模型蒸馏方法的实施例相对应,本发明还提供了基于深度压缩的模型蒸馏装置的实施例。
31.参见图2,本发明实施例提供的一种基于深度压缩的模型蒸馏装置,包括一个或多个处理器,用于实现上述实施例中的基于深度压缩的模型蒸馏方法。
32.本发明基于深度压缩的模型蒸馏装置的实施例可以应用在任意具备数据处理能力的设备上,该任意具备数据处理能力的设备可以为诸如计算机等设备或装置。装置实施例可以通过软件实现,也可以通过硬件或者软硬件结合的方式实现。以软件实现为例,作为一个逻辑意义上的装置,是通过其所在任意具备数据处理能力的设备的处理器将非易失性存储器中对应的计算机程序指令读取到内存中运行形成的。从硬件层面而言,如图2所示,为本发明基于深度压缩的模型蒸馏装置所在任意具备数据处理能力的设备的一种硬件结构图,除了图2所示的处理器、内存、网络接口、以及非易失性存储器之外,实施例中装置所在的任意具备数据处理能力的设备通常根据该任意具备数据处理能力的设备的实际功能,还可以包括其他硬件,对此不再赘述。
33.上述装置中各个单元的功能和作用的实现过程具体详见上述方法中对应步骤的实现过程,在此不再赘述。
34.对于装置实施例而言,由于其基本对应于方法实施例,所以相关之处参见方法实施例的部分说明即可。以上所描述的装置实施例仅仅是示意性的,其中所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以
不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本发明方案的目的。本领域普通技术人员在不付出创造性劳动的情况下,即可以理解并实施。
35.本发明实施例还提供一种计算机可读存储介质,其上存储有程序,该程序被处理器执行时,实现上述实施例中的基于多突触可塑性脉冲神经网络快速记忆编码方法。
36.所述计算机可读存储介质可以是前述任一实施例所述的任意具备数据处理能力的设备的内部存储单元,例如硬盘或内存。所述计算机可读存储介质也可以是风力发电机的外部存储设备,例如所述设备上配备的插接式硬盘、智能存储卡(smart media card,smc)、sd卡、闪存卡(flash card)等。进一步的,所述计算机可读存储介质还可以既包括任意具备数据处理能力的设备的内部存储单元也包括外部存储设备。所述计算机可读存储介质用于存储所述计算机程序以及所述任意具备数据处理能力的设备所需的其他程序和数据,还可以用于暂时地存储已经输出或者将要输出的数据。
37.以上所述,仅为本发明的优选实施案例,并非对本发明做任何形式上的限制。虽然前文对本发明的实施过程进行了详细说明,对于熟悉本领域的人员来说,其依然可以对前述各实例记载的技术方案进行修改,或者对其中部分技术特征进行同等替换。凡在本发明精神和原则之内所做修改、同等替换等,均应包含在本发明的保护范围之内。

技术特征:
表示网络最后一层的层系数,表示复杂教师模型网络最后一层样本特征距离,表示简单学生模型网络第层特征距离,表示输入网络样本数。6.如权利要求5所述的一种基于网络深度压缩的模型蒸馏方法,其特征在于,所述步骤s4,具体为:关联损失函数和交叉熵损失函数,训练简单学生模型:其中表示训练简单学生模型网络所用的损失函数,其为交叉熵损失函数和特征关联损失函数的和,表示权重参数。7.一种基于网络深度压缩的模型蒸馏装置,其特征在于,包括一个或多个处理器,用于实现权利要求1-6中任一项所述的基于网络深度压缩的模型蒸馏方法。8.一种计算机可读存储介质,其特征在于,其上存储有程序,该程序被处理器执行时,实现权利要求1-6中任一项所述的基于网络深度压缩的模型蒸馏方法。

技术总结
本发明涉及模型压缩领域,尤其涉及一种基于网络深度压缩的模型蒸馏方法、装置和介质,该方法使用余弦距离计算不同数据经过模型后的特征关系,基于斯皮尔曼相关公式构建损失函数,将简单模型不同层特征关系与复杂模型最后一层的特征关系进行匹配,并以此构建损失函数,引导简单模型不同层的数据特征关系向复杂模型最深层的数据特征关系靠近,使简单模型的浅层学习到更深层特征信息,从而实现网络深度的压缩。相较于大部分已有知识蒸馏方法主要应用于教师网络和学生网络相近深度的网络层,本发明方法考虑深度对网络提升的影响,直接进行浅层向深层的蒸馏学习,且本发明方法实现方法简便,效果提升显著,并可以与已有知识蒸馏方法同时使用提升效果。法同时使用提升效果。法同时使用提升效果。


技术研发人员:苏慧 程乐超 杨非 鲍虎军
受保护的技术使用者:之江实验室
技术研发日:2022.02.10
技术公布日:2022/3/8

最新回复(0)