Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/memos/api/handlers/formatters_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@

logger = get_logger(__name__)

__all__ = [
"format_memory_item",
"post_process_textual_mem",
"rerank_knowledge_mem",
"to_iter",
]


def to_iter(running: Any) -> list[Any]:
"""
Expand Down Expand Up @@ -164,6 +171,7 @@ def rerank_knowledge_mem(
text_mem: list[dict[str, Any]],
top_k: int,
file_mem_proportion: float = 0.5,
strip_conversation_sources: bool = False,
) -> list[dict[str, Any]]:
"""
Rerank knowledge memories and keep conversation memories.
Expand Down Expand Up @@ -193,8 +201,9 @@ def rerank_knowledge_mem(
item["memory"] = item["metadata"]["sources"][0]["content"]
item["metadata"]["sources"] = []

for item in conversation_mem:
item.setdefault("metadata", {})["sources"] = []
if strip_conversation_sources:
for item in conversation_mem:
item.setdefault("metadata", {})["sources"] = []

# deduplicate: remove items with duplicate memory content
original_count = len(reranked_knowledge_mem)
Expand Down
17 changes: 17 additions & 0 deletions src/memos/api/handlers/search_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,23 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
top_k=search_req_local.top_k,
file_mem_proportion=0.5,
)
hooked_results = trigger_hook(
H.SEARCH_RESULTS_AFTER_RERANK,
handler=self,
search_req=search_req_local,
results=results,
)
if hooked_results is not None:
results = hooked_results

hooked_results = trigger_hook(
H.SEARCH_CONTEXT_RENDER,
handler=self,
search_req=search_req_local,
results=results,
)
if hooked_results is not None:
results = hooked_results

self.logger.info(
f"[SearchHandler] Final search results: count={len(results)} results={results}"
Expand Down
13 changes: 13 additions & 0 deletions src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,14 @@ class APISearchRequest(BaseRequest):
),
)

rerank: bool = Field(
True,
description=(
"Whether to apply the textual memory reranker during search. "
"Set false to return retrieval-order candidates before post-search dedup/formatting."
),
)

pref_top_k: int = Field(
6,
ge=0,
Expand Down Expand Up @@ -464,6 +472,11 @@ class APISearchRequest(BaseRequest):
description="Number of skill memories to retrieve (top-K). Default: 3.",
)

context_format: str = Field(
"memory",
description="Optional search context format passed through to installed plugins.",
)

# ==== Filter conditions ====
# TODO: maybe add detailed description later
filter: dict[str, Any] | None = Field(
Expand Down
100 changes: 79 additions & 21 deletions src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,30 @@ def __init__(
self.tokenizer = tokenizer
self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage")

def _maybe_rerank(
self,
enabled: bool,
*,
query: str,
graph_results: list[TextualMemoryItem],
top_k: int,
**kwargs,
) -> list[tuple[TextualMemoryItem, float]]:
if not enabled or self.reranker is None:
return [(item, 0.0) for item in graph_results[:top_k]]
return self.reranker.rerank(
query=query,
graph_results=graph_results,
top_k=top_k,
**kwargs,
)

@staticmethod
def _query_embedding_for_rerank(enabled: bool, query_embedding):
if not enabled:
return None
return query_embedding[0]

@timed
def retrieve(
self,
Expand All @@ -99,6 +123,7 @@ def retrieve(
logger.info(
f"[RECALL] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}, user_name={user_name}"
)
rerank = bool(kwargs.get("rerank", True))
parsed_goal, query_embedding, _context, query = self._parse_task(
query,
info,
Expand All @@ -125,6 +150,7 @@ def retrieve(
skill_mem_top_k,
include_preference_memory,
pref_mem_top_k,
rerank,
)
return results

Expand Down Expand Up @@ -350,6 +376,7 @@ def _retrieve_paths(
skill_mem_top_k: int = 3,
include_preference_memory: bool = False,
pref_mem_top_k: int = 6,
rerank: bool = True,
):
"""Run A/B/C/D/E/F retrieval paths in parallel"""
tasks = []
Expand All @@ -372,6 +399,7 @@ def _retrieve_paths(
search_priority,
user_name,
id_filter,
rerank=rerank,
)
)
tasks.append(
Expand All @@ -387,6 +415,7 @@ def _retrieve_paths(
user_name,
id_filter,
mode=mode,
rerank=rerank,
)
)
tasks.append(
Expand All @@ -400,6 +429,7 @@ def _retrieve_paths(
mode,
memory_type,
user_name,
rerank=rerank,
)
)
if self.use_fulltext:
Expand All @@ -415,6 +445,7 @@ def _retrieve_paths(
search_priority,
user_name,
id_filter,
rerank=rerank,
)
)
if search_tool_memory:
Expand All @@ -431,6 +462,7 @@ def _retrieve_paths(
user_name,
id_filter,
mode=mode,
rerank=rerank,
)
)
if include_skill_memory:
Expand All @@ -447,6 +479,7 @@ def _retrieve_paths(
user_name,
id_filter,
mode=mode,
rerank=rerank,
)
)
if include_preference_memory:
Expand All @@ -463,6 +496,7 @@ def _retrieve_paths(
user_name,
id_filter,
mode=mode,
rerank=rerank,
)
)
results = []
Expand All @@ -485,6 +519,7 @@ def _retrieve_from_working_memory(
search_priority: dict | None = None,
user_name: str | None = None,
id_filter: dict | None = None,
rerank: bool = True,
):
"""Retrieve and rerank from WorkingMemory"""
if memory_type not in ["All", "WorkingMemory"]:
Expand All @@ -501,9 +536,10 @@ def _retrieve_from_working_memory(
id_filter=id_filter,
use_fast_graph=self.use_fast_graph,
)
return self.reranker.rerank(
return self._maybe_rerank(
rerank,
query=query,
query_embedding=query_embedding[0],
query_embedding=self._query_embedding_for_rerank(rerank, query_embedding),
graph_results=items,
top_k=top_k,
parsed_goal=parsed_goal,
Expand Down Expand Up @@ -613,6 +649,7 @@ def _retrieve_from_keyword(
search_priority: dict | None = None,
user_name: str | None = None,
id_filter: dict | None = None,
rerank: bool = True,
) -> list[tuple[TextualMemoryItem, float]]:
"""Keyword/fulltext path that directly calls graph DB fulltext search."""

Expand Down Expand Up @@ -687,9 +724,10 @@ def _retrieve_from_keyword(
ordered_nodes.append(node)

results = [TextualMemoryItem.from_dict(n) for n in ordered_nodes]
return self.reranker.rerank(
return self._maybe_rerank(
rerank,
query=query,
query_embedding=query_embedding[0],
query_embedding=self._query_embedding_for_rerank(rerank, query_embedding),
graph_results=results,
top_k=top_k,
parsed_goal=parsed_goal,
Expand All @@ -710,6 +748,7 @@ def _retrieve_from_long_term_and_user(
user_name: str | None = None,
id_filter: dict | None = None,
mode: str = "fast",
rerank: bool = True,
):
"""Retrieve and rerank from LongTermMemory and UserMemory"""
results = []
Expand Down Expand Up @@ -781,9 +820,10 @@ def _retrieve_from_long_term_and_user(
results = self._deduplicate_rawfile_results(results, user_name=user_name)
results = self._filter_intermediate_content(results)

return self.reranker.rerank(
return self._maybe_rerank(
rerank,
query=query,
query_embedding=query_embedding[0],
query_embedding=self._query_embedding_for_rerank(rerank, query_embedding),
graph_results=results,
top_k=top_k,
parsed_goal=parsed_goal,
Expand All @@ -792,7 +832,13 @@ def _retrieve_from_long_term_and_user(

@timed
def _retrieve_from_memcubes(
self, query, parsed_goal, query_embedding, top_k, cube_name="memos_cube01"
self,
query,
parsed_goal,
query_embedding,
top_k,
cube_name="memos_cube01",
rerank: bool = True,
):
"""Retrieve and rerank from LongTermMemory and UserMemory"""
results = self.graph_retriever.retrieve_from_cube(
Expand All @@ -802,9 +848,10 @@ def _retrieve_from_memcubes(
cube_name=cube_name,
user_name=cube_name,
)
return self.reranker.rerank(
return self._maybe_rerank(
rerank,
query=query,
query_embedding=query_embedding[0],
query_embedding=self._query_embedding_for_rerank(rerank, query_embedding),
graph_results=results,
top_k=top_k,
parsed_goal=parsed_goal,
Expand All @@ -822,6 +869,7 @@ def _retrieve_from_internet(
mode,
memory_type,
user_id: str | None = None,
rerank: bool = True,
):
"""Retrieve and rerank from Internet source"""
if not self.internet_retriever:
Expand All @@ -838,9 +886,10 @@ def _retrieve_from_internet(
query=query, top_k=2 * top_k, parsed_goal=parsed_goal, info=info, mode=mode
)
logger.info(f"[PATH-C] '{query}' Retrieved from internet {len(items)} items: {items}")
return self.reranker.rerank(
return self._maybe_rerank(
rerank,
query=query,
query_embedding=query_embedding[0],
query_embedding=self._query_embedding_for_rerank(rerank, query_embedding),
graph_results=items,
top_k=top_k,
parsed_goal=parsed_goal,
Expand All @@ -860,6 +909,7 @@ def _retrieve_from_tool_memory(
user_name: str | None = None,
id_filter: dict | None = None,
mode: str = "fast",
rerank: bool = True,
):
"""Retrieve and rerank from ToolMemory"""
results = {
Expand Down Expand Up @@ -920,17 +970,19 @@ def _retrieve_from_tool_memory(
elif rsp and rsp[0].metadata.memory_type == "ToolTrajectoryMemory":
results["ToolTrajectoryMemory"].extend(rsp)

schema_reranked = self.reranker.rerank(
schema_reranked = self._maybe_rerank(
rerank,
query=query,
query_embedding=query_embedding[0],
query_embedding=self._query_embedding_for_rerank(rerank, query_embedding),
graph_results=results["ToolSchemaMemory"],
top_k=top_k,
parsed_goal=parsed_goal,
search_filter=search_filter,
)
trajectory_reranked = self.reranker.rerank(
trajectory_reranked = self._maybe_rerank(
rerank,
query=query,
query_embedding=query_embedding[0],
query_embedding=self._query_embedding_for_rerank(rerank, query_embedding),
graph_results=results["ToolTrajectoryMemory"],
top_k=top_k,
parsed_goal=parsed_goal,
Expand All @@ -952,6 +1004,7 @@ def _retrieve_from_skill_memory(
user_name: str | None = None,
id_filter: dict | None = None,
mode: str = "fast",
rerank: bool = True,
):
"""Retrieve and rerank from SkillMemory"""

Expand Down Expand Up @@ -982,9 +1035,10 @@ def _retrieve_from_skill_memory(
use_fast_graph=self.use_fast_graph,
)

return self.reranker.rerank(
return self._maybe_rerank(
rerank,
query=query,
query_embedding=query_embedding[0],
query_embedding=self._query_embedding_for_rerank(rerank, query_embedding),
graph_results=items,
top_k=top_k,
parsed_goal=parsed_goal,
Expand All @@ -1004,6 +1058,7 @@ def _retrieve_from_preference_memory(
user_name: str | None = None,
id_filter: dict | None = None,
mode: str = "fast",
rerank: bool = True,
):
"""Retrieve and rerank from PreferenceMemory"""
if memory_type not in ["All", "PreferenceMemory"]:
Expand Down Expand Up @@ -1033,9 +1088,10 @@ def _retrieve_from_preference_memory(
use_fast_graph=self.use_fast_graph,
)

return self.reranker.rerank(
return self._maybe_rerank(
rerank,
query=query,
query_embedding=query_embedding[0],
query_embedding=self._query_embedding_for_rerank(rerank, query_embedding),
graph_results=items,
top_k=top_k,
parsed_goal=parsed_goal,
Expand Down Expand Up @@ -1086,9 +1142,11 @@ def _retrieve_simple(
logger.info(
f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}"
)
return self.reranker.rerank(
rerank = bool(kwargs.get("rerank", True))
return self._maybe_rerank(
rerank,
query=query,
query_embedding=query_embeddings[0],
query_embedding=self._query_embedding_for_rerank(rerank, query_embeddings),
graph_results=selected_items,
top_k=top_k,
)
Expand Down
Loading
Loading