NVIDIA FLARE 实战教程:在非独立同分布 CIFAR-10 上构建并比较 FedAvg 与 FedProx 联邦学习
本文提供一份使用 NVIDIA FLARE 框架在非独立同分布 CIFAR-10 数据集上构建联邦学习实验的详细指南。通过对比 FedAvg 和 FedProx 两种聚合算法,展示如何处理标签分布不均的现实场景,并可视化全局模型精度变化。适合有 PyTorch 基础、希望入门联邦学习的开发者。
一句话看懂
NVIDIA FLARE 官方教程,手把手教你用 FedAvg 和 FedProx 在非独立同分布 CIFAR-10 上做联邦学习对比实验,含完整代码和可视化。
详细发生了什么
NVIDIA 发布了一篇面向开发者的联邦学习实战教程,使用其开源框架 NVIDIA FLARE(NVFlare)在 CIFAR-10 图像分类数据集上,对比两种经典聚合算法 FedAvg 和 FedProx 的表现。
教程的核心设定是非独立同分布(non-IID)数据分布:通过 Dirichlet 分布(alpha=0.3)将 CIFAR-10 训练集按标签倾斜分配给 3 个模拟客户端,每个客户端最多 4000 个样本,模拟真实世界中各站点数据分布不均的场景。
实验流程:
- 使用 NVFlare Job API 定义联邦任务,ScriptRunner 在每个客户端运行相同的训练脚本。
- 客户端训练一个小型 CNN 模型(无 BatchNorm),每轮接收全局权重,在本地数据上训练 1 个 epoch(SGD + momentum 0.9),然后上传更新。
- FedProx 在本地损失中加入近端项(proximal term),惩罚与全局模型的偏离,参数 mu 设为 0.1。
- 共运行 5 轮通信,每轮结束后在共享测试集上评估全局模型精度,结果写入 CSV 文件。
教程提供了完整的可运行代码,包括数据分区、客户端训练循环、作业配置和模拟运行。
中文圈视角
这篇教程对中文开发者有直接参考价值:
-
国产框架的对比:国内联邦学习框架如 FATE(微众银行)、SecretFlow(蚂蚁集团)更侧重金融场景和隐私计算,而 NVFlare 更贴近通用深度学习研究。如果你用 PyTorch 做 CV 或 NLP 实验,NVFlare 的学习成本更低。
-
非独立同分布场景的实用性:中文互联网数据(如医疗影像、方言语音)天然存在标签分布不均,教程中的 Dirichlet 分区方法可以直接复用。
-
硬件门槛:教程支持 CPU 和 GPU,但 3 个客户端模拟在普通 GPU(如 RTX 3060)上也能跑完。国内用户无需额外配置,直接 pip install nvflare 即可。
-
盲点提醒:教程只用了 5 轮和 1 个 local epoch,实际生产环境需要更多轮次和调参。FedProx 的 mu 值对收敛影响很大,中文社区目前缺少系统性的超参数对比实验。
几条值得记住的细节
- 数据分区:使用 Dirichlet 分布(alpha=0.3)生成非独立同分布标签倾斜,每个客户端最多 4000 个样本。
- 模型结构:小型 CNN(2 个卷积层 + 2 个全连接层),无 BatchNorm 以保证 state_dict 兼容。
- FedProx 实现:在本地损失中加入近端项
(mu/2) * sum((w - g)^2),mu=0.1。 - 运行方式:使用 NVFlare Job API 的
FedAvgJob和ScriptRunner,支持模拟运行(simulator_run)。 - 结果记录:全局模型精度在每个通信轮次后写入 CSV,site-1 负责记录。
一句话总结
如果你想在非独立同分布数据上快速上手联邦学习,这篇教程提供了可直接复用的 NVFlare 代码和 FedAvg/FedProx 对比模板。