diff --git a/helm/spritz/templates/ui-deployment.yaml b/helm/spritz/templates/ui-deployment.yaml index a96c49d..3a4a6cc 100644 --- a/helm/spritz/templates/ui-deployment.yaml +++ b/helm/spritz/templates/ui-deployment.yaml @@ -32,6 +32,8 @@ spec: env: - name: SPRITZ_API_BASE_URL value: {{ .Values.ui.apiBaseUrl | default (include "spritz.routeModel.apiPathPrefix" .) | quote }} + - name: SPRITZ_UI_WEBSOCKET_BASE_URL + value: {{ .Values.ui.websocketBaseUrl | quote }} - name: SPRITZ_UI_CHAT_PATH_PREFIX value: {{ include "spritz.routeModel.chatPathPrefix" . | quote }} - name: SPRITZ_UI_OWNER_ID diff --git a/helm/spritz/values.yaml b/helm/spritz/values.yaml index 78752bf..e95ab8c 100644 --- a/helm/spritz/values.yaml +++ b/helm/spritz/values.yaml @@ -304,6 +304,7 @@ ui: ingress: enabled: true apiBaseUrl: "/api" + websocketBaseUrl: "" ownerId: "" assetVersion: "" presets: [] diff --git a/ui/entrypoint.sh b/ui/entrypoint.sh index bfdf98a..43922bb 100755 --- a/ui/entrypoint.sh +++ b/ui/entrypoint.sh @@ -2,6 +2,7 @@ set -eu API_BASE_URL="${SPRITZ_API_BASE_URL:-}" +WEBSOCKET_BASE_URL="${SPRITZ_UI_WEBSOCKET_BASE_URL:-}" CHAT_PATH_PREFIX="${SPRITZ_UI_CHAT_PATH_PREFIX:-}" OWNER_ID="${SPRITZ_UI_OWNER_ID:-}" AUTH_MODE="${SPRITZ_UI_AUTH_MODE:-}" @@ -46,6 +47,7 @@ escape_sed() { } API_BASE_URL_ESCAPED="$(escape_sed "$API_BASE_URL")" +WEBSOCKET_BASE_URL_ESCAPED="$(escape_sed "$WEBSOCKET_BASE_URL")" CHAT_PATH_PREFIX_ESCAPED="$(escape_sed "$CHAT_PATH_PREFIX")" OWNER_ID_ESCAPED="$(escape_sed "$OWNER_ID")" AUTH_MODE_ESCAPED="$(escape_sed "$AUTH_MODE")" @@ -76,6 +78,7 @@ BRANDING_CONFIG_ESCAPED="$(escape_sed "$BRANDING_CONFIG_VALUE")" ASSET_VERSION_ESCAPED="$(escape_sed "$ASSET_VERSION")" sed "s|__SPRITZ_API_BASE_URL__|${API_BASE_URL_ESCAPED}|g" "${HTML_DIR}/config.js" \ + | sed "s|__SPRITZ_UI_WEBSOCKET_BASE_URL__|${WEBSOCKET_BASE_URL_ESCAPED}|g" \ | sed "s|__SPRITZ_UI_CHAT_PATH_PREFIX__|${CHAT_PATH_PREFIX_ESCAPED}|g" \ | sed "s|__SPRITZ_OWNER_ID__|${OWNER_ID_ESCAPED}|g" \ | sed "s|__SPRITZ_UI_AUTH_MODE__|${AUTH_MODE_ESCAPED}|g" \ diff --git a/ui/public/config.js b/ui/public/config.js index c7a6c95..a06ac5d 100644 --- a/ui/public/config.js +++ b/ui/public/config.js @@ -1,5 +1,6 @@ window.SPRITZ_CONFIG = { apiBaseUrl: '__SPRITZ_API_BASE_URL__', + websocketBaseUrl: '__SPRITZ_UI_WEBSOCKET_BASE_URL__', chatPathPrefix: '__SPRITZ_UI_CHAT_PATH_PREFIX__', ownerId: '__SPRITZ_OWNER_ID__', presets: __SPRITZ_UI_PRESETS__, diff --git a/ui/src/lib/api.ts b/ui/src/lib/api.ts index 1c913a8..30088d3 100644 --- a/ui/src/lib/api.ts +++ b/ui/src/lib/api.ts @@ -102,6 +102,25 @@ export function getAuthToken(): string { return readTokenFromStorage(authTokenStorageKeys); } +/** + * Attempts the configured bearer refresh flow for direct WebSocket connections. + * Returns the latest stored bearer token and whether a refresh succeeded. + */ +export async function refreshAuthTokenForWebSocket(): Promise<{ + token: string; + refreshed: boolean; +}> { + const token = getAuthToken(); + if (!shouldAttemptAuthRefresh()) { + return { token, refreshed: false }; + } + const refreshResult = await runAuthRefresh(); + return { + token: getAuthToken(), + refreshed: refreshResult.ok, + }; +} + function getAuthRefreshToken(): string { return readTokenFromStorage(authRefreshTokenStorageKeys); } diff --git a/ui/src/lib/config.test.ts b/ui/src/lib/config.test.ts index 467d4c6..721672d 100644 --- a/ui/src/lib/config.test.ts +++ b/ui/src/lib/config.test.ts @@ -29,4 +29,14 @@ describe('resolveConfig', () => { expect(config.branding.theme.background).toBe(''); expect(config.branding.terminal.background).toBe(''); }); + + it('preserves websocket base overrides separately from the api base url', () => { + const config = resolveConfig({ + apiBaseUrl: 'https://api.example.com/base', + websocketBaseUrl: 'https://ws.example.com/base', + }); + + expect(config.apiBaseUrl).toBe('https://api.example.com/base'); + expect(config.websocketBaseUrl).toBe('https://ws.example.com/base'); + }); }); diff --git a/ui/src/lib/config.ts b/ui/src/lib/config.ts index 21a8a0a..27022c8 100644 --- a/ui/src/lib/config.ts +++ b/ui/src/lib/config.ts @@ -80,6 +80,7 @@ export interface Preset { export interface SpritzConfig { apiBaseUrl: string; + websocketBaseUrl: string; chatPathPrefix: string; ownerId: string; presets: Preset[] | string; @@ -100,6 +101,7 @@ declare global { export function resolveConfig(raw: RawSpritzConfig = {}): SpritzConfig { return { apiBaseUrl: raw.apiBaseUrl || '', + websocketBaseUrl: raw.websocketBaseUrl || '', chatPathPrefix: raw.chatPathPrefix || '/c', ownerId: raw.ownerId || '', presets: raw.presets || [], diff --git a/ui/src/lib/network.ts b/ui/src/lib/network.ts new file mode 100644 index 0000000..89bdd53 --- /dev/null +++ b/ui/src/lib/network.ts @@ -0,0 +1,81 @@ +const URL_PARSE_BASE = 'http://spritz.local'; + +function resolveLocationHref(locationHref?: string): string { + if (locationHref) return locationHref; + if (typeof window !== 'undefined' && window.location?.href) { + return window.location.href; + } + return `${URL_PARSE_BASE}/`; +} + +function normalizeApiBaseUrl(apiBaseUrl: string, locationHref?: string): URL { + const trimmed = String(apiBaseUrl || '').trim(); + const base = trimmed || '/'; + return new URL(base, resolveLocationHref(locationHref)); +} + +function normalizeWebSocketBaseUrl( + apiBaseUrl: string, + websocketBaseUrl?: string, + locationHref?: string, +): URL { + const location = new URL(resolveLocationHref(locationHref)); + const explicitBase = String(websocketBaseUrl || '').trim(); + if (explicitBase) { + return new URL(explicitBase, location.href); + } + const apiUrl = normalizeApiBaseUrl(apiBaseUrl, locationHref); + const sameHostUrl = new URL(location.origin); + sameHostUrl.pathname = apiUrl.pathname; + sameHostUrl.search = apiUrl.search; + sameHostUrl.hash = apiUrl.hash; + return sameHostUrl; +} + +function normalizeRelativePath(path: string): URL { + const trimmed = String(path || '').trim(); + const normalized = trimmed.startsWith('/') ? trimmed : `/${trimmed}`; + return new URL(normalized || '/', URL_PARSE_BASE); +} + +function joinPaths(basePath: string, relativePath: string): string { + const normalizedBase = `/${String(basePath || '').replace(/^\/+|\/+$/g, '')}`; + const normalizedRelative = String(relativePath || '').replace(/^\/+/, ''); + if (!normalizedRelative) return normalizedBase === '/' ? '/' : normalizedBase; + if (normalizedBase === '/') return `/${normalizedRelative}`; + return `${normalizedBase}/${normalizedRelative}`; +} + +export function buildApiWebSocketUrl( + apiBaseUrl: string, + path: string, + options?: { + bearerToken?: string; + bearerTokenParam?: string; + websocketBaseUrl?: string; + locationHref?: string; + }, +): string { + const url = normalizeWebSocketBaseUrl( + apiBaseUrl, + options?.websocketBaseUrl, + options?.locationHref, + ); + const relative = normalizeRelativePath(path); + url.pathname = joinPaths(url.pathname, relative.pathname); + url.search = relative.search; + url.hash = relative.hash; + if (url.protocol === 'https:') { + url.protocol = 'wss:'; + } else if (url.protocol === 'http:') { + url.protocol = 'ws:'; + } + const bearerToken = String(options?.bearerToken || '').trim(); + if (bearerToken) { + url.searchParams.set( + String(options?.bearerTokenParam || 'token').trim() || 'token', + bearerToken, + ); + } + return url.toString(); +} diff --git a/ui/src/pages/chat.test.tsx b/ui/src/pages/chat.test.tsx index 8918ae5..b9c4da4 100644 --- a/ui/src/pages/chat.test.tsx +++ b/ui/src/pages/chat.test.tsx @@ -1,18 +1,38 @@ import type React from 'react'; import { describe, it, expect, beforeEach, vi } from 'vite-plus/test'; -import { render, screen, waitFor } from '@testing-library/react'; +import { act, render, screen, waitFor } from '@testing-library/react'; import userEvent from '@testing-library/user-event'; import { MemoryRouter, Route, Routes } from 'react-router-dom'; import { createMockStorage } from '@/test/helpers'; -import { ConfigProvider, config } from '@/lib/config'; +import { ConfigProvider, config, resolveConfig, type RawSpritzConfig } from '@/lib/config'; import { NoticeProvider } from '@/components/notice-banner'; import { ChatPage } from './chat'; -const { requestMock, sendPromptMock, emitUpdate, emitReplayState, setUpdateHandler, setReplayStateHandler } = vi.hoisted(() => { +const { + requestMock, + sendPromptMock, + emitUpdate, + emitReplayState, + setUpdateHandler, + setReplayStateHandler, + getAuthTokenMock, + setAuthToken, + refreshAuthTokenForWebSocketMock, + setRefreshAuthResult, + setACPStartReady, + getACPStartReady, + captureACPOptions, + getLastACPOptions, + resetACPMockState, +} = vi.hoisted(() => { let updateHandler: | ((update: Record, options?: { historical?: boolean }) => void) | undefined; let replayStateHandler: ((replaying: boolean) => void) | undefined; + let authToken = ''; + let refreshResult = { token: '', refreshed: false }; + let acpStartReady = true; + let lastACPOptions: Record | null = null; return { requestMock: vi.fn(), sendPromptMock: vi.fn(), @@ -30,11 +50,38 @@ const { requestMock, sendPromptMock, emitUpdate, emitReplayState, setUpdateHandl setReplayStateHandler: (handler?: (replaying: boolean) => void) => { replayStateHandler = handler; }, + getAuthTokenMock: () => authToken, + setAuthToken: (value: string) => { + authToken = value; + }, + refreshAuthTokenForWebSocketMock: vi.fn(async () => refreshResult), + setRefreshAuthResult: (value: { token: string; refreshed: boolean }) => { + refreshResult = value; + }, + setACPStartReady: (value: boolean) => { + acpStartReady = value; + }, + captureACPOptions: (options: Record) => { + lastACPOptions = options; + }, + getLastACPOptions: () => lastACPOptions, + resetACPMockState: () => { + authToken = ''; + refreshResult = { token: '', refreshed: false }; + acpStartReady = true; + lastACPOptions = null; + updateHandler = undefined; + replayStateHandler = undefined; + }, + getACPStartReady: () => acpStartReady, }; }); vi.mock('@/lib/api', () => ({ request: requestMock, + getAuthToken: getAuthTokenMock, + refreshAuthTokenForWebSocket: refreshAuthTokenForWebSocketMock, + authBearerTokenParam: 'token', })); vi.mock('@/lib/acp-client', () => ({ @@ -49,20 +96,25 @@ vi.mock('@/lib/acp-client', () => ({ return ''; }, createACPClient: ({ + wsUrl, onReadyChange, onStatus, onUpdate, onReplayStateChange, + ...rest }: { + wsUrl: string; onReadyChange?: (ready: boolean) => void; onStatus?: (status: string) => void; onUpdate?: (update: Record, options?: { historical?: boolean }) => void; onReplayStateChange?: (replaying: boolean) => void; }) => { + captureACPOptions({ wsUrl, ...rest }); setUpdateHandler(onUpdate); setReplayStateHandler(onReplayStateChange); return { start: vi.fn(async () => { + if (!getACPStartReady()) return; onStatus?.('Connected'); onReadyChange?.(true); }), @@ -199,10 +251,11 @@ function createDeferred() { return { promise, resolve, reject }; } -async function renderChat(route: string) { +async function renderChat(route: string, rawConfig?: RawSpritzConfig) { + const resolvedConfig = rawConfig ? resolveConfig({ ...config, ...rawConfig }) : config; render( - + } /> @@ -227,12 +280,106 @@ describe('ChatPage draft persistence', () => { }); requestMock.mockReset(); sendPromptMock.mockReset(); - setUpdateHandler(undefined); - setReplayStateHandler(undefined); + refreshAuthTokenForWebSocketMock.mockClear(); + resetACPMockState(); sendPromptMock.mockResolvedValue({}); setupRequestMock(); }); + it('keeps ACP websocket connections on the current host by default', async () => { + setAuthToken('external-ui-token'); + + await renderChat('/c/covo/conv-1', { + apiBaseUrl: 'https://spritz.example.com/api', + auth: { + mode: 'bearer', + tokenStorageKeys: 'spritz-token', + }, + }); + + await waitFor(() => { + expect(getLastACPOptions()?.wsUrl).toBe( + 'ws://localhost:3000/api/acp/conversations/conv-1/connect?token=external-ui-token', + ); + }); + }); + + it('uses an explicit websocket base url for cross-host ACP websocket connections', async () => { + setAuthToken('external-ui-token'); + + await renderChat('/c/covo/conv-1', { + apiBaseUrl: 'https://spritz.example.com/api', + websocketBaseUrl: 'https://spritz.example.com/api', + auth: { + mode: 'bearer', + tokenStorageKeys: 'spritz-token', + }, + }); + + await waitFor(() => { + expect(getLastACPOptions()?.wsUrl).toBe( + 'wss://spritz.example.com/api/acp/conversations/conv-1/connect?token=external-ui-token', + ); + }); + }); + + it('refreshes bearer auth and reconnects ACP websocket when the socket closes before ready', async () => { + setACPStartReady(false); + setAuthToken('expired-token'); + setRefreshAuthResult({ token: 'refreshed-token', refreshed: true }); + refreshAuthTokenForWebSocketMock.mockImplementation(async () => { + setAuthToken('refreshed-token'); + return { token: 'refreshed-token', refreshed: true }; + }); + + const resolvedConfig = resolveConfig({ + ...config, + apiBaseUrl: 'https://spritz.example.com/api', + websocketBaseUrl: 'https://spritz.example.com/api', + auth: { + mode: 'bearer', + tokenStorageKeys: 'spritz-token', + refresh: { + enabled: 'true', + url: '/oauth/refresh', + tokenStorageKeys: 'spritz-refresh-token', + }, + }, + }); + + render( + + + + + } /> + + + + , + ); + + await waitFor(() => { + expect(getLastACPOptions()?.wsUrl).toBe( + 'wss://spritz.example.com/api/acp/conversations/conv-1/connect?token=expired-token', + ); + }); + + const onClose = getLastACPOptions()?.onClose as (() => void) | undefined; + expect(onClose).toEqual(expect.any(Function)); + + act(() => { + onClose?.(); + }); + + await waitFor(() => { + expect(refreshAuthTokenForWebSocketMock).toHaveBeenCalledTimes(1); + expect(getLastACPOptions()?.wsUrl).toBe( + 'wss://spritz.example.com/api/acp/conversations/conv-1/connect?token=refreshed-token', + ); + }); + }); + it('restores the draft after remounting the same conversation route', async () => { const user = userEvent.setup(); const firstRender = render( diff --git a/ui/src/pages/chat.tsx b/ui/src/pages/chat.tsx index d4cc8d4..55edd30 100644 --- a/ui/src/pages/chat.tsx +++ b/ui/src/pages/chat.tsx @@ -2,10 +2,11 @@ import { useState, useEffect, useCallback, useRef } from 'react'; import { useParams, useNavigate } from 'react-router-dom'; import { toast } from 'sonner'; import { MenuIcon, RotateCwIcon, ExternalLinkIcon } from 'lucide-react'; -import { request } from '@/lib/api'; +import { request, getAuthToken, refreshAuthTokenForWebSocket, authBearerTokenParam } from '@/lib/api'; import { cn } from '@/lib/utils'; import { useConfig } from '@/lib/config'; import { createACPClient } from '@/lib/acp-client'; +import { buildApiWebSocketUrl } from '@/lib/network'; import { createTranscript, applySessionUpdate, finalizeStreaming, finalizeHistoricalThinking, getPreviewText, isTranscriptBearingUpdate } from '@/lib/acp-transcript'; import { readCachedTranscript, writeCachedTranscript, evictCachedTranscript } from '@/lib/acp-cache'; import { readChatDraft, writeChatDraft, clearChatDraft } from '@/lib/chat-draft'; @@ -156,6 +157,7 @@ export function ChatPage() { replaySawTranscriptUpdateRef.current = false; const apiBase = config.apiBaseUrl || ''; + const websocketBase = config.websocketBaseUrl || ''; function needsBootstrap(conv: ConversationInfo, force?: boolean): boolean { if (force) return true; @@ -164,14 +166,18 @@ export function ChatPage() { return String(conv.status?.bindingState || '').trim().toLowerCase() !== 'active'; } - async function connect(options: { forceBootstrap?: boolean } = {}) { + async function connect(options: { + forceBootstrap?: boolean; + allowAuthRefreshRetry?: boolean; + } = {}) { if (cancelled) return; + const { forceBootstrap = false, allowAuthRefreshRetry = true } = options; let effectiveConversation = selectedConversation!; let effectiveSessionId = String(effectiveConversation.spec?.sessionId || '').trim(); // Step 1: Bootstrap if needed - if (needsBootstrap(effectiveConversation, options.forceBootstrap)) { + if (needsBootstrap(effectiveConversation, forceBootstrap)) { setStatus('Bootstrapping…'); let bootstrapData: Record; try { @@ -211,17 +217,44 @@ export function ChatPage() { if (cancelled) return; // Step 2: Connect WebSocket - const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; - const wsHost = window.location.host; - const wsUrl = `${wsProtocol}//${wsHost}${apiBase}/acp/conversations/${encodeURIComponent(conversationId)}/connect`; + const wsUrl = buildApiWebSocketUrl( + apiBase, + `/acp/conversations/${encodeURIComponent(conversationId)}/connect`, + { + bearerToken: getAuthToken(), + bearerTokenParam: authBearerTokenParam, + websocketBaseUrl: websocketBase, + }, + ); replaySawTranscriptUpdateRef.current = false; + let socketReady = false; + + function scheduleReconnect() { + if (cancelled) return; + setStatus('Disconnected. Reconnecting…'); + if (reconnectTimerRef.current) { + clearTimeout(reconnectTimerRef.current); + } + reconnectTimerRef.current = setTimeout(() => { + if (cancelled) return; + retryCount++; + connect({ forceBootstrap: retryCount > 1 }).catch((err) => { + if (!cancelled) setStatus(err instanceof Error ? err.message : 'Reconnect failed'); + }); + }, RECONNECT_DELAY_MS); + } const client = createACPClient({ wsUrl, conversation: effectiveConversation, onStatus: (text) => { if (!cancelled) setStatus(text); }, - onReadyChange: (ready) => { if (!cancelled) setClientReady(ready); }, + onReadyChange: (ready) => { + if (ready) { + socketReady = true; + } + if (!cancelled) setClientReady(ready); + }, onReplayStateChange: (replaying) => { if (cancelled) return; if (replaying) { @@ -288,15 +321,24 @@ export function ChatPage() { }, onClose: () => { if (cancelled) return; - setStatus('Disconnected. Reconnecting…'); - // Auto-reconnect after delay (matching staging behavior) - reconnectTimerRef.current = setTimeout(() => { - if (cancelled) return; - retryCount++; - connect({ forceBootstrap: retryCount > 1 }).catch((err) => { - if (!cancelled) setStatus(err instanceof Error ? err.message : 'Reconnect failed'); - }); - }, RECONNECT_DELAY_MS); + if (!socketReady && allowAuthRefreshRetry) { + void (async () => { + try { + const refreshed = await refreshAuthTokenForWebSocket(); + if (cancelled) return; + if (refreshed.refreshed && refreshed.token) { + clientRef.current = null; + await connect({ forceBootstrap, allowAuthRefreshRetry: false }); + return; + } + } catch { + // Fall through to the normal reconnect timer when refresh fails. + } + scheduleReconnect(); + })(); + return; + } + scheduleReconnect(); }, }); @@ -357,7 +399,7 @@ export function ChatPage() { setClientReady(false); setPromptInFlight(false); }; - }, [selectedConversation?.metadata?.name, config.apiBaseUrl]); + }, [selectedConversation?.metadata?.name, config.apiBaseUrl, config.websocketBaseUrl]); useEffect(() => { if (!selectedSpritzName || !selectedConversationId) { diff --git a/ui/src/pages/terminal.test.tsx b/ui/src/pages/terminal.test.tsx index edba7ee..cdde95b 100644 --- a/ui/src/pages/terminal.test.tsx +++ b/ui/src/pages/terminal.test.tsx @@ -1,14 +1,42 @@ import { describe, it, expect, beforeEach, vi } from 'vite-plus/test'; -import { render } from '@testing-library/react'; +import { act, render, waitFor } from '@testing-library/react'; import { MemoryRouter, Route, Routes } from 'react-router-dom'; import { ConfigProvider, resolveConfig } from '@/lib/config'; import { TerminalPage } from './terminal'; import { FakeWebSocket } from '@/test/helpers'; -const { terminalConstructor, fitAddonConstructor } = vi.hoisted(() => ({ - terminalConstructor: vi.fn(), - fitAddonConstructor: vi.fn(), -})); +const { + terminalConstructor, + fitAddonConstructor, + getAuthTokenMock, + setAuthToken, + emitTerminalData, + refreshAuthTokenForWebSocketMock, + setRefreshAuthResult, + setOnDataHandler, +} = vi.hoisted(() => { + let authToken = ''; + let refreshResult = { token: '', refreshed: false }; + let onDataHandler: ((data: string) => void) | null = null; + return { + terminalConstructor: vi.fn(), + fitAddonConstructor: vi.fn(), + getAuthTokenMock: () => authToken, + setAuthToken: (value: string) => { + authToken = value; + }, + emitTerminalData: (data: string) => { + onDataHandler?.(data); + }, + refreshAuthTokenForWebSocketMock: vi.fn(async () => refreshResult), + setRefreshAuthResult: (value: { token: string; refreshed: boolean }) => { + refreshResult = value; + }, + setOnDataHandler: (handler: ((data: string) => void) | null) => { + onDataHandler = handler; + }, + }; +}); vi.mock('@xterm/xterm', () => ({ Terminal: function MockTerminal(options: unknown) { @@ -17,7 +45,12 @@ vi.mock('@xterm/xterm', () => ({ loadAddon: vi.fn(), open: vi.fn(), write: vi.fn(), - onData: vi.fn(() => ({ dispose: vi.fn() })), + onData: vi.fn((handler: (data: string) => void) => { + setOnDataHandler(handler); + return { + dispose: vi.fn(), + }; + }), onBinary: vi.fn(() => ({ dispose: vi.fn() })), onResize: vi.fn(() => ({ dispose: vi.fn() })), dispose: vi.fn(), @@ -36,16 +69,44 @@ vi.mock('@xterm/addon-fit', () => ({ })); vi.mock('@/lib/api', () => ({ - getAuthToken: () => '', + getAuthToken: getAuthTokenMock, + refreshAuthTokenForWebSocket: refreshAuthTokenForWebSocketMock, authBearerTokenParam: 'token', })); describe('TerminalPage branding', () => { + let lastSocket: FakeWebSocket | null = null; + let sockets: FakeWebSocket[] = []; + let deferCloseEvents = false; + beforeEach(() => { terminalConstructor.mockReset(); fitAddonConstructor.mockReset(); + refreshAuthTokenForWebSocketMock.mockClear(); + setAuthToken(''); + setRefreshAuthResult({ token: '', refreshed: false }); + setOnDataHandler(null); + lastSocket = null; + sockets = []; + deferCloseEvents = false; Object.defineProperty(globalThis, 'WebSocket', { - value: FakeWebSocket, + value: class extends FakeWebSocket { + constructor(url: string) { + super(url); + sockets.push(this); + lastSocket = this; + } + + close() { + this.readyState = FakeWebSocket.CLOSED; + const fireClose = () => this.onclose?.(new CloseEvent('close')); + if (deferCloseEvents) { + queueMicrotask(fireClose); + return; + } + fireClose(); + } + }, writable: true, }); }); @@ -80,4 +141,168 @@ describe('TerminalPage branding', () => { })); expect(fitAddonConstructor).toHaveBeenCalled(); }); + + it('keeps terminal websocket connections on the current host by default', () => { + setAuthToken('external-ui-token'); + const config = resolveConfig({ + apiBaseUrl: 'https://spritz.example.com/api', + auth: { + mode: 'bearer', + tokenStorageKeys: 'spritz-token', + }, + }); + + render( + + + + } /> + + + , + ); + + expect(lastSocket?.url).toBe( + 'ws://localhost:3000/api/spritzes/example-instance/terminal?token=external-ui-token', + ); + }); + + it('uses an explicit websocket base url for cross-host terminal websocket connections', () => { + setAuthToken('external-ui-token'); + const config = resolveConfig({ + apiBaseUrl: 'https://spritz.example.com/api', + websocketBaseUrl: 'https://spritz.example.com/api', + auth: { + mode: 'bearer', + tokenStorageKeys: 'spritz-token', + }, + }); + + render( + + + + } /> + + + , + ); + + expect(lastSocket?.url).toBe( + 'wss://spritz.example.com/api/spritzes/example-instance/terminal?token=external-ui-token', + ); + }); + + it('refreshes bearer auth and reconnects when the initial terminal websocket closes before opening', async () => { + setAuthToken('expired-token'); + setRefreshAuthResult({ token: 'refreshed-token', refreshed: true }); + refreshAuthTokenForWebSocketMock.mockImplementation(async () => { + setAuthToken('refreshed-token'); + return { token: 'refreshed-token', refreshed: true }; + }); + + const config = resolveConfig({ + apiBaseUrl: 'https://spritz.example.com/api', + websocketBaseUrl: 'https://spritz.example.com/api', + auth: { + mode: 'bearer', + tokenStorageKeys: 'spritz-token', + refresh: { + enabled: 'true', + url: '/oauth/refresh', + tokenStorageKeys: 'spritz-refresh-token', + }, + }, + }); + + render( + + + + } /> + + + , + ); + + expect(lastSocket?.url).toBe( + 'wss://spritz.example.com/api/spritzes/example-instance/terminal?token=expired-token', + ); + + act(() => { + lastSocket?.close(); + }); + + await waitFor(() => { + expect(refreshAuthTokenForWebSocketMock).toHaveBeenCalledTimes(1); + expect(lastSocket?.url).toBe( + 'wss://spritz.example.com/api/spritzes/example-instance/terminal?token=refreshed-token', + ); + }); + }); + + it('keeps the active terminal socket when an earlier socket closes late', async () => { + deferCloseEvents = true; + setAuthToken('external-ui-token'); + + const firstConfig = resolveConfig({ + apiBaseUrl: 'https://spritz.example.com/api', + websocketBaseUrl: 'https://first.example.com/api', + auth: { + mode: 'bearer', + tokenStorageKeys: 'spritz-token', + }, + }); + + const secondConfig = resolveConfig({ + apiBaseUrl: 'https://spritz.example.com/api', + websocketBaseUrl: 'https://second.example.com/api', + auth: { + mode: 'bearer', + tokenStorageKeys: 'spritz-token', + }, + }); + + const view = render( + + + + } /> + + + , + ); + + expect(sockets[0]?.url).toBe( + 'wss://first.example.com/api/spritzes/example-instance/terminal?token=external-ui-token', + ); + + view.rerender( + + + + } /> + + + , + ); + + expect(sockets[1]?.url).toBe( + 'wss://second.example.com/api/spritzes/example-instance/terminal?token=external-ui-token', + ); + + await act(async () => { + await Promise.resolve(); + }); + + act(() => { + sockets[1]?.simulateOpen(); + }); + + act(() => { + emitTerminalData('pwd\n'); + }); + + expect(sockets[1]?.sent).toContain('pwd\n'); + }); }); diff --git a/ui/src/pages/terminal.tsx b/ui/src/pages/terminal.tsx index bf48dc7..01aba24 100644 --- a/ui/src/pages/terminal.tsx +++ b/ui/src/pages/terminal.tsx @@ -4,8 +4,9 @@ import { Terminal } from '@xterm/xterm'; import { FitAddon } from '@xterm/addon-fit'; import '@xterm/xterm/css/xterm.css'; import { useConfig } from '@/lib/config'; -import { getAuthToken, authBearerTokenParam } from '@/lib/api'; +import { getAuthToken, refreshAuthTokenForWebSocket, authBearerTokenParam } from '@/lib/api'; import { buildTerminalTheme } from '@/lib/branding'; +import { buildApiWebSocketUrl } from '@/lib/network'; import { chatPath } from '@/lib/urls'; import { cn } from '@/lib/utils'; import { Button } from '@/components/ui/button'; @@ -13,17 +14,6 @@ import { ArrowLeftIcon } from 'lucide-react'; type ConnectionStatus = 'connecting' | 'connected' | 'disconnected' | 'error'; -function buildTerminalWsUrl(apiBaseUrl: string, name: string): string { - const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; - const wsHost = window.location.host; - const apiBase = apiBaseUrl || ''; - const token = getAuthToken(); - const params = new URLSearchParams(); - if (token) params.set(authBearerTokenParam, token); - const qs = params.toString(); - return `${wsProtocol}//${wsHost}${apiBase}/spritzes/${encodeURIComponent(name)}/terminal${qs ? `?${qs}` : ''}`; -} - export function TerminalPage() { const { name } = useParams<{ name: string }>(); const config = useConfig(); @@ -37,6 +27,8 @@ export function TerminalPage() { useEffect(() => { if (!name || !terminalRef.current) return; + const instanceName = name; + let disposed = false; const term = new Terminal({ cursorBlink: true, @@ -55,13 +47,37 @@ export function TerminalPage() { xtermRef.current = term; fitAddonRef.current = fitAddon; - function connect() { + function scheduleReconnect() { + if (disposed) return; + if (reconnectTimerRef.current) clearTimeout(reconnectTimerRef.current); + term.write('\r\n\x1b[33m--- Connection closed. Reconnecting in 3s... ---\x1b[0m\r\n'); + reconnectTimerRef.current = setTimeout(() => { + void connect(); + }, 3000); + } + + async function connect(options: { allowAuthRefreshRetry?: boolean } = {}) { + if (disposed) return; + const { allowAuthRefreshRetry = true } = options; setStatus('connecting'); - const ws = new WebSocket(buildTerminalWsUrl(config.apiBaseUrl, name!)); + const bearerToken = getAuthToken(); + const ws = new WebSocket( + buildApiWebSocketUrl( + config.apiBaseUrl, + `/spritzes/${encodeURIComponent(instanceName)}/terminal`, + { + bearerToken, + bearerTokenParam: authBearerTokenParam, + websocketBaseUrl: config.websocketBaseUrl, + }, + ), + ); ws.binaryType = 'arraybuffer'; wsRef.current = ws; + let opened = false; ws.onopen = () => { + opened = true; setStatus('connected'); const dims = fitAddon.proposeDimensions(); const cols = dims?.cols ?? 80; @@ -78,9 +94,25 @@ export function TerminalPage() { }; ws.onclose = () => { + if (wsRef.current === ws) { + wsRef.current = null; + } + if (disposed) return; + if (!opened && allowAuthRefreshRetry) { + void (async () => { + const refreshed = await refreshAuthTokenForWebSocket(); + if (disposed) return; + if (refreshed.refreshed && refreshed.token) { + void connect({ allowAuthRefreshRetry: false }); + return; + } + setStatus('disconnected'); + scheduleReconnect(); + })(); + return; + } setStatus('disconnected'); - term.write('\r\n\x1b[33m--- Connection closed. Reconnecting in 3s... ---\x1b[0m\r\n'); - reconnectTimerRef.current = setTimeout(connect, 3000); + scheduleReconnect(); }; ws.onerror = () => { @@ -88,7 +120,7 @@ export function TerminalPage() { }; } - connect(); + void connect(); const inputDisposable = term.onData((data) => { if (wsRef.current?.readyState === WebSocket.OPEN) { @@ -114,6 +146,7 @@ export function TerminalPage() { window.addEventListener('resize', handleWindowResize); return () => { + disposed = true; inputDisposable.dispose(); binaryDisposable.dispose(); resizeDisposable.dispose(); @@ -125,7 +158,14 @@ export function TerminalPage() { fitAddonRef.current = null; wsRef.current = null; }; - }, [name, config.apiBaseUrl, terminalTheme.background, terminalTheme.cursor, terminalTheme.foreground]); + }, [ + name, + config.apiBaseUrl, + config.websocketBaseUrl, + terminalTheme.background, + terminalTheme.cursor, + terminalTheme.foreground, + ]); if (!name) { return (