diff --git a/.scripts/source/grant_lakebase_permissions.py b/.scripts/source/grant_lakebase_permissions.py index 95b9e5f6..c5478833 100644 --- a/.scripts/source/grant_lakebase_permissions.py +++ b/.scripts/source/grant_lakebase_permissions.py @@ -47,6 +47,7 @@ "agent_server": [ "responses", "messages", + "conversation_aliases", ], }, "openai": { @@ -57,6 +58,7 @@ "agent_server": [ "responses", "messages", + "conversation_aliases", ], }, } diff --git a/agent-langgraph-advanced/agent_server/agent.py b/agent-langgraph-advanced/agent_server/agent.py index 22d0e8bc..49316d76 100644 --- a/agent-langgraph-advanced/agent_server/agent.py +++ b/agent-langgraph-advanced/agent_server/agent.py @@ -23,6 +23,7 @@ from agent_server.prompts import SYSTEM_PROMPT from agent_server.utils import ( _get_or_create_thread_id, + deduplicate_input, get_user_workspace_client, init_mcp_client, process_agent_astream_events, @@ -110,10 +111,7 @@ async def stream_handler( if user_id: config["configurable"]["user_id"] = user_id - input_state: dict[str, Any] = { - "messages": to_chat_completions_input([i.model_dump() for i in request.input]), - "custom_inputs": dict(request.custom_inputs or {}), - } + incoming_messages = to_chat_completions_input([i.model_dump() for i in request.input]) try: async with lakebase_context(LAKEBASE_CONFIG) as (checkpointer, store): @@ -123,6 +121,11 @@ async def stream_handler( # For on-behalf-of user authentication, pass get_user_workspace_client() to init_agent. agent = await init_agent(store=store, checkpointer=checkpointer) + input_state: dict[str, Any] = { + "messages": await deduplicate_input(agent, config, incoming_messages), + "custom_inputs": dict(request.custom_inputs or {}), + } + async for event in process_agent_astream_events( agent.astream(input_state, config, stream_mode=["updates", "messages"]) ): diff --git a/agent-langgraph-advanced/agent_server/utils.py b/agent-langgraph-advanced/agent_server/utils.py index 75b92de2..2d04b39f 100644 --- a/agent-langgraph-advanced/agent_server/utils.py +++ b/agent-langgraph-advanced/agent_server/utils.py @@ -40,6 +40,27 @@ def _is_databricks_app_env() -> bool: return bool(os.getenv("DATABRICKS_APP_NAME")) +async def deduplicate_input( + agent: Any, config: dict[str, Any], messages: list[dict[str, Any]] +) -> list[dict[str, Any]]: + """Drop UI-echoed history when the checkpointer already holds the thread. + + The chatbot UI replays the full conversation on each turn, but LangGraph's + checkpointer already has the prior messages keyed by ``thread_id``. Sending + them again duplicates everything in the agent's view. When we detect an + existing checkpoint for this thread, keep only the latest user message. + """ + if not messages: + return messages + try: + state = await agent.aget_state(config) + except Exception: + return messages + if state and state.values.get("messages"): + return messages[-1:] + return messages + + def init_mcp_client(workspace_client: WorkspaceClient) -> DatabricksMultiServerMCPClient: host_name = get_databricks_host_from_env() return DatabricksMultiServerMCPClient( diff --git a/agent-langgraph-advanced/scripts/grant_lakebase_permissions.py b/agent-langgraph-advanced/scripts/grant_lakebase_permissions.py index 95b9e5f6..c5478833 100644 --- a/agent-langgraph-advanced/scripts/grant_lakebase_permissions.py +++ b/agent-langgraph-advanced/scripts/grant_lakebase_permissions.py @@ -47,6 +47,7 @@ "agent_server": [ "responses", "messages", + "conversation_aliases", ], }, "openai": { @@ -57,6 +58,7 @@ "agent_server": [ "responses", "messages", + "conversation_aliases", ], }, } diff --git a/agent-langgraph/scripts/grant_lakebase_permissions.py b/agent-langgraph/scripts/grant_lakebase_permissions.py index 95b9e5f6..c5478833 100644 --- a/agent-langgraph/scripts/grant_lakebase_permissions.py +++ b/agent-langgraph/scripts/grant_lakebase_permissions.py @@ -47,6 +47,7 @@ "agent_server": [ "responses", "messages", + "conversation_aliases", ], }, "openai": { @@ -57,6 +58,7 @@ "agent_server": [ "responses", "messages", + "conversation_aliases", ], }, } diff --git a/agent-migration-from-model-serving/scripts/grant_lakebase_permissions.py b/agent-migration-from-model-serving/scripts/grant_lakebase_permissions.py index 95b9e5f6..c5478833 100644 --- a/agent-migration-from-model-serving/scripts/grant_lakebase_permissions.py +++ b/agent-migration-from-model-serving/scripts/grant_lakebase_permissions.py @@ -47,6 +47,7 @@ "agent_server": [ "responses", "messages", + "conversation_aliases", ], }, "openai": { @@ -57,6 +58,7 @@ "agent_server": [ "responses", "messages", + "conversation_aliases", ], }, } diff --git a/agent-non-conversational/scripts/grant_lakebase_permissions.py b/agent-non-conversational/scripts/grant_lakebase_permissions.py index 95b9e5f6..c5478833 100644 --- a/agent-non-conversational/scripts/grant_lakebase_permissions.py +++ b/agent-non-conversational/scripts/grant_lakebase_permissions.py @@ -47,6 +47,7 @@ "agent_server": [ "responses", "messages", + "conversation_aliases", ], }, "openai": { @@ -57,6 +58,7 @@ "agent_server": [ "responses", "messages", + "conversation_aliases", ], }, } diff --git a/agent-openai-advanced/agent_server/utils.py b/agent-openai-advanced/agent_server/utils.py index 7cd07e8c..1b001590 100644 --- a/agent-openai-advanced/agent_server/utils.py +++ b/agent-openai-advanced/agent_server/utils.py @@ -207,7 +207,7 @@ async def deduplicate_input(request: ResponsesAgentRequest, session: AsyncDatabr ): msg["content"] = [{"type": "output_text", "text": msg["content"], "annotations": []}] session_items = await session.get_items() - if len(session_items) >= len(messages) - 1: + if session_items and len(messages) > 1: return [messages[-1]] return messages diff --git a/agent-openai-advanced/scripts/grant_lakebase_permissions.py b/agent-openai-advanced/scripts/grant_lakebase_permissions.py index 95b9e5f6..c5478833 100644 --- a/agent-openai-advanced/scripts/grant_lakebase_permissions.py +++ b/agent-openai-advanced/scripts/grant_lakebase_permissions.py @@ -47,6 +47,7 @@ "agent_server": [ "responses", "messages", + "conversation_aliases", ], }, "openai": { @@ -57,6 +58,7 @@ "agent_server": [ "responses", "messages", + "conversation_aliases", ], }, } diff --git a/agent-openai-agents-sdk-multiagent/scripts/grant_lakebase_permissions.py b/agent-openai-agents-sdk-multiagent/scripts/grant_lakebase_permissions.py index 95b9e5f6..c5478833 100644 --- a/agent-openai-agents-sdk-multiagent/scripts/grant_lakebase_permissions.py +++ b/agent-openai-agents-sdk-multiagent/scripts/grant_lakebase_permissions.py @@ -47,6 +47,7 @@ "agent_server": [ "responses", "messages", + "conversation_aliases", ], }, "openai": { @@ -57,6 +58,7 @@ "agent_server": [ "responses", "messages", + "conversation_aliases", ], }, } diff --git a/agent-openai-agents-sdk/scripts/grant_lakebase_permissions.py b/agent-openai-agents-sdk/scripts/grant_lakebase_permissions.py index 95b9e5f6..c5478833 100644 --- a/agent-openai-agents-sdk/scripts/grant_lakebase_permissions.py +++ b/agent-openai-agents-sdk/scripts/grant_lakebase_permissions.py @@ -47,6 +47,7 @@ "agent_server": [ "responses", "messages", + "conversation_aliases", ], }, "openai": { @@ -57,6 +58,7 @@ "agent_server": [ "responses", "messages", + "conversation_aliases", ], }, } diff --git a/e2e-chatbot-app-next/packages/ai-sdk-providers/src/providers-server.ts b/e2e-chatbot-app-next/packages/ai-sdk-providers/src/providers-server.ts index f19dad35..b621e15b 100644 --- a/e2e-chatbot-app-next/packages/ai-sdk-providers/src/providers-server.ts +++ b/e2e-chatbot-app-next/packages/ai-sdk-providers/src/providers-server.ts @@ -71,7 +71,38 @@ export async function getWorkspaceHostname(): Promise { // Environment variable to enable SSE logging const LOG_SSE_EVENTS = process.env.LOG_SSE_EVENTS === 'true'; -const API_PROXY = process.env.API_PROXY; +// Read API_PROXY at call sites (not module load) so tests can flip it +// per-case via process.env without forcing a re-import. +const getApiProxy = () => process.env.API_PROXY; + +// Durable-execution support: when talking to a `LongRunningAgentServer` +// agent (the case when `API_PROXY` is set in our advanced templates) we +// 1. inject `background: true` so the server persists every SSE frame +// to its durable store and the retrieve endpoint can resume mid-stream; +// 2. on connection close without `[DONE]`, transparently re-stream from +// the retrieve endpoint using the last seen sequence number. +// +// We deliberately do NOT track / substitute rotated `conversation_id` +// values here. The bridge resolves rotation server-side via its +// `conversation_aliases` table — the chatbot always sends the original +// chat id and the bridge maps it to the post-rotation SDK session. This +// is what lets the durable path survive chatbot restarts and multi-pod +// chatbot deployments without any client-side state. +const MAX_RESUME_ATTEMPTS = 5; + +function extractResponseId(json: Record | null): string | null { + if (!json) return null; + if (typeof json.response_id === 'string') return json.response_id; + const resp = json.response as { id?: unknown } | undefined; + if (resp && typeof resp.id === 'string') return resp.id; + return null; +} + +function buildRetrieveUrl(invocationsUrl: string, responseId: string): string { + // The bridge mounts GET /responses/{id} on the same origin as POST /invocations. + const base = invocationsUrl.replace(/\/invocations\/?$/, ''); + return `${base}/responses/${encodeURIComponent(responseId)}`; +} // Cache for endpoint details to check task type and OBO scopes const endpointDetailsCache = new Map< @@ -87,7 +118,7 @@ const ENDPOINT_DETAILS_CACHE_DURATION = 5 * 60 * 1000; // 5 minutes function shouldInjectContext(): boolean { const servingEndpoint = process.env.DATABRICKS_SERVING_ENDPOINT; if (!servingEndpoint) { - return Boolean(API_PROXY); + return Boolean(getApiProxy()); } const cached = endpointDetailsCache.get(servingEndpoint); @@ -110,28 +141,36 @@ export const databricksFetch: typeof fetch = async (input, init) => { headers.delete(CONTEXT_HEADER_USER_ID); requestInit = { ...requestInit, headers }; - // Inject context into request body if appropriate - if ( - conversationId && - userId && - requestInit?.body && - typeof requestInit.body === 'string' - ) { - if (shouldInjectContext()) { - try { - const body = JSON.parse(requestInit.body); - const enhancedBody = { - ...body, - context: { - ...body.context, - conversation_id: conversationId, - user_id: userId, - }, + // Mutate the request body for durable execution (when we have a body to + // mutate). Two things happen here, both conditional: + // - Inject context.conversation_id / context.user_id from headers when the + // endpoint expects it (existing behavior). + // - Set body.background = true on streaming requests when API_PROXY is + // set, so the long-running server persists the stream to its store. + if (requestInit?.body && typeof requestInit.body === 'string') { + try { + const body = JSON.parse(requestInit.body); + let mutated = false; + + if (conversationId && userId && shouldInjectContext()) { + body.context = { + ...(body.context ?? {}), + conversation_id: conversationId, + user_id: userId, }; - requestInit = { ...requestInit, body: JSON.stringify(enhancedBody) }; - } catch { - // If JSON parsing fails, pass through unchanged + mutated = true; + } + + if (getApiProxy() && body.stream === true && body.background !== true) { + body.background = true; + mutated = true; } + + if (mutated) { + requestInit = { ...requestInit, body: JSON.stringify(body) }; + } + } catch { + // If JSON parsing fails, pass through unchanged } } @@ -161,58 +200,28 @@ export const databricksFetch: typeof fetch = async (input, init) => { const response = await fetch(url, requestInit); - // If SSE logging is enabled and this is a streaming response, wrap the body to log events - if (LOG_SSE_EVENTS && response.body) { + // Only wrap the response for durable resume when API_PROXY is set — + // standard Databricks serving endpoints aren't long-running servers, so + // the resume path can't fire there and we'd just pay parse cost for + // every SSE chunk for no benefit. + if (getApiProxy() && response.body) { const contentType = response.headers.get('content-type') || ''; const isSSE = contentType.includes('text/event-stream') || contentType.includes('application/x-ndjson'); if (isSSE) { - const originalBody = response.body; - const reader = originalBody.getReader(); - const decoder = new TextDecoder(); - let eventCounter = 0; - - const loggingStream = new ReadableStream({ - async pull(controller) { - const { done, value } = await reader.read(); - - if (done) { - console.log('[SSE] Stream ended'); - controller.close(); - return; - } - - // Decode and log the chunk - const text = decoder.decode(value, { stream: true }); - const lines = text.split('\n').filter((line) => line.trim()); - - for (const line of lines) { - eventCounter++; - if (line.startsWith('data:')) { - const data = line.slice(5).trim(); - try { - const parsed = JSON.parse(data); - console.log(`[SSE #${eventCounter}]`, JSON.stringify(parsed)); - } catch { - console.log(`[SSE #${eventCounter}] (raw)`, data); - } - } else if (line.trim()) { - console.log(`[SSE #${eventCounter}] (line)`, line); - } - } - - // Pass the original data through - controller.enqueue(value); - }, - cancel() { - reader.cancel(); - }, - }); - - // Create a new response with the logging stream - return new Response(loggingStream, { + // Pass only the Authorization header to the resume fetch — it's a + // simple GET, no content-type / content-length / mlflow trace + // headers needed, and copying the whole request init can carry + // along stale fields that confuse the retrieve endpoint. + const resumeHeaders = new Headers(); + const reqHeaders = new Headers(requestInit?.headers); + const auth = reqHeaders.get('authorization'); + if (auth) resumeHeaders.set('authorization', auth); + + const wrapped = wrapDurableSseStream(response.body, url, resumeHeaders); + return new Response(wrapped, { status: response.status, statusText: response.statusText, headers: response.headers, @@ -223,6 +232,113 @@ export const databricksFetch: typeof fetch = async (input, init) => { return response; }; +/** + * Wrap a long-running-server SSE response so we can: + * - track the last sequence number and response_id we observed, + * - if the upstream stream closes before `[DONE]`, transparently re-stream + * from `GET /responses/{id}?stream=true&starting_after=`. + * + * Bytes are passed through untouched; we only sniff data frames. + */ +function wrapDurableSseStream( + initialBody: ReadableStream, + invocationsUrl: string, + resumeHeaders: Headers, +): ReadableStream { + const decoder = new TextDecoder(); + let buffer = ''; + let eventCounter = 0; + let responseId: string | null = null; + let lastSeq = -1; + let sawDone = false; + let attemptsLeft = MAX_RESUME_ATTEMPTS; + + function processChunk(value: Uint8Array): void { + buffer += decoder.decode(value, { stream: true }); + while (true) { + const nl = buffer.indexOf('\n'); + if (nl === -1) break; + const line = buffer.slice(0, nl); + buffer = buffer.slice(nl + 1); + const trimmed = line.trim(); + if (!trimmed) continue; + if (!trimmed.startsWith('data:')) continue; + const data = trimmed.slice(5).trim(); + if (data === '[DONE]') { + sawDone = true; + if (LOG_SSE_EVENTS) console.log(`[SSE #${++eventCounter}] [DONE]`); + continue; + } + let json: Record | null = null; + try { + json = JSON.parse(data) as Record; + } catch { + if (LOG_SSE_EVENTS) console.log(`[SSE #${++eventCounter}] (raw)`, data); + continue; + } + if (LOG_SSE_EVENTS) { + console.log(`[SSE #${++eventCounter}]`, JSON.stringify(json)); + } + const rid = extractResponseId(json); + if (rid) responseId = rid; + const seq = json.sequence_number; + if (typeof seq === 'number' && seq > lastSeq) lastSeq = seq; + } + } + + return new ReadableStream({ + async start(controller) { + let currentBody: ReadableStream | null = initialBody; + + while (currentBody) { + const reader = currentBody.getReader(); + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + controller.enqueue(value); + processChunk(value); + } + } catch (err) { + if (LOG_SSE_EVENTS) console.warn('[SSE] read error', err); + } finally { + reader.releaseLock(); + } + + if (sawDone) break; + if (!responseId || attemptsLeft <= 0) break; + + attemptsLeft -= 1; + const startingAfter = Math.max(lastSeq, 0); + const resumeUrl = + `${buildRetrieveUrl(invocationsUrl, responseId)}` + + `?stream=true&starting_after=${startingAfter}`; + console.log( + `[SSE] upstream closed without [DONE], resuming response_id=${responseId} from seq=${startingAfter}`, + ); + try { + const resp = await fetch(resumeUrl, { + method: 'GET', + headers: resumeHeaders, + }); + if (!resp.ok || !resp.body) { + console.warn( + `[SSE] resume request failed status=${resp.status}, giving up`, + ); + break; + } + currentBody = resp.body; + } catch (err) { + console.warn('[SSE] resume fetch threw, giving up', err); + break; + } + } + + controller.close(); + }, + }); +} + type CachedProvider = ReturnType; let oauthProviderCache: CachedProvider | null = null; let oauthProviderCacheTime = 0; @@ -249,7 +365,7 @@ async function getOrCreateDatabricksProvider(): Promise { // When using endpoints such as Agent Bricks or custom agents, we need to use remote tool calling to handle the tool calls useRemoteToolCalling: true, baseURL: `${hostname}/serving-endpoints`, - formatUrl: ({ baseUrl, path }) => API_PROXY ?? `${baseUrl}${path}`, + formatUrl: ({ baseUrl, path }) => getApiProxy() ?? `${baseUrl}${path}`, fetch: async (...[input, init]: Parameters) => { const headers = new Headers(init?.headers); @@ -264,7 +380,7 @@ async function getOrCreateDatabricksProvider(): Promise { headers.set('Authorization', `Bearer ${currentToken}`); } - if (API_PROXY) { + if (getApiProxy()) { headers.set('x-mlflow-return-trace-id', 'true'); } @@ -390,7 +506,7 @@ export class OAuthAwareProvider implements SmartProvider { const provider = await getOrCreateDatabricksProvider(); const model = await (async () => { - if (API_PROXY) { + if (getApiProxy()) { // For API proxy we always use the responses agent return provider.responses(id); } diff --git a/e2e-chatbot-app-next/tests/ai-sdk-provider/durable-fetch.test.ts b/e2e-chatbot-app-next/tests/ai-sdk-provider/durable-fetch.test.ts new file mode 100644 index 00000000..556ba1d9 --- /dev/null +++ b/e2e-chatbot-app-next/tests/ai-sdk-provider/durable-fetch.test.ts @@ -0,0 +1,267 @@ +import { expect, test } from '@playwright/test'; +import { databricksFetch } from '@chat-template/ai-sdk-providers'; + +/** + * Tests for the durable-execution glue inside `databricksFetch`: + * - `background: true` injection on streaming requests when `API_PROXY` is set + * - SSE response wrapping that auto-resumes from + * `GET /responses/{id}?stream=true&starting_after=` when the upstream + * stream closes without `[DONE]` + * + * Each test stashes and restores the global fetch + the API_PROXY env var so + * tests don't leak state into each other. + */ + +const ORIG_FETCH = globalThis.fetch; +const ORIG_API_PROXY = process.env.API_PROXY; + +function sseChunk(obj: Record): Uint8Array { + return new TextEncoder().encode(`data: ${JSON.stringify(obj)}\n\n`); +} + +function sseDone(): Uint8Array { + return new TextEncoder().encode('data: [DONE]\n\n'); +} + +function makeSseResponse(chunks: Uint8Array[]): Response { + const stream = new ReadableStream({ + start(controller) { + for (const c of chunks) controller.enqueue(c); + controller.close(); + }, + }); + return new Response(stream, { + status: 200, + headers: { 'content-type': 'text/event-stream' }, + }); +} + +async function readSseFrames( + body: ReadableStream, +): Promise { + const reader = body.getReader(); + const decoder = new TextDecoder(); + let buf = ''; + const frames: string[] = []; + while (true) { + const { done, value } = await reader.read(); + if (done) break; + buf += decoder.decode(value, { stream: true }); + } + for (const line of buf.split('\n')) { + const t = line.trim(); + if (t.startsWith('data:')) frames.push(t.slice(5).trim()); + } + return frames; +} + +test.afterEach(() => { + globalThis.fetch = ORIG_FETCH; + process.env.API_PROXY = ORIG_API_PROXY ?? ''; +}); + +test.describe('background: true injection', () => { + test('injects background=true when API_PROXY is set + stream=true', async () => { + process.env.API_PROXY = 'http://localhost:8000/invocations'; + let capturedBody: Record | null = null; + globalThis.fetch = (async (_input, init) => { + capturedBody = JSON.parse((init?.body as string) ?? '{}'); + return new Response('{}', { + status: 200, + headers: { 'content-type': 'application/json' }, + }); + }) as typeof fetch; + + await databricksFetch('http://localhost:8000/invocations', { + method: 'POST', + body: JSON.stringify({ input: [], stream: true }), + }); + + expect(capturedBody?.background).toBe(true); + expect(capturedBody?.stream).toBe(true); + }); + + test('leaves background alone when API_PROXY is NOT set', async () => { + process.env.API_PROXY = ''; + let capturedBody: Record | null = null; + globalThis.fetch = (async (_input, init) => { + capturedBody = JSON.parse((init?.body as string) ?? '{}'); + return new Response('{}', { + status: 200, + headers: { 'content-type': 'application/json' }, + }); + }) as typeof fetch; + + await databricksFetch( + 'http://example.com/serving-endpoints/x/invocations', + { + method: 'POST', + body: JSON.stringify({ input: [], stream: true }), + }, + ); + + expect(capturedBody?.background).toBeUndefined(); + }); + + test('leaves background alone for non-streaming requests', async () => { + process.env.API_PROXY = 'http://localhost:8000/invocations'; + let capturedBody: Record | null = null; + globalThis.fetch = (async (_input, init) => { + capturedBody = JSON.parse((init?.body as string) ?? '{}'); + return new Response('{}', { + status: 200, + headers: { 'content-type': 'application/json' }, + }); + }) as typeof fetch; + + await databricksFetch('http://localhost:8000/invocations', { + method: 'POST', + body: JSON.stringify({ input: [], stream: false }), + }); + + expect(capturedBody?.background).toBeUndefined(); + }); +}); + +test.describe('durable resume on stream close-without-DONE', () => { + test('fires GET retrieve with starting_after= when SSE ends without DONE', async () => { + process.env.API_PROXY = 'http://localhost:8000/invocations'; + + const fetchCalls: { url: string; init: RequestInit | undefined }[] = []; + globalThis.fetch = (async (input, init) => { + const url = input.toString(); + fetchCalls.push({ url, init }); + + if (init?.method === 'GET') { + // resume call — return the rest of the stream + DONE + return makeSseResponse([ + sseChunk({ + type: 'response.output_text.delta', + sequence_number: 3, + response_id: 'resp_abc', + }), + sseChunk({ + type: 'response.completed', + sequence_number: 4, + response_id: 'resp_abc', + }), + sseDone(), + ]); + } + // initial POST — return a stream that closes WITHOUT [DONE] + return makeSseResponse([ + sseChunk({ + type: 'response.created', + sequence_number: 0, + response_id: 'resp_abc', + }), + sseChunk({ + type: 'response.output_text.delta', + sequence_number: 1, + response_id: 'resp_abc', + }), + sseChunk({ + type: 'response.output_text.delta', + sequence_number: 2, + response_id: 'resp_abc', + }), + ]); + }) as typeof fetch; + + const response = await databricksFetch( + 'http://localhost:8000/invocations', + { + method: 'POST', + body: JSON.stringify({ input: [], stream: true }), + headers: { Authorization: 'Bearer tok-123' }, + }, + ); + + // Drain the wrapped body — this is what triggers the resume internally. + if (!response.body) throw new Error('expected response.body'); + const frames = await readSseFrames(response.body); + + // Two fetches: initial POST + one GET resume. + expect(fetchCalls.length).toBe(2); + expect(fetchCalls[0]?.init?.method).toBe('POST'); + expect(fetchCalls[1]?.init?.method).toBe('GET'); + expect(fetchCalls[1]?.url).toContain('/responses/resp_abc'); + expect(fetchCalls[1]?.url).toContain('starting_after=2'); + + // Frames from both halves end up downstream. + expect(frames.some((f) => f.includes('"sequence_number":1'))).toBe(true); + expect(frames.some((f) => f.includes('"sequence_number":3'))).toBe(true); + expect(frames).toContain('[DONE]'); + }); + + test('resume request carries ONLY the Authorization header', async () => { + process.env.API_PROXY = 'http://localhost:8000/invocations'; + + const resumeHeaders = new Headers(); + let resumeHeadersCaptured = false; + globalThis.fetch = (async (_input, init) => { + if (init?.method === 'GET') { + for (const [k, v] of new Headers(init.headers)) resumeHeaders.set(k, v); + resumeHeadersCaptured = true; + return makeSseResponse([sseDone()]); + } + return makeSseResponse([ + sseChunk({ + type: 'response.created', + sequence_number: 0, + response_id: 'resp_xyz', + }), + ]); + }) as typeof fetch; + + const response = await databricksFetch( + 'http://localhost:8000/invocations', + { + method: 'POST', + body: JSON.stringify({ input: [], stream: true }), + headers: { + Authorization: 'Bearer tok-xyz', + 'content-type': 'application/json', + 'x-mlflow-return-trace-id': 'true', + 'x-some-other-header': 'leakage', + }, + }, + ); + if (response.body) await readSseFrames(response.body); + + expect(resumeHeadersCaptured).toBe(true); + expect(resumeHeaders.get('authorization')).toBe('Bearer tok-xyz'); + expect(resumeHeaders.get('content-type')).toBeNull(); + expect(resumeHeaders.get('x-mlflow-return-trace-id')).toBeNull(); + expect(resumeHeaders.get('x-some-other-header')).toBeNull(); + }); + + test('does NOT wrap SSE responses when API_PROXY is unset', async () => { + process.env.API_PROXY = ''; + + let fetchCount = 0; + globalThis.fetch = (async () => { + fetchCount += 1; + // Return a stream that closes without [DONE]. If we were wrapping, + // this would trigger a resume fetch and bump the count. + return makeSseResponse([ + sseChunk({ + type: 'response.created', + sequence_number: 0, + response_id: 'resp_no_wrap', + }), + ]); + }) as typeof fetch; + + const response = await databricksFetch( + 'http://example.com/serving-endpoints/x/invocations', + { + method: 'POST', + body: JSON.stringify({ input: [], stream: true }), + }, + ); + if (response.body) await readSseFrames(response.body); + + expect(fetchCount).toBe(1); + }); +});