Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 305 additions & 1 deletion packages/test/src/test/ai-provider/WebBrowserProvider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const {
WebBrowser_ToolCalling,
sessions,
chatHistory,
chromeHelpers,
probe,
} = _testOnly;

Expand Down Expand Up @@ -756,6 +757,136 @@ describe("WebBrowser_StructuredGeneration validation", () => {
});
});

// --------------------------------------------------------------------------
// WebBrowser_Chat session cache (HIGH-1)
// --------------------------------------------------------------------------

/**
* Fake `LanguageModel` for chat tests. Each `create()` returns a fresh
* session whose `promptStreaming` emits one canned text snapshot per turn.
* The factory records each call's options so we can inspect what was
* passed.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
function makeFakeChatModel(repliesPerTurn: readonly string[]): any {
let turn = 0;
const sessions: Array<{ destroy: ReturnType<typeof vi.fn>; promptStreaming: ReturnType<typeof vi.fn> }> =
[];
const factory = {
availability: vi.fn().mockResolvedValue("available"),
create: vi.fn(async () => {
const promptStreaming = vi.fn(() => {
const value = repliesPerTurn[turn++] ?? "";
return new ReadableStream<string>({
start(controller) {
controller.enqueue(value);
controller.close();
},
});
});
const session = { destroy: vi.fn(), promptStreaming };
sessions.push(session);
return session;
}),
};
return { factory, sessions };
}

describe("WebBrowser_Chat session cache", () => {
const sid = "chat-test-1";
const userMsg = (text: string): ChatMessage => ({
role: "user",
content: [{ type: "text", text }],
});
const assistantMsg = (text: string): ChatMessage => ({
role: "assistant",
content: [{ type: "text", text }],
});

afterEach(() => {
sessions.deleteChromeSession?.(sid);
});

it("reuses the cached session across consecutive turns (one factory.create)", async () => {
const { factory, sessions: fakeSessions } = makeFakeChatModel(["hi back", "sure"]);
const restore = installLanguageModelGlobal(factory);
try {
const emit = vi.fn();
const turn1: ChatMessage[] = [userMsg("hi")];
await WebBrowser_TextGeneration_Unified(
{ messages: turn1 },
undefined,
new AbortController().signal,
emit,
undefined,
sid
);
// After turn 1 cache should be at messages.length + 1 == 2.
expect(_testOnly.sessions.getChromeSession(sid)?.messageCount).toBe(2);
const turn2: ChatMessage[] = [
userMsg("hi"),
assistantMsg("hi back"),
userMsg("how are you?"),
];
await WebBrowser_TextGeneration_Unified(
{ messages: turn2 },
undefined,
new AbortController().signal,
emit,
undefined,
sid
);
// Same session reused: only one factory.create call total.
expect(factory.create).toHaveBeenCalledTimes(1);
// promptStreaming called twice on the SAME session reference.
expect(fakeSessions).toHaveLength(1);
expect(fakeSessions[0]?.promptStreaming).toHaveBeenCalledTimes(2);
// After turn 2, cache watermark = messages.length + 1 = 4.
expect(_testOnly.sessions.getChromeSession(sid)?.messageCount).toBe(4);
} finally {
restore();
}
});

it("rebuilds the session when messageCount diverges (e.g. retroactive edit)", async () => {
const { factory, sessions: fakeSessions } = makeFakeChatModel(["a", "b"]);
const restore = installLanguageModelGlobal(factory);
try {
const emit = vi.fn();
await WebBrowser_TextGeneration_Unified(
{ messages: [userMsg("first")] },
undefined,
new AbortController().signal,
emit,
undefined,
sid
);
// Cache is at messageCount=2 after turn 1.
expect(_testOnly.sessions.getChromeSession(sid)?.messageCount).toBe(2);
// Now simulate a retroactive history mutation by shrinking the history:
// the caller resends a single user message (messages.length=1, so
// lastUserIdx=0 and expectedPriorCount=1), but the cache still has
// messageCount=2 from the previous turn. The mismatch (1 !== 2) forces
// the run-fn to destroy the cached session and rebuild from scratch.
await WebBrowser_TextGeneration_Unified(
{ messages: [userMsg("reset")] },
undefined,
new AbortController().signal,
emit,
undefined,
sid
);
expect(factory.create).toHaveBeenCalledTimes(2);
// First session was destroyed during the divergence rebuild.
expect(fakeSessions[0]?.destroy).toHaveBeenCalled();
// Watermark after the rebuilt turn = messages.length + 1 = 2.
expect(_testOnly.sessions.getChromeSession(sid)?.messageCount).toBe(2);
} finally {
restore();
}
});
});

// --------------------------------------------------------------------------
// ToolCalling session lifecycle
// --------------------------------------------------------------------------
Expand Down Expand Up @@ -843,7 +974,7 @@ describe("WebBrowser_ToolCalling session lifecycle", () => {
undefined,
sid
);
// Same tool set, same conversation thread → cache reuse, one create().
// Tool-calling intentionally rebuilds per turn — two creates expected.
expect(factory.create).toHaveBeenCalledTimes(2);
expect(sessions.getChromeSession(sid)).toBeUndefined();
} finally {
Expand Down Expand Up @@ -971,3 +1102,176 @@ describe("WebBrowser_ToolCalling argument validation", () => {
}
});
});

// --------------------------------------------------------------------------
// ToolCalling prototype-pollution sanitization (HIGH-2)
// --------------------------------------------------------------------------

describe("WebBrowser_ToolCalling sanitizes captured args", () => {
const looseTool: ToolDefinition = {
name: "loose",
description: "loose",
// Permissive schema so the validator doesn't reject the cleaned object.
inputSchema: { type: "object", additionalProperties: true },
};

it("strips __proto__ and constructor keys from captured tool args", async () => {
// Build a payload as if the model hallucinated a prototype-pollution attempt.
const polluted: Record<string, unknown> = { ok: true };
// Use Object.defineProperty so `__proto__` is captured as a real own key,
// not as the actual prototype link — mirrors what JSON.parse can do.
Object.defineProperty(polluted, "__proto__", {
value: { polluted: true },
enumerable: true,
configurable: true,
writable: true,
});
Object.defineProperty(polluted, "constructor", {
value: { evil: 1 },
enumerable: true,
configurable: true,
writable: true,
});
const { factory } = makeFakeToolCallingModel({ loose: polluted });
const restore = installLanguageModelGlobal(factory);
try {
const events: Array<{ type: string; port?: string; objectDelta?: unknown }> = [];
const emit = (e: unknown): void => {
events.push(e as { type: string; port?: string; objectDelta?: unknown });
};
await WebBrowser_ToolCalling(
asTCI({ prompt: "go", tools: [looseTool] }),
undefined,
new AbortController().signal,
emit
);
const tcEvent = events.find((e) => e.type === "object-delta" && e.port === "toolCalls");
const calls = (tcEvent?.objectDelta as Array<{ name: string; input: Record<string, unknown> }>) ?? [];
expect(calls).toHaveLength(1);
const input = calls[0]!.input;
// Legitimate key preserved.
expect(input.ok).toBe(true);
// Forbidden keys scrubbed.
expect(Object.prototype.hasOwnProperty.call(input, "__proto__")).toBe(false);
expect(Object.prototype.hasOwnProperty.call(input, "constructor")).toBe(false);
// Prototype is plain Object.prototype — not the tainted attacker object.
expect(Object.getPrototypeOf(input)).toBe(Object.prototype);
// And the actual Object prototype was not polluted as a side-effect.
expect(({} as Record<string, unknown>).polluted).toBeUndefined();
} finally {
restore();
}
});

it("strips forbidden keys recursively in nested objects and arrays", async () => {
const inner: Record<string, unknown> = { ok: true };
Object.defineProperty(inner, "constructor", {
value: { x: 1 },
enumerable: true,
configurable: true,
writable: true,
});
const outer: Record<string, unknown> = {
list: [inner],
};
Object.defineProperty(outer, "__proto__", {
value: { p: 1 },
enumerable: true,
configurable: true,
writable: true,
});
const payload: Record<string, unknown> = { outer };
const { factory } = makeFakeToolCallingModel({ loose: payload });
const restore = installLanguageModelGlobal(factory);
try {
const events: Array<{ type: string; port?: string; objectDelta?: unknown }> = [];
const emit = (e: unknown): void => {
events.push(e as { type: string; port?: string; objectDelta?: unknown });
};
await WebBrowser_ToolCalling(
asTCI({ prompt: "go", tools: [looseTool] }),
undefined,
new AbortController().signal,
emit
);
const tcEvent = events.find((e) => e.type === "object-delta" && e.port === "toolCalls");
const calls = (tcEvent?.objectDelta as Array<{ input: Record<string, unknown> }>) ?? [];
expect(calls).toHaveLength(1);
const input = calls[0]!.input;
const o = input.outer as Record<string, unknown>;
expect(Object.prototype.hasOwnProperty.call(o, "__proto__")).toBe(false);
expect(Object.getPrototypeOf(o)).toBe(Object.prototype);
const list = o.list as Array<Record<string, unknown>>;
expect(Array.isArray(list)).toBe(true);
expect(list).toHaveLength(1);
const first = list[0]!;
expect(first.ok).toBe(true);
expect(Object.prototype.hasOwnProperty.call(first, "constructor")).toBe(false);
expect(Object.getPrototypeOf(first)).toBe(Object.prototype);
} finally {
restore();
}
});
});

// --------------------------------------------------------------------------
// snapshotStreamToTextDeltas reset semantics (HIGH-3)
// --------------------------------------------------------------------------

/** Drain an async iterable of stream events into an array. */
async function drain<T>(it: AsyncIterable<T>): Promise<T[]> {
const out: T[] = [];
for await (const e of it) out.push(e);
return out;
}

/** Build a ReadableStream that emits the given strings in order. */
function streamOf(values: readonly string[]): ReadableStream<string> {
return new ReadableStream<string>({
start(controller) {
for (const v of values) controller.enqueue(v);
controller.close();
},
});
}

describe("snapshotStreamToTextDeltas", () => {
it("emits incremental deltas on prefix-extending snapshots", async () => {
const events = await drain(
chromeHelpers.snapshotStreamToTextDeltas(
streamOf(["hel", "hello", "hello world"]),
"text"
)
);
const deltas = events
.filter((e) => (e as { type: string }).type === "text-delta")
.map((e) => (e as { textDelta: string }).textDelta);
expect(deltas).toEqual(["hel", "lo", " world"]);
});

it("resets on a non-prefix snapshot", async () => {
const events = await drain(
chromeHelpers.snapshotStreamToTextDeltas(
streamOf(["hello world", "hello sailor"]),
"text"
)
);
const deltas = events
.filter((e) => (e as { type: string }).type === "text-delta")
.map((e) => (e as { textDelta: string }).textDelta);
expect(deltas).toEqual(["hello world", "hello sailor"]);
// No buggy concatenation anywhere in the emitted stream.
for (const d of deltas) {
expect(d).not.toContain("hello worldhello sailor");
}
});

it("does not emit an empty delta on identical snapshots", async () => {
const events = await drain(
chromeHelpers.snapshotStreamToTextDeltas(streamOf(["hi", "hi"]), "text")
);
const deltas = events.filter((e) => (e as { type: string }).type === "text-delta");
expect(deltas).toHaveLength(1);
expect((deltas[0] as { textDelta: string }).textDelta).toBe("hi");
});
});
34 changes: 18 additions & 16 deletions providers/chrome-ai/src/ai/common/WebBrowser_Chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,15 @@ export const WebBrowser_Chat: AiProviderRunFn<
throw new Error("WebBrowser_Chat: trailing user message has no text content");
}

// History the session should already have heard by the time we prompt.
// After this turn the session will additionally contain the trailing user
// turn + the assistant response we generate — i.e. `messages.length + 1`
// messages, which is the watermark we cache for the next call.
const priorHistory = messages.slice(0, lastUserIdx);
const { initialPrompts, fingerprint: historyFingerprint } =
buildInitialPromptsFromHistory(priorHistory);

// Cache hygiene: only reuse the cached session if its watermark exactly
// matches the history we'd otherwise re-feed. Out-of-sync caches (task
// reset mid-conversation, retroactive edits to `messages`) are torn down
// and rebuilt.
// Cache reuse requires: same sessionId, AND the cache's high-water mark
// equals the number of messages we expect Chrome to have heard BEFORE
// this turn (everything up to but not including the trailing user
// message). This is robust against retroactive edits to `messages` and
// against task resets that re-run from a smaller history.
let cached = sessionId ? getChromeSession(sessionId) : undefined;
if (sessionId !== undefined && cached && cached.historyFingerprint !== historyFingerprint) {
const expectedPriorCount = lastUserIdx;
if (sessionId !== undefined && cached && cached.messageCount !== expectedPriorCount) {
// History diverged — tear down the stale session and rebuild.
deleteChromeSession(sessionId);
cached = undefined;
}
Expand All @@ -76,6 +71,11 @@ export const WebBrowser_Chat: AiProviderRunFn<
if (cached) {
session = cached.session;
} else {
// Fresh session: replay all prior history via initialPrompts so the
// model has full context for the trailing user turn.
const { initialPrompts } = buildInitialPromptsFromHistory(
messages.slice(0, lastUserIdx)
);
session = await factory.create({
signal,
// `temperature` is `@deprecated` for non-extension contexts in the
Expand All @@ -89,6 +89,9 @@ export const WebBrowser_Chat: AiProviderRunFn<

let cacheWritten = false;
try {
// `promptStreaming` both runs the turn AND mutates the session's
// internal history so the next call's "prior count" is
// `messages.length + 1`.
const stream = session.promptStreaming(promptText, { signal });
for await (const e of snapshotStreamToTextDeltas<AiChatProviderOutput>(stream, "text")) {
emit(e);
Expand All @@ -99,9 +102,8 @@ export const WebBrowser_Chat: AiProviderRunFn<
// to the cache; `WebBrowserProvider.disposeSession` (wired into
// ResourceScope by AiChatTask) reclaims it at end of run.
setChromeSession(sessionId, {
session,
messageCount: messages.length + 1,
historyFingerprint,
session,
messageCount: messages.length + 1,
});
cacheWritten = true;
}
Expand Down
Loading
Loading