diff --git a/apps/memos-local-plugin/core/capture/capture.ts b/apps/memos-local-plugin/core/capture/capture.ts index 9d52f749e..74d632218 100644 --- a/apps/memos-local-plugin/core/capture/capture.ts +++ b/apps/memos-local-plugin/core/capture/capture.ts @@ -388,41 +388,72 @@ export function createCaptureRunner(deps: CaptureDeps): CaptureRunner { const normalized = normalizeSteps(rawAll, deps.cfg); const normalizeMs = now() - normStart; - // Pair each normalized step with its already-persisted trace row - // (matched by ts). If runLite was skipped for any step, fall back - // to a fresh insert path so we don't lose data. - const existing = deps.tracesRepo.list({ episodeId: input.episode.id }); - const traceByTs = new Map(); - for (const tr of existing) traceByTs.set(tr.ts, tr); - const orphan = normalized.filter((s) => !traceByTs.has(s.ts)); - if (orphan.length > 0) { + // Pair each normalized step with its already-persisted trace row. + // Matching by timestamp alone is not stable during startup recovery: + // recovered snapshots are rebuilt from trace rows and the extractor may + // shift duplicate tool timestamps. Use content signatures first, and in + // recovered replay allow a timing-insensitive tool signature fallback. + const existing = deps.tracesRepo.listAllForEpisode(input.episode.id); + const recoveredReplay = isRecoveredReplay(input.episode); + const matcher = createTraceMatcher(existing, { allowRelaxedToolTiming: recoveredReplay }); + const matchedRows = normalized.map((s) => matcher.take(s)); + const orphanEntries = normalized + .map((step, index) => ({ step, index })) + .filter(({ index }) => matchedRows[index] === null); + if (orphanEntries.length > 0) { log.warn("reflect.orphan_steps", { episodeId: input.episode.id, - count: orphan.length, - action: "fallback_insert", + count: orphanEntries.length, + action: recoveredReplay ? "skip_recovery_insert" : "fallback_insert", }); + const maxRecoveryOrphans = Math.max(0, Math.floor(deps.cfg.maxRecoveryOrphanInserts)); + const insertableEntries = recoveredReplay + ? orphanEntries.slice(0, maxRecoveryOrphans) + : orphanEntries; + const skipped = orphanEntries.length - insertableEntries.length; + if (skipped > 0) { + warnings.push({ + stage: "persist", + message: "skipped recovered orphan trace inserts to avoid replay duplicates", + detail: { + episodeId: input.episode.id, + skipped, + maxRecoveryOrphanInserts: maxRecoveryOrphans, + }, + }); + } // These steps never went through runLite (likely a test path or a // dropped event). Insert them now with reflection=null so the // batch pass below can patch them like the rest. - const summStart = now(); - const { summaries } = await runSummarize( - orphan.map((s) => ({ + if (insertableEntries.length > 0) { + const orphan = insertableEntries.map(({ step }) => step); + const summStart = now(); + const { summaries } = recoveredReplay + ? { + summaries: orphan.map(heuristicTraceSummary), + } + : await runSummarize( + orphan.map((s) => ({ + ...s, + reflection: { text: null, alpha: 0, usable: false, source: "none" }, + })), + summStart, + llmCalls, + warnings, + { episodeId: input.episode.id, phase: "reflect" }, + ); + const orphanScored: ScoredStep[] = orphan.map((s) => ({ ...s, reflection: { text: null, alpha: 0, usable: false, source: "none" }, - })), - summStart, - llmCalls, - warnings, - { episodeId: input.episode.id, phase: "reflect" }, - ); - const orphanScored: ScoredStep[] = orphan.map((s) => ({ - ...s, - reflection: { text: null, alpha: 0, usable: false, source: "none" }, - })); - const { vecs } = await runEmbed(orphanScored, summaries, warnings); - const orphanRows = buildRows(orphanScored, summaries, vecs, input.episode); - await persistRows(orphanRows, input, warnings); - for (const r of orphanRows) traceByTs.set(r.ts, r); + })); + const { vecs } = await runEmbed(orphanScored, summaries, warnings); + const orphanRows = buildRows(orphanScored, summaries, vecs, input.episode); + await persistRows(orphanRows, input, warnings); + orphanRows.forEach((row, i) => { + const entry = insertableEntries[i]; + if (entry) matchedRows[entry.index] = row; + }); + } } if (normalized.length === 0) { @@ -461,8 +492,18 @@ export function createCaptureRunner(deps: CaptureDeps): CaptureRunner { taskSummary: taskSummary ? taskSummary.slice(0, 240) : null, }); let scored: ScoredStep[] = []; + const reflectBudget = createReflectLlmBudget(deps.cfg.maxReflectLlmCalls, warnings, input.episode.id); if (useBatch) { - scored = await runBatchScoring(normalized, rLlm!, deps, warnings, llmCalls, input.episode.id, taskSummary); + scored = await runBatchScoring( + normalized, + rLlm!, + deps, + warnings, + llmCalls, + input.episode.id, + taskSummary, + reflectBudget, + ); } if (!useBatch || scored.length === 0) { scored = await runPerStepScoring( @@ -473,6 +514,7 @@ export function createCaptureRunner(deps: CaptureDeps): CaptureRunner { llmCalls, input.episode.id, buildReflectionContexts(normalized, taskSummary, downstreamByStep), + reflectBudget, ); } const reflectMs = now() - reflectStart; @@ -482,12 +524,13 @@ export function createCaptureRunner(deps: CaptureDeps): CaptureRunner { // orphan-fallback above) are skipped with a warning. const persistStart = now(); const patchedTraceIds: string[] = []; - for (const s of scored) { - const row = traceByTs.get(s.ts); + for (let i = 0; i < scored.length; i++) { + const s = scored[i]!; + const row = matchedRows[i]; if (!row) { warnings.push({ stage: "persist", - message: "reflect: no trace row for step ts; skipping", + message: "reflect: no trace row for step signature; skipping", detail: { ts: s.ts, key: s.key }, }); continue; @@ -527,8 +570,8 @@ export function createCaptureRunner(deps: CaptureDeps): CaptureRunner { // assignment). For reflect-phase rows we re-emit ScoredStep-shaped // candidates carrying the freshly computed reflection + α; the // already-existing trace ids come from the matched DB rows. - const traces: TraceCandidate[] = scored.map((s) => { - const row = traceByTs.get(s.ts); + const traces: TraceCandidate[] = scored.map((s, i) => { + const row = matchedRows[i]; return { ...s, traceId: (row?.id ?? "") as TraceId, @@ -868,6 +911,65 @@ export function createCaptureRunner(deps: CaptureDeps): CaptureRunner { return out; } + function createTraceMatcher( + rows: TraceRow[], + opts: { allowRelaxedToolTiming: boolean }, + ): { take(step: StepCandidate): TraceRow | null } { + const exact = indexRows(rows, traceIdentitySignature); + const relaxed = opts.allowRelaxedToolTiming + ? indexRows(rows, traceRelaxedIdentitySignature) + : new Map(); + const used = new Set(); + + function takeFrom(index: Map, signature: string): TraceRow | null { + const candidates = index.get(signature) ?? []; + for (const row of candidates) { + if (used.has(row.id)) continue; + used.add(row.id); + return row; + } + return null; + } + + return { + take(step) { + const exactMatch = takeFrom(exact, stepIdentitySignature(step)); + if (exactMatch) return exactMatch; + if (!opts.allowRelaxedToolTiming) return null; + return takeFrom(relaxed, stepRelaxedIdentitySignature(step)); + }, + }; + } + + function indexRows( + rows: TraceRow[], + signatureOf: (row: TraceRow) => string, + ): Map { + const out = new Map(); + for (const row of rows) { + const signature = signatureOf(row); + const bucket = out.get(signature); + if (bucket) bucket.push(row); + else out.set(signature, [row]); + } + return out; + } + + function isRecoveredReplay(episode: CaptureInput["episode"]): boolean { + const meta = episode.meta ?? {}; + return Boolean(meta.recoveredAtStartup) || typeof meta.recoveryReason === "string"; + } + + function heuristicTraceSummary(step: NormalizedStep): string { + const tool = step.toolCalls[0]; + const base = firstNonEmpty([ + step.userText, + step.agentText, + tool ? `Tool ${tool.name}` : "", + ]) || "(empty turn)"; + return base.replace(/\s+/g, " ").trim().slice(0, 140); + } + function stepIdentitySignature(step: StepCandidate): string { const tool = step.toolCalls[0]; const turnId = pickTurnId(step.meta, step.ts); @@ -891,6 +993,25 @@ export function createCaptureRunner(deps: CaptureDeps): CaptureRunner { return ["user", turnId, step.ts, step.userText.trim()].join("\x1f"); } + function stepRelaxedIdentitySignature(step: StepCandidate): string { + const tool = step.toolCalls[0]; + const turnId = pickTurnId(step.meta, step.ts); + if (tool) { + return [ + "tool", + turnId, + tool.name, + stableJson(tool.input), + stableJson(tool.output), + tool.errorCode ?? "", + ].join("\x1f"); + } + if (step.agentText.trim()) { + return ["assistant", turnId, step.agentText.trim()].join("\x1f"); + } + return ["user", turnId, step.userText.trim()].join("\x1f"); + } + function traceIdentitySignature(row: TraceRow): string { const tool = row.toolCalls[0]; if (tool) { @@ -913,6 +1034,24 @@ export function createCaptureRunner(deps: CaptureDeps): CaptureRunner { return ["user", row.turnId, row.ts, row.userText.trim()].join("\x1f"); } + function traceRelaxedIdentitySignature(row: TraceRow): string { + const tool = row.toolCalls[0]; + if (tool) { + return [ + "tool", + row.turnId, + tool.name, + stableJson(tool.input), + stableJson(tool.output), + tool.errorCode ?? "", + ].join("\x1f"); + } + if (row.agentText.trim()) { + return ["assistant", row.turnId, row.agentText.trim()].join("\x1f"); + } + return ["user", row.turnId, row.userText.trim()].join("\x1f"); + } + function stableJson(value: unknown): string { if (value === undefined) return ""; return JSON.stringify(sortJson(value)); @@ -1070,12 +1209,15 @@ async function runBatchScoring( llmCalls: { reflectionSynth: number; alphaScoring: number; batchedReflection: number }, episodeId: string, taskSummary: string | null, + budget: ReflectLlmBudget, ): Promise { const inputs: BatchScoreInput[] = normalized.map((step) => ({ step, existingReflection: extractReflection(step), })); + if (!budget.tryUse("batch")) return []; + try { const out = await batchScoreReflections(llm, inputs, { synthReflections: deps.cfg.synthReflections, @@ -1089,6 +1231,7 @@ async function runBatchScoring( reflection: out.scores[i] ?? disabledScore(null, "none"), })); } catch (err) { + budget.stopIfTerminal(err, "batch"); // Single failure mode: the batched call (or its validator) threw. // Fall back to per-step in the caller. We surface a warning so the // viewer can show "batch path degraded" without crashing capture. @@ -1109,13 +1252,14 @@ async function runPerStepScoring( llmCalls: { reflectionSynth: number; alphaScoring: number }, episodeId: string, contexts: ReflectionContext[], + budget: ReflectLlmBudget, ): Promise { const concurrency = Math.max(1, deps.cfg.llmConcurrency); return runConcurrently(normalized, concurrency, async (step, idx): Promise => { const context = contexts[idx] ?? {}; - const { score, synthCount } = await resolveReflection(step, llm, deps, warnings, episodeId, context); + const { score, synthCount } = await resolveReflection(step, llm, deps, warnings, episodeId, context, budget); llmCalls.reflectionSynth += synthCount; - const finalScore = await resolveAlpha(step, score, llm, deps, warnings, episodeId, context); + const finalScore = await resolveAlpha(step, score, llm, deps, warnings, episodeId, context, budget); if (finalScore !== score) llmCalls.alphaScoring += 1; return { ...step, reflection: finalScore }; }); @@ -1128,6 +1272,7 @@ async function resolveReflection( warnings: CaptureResult["warnings"], episodeId: string, context: ReflectionContext, + budget: ReflectLlmBudget, ): Promise<{ score: ReflectionScore; synthCount: number }> { const adapterProvided = step.rawReflection !== null && step.rawReflection.trim().length > 0; const extracted = extractReflection(step); @@ -1140,6 +1285,9 @@ async function resolveReflection( if (!deps.cfg.synthReflections || !llm) { return { score: disabledScore(null, "none"), synthCount: 0 }; } + if (!budget.tryUse("reflection.synth")) { + return { score: disabledScore(null, "none"), synthCount: 0 }; + } try { const synth = await synthesizeReflection(llm, step, { episodeId, @@ -1156,6 +1304,7 @@ async function resolveReflection( } return { score: disabledScore(null, "none"), synthCount: 1 }; } catch (err) { + budget.stopIfTerminal(err, "reflection.synth"); warnings.push({ stage: "reflection.synth", message: "synth failed", @@ -1173,9 +1322,11 @@ async function resolveAlpha( warnings: CaptureResult["warnings"], episodeId: string, context: ReflectionContext, + budget: ReflectLlmBudget, ): Promise { if (!current.text) return current; // nothing to grade if (!deps.cfg.alphaScoring || !llm) return current; + if (!budget.tryUse("alpha")) return current; try { const scored = await scoreReflection(llm, { @@ -1195,6 +1346,7 @@ async function resolveAlpha( model: scored.model, }; } catch (err) { + budget.stopIfTerminal(err, "alpha"); warnings.push({ stage: "alpha", message: "alpha scoring failed; keeping neutral α", @@ -1226,6 +1378,78 @@ async function runConcurrently( return out; } +interface ReflectLlmBudget { + tryUse(stage: string): boolean; + stopIfTerminal(err: unknown, stage: string): void; +} + +function createReflectLlmBudget( + configuredLimit: number, + warnings: CaptureResult["warnings"], + episodeId: string, +): ReflectLlmBudget { + const limit = Math.max(0, Math.floor(configuredLimit)); + let used = 0; + let stopped = false; + let exhaustedWarned = false; + let terminalWarned = false; + + return { + tryUse(stage) { + if (stopped || used >= limit) { + if (!exhaustedWarned) { + exhaustedWarned = true; + warnings.push({ + stage, + message: "reflect LLM budget exhausted; using non-LLM fallback for remaining steps", + detail: { episodeId, limit, used, stopped }, + }); + } + return false; + } + used += 1; + return true; + }, + stopIfTerminal(err, stage) { + if (!isTerminalReflectLlmError(err)) return; + stopped = true; + if (terminalWarned) return; + terminalWarned = true; + warnings.push({ + stage, + message: "terminal reflect LLM error; stopped remaining reflect LLM calls", + detail: { episodeId, limit, used, ...errDetail(err) }, + }); + }, + }; +} + +function isTerminalReflectLlmError(err: unknown): boolean { + if (!(err instanceof MemosError)) { + const msg = err instanceof Error ? err.message : String(err); + return terminalMessage(msg); + } + if (err.code !== ERROR_CODES.LLM_UNAVAILABLE) return terminalMessage(err.message); + const details = (err.details ?? {}) as Record; + if (details.circuitOpen === true) return true; + const status = Number(details.status); + if (status === 401 || status === 402 || status === 403) return true; + return terminalMessage(err.message); +} + +function terminalMessage(message: string): boolean { + const msg = message.toLowerCase(); + return ( + msg.includes("circuit_open") || + msg.includes("insufficient balance") || + msg.includes("invalid api key") || + msg.includes("invalid_api_key") || + msg.includes("unauthorized") || + msg.includes("account suspended") || + msg.includes("billing") + ); +} + function errDetail(err: unknown): Record { if (err instanceof MemosError) return { code: err.code, message: err.message, ...(err.details ?? {}) }; if (err instanceof Error) return { name: err.name, message: err.message }; diff --git a/apps/memos-local-plugin/core/capture/types.ts b/apps/memos-local-plugin/core/capture/types.ts index efdc637f2..5e37c4448 100644 --- a/apps/memos-local-plugin/core/capture/types.ts +++ b/apps/memos-local-plugin/core/capture/types.ts @@ -193,6 +193,19 @@ export interface CaptureConfig { alphaScoring: boolean; synthReflections: boolean; llmConcurrency: number; + /** + * Hard cap for LLM calls made by one topic-end reflect pass. This bounds + * startup recovery / dirty-episode replay so a single large episode cannot + * generate unbounded paid requests. + */ + maxReflectLlmCalls: number; + /** + * Startup-recovered episodes are reconstructed from already-persisted + * traces. Any "orphan" during that replay is usually a matching drift, not + * genuinely missing content. Keep inserts disabled by default to avoid + * duplicating historical rows while still allowing operators to opt in. + */ + maxRecoveryOrphanInserts: number; /** * V7 §3.2 batched variant. Controls when reflection synthesis + α scoring * collapse into ONE LLM call per episode instead of N per-step calls. diff --git a/apps/memos-local-plugin/core/config/defaults.ts b/apps/memos-local-plugin/core/config/defaults.ts index 1cf2d2cf6..828e8ec45 100644 --- a/apps/memos-local-plugin/core/config/defaults.ts +++ b/apps/memos-local-plugin/core/config/defaults.ts @@ -72,6 +72,12 @@ export const DEFAULT_CONFIG: ResolvedConfig = { // still contribute useful α values. synthReflections: true, llmConcurrency: 4, + // Bound topic-end reflect work so dirty startup recovery cannot replay + // a huge historical episode into thousands of paid LLM calls. + maxReflectLlmCalls: 128, + // Recovered episodes are reconstructed from persisted traces; replay + // orphans are usually matching drift, so do not insert duplicate rows. + maxRecoveryOrphanInserts: 0, // V7 §3.2 batched variant. With "auto" we issue a single LLM call // per episode for both reflection synth and α scoring as long as // the episode is short enough — this collapses 2N per-step calls diff --git a/apps/memos-local-plugin/core/config/schema.ts b/apps/memos-local-plugin/core/config/schema.ts index 7c9ff193b..a009daeee 100644 --- a/apps/memos-local-plugin/core/config/schema.ts +++ b/apps/memos-local-plugin/core/config/schema.ts @@ -114,6 +114,10 @@ const AlgorithmSchema = Type.Object({ synthReflections: Bool(false), /** Concurrency for α scoring + synth LLM calls (per_step mode only). */ llmConcurrency: NumberInRange(4, 1, 32), + /** Hard cap for one topic-end reflect pass, including recovery replay. */ + maxReflectLlmCalls: NumberInRange(128, 0, 10_000), + /** Max orphan trace inserts allowed during startup-recovered replay. */ + maxRecoveryOrphanInserts: NumberInRange(0, 0, 10_000), /** * V7 §3.2 batched variant. When/how to fold per-step reflection synth + * α scoring into one episode-level LLM call: diff --git a/apps/memos-local-plugin/core/llm/client.ts b/apps/memos-local-plugin/core/llm/client.ts index 38ac2a395..d969d714d 100644 --- a/apps/memos-local-plugin/core/llm/client.ts +++ b/apps/memos-local-plugin/core/llm/client.ts @@ -164,8 +164,12 @@ export function createLlmClientWithProvider( } function throwBreakerOpen(): never { + throw makeBreakerOpenError(); + } + + function makeBreakerOpenError(): MemosError { const until = circuitOpenUntil ?? breakerNow(); - throw new MemosError( + return new MemosError( ERROR_CODES.LLM_UNAVAILABLE, `circuit_open: ${circuitOpenedReason ?? "terminal provider error"}`, { @@ -177,6 +181,14 @@ export function createLlmClientWithProvider( ); } + function canUseHostFallback(): boolean { + return ( + config.fallbackToHost === true && + provider.name !== "host" && + getHostLlmBridge() !== null + ); + } + /** * Mark a successful primary-provider call. We **do not** clear * `lastError` / `lastFallbackAt` here — the viewer picks the most @@ -258,12 +270,18 @@ export function createLlmClientWithProvider( op: string, ): Promise<{ completion: LlmCompletion }> { // ── Circuit breaker short-circuit ── - // When the breaker is open we never reach the provider, so no paid - // request is generated. We still emit (coalesced) `circuit_open` - // status rows so the Logs viewer / Overview can surface that - // suppression is happening. + // When the breaker is open we never reach the primary provider, so + // no request is generated against the broken paid API. We still + // emit (coalesced) `circuit_open` status rows so the Logs viewer / + // Overview can surface that suppression is happening. if (breakerIsOpen()) { maybeEmitCircuitOpenStatus(opts, op); + if (canUseHostFallback()) { + return callHostFallback(makeBreakerOpenError(), messages, input, opts, op, { + keepBreakerOpen: true, + notifyError: false, + }); + } throwBreakerOpen(); } requests++; @@ -295,54 +313,13 @@ export function createLlmClientWithProvider( return { completion }; } catch (err) { if (shouldFallback(err, config, provider.name)) { - const hostProv = new HostLlmProvider(); + const primaryTerminal = breakerIsTerminal(err); + if (primaryTerminal) breakerTrip(err); try { - const res = await hostProv.complete(messages, input, makeCtx(opts, asProviderLog(rootLogger.child({ channel: "llm.host" })))); - hostFallbacks++; - facadeLog.warn("host.fallback", { - from: provider.name, - op, - reason: summarizeErr(err), - }); - const completion: LlmCompletion = { - text: res.text, - provider: provider.name, - model: config.model, - finishReason: res.finishReason, - usage: res.usage, - servedBy: "host_fallback", - durationMs: res.durationMs, - }; - record(completion, op, messages); - // The primary provider is still broken even though the host - // bridge saved this call. Tag the slot yellow (`lastFallbackAt`) - // and surface the upstream error to the user via the - // system_error log so they can see *why* fallback engaged. - // - // The circuit breaker stays CLOSED here: from the caller's - // perspective the call was rescued, and tripping the breaker - // on host-fallback success would defeat the point of the - // bridge (it exists precisely to keep going when the primary - // is down). The fallback path also already records the - // primary's failure, so the operator still sees the red trail - // in the Logs viewer. - const fallbackAt = markFallback(err); - breakerRecordSuccess(); - notifyOnError(err); - notifyStatus({ - status: "fallback", - provider: provider.name, - model: config.model, - message: summarizeErrMessage(err), - code: err instanceof MemosError ? err.code : undefined, - at: fallbackAt, - durationMs: completion.durationMs, - fallbackProvider: "host", - op, - episodeId: opts?.episodeId, - phase: opts?.phase, + return await callHostFallback(err, messages, input, opts, op, { + keepBreakerOpen: primaryTerminal, + notifyError: true, }); - return { completion }; } catch (hostErr) { failures++; const failAt = markFail(hostErr); @@ -350,7 +327,7 @@ export function createLlmClientWithProvider( primary: summarizeErr(err), host: summarizeErr(hostErr), }); - // Primary AND host bridge both failed terminally. Trip on the + // Primary AND host bridge both failed. Trip on a terminal // primary error (the one the operator typically needs to fix // — host bridge failures are usually transient stdio issues). if (breakerIsTerminal(err)) breakerTrip(err); @@ -617,6 +594,59 @@ export function createLlmClientWithProvider( } } + async function callHostFallback( + primaryErr: unknown, + messages: LlmMessage[], + input: ProviderCallInput, + opts: LlmCallOptions | undefined, + op: string, + behavior: { keepBreakerOpen: boolean; notifyError: boolean }, + ): Promise<{ completion: LlmCompletion }> { + const hostProv = new HostLlmProvider(); + const res = await hostProv.complete( + messages, + input, + makeCtx(opts, asProviderLog(rootLogger.child({ channel: "llm.host" }))), + ); + hostFallbacks++; + facadeLog.warn("host.fallback", { + from: provider.name, + op, + reason: summarizeErr(primaryErr), + }); + const completion: LlmCompletion = { + text: res.text, + provider: provider.name, + model: config.model, + finishReason: res.finishReason, + usage: res.usage, + servedBy: "host_fallback", + durationMs: res.durationMs, + }; + record(completion, op, messages); + // The primary provider is still broken even though the host bridge + // saved this call. Keep the breaker open for terminal primary + // errors so later calls can go straight to host fallback without + // touching the paid provider again. + const fallbackAt = markFallback(primaryErr); + if (!behavior.keepBreakerOpen) breakerRecordSuccess(); + if (behavior.notifyError) notifyOnError(primaryErr); + notifyStatus({ + status: "fallback", + provider: provider.name, + model: config.model, + message: summarizeErrMessage(primaryErr), + code: primaryErr instanceof MemosError ? primaryErr.code : undefined, + at: fallbackAt, + durationMs: completion.durationMs, + fallbackProvider: "host", + op, + episodeId: opts?.episodeId, + phase: opts?.phase, + }); + return { completion }; + } + const client: LlmClient = { provider: provider.name, model: config.model, diff --git a/apps/memos-local-plugin/core/storage/repos/traces.ts b/apps/memos-local-plugin/core/storage/repos/traces.ts index fe9e5b623..34f82e021 100644 --- a/apps/memos-local-plugin/core/storage/repos/traces.ts +++ b/apps/memos-local-plugin/core/storage/repos/traces.ts @@ -170,6 +170,20 @@ export function makeTracesRepo(db: StorageDb) { return db.prepare(sql).all(params).map(mapRow); }, + /** + * Full episode-scoped trace fetch with no pagination cap. Capture + * reconciliation must see every row; the normal `list()` path applies + * viewer pagination and can misclassify rows past the page as missing. + */ + listAllForEpisode(episodeId: EpisodeId): TraceRow[] { + const sql = + `SELECT ${COLUMNS.join(", ")} FROM traces WHERE episode_id = @episode_id ORDER BY ts ASC`; + return db + .prepare<{ episode_id: string }, RawTraceRow>(sql) + .all({ episode_id: episodeId }) + .map(mapRow); + }, + /** * Total row count matching the same filter (no limit/offset). * Used by list endpoints so the viewer can show "Page N of M". diff --git a/apps/memos-local-plugin/tests/unit/capture/capture-batch.test.ts b/apps/memos-local-plugin/tests/unit/capture/capture-batch.test.ts index d86290517..194441d5f 100644 --- a/apps/memos-local-plugin/tests/unit/capture/capture-batch.test.ts +++ b/apps/memos-local-plugin/tests/unit/capture/capture-batch.test.ts @@ -80,6 +80,8 @@ function baseConfig(overrides: Partial = {}): CaptureConfig { alphaScoring: true, synthReflections: true, llmConcurrency: 2, + maxReflectLlmCalls: 128, + maxRecoveryOrphanInserts: 0, batchMode: "auto", batchThreshold: 12, reflectionContextMode: "none", diff --git a/apps/memos-local-plugin/tests/unit/capture/capture.test.ts b/apps/memos-local-plugin/tests/unit/capture/capture.test.ts index 95d728338..b30cbe4c1 100644 --- a/apps/memos-local-plugin/tests/unit/capture/capture.test.ts +++ b/apps/memos-local-plugin/tests/unit/capture/capture.test.ts @@ -87,6 +87,8 @@ function baseConfig(overrides: Partial = {}): CaptureConfig { alphaScoring: true, synthReflections: false, llmConcurrency: 2, + maxReflectLlmCalls: 128, + maxRecoveryOrphanInserts: 0, // Default to per-step here so the existing assertions on // `llmCalls.alphaScoring`/`reflectionSynth` continue to hold. The // batched path has its own dedicated test file. @@ -133,15 +135,18 @@ function episodeSnapshot(opts: { function traceRow(opts: { id: string; + episodeId?: string; + sessionId?: string; ts: number; + turnId?: number; userText?: string; agentText?: string; toolCalls?: TraceRow["toolCalls"]; }): TraceRow { return { id: opts.id as TraceId, - episodeId: "ep_1" as EpisodeId, - sessionId: "se_1" as SessionId, + episodeId: (opts.episodeId ?? "ep_1") as EpisodeId, + sessionId: (opts.sessionId ?? "se_1") as SessionId, ts: opts.ts as EpochMs, userText: opts.userText ?? "", agentText: opts.agentText ?? "", @@ -157,7 +162,7 @@ function traceRow(opts: { errorSignatures: [], vecSummary: null, vecAction: null, - turnId: 1_000 as EpochMs, + turnId: (opts.turnId ?? 1_000) as EpochMs, schemaVersion: 1, }; } @@ -217,6 +222,129 @@ describe("capture/pipeline (end-to-end)", () => { }); } + it("recovered replay matches tool traces by payload when timestamps drift", async () => { + tmp.repos.traces.insert(traceRow({ + id: "tr_tool", + ts: 1_001, + turnId: 1_000, + userText: "run search", + toolCalls: [{ + name: "search", + input: { q: "memos" }, + output: "ok", + }], + })); + tmp.repos.traces.insert(traceRow({ + id: "tr_response", + ts: 1_003, + turnId: 1_000, + agentText: "done", + })); + + const runner = buildRunner({ + alphaScoring: false, + synthReflections: false, + embedTraces: false, + }, null, null); + const ep = episodeSnapshot({ + id: "ep_1", + sessionId: "se_1", + turns: [ + turn("user", "run search", 1_000), + turn("tool", "ok", 1_002, { + name: "search", + input: { q: "memos" }, + output: "ok", + }), + turn("assistant", "done", 1_003), + ], + }); + ep.meta = { recoveredAtStartup: 1_004, recoveryReason: "dirty_reward_rescore" }; + + const result = await runner.runReflect({ episode: ep, closedBy: "finalized" }); + + expect(result.traceIds).toHaveLength(2); + expect(result.warnings.some((w) => w.message.includes("skipped recovered orphan"))).toBe(false); + expect(tmp.repos.traces.listAllForEpisode("ep_1" as EpisodeId)).toHaveLength(2); + }); + + it("does not insert replay orphans for startup-recovered episodes by default", async () => { + const runner = buildRunner({ + alphaScoring: false, + synthReflections: false, + embedTraces: false, + }, null, null); + const ep = episodeSnapshot({ + id: "ep_1", + sessionId: "se_1", + turns: [ + turn("user", "hello", 1_000), + turn("assistant", "hi", 1_010), + ], + }); + ep.meta = { recoveredAtStartup: 1_020, recoveryReason: "dirty_reward_rescore" }; + + const result = await runner.runReflect({ episode: ep, closedBy: "finalized" }); + + expect(tmp.repos.traces.listAllForEpisode("ep_1" as EpisodeId)).toHaveLength(0); + expect(result.traceIds).toHaveLength(0); + expect(result.warnings.some((w) => w.message.includes("skipped recovered orphan"))).toBe(true); + }); + + it("caps reflect-phase LLM calls per episode", async () => { + tmp.repos.traces.insert(traceRow({ + id: "tr_1", + ts: 1_100, + turnId: 1_000, + userText: "u1", + agentText: "a1", + })); + tmp.repos.traces.insert(traceRow({ + id: "tr_2", + ts: 2_100, + turnId: 2_000, + userText: "u2", + agentText: "a2", + })); + tmp.repos.traces.insert(traceRow({ + id: "tr_3", + ts: 3_100, + turnId: 3_000, + userText: "u3", + agentText: "a3", + })); + const llm = fakeLlm({ + complete: { + "capture.reflection.synth": "I chose the next action for this task.", + }, + }); + const runner = buildRunner({ + alphaScoring: false, + synthReflections: true, + embedTraces: false, + batchMode: "per_step", + maxReflectLlmCalls: 1, + }, llm, null); + const ep = episodeSnapshot({ + id: "ep_1", + sessionId: "se_1", + turns: [ + turn("user", "u1", 1_000), + turn("assistant", "a1", 1_100), + turn("user", "u2", 2_000), + turn("assistant", "a2", 2_100), + turn("user", "u3", 3_000), + turn("assistant", "a3", 3_100), + ], + }); + + const result = await runner.runReflect({ episode: ep, closedBy: "finalized" }); + + expect(result.llmCalls.reflectionSynth).toBe(1); + expect(result.warnings.some((w) => w.message.includes("reflect LLM budget exhausted"))).toBe(true); + expect(llm.stats().requests).toBe(1); + }); + it("lightweight capture merges one turn into one memory with summary-only embedding", async () => { const llm = fakeLlm({ completeJson: { diff --git a/apps/memos-local-plugin/tests/unit/capture/normalizer.test.ts b/apps/memos-local-plugin/tests/unit/capture/normalizer.test.ts index a09f7d1f5..3cd0877ac 100644 --- a/apps/memos-local-plugin/tests/unit/capture/normalizer.test.ts +++ b/apps/memos-local-plugin/tests/unit/capture/normalizer.test.ts @@ -11,6 +11,8 @@ const cfg: CaptureConfig = { alphaScoring: false, synthReflections: false, llmConcurrency: 1, + maxReflectLlmCalls: 128, + maxRecoveryOrphanInserts: 0, batchMode: "per_step", batchThreshold: 12, }; diff --git a/apps/memos-local-plugin/tests/unit/llm/client.test.ts b/apps/memos-local-plugin/tests/unit/llm/client.test.ts index f103ccf5d..7e904a2c6 100644 --- a/apps/memos-local-plugin/tests/unit/llm/client.test.ts +++ b/apps/memos-local-plugin/tests/unit/llm/client.test.ts @@ -429,14 +429,16 @@ describe("llm/client", () => { expect(client.stats().circuitOpen).toBe(false); }); - it("does NOT trip when host fallback rescues the call", async () => { + it("trips on terminal primary error even when host fallback rescues the call", async () => { const sink = statusSink(); const provider = new ThrowingProvider( new MemosError(ERROR_CODES.LLM_UNAVAILABLE, "402", { status: 402 }), ); + let hostCalls = 0; registerHostLlmBridge({ id: "test.host", async complete() { + hostCalls++; return { text: "rescued", model: "host-m", durationMs: 1 }; }, }); @@ -450,12 +452,16 @@ describe("llm/client", () => { ); const r = await client.complete("call-1"); expect(r.servedBy).toBe("host_fallback"); - // Breaker still closed: fallback rescued the call. - expect(client.stats().circuitOpen).toBe(false); + // The terminal primary error still opens the breaker even though + // host fallback rescued the user-visible call. + expect(client.stats().circuitOpen).toBe(true); const r2 = await client.complete("call-2"); expect(r2.servedBy).toBe("host_fallback"); - // Provider hit twice; not short-circuited. - expect(provider.calls).toBe(2); + // The second call goes directly to host fallback and never touches + // the broken paid provider again. + expect(provider.calls).toBe(1); + expect(hostCalls).toBe(2); + expect(sink.rows.map((row) => row.status)).toContain("circuit_open"); }); it("disabled when circuitBreaker.enabled=false (legacy behavior)", async () => {