一种基于强化学习的联邦学习鲁棒性聚合方法及系统

专利查询1月前  27


本发明涉及联邦学习,尤其涉及一种基于强化学习的联邦学习鲁棒性聚合方法及系统。


背景技术:

1、联邦学习是分布式机器学习的新范式,旨在多方在不共享本地数据的情况下协作提升人工智能模型效果。然而,在分布式场景下,各个参与方的数据规模和数据质量往往参差不一,每个参与方节点训练出的本地模型可能与最终的全局模型存在显著的偏差。同时,在联邦学习的实际应用中,各个几点可能受限于客户端状态、网络条件等物理设备等因素,参与方节点的模型参数在训练或传输过程中可能受到损害,从而影响了全局模型的性能。

2、此外,针对分布式系统的拜占庭攻击,是联邦学习系统安全的一大威胁。恶意的参与者可以通过直接修改从本地设备发送到协调服务器的本地模型参数来故意操纵训练过程,即模型中毒攻击,或通过改变局部训练集中的样本来使局部模型产生偏差,通过改变个体节点的数据,导致个体节点的训练效果较差,而现有技术通常对各个节点进行均匀聚合,难以消除较差的个体节点的影响。


技术实现思路

1、鉴于此,本发明实施例提供了一种基于强化学习的联邦学习鲁棒性聚合方法,以消除或改善现有技术中存在的一个或更多个缺陷。

2、本发明的一个方面提供了一种基于强化学习的联邦学习鲁棒性聚合方法,本方法应用于联邦网络,所述联邦网络包括服务端节点和客户端节点,该方法包括以下步骤:

3、服务端节点接收客户端节点在完成本地训练后上传的模型参数;

4、将每个客户端节点上传的模型参数构建为初始参数向量,将服务端节点当前的模型参数构建为本轮参数向量,基于所述初始参数向量、当前的本轮参数向量和客户端节点当前的权重值计算几何中值向量;

5、计算每个客户端节点当前的初始参数向量与几何中值向量的距离,并构建为状态向量输入到预设的强化学习模型中,所述强化学习模型对应每个客户端节点输出更新的权重值;

6、基于每个客户端节点输出更新的权重值、所述初始参数向量和当前的本轮参数向量计算模型参数向量,所述模型参数向量中各个维度的值均为模型参数。

7、采用上述方案,本方案首先通过各个节点上传的模型参数,计算全部节点的几何中值向量,通过所述强化学习模型对应每个客户端节点输出更新的权重值,在通过更新的权重值计算更新的模型参数,则当单个节点遭遇攻击时,上传的模型参数存在较大误差,则模型参数与几何中值向量的距离较远,则更新的权重值也会相对其他客户端节点变小,则对最终计算的影响也变小,则本方案能够消除较差的个体节点的影响。

8、在本发明的一些实施方式中,在基于所述初始参数向量、当前的本轮参数向量和客户端节点当前的权重值计算几何中值向量的步骤中,

9、基于每个客户端节点对应的初始参数向量和权重值,以及本轮参数向量计算每个客户端节点的过渡权重;

10、基于所述本轮参数向量和每个客户端节点的过渡权重计算几何中值向量。

11、在本发明的一些实施方式中,在基于每个客户端节点对应的初始参数向量和权重值,以及本轮参数向量计算每个客户端节点的过渡权重的步骤中,基于如下公式计算客户端节点的过渡权重:

12、

13、其中,表示客户端节点k的过渡权重;αk表示客户端节点k的权重值;w表示本轮参数向量;wk表示客户端节点k对应的初始参数向量。

14、在本发明的一些实施方式中,在基于所述本轮参数向量和每个客户端节点的过渡权重计算几何中值向量的步骤中,基于如下公式计算几何中值向量:

15、

16、其中,gm表示几何中值向量;表示客户端节点k的过渡权重;wk表示客户端节点k对应的初始参数向量;k表示客户端节点的总数量。

17、在本发明的一些实施方式中,在计算每个客户端节点当前的初始参数向量与几何中值向量的距离的步骤中,计算每个客户端节点当前的初始参数向量与几何中值向量的欧氏距离。

18、在本发明的一些实施方式中,在基于每个客户端节点输出更新的权重值、所述初始参数向量和当前的本轮参数向量计算模型参数向量的步骤中,基于每个客户端节点输出更新的权重值、所述初始参数向量和当前的本轮参数向量重新计算每个客户端节点的过渡权重,并基于所述本轮参数向量和每个客户端节点的过渡权重重新计算几何中值向量,将重新计算出的几何中值向量作为模型参数向量。

19、在本发明的一些实施方式中,所述方法的步骤还包括,所述服务端节点将计算模型参数向量向各个客户端节点反馈。

20、在本发明的一些实施方式中,所述方法的步骤还包括:将每次服务端节点接收客户端节点在完成本地训练后上传的模型参数,并计算出模型参数向量的过程作为一个联邦训练轮次,记录多个联邦训练轮次的训练数据。

21、在本发明的一些实施方式中,所述方法的步骤还包括:将所述训练数据中相邻联邦训练轮次的状态向量分别作为前一轮次的状态和后一个轮次的状态;将所述训练数据中各个客户端节点输出更新的权重值进行组合,作为动作;将所述训练数据中通过每个联邦训练轮次得到的模型参数向量而构建的模型的准确率作为奖励;对所述强化学习模型进行训练,更新所述强化学习模型。

22、本发明的第二方面还提供一种基于强化学习的联邦学习鲁棒性聚合系统,该系统包括计算机设备,所述计算机设备包括处理器和存储器,所述存储器中存储有计算机指令,所述处理器用于执行所述存储器中存储的计算机指令,当所述计算机指令被处理器执行时该系统实现如前所述方法所实现的步骤。

23、本发明的第三方面还提供一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时以实现前述基于强化学习的联邦学习鲁棒性聚合方法所实现的步骤。

24、本发明的附加优点、目的,以及特征将在下面的描述中将部分地加以阐述,且将对于本领域普通技术人员在研究下文后部分地变得明显,或者可以根据本发明的实践而获知。本发明的目的和其它优点可以通过在说明书以及附图中具体指出并获得。

25、本领域技术人员将会理解的是,能够用本发明实现的目的和优点不限于以上具体所述,并且根据以下详细说明将更清楚地理解本发明能够实现的上述和其他目的。



技术特征:

1.一种基于强化学习的联邦学习鲁棒性聚合方法,其特征在于,本方法应用于联邦网络,所述联邦网络包括服务端节点和客户端节点,该方法包括以下步骤:

2.根据权利要求1所述的基于强化学习的联邦学习鲁棒性聚合方法,其特征在于,在基于所述初始参数向量、当前的本轮参数向量和客户端节点当前的权重值计算几何中值向量的步骤中,

3.根据权利要求2所述的基于强化学习的联邦学习鲁棒性聚合方法,其特征在于,在基于每个客户端节点对应的初始参数向量和权重值,以及本轮参数向量计算每个客户端节点的过渡权重的步骤中,基于如下公式计算客户端节点的过渡权重:

4.根据权利要求2所述的基于强化学习的联邦学习鲁棒性聚合方法,其特征在于,在基于所述本轮参数向量和每个客户端节点的过渡权重计算几何中值向量的步骤中,基于如下公式计算几何中值向量:

5.根据权利要求1所述的基于强化学习的联邦学习鲁棒性聚合方法,其特征在于,在计算每个客户端节点当前的初始参数向量与几何中值向量的距离的步骤中,计算每个客户端节点当前的初始参数向量与几何中值向量的欧氏距离。

6.根据权利要求2所述的基于强化学习的联邦学习鲁棒性聚合方法,其特征在于,在基于每个客户端节点输出更新的权重值、所述初始参数向量和当前的本轮参数向量计算模型参数向量的步骤中,基于每个客户端节点输出更新的权重值、所述初始参数向量和当前的本轮参数向量重新计算每个客户端节点的过渡权重,并基于所述本轮参数向量和每个客户端节点的过渡权重重新计算几何中值向量,将重新计算出的几何中值向量作为模型参数向量。

7.根据权利要求1所述的基于强化学习的联邦学习鲁棒性聚合方法,其特征在于,所述方法的步骤还包括,所述服务端节点将计算模型参数向量向各个客户端节点反馈。

8.根据权利要求1~7任一项所述的基于强化学习的联邦学习鲁棒性聚合方法,其特征在于,所述方法的步骤还包括:将每次服务端节点接收客户端节点在完成本地训练后上传的模型参数,并计算出模型参数向量的过程作为一个联邦训练轮次,记录多个联邦训练轮次的训练数据。

9.根据权利要求8所述的基于强化学习的联邦学习鲁棒性聚合方法,其特征在于,所述方法的步骤还包括:将所述训练数据中相邻联邦训练轮次的状态向量分别作为前一轮次的状态和后一个轮次的状态;将所述训练数据中各个客户端节点输出更新的权重值进行组合,作为动作;将所述训练数据中通过每个联邦训练轮次得到的模型参数向量而构建的模型的准确率作为奖励;对所述强化学习模型进行训练,更新所述强化学习模型。

10.一种基于强化学习的联邦学习鲁棒性聚合系统,其特征在于,该系统包括计算机设备,所述计算机设备包括处理器和存储器,所述存储器中存储有计算机指令,所述处理器用于执行所述存储器中存储的计算机指令,当所述计算机指令被处理器执行时该系统实现如权利要求1~9任一项所述方法所实现的步骤。


技术总结
本发明提供一种基于强化学习的联邦学习鲁棒性聚合方法及系统,该方法包括以下步骤:服务端节点接收客户端节点在完成本地训练后上传的模型参数;将每个客户端节点上传的模型参数构建为初始参数向量,将服务端节点当前的模型参数构建为本轮参数向量,基于所述初始参数向量、当前的本轮参数向量和客户端节点当前的权重值计算几何中值向量;计算每个客户端节点当前的初始参数向量与几何中值向量的距离,并构建为状态向量输入到强化学习模型中,所述强化学习模型对应每个客户端节点输出更新的权重值;基于每个客户端节点输出更新的权重值、所述初始参数向量和当前的本轮参数向量计算模型参数向量,所述模型参数向量中各个维度的值均为模型参数。

技术研发人员:杜军平,闫思铮,管泽礼,梁美玉
受保护的技术使用者:北京邮电大学
技术研发日:
技术公布日:2024/12/5

最新回复(0)