Skip to content

[gfx1250][FlyDSL] Add Ragged-M OOB Support for PTPC FP8 GEMM#3582

Open
aoli26 wants to merge 15 commits into
mainfrom
aoli/flydsl_ptpc_gemm
Open

[gfx1250][FlyDSL] Add Ragged-M OOB Support for PTPC FP8 GEMM#3582
aoli26 wants to merge 15 commits into
mainfrom
aoli/flydsl_ptpc_gemm

Conversation

@aoli26
Copy link
Copy Markdown
Contributor

@aoli26 aoli26 commented Jun 7, 2026

Motivation

Enable the gfx1250 FlyDSL PTPC FP8 GEMM path to support non-tile-aligned M without host-side padding. This reduces extra allocations and copies for decode/ragged-M shapes while preserving correctness for split-K and clustered launches.

Technical Details

  • Add runtime M OOB handling for A TDM loads and output stores, with a vendored OOB-capable TDM descriptor fallback for older FlyDSL versions.
  • Remove host-side M padding in the gfx1250 bpreshuffle backend and cache compiled kernels independent of M.
  • Predicate split-K atomic stores on row < M and fall back to clipped buffer stores for partial M tiles.
  • Extend WMMA kernel names and tuning candidates with m_warp / n_warp, including small-M decode-friendly configs.
  • Update the LDS estimate to match the actual kernel allocation and pass warp config through the tuner.

Test Plan

  • pytest op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py

Test Result

All tests passed.

Submission Checklist

@aoli26 aoli26 force-pushed the aoli/flydsl_ptpc_gemm branch from df142bd to 0c9c070 Compare June 7, 2026 03:56
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 7, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3582 --add-label <label>

@aoli26 aoli26 requested a review from coderfeli June 7, 2026 04:35
@coderfeli coderfeli marked this pull request as ready for review June 7, 2026 04:38
@coderfeli coderfeli requested a review from a team June 7, 2026 04:38
Copilot AI review requested due to automatic review settings June 7, 2026 04:43
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the gfx1250 FlyDSL PTPC FP8 GEMM path to correctly and efficiently handle ragged (non-tile-aligned) M without host-side padding, including split-K and clustered launches. It introduces runtime out-of-bounds clipping for A loads / C stores and expands the tuning/kernel-id surface to include warp-shape parameters (m_warp, n_warp).

Changes:

  • Add runtime M OOB handling via oob_outer_bound (native FlyDSL when available; otherwise a vendored descriptor builder fallback).
  • Remove host-side M padding in the gfx1250 bpreshuffle backend and compile/cache kernels independent of M.
  • Extend WMMA kernel naming/tuning to include m_warp/n_warp and add tests covering ragged-M, split-K predicates, and the vendored OOB fallback path.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py Expands correctness coverage for ragged-M, split-K row predicates, warp configs, and vendored OOB fallback.
aiter/ops/flydsl/kernels/tdm_oob.py Adds a vendored OOB-capable TDM 2D descriptor builder for older FlyDSL versions.
aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py Adds feature-detected OOB descriptor routing, split-K per-lane row predicates, and partial-tile store fallback.
aiter/ops/flydsl/gemm_tune/gemm_a8w8_bpreshuffle_wmma_tune.py Passes m_warp/n_warp from tuned instances into the runner.
aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py Extends candidate space and LDS estimation to account for warp shapes and updated kernel LDS allocation.
aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py Removes host padding for ragged-M, compiles with M=0 for cache reuse, and updates kernelName format + parsing.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py Outdated
Comment thread aiter/ops/flydsl/kernels/tdm_oob.py Outdated
@aoli26 aoli26 force-pushed the aoli/flydsl_ptpc_gemm branch from 6d46de6 to 7f73b08 Compare June 7, 2026 07:24
Comment thread aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py Outdated
Comment thread aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py Outdated
Comment thread aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py Outdated
Comment thread aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py Outdated
Comment thread aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py Outdated
Comment thread op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py Outdated
@aoli26 aoli26 requested a review from valarLip June 7, 2026 12:51
@coderfeli coderfeli force-pushed the aoli/flydsl_ptpc_gemm branch 7 times, most recently from d4ebb5a to ebda70a Compare June 7, 2026 15:24
gfx950,256,8192,57344,8192,torch.float8_e4m3fn,flydsl,825,0,2916.8532,flydsl_bpreshuflle_256x128x128_F8_F8_B16_2x0x1x2x0_default,2638.66,506.16,0.0
gfx950,256,16384,57344,8192,torch.float8_e4m3fn,flydsl,825,0,5899.997,flydsl_bpreshuflle_256x128x128_F8_F8_B16_2x0x1x2x0_default,2609.01,420.85,0.0
gfx950,256,32768,57344,8192,torch.float8_e4m3fn,ck,33,0,12218.2137,a8w8_bpreshuffle_256x256x256x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8x8x1_2x1_intrawave_v3,2519.71,368.0,0.0
gfx1250,256,1,2112,7168,torch.float8_e4m3fn,flydsl,261,0,8.8799,flydsl_bpreshuffle_wmma_t16x32x256_mw1_nw2_nb4_sk1_cm1_cn1,3.41,1706.12,0.0066
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why have error ratio even we don't enable splitk?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've narrowed the error down to the WMMA FP8 instruction itself, likely due to internal accumulator precision. A minimal handwritten HIP reproducer bitwise matches the FlyDSL kernel, which effectively rules out the FlyDSL kernel as the source of the precision issue.

Comment thread op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py Outdated
@aoli26 aoli26 force-pushed the aoli/flydsl_ptpc_gemm branch 2 times, most recently from 015aa51 to 08daf3a Compare June 8, 2026 06:32
@aoli26 aoli26 requested a review from valarLip June 8, 2026 07:31
Comment thread op_tests/test_gemm_a8w8.py Outdated
Comment thread op_tests/test_gemm_a8w8.py Outdated
aoli26 and others added 7 commits June 8, 2026 15:54
- catalogue: exact per-stage LDS arena estimate (A row-pad + 16/128/1024
  alignment + epilogue D buffer) so over-LDS tiles (e.g. t64x256x256_nb4 =
  331776 B > 320 KiB) are filtered out instead of faulting the GPU at launch
- make m_warp/n_warp first-class tuned params: kernel name carries mw{m}_nw{n}
  (mandatory), catalogue sweeps _WARP_COMBOS and allows tile_m=16, host dispatch
  and tuner thread them through; reaches the decode-winning m_warp=1 configs
- compile with M=0: compile-time M is codegen-unused (runtime i32_m drives all
  bounds via OOB), so the kernel caches per (N,K,config) and is reused across M
  instead of recompiling per M
- kernel_fits_shape: drop the now-unneeded cluster M-divisibility (OOB handles
  ragged M; grid rounds up), keep N tile/cluster divisibility
- drop a redundant double .contiguous() on the scale tensors
- tests: m_warp/n_warp config coverage, ragged-M split-k, full name roundtrip
aoli26 and others added 6 commits June 8, 2026 15:54
aiter runs op_tests via `python3 <file>` (not pytest), so keep a dedicated
main-driven test for the gfx1250 PTPC FP8 bpreshuffle backend, styled after
op_tests/test_gemm_a8w8.py: @perftest-timed runner + @benchmark test_gemm
into a DataFrame / markdown table, correctness in err/pass columns.

- run_sweep: -mnk x -d shape sweep (real model shapes).
- run_features: ragged M, strided A/C, split-k, m_warp/n_warp + cluster
  configs, vendored OOB descriptor, via test_gemm variants.

Timing uses run_perftest(use_cuda_event=True) (the FlyDSL kernel is
JIT-dispatched, which torch.profiler/ROCTracer can miss). Args mirror
test_gemm_a8w8 (-d / -mnk / -o / --suffix) plus tile / split_k /
num_buffers / m_warp / n_warp / --no-features. Skipped off gfx1250.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@aoli26 aoli26 force-pushed the aoli/flydsl_ptpc_gemm branch from e426134 to ea69899 Compare June 8, 2026 07:55
Comment thread op_tests/test_gemm_a8w8.py Outdated


@perftest(num_iters=TEST_NUM_ITERS, num_rotate_args=1, use_cuda_event=True)
def run_gemm_flydsl_gfx1250(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_rotate_args

Comment thread op_tests/test_gemm_a8w8.py Outdated


@benchmark()
def test_gemm_flydsl_feature(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need this specific feature test here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants