一种基于联邦学习的交通推荐系统分布式训练方法

专利查询22小时前  1


本发明涉及一种基于联邦学习的交通推荐系统分布式训练方法,属于边缘计算。


背景技术:

1、车联网实现了“车-万物”(v2x)通信,彻底改变了汽车行业,使车辆能够对交通环境进行分析,并提供大量相关的智能服务(例如障碍物识别、交通流量预测、交通推荐系统)。交通推荐系统旨在向乘客和司机提供个性化内容,包括音乐、视频、最佳行驶路线等,以提升用户的旅行体验,确保司乘出行安全。

2、传统的基于协同过滤的推荐系统在输入数据稀疏的情况下难以做出准确的推荐。这意味着基于协同过滤的推荐系统不适合应用于车联网场景,因为在车联网场景中,服务提供商需要在任何情况下都能提供准确的推荐服务。考虑到深度学习模型在各领域中所表现的出色的学习能力,它们为交通推荐系统提供了可行的解决方案。为了提供足够的数据用于模型训练并保障学习效果,传统方法是在中心服务器上使用来自不同用户的大量数据样本训练模型。然而,中心服务器并不总是可信的,可能会遭受外部攻击。此外,在传输过程中长时间暴露用户的原始数据也可能导致隐私泄露。

3、联邦学习是解决上述隐私问题的可行方案。在联邦学习过程中,每个用户只使用本地数据训练自己本地的深度学习模型,只有模型的梯度或参数被共享用于全局聚合。随后,将全局聚合所得参数发送给每个用户进行本地更新。因此,联邦学习不仅通过避免共享敏感的原始数据有效地保护了用户隐私,还通过全局聚合保证了模型训练性能。从这一点来看,基于联邦学习的训练框架十分适用于面向隐私保护的交通推荐系统。

4、然而,实际车联网场景中车辆的数据通常是非独立同分布(异质)的。由于不同车辆上的数据异质性,在每轮全局更新中,每辆车上的本地梯度更新方向都不相同,这严重阻碍了全局损失函数的收敛。此外,联邦学习是一个迭代过程,会导致大量延迟,从而降低车联网智能服务的质量。在数据非独立同分布的场景中,大多数现有的联邦学习研究仅关注标签分布偏差和特征分布偏差中的一种,而没有考虑到这两种数据异质性可能同时出现。另外,对于联邦学习加速方面的研究,现有工作主要集中于减少通信开销,而随着通信技术的快速发展,传输延迟在未来将大大减少,因此在确保学习准确性的同时缩短本地更新时间是一个更为重要和迫切的优化方向。


技术实现思路

1、目的:鉴于以上技术问题中的至少一项,本发明提供一种基于联邦学习的交通推荐系统分布式训练方法,充分考虑车联网场景中不同用户本地数据的标签和特征分布情况,提高交通推荐系统的训练精度,降低交通推荐系统的训练时延。

2、技术方案:为解决上述技术问题,本发明采用的技术方案为:

3、第一方面,本发明提供一种基于联邦学习的交通推荐系统分布式训练方法,包括:

4、基于由具备异质数据样本的车辆与边缘服务器构成的端-边协同的车联网系统下,各边缘服务器分别与各预设区域一一对应,针对预设区域中各车辆上交通推荐系统对应的深度学习模型训练任务,端-边协同的车联网系统执行以下步骤:

5、步骤a:针对每个车辆,利用香农熵生成车辆上本地数据样本的标签分布,利用车辆上本地数据样本对初始本地模型进行本地训练,获取训练后得到的本地模型参数以及本地模型最后一个卷积层输出的所有特征图;根据特征图利用稀疏度向量生成车辆上本地数据样本的特征分布;并将车辆上本地数据样本的标签分布、特征分布和本地模型参数上传给对应的边缘服务器;

6、步骤b:边缘服务器获取对应区域内各车辆上本地数据样本的标签分布和特征分布,基于各车辆上本地数据样本的标签分布的相似性和特征分布的相似性为每个车辆选择合作者,并将当前全局更新中合作者最新的本地模型参数发送给对应的车辆;

7、步骤c:针对每个车辆,获取当前全局更新中合作者最新的本地模型参数,计算各合作者的聚合权重,根据该车辆的本地模型参数、所有合作者的最新的本地模型参数和聚合权重进行聚合得到全局更新后的本地模型参数,并利用车辆上本地数据样本对全局更新后的本地模型进行本地训练,获取训练后得到的本地模型参数以及本地模型最后一个卷积层输出的所有特征图,根据特征图利用稀疏度向量生成车辆上本地数据样本的特征分布;并将该车辆上本地数据样本的特征分布和本地模型参数上传至边缘服务器进行更新;

8、步骤d:重复步骤b至步骤c,直至全局损失函数收敛。

9、在一些实施例中,步骤a中,利用香农熵生成车辆上本地数据样本的标签分布,包括:

10、针对预设区域内第个车辆,车辆上本地数据样本数量为,样本所涉及类别数量为,属于第个类别的样本数量为,根据香农熵,该车辆上本地数据样本的标签分布为:

11、。

12、在一些实施例中,步骤a、步骤c中,根据特征图利用稀疏度向量生成车辆上本地数据样本的特征分布,包括:

13、针对车辆中本地模型最后一个卷积层输出的特征图,表示所有特征图第个通道中0元素数量的总和,表示所有特征图第


最新回复(0)