[Bugfix][MLA] Correct final_lse in PS MLA prefill kernel for chunked prefill#3606
Draft
simondanielsson wants to merge 6 commits into
Draft
[Bugfix][MLA] Correct final_lse in PS MLA prefill kernel for chunked prefill#3606simondanielsson wants to merge 6 commits into
simondanielsson wants to merge 6 commits into
Conversation
The PS scheduler in v1_2_host.cuh's generate_work_info tagged tiles that a single threadgroup could absorb from block 0 with the sentinel partial_o_loc = -1, causing generate_reduce_info to exclude them from reduce_indptr / reduce_final_map / reduce_partial_map. As a result mla_reduce_v1 never wrote final_lse[qo_start:qo_end, :] for those tiles, leaving stale workspace bytes that downstream consumers (vLLM's chunked-context MLA prefill on MI350) read as garbage LSE and fed into merge_attn_states, causing a ~2-3pp gsm8k accuracy regression on DeepSeek-V3. Always assign a real partial_o_loc so every tile produces a partial output + partial LSE that reduce processes (num_partials==1 for what used to be the fast-path tiles). The kernel already supports this path -- only the host scheduler needed changing; no assembly edits. The if(work.partial_o_loc == -1) continue; guard in generate_reduce_info is left in as defensive coding. Mirrors PR ROCm#3542 ("[Triton-Gluon-MLA-GFX950] return_lse: full decode + merged fp32 lse"), which fixed the analogous bug in the triton-gluon decode kernel's NUM_KV_SPLITS==1 fast path.
Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Ensure LSE output from the PS ASM MLA prefill kernel is correct.
vllm-project/vllm#44544 integrates the non-causal prefill kernel into vLLM. However, using the kernel shipped with AITER v0.1.13.post1 yields an accuracy degradation of ~3pp on GSM8k with DeepSeekv3. As this bug report reports, this has to do with incorrect LSE's being calculated and returned from the kernel. The accuracy degradation disappears when disabling chunked prefill, implying the issue is in the non-causal path.
AFAIK no other frameworks (SGL, ATOM) use the non-causal ASM PS prefill kernel for chunked contexts (they default to FA for those batches), why this haven't been spotted and/or fixed earlier.
Technical Details
Bug was initially spotted and reported here, but no fix was pushed.
Test Plan
Test Result
All test passes. For the 8192 context non-causal test:
Submission Checklist