Skip to content

Conversation

@m96-chan
Copy link
Owner

@m96-chan m96-chan commented Jan 1, 2026

Summary

  • Add native CUDA kernels for FLUX.1 transformer operations (layer norm, modulate, gated residual, concat, split, scale, broadcast)
  • Implement cuBLAS strided batched GEMM for efficient batched matrix multiplication
  • Update FLUX ops.py to use GPU-native kernels, eliminating H2D/D2H transfers
  • Add NumPy validation tests for all new kernels

Changes

CUDA Kernels (native/ops/nn/diffusion/flux_kernels.cuh)

  • layer_norm_simple_kernel - Simple layer normalization (no affine)
  • modulate_kernel - AdaLN modulation: y = x * (1 + scale) + shift
  • gated_residual_kernel - Gated residual: y = residual + gate * value
  • scale_tensor_kernel - Element-wise scaling
  • concat_axis1_kernel / split_axis1_kernel - Axis-1 concatenation/splitting
  • add_broadcast_kernel - Broadcasting addition
  • layer_norm_modulate_fused_kernel - Fused LayerNorm + modulation

cuBLAS Integration (native/ops/matmul/batched.cu)

  • Implement batched_matmul_fp32 using cuBLAS sgemm_strided_batched
  • Proper row-major to column-major conversion
  • Stream-compatible for CUDA Graph capture

Python API Updates (src/pygpukit/diffusion/models/flux/ops.py)

  • All operations now use native GPU kernels instead of NumPy fallbacks
  • Eliminates CPU roundtrips during inference

Test plan

  • Build passes (SM86)
  • 10 NumPy validation tests pass (tests/test_flux_kernels.py)
  • Ruff lint passes
  • Mypy type check passes

Benchmark

Run python -m pytest tests/test_flux_kernels.py -v to verify:

test_layer_norm_simple PASSED
test_modulate PASSED
test_gated_residual PASSED
test_scale_tensor PASSED
test_concat_axis1 PASSED
test_split_axis1 PASSED
test_add_broadcast PASSED
test_layer_norm_modulate PASSED
test_batched_matmul_3d PASSED
test_batched_matmul_4d PASSED

Closes #187

🤖 Generated with Claude Code

m96-chan and others added 2 commits January 2, 2026 05:29
Implements GPU-native operations to eliminate H2D/D2H transfer overhead:

CUDA Kernels:
- layer_norm_simple: LayerNorm without learnable params
- modulate: AdaLN-style modulation (x * (1 + scale) + shift)
- gated_residual: Gated residual connection
- scale_tensor: Scalar multiplication
- concat_axis1/split_axis1: Tensor manipulation along axis 1
- apply_rope: Rotary position embedding
- layer_norm_modulate: Fused LayerNorm + modulation
- add_broadcast: Broadcasting addition

Fixes:
- batched_matmul now uses cuBLAS sgemm_strided_batched
- Proper row-major to column-major conversion for cuBLAS

Tests:
- NumPy validation tests for all new kernels
- 3D and 4D batched matmul tests

Closes #187

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Update blocks.py to use GPU-native operations:
  - Replace NumPy gated residual with gpu_gated_residual
  - Replace NumPy layer norm with gpu_layer_norm
  - Replace NumPy modulate with gpu_modulate
  - Use gpu_concat_axis1 and gpu_split_axis1 for tensor ops
- Update attention.py layer_norm to use GPU-native kernel
- Add benchmark script for PyGPUkit vs Diffusers comparison

Performance analysis shows matmul operations are fast (30-70 TFLOPS)
but significant overhead from:
- GC overhead (6.6s) from temporary GPUArray allocations
- reshape_copy operations (5.7s)
Further optimization requires reducing temporary allocations.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

perf(diffusion): FLUX.1 transformer performance optimization

2 participants