Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ issues = "https://github.com/MemTensor/MemOS/issues"
[project.scripts]
memos = "memos.cli:main"

[project.entry-points."memos.plugins"]
dream = "memos.dream:CommunityDreamPlugin"

[project.optional-dependencies]
# These are optional dependencies for various features of MemoryOS.
# Developers install: `poetry install --extras <feature>`. e.g., `poetry install --extras general-mem`
Expand Down
117 changes: 117 additions & 0 deletions src/memos/api/handlers/search_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@

import copy
import math
import os

from contextlib import suppress
from typing import Any

from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
from memos.api.handlers.formatters_handler import rerank_knowledge_mem
from memos.api.product_models import APISearchRequest, SearchResponse
from memos.dream.contextualization import CONTEXT_MEMORY_TYPE
from memos.log import get_logger
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import (
cosine_similarity_matrix,
Expand All @@ -25,6 +28,20 @@

logger = get_logger(__name__)

_ENV_CONTEXT_RECALL = "MEMOS_DREAM_CONTEXT_RECALL"
_ENV_CONTEXT_RECALL_TOP_K = "MEMOS_DREAM_CONTEXT_RECALL_TOP_K"
_DEFAULT_CONTEXT_RECALL_TOP_K = 2


def _env_enabled(name: str, default: str = "off") -> bool:
return os.getenv(name, default).strip().lower() not in {"0", "false", "no", "off"}


def _env_int(name: str, default: int) -> int:
with suppress(TypeError, ValueError):
return int(os.getenv(name, str(default)))
return default


class SearchHandler(BaseHandler):
"""
Expand Down Expand Up @@ -71,6 +88,7 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
# Search and deduplicate
cube_view = self._build_cube_view(search_req_local)
results = cube_view.search_memories(search_req_local)
self._merge_context_recall(results=results, search_req=search_req_local)
if not search_req_local.relativity:
search_req_local.relativity = 0
self.logger.info(f"[SearchHandler] Relativity filter: {search_req_local.relativity}")
Expand Down Expand Up @@ -102,6 +120,105 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
data=results,
)

def _merge_context_recall(
self, *, results: dict[str, Any], search_req: APISearchRequest
) -> None:
if not _env_enabled(_ENV_CONTEXT_RECALL, "off"):
return

top_k = max(0, _env_int(_ENV_CONTEXT_RECALL_TOP_K, _DEFAULT_CONTEXT_RECALL_TOP_K))
if top_k <= 0:
return

context_buckets = self._recall_context_buckets(search_req=search_req, top_k=top_k)
if not context_buckets:
return

results.setdefault("text_mem", []).extend(context_buckets)

def _recall_context_buckets(
self, *, search_req: APISearchRequest, top_k: int
) -> list[dict[str, Any]]:
graph_db = self.graph_db or getattr(self.searcher, "graph_store", None)
embedder = self.embedder or getattr(self.searcher, "embedder", None)
if graph_db is None or embedder is None:
self.logger.info(
"[SearchHandler] Context recall skipped: graph_db or embedder unavailable."
)
return []

try:
query_embedding = embedder.embed([search_req.query])[0]
except Exception:
self.logger.warning("[SearchHandler] Context recall embedding failed.", exc_info=True)
return []

buckets: list[dict[str, Any]] = []
for cube_id in self._resolve_cube_ids(search_req):
try:
hits = graph_db.search_by_embedding(
query_embedding,
top_k=top_k,
scope=CONTEXT_MEMORY_TYPE,
status="activated",
user_name=cube_id,
return_fields=[
"memory",
"key",
"created_at",
"updated_at",
"source",
"internal_info",
],
)
except Exception:
self.logger.warning(
"[SearchHandler] Context recall search failed for cube=%s.",
cube_id,
exc_info=True,
)
continue

memories = [self._format_context_hit(hit) for hit in hits or [] if hit.get("memory")]
if not memories:
continue
buckets.append(
{
"cube_id": cube_id,
"memories": memories,
"total_nodes": len(memories),
}
)
return buckets

@staticmethod
def _format_context_hit(hit: dict[str, Any]) -> dict[str, Any]:
context_id = str(hit.get("id", ""))
score = float(hit.get("score", 0.0) or 0.0)
metadata = {
"id": context_id,
"memory": hit.get("memory", ""),
"memory_type": CONTEXT_MEMORY_TYPE,
"source": hit.get("source") or "dream",
"key": hit.get("key", ""),
"relativity": score,
"score": score,
"embedding": [],
"sources": [],
"usage": [],
"ref_id": f"[{context_id.split('-')[0]}]" if context_id else "[context]",
}
for field in ("created_at", "updated_at", "internal_info"):
if hit.get(field) is not None:
metadata[field] = hit[field]

return {
"id": context_id,
"memory": hit.get("memory", ""),
"metadata": metadata,
"ref_id": metadata["ref_id"],
}

@staticmethod
def _apply_relativity_threshold(results: dict[str, Any], relativity: float) -> dict[str, Any]:
if relativity <= 0:
Expand Down
Loading
Loading