From ecdd118a4fd7dddfd5791165c414f40dc0c69cef Mon Sep 17 00:00:00 2001 From: MemOS AutoDev Date: Wed, 17 Jun 2026 15:37:17 +0800 Subject: [PATCH] perf(memories/textual): cache query embeddings and skip redundant sort MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three targeted hot-path improvements to GeneralTextMemory: - Add a per-instance bounded LRU cache (size 256) on `_embed_one_sentence` so identical queries within a process — common during retry / rerank loops, repeated chat lookups, and update() calls that reuse a recent search() text — skip the embedder round-trip. Cached values are deep- copied on the way in and out so downstream mutation cannot poison the cache. - Drop the redundant `sorted(..., reverse=True)` in `search`. Qdrant's `query_points` already returns hits in descending-by-score order; the in-process sort was a no-op. A regression test now guards the contract. - Use `TextualMemoryItem.model_validate(payload)` instead of `TextualMemoryItem(**payload)` on the read paths (`search`, `get`, `get_by_ids`, `get_all`) — equivalent semantics on a dict input and the recommended Pydantic v2 idiom. Adds 4 unit tests under tests/memories/textual/test_general.py covering cache hit, LRU eviction, search ordering preservation, and search↔update cache sharing. All 62 textual-memory tests pass; ruff format + check clean. No public API change, no schema change, no new dependency. Refs #1926 --- src/memos/memories/textual/general.py | 49 +++++++++---- tests/memories/textual/test_general.py | 97 ++++++++++++++++++++++++++ 2 files changed, 133 insertions(+), 13 deletions(-) diff --git a/src/memos/memories/textual/general.py b/src/memos/memories/textual/general.py index b90f2a6ab..9a88031d2 100644 --- a/src/memos/memories/textual/general.py +++ b/src/memos/memories/textual/general.py @@ -1,6 +1,7 @@ import json import os +from collections import OrderedDict from datetime import datetime from typing import Any @@ -24,6 +25,11 @@ class GeneralTextMemory(BaseTextMemory): """General textual memory implementation for storing and retrieving memories.""" + # Upper bound on the per-instance query-embedding cache. Caps memory growth + # for long-running processes that handle a large vocabulary of distinct + # queries; tuned against typical embedding dims (~1k floats per entry). + _QUERY_EMBED_CACHE_MAX = 256 + def __init__(self, config: GeneralTextMemoryConfig): """Initialize memory with the given configuration.""" # Set mode from class default or override if needed @@ -34,6 +40,9 @@ def __init__(self, config: GeneralTextMemoryConfig): ) self.vector_db: QdrantVecDB = VecDBFactory.from_config(config.vector_db) self.embedder: OllamaEmbedder | ArkEmbedder = EmbedderFactory.from_config(config.embedder) + # LRU cache of sentence -> embedding. Keyed on raw text; identical + # queries within a process share a single embedder round-trip. + self._query_embed_cache: OrderedDict[str, list[float]] = OrderedDict() @retry( stop=stop_after_attempt(3), @@ -128,20 +137,18 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem """ query_vector = self._embed_one_sentence(query) search_results = self.vector_db.search(query_vector, top_k) - search_results = sorted( # make higher score first - search_results, key=lambda x: x.score, reverse=True - ) - result_memories = [ - TextualMemoryItem(**search_item.payload) for search_item in search_results + # The vector DB (Qdrant `query_points`) returns results in descending + # score order already; a redundant in-process sort was removed here. + return [ + TextualMemoryItem.model_validate(search_item.payload) for search_item in search_results ] - return result_memories def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem: """Get a memory by its ID.""" result = self.vector_db.get_by_id(memory_id) if result is None: raise ValueError(f"Memory with ID {memory_id} not found") - return TextualMemoryItem(**result.payload) + return TextualMemoryItem.model_validate(result.payload) def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: """Get memories by their IDs. @@ -151,8 +158,7 @@ def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: list[TextualMemoryItem]: List of memories with the specified IDs. """ db_items = self.vector_db.get_by_ids(memory_ids) - memories = [TextualMemoryItem(**db_item.payload) for db_item in db_items] - return memories + return [TextualMemoryItem.model_validate(db_item.payload) for db_item in db_items] def get_all(self) -> list[TextualMemoryItem]: """Get all memories. @@ -160,8 +166,7 @@ def get_all(self) -> list[TextualMemoryItem]: list[TextualMemoryItem]: List of all memories. """ all_items = self.vector_db.get_all() - all_memories = [TextualMemoryItem(**memo.payload) for memo in all_items] - return all_memories + return [TextualMemoryItem.model_validate(memo.payload) for memo in all_items] def delete(self, memory_ids: list[str]) -> None: """Delete a memory.""" @@ -217,8 +222,26 @@ def drop( pass def _embed_one_sentence(self, sentence: str) -> list[float]: - """Embed a single sentence.""" - return self.embedder.embed([sentence])[0] + """Embed a single sentence, reusing recent embeddings via a bounded LRU cache. + + The cache is keyed on the raw sentence text and capped at + ``_QUERY_EMBED_CACHE_MAX`` entries. Identical queries within a process + (common during retry / rerank loops and chat sessions that revisit + the same topic) skip the embedder round-trip. Embeddings are deep-copied + on the way in and on the way out so callers cannot mutate cached state. + """ + cached = self._query_embed_cache.get(sentence) + if cached is not None: + # Mark as most-recently used; return a defensive copy so downstream + # mutation (e.g. by a vec_db client) cannot poison the cache. + self._query_embed_cache.move_to_end(sentence) + return list(cached) + + embedding = self.embedder.embed([sentence])[0] + self._query_embed_cache[sentence] = list(embedding) + if len(self._query_embed_cache) > self._QUERY_EMBED_CACHE_MAX: + self._query_embed_cache.popitem(last=False) + return embedding def parse_json_result(self, response_text): try: diff --git a/tests/memories/textual/test_general.py b/tests/memories/textual/test_general.py index bebedcb56..f78e034af 100644 --- a/tests/memories/textual/test_general.py +++ b/tests/memories/textual/test_general.py @@ -336,6 +336,103 @@ def test_delete_all_memories(self): self.mock_vector_db.delete_collection.assert_called_once_with(collection_name) self.mock_vector_db.create_collection.assert_called_once() # Assumes create_collection is called after delete + def test_embed_one_sentence_caches_repeated_queries(self): + """Repeated identical sentences should hit a per-instance cache and skip the embedder.""" + sentence = "frequent query" + embedding = [0.7, 0.8, 0.9] + self.mock_embedder.embed.return_value = [embedding] + + first = self.memory._embed_one_sentence(sentence) + second = self.memory._embed_one_sentence(sentence) + third = self.memory._embed_one_sentence(sentence) + + self.assertEqual(first, embedding) + self.assertEqual(second, embedding) + self.assertEqual(third, embedding) + # Embedder must be invoked exactly once across the three identical lookups. + self.mock_embedder.embed.assert_called_once_with([sentence]) + + def test_embed_one_sentence_cache_evicts_when_full(self): + """Cache must bound memory growth — oldest entries are evicted, recent ones still hit.""" + # The cache cap is an implementation detail; we use the public attribute to size the test. + cap = getattr(self.memory, "_QUERY_EMBED_CACHE_MAX", 256) + self.mock_embedder.embed.side_effect = lambda batch: [[float(len(batch[0]))]] + + # Fill the cache to exactly cap entries (sentence-0 .. sentence-{cap-1}). + for i in range(cap): + self.memory._embed_one_sentence(f"sentence-{i}") + self.assertEqual(self.mock_embedder.embed.call_count, cap) + + # One more distinct sentence pushes us over the limit — sentence-0 must be evicted. + self.memory._embed_one_sentence("sentence-overflow") + calls_after_fill = self.mock_embedder.embed.call_count + self.assertEqual(calls_after_fill, cap + 1) + + # The most-recent sentence is still cached — no extra embedder call. + self.memory._embed_one_sentence("sentence-overflow") + self.assertEqual(self.mock_embedder.embed.call_count, calls_after_fill) + + # The oldest sentence WAS evicted — re-querying it triggers a new embed call. + self.memory._embed_one_sentence("sentence-0") + self.assertEqual(self.mock_embedder.embed.call_count, calls_after_fill + 1) + + def test_search_preserves_vector_db_order(self): + """Search must return items in the order the vector DB returned them (already sorted).""" + query = "ranked query" + top_k = 3 + self.mock_embedder.embed.return_value = [[0.1, 0.2, 0.3]] + + ordered_ids = [str(uuid.uuid4()) for _ in range(top_k)] + # Vector DB returns descending-by-score; search() must NOT shuffle it. + ordered_results = [ + VecDBItem( + id=ordered_ids[0], + vector=[0.1], + payload={"id": ordered_ids[0], "memory": "top hit", "metadata": {}}, + score=0.99, + ), + VecDBItem( + id=ordered_ids[1], + vector=[0.2], + payload={"id": ordered_ids[1], "memory": "middle hit", "metadata": {}}, + score=0.55, + ), + VecDBItem( + id=ordered_ids[2], + vector=[0.3], + payload={"id": ordered_ids[2], "memory": "low hit", "metadata": {}}, + score=0.10, + ), + ] + self.mock_vector_db.search.return_value = ordered_results + + results = self.memory.search(query, top_k) + + self.assertEqual([m.id for m in results], ordered_ids) + + def test_update_uses_cached_embedding_when_query_matches(self): + """update() shares the same embedding cache as search() — repeated text shouldn't re-embed.""" + text = "shared text body" + embedding = [0.42, 0.42, 0.42] + self.mock_embedder.embed.return_value = [embedding] + + # Warm the cache via a search call. + self.mock_vector_db.search.return_value = [] + self.memory.search(text, top_k=1) + + # Now update with the same memory text — should reuse the cached embedding. + memory_id = str(uuid.uuid4()) + item = TextualMemoryItem(memory=text, metadata={"source": "conversation"}) + self.memory.update(memory_id, item) + + # Embedder was invoked once (during search) and reused on update(). + self.mock_embedder.embed.assert_called_once_with([text]) + # And vector_db.update received an item carrying the cached vector. + self.mock_vector_db.update.assert_called_once() + called_args = self.mock_vector_db.update.call_args.args + self.assertEqual(called_args[0], memory_id) + self.assertEqual(called_args[1].vector, embedding) + if __name__ == "__main__": unittest.main()