Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 254 additions & 0 deletions aiter/ops/triton/_triton_kernels/gather_kv_b_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,260 @@ def _load_unshuffle_segment(
return w


@triton.jit
def _triton_gather_kv_b_proj_fp4(
batch_size,
k_buffer, # [num_block, block_size, kv_c_dim + kv_pe_dim]
k_scale, # [1] or None
kv_indptr, # [batch_size + 1]
kv_indices, # [total_kv]
kv_prefix_sum_context_lens, # [batch_size + 1]
kv_proj_weight, # packed fp4: [tp_heads * (qk_nope_head_dim + v_head_dim), kv_c_dim // 2]
kv_proj_scale, # e8m0 per-1x32: [weight_n, kv_c_dim // 32]
k_prefix, # [total_kv, tp_k_head_num, qk_nope_head_dim + kv_pe_dim]
v_prefix, # [total_kv, tp_k_head_num, v_head_dim]
KBlockSize: tl.constexpr,
TpNumHeads: tl.constexpr,
QkNopeHeadDim: tl.constexpr,
VHeadDim: tl.constexpr,
KV_CDim: tl.constexpr,
KV_PeDim: tl.constexpr,
ChunkK: tl.constexpr,
PaddedK: tl.constexpr,
PaddedV: tl.constexpr,
ScaleCols: tl.constexpr,
WEIGHT_PRESHUFFLE: tl.constexpr = False,
):
"""FP4/per-1x32 gather + kv_b_proj expansion for raw MXFP4 weights."""
stride_k_buffer = tl.full([], KBlockSize * (KV_CDim + KV_PeDim), dtype=tl.int64)
stride_k_prefix = tl.full(
[], TpNumHeads * (QkNopeHeadDim + KV_PeDim), dtype=tl.int64
)
stride_v_prefix = tl.full([], TpNumHeads * VHeadDim, dtype=tl.int64)

ScaleKGranularity: tl.constexpr = 32
PackedScaleKGranularity: tl.constexpr = ScaleKGranularity // 2
PackedKV_CDim: tl.constexpr = KV_CDim // 2
KBlocksPerChunkK: tl.constexpr = ChunkK // KBlockSize
NumKSegments: tl.constexpr = KV_CDim // ScaleKGranularity

pid = tl.program_id(0)
pid_batch = pid // TpNumHeads
pid_head = pid % TpNumHeads

kv_block_start = tl.load(kv_indptr + pid_batch)
kv_block_end = tl.load(kv_indptr + pid_batch + 1)

context_start = tl.load(kv_prefix_sum_context_lens + pid_batch)
context_end = tl.load(kv_prefix_sum_context_lens + pid_batch + 1)

total_kv_block = kv_block_end - kv_block_start
total_kv_chunk = (total_kv_block + KBlocksPerChunkK - 1) // KBlocksPerChunkK

k_type = k_buffer.dtype.element_ty
if k_type == tl.bfloat16:
k_scalar_scale = 1.0
else:
k_scalar_scale = tl.load(k_scale)

offs_n_k = tl.arange(0, PaddedK)
offs_n_v = tl.arange(0, PaddedV)
mask_k = offs_n_k < QkNopeHeadDim
mask_v = offs_n_v < VHeadDim
offs_k = tl.arange(0, ScaleKGranularity)
offs_k_packed = tl.arange(0, PackedScaleKGranularity)

head_row_base = pid_head * (QkNopeHeadDim + VHeadDim)
k_abs_rows = head_row_base + offs_n_k
v_abs_rows = head_row_base + QkNopeHeadDim + offs_n_v

for chunk_id in range(total_kv_chunk):
block_lane_valid = (
chunk_id * KBlocksPerChunkK + tl.arange(0, ChunkK) // KBlockSize
< total_kv_block
)
kv_block_idx = tl.load(
kv_indices
+ kv_block_start
+ chunk_id * KBlocksPerChunkK
+ tl.arange(0, ChunkK) // KBlockSize,
mask=block_lane_valid,
other=0,
)

accum_k = tl.zeros((ChunkK, PaddedK), dtype=tl.float32)
accum_v = tl.zeros((ChunkK, PaddedV), dtype=tl.float32)
row_mask = block_lane_valid[:, None]

for seg in range(NumKSegments):
kv_c_data = tl.load(
k_buffer
+ kv_block_idx[:, None] * stride_k_buffer
+ tl.arange(0, ChunkK)[:, None] % KBlockSize * (KV_CDim + KV_PeDim)
+ seg * ScaleKGranularity
+ offs_k[None, :],
mask=row_mask,
other=0.0,
).to(tl.bfloat16)

packed_cols = seg * PackedScaleKGranularity + offs_k_packed[:, None]

if WEIGHT_PRESHUFFLE:
# Inverse of aiter.ops.shuffle.shuffle_weight for uint8/packed-fp4
# tensors with layout=(16, 16). It maps logical [N, K//2]
# coordinates back into the preshuffled storage.
k_n0 = k_abs_rows[None, :] // 16
k_bn = k_abs_rows[None, :] % 16
k_k0 = packed_cols // 32
k_r = (packed_cols % 32) // 16
k_c = packed_cols % 16
k_weight_offset = (
((k_n0 * (PackedKV_CDim // 32) + k_k0) * 2 + k_r) * 16 + k_bn
) * 16 + k_c

v_n0 = v_abs_rows[None, :] // 16
v_bn = v_abs_rows[None, :] % 16
v_k0 = packed_cols // 32
v_r = (packed_cols % 32) // 16
v_c = packed_cols % 16
v_weight_offset = (
((v_n0 * (PackedKV_CDim // 32) + v_k0) * 2 + v_r) * 16 + v_bn
) * 16 + v_c
else:
k_weight_offset = k_abs_rows[None, :] * PackedKV_CDim + packed_cols
v_weight_offset = v_abs_rows[None, :] * PackedKV_CDim + packed_cols

k_weight = tl.load(
kv_proj_weight + k_weight_offset,
mask=mask_k[None, :],
other=0,
)
v_weight = tl.load(
kv_proj_weight + v_weight_offset,
mask=mask_v[None, :],
other=0,
)

if WEIGHT_PRESHUFFLE:
# Inverse of fp4_utils.e8m0_shuffle / shuffle_scale.
k_scale_n = k_abs_rows[:, None]
k_scale_g = seg
k_scale_a = k_scale_n // 32
k_scale_b = (k_scale_n % 32) // 16
k_scale_c = k_scale_n % 16
k_scale_d = k_scale_g // 8
k_scale_e = (k_scale_g % 8) // 4
k_scale_f = k_scale_g % 4
k_scale_offset = (
(
((k_scale_a * (ScaleCols // 8) + k_scale_d) * 4 + k_scale_f)
* 16
+ k_scale_c
)
* 2
+ k_scale_e
) * 2 + k_scale_b

v_scale_n = v_abs_rows[:, None]
v_scale_g = seg
v_scale_a = v_scale_n // 32
v_scale_b = (v_scale_n % 32) // 16
v_scale_c = v_scale_n % 16
v_scale_d = v_scale_g // 8
v_scale_e = (v_scale_g % 8) // 4
v_scale_f = v_scale_g % 4
v_scale_offset = (
(
((v_scale_a * (ScaleCols // 8) + v_scale_d) * 4 + v_scale_f)
* 16
+ v_scale_c
)
* 2
+ v_scale_e
) * 2 + v_scale_b
else:
k_scale_offset = k_abs_rows[:, None] * ScaleCols + seg
v_scale_offset = v_abs_rows[:, None] * ScaleCols + seg

k_weight_scale = tl.load(
kv_proj_scale + k_scale_offset,
mask=mask_k[:, None],
other=0,
)
v_weight_scale = tl.load(
kv_proj_scale + v_scale_offset,
mask=mask_v[:, None],
other=0,
)

accum_k = tl.dot_scaled(
kv_c_data,
None,
"bf16",
k_weight,
k_weight_scale,
"e2m1",
acc=accum_k,
fast_math=True,
)
accum_v = tl.dot_scaled(
kv_c_data,
None,
"bf16",
v_weight,
v_weight_scale,
"e2m1",
acc=accum_v,
fast_math=True,
)

kv_pe_data = tl.load(
k_buffer
+ kv_block_idx[:, None] * stride_k_buffer
+ tl.arange(0, ChunkK)[:, None] % KBlockSize * (KV_CDim + KV_PeDim)
+ KV_CDim
+ tl.arange(0, KV_PeDim)[None, :],
mask=row_mask,
other=0.0,
)

accum_k *= k_scalar_scale
accum_v *= k_scalar_scale
kv_pe_data *= k_scalar_scale

context_mask = (
context_start + chunk_id * ChunkK + tl.arange(0, ChunkK) < context_end
)
tl.store(
k_prefix
+ (context_start + chunk_id * ChunkK + tl.arange(0, ChunkK))[:, None]
* stride_k_prefix
+ pid_head * (QkNopeHeadDim + KV_PeDim)
+ QkNopeHeadDim
+ tl.arange(0, KV_PeDim)[None, :],
kv_pe_data,
mask=context_mask[:, None],
)
tl.store(
k_prefix
+ (context_start + chunk_id * ChunkK + tl.arange(0, ChunkK))[:, None]
* stride_k_prefix
+ pid_head * (QkNopeHeadDim + KV_PeDim)
+ offs_n_k[None, :],
accum_k,
mask=context_mask[:, None] & mask_k[None, :],
)
tl.store(
v_prefix
+ (context_start + chunk_id * ChunkK + tl.arange(0, ChunkK))[:, None]
* stride_v_prefix
+ pid_head * VHeadDim
+ offs_n_v[None, :],
accum_v,
mask=context_mask[:, None] & mask_v[None, :],
)


@triton.jit
def _triton_gather_kv_b_proj(
batch_size,
Expand Down
59 changes: 56 additions & 3 deletions aiter/ops/triton/gather_kv_b_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from aiter.ops.triton._triton_kernels.gather_kv_b_proj import (
_next_pow2,
_triton_gather_kv_b_proj,
_triton_gather_kv_b_proj_fp4,
)


Expand All @@ -23,7 +24,12 @@ def gather_kv_b_proj(
):
num_block, block_size, hidden_dim = k_buffer.shape
batch_size = kv_indptr.shape[0] - 1
weight_n, weight_k = kv_proj_weight.shape
weight_n, packed_weight_k = kv_proj_weight.shape
fp4_weight_dtype = getattr(torch, "float4_e2m1fn_x2", None)
is_fp4_weight = (
fp4_weight_dtype is not None and kv_proj_weight.dtype == fp4_weight_dtype
)
weight_k = packed_weight_k * 2 if is_fp4_weight else packed_weight_k
total_kv_k, tp_k_head_num_k, qk_nope_pe_dim = k_prefix.shape
total_kv_v, tp_k_head_num_v, v_head_dim = v_prefix.shape

Expand Down Expand Up @@ -55,8 +61,28 @@ def gather_kv_b_proj(
scale_n, scale_k = kv_proj_scale.shape
scale_k_granularity = weight_k // scale_k
scale_n_granularity = weight_n // scale_n
assert scale_k_granularity == 128
assert scale_n_granularity == 128
if is_fp4_weight:
if weight_preshuffle:
assert scale_k >= (weight_k + 31) // 32, (
"Preshuffled FP4 gather_kv_b_proj expects padded per-1x32 scale columns, "
f"got scale cols {scale_k} for logical K {weight_k}"
)
assert scale_n >= weight_n, (
"Preshuffled FP4 gather_kv_b_proj expects padded per-output-row MXFP4 scales, "
f"got scale rows {scale_n} for weight rows {weight_n}"
)
else:
assert scale_k_granularity == 32, (
"FP4 gather_kv_b_proj expects per-1x32 weight scales, "
f"got K granularity {scale_k_granularity}"
)
assert scale_n_granularity == 1, (
"FP4 gather_kv_b_proj expects per-output-row MXFP4 scales, "
f"got N granularity {scale_n_granularity}"
)
else:
assert scale_k_granularity == 128
assert scale_n_granularity == 128

ChunkK = 16 if k_buffer.dtype in [torch.float16, torch.bfloat16] else 32

Expand All @@ -68,6 +94,33 @@ def gather_kv_b_proj(
padded_v = _next_pow2(v_head_dim)

grid = (batch_size * tp_k_head_num_k,)
if is_fp4_weight:
_triton_gather_kv_b_proj_fp4[grid](
batch_size,
k_buffer,
k_scale,
kv_indptr,
kv_indices,
kv_prefix_sum_context_lens,
kv_proj_weight.view(torch.uint8),
kv_proj_scale.view(torch.uint8),
k_prefix,
v_prefix,
KBlockSize=block_size,
TpNumHeads=tp_k_head_num_k,
QkNopeHeadDim=qk_nope_head_dim,
VHeadDim=v_head_dim,
KV_CDim=weight_k,
KV_PeDim=qk_nope_pe_dim - qk_nope_head_dim,
ChunkK=ChunkK,
PaddedK=padded_k,
PaddedV=padded_v,
ScaleCols=scale_k if not no_scale and not per_row_scale else 1,
WEIGHT_PRESHUFFLE=weight_preshuffle,
num_stages=3,
)
return

_triton_gather_kv_b_proj[grid](
batch_size,
k_buffer,
Expand Down
Loading