Skip to content

[Bugfix][MLA] Correct final_lse in PS MLA prefill kernel for chunked prefill#3606

Draft
simondanielsson wants to merge 6 commits into
ROCm:mainfrom
simondanielsson:fix/mla-prefill-ps-final-lse
Draft

[Bugfix][MLA] Correct final_lse in PS MLA prefill kernel for chunked prefill#3606
simondanielsson wants to merge 6 commits into
ROCm:mainfrom
simondanielsson:fix/mla-prefill-ps-final-lse

Conversation

@simondanielsson
Copy link
Copy Markdown

@simondanielsson simondanielsson commented Jun 8, 2026

Motivation

Ensure LSE output from the PS ASM MLA prefill kernel is correct.

vllm-project/vllm#44544 integrates the non-causal prefill kernel into vLLM. However, using the kernel shipped with AITER v0.1.13.post1 yields an accuracy degradation of ~3pp on GSM8k with DeepSeekv3. As this bug report reports, this has to do with incorrect LSE's being calculated and returned from the kernel. The accuracy degradation disappears when disabling chunked prefill, implying the issue is in the non-causal path.

AFAIK no other frameworks (SGL, ATOM) use the non-causal ASM PS prefill kernel for chunked contexts (they default to FA for those batches), why this haven't been spotted and/or fixed earlier.

Technical Details

Bug was initially spotted and reported here, but no fix was pushed.

Test Plan

python3 op_tests/test_mla_prefill_ps.py -qkh 192 -vh 128 -n 1 -c 4096 -b 1 --causal true
python3 op_tests/test_mla_prefill_ps.py -qkh 192 -vh 128 -n 1 -c 4096 -b 1 --causal false
python3 op_tests/test_mla_prefill_ps.py -qkh 192 -vh 128 -n 1 -c 8192 -b 1 --causal true
python3 op_tests/test_mla_prefill_ps.py -qkh 192 -vh 128 -n 1 -c 8192 -b 1 --causal false

Test Result

All test passes. For the 8192 context non-causal test:

$ python3 op_tests/test_mla_prefill_ps.py -qkh 192 -vh 128 -n 1 -c 8192 -b 1 --causal false
[aiter] import [module_aiter_core] under /home/pedaniel/repos/aiter/.venv/lib/python3.12/site-packages/aiter/jit/module_aiter_core.so
[aiter]
calling test_mla_prefill(ctx_lens                     = 8192,
                         batch_size                   = 1,
                         num_head                     = 1,
                         qk_head_dim                  = 192,
                         v_head_dim                   = 128,
                         dtype                        = torch.float8_e4m3fn,
                         kv_dtype                     = torch.float8_e4m3fn,
                         block_size                   = 1,
                         varlen                       = False,
                         is_causal                    = False,
                         load_metadata                = False,
                         dump_metadata                = False,
                         profile_ps                   = False,
                         skip_reference               = False)
[aiter] import [module_ps_metadata] under /home/pedaniel/repos/aiter/.venv/lib/python3.12/site-packages/aiter/jit/module_ps_metadata.so
[aiter] LoadKernel: _ZN5aiter40mla_pfl_qh192_vh128_m32x8_n128x1_causal0E hsaco: /home/pedaniel/repos/aiter/.venv/lib/python3.12/site-packages/aiter_meta/hsa//gfx950/mla/mla_pfl_qh192_vh128_m32x8_n128x1_causal0.co
/home/pedaniel/repos/aiter/.venv/lib/python3.12/site-packages/torch/profiler/profiler.py:272: UserWarning: Warning: Profiler clears events at the end of each cycle.Only events from the current cycle will be reported.To keep events across cycles, set acc_events=True.
  _warn_once(
USDT:2026-06-08 08:59:21 303490:303490 ActivityProfilerController.cpp:415] profiler_start
USDT:2026-06-08 08:59:21 303490:303490 ActivityProfilerController.cpp:455] profiler_stop
[W608 08:59:21.930083243 collection.cpp:1182] Warning: ROCTracer produced duplicate flow start: 4 (function operator())
[aiter] import [module_mla_reduce] under /home/pedaniel/repos/aiter/.venv/lib/python3.12/site-packages/aiter/jit/module_mla_reduce.so
USDT:2026-06-08 08:59:21 303490:303490 ActivityProfilerController.cpp:415] profiler_start
USDT:2026-06-08 08:59:21 303490:303490 ActivityProfilerController.cpp:455] profiler_stop
[aiter] mla_prefill_ps    [torch vs aiter_asm]: us......[checkAllclose atol=0.05 rtol=0.05 passed~]
[aiter] mla_prefill_lse   [torch vs aiter_asm]: us......[checkAllclose atol=0.05 rtol=0.05 passed~]
[aiter] mla_prefill_ps summary (markdown):
|   ctx_lens |   batch_size |   num_head |   qk_head_dim |   v_head_dim | dtype               | kv_dtype            |   block_size | varlen   | is_causal   | load_metadata   | dump_metadata   | profile_ps   | skip_reference   |   us_mla_prefill_ps |   err fp8 | acc result   |   err lse | lse result   |   lse max_abs_diff |   lse mean_abs_diff |   lse max_rel_err |   lse mean_rel_err |
|-----------:|-------------:|-----------:|--------------:|-------------:|:--------------------|:--------------------|-------------:|:---------|:------------|:----------------|:----------------|:-------------|:-----------------|--------------------:|----------:|:-------------|----------:|:-------------|-------------------:|--------------------:|------------------:|-------------------:|
|       8192 |            1 |          1 |           192 |          128 | torch.float8_e4m3fn | torch.float8_e4m3fn |            1 | False    | False       | False           | False           | False        | False            |             49.8771 |         0 | passed       |         0 | passed       |          0.0152264 |          0.00302858 |        0.00159649 |        0.000318251 |
[aiter] Output: passed 1/1(100.00%) warning 0/1(0.00%) failed 0/1(0.00%)
[aiter] LSE: passed 1/1(100.00%) warning 0/1(0.00%) failed 0/1(0.00%)

Submission Checklist

The PS scheduler in v1_2_host.cuh's generate_work_info tagged tiles
that a single threadgroup could absorb from block 0 with the sentinel
partial_o_loc = -1, causing generate_reduce_info to exclude them
from reduce_indptr / reduce_final_map / reduce_partial_map. As a
result mla_reduce_v1 never wrote final_lse[qo_start:qo_end, :] for
those tiles, leaving stale workspace bytes that downstream consumers
(vLLM's chunked-context MLA prefill on MI350) read as garbage LSE
and fed into merge_attn_states, causing a ~2-3pp gsm8k accuracy
regression on DeepSeek-V3.

Always assign a real partial_o_loc so every tile produces a partial
output + partial LSE that reduce processes (num_partials==1 for what
used to be the fast-path tiles). The kernel already supports this
path -- only the host scheduler needed changing; no assembly edits.

The if(work.partial_o_loc == -1) continue; guard in
generate_reduce_info is left in as defensive coding.

Mirrors PR ROCm#3542 ("[Triton-Gluon-MLA-GFX950] return_lse: full decode
+ merged fp32 lse"), which fixed the analogous bug in the
triton-gluon decode kernel's NUM_KV_SPLITS==1 fast path.
Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 8, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3606 --add-label <label>

@simondanielsson simondanielsson changed the title [Bugfix][MLA] Correct PS MLA prefill final_lse for chunked prefill [Bugfix][MLA] Correct final_lse in PS MLA prefill kernel for chunked prefill Jun 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant