Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aiter/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ def mla_prefill_ps_fwd(
final_lse,
)

return output.view(total_s, nhead, v_head_dim), attn_lse
return output.view(total_s, nhead, v_head_dim), final_lse


@triton.jit
Expand Down
24 changes: 18 additions & 6 deletions aiter/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,8 @@ def get_ps_metadata_info_v1(
num_head_k: int,
max_qlen: int,
qlen_granularity: int = 256,
max_kvlen: int | None = None,
kvlen_granularity: int = 128,
):
"""
Returns:
Expand All @@ -833,14 +835,24 @@ def get_ps_metadata_info_v1(
cus_per_cluster = cu_num // num_clusters

max_qo_split_per_batch = math.ceil(max_qlen / qlen_granularity)
# When max_kvlen is not provided, fall back to max_qlen (causal-only sizing
# where KV length per q-tile is bounded by q length). Noncausal callers
# (e.g. chunked-context prefill) must pass the true max KV length per
# sequence so per-q-tile KV splits are accounted for.
effective_max_kvlen = max_kvlen if max_kvlen is not None else max_qlen
max_kv_split_per_qtile = max(1, math.ceil(effective_max_kvlen / kvlen_granularity))

qo_tile_cnt = batch_size * max_qo_split_per_batch
# TODO: consider split q to reduce max_works & max_partials
max_works = (batch_size + cus_per_cluster - 1) * max_qo_split_per_batch * num_head_k
max_partials = (
min(batch_size + cus_per_cluster - 1, (cus_per_cluster - 1) * 2)
* max_qo_split_per_batch
)
# Per-cluster work entries are bounded by total_units (sum over q-tiles of
# their KV-split count) plus one trailing slot per cluster per q-tile from
# the scheduler's allocate_work tail. total_units <= qo_tile_cnt *
# max_kv_split_per_qtile, so per-cluster work count <= qo_tile_cnt *
# (max_kv_split_per_qtile + cus_per_cluster). max_works covers all heads.
per_cluster_work = qo_tile_cnt * (max_kv_split_per_qtile + cus_per_cluster)
max_works = num_head_k * per_cluster_work
# Each work_info entry contributes at most one partial slot. max_partials
# bounds the per-head reduce_partial_map.
max_partials = per_cluster_work

return (
(2, torch.uint64), # work_metadata_ptrs
Expand Down
16 changes: 11 additions & 5 deletions csrc/kernels/mla/metadata/v1_2_host.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,17 @@ void kn_generate_ps_metadata(std::vector<int32_t>& seqlens_qo_indptr,
{
consuming_blocks = remaining_blocks;
// This TG can process all of this qo_tile's remaining_blocks to the causal
// boundary
const int32_t partial_o_loc =
(current_block_idx == 0)
? -1
: (qlen_granularity * partial_tile_idx++); // -1 - no split
// boundary.
// NOTE: previously, when current_block_idx == 0 (a single TG absorbs the
// entire tile without splitting), partial_o_loc was set to the sentinel -1
// and generate_reduce_info skipped the tile, so the reduce kernel never
// wrote final_lse[qo_start:qo_end, :]. Downstream consumers (chunked-context
// prefill in vLLM) then read stale workspace bytes as LSE, causing a
// ~2-3pp gsm8k accuracy regression on DeepSeek-V3 on MI350. Always emit a
// partial slot so mla_reduce_v1 runs (with num_partials == 1 for unsplit
// tiles) and fully populates final_lse. Mirrors the fix landed in
// ROCm/aiter#3542 for the triton-gluon MLA decode NUM_KV_SPLITS==1 path.
const int32_t partial_o_loc = qlen_granularity * partial_tile_idx++;
const int32_t kv_end =
std::min(kv_start + consuming_blocks,
pages_kv_indptr[current_tile.batch_idx + 1]);
Expand Down
22 changes: 16 additions & 6 deletions csrc/kernels/mla/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -738,9 +738,17 @@ __launch_bounds__(Traits::kNumThreads, Traits::kOccupancy) __global__
p_lds);
}
}
// In theory, we can handle the case that #split = 1. However, it is meaningless and
// metadata should be in charge of getting rid of this kind of scenario.
else if(num_splits > 1)
// NOTE: num_splits == 1 must also run through the simple reduce path.
// The prefill PS scheduler (v1_2_host.cuh) used to filter these out via
// a partial_o_loc == -1 sentinel and the prefill kernel wrote `out`
// directly without LSE. That dropped final_lse on the floor for tiles
// that a single TG could absorb, corrupting downstream chunked-context
// merges in vLLM. The scheduler now always emits a real partial slot,
// so reduce must process num_splits == 1 too. The math degenerates
// cleanly: the merge loop body runs zero times, sum_e_lse stays 1.0,
// max_lse stays at partial_lse[0], final_lse = log(1) + max_lse =
// max_lse, final_out = partial_out / 1.0 = partial_out.
else
{
mla_reduce_v1_impl_simple<Traits, lse_t, out_t>(
params, head_idx, block_idx, tile_idx, reduce_tile_start, reduce_tile_end, p_lds);
Expand Down Expand Up @@ -805,9 +813,11 @@ __launch_bounds__(Traits::kNumThreads, Traits::kOccupancy) __global__
params, head_idx, block_idx, tile_idx, reduce_tile_start, reduce_tile_end, p_lds);
}
}
// In theory, we can handle the case that #split = 1. However, it is meaningless and metadata
// should be in charge of getting rid of this kind of scenario.
else if(num_splits > 1)
// num_splits == 1 must also run through the simple reduce path so that
// final_lse is written for tiles that a single TG can absorb. The math
// degenerates cleanly (sum_e_lse == 1.0, max_lse == partial_lse[0],
// final_lse = max_lse, final_out = partial_out).
else
{
mla_reduce_v1_impl_simple<Traits, lse_t, out_t>(
params, head_idx, block_idx, tile_idx, reduce_tile_start, reduce_tile_end, p_lds);
Expand Down
201 changes: 106 additions & 95 deletions op_tests/test_mla_prefill_ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from aiter import dtypes
from aiter import per_tensor_quant
from aiter.test_common import benchmark, checkAllclose, perftest, run_perftest
from aiter.test_common import benchmark, checkAllclose, perftest
from aiter.jit.utils.chip_info import get_gfx

from typing import Tuple, Optional
Expand All @@ -29,30 +29,27 @@
torch.set_printoptions(sci_mode=False)


def calculate_pass_rate(df):
if "acc result" not in df.columns:
def _print_pass_rate(df, column, label):
if column not in df.columns:
return

num_tests = df["acc result"].value_counts().sum()
if "passed" in df["acc result"].value_counts():
num_passed = df["acc result"].value_counts()["passed"]
else:
num_passed = 0
if "warning" in df["acc result"].value_counts():
num_warning = df["acc result"].value_counts()["warning"]
else:
num_warning = 0
if "failed" in df["acc result"].value_counts():
num_failed = df["acc result"].value_counts()["failed"]
else:
num_failed = 0
counts = df[column].value_counts()
num_tests = counts.sum()
num_passed = counts.get("passed", 0)
num_warning = counts.get("warning", 0)
num_failed = counts.get("failed", 0)
aiter.logger.info(
f"\033[32mpassed {num_passed}/{num_tests}({num_passed / num_tests * 100:.2f}%) \
\033[33mwarning {num_warning}/{num_tests}({num_warning / num_tests * 100:.2f}%) \
\033[31mfailed {num_failed}/{num_tests}({num_failed / num_tests * 100:.2f}%) \033[0m"
f"{label}: "
f"\033[32mpassed {num_passed}/{num_tests}({num_passed / num_tests * 100:.2f}%) "
f"\033[33mwarning {num_warning}/{num_tests}({num_warning / num_tests * 100:.2f}%) "
f"\033[31mfailed {num_failed}/{num_tests}({num_failed / num_tests * 100:.2f}%) \033[0m"
)


def calculate_pass_rate(df):
_print_pass_rate(df, "acc result", "Output")
_print_pass_rate(df, "lse result", "LSE")


def ref_masked_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -152,7 +149,7 @@ def torch_mla_extend(
os.append(o)
lses.append(lse)
o = torch.concat(os)
lse = torch.concat(lses).transpose(0, 1)
lse = torch.concat(lses, dim=1) # [nhead, total_q]
return o, lse


Expand Down Expand Up @@ -219,7 +216,7 @@ def run_aiter_mla_reduce(
output,
final_lse,
)
return output, attn_lse
return output, final_lse


@benchmark()
Expand Down Expand Up @@ -389,58 +386,60 @@ def test_mla_prefill(

output = torch.empty((num_tokens, num_head_q, v_head_dim), dtype=torch.bfloat16)

if profile_ps:
# pre-allocate final and partial output & lse
total_s, nhead, v_head_dim = output.shape

tile_q = 256
logits = torch.empty(
(reduce_partial_map.size(0) * tile_q, nhead, v_head_dim),
dtype=dtypes.fp32,
device=device,
)
attn_lse = torch.empty(
(reduce_partial_map.size(0) * tile_q, nhead),
dtype=dtypes.fp32,
device=device,
)
final_lse = torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device)
# Always run two-phase (asm + reduce) to get both output and final_lse
# for LSE validation, regardless of profile_ps.
total_s, nhead, _ = output.shape
tile_q = 256
logits = torch.empty(
(reduce_partial_map.size(0) * tile_q, nhead, v_head_dim),
dtype=dtypes.fp32,
device=device,
)
attn_lse = torch.empty(
(reduce_partial_map.size(0) * tile_q, nhead),
dtype=dtypes.fp32,
device=device,
)
final_lse = torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device)

out_mla_prefill_asm, us_mla_prefill_asm = run_aiter_mla_prefill_asm(
q_quant,
k_quant,
v_quant,
output,
qo_indptr,
kv_indptr,
kv_indices,
work_indptr,
work_info,
max_qlen,
is_causal,
softmax_scale,
logits,
attn_lse,
q_scale,
k_scale,
v_scale,
)
output, logits, attn_lse = out_mla_prefill_asm
out_mla_prefill_asm, us_mla_prefill_asm = run_aiter_mla_prefill_asm(
q_quant,
k_quant,
v_quant,
output,
qo_indptr,
kv_indptr,
kv_indices,
work_indptr,
work_info,
max_qlen,
is_causal,
softmax_scale,
logits,
attn_lse,
q_scale,
k_scale,
v_scale,
)
output, logits, attn_lse = out_mla_prefill_asm

out_reduce, us_reduce = run_aiter_mla_reduce(
logits,
attn_lse,
reduce_indptr,
reduce_final_map,
reduce_partial_map,
tile_q,
output,
final_lse,
)
output, final_lse = out_reduce
output = output.view(total_s, nhead, v_head_dim)
out_reduce, us_reduce = run_aiter_mla_reduce(
logits,
attn_lse,
reduce_indptr,
reduce_final_map,
reduce_partial_map,
tile_q,
output,
final_lse,
)
output, final_lse = out_reduce
output = output.view(total_s, nhead, v_head_dim)

us_mla_prefill_ps = us_mla_prefill_asm + us_reduce
us_mla_prefill_ps = us_mla_prefill_asm + us_reduce
ret["us_mla_prefill_ps"] = us_mla_prefill_ps

if profile_ps:
# calculate mla_prefill_ps kernel tflops
# for causal, only take the lower triangle(ops/2)
g_div = 2 if is_causal else 1
Expand Down Expand Up @@ -505,30 +504,6 @@ def test_mla_prefill(
ret["us_reduce"] = us_reduce
ret["us_reduce_ratio"] = us_reduce / us_mla_prefill_ps
ret["bw_reduce(TB/s)"] = bw_reduce if effective_final_tiles > 0 else 0
else:
_, us_aiter_asm = run_perftest(
aiter.mla.mla_prefill_ps_fwd,
q_quant,
k_quant,
v_quant,
output,
qo_indptr,
kv_indptr,
kv_indices,
work_indptr,
work_info,
max_qlen,
is_causal,
reduce_indptr,
reduce_final_map,
reduce_partial_map,
softmax_scale,
q_scale,
k_scale,
v_scale,
)

ret["us_mla_prefill_ps"] = us_aiter_asm

if not skip_reference:
# TODO: optimize reference implementation(too slow for large context length)
Expand Down Expand Up @@ -562,6 +537,42 @@ def test_mla_prefill(
ret["err fp8"] = err
ret["acc result"] = status

# LSE validation: final_lse [total_q, nhead] vs lse_ref [nhead, total_q]
# Transpose final_lse to [nhead, total_q] for comparison.
asm_lse = final_lse.transpose(0, 1) # [nhead, total_q]
valid_mask = lse_ref.isfinite()
if valid_mask.any():
asm_lse_valid = asm_lse[valid_mask]
ref_lse_valid = lse_ref[valid_mask]
lse_err = checkAllclose(
ref_lse_valid,
asm_lse_valid,
rtol=5e-2,
atol=5e-2,
msg="mla_prefill_lse [torch vs aiter_asm]: us......",
)
if lse_err == 0:
lse_status = "passed"
elif 0 < lse_err <= 0.05:
lse_status = "warning"
else:
lse_status = "failed"
else:
lse_err = 0
lse_status = "passed"
ret["err lse"] = lse_err
ret["lse result"] = lse_status

# Detailed LSE stats for debugging
if valid_mask.any():
lse_diff = (asm_lse_valid - ref_lse_valid).abs()
ret["lse max_abs_diff"] = lse_diff.max().item()
ret["lse mean_abs_diff"] = lse_diff.mean().item()
lse_denom = ref_lse_valid.abs().clamp(min=1e-6)
lse_rel = lse_diff / lse_denom
ret["lse max_rel_err"] = lse_rel.max().item()
ret["lse mean_rel_err"] = lse_rel.mean().item()

return ret


Expand Down
Loading