Skip to content

retyryu4345/fedfuse

Repository files navigation

FedFuse

Federated Fusion CUDA Primitives
中文名:联邦融合算子库

CUDA primitives for federated learning server-side aggregation, optimization, compression, and robustness.

FedFuse is not another federated learning framework. It is a low-level CUDA/C++ kernel backend for the hot paths that FL frameworks usually express as Python, NumPy, PyTorch, or strategy-layer code.

Name meaning:

  • Fed: federated learning.
  • Fuse: fused CUDA operators that combine multiple FL server-side steps into one optimized primitive.
  • 联邦融合算子库: 面向联邦学习服务端的 CUDA 融合算子库。

The goal is simple: make FL-specific server operations fast, reproducible, and easy to plug into frameworks such as Flower, FedLab, NVIDIA FLARE, OpenFL, FATE, or custom PyTorch services.

中文简介见下方:中文说明

Why FedFuse?

Modern PyTorch is already very strong for plain dense FedAvg. This project is valuable where FL logic becomes more than a single matrix-vector reduction:

  • robust aggregation: coordinate median, trimmed mean, future Krum/Multi-Krum
  • mixed precision updates: fp16/bf16 input with fp32 accumulation
  • compressed updates: top-k sparse merge
  • privacy-aware aggregation: DP noise fused with aggregation
  • server optimizers: FedAdam/FedOpt-style fused update paths
  • future large-model support: chunked and streaming aggregation

FedFuse focuses on those FL-specific fused primitives rather than competing with PyTorch on generic BLAS.

Current Primitives

Implemented CUDA primitives:

  • weighted_sum_f32
  • weighted_sum_f32_vec4, vec8, vec16
  • clipped_weighted_sum_f32
  • weighted_sum_f16_f32
  • weighted_sum_bf16_f32
  • topk_sparse_weighted_sum_f32
  • noisy_weighted_sum_f32
  • fedadam_weighted_update_f32
  • coordinate_median_f32
  • trimmed_mean_f32

CPU/Python references and benchmarks are included for correctness and fair comparison.

Results Snapshot

Measured on an RTX 5070 Ti Laptop GPU with CUDA 12.8.

Modern PyTorch GPU-resident baselines, 100k-1M parameter workloads:

Operator Scenario PyTorch GPU FedFuse CUDA Speedup
coordinate median 32 clients, 100k params 0.379 ms 0.099 ms 3.81x
trimmed mean 32 clients, 100k params 0.426 ms 0.097 ms 4.38x
coordinate median 128 clients, 100k params 2.227 ms 0.815 ms 2.73x
trimmed mean 128 clients, 100k params 2.201 ms 0.977 ms 2.25x
fp16 update + fp32 accumulation 32 clients, 1M params 0.630 ms 0.107 ms 5.89x
bf16 update + fp32 accumulation 128 clients, 1M params 2.609 ms 0.411 ms 6.35x
FedAdam fused update 128 clients, 1M params 1.140 ms 0.834 ms 1.37x
top-k sparse merge 128 clients, 1M params, k=10k 0.072 ms 0.049 ms 1.47x

See docs/benchmarks.md for the full benchmark log, including Flower, LEAF, NumPy, PyTorch CPU, PyTorch CPU-to-GPU-to-CPU, and native CUDA comparisons.

Repository Layout

  • include/flk/: public C++ API
  • src/cuda/: CUDA kernels and launchers
  • src/cpp/: host-side wrappers
  • python/flk/: Python reference implementations
  • benchmarks/: native, Flower/LEAF-style, and PyTorch baselines
  • tests/: C++/CUDA and Python correctness tests
  • docs/: project plan, benchmark notes, market scan, environment notes
  • scripts/: Windows setup, build, test, and benchmark helpers

Build On Windows

Tested stack:

  • Windows
  • Visual Studio 2022 Build Tools
  • CUDA 12.8
  • CMake + Ninja
  • Python 3.11

Create the conda environment:

conda env create -f environment.yml
conda activate flkernel-dev

Configure, build, and test:

.\scripts\configure.ps1
.\scripts\build.ps1
.\scripts\test_native.ps1
.\scripts\test_python.ps1

Optional modern framework baselines:

pip install -r requirements-optional.txt

Python Quickstart

The current Python package provides NumPy reference implementations under the public package name fedfuse. They are useful for trying the API and checking CUDA results.

D:\conda\envs\flkernel-dev\python.exe .\examples\quickstart.py

Minimal usage:

import numpy as np
import fedfuse as ff

updates = np.random.randn(5, 1024).astype(np.float32)
weights = np.ones(5, dtype=np.float32) / 5

avg = ff.weighted_sum(updates, weights)
robust = ff.coordinate_median(updates)
trimmed = ff.coordinate_trimmed_mean(updates, trim_count=1)

Benchmarks

Native FL-specific operators:

.\scripts\benchmark_fl_ops.ps1 -Clients 32 -Params 1000000 -K 10000 -Device 0

Modern PyTorch FL operator baselines:

D:\conda\envs\flkernel-dev\python.exe .\benchmarks\modern_fl_ops_bench.py --clients 32 --params 1000000 --k 10000 --device 0

Flower and LEAF-style aggregation baselines:

D:\conda\envs\flkernel-dev\python.exe .\benchmarks\flower_fedavg_bench.py --clients 32 --params 1000000 --layers 96
D:\conda\envs\flkernel-dev\python.exe .\benchmarks\leaf_fedavg_bench.py --clients 32 --profile leaf_femnist_cnn

Contributing

FedFuse is meant to be optimized in the open. Good first contribution areas:

  • optimize coordinate median and trimmed mean further
  • add Krum/Multi-Krum kernels
  • add chunked/streaming aggregation for models larger than GPU memory
  • add PyTorch extension bindings
  • add Linux CI and CUDA container builds
  • benchmark on RTX 4090, 5090, A100, H100, L40S, and laptop GPUs
  • integrate with Flower/FedLab/NVIDIA FLARE as backend adapters

See CONTRIBUTING.md.

License

Apache-2.0. See LICENSE.


中文说明

FedFuse 是一个面向联邦学习服务端热路径的 CUDA 算子库。

它不是新的 FL 框架,而是一个可以被 Flower、FedLab、NVIDIA FLARE、OpenFL、FATE 或自研 PyTorch 服务调用的底层后端。重点不是重复实现训练流程,而是优化 FL 服务端特有的聚合、压缩、鲁棒、防隐私泄露和 server optimizer 算子。

为什么做这个?

普通 FedAvg 如果已经是 GPU-resident,现代 PyTorch 已经很快。因此 FedFuse 不应该主打“比 PyTorch 做普通平均快一点”。真正有价值的是这些 FL 特有路径:

  • 鲁棒聚合:coordinate median、trimmed mean、后续 Krum/Multi-Krum
  • 混合精度更新:fp16/bf16 update + fp32 accumulation
  • 稀疏压缩更新:top-k sparse merge
  • 隐私聚合:DP noise + aggregate fusion
  • 服务端优化器:FedAdam/FedOpt fused update
  • 大模型场景:后续支持分块、流式聚合,避免显存不够

当前已经实现

  • dense weighted aggregation
  • clip + weighted aggregate
  • fp16/bf16 update + fp32 accumulation
  • top-k sparse merge
  • DP noise + weighted aggregate
  • FedAdam fused update
  • coordinate median
  • trimmed mean
  • Flower / LEAF / NumPy / PyTorch / native CUDA benchmark

Python 快速尝试

当前 Python 包名是 fedfuse,提供 NumPy 参考实现,方便用户先理解 API 和验证结果:

D:\conda\envs\flkernel-dev\python.exe .\examples\quickstart.py

示例:

import numpy as np
import fedfuse as ff

updates = np.random.randn(5, 1024).astype(np.float32)
weights = np.ones(5, dtype=np.float32) / 5

avg = ff.weighted_sum(updates, weights)
robust = ff.coordinate_median(updates)
trimmed = ff.coordinate_trimmed_mean(updates, trim_count=1)

当前结论

作为雏形,FedFuse 已经证明:

  • 普通 FedAvg 不是主要护城河。
  • 鲁棒聚合、混合精度、稀疏合并、DP/FedOpt 融合更有价值。
  • 对现代 PyTorch GPU baseline,鲁棒聚合和 mixed precision 路径已经能跑出更有说服力的优势。

欢迎大家在 GitHub 上一起优化 kernel、补 benchmark、接入 FL 框架。

About

Federated Fusion CUDA primitives for FL server-side aggregation, optimization, compression, and robustness.

Topics

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors