Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 36 additions & 13 deletions src/memos/memories/textual/general.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os

from collections import OrderedDict
from datetime import datetime
from typing import Any

Expand All @@ -24,6 +25,11 @@
class GeneralTextMemory(BaseTextMemory):
"""General textual memory implementation for storing and retrieving memories."""

# Upper bound on the per-instance query-embedding cache. Caps memory growth
# for long-running processes that handle a large vocabulary of distinct
# queries; tuned against typical embedding dims (~1k floats per entry).
_QUERY_EMBED_CACHE_MAX = 256

def __init__(self, config: GeneralTextMemoryConfig):
"""Initialize memory with the given configuration."""
# Set mode from class default or override if needed
Expand All @@ -34,6 +40,9 @@ def __init__(self, config: GeneralTextMemoryConfig):
)
self.vector_db: QdrantVecDB = VecDBFactory.from_config(config.vector_db)
self.embedder: OllamaEmbedder | ArkEmbedder = EmbedderFactory.from_config(config.embedder)
# LRU cache of sentence -> embedding. Keyed on raw text; identical
# queries within a process share a single embedder round-trip.
self._query_embed_cache: OrderedDict[str, list[float]] = OrderedDict()

@retry(
stop=stop_after_attempt(3),
Expand Down Expand Up @@ -128,20 +137,18 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem
"""
query_vector = self._embed_one_sentence(query)
search_results = self.vector_db.search(query_vector, top_k)
search_results = sorted( # make higher score first
search_results, key=lambda x: x.score, reverse=True
)
result_memories = [
TextualMemoryItem(**search_item.payload) for search_item in search_results
# The vector DB (Qdrant `query_points`) returns results in descending
# score order already; a redundant in-process sort was removed here.
return [
TextualMemoryItem.model_validate(search_item.payload) for search_item in search_results
]
return result_memories

def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem:
"""Get a memory by its ID."""
result = self.vector_db.get_by_id(memory_id)
if result is None:
raise ValueError(f"Memory with ID {memory_id} not found")
return TextualMemoryItem(**result.payload)
return TextualMemoryItem.model_validate(result.payload)

def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]:
"""Get memories by their IDs.
Expand All @@ -151,17 +158,15 @@ def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]:
list[TextualMemoryItem]: List of memories with the specified IDs.
"""
db_items = self.vector_db.get_by_ids(memory_ids)
memories = [TextualMemoryItem(**db_item.payload) for db_item in db_items]
return memories
return [TextualMemoryItem.model_validate(db_item.payload) for db_item in db_items]

def get_all(self) -> list[TextualMemoryItem]:
"""Get all memories.
Returns:
list[TextualMemoryItem]: List of all memories.
"""
all_items = self.vector_db.get_all()
all_memories = [TextualMemoryItem(**memo.payload) for memo in all_items]
return all_memories
return [TextualMemoryItem.model_validate(memo.payload) for memo in all_items]

def delete(self, memory_ids: list[str]) -> None:
"""Delete a memory."""
Expand Down Expand Up @@ -217,8 +222,26 @@ def drop(
pass

def _embed_one_sentence(self, sentence: str) -> list[float]:
"""Embed a single sentence."""
return self.embedder.embed([sentence])[0]
"""Embed a single sentence, reusing recent embeddings via a bounded LRU cache.

The cache is keyed on the raw sentence text and capped at
``_QUERY_EMBED_CACHE_MAX`` entries. Identical queries within a process
(common during retry / rerank loops and chat sessions that revisit
the same topic) skip the embedder round-trip. Embeddings are deep-copied
on the way in and on the way out so callers cannot mutate cached state.
"""
cached = self._query_embed_cache.get(sentence)
if cached is not None:
# Mark as most-recently used; return a defensive copy so downstream
# mutation (e.g. by a vec_db client) cannot poison the cache.
self._query_embed_cache.move_to_end(sentence)
return list(cached)

embedding = self.embedder.embed([sentence])[0]
self._query_embed_cache[sentence] = list(embedding)
if len(self._query_embed_cache) > self._QUERY_EMBED_CACHE_MAX:
self._query_embed_cache.popitem(last=False)
return embedding

def parse_json_result(self, response_text):
try:
Expand Down
97 changes: 97 additions & 0 deletions tests/memories/textual/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,103 @@ def test_delete_all_memories(self):
self.mock_vector_db.delete_collection.assert_called_once_with(collection_name)
self.mock_vector_db.create_collection.assert_called_once() # Assumes create_collection is called after delete

def test_embed_one_sentence_caches_repeated_queries(self):
"""Repeated identical sentences should hit a per-instance cache and skip the embedder."""
sentence = "frequent query"
embedding = [0.7, 0.8, 0.9]
self.mock_embedder.embed.return_value = [embedding]

first = self.memory._embed_one_sentence(sentence)
second = self.memory._embed_one_sentence(sentence)
third = self.memory._embed_one_sentence(sentence)

self.assertEqual(first, embedding)
self.assertEqual(second, embedding)
self.assertEqual(third, embedding)
# Embedder must be invoked exactly once across the three identical lookups.
self.mock_embedder.embed.assert_called_once_with([sentence])

def test_embed_one_sentence_cache_evicts_when_full(self):
"""Cache must bound memory growth — oldest entries are evicted, recent ones still hit."""
# The cache cap is an implementation detail; we use the public attribute to size the test.
cap = getattr(self.memory, "_QUERY_EMBED_CACHE_MAX", 256)
self.mock_embedder.embed.side_effect = lambda batch: [[float(len(batch[0]))]]

# Fill the cache to exactly cap entries (sentence-0 .. sentence-{cap-1}).
for i in range(cap):
self.memory._embed_one_sentence(f"sentence-{i}")
self.assertEqual(self.mock_embedder.embed.call_count, cap)

# One more distinct sentence pushes us over the limit — sentence-0 must be evicted.
self.memory._embed_one_sentence("sentence-overflow")
calls_after_fill = self.mock_embedder.embed.call_count
self.assertEqual(calls_after_fill, cap + 1)

# The most-recent sentence is still cached — no extra embedder call.
self.memory._embed_one_sentence("sentence-overflow")
self.assertEqual(self.mock_embedder.embed.call_count, calls_after_fill)

# The oldest sentence WAS evicted — re-querying it triggers a new embed call.
self.memory._embed_one_sentence("sentence-0")
self.assertEqual(self.mock_embedder.embed.call_count, calls_after_fill + 1)

def test_search_preserves_vector_db_order(self):
"""Search must return items in the order the vector DB returned them (already sorted)."""
query = "ranked query"
top_k = 3
self.mock_embedder.embed.return_value = [[0.1, 0.2, 0.3]]

ordered_ids = [str(uuid.uuid4()) for _ in range(top_k)]
# Vector DB returns descending-by-score; search() must NOT shuffle it.
ordered_results = [
VecDBItem(
id=ordered_ids[0],
vector=[0.1],
payload={"id": ordered_ids[0], "memory": "top hit", "metadata": {}},
score=0.99,
),
VecDBItem(
id=ordered_ids[1],
vector=[0.2],
payload={"id": ordered_ids[1], "memory": "middle hit", "metadata": {}},
score=0.55,
),
VecDBItem(
id=ordered_ids[2],
vector=[0.3],
payload={"id": ordered_ids[2], "memory": "low hit", "metadata": {}},
score=0.10,
),
]
self.mock_vector_db.search.return_value = ordered_results

results = self.memory.search(query, top_k)

self.assertEqual([m.id for m in results], ordered_ids)

def test_update_uses_cached_embedding_when_query_matches(self):
"""update() shares the same embedding cache as search() — repeated text shouldn't re-embed."""
text = "shared text body"
embedding = [0.42, 0.42, 0.42]
self.mock_embedder.embed.return_value = [embedding]

# Warm the cache via a search call.
self.mock_vector_db.search.return_value = []
self.memory.search(text, top_k=1)

# Now update with the same memory text — should reuse the cached embedding.
memory_id = str(uuid.uuid4())
item = TextualMemoryItem(memory=text, metadata={"source": "conversation"})
self.memory.update(memory_id, item)

# Embedder was invoked once (during search) and reused on update().
self.mock_embedder.embed.assert_called_once_with([text])
# And vector_db.update received an item carrying the cached vector.
self.mock_vector_db.update.assert_called_once()
called_args = self.mock_vector_db.update.call_args.args
self.assertEqual(called_args[0], memory_id)
self.assertEqual(called_args[1].vector, embedding)


if __name__ == "__main__":
unittest.main()
Loading