Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def _initialize_kv_cache(self, dtype, device):
return
kv_cache1 = []
ws = self._sp_world_size()
self.kv_cache_size = self._kv_size // ws
if self.local_attn_size != -1:
kv_cache_size = self.local_attn_size * self.frame_seq_length // ws
else:
kv_cache_size = self._kv_size // ws
self.kv_cache_size = kv_cache_size
Comment on lines +55 to +59
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current logic for calculating kv_cache_size ignores the max_attention_size override from the configuration (cfg_max in reinit_caches). If cfg_max is set to a value larger than the calculated local_attn_size * self.frame_seq_length, the attention mechanism will attempt to access indices outside the allocated KV cache, leading to a crash. Conversely, if cfg_max is smaller, memory is wasted. Since self.max_attention_size already correctly accounts for this override and sequence parallelism when local_attn_size != -1, it should be used to set the cache size to ensure consistency and robustness.

Suggested change
if self.local_attn_size != -1:
kv_cache_size = self.local_attn_size * self.frame_seq_length // ws
else:
kv_cache_size = self._kv_size // ws
self.kv_cache_size = kv_cache_size
if self.local_attn_size != -1:
self.kv_cache_size = self.max_attention_size
else:
self.kv_cache_size = self._kv_size // ws


n, d = self.num_heads, self.head_dim
if self.kv_quant_config is not None:
Expand Down
Loading