Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions aiter/ops/triton/fusions/fused_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def fused_qk_rope_cat_and_cache_mla(
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Perform RoPE on q_pe and k_pe and concat q_nope with q_pe and k_nope with k_pe along the last dimension
the concatentaed k_nope and k_pe are copied to kv_cache inplace
the concatenated k_nope and k_pe are copied to kv_cache inplace

Key parameters:
- q_nope: Matrix X with shape (B, QH, D1).
Expand All @@ -117,7 +117,7 @@ def fused_qk_rope_cat_and_cache_mla(
- kv_cache: Matrix W with shape (B_cache, KH, D1 + D2).
- slot_mapping: Matrix W with shape (B_slot, ).

B is the number of decode tokens, B_slot is the number of prefill + decode tokens, B_cahce is the max number of tokens of kv_cache
B is the number of decode tokens, B_slot is the number of prefill + decode tokens, B_cache is the max number of tokens of kv_cache
QH must be multiple of KH

Returns:
Expand Down Expand Up @@ -334,12 +334,12 @@ def fused_qk_rope_reshape_and_cache(
zeros_out: torch.Tensor = None,
):
"""
Perform RoPE on q and k and along the last dimension and copy k and v in to key_cache and value_cache inplace
Perform RoPE on q and k and along the last dimension and copy k and v into key_cache and value_cache inplace

Key parameters:
- q: shape (T, QH, D).
- k: shape (T_slot, KH, D).
- v: shape (T_slot, KH, D).
- k: shape (T, KH, D).
- v: shape (T, KH, D).
- if flash_layout:
- key_cache: shape (T_cache, block_size, KH, D).
- value_cache: shape (T_cache, block_size, KH, D).
Expand All @@ -348,7 +348,7 @@ def fused_qk_rope_reshape_and_cache(
- value_cache: shape (T_cache, KH, D, block_size).
- slot_mapping: shape (T_slot, ).

T is the number of decode tokens, T_cahce * block_size is the max number of tokens of kv_cache
T is the number of decode tokens, T_cache * block_size is the max number of tokens of kv_cache
QH must be multiple of KH

Returns:
Expand Down Expand Up @@ -407,8 +407,8 @@ def fused_qk_rope_reshape_and_cache(
(t_slot,) = slot_mapping.shape

assert (
t == tk == tv and t_slot <= tk
), f"Number of tokens should be identical for q, kand v. The number of tokens of slot_mapping should no more than that of q, k and v, {t=} {tk=} {tv=} {t_slot=}"
t == tk == tv and t <= t_slot
), f"Number of tokens should be identical for q, kand v. The number of tokens of slot_mapping should no more less that of q, k and v, {t=} {tk=} {tv=} {t_slot=}"
assert (
block_size == block_size_v
), f"block size should be identical for key_cache, and value_cache {block_size} {block_size_v}"
Expand Down Expand Up @@ -623,12 +623,12 @@ def fused_qk_rope_cosine_cache_llama(
q_out: torch.Tensor = None,
):
"""
Perform RoPE on q and k and along the last dimension and copy k and v in to key_cache and value_cache inplace
Perform RoPE on q and k and along the last dimension and copy k and v into key_cache and value_cache inplace

Key parameters:
- q: shape (T, QH, D).
- k: shape (T_slot, KH, D).
- v: shape (T_slot, KH, D).
- k: shape (T, KH, D).
- v: shape (T, KH, D).
- if flash_layout:
- key_cache: shape (T_cache, block_size, KH, D).
- value_cache: shape (T_cache, block_size, KH, D).
Expand All @@ -637,7 +637,7 @@ def fused_qk_rope_cosine_cache_llama(
- value_cache: shape (T_cache, KH, D, block_size).
- slot_mapping: shape (T_slot, ).

T is the number of decode tokens, T_cahce * block_size is the max number of tokens of kv_cache
T is the number of decode tokens, T_cache * block_size is the max number of tokens of kv_cache
QH must be multiple of KH

Returns:
Expand All @@ -662,8 +662,8 @@ def fused_qk_rope_cosine_cache_llama(
(t_slot,) = slot_mapping.shape

assert (
t == tk == tv and t_slot <= tk
), f"Number of tokens should be identical for q, kand v. The number of tokens of slot_mapping should no more than that of q, k and v, {t=} {tk=} {tv=} {t_slot=}"
t == tk == tv and t <= t_slot
), f"Number of tokens should be identical for q, k and v. The number of tokens of slot_mapping should be no less than that of q, k and v, {t=} {tk=} {tv=} {t_slot=}"
assert (
block_size == block_size_v
), f"block size should be identical for key_cache, and value_cache {block_size} {block_size_v}"
Expand Down
Loading