Adjust kv_cache_size calculation based on local_attn_size#1008
Adjust kv_cache_size calculation based on local_attn_size#1008
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the KV cache initialization in transformer_infer.py to account for local attention settings. A review comment suggests using self.max_attention_size instead of manual calculation to maintain consistency with configuration overrides and avoid potential runtime crashes or memory waste.
| if self.local_attn_size != -1: | ||
| kv_cache_size = self.local_attn_size * self.frame_seq_length // ws | ||
| else: | ||
| kv_cache_size = self._kv_size // ws | ||
| self.kv_cache_size = kv_cache_size |
There was a problem hiding this comment.
The current logic for calculating kv_cache_size ignores the max_attention_size override from the configuration (cfg_max in reinit_caches). If cfg_max is set to a value larger than the calculated local_attn_size * self.frame_seq_length, the attention mechanism will attempt to access indices outside the allocated KV cache, leading to a crash. Conversely, if cfg_max is smaller, memory is wasted. Since self.max_attention_size already correctly accounts for this override and sequence parallelism when local_attn_size != -1, it should be used to set the cache size to ensure consistency and robustness.
| if self.local_attn_size != -1: | |
| kv_cache_size = self.local_attn_size * self.frame_seq_length // ws | |
| else: | |
| kv_cache_size = self._kv_size // ws | |
| self.kv_cache_size = kv_cache_size | |
| if self.local_attn_size != -1: | |
| self.kv_cache_size = self.max_attention_size | |
| else: | |
| self.kv_cache_size = self._kv_size // ws |
No description provided.