Federated Fusion CUDA Primitives
中文名:联邦融合算子库
CUDA primitives for federated learning server-side aggregation, optimization, compression, and robustness.
FedFuse is not another federated learning framework. It is a low-level CUDA/C++ kernel backend for the hot paths that FL frameworks usually express as Python, NumPy, PyTorch, or strategy-layer code.
Name meaning:
- Fed: federated learning.
- Fuse: fused CUDA operators that combine multiple FL server-side steps into one optimized primitive.
- 联邦融合算子库: 面向联邦学习服务端的 CUDA 融合算子库。
The goal is simple: make FL-specific server operations fast, reproducible, and easy to plug into frameworks such as Flower, FedLab, NVIDIA FLARE, OpenFL, FATE, or custom PyTorch services.
中文简介见下方:中文说明。
Modern PyTorch is already very strong for plain dense FedAvg. This project is valuable where FL logic becomes more than a single matrix-vector reduction:
- robust aggregation: coordinate median, trimmed mean, future Krum/Multi-Krum
- mixed precision updates: fp16/bf16 input with fp32 accumulation
- compressed updates: top-k sparse merge
- privacy-aware aggregation: DP noise fused with aggregation
- server optimizers: FedAdam/FedOpt-style fused update paths
- future large-model support: chunked and streaming aggregation
FedFuse focuses on those FL-specific fused primitives rather than competing with PyTorch on generic BLAS.
Implemented CUDA primitives:
weighted_sum_f32weighted_sum_f32_vec4,vec8,vec16clipped_weighted_sum_f32weighted_sum_f16_f32weighted_sum_bf16_f32topk_sparse_weighted_sum_f32noisy_weighted_sum_f32fedadam_weighted_update_f32coordinate_median_f32trimmed_mean_f32
CPU/Python references and benchmarks are included for correctness and fair comparison.
Measured on an RTX 5070 Ti Laptop GPU with CUDA 12.8.
Modern PyTorch GPU-resident baselines, 100k-1M parameter workloads:
| Operator | Scenario | PyTorch GPU | FedFuse CUDA | Speedup |
|---|---|---|---|---|
| coordinate median | 32 clients, 100k params | 0.379 ms | 0.099 ms | 3.81x |
| trimmed mean | 32 clients, 100k params | 0.426 ms | 0.097 ms | 4.38x |
| coordinate median | 128 clients, 100k params | 2.227 ms | 0.815 ms | 2.73x |
| trimmed mean | 128 clients, 100k params | 2.201 ms | 0.977 ms | 2.25x |
| fp16 update + fp32 accumulation | 32 clients, 1M params | 0.630 ms | 0.107 ms | 5.89x |
| bf16 update + fp32 accumulation | 128 clients, 1M params | 2.609 ms | 0.411 ms | 6.35x |
| FedAdam fused update | 128 clients, 1M params | 1.140 ms | 0.834 ms | 1.37x |
| top-k sparse merge | 128 clients, 1M params, k=10k | 0.072 ms | 0.049 ms | 1.47x |
See docs/benchmarks.md for the full benchmark log, including Flower, LEAF, NumPy, PyTorch CPU, PyTorch CPU-to-GPU-to-CPU, and native CUDA comparisons.
include/flk/: public C++ APIsrc/cuda/: CUDA kernels and launcherssrc/cpp/: host-side wrapperspython/flk/: Python reference implementationsbenchmarks/: native, Flower/LEAF-style, and PyTorch baselinestests/: C++/CUDA and Python correctness testsdocs/: project plan, benchmark notes, market scan, environment notesscripts/: Windows setup, build, test, and benchmark helpers
Tested stack:
- Windows
- Visual Studio 2022 Build Tools
- CUDA 12.8
- CMake + Ninja
- Python 3.11
Create the conda environment:
conda env create -f environment.yml
conda activate flkernel-devConfigure, build, and test:
.\scripts\configure.ps1
.\scripts\build.ps1
.\scripts\test_native.ps1
.\scripts\test_python.ps1Optional modern framework baselines:
pip install -r requirements-optional.txtThe current Python package provides NumPy reference implementations under the
public package name fedfuse. They are useful for trying the API and checking
CUDA results.
D:\conda\envs\flkernel-dev\python.exe .\examples\quickstart.pyMinimal usage:
import numpy as np
import fedfuse as ff
updates = np.random.randn(5, 1024).astype(np.float32)
weights = np.ones(5, dtype=np.float32) / 5
avg = ff.weighted_sum(updates, weights)
robust = ff.coordinate_median(updates)
trimmed = ff.coordinate_trimmed_mean(updates, trim_count=1)Native FL-specific operators:
.\scripts\benchmark_fl_ops.ps1 -Clients 32 -Params 1000000 -K 10000 -Device 0Modern PyTorch FL operator baselines:
D:\conda\envs\flkernel-dev\python.exe .\benchmarks\modern_fl_ops_bench.py --clients 32 --params 1000000 --k 10000 --device 0Flower and LEAF-style aggregation baselines:
D:\conda\envs\flkernel-dev\python.exe .\benchmarks\flower_fedavg_bench.py --clients 32 --params 1000000 --layers 96
D:\conda\envs\flkernel-dev\python.exe .\benchmarks\leaf_fedavg_bench.py --clients 32 --profile leaf_femnist_cnnFedFuse is meant to be optimized in the open. Good first contribution areas:
- optimize coordinate median and trimmed mean further
- add Krum/Multi-Krum kernels
- add chunked/streaming aggregation for models larger than GPU memory
- add PyTorch extension bindings
- add Linux CI and CUDA container builds
- benchmark on RTX 4090, 5090, A100, H100, L40S, and laptop GPUs
- integrate with Flower/FedLab/NVIDIA FLARE as backend adapters
See CONTRIBUTING.md.
Apache-2.0. See LICENSE.
FedFuse 是一个面向联邦学习服务端热路径的 CUDA 算子库。
它不是新的 FL 框架,而是一个可以被 Flower、FedLab、NVIDIA FLARE、OpenFL、FATE 或自研 PyTorch 服务调用的底层后端。重点不是重复实现训练流程,而是优化 FL 服务端特有的聚合、压缩、鲁棒、防隐私泄露和 server optimizer 算子。
普通 FedAvg 如果已经是 GPU-resident,现代 PyTorch 已经很快。因此 FedFuse 不应该主打“比 PyTorch 做普通平均快一点”。真正有价值的是这些 FL 特有路径:
- 鲁棒聚合:coordinate median、trimmed mean、后续 Krum/Multi-Krum
- 混合精度更新:fp16/bf16 update + fp32 accumulation
- 稀疏压缩更新:top-k sparse merge
- 隐私聚合:DP noise + aggregate fusion
- 服务端优化器:FedAdam/FedOpt fused update
- 大模型场景:后续支持分块、流式聚合,避免显存不够
- dense weighted aggregation
- clip + weighted aggregate
- fp16/bf16 update + fp32 accumulation
- top-k sparse merge
- DP noise + weighted aggregate
- FedAdam fused update
- coordinate median
- trimmed mean
- Flower / LEAF / NumPy / PyTorch / native CUDA benchmark
当前 Python 包名是 fedfuse,提供 NumPy 参考实现,方便用户先理解 API 和验证结果:
D:\conda\envs\flkernel-dev\python.exe .\examples\quickstart.py示例:
import numpy as np
import fedfuse as ff
updates = np.random.randn(5, 1024).astype(np.float32)
weights = np.ones(5, dtype=np.float32) / 5
avg = ff.weighted_sum(updates, weights)
robust = ff.coordinate_median(updates)
trimmed = ff.coordinate_trimmed_mean(updates, trim_count=1)作为雏形,FedFuse 已经证明:
- 普通 FedAvg 不是主要护城河。
- 鲁棒聚合、混合精度、稀疏合并、DP/FedOpt 融合更有价值。
- 对现代 PyTorch GPU baseline,鲁棒聚合和 mixed precision 路径已经能跑出更有说服力的优势。
欢迎大家在 GitHub 上一起优化 kernel、补 benchmark、接入 FL 框架。