diff --git a/lightx2v/models/networks/wan/infer/lingbot_fast/transformer_infer.py b/lightx2v/models/networks/wan/infer/lingbot_fast/transformer_infer.py index f5111186..2d33679d 100755 --- a/lightx2v/models/networks/wan/infer/lingbot_fast/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/lingbot_fast/transformer_infer.py @@ -52,7 +52,11 @@ def _initialize_kv_cache(self, dtype, device): return kv_cache1 = [] ws = self._sp_world_size() - self.kv_cache_size = self._kv_size // ws + if self.local_attn_size != -1: + kv_cache_size = self.local_attn_size * self.frame_seq_length // ws + else: + kv_cache_size = self._kv_size // ws + self.kv_cache_size = kv_cache_size n, d = self.num_heads, self.head_dim if self.kv_quant_config is not None: