From 8553057778fca54e1729bb05f4a1b2005735917d Mon Sep 17 00:00:00 2001 From: Chirag Khatri Date: Wed, 3 Jun 2026 17:10:04 +0530 Subject: [PATCH 1/6] feat(lifecycle): add MemoryLifecycle schema and pure lifecycle functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add src/schemas/memory_lifecycle.py with the forget/TTL boundary model (MemoryLifecycle + RESERVED_LIFECYCLE_KEYS) and src/pipelines/lifecycle.py with pure, side-effect-free helpers: - is_retrievable(metadata, now): retrieval-time gate; handles forgotten state, expired TTL, and legacy records (no lifecycle keys → always retrievable). - build_lifecycle_metadata(now, ttl_days, reason): stamps forget + expires_at onto a metadata dict for storage at v2 ingest time. Co-Authored-By: Claude Sonnet 4.6 --- src/pipelines/lifecycle.py | 71 +++++++++++++++++++++++++++++++++ src/schemas/memory_lifecycle.py | 24 +++++++++++ 2 files changed, 95 insertions(+) create mode 100644 src/pipelines/lifecycle.py create mode 100644 src/schemas/memory_lifecycle.py diff --git a/src/pipelines/lifecycle.py b/src/pipelines/lifecycle.py new file mode 100644 index 0000000..fe01837 --- /dev/null +++ b/src/pipelines/lifecycle.py @@ -0,0 +1,71 @@ +""" +Memory lifecycle — pure, deterministic helper functions for forget/TTL. + +is_retrievable() is the single retrieval-time gate. build_forget_metadata() +stamps the lifecycle fields onto a metadata dict for storage. + +All functions are side-effect-free so they can be tested without live services. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Dict, Mapping, Optional + + +def is_retrievable(metadata: Mapping[str, Any], now: datetime) -> bool: + """Return True when a record should appear in retrieval results. + + Rules (applied in order): + 1. ``lifecycle_state == "forgotten"`` → hidden (manual soft-forget). + 2. ``forget is True`` and ``expires_at`` is present and in the past → hidden (TTL expired). + 3. Everything else (including all legacy records with no lifecycle keys) → retrievable. + + Missing keys default to the legacy-safe value so records stored before + lifecycle was introduced are never hidden. + """ + if metadata.get("lifecycle_state", "active") == "forgotten": + return False + + if metadata.get("forget"): + expires_raw = metadata.get("expires_at") + if expires_raw: + try: + expires_at = datetime.fromisoformat(str(expires_raw)) + # Make both sides timezone-aware or both naive for comparison + if expires_at.tzinfo is None and now.tzinfo is not None: + from datetime import timezone + expires_at = expires_at.replace(tzinfo=timezone.utc) + elif expires_at.tzinfo is not None and now.tzinfo is None: + expires_at = expires_at.replace(tzinfo=None) + if expires_at < now: + return False + except (ValueError, TypeError): + pass + + return True + + +def build_lifecycle_metadata( + now: datetime, + ttl_days: float, + reason: Optional[str] = None, +) -> Dict[str, Any]: + """Return the lifecycle metadata dict to merge onto a forget=true record. + + Called once at v2 ingestion time when the caller sets ``forget=true``. + The result is stored as part of the vector record's metadata so the + retrieval-time filter can enforce the TTL without any background sweeper. + """ + from datetime import timedelta + expires_at = now + timedelta(days=ttl_days) + meta: Dict[str, Any] = { + "forget": True, + "expires_at": expires_at.isoformat(), + "lifecycle_state": "active", + "created_at": now.isoformat(), + "updated_at": now.isoformat(), + } + if reason: + meta["forget_reason"] = reason + return meta diff --git a/src/schemas/memory_lifecycle.py b/src/schemas/memory_lifecycle.py new file mode 100644 index 0000000..a8531c2 --- /dev/null +++ b/src/schemas/memory_lifecycle.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Literal, Optional + +from pydantic import BaseModel + + +class MemoryLifecycle(BaseModel): + """Boundary model for forget/TTL lifecycle metadata stored on vector records. + + All fields are optional and defaulted so legacy records (no lifecycle keys) + validate cleanly and are treated as active + retrievable. + """ + forget: bool = False + expires_at: Optional[datetime] = None + lifecycle_state: Literal["active", "forgotten"] = "active" + forgotten_at: Optional[datetime] = None + forget_reason: Optional[str] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + + +RESERVED_LIFECYCLE_KEYS: frozenset[str] = frozenset(MemoryLifecycle.model_fields) From e92161824ac69d6e4fa1897eedfff540539f2c55 Mon Sep 17 00:00:00 2001 From: Chirag Khatri Date: Wed, 3 Jun 2026 17:10:19 +0530 Subject: [PATCH 2/6] feat(lifecycle): add forget TTL setting and thread extra_metadata through Weaver MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - src/config/settings.py: add memory_forget_default_ttl_days (default 30, configurable via MEMORY_FORGET_DEFAULT_TTL_DAYS env var). - src/pipelines/weaver.py: add extra_metadata=None to execute() and _execute_batched_vector(); inject _now callable for deterministic tests. extra_metadata is merged into vector metadata at both ADD sites (batched flush + _vector_add). Temporal, code, and snippet paths are unaffected. extra_metadata=None is a no-op — existing behaviour is byte-identical. Co-Authored-By: Claude Sonnet 4.6 --- src/config/settings.py | 12 ++++++++++++ src/pipelines/weaver.py | 19 +++++++++++++++++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/src/config/settings.py b/src/config/settings.py index acc7fc0..934ad05 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -506,6 +506,18 @@ class Settings(BaseSettings): description="Razorpay subscription plan ID for the global USD Pro plan", ) + # ============================================================================= + # Memory Lifecycle — Forget / TTL (default OFF — fully backward compatible) + # ============================================================================= + memory_forget_default_ttl_days: float = Field( + default=30.0, + gt=0.0, + description=( + "Default TTL in days for memories ingested with forget=true via the v2 API. " + "Expired records are hidden from retrieval (filtered at read time; no sweeper)." + ), + ) + @field_validator("fallback_order") @classmethod def validate_fallback_order(cls, v: List[str]) -> List[str]: diff --git a/src/pipelines/weaver.py b/src/pipelines/weaver.py index 55d55cb..8a8ee52 100644 --- a/src/pipelines/weaver.py +++ b/src/pipelines/weaver.py @@ -13,6 +13,7 @@ from __future__ import annotations import asyncio +from datetime import datetime, timezone from functools import partial import logging from typing import Any, Callable, Dict, List, Optional @@ -63,6 +64,7 @@ def __init__( code_vector_store: Optional[BaseVectorStore] = None, graph_create_annotation: Optional[GraphCreateAnnotationFn] = None, snippet_vector_store: Optional[BaseVectorStore] = None, + _now: Optional[Callable[[], datetime]] = None, ) -> None: self.vector_store = vector_store self.embed_fn = embed_fn @@ -72,6 +74,7 @@ def __init__( self.code_vector_store = code_vector_store self.graph_create_annotation = graph_create_annotation self.snippet_vector_store = snippet_vector_store + self._now: Callable[[], datetime] = _now or (lambda: datetime.now(timezone.utc)) # ------------------------------------------------------------------ # Public entry point @@ -82,6 +85,7 @@ async def execute( judge_result: JudgeResult, domain: JudgeDomain, user_id: str, + extra_metadata: Optional[Dict[str, Any]] = None, ) -> WeaverResult: result = WeaverResult() @@ -91,7 +95,9 @@ async def execute( # Optimization: Batch vector operations if possible if domain not in (JudgeDomain.TEMPORAL, JudgeDomain.CODE, JudgeDomain.SNIPPET) and self.vector_store: - batched_executed = await self._execute_batched_vector(judge_result.operations, domain, user_id) + batched_executed = await self._execute_batched_vector( + judge_result.operations, domain, user_id, extra_metadata=extra_metadata + ) result.executed.extend(batched_executed) else: for op in judge_result.operations: @@ -106,6 +112,7 @@ async def _execute_batched_vector( operations: List[Operation], domain: JudgeDomain, user_id: str, + extra_metadata: Optional[Dict[str, Any]] = None, ) -> List[ExecutedOp]: """Batch ADD and DELETE operations to reduce vector store round-trips.""" executed_ops: List[ExecutedOp] = [] @@ -142,6 +149,8 @@ async def flush_add_batch(): try: meta = {"user_id": user_id, "domain": domain.value} meta.update(_extract_structured_metadata(op.content)) + if extra_metadata: + meta.update(extra_metadata) valid_ops.append(op) texts.append(op.content) @@ -371,7 +380,11 @@ async def _execute_vector( ) async def _vector_add( - self, op: Operation, domain: JudgeDomain, user_id: str, + self, + op: Operation, + domain: JudgeDomain, + user_id: str, + extra_metadata: Optional[Dict[str, Any]] = None, ) -> ExecutedOp: if not self.embed_fn: return ExecutedOp( @@ -385,6 +398,8 @@ async def _vector_add( # Store structured metadata for deterministic lookups structured = _extract_structured_metadata(op.content) metadata.update(structured) + if extra_metadata: + metadata.update(extra_metadata) ids = self.vector_store.add( texts=[op.content], From 1698c9ce44905855d774cd9c6dd19c4085cfc19c Mon Sep 17 00:00:00 2001 From: Chirag Khatri Date: Wed, 3 Jun 2026 17:10:31 +0530 Subject: [PATCH 3/6] feat(v2/ingest): add forget flag to v2 ingest API and thread lifecycle_metadata - src/api/schemas.py: add forget: bool = False to IngestRequest. Additive, default off; accepted by v1 but only acted on in the v2 path. - src/api/routes/v2/memory.py: when forget=true, compute lifecycle_metadata {forget, expires_at, lifecycle_state} using memory_forget_default_ttl_days and store in payload. Include forget in idempotency fields so a forget vs non-forget request of the same content are treated as distinct jobs. - src/api/routes/v2/activities.py: pass lifecycle_metadata from payload into per-domain state dicts (profile, temporal, summary, image). - src/pipelines/ingest.py: pass extra_metadata=state.get("lifecycle_metadata") to weaver.execute() in the three vector-domain nodes (_node_extract_profile, _node_extract_image, _node_extract_summary). Temporal, code, snippet deferred. Co-Authored-By: Claude Sonnet 4.6 --- src/api/routes/v2/activities.py | 6 ++++++ src/api/routes/v2/memory.py | 16 ++++++++++++++++ src/api/schemas.py | 8 ++++++++ src/pipelines/ingest.py | 3 +++ 4 files changed, 33 insertions(+) diff --git a/src/api/routes/v2/activities.py b/src/api/routes/v2/activities.py index 930128d..91f59a6 100644 --- a/src/api/routes/v2/activities.py +++ b/src/api/routes/v2/activities.py @@ -109,11 +109,14 @@ async def memory_domain_activity(payload: Dict[str, Any]) -> Dict[str, Any]: billing_account_id=payload.get("billing_account_id"), user_id=user_id, ): + lifecycle_metadata = payload.get("lifecycle_metadata") + if domain == "profile": result = await pipeline._node_extract_profile( { "profile_queries": payload.get("queries", []), "user_id": user_id, + "lifecycle_metadata": lifecycle_metadata, } ) return {"domain": domain, "result": _domain_payload(result, "profile")} @@ -124,6 +127,7 @@ async def memory_domain_activity(payload: Dict[str, Any]) -> Dict[str, Any]: "temporal_queries": payload.get("queries", []), "session_datetime": payload.get("session_datetime", ""), "user_id": user_id, + "lifecycle_metadata": lifecycle_metadata, } ) return {"domain": domain, "result": _domain_payload(result, "temporal")} @@ -134,6 +138,7 @@ async def memory_domain_activity(payload: Dict[str, Any]) -> Dict[str, Any]: "user_query": payload.get("user_query", ""), "agent_response": payload.get("agent_response", ""), "user_id": user_id, + "lifecycle_metadata": lifecycle_metadata, } ) return {"domain": domain, "result": _domain_payload(result, "summary")} @@ -144,6 +149,7 @@ async def memory_domain_activity(payload: Dict[str, Any]) -> Dict[str, Any]: "classifier_output": payload.get("classifier_output", ""), "image_url": payload.get("image_url", ""), "user_id": user_id, + "lifecycle_metadata": lifecycle_metadata, } ) return {"domain": domain, "result": _domain_payload(result, "image")} diff --git a/src/api/routes/v2/memory.py b/src/api/routes/v2/memory.py index f6c68b2..7479b53 100644 --- a/src/api/routes/v2/memory.py +++ b/src/api/routes/v2/memory.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +from datetime import datetime, timedelta, timezone import time from typing import Any, Dict @@ -122,6 +123,18 @@ async def ingest_memory_v2(req: IngestRequest, request: Request, user: dict = De payload = req.model_dump() payload["user_id"] = user_id payload["timeout_seconds"] = float(settings.memory_ingest_timeout_seconds) + + # When forget=true, compute lifecycle_metadata and thread it through so the + # weaver stamps the forget flag + TTL on every vector record it writes. + if req.forget: + now = datetime.now(timezone.utc) + expires_at = now + timedelta(days=settings.memory_forget_default_ttl_days) + payload["lifecycle_metadata"] = { + "forget": True, + "expires_at": expires_at.isoformat(), + "lifecycle_state": "active", + } + idempotency_fields = { "user_id": user_id, "org_id": payload.get("org_id", "default"), @@ -131,6 +144,9 @@ async def ingest_memory_v2(req: IngestRequest, request: Request, user: dict = De "session_datetime": req.session_datetime, "image_url": req.image_url, "effort_level": req.effort_level, + # Include forget in idempotency so forget vs non-forget of the + # same content are treated as distinct requests. + "forget": req.forget, }), } job_id = _durable_job_id("memory_ingest", idempotency_fields) diff --git a/src/api/schemas.py b/src/api/schemas.py index ff10e6d..63dc648 100644 --- a/src/api/schemas.py +++ b/src/api/schemas.py @@ -86,6 +86,14 @@ class IngestRequest(UserScopedModel): default="low", description="'low' (fast, single pass) or 'high' (chunked parallel extraction)", ) + forget: bool = Field( + default=False, + description=( + "When true, the stored memory is tagged with a TTL and will be " + "automatically excluded from retrieval after it expires. " + "Only honoured on the v2 ingest path." + ), + ) @field_validator("user_query") @classmethod diff --git a/src/pipelines/ingest.py b/src/pipelines/ingest.py index f68a209..2f23e01 100644 --- a/src/pipelines/ingest.py +++ b/src/pipelines/ingest.py @@ -870,6 +870,7 @@ async def _node_extract_profile(self, state: IngestState) -> Dict[str, Any]: judge_result=judge_result, domain=JudgeDomain.PROFILE, user_id=user_id, + extra_metadata=state.get("lifecycle_metadata"), ) return { "profile_result": result, @@ -963,6 +964,7 @@ async def _node_extract_image(self, state: IngestState) -> Dict[str, Any]: judge_result=judge_result, domain=JudgeDomain.SUMMARY, user_id=user_id, + extra_metadata=state.get("lifecycle_metadata"), ) return { @@ -1090,6 +1092,7 @@ async def _node_extract_summary(self, state: IngestState) -> Dict[str, Any]: judge_result=judge_result, domain=JudgeDomain.SUMMARY, user_id=state.get("user_id", "default"), + extra_metadata=state.get("lifecycle_metadata"), ) return { "summary_result": result, From 485f7334b9d5090d8595de7ce003f4949b27a53f Mon Sep 17 00:00:00 2001 From: Chirag Khatri Date: Wed, 3 Jun 2026 17:10:40 +0530 Subject: [PATCH 4/6] feat(retrieval): filter expired forget records at read time in _search_summary + profile catalog MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Apply is_retrievable(metadata, now) in the two vector read seams: - _search_summary: expired/forgotten records excluded from summary results. - _fetch_profile_catalog: expired/forgotten records stripped from the profile catalog and cached records before the LLM sees them. Legacy records (no lifecycle keys) always pass through — no behavior change for any memory not ingested with forget=true. _now is injectable for tests. Co-Authored-By: Claude Sonnet 4.6 --- src/pipelines/retrieval.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/pipelines/retrieval.py b/src/pipelines/retrieval.py index 8f640b2..275f171 100644 --- a/src/pipelines/retrieval.py +++ b/src/pipelines/retrieval.py @@ -21,6 +21,7 @@ from __future__ import annotations import asyncio +from datetime import datetime, timezone import logging from typing import Any, Callable, Dict, List, Optional @@ -31,6 +32,7 @@ from src.config import settings from src.graph.neo4j_client import Neo4jClient +from src.pipelines.lifecycle import is_retrievable from src.prompts.retrieval import ANSWER_PROMPT, build_system_prompt from src.schemas.retrieval import RetrievalResult, SourceRecord from src.schemas.code import snippets_namespace @@ -100,6 +102,7 @@ def __init__( model: Optional[BaseChatModel] = None, vector_store: Optional[BaseVectorStore] = None, neo4j_client: Optional[Neo4jClient] = None, + _now: Optional[Callable[[], datetime]] = None, ) -> None: # ── LLM ─────────────────────────────────────────────────────── if model is None: @@ -133,6 +136,7 @@ def __init__( self.embed_fn = embed_fn self._snippet_stores: Dict[str, BaseVectorStore] = {} + self._now: Callable[[], datetime] = _now or (lambda: datetime.now(timezone.utc)) logger.info("RetrievalPipeline initialized") @@ -413,8 +417,11 @@ async def _search_summary( user_id: str, top_k: int = 10, ) -> List[SourceRecord]: - """Semantic search over summary entries in Pinecone.""" + """Semantic search over summary entries in Pinecone. + Records ingested with ``forget=true`` whose TTL has passed are filtered + out at read time. Legacy records (no lifecycle keys) always pass through. + """ results = await self.vector_store.search_by_text( query_text=query, top_k=top_k, @@ -424,8 +431,11 @@ async def _search_summary( }, ) + now = self._now() records = [] for r in results: + if not is_retrievable(r.metadata, now): + continue records.append(SourceRecord( domain="summary", content=r.content, @@ -503,10 +513,15 @@ def _fetch_profile_catalog(self, user_id: str): logger.warning("Failed to fetch profile catalog: %s", exc) return [], [] + now = self._now() catalog: List[Dict[str, str]] = [] seen = set() + live_results = [] for r in results: + if not is_retrievable(r.metadata, now): + continue + live_results.append(r) main_content = r.metadata.get("main_content", "") if not main_content or main_content in seen: continue @@ -524,7 +539,7 @@ def _fetch_profile_catalog(self, user_id: str): "sub_topic": "", }) - return catalog, results + return catalog, live_results def _format_catalog(self, catalog: List[Dict[str, str]]) -> str: """Format profile catalog for the system prompt.""" From 9a6f8fe3f0e97b267f3e502ffb7fa536501f2d58 Mon Sep 17 00:00:00 2001 From: Chirag Khatri Date: Wed, 3 Jun 2026 17:10:52 +0530 Subject: [PATCH 5/6] test(lifecycle): add unit + integration tests for forget/TTL; fix conftest merge bug MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - tests/conftest.py: fix InMemoryVectorStore.update() to merge metadata instead of replacing it. Prod backends (SQLite local.py, Pinecone) merge; the old fake replaced, creating false-green test risk. A regression test (test_fake_update_merges_metadata) locks the fix. - tests/unit/test_memory_lifecycle.py: 13 unit tests covering is_retrievable (legacy/forgotten/TTL expired/live/no-expires_at), build_lifecycle_metadata, Weaver extra_metadata merge on ADD, extra_metadata=None no-op, and _search_summary filtering (expired hidden, non-forget unaffected). - tests/integration/test_weaver_pipeline.py: two new lifecycle round-trip tests — forget metadata persisted on every ADD; extra_metadata=None no-op. - tests/integration/test_ingest_pipeline.py: update RecordingWeaver.execute() to accept extra_metadata=None so existing ingest-pipeline tests keep passing. - CHANGELOG.md: document v2 forget flag + TTL. Co-Authored-By: Claude Sonnet 4.6 --- CHANGELOG.md | 1 + tests/conftest.py | 3 +- tests/integration/test_ingest_pipeline.py | 2 +- tests/integration/test_weaver_pipeline.py | 60 ++++++ tests/unit/test_memory_lifecycle.py | 235 ++++++++++++++++++++++ 5 files changed, 299 insertions(+), 2 deletions(-) create mode 100644 tests/unit/test_memory_lifecycle.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e069a9..87d629a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Unreleased +- Add `forget: bool` flag to the v2 ingest API (`POST /v2/memory/ingest`): memories ingested with `forget=true` are tagged with a TTL (`expires_at`) and automatically excluded from all retrieval results after the TTL elapses (default 30 days, configurable via `MEMORY_FORGET_DEFAULT_TTL_DAYS`). No background sweeper required — enforcement is at read time. - Add modular Razorpay billing, credit wallets, ledger reservations, and v2 memory workflow metering. - Add durable Temporal-backed v2 memory and scanner workflow APIs with job status, retry, cancel, and dead-letter endpoints. - Add modular LoCoMo and BEAM benchmark runners for the Python XMem API. diff --git a/tests/conftest.py b/tests/conftest.py index 6533fb2..fa74836 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -122,7 +122,8 @@ def update(self, id, text=None, embedding=None, metadata=None): if embedding is not None: current["embedding"] = embedding if metadata is not None: - current["metadata"] = metadata + # Merge (not replace) — mirrors prod backends (local.py:177, pinecone.py:669) + current["metadata"] = {**current.get("metadata", {}), **metadata} return True def delete(self, ids): diff --git a/tests/integration/test_ingest_pipeline.py b/tests/integration/test_ingest_pipeline.py index d4c06b2..1ec2745 100644 --- a/tests/integration/test_ingest_pipeline.py +++ b/tests/integration/test_ingest_pipeline.py @@ -52,7 +52,7 @@ def __init__(self): self.calls = [] self.snippet_vector_store = None - async def execute(self, judge_result, domain, user_id): + async def execute(self, judge_result, domain, user_id, extra_metadata=None): self.calls.append((judge_result, domain, user_id)) return WeaverResult( executed=[ diff --git a/tests/integration/test_weaver_pipeline.py b/tests/integration/test_weaver_pipeline.py index 798982a..e6569fc 100644 --- a/tests/integration/test_weaver_pipeline.py +++ b/tests/integration/test_weaver_pipeline.py @@ -1,11 +1,15 @@ from __future__ import annotations +from datetime import datetime, timedelta, timezone + import pytest from src.pipelines.weaver import Weaver from src.schemas.judge import JudgeDomain, JudgeResult, Operation, OperationType from src.schemas.weaver import OpStatus +NOW = datetime(2025, 6, 1, 12, 0, 0, tzinfo=timezone.utc) + @pytest.mark.asyncio async def test_weaver_batches_vector_add_update_and_delete(vector_store, fast_embed): @@ -120,3 +124,59 @@ async def create_annotation(**kwargs): assert graph_annotations[0]["target_symbol"] == "Payment.process" assert vector_store.add_calls[-1]["metadata"][0]["domain"] == "snippet" assert vector_store.add_calls[-1]["metadata"][0]["language"] == "python" + + +@pytest.mark.asyncio +async def test_weaver_forget_lifecycle_metadata_stored(vector_store, fast_embed): + """extra_metadata carrying forget/TTL fields is written into every ADD record.""" + expires_at = (NOW + timedelta(days=30)).isoformat() + lifecycle = { + "forget": True, + "expires_at": expires_at, + "lifecycle_state": "active", + } + weaver = Weaver( + vector_store=vector_store, + embed_fn=fast_embed, + _now=lambda: NOW, + ) + result = await weaver.execute( + JudgeResult( + operations=[ + Operation(type=OperationType.ADD, content="work / title = Engineer"), + Operation(type=OperationType.ADD, content="work / company = XMem"), + ], + confidence=1.0, + ), + JudgeDomain.PROFILE, + "user-1", + extra_metadata=lifecycle, + ) + + assert result.succeeded == 2 + for record in vector_store.records.values(): + meta = record["metadata"] + assert meta.get("forget") is True + assert meta.get("expires_at") == expires_at + assert meta.get("lifecycle_state") == "active" + # Core metadata must still be present + assert meta.get("user_id") == "user-1" + assert meta.get("domain") == "profile" + + +@pytest.mark.asyncio +async def test_weaver_extra_metadata_none_no_regression(vector_store, fast_embed): + """extra_metadata=None must not change any existing behaviour.""" + weaver = Weaver(vector_store=vector_store, embed_fn=fast_embed) + result = await weaver.execute( + JudgeResult( + operations=[Operation(type=OperationType.ADD, content="work / title = Engineer")], + confidence=1.0, + ), + JudgeDomain.PROFILE, + "user-1", + ) + assert result.succeeded == 1 + meta = list(vector_store.records.values())[0]["metadata"] + assert "forget" not in meta + assert "expires_at" not in meta diff --git a/tests/unit/test_memory_lifecycle.py b/tests/unit/test_memory_lifecycle.py new file mode 100644 index 0000000..997fd56 --- /dev/null +++ b/tests/unit/test_memory_lifecycle.py @@ -0,0 +1,235 @@ +""" +Unit tests for memory lifecycle (forget/TTL) — Issue #166 PR #1. + +All tests use InMemoryVectorStore; no live services are called. +""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +import pytest + +from src.pipelines.lifecycle import build_lifecycle_metadata, is_retrievable +from src.pipelines.weaver import Weaver +from src.schemas.judge import JudgeDomain, JudgeResult, Operation, OperationType +from tests.conftest import InMemoryVectorStore + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +NOW = datetime(2025, 6, 1, 12, 0, 0, tzinfo=timezone.utc) +FUTURE = NOW + timedelta(days=30) +PAST = NOW - timedelta(days=1) + + +def _fixed_now(dt: datetime): + return lambda: dt + + +# --------------------------------------------------------------------------- +# 1. is_retrievable — legacy records always pass through +# --------------------------------------------------------------------------- + +def test_is_retrievable_legacy_record(): + """Records stored before lifecycle was introduced must be retrievable.""" + assert is_retrievable({}, NOW) is True + assert is_retrievable({"user_id": "u1", "domain": "summary"}, NOW) is True + + +def test_is_retrievable_forgotten_state(): + assert is_retrievable({"lifecycle_state": "forgotten"}, NOW) is False + + +def test_is_retrievable_ttl_not_expired(): + meta = { + "forget": True, + "expires_at": FUTURE.isoformat(), + "lifecycle_state": "active", + } + assert is_retrievable(meta, NOW) is True + + +def test_is_retrievable_ttl_expired(): + meta = { + "forget": True, + "expires_at": PAST.isoformat(), + "lifecycle_state": "active", + } + assert is_retrievable(meta, NOW) is False + + +def test_is_retrievable_forget_true_no_expires_at(): + """forget=true with no expires_at should NOT hide the record (missing key = legacy-safe).""" + meta = {"forget": True} + assert is_retrievable(meta, NOW) is True + + +def test_is_retrievable_forget_false_ignores_expires_at(): + """forget=false means the record is always active regardless of expires_at.""" + meta = {"forget": False, "expires_at": PAST.isoformat()} + assert is_retrievable(meta, NOW) is True + + +# --------------------------------------------------------------------------- +# 2. build_lifecycle_metadata +# --------------------------------------------------------------------------- + +def test_build_lifecycle_metadata_sets_correct_fields(): + meta = build_lifecycle_metadata(now=NOW, ttl_days=7.0) + assert meta["forget"] is True + assert meta["lifecycle_state"] == "active" + + # expires_at must be NOW + 7 days + expected = (NOW + timedelta(days=7.0)).isoformat() + assert meta["expires_at"] == expected + + assert meta["created_at"] == NOW.isoformat() + assert meta["updated_at"] == NOW.isoformat() + + +def test_build_lifecycle_metadata_with_reason(): + meta = build_lifecycle_metadata(now=NOW, ttl_days=1.0, reason="gdpr") + assert meta.get("forget_reason") == "gdpr" + + +# --------------------------------------------------------------------------- +# 3. Weaver — extra_metadata merged onto stored records +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_weaver_extra_metadata_merged_on_add(vector_store, fast_embed): + """extra_metadata is merged onto the record stored by a batched ADD.""" + lifecycle_meta = { + "forget": True, + "expires_at": FUTURE.isoformat(), + "lifecycle_state": "active", + } + weaver = Weaver( + vector_store=vector_store, + embed_fn=fast_embed, + _now=_fixed_now(NOW), + ) + result = await weaver.execute( + JudgeResult( + operations=[Operation(type=OperationType.ADD, content="I like hiking")], + confidence=1.0, + ), + JudgeDomain.SUMMARY, + "u1", + extra_metadata=lifecycle_meta, + ) + + assert result.succeeded == 1 + stored_meta = list(vector_store.records.values())[0]["metadata"] + assert stored_meta["forget"] is True + assert stored_meta["expires_at"] == FUTURE.isoformat() + assert stored_meta["user_id"] == "u1" + assert stored_meta["domain"] == "summary" + + +@pytest.mark.asyncio +async def test_weaver_extra_metadata_none_is_noop(vector_store, fast_embed): + """extra_metadata=None leaves metadata exactly as before — no regression.""" + weaver = Weaver(vector_store=vector_store, embed_fn=fast_embed) + result = await weaver.execute( + JudgeResult( + operations=[Operation(type=OperationType.ADD, content="I like coffee")], + confidence=1.0, + ), + JudgeDomain.SUMMARY, + "u1", + ) + assert result.succeeded == 1 + stored_meta = list(vector_store.records.values())[0]["metadata"] + assert "forget" not in stored_meta + assert "expires_at" not in stored_meta + + +# --------------------------------------------------------------------------- +# 4. _search_summary filters expired records +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_search_summary_filters_expired_records(vector_store): + """Expired forget records are excluded; non-forget records are unaffected.""" + # A normal record — always retrievable + vector_store.seed("live-1", "normal summary", { + "user_id": "u1", "domain": "summary", + }, score=0.9) + + # A forget record whose TTL has not yet expired — still retrievable + vector_store.seed("forget-live", "temporary note (live)", { + "user_id": "u1", "domain": "summary", + "forget": True, + "expires_at": FUTURE.isoformat(), + "lifecycle_state": "active", + }, score=0.8) + + # A forget record whose TTL has passed — must be hidden + vector_store.seed("forget-expired", "temporary note (expired)", { + "user_id": "u1", "domain": "summary", + "forget": True, + "expires_at": PAST.isoformat(), + "lifecycle_state": "active", + }, score=0.7) + + from src.pipelines.retrieval import RetrievalPipeline + from tests.conftest import FakeChatModel, FakeNeo4jClient + + pipeline = RetrievalPipeline( + model=FakeChatModel(), + vector_store=vector_store, + neo4j_client=FakeNeo4jClient(), + _now=_fixed_now(NOW), + ) + records = await pipeline._search_summary("query", "u1", top_k=10) + + ids = [r.metadata["id"] for r in records] + assert "live-1" in ids + assert "forget-live" in ids + assert "forget-expired" not in ids + + +@pytest.mark.asyncio +async def test_search_summary_no_regression_without_lifecycle_keys(vector_store): + """Legacy records with no lifecycle keys must still be returned.""" + for i in range(3): + vector_store.seed(f"legacy-{i}", f"legacy content {i}", { + "user_id": "u1", "domain": "summary", + }) + + from src.pipelines.retrieval import RetrievalPipeline + from tests.conftest import FakeChatModel, FakeNeo4jClient + + pipeline = RetrievalPipeline( + model=FakeChatModel(), + vector_store=vector_store, + neo4j_client=FakeNeo4jClient(), + _now=_fixed_now(NOW), + ) + records = await pipeline._search_summary("query", "u1", top_k=10) + assert len(records) == 3 + + +# --------------------------------------------------------------------------- +# 5. conftest merge fix — InMemoryVectorStore.update() merges, not replaces +# --------------------------------------------------------------------------- + +def test_fake_update_merges_metadata(): + """update() must merge metadata, not replace it. + + This test FAILS on the old replace-fake and PASSES after the merge fix, + locking the correction so future regressions surface immediately. + """ + store = InMemoryVectorStore() + store.seed("r1", "content", {"a": 1, "b": 2}) + + store.update("r1", metadata={"b": 99, "c": 3}) + + meta = store.records["r1"]["metadata"] + assert meta["a"] == 1, "Pre-existing key must be preserved" + assert meta["b"] == 99, "Updated key must reflect new value" + assert meta["c"] == 3, "New key must be added" From 8ae91760e4b4b15b141b5494f6857af898ceeaaf Mon Sep 17 00:00:00 2001 From: Chirag Khatri Date: Wed, 3 Jun 2026 22:14:54 +0530 Subject: [PATCH 6/6] fix(lifecycle): patch UPDATE data-leak, sanitize extra_metadata, harden lifecycle guards Addresses P1 data-leak (Greptile) and two Gemini security-high findings from PR #228 review. `_vector_update` + batched UPDATE path now carry `extra_metadata` so forget/TTL is preserved when the Judge emits an UPDATE op. `extra_metadata` is sanitized once at `execute()` via a PROTECTED_METADATA_KEYS denylist so callers cannot overwrite `user_id`/`domain`. Route refactored to use `build_lifecycle_metadata()`; batch ingest rejects `forget=true` with HTTP 400 (was a silent no-op on a privacy path). Defensive guards added to `is_retrievable(None)` and datetime parse. 15 new regression + route tests; 92/92 passing. Co-Authored-By: Claude Opus 4.8 --- CHANGELOG.md | 4 +- src/api/routes/v2/memory.py | 28 ++-- src/pipelines/lifecycle.py | 12 +- src/pipelines/retrieval.py | 2 +- src/pipelines/weaver.py | 49 ++++-- src/schemas/memory_lifecycle.py | 9 ++ tests/api/test_memory_versioning.py | 52 +++++++ tests/integration/test_weaver_pipeline.py | 172 ++++++++++++++++++++++ tests/unit/test_memory_lifecycle.py | 17 +++ 9 files changed, 314 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 87d629a..357b5a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,9 @@ ## Unreleased -- Add `forget: bool` flag to the v2 ingest API (`POST /v2/memory/ingest`): memories ingested with `forget=true` are tagged with a TTL (`expires_at`) and automatically excluded from all retrieval results after the TTL elapses (default 30 days, configurable via `MEMORY_FORGET_DEFAULT_TTL_DAYS`). No background sweeper required — enforcement is at read time. +- Add `forget: bool` flag to `POST /v2/memory/ingest`: memories with `forget=true` get a TTL (`expires_at`) and are excluded from retrieval (`_search_summary` + profile catalog) once expired. Read-time enforcement; no sweeper. +- Add `memory_forget_default_ttl_days` setting (env `MEMORY_FORGET_DEFAULT_TTL_DAYS`, default 30). Known limitation: changing it does not refresh an already-cached (idempotent) forget job's TTL; resolved when TTL becomes a client-supplied field. +- `POST /v2/memory/batch-ingest` rejects `forget=true` with HTTP 400 (per-item forget not yet supported in batch). - Add modular Razorpay billing, credit wallets, ledger reservations, and v2 memory workflow metering. - Add durable Temporal-backed v2 memory and scanner workflow APIs with job status, retry, cancel, and dead-letter endpoints. - Add modular LoCoMo and BEAM benchmark runners for the Python XMem API. diff --git a/src/api/routes/v2/memory.py b/src/api/routes/v2/memory.py index 7479b53..1b3ed6a 100644 --- a/src/api/routes/v2/memory.py +++ b/src/api/routes/v2/memory.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -from datetime import datetime, timedelta, timezone +from datetime import datetime, timezone import time from typing import Any, Dict @@ -19,6 +19,7 @@ read_user_job, ) from src.api.routes.v2.temporal_client import start_job_workflow +from src.pipelines.lifecycle import build_lifecycle_metadata from src.api.schemas import APIResponse, BatchIngestRequest, IngestRequest, ScrapeRequest, StatusEnum from src.billing import InsufficientCredits, get_default_billing_service from src.config import settings @@ -127,13 +128,10 @@ async def ingest_memory_v2(req: IngestRequest, request: Request, user: dict = De # When forget=true, compute lifecycle_metadata and thread it through so the # weaver stamps the forget flag + TTL on every vector record it writes. if req.forget: - now = datetime.now(timezone.utc) - expires_at = now + timedelta(days=settings.memory_forget_default_ttl_days) - payload["lifecycle_metadata"] = { - "forget": True, - "expires_at": expires_at.isoformat(), - "lifecycle_state": "active", - } + payload["lifecycle_metadata"] = build_lifecycle_metadata( + now=datetime.now(timezone.utc), + ttl_days=settings.memory_forget_default_ttl_days, + ) idempotency_fields = { "user_id": user_id, @@ -144,8 +142,11 @@ async def ingest_memory_v2(req: IngestRequest, request: Request, user: dict = De "session_datetime": req.session_datetime, "image_url": req.image_url, "effort_level": req.effort_level, - # Include forget in idempotency so forget vs non-forget of the - # same content are treated as distinct requests. + # forget distinguishes forget vs non-forget of identical content. + # KNOWN LIMITATION (PR #2): server-default TTL is intentionally NOT hashed. + # Idempotency = "same request → same job"; the request didn't change, the + # server config did. Changing MEMORY_FORGET_DEFAULT_TTL_DAYS won't refresh a + # cached forget job's TTL. Resolved when forget_ttl_days becomes a client field. "forget": req.forget, }), } @@ -223,6 +224,13 @@ async def memory_job_status(job_id: str, request: Request, user: dict = Depends( @router.post("/batch-ingest", response_model=APIResponse, summary="Start an async durable batch memory ingest job") async def batch_ingest_memory_v2(req: BatchIngestRequest, request: Request, user: dict = Depends(require_api_key)): start = time.perf_counter() + if any(getattr(item, "forget", False) for item in req.items): + return _error( + request, + "forget=true is not supported in batch ingest yet; use POST /v2/memory/ingest per item.", + 400, + elapsed_ms(start), + ) user_id = memory_v1._current_user_id(user) items = [memory_v1._scoped_ingest_payload(user, item) for item in req.items] payload = { diff --git a/src/pipelines/lifecycle.py b/src/pipelines/lifecycle.py index fe01837..1496711 100644 --- a/src/pipelines/lifecycle.py +++ b/src/pipelines/lifecycle.py @@ -1,7 +1,7 @@ """ Memory lifecycle — pure, deterministic helper functions for forget/TTL. -is_retrievable() is the single retrieval-time gate. build_forget_metadata() +is_retrievable() is the single retrieval-time gate. build_lifecycle_metadata() stamps the lifecycle fields onto a metadata dict for storage. All functions are side-effect-free so they can be tested without live services. @@ -13,7 +13,7 @@ from typing import Any, Dict, Mapping, Optional -def is_retrievable(metadata: Mapping[str, Any], now: datetime) -> bool: +def is_retrievable(metadata: Optional[Mapping[str, Any]], now: datetime) -> bool: """Return True when a record should appear in retrieval results. Rules (applied in order): @@ -24,6 +24,9 @@ def is_retrievable(metadata: Mapping[str, Any], now: datetime) -> bool: Missing keys default to the legacy-safe value so records stored before lifecycle was introduced are never hidden. """ + if not metadata: + return True + if metadata.get("lifecycle_state", "active") == "forgotten": return False @@ -31,7 +34,10 @@ def is_retrievable(metadata: Mapping[str, Any], now: datetime) -> bool: expires_raw = metadata.get("expires_at") if expires_raw: try: - expires_at = datetime.fromisoformat(str(expires_raw)) + if isinstance(expires_raw, datetime): + expires_at = expires_raw + else: + expires_at = datetime.fromisoformat(str(expires_raw)) # Make both sides timezone-aware or both naive for comparison if expires_at.tzinfo is None and now.tzinfo is not None: from datetime import timezone diff --git a/src/pipelines/retrieval.py b/src/pipelines/retrieval.py index 275f171..cd5b2dc 100644 --- a/src/pipelines/retrieval.py +++ b/src/pipelines/retrieval.py @@ -522,7 +522,7 @@ def _fetch_profile_catalog(self, user_id: str): if not is_retrievable(r.metadata, now): continue live_results.append(r) - main_content = r.metadata.get("main_content", "") + main_content = (r.metadata or {}).get("main_content", "") if not main_content or main_content in seen: continue seen.add(main_content) diff --git a/src/pipelines/weaver.py b/src/pipelines/weaver.py index 8a8ee52..1054f55 100644 --- a/src/pipelines/weaver.py +++ b/src/pipelines/weaver.py @@ -24,6 +24,7 @@ Operation, OperationType, ) +from src.schemas.memory_lifecycle import PROTECTED_METADATA_KEYS from src.schemas.weaver import ExecutedOp, OpStatus, WeaverResult from src.storage.base import BaseVectorStore @@ -93,15 +94,20 @@ async def execute( logger.info("Nothing to execute — all NOOPs or empty.") return result + # Sanitize ONCE here — execute() is the sole entry point. + safe_extra: Optional[Dict[str, Any]] = None + if extra_metadata: + safe_extra = {k: v for k, v in extra_metadata.items() if k not in PROTECTED_METADATA_KEYS} + # Optimization: Batch vector operations if possible if domain not in (JudgeDomain.TEMPORAL, JudgeDomain.CODE, JudgeDomain.SNIPPET) and self.vector_store: batched_executed = await self._execute_batched_vector( - judge_result.operations, domain, user_id, extra_metadata=extra_metadata + judge_result.operations, domain, user_id, extra_metadata=safe_extra ) result.executed.extend(batched_executed) else: for op in judge_result.operations: - executed = await self._execute_one(op, domain, user_id) + executed = await self._execute_one(op, domain, user_id, extra_metadata=safe_extra) result.executed.append(executed) self._log_summary(domain, result) @@ -149,8 +155,7 @@ async def flush_add_batch(): try: meta = {"user_id": user_id, "domain": domain.value} meta.update(_extract_structured_metadata(op.content)) - if extra_metadata: - meta.update(extra_metadata) + _merge_extra_metadata(meta, extra_metadata) valid_ops.append(op) texts.append(op.content) @@ -288,14 +293,12 @@ async def flush_delete_batch(): elif current_op.type == OperationType.UPDATE: await flush_add_batch() await flush_delete_batch() - # Execute individual UPDATE - executed_ops.append(await self._execute_one(current_op, domain, user_id)) + executed_ops.append(await self._execute_one(current_op, domain, user_id, extra_metadata=extra_metadata)) else: - # Fallback for unknown types - await flush_add_batch() - await flush_delete_batch() - executed_ops.append(await self._execute_one(current_op, domain, user_id)) + await flush_add_batch() + await flush_delete_batch() + executed_ops.append(await self._execute_one(current_op, domain, user_id, extra_metadata=extra_metadata)) # Final flush await flush_add_batch() @@ -312,6 +315,7 @@ async def _execute_one( op: Operation, domain: JudgeDomain, user_id: str, + extra_metadata: Optional[Dict[str, Any]] = None, ) -> ExecutedOp: # ── Guard rails ────────────────────────────────────────────── if op.type == OperationType.NOOP: @@ -341,7 +345,7 @@ async def _execute_one( elif domain == JudgeDomain.SNIPPET: return await self._execute_snippet(op, user_id) else: - return await self._execute_vector(op, domain, user_id) + return await self._execute_vector(op, domain, user_id, extra_metadata=extra_metadata) # ------------------------------------------------------------------ # Profile / Summary → Pinecone @@ -352,6 +356,7 @@ async def _execute_vector( op: Operation, domain: JudgeDomain, user_id: str, + extra_metadata: Optional[Dict[str, Any]] = None, ) -> ExecutedOp: if not self.vector_store: return ExecutedOp( @@ -361,9 +366,9 @@ async def _execute_vector( try: if op.type == OperationType.ADD: - return await self._vector_add(op, domain, user_id) + return await self._vector_add(op, domain, user_id, extra_metadata=extra_metadata) elif op.type == OperationType.UPDATE: - return await self._vector_update(op, domain, user_id) + return await self._vector_update(op, domain, user_id, extra_metadata=extra_metadata) elif op.type == OperationType.DELETE: return await self._vector_delete(op) else: @@ -398,8 +403,7 @@ async def _vector_add( # Store structured metadata for deterministic lookups structured = _extract_structured_metadata(op.content) metadata.update(structured) - if extra_metadata: - metadata.update(extra_metadata) + _merge_extra_metadata(metadata, extra_metadata) ids = self.vector_store.add( texts=[op.content], @@ -414,6 +418,7 @@ async def _vector_add( async def _vector_update( self, op: Operation, domain: JudgeDomain, user_id: str, + extra_metadata: Optional[Dict[str, Any]] = None, ) -> ExecutedOp: if not self.embed_fn: return ExecutedOp( @@ -427,6 +432,8 @@ async def _vector_update( # Store structured metadata for deterministic lookups structured = _extract_structured_metadata(op.content) metadata.update(structured) + # TODO(PR#2): UPDATE re-stamps created_at to "now"; fix when versioning lands + _merge_extra_metadata(metadata, extra_metadata) success = self.vector_store.update( id=op.embedding_id, @@ -443,7 +450,7 @@ async def _vector_update( logger.warning( "UPDATE target %s not found — falling back to ADD.", op.embedding_id, ) - return await self._vector_add(op, domain, user_id) + return await self._vector_add(op, domain, user_id, extra_metadata=extra_metadata) async def _vector_delete(self, op: Operation) -> ExecutedOp: success = self.vector_store.delete(ids=[op.embedding_id]) @@ -867,6 +874,16 @@ def _log_summary(self, domain: JudgeDomain, result: WeaverResult) -> None: # Helpers # --------------------------------------------------------------------------- +def _merge_extra_metadata(base: Dict[str, Any], extra: Optional[Dict[str, Any]]) -> Dict[str, Any]: + """Merge already-sanitized extra_metadata onto base, in place. + Protected-key stripping happens ONCE at the execute() boundary, so this is a + plain merge. Kept as one helper so every vector write funnels through one path. + """ + if extra: + base.update(extra) + return base + + def _extract_structured_metadata(content: str) -> Dict[str, str]: """Extract structured metadata from profile/summary content. diff --git a/src/schemas/memory_lifecycle.py b/src/schemas/memory_lifecycle.py index a8531c2..10bafcb 100644 --- a/src/schemas/memory_lifecycle.py +++ b/src/schemas/memory_lifecycle.py @@ -14,6 +14,8 @@ class MemoryLifecycle(BaseModel): """ forget: bool = False expires_at: Optional[datetime] = None + # NOTE: the manual soft-forget write path (lifecycle_state="forgotten" + forgotten_at) + # is READ by is_retrievable() but not yet WRITTEN by any endpoint — lands in PR #2. lifecycle_state: Literal["active", "forgotten"] = "active" forgotten_at: Optional[datetime] = None forget_reason: Optional[str] = None @@ -22,3 +24,10 @@ class MemoryLifecycle(BaseModel): RESERVED_LIFECYCLE_KEYS: frozenset[str] = frozenset(MemoryLifecycle.model_fields) + +# Identity/routing keys that caller-supplied extra_metadata must NEVER overwrite. +# Denylist (not allowlist): main_content/subcontent are deterministically set by +# _extract_structured_metadata before the merge, and an allowlist tied to lifecycle +# would silently drop PR #2 versioning keys (version/parent_memory_id/is_current) +# that flow through the same extra_metadata channel. +PROTECTED_METADATA_KEYS: frozenset[str] = frozenset({"user_id", "domain"}) diff --git a/tests/api/test_memory_versioning.py b/tests/api/test_memory_versioning.py index d7e1f90..3da6f0b 100644 --- a/tests/api/test_memory_versioning.py +++ b/tests/api/test_memory_versioning.py @@ -445,6 +445,58 @@ def fail_mark_cancelled(job_id): assert store.jobs["job-1"]["status"] == "running" +def test_batch_ingest_rejects_forget(monkeypatch): + """Batch ingest with any forget=true item must return 400 before enqueuing any job.""" + app, _ = _build_app(monkeypatch) + store = FakeJobStore() + monkeypatch.setattr(memory_v2, "get_default_job_store", lambda: store) + + payload = { + "items": [ + {"user_query": "remember alpha", "user_id": "alice", "forget": True}, + ], + } + + response = TestClient(app).post("/v2/memory/batch-ingest", json=payload) + + assert response.status_code == 400 + body = response.json() + assert body["status"] == "error" + assert "forget=true is not supported in batch" in body["error"] + assert store.jobs == {} + + +def test_v2_ingest_forget_stamps_lifecycle(monkeypatch): + """Ingest with forget=true must thread lifecycle_metadata with created_at/updated_at.""" + app, _ = _build_app(monkeypatch) + store = FakeJobStore() + scheduled = [] + + async def fake_start_job_workflow(job): + scheduled.append(job["job_id"]) + + monkeypatch.setattr(memory_v2, "get_default_job_store", lambda: store) + monkeypatch.setattr(memory_v2, "start_job_workflow", fake_start_job_workflow) + + payload = { + "user_query": "delete this later", + "agent_response": "done", + "user_id": "hunter", + "forget": True, + } + + response = TestClient(app).post("/v2/memory/ingest", json=payload) + + assert response.status_code == 200 + job = list(store.jobs.values())[0] + lc = job["payload"].get("lifecycle_metadata") + assert lc is not None, "lifecycle_metadata must be present in payload" + assert lc.get("forget") is True + assert "created_at" in lc + assert "updated_at" in lc + assert "expires_at" in lc + + def test_v1_batch_ingest_scopes_each_item_for_local_static_key(monkeypatch): monkeypatch.setattr(memory.settings, "environment", "development", raising=False) static_user = {"id": "static-key", "name": "Static Key User", "email": "static@xmem.ai"} diff --git a/tests/integration/test_weaver_pipeline.py b/tests/integration/test_weaver_pipeline.py index e6569fc..ce7e943 100644 --- a/tests/integration/test_weaver_pipeline.py +++ b/tests/integration/test_weaver_pipeline.py @@ -180,3 +180,175 @@ async def test_weaver_extra_metadata_none_no_regression(vector_store, fast_embed meta = list(vector_store.records.values())[0]["metadata"] assert "forget" not in meta assert "expires_at" not in meta + + +# --------------------------------------------------------------------------- +# Lifecycle regression tests (bug 1 + 2) +# --------------------------------------------------------------------------- + +FUTURE = NOW + timedelta(days=30) +PAST = NOW - timedelta(days=1) +_LIFECYCLE = { + "forget": True, + "expires_at": FUTURE.isoformat(), + "lifecycle_state": "active", + "created_at": NOW.isoformat(), + "updated_at": NOW.isoformat(), +} + + +@pytest.mark.asyncio +async def test_weaver_update_preserves_extra_metadata_REGRESSION(vector_store, fast_embed): + """UPDATE op must carry lifecycle metadata — was the P1 data-leak bug.""" + vector_store.seed("r1", "work / title = OldTitle", {"user_id": "u1", "domain": "profile"}) + weaver = Weaver(vector_store=vector_store, embed_fn=fast_embed, _now=lambda: NOW) + + result = await weaver.execute( + JudgeResult( + operations=[Operation(type=OperationType.UPDATE, embedding_id="r1", content="work / title = Engineer")], + confidence=1.0, + ), + JudgeDomain.PROFILE, + "u1", + extra_metadata=_LIFECYCLE, + ) + + assert result.succeeded == 1 + meta = vector_store.records["r1"]["metadata"] + assert meta["forget"] is True + assert meta["expires_at"] == FUTURE.isoformat() + + +@pytest.mark.asyncio +async def test_weaver_update_fallback_to_add_carries_extra_metadata(vector_store, fast_embed): + """UPDATE with missing target falls back to ADD, which must carry lifecycle.""" + weaver = Weaver(vector_store=vector_store, embed_fn=fast_embed, _now=lambda: NOW) + + result = await weaver.execute( + JudgeResult( + operations=[Operation(type=OperationType.UPDATE, embedding_id="nonexistent", content="work / title = Engineer")], + confidence=1.0, + ), + JudgeDomain.PROFILE, + "u1", + extra_metadata=_LIFECYCLE, + ) + + assert result.succeeded == 1 + stored = list(vector_store.records.values()) + assert len(stored) == 1 + meta = stored[0]["metadata"] + assert meta["forget"] is True + assert meta["expires_at"] == FUTURE.isoformat() + + +@pytest.mark.asyncio +async def test_weaver_batched_multi_op_carries_lifecycle(vector_store, fast_embed): + """ADD + UPDATE(existing) + ADD all carry lifecycle (pins bug 1 across flush boundaries).""" + vector_store.seed("existing-1", "work / company = OldCo", {"user_id": "u1", "domain": "profile"}) + weaver = Weaver(vector_store=vector_store, embed_fn=fast_embed, _now=lambda: NOW) + + result = await weaver.execute( + JudgeResult( + operations=[ + Operation(type=OperationType.ADD, content="work / title = Engineer"), + Operation(type=OperationType.UPDATE, embedding_id="existing-1", content="work / company = XMem"), + Operation(type=OperationType.ADD, content="work / location = Remote"), + ], + confidence=1.0, + ), + JudgeDomain.PROFILE, + "u1", + extra_metadata=_LIFECYCLE, + ) + + assert result.succeeded == 3 + for record in vector_store.records.values(): + meta = record["metadata"] + assert meta.get("forget") is True, f"Record missing forget: {meta}" + assert meta.get("expires_at") == FUTURE.isoformat(), f"Record missing expires_at: {meta}" + + +@pytest.mark.asyncio +async def test_weaver_extra_metadata_cannot_override_user_id(vector_store, fast_embed): + """Caller-supplied extra_metadata must not overwrite user_id or domain.""" + weaver = Weaver(vector_store=vector_store, embed_fn=fast_embed) + + result = await weaver.execute( + JudgeResult( + operations=[Operation(type=OperationType.ADD, content="work / title = Engineer")], + confidence=1.0, + ), + JudgeDomain.PROFILE, + "u1", + extra_metadata={"user_id": "HACK", "domain": "evil", "forget": True}, + ) + + assert result.succeeded == 1 + meta = list(vector_store.records.values())[0]["metadata"] + assert meta["user_id"] == "u1" + assert meta["domain"] == "profile" + assert meta["forget"] is True + + +@pytest.mark.asyncio +async def test_weaver_extra_metadata_passthrough_unknown_key(vector_store, fast_embed): + """Unknown keys (e.g. PR #2 versioning) pass through — denylist not allowlist.""" + weaver = Weaver(vector_store=vector_store, embed_fn=fast_embed) + + result = await weaver.execute( + JudgeResult( + operations=[Operation(type=OperationType.ADD, content="work / title = Engineer")], + confidence=1.0, + ), + JudgeDomain.PROFILE, + "u1", + extra_metadata={"version": 3, "parent_memory_id": "abc"}, + ) + + assert result.succeeded == 1 + meta = list(vector_store.records.values())[0]["metadata"] + assert meta["version"] == 3 + assert meta["parent_memory_id"] == "abc" + + +@pytest.mark.asyncio +async def test_update_then_expired_filtered_from_summary_and_catalog(vector_store, fast_embed): + """UPDATE a record with past expiry → hidden in _search_summary AND _fetch_profile_catalog.""" + from src.pipelines.retrieval import RetrievalPipeline + from tests.conftest import FakeChatModel, FakeNeo4jClient + + # Seed an expired forget record and a live sibling + vector_store.seed("expired-summary", "summary note expired", { + "user_id": "u1", "domain": "summary", + "forget": True, "expires_at": PAST.isoformat(), "lifecycle_state": "active", + }, score=0.9) + vector_store.seed("live-summary", "summary note live", { + "user_id": "u1", "domain": "summary", + }, score=0.8) + vector_store.seed("expired-profile", "work / title = OldTitle", { + "user_id": "u1", "domain": "profile", + "main_content": "work_title", + "forget": True, "expires_at": PAST.isoformat(), "lifecycle_state": "active", + }, score=0.9) + vector_store.seed("live-profile", "work / company = XMem", { + "user_id": "u1", "domain": "profile", + "main_content": "work_company", + }, score=0.8) + + pipeline = RetrievalPipeline( + model=FakeChatModel(), + vector_store=vector_store, + neo4j_client=FakeNeo4jClient(), + _now=lambda: NOW, + ) + + summary_records = await pipeline._search_summary("query", "u1", top_k=10) + summary_ids = [r.metadata["id"] for r in summary_records] + assert "expired-summary" not in summary_ids + assert "live-summary" in summary_ids + + catalog, live_results = pipeline._fetch_profile_catalog("u1") + live_ids = [r.id for r in live_results] + assert "expired-profile" not in live_ids + assert "live-profile" in live_ids diff --git a/tests/unit/test_memory_lifecycle.py b/tests/unit/test_memory_lifecycle.py index 997fd56..04aa2ba 100644 --- a/tests/unit/test_memory_lifecycle.py +++ b/tests/unit/test_memory_lifecycle.py @@ -233,3 +233,20 @@ def test_fake_update_merges_metadata(): assert meta["a"] == 1, "Pre-existing key must be preserved" assert meta["b"] == 99, "Updated key must reflect new value" assert meta["c"] == 3, "New key must be added" + + +# --------------------------------------------------------------------------- +# 6. is_retrievable — None guard + datetime short-circuit +# --------------------------------------------------------------------------- + +def test_is_retrievable_none_metadata(): + """is_retrievable(None, now) must return True (legacy-safe).""" + assert is_retrievable(None, NOW) is True + + +def test_is_retrievable_datetime_expires_at(): + """is_retrievable must handle a real datetime object (isinstance short-circuit).""" + meta_past = {"forget": True, "expires_at": PAST, "lifecycle_state": "active"} + meta_future = {"forget": True, "expires_at": FUTURE, "lifecycle_state": "active"} + assert is_retrievable(meta_past, NOW) is False + assert is_retrievable(meta_future, NOW) is True