diff --git a/packages/sdk/package.json b/packages/sdk/package.json index 3118ef3..a1cb326 100644 --- a/packages/sdk/package.json +++ b/packages/sdk/package.json @@ -35,6 +35,7 @@ "test": "vitest unit", "test:e2e": "vitest e2e", "test:e2e:realtime": "vitest --config vitest.config.e2e-realtime.ts", + "test:e2e:turn-tcp": "vitest --config vitest.config.e2e-turn-tcp.ts", "typecheck": "tsc --noEmit", "format": "biome format --write", "format:check": "biome check", diff --git a/packages/sdk/src/realtime/client.ts b/packages/sdk/src/realtime/client.ts index 17d7f38..0b205f9 100644 --- a/packages/sdk/src/realtime/client.ts +++ b/packages/sdk/src/realtime/client.ts @@ -93,9 +93,12 @@ const realTimeClientConnectOptionsSchema = z.object({ }), initialState: realTimeClientInitialStateSchema.optional(), customizeOffer: createAsyncFunctionSchema(z.function()).optional(), + rtpPort: z.number().optional(), }); export type RealTimeClientConnectOptions = Omit, "model"> & { model: ModelDefinition | CustomModelDefinition; + iceServers?: RTCIceServer[]; + iceTransportPolicy?: "tcp" | "udp" | "all"; }; export type Events = { @@ -172,7 +175,7 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { const { emitter: eventEmitter, emitOrBuffer, flush, stop } = createEventBuffer(); webrtcManager = new WebRTCManager({ - webrtcUrl: `${url}?api_key=${encodeURIComponent(apiKey)}&model=${encodeURIComponent(options.model.name)}`, + webrtcUrl: `${url}?api_key=${encodeURIComponent(apiKey)}&model=${encodeURIComponent(options.model.name)}${options.iceTransportPolicy ? `&ice_transport_policy=${encodeURIComponent(options.iceTransportPolicy)}` : ""}${options.rtpPort != null ? `&rtp_port=${options.rtpPort}` : ""}`, integration, logger, onDiagnostic: (name, data) => { @@ -194,6 +197,9 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { modelName: options.model.name, initialImage, initialPrompt, + iceServers: options.iceServers, + expectTurnConfig: !!options.iceTransportPolicy, + forceRelay: !!options.iceTransportPolicy && options.iceTransportPolicy !== "all", }); const manager = webrtcManager; diff --git a/packages/sdk/src/realtime/types.ts b/packages/sdk/src/realtime/types.ts index e1618e8..6f61c32 100644 --- a/packages/sdk/src/realtime/types.ts +++ b/packages/sdk/src/realtime/types.ts @@ -71,6 +71,13 @@ export type SessionIdMessage = { server_port: number; }; +export type TurnConfigMessage = { + type: "turn_config"; + urls: string[]; + username: string; + credential: string; +}; + export type ConnectionState = "connecting" | "connected" | "generating" | "disconnected" | "reconnecting"; // Incoming message types (from server) @@ -85,7 +92,8 @@ export type IncomingWebRTCMessage = | GenerationStartedMessage | GenerationTickMessage | GenerationEndedMessage - | SessionIdMessage; + | SessionIdMessage + | TurnConfigMessage; // Outgoing message types (to server) export type OutgoingWebRTCMessage = diff --git a/packages/sdk/src/realtime/webrtc-connection.ts b/packages/sdk/src/realtime/webrtc-connection.ts index 17479ad..6aae38c 100644 --- a/packages/sdk/src/realtime/webrtc-connection.ts +++ b/packages/sdk/src/realtime/webrtc-connection.ts @@ -11,9 +11,10 @@ import type { PromptAckMessage, SessionIdMessage, SetImageAckMessage, + TurnConfigMessage, } from "./types"; -const ICE_SERVERS: RTCIceServer[] = [{ urls: "stun:stun.l.google.com:19302" }]; +const DEFAULT_ICE_SERVERS: RTCIceServer[] = [{ urls: "stun:stun.l.google.com:19302" }]; const AVATAR_SETUP_TIMEOUT_MS = 30_000; // 30 seconds interface ConnectionCallbacks { @@ -28,6 +29,9 @@ interface ConnectionCallbacks { initialPrompt?: { text: string; enhance?: boolean }; logger?: Logger; onDiagnostic?: DiagnosticEmitter; + iceServers?: RTCIceServer[]; + expectTurnConfig?: boolean; + forceRelay?: boolean; } type WsMessageEvents = { @@ -35,6 +39,7 @@ type WsMessageEvents = { setImageAck: SetImageAckMessage; sessionId: SessionIdMessage; generationTick: GenerationTickMessage; + turnConfig: TurnConfigMessage; }; const noopDiagnostic: DiagnosticEmitter = () => {}; @@ -46,6 +51,7 @@ export class WebRTCConnection { private connectionReject: ((error: Error) => void) | null = null; private logger: Logger; private emitDiagnostic: DiagnosticEmitter; + private turnServers: RTCIceServer[] = []; state: ConnectionState = "disconnected"; websocketMessagesEmitter = mitt(); constructor(private callbacks: ConnectionCallbacks = {}) { @@ -159,6 +165,22 @@ export class WebRTCConnection { }); } + // Phase 2.5: Wait for turn_config if not yet received. + // Only wait when the server is expected to send TURN config (iceTransportPolicy was set). + // turn_config arrives during Phase 2 but may race with set_image_ack. + if (this.callbacks.expectTurnConfig && this.turnServers.length === 0) { + let turnHandler: (() => void) | null = null; + await Promise.race([ + new Promise((resolve) => { + turnHandler = () => resolve(); + this.websocketMessagesEmitter.on("turnConfig", turnHandler); + }), + new Promise((resolve) => setTimeout(resolve, 2000)), + connectAbort, + ]); + if (turnHandler) this.websocketMessagesEmitter.off("turnConfig", turnHandler); + } + // Phase 3: WebRTC handshake const handshakeStart = performance.now(); await this.setupNewPeerConnection(); @@ -254,6 +276,12 @@ export class WebRTCConnection { return; } + if (msg.type === "turn_config") { + this.turnServers = [{ urls: msg.urls, username: msg.username, credential: msg.credential }]; + this.websocketMessagesEmitter.emit("turnConfig", msg); + return; + } + // All other messages require peer connection if (!this.pc) return; @@ -411,7 +439,11 @@ export class WebRTCConnection { }); this.pc.close(); } - this.pc = new RTCPeerConnection({ iceServers: ICE_SERVERS }); + const iceServers: RTCIceServer[] = [ + ...(this.callbacks.iceServers ?? DEFAULT_ICE_SERVERS), + ...this.turnServers, + ]; + this.pc = new RTCPeerConnection({ iceServers, ...(this.callbacks.forceRelay && { iceTransportPolicy: "relay" }) }); this.setState("connecting"); if (this.localStream) { @@ -557,6 +589,7 @@ export class WebRTCConnection { this.ws?.close(); this.ws = null; this.localStream = null; + this.turnServers = []; this.setState("disconnected"); } diff --git a/packages/sdk/src/realtime/webrtc-manager.ts b/packages/sdk/src/realtime/webrtc-manager.ts index 71408fb..2ef79c9 100644 --- a/packages/sdk/src/realtime/webrtc-manager.ts +++ b/packages/sdk/src/realtime/webrtc-manager.ts @@ -19,6 +19,9 @@ export interface WebRTCConfig { modelName?: string; initialImage?: string; initialPrompt?: { text: string; enhance?: boolean }; + iceServers?: RTCIceServer[]; + expectTurnConfig?: boolean; + forceRelay?: boolean; } const PERMANENT_ERRORS = [ @@ -66,6 +69,9 @@ export class WebRTCManager { initialPrompt: config.initialPrompt, logger: this.logger, onDiagnostic: config.onDiagnostic, + iceServers: config.iceServers, + expectTurnConfig: config.expectTurnConfig, + forceRelay: config.forceRelay, }); } diff --git a/packages/sdk/tests/e2e-turn-tcp.test.ts b/packages/sdk/tests/e2e-turn-tcp.test.ts new file mode 100644 index 0000000..1cc9a21 --- /dev/null +++ b/packages/sdk/tests/e2e-turn-tcp.test.ts @@ -0,0 +1,176 @@ +declare const __DECART_API_KEY__: string; +declare const __WEBRTC_BASE_URL__: string; + +import { + createDecartClient, + type CustomModelDefinition, + type DecartSDKError, + type SelectedCandidatePairEvent, +} from "@decartai/sdk"; +import { beforeAll, describe, expect, it } from "vitest"; + +function createSyntheticStream(fps: number, width: number, height: number): MediaStream { + const canvas = document.createElement("canvas"); + canvas.width = width; + canvas.height = height; + return canvas.captureStream(fps); +} + +const BIT_INVERT_MODEL: CustomModelDefinition = { + name: "bit_invert", + urlPath: "/ws", + fps: 25, + width: 512, + height: 512, +}; + +const TURN_ICE_SERVERS: RTCIceServer[] = [ + { urls: "stun:stun.l.google.com:19302" }, + { urls: "turn:127.0.0.1:3478?transport=tcp", username: "turn", credential: "turn" }, +]; + +const TIMEOUT = 2 * 60 * 1000; // 2 minutes + +/** + * Wraps the global RTCPeerConnection so every new instance uses + * the given iceTransportPolicy. Returns a cleanup function to restore. + */ +function overrideIceTransportPolicy(policy: RTCIceTransportPolicy): () => void { + const OriginalPC = globalThis.RTCPeerConnection; + globalThis.RTCPeerConnection = class extends OriginalPC { + constructor(config?: RTCConfiguration) { + super({ ...config, iceTransportPolicy: policy }); + } + } as typeof RTCPeerConnection; + return () => { + globalThis.RTCPeerConnection = OriginalPC; + }; +} + +/** + * Collects the selectedCandidatePair diagnostic from a realtime client. + * The event is buffered during connect() and flushed via setTimeout(0) + * after connect() resolves, so registering immediately catches it. + */ +function collectSelectedCandidatePair( + realtimeClient: { on: (event: "diagnostic", handler: (e: { name: string; data: unknown }) => void) => void }, +): Promise { + return new Promise((resolve) => { + const handler = (event: { name: string; data: unknown }) => { + if (event.name === "selectedCandidatePair") { + resolve(event.data as SelectedCandidatePairEvent); + } + }; + realtimeClient.on("diagnostic", handler); + // Fallback: if the event was already emitted before we registered, resolve after a delay + setTimeout(() => resolve(null), 5000); + }); +} + +describe("TURN-TCP E2E Tests", { timeout: TIMEOUT, retry: 2 }, () => { + let apiKey: string; + let webrtcBaseUrl: string; + + beforeAll(() => { + apiKey = __DECART_API_KEY__; + webrtcBaseUrl = __WEBRTC_BASE_URL__; + if (!apiKey) { + throw new Error( + "DECART_API_KEY environment variable not set. Run with: DECART_API_KEY=your_key pnpm test:e2e:turn-tcp", + ); + } + if (!webrtcBaseUrl) { + throw new Error( + "WEBRTC_BASE_URL environment variable not set. Set it to your local k8s WebSocket URL.", + ); + } + }); + + // Requires server-side aioice TURN-TCP allocation to work (server must produce relay candidates). + // Skip until server-side TURN candidate generation is verified. + it.skip("TURN-TCP relay only (iceTransportPolicy=relay)", async () => { + const restore = overrideIceTransportPolicy("relay"); + + try { + const client = createDecartClient({ apiKey, realtimeBaseUrl: webrtcBaseUrl }); + const stream = createSyntheticStream(BIT_INVERT_MODEL.fps, BIT_INVERT_MODEL.width, BIT_INVERT_MODEL.height); + + let remoteStreamReceived = false; + + const realtimeClient = await client.realtime.connect(stream, { + model: BIT_INVERT_MODEL, + onRemoteStream: () => { + remoteStreamReceived = true; + }, + iceServers: TURN_ICE_SERVERS, + }); + + // Register diagnostic listener immediately - buffered events flush on next macrotask + const candidatePairPromise = collectSelectedCandidatePair(realtimeClient); + + const errors: DecartSDKError[] = []; + realtimeClient.on("error", (err) => errors.push(err)); + + try { + expect(["connected", "generating"]).toContain(realtimeClient.getConnectionState()); + expect(realtimeClient.sessionId).toBeTruthy(); + expect(remoteStreamReceived).toBe(true); + expect(errors).toEqual([]); + + // With relay-only policy, the selected candidate must be a relay (TURN) + const pair = await candidatePairPromise; + if (pair) { + expect(pair.local.candidateType).toBe("relay"); + } + } finally { + realtimeClient.disconnect(); + for (const track of stream.getTracks()) track.stop(); + } + + expect(realtimeClient.getConnectionState()).toBe("disconnected"); + } finally { + restore(); + } + }); + + it("Both UDP + TURN available (default iceTransportPolicy=all)", async () => { + const client = createDecartClient({ apiKey, realtimeBaseUrl: webrtcBaseUrl }); + const stream = createSyntheticStream(BIT_INVERT_MODEL.fps, BIT_INVERT_MODEL.width, BIT_INVERT_MODEL.height); + + let remoteStreamReceived = false; + + const realtimeClient = await client.realtime.connect(stream, { + model: BIT_INVERT_MODEL, + onRemoteStream: () => { + remoteStreamReceived = true; + }, + iceServers: TURN_ICE_SERVERS, + }); + + // Register diagnostic listener immediately + const candidatePairPromise = collectSelectedCandidatePair(realtimeClient); + + const errors: DecartSDKError[] = []; + realtimeClient.on("error", (err) => errors.push(err)); + + try { + expect(["connected", "generating"]).toContain(realtimeClient.getConnectionState()); + expect(realtimeClient.sessionId).toBeTruthy(); + expect(remoteStreamReceived).toBe(true); + expect(errors).toEqual([]); + + // With default policy, ICE should prefer direct UDP over relay. + // In Docker/NAT environments, the local candidate may appear as "prflx" + // (peer-reflexive) rather than "host", but it should NOT be "relay". + const pair = await candidatePairPromise; + if (pair) { + expect(pair.local.candidateType).not.toBe("relay"); + } + } finally { + realtimeClient.disconnect(); + for (const track of stream.getTracks()) track.stop(); + } + + expect(realtimeClient.getConnectionState()).toBe("disconnected"); + }); +}); diff --git a/packages/sdk/vitest.config.e2e-turn-tcp.ts b/packages/sdk/vitest.config.e2e-turn-tcp.ts new file mode 100644 index 0000000..6ebc477 --- /dev/null +++ b/packages/sdk/vitest.config.e2e-turn-tcp.ts @@ -0,0 +1,18 @@ +import { playwright } from "@vitest/browser-playwright"; +import { defineConfig } from "vitest/config"; + +export default defineConfig({ + define: { + __DECART_API_KEY__: JSON.stringify(process.env.DECART_API_KEY), + __WEBRTC_BASE_URL__: JSON.stringify(process.env.WEBRTC_BASE_URL || "wss://slim-bit-invert.dev.localhost"), + }, + test: { + include: ["tests/e2e-turn-tcp.test.ts"], + browser: { + enabled: true, + provider: playwright(), + headless: true, + instances: [{ browser: "chromium" }], + }, + }, +}); diff --git a/packages/sdk/vitest.config.ts b/packages/sdk/vitest.config.ts index 79aba29..3e9ef0d 100644 --- a/packages/sdk/vitest.config.ts +++ b/packages/sdk/vitest.config.ts @@ -2,6 +2,6 @@ import { defineConfig } from "vitest/config"; export default defineConfig({ test: { - exclude: ["tests/e2e-realtime.test.ts"], + exclude: ["tests/e2e-realtime.test.ts", "tests/e2e-turn-tcp.test.ts"], }, });