From 0e3f5f81ffc4de9f009e69513bfd59a339b30011 Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Fri, 5 Jun 2026 13:16:15 +0200 Subject: [PATCH 1/6] fix(mla): always write partial output + LSE for prefill PS tiles The PS scheduler in v1_2_host.cuh's generate_work_info tagged tiles that a single threadgroup could absorb from block 0 with the sentinel partial_o_loc = -1, causing generate_reduce_info to exclude them from reduce_indptr / reduce_final_map / reduce_partial_map. As a result mla_reduce_v1 never wrote final_lse[qo_start:qo_end, :] for those tiles, leaving stale workspace bytes that downstream consumers (vLLM's chunked-context MLA prefill on MI350) read as garbage LSE and fed into merge_attn_states, causing a ~2-3pp gsm8k accuracy regression on DeepSeek-V3. Always assign a real partial_o_loc so every tile produces a partial output + partial LSE that reduce processes (num_partials==1 for what used to be the fast-path tiles). The kernel already supports this path -- only the host scheduler needed changing; no assembly edits. The if(work.partial_o_loc == -1) continue; guard in generate_reduce_info is left in as defensive coding. Mirrors PR #3542 ("[Triton-Gluon-MLA-GFX950] return_lse: full decode + merged fp32 lse"), which fixed the analogous bug in the triton-gluon decode kernel's NUM_KV_SPLITS==1 fast path. --- csrc/kernels/mla/metadata/v1_2_host.cuh | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/csrc/kernels/mla/metadata/v1_2_host.cuh b/csrc/kernels/mla/metadata/v1_2_host.cuh index e54276d3fb..40ff7f351e 100644 --- a/csrc/kernels/mla/metadata/v1_2_host.cuh +++ b/csrc/kernels/mla/metadata/v1_2_host.cuh @@ -178,11 +178,17 @@ void kn_generate_ps_metadata(std::vector& seqlens_qo_indptr, { consuming_blocks = remaining_blocks; // This TG can process all of this qo_tile's remaining_blocks to the causal - // boundary - const int32_t partial_o_loc = - (current_block_idx == 0) - ? -1 - : (qlen_granularity * partial_tile_idx++); // -1 - no split + // boundary. + // NOTE: previously, when current_block_idx == 0 (a single TG absorbs the + // entire tile without splitting), partial_o_loc was set to the sentinel -1 + // and generate_reduce_info skipped the tile, so the reduce kernel never + // wrote final_lse[qo_start:qo_end, :]. Downstream consumers (chunked-context + // prefill in vLLM) then read stale workspace bytes as LSE, causing a + // ~2-3pp gsm8k accuracy regression on DeepSeek-V3 on MI350. Always emit a + // partial slot so mla_reduce_v1 runs (with num_partials == 1 for unsplit + // tiles) and fully populates final_lse. Mirrors the fix landed in + // ROCm/aiter#3542 for the triton-gluon MLA decode NUM_KV_SPLITS==1 path. + const int32_t partial_o_loc = qlen_granularity * partial_tile_idx++; const int32_t kv_end = std::min(kv_start + consuming_blocks, pages_kv_indptr[current_tile.batch_idx + 1]); From d6f85ab85020bbaa4e66f00d8ba781bb8fbff758 Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Fri, 5 Jun 2026 16:02:26 +0200 Subject: [PATCH 2/6] fix: dont skip reduction of tiles with num_split==1 Signed-off-by: simondanielsson --- csrc/kernels/mla/reduce.cu | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/csrc/kernels/mla/reduce.cu b/csrc/kernels/mla/reduce.cu index 12478ad844..d9e0203123 100644 --- a/csrc/kernels/mla/reduce.cu +++ b/csrc/kernels/mla/reduce.cu @@ -738,9 +738,17 @@ __launch_bounds__(Traits::kNumThreads, Traits::kOccupancy) __global__ p_lds); } } - // In theory, we can handle the case that #split = 1. However, it is meaningless and - // metadata should be in charge of getting rid of this kind of scenario. - else if(num_splits > 1) + // NOTE: num_splits == 1 must also run through the simple reduce path. + // The prefill PS scheduler (v1_2_host.cuh) used to filter these out via + // a partial_o_loc == -1 sentinel and the prefill kernel wrote `out` + // directly without LSE. That dropped final_lse on the floor for tiles + // that a single TG could absorb, corrupting downstream chunked-context + // merges in vLLM. The scheduler now always emits a real partial slot, + // so reduce must process num_splits == 1 too. The math degenerates + // cleanly: the merge loop body runs zero times, sum_e_lse stays 1.0, + // max_lse stays at partial_lse[0], final_lse = log(1) + max_lse = + // max_lse, final_out = partial_out / 1.0 = partial_out. + else { mla_reduce_v1_impl_simple( params, head_idx, block_idx, tile_idx, reduce_tile_start, reduce_tile_end, p_lds); @@ -805,9 +813,11 @@ __launch_bounds__(Traits::kNumThreads, Traits::kOccupancy) __global__ params, head_idx, block_idx, tile_idx, reduce_tile_start, reduce_tile_end, p_lds); } } - // In theory, we can handle the case that #split = 1. However, it is meaningless and metadata - // should be in charge of getting rid of this kind of scenario. - else if(num_splits > 1) + // num_splits == 1 must also run through the simple reduce path so that + // final_lse is written for tiles that a single TG can absorb. The math + // degenerates cleanly (sum_e_lse == 1.0, max_lse == partial_lse[0], + // final_lse = max_lse, final_out = partial_out). + else { mla_reduce_v1_impl_simple( params, head_idx, block_idx, tile_idx, reduce_tile_start, reduce_tile_end, p_lds); From 08d060e0e260461e0f940a10a9ac6356b86d003c Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Fri, 5 Jun 2026 16:53:13 +0200 Subject: [PATCH 3/6] fix: increase max_partial now that we support split tiles Signed-off-by: simondanielsson --- aiter/ops/attention.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index 3790d84b0f..e258e8e605 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -837,10 +837,18 @@ def get_ps_metadata_info_v1( qo_tile_cnt = batch_size * max_qo_split_per_batch # TODO: consider split q to reduce max_works & max_partials max_works = (batch_size + cus_per_cluster - 1) * max_qo_split_per_batch * num_head_k - max_partials = ( - min(batch_size + cus_per_cluster - 1, (cus_per_cluster - 1) * 2) - * max_qo_split_per_batch - ) + # NOTE: the previous max_partials formula assumed unsplit tiles (a single TG + # absorbing an entire q-tile) were tagged with partial_o_loc == -1 and thus + # contributed zero partial slots. v1_2_host.cuh now always assigns a real + # partial_o_loc so mla_reduce_v1 runs for those tiles too (required to + # populate final_lse, which downstream chunked-context merges in vLLM + # consume). Each work_info entry from one cluster now contributes one + # partial slot, so the bound must dominate per-cluster work_info count + # rather than just multi-split slots. Per-cluster work_info count is at + # most max_works / num_clusters; using (batch_size + cus_per_cluster - 1) + # * max_qo_split_per_batch is a tight upper bound that subsumes both the + # original split-driven partials and the newly-emitted unsplit-tile slots. + max_partials = (batch_size + cus_per_cluster - 1) * max_qo_split_per_batch return ( (2, torch.uint64), # work_metadata_ptrs From 5e5db60d55ad99fc65e1de281c57ff1c510f6687 Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Fri, 5 Jun 2026 17:36:31 +0200 Subject: [PATCH 4/6] fix: compute the work and partial based on actual kv size Signed-off-by: simondanielsson --- aiter/ops/attention.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index e258e8e605..8e1ab6122c 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -814,6 +814,8 @@ def get_ps_metadata_info_v1( num_head_k: int, max_qlen: int, qlen_granularity: int = 256, + max_kvlen: int | None = None, + kvlen_granularity: int = 128, ): """ Returns: @@ -833,22 +835,24 @@ def get_ps_metadata_info_v1( cus_per_cluster = cu_num // num_clusters max_qo_split_per_batch = math.ceil(max_qlen / qlen_granularity) + # When max_kvlen is not provided, fall back to max_qlen (causal-only sizing + # where KV length per q-tile is bounded by q length). Noncausal callers + # (e.g. chunked-context prefill) must pass the true max KV length per + # sequence so per-q-tile KV splits are accounted for. + effective_max_kvlen = max_kvlen if max_kvlen is not None else max_qlen + max_kv_split_per_qtile = max(1, math.ceil(effective_max_kvlen / kvlen_granularity)) qo_tile_cnt = batch_size * max_qo_split_per_batch - # TODO: consider split q to reduce max_works & max_partials - max_works = (batch_size + cus_per_cluster - 1) * max_qo_split_per_batch * num_head_k - # NOTE: the previous max_partials formula assumed unsplit tiles (a single TG - # absorbing an entire q-tile) were tagged with partial_o_loc == -1 and thus - # contributed zero partial slots. v1_2_host.cuh now always assigns a real - # partial_o_loc so mla_reduce_v1 runs for those tiles too (required to - # populate final_lse, which downstream chunked-context merges in vLLM - # consume). Each work_info entry from one cluster now contributes one - # partial slot, so the bound must dominate per-cluster work_info count - # rather than just multi-split slots. Per-cluster work_info count is at - # most max_works / num_clusters; using (batch_size + cus_per_cluster - 1) - # * max_qo_split_per_batch is a tight upper bound that subsumes both the - # original split-driven partials and the newly-emitted unsplit-tile slots. - max_partials = (batch_size + cus_per_cluster - 1) * max_qo_split_per_batch + # Per-cluster work entries are bounded by total_units (sum over q-tiles of + # their KV-split count) plus one trailing slot per cluster per q-tile from + # the scheduler's allocate_work tail. total_units <= qo_tile_cnt * + # max_kv_split_per_qtile, so per-cluster work count <= qo_tile_cnt * + # (max_kv_split_per_qtile + cus_per_cluster). max_works covers all heads. + per_cluster_work = qo_tile_cnt * (max_kv_split_per_qtile + cus_per_cluster) + max_works = num_head_k * per_cluster_work + # Each work_info entry contributes at most one partial slot. max_partials + # bounds the per-head reduce_partial_map. + max_partials = per_cluster_work return ( (2, torch.uint64), # work_metadata_ptrs From 9c44558782d22917101131bfe4d7f8c241c9aee7 Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Mon, 8 Jun 2026 09:56:21 +0200 Subject: [PATCH 5/6] fix: return final_lse instead of partial attn_lse from kernel wrapper Signed-off-by: simondanielsson --- aiter/mla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiter/mla.py b/aiter/mla.py index 4d59925bd4..f15c36c323 100644 --- a/aiter/mla.py +++ b/aiter/mla.py @@ -655,7 +655,7 @@ def mla_prefill_ps_fwd( final_lse, ) - return output.view(total_s, nhead, v_head_dim), attn_lse + return output.view(total_s, nhead, v_head_dim), final_lse @triton.jit From 253301bed8f32862b81de6ff85c6e2fcea54da28 Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Mon, 8 Jun 2026 09:58:34 +0200 Subject: [PATCH 6/6] test: add proper final lse validation test Signed-off-by: simondanielsson --- op_tests/test_mla_prefill_ps.py | 201 +++++++++++++++++--------------- 1 file changed, 106 insertions(+), 95 deletions(-) diff --git a/op_tests/test_mla_prefill_ps.py b/op_tests/test_mla_prefill_ps.py index e6d4d54d17..01b6b732ab 100644 --- a/op_tests/test_mla_prefill_ps.py +++ b/op_tests/test_mla_prefill_ps.py @@ -13,7 +13,7 @@ from aiter import dtypes from aiter import per_tensor_quant -from aiter.test_common import benchmark, checkAllclose, perftest, run_perftest +from aiter.test_common import benchmark, checkAllclose, perftest from aiter.jit.utils.chip_info import get_gfx from typing import Tuple, Optional @@ -29,30 +29,27 @@ torch.set_printoptions(sci_mode=False) -def calculate_pass_rate(df): - if "acc result" not in df.columns: +def _print_pass_rate(df, column, label): + if column not in df.columns: return - - num_tests = df["acc result"].value_counts().sum() - if "passed" in df["acc result"].value_counts(): - num_passed = df["acc result"].value_counts()["passed"] - else: - num_passed = 0 - if "warning" in df["acc result"].value_counts(): - num_warning = df["acc result"].value_counts()["warning"] - else: - num_warning = 0 - if "failed" in df["acc result"].value_counts(): - num_failed = df["acc result"].value_counts()["failed"] - else: - num_failed = 0 + counts = df[column].value_counts() + num_tests = counts.sum() + num_passed = counts.get("passed", 0) + num_warning = counts.get("warning", 0) + num_failed = counts.get("failed", 0) aiter.logger.info( - f"\033[32mpassed {num_passed}/{num_tests}({num_passed / num_tests * 100:.2f}%) \ - \033[33mwarning {num_warning}/{num_tests}({num_warning / num_tests * 100:.2f}%) \ - \033[31mfailed {num_failed}/{num_tests}({num_failed / num_tests * 100:.2f}%) \033[0m" + f"{label}: " + f"\033[32mpassed {num_passed}/{num_tests}({num_passed / num_tests * 100:.2f}%) " + f"\033[33mwarning {num_warning}/{num_tests}({num_warning / num_tests * 100:.2f}%) " + f"\033[31mfailed {num_failed}/{num_tests}({num_failed / num_tests * 100:.2f}%) \033[0m" ) +def calculate_pass_rate(df): + _print_pass_rate(df, "acc result", "Output") + _print_pass_rate(df, "lse result", "LSE") + + def ref_masked_attention( query: torch.Tensor, key: torch.Tensor, @@ -152,7 +149,7 @@ def torch_mla_extend( os.append(o) lses.append(lse) o = torch.concat(os) - lse = torch.concat(lses).transpose(0, 1) + lse = torch.concat(lses, dim=1) # [nhead, total_q] return o, lse @@ -219,7 +216,7 @@ def run_aiter_mla_reduce( output, final_lse, ) - return output, attn_lse + return output, final_lse @benchmark() @@ -389,58 +386,60 @@ def test_mla_prefill( output = torch.empty((num_tokens, num_head_q, v_head_dim), dtype=torch.bfloat16) - if profile_ps: - # pre-allocate final and partial output & lse - total_s, nhead, v_head_dim = output.shape - - tile_q = 256 - logits = torch.empty( - (reduce_partial_map.size(0) * tile_q, nhead, v_head_dim), - dtype=dtypes.fp32, - device=device, - ) - attn_lse = torch.empty( - (reduce_partial_map.size(0) * tile_q, nhead), - dtype=dtypes.fp32, - device=device, - ) - final_lse = torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device) + # Always run two-phase (asm + reduce) to get both output and final_lse + # for LSE validation, regardless of profile_ps. + total_s, nhead, _ = output.shape + tile_q = 256 + logits = torch.empty( + (reduce_partial_map.size(0) * tile_q, nhead, v_head_dim), + dtype=dtypes.fp32, + device=device, + ) + attn_lse = torch.empty( + (reduce_partial_map.size(0) * tile_q, nhead), + dtype=dtypes.fp32, + device=device, + ) + final_lse = torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device) - out_mla_prefill_asm, us_mla_prefill_asm = run_aiter_mla_prefill_asm( - q_quant, - k_quant, - v_quant, - output, - qo_indptr, - kv_indptr, - kv_indices, - work_indptr, - work_info, - max_qlen, - is_causal, - softmax_scale, - logits, - attn_lse, - q_scale, - k_scale, - v_scale, - ) - output, logits, attn_lse = out_mla_prefill_asm + out_mla_prefill_asm, us_mla_prefill_asm = run_aiter_mla_prefill_asm( + q_quant, + k_quant, + v_quant, + output, + qo_indptr, + kv_indptr, + kv_indices, + work_indptr, + work_info, + max_qlen, + is_causal, + softmax_scale, + logits, + attn_lse, + q_scale, + k_scale, + v_scale, + ) + output, logits, attn_lse = out_mla_prefill_asm - out_reduce, us_reduce = run_aiter_mla_reduce( - logits, - attn_lse, - reduce_indptr, - reduce_final_map, - reduce_partial_map, - tile_q, - output, - final_lse, - ) - output, final_lse = out_reduce - output = output.view(total_s, nhead, v_head_dim) + out_reduce, us_reduce = run_aiter_mla_reduce( + logits, + attn_lse, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + tile_q, + output, + final_lse, + ) + output, final_lse = out_reduce + output = output.view(total_s, nhead, v_head_dim) - us_mla_prefill_ps = us_mla_prefill_asm + us_reduce + us_mla_prefill_ps = us_mla_prefill_asm + us_reduce + ret["us_mla_prefill_ps"] = us_mla_prefill_ps + + if profile_ps: # calculate mla_prefill_ps kernel tflops # for causal, only take the lower triangle(ops/2) g_div = 2 if is_causal else 1 @@ -505,30 +504,6 @@ def test_mla_prefill( ret["us_reduce"] = us_reduce ret["us_reduce_ratio"] = us_reduce / us_mla_prefill_ps ret["bw_reduce(TB/s)"] = bw_reduce if effective_final_tiles > 0 else 0 - else: - _, us_aiter_asm = run_perftest( - aiter.mla.mla_prefill_ps_fwd, - q_quant, - k_quant, - v_quant, - output, - qo_indptr, - kv_indptr, - kv_indices, - work_indptr, - work_info, - max_qlen, - is_causal, - reduce_indptr, - reduce_final_map, - reduce_partial_map, - softmax_scale, - q_scale, - k_scale, - v_scale, - ) - - ret["us_mla_prefill_ps"] = us_aiter_asm if not skip_reference: # TODO: optimize reference implementation(too slow for large context length) @@ -562,6 +537,42 @@ def test_mla_prefill( ret["err fp8"] = err ret["acc result"] = status + # LSE validation: final_lse [total_q, nhead] vs lse_ref [nhead, total_q] + # Transpose final_lse to [nhead, total_q] for comparison. + asm_lse = final_lse.transpose(0, 1) # [nhead, total_q] + valid_mask = lse_ref.isfinite() + if valid_mask.any(): + asm_lse_valid = asm_lse[valid_mask] + ref_lse_valid = lse_ref[valid_mask] + lse_err = checkAllclose( + ref_lse_valid, + asm_lse_valid, + rtol=5e-2, + atol=5e-2, + msg="mla_prefill_lse [torch vs aiter_asm]: us......", + ) + if lse_err == 0: + lse_status = "passed" + elif 0 < lse_err <= 0.05: + lse_status = "warning" + else: + lse_status = "failed" + else: + lse_err = 0 + lse_status = "passed" + ret["err lse"] = lse_err + ret["lse result"] = lse_status + + # Detailed LSE stats for debugging + if valid_mask.any(): + lse_diff = (asm_lse_valid - ref_lse_valid).abs() + ret["lse max_abs_diff"] = lse_diff.max().item() + ret["lse mean_abs_diff"] = lse_diff.mean().item() + lse_denom = ref_lse_valid.abs().clamp(min=1e-6) + lse_rel = lse_diff / lse_denom + ret["lse max_rel_err"] = lse_rel.max().item() + ret["lse mean_rel_err"] = lse_rel.mean().item() + return ret