diff --git a/aiter/ops/triton/_triton_kernels/gather_kv_b_proj.py b/aiter/ops/triton/_triton_kernels/gather_kv_b_proj.py index c928da80d2..0fee31a7a3 100644 --- a/aiter/ops/triton/_triton_kernels/gather_kv_b_proj.py +++ b/aiter/ops/triton/_triton_kernels/gather_kv_b_proj.py @@ -58,6 +58,260 @@ def _load_unshuffle_segment( return w +@triton.jit +def _triton_gather_kv_b_proj_fp4( + batch_size, + k_buffer, # [num_block, block_size, kv_c_dim + kv_pe_dim] + k_scale, # [1] or None + kv_indptr, # [batch_size + 1] + kv_indices, # [total_kv] + kv_prefix_sum_context_lens, # [batch_size + 1] + kv_proj_weight, # packed fp4: [tp_heads * (qk_nope_head_dim + v_head_dim), kv_c_dim // 2] + kv_proj_scale, # e8m0 per-1x32: [weight_n, kv_c_dim // 32] + k_prefix, # [total_kv, tp_k_head_num, qk_nope_head_dim + kv_pe_dim] + v_prefix, # [total_kv, tp_k_head_num, v_head_dim] + KBlockSize: tl.constexpr, + TpNumHeads: tl.constexpr, + QkNopeHeadDim: tl.constexpr, + VHeadDim: tl.constexpr, + KV_CDim: tl.constexpr, + KV_PeDim: tl.constexpr, + ChunkK: tl.constexpr, + PaddedK: tl.constexpr, + PaddedV: tl.constexpr, + ScaleCols: tl.constexpr, + WEIGHT_PRESHUFFLE: tl.constexpr = False, +): + """FP4/per-1x32 gather + kv_b_proj expansion for raw MXFP4 weights.""" + stride_k_buffer = tl.full([], KBlockSize * (KV_CDim + KV_PeDim), dtype=tl.int64) + stride_k_prefix = tl.full( + [], TpNumHeads * (QkNopeHeadDim + KV_PeDim), dtype=tl.int64 + ) + stride_v_prefix = tl.full([], TpNumHeads * VHeadDim, dtype=tl.int64) + + ScaleKGranularity: tl.constexpr = 32 + PackedScaleKGranularity: tl.constexpr = ScaleKGranularity // 2 + PackedKV_CDim: tl.constexpr = KV_CDim // 2 + KBlocksPerChunkK: tl.constexpr = ChunkK // KBlockSize + NumKSegments: tl.constexpr = KV_CDim // ScaleKGranularity + + pid = tl.program_id(0) + pid_batch = pid // TpNumHeads + pid_head = pid % TpNumHeads + + kv_block_start = tl.load(kv_indptr + pid_batch) + kv_block_end = tl.load(kv_indptr + pid_batch + 1) + + context_start = tl.load(kv_prefix_sum_context_lens + pid_batch) + context_end = tl.load(kv_prefix_sum_context_lens + pid_batch + 1) + + total_kv_block = kv_block_end - kv_block_start + total_kv_chunk = (total_kv_block + KBlocksPerChunkK - 1) // KBlocksPerChunkK + + k_type = k_buffer.dtype.element_ty + if k_type == tl.bfloat16: + k_scalar_scale = 1.0 + else: + k_scalar_scale = tl.load(k_scale) + + offs_n_k = tl.arange(0, PaddedK) + offs_n_v = tl.arange(0, PaddedV) + mask_k = offs_n_k < QkNopeHeadDim + mask_v = offs_n_v < VHeadDim + offs_k = tl.arange(0, ScaleKGranularity) + offs_k_packed = tl.arange(0, PackedScaleKGranularity) + + head_row_base = pid_head * (QkNopeHeadDim + VHeadDim) + k_abs_rows = head_row_base + offs_n_k + v_abs_rows = head_row_base + QkNopeHeadDim + offs_n_v + + for chunk_id in range(total_kv_chunk): + block_lane_valid = ( + chunk_id * KBlocksPerChunkK + tl.arange(0, ChunkK) // KBlockSize + < total_kv_block + ) + kv_block_idx = tl.load( + kv_indices + + kv_block_start + + chunk_id * KBlocksPerChunkK + + tl.arange(0, ChunkK) // KBlockSize, + mask=block_lane_valid, + other=0, + ) + + accum_k = tl.zeros((ChunkK, PaddedK), dtype=tl.float32) + accum_v = tl.zeros((ChunkK, PaddedV), dtype=tl.float32) + row_mask = block_lane_valid[:, None] + + for seg in range(NumKSegments): + kv_c_data = tl.load( + k_buffer + + kv_block_idx[:, None] * stride_k_buffer + + tl.arange(0, ChunkK)[:, None] % KBlockSize * (KV_CDim + KV_PeDim) + + seg * ScaleKGranularity + + offs_k[None, :], + mask=row_mask, + other=0.0, + ).to(tl.bfloat16) + + packed_cols = seg * PackedScaleKGranularity + offs_k_packed[:, None] + + if WEIGHT_PRESHUFFLE: + # Inverse of aiter.ops.shuffle.shuffle_weight for uint8/packed-fp4 + # tensors with layout=(16, 16). It maps logical [N, K//2] + # coordinates back into the preshuffled storage. + k_n0 = k_abs_rows[None, :] // 16 + k_bn = k_abs_rows[None, :] % 16 + k_k0 = packed_cols // 32 + k_r = (packed_cols % 32) // 16 + k_c = packed_cols % 16 + k_weight_offset = ( + ((k_n0 * (PackedKV_CDim // 32) + k_k0) * 2 + k_r) * 16 + k_bn + ) * 16 + k_c + + v_n0 = v_abs_rows[None, :] // 16 + v_bn = v_abs_rows[None, :] % 16 + v_k0 = packed_cols // 32 + v_r = (packed_cols % 32) // 16 + v_c = packed_cols % 16 + v_weight_offset = ( + ((v_n0 * (PackedKV_CDim // 32) + v_k0) * 2 + v_r) * 16 + v_bn + ) * 16 + v_c + else: + k_weight_offset = k_abs_rows[None, :] * PackedKV_CDim + packed_cols + v_weight_offset = v_abs_rows[None, :] * PackedKV_CDim + packed_cols + + k_weight = tl.load( + kv_proj_weight + k_weight_offset, + mask=mask_k[None, :], + other=0, + ) + v_weight = tl.load( + kv_proj_weight + v_weight_offset, + mask=mask_v[None, :], + other=0, + ) + + if WEIGHT_PRESHUFFLE: + # Inverse of fp4_utils.e8m0_shuffle / shuffle_scale. + k_scale_n = k_abs_rows[:, None] + k_scale_g = seg + k_scale_a = k_scale_n // 32 + k_scale_b = (k_scale_n % 32) // 16 + k_scale_c = k_scale_n % 16 + k_scale_d = k_scale_g // 8 + k_scale_e = (k_scale_g % 8) // 4 + k_scale_f = k_scale_g % 4 + k_scale_offset = ( + ( + ((k_scale_a * (ScaleCols // 8) + k_scale_d) * 4 + k_scale_f) + * 16 + + k_scale_c + ) + * 2 + + k_scale_e + ) * 2 + k_scale_b + + v_scale_n = v_abs_rows[:, None] + v_scale_g = seg + v_scale_a = v_scale_n // 32 + v_scale_b = (v_scale_n % 32) // 16 + v_scale_c = v_scale_n % 16 + v_scale_d = v_scale_g // 8 + v_scale_e = (v_scale_g % 8) // 4 + v_scale_f = v_scale_g % 4 + v_scale_offset = ( + ( + ((v_scale_a * (ScaleCols // 8) + v_scale_d) * 4 + v_scale_f) + * 16 + + v_scale_c + ) + * 2 + + v_scale_e + ) * 2 + v_scale_b + else: + k_scale_offset = k_abs_rows[:, None] * ScaleCols + seg + v_scale_offset = v_abs_rows[:, None] * ScaleCols + seg + + k_weight_scale = tl.load( + kv_proj_scale + k_scale_offset, + mask=mask_k[:, None], + other=0, + ) + v_weight_scale = tl.load( + kv_proj_scale + v_scale_offset, + mask=mask_v[:, None], + other=0, + ) + + accum_k = tl.dot_scaled( + kv_c_data, + None, + "bf16", + k_weight, + k_weight_scale, + "e2m1", + acc=accum_k, + fast_math=True, + ) + accum_v = tl.dot_scaled( + kv_c_data, + None, + "bf16", + v_weight, + v_weight_scale, + "e2m1", + acc=accum_v, + fast_math=True, + ) + + kv_pe_data = tl.load( + k_buffer + + kv_block_idx[:, None] * stride_k_buffer + + tl.arange(0, ChunkK)[:, None] % KBlockSize * (KV_CDim + KV_PeDim) + + KV_CDim + + tl.arange(0, KV_PeDim)[None, :], + mask=row_mask, + other=0.0, + ) + + accum_k *= k_scalar_scale + accum_v *= k_scalar_scale + kv_pe_data *= k_scalar_scale + + context_mask = ( + context_start + chunk_id * ChunkK + tl.arange(0, ChunkK) < context_end + ) + tl.store( + k_prefix + + (context_start + chunk_id * ChunkK + tl.arange(0, ChunkK))[:, None] + * stride_k_prefix + + pid_head * (QkNopeHeadDim + KV_PeDim) + + QkNopeHeadDim + + tl.arange(0, KV_PeDim)[None, :], + kv_pe_data, + mask=context_mask[:, None], + ) + tl.store( + k_prefix + + (context_start + chunk_id * ChunkK + tl.arange(0, ChunkK))[:, None] + * stride_k_prefix + + pid_head * (QkNopeHeadDim + KV_PeDim) + + offs_n_k[None, :], + accum_k, + mask=context_mask[:, None] & mask_k[None, :], + ) + tl.store( + v_prefix + + (context_start + chunk_id * ChunkK + tl.arange(0, ChunkK))[:, None] + * stride_v_prefix + + pid_head * VHeadDim + + offs_n_v[None, :], + accum_v, + mask=context_mask[:, None] & mask_v[None, :], + ) + + @triton.jit def _triton_gather_kv_b_proj( batch_size, diff --git a/aiter/ops/triton/gather_kv_b_proj.py b/aiter/ops/triton/gather_kv_b_proj.py index c0d9579d35..15176f1bec 100644 --- a/aiter/ops/triton/gather_kv_b_proj.py +++ b/aiter/ops/triton/gather_kv_b_proj.py @@ -6,6 +6,7 @@ from aiter.ops.triton._triton_kernels.gather_kv_b_proj import ( _next_pow2, _triton_gather_kv_b_proj, + _triton_gather_kv_b_proj_fp4, ) @@ -23,7 +24,12 @@ def gather_kv_b_proj( ): num_block, block_size, hidden_dim = k_buffer.shape batch_size = kv_indptr.shape[0] - 1 - weight_n, weight_k = kv_proj_weight.shape + weight_n, packed_weight_k = kv_proj_weight.shape + fp4_weight_dtype = getattr(torch, "float4_e2m1fn_x2", None) + is_fp4_weight = ( + fp4_weight_dtype is not None and kv_proj_weight.dtype == fp4_weight_dtype + ) + weight_k = packed_weight_k * 2 if is_fp4_weight else packed_weight_k total_kv_k, tp_k_head_num_k, qk_nope_pe_dim = k_prefix.shape total_kv_v, tp_k_head_num_v, v_head_dim = v_prefix.shape @@ -55,8 +61,28 @@ def gather_kv_b_proj( scale_n, scale_k = kv_proj_scale.shape scale_k_granularity = weight_k // scale_k scale_n_granularity = weight_n // scale_n - assert scale_k_granularity == 128 - assert scale_n_granularity == 128 + if is_fp4_weight: + if weight_preshuffle: + assert scale_k >= (weight_k + 31) // 32, ( + "Preshuffled FP4 gather_kv_b_proj expects padded per-1x32 scale columns, " + f"got scale cols {scale_k} for logical K {weight_k}" + ) + assert scale_n >= weight_n, ( + "Preshuffled FP4 gather_kv_b_proj expects padded per-output-row MXFP4 scales, " + f"got scale rows {scale_n} for weight rows {weight_n}" + ) + else: + assert scale_k_granularity == 32, ( + "FP4 gather_kv_b_proj expects per-1x32 weight scales, " + f"got K granularity {scale_k_granularity}" + ) + assert scale_n_granularity == 1, ( + "FP4 gather_kv_b_proj expects per-output-row MXFP4 scales, " + f"got N granularity {scale_n_granularity}" + ) + else: + assert scale_k_granularity == 128 + assert scale_n_granularity == 128 ChunkK = 16 if k_buffer.dtype in [torch.float16, torch.bfloat16] else 32 @@ -68,6 +94,33 @@ def gather_kv_b_proj( padded_v = _next_pow2(v_head_dim) grid = (batch_size * tp_k_head_num_k,) + if is_fp4_weight: + _triton_gather_kv_b_proj_fp4[grid]( + batch_size, + k_buffer, + k_scale, + kv_indptr, + kv_indices, + kv_prefix_sum_context_lens, + kv_proj_weight.view(torch.uint8), + kv_proj_scale.view(torch.uint8), + k_prefix, + v_prefix, + KBlockSize=block_size, + TpNumHeads=tp_k_head_num_k, + QkNopeHeadDim=qk_nope_head_dim, + VHeadDim=v_head_dim, + KV_CDim=weight_k, + KV_PeDim=qk_nope_pe_dim - qk_nope_head_dim, + ChunkK=ChunkK, + PaddedK=padded_k, + PaddedV=padded_v, + ScaleCols=scale_k if not no_scale and not per_row_scale else 1, + WEIGHT_PRESHUFFLE=weight_preshuffle, + num_stages=3, + ) + return + _triton_gather_kv_b_proj[grid]( batch_size, k_buffer,