[gfx1250][FlyDSL] Add Ragged-M OOB Support for PTPC FP8 GEMM#3582
[gfx1250][FlyDSL] Add Ragged-M OOB Support for PTPC FP8 GEMM#3582aoli26 wants to merge 15 commits into
Conversation
df142bd to
0c9c070
Compare
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
There was a problem hiding this comment.
Pull request overview
This PR updates the gfx1250 FlyDSL PTPC FP8 GEMM path to correctly and efficiently handle ragged (non-tile-aligned) M without host-side padding, including split-K and clustered launches. It introduces runtime out-of-bounds clipping for A loads / C stores and expands the tuning/kernel-id surface to include warp-shape parameters (m_warp, n_warp).
Changes:
- Add runtime M OOB handling via
oob_outer_bound(native FlyDSL when available; otherwise a vendored descriptor builder fallback). - Remove host-side M padding in the gfx1250 bpreshuffle backend and compile/cache kernels independent of M.
- Extend WMMA kernel naming/tuning to include
m_warp/n_warpand add tests covering ragged-M, split-K predicates, and the vendored OOB fallback path.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py | Expands correctness coverage for ragged-M, split-K row predicates, warp configs, and vendored OOB fallback. |
| aiter/ops/flydsl/kernels/tdm_oob.py | Adds a vendored OOB-capable TDM 2D descriptor builder for older FlyDSL versions. |
| aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py | Adds feature-detected OOB descriptor routing, split-K per-lane row predicates, and partial-tile store fallback. |
| aiter/ops/flydsl/gemm_tune/gemm_a8w8_bpreshuffle_wmma_tune.py | Passes m_warp/n_warp from tuned instances into the runner. |
| aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py | Extends candidate space and LDS estimation to account for warp shapes and updated kernel LDS allocation. |
| aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py | Removes host padding for ragged-M, compiles with M=0 for cache reuse, and updates kernelName format + parsing. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
6d46de6 to
7f73b08
Compare
d4ebb5a to
ebda70a
Compare
| gfx950,256,8192,57344,8192,torch.float8_e4m3fn,flydsl,825,0,2916.8532,flydsl_bpreshuflle_256x128x128_F8_F8_B16_2x0x1x2x0_default,2638.66,506.16,0.0 | ||
| gfx950,256,16384,57344,8192,torch.float8_e4m3fn,flydsl,825,0,5899.997,flydsl_bpreshuflle_256x128x128_F8_F8_B16_2x0x1x2x0_default,2609.01,420.85,0.0 | ||
| gfx950,256,32768,57344,8192,torch.float8_e4m3fn,ck,33,0,12218.2137,a8w8_bpreshuffle_256x256x256x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8x8x1_2x1_intrawave_v3,2519.71,368.0,0.0 | ||
| gfx1250,256,1,2112,7168,torch.float8_e4m3fn,flydsl,261,0,8.8799,flydsl_bpreshuffle_wmma_t16x32x256_mw1_nw2_nb4_sk1_cm1_cn1,3.41,1706.12,0.0066 |
There was a problem hiding this comment.
why have error ratio even we don't enable splitk?
There was a problem hiding this comment.
We've narrowed the error down to the WMMA FP8 instruction itself, likely due to internal accumulator precision. A minimal handwritten HIP reproducer bitwise matches the FlyDSL kernel, which effectively rules out the FlyDSL kernel as the source of the precision issue.
015aa51 to
08daf3a
Compare
- catalogue: exact per-stage LDS arena estimate (A row-pad + 16/128/1024
alignment + epilogue D buffer) so over-LDS tiles (e.g. t64x256x256_nb4 =
331776 B > 320 KiB) are filtered out instead of faulting the GPU at launch
- make m_warp/n_warp first-class tuned params: kernel name carries mw{m}_nw{n}
(mandatory), catalogue sweeps _WARP_COMBOS and allows tile_m=16, host dispatch
and tuner thread them through; reaches the decode-winning m_warp=1 configs
- compile with M=0: compile-time M is codegen-unused (runtime i32_m drives all
bounds via OOB), so the kernel caches per (N,K,config) and is reused across M
instead of recompiling per M
- kernel_fits_shape: drop the now-unneeded cluster M-divisibility (OOB handles
ragged M; grid rounds up), keep N tile/cluster divisibility
- drop a redundant double .contiguous() on the scale tensors
- tests: m_warp/n_warp config coverage, ragged-M split-k, full name roundtrip
aiter runs op_tests via `python3 <file>` (not pytest), so keep a dedicated main-driven test for the gfx1250 PTPC FP8 bpreshuffle backend, styled after op_tests/test_gemm_a8w8.py: @perftest-timed runner + @benchmark test_gemm into a DataFrame / markdown table, correctness in err/pass columns. - run_sweep: -mnk x -d shape sweep (real model shapes). - run_features: ragged M, strided A/C, split-k, m_warp/n_warp + cluster configs, vendored OOB descriptor, via test_gemm variants. Timing uses run_perftest(use_cuda_event=True) (the FlyDSL kernel is JIT-dispatched, which torch.profiler/ROCTracer can miss). Args mirror test_gemm_a8w8 (-d / -mnk / -o / --suffix) plus tile / split_k / num_buffers / m_warp / n_warp / --no-features. Skipped off gfx1250. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
e426134 to
ea69899
Compare
|
|
||
|
|
||
| @perftest(num_iters=TEST_NUM_ITERS, num_rotate_args=1, use_cuda_event=True) | ||
| def run_gemm_flydsl_gfx1250( |
|
|
||
|
|
||
| @benchmark() | ||
| def test_gemm_flydsl_feature( |
There was a problem hiding this comment.
need this specific feature test here?
Motivation
Enable the gfx1250 FlyDSL PTPC FP8 GEMM path to support non-tile-aligned M without host-side padding. This reduces extra allocations and copies for decode/ragged-M shapes while preserving correctness for split-K and clustered launches.
Technical Details
row < Mand fall back to clipped buffer stores for partial M tiles.m_warp/n_warp, including small-M decode-friendly configs.Test Plan
pytest op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.pyTest Result
All tests passed.
Submission Checklist