From 31370d12ac93443d7bb2f5c9ddd3bbfa3efeb635 Mon Sep 17 00:00:00 2001 From: Leon Ling Date: Mon, 8 Jun 2026 07:16:56 +0000 Subject: [PATCH] Fix assert in triton fused_kv_cache --- aiter/ops/triton/fusions/fused_kv_cache.py | 28 +++++++++++----------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/aiter/ops/triton/fusions/fused_kv_cache.py b/aiter/ops/triton/fusions/fused_kv_cache.py index ecfb01e197..818a1399e3 100644 --- a/aiter/ops/triton/fusions/fused_kv_cache.py +++ b/aiter/ops/triton/fusions/fused_kv_cache.py @@ -107,7 +107,7 @@ def fused_qk_rope_cat_and_cache_mla( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Perform RoPE on q_pe and k_pe and concat q_nope with q_pe and k_nope with k_pe along the last dimension - the concatentaed k_nope and k_pe are copied to kv_cache inplace + the concatenated k_nope and k_pe are copied to kv_cache inplace Key parameters: - q_nope: Matrix X with shape (B, QH, D1). @@ -117,7 +117,7 @@ def fused_qk_rope_cat_and_cache_mla( - kv_cache: Matrix W with shape (B_cache, KH, D1 + D2). - slot_mapping: Matrix W with shape (B_slot, ). - B is the number of decode tokens, B_slot is the number of prefill + decode tokens, B_cahce is the max number of tokens of kv_cache + B is the number of decode tokens, B_slot is the number of prefill + decode tokens, B_cache is the max number of tokens of kv_cache QH must be multiple of KH Returns: @@ -334,12 +334,12 @@ def fused_qk_rope_reshape_and_cache( zeros_out: torch.Tensor = None, ): """ - Perform RoPE on q and k and along the last dimension and copy k and v in to key_cache and value_cache inplace + Perform RoPE on q and k and along the last dimension and copy k and v into key_cache and value_cache inplace Key parameters: - q: shape (T, QH, D). - - k: shape (T_slot, KH, D). - - v: shape (T_slot, KH, D). + - k: shape (T, KH, D). + - v: shape (T, KH, D). - if flash_layout: - key_cache: shape (T_cache, block_size, KH, D). - value_cache: shape (T_cache, block_size, KH, D). @@ -348,7 +348,7 @@ def fused_qk_rope_reshape_and_cache( - value_cache: shape (T_cache, KH, D, block_size). - slot_mapping: shape (T_slot, ). - T is the number of decode tokens, T_cahce * block_size is the max number of tokens of kv_cache + T is the number of decode tokens, T_cache * block_size is the max number of tokens of kv_cache QH must be multiple of KH Returns: @@ -407,8 +407,8 @@ def fused_qk_rope_reshape_and_cache( (t_slot,) = slot_mapping.shape assert ( - t == tk == tv and t_slot <= tk - ), f"Number of tokens should be identical for q, kand v. The number of tokens of slot_mapping should no more than that of q, k and v, {t=} {tk=} {tv=} {t_slot=}" + t == tk == tv and t <= t_slot + ), f"Number of tokens should be identical for q, kand v. The number of tokens of slot_mapping should no more less that of q, k and v, {t=} {tk=} {tv=} {t_slot=}" assert ( block_size == block_size_v ), f"block size should be identical for key_cache, and value_cache {block_size} {block_size_v}" @@ -623,12 +623,12 @@ def fused_qk_rope_cosine_cache_llama( q_out: torch.Tensor = None, ): """ - Perform RoPE on q and k and along the last dimension and copy k and v in to key_cache and value_cache inplace + Perform RoPE on q and k and along the last dimension and copy k and v into key_cache and value_cache inplace Key parameters: - q: shape (T, QH, D). - - k: shape (T_slot, KH, D). - - v: shape (T_slot, KH, D). + - k: shape (T, KH, D). + - v: shape (T, KH, D). - if flash_layout: - key_cache: shape (T_cache, block_size, KH, D). - value_cache: shape (T_cache, block_size, KH, D). @@ -637,7 +637,7 @@ def fused_qk_rope_cosine_cache_llama( - value_cache: shape (T_cache, KH, D, block_size). - slot_mapping: shape (T_slot, ). - T is the number of decode tokens, T_cahce * block_size is the max number of tokens of kv_cache + T is the number of decode tokens, T_cache * block_size is the max number of tokens of kv_cache QH must be multiple of KH Returns: @@ -662,8 +662,8 @@ def fused_qk_rope_cosine_cache_llama( (t_slot,) = slot_mapping.shape assert ( - t == tk == tv and t_slot <= tk - ), f"Number of tokens should be identical for q, kand v. The number of tokens of slot_mapping should no more than that of q, k and v, {t=} {tk=} {tv=} {t_slot=}" + t == tk == tv and t <= t_slot + ), f"Number of tokens should be identical for q, k and v. The number of tokens of slot_mapping should be no less than that of q, k and v, {t=} {tk=} {tv=} {t_slot=}" assert ( block_size == block_size_v ), f"block size should be identical for key_cache, and value_cache {block_size} {block_size_v}"