diff --git a/src/api/routes/memory.py b/src/api/routes/memory.py index f6b41f9d..562937d1 100644 --- a/src/api/routes/memory.py +++ b/src/api/routes/memory.py @@ -17,6 +17,7 @@ from src.api.dependencies import ( enforce_rate_limit, + get_code_pipeline, get_ingest_pipeline, get_retrieval_pipeline, require_api_key, @@ -41,7 +42,6 @@ StatusEnum, WeaverSummary, ) -from src.pipelines.retrieval import RetrievalPipeline from bs4 import BeautifulSoup import json @@ -49,7 +49,12 @@ from playwright.sync_api import sync_playwright from src.config import settings -from src.jobs.durable import serialize_job +from src.jobs.durable import ( + QUEUED, + get_default_job_store, + run_job, + serialize_job, +) logger = logging.getLogger("xmem.api.routes.memory") @@ -68,6 +73,18 @@ dependencies=[Depends(enforce_rate_limit)], ) +v2_router = APIRouter( + prefix="/v2/memory", + tags=["memory"], + dependencies=[Depends(require_ready), Depends(enforce_rate_limit)], +) + +v2_scrape_router = APIRouter( + prefix="/v2/memory", + tags=["memory"], + dependencies=[Depends(enforce_rate_limit)], +) + # Helpers def _model_name(model: Any) -> str: @@ -84,7 +101,9 @@ def _build_domain_result(judge: Any, weaver: Any) -> DomainResult | None: ws = None if weaver: ws = WeaverSummary( - succeeded=weaver.succeeded, skipped=weaver.skipped, failed=weaver.failed, + succeeded=weaver.succeeded, + skipped=weaver.skipped, + failed=weaver.failed, ) return DomainResult(confidence=judge.confidence, operations=ops, weaver=ws) @@ -121,7 +140,9 @@ def _error( def _is_static_key_user(user: dict) -> bool: - return user.get("email") == "static@xmem.ai" or user.get("name") == "Static Key User" + return ( + user.get("email") == "static@xmem.ai" or user.get("name") == "Static Key User" + ) def _current_user_id(user: dict, requested_user_id: str = "") -> str: @@ -176,6 +197,13 @@ def _job_accepted( return _wrap(request, data, elapsed_ms) +def _invalidate_retrieval_cache(user_id: str) -> None: + try: + get_retrieval_pipeline().invalidate_user_cache(user_id) + except Exception: + logger.debug("Retrieval cache invalidation skipped", exc_info=True) + + async def _run_ingest_payload( payload: Dict[str, Any], user_id: str, @@ -210,6 +238,7 @@ async def _run_ingest_payload( result.get("image_weaver"), ), ) + _invalidate_retrieval_cache(user_id) return data.model_dump() @@ -241,7 +270,11 @@ def _detect_chat_provider(*urls: str) -> str: lowered = (url or "").lower() if not lowered: continue - if "chatgpt.com" in lowered or "chat.openai.com" in lowered or "openai.com" in lowered: + if ( + "chatgpt.com" in lowered + or "chat.openai.com" in lowered + or "openai.com" in lowered + ): return "chatgpt" if "claude.ai" in lowered or "claude.com" in lowered: return "claude" @@ -303,7 +336,10 @@ def _get_or_create_browser(): if channel: kwargs["channel"] = channel _browser_instance = _pw_instance.chromium.launch(**kwargs) - logger.info("[scrape] Playwright browser launched (channel=%s)", channel or "bundled") + logger.info( + "[scrape] Playwright browser launched (channel=%s)", + channel or "bundled", + ) return _browser_instance except Exception as exc: launch_errors.append(f"{channel or 'bundled chromium'}: {exc}") @@ -387,10 +423,12 @@ def _extract_chat_pairs( user_msgs = soup.find_all("div", {"data-message-author-role": "user"}) asst_msgs = soup.find_all("div", {"data-message-author-role": "assistant"}) for u, a in zip(user_msgs, asst_msgs): - pairs.append(MessagePair( - user_query=u.get_text(separator="\n").strip(), - agent_response=a.get_text(separator="\n").strip(), - )) + pairs.append( + MessagePair( + user_query=u.get_text(separator="\n").strip(), + agent_response=a.get_text(separator="\n").strip(), + ) + ) if pairs: extraction_method = "dom" @@ -411,10 +449,12 @@ def _extract_chat_pairs( if msg.get("sender") == "human": current_user = msg.get("text", "") elif msg.get("sender") == "assistant": - pairs.append(MessagePair( - user_query=current_user, - agent_response=msg.get("text", ""), - )) + pairs.append( + MessagePair( + user_query=current_user, + agent_response=msg.get("text", ""), + ) + ) current_user = "" if pairs: extraction_method = "structured" @@ -425,10 +465,12 @@ def _extract_chat_pairs( user_blocks = soup.select("message-content[role='user'], div.user-query") model_blocks = soup.select("message-content[role='model'], div.model-response") for u, m in zip(user_blocks, model_blocks): - pairs.append(MessagePair( - user_query=u.get_text(separator="\n").strip(), - agent_response=m.get_text(separator="\n").strip(), - )) + pairs.append( + MessagePair( + user_query=u.get_text(separator="\n").strip(), + agent_response=m.get_text(separator="\n").strip(), + ) + ) if pairs: extraction_method = "dom" @@ -445,10 +487,12 @@ def _extract_chat_pairs( if unique_paras: text = "\n\n".join(unique_paras[:50]) - pairs.append(MessagePair( - user_query="Extracted text from link", - agent_response=text[:10000], - )) + pairs.append( + MessagePair( + user_query="Extracted text from link", + agent_response=text[:10000], + ) + ) extraction_method = "fallback" return provider, extraction_method, pairs @@ -456,7 +500,7 @@ def _extract_chat_pairs( def _parse_cursor_transcript(text: str) -> List[MessagePair]: """Parse a Cursor-exported markdown transcript into message pairs. - + Cursor transcripts have the format: _Exported on ... from Cursor_ --- @@ -469,41 +513,47 @@ def _parse_cursor_transcript(text: str) -> List[MessagePair]: ... """ pairs: List[MessagePair] = [] - + # Split by --- separator sections = text.split("---") - + # Skip the first section if it's the header (contains "Exported on") start_idx = 0 if sections and "Exported on" in sections[0]: start_idx = 1 - + current_user_query = None - + for section in sections[start_idx:]: section = section.strip() if not section: continue - + # Check if this is a User message if section.startswith("**User**"): # Extract the user message (remove the **User** header) content = section.replace("**User**", "", 1).strip() current_user_query = content - + # Check if this is a Cursor/Agent message elif section.startswith("**Cursor**") or section.startswith("**Assistant**"): # Extract the agent response - content = section.replace("**Cursor**", "", 1).replace("**Assistant**", "", 1).strip() - + content = ( + section.replace("**Cursor**", "", 1) + .replace("**Assistant**", "", 1) + .strip() + ) + # If we have a user query, create a pair if current_user_query: - pairs.append(MessagePair( - user_query=current_user_query, - agent_response=content, - )) + pairs.append( + MessagePair( + user_query=current_user_query, + agent_response=content, + ) + ) current_user_query = None - + return pairs @@ -552,10 +602,12 @@ def _parse_antigravity_transcript(text: str) -> List[MessagePair]: if re.match(r"###\s+User Input", block, re.IGNORECASE): # Flush any pending planner chunks as a completed pair if current_user_query and planner_chunks: - pairs.append(MessagePair( - user_query=current_user_query, - agent_response="\n\n".join(planner_chunks).strip(), - )) + pairs.append( + MessagePair( + user_query=current_user_query, + agent_response="\n\n".join(planner_chunks).strip(), + ) + ) planner_chunks = [] # The next block (index i+1) is the content of this user turn current_user_query = None # will be filled by the content block below @@ -572,10 +624,12 @@ def _parse_antigravity_transcript(text: str) -> List[MessagePair]: if re.match(r"###\s+User Input", prev_heading, re.IGNORECASE): # New user turn — flush previous pair first if current_user_query and planner_chunks: - pairs.append(MessagePair( - user_query=current_user_query, - agent_response="\n\n".join(planner_chunks).strip(), - )) + pairs.append( + MessagePair( + user_query=current_user_query, + agent_response="\n\n".join(planner_chunks).strip(), + ) + ) planner_chunks = [] current_user_query = block @@ -586,10 +640,12 @@ def _parse_antigravity_transcript(text: str) -> List[MessagePair]: # Flush last pair if current_user_query and planner_chunks: - pairs.append(MessagePair( - user_query=current_user_query, - agent_response="\n\n".join(planner_chunks).strip(), - )) + pairs.append( + MessagePair( + user_query=current_user_query, + agent_response="\n\n".join(planner_chunks).strip(), + ) + ) return pairs @@ -597,14 +653,14 @@ def _parse_antigravity_transcript(text: str) -> List[MessagePair]: async def _parse_transcript_with_llm(text: str) -> List[MessagePair]: """Use an LLM to parse transcript text when format detection fails.""" from src.models import get_model - + # Limit text size to avoid token issues max_chars = 50000 if len(text) > max_chars: text = text[:max_chars] - + model = get_model(temperature=0.0) - + prompt = f"""You are parsing a chat transcript. Extract all user-agent message pairs from the following text. Return a JSON array of objects with this structure: @@ -618,24 +674,27 @@ async def _parse_transcript_with_llm(text: str) -> List[MessagePair]: Transcript: {text} """ - + try: response = await model.ainvoke(prompt) content = response.content if hasattr(response, "content") else str(response) - + # Try to extract JSON from the response - json_match = re.search(r'\[.*\]', content, re.DOTALL) + json_match = re.search(r"\[.*\]", content, re.DOTALL) if json_match: data = json.loads(json_match.group(0)) pairs = [ - MessagePair(user_query=item.get("user_query", ""), agent_response=item.get("agent_response", "")) + MessagePair( + user_query=item.get("user_query", ""), + agent_response=item.get("agent_response", ""), + ) for item in data if isinstance(item, dict) ] return pairs except Exception as exc: logger.warning("LLM transcript parsing failed: %s", exc) - + return [] @@ -649,7 +708,9 @@ def _parse_transcript_text(text: str) -> tuple[str, List[MessagePair]]: return "cursor", pairs # Detect Antigravity format - if "# Chat Conversation" in text and ("### User Input" in text or "### Planner Response" in text): + if "# Chat Conversation" in text and ( + "### User Input" in text or "### Planner Response" in text + ): pairs = _parse_antigravity_transcript(text) if pairs: return "antigravity", pairs @@ -659,7 +720,9 @@ def _parse_transcript_text(text: str) -> tuple[str, List[MessagePair]]: async def _scrape_chat_share(url: str) -> Dict[str, Any]: html, final_url = await _render_chat_share(url) - provider, extraction_method, pairs = _extract_chat_pairs(final_url or url, html, url) + provider, extraction_method, pairs = _extract_chat_pairs( + final_url or url, html, url + ) return { "provider": provider, @@ -678,7 +741,9 @@ async def _scrape_chat_share(url: str) -> Dict[str, Any]: response_model=APIResponse, summary="Ingest a conversation turn into long-term memory", ) -async def ingest_memory(req: IngestRequest, request: Request, user: dict = Depends(require_api_key)): +async def ingest_memory( + req: IngestRequest, request: Request, user: dict = Depends(require_api_key) +): start = time.perf_counter() user_id = _current_user_id(user, req.user_id) payload = req.model_dump() @@ -697,6 +762,57 @@ async def ingest_memory(req: IngestRequest, request: Request, user: dict = Depen return _error(request, str(exc), 500, elapsed) +# POST /v2/memory/ingest +@v2_router.post( + "/ingest", + response_model=APIResponse, + summary="Start an async durable memory ingest job", +) +async def ingest_memory_v2( + req: IngestRequest, request: Request, user: dict = Depends(require_api_key) +): + start = time.perf_counter() + user_id = _current_user_id(user, req.user_id) + job_user_id = _current_user_id(user) + payload = req.model_dump() + payload["user_id"] = user_id + + try: + store = get_default_job_store() + job, created = await asyncio.to_thread( + store.enqueue, + job_type="memory_ingest", + payload=payload, + idempotency_fields={ + "user_id": user_id, + "user_query": req.user_query, + "agent_response": req.agent_response or "", + "session_datetime": req.session_datetime, + "image_url": req.image_url, + "effort_level": req.effort_level, + }, + user_id=job_user_id, + timeout_seconds=float(settings.memory_ingest_timeout_seconds), + max_attempts=3, + ) + _schedule_job( + job, + lambda: _run_ingest_payload(payload, user_id), + ) + elapsed = round((time.perf_counter() - start) * 1000, 2) + return _job_accepted( + request, + job, + created, + f"/v2/memory/ingest/{job['job_id']}/status", + elapsed, + ) + + except Exception as exc: + elapsed = round((time.perf_counter() - start) * 1000, 2) + logger.exception("Ingest enqueue failed for user=%s", user_id) + return _error(request, str(exc), 500, elapsed) + def _safe_classifications(result: Dict[str, Any]) -> list: cr = result.get("classification_result") if cr and getattr(cr, "classifications", None): @@ -713,6 +829,38 @@ async def _read_user_job(job_id: str, user_id: str) -> Dict[str, Any] | None: return job +@v2_router.get( + "/ingest/{job_id}/status", + response_model=APIResponse, + summary="Poll an async memory ingest job", +) +async def ingest_job_status( + job_id: str, request: Request, user: dict = Depends(require_api_key) +): + start = time.perf_counter() + job = await _read_user_job(job_id, _current_user_id(user)) + if not job: + elapsed = round((time.perf_counter() - start) * 1000, 2) + return _error(request, "Job not found.", 404, elapsed) + elapsed = round((time.perf_counter() - start) * 1000, 2) + return _wrap(request, _job_status_data(job), elapsed) + + +@v2_router.get( + "/jobs/{job_id}/status", + response_model=APIResponse, + summary="Poll an async memory job", +) +async def memory_job_status( + job_id: str, request: Request, user: dict = Depends(require_api_key) +): + start = time.perf_counter() + job = await _read_user_job(job_id, _current_user_id(user)) + if not job: + elapsed = round((time.perf_counter() - start) * 1000, 2) + return _error(request, "Job not found.", 404, elapsed) + elapsed = round((time.perf_counter() - start) * 1000, 2) + return _wrap(request, _job_status_data(job), elapsed) # POST /v1/memory/batch-ingest @@ -721,7 +869,9 @@ async def _read_user_job(job_id: str, user_id: str) -> Dict[str, Any] | None: response_model=APIResponse, summary="Ingest multiple conversation turns into long-term memory sequentially", ) -async def batch_ingest_memory(req: BatchIngestRequest, request: Request, user: dict = Depends(require_api_key)): +async def batch_ingest_memory( + req: BatchIngestRequest, request: Request, user: dict = Depends(require_api_key) +): start = time.perf_counter() user_id = _current_user_id(user) @@ -745,6 +895,59 @@ async def batch_ingest_memory(req: BatchIngestRequest, request: Request, user: d return _error(request, str(exc), 500, elapsed) +# POST /v2/memory/batch-ingest +@v2_router.post( + "/batch-ingest", + response_model=APIResponse, + summary="Start an async durable batch memory ingest job", +) +async def batch_ingest_memory_v2( + req: BatchIngestRequest, request: Request, user: dict = Depends(require_api_key) +): + start = time.perf_counter() + user_id = _current_user_id(user) + payload = req.model_dump() + payload["user_id"] = user_id + payload["items"] = [_scoped_ingest_payload(user, item) for item in req.items] + + try: + store = get_default_job_store() + job, created = await asyncio.to_thread( + store.enqueue, + job_type="memory_batch_ingest", + payload=payload, + idempotency_fields={ + "user_id": user_id, + "items": payload["items"], + }, + user_id=user_id, + timeout_seconds=max( + float(settings.memory_ingest_timeout_seconds), + min( + len(req.items) * float(settings.memory_ingest_timeout_seconds), + 3600.0, + ), + ), + max_attempts=3, + ) + _schedule_job( + job, + lambda: _run_batch_ingest_payload(payload), + ) + elapsed = round((time.perf_counter() - start) * 1000, 2) + return _job_accepted( + request, + job, + created, + f"/v2/memory/jobs/{job['job_id']}/status", + elapsed, + ) + + except Exception as exc: + elapsed = round((time.perf_counter() - start) * 1000, 2) + logger.exception("Batch ingest enqueue failed for user=%s", user_id) + return _error(request, str(exc), 500, elapsed) + # POST /v1/memory/retrieve @router.post( @@ -752,10 +955,12 @@ async def batch_ingest_memory(req: BatchIngestRequest, request: Request, user: d response_model=APIResponse, summary="Retrieve an LLM-generated answer backed by stored memories", ) -async def retrieve_memory(req: RetrieveRequest, request: Request, user: dict = Depends(require_api_key)): +async def retrieve_memory( + req: RetrieveRequest, request: Request, user: dict = Depends(require_api_key) +): start = time.perf_counter() pipeline = get_retrieval_pipeline() - + # Get username from authenticated user user_id = _current_user_id(user, req.user_id) @@ -766,8 +971,10 @@ async def retrieve_memory(req: RetrieveRequest, request: Request, user: dict = D answer=result.answer, sources=[ SourceRecord( - domain=s.domain, content=s.content, - score=round(s.score, 3), metadata=s.metadata, + domain=s.domain, + content=s.content, + score=round(s.score, 3), + metadata=s.metadata, ) for s in result.sources ], @@ -786,26 +993,91 @@ async def retrieve_memory(req: RetrieveRequest, request: Request, user: dict = D @router.post( "/search", response_model=APIResponse, - summary="Raw semantic search across memory domains (no LLM answer)", + summary="Raw semantic search across memory domains with optional LLM answer", ) -async def search_memory(req: SearchRequest, request: Request, user: dict = Depends(require_api_key)): +async def search_memory( + req: SearchRequest, request: Request, user: dict = Depends(require_api_key) +): start = time.perf_counter() pipeline = get_retrieval_pipeline() - + # Get username from authenticated user user_id = _current_user_id(user, req.user_id) try: - all_results: List[SourceRecord] = [] + if "code" in req.domains and not req.org_id: + elapsed = round((time.perf_counter() - start) * 1000, 2) + return _error( + request, + "org_id is required when searching the code domain", + 400, + elapsed, + ) - if "profile" in req.domains: - all_results.extend(_search_profile(pipeline, user_id)) - if "temporal" in req.domains: - all_results.extend(_search_temporal(pipeline, req.query, user_id, req.top_k)) - if "summary" in req.domains: - all_results.extend(await _search_summary(pipeline, req.query, user_id, req.top_k)) + memory_domains = [domain for domain in req.domains if domain != "code"] + all_results = [] + latency: Dict[str, Any] = {} - data = SearchResponse(results=all_results, total=len(all_results)) + if memory_domains: + memory_results, raw_latency = await pipeline.raw_search( + query=req.query, + user_id=user_id, + domains=memory_domains, + top_k=req.top_k, + ) + all_results.extend(memory_results) + latency["raw"] = raw_latency + + if "code" in req.domains: + code_start = time.perf_counter() + try: + code_results = await _search_code( + query=req.query, + user_id=user_id, + org_id=req.org_id, + repo=req.repo or "", + top_k=req.top_k, + ) + all_results.extend(code_results) + except Exception: + logger.exception( + "Code search failed for user=%s org=%s repo=%s", + user_id, + req.org_id, + req.repo or "", + ) + latency["code"] = pipeline.record_latency( + "code", (time.perf_counter() - code_start) * 1000 + ) + + all_results.sort(key=lambda source: source.score, reverse=True) + all_results = all_results[: req.top_k] + answer = "" + confidence = 0.0 + mode = "raw" + + if req.answer: + answer_start = time.perf_counter() + try: + answer = await pipeline.synthesize_answer(req.query, all_results) + confidence = min(1.0, len(all_results) * 0.2) if all_results else 0.1 + mode = "answer" + except Exception: + logger.exception("Answer synthesis failed for user=%s", user_id) + latency["answer"] = pipeline.record_latency( + "answer", + (time.perf_counter() - answer_start) * 1000, + ) + + api_results = [_to_api_source(source) for source in all_results] + data = SearchResponse( + results=api_results, + total=len(api_results), + answer=answer, + confidence=confidence, + mode=mode, + latency=latency, + ) elapsed = round((time.perf_counter() - start) * 1000, 2) return _wrap(request, data, elapsed) @@ -815,59 +1087,42 @@ async def search_memory(req: SearchRequest, request: Request, user: dict = Depen return _error(request, str(exc), 500, elapsed) -def _search_profile(pipeline: RetrievalPipeline, user_id: str) -> List[SourceRecord]: - try: - raw = pipeline.vector_store.search_by_metadata( - filters={"user_id": user_id, "domain": "profile"}, top_k=100, - ) - return [SourceRecord(domain="profile", content=r.content, score=r.score, metadata=r.metadata) for r in raw] - except Exception as exc: - logger.warning("Profile search error: %s", exc) - return [] +def _to_api_source(source: Any) -> SourceRecord: + return SourceRecord( + domain=source.domain, + content=source.content, + score=round(float(source.score or 0.0), 3), + metadata=source.metadata, + ) -def _search_temporal(pipeline: RetrievalPipeline, query: str, user_id: str, top_k: int) -> List[SourceRecord]: - try: - events = pipeline.neo4j.search_events_by_embedding( - user_id=user_id, query_text=query, top_k=top_k, similarity_threshold=0.15, - ) - results = [] - for ev in events: - parts = [] - if ev.get("date"): - d = ev["date"] - if ev.get("year"): - d += f", {ev['year']}" - parts.append(f"Date: {d}") - if ev.get("event_name"): - parts.append(f"Event: {ev['event_name']}") - if ev.get("desc"): - parts.append(f"Description: {ev['desc']}") - if ev.get("time"): - parts.append(f"Time: {ev['time']}") - results.append(SourceRecord( - domain="temporal", content=" | ".join(parts), - score=ev.get("similarity_score", 0.0), metadata=ev, - )) - return results - except Exception as exc: - logger.warning("Temporal search error: %s", exc) - return [] - +async def _search_code( + query: str, + user_id: str, + org_id: str, + repo: str, + top_k: int, +) -> List[SourceRecord]: + code_pipeline = get_code_pipeline(org_id=org_id, repo=repo) + raw_records = await code_pipeline.raw_search( + query=query, + user_id=user_id, + repo=repo, + top_k=top_k, + ) -async def _search_summary(pipeline: RetrievalPipeline, query: str, user_id: str, top_k: int) -> List[SourceRecord]: - try: - raw = await pipeline.vector_store.search_by_text( - query_text=query, top_k=top_k, - filters={"user_id": user_id, "domain": "summary"}, + results: List[SourceRecord] = [] + for source in raw_records: + metadata = {**source.metadata, "source_domain": source.domain} + results.append( + SourceRecord( + domain="code", + content=source.content, + score=source.score, + metadata=metadata, + ) ) - return [ - SourceRecord(domain="summary", content=r.content, score=r.score, metadata={"id": r.id, **r.metadata}) - for r in raw - ] - except Exception as exc: - logger.warning("Summary search error: %s", exc) - return [] + return results # POST /v1/memory/scrape @@ -898,6 +1153,57 @@ async def scrape_chat_link(req: ScrapeRequest, request: Request): return _error(request, str(exc) or repr(exc), 500, elapsed) +# POST /v2/memory/scrape +@v2_scrape_router.post( + "/scrape", + response_model=APIResponse, + summary="Start an async durable scrape job", +) +async def scrape_chat_link_v2(req: ScrapeRequest, request: Request): + start = time.perf_counter() + payload = req.model_dump() + + try: + store = get_default_job_store() + job, created = await asyncio.to_thread( + store.enqueue, + job_type="memory_scrape", + payload=payload, + idempotency_fields={"url": req.url}, + user_id="anonymous", + timeout_seconds=60.0, + max_attempts=2, + ) + _schedule_job(job, lambda: _run_scrape_payload(payload)) + elapsed = round((time.perf_counter() - start) * 1000, 2) + return _job_accepted( + request, + job, + created, + f"/v2/memory/scrape/{job['job_id']}/status", + elapsed, + ) + + except Exception as exc: + elapsed = round((time.perf_counter() - start) * 1000, 2) + logger.exception("Scrape enqueue failed for url=%s", req.url) + return _error(request, str(exc) or repr(exc), 500, elapsed) + + +@v2_scrape_router.get( + "/scrape/{job_id}/status", + response_model=APIResponse, + summary="Poll an async scrape job", +) +async def scrape_job_status(job_id: str, request: Request): + start = time.perf_counter() + job = await asyncio.to_thread(get_default_job_store().get, job_id) + if not job or job.get("user_id") != "anonymous": + elapsed = round((time.perf_counter() - start) * 1000, 2) + return _error(request, "Job not found.", 404, elapsed) + elapsed = round((time.perf_counter() - start) * 1000, 2) + return _wrap(request, _job_status_data(job), elapsed) + # POST /v1/memory/parse_transcript @scrape_router.post( @@ -907,36 +1213,40 @@ async def scrape_chat_link(req: ScrapeRequest, request: Request): ) async def parse_transcript( request: Request, - file: UploadFile = File(..., description="Chat transcript file (.txt, .md, .json)") + file: UploadFile = File(..., description="Chat transcript file (.txt, .md, .json)"), ): start = time.perf_counter() - + try: # Read file content content_bytes = await file.read() text = content_bytes.decode("utf-8", errors="ignore") - + if not text.strip(): return _error(request, "Uploaded file is empty.", 400) - + # Try to parse the transcript format_detected, pairs = _parse_transcript_text(text) - + # If no pairs found, try LLM fallback if not pairs: logger.info("Format detection failed, trying LLM fallback") pairs = await _parse_transcript_with_llm(text) - + if not pairs: - return _error(request, "Could not extract message pairs from the transcript.", 400) - + return _error( + request, "Could not extract message pairs from the transcript.", 400 + ) + data = ScrapeResponse(pairs=pairs) elapsed = round((time.perf_counter() - start) * 1000, 2) return _wrap(request, data, elapsed) - + except UnicodeDecodeError: elapsed = round((time.perf_counter() - start) * 1000, 2) - return _error(request, "Could not decode file. Please upload a text file.", 400, elapsed) + return _error( + request, "Could not decode file. Please upload a text file.", 400, elapsed + ) except Exception as exc: elapsed = round((time.perf_counter() - start) * 1000, 2) logger.exception("Transcript parsing failed for file=%s", file.filename) diff --git a/src/api/schemas.py b/src/api/schemas.py index ff10e6de..7e3c1159 100644 --- a/src/api/schemas.py +++ b/src/api/schemas.py @@ -7,7 +7,6 @@ from __future__ import annotations -from datetime import datetime from enum import Enum import re from typing import Any, Dict, List, Optional @@ -34,6 +33,7 @@ def normalize_user_id_field(cls, v: Any) -> str: # ── Shared envelope ──────────────────────────────────────────────────────── + class StatusEnum(str, Enum): OK = "ok" ERROR = "error" @@ -41,6 +41,7 @@ class StatusEnum(str, Enum): class APIResponse(BaseModel): """Standard wrapper returned by every endpoint.""" + status: StatusEnum = StatusEnum.OK request_id: Optional[str] = None data: Optional[Any] = None @@ -50,6 +51,7 @@ class APIResponse(BaseModel): # ── Health ───────────────────────────────────────────────────────────────── + class HealthResponse(BaseModel): status: str pipelines_ready: bool @@ -62,16 +64,22 @@ class HealthResponse(BaseModel): class IngestRequest(UserScopedModel): """Store a new memory from a conversation turn.""" + user_query: str = Field( - ..., min_length=1, max_length=10_000, + ..., + min_length=1, + max_length=10_000, description="The user's message to memorize", ) agent_response: str = Field( - default="", max_length=10_000, + default="", + max_length=10_000, description="The assistant's reply (used for summary extraction)", ) user_id: str = Field( - ..., min_length=1, max_length=256, + ..., + min_length=1, + max_length=256, description="User identifier. Friendly names are normalized internally.", ) session_datetime: str = Field( @@ -79,7 +87,8 @@ class IngestRequest(UserScopedModel): description="ISO-8601 datetime context for temporal event extraction", ) image_url: str = Field( - default="", max_length=50_000, + default="", + max_length=50_000, description="URL or base64 data-URI of an attached image", ) effort_level: str = Field( @@ -121,27 +130,36 @@ class IngestResponse(BaseModel): class BatchIngestRequest(BaseModel): """Store multiple new memories in a single batch.""" + items: List[IngestRequest] = Field( - ..., min_length=1, max_length=100, - description="List of conversation turns to ingest" + ..., + min_length=1, + max_length=100, + description="List of conversation turns to ingest", ) + class BatchIngestResponse(BaseModel): """Response for a batch ingest operation.""" - results: List[IngestResponse] = Field(default_factory=list) + results: List[IngestResponse] = Field(default_factory=list) # ── Retrieve (answer a question from memory) ────────────────────────────── class RetrieveRequest(UserScopedModel): """Ask a question answered from stored memories.""" + query: str = Field( - ..., min_length=1, max_length=5_000, + ..., + min_length=1, + max_length=5_000, description="The question to answer from memory", ) user_id: str = Field( - ..., min_length=1, max_length=256, + ..., + min_length=1, + max_length=256, ) top_k: int = Field(default=5, ge=1, le=50) @@ -167,23 +185,43 @@ class RetrieveResponse(BaseModel): # ── Search (raw vector / graph search without LLM answer) ───────────────── class SearchRequest(UserScopedModel): - """Raw semantic search across memory domains.""" + """Raw semantic search across memory domains, optionally with answer synthesis.""" + query: str = Field( - ..., min_length=1, max_length=5_000, + ..., + min_length=1, + max_length=5_000, ) user_id: str = Field( - ..., min_length=1, max_length=256, + ..., + min_length=1, + max_length=256, ) domains: List[str] = Field( - default=["profile", "temporal", "summary"], + default=["profile", "temporal", "summary", "snippet"], description="Which memory domains to search", ) top_k: int = Field(default=10, ge=1, le=100) + answer: bool = Field( + default=False, + description="When true, synthesize an LLM answer from raw search hits.", + ) + org_id: Optional[str] = Field( + default=None, + min_length=1, + max_length=256, + description="Organization id required when the code domain is requested.", + ) + repo: Optional[str] = Field( + default=None, + max_length=256, + description="Optional repository filter for code search.", + ) @field_validator("domains") @classmethod def validate_domains(cls, v: List[str]) -> List[str]: - allowed = {"profile", "temporal", "summary"} + allowed = {"profile", "temporal", "summary", "snippet", "code"} for d in v: if d not in allowed: raise ValueError(f"Invalid domain '{d}'. Allowed: {allowed}") @@ -193,21 +231,31 @@ def validate_domains(cls, v: List[str]) -> List[str]: class SearchResponse(BaseModel): results: List[SourceRecord] = Field(default_factory=list) total: int = 0 + answer: str = "" + confidence: float = 0.0 + mode: str = "raw" + latency: Dict[str, Any] = Field(default_factory=dict) # ── Scrape (extract from shared chat links) ──────────────────────────────── + class ScrapeRequest(BaseModel): """Request to scrape a shared AI chat link.""" + url: str = Field( - ..., min_length=1, max_length=2000, - description="Public share link (ChatGPT, Claude, Gemini)" + ..., + min_length=1, + max_length=2000, + description="Public share link (ChatGPT, Claude, Gemini)", ) + class MessagePair(BaseModel): user_query: str agent_response: str + class ScrapeResponse(BaseModel): pairs: List[MessagePair] = Field(default_factory=list) error: Optional[str] = None @@ -217,6 +265,7 @@ class ScrapeResponse(BaseModel): class CodeQueryRequest(UserScopedModel): """Query a codebase via the code retrieval pipeline.""" + org_id: str = Field(..., min_length=1, max_length=256) repo: str = Field(..., min_length=1, max_length=256) query: str = Field(..., min_length=1, max_length=5_000) @@ -237,6 +286,7 @@ class CodeQueryResponse(BaseModel): class ExecuteToolRequest(UserScopedModel): """Execute a specific raw code retrieval tool natively.""" + org_id: str = Field(..., min_length=1, max_length=256) repo: str = Field(..., min_length=1, max_length=256) tool_name: str = Field(..., min_length=1, max_length=128) diff --git a/src/pipelines/code_retrieval.py b/src/pipelines/code_retrieval.py index 2187e328..c67d432e 100644 --- a/src/pipelines/code_retrieval.py +++ b/src/pipelines/code_retrieval.py @@ -397,6 +397,55 @@ def _rrf_fuse( return fused +def _source_record_key(record: SourceRecord) -> tuple[str, str, str, str]: + """Return a stable identity for deduping source records across native tools.""" + metadata = record.metadata + repo = str(metadata.get("repo", "")) + qualified_name = metadata.get("qualified_name") + if qualified_name: + return ( + "symbol", + repo, + str(qualified_name), + str(metadata.get("file_path", "")), + ) + + file_path = metadata.get("file_path") + if file_path: + return ("file", repo, str(file_path), "") + + return (record.domain, repo, record.content, "") + + +def _fuse_source_records( + ranked_lists: List[List[SourceRecord]], + limit: int, + k: int = 60, +) -> List[SourceRecord]: + """Fuse native SourceRecord lists while deduping and enforcing a hard limit.""" + fused_scores: Dict[tuple[str, str, str, str], float] = {} + best_records: Dict[tuple[str, str, str, str], SourceRecord] = {} + first_positions: Dict[tuple[str, str, str, str], tuple[int, int]] = {} + + for list_index, ranked_list in enumerate(ranked_lists): + for rank_pos, record in enumerate(ranked_list, start=1): + record_key = _source_record_key(record) + fused_scores[record_key] = ( + fused_scores.get(record_key, 0.0) + 1.0 / (k + rank_pos) + ) + first_positions.setdefault(record_key, (rank_pos, list_index)) + current = best_records.get(record_key) + if current is None or record.score > current.score: + best_records[record_key] = record + + ranked_keys = sorted( + fused_scores, + key=lambda record_key: (-fused_scores[record_key], first_positions[record_key]), + ) + + return [best_records[record_key] for record_key in ranked_keys[:limit]] + + # ═══════════════════════════════════════════════════════════════════════════ # Deterministic fast-path detection # ═══════════════════════════════════════════════════════════════════════════ @@ -491,6 +540,42 @@ def __init__( # Public entry point # ------------------------------------------------------------------ + async def raw_search( + self, + query: str, + user_id: str = "", + repo: str = "", + top_k: int = 10, + ) -> List[SourceRecord]: + """Return direct code search hits without LLM answer synthesis.""" + raw_tool_results = await asyncio.gather( + self._execute_tool( + tool_name="search_symbols", + tool_args={"query": query, "repo": repo}, + repo=repo, + top_k=top_k, + user_id=user_id, + ), + self._execute_tool( + tool_name="search_files", + tool_args={"query": query, "repo": repo}, + repo=repo, + top_k=top_k, + user_id=user_id, + ), + return_exceptions=True, + ) + + tool_results = [ + result + for result in raw_tool_results + if not isinstance(result, BaseException) + ] + if not tool_results: + return [] + + return _fuse_source_records(tool_results, limit=top_k) + async def run( self, query: str, diff --git a/src/pipelines/retrieval.py b/src/pipelines/retrieval.py index 8f640b22..f912ce7f 100644 --- a/src/pipelines/retrieval.py +++ b/src/pipelines/retrieval.py @@ -22,13 +22,16 @@ import asyncio import logging -from typing import Any, Callable, Dict, List, Optional +import time +from collections import OrderedDict, deque +from typing import Any, Callable, Dict, List, Optional, Tuple from dotenv import load_dotenv from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from pydantic import BaseModel, Field +from src.config.metrics import METRICS from src.config import settings from src.graph.neo4j_client import Neo4jClient from src.prompts.retrieval import ANSWER_PROMPT, build_system_prompt @@ -41,11 +44,51 @@ logger = logging.getLogger("xmem.pipelines.retrieval") +DEFAULT_RAW_SEARCH_DOMAINS = ("profile", "temporal", "summary", "snippet") +PROFILE_CATALOG_CACHE_TTL_SECONDS = 60.0 +PROFILE_CATALOG_CACHE_MAX = 256 +RETRIEVAL_PLAN_CACHE_TTL_SECONDS = 60.0 +RETRIEVAL_PLAN_CACHE_MAX = 256 + + +class _LatencyTracker: + """In-process latency percentiles for retrieval modes.""" + + def __init__(self, max_samples: int = 512) -> None: + self._max_samples = max_samples + self._samples: Dict[str, deque[float]] = {} + + def record(self, mode: str, elapsed_ms: float) -> Dict[str, Any]: + samples = self._samples.setdefault(mode, deque(maxlen=self._max_samples)) + samples.append(elapsed_ms) + stats = self.snapshot(mode) + stats["current_ms"] = round(elapsed_ms, 2) + return stats + + def snapshot(self, mode: str) -> Dict[str, Any]: + samples = list(self._samples.get(mode, ())) + return { + "mode": mode, + "samples": len(samples), + "p50_ms": self._percentile(samples, 50), + "p95_ms": self._percentile(samples, 95), + "p99_ms": self._percentile(samples, 99), + } + + @staticmethod + def _percentile(samples: List[float], percentile: int) -> float: + if not samples: + return 0.0 + ordered = sorted(samples) + index = round((len(ordered) - 1) * (percentile / 100)) + return round(ordered[index], 2) + # ═══════════════════════════════════════════════════════════════════════════ # Tool schemas — These are the "function signatures" exposed to the LLM # ═══════════════════════════════════════════════════════════════════════════ + class SearchProfile(BaseModel): """Look up user profile facts by topic. Use when the question asks about a specific attribute like job, name, hobby, food preference, etc. @@ -59,21 +102,28 @@ class SearchTemporal(BaseModel): """Search for date-based events like appointments, birthdays, milestones. Use when the question involves 'when', dates, schedules, or events.""" - query: str = Field(description="Short search query describing the event, e.g. 'dentist appointment'") + query: str = Field( + description="Short search query describing the event, e.g. 'dentist appointment'" + ) class SearchSummary(BaseModel): """Search general conversation summaries for broad context. Use as a fallback for questions that don't fit profile or temporal domains.""" - query: str = Field(description="Short search query, e.g. 'what does the user enjoy'") + query: str = Field( + description="Short search query, e.g. 'what does the user enjoy'" + ) class SearchSnippet(BaseModel): """Search for personal code snippets previously saved by the user. - Use when the question asks about a specific piece of code, script, or technical configuration the user wrote.""" + Use when the question asks about a specific piece of code, script, or technical configuration the user wrote. + """ - query: str = Field(description="Short search query, e.g. 'python database connection script'") + query: str = Field( + description="Short search query, e.g. 'python database connection script'" + ) TOOLS = [SearchProfile, SearchTemporal, SearchSummary, SearchSnippet] @@ -83,8 +133,10 @@ class SearchSnippet(BaseModel): # Embedding helper (reuses the cached model from ingest) # ═══════════════════════════════════════════════════════════════════════════ + def _get_embed_fn() -> Callable[[str], List[float]]: from src.pipelines.ingest import embed_text + return embed_text @@ -92,6 +144,7 @@ def _get_embed_fn() -> Callable[[str], List[float]]: # RetrievalPipeline # ═══════════════════════════════════════════════════════════════════════════ + class RetrievalPipeline: """Two-step agentic retrieval: tool-call → fetch → answer.""" @@ -104,6 +157,7 @@ def __init__( # ── LLM ─────────────────────────────────────────────────────── if model is None: from src.models import get_model + override = settings.retrieval_model self.model = get_model(model_name=override) if override else get_model() else: @@ -133,6 +187,15 @@ def __init__( self.embed_fn = embed_fn self._snippet_stores: Dict[str, BaseVectorStore] = {} + self._cached_profile_records: List[Any] = [] + self._profile_catalog_cache: OrderedDict[ + str, Tuple[float, List[Dict[str, str]], List[Any]] + ] = OrderedDict() + self._retrieval_plan_cache: OrderedDict[ + Tuple[Any, ...], Tuple[float, List[Dict[str, Any]]] + ] = OrderedDict() + self._user_memory_versions: Dict[str, int] = {} + self._latency_tracker = _LatencyTracker() logger.info("RetrievalPipeline initialized") @@ -147,6 +210,7 @@ async def run( top_k: int = 5, ) -> RetrievalResult: """Run the two-step retrieval pipeline.""" + start = time.perf_counter() logger.info("=" * 60) logger.info("RETRIEVAL PIPELINE START") @@ -169,14 +233,28 @@ async def run( HumanMessage(content=query), ] - ai_response: AIMessage = await self.model_with_tools.ainvoke(messages) - logger.info("LLM response received (tool_calls=%d)", len(ai_response.tool_calls or [])) + plan_cache_key = self._retrieval_plan_cache_key( + user_id=user_id, + query=query, + profile_catalog=profile_catalog, + ) + cached_tool_calls = self._get_cached_retrieval_plan(plan_cache_key) + if cached_tool_calls is not None: + ai_response: AIMessage = AIMessage(content="") + tool_calls = cached_tool_calls + logger.info("Using cached retrieval plan (tool_calls=%d)", len(tool_calls)) + else: + ai_response = await self.model_with_tools.ainvoke(messages) + tool_calls = ai_response.tool_calls or [] + if tool_calls: + self._remember_retrieval_plan(plan_cache_key, tool_calls) + logger.info("LLM response received (tool_calls=%d)", len(tool_calls)) # ── Step 2: Execute tool calls ──────────────────────────────── sources: List[SourceRecord] = [] tool_messages: List[ToolMessage] = [] - if ai_response.tool_calls: + if tool_calls: called_tools = set() async def _process_tool_call(tc): @@ -185,11 +263,16 @@ async def _process_tool_call(tc): tool_id = tc["id"] logger.info(" Tool call: %s(%s)", tool_name, tool_args) records = await self._execute_tool( - tool_name, tool_args, user_id, top_k, + tool_name, + tool_args, + user_id, + top_k, ) return tool_name, tool_args, tool_id, records - tool_results = await asyncio.gather(*[_process_tool_call(tc) for tc in ai_response.tool_calls]) + tool_results = await asyncio.gather( + *[_process_tool_call(tc) for tc in tool_calls] + ) for tool_name, tool_args, tool_id, records in tool_results: sources.extend(records) @@ -206,7 +289,9 @@ async def _process_tool_call(tc): if "searchsummary" not in called_tools: logger.info(" Auto-adding summary context (top_k=5)") extra = await self._search_summary( - query=query, user_id=user_id, top_k=20, + query=query, + user_id=user_id, + top_k=20, ) if extra: sources.extend(extra) @@ -215,7 +300,7 @@ async def _process_tool_call(tc): tool_messages.append( ToolMessage( content=f"[Auto-fetched summary context]\n{extra_text}", - tool_call_id=ai_response.tool_calls[-1]["id"], + tool_call_id=tool_calls[-1]["id"], ) ) @@ -223,37 +308,23 @@ async def _process_tool_call(tc): # Only send the retrieved context + user query to the LLM. # No need for the system prompt, tool schemas, or tool-call history. context_text = "\n".join(tm.content for tm in tool_messages) - answer_prompt = ANSWER_PROMPT.format( - context=context_text, - query=query, - ) - - final_response = await self.model.ainvoke( - [HumanMessage(content=answer_prompt)] - ) - answer = final_response.content + answer = await self._generate_answer(query=query, context_text=context_text) else: # No tool calls — LLM answered directly (shouldn't happen often) answer = ai_response.content logger.info("LLM answered without tool calls") - if isinstance(answer, list): - parts = [] - for c in answer: - if isinstance(c, dict) and "text" in c: - parts.append(c["text"]) - elif isinstance(c, str): - parts.append(c) - else: - parts.append(str(c)) - answer = "\n".join(parts) + answer = self._coerce_answer_text(answer) confidence = min(1.0, len(sources) * 0.2) if sources else 0.1 + self.record_latency("answer", (time.perf_counter() - start) * 1000) logger.info("=" * 60) logger.info("RETRIEVAL PIPELINE COMPLETE") logger.info(" sources: %d", len(sources)) - logger.info(" answer: %s", answer[:100] + "..." if len(answer) > 100 else answer) + logger.info( + " answer: %s", answer[:100] + "..." if len(answer) > 100 else answer + ) logger.info("=" * 60) return RetrievalResult( @@ -263,6 +334,122 @@ async def _process_tool_call(tc): confidence=confidence, ) + async def raw_search( + self, + query: str, + user_id: str, + domains: Optional[List[str]] = None, + top_k: int = 10, + ) -> Tuple[List[SourceRecord], Dict[str, Any]]: + """Return ranked raw memory hits without LLM tool selection or synthesis.""" + start = time.perf_counter() + selected_domains = domains or list(DEFAULT_RAW_SEARCH_DOMAINS) + selected = [ + domain + for domain in selected_domains + if domain in DEFAULT_RAW_SEARCH_DOMAINS + ] + + results: List[SourceRecord] = [] + if "profile" in selected: + results.extend(self._raw_profile_records(user_id=user_id, top_k=top_k)) + + async def _safe_search(domain: str, search_coro) -> List[SourceRecord]: + try: + return await search_coro + except Exception as exc: + logger.warning("Raw %s search failed: %s", domain, exc) + return [] + + tasks = [] + if "temporal" in selected: + tasks.append( + _safe_search("temporal", self._search_temporal(query, user_id, top_k)) + ) + if "summary" in selected: + tasks.append( + _safe_search("summary", self._search_summary(query, user_id, top_k)) + ) + if "snippet" in selected: + tasks.append( + _safe_search("snippet", self._search_snippet(query, user_id, top_k)) + ) + + if tasks: + for records in await asyncio.gather(*tasks): + results.extend(records) + + results.sort(key=lambda record: record.score, reverse=True) + results = results[:top_k] + latency = self.record_latency("raw", (time.perf_counter() - start) * 1000) + logger.info( + "Raw retrieval complete (domains=%s results=%d current=%sms)", + ",".join(selected), + len(results), + latency["current_ms"], + ) + return results, latency + + async def synthesize_answer(self, query: str, sources: List[SourceRecord]) -> str: + """Generate an answer from already-retrieved raw sources.""" + context_text = self._format_tool_results(sources) + return await self._generate_answer(query=query, context_text=context_text) + + def record_latency(self, mode: str, elapsed_ms: float) -> Dict[str, Any]: + stats = self._latency_tracker.record(mode, elapsed_ms) + try: + METRICS.pipeline_stage_duration.labels( + pipeline="retrieval", + stage=mode, + ).observe(elapsed_ms / 1000) + except Exception: + logger.debug("Failed to observe retrieval latency metric", exc_info=True) + logger.info( + "Retrieval latency mode=%s current=%sms p50=%sms p95=%sms p99=%sms", + mode, + stats["current_ms"], + stats["p50_ms"], + stats["p95_ms"], + stats["p99_ms"], + ) + return {"mode": mode, "current_ms": stats["current_ms"]} + + def invalidate_user_cache(self, user_id: str) -> None: + """Drop cached retrieval state for a user after memory writes.""" + self._user_memory_versions[user_id] = ( + self._user_memory_versions.get(user_id, 0) + 1 + ) + self._profile_catalog_cache.pop(user_id, None) + stale_plan_keys = [ + cache_key + for cache_key in self._retrieval_plan_cache + if cache_key and cache_key[0] == user_id + ] + for cache_key in stale_plan_keys: + self._retrieval_plan_cache.pop(cache_key, None) + + async def _generate_answer(self, query: str, context_text: str) -> str: + answer_prompt = ANSWER_PROMPT.format( + context=context_text, + query=query, + ) + final_response = await self.model.ainvoke([HumanMessage(content=answer_prompt)]) + return self._coerce_answer_text(final_response.content) + + @staticmethod + def _coerce_answer_text(answer: Any) -> str: + if isinstance(answer, list): + parts = [] + for c in answer: + if isinstance(c, dict) and "text" in c: + parts.append(c["text"]) + elif isinstance(c, str): + parts.append(c) + else: + parts.append(str(c)) + return "\n".join(parts) + return str(answer) + # ------------------------------------------------------------------ # Tool execution # ------------------------------------------------------------------ @@ -332,17 +519,19 @@ def _search_profile( parts = main_content.split("_", 1) sub_topic = parts[1] if len(parts) == 2 else "" - records.append(SourceRecord( - domain="profile", - content=r.content, - score=r.score, - metadata={ - "id": r.id, - "topic": topic, - "sub_topic": sub_topic, - **r.metadata, - }, - )) + records.append( + SourceRecord( + domain="profile", + content=r.content, + score=r.score, + metadata={ + "id": r.id, + "topic": topic, + "sub_topic": sub_topic, + **r.metadata, + }, + ) + ) logger.info(" → Profile [%s]: %d results", topic, len(records)) return records @@ -368,7 +557,7 @@ async def _search_temporal( query_text=query, top_k=top_k, similarity_threshold=0.15, - ) + ), ) records = [] @@ -395,12 +584,14 @@ async def _search_temporal( content = " | ".join(content_parts) - records.append(SourceRecord( - domain="temporal", - content=content, - score=ev.get("similarity_score", 0.0), - metadata=ev, - )) + records.append( + SourceRecord( + domain="temporal", + content=content, + score=ev.get("similarity_score", 0.0), + metadata=ev, + ) + ) logger.info(" → Temporal [%s]: %d results", query, len(records)) return records @@ -426,12 +617,14 @@ async def _search_summary( records = [] for r in results: - records.append(SourceRecord( - domain="summary", - content=r.content, - score=r.score, - metadata={"id": r.id, **r.metadata}, - )) + records.append( + SourceRecord( + domain="summary", + content=r.content, + score=r.score, + metadata={"id": r.id, **r.metadata}, + ) + ) logger.info(" → Summary [%s]: %d results", query, len(records)) return records @@ -455,7 +648,7 @@ async def _search_snippet( ) -> List[SourceRecord]: """Semantic search over user code snippets (sandboxed namespace).""" store = self._get_snippet_store(user_id) - + # In the sandboxed namespace, we can just search. We pass domain filter just in case. results = await store.search_by_text( query_text=query, @@ -472,12 +665,14 @@ async def _search_snippet( lang = r.metadata.get("language", "") content += f"\n```{lang}\n{snippet}\n```" - records.append(SourceRecord( - domain="snippet", - content=content, - score=r.score, - metadata={"id": r.id, **r.metadata}, - )) + records.append( + SourceRecord( + domain="snippet", + content=content, + score=r.score, + metadata={"id": r.id, **r.metadata}, + ) + ) logger.info(" → Snippet [%s]: %d results", query, len(records)) return records @@ -486,6 +681,89 @@ async def _search_snippet( # Profile catalog (tells the LLM what profile keys exist) # ------------------------------------------------------------------ + def _raw_profile_records(self, user_id: str, top_k: int) -> List[SourceRecord]: + _, raw_records = self._fetch_profile_catalog(user_id) + ranked_records = sorted( + raw_records, key=lambda record: record.score, reverse=True + ) + return [ + SourceRecord( + domain="profile", + content=r.content, + score=r.score, + metadata={"id": r.id, **r.metadata}, + ) + for r in ranked_records[:top_k] + ] + + def _retrieval_plan_cache_key( + self, + user_id: str, + query: str, + profile_catalog: List[Dict[str, str]], + ) -> Tuple[Any, ...]: + catalog_key = tuple( + sorted( + (entry.get("topic", ""), entry.get("sub_topic", "")) + for entry in profile_catalog + ) + ) + return ( + user_id, + query.strip().lower(), + self._user_memory_versions.get(user_id, 0), + catalog_key, + ) + + def _get_cached_retrieval_plan( + self, + cache_key: Tuple[Any, ...], + ) -> Optional[List[Dict[str, Any]]]: + cached = self._retrieval_plan_cache.get(cache_key) + if not cached: + return None + + cached_at, tool_calls = cached + if time.monotonic() - cached_at > RETRIEVAL_PLAN_CACHE_TTL_SECONDS: + self._retrieval_plan_cache.pop(cache_key, None) + return None + + self._retrieval_plan_cache.move_to_end(cache_key) + return [ + { + "name": call["name"], + "args": dict(call.get("args", {})), + "id": f"cached-plan-{idx}", + } + for idx, call in enumerate(tool_calls) + ] + + def _remember_retrieval_plan( + self, + cache_key: Tuple[Any, ...], + tool_calls: List[Dict[str, Any]], + ) -> None: + clean_calls = [ + { + "name": str(call.get("name", "")), + "args": dict(call.get("args", {})), + } + for call in tool_calls + if call.get("name") + ] + if not clean_calls: + return + + if cache_key in self._retrieval_plan_cache: + self._retrieval_plan_cache.pop(cache_key, None) + elif len(self._retrieval_plan_cache) >= RETRIEVAL_PLAN_CACHE_MAX: + self._retrieval_plan_cache.popitem(last=False) + + self._retrieval_plan_cache[cache_key] = ( + time.monotonic(), + clean_calls, + ) + def _fetch_profile_catalog(self, user_id: str): """Fetch all profile entries for a user. @@ -494,6 +772,12 @@ def _fetch_profile_catalog(self, user_id: str): catalog — list of {topic, sub_topic} for the prompt raw_results — the full SearchResult list, cached for _search_profile """ + now = time.monotonic() + cached = self._profile_catalog_cache.get(user_id) + if cached and now - cached[0] <= PROFILE_CATALOG_CACHE_TTL_SECONDS: + self._profile_catalog_cache.move_to_end(user_id) + return cached[1], cached[2] + try: results = self.vector_store.search_by_metadata( filters={"user_id": user_id, "domain": "profile"}, @@ -514,16 +798,24 @@ def _fetch_profile_catalog(self, user_id: str): parts = main_content.split("_", 1) if len(parts) == 2: - catalog.append({ - "topic": parts[0], - "sub_topic": parts[1], - }) + catalog.append( + { + "topic": parts[0], + "sub_topic": parts[1], + } + ) else: - catalog.append({ - "topic": main_content, - "sub_topic": "", - }) + catalog.append( + { + "topic": main_content, + "sub_topic": "", + } + ) + self._profile_catalog_cache[user_id] = (now, catalog, results) + self._profile_catalog_cache.move_to_end(user_id) + if len(self._profile_catalog_cache) > PROFILE_CATALOG_CACHE_MAX: + self._profile_catalog_cache.popitem(last=False) return catalog, results def _format_catalog(self, catalog: List[Dict[str, str]]) -> str: