diff --git a/src/model/kv_cache.py b/src/model/kv_cache.py index 839c9a5..da244a2 100644 --- a/src/model/kv_cache.py +++ b/src/model/kv_cache.py @@ -185,6 +185,18 @@ def reset(self): layer.reset() self._is_frozen = True + @torch.inference_mode() + def get_state(self): + layers = [(layer.kv.detach().clone(), layer.written.detach().clone()) for layer in self.layers] + return {"_is_frozen": self._is_frozen, "layers": layers} + + @torch.inference_mode() + def load_state(self, state): + self._is_frozen = bool(state.get("_is_frozen", True)) + for layer, (kv, written) in zip(self.layers, state["layers"]): + layer.kv.copy_(kv) + layer.written.copy_(written) + def set_frozen(self, is_frozen: bool): self._is_frozen = is_frozen diff --git a/src/world_engine.py b/src/world_engine.py index 09a65e3..ae56a94 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -93,6 +93,17 @@ def reset(self): for v in self._ctx.values(): v.zero_() + @torch.inference_mode() + def get_state(self): + """Captures a world state to continue via load_state. Doesn't save model""" + return {"kv_cache": self.kv_cache.get_state(), "frame_ts": self.frame_ts.detach().clone()} + + @torch.inference_mode() + def load_state(self, state): + """Loads a world state object saved via save_state. Doesn't load or change model""" + self.kv_cache.load_state(state["kv_cache"]) + self.frame_ts.copy_(state["frame_ts"]) + def set_prompt(self, prompt: str): """Apply text conditioning for T2V""" if self.prompt_encoder is None: