Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/memos/mem_scheduler/schemas/message_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,13 @@ def to_dict(self) -> dict:
"user_id": self.user_id,
"cube_id": self.mem_cube_id,
"trace_id": self.trace_id,
"session_id": self.session_id,
"label": self.label,
"cube": "Not Applicable", # Custom cube serialization
"content": self.content,
"timestamp": self.timestamp.isoformat(),
"user_name": self.user_name,
"info": self.info if self.info is not None else {},
"task_id": self.task_id if self.task_id is not None else "",
"api_path": self.api_path if self.api_path is not None else "",
"chat_history": self.chat_history if self.chat_history is not None else [],
Expand Down Expand Up @@ -132,6 +134,18 @@ def _decode(val: Any) -> Any:
else:
chat_history = raw_chat_history

raw_info = _decode(data.get("info"))
if isinstance(raw_info, str):
if raw_info:
try:
info = json.loads(raw_info)
except Exception:
info = None
else:
info = None
else:
info = raw_info

raw_user_context = _decode(data.get("user_context"))
if isinstance(raw_user_context, str):
if raw_user_context:
Expand All @@ -149,11 +163,13 @@ def _decode(val: Any) -> Any:
user_id=_decode(data["user_id"]),
mem_cube_id=_decode(data["cube_id"]),
trace_id=_decode(data.get("trace_id", generate_trace_id())),
session_id=_decode(data.get("session_id", "")),
label=_decode(data["label"]),
content=_decode(data["content"]),
timestamp=timestamp,
user_name=_decode(data.get("user_name")),
task_id=_decode(data.get("task_id")),
info=info,
task_id=_decode(data.get("task_id")) or None,
api_path=_decode(data.get("api_path")),
chat_history=chat_history,
user_context=UserContext.model_validate(raw_user_context) if raw_user_context else None,
Expand Down
45 changes: 42 additions & 3 deletions src/memos/mem_scheduler/task_schedule_modules/redis_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def __init__(
self.seen_streams = set()

# Task Orchestrator — cap in-memory cache to avoid unbounded growth
self._cache_max_packs = int(os.getenv("MEMSCHEDULER_REDIS_CACHE_MAX_PACKS", "50") or 50)
self._cache_max_packs = max(
1, int(os.getenv("MEMSCHEDULER_REDIS_CACHE_MAX_PACKS", "50") or 50)
)
self.message_pack_cache: deque[list[ScheduleMessageItem]] = deque(
maxlen=self._cache_max_packs
)
Expand All @@ -138,6 +140,7 @@ def __init__(
self._stream_keys_lock = threading.Lock()
self._stream_keys_refresh_thread: ContextThread | None = None
self._stream_keys_refresh_stop_event = threading.Event()
self._stream_read_offset = 0
self._initial_scan_max_keys = int(
os.getenv("MEMSCHEDULER_REDIS_INITIAL_SCAN_MAX_KEYS", "1000") or 1000
)
Expand Down Expand Up @@ -315,6 +318,20 @@ def task_broker(
stream_keys = self.get_stream_keys(stream_key_prefix=self.stream_key_prefix)
if not stream_keys:
return []
stream_key_count = len(stream_keys)
if stream_key_count > self._cache_max_packs:
start = self._stream_read_offset % stream_key_count
end = start + self._cache_max_packs
if end <= stream_key_count:
stream_keys = stream_keys[start:end]
else:
stream_keys = stream_keys[start:] + stream_keys[: end % stream_key_count]
self._stream_read_offset = (start + self._cache_max_packs) % stream_key_count
logger.debug(
"[REDIS_QUEUE] Broker stream scan capped. scanned_streams=%s cache_max_packs=%s",
len(stream_keys),
self._cache_max_packs,
)

# Determine per-stream quotas for this cycle
stream_quotas = self.orchestrator.get_stream_quotas(
Expand Down Expand Up @@ -353,6 +370,28 @@ def task_broker(
if claimed_messages:
messages.extend(claimed_messages)

max_cached_messages = max(consume_batch_size, consume_batch_size * self._cache_max_packs)
limited_messages: list[tuple[str, list[tuple[str, dict]]]] = []
remaining = max_cached_messages
for stream_key, stream_messages in messages:
if remaining <= 0:
break
if len(stream_messages) <= remaining:
limited_messages.append((stream_key, stream_messages))
remaining -= len(stream_messages)
else:
limited_messages.append((stream_key, stream_messages[:remaining]))
remaining = 0
if remaining == 0 and len(limited_messages) < len(messages):
logger.debug(
"[REDIS_QUEUE] Broker prefetch capped. streams=%s capped_messages=%s cache_max_packs=%s consume_batch=%s",
len(messages),
max_cached_messages,
self._cache_max_packs,
consume_batch_size,
)
messages = limited_messages

cache: list[ScheduleMessageItem] = self._convert_messages(messages)

# pack messages
Expand Down Expand Up @@ -400,9 +439,9 @@ def _is_refill_thread_available(self) -> bool:
return True
if (time.time() - self._refill_thread_start) > self._refill_thread_timeout:
logger.warning(
f"Refill thread has been running for >{self._refill_thread_timeout}s, treating as stale"
f"Refill thread has been running for >{self._refill_thread_timeout}s; "
"skip starting another refill thread to avoid duplicate memory growth"
)
return True
return False

def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]:
Expand Down
Loading