diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index ceb89e0de..29e15198f 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -89,11 +89,13 @@ def to_dict(self) -> dict: "user_id": self.user_id, "cube_id": self.mem_cube_id, "trace_id": self.trace_id, + "session_id": self.session_id, "label": self.label, "cube": "Not Applicable", # Custom cube serialization "content": self.content, "timestamp": self.timestamp.isoformat(), "user_name": self.user_name, + "info": self.info if self.info is not None else {}, "task_id": self.task_id if self.task_id is not None else "", "api_path": self.api_path if self.api_path is not None else "", "chat_history": self.chat_history if self.chat_history is not None else [], @@ -132,6 +134,18 @@ def _decode(val: Any) -> Any: else: chat_history = raw_chat_history + raw_info = _decode(data.get("info")) + if isinstance(raw_info, str): + if raw_info: + try: + info = json.loads(raw_info) + except Exception: + info = None + else: + info = None + else: + info = raw_info + raw_user_context = _decode(data.get("user_context")) if isinstance(raw_user_context, str): if raw_user_context: @@ -149,11 +163,13 @@ def _decode(val: Any) -> Any: user_id=_decode(data["user_id"]), mem_cube_id=_decode(data["cube_id"]), trace_id=_decode(data.get("trace_id", generate_trace_id())), + session_id=_decode(data.get("session_id", "")), label=_decode(data["label"]), content=_decode(data["content"]), timestamp=timestamp, user_name=_decode(data.get("user_name")), - task_id=_decode(data.get("task_id")), + info=info, + task_id=_decode(data.get("task_id")) or None, api_path=_decode(data.get("api_path")), chat_history=chat_history, user_context=UserContext.model_validate(raw_user_context) if raw_user_context else None, diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 561d7931f..79f40def4 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -124,7 +124,9 @@ def __init__( self.seen_streams = set() # Task Orchestrator — cap in-memory cache to avoid unbounded growth - self._cache_max_packs = int(os.getenv("MEMSCHEDULER_REDIS_CACHE_MAX_PACKS", "50") or 50) + self._cache_max_packs = max( + 1, int(os.getenv("MEMSCHEDULER_REDIS_CACHE_MAX_PACKS", "50") or 50) + ) self.message_pack_cache: deque[list[ScheduleMessageItem]] = deque( maxlen=self._cache_max_packs ) @@ -138,6 +140,7 @@ def __init__( self._stream_keys_lock = threading.Lock() self._stream_keys_refresh_thread: ContextThread | None = None self._stream_keys_refresh_stop_event = threading.Event() + self._stream_read_offset = 0 self._initial_scan_max_keys = int( os.getenv("MEMSCHEDULER_REDIS_INITIAL_SCAN_MAX_KEYS", "1000") or 1000 ) @@ -315,6 +318,20 @@ def task_broker( stream_keys = self.get_stream_keys(stream_key_prefix=self.stream_key_prefix) if not stream_keys: return [] + stream_key_count = len(stream_keys) + if stream_key_count > self._cache_max_packs: + start = self._stream_read_offset % stream_key_count + end = start + self._cache_max_packs + if end <= stream_key_count: + stream_keys = stream_keys[start:end] + else: + stream_keys = stream_keys[start:] + stream_keys[: end % stream_key_count] + self._stream_read_offset = (start + self._cache_max_packs) % stream_key_count + logger.debug( + "[REDIS_QUEUE] Broker stream scan capped. scanned_streams=%s cache_max_packs=%s", + len(stream_keys), + self._cache_max_packs, + ) # Determine per-stream quotas for this cycle stream_quotas = self.orchestrator.get_stream_quotas( @@ -353,6 +370,28 @@ def task_broker( if claimed_messages: messages.extend(claimed_messages) + max_cached_messages = max(consume_batch_size, consume_batch_size * self._cache_max_packs) + limited_messages: list[tuple[str, list[tuple[str, dict]]]] = [] + remaining = max_cached_messages + for stream_key, stream_messages in messages: + if remaining <= 0: + break + if len(stream_messages) <= remaining: + limited_messages.append((stream_key, stream_messages)) + remaining -= len(stream_messages) + else: + limited_messages.append((stream_key, stream_messages[:remaining])) + remaining = 0 + if remaining == 0 and len(limited_messages) < len(messages): + logger.debug( + "[REDIS_QUEUE] Broker prefetch capped. streams=%s capped_messages=%s cache_max_packs=%s consume_batch=%s", + len(messages), + max_cached_messages, + self._cache_max_packs, + consume_batch_size, + ) + messages = limited_messages + cache: list[ScheduleMessageItem] = self._convert_messages(messages) # pack messages @@ -400,9 +439,9 @@ def _is_refill_thread_available(self) -> bool: return True if (time.time() - self._refill_thread_start) > self._refill_thread_timeout: logger.warning( - f"Refill thread has been running for >{self._refill_thread_timeout}s, treating as stale" + f"Refill thread has been running for >{self._refill_thread_timeout}s; " + "skip starting another refill thread to avoid duplicate memory growth" ) - return True return False def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: