diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py index c8024baa3..638565bb9 100644 --- a/src/memos/api/handlers/formatters_handler.py +++ b/src/memos/api/handlers/formatters_handler.py @@ -13,6 +13,13 @@ logger = get_logger(__name__) +__all__ = [ + "format_memory_item", + "post_process_textual_mem", + "rerank_knowledge_mem", + "to_iter", +] + def to_iter(running: Any) -> list[Any]: """ @@ -164,6 +171,7 @@ def rerank_knowledge_mem( text_mem: list[dict[str, Any]], top_k: int, file_mem_proportion: float = 0.5, + strip_conversation_sources: bool = False, ) -> list[dict[str, Any]]: """ Rerank knowledge memories and keep conversation memories. @@ -193,8 +201,9 @@ def rerank_knowledge_mem( item["memory"] = item["metadata"]["sources"][0]["content"] item["metadata"]["sources"] = [] - for item in conversation_mem: - item.setdefault("metadata", {})["sources"] = [] + if strip_conversation_sources: + for item in conversation_mem: + item.setdefault("metadata", {})["sources"] = [] # deduplicate: remove items with duplicate memory content original_count = len(reranked_knowledge_mem) diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 5cfaf658b..03e6977ad 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -118,6 +118,23 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse top_k=search_req_local.top_k, file_mem_proportion=0.5, ) + hooked_results = trigger_hook( + H.SEARCH_RESULTS_AFTER_RERANK, + handler=self, + search_req=search_req_local, + results=results, + ) + if hooked_results is not None: + results = hooked_results + + hooked_results = trigger_hook( + H.SEARCH_CONTEXT_RENDER, + handler=self, + search_req=search_req_local, + results=results, + ) + if hooked_results is not None: + results = hooked_results self.logger.info( f"[SearchHandler] Final search results: count={len(results)} results={results}" diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 110b97e65..ff9d94859 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -422,6 +422,14 @@ class APISearchRequest(BaseRequest): ), ) + rerank: bool = Field( + True, + description=( + "Whether to apply the textual memory reranker during search. " + "Set false to return retrieval-order candidates before post-search dedup/formatting." + ), + ) + pref_top_k: int = Field( 6, ge=0, @@ -464,6 +472,11 @@ class APISearchRequest(BaseRequest): description="Number of skill memories to retrieve (top-K). Default: 3.", ) + context_format: str = Field( + "memory", + description="Optional search context format passed through to installed plugins.", + ) + # ==== Filter conditions ==== # TODO: maybe add detailed description later filter: dict[str, Any] | None = Field( diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index dec6a4458..cd27d92a1 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -77,6 +77,30 @@ def __init__( self.tokenizer = tokenizer self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") + def _maybe_rerank( + self, + enabled: bool, + *, + query: str, + graph_results: list[TextualMemoryItem], + top_k: int, + **kwargs, + ) -> list[tuple[TextualMemoryItem, float]]: + if not enabled or self.reranker is None: + return [(item, 0.0) for item in graph_results[:top_k]] + return self.reranker.rerank( + query=query, + graph_results=graph_results, + top_k=top_k, + **kwargs, + ) + + @staticmethod + def _query_embedding_for_rerank(enabled: bool, query_embedding): + if not enabled: + return None + return query_embedding[0] + @timed def retrieve( self, @@ -99,6 +123,7 @@ def retrieve( logger.info( f"[RECALL] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}, user_name={user_name}" ) + rerank = bool(kwargs.get("rerank", True)) parsed_goal, query_embedding, _context, query = self._parse_task( query, info, @@ -125,6 +150,7 @@ def retrieve( skill_mem_top_k, include_preference_memory, pref_mem_top_k, + rerank, ) return results @@ -350,6 +376,7 @@ def _retrieve_paths( skill_mem_top_k: int = 3, include_preference_memory: bool = False, pref_mem_top_k: int = 6, + rerank: bool = True, ): """Run A/B/C/D/E/F retrieval paths in parallel""" tasks = [] @@ -372,6 +399,7 @@ def _retrieve_paths( search_priority, user_name, id_filter, + rerank=rerank, ) ) tasks.append( @@ -387,6 +415,7 @@ def _retrieve_paths( user_name, id_filter, mode=mode, + rerank=rerank, ) ) tasks.append( @@ -400,6 +429,7 @@ def _retrieve_paths( mode, memory_type, user_name, + rerank=rerank, ) ) if self.use_fulltext: @@ -415,6 +445,7 @@ def _retrieve_paths( search_priority, user_name, id_filter, + rerank=rerank, ) ) if search_tool_memory: @@ -431,6 +462,7 @@ def _retrieve_paths( user_name, id_filter, mode=mode, + rerank=rerank, ) ) if include_skill_memory: @@ -447,6 +479,7 @@ def _retrieve_paths( user_name, id_filter, mode=mode, + rerank=rerank, ) ) if include_preference_memory: @@ -463,6 +496,7 @@ def _retrieve_paths( user_name, id_filter, mode=mode, + rerank=rerank, ) ) results = [] @@ -485,6 +519,7 @@ def _retrieve_from_working_memory( search_priority: dict | None = None, user_name: str | None = None, id_filter: dict | None = None, + rerank: bool = True, ): """Retrieve and rerank from WorkingMemory""" if memory_type not in ["All", "WorkingMemory"]: @@ -501,9 +536,10 @@ def _retrieve_from_working_memory( id_filter=id_filter, use_fast_graph=self.use_fast_graph, ) - return self.reranker.rerank( + return self._maybe_rerank( + rerank, query=query, - query_embedding=query_embedding[0], + query_embedding=self._query_embedding_for_rerank(rerank, query_embedding), graph_results=items, top_k=top_k, parsed_goal=parsed_goal, @@ -613,6 +649,7 @@ def _retrieve_from_keyword( search_priority: dict | None = None, user_name: str | None = None, id_filter: dict | None = None, + rerank: bool = True, ) -> list[tuple[TextualMemoryItem, float]]: """Keyword/fulltext path that directly calls graph DB fulltext search.""" @@ -687,9 +724,10 @@ def _retrieve_from_keyword( ordered_nodes.append(node) results = [TextualMemoryItem.from_dict(n) for n in ordered_nodes] - return self.reranker.rerank( + return self._maybe_rerank( + rerank, query=query, - query_embedding=query_embedding[0], + query_embedding=self._query_embedding_for_rerank(rerank, query_embedding), graph_results=results, top_k=top_k, parsed_goal=parsed_goal, @@ -710,6 +748,7 @@ def _retrieve_from_long_term_and_user( user_name: str | None = None, id_filter: dict | None = None, mode: str = "fast", + rerank: bool = True, ): """Retrieve and rerank from LongTermMemory and UserMemory""" results = [] @@ -781,9 +820,10 @@ def _retrieve_from_long_term_and_user( results = self._deduplicate_rawfile_results(results, user_name=user_name) results = self._filter_intermediate_content(results) - return self.reranker.rerank( + return self._maybe_rerank( + rerank, query=query, - query_embedding=query_embedding[0], + query_embedding=self._query_embedding_for_rerank(rerank, query_embedding), graph_results=results, top_k=top_k, parsed_goal=parsed_goal, @@ -792,7 +832,13 @@ def _retrieve_from_long_term_and_user( @timed def _retrieve_from_memcubes( - self, query, parsed_goal, query_embedding, top_k, cube_name="memos_cube01" + self, + query, + parsed_goal, + query_embedding, + top_k, + cube_name="memos_cube01", + rerank: bool = True, ): """Retrieve and rerank from LongTermMemory and UserMemory""" results = self.graph_retriever.retrieve_from_cube( @@ -802,9 +848,10 @@ def _retrieve_from_memcubes( cube_name=cube_name, user_name=cube_name, ) - return self.reranker.rerank( + return self._maybe_rerank( + rerank, query=query, - query_embedding=query_embedding[0], + query_embedding=self._query_embedding_for_rerank(rerank, query_embedding), graph_results=results, top_k=top_k, parsed_goal=parsed_goal, @@ -822,6 +869,7 @@ def _retrieve_from_internet( mode, memory_type, user_id: str | None = None, + rerank: bool = True, ): """Retrieve and rerank from Internet source""" if not self.internet_retriever: @@ -838,9 +886,10 @@ def _retrieve_from_internet( query=query, top_k=2 * top_k, parsed_goal=parsed_goal, info=info, mode=mode ) logger.info(f"[PATH-C] '{query}' Retrieved from internet {len(items)} items: {items}") - return self.reranker.rerank( + return self._maybe_rerank( + rerank, query=query, - query_embedding=query_embedding[0], + query_embedding=self._query_embedding_for_rerank(rerank, query_embedding), graph_results=items, top_k=top_k, parsed_goal=parsed_goal, @@ -860,6 +909,7 @@ def _retrieve_from_tool_memory( user_name: str | None = None, id_filter: dict | None = None, mode: str = "fast", + rerank: bool = True, ): """Retrieve and rerank from ToolMemory""" results = { @@ -920,17 +970,19 @@ def _retrieve_from_tool_memory( elif rsp and rsp[0].metadata.memory_type == "ToolTrajectoryMemory": results["ToolTrajectoryMemory"].extend(rsp) - schema_reranked = self.reranker.rerank( + schema_reranked = self._maybe_rerank( + rerank, query=query, - query_embedding=query_embedding[0], + query_embedding=self._query_embedding_for_rerank(rerank, query_embedding), graph_results=results["ToolSchemaMemory"], top_k=top_k, parsed_goal=parsed_goal, search_filter=search_filter, ) - trajectory_reranked = self.reranker.rerank( + trajectory_reranked = self._maybe_rerank( + rerank, query=query, - query_embedding=query_embedding[0], + query_embedding=self._query_embedding_for_rerank(rerank, query_embedding), graph_results=results["ToolTrajectoryMemory"], top_k=top_k, parsed_goal=parsed_goal, @@ -952,6 +1004,7 @@ def _retrieve_from_skill_memory( user_name: str | None = None, id_filter: dict | None = None, mode: str = "fast", + rerank: bool = True, ): """Retrieve and rerank from SkillMemory""" @@ -982,9 +1035,10 @@ def _retrieve_from_skill_memory( use_fast_graph=self.use_fast_graph, ) - return self.reranker.rerank( + return self._maybe_rerank( + rerank, query=query, - query_embedding=query_embedding[0], + query_embedding=self._query_embedding_for_rerank(rerank, query_embedding), graph_results=items, top_k=top_k, parsed_goal=parsed_goal, @@ -1004,6 +1058,7 @@ def _retrieve_from_preference_memory( user_name: str | None = None, id_filter: dict | None = None, mode: str = "fast", + rerank: bool = True, ): """Retrieve and rerank from PreferenceMemory""" if memory_type not in ["All", "PreferenceMemory"]: @@ -1033,9 +1088,10 @@ def _retrieve_from_preference_memory( use_fast_graph=self.use_fast_graph, ) - return self.reranker.rerank( + return self._maybe_rerank( + rerank, query=query, - query_embedding=query_embedding[0], + query_embedding=self._query_embedding_for_rerank(rerank, query_embedding), graph_results=items, top_k=top_k, parsed_goal=parsed_goal, @@ -1086,9 +1142,11 @@ def _retrieve_simple( logger.info( f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}" ) - return self.reranker.rerank( + rerank = bool(kwargs.get("rerank", True)) + return self._maybe_rerank( + rerank, query=query, - query_embedding=query_embeddings[0], + query_embedding=self._query_embedding_for_rerank(rerank, query_embeddings), graph_results=selected_items, top_k=top_k, ) diff --git a/src/memos/plugins/README.md b/src/memos/plugins/README.md new file mode 100644 index 000000000..c9b6a0c29 --- /dev/null +++ b/src/memos/plugins/README.md @@ -0,0 +1,448 @@ +# MemOS Plugin System + +This document explains how to use and extend the open-source MemOS plugin system. + +MemOS keeps the plugin framework in `src/memos/plugins`. A plugin can contribute +FastAPI routes, middleware, runtime components, and Hook callbacks without +modifying the core request handlers directly. The built-in open-source Dream +feature is implemented this way and is registered as the `dream` plugin. + +## Quick Navigation + +1. [Architecture](#architecture) +2. [Plugin Lifecycle](#plugin-lifecycle) +3. [Creating a Plugin](#creating-a-plugin) +4. [Registering a Plugin](#registering-a-plugin) +5. [Using Hooks](#using-hooks) +6. [Testing](#testing) +7. [Runtime Verification](#runtime-verification) +8. [Development Checklist](#development-checklist) + +## Architecture + +The core framework lives in: + +```text +src/memos/plugins/ +├── base.py # MemOSPlugin base class +├── manager.py # Plugin discovery, loading, and initialization +├── hooks.py # Hook registration and trigger runtime +├── hook_defs.py # Core Hook declarations and constants +└── component_bootstrap.py # Runtime component context bootstrap helpers +``` + +Plugins are discovered through the Python entry point group: + +```toml +[project.entry-points."memos.plugins"] +dream = "memos.dream:CommunityDreamPlugin" +``` + +At startup, `PluginManager` loads installed entry points, instantiates plugins, +keeps the highest-priority implementation when multiple providers expose the +same logical plugin name, and initializes enabled plugins. + +## Plugin Lifecycle + +All plugins inherit from `memos.plugins.base.MemOSPlugin`. + +```python +class MemOSPlugin: + name: str = "unnamed" + version: str = "0.0.0" + description: str = "" + priority: int = 0 + + def on_load(self) -> None: + ... + + def init_components(self, context: dict) -> None: + ... + + def init_app(self) -> None: + ... + + def on_shutdown(self) -> None: + ... +``` + +Lifecycle methods are called in this order: + +1. `on_load()`: called after discovery. Use it for lightweight state setup and + Hook registration that does not require a FastAPI app. +2. `init_components(context)`: called during server bootstrap. Use it when the + plugin needs access to shared runtime components such as scheduler handles or + memory backends. +3. `init_app()`: called after the FastAPI app is bound. Register routers, + middleware, and app-bound integrations here. +4. `on_shutdown()`: called when the service shuts down. Release resources here. + +Plugins can register capabilities with: + +```python +self.register_router(router) +self.register_middleware(MiddlewareClass) +self.register_hook("hook.name", callback) +self.register_hooks(["hook.a", "hook.b"], callback) +``` + +## Creating a Plugin + +The simplest plugin is a Python package that exposes a `MemOSPlugin` subclass. +The example below uses `memos_foo_plugin` as the package name. + +```text +memos_foo_plugin/ +├── __init__.py +├── plugin.py +├── routes.py +├── hooks.py +└── tests/ + ├── __init__.py + ├── conftest.py + ├── test_lifecycle.py + ├── test_hooks.py + └── test_routes.py +``` + +### Package Entry + +`memos_foo_plugin/__init__.py` + +```python +from memos_foo_plugin.plugin import FooPlugin + +__all__ = ["FooPlugin"] +``` + +### Plugin Class + +`memos_foo_plugin/plugin.py` + +```python +import logging +from functools import partial + +from memos.plugins.base import MemOSPlugin +from memos.plugins.hook_defs import H + +logger = logging.getLogger(__name__) + + +class FooPlugin(MemOSPlugin): + name = "foo" + version = "0.1.0" + description = "Foo plugin" + + def on_load(self) -> None: + self.counter: dict[str, int] = {} + logger.info("[Foo] plugin loaded") + + def init_app(self) -> None: + from memos_foo_plugin.hooks import on_add_after + from memos_foo_plugin.routes import create_router + + self.register_router(create_router(self)) + self.register_hook(H.ADD_AFTER, partial(on_add_after, self)) + + logger.info("[Foo] plugin initialized") + + def on_shutdown(self) -> None: + logger.info("[Foo] plugin shutdown") +``` + +### Routes + +`memos_foo_plugin/routes.py` + +```python +from __future__ import annotations + +from typing import TYPE_CHECKING + +from fastapi import APIRouter + +if TYPE_CHECKING: + from memos_foo_plugin.plugin import FooPlugin + + +def create_router(plugin: FooPlugin) -> APIRouter: + router = APIRouter(prefix="/foo", tags=["foo"]) + + @router.get("/health") + async def health(): + return {"status": "ok", "plugin": plugin.name} + + @router.get("/stats") + async def stats(): + return {"counter": plugin.counter} + + return router +``` + +### Hook Callback + +`memos_foo_plugin/hooks.py` + +```python +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from memos_foo_plugin.plugin import FooPlugin + + +def on_add_after(plugin: FooPlugin, *, request, result, **kwargs) -> None: + user_id = getattr(request, "user_id", "unknown") + plugin.counter[user_id] = plugin.counter.get(user_id, 0) + 1 +``` + +### Middleware + +Middleware is optional. If a plugin needs one, register it in `init_app()`. + +```python +from starlette.middleware.base import BaseHTTPMiddleware + + +class FooMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + response = await call_next(request) + response.headers["X-MemOS-Plugin"] = "foo" + return response +``` + +```python +self.register_middleware(FooMiddleware) +``` + +## Registering a Plugin + +Plugins are registered with the `memos.plugins` entry point group. + +For a plugin shipped inside this repository, add an entry point in +`pyproject.toml`: + +```toml +[project.entry-points."memos.plugins"] +foo = "memos_foo_plugin:FooPlugin" +``` + +For a plugin distributed as a separate Python package, declare the same entry +point in that package's project metadata. After installing the package into the +same environment as MemOS, the plugin manager can discover it automatically. + +Reinstall the package after adding or changing entry points: + +```bash +pip install -e . +``` + +If you only changed the implementation of an already installed editable plugin, +restarting the service is usually enough. + +### Disabling Plugins + +Set `MEMOS_DISABLED_PLUGINS` to a comma-separated list of logical plugin names: + +```bash +MEMOS_DISABLED_PLUGINS=dream,foo uvicorn memos.api.server_api:app --port 8001 +``` + +### Plugin Priority + +`priority` resolves duplicate logical plugin names. If two installed packages +provide `name = "dream"`, MemOS keeps the implementation with the higher +priority. If the highest priority is tied, startup fails so the ambiguity is +visible. + +## Using Hooks + +Core Hook names are defined in `memos.plugins.hook_defs.H`. + +Common extension points include: + +```python +H.ADD_BEFORE +H.ADD_AFTER +H.SEARCH_BEFORE +H.SEARCH_AFTER +H.SEARCH_MEMORY_RESULTS +H.MEM_READER_PRE_EXTRACT +H.MEMORY_ITEMS_AFTER_FINE_EXTRACT +H.DREAM_EXECUTE +``` + +Hook callbacks receive keyword arguments. Some Hooks define a `pipe_key`; when a +callback returns a non-`None` value, that value replaces the named argument for +later callbacks and for the caller. + +For example, `add.after` can replace `result`: + +```python +def on_add_after(*, request, result, **kwargs): + result.metadata["handled_by"] = "foo" + return result +``` + +### Defining Plugin-Owned Hooks + +Plugins may define their own Hook names when they need internal extension +points. Keep those declarations inside the plugin package rather than adding +plugin-specific names to `memos.plugins.hook_defs`. + +`memos_foo_plugin/hook_defs.py` + +```python +from memos.plugins.hook_defs import define_hook + + +class FooH: + RESULT_ENRICH = "foo.result.enrich" + + +define_hook( + FooH.RESULT_ENRICH, + description="Enrich Foo result data", + params=["user_id", "result"], + pipe_key="result", +) +``` + +Trigger it from plugin code: + +```python +from memos.plugins.hooks import trigger_hook +from memos_foo_plugin.hook_defs import FooH + +updated = trigger_hook(FooH.RESULT_ENRICH, user_id="alice", result=data) +data = updated if updated is not None else data +``` + +## Built-In Dream Plugin + +The open-source Dream implementation is a built-in plugin: + +```text +src/memos/dream/ +├── __init__.py +├── plugin.py +├── hooks.py +├── hook_defs.py +├── pipeline/ +└── routers/ +``` + +It is registered in `pyproject.toml`: + +```toml +[project.entry-points."memos.plugins"] +dream = "memos.dream:CommunityDreamPlugin" +``` + +`CommunityDreamPlugin` demonstrates the recommended pattern: + +- initialize state in `on_load()` +- register scheduler-facing Hooks such as `H.DREAM_EXECUTE` +- bind shared runtime context in `init_components()` +- register HTTP routes in `init_app()` +- keep pipeline stages replaceable behind clear module boundaries + +Use it as the primary in-repository reference when extending the plugin system. + +## Testing + +Framework tests live under `tests/plugins/`. Plugin-specific tests should live +next to the plugin package or under an appropriate `tests//` +directory. + +### Test Hook Declarations + +Tests that trigger `@hookable`-generated Hooks should declare those Hooks before +registration. A small `conftest.py` is usually enough: + +```python +from memos.plugins.hooks import hookable + +hookable("add") +hookable("search") + +# Import plugin-owned hook definitions when needed: +# import memos_foo_plugin.hook_defs # noqa: F401 +``` + +### Lifecycle Test + +```python +from fastapi import FastAPI + + +def test_foo_plugin_lifecycle(): + from memos_foo_plugin.plugin import FooPlugin + + app = FastAPI() + plugin = FooPlugin() + plugin.on_load() + plugin._bind_app(app) + plugin.init_app() + + paths = [route.path for route in app.routes] + assert "/foo/health" in paths +``` + +### Running Tests + +Run the plugin framework tests: + +```bash +python -m pytest tests/plugins/ -v +``` + +Run Dream plugin tests: + +```bash +python -m pytest tests/dream/ -v +``` + +Run a plugin's own tests: + +```bash +python -m pytest path/to/plugin/tests/ -v +``` + +## Runtime Verification + +Start the API server: + +```bash +uvicorn memos.api.server_api:app --port 8001 +``` + +Startup logs should show discovered and initialized plugins: + +```text +INFO: Plugin discovered: dream v0.1.0 (priority=10) +INFO: Plugin initialized: dream +``` + +If your plugin registers a health route, verify it with: + +```bash +curl http://127.0.0.1:8001/foo/health +``` + +## Development Checklist + +- [ ] Plugin class inherits from `MemOSPlugin` +- [ ] `name`, `version`, and `description` are set +- [ ] `priority` is set when duplicate providers may exist +- [ ] State setup belongs in `on_load()` +- [ ] Shared runtime component wiring belongs in `init_components()` +- [ ] Routes and middleware are registered in `init_app()` +- [ ] Hook callbacks are registered with `self.register_hook(...)` +- [ ] Plugin-owned Hooks are declared inside the plugin package +- [ ] Entry point is declared under `memos.plugins` +- [ ] Package is reinstalled after entry point changes +- [ ] Tests cover lifecycle, Hook callbacks, and routes where applicable +- [ ] Service startup logs show the plugin was discovered and initialized diff --git a/src/memos/plugins/hook_defs.py b/src/memos/plugins/hook_defs.py index 3650e60c0..f955da0af 100644 --- a/src/memos/plugins/hook_defs.py +++ b/src/memos/plugins/hook_defs.py @@ -74,6 +74,8 @@ class H: # Search extension point before core threshold/dedup/rerank processing. SEARCH_MEMORY_RESULTS = "search.memory_results" + SEARCH_RESULTS_AFTER_RERANK = "search.results.after_rerank" + SEARCH_CONTEXT_RENDER = "search.context.render" # Custom Hook (manually triggered via trigger_hook) ADD_MEMORIES_POST_PROCESS = "add.memories.post_process" @@ -119,6 +121,20 @@ class H: pipe_key="results", ) +define_hook( + H.SEARCH_RESULTS_AFTER_RERANK, + description="Allow plugins to update search results after core rerank and before rendering.", + params=["handler", "search_req", "results"], + pipe_key="results", +) + +define_hook( + H.SEARCH_CONTEXT_RENDER, + description="Render final search context after retrieval, rerank, and result-level plugins.", + params=["handler", "search_req", "results"], + pipe_key="results", +) + define_hook( H.MEMORY_ITEMS_AFTER_FINE_EXTRACT, description="Post-process memory items after mem_reader fine extraction completes", diff --git a/src/memos/search/search_service.py b/src/memos/search/search_service.py index f4092d168..db9f9a6aa 100644 --- a/src/memos/search/search_service.py +++ b/src/memos/search/search_service.py @@ -94,5 +94,6 @@ def search_text_memories( include_preference_memory=search_req.include_preference, pref_mem_top_k=search_req.pref_top_k, dedup=search_req.dedup, + rerank=search_req.rerank, include_embedding=include_embedding, ) diff --git a/tests/api/test_search_pipeline_hooks.py b/tests/api/test_search_pipeline_hooks.py new file mode 100644 index 000000000..e96c3fbf5 --- /dev/null +++ b/tests/api/test_search_pipeline_hooks.py @@ -0,0 +1,78 @@ +from memos.api.handlers.formatters_handler import rerank_knowledge_mem +from memos.api.product_models import APISearchRequest +from memos.plugins.hook_defs import H, get_hook_spec + + +def _memory(memory_id: str, memory: str, memory_type: str = "LongTermMemory") -> dict: + return { + "id": memory_id, + "memory": memory, + "metadata": { + "memory_type": memory_type, + "relativity": 1.0, + "sources": [{"content": f"source for {memory}"}], + }, + } + + +def test_search_request_passes_context_format_through_to_plugins(): + req = APISearchRequest( + user_id="user", + query="What did Maria buy?", + context_format="plugin-owned-format", + ) + + assert req.context_format == "plugin-owned-format" + + +def test_search_pipeline_hook_specs_are_registered(): + after_rerank = get_hook_spec(H.SEARCH_RESULTS_AFTER_RERANK) + render = get_hook_spec(H.SEARCH_CONTEXT_RENDER) + + assert after_rerank is not None + assert after_rerank.pipe_key == "results" + assert after_rerank.params == ["handler", "search_req", "results"] + + assert render is not None + assert render.pipe_key == "results" + assert render.params == ["handler", "search_req", "results"] + + +def test_rerank_knowledge_mem_preserves_conversation_sources_by_default(): + text_mem = [ + { + "cube_id": "cube", + "memories": [ + _memory("mem-1", "conversation memory", memory_type="WorkingMemory"), + _memory("mem-2", "knowledge memory", memory_type="LongTermMemory"), + ], + } + ] + + reranked = rerank_knowledge_mem(None, "query", text_mem, top_k=2)[0]["memories"] + + conversation = next(item for item in reranked if item["memory"] == "conversation memory") + assert conversation["metadata"]["sources"] == [{"content": "source for conversation memory"}] + + +def test_rerank_knowledge_mem_can_strip_conversation_sources(): + text_mem = [ + { + "cube_id": "cube", + "memories": [ + _memory("mem-1", "conversation memory", memory_type="WorkingMemory"), + _memory("mem-2", "knowledge memory", memory_type="LongTermMemory"), + ], + } + ] + + reranked = rerank_knowledge_mem( + None, + "query", + text_mem, + top_k=2, + strip_conversation_sources=True, + )[0]["memories"] + + conversation = next(item for item in reranked if item["memory"] == "conversation memory") + assert conversation["metadata"]["sources"] == [] diff --git a/tests/memories/textual/test_tree_searcher.py b/tests/memories/textual/test_tree_searcher.py index 3d1469d00..b79958ca1 100644 --- a/tests/memories/textual/test_tree_searcher.py +++ b/tests/memories/textual/test_tree_searcher.py @@ -82,6 +82,28 @@ def retrieve_side_effect(*args, **kwargs): assert all(isinstance(item, TextualMemoryItem) for item in result) +def test_searcher_can_skip_rerank_per_request(mock_searcher): + parsed_goal = MagicMock() + parsed_goal.memories = ["Cats are cute"] + parsed_goal.rephrased_query = None + mock_searcher.task_goal_parser.parse.return_value = parsed_goal + mock_searcher.embedder.embed.return_value = [[0.1] * 5, [0.2] * 5] + mock_searcher.graph_retriever.retrieve.return_value = [make_item("wm1", 0.9)[0]] + + result = mock_searcher.search( + query="Tell me about cats", + top_k=1, + info={"test": True}, + mode="fast", + memory_type="WorkingMemory", + rerank=False, + ) + + mock_searcher.reranker.rerank.assert_not_called() + assert len(result) == 1 + assert result[0].memory == "wm1" + + def test_searcher_fine_mode_triggers_reasoner(mock_searcher): parsed_goal = MagicMock() parsed_goal.memories = ["Cats"]