diff --git a/config/gni/devtools_grd_files.gni b/config/gni/devtools_grd_files.gni index 7f68872504..c917cc5797 100644 --- a/config/gni/devtools_grd_files.gni +++ b/config/gni/devtools_grd_files.gni @@ -13,6 +13,8 @@ grd_files_bundled_sources = [ "front_end/Images/3d-center.svg", "front_end/Images/3d-pan.svg", "front_end/Images/3d-rotate.svg", + "front_end/Images/browser-operator-logo.png", + "front_end/Images/demo.gif", "front_end/Images/Images.js", "front_end/Images/accelerometer-back.svg", "front_end/Images/accelerometer-bottom.png", @@ -654,6 +656,26 @@ grd_files_bundled_sources = [ "front_end/panels/ai_chat/ui/HelpDialog.js", "front_end/panels/ai_chat/ui/PromptEditDialog.js", "front_end/panels/ai_chat/ui/SettingsDialog.js", + "front_end/panels/ai_chat/ui/settings/types.js", + "front_end/panels/ai_chat/ui/settings/constants.js", + "front_end/panels/ai_chat/ui/settings/i18n-strings.js", + "front_end/panels/ai_chat/ui/settings/providerConfigs.js", + "front_end/panels/ai_chat/ui/settings/utils/validation.js", + "front_end/panels/ai_chat/ui/settings/utils/storage.js", + "front_end/panels/ai_chat/ui/settings/utils/styles.js", + "front_end/panels/ai_chat/ui/settings/components/ModelSelectorFactory.js", + "front_end/panels/ai_chat/ui/settings/components/SettingsHeader.js", + "front_end/panels/ai_chat/ui/settings/components/SettingsFooter.js", + "front_end/panels/ai_chat/ui/settings/components/AdvancedToggle.js", + "front_end/panels/ai_chat/ui/settings/providers/BaseProviderSettings.js", + "front_end/panels/ai_chat/ui/settings/providers/GenericProviderSettings.js", + "front_end/panels/ai_chat/ui/settings/providers/LiteLLMSettings.js", + "front_end/panels/ai_chat/ui/settings/providers/OpenRouterSettings.js", + "front_end/panels/ai_chat/ui/settings/advanced/BrowsingHistorySettings.js", + "front_end/panels/ai_chat/ui/settings/advanced/EvaluationSettings.js", + "front_end/panels/ai_chat/ui/settings/advanced/MCPSettings.js", + "front_end/panels/ai_chat/ui/settings/advanced/TracingSettings.js", + "front_end/panels/ai_chat/ui/settings/advanced/VectorDBSettings.js", "front_end/panels/ai_chat/ui/mcp/MCPConnectionsDialog.js", "front_end/panels/ai_chat/ui/mcp/MCPConnectorsCatalogDialog.js", "front_end/panels/ai_chat/ui/EvaluationDialog.js", @@ -682,18 +704,37 @@ grd_files_bundled_sources = [ "front_end/panels/ai_chat/core/Version.js", "front_end/panels/ai_chat/core/VersionChecker.js", "front_end/panels/ai_chat/core/LLMConfigurationManager.js", + "front_end/panels/ai_chat/core/CustomProviderManager.js", + "front_end/panels/ai_chat/guardrails/index.js", + "front_end/panels/ai_chat/guardrails/types.js", + "front_end/panels/ai_chat/guardrails/policies.js", + "front_end/panels/ai_chat/guardrails/PolicyEvaluator.js", + "front_end/panels/ai_chat/guardrails/GuardrailMiddleware.js", "front_end/panels/ai_chat/LLM/LLMTypes.js", "front_end/panels/ai_chat/LLM/LLMProvider.js", "front_end/panels/ai_chat/LLM/LLMProviderRegistry.js", "front_end/panels/ai_chat/LLM/LLMErrorHandler.js", "front_end/panels/ai_chat/LLM/LLMResponseParser.js", + "front_end/panels/ai_chat/LLM/FuzzyModelMatcher.js", "front_end/panels/ai_chat/LLM/OpenAIProvider.js", "front_end/panels/ai_chat/LLM/LiteLLMProvider.js", "front_end/panels/ai_chat/LLM/GroqProvider.js", "front_end/panels/ai_chat/LLM/OpenRouterProvider.js", "front_end/panels/ai_chat/LLM/BrowserOperatorProvider.js", + "front_end/panels/ai_chat/LLM/AnthropicProvider.js", + "front_end/panels/ai_chat/LLM/CerebrasProvider.js", + "front_end/panels/ai_chat/LLM/GenericOpenAIProvider.js", + "front_end/panels/ai_chat/LLM/GoogleAIProvider.js", "front_end/panels/ai_chat/LLM/LLMClient.js", "front_end/panels/ai_chat/LLM/MessageSanitizer.js", + "front_end/panels/ai_chat/memory/types.js", + "front_end/panels/ai_chat/memory/MemoryModule.js", + "front_end/panels/ai_chat/memory/MemoryBlockManager.js", + "front_end/panels/ai_chat/memory/MemoryAgentConfig.js", + "front_end/panels/ai_chat/memory/index.js", + "front_end/panels/ai_chat/memory/SearchMemoryTool.js", + "front_end/panels/ai_chat/memory/UpdateMemoryTool.js", + "front_end/panels/ai_chat/memory/ListMemoryBlocksTool.js", "front_end/panels/ai_chat/tools/Tools.js", "front_end/panels/ai_chat/tools/SequentialThinkingTool.js", "front_end/panels/ai_chat/tools/CombinedExtractionTool.js", @@ -721,10 +762,13 @@ grd_files_bundled_sources = [ "front_end/panels/ai_chat/tools/ExecuteCodeTool.js", "front_end/panels/ai_chat/tools/UpdateTodoTool.js", "front_end/panels/ai_chat/tools/VisualIndicatorTool.js", + "front_end/panels/ai_chat/tools/ReadabilityExtractorTool.js", "front_end/panels/ai_chat/common/utils.js", "front_end/panels/ai_chat/common/log.js", "front_end/panels/ai_chat/common/context.js", "front_end/panels/ai_chat/common/page.js", + "front_end/panels/ai_chat/utils/ContentChunker.js", + "front_end/panels/ai_chat/vendor/readability-source.js", "front_end/panels/ai_chat/core/structured_response.js", "front_end/panels/ai_chat/models/ChatTypes.js", "front_end/panels/ai_chat/ui/input/ChatInput.js", @@ -738,6 +782,7 @@ grd_files_bundled_sources = [ "front_end/panels/ai_chat/ui/message/GlobalActionsRow.js", "front_end/panels/ai_chat/ui/message/ToolResultMessage.js", "front_end/panels/ai_chat/ui/message/UserMessage.js", + "front_end/panels/ai_chat/ui/message/ApprovalRequestMessage.js", "front_end/panels/ai_chat/ui/model_selector/ModelSelector.js", "front_end/panels/ai_chat/ui/oauth/OAuthConnectPanel.js", "front_end/panels/ai_chat/ui/version/VersionBanner.js", @@ -784,6 +829,7 @@ grd_files_bundled_sources = [ "front_end/panels/ai_chat/evaluation/test-cases/schema-extractor-tests.js", "front_end/panels/ai_chat/evaluation/test-cases/streamlined-schema-extractor-tests.js", "front_end/panels/ai_chat/evaluation/test-cases/web-task-agent-tests.js", + "front_end/panels/ai_chat/evaluation/test-cases/html-to-markdown-tests.js", "front_end/panels/ai_chat/evaluation/utils/ErrorHandlingUtils.js", "front_end/panels/ai_chat/evaluation/utils/EvaluationTypes.js", "front_end/panels/ai_chat/evaluation/utils/PromptTemplates.js", diff --git a/config/gni/devtools_image_files.gni b/config/gni/devtools_image_files.gni index 8dcb26105c..3b6d896ee6 100644 --- a/config/gni/devtools_image_files.gni +++ b/config/gni/devtools_image_files.gni @@ -20,6 +20,8 @@ devtools_image_files = [ "touchCursor.png", "gdp-logo-light.png", "gdp-logo-dark.png", + "browser-operator-logo.png", + "demo.gif", ] devtools_svg_sources = [ diff --git a/front_end/Images/browser-operator-logo.png b/front_end/Images/browser-operator-logo.png new file mode 100644 index 0000000000..a97b47b597 Binary files /dev/null and b/front_end/Images/browser-operator-logo.png differ diff --git a/front_end/Images/demo.gif b/front_end/Images/demo.gif new file mode 100644 index 0000000000..799e849e33 Binary files /dev/null and b/front_end/Images/demo.gif differ diff --git a/front_end/panels/ai_chat/BUILD.gn b/front_end/panels/ai_chat/BUILD.gn index d8deb34ac0..244916bccb 100644 --- a/front_end/panels/ai_chat/BUILD.gn +++ b/front_end/panels/ai_chat/BUILD.gn @@ -24,6 +24,7 @@ devtools_module("ai_chat") { "ui/message/ModelMessage.ts", "ui/message/ToolResultMessage.ts", "ui/message/MessageCombiner.ts", + "ui/message/ApprovalRequestMessage.ts", "ui/message/StructuredResponseRender.ts", "ui/message/StructuredResponseController.ts", "ui/message/GlobalActionsRow.ts", @@ -40,6 +41,8 @@ devtools_module("ai_chat") { "ui/ToolDescriptionFormatter.ts", "ui/HelpDialog.ts", "ui/SettingsDialog.ts", + "ui/OnboardingDialog.ts", + "ui/onboardingStyles.ts", "ui/settings/types.ts", "ui/settings/constants.ts", "ui/settings/i18n-strings.ts", @@ -60,6 +63,7 @@ devtools_module("ai_chat") { "ui/settings/advanced/VectorDBSettings.ts", "ui/settings/advanced/TracingSettings.ts", "ui/settings/advanced/EvaluationSettings.ts", + "ui/settings/advanced/MemorySettings.ts", "ui/PromptEditDialog.ts", "ui/EvaluationDialog.ts", "ui/WebAppCodeViewer.ts", @@ -75,6 +79,11 @@ devtools_module("ai_chat") { "persistence/ConversationTypes.ts", "persistence/ConversationStorageManager.ts", "persistence/ConversationManager.ts", + "memory/types.ts", + "memory/MemoryModule.ts", + "memory/MemoryBlockManager.ts", + "memory/MemoryAgentConfig.ts", + "memory/index.ts", "core/Graph.ts", "core/State.ts", "core/Types.ts", @@ -95,6 +104,11 @@ devtools_module("ai_chat") { "core/ToolSurfaceProvider.ts", "core/StateGraph.ts", "core/Logger.ts", + "guardrails/index.ts", + "guardrails/types.ts", + "guardrails/policies.ts", + "guardrails/PolicyEvaluator.ts", + "guardrails/GuardrailMiddleware.ts", "core/AgentErrorHandler.ts", "core/Version.ts", "core/VersionChecker.ts", @@ -103,6 +117,7 @@ devtools_module("ai_chat") { "LLM/LLMProviderRegistry.ts", "LLM/LLMErrorHandler.ts", "LLM/LLMResponseParser.ts", + "LLM/FuzzyModelMatcher.ts", "LLM/OpenAIProvider.ts", "LLM/LiteLLMProvider.ts", "LLM/GroqProvider.ts", @@ -135,6 +150,9 @@ devtools_module("ai_chat") { "tools/DeleteFileTool.ts", "tools/ReadFileTool.ts", "tools/ListFilesTool.ts", + "memory/SearchMemoryTool.ts", + "memory/UpdateMemoryTool.ts", + "memory/ListMemoryBlocksTool.ts", "tools/UpdateTodoTool.ts", "tools/ExecuteCodeTool.ts", "tools/SequentialThinkingTool.ts", @@ -232,6 +250,7 @@ _ai_chat_sources = [ "ui/message/ModelMessage.ts", "ui/message/ToolResultMessage.ts", "ui/message/MessageCombiner.ts", + "ui/message/ApprovalRequestMessage.ts", "ui/message/StructuredResponseRender.ts", "ui/message/StructuredResponseController.ts", "ui/message/GlobalActionsRow.ts", @@ -278,6 +297,11 @@ _ai_chat_sources = [ "ui/mcp/MCPConnectorsCatalogDialog.ts", "ai_chat_impl.ts", "models/ChatTypes.ts", + "memory/types.ts", + "memory/MemoryModule.ts", + "memory/MemoryBlockManager.ts", + "memory/MemoryAgentConfig.ts", + "memory/index.ts", "core/Graph.ts", "core/State.ts", "core/Types.ts", @@ -298,6 +322,11 @@ _ai_chat_sources = [ "core/ToolSurfaceProvider.ts", "core/StateGraph.ts", "core/Logger.ts", + "guardrails/index.ts", + "guardrails/types.ts", + "guardrails/policies.ts", + "guardrails/PolicyEvaluator.ts", + "guardrails/GuardrailMiddleware.ts", "core/AgentErrorHandler.ts", "core/Version.ts", "core/VersionChecker.ts", @@ -306,6 +335,7 @@ _ai_chat_sources = [ "LLM/LLMProviderRegistry.ts", "LLM/LLMErrorHandler.ts", "LLM/LLMResponseParser.ts", + "LLM/FuzzyModelMatcher.ts", "LLM/OpenAIProvider.ts", "LLM/LiteLLMProvider.ts", "LLM/GroqProvider.ts", @@ -338,6 +368,9 @@ _ai_chat_sources = [ "tools/DeleteFileTool.ts", "tools/ReadFileTool.ts", "tools/ListFilesTool.ts", + "memory/SearchMemoryTool.ts", + "memory/UpdateMemoryTool.ts", + "memory/ListMemoryBlocksTool.ts", "tools/UpdateTodoTool.ts", "tools/ExecuteCodeTool.ts", "tools/SequentialThinkingTool.ts", @@ -497,6 +530,15 @@ ts_library("unittests") { "ui/message/__tests__/MessageCombiner.test.ts", "ui/message/__tests__/StructuredResponseController.test.ts", "LLM/__tests__/MessageSanitizer.test.ts", + "LLM/__tests__/LLMTestHelpers.ts", + "LLM/__tests__/LLMErrorHandler.test.ts", + "LLM/__tests__/LLMResponseParser.test.ts", + "LLM/__tests__/LLMClient.test.ts", + "LLM/__tests__/OpenAIProvider.test.ts", + "LLM/__tests__/AnthropicProvider.test.ts", + "LLM/__tests__/GoogleAIProvider.test.ts", + "LLM/__tests__/OpenAICompatibleProviders.test.ts", + "LLM/__tests__/LLMProviderRegistry.test.ts", "agent_framework/__tests__/AgentRunner.sanitizeToolResult.test.ts", "agent_framework/__tests__/AgentRunner.computeToolResultText.test.ts", "agent_framework/__tests__/AgentRunner.run.flows.test.ts", @@ -518,6 +560,12 @@ ts_library("unittests") { "tools/__tests__/ReadFileTool.test.ts", "tools/__tests__/ListFilesTool.test.ts", "tools/__tests__/FileStorageManager.test.ts", + "memory/__tests__/MemoryModule.test.ts", + "memory/__tests__/MemoryBlockManager.test.ts", + "memory/__tests__/SearchMemoryTool.test.ts", + "memory/__tests__/UpdateMemoryTool.test.ts", + "memory/__tests__/ListMemoryBlocksTool.test.ts", + "memory/__tests__/MemoryIntegration.test.ts", ] deps = [ diff --git a/front_end/panels/ai_chat/LLM/FuzzyModelMatcher.ts b/front_end/panels/ai_chat/LLM/FuzzyModelMatcher.ts new file mode 100644 index 0000000000..8640484b95 --- /dev/null +++ b/front_end/panels/ai_chat/LLM/FuzzyModelMatcher.ts @@ -0,0 +1,185 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/** + * Fuzzy model name matcher for finding the closest available model + * when an exact match isn't found. + */ + +/** + * Calculate Levenshtein distance between two strings + */ +function levenshteinDistance(a: string, b: string): number { + const matrix: number[][] = []; + + // Initialize matrix + for (let i = 0; i <= b.length; i++) { + matrix[i] = [i]; + } + for (let j = 0; j <= a.length; j++) { + matrix[0][j] = j; + } + + // Fill matrix + for (let i = 1; i <= b.length; i++) { + for (let j = 1; j <= a.length; j++) { + if (b.charAt(i - 1) === a.charAt(j - 1)) { + matrix[i][j] = matrix[i - 1][j - 1]; + } else { + matrix[i][j] = Math.min( + matrix[i - 1][j - 1] + 1, // substitution + matrix[i][j - 1] + 1, // insertion + matrix[i - 1][j] + 1 // deletion + ); + } + } + } + + return matrix[b.length][a.length]; +} + +/** + * Calculate similarity score between two strings (0-1) + */ +function similarity(a: string, b: string): number { + const distance = levenshteinDistance(a, b); + const maxLen = Math.max(a.length, b.length); + return maxLen === 0 ? 1 : 1 - distance / maxLen; +} + +/** + * Normalize model name for comparison by removing dates, versions, and separators + */ +function normalizeModelName(name: string): string { + return name + .toLowerCase() + .replace(/[-_]/g, '') // Remove separators + .replace(/\d{4}-?\d{2}-?\d{2}$/g, '') // Remove date suffixes (2025-04-14 or 20250514) + .replace(/\d{8}$/g, '') // Remove date suffixes without dashes + .trim(); +} + +/** + * Check if target is a prefix of candidate (case-insensitive) + */ +function isPrefixMatch(target: string, candidate: string): boolean { + const normalizedTarget = target.toLowerCase().replace(/[._]/g, '-'); + const normalizedCandidate = candidate.toLowerCase().replace(/[._]/g, '-'); + return normalizedCandidate.startsWith(normalizedTarget); +} + +/** + * Find the closest matching model from available options + * + * Matching strategy (in priority order): + * 1. Exact match - return immediately + * 2. Prefix match - if target is prefix of an available model + * 3. Normalized match - strip dates/versions and compare base names + * 4. Levenshtein similarity - if similarity > threshold, return best match + * + * @param targetModel - The model name to find a match for + * @param availableModels - Array of available model names + * @param threshold - Minimum similarity score (0-1) for fuzzy matching (default: 0.5) + * @returns The closest matching model name, or null if no good match found + */ +export function findClosestModel( + targetModel: string, + availableModels: string[], + threshold: number = 0.5 +): string | null { + if (!targetModel || availableModels.length === 0) { + return null; + } + + // 1. Exact match + if (availableModels.includes(targetModel)) { + return targetModel; + } + + // 2. Prefix match - find models where target is a prefix + const prefixMatches = availableModels.filter(model => isPrefixMatch(targetModel, model)); + if (prefixMatches.length > 0) { + // Return the shortest prefix match (most specific) + return prefixMatches.sort((a, b) => a.length - b.length)[0]; + } + + // 3. Normalized match - compare base names without dates/versions + const normalizedTarget = normalizeModelName(targetModel); + for (const model of availableModels) { + if (normalizeModelName(model) === normalizedTarget) { + return model; + } + } + + // 4. Levenshtein similarity on normalized names + let bestMatch: string | null = null; + let bestScore = 0; + + for (const model of availableModels) { + const score = similarity(normalizedTarget, normalizeModelName(model)); + if (score > bestScore && score >= threshold) { + bestScore = score; + bestMatch = model; + } + } + + return bestMatch; +} + +/** + * Find closest model with detailed match info for logging + */ +export interface FuzzyMatchResult { + match: string | null; + matchType: 'exact' | 'prefix' | 'normalized' | 'similarity' | 'none'; + score: number; +} + +export function findClosestModelWithInfo( + targetModel: string, + availableModels: string[], + threshold: number = 0.5 +): FuzzyMatchResult { + if (!targetModel || availableModels.length === 0) { + return { match: null, matchType: 'none', score: 0 }; + } + + // 1. Exact match + if (availableModels.includes(targetModel)) { + return { match: targetModel, matchType: 'exact', score: 1 }; + } + + // 2. Prefix match + const prefixMatches = availableModels.filter(model => isPrefixMatch(targetModel, model)); + if (prefixMatches.length > 0) { + const match = prefixMatches.sort((a, b) => a.length - b.length)[0]; + return { match, matchType: 'prefix', score: targetModel.length / match.length }; + } + + // 3. Normalized match + const normalizedTarget = normalizeModelName(targetModel); + for (const model of availableModels) { + if (normalizeModelName(model) === normalizedTarget) { + return { match: model, matchType: 'normalized', score: 1 }; + } + } + + // 4. Levenshtein similarity + let bestMatch: string | null = null; + let bestScore = 0; + + for (const model of availableModels) { + const score = similarity(normalizedTarget, normalizeModelName(model)); + if (score > bestScore && score >= threshold) { + bestScore = score; + bestMatch = model; + } + } + + if (bestMatch) { + return { match: bestMatch, matchType: 'similarity', score: bestScore }; + } + + return { match: null, matchType: 'none', score: 0 }; +} diff --git a/front_end/panels/ai_chat/LLM/LLMProviderRegistry.ts b/front_end/panels/ai_chat/LLM/LLMProviderRegistry.ts index 0c187a8b4d..630e9caa4c 100644 --- a/front_end/panels/ai_chat/LLM/LLMProviderRegistry.ts +++ b/front_end/panels/ai_chat/LLM/LLMProviderRegistry.ts @@ -5,6 +5,7 @@ import { createLogger } from '../core/Logger.js'; import type { LLMProviderInterface } from './LLMProvider.js'; import type { LLMProvider, ModelInfo } from './LLMTypes.js'; +import { isCustomProvider } from './LLMTypes.js'; import { OpenAIProvider } from './OpenAIProvider.js'; import { LiteLLMProvider } from './LiteLLMProvider.js'; import { GroqProvider } from './GroqProvider.js'; @@ -13,6 +14,8 @@ import { BrowserOperatorProvider } from './BrowserOperatorProvider.js'; import { CerebrasProvider } from './CerebrasProvider.js'; import { AnthropicProvider } from './AnthropicProvider.js'; import { GoogleAIProvider } from './GoogleAIProvider.js'; +import { GenericOpenAIProvider } from './GenericOpenAIProvider.js'; +import { CustomProviderManager } from '../core/CustomProviderManager.js'; const logger = createLogger('LLMProviderRegistry'); @@ -116,25 +119,39 @@ export class LLMProviderRegistry { * Create a temporary provider instance for utility operations * Used when provider isn't registered yet (e.g., during setup/validation) */ - private static createTemporaryProvider(providerType: LLMProvider): LLMProviderInterface | null { + private static createTemporaryProvider( + providerType: LLMProvider, + apiKey: string = '', + endpoint?: string + ): LLMProviderInterface | null { try { + // Handle custom providers - create GenericOpenAIProvider with config from CustomProviderManager + if (isCustomProvider(providerType)) { + const config = CustomProviderManager.getProvider(providerType); + if (!config) { + logger.warn(`Custom provider ${providerType} not found in CustomProviderManager`); + return null; + } + return new GenericOpenAIProvider(config, apiKey || undefined); + } + switch (providerType) { case 'openai': - return new OpenAIProvider(''); + return new OpenAIProvider(apiKey); case 'litellm': - return new LiteLLMProvider('', ''); + return new LiteLLMProvider(apiKey, endpoint || ''); case 'groq': - return new GroqProvider(''); + return new GroqProvider(apiKey); case 'openrouter': - return new OpenRouterProvider(''); + return new OpenRouterProvider(apiKey); case 'browseroperator': - return new BrowserOperatorProvider(null, ''); + return new BrowserOperatorProvider(null, apiKey); case 'cerebras': - return new CerebrasProvider(''); + return new CerebrasProvider(apiKey); case 'anthropic': - return new AnthropicProvider(''); + return new AnthropicProvider(apiKey); case 'googleai': - return new GoogleAIProvider(''); + return new GoogleAIProvider(apiKey); default: logger.warn(`Unknown provider type: ${providerType}`); return null; @@ -148,16 +165,23 @@ export class LLMProviderRegistry { /** * Get or create a provider instance for utility operations * Prefers registered instance, falls back to temporary instance + * @param providerType The type of provider to get/create + * @param apiKey Optional API key for temporary provider creation + * @param endpoint Optional endpoint for temporary provider creation */ - private static getOrCreateProvider(providerType: LLMProvider): LLMProviderInterface | null { + private static getOrCreateProvider( + providerType: LLMProvider, + apiKey?: string, + endpoint?: string + ): LLMProviderInterface | null { // Try to get registered provider first const registered = this.getProvider(providerType); if (registered) { return registered; } - // Fall back to creating temporary instance - return this.createTemporaryProvider(providerType); + // Fall back to creating temporary instance with provided credentials + return this.createTemporaryProvider(providerType, apiKey || '', endpoint); } /** @@ -165,6 +189,13 @@ export class LLMProviderRegistry { * Returns the localStorage keys used by the provider for credentials */ static getProviderStorageKeys(providerType: LLMProvider): {apiKey?: string; endpoint?: string; [key: string]: string | undefined} { + // Handle custom providers - they use CustomProviderManager for storage + if (isCustomProvider(providerType)) { + return { + apiKey: CustomProviderManager.getApiKeyStorageKey(providerType), + }; + } + const provider = this.getOrCreateProvider(providerType); if (!provider) { logger.warn(`Provider ${providerType} not available`); @@ -177,6 +208,11 @@ export class LLMProviderRegistry { * Get API key from localStorage for a provider */ static getProviderApiKey(providerType: LLMProvider): string { + // Handle custom providers - they use CustomProviderManager for API key storage + if (isCustomProvider(providerType)) { + return CustomProviderManager.getApiKey(providerType) || ''; + } + const keys = this.getProviderStorageKeys(providerType); if (!keys.apiKey) { return ''; @@ -296,19 +332,65 @@ export class LLMProviderRegistry { apiKey: string, endpoint?: string ): Promise { - const provider = this.getOrCreateProvider(providerType); + // Handle custom providers - check if models were manually configured + if (isCustomProvider(providerType)) { + const config = CustomProviderManager.getProvider(providerType); + if (!config) { + logger.warn(`Custom provider ${providerType} not found`); + return []; + } + + // If models were manually added by user, return them as-is + if (config.modelsManuallyAdded && config.models.length > 0) { + logger.debug(`Returning ${config.models.length} manually configured models for ${providerType}`); + return config.models.map(modelId => ({ + id: modelId, + name: modelId, + provider: providerType, + })); + } + + // Otherwise, fetch from the custom provider's API (OpenAI-compatible) + logger.debug(`Fetching models from API for custom provider ${providerType}`); + const provider = new GenericOpenAIProvider(config, apiKey || undefined); + try { + if (typeof provider.fetchModels === 'function') { + const models = await provider.fetchModels(); + return models.map((m: any) => ({ + id: m.id || m.name, + name: m.name || m.id, + provider: providerType, + ...(m.capabilities ? { capabilities: m.capabilities } : {}), + })); + } + return await provider.getModels(); + } catch (error) { + logger.error(`Failed to fetch models for custom provider ${providerType}:`, error); + throw error; + } + } + + // Built-in providers: always create a fresh provider instance with the provided credentials for testing + // Don't use getOrCreateProvider() which returns the registered instance with old/no API key + const provider = this.createTemporaryProvider(providerType, apiKey, endpoint); if (!provider) { logger.warn(`Provider ${providerType} not available`); return []; } try { - // Use the provider's fetchModels method if available - if ('fetchModels' in provider && typeof provider.fetchModels === 'function') { - return await provider.fetchModels(apiKey, endpoint); + // Use fetchModels() if available - it throws on API errors (good for validation) + // Fall back to getModels() which may swallow errors and return defaults + if (typeof (provider as any).fetchModels === 'function') { + const models = await (provider as any).fetchModels(); + // Convert to ModelInfo format if needed + return models.map((m: any) => ({ + id: m.id || m.name, + name: m.name || m.id, + provider: providerType, + ...(m.capabilities ? { capabilities: m.capabilities } : {}), + })); } - - // Fallback to getModels return await provider.getModels(); } catch (error) { logger.error(`Failed to fetch models for ${providerType}:`, error); diff --git a/front_end/panels/ai_chat/LLM/OpenAIProvider.ts b/front_end/panels/ai_chat/LLM/OpenAIProvider.ts index dbfc03b3b3..0bb88ded0f 100644 --- a/front_end/panels/ai_chat/LLM/OpenAIProvider.ts +++ b/front_end/panels/ai_chat/LLM/OpenAIProvider.ts @@ -415,11 +415,89 @@ export class OpenAIProvider extends LLMBaseProvider { return this.callWithMessages(modelName, messages, options); } + /** + * Fetch available models from OpenAI API + * This method makes an actual API call and throws on error (good for validation) + */ + async fetchModels(): Promise> { + const response = await fetch('https://api.openai.com/v1/models', { + method: 'GET', + headers: { + 'Authorization': `Bearer ${this.apiKey}`, + }, + }); + + if (!response.ok) { + const errorData = await response.json().catch(() => ({ error: { message: 'Unknown error' } })); + throw new Error(`OpenAI API error: ${response.statusText} - ${errorData?.error?.message || 'Unknown error'}`); + } + + const data = await response.json(); + + // Models to include (chat/reasoning models that work with Responses API) + const SUPPORTED_PREFIXES = ['gpt-4.1', 'gpt-4o', 'gpt-5', 'o1', 'o3', 'o4']; + + // Models to exclude (non-chat models like TTS, STT, embeddings, etc.) + const EXCLUDED_PATTERNS = ['transcribe', 'tts', 'audio', 'image', 'embedding', 'moderation', 'whisper', 'dall-e', 'realtime', 'codex', 'chat', 'search']; + + return data.data + .filter((model: any) => { + const id = model.id.toLowerCase(); + // Must start with a supported prefix + const hasPrefix = SUPPORTED_PREFIXES.some(prefix => id.startsWith(prefix)); + // Must not contain excluded patterns + const isExcluded = EXCLUDED_PATTERNS.some(pattern => id.includes(pattern)); + return hasPrefix && !isExcluded; + }) + .map((model: any) => ({ + id: model.id, + name: model.id + })); + } + /** * Get all OpenAI models supported by this provider */ async getModels(): Promise { - // Return hardcoded OpenAI models with their capabilities + try { + const models = await this.fetchModels(); + + return models.map(model => ({ + id: model.id, + name: model.name, + provider: 'openai' as LLMProvider, + capabilities: { + functionCalling: true, + reasoning: this.modelSupportsReasoning(model.id), + vision: this.modelSupportsVision(model.id), + structured: true + } + })); + } catch (error) { + logger.warn('Failed to fetch models from OpenAI API, using default list:', error); + return this.getDefaultModels(); + } + } + + /** + * Check if model supports reasoning (O-series and GPT-5) + */ + private modelSupportsReasoning(modelId: string): boolean { + return modelId.startsWith('o') || modelId.includes('gpt-5'); + } + + /** + * Check if model supports vision + */ + private modelSupportsVision(modelId: string): boolean { + // O3-mini doesn't support vision, most others do + return !modelId.includes('o3-mini'); + } + + /** + * Get default list of known OpenAI models (fallback) + */ + private getDefaultModels(): ModelInfo[] { return [ { id: 'gpt-4.1-2025-04-14', diff --git a/front_end/panels/ai_chat/LLM/__tests__/AnthropicProvider.test.ts b/front_end/panels/ai_chat/LLM/__tests__/AnthropicProvider.test.ts new file mode 100644 index 0000000000..8d9c7e224d --- /dev/null +++ b/front_end/panels/ai_chat/LLM/__tests__/AnthropicProvider.test.ts @@ -0,0 +1,648 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { AnthropicProvider } from '../AnthropicProvider.js'; +import type { LLMMessage } from '../LLMTypes.js'; +// Use global sinon provided by Karma framework +declare const sinon: typeof import('sinon'); +import { + createMockAnthropicResponse, + createMock401Response, + createMock429Response, + createMock500Response, + createLocalStorageMock, + createFastRetryConfig, + createTestMessages, + createMockToolDefinition, + STORAGE_KEYS, +} from './LLMTestHelpers.js'; + +describe('ai_chat: AnthropicProvider', () => { + let provider: AnthropicProvider; + let fetchStub: sinon.SinonStub; + let localStorageMock: ReturnType; + + const TEST_API_KEY = 'sk-ant-test-api-key-12345'; + const MESSAGES_ENDPOINT = 'https://api.anthropic.com/v1/messages'; + const MODELS_ENDPOINT = 'https://api.anthropic.com/v1/models'; + + beforeEach(() => { + provider = new AnthropicProvider(TEST_API_KEY); + localStorageMock = createLocalStorageMock({ + [STORAGE_KEYS.ANTHROPIC_API_KEY]: TEST_API_KEY, + }); + }); + + afterEach(() => { + if (fetchStub) { + fetchStub.restore(); + } + localStorageMock.restore(); + sinon.restore(); + }); + + // ============ Constructor Tests ============ + describe('constructor', () => { + it('should set provider name correctly', () => { + assert.strictEqual(provider.name, 'anthropic'); + }); + }); + + // ============ callWithMessages Tests ============ + describe('callWithMessages', () => { + it('should make POST request to correct endpoint', async () => { + const mockResponse = createMockAnthropicResponse({ text: 'Hello!' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('claude-sonnet-4-20250514', createTestMessages()); + + assert.isTrue(fetchStub.calledOnce); + const [url, options] = fetchStub.firstCall.args; + assert.strictEqual(url, MESSAGES_ENDPOINT); + assert.strictEqual(options.method, 'POST'); + }); + + it('should include x-api-key header (NOT Authorization)', async () => { + const mockResponse = createMockAnthropicResponse({ text: 'Hello!' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('claude-sonnet-4-20250514', createTestMessages()); + + const options = fetchStub.firstCall.args[1]; + assert.strictEqual(options.headers['x-api-key'], TEST_API_KEY); + assert.isUndefined(options.headers.Authorization, 'Should use x-api-key, not Authorization'); + }); + + it('should include anthropic-version header', async () => { + const mockResponse = createMockAnthropicResponse({ text: 'Hello!' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('claude-sonnet-4-20250514', createTestMessages()); + + const options = fetchStub.firstCall.args[1]; + assert.strictEqual(options.headers['anthropic-version'], '2023-06-01'); + }); + + it('should include anthropic-dangerous-direct-browser-access header', async () => { + const mockResponse = createMockAnthropicResponse({ text: 'Hello!' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('claude-sonnet-4-20250514', createTestMessages()); + + const options = fetchStub.firstCall.args[1]; + assert.strictEqual(options.headers['anthropic-dangerous-direct-browser-access'], 'true'); + }); + + it('should use "messages" array in request body', async () => { + const mockResponse = createMockAnthropicResponse({ text: 'Hello!' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('claude-sonnet-4-20250514', createTestMessages()); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + assert.isDefined(body.messages, 'Should use "messages" array for Anthropic'); + assert.isArray(body.messages); + }); + + it('should include max_tokens (required parameter)', async () => { + const mockResponse = createMockAnthropicResponse({ text: 'Hello!' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('claude-sonnet-4-20250514', createTestMessages()); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + assert.strictEqual(body.max_tokens, 4096, 'Should include max_tokens'); + }); + + it('should extract system prompt to separate parameter', async () => { + const mockResponse = createMockAnthropicResponse({ text: 'Hello!' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const messages: LLMMessage[] = [ + { role: 'system', content: 'You are a helpful assistant.' }, + { role: 'user', content: 'Hello' }, + ]; + + await provider.callWithMessages('claude-sonnet-4-20250514', messages); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + + // System should be a separate parameter + assert.strictEqual(body.system, 'You are a helpful assistant.'); + + // Messages array should NOT contain system message + const systemInMessages = body.messages.find((m: any) => m.role === 'system'); + assert.isUndefined(systemInMessages, 'System message should not be in messages array'); + }); + + it('should handle text response correctly', async () => { + const mockResponse = createMockAnthropicResponse({ text: 'The answer is 42.' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const result = await provider.callWithMessages('claude-sonnet-4-20250514', createTestMessages()); + + assert.strictEqual(result.text, 'The answer is 42.'); + }); + + it('should handle tool_use response correctly', async () => { + const mockResponse = createMockAnthropicResponse({ + functionCall: { name: 'click_element', arguments: { selector: '#btn' } }, + }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const result = await provider.callWithMessages('claude-sonnet-4-20250514', createTestMessages()); + + assert.isDefined(result.functionCall); + assert.strictEqual(result.functionCall!.name, 'click_element'); + assert.deepEqual(result.functionCall!.arguments, { selector: '#btn' }); + }); + + it('should convert tools to Anthropic format (input_schema)', async () => { + const mockResponse = createMockAnthropicResponse({ text: 'Done' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const tools = [createMockToolDefinition('click', 'Click an element')]; + await provider.callWithMessages('claude-sonnet-4-20250514', createTestMessages(), { tools }); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + + assert.isDefined(body.tools); + assert.isArray(body.tools); + // Anthropic uses input_schema instead of parameters + assert.isDefined(body.tools[0].input_schema); + assert.isDefined(body.tools[0].name); + assert.isDefined(body.tools[0].description); + }); + + it('should include temperature when provided', async () => { + const mockResponse = createMockAnthropicResponse({ text: 'Hello' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('claude-sonnet-4-20250514', createTestMessages(), { temperature: 0.7 }); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + assert.strictEqual(body.temperature, 0.7); + }); + + it('should add anthropic-beta header for reasoning', async () => { + const mockResponse = createMockAnthropicResponse({ text: 'Hello' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('claude-sonnet-4.5-20250514', createTestMessages(), { + reasoningLevel: 'high', + }); + + const options = fetchStub.firstCall.args[1]; + assert.include(options.headers['anthropic-beta'], 'interleaved-thinking'); + }); + + it('should convert assistant tool_calls to tool_use blocks', async () => { + const mockResponse = createMockAnthropicResponse({ text: 'Done' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const messages: LLMMessage[] = [ + { role: 'user', content: 'Click the button' }, + { + role: 'assistant', + tool_calls: [ + { + id: 'toolu_123', + type: 'function', + function: { name: 'click', arguments: '{"selector":"#btn"}' }, + }, + ], + }, + { + role: 'tool', + tool_call_id: 'toolu_123', + content: 'Clicked successfully', + }, + ]; + + await provider.callWithMessages('claude-sonnet-4-20250514', messages); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + + // Assistant message should have tool_use block in content + const assistantMsg = body.messages.find((m: any) => m.role === 'assistant'); + assert.isDefined(assistantMsg); + assert.isArray(assistantMsg.content); + assert.strictEqual(assistantMsg.content[0].type, 'tool_use'); + assert.strictEqual(assistantMsg.content[0].id, 'toolu_123'); + assert.strictEqual(assistantMsg.content[0].name, 'click'); + }); + + it('should convert tool results to user message with tool_result type', async () => { + const mockResponse = createMockAnthropicResponse({ text: 'Done' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const messages: LLMMessage[] = [ + { role: 'user', content: 'Click the button' }, + { + role: 'assistant', + tool_calls: [ + { + id: 'toolu_123', + type: 'function', + function: { name: 'click', arguments: '{}' }, + }, + ], + }, + { + role: 'tool', + tool_call_id: 'toolu_123', + content: 'Clicked successfully', + }, + ]; + + await provider.callWithMessages('claude-sonnet-4-20250514', messages); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + + // Tool result should be a user message with tool_result content + const toolResultMsg = body.messages.find((m: any) => + m.role === 'user' && m.content?.[0]?.type === 'tool_result' + ); + assert.isDefined(toolResultMsg, 'Should have tool_result in user message'); + assert.strictEqual(toolResultMsg.content[0].tool_use_id, 'toolu_123'); + assert.strictEqual(toolResultMsg.content[0].content, 'Clicked successfully'); + }); + + it('should handle base64 image content', async () => { + const mockResponse = createMockAnthropicResponse({ text: 'I see an image' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const messages: LLMMessage[] = [ + { + role: 'user', + content: [ + { type: 'text', text: 'What is this?' }, + { type: 'image_url', image_url: { url: 'data:image/png;base64,iVBORw0KGgo=' } }, + ], + }, + ]; + + await provider.callWithMessages('claude-sonnet-4-20250514', messages); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + + const userMsg = body.messages.find((m: any) => m.role === 'user'); + assert.isDefined(userMsg); + + // Should have image with source.type = 'base64' + const imageContent = userMsg.content.find((c: any) => c.type === 'image'); + assert.isDefined(imageContent); + assert.strictEqual(imageContent.source.type, 'base64'); + assert.strictEqual(imageContent.source.media_type, 'image/png'); + }); + + it('should handle URL image content', async () => { + const mockResponse = createMockAnthropicResponse({ text: 'I see an image' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const messages: LLMMessage[] = [ + { + role: 'user', + content: [ + { type: 'text', text: 'What is this?' }, + { type: 'image_url', image_url: { url: 'https://example.com/image.png' } }, + ], + }, + ]; + + await provider.callWithMessages('claude-sonnet-4-20250514', messages); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + + const userMsg = body.messages.find((m: any) => m.role === 'user'); + const imageContent = userMsg.content.find((c: any) => c.type === 'image'); + assert.isDefined(imageContent); + assert.strictEqual(imageContent.source.type, 'url'); + assert.strictEqual(imageContent.source.url, 'https://example.com/image.png'); + }); + }); + + // ============ call Tests ============ + describe('call', () => { + it('should build messages from prompt and systemPrompt', async () => { + const mockResponse = createMockAnthropicResponse({ text: 'Hello' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.call('claude-sonnet-4-20250514', 'User prompt', 'System prompt'); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + + assert.strictEqual(body.system, 'System prompt'); + const userMsg = body.messages.find((m: any) => m.role === 'user'); + assert.isDefined(userMsg); + }); + }); + + // ============ getModels / fetchModels Tests ============ + describe('getModels', () => { + it('should fetch models from API', async () => { + const modelsResponse = { + data: [ + { id: 'claude-sonnet-4-20250514', display_name: 'Claude Sonnet 4', type: 'model' }, + { id: 'claude-opus-4-20250514', display_name: 'Claude Opus 4', type: 'model' }, + ], + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const models = await provider.getModels(); + + assert.isArray(models); + assert.isTrue(models.length > 0); + }); + + it('should return ModelInfo array with correct structure', async () => { + const modelsResponse = { + data: [ + { id: 'claude-sonnet-4-20250514', display_name: 'Claude Sonnet 4', type: 'model' }, + ], + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const models = await provider.getModels(); + + if (models.length > 0) { + const model = models[0]; + assert.isDefined(model.id); + assert.isDefined(model.name); + assert.strictEqual(model.provider, 'anthropic'); + assert.isDefined(model.capabilities); + } + }); + + it('should fallback to defaults on API error', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').rejects(new Error('Network error')); + + const models = await provider.getModels(); + + assert.isArray(models); + assert.isTrue(models.length > 0); + }); + }); + + describe('fetchModels', () => { + it('should make GET request to models endpoint', async () => { + const modelsResponse = { + data: [{ id: 'claude-sonnet-4', display_name: 'Claude Sonnet 4', type: 'model' }], + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + await provider.fetchModels(); + + assert.isTrue(fetchStub.calledOnce); + const [url, options] = fetchStub.firstCall.args; + assert.strictEqual(url, MODELS_ENDPOINT); + assert.strictEqual(options.method, 'GET'); + }); + + it('should include proper headers in models request', async () => { + const modelsResponse = { + data: [{ id: 'claude-sonnet-4', display_name: 'Claude Sonnet 4', type: 'model' }], + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + await provider.fetchModels(); + + const options = fetchStub.firstCall.args[1]; + assert.strictEqual(options.headers['x-api-key'], TEST_API_KEY); + assert.strictEqual(options.headers['anthropic-version'], '2023-06-01'); + }); + }); + + // ============ validateCredentials Tests ============ + describe('validateCredentials', () => { + it('should return valid when API key present', () => { + const result = provider.validateCredentials(); + assert.isTrue(result.isValid); + }); + + it('should return invalid with missingItems when no API key', () => { + localStorageMock.restore(); + localStorageMock = createLocalStorageMock({}); + + const newProvider = new AnthropicProvider(''); + const result = newProvider.validateCredentials(); + + assert.isFalse(result.isValid); + assert.isDefined(result.missingItems); + assert.include(result.missingItems!, 'API Key'); + }); + }); + + // ============ getCredentialStorageKeys Tests ============ + describe('getCredentialStorageKeys', () => { + it('should return correct storage keys', () => { + const keys = provider.getCredentialStorageKeys(); + assert.strictEqual(keys.apiKey, 'ai_chat_anthropic_api_key'); + }); + }); + + // ============ testConnection Tests ============ + describe('testConnection', () => { + it('should return success on valid response', async () => { + const mockResponse = createMockAnthropicResponse({ text: 'Connection successful!' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const result = await provider.testConnection('claude-sonnet-4-20250514'); + + assert.isTrue(result.success); + assert.include(result.message, 'Successfully connected'); + }); + + it('should return failure on error', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').rejects(new Error('Connection failed')); + + const result = await provider.testConnection('claude-sonnet-4-20250514'); + + assert.isFalse(result.success); + assert.include(result.message, 'Connection failed'); + }); + }); + + // ============ Error Scenarios ============ + describe('error scenarios', () => { + it('should handle 401 Unauthorized', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + createMock401Response('anthropic') + ); + + try { + await provider.callWithMessages('claude-sonnet-4-20250514', createTestMessages(), { + retryConfig: { maxRetries: 0, baseDelayMs: 0, maxDelayMs: 0, backoffMultiplier: 1, jitterMs: 0 }, + }); + assert.fail('Should have thrown'); + } catch (error) { + assert.include((error as Error).message, 'Anthropic API error'); + } + }); + + it('should handle 429 Rate Limit', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + createMock429Response('anthropic') + ); + + try { + await provider.callWithMessages('claude-sonnet-4-20250514', createTestMessages(), { + retryConfig: { maxRetries: 0, baseDelayMs: 0, maxDelayMs: 0, backoffMultiplier: 1, jitterMs: 0 }, + }); + assert.fail('Should have thrown'); + } catch (error) { + assert.include((error as Error).message.toLowerCase(), 'rate limit'); + } + }); + + it('should handle 500 Internal Server Error', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + createMock500Response('anthropic') + ); + + try { + await provider.callWithMessages('claude-sonnet-4-20250514', createTestMessages(), { + retryConfig: { maxRetries: 0, baseDelayMs: 0, maxDelayMs: 0, backoffMultiplier: 1, jitterMs: 0 }, + }); + assert.fail('Should have thrown'); + } catch (error) { + assert.include((error as Error).message.toLowerCase(), 'server'); + } + }); + + it('should retry on transient errors', async () => { + let callCount = 0; + fetchStub = sinon.stub(globalThis, 'fetch').callsFake(async () => { + callCount++; + if (callCount < 3) { + return createMock500Response('anthropic'); + } + return new Response(JSON.stringify(createMockAnthropicResponse({ text: 'Success' })), { status: 200 }); + }); + + const result = await provider.callWithMessages('claude-sonnet-4-20250514', createTestMessages(), { + retryConfig: createFastRetryConfig(3), + }); + + assert.strictEqual(result.text, 'Success'); + assert.strictEqual(callCount, 3); + }); + }); + + // ============ Model Capability Detection ============ + describe('Model Capability Detection', () => { + it('should detect function calling support for Claude 3+ models', async () => { + const modelsResponse = { + data: [ + { id: 'claude-3-sonnet-20240229', display_name: 'Claude 3 Sonnet', type: 'model' }, + ], + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const models = await provider.getModels(); + const model = models.find(m => m.id.includes('claude-3')); + + if (model) { + assert.isTrue(model.capabilities?.functionCalling); + } + }); + + it('should detect reasoning support for Claude Sonnet 4.5', async () => { + const modelsResponse = { + data: [ + { id: 'claude-sonnet-4.5-20250514', display_name: 'Claude Sonnet 4.5', type: 'model' }, + ], + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const models = await provider.getModels(); + const model = models.find(m => m.id.includes('4.5')); + + if (model) { + assert.isTrue(model.capabilities?.reasoning); + } + }); + + it('should detect vision support for Claude 3+ models (except Haiku)', async () => { + const modelsResponse = { + data: [ + { id: 'claude-3-opus-20240229', display_name: 'Claude 3 Opus', type: 'model' }, + { id: 'claude-3-5-haiku-20241022', display_name: 'Claude 3.5 Haiku', type: 'model' }, + ], + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const models = await provider.getModels(); + const opus = models.find(m => m.id.includes('opus')); + const haiku = models.find(m => m.id.includes('haiku')); + + if (opus) { + assert.isTrue(opus.capabilities?.vision); + } + if (haiku) { + assert.isFalse(haiku.capabilities?.vision); + } + }); + }); +}); diff --git a/front_end/panels/ai_chat/LLM/__tests__/FuzzyModelMatcher.test.ts b/front_end/panels/ai_chat/LLM/__tests__/FuzzyModelMatcher.test.ts new file mode 100644 index 0000000000..19c0ea36f2 --- /dev/null +++ b/front_end/panels/ai_chat/LLM/__tests__/FuzzyModelMatcher.test.ts @@ -0,0 +1,162 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { findClosestModel, findClosestModelWithInfo } from '../FuzzyModelMatcher.js'; + +describe('ai_chat: FuzzyModelMatcher', () => { + describe('findClosestModel', () => { + describe('exact match', () => { + it('returns exact match when available', () => { + const available = ['gpt-4.1-2025-04-14', 'gpt-4.1-mini-2025-04-14']; + assert.strictEqual(findClosestModel('gpt-4.1-2025-04-14', available), 'gpt-4.1-2025-04-14'); + }); + + it('returns null for empty target', () => { + const available = ['gpt-4.1-2025-04-14']; + assert.isNull(findClosestModel('', available)); + }); + + it('returns null for empty available list', () => { + assert.isNull(findClosestModel('gpt-4.1', [])); + }); + }); + + describe('prefix match', () => { + it('matches when target is prefix of available model', () => { + const available = ['claude-sonnet-4-5-20250514', 'claude-haiku-4-5-20250514']; + assert.strictEqual(findClosestModel('claude-sonnet-4-5', available), 'claude-sonnet-4-5-20250514'); + }); + + it('matches gpt model prefix with date suffix', () => { + const available = ['gpt-4.1-2025-04-14', 'gpt-4.1-mini-2025-04-14']; + assert.strictEqual(findClosestModel('gpt-4.1', available), 'gpt-4.1-2025-04-14'); + }); + + it('returns shortest prefix match when multiple matches exist', () => { + const available = ['claude-sonnet-4', 'claude-sonnet-4-5', 'claude-sonnet-4-5-20250514']; + assert.strictEqual(findClosestModel('claude-sonnet', available), 'claude-sonnet-4'); + }); + + it('handles dot vs dash variations in prefix', () => { + const available = ['claude-sonnet-4-5-20250514']; + assert.strictEqual(findClosestModel('claude-sonnet-4.5', available), 'claude-sonnet-4-5-20250514'); + }); + }); + + describe('normalized match', () => { + it('matches models ignoring date suffix', () => { + const available = ['gemini-2.5-pro-20250514']; + // After normalization, 'gemini25pro' should match + assert.strictEqual(findClosestModel('gemini-2.5-pro', available), 'gemini-2.5-pro-20250514'); + }); + + it('matches models ignoring separators', () => { + const available = ['claude-sonnet-4-5-20250514']; + // 'claude_sonnet_4_5' normalized becomes 'claudesonnet45' + assert.strictEqual(findClosestModel('claude_sonnet_4_5', available), 'claude-sonnet-4-5-20250514'); + }); + }); + + describe('similarity match', () => { + it('matches similar model names above threshold', () => { + const available = ['gemini-2.5-pro', 'gemini-2.5-flash', 'gpt-4.1']; + // 'gemini-pro' should fuzzy match to 'gemini-2.5-pro' + const result = findClosestModel('gemini-pro', available); + assert.strictEqual(result, 'gemini-2.5-pro'); + }); + + it('returns null for dissimilar models below threshold', () => { + const available = ['gpt-4.1-2025-04-14', 'claude-sonnet-4-5-20250514']; + assert.isNull(findClosestModel('completely-different-model', available)); + }); + + it('respects custom threshold', () => { + const available = ['gpt-4.1-2025-04-14']; + // With very high threshold, even similar names won't match + assert.isNull(findClosestModel('gpt-4', available, 0.99)); + }); + }); + + describe('real-world model names', () => { + const anthropicModels = [ + 'claude-sonnet-4-5-20250514', + 'claude-sonnet-4-20250514', + 'claude-opus-4-20250514', + 'claude-haiku-4-20250514', + 'claude-3-5-sonnet-20241022', + ]; + + const googleModels = [ + 'gemini-2.5-pro', + 'gemini-2.5-flash', + 'gemini-2.0-flash', + 'gemini-1.5-pro', + ]; + + const openaiModels = [ + 'gpt-4.1-2025-04-14', + 'gpt-4.1-mini-2025-04-14', + 'gpt-4.1-nano-2025-04-14', + 'o4-mini-2025-04-16', + ]; + + it('matches Anthropic short names to full names', () => { + assert.strictEqual(findClosestModel('claude-sonnet-4-5', anthropicModels), 'claude-sonnet-4-5-20250514'); + assert.strictEqual(findClosestModel('claude-haiku-4', anthropicModels), 'claude-haiku-4-20250514'); + assert.strictEqual(findClosestModel('claude-opus-4', anthropicModels), 'claude-opus-4-20250514'); + }); + + it('matches Google AI model variations', () => { + assert.strictEqual(findClosestModel('gemini-2.5-pro', googleModels), 'gemini-2.5-pro'); + assert.strictEqual(findClosestModel('gemini-flash', googleModels), 'gemini-2.5-flash'); + }); + + it('matches OpenAI model variations', () => { + assert.strictEqual(findClosestModel('gpt-4.1', openaiModels), 'gpt-4.1-2025-04-14'); + assert.strictEqual(findClosestModel('gpt-4.1-mini', openaiModels), 'gpt-4.1-mini-2025-04-14'); + }); + }); + }); + + describe('findClosestModelWithInfo', () => { + it('returns exact match type', () => { + const available = ['gpt-4.1-2025-04-14']; + const result = findClosestModelWithInfo('gpt-4.1-2025-04-14', available); + assert.strictEqual(result.match, 'gpt-4.1-2025-04-14'); + assert.strictEqual(result.matchType, 'exact'); + assert.strictEqual(result.score, 1); + }); + + it('returns prefix match type', () => { + const available = ['claude-sonnet-4-5-20250514']; + const result = findClosestModelWithInfo('claude-sonnet-4-5', available); + assert.strictEqual(result.match, 'claude-sonnet-4-5-20250514'); + assert.strictEqual(result.matchType, 'prefix'); + assert.isAbove(result.score, 0); + }); + + it('returns normalized match type', () => { + const available = ['claude-sonnet-4-5-20250514']; + const result = findClosestModelWithInfo('claude_sonnet_4_5', available); + assert.strictEqual(result.match, 'claude-sonnet-4-5-20250514'); + assert.strictEqual(result.matchType, 'normalized'); + }); + + it('returns similarity match type', () => { + const available = ['gemini-2.5-pro']; + const result = findClosestModelWithInfo('gemini-pro', available); + assert.strictEqual(result.match, 'gemini-2.5-pro'); + assert.strictEqual(result.matchType, 'similarity'); + assert.isAbove(result.score, 0.5); + }); + + it('returns none match type when no match found', () => { + const available = ['gpt-4.1-2025-04-14']; + const result = findClosestModelWithInfo('completely-unrelated', available); + assert.isNull(result.match); + assert.strictEqual(result.matchType, 'none'); + assert.strictEqual(result.score, 0); + }); + }); +}); diff --git a/front_end/panels/ai_chat/LLM/__tests__/GoogleAIProvider.test.ts b/front_end/panels/ai_chat/LLM/__tests__/GoogleAIProvider.test.ts new file mode 100644 index 0000000000..d707a2e1ac --- /dev/null +++ b/front_end/panels/ai_chat/LLM/__tests__/GoogleAIProvider.test.ts @@ -0,0 +1,534 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { GoogleAIProvider } from '../GoogleAIProvider.js'; +import type { LLMMessage } from '../LLMTypes.js'; +// Use global sinon provided by Karma framework +declare const sinon: typeof import('sinon'); +import { + createLocalStorageMock, + createMockGoogleAIResponse, + STORAGE_KEYS, +} from './LLMTestHelpers.js'; + +describe('ai_chat: GoogleAIProvider', () => { + let provider: GoogleAIProvider; + let fetchStub: sinon.SinonStub; + let localStorageMock: ReturnType; + + beforeEach(() => { + provider = new GoogleAIProvider('test-api-key'); + localStorageMock = createLocalStorageMock({ + [STORAGE_KEYS.GOOGLEAI_API_KEY]: 'test-api-key', + }); + }); + + afterEach(() => { + if (fetchStub) { + fetchStub.restore(); + } + localStorageMock.restore(); + sinon.restore(); + }); + + // ============ API Format Tests ============ + describe('API format', () => { + it('should use correct endpoint format with API key as query param', async () => { + const mockResponse = createMockGoogleAIResponse({ text: 'Hello!' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('gemini-2.5-pro', [ + { role: 'user', content: 'Hello' } + ]); + + const fetchCall = fetchStub.firstCall; + const url = fetchCall.args[0]; + assert.include(url, 'https://generativelanguage.googleapis.com/v1beta'); + assert.include(url, 'models/gemini-2.5-pro:generateContent'); + assert.include(url, 'key=test-api-key'); + }); + + it('should normalize model name with models/ prefix', async () => { + const mockResponse = createMockGoogleAIResponse({ text: 'Hello!' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + // Test without prefix + await provider.callWithMessages('gemini-2.5-pro', [ + { role: 'user', content: 'Hello' } + ]); + + let url = fetchStub.firstCall.args[0]; + assert.include(url, 'models/gemini-2.5-pro:generateContent'); + + fetchStub.resetHistory(); + + // Test with prefix already present + await provider.callWithMessages('models/gemini-2.5-pro', [ + { role: 'user', content: 'Hello' } + ]); + + url = fetchStub.firstCall.args[0]; + assert.include(url, 'models/gemini-2.5-pro:generateContent'); + // Should not have double prefix + assert.notInclude(url, 'models/models/'); + }); + + it('should send POST request with JSON content type only', async () => { + const mockResponse = createMockGoogleAIResponse({ text: 'Hello!' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('gemini-2.5-pro', [ + { role: 'user', content: 'Hello' } + ]); + + const fetchCall = fetchStub.firstCall; + const options = fetchCall.args[1]; + assert.strictEqual(options.method, 'POST'); + assert.strictEqual(options.headers['Content-Type'], 'application/json'); + // Google AI uses API key in URL, no Authorization header + assert.isUndefined(options.headers['Authorization']); + }); + }); + + // ============ Message Conversion Tests ============ + describe('message conversion', () => { + it('should convert user messages to contents array with parts', async () => { + const mockResponse = createMockGoogleAIResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('gemini-2.5-pro', [ + { role: 'user', content: 'Hello world' } + ]); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + assert.isArray(body.contents); + assert.strictEqual(body.contents[0].role, 'user'); + assert.isArray(body.contents[0].parts); + assert.deepEqual(body.contents[0].parts[0], { text: 'Hello world' }); + }); + + it('should convert assistant messages to model role', async () => { + const mockResponse = createMockGoogleAIResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('gemini-2.5-pro', [ + { role: 'user', content: 'Hi' }, + { role: 'assistant', content: 'Hello!' }, + { role: 'user', content: 'How are you?' } + ]); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + // First should be user (after any system processing) + assert.strictEqual(body.contents[0].role, 'user'); + assert.strictEqual(body.contents[1].role, 'model'); // assistant -> model + assert.strictEqual(body.contents[2].role, 'user'); + }); + + it('should add system prompt as first user message', async () => { + const mockResponse = createMockGoogleAIResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('gemini-2.5-pro', [ + { role: 'system', content: 'You are a helpful assistant.' }, + { role: 'user', content: 'Hello' } + ]); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + // System prompt should be prepended as first user message + assert.strictEqual(body.contents[0].role, 'user'); + assert.strictEqual(body.contents[0].parts[0].text, 'You are a helpful assistant.'); + // Original user message follows + assert.strictEqual(body.contents[1].role, 'user'); + assert.strictEqual(body.contents[1].parts[0].text, 'Hello'); + }); + + it('should convert tool calls to functionCall format', async () => { + const mockResponse = createMockGoogleAIResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const messages: LLMMessage[] = [ + { role: 'user', content: 'Search for cats' }, + { + role: 'assistant', + content: '', + tool_calls: [{ + id: 'call_123', + type: 'function', + function: { + name: 'search', + arguments: JSON.stringify({ query: 'cats' }) + } + }] + }, + { + role: 'tool', + tool_call_id: 'call_123', + name: 'search', + content: JSON.stringify({ results: ['cat1', 'cat2'] }) + } + ]; + + await provider.callWithMessages('gemini-2.5-pro', messages); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + // Assistant tool call becomes model with functionCall parts + const modelMessage = body.contents[1]; + assert.strictEqual(modelMessage.role, 'model'); + assert.isDefined(modelMessage.parts[0].functionCall); + assert.strictEqual(modelMessage.parts[0].functionCall.name, 'search'); + + // Tool response becomes function role with functionResponse + const toolMessage = body.contents[2]; + assert.strictEqual(toolMessage.role, 'function'); + assert.isDefined(toolMessage.parts[0].functionResponse); + assert.strictEqual(toolMessage.parts[0].functionResponse.name, 'search'); + }); + }); + + // ============ Image Handling Tests ============ + describe('image handling', () => { + it('should convert base64 data URL to inline_data format', async () => { + const mockResponse = createMockGoogleAIResponse({ text: 'I see an image' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const imageContent = [ + { type: 'text' as const, text: 'What is this?' }, + { + type: 'image_url' as const, + image_url: { url: 'data:image/png;base64,iVBORw0KGgoAAAANS' } + } + ]; + + await provider.callWithMessages('gemini-2.5-pro', [ + { role: 'user', content: imageContent } + ]); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + const parts = body.contents[0].parts; + assert.strictEqual(parts[0].text, 'What is this?'); + assert.isDefined(parts[1].inline_data); + assert.strictEqual(parts[1].inline_data.mime_type, 'image/png'); + assert.strictEqual(parts[1].inline_data.data, 'iVBORw0KGgoAAAANS'); + }); + + it('should handle image URL (not base64) with warning', async () => { + const mockResponse = createMockGoogleAIResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const imageContent = [ + { type: 'text' as const, text: 'What is this?' }, + { + type: 'image_url' as const, + image_url: { url: 'https://example.com/image.png' } + } + ]; + + await provider.callWithMessages('gemini-2.5-pro', [ + { role: 'user', content: imageContent } + ]); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + const parts = body.contents[0].parts; + // URL images get converted to text notice + assert.include(parts[1].text, 'not supported'); + }); + }); + + // ============ Tool Conversion Tests ============ + describe('tool conversion', () => { + it('should convert OpenAI tool format to function_declarations', async () => { + const mockResponse = createMockGoogleAIResponse({ + functionCall: { name: 'search', arguments: { query: 'test' } } + }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const tools = [{ + type: 'function', + function: { + name: 'search', + description: 'Search for something', + parameters: { + type: 'object', + properties: { + query: { type: 'string' } + }, + required: ['query'] + } + } + }]; + + await provider.callWithMessages('gemini-2.5-pro', [ + { role: 'user', content: 'Search for cats' } + ], { tools }); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + assert.isDefined(body.tools); + assert.isArray(body.tools); + assert.isDefined(body.tools[0].function_declarations); + const funcDecl = body.tools[0].function_declarations[0]; + assert.strictEqual(funcDecl.name, 'search'); + assert.strictEqual(funcDecl.description, 'Search for something'); + }); + }); + + // ============ Response Processing Tests ============ + describe('response processing', () => { + it('should extract text from response', async () => { + const mockResponse = createMockGoogleAIResponse({ text: 'Hello there!' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const response = await provider.callWithMessages('gemini-2.5-pro', [ + { role: 'user', content: 'Hello' } + ]); + + assert.strictEqual(response.text, 'Hello there!'); + }); + + it('should extract function call from response', async () => { + const mockResponse = createMockGoogleAIResponse({ + functionCall: { name: 'search', arguments: { query: 'cats' } } + }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const response = await provider.callWithMessages('gemini-2.5-pro', [ + { role: 'user', content: 'Search for cats' } + ], { + tools: [{ + type: 'function', + function: { name: 'search', description: 'Search', parameters: {} } + }] + }); + + assert.isDefined(response.functionCall); + assert.strictEqual(response.functionCall!.name, 'search'); + assert.deepEqual(response.functionCall!.arguments, { query: 'cats' }); + }); + + it('should throw on empty candidates', async () => { + const emptyResponse = { candidates: [] }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(emptyResponse), { status: 200 }) + ); + + try { + await provider.callWithMessages('gemini-2.5-pro', [ + { role: 'user', content: 'Hello' } + ]); + assert.fail('Should have thrown'); + } catch (error) { + assert.include((error as Error).message, 'No candidates'); + } + }); + }); + + // ============ Generation Config Tests ============ + describe('generation config', () => { + it('should include temperature in generationConfig', async () => { + const mockResponse = createMockGoogleAIResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('gemini-2.5-pro', [ + { role: 'user', content: 'Hello' } + ], { temperature: 0.8 }); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + assert.isDefined(body.generationConfig); + assert.strictEqual(body.generationConfig.temperature, 0.8); + }); + + it('should not include generationConfig when no options provided', async () => { + const mockResponse = createMockGoogleAIResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('gemini-2.5-pro', [ + { role: 'user', content: 'Hello' } + ]); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + assert.isUndefined(body.generationConfig); + }); + }); + + // ============ Models Fetching Tests ============ + describe('getModels', () => { + it('should fetch and filter models that support generateContent', async () => { + const modelsResponse = { + models: [ + { + name: 'models/gemini-2.5-pro', + displayName: 'Gemini 2.5 Pro', + description: 'Latest model', + supportedGenerationMethods: ['generateContent', 'countTokens'] + }, + { + name: 'models/embedding-001', + displayName: 'Embedding Model', + description: 'For embeddings', + supportedGenerationMethods: ['embedContent'] + } + ] + }; + + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const models = await provider.getModels(); + + // Should only return model that supports generateContent + assert.strictEqual(models.length, 1); + assert.strictEqual(models[0].id, 'gemini-2.5-pro'); // Without models/ prefix + assert.strictEqual(models[0].name, 'Gemini 2.5 Pro'); + assert.strictEqual(models[0].provider, 'googleai'); + }); + + it('should return default models on API error', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').rejects(new Error('Network error')); + + const models = await provider.getModels(); + + assert.isArray(models); + assert.isTrue(models.length > 0); + assert.include(models.map(m => m.id), 'gemini-2.5-pro'); + }); + }); + + // ============ Error Handling Tests ============ + describe('error handling', () => { + it('should throw on API error', async () => { + const errorResponse = { + error: { + message: 'Invalid API key', + code: 401 + } + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(errorResponse), { status: 401, statusText: 'Unauthorized' }) + ); + + try { + await provider.callWithMessages('gemini-2.5-pro', [ + { role: 'user', content: 'Hello' } + ]); + assert.fail('Should have thrown'); + } catch (error) { + assert.include((error as Error).message, 'Google AI API error'); + assert.include((error as Error).message, 'Invalid API key'); + } + }); + }); + + // ============ Credential Validation Tests ============ + describe('validateCredentials', () => { + it('should return valid when API key exists', () => { + const result = provider.validateCredentials(); + + assert.isTrue(result.isValid); + assert.include(result.message, 'configured correctly'); + }); + + it('should return invalid when API key missing', () => { + localStorageMock.restore(); + localStorageMock = createLocalStorageMock({}); + + const result = provider.validateCredentials(); + + assert.isFalse(result.isValid); + assert.include(result.missingItems!, 'API Key'); + }); + }); + + // ============ Capability Detection Tests ============ + describe('capability detection', () => { + it('should detect function calling support for Gemini models', async () => { + const modelsResponse = { + models: [{ + name: 'models/gemini-2.5-flash', + displayName: 'Gemini 2.5 Flash', + supportedGenerationMethods: ['generateContent'] + }] + }; + + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const models = await provider.getModels(); + assert.isTrue(models[0].capabilities!.functionCalling); + }); + + it('should detect reasoning support for Gemini 2.x models', async () => { + const modelsResponse = { + models: [{ + name: 'models/gemini-2.5-pro', + displayName: 'Gemini 2.5 Pro', + supportedGenerationMethods: ['generateContent'] + }] + }; + + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const models = await provider.getModels(); + assert.isTrue(models[0].capabilities!.reasoning); + }); + + it('should detect vision support for non-text Gemini models', async () => { + const modelsResponse = { + models: [ + { + name: 'models/gemini-2.5-pro', + displayName: 'Gemini 2.5 Pro', + supportedGenerationMethods: ['generateContent'] + }, + { + name: 'models/gemini-text-only', + displayName: 'Gemini Text', + supportedGenerationMethods: ['generateContent'] + } + ] + }; + + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const models = await provider.getModels(); + const proModel = models.find(m => m.id === 'gemini-2.5-pro'); + const textModel = models.find(m => m.id === 'gemini-text-only'); + + assert.isTrue(proModel!.capabilities!.vision); + assert.isFalse(textModel!.capabilities!.vision); + }); + }); +}); diff --git a/front_end/panels/ai_chat/LLM/__tests__/LLMClient.test.ts b/front_end/panels/ai_chat/LLM/__tests__/LLMClient.test.ts new file mode 100644 index 0000000000..dc9956b5d4 --- /dev/null +++ b/front_end/panels/ai_chat/LLM/__tests__/LLMClient.test.ts @@ -0,0 +1,695 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { LLMClient } from '../LLMClient.js'; +import { LLMProviderRegistry } from '../LLMProviderRegistry.js'; +import type { LLMProvider } from '../LLMTypes.js'; +// Use global sinon provided by Karma framework +declare const sinon: typeof import('sinon'); +import { + createMockProvider, + createLocalStorageMock, + createMockOpenAIResponse, + createMockAnthropicResponse, + STORAGE_KEYS, +} from './LLMTestHelpers.js'; + +describe('ai_chat: LLMClient', () => { + let localStorageMock: ReturnType; + let fetchStub: sinon.SinonStub; + + // Helper to reset singleton between tests + function resetLLMClient(): void { + // Access private static instance to reset + (LLMClient as any).instance = null; + } + + beforeEach(() => { + resetLLMClient(); + LLMProviderRegistry.clear(); + localStorageMock = createLocalStorageMock({ + [STORAGE_KEYS.OPENAI_API_KEY]: 'sk-test-key', + [STORAGE_KEYS.ANTHROPIC_API_KEY]: 'sk-ant-test-key', + [STORAGE_KEYS.LITELLM_ENDPOINT]: 'http://localhost:4000', + [STORAGE_KEYS.LITELLM_API_KEY]: 'test-litellm-key', + }); + }); + + afterEach(() => { + resetLLMClient(); + LLMProviderRegistry.clear(); + localStorageMock.restore(); + if (fetchStub) { + fetchStub.restore(); + } + sinon.restore(); + }); + + // ============ Singleton Tests ============ + describe('getInstance', () => { + it('should return singleton instance', () => { + const instance1 = LLMClient.getInstance(); + const instance2 = LLMClient.getInstance(); + + assert.strictEqual(instance1, instance2); + }); + + it('should create new instance after reset', () => { + const instance1 = LLMClient.getInstance(); + resetLLMClient(); + const instance2 = LLMClient.getInstance(); + + assert.notStrictEqual(instance1, instance2); + }); + }); + + // ============ Initialization Tests ============ + describe('initialize', () => { + it('should initialize with OpenAI provider', async () => { + // Mock fetch for provider initialization + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify({ data: [] }), { status: 200 }) + ); + + const client = LLMClient.getInstance(); + await client.initialize({ + providers: [ + { provider: 'openai', apiKey: 'sk-test-key' } + ] + }); + + assert.isTrue(LLMProviderRegistry.hasProvider('openai')); + }); + + it('should initialize with multiple providers', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify({ data: [] }), { status: 200 }) + ); + + const client = LLMClient.getInstance(); + await client.initialize({ + providers: [ + { provider: 'openai', apiKey: 'sk-test-key' }, + { provider: 'anthropic', apiKey: 'sk-ant-test-key' }, + { provider: 'groq', apiKey: 'gsk-test-key' }, + ] + }); + + assert.isTrue(LLMProviderRegistry.hasProvider('openai')); + assert.isTrue(LLMProviderRegistry.hasProvider('anthropic')); + assert.isTrue(LLMProviderRegistry.hasProvider('groq')); + }); + + it('should initialize LiteLLM provider with endpoint', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify({ data: [] }), { status: 200 }) + ); + + const client = LLMClient.getInstance(); + await client.initialize({ + providers: [ + { provider: 'litellm', apiKey: 'test-key', providerURL: 'http://localhost:4000' } + ] + }); + + assert.isTrue(LLMProviderRegistry.hasProvider('litellm')); + }); + + it('should clear existing providers on re-initialization', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify({ data: [] }), { status: 200 }) + ); + + const client = LLMClient.getInstance(); + + // First initialization + await client.initialize({ + providers: [ + { provider: 'openai', apiKey: 'sk-test-key' }, + { provider: 'anthropic', apiKey: 'sk-ant-test-key' } + ] + }); + + assert.isTrue(LLMProviderRegistry.hasProvider('anthropic')); + + // Second initialization without Anthropic + await client.initialize({ + providers: [ + { provider: 'openai', apiKey: 'sk-test-key' } + ] + }); + + assert.isTrue(LLMProviderRegistry.hasProvider('openai')); + assert.isFalse(LLMProviderRegistry.hasProvider('anthropic')); + }); + + it('should skip unknown provider types', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify({ data: [] }), { status: 200 }) + ); + + const client = LLMClient.getInstance(); + await client.initialize({ + providers: [ + { provider: 'unknown_provider' as any, apiKey: 'test-key' } + ] + }); + + // Should not throw and should have no providers + const stats = client.getStats(); + assert.strictEqual(stats.providersCount, 0); + }); + }); + + // ============ Call Method Tests ============ + describe('call', () => { + it('should throw error when not initialized', async () => { + const client = LLMClient.getInstance(); + + try { + await client.call({ + provider: 'openai', + model: 'gpt-4.1', + messages: [{ role: 'user', content: 'Hello' }], + systemPrompt: 'You are a helpful assistant.' + }); + assert.fail('Should have thrown'); + } catch (error) { + assert.include((error as Error).message, 'must be initialized'); + } + }); + + it('should throw error for unavailable provider', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify({ data: [] }), { status: 200 }) + ); + + const client = LLMClient.getInstance(); + await client.initialize({ + providers: [ + { provider: 'openai', apiKey: 'sk-test-key' } + ] + }); + + try { + await client.call({ + provider: 'anthropic' as LLMProvider, + model: 'claude-sonnet-4-20250514', + messages: [{ role: 'user', content: 'Hello' }], + systemPrompt: 'You are a helpful assistant.' + }); + assert.fail('Should have thrown'); + } catch (error) { + assert.include((error as Error).message, 'not available'); + } + }); + + it('should make call with registered provider', async () => { + const mockProvider = createMockProvider({ + name: 'openai', + response: { + text: 'Hello! How can I help you?', + rawResponse: {} + } + }); + + LLMProviderRegistry.registerProvider('openai', mockProvider); + + // Mark as initialized + const client = LLMClient.getInstance(); + (client as any).initialized = true; + + const response = await client.call({ + provider: 'openai', + model: 'gpt-4.1', + messages: [{ role: 'user', content: 'Hello' }], + systemPrompt: 'You are a helpful assistant.' + }); + + assert.strictEqual(response.text, 'Hello! How can I help you?'); + assert.isTrue(mockProvider.callWithMessages.calledOnce); + }); + + it('should prepend system prompt when not present', async () => { + const mockProvider = createMockProvider({ + name: 'openai', + response: { + text: 'Response', + rawResponse: {} + } + }); + + LLMProviderRegistry.registerProvider('openai', mockProvider); + + const client = LLMClient.getInstance(); + (client as any).initialized = true; + + await client.call({ + provider: 'openai', + model: 'gpt-4.1', + messages: [{ role: 'user', content: 'Hello' }], + systemPrompt: 'You are a helpful assistant.' + }); + + // Check that system prompt was added + const callArgs = mockProvider.callWithMessages.firstCall.args; + const messages = callArgs[1]; + assert.strictEqual(messages[0].role, 'system'); + assert.strictEqual(messages[0].content, 'You are a helpful assistant.'); + }); + + it('should not duplicate system prompt when already present', async () => { + const mockProvider = createMockProvider({ + name: 'openai', + response: { + text: 'Response', + rawResponse: {} + } + }); + + LLMProviderRegistry.registerProvider('openai', mockProvider); + + const client = LLMClient.getInstance(); + (client as any).initialized = true; + + await client.call({ + provider: 'openai', + model: 'gpt-4.1', + messages: [ + { role: 'system', content: 'Existing system prompt' }, + { role: 'user', content: 'Hello' } + ], + systemPrompt: 'New system prompt' + }); + + // Check that existing system prompt is preserved + const callArgs = mockProvider.callWithMessages.firstCall.args; + const messages = callArgs[1]; + assert.strictEqual(messages[0].content, 'Existing system prompt'); + // Should have only 2 messages, not 3 + assert.strictEqual(messages.length, 2); + }); + + it('should pass tools in options', async () => { + const mockProvider = createMockProvider({ + name: 'openai', + response: { + text: '', + functionCall: { name: 'test_tool', arguments: {} }, + rawResponse: {} + } + }); + + LLMProviderRegistry.registerProvider('openai', mockProvider); + + const client = LLMClient.getInstance(); + (client as any).initialized = true; + + const tools = [{ + name: 'test_tool', + description: 'A test tool', + parameters: { type: 'object', properties: {} } + }]; + + await client.call({ + provider: 'openai', + model: 'gpt-4.1', + messages: [{ role: 'user', content: 'Use the test tool' }], + systemPrompt: 'You are a helpful assistant.', + tools + }); + + const callArgs = mockProvider.callWithMessages.firstCall.args; + const options = callArgs[2]; + assert.deepEqual(options.tools, tools); + }); + + it('should pass temperature in options', async () => { + const mockProvider = createMockProvider({ + name: 'openai', + response: { + text: 'Response', + rawResponse: {} + } + }); + + LLMProviderRegistry.registerProvider('openai', mockProvider); + + const client = LLMClient.getInstance(); + (client as any).initialized = true; + + await client.call({ + provider: 'openai', + model: 'gpt-4.1', + messages: [{ role: 'user', content: 'Hello' }], + systemPrompt: 'You are a helpful assistant.', + temperature: 0.7 + }); + + const callArgs = mockProvider.callWithMessages.firstCall.args; + const options = callArgs[2]; + assert.strictEqual(options.temperature, 0.7); + }); + }); + + // ============ getAvailableModels Tests ============ + describe('getAvailableModels', () => { + it('should throw when not initialized', async () => { + const client = LLMClient.getInstance(); + + try { + await client.getAvailableModels(); + assert.fail('Should have thrown'); + } catch (error) { + assert.include((error as Error).message, 'must be initialized'); + } + }); + + it('should return models from all providers', async () => { + const openaiProvider = createMockProvider({ + name: 'openai', + models: [{ id: 'gpt-4.1', name: 'GPT-4.1' }] + }); + const anthropicProvider = createMockProvider({ + name: 'anthropic', + models: [{ id: 'claude-sonnet-4', name: 'Claude Sonnet 4' }] + }); + + LLMProviderRegistry.registerProvider('openai', openaiProvider); + LLMProviderRegistry.registerProvider('anthropic', anthropicProvider); + + const client = LLMClient.getInstance(); + (client as any).initialized = true; + + const models = await client.getAvailableModels(); + + assert.strictEqual(models.length, 2); + const modelIds = models.map(m => m.id); + assert.include(modelIds, 'gpt-4.1'); + assert.include(modelIds, 'claude-sonnet-4'); + }); + }); + + // ============ getModelsByProvider Tests ============ + describe('getModelsByProvider', () => { + it('should return models for specific provider', async () => { + const openaiProvider = createMockProvider({ + name: 'openai', + models: [ + { id: 'gpt-4.1', name: 'GPT-4.1' }, + { id: 'gpt-4.1-mini', name: 'GPT-4.1 Mini' } + ] + }); + + LLMProviderRegistry.registerProvider('openai', openaiProvider); + + const client = LLMClient.getInstance(); + (client as any).initialized = true; + + const models = await client.getModelsByProvider('openai'); + + assert.strictEqual(models.length, 2); + }); + }); + + // ============ testConnection Tests ============ + describe('testConnection', () => { + it('should return failure for unavailable provider', async () => { + const client = LLMClient.getInstance(); + (client as any).initialized = true; + + const result = await client.testConnection('openai', 'gpt-4.1'); + + assert.isFalse(result.success); + assert.include(result.message, 'not available'); + }); + + it('should use provider testConnection if available', async () => { + const mockProvider = createMockProvider({ name: 'openai' }); + mockProvider.testConnection = sinon.stub().resolves({ + success: true, + message: 'Connection successful' + }); + + LLMProviderRegistry.registerProvider('openai', mockProvider); + + const client = LLMClient.getInstance(); + (client as any).initialized = true; + + const result = await client.testConnection('openai', 'gpt-4.1'); + + assert.isTrue(result.success); + assert.isTrue(mockProvider.testConnection.calledWith('gpt-4.1')); + }); + + it('should fallback to test call when provider has no testConnection', async () => { + const mockProvider = createMockProvider({ + name: 'openai', + response: { + text: 'OK', + rawResponse: {} + } + }); + // Remove testConnection + delete mockProvider.testConnection; + + LLMProviderRegistry.registerProvider('openai', mockProvider); + + const client = LLMClient.getInstance(); + (client as any).initialized = true; + + const result = await client.testConnection('openai', 'gpt-4.1'); + + assert.isTrue(result.success); + assert.include(result.message, 'Connected successfully'); + }); + + it('should return failure on test call error', async () => { + const mockProvider = createMockProvider({ + name: 'openai', + callError: new Error('Connection failed') + }); + delete mockProvider.testConnection; + + LLMProviderRegistry.registerProvider('openai', mockProvider); + + const client = LLMClient.getInstance(); + (client as any).initialized = true; + + const result = await client.testConnection('openai', 'gpt-4.1'); + + assert.isFalse(result.success); + assert.include(result.message, 'Connection failed'); + }); + }); + + // ============ refreshProviderModels Tests ============ + describe('refreshProviderModels', () => { + it('should refresh models for specific provider', async () => { + const mockProvider = createMockProvider({ + name: 'openai', + models: [{ id: 'gpt-4.1', name: 'GPT-4.1' }] + }); + + LLMProviderRegistry.registerProvider('openai', mockProvider); + + const client = LLMClient.getInstance(); + (client as any).initialized = true; + + await client.refreshProviderModels('openai'); + + assert.isTrue(mockProvider.getModels.called); + }); + + it('should refresh all providers when no provider specified', async () => { + const openaiProvider = createMockProvider({ name: 'openai', models: [] }); + const anthropicProvider = createMockProvider({ name: 'anthropic', models: [] }); + + LLMProviderRegistry.registerProvider('openai', openaiProvider); + LLMProviderRegistry.registerProvider('anthropic', anthropicProvider); + + const client = LLMClient.getInstance(); + (client as any).initialized = true; + + await client.refreshProviderModels(); + + assert.isTrue(openaiProvider.getModels.called); + assert.isTrue(anthropicProvider.getModels.called); + }); + }); + + // ============ registerCustomModel Tests ============ + describe('registerCustomModel', () => { + it('should save custom model to localStorage', () => { + const client = LLMClient.getInstance(); + + const modelInfo = client.registerCustomModel('my-custom-model', 'My Custom Model'); + + assert.strictEqual(modelInfo.id, 'my-custom-model'); + assert.strictEqual(modelInfo.name, 'My Custom Model'); + assert.strictEqual(modelInfo.provider, 'litellm'); + + const stored = JSON.parse(localStorageMock.store.get('ai_chat_custom_models') || '[]'); + assert.strictEqual(stored.length, 1); + assert.strictEqual(stored[0].id, 'my-custom-model'); + }); + + it('should use model ID as name when name not provided', () => { + const client = LLMClient.getInstance(); + + const modelInfo = client.registerCustomModel('my-custom-model'); + + assert.strictEqual(modelInfo.name, 'my-custom-model'); + }); + + it('should append to existing custom models', () => { + localStorageMock.store.set('ai_chat_custom_models', JSON.stringify([ + { id: 'existing-model', name: 'Existing', provider: 'litellm' } + ])); + + const client = LLMClient.getInstance(); + client.registerCustomModel('new-model', 'New Model'); + + const stored = JSON.parse(localStorageMock.store.get('ai_chat_custom_models') || '[]'); + assert.strictEqual(stored.length, 2); + }); + }); + + // ============ getStats Tests ============ + describe('getStats', () => { + it('should return initialized status and provider count', async () => { + const client = LLMClient.getInstance(); + + // Before initialization + let stats = client.getStats(); + assert.isFalse(stats.initialized); + assert.strictEqual(stats.providersCount, 0); + + // After initialization + LLMProviderRegistry.registerProvider('openai', createMockProvider({ name: 'openai' })); + (client as any).initialized = true; + + stats = client.getStats(); + assert.isTrue(stats.initialized); + assert.strictEqual(stats.providersCount, 1); + }); + }); + + // ============ Static Method Tests ============ + describe('static methods', () => { + describe('fetchLiteLLMModels', () => { + it('should delegate to LLMProviderRegistry', async () => { + const modelsResponse = { data: [{ id: 'model-1' }] }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const models = await LLMClient.fetchLiteLLMModels('test-key', 'http://localhost:4000'); + + assert.isArray(models); + }); + }); + + describe('testLiteLLMConnection', () => { + it('should delegate to LLMProviderRegistry', async () => { + const modelsResponse = { data: [{ id: 'model-1' }] }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const result = await LLMClient.testLiteLLMConnection('test-key', 'model-1', 'http://localhost:4000'); + + assert.isDefined(result.success); + assert.isDefined(result.message); + }); + }); + + describe('testBrowserOperatorConnection', () => { + it('should check health endpoint', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify({ status: 'ok' }), { status: 200 }) + ); + + const result = await LLMClient.testBrowserOperatorConnection('http://localhost:3000/v1'); + + assert.isTrue(result.success); + assert.include(result.message, 'Connected to BrowserOperator'); + assert.isTrue(fetchStub.calledWith('http://localhost:3000/health')); + }); + + it('should handle health check failure', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response('', { status: 500, statusText: 'Internal Server Error' }) + ); + + const result = await LLMClient.testBrowserOperatorConnection('http://localhost:3000/v1'); + + assert.isFalse(result.success); + assert.include(result.message, 'Health check failed'); + }); + + it('should handle network error', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').rejects(new Error('Connection refused')); + + const result = await LLMClient.testBrowserOperatorConnection('http://localhost:3000/v1'); + + assert.isFalse(result.success); + assert.include(result.message, 'Connection refused'); + }); + }); + + describe('validateProviderCredentials', () => { + it('should validate standard provider credentials', () => { + const result = LLMClient.validateProviderCredentials('openai'); + + assert.isTrue(result.isValid); + }); + + it('should fail validation for missing credentials', () => { + localStorageMock.restore(); + localStorageMock = createLocalStorageMock({}); + + const result = LLMClient.validateProviderCredentials('openai'); + + assert.isFalse(result.isValid); + assert.isDefined(result.missingItems); + }); + }); + + describe('getProviderCredentials', () => { + it('should return credentials for standard provider', () => { + const result = LLMClient.getProviderCredentials('openai'); + + assert.isTrue(result.canProceed); + assert.strictEqual(result.apiKey, 'sk-test-key'); + }); + + it('should return canProceed false for missing credentials', () => { + localStorageMock.restore(); + localStorageMock = createLocalStorageMock({}); + + const result = LLMClient.getProviderCredentials('openai'); + + assert.isFalse(result.canProceed); + assert.isNull(result.apiKey); + }); + }); + }); + + // ============ parseResponse Tests ============ + describe('parseResponse', () => { + it('should delegate to LLMResponseParser', () => { + const client = LLMClient.getInstance(); + + const response = { + text: '{"type": "final_answer", "result": "Hello"}', + rawResponse: {} + }; + + const parsed = client.parseResponse(response); + + assert.isDefined(parsed); + }); + }); +}); diff --git a/front_end/panels/ai_chat/LLM/__tests__/LLMErrorHandler.test.ts b/front_end/panels/ai_chat/LLM/__tests__/LLMErrorHandler.test.ts new file mode 100644 index 0000000000..5e5cd1a12e --- /dev/null +++ b/front_end/panels/ai_chat/LLM/__tests__/LLMErrorHandler.test.ts @@ -0,0 +1,554 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { LLMErrorClassifier, LLMRetryManager, LLMErrorUtils } from '../LLMErrorHandler.js'; +import { LLMErrorType } from '../LLMTypes.js'; +import { + createFastRetryConfig, + DEFAULT_RETRY_CONFIG, + RATE_LIMIT_RETRY_CONFIG, + NETWORK_ERROR_RETRY_CONFIG, +} from './LLMTestHelpers.js'; + +describe('ai_chat: LLMErrorHandler', () => { + // ============ LLMErrorClassifier Tests ============ + describe('LLMErrorClassifier', () => { + describe('classifyError', () => { + // Rate limit errors + it('should classify "rate limit" errors as RATE_LIMIT_ERROR', () => { + const error = new Error('Rate limit exceeded'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.RATE_LIMIT_ERROR); + }); + + it('should classify "429" errors as RATE_LIMIT_ERROR', () => { + const error = new Error('API error: 429 Too Many Requests'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.RATE_LIMIT_ERROR); + }); + + it('should classify "too many requests" errors as RATE_LIMIT_ERROR', () => { + const error = new Error('Too many requests, please slow down'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.RATE_LIMIT_ERROR); + }); + + it('should classify "quota exceeded" errors as RATE_LIMIT_ERROR', () => { + const error = new Error('Quota exceeded for the day'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.RATE_LIMIT_ERROR); + }); + + it('should classify "rate_limit_exceeded" errors as RATE_LIMIT_ERROR', () => { + const error = new Error('Error code: rate_limit_exceeded'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.RATE_LIMIT_ERROR); + }); + + // Auth errors + it('should classify "unauthorized" errors as AUTH_ERROR', () => { + const error = new Error('Unauthorized access'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.AUTH_ERROR); + }); + + it('should classify "401" errors as AUTH_ERROR', () => { + const error = new Error('HTTP 401: Authentication required'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.AUTH_ERROR); + }); + + it('should classify "403" errors as AUTH_ERROR', () => { + const error = new Error('HTTP 403: Forbidden'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.AUTH_ERROR); + }); + + it('should classify "invalid api key" errors as AUTH_ERROR', () => { + const error = new Error('Invalid API key provided'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.AUTH_ERROR); + }); + + it('should classify "authentication" errors as AUTH_ERROR', () => { + const error = new Error('Authentication failed'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.AUTH_ERROR); + }); + + it('should classify "forbidden" errors as AUTH_ERROR', () => { + const error = new Error('Access forbidden'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.AUTH_ERROR); + }); + + // Server errors + it('should classify "500" errors as SERVER_ERROR', () => { + const error = new Error('HTTP 500: Internal Server Error'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.SERVER_ERROR); + }); + + it('should classify "502" errors as SERVER_ERROR', () => { + const error = new Error('HTTP 502: Bad Gateway'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.SERVER_ERROR); + }); + + it('should classify "503" errors as SERVER_ERROR', () => { + const error = new Error('HTTP 503: Service Unavailable'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.SERVER_ERROR); + }); + + it('should classify "504" errors as NETWORK_ERROR', () => { + // 504 Gateway Timeout is classified as NETWORK_ERROR by the implementation + const error = new Error('HTTP 504: Gateway Timeout'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.NETWORK_ERROR); + }); + + it('should classify "internal server error" as SERVER_ERROR', () => { + const error = new Error('Internal server error occurred'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.SERVER_ERROR); + }); + + it('should classify "service unavailable" as SERVER_ERROR', () => { + const error = new Error('Service unavailable, try again later'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.SERVER_ERROR); + }); + + // Network errors + it('should classify "fetch" errors as NETWORK_ERROR', () => { + const error = new Error('fetch failed'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.NETWORK_ERROR); + }); + + it('should classify "network" errors as NETWORK_ERROR', () => { + const error = new Error('Network error occurred'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.NETWORK_ERROR); + }); + + it('should classify "timeout" errors as NETWORK_ERROR', () => { + const error = new Error('Request timeout'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.NETWORK_ERROR); + }); + + it('should classify "econnreset" errors as NETWORK_ERROR', () => { + const error = new Error('ECONNRESET: Connection reset'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.NETWORK_ERROR); + }); + + it('should classify "enotfound" errors as NETWORK_ERROR', () => { + const error = new Error('ENOTFOUND: DNS lookup failed'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.NETWORK_ERROR); + }); + + it('should classify "connection" errors as NETWORK_ERROR', () => { + const error = new Error('Connection refused'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.NETWORK_ERROR); + }); + + it('should classify "socket" errors as NETWORK_ERROR', () => { + const error = new Error('Socket hang up'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.NETWORK_ERROR); + }); + + it('should classify "aborted" errors as NETWORK_ERROR', () => { + const error = new Error('Request aborted'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.NETWORK_ERROR); + }); + + // JSON parse errors + it('should classify "json parsing failed" as JSON_PARSE_ERROR', () => { + const error = new Error('JSON parsing failed'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.JSON_PARSE_ERROR); + }); + + it('should classify "invalid json" as JSON_PARSE_ERROR', () => { + const error = new Error('Invalid JSON in response'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.JSON_PARSE_ERROR); + }); + + it('should classify "unexpected token" as JSON_PARSE_ERROR', () => { + const error = new Error('Unexpected token < in JSON at position 0'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.JSON_PARSE_ERROR); + }); + + it('should classify "syntaxerror" as JSON_PARSE_ERROR', () => { + const error = new Error('SyntaxError: Unexpected end of JSON input'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.JSON_PARSE_ERROR); + }); + + it('should classify "json parse" as JSON_PARSE_ERROR', () => { + const error = new Error('JSON parse error'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.JSON_PARSE_ERROR); + }); + + // Quota errors + it('should classify "insufficient quota" as QUOTA_ERROR', () => { + const error = new Error('Insufficient quota for this operation'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.QUOTA_ERROR); + }); + + it('should classify "billing" errors as QUOTA_ERROR', () => { + const error = new Error('Billing error: payment required'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.QUOTA_ERROR); + }); + + it('should classify "usage limit" as QUOTA_ERROR', () => { + const error = new Error('Usage limit reached'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.QUOTA_ERROR); + }); + + it('should classify "quota_exceeded" as QUOTA_ERROR', () => { + const error = new Error('Error: quota_exceeded'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.QUOTA_ERROR); + }); + + it('should classify "insufficient_quota" as QUOTA_ERROR', () => { + const error = new Error('insufficient_quota error'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.QUOTA_ERROR); + }); + + // Unknown errors + it('should classify unknown errors as UNKNOWN_ERROR', () => { + const error = new Error('Something went wrong'); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.UNKNOWN_ERROR); + }); + + it('should handle empty error messages', () => { + const error = new Error(''); + assert.strictEqual(LLMErrorClassifier.classifyError(error), LLMErrorType.UNKNOWN_ERROR); + }); + }); + + describe('shouldRetry', () => { + it('should return false for AUTH_ERROR', () => { + assert.isFalse(LLMErrorClassifier.shouldRetry(LLMErrorType.AUTH_ERROR)); + }); + + it('should return false for QUOTA_ERROR', () => { + assert.isFalse(LLMErrorClassifier.shouldRetry(LLMErrorType.QUOTA_ERROR)); + }); + + it('should return true for RATE_LIMIT_ERROR', () => { + assert.isTrue(LLMErrorClassifier.shouldRetry(LLMErrorType.RATE_LIMIT_ERROR)); + }); + + it('should return true for NETWORK_ERROR', () => { + assert.isTrue(LLMErrorClassifier.shouldRetry(LLMErrorType.NETWORK_ERROR)); + }); + + it('should return true for SERVER_ERROR', () => { + assert.isTrue(LLMErrorClassifier.shouldRetry(LLMErrorType.SERVER_ERROR)); + }); + + it('should return true for JSON_PARSE_ERROR', () => { + assert.isTrue(LLMErrorClassifier.shouldRetry(LLMErrorType.JSON_PARSE_ERROR)); + }); + + it('should return true for UNKNOWN_ERROR', () => { + assert.isTrue(LLMErrorClassifier.shouldRetry(LLMErrorType.UNKNOWN_ERROR)); + }); + }); + + describe('getRetryConfig', () => { + it('should return default config for unknown error types', () => { + const config = LLMErrorClassifier.getRetryConfig(LLMErrorType.UNKNOWN_ERROR); + assert.strictEqual(config.maxRetries, DEFAULT_RETRY_CONFIG.maxRetries); + assert.strictEqual(config.baseDelayMs, DEFAULT_RETRY_CONFIG.baseDelayMs); + }); + + it('should return rate limit specific config (60s delay)', () => { + const config = LLMErrorClassifier.getRetryConfig(LLMErrorType.RATE_LIMIT_ERROR); + assert.strictEqual(config.maxRetries, RATE_LIMIT_RETRY_CONFIG.maxRetries); + assert.strictEqual(config.baseDelayMs, RATE_LIMIT_RETRY_CONFIG.baseDelayMs); + assert.strictEqual(config.backoffMultiplier, 1, 'Rate limit should not use exponential backoff'); + }); + + it('should return network error specific config', () => { + const config = LLMErrorClassifier.getRetryConfig(LLMErrorType.NETWORK_ERROR); + assert.strictEqual(config.maxRetries, NETWORK_ERROR_RETRY_CONFIG.maxRetries); + assert.strictEqual(config.baseDelayMs, NETWORK_ERROR_RETRY_CONFIG.baseDelayMs); + }); + + it('should merge custom config overrides', () => { + const customConfig = { maxRetries: 5, baseDelayMs: 500 }; + const config = LLMErrorClassifier.getRetryConfig(LLMErrorType.SERVER_ERROR, customConfig); + assert.strictEqual(config.maxRetries, 5); + assert.strictEqual(config.baseDelayMs, 500); + }); + + it('should use default config for SERVER_ERROR', () => { + const config = LLMErrorClassifier.getRetryConfig(LLMErrorType.SERVER_ERROR); + assert.strictEqual(config.maxRetries, DEFAULT_RETRY_CONFIG.maxRetries); + }); + + it('should use default config for JSON_PARSE_ERROR', () => { + const config = LLMErrorClassifier.getRetryConfig(LLMErrorType.JSON_PARSE_ERROR); + assert.strictEqual(config.maxRetries, DEFAULT_RETRY_CONFIG.maxRetries); + }); + }); + }); + + // ============ LLMRetryManager Tests ============ + describe('LLMRetryManager', () => { + describe('executeWithRetry', () => { + it('should return immediately on success', async () => { + const manager = new LLMRetryManager({ enableLogging: false }); + let callCount = 0; + + const result = await manager.executeWithRetry(async () => { + callCount++; + return 'success'; + }); + + assert.strictEqual(result, 'success'); + assert.strictEqual(callCount, 1); + }); + + it('should retry on retryable errors', async function() { + this.timeout(15000); // Allow more time for retries + const manager = new LLMRetryManager({ + enableLogging: false, + defaultConfig: createFastRetryConfig(3), + }); + let callCount = 0; + + const result = await manager.executeWithRetry(async () => { + callCount++; + if (callCount < 3) { + throw new Error('Network error'); + } + return 'success'; + }, { customRetryConfig: createFastRetryConfig(3) }); + + assert.strictEqual(result, 'success'); + assert.strictEqual(callCount, 3); + }); + + it('should not retry on AUTH_ERROR', async () => { + const manager = new LLMRetryManager({ + enableLogging: false, + defaultConfig: createFastRetryConfig(3), + }); + let callCount = 0; + + try { + await manager.executeWithRetry(async () => { + callCount++; + throw new Error('401 Unauthorized'); + }); + assert.fail('Should have thrown'); + } catch (error) { + assert.strictEqual(callCount, 1, 'Should not retry AUTH_ERROR'); + assert.include((error as Error).message, 'Unauthorized'); + } + }); + + it('should not retry on QUOTA_ERROR', async () => { + const manager = new LLMRetryManager({ + enableLogging: false, + defaultConfig: createFastRetryConfig(3), + }); + let callCount = 0; + + try { + await manager.executeWithRetry(async () => { + callCount++; + throw new Error('Insufficient quota'); + }); + assert.fail('Should have thrown'); + } catch (error) { + assert.strictEqual(callCount, 1, 'Should not retry QUOTA_ERROR'); + assert.include((error as Error).message, 'quota'); + } + }); + + it('should respect maxRetries limit', async () => { + const maxRetries = 2; + const manager = new LLMRetryManager({ + enableLogging: false, + defaultConfig: createFastRetryConfig(maxRetries), + }); + let callCount = 0; + + try { + await manager.executeWithRetry(async () => { + callCount++; + throw new Error('Server error 500'); + }); + assert.fail('Should have thrown'); + } catch (error) { + // Should try once + maxRetries = 3 total calls + assert.strictEqual(callCount, maxRetries + 1); + } + }); + + it('should call onRetry callback on each retry', async function() { + this.timeout(15000); // Allow more time for retries + const retryAttempts: number[] = []; + const manager = new LLMRetryManager({ + enableLogging: false, + defaultConfig: createFastRetryConfig(3), + onRetry: (attempt) => { + retryAttempts.push(attempt); + }, + }); + let callCount = 0; + + try { + await manager.executeWithRetry(async () => { + callCount++; + throw new Error('Network error'); + }, { customRetryConfig: createFastRetryConfig(3) }); + } catch { + // Expected + } + + // Should have retry callbacks for attempts 1, 2, 3 + assert.deepEqual(retryAttempts, [1, 2, 3]); + }); + + it('should throw last error after max retries exceeded', async function() { + this.timeout(15000); // Allow more time for retries + const manager = new LLMRetryManager({ + enableLogging: false, + defaultConfig: createFastRetryConfig(2), + }); + + try { + await manager.executeWithRetry(async () => { + throw new Error('Persistent network error'); + }, { customRetryConfig: createFastRetryConfig(2) }); + assert.fail('Should have thrown'); + } catch (error) { + assert.include((error as Error).message, 'Persistent network error'); + } + }); + + it('should respect maxTotalTimeMs limit', async () => { + const manager = new LLMRetryManager({ + enableLogging: false, + defaultConfig: createFastRetryConfig(10), + maxTotalTimeMs: 50, // Very short timeout + }); + let callCount = 0; + + try { + await manager.executeWithRetry(async () => { + callCount++; + // Add a small delay to ensure time passes + await new Promise(resolve => setTimeout(resolve, 20)); + throw new Error('Network error'); + }); + assert.fail('Should have thrown'); + } catch (error) { + // Should stop before maxRetries due to time limit + assert.isBelow(callCount, 10); + } + }); + }); + + describe('simpleRetry static method', () => { + it('should work as static convenience method', async () => { + let callCount = 0; + + const result = await LLMRetryManager.simpleRetry(async () => { + callCount++; + if (callCount < 2) { + throw new Error('Network error'); + } + return 'success'; + }, createFastRetryConfig(3)); + + assert.strictEqual(result, 'success'); + assert.strictEqual(callCount, 2); + }); + }); + }); + + // ============ LLMErrorUtils Tests ============ + describe('LLMErrorUtils', () => { + describe('isRetryable', () => { + it('should return true for retryable errors', () => { + assert.isTrue(LLMErrorUtils.isRetryable(new Error('Network error'))); + assert.isTrue(LLMErrorUtils.isRetryable(new Error('Server error 500'))); + assert.isTrue(LLMErrorUtils.isRetryable(new Error('Rate limit exceeded'))); + }); + + it('should return false for non-retryable errors', () => { + assert.isFalse(LLMErrorUtils.isRetryable(new Error('401 Unauthorized'))); + assert.isFalse(LLMErrorUtils.isRetryable(new Error('Insufficient quota'))); + }); + }); + + describe('getErrorMessage', () => { + it('should return user-friendly message for RATE_LIMIT_ERROR', () => { + const error = new Error('Rate limit exceeded'); + const message = LLMErrorUtils.getErrorMessage(error); + assert.include(message.toLowerCase(), 'rate limit'); + assert.include(message.toLowerCase(), 'wait'); + }); + + it('should return user-friendly message for AUTH_ERROR', () => { + const error = new Error('401 Unauthorized'); + const message = LLMErrorUtils.getErrorMessage(error); + assert.include(message.toLowerCase(), 'authentication'); + assert.include(message.toLowerCase(), 'api key'); + }); + + it('should return user-friendly message for NETWORK_ERROR', () => { + const error = new Error('Network error'); + const message = LLMErrorUtils.getErrorMessage(error); + assert.include(message.toLowerCase(), 'network'); + assert.include(message.toLowerCase(), 'connection'); + }); + + it('should return user-friendly message for SERVER_ERROR', () => { + const error = new Error('Server error 500'); + const message = LLMErrorUtils.getErrorMessage(error); + assert.include(message.toLowerCase(), 'server'); + assert.include(message.toLowerCase(), 'unavailable'); + }); + + it('should return user-friendly message for QUOTA_ERROR', () => { + const error = new Error('Insufficient quota'); + const message = LLMErrorUtils.getErrorMessage(error); + assert.include(message.toLowerCase(), 'quota'); + }); + + it('should return user-friendly message for JSON_PARSE_ERROR', () => { + const error = new Error('JSON parsing failed'); + const message = LLMErrorUtils.getErrorMessage(error); + assert.include(message.toLowerCase(), 'parse'); + }); + + it('should return original message for UNKNOWN_ERROR', () => { + const error = new Error('Something specific went wrong'); + const message = LLMErrorUtils.getErrorMessage(error); + assert.include(message, 'Something specific went wrong'); + }); + }); + + describe('enhanceError', () => { + it('should add operation context to error', () => { + const error = new Error('Network error'); + const enhanced = LLMErrorUtils.enhanceError(error, { operation: 'OpenAI call' }); + assert.include(enhanced.message, 'OpenAI call'); + assert.include(enhanced.message, 'NETWORK_ERROR'); + }); + + it('should add attempt number to error', () => { + const error = new Error('Server error 500'); + const enhanced = LLMErrorUtils.enhanceError(error, { attempt: 3 }); + assert.include(enhanced.message, 'attempt 3'); + }); + + it('should preserve original error', () => { + const error = new Error('Original error'); + const enhanced = LLMErrorUtils.enhanceError(error, { operation: 'test' }); + assert.strictEqual((enhanced as any).originalError, error); + }); + + it('should add error type to enhanced error', () => { + const error = new Error('Rate limit exceeded'); + const enhanced = LLMErrorUtils.enhanceError(error, { operation: 'test' }); + assert.strictEqual((enhanced as any).errorType, LLMErrorType.RATE_LIMIT_ERROR); + }); + + it('should include operation and attempt context', () => { + const error = new Error('Network error'); + const enhanced = LLMErrorUtils.enhanceError(error, { operation: 'fetchModels', attempt: 2 }); + assert.deepEqual((enhanced as any).context, { operation: 'fetchModels', attempt: 2 }); + }); + }); + }); +}); diff --git a/front_end/panels/ai_chat/LLM/__tests__/LLMProviderRegistry.test.ts b/front_end/panels/ai_chat/LLM/__tests__/LLMProviderRegistry.test.ts new file mode 100644 index 0000000000..d1c3f1080f --- /dev/null +++ b/front_end/panels/ai_chat/LLM/__tests__/LLMProviderRegistry.test.ts @@ -0,0 +1,441 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { LLMProviderRegistry } from '../LLMProviderRegistry.js'; +import type { LLMProvider } from '../LLMTypes.js'; +// Use global sinon provided by Karma framework +declare const sinon: typeof import('sinon'); +import { + createMockProvider, + createLocalStorageMock, + createMockOpenAIResponse, + STORAGE_KEYS, +} from './LLMTestHelpers.js'; + +describe('ai_chat: LLMProviderRegistry', () => { + let localStorageMock: ReturnType; + let fetchStub: sinon.SinonStub; + + beforeEach(() => { + LLMProviderRegistry.clear(); + localStorageMock = createLocalStorageMock({ + [STORAGE_KEYS.OPENAI_API_KEY]: 'sk-test-key', + [STORAGE_KEYS.ANTHROPIC_API_KEY]: 'sk-ant-test-key', + [STORAGE_KEYS.LITELLM_ENDPOINT]: 'http://localhost:4000', + [STORAGE_KEYS.LITELLM_API_KEY]: 'test-litellm-key', + }); + }); + + afterEach(() => { + LLMProviderRegistry.clear(); + localStorageMock.restore(); + if (fetchStub) { + fetchStub.restore(); + } + sinon.restore(); + }); + + // ============ registerProvider / getProvider Tests ============ + describe('registerProvider / getProvider', () => { + it('should register and retrieve provider', () => { + const mockProvider = createMockProvider({ name: 'openai' }); + LLMProviderRegistry.registerProvider('openai', mockProvider); + + const retrieved = LLMProviderRegistry.getProvider('openai'); + assert.strictEqual(retrieved, mockProvider); + }); + + it('should overwrite existing provider', () => { + const provider1 = createMockProvider({ name: 'openai' }); + const provider2 = createMockProvider({ name: 'openai' }); + + LLMProviderRegistry.registerProvider('openai', provider1); + LLMProviderRegistry.registerProvider('openai', provider2); + + const retrieved = LLMProviderRegistry.getProvider('openai'); + assert.strictEqual(retrieved, provider2); + }); + + it('should return undefined for unregistered provider', () => { + const retrieved = LLMProviderRegistry.getProvider('openai'); + assert.isUndefined(retrieved); + }); + }); + + // ============ hasProvider Tests ============ + describe('hasProvider', () => { + it('should return true for registered provider', () => { + const mockProvider = createMockProvider({ name: 'openai' }); + LLMProviderRegistry.registerProvider('openai', mockProvider); + + assert.isTrue(LLMProviderRegistry.hasProvider('openai')); + }); + + it('should return false for unregistered provider', () => { + assert.isFalse(LLMProviderRegistry.hasProvider('openai')); + }); + }); + + // ============ getAllModels Tests ============ + describe('getAllModels', () => { + it('should aggregate models from all providers', async () => { + const openaiProvider = createMockProvider({ + name: 'openai', + models: [ + { id: 'gpt-4.1', name: 'GPT-4.1' }, + { id: 'gpt-4.1-mini', name: 'GPT-4.1 Mini' }, + ], + }); + const anthropicProvider = createMockProvider({ + name: 'anthropic', + models: [ + { id: 'claude-sonnet-4', name: 'Claude Sonnet 4' }, + ], + }); + + LLMProviderRegistry.registerProvider('openai', openaiProvider); + LLMProviderRegistry.registerProvider('anthropic', anthropicProvider); + + const allModels = await LLMProviderRegistry.getAllModels(); + + assert.strictEqual(allModels.length, 3); + const modelIds = allModels.map(m => m.id); + assert.include(modelIds, 'gpt-4.1'); + assert.include(modelIds, 'gpt-4.1-mini'); + assert.include(modelIds, 'claude-sonnet-4'); + }); + + it('should handle provider errors gracefully', async () => { + const goodProvider = createMockProvider({ + name: 'openai', + models: [{ id: 'gpt-4.1', name: 'GPT-4.1' }], + }); + const badProvider = createMockProvider({ + name: 'anthropic', + callError: new Error('API error'), + }); + // Override getModels to throw + badProvider.getModels = async () => { throw new Error('API error'); }; + + LLMProviderRegistry.registerProvider('openai', goodProvider); + LLMProviderRegistry.registerProvider('anthropic', badProvider); + + const allModels = await LLMProviderRegistry.getAllModels(); + + // Should still return models from good provider + assert.strictEqual(allModels.length, 1); + assert.strictEqual(allModels[0].id, 'gpt-4.1'); + }); + + it('should return empty array when no providers registered', async () => { + const allModels = await LLMProviderRegistry.getAllModels(); + assert.deepEqual(allModels, []); + }); + }); + + // ============ getModelsByProvider Tests ============ + describe('getModelsByProvider', () => { + it('should return models for specific provider', async () => { + const openaiProvider = createMockProvider({ + name: 'openai', + models: [ + { id: 'gpt-4.1', name: 'GPT-4.1' }, + ], + }); + LLMProviderRegistry.registerProvider('openai', openaiProvider); + + const models = await LLMProviderRegistry.getModelsByProvider('openai'); + + assert.strictEqual(models.length, 1); + assert.strictEqual(models[0].id, 'gpt-4.1'); + }); + + it('should return empty array for unregistered provider', async () => { + const models = await LLMProviderRegistry.getModelsByProvider('openai'); + assert.deepEqual(models, []); + }); + + it('should handle provider getModels errors', async () => { + const badProvider = createMockProvider({ name: 'openai' }); + badProvider.getModels = async () => { throw new Error('API error'); }; + LLMProviderRegistry.registerProvider('openai', badProvider); + + const models = await LLMProviderRegistry.getModelsByProvider('openai'); + assert.deepEqual(models, []); + }); + }); + + // ============ getRegisteredProviders Tests ============ + describe('getRegisteredProviders', () => { + it('should return list of registered provider names', () => { + LLMProviderRegistry.registerProvider('openai', createMockProvider({ name: 'openai' })); + LLMProviderRegistry.registerProvider('anthropic', createMockProvider({ name: 'anthropic' })); + + const providers = LLMProviderRegistry.getRegisteredProviders(); + + assert.includeMembers(providers, ['openai', 'anthropic']); + }); + + it('should return empty array when no providers registered', () => { + const providers = LLMProviderRegistry.getRegisteredProviders(); + assert.deepEqual(providers, []); + }); + }); + + // ============ clear Tests ============ + describe('clear', () => { + it('should remove all providers', () => { + LLMProviderRegistry.registerProvider('openai', createMockProvider({ name: 'openai' })); + LLMProviderRegistry.registerProvider('anthropic', createMockProvider({ name: 'anthropic' })); + + LLMProviderRegistry.clear(); + + assert.isFalse(LLMProviderRegistry.hasProvider('openai')); + assert.isFalse(LLMProviderRegistry.hasProvider('anthropic')); + }); + }); + + // ============ getStats Tests ============ + describe('getStats', () => { + it('should return provider count and list', () => { + LLMProviderRegistry.registerProvider('openai', createMockProvider({ name: 'openai' })); + LLMProviderRegistry.registerProvider('anthropic', createMockProvider({ name: 'anthropic' })); + + const stats = LLMProviderRegistry.getStats(); + + assert.strictEqual(stats.providersCount, 2); + assert.includeMembers(stats.providers, ['openai', 'anthropic']); + }); + + it('should return zero count when empty', () => { + const stats = LLMProviderRegistry.getStats(); + + assert.strictEqual(stats.providersCount, 0); + assert.deepEqual(stats.providers, []); + }); + }); + + // ============ getProviderStorageKeys Tests ============ + describe('getProviderStorageKeys', () => { + it('should return correct keys for OpenAI', () => { + const keys = LLMProviderRegistry.getProviderStorageKeys('openai'); + assert.strictEqual(keys.apiKey, 'ai_chat_api_key'); + }); + + it('should return correct keys for Anthropic', () => { + const keys = LLMProviderRegistry.getProviderStorageKeys('anthropic'); + assert.strictEqual(keys.apiKey, 'ai_chat_anthropic_api_key'); + }); + + it('should return correct keys for LiteLLM (has endpoint)', () => { + const keys = LLMProviderRegistry.getProviderStorageKeys('litellm'); + assert.strictEqual(keys.apiKey, 'ai_chat_litellm_api_key'); + assert.strictEqual(keys.endpoint, 'ai_chat_litellm_endpoint'); + }); + + it('should return correct keys for Groq', () => { + const keys = LLMProviderRegistry.getProviderStorageKeys('groq'); + assert.strictEqual(keys.apiKey, 'ai_chat_groq_api_key'); + }); + }); + + // ============ getProviderApiKey / saveProviderApiKey Tests ============ + describe('getProviderApiKey / saveProviderApiKey', () => { + it('should read API key from localStorage', () => { + const apiKey = LLMProviderRegistry.getProviderApiKey('openai'); + assert.strictEqual(apiKey, 'sk-test-key'); + }); + + it('should return empty string when no API key', () => { + localStorageMock.restore(); + localStorageMock = createLocalStorageMock({}); + + const apiKey = LLMProviderRegistry.getProviderApiKey('openai'); + assert.strictEqual(apiKey, ''); + }); + + it('should write API key to localStorage', () => { + LLMProviderRegistry.saveProviderApiKey('openai', 'sk-new-key'); + + const stored = localStorageMock.store.get('ai_chat_api_key'); + assert.strictEqual(stored, 'sk-new-key'); + }); + + it('should remove API key when null passed', () => { + LLMProviderRegistry.saveProviderApiKey('openai', null); + + const stored = localStorageMock.store.get('ai_chat_api_key'); + assert.isUndefined(stored); + }); + }); + + // ============ getProviderEndpoint / saveProviderEndpoint Tests ============ + describe('getProviderEndpoint / saveProviderEndpoint', () => { + it('should read endpoint from localStorage', () => { + const endpoint = LLMProviderRegistry.getProviderEndpoint('litellm'); + assert.strictEqual(endpoint, 'http://localhost:4000'); + }); + + it('should return undefined if no endpoint key', () => { + const endpoint = LLMProviderRegistry.getProviderEndpoint('openai'); + assert.isUndefined(endpoint); + }); + + it('should save endpoint to localStorage', () => { + LLMProviderRegistry.saveProviderEndpoint('litellm', 'http://new-endpoint:5000'); + + const stored = localStorageMock.store.get('ai_chat_litellm_endpoint'); + assert.strictEqual(stored, 'http://new-endpoint:5000'); + }); + }); + + // ============ validateProviderCredentials Tests ============ + describe('validateProviderCredentials', () => { + it('should validate OpenAI credentials (requires API key)', () => { + const result = LLMProviderRegistry.validateProviderCredentials('openai'); + assert.isTrue(result.isValid); + }); + + it('should fail validation when API key missing', () => { + localStorageMock.restore(); + localStorageMock = createLocalStorageMock({}); + + const result = LLMProviderRegistry.validateProviderCredentials('openai'); + assert.isFalse(result.isValid); + assert.isDefined(result.missingItems); + }); + + it('should validate LiteLLM credentials (requires endpoint)', () => { + const result = LLMProviderRegistry.validateProviderCredentials('litellm'); + // LiteLLM requires endpoint + assert.isDefined(result.isValid); + }); + }); + + // ============ getProviderCredentials Tests ============ + describe('getProviderCredentials', () => { + it('should return canProceed true when valid', () => { + const result = LLMProviderRegistry.getProviderCredentials('openai'); + assert.isTrue(result.canProceed); + assert.strictEqual(result.apiKey, 'sk-test-key'); + }); + + it('should return canProceed false when invalid', () => { + localStorageMock.restore(); + localStorageMock = createLocalStorageMock({}); + + const result = LLMProviderRegistry.getProviderCredentials('openai'); + assert.isFalse(result.canProceed); + assert.isNull(result.apiKey); + }); + + it('should return apiKey and endpoint from storage', () => { + const result = LLMProviderRegistry.getProviderCredentials('litellm'); + + if (result.canProceed) { + assert.strictEqual(result.apiKey, 'test-litellm-key'); + assert.strictEqual(result.endpoint, 'http://localhost:4000'); + } + }); + }); + + // ============ fetchProviderModels Tests ============ + describe('fetchProviderModels', () => { + it('should fetch from provider API', async () => { + const modelsResponse = { + data: [ + { id: 'gpt-4.1', object: 'model' }, + { id: 'gpt-4.1-mini', object: 'model' }, + ], + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const models = await LLMProviderRegistry.fetchProviderModels('openai', 'sk-test-key'); + + assert.isArray(models); + assert.isTrue(models.length > 0); + }); + + it('should throw on API errors', async () => { + const errorResponse = { error: { message: 'Unauthorized' } }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(errorResponse), { status: 401 }) + ); + + try { + await LLMProviderRegistry.fetchProviderModels('openai', 'invalid-key'); + assert.fail('Should have thrown'); + } catch (error) { + assert.isDefined(error); + } + }); + }); + + // ============ testProviderConnection Tests ============ + describe('testProviderConnection', () => { + it('should return success on successful fetch', async () => { + const modelsResponse = { + data: [{ id: 'gpt-4.1', object: 'model' }], + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const result = await LLMProviderRegistry.testProviderConnection('openai', 'sk-test-key'); + + assert.isTrue(result.success); + assert.include(result.message, 'Successfully connected'); + }); + + it('should return failure with error message on failure', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').rejects(new Error('Network error')); + + const result = await LLMProviderRegistry.testProviderConnection('openai', 'sk-test-key'); + + assert.isFalse(result.success); + assert.include(result.message, 'Network error'); + }); + }); + + // ============ Provider Creation Tests ============ + describe('Provider Creation (via storage keys)', () => { + it('should create OpenAI provider correctly', () => { + const keys = LLMProviderRegistry.getProviderStorageKeys('openai'); + assert.isDefined(keys.apiKey); + }); + + it('should create Anthropic provider correctly', () => { + const keys = LLMProviderRegistry.getProviderStorageKeys('anthropic'); + assert.isDefined(keys.apiKey); + }); + + it('should create LiteLLM provider with endpoint', () => { + const keys = LLMProviderRegistry.getProviderStorageKeys('litellm'); + assert.isDefined(keys.apiKey); + assert.isDefined(keys.endpoint); + }); + + it('should create Groq provider correctly', () => { + const keys = LLMProviderRegistry.getProviderStorageKeys('groq'); + assert.isDefined(keys.apiKey); + }); + + it('should create OpenRouter provider correctly', () => { + const keys = LLMProviderRegistry.getProviderStorageKeys('openrouter'); + assert.isDefined(keys.apiKey); + }); + + it('should create Cerebras provider correctly', () => { + const keys = LLMProviderRegistry.getProviderStorageKeys('cerebras'); + assert.isDefined(keys.apiKey); + }); + + it('should create Google AI provider correctly', () => { + const keys = LLMProviderRegistry.getProviderStorageKeys('googleai'); + assert.isDefined(keys.apiKey); + }); + }); +}); diff --git a/front_end/panels/ai_chat/LLM/__tests__/LLMResponseParser.test.ts b/front_end/panels/ai_chat/LLM/__tests__/LLMResponseParser.test.ts new file mode 100644 index 0000000000..388fc87db1 --- /dev/null +++ b/front_end/panels/ai_chat/LLM/__tests__/LLMResponseParser.test.ts @@ -0,0 +1,411 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { LLMResponseParser } from '../LLMResponseParser.js'; +import type { UnifiedLLMResponse } from '../LLMTypes.js'; + +describe('ai_chat: LLMResponseParser', () => { + // ============ parseStrictJSON Tests ============ + describe('parseStrictJSON', () => { + it('should parse valid JSON directly', () => { + const json = '{"name": "test", "value": 123}'; + const result = LLMResponseParser.parseStrictJSON(json); + assert.deepEqual(result, { name: 'test', value: 123 }); + }); + + it('should handle JSON with leading/trailing whitespace', () => { + const json = ' \n {"key": "value"} \n '; + const result = LLMResponseParser.parseStrictJSON(json); + assert.deepEqual(result, { key: 'value' }); + }); + + it('should strip markdown code blocks (```json)', () => { + const json = '```json\n{"name": "test"}\n```'; + const result = LLMResponseParser.parseStrictJSON(json); + assert.deepEqual(result, { name: 'test' }); + }); + + it('should strip plain markdown code blocks (```)', () => { + const json = '```\n{"name": "test"}\n```'; + const result = LLMResponseParser.parseStrictJSON(json); + assert.deepEqual(result, { name: 'test' }); + }); + + it('should extract JSON from surrounding text', () => { + const text = 'Here is the response: {"action": "click"} that you requested.'; + const result = LLMResponseParser.parseStrictJSON(text); + assert.deepEqual(result, { action: 'click' }); + }); + + it('should extract JSON array from surrounding text', () => { + const text = 'The results are: [1, 2, 3] as expected.'; + const result = LLMResponseParser.parseStrictJSON(text); + assert.deepEqual(result, [1, 2, 3]); + }); + + it('should throw on invalid JSON after cleanup', () => { + const invalidJson = 'This is not JSON at all'; + assert.throws(() => { + LLMResponseParser.parseStrictJSON(invalidJson); + }, /Unable to parse JSON/); + }); + + it('should handle nested JSON objects', () => { + const json = '{"outer": {"inner": {"value": true}}}'; + const result = LLMResponseParser.parseStrictJSON(json); + assert.deepEqual(result, { outer: { inner: { value: true } } }); + }); + + it('should handle JSON with special characters', () => { + const json = '{"message": "Hello\\nWorld\\t!"}'; + const result = LLMResponseParser.parseStrictJSON(json); + assert.deepEqual(result, { message: 'Hello\nWorld\t!' }); + }); + }); + + // ============ parseResponse Tests ============ + describe('parseResponse', () => { + it('should return tool_call for functionCall responses', () => { + const response: UnifiedLLMResponse = { + functionCall: { + name: 'click_element', + arguments: { selector: '#button' }, + }, + }; + + const result = LLMResponseParser.parseResponse(response); + assert.strictEqual(result.type, 'tool_call'); + if (result.type === 'tool_call') { + assert.strictEqual(result.name, 'click_element'); + assert.deepEqual(result.args, { selector: '#button' }); + } + }); + + it('should return final_answer for text responses', () => { + const response: UnifiedLLMResponse = { + text: 'The answer is 42.', + }; + + const result = LLMResponseParser.parseResponse(response); + assert.strictEqual(result.type, 'final_answer'); + if (result.type === 'final_answer') { + assert.strictEqual(result.answer, 'The answer is 42.'); + } + }); + + it('should parse JSON tool call from text (fallback)', () => { + const response: UnifiedLLMResponse = { + text: '{"action":"tool","toolName":"navigate","toolArgs":{"url":"https://example.com"}}', + }; + + const result = LLMResponseParser.parseResponse(response); + assert.strictEqual(result.type, 'tool_call'); + if (result.type === 'tool_call') { + assert.strictEqual(result.name, 'navigate'); + assert.deepEqual(result.args, { url: 'https://example.com' }); + } + }); + + it('should return error for empty response', () => { + const response: UnifiedLLMResponse = {}; + + const result = LLMResponseParser.parseResponse(response); + assert.strictEqual(result.type, 'error'); + if (result.type === 'error') { + assert.include(result.error, 'No valid response'); + } + }); + + it('should prioritize functionCall over text', () => { + const response: UnifiedLLMResponse = { + text: 'Some text', + functionCall: { + name: 'test_tool', + arguments: { arg: 'value' }, + }, + }; + + const result = LLMResponseParser.parseResponse(response); + assert.strictEqual(result.type, 'tool_call'); + if (result.type === 'tool_call') { + assert.strictEqual(result.name, 'test_tool'); + } + }); + + it('should handle text with action:tool but missing toolName', () => { + const response: UnifiedLLMResponse = { + text: '{"action":"tool","description":"some action"}', + }; + + const result = LLMResponseParser.parseResponse(response); + assert.strictEqual(result.type, 'final_answer'); + }); + + it('should treat non-tool JSON as final answer', () => { + const response: UnifiedLLMResponse = { + text: '{"result": "success", "data": [1, 2, 3]}', + }; + + const result = LLMResponseParser.parseResponse(response); + assert.strictEqual(result.type, 'final_answer'); + }); + + it('should handle tool call with empty toolArgs', () => { + const response: UnifiedLLMResponse = { + text: '{"action":"tool","toolName":"get_time"}', + }; + + const result = LLMResponseParser.parseResponse(response); + assert.strictEqual(result.type, 'tool_call'); + if (result.type === 'tool_call') { + assert.strictEqual(result.name, 'get_time'); + assert.deepEqual(result.args, {}); + } + }); + }); + + // ============ parseJSONWithFallbacks Tests ============ + describe('parseJSONWithFallbacks', () => { + it('should try direct parsing first', () => { + const json = '{"valid": true}'; + const result = LLMResponseParser.parseJSONWithFallbacks(json); + assert.deepEqual(result, { valid: true }); + }); + + it('should try trim and parse', () => { + const json = ' {"valid": true} '; + const result = LLMResponseParser.parseJSONWithFallbacks(json); + assert.deepEqual(result, { valid: true }); + }); + + it('should handle markdown code blocks', () => { + const json = '```json\n{"code": "block"}\n```'; + const result = LLMResponseParser.parseJSONWithFallbacks(json); + assert.deepEqual(result, { code: 'block' }); + }); + + it('should extract JSON from text', () => { + const text = 'The data is {"extracted": true} from here.'; + const result = LLMResponseParser.parseJSONWithFallbacks(text); + assert.deepEqual(result, { extracted: true }); + }); + + it('should fix common JSON issues (single quotes)', () => { + const json = "{'key': 'value'}"; + const result = LLMResponseParser.parseJSONWithFallbacks(json); + assert.deepEqual(result, { key: 'value' }); + }); + + it('should fix trailing commas', () => { + const json = '{"a": 1, "b": 2, }'; + const result = LLMResponseParser.parseJSONWithFallbacks(json); + assert.deepEqual(result, { a: 1, b: 2 }); + }); + + it('should throw after all strategies fail', () => { + const invalid = 'This is definitely not JSON { broken }'; + assert.throws(() => { + LLMResponseParser.parseJSONWithFallbacks(invalid); + }, /JSON parsing failed/); + }); + + it('should handle arrays', () => { + const json = '[1, 2, "three", {"four": 4}]'; + const result = LLMResponseParser.parseJSONWithFallbacks(json); + assert.deepEqual(result, [1, 2, 'three', { four: 4 }]); + }); + + it('should handle nested objects', () => { + const json = '{"level1": {"level2": {"level3": "deep"}}}'; + const result = LLMResponseParser.parseJSONWithFallbacks(json); + assert.deepEqual(result, { level1: { level2: { level3: 'deep' } } }); + }); + }); + + // ============ validateStrictJSON Tests ============ + describe('validateStrictJSON', () => { + it('should return isValid true for valid JSON', () => { + const result = LLMResponseParser.validateStrictJSON('{"valid": true}'); + assert.isTrue(result.isValid); + assert.isDefined(result.cleaned); + }); + + it('should return cleaned JSON string', () => { + const result = LLMResponseParser.validateStrictJSON(' {"key": "value"} '); + assert.isTrue(result.isValid); + // Implementation trims whitespace but preserves JSON formatting + assert.strictEqual(result.cleaned, '{"key": "value"}'); + }); + + it('should return isValid false with error for invalid JSON', () => { + const result = LLMResponseParser.validateStrictJSON('not json'); + assert.isFalse(result.isValid); + assert.isDefined(result.error); + }); + + it('should clean up JSON with fallbacks', () => { + const result = LLMResponseParser.validateStrictJSON('```json\n{"code": true}\n```'); + assert.isTrue(result.isValid); + }); + }); + + // ============ extractStructuredData Tests ============ + describe('extractStructuredData', () => { + it('should extract JSON if present', () => { + const text = '{"name": "test", "value": 123}'; + const result = LLMResponseParser.extractStructuredData(text, ['name', 'value']); + assert.deepEqual(result, { name: 'test', value: 123 }); + }); + + it('should extract fields using pattern matching', () => { + const text = 'name: "John", age: 30, city: "NYC"'; + const result = LLMResponseParser.extractStructuredData(text, ['name', 'age', 'city']); + assert.strictEqual(result.name, 'John'); + assert.strictEqual(result.age, '30'); + assert.strictEqual(result.city, 'NYC'); + }); + + it('should handle missing fields gracefully', () => { + const text = 'name: "Test"'; + const result = LLMResponseParser.extractStructuredData(text, ['name', 'missing']); + assert.strictEqual(result.name, 'Test'); + assert.isUndefined(result.missing); + }); + + it('should prefer JSON parsing over pattern matching', () => { + const text = '{"name": "JSON"} but also name: "Pattern"'; + const result = LLMResponseParser.extractStructuredData(text, ['name']); + assert.strictEqual(result.name, 'JSON'); + }); + }); + + // ============ enhanceResponse Tests ============ + describe('enhanceResponse', () => { + it('should add parsedJson when strictJsonMode enabled', () => { + const response: UnifiedLLMResponse = { + text: '{"parsed": true}', + }; + const enhanced = LLMResponseParser.enhanceResponse(response, { strictJsonMode: true }); + assert.deepEqual(enhanced.parsedJson, { parsed: true }); + }); + + it('should not modify response when strictJsonMode disabled', () => { + const response: UnifiedLLMResponse = { + text: '{"ignored": true}', + }; + const enhanced = LLMResponseParser.enhanceResponse(response, { strictJsonMode: false }); + assert.isUndefined(enhanced.parsedJson); + }); + + it('should extract expectedFields from text', () => { + const response: UnifiedLLMResponse = { + text: 'action: click, selector: "#btn"', + }; + const enhanced = LLMResponseParser.enhanceResponse(response, { + expectedFields: ['action', 'selector'], + }); + assert.isDefined(enhanced.parsedJson); + assert.strictEqual(enhanced.parsedJson.action, 'click'); + }); + + it('should handle both strictJsonMode and expectedFields', () => { + const response: UnifiedLLMResponse = { + text: '{"action": "click"}', + }; + const enhanced = LLMResponseParser.enhanceResponse(response, { + strictJsonMode: true, + expectedFields: ['action'], + }); + assert.deepEqual(enhanced.parsedJson, { action: 'click' }); + }); + + it('should not throw on invalid JSON with strictJsonMode', () => { + const response: UnifiedLLMResponse = { + text: 'Not valid JSON', + }; + // Should not throw, just log error + const enhanced = LLMResponseParser.enhanceResponse(response, { strictJsonMode: true }); + assert.isUndefined(enhanced.parsedJson); + }); + }); + + // ============ isValidJSON Tests ============ + describe('isValidJSON', () => { + it('should return true for valid JSON object', () => { + assert.isTrue(LLMResponseParser.isValidJSON('{"valid": true}')); + }); + + it('should return true for valid JSON array', () => { + assert.isTrue(LLMResponseParser.isValidJSON('[1, 2, 3]')); + }); + + it('should return true for JSON string', () => { + assert.isTrue(LLMResponseParser.isValidJSON('"string"')); + }); + + it('should return true for JSON number', () => { + assert.isTrue(LLMResponseParser.isValidJSON('123')); + }); + + it('should return true for JSON boolean', () => { + assert.isTrue(LLMResponseParser.isValidJSON('true')); + assert.isTrue(LLMResponseParser.isValidJSON('false')); + }); + + it('should return true for JSON null', () => { + assert.isTrue(LLMResponseParser.isValidJSON('null')); + }); + + it('should return false for invalid JSON', () => { + assert.isFalse(LLMResponseParser.isValidJSON('not json')); + assert.isFalse(LLMResponseParser.isValidJSON('{invalid}')); + assert.isFalse(LLMResponseParser.isValidJSON("{'single': 'quotes'}")); + }); + + it('should handle whitespace correctly', () => { + assert.isTrue(LLMResponseParser.isValidJSON(' {"key": "value"} ')); + }); + }); + + // ============ getJSONParsingSuggestions Tests ============ + describe('getJSONParsingSuggestions', () => { + it('should suggest starting with { or [', () => { + const suggestions = LLMResponseParser.getJSONParsingSuggestions('abc{"key": "value"}'); + assert.includeMembers(suggestions, ['Response should start with { or [']); + }); + + it('should suggest ending with } or ]', () => { + const suggestions = LLMResponseParser.getJSONParsingSuggestions('{"key": "value"}abc'); + assert.includeMembers(suggestions, ['Response should end with } or ]']); + }); + + it('should suggest using double quotes', () => { + const suggestions = LLMResponseParser.getJSONParsingSuggestions("{'key': 'value'}"); + assert.includeMembers(suggestions, ['Use double quotes (") instead of single quotes (\')']); + }); + + it('should suggest removing trailing commas', () => { + const suggestions = LLMResponseParser.getJSONParsingSuggestions('{"a": 1, }'); + assert.includeMembers(suggestions, ['Remove trailing commas before } or ]']); + }); + + it('should suggest quoting object keys', () => { + const suggestions = LLMResponseParser.getJSONParsingSuggestions('{key: "value"}'); + assert.includeMembers(suggestions, ['Ensure all object keys are quoted']); + }); + + it('should return empty array for valid JSON', () => { + const suggestions = LLMResponseParser.getJSONParsingSuggestions('{"valid": true}'); + // Valid JSON might still trigger unquoted keys suggestion due to regex + // but main structural suggestions should not appear + assert.notIncludeMembers(suggestions, ['Response should start with { or [']); + assert.notIncludeMembers(suggestions, ['Response should end with } or ]']); + }); + + it('should return multiple suggestions for multiple issues', () => { + const suggestions = LLMResponseParser.getJSONParsingSuggestions("abc{'key': 'value', }xyz"); + assert.isAtLeast(suggestions.length, 3); + }); + }); +}); diff --git a/front_end/panels/ai_chat/LLM/__tests__/LLMTestHelpers.ts b/front_end/panels/ai_chat/LLM/__tests__/LLMTestHelpers.ts new file mode 100644 index 0000000000..f6c03d301e --- /dev/null +++ b/front_end/panels/ai_chat/LLM/__tests__/LLMTestHelpers.ts @@ -0,0 +1,778 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/** + * Test utilities for LLM module testing + * Provides mock factories, fetch stubs, localStorage helpers, and assertion utilities + */ + +import type { LLMMessage, LLMResponse, LLMErrorType, RetryConfig } from '../LLMTypes.js'; +import { LLMErrorType as ErrorType } from '../LLMTypes.js'; + +// Declare require for CommonJS module imports used in test helpers +declare const require: (module: string) => any; + +// Use global sinon provided by Karma framework +declare const sinon: typeof import('sinon'); + +// ============ MOCK RESPONSE FACTORIES ============ + +export interface MockResponseOptions { + text?: string; + functionCall?: { name: string; arguments: any }; + usage?: { input_tokens: number; output_tokens: number }; +} + +/** + * Creates a mock OpenAI Responses API response + * Note: OpenAI uses Responses API (/v1/responses), NOT Chat Completions! + */ +export function createMockOpenAIResponse(options: MockResponseOptions = {}): any { + const output: any[] = []; + + if (options.functionCall) { + output.push({ + type: 'function_call', + name: options.functionCall.name, + arguments: JSON.stringify(options.functionCall.arguments), + call_id: `call_${Date.now()}_${Math.random().toString(36).substring(2, 8)}`, + }); + } else if (options.text !== undefined) { + output.push({ + type: 'message', + content: [{ type: 'output_text', text: options.text }], + }); + } + + return { + output, + usage: options.usage || { input_tokens: 100, output_tokens: 50 }, + }; +} + +/** + * Creates a mock Anthropic Messages API response + */ +export function createMockAnthropicResponse(options: MockResponseOptions = {}): any { + const content: any[] = []; + + if (options.functionCall) { + content.push({ + type: 'tool_use', + id: `toolu_${Date.now()}_${Math.random().toString(36).substring(2, 8)}`, + name: options.functionCall.name, + input: options.functionCall.arguments, + }); + } else if (options.text !== undefined) { + content.push({ + type: 'text', + text: options.text, + }); + } + + return { + content, + usage: { + input_tokens: options.usage?.input_tokens || 100, + output_tokens: options.usage?.output_tokens || 50, + }, + stop_reason: options.functionCall ? 'tool_use' : 'end_turn', + }; +} + +/** + * Creates a mock Google AI (Gemini) response + */ +export function createMockGoogleAIResponse(options: MockResponseOptions = {}): any { + const parts: any[] = []; + + if (options.functionCall) { + parts.push({ + functionCall: { + name: options.functionCall.name, + args: options.functionCall.arguments, + }, + }); + } else if (options.text !== undefined) { + parts.push({ text: options.text }); + } + + return { + candidates: [ + { + content: { + parts, + role: 'model', + }, + finishReason: options.functionCall ? 'FUNCTION_CALL' : 'STOP', + }, + ], + usageMetadata: { + promptTokenCount: options.usage?.input_tokens || 100, + candidatesTokenCount: options.usage?.output_tokens || 50, + }, + }; +} + +/** + * Creates a mock OpenAI-compatible response (for Groq, OpenRouter, Cerebras, LiteLLM) + */ +export function createMockOpenAICompatibleResponse(options: MockResponseOptions = {}): any { + const message: any = { role: 'assistant' }; + + if (options.functionCall) { + message.tool_calls = [ + { + id: `call_${Date.now()}`, + type: 'function', + function: { + name: options.functionCall.name, + arguments: JSON.stringify(options.functionCall.arguments), + }, + }, + ]; + } else { + message.content = options.text || ''; + } + + return { + choices: [{ message, finish_reason: options.functionCall ? 'tool_calls' : 'stop' }], + usage: { + prompt_tokens: options.usage?.input_tokens || 100, + completion_tokens: options.usage?.output_tokens || 50, + total_tokens: (options.usage?.input_tokens || 100) + (options.usage?.output_tokens || 50), + }, + }; +} + +// ============ ERROR RESPONSE FACTORIES ============ + +export interface ErrorResponseOptions { + message?: string; + code?: string; + retryAfter?: number; +} + +/** + * Creates a mock 401 Unauthorized response (AUTH_ERROR - NOT retryable) + */ +export function createMock401Response(provider: string, options: ErrorResponseOptions = {}): Response { + const errorBody = { + error: { + message: options.message || `Invalid API key for ${provider}`, + type: 'authentication_error', + code: options.code || 'invalid_api_key', + }, + }; + + return new Response(JSON.stringify(errorBody), { + status: 401, + statusText: 'Unauthorized', + headers: { 'Content-Type': 'application/json' }, + }); +} + +/** + * Creates a mock 403 Forbidden response (AUTH_ERROR - NOT retryable) + */ +export function createMock403Response(provider: string, options: ErrorResponseOptions = {}): Response { + const errorBody = { + error: { + message: options.message || `Access forbidden for ${provider}`, + type: 'permission_error', + code: options.code || 'forbidden', + }, + }; + + return new Response(JSON.stringify(errorBody), { + status: 403, + statusText: 'Forbidden', + headers: { 'Content-Type': 'application/json' }, + }); +} + +/** + * Creates a mock 429 Rate Limit response (RATE_LIMIT_ERROR - retryable with 60s delay) + */ +export function createMock429Response(provider: string, options: ErrorResponseOptions = {}): Response { + const errorBody = { + error: { + message: options.message || `Rate limit exceeded for ${provider}`, + type: 'rate_limit_error', + code: options.code || 'rate_limit_exceeded', + }, + }; + + const headers: Record = { 'Content-Type': 'application/json' }; + if (options.retryAfter) { + headers['Retry-After'] = String(options.retryAfter); + } + + return new Response(JSON.stringify(errorBody), { + status: 429, + statusText: 'Too Many Requests', + headers, + }); +} + +/** + * Creates a mock 500 Internal Server Error response (SERVER_ERROR - retryable) + */ +export function createMock500Response(provider: string, options: ErrorResponseOptions = {}): Response { + const errorBody = { + error: { + message: options.message || `Internal server error from ${provider}`, + type: 'server_error', + code: options.code || 'internal_error', + }, + }; + + return new Response(JSON.stringify(errorBody), { + status: 500, + statusText: 'Internal Server Error', + headers: { 'Content-Type': 'application/json' }, + }); +} + +/** + * Creates a mock 503 Service Unavailable response (SERVER_ERROR - retryable) + */ +export function createMock503Response(provider: string, options: ErrorResponseOptions = {}): Response { + const errorBody = { + error: { + message: options.message || `${provider} service temporarily unavailable`, + type: 'server_error', + code: options.code || 'service_unavailable', + }, + }; + + return new Response(JSON.stringify(errorBody), { + status: 503, + statusText: 'Service Unavailable', + headers: { 'Content-Type': 'application/json' }, + }); +} + +/** + * Creates a mock quota exceeded response (QUOTA_ERROR - NOT retryable) + */ +export function createMockQuotaExceededResponse(provider: string, options: ErrorResponseOptions = {}): Response { + const errorBody = { + error: { + message: options.message || `Insufficient quota for ${provider}`, + type: 'insufficient_quota', + code: options.code || 'quota_exceeded', + }, + }; + + return new Response(JSON.stringify(errorBody), { + status: 402, + statusText: 'Payment Required', + headers: { 'Content-Type': 'application/json' }, + }); +} + +/** + * Creates a network error (NETWORK_ERROR - retryable) + */ +export function createMockNetworkError(message?: string): Error { + return new Error(message || 'fetch failed: Network error'); +} + +/** + * Creates a timeout error (NETWORK_ERROR - retryable) + */ +export function createMockTimeoutError(message?: string): Error { + return new Error(message || 'Request timeout'); +} + +/** + * Creates a JSON parse error (JSON_PARSE_ERROR - retryable) + */ +export function createMockJSONParseError(message?: string): Error { + return new Error(message || 'JSON parsing failed: Unexpected token'); +} + +// ============ FETCH STUB HELPERS ============ + +export interface FetchStubConfig { + url?: string | RegExp; + method?: 'GET' | 'POST'; + response?: any; + responseStatus?: number; + responseHeaders?: Record; + delay?: number; + error?: Error; +} + +/** + * Creates a sinon stub for fetch with configurable response + */ +export function createFetchStub(config: FetchStubConfig): sinon.SinonStub { + const stub = sinon.stub(globalThis, 'fetch'); + + if (config.error) { + stub.rejects(config.error); + } else { + const responseInit: ResponseInit = { + status: config.responseStatus || 200, + statusText: config.responseStatus === 200 ? 'OK' : 'Error', + headers: { + 'Content-Type': 'application/json', + ...config.responseHeaders, + }, + }; + + const response = new Response(JSON.stringify(config.response || {}), responseInit); + + if (config.delay) { + stub.callsFake(async () => { + await new Promise(resolve => setTimeout(resolve, config.delay)); + return response; + }); + } else { + stub.resolves(response); + } + } + + return stub; +} + +/** + * Creates a fetch stub that returns different responses sequentially + * Useful for testing retry logic + */ +export function createSequentialFetchStub( + responses: Array<{ response?: any; status?: number; error?: Error; delay?: number }> +): sinon.SinonStub { + const stub = sinon.stub(globalThis, 'fetch'); + let callIndex = 0; + + stub.callsFake(async () => { + const responseConfig = responses[Math.min(callIndex, responses.length - 1)]; + callIndex++; + + if (responseConfig.delay) { + await new Promise(resolve => setTimeout(resolve, responseConfig.delay)); + } + + if (responseConfig.error) { + throw responseConfig.error; + } + + return new Response(JSON.stringify(responseConfig.response || {}), { + status: responseConfig.status || 200, + statusText: responseConfig.status === 200 ? 'OK' : 'Error', + headers: { 'Content-Type': 'application/json' }, + }); + }); + + return stub; +} + +/** + * Restores the fetch stub + */ +export function restoreFetch(stub: sinon.SinonStub): void { + stub.restore(); +} + +// ============ LOCALSTORAGE HELPERS ============ + +export interface LocalStorageMockConfig { + [key: string]: string | null; +} + +/** + * Common localStorage keys used by LLM providers + */ +export const STORAGE_KEYS = { + // OpenAI + OPENAI_API_KEY: 'ai_chat_api_key', + + // Anthropic + ANTHROPIC_API_KEY: 'ai_chat_anthropic_api_key', + + // LiteLLM + LITELLM_ENDPOINT: 'ai_chat_litellm_endpoint', + LITELLM_API_KEY: 'ai_chat_litellm_api_key', + + // Groq + GROQ_API_KEY: 'ai_chat_groq_api_key', + + // OpenRouter + OPENROUTER_API_KEY: 'ai_chat_openrouter_api_key', + + // Cerebras + CEREBRAS_API_KEY: 'ai_chat_cerebras_api_key', + + // Google AI + GOOGLEAI_API_KEY: 'ai_chat_googleai_api_key', + + // BrowserOperator + BROWSEROPERATOR_API_KEY: 'ai_chat_browseroperator_api_key', + + // General + PROVIDER: 'ai_chat_provider', + MODEL_SELECTION: 'ai_chat_model_selection', + CUSTOM_MODELS: 'ai_chat_custom_models', + ALL_MODEL_OPTIONS: 'ai_chat_all_model_options', +} as const; + +/** + * Creates a mock localStorage with configurable initial values + */ +export function createLocalStorageMock(config: LocalStorageMockConfig = {}): { + mock: Storage; + store: Map; + getItem: sinon.SinonStub; + setItem: sinon.SinonStub; + removeItem: sinon.SinonStub; + clear: sinon.SinonStub; + restore: () => void; +} { + const store = new Map(); + + // Initialize with config values + for (const [key, value] of Object.entries(config)) { + if (value !== null) { + store.set(key, value); + } + } + + const originalLocalStorage = globalThis.localStorage; + + const getItem = sinon.stub().callsFake((key: string) => store.get(key) || null); + const setItem = sinon.stub().callsFake((key: string, value: string) => store.set(key, value)); + const removeItem = sinon.stub().callsFake((key: string) => store.delete(key)); + const clear = sinon.stub().callsFake(() => store.clear()); + + const mockStorage: Storage = { + getItem: getItem as unknown as Storage['getItem'], + setItem: setItem as unknown as Storage['setItem'], + removeItem: removeItem as unknown as Storage['removeItem'], + clear: clear as unknown as Storage['clear'], + get length() { return store.size; }, + key(index: number) { return Array.from(store.keys())[index] || null; }, + }; + + Object.defineProperty(globalThis, 'localStorage', { + value: mockStorage, + writable: true, + configurable: true, + }); + + return { + mock: mockStorage, + store, + getItem, + setItem, + removeItem, + clear, + restore: () => { + Object.defineProperty(globalThis, 'localStorage', { + value: originalLocalStorage, + writable: true, + configurable: true, + }); + }, + }; +} + +// ============ SINGLETON RESET HELPERS ============ + +/** + * Resets the LLMClient singleton for testing + */ +export function resetLLMClient(): void { + // Access private instance via type assertion + const LLMClientModule = require('../LLMClient.js'); + if (LLMClientModule.LLMClient) { + (LLMClientModule.LLMClient as any).instance = null; + } +} + +/** + * Resets the LLMProviderRegistry for testing + */ +export function resetLLMProviderRegistry(): void { + const RegistryModule = require('../LLMProviderRegistry.js'); + if (RegistryModule.LLMProviderRegistry) { + RegistryModule.LLMProviderRegistry.clear(); + } +} + +// ============ TEST MESSAGE FACTORIES ============ + +/** + * Creates test messages for LLM calls + */ +export function createTestMessages(options: { + includeSystem?: boolean; + includeTools?: boolean; + includeImages?: boolean; + systemPrompt?: string; + userMessage?: string; +} = {}): LLMMessage[] { + const messages: LLMMessage[] = []; + + if (options.includeSystem !== false) { + messages.push({ + role: 'system', + content: options.systemPrompt || 'You are a helpful assistant.', + }); + } + + if (options.includeImages) { + messages.push({ + role: 'user', + content: [ + { type: 'text', text: options.userMessage || 'What is in this image?' }, + { type: 'image_url', image_url: { url: 'data:image/png;base64,iVBORw0KGgoAAAANSUhEUg...' } }, + ], + }); + } else { + messages.push({ + role: 'user', + content: options.userMessage || 'Hello, how are you?', + }); + } + + if (options.includeTools) { + messages.push({ + role: 'assistant', + tool_calls: [ + { + id: 'call_test123', + type: 'function', + function: { + name: 'test_tool', + arguments: JSON.stringify({ arg1: 'value1' }), + }, + }, + ], + }); + messages.push({ + role: 'tool', + tool_call_id: 'call_test123', + content: 'Tool result: success', + }); + } + + return messages; +} + +/** + * Creates a tool call message + */ +export function createToolCallMessage(toolName: string, args: any): LLMMessage { + return { + role: 'assistant', + tool_calls: [ + { + id: `call_${Date.now()}`, + type: 'function', + function: { + name: toolName, + arguments: JSON.stringify(args), + }, + }, + ], + }; +} + +/** + * Creates a tool result message + */ +export function createToolResultMessage(callId: string, result: string): LLMMessage { + return { + role: 'tool', + tool_call_id: callId, + content: result, + }; +} + +// ============ ASSERTION HELPERS ============ + +/** + * Asserts that an LLM response matches expected values + */ +export function assertLLMResponse(response: LLMResponse, expected: Partial): void { + if (expected.text !== undefined) { + assert.strictEqual(response.text, expected.text, 'Response text mismatch'); + } + + if (expected.functionCall !== undefined) { + assert.isDefined(response.functionCall, 'Expected functionCall but got none'); + assert.strictEqual(response.functionCall!.name, expected.functionCall.name, 'Function name mismatch'); + assert.deepEqual(response.functionCall!.arguments, expected.functionCall.arguments, 'Function arguments mismatch'); + } + + if (expected.rawResponse !== undefined) { + assert.isDefined(response.rawResponse, 'Expected rawResponse but got none'); + } +} + +/** + * Asserts that an error is of the expected LLM error type + */ +export function assertErrorType(error: Error, expectedType: LLMErrorType): void { + const { LLMErrorClassifier } = require('../LLMErrorHandler.js'); + const actualType = LLMErrorClassifier.classifyError(error); + assert.strictEqual(actualType, expectedType, `Expected error type ${expectedType} but got ${actualType}`); +} + +/** + * Asserts that an error is retryable + */ +export function assertRetryable(error: Error): void { + const { LLMErrorClassifier } = require('../LLMErrorHandler.js'); + const errorType = LLMErrorClassifier.classifyError(error); + const shouldRetry = LLMErrorClassifier.shouldRetry(errorType); + assert.isTrue(shouldRetry, `Expected error to be retryable but ${errorType} is not retryable`); +} + +/** + * Asserts that an error is NOT retryable + */ +export function assertNotRetryable(error: Error): void { + const { LLMErrorClassifier } = require('../LLMErrorHandler.js'); + const errorType = LLMErrorClassifier.classifyError(error); + const shouldRetry = LLMErrorClassifier.shouldRetry(errorType); + assert.isFalse(shouldRetry, `Expected error to NOT be retryable but ${errorType} is retryable`); +} + +// ============ MOCK PROVIDER FACTORY ============ + +/** + * Creates a mock LLM provider for testing with sinon stubs + */ +export function createMockProvider(config: { + name: string; + models?: Array<{ id: string; name: string }>; + response?: LLMResponse; + callResponse?: LLMResponse; // Alias for response + callError?: Error; +}): any { + const mockResponse = config.response || config.callResponse || { text: 'Mock response', rawResponse: {} }; + + // Create stub functions + const callWithMessagesStub = sinon.stub(); + if (config.callError) { + callWithMessagesStub.rejects(config.callError); + } else { + callWithMessagesStub.resolves(mockResponse); + } + + const callStub = sinon.stub(); + if (config.callError) { + callStub.rejects(config.callError); + } else { + callStub.resolves(mockResponse); + } + + const getModelsStub = sinon.stub().resolves( + (config.models || [{ id: 'mock-model', name: 'Mock Model' }]).map(m => ({ + ...m, + provider: config.name, + capabilities: { functionCalling: true, reasoning: false, vision: false, structured: true }, + })) + ); + + const fetchModelsStub = sinon.stub().resolves( + config.models || [{ id: 'mock-model', name: 'Mock Model' }] + ); + + return { + name: config.name, + + callWithMessages: callWithMessagesStub, + call: callStub, + getModels: getModelsStub, + fetchModels: fetchModelsStub, + + validateCredentials: sinon.stub().returns({ isValid: true, message: 'Mock credentials valid' }), + + getCredentialStorageKeys: sinon.stub().returns({ apiKey: `ai_chat_${config.name}_api_key` }), + + parseResponse: sinon.stub().callsFake((response: LLMResponse) => { + if (response.functionCall) { + return { type: 'tool_call', name: response.functionCall.name, args: response.functionCall.arguments }; + } + return { type: 'final_answer', answer: response.text || '' }; + }), + + testConnection: sinon.stub().resolves({ success: true, message: 'Connection successful' }), + }; +} + +// ============ RETRY CONFIG HELPERS ============ + +/** + * Default retry config values (from LLMErrorHandler.ts) + */ +export const DEFAULT_RETRY_CONFIG: RetryConfig = { + maxRetries: 2, + baseDelayMs: 1000, + maxDelayMs: 10000, + backoffMultiplier: 2, + jitterMs: 500, +}; + +/** + * Rate limit retry config (60 seconds base delay!) + */ +export const RATE_LIMIT_RETRY_CONFIG: RetryConfig = { + maxRetries: 3, + baseDelayMs: 60000, // 60 seconds! + maxDelayMs: 300000, // 5 minutes max + backoffMultiplier: 1, // No exponential backoff for rate limits + jitterMs: 5000, +}; + +/** + * Network error retry config + */ +export const NETWORK_ERROR_RETRY_CONFIG: RetryConfig = { + maxRetries: 3, + baseDelayMs: 2000, + maxDelayMs: 30000, + backoffMultiplier: 2, + jitterMs: 1000, +}; + +/** + * Creates a fast retry config for testing (short delays) + */ +export function createFastRetryConfig(maxRetries: number = 2): RetryConfig { + return { + maxRetries, + baseDelayMs: 10, // Very short for tests + maxDelayMs: 50, + backoffMultiplier: 2, + jitterMs: 5, + }; +} + +// ============ TOOL DEFINITION HELPERS ============ + +/** + * Creates a mock tool definition in OpenAI format + */ +export function createMockToolDefinition(name: string, description: string, parameters?: any): any { + return { + type: 'function', + function: { + name, + description, + parameters: parameters || { + type: 'object', + properties: { + input: { type: 'string', description: 'Input parameter' }, + }, + required: ['input'], + }, + }, + }; +} diff --git a/front_end/panels/ai_chat/LLM/__tests__/OpenAICompatibleProviders.test.ts b/front_end/panels/ai_chat/LLM/__tests__/OpenAICompatibleProviders.test.ts new file mode 100644 index 0000000000..cbd0f22b9e --- /dev/null +++ b/front_end/panels/ai_chat/LLM/__tests__/OpenAICompatibleProviders.test.ts @@ -0,0 +1,725 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/** + * Tests for OpenAI-compatible providers: LiteLLM, Groq, OpenRouter, Cerebras + * These providers share the same Chat Completions API format. + */ + +import { LiteLLMProvider } from '../LiteLLMProvider.js'; +import { GroqProvider } from '../GroqProvider.js'; +import { OpenRouterProvider } from '../OpenRouterProvider.js'; +import { CerebrasProvider } from '../CerebrasProvider.js'; +import type { LLMMessage } from '../LLMTypes.js'; +// Use global sinon provided by Karma framework +declare const sinon: typeof import('sinon'); +import { + createLocalStorageMock, + createMockOpenAICompatibleResponse, + STORAGE_KEYS, +} from './LLMTestHelpers.js'; + +describe('ai_chat: OpenAI-Compatible Providers', () => { + let fetchStub: sinon.SinonStub; + let localStorageMock: ReturnType; + + beforeEach(() => { + localStorageMock = createLocalStorageMock({ + [STORAGE_KEYS.LITELLM_ENDPOINT]: 'http://localhost:4000', + [STORAGE_KEYS.LITELLM_API_KEY]: 'test-litellm-key', + [STORAGE_KEYS.GROQ_API_KEY]: 'gsk-test-key', + [STORAGE_KEYS.OPENROUTER_API_KEY]: 'sk-or-test-key', + [STORAGE_KEYS.CEREBRAS_API_KEY]: 'csk-test-key', + }); + }); + + afterEach(() => { + if (fetchStub) { + fetchStub.restore(); + } + localStorageMock.restore(); + sinon.restore(); + }); + + // ============ LiteLLM Provider Tests ============ + describe('LiteLLMProvider', () => { + let provider: LiteLLMProvider; + + beforeEach(() => { + provider = new LiteLLMProvider('test-api-key', 'http://localhost:4000'); + }); + + describe('endpoint configuration', () => { + it('should use provided endpoint', async () => { + const mockResponse = createMockOpenAICompatibleResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('gpt-4.1', [ + { role: 'user', content: 'Hello' } + ]); + + const url = fetchStub.firstCall.args[0]; + assert.include(url, 'http://localhost:4000/v1/chat/completions'); + }); + + it('should fallback to localStorage endpoint', async () => { + const providerNoEndpoint = new LiteLLMProvider('test-key', undefined); + const mockResponse = createMockOpenAICompatibleResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await providerNoEndpoint.callWithMessages('model', [ + { role: 'user', content: 'Hi' } + ]); + + const url = fetchStub.firstCall.args[0]; + assert.include(url, 'http://localhost:4000/v1/chat/completions'); + }); + + it('should throw when no endpoint configured', async () => { + localStorageMock.restore(); + localStorageMock = createLocalStorageMock({}); + const providerNoEndpoint = new LiteLLMProvider('test-key', undefined); + + try { + await providerNoEndpoint.callWithMessages('model', [ + { role: 'user', content: 'Hi' } + ]); + assert.fail('Should have thrown'); + } catch (error) { + assert.include((error as Error).message, 'endpoint not configured'); + } + }); + }); + + describe('authentication', () => { + it('should include Bearer token when API key provided', async () => { + const mockResponse = createMockOpenAICompatibleResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('model', [{ role: 'user', content: 'Hi' }]); + + const headers = fetchStub.firstCall.args[1].headers; + assert.strictEqual(headers.Authorization, 'Bearer test-api-key'); + }); + + it('should not include Authorization when no API key', async () => { + const providerNoKey = new LiteLLMProvider(null, 'http://localhost:4000'); + const mockResponse = createMockOpenAICompatibleResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await providerNoKey.callWithMessages('model', [{ role: 'user', content: 'Hi' }]); + + const headers = fetchStub.firstCall.args[1].headers; + assert.isUndefined(headers.Authorization); + }); + }); + + describe('credential validation', () => { + it('should require endpoint for LiteLLM', () => { + localStorageMock.restore(); + localStorageMock = createLocalStorageMock({}); + + const result = provider.validateCredentials(); + + assert.isFalse(result.isValid); + assert.include(result.missingItems!, 'Endpoint URL'); + }); + + it('should pass validation with endpoint (API key optional)', () => { + const result = provider.validateCredentials(); + + assert.isTrue(result.isValid); + }); + }); + + describe('custom models from localStorage', () => { + it('should include custom models in getModels', async () => { + localStorageMock.store.set('ai_chat_custom_models', JSON.stringify([ + { id: 'my-custom-model', name: 'My Custom Model' } + ])); + + // Mock API response with empty data (no models from API) + fetchStub = sinon.stub(globalThis, 'fetch').rejects(new Error('API unavailable')); + + const models = await provider.getModels(); + + const customModel = models.find(m => m.id === 'my-custom-model'); + assert.isDefined(customModel); + assert.strictEqual(customModel!.name, 'My Custom Model'); + }); + }); + }); + + // ============ Groq Provider Tests ============ + describe('GroqProvider', () => { + let provider: GroqProvider; + + beforeEach(() => { + provider = new GroqProvider('gsk-test-key'); + }); + + describe('endpoint configuration', () => { + it('should use Groq API endpoint', async () => { + const mockResponse = createMockOpenAICompatibleResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('llama-3.3-70b-versatile', [ + { role: 'user', content: 'Hello' } + ]); + + const url = fetchStub.firstCall.args[0]; + assert.include(url, 'https://api.groq.com/openai/v1/chat/completions'); + }); + }); + + describe('message conversion', () => { + it('should stringify tool content for tool role', async () => { + const mockResponse = createMockOpenAICompatibleResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const messages: LLMMessage[] = [ + { role: 'user', content: 'Search for cats' }, + { + role: 'assistant', + content: '', + tool_calls: [{ + id: 'call_123', + type: 'function', + function: { name: 'search', arguments: '{"query":"cats"}' } + }] + }, + { + role: 'tool', + tool_call_id: 'call_123', + name: 'search', + content: { results: ['cat1', 'cat2'] } as any // Object content + } + ]; + + await provider.callWithMessages('llama-3.3-70b-versatile', messages); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + const toolMessage = body.messages[2]; + // Content should be stringified + assert.strictEqual(typeof toolMessage.content, 'string'); + assert.include(toolMessage.content, 'cat1'); + }); + + it('should set tool_choice to auto when tools provided', async () => { + const mockResponse = createMockOpenAICompatibleResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('llama-3.3-70b-versatile', [ + { role: 'user', content: 'Search for cats' } + ], { + tools: [{ + type: 'function', + function: { name: 'search', description: 'Search', parameters: {} } + }] + }); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + assert.strictEqual(body.tool_choice, 'auto'); + }); + }); + + describe('model capabilities', () => { + it('should detect function calling for supported models', async () => { + const modelsResponse = { + object: 'list', + data: [ + { id: 'llama-3.3-70b-versatile', object: 'model', created: 1, owned_by: 'groq', active: true, context_window: 8192 }, + { id: 'llama-3.2-90b-vision-preview', object: 'model', created: 1, owned_by: 'groq', active: true, context_window: 8192 } + ] + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const models = await provider.getModels(); + + const llamaModel = models.find(m => m.id === 'llama-3.3-70b-versatile'); + assert.isTrue(llamaModel!.capabilities!.functionCalling); + + const visionModel = models.find(m => m.id === 'llama-3.2-90b-vision-preview'); + assert.isTrue(visionModel!.capabilities!.vision); + }); + + it('should filter inactive models', async () => { + const modelsResponse = { + object: 'list', + data: [ + { id: 'active-model', object: 'model', created: 1, owned_by: 'groq', active: true, context_window: 8192 }, + { id: 'inactive-model', object: 'model', created: 1, owned_by: 'groq', active: false, context_window: 8192 } + ] + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const models = await provider.getModels(); + + assert.strictEqual(models.length, 1); + assert.strictEqual(models[0].id, 'active-model'); + }); + }); + }); + + // ============ OpenRouter Provider Tests ============ + describe('OpenRouterProvider', () => { + let provider: OpenRouterProvider; + + beforeEach(() => { + provider = new OpenRouterProvider('sk-or-test-key'); + }); + + describe('endpoint configuration', () => { + it('should use OpenRouter API endpoint', async () => { + const mockResponse = createMockOpenAICompatibleResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('openai/gpt-4o', [ + { role: 'user', content: 'Hello' } + ]); + + const url = fetchStub.firstCall.args[0]; + assert.include(url, 'https://openrouter.ai/api/v1/chat/completions'); + }); + }); + + describe('special headers', () => { + it('should include HTTP-Referer and X-Title headers', async () => { + const mockResponse = createMockOpenAICompatibleResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('openai/gpt-4o', [ + { role: 'user', content: 'Hello' } + ]); + + const headers = fetchStub.firstCall.args[1].headers; + assert.strictEqual(headers['HTTP-Referer'], 'https://browseroperator.io'); + assert.strictEqual(headers['X-Title'], 'Browser Operator'); + }); + }); + + describe('temperature handling', () => { + it('should exclude temperature for O-series models', async () => { + const mockResponse = createMockOpenAICompatibleResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('openai/o3', [ + { role: 'user', content: 'Hello' } + ], { temperature: 0.7 }); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + assert.isUndefined(body.temperature); + }); + + it('should exclude temperature for GPT-5 models', async () => { + const mockResponse = createMockOpenAICompatibleResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('openai/gpt-5', [ + { role: 'user', content: 'Hello' } + ], { temperature: 0.7 }); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + assert.isUndefined(body.temperature); + }); + + it('should include temperature for other models', async () => { + const mockResponse = createMockOpenAICompatibleResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('openai/gpt-4o', [ + { role: 'user', content: 'Hello' } + ], { temperature: 0.7 }); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + assert.strictEqual(body.temperature, 0.7); + }); + }); + + describe('model fetching', () => { + it('should use tools filter when fetching models', async () => { + const modelsResponse = { + data: [ + { + id: 'openai/gpt-4o', + name: 'GPT-4o', + context_length: 128000, + architecture: { modality: 'multimodal', tokenizer: 'gpt-4' }, + pricing: { prompt: '0.005', completion: '0.015' }, + top_provider: { context_length: 128000 } + } + ] + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + await provider.fetchModels(); + + const url = fetchStub.firstCall.args[0]; + assert.include(url, 'supported_parameters=tools'); + }); + + it('should detect vision from multimodal architecture', async () => { + const modelsResponse = { + data: [ + { + id: 'openai/gpt-4o', + name: 'GPT-4o', + context_length: 128000, + architecture: { modality: 'multimodal', tokenizer: 'gpt-4' }, + pricing: { prompt: '0.005', completion: '0.015' }, + top_provider: { context_length: 128000 } + }, + { + id: 'meta/llama-3.1', + name: 'Llama 3.1', + context_length: 8000, + architecture: { modality: 'text', tokenizer: 'llama' }, + pricing: { prompt: '0.001', completion: '0.002' }, + top_provider: { context_length: 8000 } + } + ] + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const models = await provider.getModels(); + + const gpt4o = models.find(m => m.id === 'openai/gpt-4o'); + const llama = models.find(m => m.id === 'meta/llama-3.1'); + + assert.isTrue(gpt4o!.capabilities!.vision); + assert.isFalse(llama!.capabilities!.vision); + }); + + it('should detect reasoning for O-series models', async () => { + const modelsResponse = { + data: [ + { + id: 'openai/o1-preview', + name: 'O1 Preview', + context_length: 128000, + architecture: { modality: 'text', tokenizer: 'gpt-4' }, + pricing: { prompt: '0.015', completion: '0.060' }, + top_provider: { context_length: 128000 } + } + ] + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const models = await provider.getModels(); + + assert.isTrue(models[0].capabilities!.reasoning); + }); + }); + + describe('supportsVision API check', () => { + it('should check vision support via API with caching', async () => { + const visionModelsResponse = { + data: [ + { + id: 'openai/gpt-4o', + name: 'GPT-4o', + context_length: 128000, + architecture: { modality: 'multimodal', tokenizer: 'gpt-4', input_modalities: ['text', 'image'] }, + pricing: { prompt: '0.005', completion: '0.015' }, + top_provider: { context_length: 128000 } + } + ] + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(visionModelsResponse), { status: 200 }) + ); + + // First call - should fetch + const result1 = await provider.supportsVision('openai/gpt-4o'); + assert.isTrue(result1); + + // Second call - should use cache + const result2 = await provider.supportsVision('openai/gpt-4o'); + assert.isTrue(result2); + + // Should only fetch once (cached) + assert.strictEqual(fetchStub.callCount, 1); + }); + }); + }); + + // ============ Cerebras Provider Tests ============ + describe('CerebrasProvider', () => { + let provider: CerebrasProvider; + + beforeEach(() => { + provider = new CerebrasProvider('csk-test-key'); + }); + + describe('endpoint configuration', () => { + it('should use Cerebras API endpoint', async () => { + const mockResponse = createMockOpenAICompatibleResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('llama-3.3-70b', [ + { role: 'user', content: 'Hello' } + ]); + + const url = fetchStub.firstCall.args[0]; + assert.include(url, 'https://api.cerebras.ai/v1/chat/completions'); + }); + }); + + describe('temperature handling', () => { + it('should clamp temperature to 0-1.5 range', async () => { + const mockResponse = createMockOpenAICompatibleResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + // Test clamping high value + await provider.callWithMessages('llama-3.3-70b', [ + { role: 'user', content: 'Hello' } + ], { temperature: 2.0 }); + + let body = JSON.parse(fetchStub.firstCall.args[1].body); + assert.strictEqual(body.temperature, 1.5); + + fetchStub.resetHistory(); + + // Test clamping negative value + await provider.callWithMessages('llama-3.3-70b', [ + { role: 'user', content: 'Hello' } + ], { temperature: -0.5 }); + + body = JSON.parse(fetchStub.firstCall.args[1].body); + assert.strictEqual(body.temperature, 0); + }); + + it('should preserve valid temperature values', async () => { + const mockResponse = createMockOpenAICompatibleResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('llama-3.3-70b', [ + { role: 'user', content: 'Hello' } + ], { temperature: 0.8 }); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + assert.strictEqual(body.temperature, 0.8); + }); + }); + + describe('model capabilities', () => { + it('should detect function calling for supported Cerebras models', async () => { + const modelsResponse = { + object: 'list', + data: [ + { id: 'llama-3.3-70b', object: 'model', created: 1, owned_by: 'cerebras' }, + { id: 'qwen-3-32b', object: 'model', created: 1, owned_by: 'cerebras' } + ] + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const models = await provider.getModels(); + + assert.isTrue(models.every(m => m.capabilities!.functionCalling === true)); + // Cerebras doesn't support vision + assert.isTrue(models.every(m => m.capabilities!.vision === false)); + }); + }); + + describe('tool handling', () => { + it('should set tool_choice to auto when tools provided', async () => { + const mockResponse = createMockOpenAICompatibleResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('llama-3.3-70b', [ + { role: 'user', content: 'Search' } + ], { + tools: [{ + type: 'function', + function: { name: 'search', description: 'Search', parameters: {} } + }] + }); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + assert.strictEqual(body.tool_choice, 'auto'); + }); + }); + }); + + // ============ Common Tests for All Providers ============ + describe('Common behavior', () => { + const providers = [ + { name: 'LiteLLM', create: () => new LiteLLMProvider('key', 'http://localhost:4000') }, + { name: 'Groq', create: () => new GroqProvider('key') }, + { name: 'OpenRouter', create: () => new OpenRouterProvider('key') }, + { name: 'Cerebras', create: () => new CerebrasProvider('key') } + ]; + + providers.forEach(({ name, create }) => { + describe(`${name}Provider`, () => { + it('should use OpenAI-compatible Chat Completions format', async () => { + const provider = create(); + const mockResponse = createMockOpenAICompatibleResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('model', [ + { role: 'system', content: 'You are helpful' }, + { role: 'user', content: 'Hello' } + ]); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + + // Should have model and messages + assert.isDefined(body.model); + assert.isArray(body.messages); + + // Messages should be in OpenAI format + assert.strictEqual(body.messages[0].role, 'system'); + assert.strictEqual(body.messages[0].content, 'You are helpful'); + assert.strictEqual(body.messages[1].role, 'user'); + + fetchStub.restore(); + }); + + it('should extract text from choice.message.content', async () => { + const provider = create(); + const mockResponse = createMockOpenAICompatibleResponse({ text: 'Hello world!' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const response = await provider.callWithMessages('model', [ + { role: 'user', content: 'Hi' } + ]); + + assert.strictEqual(response.text, 'Hello world!'); + fetchStub.restore(); + }); + + it('should extract function call from tool_calls', async () => { + const provider = create(); + const mockResponse = createMockOpenAICompatibleResponse({ + functionCall: { + name: 'search', + arguments: { query: 'cats' } + } + }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const response = await provider.callWithMessages('model', [ + { role: 'user', content: 'Search for cats' } + ], { + tools: [{ + type: 'function', + function: { name: 'search', description: 'Search', parameters: {} } + }] + }); + + assert.isDefined(response.functionCall); + assert.strictEqual(response.functionCall!.name, 'search'); + assert.deepEqual(response.functionCall!.arguments, { query: 'cats' }); + fetchStub.restore(); + }); + + it('should throw on empty choices', async () => { + const provider = create(); + const emptyResponse = { choices: [] }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(emptyResponse), { status: 200 }) + ); + + try { + await provider.callWithMessages('model', [{ role: 'user', content: 'Hi' }]); + assert.fail('Should have thrown'); + } catch (error) { + assert.include((error as Error).message, 'No choices'); + } + fetchStub.restore(); + }); + + it('should include tools with proper parameters format', async () => { + const provider = create(); + const mockResponse = createMockOpenAICompatibleResponse({ text: 'Response' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const tools = [{ + type: 'function', + function: { + name: 'test_tool', + description: 'A test tool' + // No parameters - should add default + } + }]; + + await provider.callWithMessages('model', [ + { role: 'user', content: 'Use tool' } + ], { tools: tools as any }); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + + // Should have default parameters + assert.isDefined(body.tools[0].function.parameters); + assert.strictEqual(body.tools[0].function.parameters.type, 'object'); + fetchStub.restore(); + }); + + it('should return default models on API error', async () => { + const provider = create(); + fetchStub = sinon.stub(globalThis, 'fetch').rejects(new Error('Network error')); + + const models = await provider.getModels(); + + assert.isArray(models); + assert.isTrue(models.length > 0); + fetchStub.restore(); + }); + }); + }); + }); +}); diff --git a/front_end/panels/ai_chat/LLM/__tests__/OpenAIProvider.test.ts b/front_end/panels/ai_chat/LLM/__tests__/OpenAIProvider.test.ts new file mode 100644 index 0000000000..eb8fd4186e --- /dev/null +++ b/front_end/panels/ai_chat/LLM/__tests__/OpenAIProvider.test.ts @@ -0,0 +1,716 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { OpenAIProvider } from '../OpenAIProvider.js'; +import type { LLMMessage, LLMCallOptions } from '../LLMTypes.js'; +// Use global sinon provided by Karma framework +declare const sinon: typeof import('sinon'); +import { + createMockOpenAIResponse, + createMock401Response, + createMock429Response, + createMock500Response, + createMock503Response, + createLocalStorageMock, + createFastRetryConfig, + createTestMessages, + createMockToolDefinition, + STORAGE_KEYS, +} from './LLMTestHelpers.js'; + +describe('ai_chat: OpenAIProvider', () => { + let provider: OpenAIProvider; + let fetchStub: sinon.SinonStub; + let localStorageMock: ReturnType; + + const TEST_API_KEY = 'sk-test-api-key-12345'; + const API_ENDPOINT = 'https://api.openai.com/v1/responses'; + const MODELS_ENDPOINT = 'https://api.openai.com/v1/models'; + + beforeEach(() => { + provider = new OpenAIProvider(TEST_API_KEY); + localStorageMock = createLocalStorageMock({ + [STORAGE_KEYS.OPENAI_API_KEY]: TEST_API_KEY, + }); + }); + + afterEach(() => { + if (fetchStub) { + fetchStub.restore(); + } + localStorageMock.restore(); + sinon.restore(); + }); + + // ============ Constructor Tests ============ + describe('constructor', () => { + it('should set provider name correctly', () => { + assert.strictEqual(provider.name, 'openai'); + }); + + it('should store API key', () => { + // API key is private, but we can verify it works by checking requests + const newProvider = new OpenAIProvider('sk-different-key'); + assert.strictEqual(newProvider.name, 'openai'); + }); + }); + + // ============ callWithMessages Tests ============ + describe('callWithMessages', () => { + it('should make POST request to correct endpoint', async () => { + const mockResponse = createMockOpenAIResponse({ text: 'Hello!' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('gpt-4.1', createTestMessages()); + + assert.isTrue(fetchStub.calledOnce); + const [url, options] = fetchStub.firstCall.args; + assert.strictEqual(url, API_ENDPOINT); + assert.strictEqual(options.method, 'POST'); + }); + + it('should include Authorization header', async () => { + const mockResponse = createMockOpenAIResponse({ text: 'Hello!' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('gpt-4.1', createTestMessages()); + + const options = fetchStub.firstCall.args[1]; + assert.strictEqual(options.headers.Authorization, `Bearer ${TEST_API_KEY}`); + }); + + it('should include Content-Type header', async () => { + const mockResponse = createMockOpenAIResponse({ text: 'Hello!' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('gpt-4.1', createTestMessages()); + + const options = fetchStub.firstCall.args[1]; + assert.strictEqual(options.headers['Content-Type'], 'application/json'); + }); + + it('should use "input" array in request body (Responses API format)', async () => { + const mockResponse = createMockOpenAIResponse({ text: 'Hello!' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('gpt-4.1', createTestMessages()); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + assert.isDefined(body.input, 'Should use "input" array for Responses API'); + assert.isArray(body.input); + }); + + it('should handle text response correctly', async () => { + const mockResponse = createMockOpenAIResponse({ text: 'The answer is 42.' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const result = await provider.callWithMessages('gpt-4.1', createTestMessages()); + + assert.strictEqual(result.text, 'The answer is 42.'); + }); + + it('should handle function call response correctly', async () => { + const mockResponse = createMockOpenAIResponse({ + functionCall: { name: 'click_element', arguments: { selector: '#btn' } }, + }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const result = await provider.callWithMessages('gpt-4.1', createTestMessages()); + + assert.isDefined(result.functionCall); + assert.strictEqual(result.functionCall!.name, 'click_element'); + assert.deepEqual(result.functionCall!.arguments, { selector: '#btn' }); + }); + + it('should include tools in request when provided', async () => { + const mockResponse = createMockOpenAIResponse({ text: 'Done' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const tools = [createMockToolDefinition('click', 'Click an element')]; + await provider.callWithMessages('gpt-4.1', createTestMessages(), { tools }); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + assert.isDefined(body.tools); + assert.isArray(body.tools); + }); + + it('should include temperature for GPT models', async () => { + const mockResponse = createMockOpenAIResponse({ text: 'Hello' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('gpt-4.1', createTestMessages(), { temperature: 0.7 }); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + assert.strictEqual(body.temperature, 0.7); + }); + + it('should NOT include temperature for O-series models', async () => { + const mockResponse = createMockOpenAIResponse({ text: 'Hello' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('o4-mini', createTestMessages(), { temperature: 0.7 }); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + assert.isUndefined(body.temperature, 'O-series models should not have temperature'); + }); + + it('should NOT include temperature for GPT-5 models', async () => { + const mockResponse = createMockOpenAIResponse({ text: 'Hello' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('gpt-5', createTestMessages(), { temperature: 0.7 }); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + assert.isUndefined(body.temperature, 'GPT-5 models should not have temperature'); + }); + + it('should include reasoning.effort for O-series models when reasoningLevel provided', async () => { + const mockResponse = createMockOpenAIResponse({ text: 'Hello' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('o4-mini', createTestMessages(), { reasoningLevel: 'high' }); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + assert.isDefined(body.reasoning); + assert.strictEqual(body.reasoning.effort, 'high'); + }); + + it('should convert messages to Responses API format', async () => { + const mockResponse = createMockOpenAIResponse({ text: 'Hello' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const messages: LLMMessage[] = [ + { role: 'system', content: 'You are a helper.' }, + { role: 'user', content: 'Hi there' }, + ]; + + await provider.callWithMessages('gpt-4.1', messages); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + + assert.isArray(body.input); + assert.strictEqual(body.input[0].role, 'system'); + assert.strictEqual(body.input[1].role, 'user'); + }); + + it('should handle multimodal content with images', async () => { + const mockResponse = createMockOpenAIResponse({ text: 'I see an image' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const messages: LLMMessage[] = [ + { + role: 'user', + content: [ + { type: 'text', text: 'What is this?' }, + { type: 'image_url', image_url: { url: 'data:image/png;base64,abc123' } }, + ], + }, + ]; + + await provider.callWithMessages('gpt-4.1', messages); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + + // For GPT models, content should be converted to Responses API format + const userMessage = body.input.find((m: any) => m.role === 'user'); + assert.isDefined(userMessage); + }); + + it('should convert tool calls to Responses API format (function_call type)', async () => { + const mockResponse = createMockOpenAIResponse({ text: 'Done' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const messages: LLMMessage[] = [ + { role: 'system', content: 'Helper' }, + { role: 'user', content: 'Click button' }, + { + role: 'assistant', + tool_calls: [ + { + id: 'call_123', + type: 'function', + function: { name: 'click', arguments: '{"selector":"#btn"}' }, + }, + ], + }, + { + role: 'tool', + tool_call_id: 'call_123', + content: 'Clicked successfully', + }, + ]; + + await provider.callWithMessages('gpt-4.1', messages); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + + // Assistant tool call should be converted to function_call type + const functionCall = body.input.find((m: any) => m.type === 'function_call'); + assert.isDefined(functionCall, 'Should have function_call type message'); + assert.strictEqual(functionCall.name, 'click'); + + // Tool result should be converted to function_call_output type + const functionOutput = body.input.find((m: any) => m.type === 'function_call_output'); + assert.isDefined(functionOutput, 'Should have function_call_output type message'); + assert.strictEqual(functionOutput.call_id, 'call_123'); + }); + + it('should throw on API errors', async () => { + const errorResponse = { + error: { message: 'Invalid request', type: 'invalid_request_error' }, + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(errorResponse), { status: 400 }) + ); + + try { + await provider.callWithMessages('gpt-4.1', createTestMessages(), { + retryConfig: { maxRetries: 0, baseDelayMs: 0, maxDelayMs: 0, backoffMultiplier: 1, jitterMs: 0 }, + }); + assert.fail('Should have thrown'); + } catch (error) { + assert.include((error as Error).message, 'OpenAI API error'); + } + }); + + it('should store raw response', async () => { + const mockResponse = createMockOpenAIResponse({ text: 'Hello' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const result = await provider.callWithMessages('gpt-4.1', createTestMessages()); + + assert.isDefined(result.rawResponse); + assert.deepEqual(result.rawResponse, mockResponse); + }); + }); + + // ============ call Tests ============ + describe('call', () => { + it('should build messages from prompt and systemPrompt', async () => { + const mockResponse = createMockOpenAIResponse({ text: 'Hello' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.call('gpt-4.1', 'User prompt', 'System prompt'); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + + assert.isArray(body.input); + const systemMsg = body.input.find((m: any) => m.role === 'system'); + const userMsg = body.input.find((m: any) => m.role === 'user'); + + assert.isDefined(systemMsg); + assert.isDefined(userMsg); + }); + + it('should work without system prompt', async () => { + const mockResponse = createMockOpenAIResponse({ text: 'Hello' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.call('gpt-4.1', 'User prompt', ''); + + const options = fetchStub.firstCall.args[1]; + const body = JSON.parse(options.body); + + const systemMsg = body.input.find((m: any) => m.role === 'system'); + assert.isUndefined(systemMsg, 'Should not include empty system prompt'); + }); + }); + + // ============ getModels / fetchModels Tests ============ + describe('getModels', () => { + it('should fetch models from API', async () => { + const modelsResponse = { + data: [ + { id: 'gpt-4.1-2025-04-14', object: 'model' }, + { id: 'o4-mini-2025-04-16', object: 'model' }, + ], + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const models = await provider.getModels(); + + assert.isArray(models); + assert.isTrue(models.length > 0); + }); + + it('should return ModelInfo array with correct structure', async () => { + const modelsResponse = { + data: [ + { id: 'gpt-4.1-2025-04-14', object: 'model' }, + ], + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const models = await provider.getModels(); + + assert.isArray(models); + if (models.length > 0) { + const model = models[0]; + assert.isDefined(model.id); + assert.isDefined(model.name); + assert.strictEqual(model.provider, 'openai'); + assert.isDefined(model.capabilities); + } + }); + + it('should fallback to defaults on API error', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').rejects(new Error('Network error')); + + const models = await provider.getModels(); + + // Should return default models + assert.isArray(models); + assert.isTrue(models.length > 0); + }); + + it('should filter out non-chat models', async () => { + const modelsResponse = { + data: [ + { id: 'gpt-4.1-2025-04-14', object: 'model' }, + { id: 'text-embedding-3-large', object: 'model' }, + { id: 'whisper-1', object: 'model' }, + { id: 'dall-e-3', object: 'model' }, + ], + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + const models = await provider.getModels(); + + const modelIds = models.map(m => m.id); + assert.notInclude(modelIds, 'text-embedding-3-large'); + assert.notInclude(modelIds, 'whisper-1'); + assert.notInclude(modelIds, 'dall-e-3'); + }); + }); + + describe('fetchModels', () => { + it('should make GET request to models endpoint', async () => { + const modelsResponse = { + data: [{ id: 'gpt-4.1', object: 'model' }], + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(modelsResponse), { status: 200 }) + ); + + await provider.fetchModels(); + + assert.isTrue(fetchStub.calledOnce); + const [url, options] = fetchStub.firstCall.args; + assert.strictEqual(url, MODELS_ENDPOINT); + assert.strictEqual(options.method, 'GET'); + }); + + it('should throw on API errors', async () => { + const errorResponse = { error: { message: 'Unauthorized' } }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(errorResponse), { status: 401 }) + ); + + try { + await provider.fetchModels(); + assert.fail('Should have thrown'); + } catch (error) { + assert.include((error as Error).message, 'OpenAI API error'); + } + }); + }); + + // ============ validateCredentials Tests ============ + describe('validateCredentials', () => { + it('should return valid when API key present', () => { + const result = provider.validateCredentials(); + assert.isTrue(result.isValid); + }); + + it('should return invalid with missingItems when no API key', () => { + localStorageMock.restore(); + localStorageMock = createLocalStorageMock({}); // No API key + + const newProvider = new OpenAIProvider(''); + // Need to check localStorage in validateCredentials + const result = newProvider.validateCredentials(); + + // The provider was initialized with an API key, but validateCredentials + // checks localStorage for the key + assert.isFalse(result.isValid); + assert.isDefined(result.missingItems); + assert.include(result.missingItems!, 'API Key'); + }); + }); + + // ============ getCredentialStorageKeys Tests ============ + describe('getCredentialStorageKeys', () => { + it('should return correct storage keys', () => { + const keys = provider.getCredentialStorageKeys(); + assert.strictEqual(keys.apiKey, 'ai_chat_api_key'); + }); + }); + + // ============ parseResponse Tests ============ + describe('parseResponse', () => { + it('should delegate to LLMResponseParser', () => { + const response = { text: 'Hello', rawResponse: {} }; + const parsed = provider.parseResponse(response); + + assert.isDefined(parsed); + assert.strictEqual(parsed.type, 'final_answer'); + }); + + it('should parse function call responses', () => { + const response = { + functionCall: { name: 'test', arguments: { arg: 1 } }, + rawResponse: {}, + }; + const parsed = provider.parseResponse(response); + + assert.strictEqual(parsed.type, 'tool_call'); + }); + }); + + // ============ Model Family Tests ============ + describe('Model Family Handling', () => { + it('should identify O-series models correctly', async () => { + const mockResponse = createMockOpenAIResponse({ text: 'Hello' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + // o1, o3, o4 models should be identified as O-series + await provider.callWithMessages('o1-preview', createTestMessages(), { temperature: 0.5 }); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + assert.isUndefined(body.temperature, 'O1 should not have temperature'); + }); + + it('should identify GPT models correctly', async () => { + const mockResponse = createMockOpenAIResponse({ text: 'Hello' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('gpt-4.1', createTestMessages(), { temperature: 0.5 }); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + assert.strictEqual(body.temperature, 0.5, 'GPT should have temperature'); + }); + + it('should treat GPT-5 like O-series for parameters', async () => { + const mockResponse = createMockOpenAIResponse({ text: 'Hello' }); + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + await provider.callWithMessages('gpt-5-2025-08-07', createTestMessages(), { temperature: 0.5 }); + + const body = JSON.parse(fetchStub.firstCall.args[1].body); + assert.isUndefined(body.temperature, 'GPT-5 should not have temperature'); + }); + }); + + // ============ Error Scenarios ============ + describe('error scenarios', () => { + it('should handle 401 Unauthorized', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + createMock401Response('openai') + ); + + try { + await provider.callWithMessages('gpt-4.1', createTestMessages(), { + retryConfig: { maxRetries: 0, baseDelayMs: 0, maxDelayMs: 0, backoffMultiplier: 1, jitterMs: 0 }, + }); + assert.fail('Should have thrown'); + } catch (error) { + assert.include((error as Error).message.toLowerCase(), 'api'); + } + }); + + it('should handle 429 Rate Limit', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + createMock429Response('openai') + ); + + try { + await provider.callWithMessages('gpt-4.1', createTestMessages(), { + retryConfig: { maxRetries: 0, baseDelayMs: 0, maxDelayMs: 0, backoffMultiplier: 1, jitterMs: 0 }, + }); + assert.fail('Should have thrown'); + } catch (error) { + assert.include((error as Error).message.toLowerCase(), 'rate limit'); + } + }); + + it('should handle 500 Internal Server Error', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + createMock500Response('openai') + ); + + try { + await provider.callWithMessages('gpt-4.1', createTestMessages(), { + retryConfig: { maxRetries: 0, baseDelayMs: 0, maxDelayMs: 0, backoffMultiplier: 1, jitterMs: 0 }, + }); + assert.fail('Should have thrown'); + } catch (error) { + assert.include((error as Error).message.toLowerCase(), 'server'); + } + }); + + it('should handle 503 Service Unavailable', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + createMock503Response('openai') + ); + + try { + await provider.callWithMessages('gpt-4.1', createTestMessages(), { + retryConfig: { maxRetries: 0, baseDelayMs: 0, maxDelayMs: 0, backoffMultiplier: 1, jitterMs: 0 }, + }); + assert.fail('Should have thrown'); + } catch (error) { + assert.include((error as Error).message.toLowerCase(), 'unavailable'); + } + }); + + it('should handle network errors', async () => { + fetchStub = sinon.stub(globalThis, 'fetch').rejects(new Error('Network error')); + + try { + await provider.callWithMessages('gpt-4.1', createTestMessages(), { + retryConfig: { maxRetries: 0, baseDelayMs: 0, maxDelayMs: 0, backoffMultiplier: 1, jitterMs: 0 }, + }); + assert.fail('Should have thrown'); + } catch (error) { + assert.include((error as Error).message, 'Network error'); + } + }); + + it('should retry on transient errors with retryConfig', async () => { + let callCount = 0; + fetchStub = sinon.stub(globalThis, 'fetch').callsFake(async () => { + callCount++; + if (callCount < 3) { + return createMock500Response('openai'); + } + return new Response(JSON.stringify(createMockOpenAIResponse({ text: 'Success' })), { status: 200 }); + }); + + const result = await provider.callWithMessages('gpt-4.1', createTestMessages(), { + retryConfig: createFastRetryConfig(3), + }); + + assert.strictEqual(result.text, 'Success'); + assert.strictEqual(callCount, 3); + }); + }); + + // ============ Response Processing Tests ============ + describe('Response Processing', () => { + it('should extract reasoning info from O-series responses', async () => { + const mockResponse = { + output: [ + { type: 'message', content: [{ type: 'output_text', text: 'Result' }] }, + ], + reasoning: { + summary: ['Step 1', 'Step 2'], + effort: 'high', + }, + usage: { input_tokens: 100, output_tokens: 50 }, + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const result = await provider.callWithMessages('o4-mini', createTestMessages()); + + assert.isDefined(result.reasoning); + assert.deepEqual(result.reasoning!.summary, ['Step 1', 'Step 2']); + assert.strictEqual(result.reasoning!.effort, 'high'); + }); + + it('should handle empty output array', async () => { + const mockResponse = { + output: [], + usage: { input_tokens: 100, output_tokens: 0 }, + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + try { + await provider.callWithMessages('gpt-4.1', createTestMessages(), { + retryConfig: { maxRetries: 0, baseDelayMs: 0, maxDelayMs: 0, backoffMultiplier: 1, jitterMs: 0 }, + }); + assert.fail('Should have thrown'); + } catch (error) { + assert.include((error as Error).message, 'No output'); + } + }); + + it('should find function_call in any position in output array', async () => { + const mockResponse = { + output: [ + { type: 'message', content: [{ type: 'output_text', text: 'Thinking...' }] }, + { type: 'function_call', name: 'click', arguments: '{"selector": "#btn"}', call_id: 'call_1' }, + ], + usage: { input_tokens: 100, output_tokens: 50 }, + }; + fetchStub = sinon.stub(globalThis, 'fetch').resolves( + new Response(JSON.stringify(mockResponse), { status: 200 }) + ); + + const result = await provider.callWithMessages('gpt-4.1', createTestMessages()); + + assert.isDefined(result.functionCall); + assert.strictEqual(result.functionCall!.name, 'click'); + }); + }); +}); diff --git a/front_end/panels/ai_chat/agent_framework/AgentRunner.ts b/front_end/panels/ai_chat/agent_framework/AgentRunner.ts index 71a8a541c1..c8805a4d34 100644 --- a/front_end/panels/ai_chat/agent_framework/AgentRunner.ts +++ b/front_end/panels/ai_chat/agent_framework/AgentRunner.ts @@ -16,6 +16,7 @@ import { AgentRunnerEventBus } from './AgentRunnerEventBus.js'; import { callLLMWithTracing } from '../tools/LLMTracingWrapper.js'; import { sanitizeMessagesForModel } from '../LLM/MessageSanitizer.js'; import { FileStorageManager } from '../tools/FileStorageManager.js'; +// Note: Guardrails are now handled by GuardrailMiddleware in AgentNodes.ts (single evaluation point) const logger = createLogger('AgentRunner'); @@ -430,12 +431,13 @@ export class AgentRunner { hooks: AgentRunnerHooks, executingAgent: ConfigurableAgentTool | null, parentSession?: AgentSession, // For natural nesting - overrides?: { sessionId?: string; parentSessionId?: string; traceId?: string }, + overrides?: { sessionId?: string; parentSessionId?: string; traceId?: string; background?: boolean }, abortSignal?: AbortSignal ): Promise { const agentName = executingAgent?.name || 'Unknown'; logger.info(`Starting execution loop for agent: ${agentName}`); const { apiKey, modelName, systemPrompt, tools, maxIterations, temperature, agentDescriptor } = config; + const isBackground = overrides?.background === true; const { prepareInitialMessages, createSuccessResult, createErrorResult, afterExecute } = hooks; @@ -463,8 +465,8 @@ export class AgentRunner { // Use local session variable instead of static let currentSession = agentSession; - // Emit session started event - if (AgentRunner.eventBus) { + // Emit session started event (skip for background agents) + if (AgentRunner.eventBus && !isBackground) { AgentRunner.eventBus.emitProgress({ type: 'session_started', sessionId: agentSession.sessionId, @@ -485,20 +487,20 @@ export class AgentRunner { currentSession.messages.push(fullMessage); - // Emit progress events based on message type - if (AgentRunner.eventBus && fullMessage.type === 'tool_call') { + // Emit progress events based on message type (skip for background agents) + if (AgentRunner.eventBus && !isBackground && fullMessage.type === 'tool_call') { AgentRunner.eventBus.emitProgress({ type: 'tool_started', sessionId: currentSession.sessionId, parentSessionId: currentSession.parentSessionId, agentName: currentSession.agentName, timestamp: new Date(), - data: { + data: { session: currentSession, toolCall: fullMessage } }); - } else if (AgentRunner.eventBus && fullMessage.type === 'tool_result') { + } else if (AgentRunner.eventBus && !isBackground && fullMessage.type === 'tool_result') { AgentRunner.eventBus.emitProgress({ type: 'tool_completed', sessionId: currentSession.sessionId, @@ -591,8 +593,8 @@ export class AgentRunner { currentSession.endTime = new Date(); currentSession.terminationReason = 'error'; - // Emit session completed event - if (AgentRunner.eventBus) { + // Emit session completed event (skip for background agents) + if (AgentRunner.eventBus && !isBackground) { AgentRunner.eventBus.emitProgress({ type: 'session_completed', sessionId: currentSession.sessionId, @@ -836,8 +838,8 @@ export class AgentRunner { agentSession.endTime = new Date(); agentSession.terminationReason = 'error'; - // Emit session completed event - if (AgentRunner.eventBus) { + // Emit session completed event (skip for background agents) + if (AgentRunner.eventBus && !isBackground) { AgentRunner.eventBus.emitProgress({ type: 'session_completed', sessionId: agentSession.sessionId, @@ -1017,8 +1019,8 @@ export class AgentRunner { agentSession.endTime = new Date(); agentSession.terminationReason = 'handed_off'; - // Emit session completed event - if (AgentRunner.eventBus) { + // Emit session completed event (skip for background agents) + if (AgentRunner.eventBus && !isBackground) { AgentRunner.eventBus.emitProgress({ type: 'session_completed', sessionId: agentSession.sessionId, @@ -1104,8 +1106,8 @@ export class AgentRunner { } }); - // Emit child agent starting - if (AgentRunner.eventBus) { + // Emit child agent starting (skip for background agents) + if (AgentRunner.eventBus && !isBackground) { AgentRunner.eventBus.emitProgress({ type: 'child_agent_started', sessionId: currentSession.sessionId, @@ -1124,6 +1126,9 @@ export class AgentRunner { try { logger.info(`${agentName} Executing tool: ${toolToExecute.name}`); const execTracingContext = getCurrentTracingContext(); + + // Note: Guardrails are now handled at the AgentNodes layer (single evaluation point) + // Execute tool directly toolResultData = await toolToExecute.execute(toolArgs as any, ({ apiKey: config.apiKey, provider: config.provider, @@ -1326,8 +1331,8 @@ export class AgentRunner { agentSession.endTime = new Date(); agentSession.terminationReason = 'final_answer'; - // Emit session completed event - if (AgentRunner.eventBus) { + // Emit session completed event (skip for background agents) + if (AgentRunner.eventBus && !isBackground) { AgentRunner.eventBus.emitProgress({ type: 'session_completed', sessionId: agentSession.sessionId, @@ -1388,8 +1393,8 @@ export class AgentRunner { agentSession.endTime = new Date(); agentSession.terminationReason = 'error'; - // Emit session completed event - if (AgentRunner.eventBus) { + // Emit session completed event (skip for background agents) + if (AgentRunner.eventBus && !isBackground) { AgentRunner.eventBus.emitProgress({ type: 'session_completed', sessionId: agentSession.sessionId, @@ -1461,8 +1466,8 @@ export class AgentRunner { agentSession.endTime = new Date(); agentSession.terminationReason = 'handed_off'; - // Emit session completed event - if (AgentRunner.eventBus) { + // Emit session completed event (skip for background agents) + if (AgentRunner.eventBus && !isBackground) { AgentRunner.eventBus.emitProgress({ type: 'session_completed', sessionId: agentSession.sessionId, @@ -1485,8 +1490,8 @@ export class AgentRunner { agentSession.endTime = new Date(); agentSession.terminationReason = 'max_iterations'; - // Emit session completed event - if (AgentRunner.eventBus) { + // Emit session completed event (skip for background agents) + if (AgentRunner.eventBus && !isBackground) { AgentRunner.eventBus.emitProgress({ type: 'session_completed', sessionId: agentSession.sessionId, diff --git a/front_end/panels/ai_chat/agent_framework/AgentRunnerEventBus.ts b/front_end/panels/ai_chat/agent_framework/AgentRunnerEventBus.ts index 3094835e71..a68593953b 100644 --- a/front_end/panels/ai_chat/agent_framework/AgentRunnerEventBus.ts +++ b/front_end/panels/ai_chat/agent_framework/AgentRunnerEventBus.ts @@ -5,7 +5,7 @@ import * as Common from '../../../core/common/common.js'; export interface AgentRunnerProgressEvent { - type: 'session_started' | 'tool_started' | 'tool_completed' | 'session_updated' | 'child_agent_started' | 'session_completed'; + type: 'session_started' | 'tool_started' | 'tool_completed' | 'session_updated' | 'child_agent_started' | 'session_completed' | 'approval_requested'; sessionId: string; parentSessionId?: string; agentName: string; diff --git a/front_end/panels/ai_chat/agent_framework/AgentSessionTypes.ts b/front_end/panels/ai_chat/agent_framework/AgentSessionTypes.ts index 41aae1eba0..1c1586ba0d 100644 --- a/front_end/panels/ai_chat/agent_framework/AgentSessionTypes.ts +++ b/front_end/panels/ai_chat/agent_framework/AgentSessionTypes.ts @@ -4,6 +4,7 @@ import type { AgentDescriptor } from '../core/AgentDescriptorRegistry.js'; import type { AgentToolConfig } from './ConfigurableAgentTool.js'; +import type { RiskLevel } from '../models/ChatTypes.js'; /** * Agent session represents a complete execution context for an agent @@ -46,10 +47,10 @@ export interface AgentSession { export interface AgentMessage { id: string; timestamp: Date; - type: 'reasoning' | 'tool_call' | 'tool_result' | 'handoff' | 'final_answer'; - + type: 'reasoning' | 'tool_call' | 'tool_result' | 'handoff' | 'final_answer' | 'approval_request'; + // Message Content (varies by type) - content: ReasoningMessage | ToolCallMessage | ToolResultMessage | HandoffMessage | FinalAnswerMessage; + content: ReasoningMessage | ToolCallMessage | ToolResultMessage | HandoffMessage | FinalAnswerMessage | ApprovalRequestContent; } /** @@ -105,6 +106,22 @@ export interface FinalAnswerMessage { summary?: string; } +/** + * Approval request for human-in-the-loop + */ +export interface ApprovalRequestContent { + type: 'approval_request'; + approvalId: string; + toolName: string; + toolArgs: Record; + riskLevel: RiskLevel; + description: string; + reasoning?: string; + policyMatched?: string; + status: 'pending' | 'approved' | 'rejected'; + feedback?: string; +} + /** * Default UI configuration for agents */ diff --git a/front_end/panels/ai_chat/agent_framework/ConfigurableAgentTool.ts b/front_end/panels/ai_chat/agent_framework/ConfigurableAgentTool.ts index a908677c3a..2a48355b9d 100644 --- a/front_end/panels/ai_chat/agent_framework/ConfigurableAgentTool.ts +++ b/front_end/panels/ai_chat/agent_framework/ConfigurableAgentTool.ts @@ -30,6 +30,8 @@ export interface CallCtx { overrideTraceId?: string, abortSignal?: AbortSignal, agentDescriptor?: AgentDescriptor, + /** If true, don't emit UI progress events (for background agents) */ + background?: boolean, } /** @@ -595,6 +597,7 @@ export class ConfigurableAgentTool implements Tool new BookmarkStoreTool()); ToolRegistry.registerToolFactory('document_search', () => new DocumentSearchTool()); + + // Register memory tools + ToolRegistry.registerToolFactory('search_memory', () => new SearchMemoryTool()); + ToolRegistry.registerToolFactory('update_memory', () => new UpdateMemoryTool()); + ToolRegistry.registerToolFactory('list_memory_blocks', () => new ListMemoryBlocksTool()); // Create and register Direct URL Navigator Agent const directURLNavigatorAgentConfig = createDirectURLNavigatorAgentConfig(); @@ -131,4 +137,14 @@ export function initializeConfiguredAgents(): void { const ecommerceProductInfoAgent = new ConfigurableAgentTool(ecommerceProductInfoAgentConfig); ToolRegistry.registerToolFactory('ecommerce_product_info_fetcher_tool', () => ecommerceProductInfoAgent); + // Create and register Memory Agent (background memory consolidation) + const memoryAgentConfig = createMemoryAgentConfig('extraction'); + const memoryAgent = new ConfigurableAgentTool(memoryAgentConfig); + ToolRegistry.registerToolFactory('memory_agent', () => memoryAgent); + + // Create and register Search Memory Agent (read-only memory search for orchestrators) + const searchMemoryAgentConfig = createMemoryAgentConfig('search'); + const searchMemoryAgent = new ConfigurableAgentTool(searchMemoryAgentConfig); + ToolRegistry.registerToolFactory('search_memory_agent', () => searchMemoryAgent); + } diff --git a/front_end/panels/ai_chat/agent_framework/implementation/agents/MemoryAgent.ts b/front_end/panels/ai_chat/agent_framework/implementation/agents/MemoryAgent.ts new file mode 100644 index 0000000000..acd65e5167 --- /dev/null +++ b/front_end/panels/ai_chat/agent_framework/implementation/agents/MemoryAgent.ts @@ -0,0 +1,176 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import type { AgentToolConfig } from '../../ConfigurableAgentTool.js'; +import { ChatMessageEntity } from '../../../models/ChatTypes.js'; +import type { ChatMessage } from '../../../models/ChatTypes.js'; +import type { ConfigurableAgentArgs } from '../../ConfigurableAgentTool.js'; +import { MODEL_SENTINELS } from '../../../core/Constants.js'; +import { AGENT_VERSION } from './AgentVersion.js'; + +const MEMORY_AGENT_PROMPT = `You are a Memory Consolidation Agent that runs in the background after conversations end. + +## Your Purpose +Extract and organize important information from completed conversations into persistent memory blocks that will help the assistant in future conversations. + +## Memory Block Types + +| Block | Purpose | Max Size | +|-------|---------|----------| +| user | User identity, preferences, communication style | 20000 chars | +| facts | Factual information learned from conversations | 20000 chars | +| project_ | Project-specific context (up to 4 projects) | 20000 chars each | + +## Workflow + +1. **List current memory** using list_memory_blocks +2. **Analyze** the conversation for extractable information +3. **Check for duplicates** before adding new facts +4. **Update blocks** with consolidated, organized content +5. **Verify** changes are correct and within limits + +## What to Extract + +### High Priority (Always Extract) +- User's name, role, job title +- Explicit preferences ("I prefer...", "I like...", "Always use...") +- Project names, tech stacks, goals +- Recurring patterns in requests + +### Medium Priority (Extract if Relevant) +- Problem-solving approaches that worked +- Tools/libraries the user uses frequently +- Team members or collaborators mentioned + +### Skip (Do Not Extract) +- One-time troubleshooting details +- Temporary debugging information +- Generic conversation pleasantries +- Information already in memory + +## Writing Guidelines + +### Be Specific with Dates +❌ "Recently discussed migration" +✅ "2025-01-15: Discussed database migration to PostgreSQL" + +### Be Concise +❌ "The user mentioned that they have a strong preference for using TypeScript in their projects because they find it helps catch errors" +✅ "Prefers TypeScript for type safety" + +### Use Bullet Points +\`\`\` +- Name: Alex Chen +- Role: Senior Frontend Engineer +- Prefers: TypeScript, React, Tailwind CSS +- Dislikes: Inline styles, any types +\`\`\` + +### Consolidate Related Info +If user block has: +\`\`\` +- Likes dark mode +- Uses VS Code +- Prefers dark themes +\`\`\` + +Consolidate to: +\`\`\` +- Prefers dark mode/themes +- Uses VS Code +\`\`\` + +## Examples + +### Example 1: User Preferences +**Conversation excerpt:** +> User: "Hey, I'm Sarah. Can you help me debug this React component? I always use functional components with hooks, never class components." + +**Memory update (user block):** +\`\`\` +- Name: Sarah +- React: Functional components + hooks only, no class components +\`\`\` + +### Example 2: Project Context +**Conversation excerpt:** +> User: "Working on our e-commerce platform. We're using Next.js 14 with App Router, Prisma for the database, and Stripe for payments." + +**Memory update (project_ecommerce block):** +\`\`\` +Project: E-commerce Platform +Stack: Next.js 14 (App Router), Prisma, Stripe +\`\`\` + +### Example 3: Skip Extraction +**Conversation excerpt:** +> User: "Getting a 404 error on /api/users endpoint" +> Assistant: "The route file is missing, create app/api/users/route.ts" +> User: "Fixed, thanks!" + +**Action:** No extraction needed - one-time debugging, no lasting value. + +## Output +After processing, briefly state what was updated or why nothing was updated. +`; + +/** + * Create the configuration for the Memory Agent + */ +export function createMemoryAgentConfig(): AgentToolConfig { + return { + name: 'memory_agent', + version: AGENT_VERSION, + description: 'Background memory consolidation agent that extracts facts from conversations and maintains organized memory blocks.', + + ui: { + displayName: 'Memory Agent', + avatar: '🧠', + color: '#8b5cf6', + backgroundColor: '#f5f3ff' + }, + + systemPrompt: MEMORY_AGENT_PROMPT, + + tools: [ + 'search_memory', + 'update_memory', + 'list_memory_blocks', + ], + + schema: { + type: 'object', + properties: { + conversation_summary: { + type: 'string', + description: 'Summary of the conversation to analyze for memory extraction' + }, + reasoning: { + type: 'string', + description: 'Why this extraction is being run' + } + }, + required: ['conversation_summary', 'reasoning'] + }, + + prepareMessages: (args: ConfigurableAgentArgs): ChatMessage[] => { + return [{ + entity: ChatMessageEntity.USER, + text: `## Conversation to Analyze + +${args.conversation_summary || ''} + +## Reason for Extraction +${args.reasoning || 'Automatic extraction after session completion'} + +Please analyze this conversation and update memory blocks as appropriate.`, + }]; + }, + + maxIterations: 5, + modelName: MODEL_SENTINELS.USE_MINI, // Cost-effective for background task + temperature: 0.1, + handoffs: [], + }; +} diff --git a/front_end/panels/ai_chat/agent_framework/implementation/agents/SearchMemoryAgent.ts b/front_end/panels/ai_chat/agent_framework/implementation/agents/SearchMemoryAgent.ts new file mode 100644 index 0000000000..a4ba71a0f0 --- /dev/null +++ b/front_end/panels/ai_chat/agent_framework/implementation/agents/SearchMemoryAgent.ts @@ -0,0 +1,111 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import type { AgentToolConfig, ConfigurableAgentArgs } from '../../ConfigurableAgentTool.js'; +import { ChatMessageEntity } from '../../../models/ChatTypes.js'; +import type { ChatMessage } from '../../../models/ChatTypes.js'; +import { MODEL_SENTINELS } from '../../../core/Constants.js'; +import { AGENT_VERSION } from './AgentVersion.js'; + +const SEARCH_MEMORY_AGENT_PROMPT = `You are a Memory Retrieval Agent. Your job is to find and summarize relevant information from stored memory to help the assistant respond to the user. + +## Memory Structure + +| Block | Contains | +|-------|----------| +| user | User identity, preferences, communication style | +| facts | Factual information from past conversations | +| project_* | Project-specific context (tech stack, goals, current work) | + +## Workflow + +1. Use list_memory_blocks to retrieve all stored memory +2. Scan each block for information relevant to the query +3. Return a concise summary of relevant findings + +## Response Format + +### When Memory Exists +Return relevant information organized by category: + +\`\`\` +**User Context:** +- Name: Sarah, Senior Frontend Engineer +- Prefers TypeScript, functional React components + +**Relevant Project:** +- E-commerce Platform: Next.js 14, Prisma, Stripe + +**Related Facts:** +- 2025-01-10: Migrated auth to NextAuth.js +\`\`\` + +### When No Memory Exists +Simply respond: "No relevant memory found." + +### When Memory is Empty +Simply respond: "No memory stored yet." + +## Guidelines + +- Only include information relevant to the query +- Don't dump entire blocks - summarize what's useful +- Prioritize recent information over old +- If query is vague, return user preferences + active project context +`; + +/** + * Create the configuration for the Search Memory Agent. + * This agent provides read-only memory search capability to orchestrator agents. + */ +export function createSearchMemoryAgentConfig(): AgentToolConfig { + return { + name: 'search_memory_agent', + version: AGENT_VERSION, + description: 'Search user memory for relevant information. Use when you need to recall user preferences, past facts, or project context.', + + ui: { + displayName: 'Search Memory', + avatar: '🔍', + color: '#10b981', + backgroundColor: '#ecfdf5' + }, + + systemPrompt: SEARCH_MEMORY_AGENT_PROMPT, + + tools: [ + 'list_memory_blocks', // Returns all memory block contents directly + ], + + schema: { + type: 'object', + properties: { + query: { + type: 'string', + description: 'What to search for in memory (user preferences, facts, project info)' + }, + context: { + type: 'string', + description: 'Why this search is needed (helps with relevance)' + } + }, + required: ['query'] + }, + + prepareMessages: (args: ConfigurableAgentArgs): ChatMessage[] => { + return [{ + entity: ChatMessageEntity.USER, + text: `Search memory for: ${args.query || ''} +${args.context ? `\nContext: ${args.context}` : ''} + +Please search memory and return any relevant information.`, + }]; + }, + + maxIterations: 2, // Just need to list and respond + modelName: MODEL_SENTINELS.USE_NANO, // Fast, cheap model for simple searches + temperature: 0, + handoffs: [], + }; +} diff --git a/front_end/panels/ai_chat/core/AgentNodes.ts b/front_end/panels/ai_chat/core/AgentNodes.ts index aba0b7af4a..dd935e363f 100644 --- a/front_end/panels/ai_chat/core/AgentNodes.ts +++ b/front_end/panels/ai_chat/core/AgentNodes.ts @@ -3,8 +3,9 @@ // found in the LICENSE file. import type { getTools } from '../tools/Tools.js'; -import { ChatMessageEntity, type ModelChatMessage, type ToolResultMessage, type ChatMessage, type AgentSessionMessage } from '../models/ChatTypes.js'; +import { ChatMessageEntity, type ModelChatMessage, type ToolResultMessage, type ChatMessage, type AgentSessionMessage, type ApprovalRequestMessage } from '../models/ChatTypes.js'; import { ConfigurableAgentTool, ToolRegistry } from '../agent_framework/ConfigurableAgentTool.js'; +import { getGuardrailMiddleware, type ExecutionContext, type ApprovalRequest } from '../guardrails/index.js'; import { LLMClient } from '../LLM/LLMClient.js'; import type { LLMMessage } from '../LLM/LLMTypes.js'; @@ -713,6 +714,110 @@ export function createToolExecutorNode(state: AgentState, provider: LLMProvider, return newState; } + // ===== GUARDRAIL EVALUATION ===== + // Use the unified GuardrailMiddleware for evaluation and approval + try { + const currentUrl = state.currentPageUrl || ''; + let currentDomain = ''; + try { + if (currentUrl) { + currentDomain = new URL(currentUrl).hostname; + } + } catch { + // Ignore URL parsing errors + } + + const context: ExecutionContext = { + currentUrl, + currentDomain, + userGoal: state.messages.find(m => m.entity === ChatMessageEntity.USER) + ? (state.messages.find(m => m.entity === ChatMessageEntity.USER) as any)?.text + : undefined, + }; + + // Helper to create approval message from middleware request + const createApprovalMessage = (request: ApprovalRequest): ApprovalRequestMessage => ({ + entity: ChatMessageEntity.APPROVAL_REQUEST, + approvalId: request.id, + toolName, + toolArgs, + description: request.decision.suggestedMessage, + status: request.status, + riskLevel: request.decision.riskLevel, + reasoning: request.decision.reasoning, + policyMatched: request.decision.policyMatched, + toolCallId, + uiLane: 'chat', + }); + + // Use middleware gate - handles evaluation, approval request, and waiting + const gateResult = await getGuardrailMiddleware().gate( + { name: toolName, args: toolArgs, callId: toolCallId || '' }, + context, + (request) => { + // Callback when approval is needed - add message to state for UI + messages.push(createApprovalMessage(request)); + logger.info('Tool requires approval, waiting for user decision', { + approvalId: request.id, + toolName, + riskLevel: request.decision.riskLevel, + }); + }, + selectedTool.approvalConfig + ); + + // If not approved, create rejection message and return + if (!gateResult.proceed) { + // Update approval message status if it exists + const lastApprovalMsg = messages.findLast( + m => m.entity === ChatMessageEntity.APPROVAL_REQUEST + ) as ApprovalRequestMessage | undefined; + if (lastApprovalMsg) { + lastApprovalMsg.status = 'rejected'; + if (gateResult.feedback) { + lastApprovalMsg.feedback = gateResult.feedback; + } + } + + // Create error result with feedback so agent can adapt + const rejectionText = gateResult.feedback + ? `REJECTED by user: ${gateResult.feedback}. Please try a different approach.` + : 'REJECTED by user. Please try a different approach or ask for clarification.'; + + const toolResultMessage: ToolResultMessage = { + entity: ChatMessageEntity.TOOL_RESULT, + toolName, + resultText: rejectionText, + isError: true, + toolCallId, + error: rejectionText, + uiLane: 'chat', + }; + messages.push(toolResultMessage); + + logger.info('Tool execution rejected', { toolName, feedback: gateResult.feedback }); + return { ...state, messages: [...messages] }; + } + + // If we had an approval message, update its status to approved + const approvedMsg = messages.findLast( + m => m.entity === ChatMessageEntity.APPROVAL_REQUEST + ) as ApprovalRequestMessage | undefined; + if (approvedMsg && approvedMsg.status === 'pending') { + approvedMsg.status = 'approved'; + } + + logger.debug('Guardrail gate passed', { toolName, decision: gateResult.decision?.decision }); + } catch (guardrailError) { + // Guardrail evaluation failed - log and continue with execution + // We don't want guardrail errors to block tool execution + logger.warn('Guardrail evaluation failed, continuing with execution', { + toolName, + error: guardrailError, + }); + } + // ===== END GUARDRAIL EVALUATION ===== + // Create span for tool execution const tracingContext = state.context?.tracingContext; let spanId: string | undefined; @@ -1067,3 +1172,4 @@ export function createFinalNode(): Runnable { }(); return finalNode; } + diff --git a/front_end/panels/ai_chat/core/AgentService.ts b/front_end/panels/ai_chat/core/AgentService.ts index 0106071fbe..5edf570ade 100644 --- a/front_end/panels/ai_chat/core/AgentService.ts +++ b/front_end/panels/ai_chat/core/AgentService.ts @@ -7,7 +7,7 @@ import * as Common from '../../../core/common/common.js'; import * as i18n from '../../../core/i18n/i18n.js'; import * as SDK from '../../../core/sdk/sdk.js'; import * as UI from '../../../ui/legacy/legacy.js'; -import { type ChatMessage, ChatMessageEntity, type ImageInputData, type ModelChatMessage } from '../models/ChatTypes.js'; +import { type ChatMessage, ChatMessageEntity, type ImageInputData, type ModelChatMessage, type ApprovalRequestMessage } from '../models/ChatTypes.js'; import {createAgentGraph} from './Graph.js'; import { createLogger } from './Logger.js'; @@ -29,6 +29,9 @@ import { BUILD_CONFIG } from './BuildConfig.js'; import { VisualIndicatorManager } from '../tools/VisualIndicatorTool.js'; import { ConversationManager } from '../persistence/ConversationManager.js'; import type { ConversationMetadata } from '../persistence/ConversationTypes.js'; +import { ToolRegistry } from '../agent_framework/ConfigurableAgentTool.js'; +import { MemoryModule } from '../memory/index.js'; +import { getGuardrailMiddleware, GuardrailEvents, type ApprovalResolvedEvent } from '../guardrails/index.js'; // Cache break: 2025-09-17T17:54:00Z - Force rebuild with AUTOMATED_MODE bypass const logger = createLogger('AgentService'); @@ -46,6 +49,7 @@ export enum Events { CHILD_AGENT_STARTED = 'child-agent-started', CONVERSATION_CHANGED = 'conversation-changed', CONVERSATION_SAVED = 'conversation-saved', + APPROVAL_REQUESTED = 'approval-requested', } /** @@ -61,6 +65,7 @@ export class AgentService extends Common.ObjectWrapper.ObjectWrapper<{ [Events.CHILD_AGENT_STARTED]: { parentSession: AgentSession, childAgentName: string, childSessionId: string }, [Events.CONVERSATION_CHANGED]: string | null, [Events.CONVERSATION_SAVED]: string, + [Events.APPROVAL_REQUESTED]: ApprovalRequestMessage, }> { static instance: AgentService; @@ -160,11 +165,23 @@ export class AgentService extends Common.ObjectWrapper.ObjectWrapper<{ // Subscribe to AgentRunner events AgentRunnerEventBus.getInstance().addEventListener('agent-progress', this.#handleAgentProgress.bind(this)); + // Subscribe to GuardrailMiddleware events for UI updates after user action + getGuardrailMiddleware().addEventListener( + GuardrailEvents.APPROVAL_RESOLVED, + this.#handleApprovalResolved.bind(this) + ); + // Initialize visual indicator system with reference to AgentService VisualIndicatorManager.getInstance().initialize(this); // Subscribe to configuration changes this.#configManager.addChangeListener(this.#handleConfigurationChange.bind(this)); + + // Process any old conversations that missed memory extraction + // Delay to avoid blocking startup and ensure tools are registered + setTimeout(() => { + this.processUnprocessedConversations(); + }, 5000); } /** @@ -889,6 +906,9 @@ export class AgentService extends Common.ObjectWrapper.ObjectWrapper<{ * Starts a new conversation */ async newConversation(): Promise { + // Capture conversation ID BEFORE clearing (for async memory extraction) + const endingConversationId = this.#currentConversationId; + // Abort any running execution this.cancelRun(); @@ -914,6 +934,126 @@ export class AgentService extends Common.ObjectWrapper.ObjectWrapper<{ this.dispatchEventToListeners(Events.CONVERSATION_CHANGED, null); logger.info('Started new conversation'); + + // Fire off memory extraction in background (non-blocking) + if (endingConversationId) { + this.#processConversationMemory(endingConversationId); + } + } + + /** + * Processes memory for a conversation. Uses claim mechanism to prevent + * concurrent processing of the same conversation. + */ + async #processConversationMemory(conversationId: string): Promise { + logger.info('[Memory] Starting processing for conversation', {conversationId}); + // Check if memory is enabled in settings + if (!MemoryModule.getInstance().isEnabled()) { + logger.info('[Memory] Skipping - memory disabled in settings'); + return; + } + + // Try to claim - if another instance is processing, skip + const claimed = await this.#conversationManager.tryClaimForMemoryProcessing(conversationId); + if (!claimed) { + logger.info('[Memory] Skipping - already processing or completed', {conversationId}); + return; + } + + try { + // Load the conversation to get messages + const loaded = await this.#conversationManager.loadConversation(conversationId); + if (!loaded || loaded.state.messages.length < 4) { + // Mark as completed (nothing to extract) + await this.#conversationManager.markMemoryCompleted(conversationId); + logger.info('[Memory] Skipping - conversation too short', {conversationId, messageCount: loaded?.state.messages.length || 0}); + return; + } + + // Format conversation summary + const conversationSummary = loaded.state.messages + .filter(m => m.entity === ChatMessageEntity.USER || m.entity === ChatMessageEntity.MODEL) + .slice(-20) + .map(m => { + const role = m.entity === ChatMessageEntity.USER ? 'User' : 'Assistant'; + const text = m.entity === ChatMessageEntity.USER + ? (m as {text: string}).text + : ((m as ModelChatMessage).answer || ''); + return `${role}: ${text}`; + }) + .join('\n'); + + const memoryAgent = ToolRegistry.getToolInstance('memory_agent'); + if (!memoryAgent) { + await this.#conversationManager.markMemoryFailed(conversationId); + logger.warn('[Memory] memory_agent not found in registry'); + return; + } + + const config = this.#configManager.getConfiguration(); + logger.info('[Memory] Processing conversation', { + conversationId, + provider: config.provider, + model: config.mainModel, + miniModel: config.miniModel, + summaryLength: conversationSummary.length + }); + + const result = await memoryAgent.execute({ + conversation_summary: conversationSummary, + reasoning: 'Extracting facts from conversation', + }, { + apiKey: config.apiKey, + provider: config.provider, + model: config.mainModel, + miniModel: config.miniModel, + nanoModel: config.nanoModel, + background: true, // Don't show in UI + }); + + logger.info('[Memory] Agent execution result', { + conversationId, + success: result.success, + outputLength: result.output?.length || 0, + outputPreview: result.output?.substring(0, 500), + error: result.error, + terminationReason: result.terminationReason, + toolCallsCount: result.toolCalls?.length || 0, + toolCalls: result.toolCalls?.map((tc: any) => ({ name: tc.name, args: tc.args })) || [], + }); + + await this.#conversationManager.markMemoryCompleted(conversationId); + logger.info('[Memory] Completed', {conversationId}); + + } catch (err) { + logger.error('[Memory] Failed:', err); + await this.#conversationManager.markMemoryFailed(conversationId); + } + } + + /** + * Processes any old conversations that never had memory extracted. + * Call this on initialization or periodically. + */ + async processUnprocessedConversations(): Promise { + const pending = await this.#conversationManager.getConversationsNeedingMemoryProcessing(); + + // Skip the currently active conversation and limit to avoid overload + const toProcess = pending + .filter(conv => conv.id !== this.#currentConversationId) + .slice(0, 3); + + for (const conv of toProcess) { + // Don't await - process in parallel + this.#processConversationMemory(conv.id); + } + + if (pending.length > 0) { + logger.info('[Memory] Processing unprocessed conversations', { + total: pending.length, + processing: toProcess.length, + }); + } } /** @@ -1113,6 +1253,26 @@ export class AgentService extends Common.ObjectWrapper.ObjectWrapper<{ } } break; + case 'approval_requested': + // NOTE: Approval requests are now added directly to AgentSession.messages + // in AgentRunner for correct timeline ordering. This case is kept for + // backwards compatibility and logging only. + { + const approvalData = progressEvent.data as { + approvalId: string; + toolName: string; + toolArgs: Record; + guardrailDecision: import('../guardrails/types.js').GuardrailDecision; + }; + + logger.info('[AgentService] Approval requested (handled in session timeline):', { + approvalId: approvalData.approvalId, + toolName: approvalData.toolName, + riskLevel: approvalData.guardrailDecision.riskLevel, + }); + } + break; + case 'session_completed': // Get the completed session from the event data or active sessions const completedSession = progressEvent.data?.session || @@ -1152,6 +1312,22 @@ export class AgentService extends Common.ObjectWrapper.ObjectWrapper<{ } } + /** + * Handle approval resolution events from ApprovalManager. + * NOTE: The actual UI update now happens via session_updated events from AgentRunner, + * since approval messages are stored in AgentSession.messages for correct timeline ordering. + * This handler is kept for logging and potential future use. + */ + #handleApprovalResolved(event: Common.EventTarget.EventTargetEvent): void { + const { approvalId, result } = event.data; + + logger.info('[AgentService] Approval resolved (UI handled via session_updated):', { + approvalId, + approved: result.approved, + feedback: result.feedback, + }); + } + // Upsert helper: ensures the chat transcript reflects the latest AgentSession state in real-time #upsertAgentSessionInMessages(session: AgentSession): void { // If this is a child session, update the parent container too @@ -1216,6 +1392,7 @@ export class AgentService extends Common.ObjectWrapper.ObjectWrapper<{ }, 5000); } } + } // Define UI strings object to manage i18n strings diff --git a/front_end/panels/ai_chat/core/BaseOrchestratorAgent.ts b/front_end/panels/ai_chat/core/BaseOrchestratorAgent.ts index 008e7aeeef..0c0ccb3847 100644 --- a/front_end/panels/ai_chat/core/BaseOrchestratorAgent.ts +++ b/front_end/panels/ai_chat/core/BaseOrchestratorAgent.ts @@ -35,7 +35,7 @@ import { ListFilesTool, type Tool } from '../tools/Tools.js'; -// Imports from their own files +import { MemoryModule } from '../memory/index.js'; // Initialize configured agents initializeConfiguredAgents(); @@ -50,7 +50,7 @@ export enum BaseOrchestratorAgentType { SHOPPING = 'shopping' } -// System prompts for each agent type +// System prompts for each agent type (WITHOUT memory instructions - added dynamically) export const SYSTEM_PROMPTS = { [BaseOrchestratorAgentType.SEARCH]: `You are an search browser agent specialized in pinpoint web fact-finding. Always delegate investigative work to the 'search_agent' tool so it can gather verified, structured results (emails, team rosters, niche professionals, etc.). @@ -323,6 +323,7 @@ export const AGENT_CONFIGS: {[key: string]: AgentConfig} = { new DeleteFileTool(), new ReadFileTool(), new ListFilesTool(), + ToolRegistry.getToolInstance('search_memory_agent') || (() => { throw new Error('search_memory_agent tool not found'); })(), ] }, [BaseOrchestratorAgentType.DEEP_RESEARCH]: { @@ -347,6 +348,7 @@ export const AGENT_CONFIGS: {[key: string]: AgentConfig} = { new DeleteFileTool(), new ReadFileTool(), new ListFilesTool(), + ToolRegistry.getToolInstance('search_memory_agent') || (() => { throw new Error('search_memory_agent tool not found'); })(), ] }, // [BaseOrchestratorAgentType.SHOPPING]: { @@ -389,14 +391,17 @@ AgentDescriptorRegistry.registerSource({ /** * Get the system prompt for a specific agent type + * Memory instructions are dynamically prepended if memory is enabled */ export function getSystemPrompt(agentType: string): string { + const memoryPrefix = MemoryModule.getInstance().getInstructions(); + // Check if there's a custom prompt for this agent type if (hasCustomPrompt(agentType)) { - return getAgentPrompt(agentType); + return memoryPrefix + getAgentPrompt(agentType); } - - return AGENT_CONFIGS[agentType]?.systemPrompt || + + return memoryPrefix + (AGENT_CONFIGS[agentType]?.systemPrompt || // Default system prompt if agent type not found `You are a browser agent for helping users with tasks. And, you are an expert task orchestrator agent focused on high-level task strategy, planning, efficient delegation to specialized web agents, and final result synthesis. Your core goal is to provide maximally helpful task completion by orchestrating an effective execution process. @@ -517,14 +522,18 @@ After specialized agents complete their tasks: 2. Identify patterns, best options, and key insights 3. Note any remaining gaps or follow-up needs 4. Create a comprehensive response following the appropriate format -`; +`); } /** * Get available tools for a specific agent type + * Conditionally includes search_memory_agent if memory is enabled */ export function getAgentTools(agentType: string): Array> { - return AGENT_CONFIGS[agentType]?.availableTools || [ + const memoryModule = MemoryModule.getInstance(); + + // Get base tools from config or use default list + const baseTools = AGENT_CONFIGS[agentType]?.availableTools || [ ToolRegistry.getToolInstance('search_agent') || (() => { throw new Error('search_agent tool not found'); })(), ToolRegistry.getToolInstance('web_task_agent') || (() => { throw new Error('web_task_agent tool not found'); })(), ToolRegistry.getToolInstance('document_search') || (() => { throw new Error('document_search tool not found'); })(), @@ -541,6 +550,22 @@ export function getAgentTools(agentType: string): Array> { new ReadFileTool(), new ListFilesTool(), ]; + + // Filter out search_memory_agent if memory is disabled, or add it if enabled and not present + if (memoryModule.shouldIncludeMemoryTool()) { + // Check if search_memory_agent is already in the list + const hasMemoryAgent = baseTools.some(tool => tool.name === 'search_memory_agent'); + if (!hasMemoryAgent) { + const memoryAgent = ToolRegistry.getToolInstance('search_memory_agent'); + if (memoryAgent) { + return [...baseTools, memoryAgent]; + } + } + return baseTools; + } + + // Memory disabled - filter out search_memory_agent + return baseTools.filter(tool => tool.name !== 'search_memory_agent'); } // Custom event for agent type selection diff --git a/front_end/panels/ai_chat/core/CustomProviderManager.ts b/front_end/panels/ai_chat/core/CustomProviderManager.ts index d922e7b73b..4f101f3e97 100644 --- a/front_end/panels/ai_chat/core/CustomProviderManager.ts +++ b/front_end/panels/ai_chat/core/CustomProviderManager.ts @@ -14,6 +14,7 @@ export interface CustomProviderConfig { name: string; // Display name (e.g., "Z.AI") baseURL: string; // Base URL (e.g., "https://api.z.ai/api/coding/paas/v4") models: string[]; // Available models + modelsManuallyAdded: boolean; // True if user manually configured models, false if fetched from API enabled: boolean; // Whether the provider is enabled createdAt: number; // Timestamp when created updatedAt: number; // Timestamp when last updated @@ -62,8 +63,10 @@ export class CustomProviderManager { } } - if (!config.models || config.models.length === 0) { - errors.push('At least one model is required'); + // Models are optional - they can be fetched from the API if not manually specified + // Only validate if modelsManuallyAdded is true and models are empty + if (config.modelsManuallyAdded && (!config.models || config.models.length === 0)) { + errors.push('At least one model is required when manually adding models'); } return { @@ -145,10 +148,19 @@ export class CustomProviderManager { /** * Add a new custom provider + * @param config Provider configuration (modelsManuallyAdded defaults to true if models are provided) */ - static addProvider(config: Omit): CustomProviderConfig { + static addProvider(config: Omit & { modelsManuallyAdded?: boolean }): CustomProviderConfig { + // Determine if models were manually added (default to true if models are provided) + const modelsManuallyAdded = config.modelsManuallyAdded ?? (config.models && config.models.length > 0); + + const fullConfig = { + ...config, + modelsManuallyAdded, + }; + // Validate config - const validation = CustomProviderManager.validateConfig(config); + const validation = CustomProviderManager.validateConfig(fullConfig); if (!validation.valid) { throw new Error(`Invalid provider configuration: ${validation.errors.join(', ')}`); } @@ -163,7 +175,7 @@ export class CustomProviderManager { const now = Date.now(); const newProvider: CustomProviderConfig = { - ...config, + ...fullConfig, id, createdAt: now, updatedAt: now, diff --git a/front_end/panels/ai_chat/core/GraphConfigs.ts b/front_end/panels/ai_chat/core/GraphConfigs.ts index effcbb4d74..328063310b 100644 --- a/front_end/panels/ai_chat/core/GraphConfigs.ts +++ b/front_end/panels/ai_chat/core/GraphConfigs.ts @@ -8,10 +8,11 @@ import { NodeType } from './Types.js'; /** * Defines the default agent graph configuration. + * Flow: AGENT → TOOL_EXECUTOR → AGENT → ... → FINAL → __end__ + * Memory is accessed on-demand via search_memory_agent tool. */ export const defaultAgentGraphConfig: GraphConfig = { name: 'defaultAgentGraph', - // Revert to using NodeType enum members entryPoint: NodeType.AGENT.toString(), nodes: [ { name: NodeType.AGENT.toString(), type: 'agent' }, diff --git a/front_end/panels/ai_chat/core/GraphHelpers.ts b/front_end/panels/ai_chat/core/GraphHelpers.ts index b6e02aa420..12e0518430 100644 --- a/front_end/panels/ai_chat/core/GraphHelpers.ts +++ b/front_end/panels/ai_chat/core/GraphHelpers.ts @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -import type { getTools } from '../tools/Tools.js'; +import type { getTools, Tool } from '../tools/Tools.js'; import { ChatMessageEntity, type ChatMessage } from '../models/ChatTypes.js'; import * as BaseOrchestratorAgent from './BaseOrchestratorAgent.js'; @@ -11,6 +11,7 @@ import { LLMConfigurationManager } from './LLMConfigurationManager.js'; import { enhancePromptWithPageContext } from './PageInfoManager.js'; import type { AgentState } from './State.js'; import { NodeType } from './Types.js'; +import { compileGuardrailPrompt, hasGuardrailContent, getGuardrailMiddleware } from '../guardrails/index.js'; const logger = createLogger('GraphHelpers'); @@ -29,7 +30,7 @@ export function createSystemPrompt(state: AgentState): string { } // Add a new async version that will be used instead -export async function createSystemPromptAsync(state: AgentState): Promise { +export async function createSystemPromptAsync(state: AgentState, tools?: Tool[]): Promise { const { selectedAgentType } = state; // Check if there's a system prompt override from conversation state @@ -43,7 +44,27 @@ export async function createSystemPromptAsync(state: AgentState): Promise = { + openai: { + main: 'gpt-4.1-2025-04-14', + mini: 'gpt-4.1-mini-2025-04-14', + nano: 'gpt-4.1-nano-2025-04-14' + }, + litellm: { + main: '', // Will use first available model + mini: '', + nano: '' + }, + groq: { + main: 'openai/gpt-oss-120b', + mini: 'openai/gpt-oss-120b', + nano: 'openai/gpt-oss-120b' + }, + openrouter: { + main: 'anthropic/claude-sonnet-4.5', + mini: 'google/gemini-2.5-flash', + nano: 'openai/gpt-oss-120b:exacto' + }, + browseroperator: { + main: 'main', + mini: 'mini', + nano: 'nano' + }, + cerebras: { + main: 'gpt-oss-120b', + mini: 'gpt-oss-120b', + nano: 'llama3.1-8b' + }, + anthropic: { + main: 'claude-sonnet-4-5', + mini: 'claude-haiku-4-5', + nano: 'claude-haiku-4-5' + }, + googleai: { + main: 'gemini-2.5-pro', + mini: 'gemini-2.5-flash', + nano: 'gemini-2.5-flash' + } +}; + +/** + * Default OpenAI models (static list for providers without fetch capability) + */ +export const DEFAULT_OPENAI_MODELS: ModelOption[] = [ + {value: 'o4-mini-2025-04-16', label: 'O4 Mini', type: 'openai'}, + {value: 'o3-mini-2025-01-31', label: 'O3 Mini', type: 'openai'}, + {value: 'gpt-5-2025-08-07', label: 'GPT-5', type: 'openai'}, + {value: 'gpt-5-mini-2025-08-07', label: 'GPT-5 Mini', type: 'openai'}, + {value: 'gpt-5-nano-2025-08-07', label: 'GPT-5 Nano', type: 'openai'}, + {value: 'gpt-4.1-2025-04-14', label: 'GPT-4.1', type: 'openai'}, + {value: 'gpt-4.1-mini-2025-04-14', label: 'GPT-4.1 Mini', type: 'openai'}, + {value: 'gpt-4.1-nano-2025-04-14', label: 'GPT-4.1 Nano', type: 'openai'}, +]; + +/** + * Placeholder constants for model options + */ +export const MODEL_PLACEHOLDERS = { + ADD_CUSTOM: 'add_custom_model', + NO_MODELS: 'no_models_available', +}; + /** * Configuration interface for LLM settings */ @@ -38,6 +116,10 @@ const STORAGE_KEYS = { GROQ_API_KEY: 'ai_chat_groq_api_key', OPENROUTER_API_KEY: 'ai_chat_openrouter_api_key', BROWSEROPERATOR_API_KEY: 'ai_chat_browseroperator_api_key', + // Model options storage + ALL_MODEL_OPTIONS: 'ai_chat_all_model_options', + MODEL_OPTIONS: 'ai_chat_model_options', // Legacy, for backward compatibility + CUSTOM_MODELS: 'ai_chat_custom_models', // For LiteLLM custom models } as const; /** @@ -49,9 +131,15 @@ export class LLMConfigurationManager { private overrideConfig?: Partial; // Override for automated mode private changeListeners: Array<() => void> = []; + // Model options state - organized by provider + private modelOptionsByProvider: Map = new Map(); + private modelOptionsInitialized = false; + private constructor() { // Listen for localStorage changes from other tabs (manual mode) window.addEventListener('storage', this.handleStorageChange.bind(this)); + // Initialize model options from localStorage + this.loadModelOptionsFromStorage(); } /** @@ -77,6 +165,7 @@ export class LLMConfigurationManager { /** * Get the main model with override fallback + * Note: For default fallback, ensure models are fetched and selected in the UI */ getMainModel(): string { if (this.overrideConfig?.mainModel) { @@ -154,6 +243,303 @@ export class LLMConfigurationManager { }; } + // ============================================================================ + // Model Options Management + // ============================================================================ + + /** + * Load model options from localStorage into memory + */ + private loadModelOptionsFromStorage(): void { + try { + // Load from comprehensive storage + const allOptionsJson = localStorage.getItem(STORAGE_KEYS.ALL_MODEL_OPTIONS); + if (allOptionsJson) { + const allOptions: ModelOption[] = JSON.parse(allOptionsJson); + // Group by provider + this.modelOptionsByProvider.clear(); + for (const option of allOptions) { + const providerOptions = this.modelOptionsByProvider.get(option.type) || []; + providerOptions.push(option); + this.modelOptionsByProvider.set(option.type, providerOptions); + } + logger.debug('Loaded model options from storage', { + providers: Array.from(this.modelOptionsByProvider.keys()), + totalModels: allOptions.length + }); + } else { + // Initialize with defaults + this.modelOptionsByProvider.set('openai', [...DEFAULT_OPENAI_MODELS]); + logger.debug('Initialized with default OpenAI models'); + } + this.modelOptionsInitialized = true; + } catch (error) { + logger.error('Failed to load model options from storage:', error); + // Initialize with defaults on error + this.modelOptionsByProvider.set('openai', [...DEFAULT_OPENAI_MODELS]); + this.modelOptionsInitialized = true; + } + } + + /** + * Get all model options across all providers + */ + getAllModelOptions(): ModelOption[] { + const allOptions: ModelOption[] = []; + for (const options of this.modelOptionsByProvider.values()) { + allOptions.push(...options); + } + return allOptions; + } + + /** + * Get model options for a specific provider + * @param provider Provider ID (e.g., 'openai', 'groq'). If not provided, uses current provider. + */ + getModelOptions(provider?: string): ModelOption[] { + const targetProvider = provider || this.getProvider(); + return this.modelOptionsByProvider.get(targetProvider) || []; + } + + /** + * Get model options for the currently selected provider + */ + getModelOptionsForCurrentProvider(): ModelOption[] { + return this.getModelOptions(this.getProvider()); + } + + /** + * Set model options for a provider + * @param provider Provider ID + * @param models Array of model options + */ + setModelOptions(provider: string, models: ModelOption[]): void { + logger.info(`Setting ${models.length} models for provider ${provider}`); + this.modelOptionsByProvider.set(provider, models); + this.persistModelOptionsToStorage(); + this.notifyListeners(); + } + + /** + * Clear model options for a provider, or all providers if not specified + * @param provider Optional provider ID to clear + */ + clearModelOptions(provider?: string): void { + if (provider) { + this.modelOptionsByProvider.delete(provider); + logger.debug(`Cleared model options for provider ${provider}`); + } else { + this.modelOptionsByProvider.clear(); + logger.debug('Cleared all model options'); + } + this.persistModelOptionsToStorage(); + this.notifyListeners(); + } + + /** + * Add a custom model option (primarily for LiteLLM) + * @param modelName The model name to add + * @param provider The provider type (defaults to current provider) + */ + addCustomModelOption(modelName: string, provider?: string): void { + const targetProvider = provider || this.getProvider(); + const providerOptions = this.modelOptionsByProvider.get(targetProvider) || []; + + // Check if model already exists + if (providerOptions.some(m => m.value === modelName)) { + logger.debug(`Model ${modelName} already exists for provider ${targetProvider}`); + return; + } + + // Create label - just use the model name (consumers can format as needed) + const newOption: ModelOption = { + value: modelName, + label: modelName, + type: targetProvider + }; + + providerOptions.push(newOption); + this.modelOptionsByProvider.set(targetProvider, providerOptions); + + // Also save to custom models list for LiteLLM + if (targetProvider === 'litellm') { + this.saveCustomModelToStorage(modelName); + } + + this.persistModelOptionsToStorage(); + this.notifyListeners(); + + logger.info(`Added custom model ${modelName} for provider ${targetProvider}`); + } + + /** + * Remove a custom model option + * @param modelName The model name to remove + * @param provider The provider type (defaults to current provider) + */ + removeCustomModelOption(modelName: string, provider?: string): void { + const targetProvider = provider || this.getProvider(); + const providerOptions = this.modelOptionsByProvider.get(targetProvider) || []; + + const filteredOptions = providerOptions.filter(m => m.value !== modelName); + if (filteredOptions.length === providerOptions.length) { + logger.debug(`Model ${modelName} not found for provider ${targetProvider}`); + return; + } + + this.modelOptionsByProvider.set(targetProvider, filteredOptions); + + // Also remove from custom models list for LiteLLM + if (targetProvider === 'litellm') { + this.removeCustomModelFromStorage(modelName); + } + + this.persistModelOptionsToStorage(); + this.notifyListeners(); + + logger.info(`Removed custom model ${modelName} from provider ${targetProvider}`); + } + + /** + * Validate a model selection against available options + * @param model The model value to validate + * @param provider Optional provider to validate against (defaults to current) + */ + validateModelSelection(model: string, provider?: string): boolean { + if (!model) return false; + const options = this.getModelOptions(provider); + return options.some(opt => opt.value === model); + } + + /** + * Validate and fix model selections for the current provider + * Returns the corrected selections + */ + validateAndFixModelSelections(): { main: string; mini: string; nano: string } { + const provider = this.getProvider(); + const available = this.getModelOptionsForCurrentProvider(); + const defaults = DEFAULT_PROVIDER_MODELS[provider] || {}; + + const availableValues = available.filter(m => m.type === provider).map(m => m.value); + + const validateModel = (stored: string, defaultValue: string | undefined): string => { + // 1. Check exact match for stored model + if (stored && available.some(m => m.value === stored && m.type === provider)) { + return stored; + } + + // 2. Try fuzzy match for stored model + if (stored) { + const fuzzyMatch = findClosestModel(stored, availableValues); + if (fuzzyMatch) { + logger.info(`Fuzzy matched model '${stored}' to '${fuzzyMatch}'`); + return fuzzyMatch; + } + } + + // 3. Check exact match for provider default + if (defaultValue && available.some(m => m.value === defaultValue)) { + return defaultValue; + } + + // 4. Try fuzzy match for provider default + if (defaultValue) { + const fuzzyDefault = findClosestModel(defaultValue, availableValues); + if (fuzzyDefault) { + logger.info(`Fuzzy matched default '${defaultValue}' to '${fuzzyDefault}'`); + return fuzzyDefault; + } + } + + // 5. Fall back to first available + return available.length > 0 ? available[0].value : ''; + }; + + const currentMain = localStorage.getItem(STORAGE_KEYS.MODEL_SELECTION) || ''; + const currentMini = localStorage.getItem(STORAGE_KEYS.MINI_MODEL) || ''; + const currentNano = localStorage.getItem(STORAGE_KEYS.NANO_MODEL) || ''; + + const main = validateModel(currentMain, defaults.main); + const mini = validateModel(currentMini, defaults.mini); + const nano = validateModel(currentNano, defaults.nano); + + // Persist corrections if needed + if (main !== currentMain) { + localStorage.setItem(STORAGE_KEYS.MODEL_SELECTION, main); + logger.info(`Corrected main model from '${currentMain}' to '${main}'`); + } + if (mini !== currentMini) { + if (mini) { + localStorage.setItem(STORAGE_KEYS.MINI_MODEL, mini); + } else { + localStorage.removeItem(STORAGE_KEYS.MINI_MODEL); + } + logger.info(`Corrected mini model from '${currentMini}' to '${mini}'`); + } + if (nano !== currentNano) { + if (nano) { + localStorage.setItem(STORAGE_KEYS.NANO_MODEL, nano); + } else { + localStorage.removeItem(STORAGE_KEYS.NANO_MODEL); + } + logger.info(`Corrected nano model from '${currentNano}' to '${nano}'`); + } + + return { main, mini, nano }; + } + + /** + * Persist model options to localStorage + */ + private persistModelOptionsToStorage(): void { + try { + const allOptions = this.getAllModelOptions(); + localStorage.setItem(STORAGE_KEYS.ALL_MODEL_OPTIONS, JSON.stringify(allOptions)); + + // Also update legacy storage for backward compatibility + const currentProviderOptions = this.getModelOptionsForCurrentProvider(); + localStorage.setItem(STORAGE_KEYS.MODEL_OPTIONS, JSON.stringify(currentProviderOptions)); + } catch (error) { + logger.error('Failed to persist model options to storage:', error); + } + } + + /** + * Save a custom model name to the LiteLLM custom models list + */ + private saveCustomModelToStorage(modelName: string): void { + try { + const customModelsJson = localStorage.getItem(STORAGE_KEYS.CUSTOM_MODELS); + const customModels: string[] = customModelsJson ? JSON.parse(customModelsJson) : []; + if (!customModels.includes(modelName)) { + customModels.push(modelName); + localStorage.setItem(STORAGE_KEYS.CUSTOM_MODELS, JSON.stringify(customModels)); + } + } catch (error) { + logger.error('Failed to save custom model to storage:', error); + } + } + + /** + * Remove a custom model name from the LiteLLM custom models list + */ + private removeCustomModelFromStorage(modelName: string): void { + try { + const customModelsJson = localStorage.getItem(STORAGE_KEYS.CUSTOM_MODELS); + if (customModelsJson) { + const customModels: string[] = JSON.parse(customModelsJson); + const filtered = customModels.filter(m => m !== modelName); + localStorage.setItem(STORAGE_KEYS.CUSTOM_MODELS, JSON.stringify(filtered)); + } + } catch (error) { + logger.error('Failed to remove custom model from storage:', error); + } + } + + // ============================================================================ + // Override Configuration + // ============================================================================ + /** * Set override configuration (for automated mode per-request overrides) */ diff --git a/front_end/panels/ai_chat/core/PageInfoManager.ts b/front_end/panels/ai_chat/core/PageInfoManager.ts index 7d1d58be14..96c1c52ef3 100644 --- a/front_end/panels/ai_chat/core/PageInfoManager.ts +++ b/front_end/panels/ai_chat/core/PageInfoManager.ts @@ -6,6 +6,7 @@ import * as SDK from '../../../core/sdk/sdk.js'; import * as Utils from '../common/utils.js'; // Path relative to core/ assuming utils.ts will be in common/ later, this will be common/utils.js import { VisitHistoryManager } from '../tools/VisitHistoryManager.js'; // Path relative to core/ assuming VisitHistoryManager.ts will be in core/ import { FileStorageManager } from '../tools/FileStorageManager.js'; +import { MemoryBlockManager } from '../memory/index.js'; import { createLogger } from './Logger.js'; const logger = createLogger('PageInfoManager'); @@ -199,6 +200,15 @@ export async function enhancePromptWithPageContext(basePrompt: string): Promise< logger.warn('Failed to fetch files for context:', error); } + // Get memory context (global across sessions) + let memoryContext = ''; + try { + const memoryManager = new MemoryBlockManager(); + memoryContext = await memoryManager.compileMemoryContext(); + } catch (error) { + logger.warn('Failed to fetch memory context:', error); + } + // If no page info is available, return the original prompt if (!pageInfo) { return basePrompt; @@ -213,6 +223,7 @@ export async function enhancePromptWithPageContext(basePrompt: string): Promise< ${new Date().toLocaleDateString()} + ${memoryContext} ${pageInfo.title} diff --git a/front_end/panels/ai_chat/guardrails/GuardrailMiddleware.ts b/front_end/panels/ai_chat/guardrails/GuardrailMiddleware.ts new file mode 100644 index 0000000000..9f67003178 --- /dev/null +++ b/front_end/panels/ai_chat/guardrails/GuardrailMiddleware.ts @@ -0,0 +1,415 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/** + * GuardrailMiddleware - Single entry point for guardrail evaluation and approval. + * Merges evaluation and approval management into one clean interface. + */ + +import * as Common from '../../../core/common/common.js'; +import { createLogger } from '../core/Logger.js'; +import { PolicyEvaluator } from './PolicyEvaluator.js'; +import type { + ToolCall, + ExecutionContext, + GuardrailConfig, + GuardrailDecision, + ApprovalRequest, + ApprovalResult, + GateResult, + ToolApprovalConfig, +} from './types.js'; +import { DEFAULT_GUARDRAIL_CONFIG } from './types.js'; + +const logger = createLogger('GuardrailMiddleware'); + +// ============================================================================ +// Events +// ============================================================================ + +export enum GuardrailEvents { + APPROVAL_REQUESTED = 'approval-requested', + APPROVAL_RESOLVED = 'approval-resolved', + APPROVAL_TIMEOUT = 'approval-timeout', +} + +export interface ApprovalRequestedEvent { + request: ApprovalRequest; +} + +export interface ApprovalResolvedEvent { + approvalId: string; + result: ApprovalResult; +} + +// ============================================================================ +// Pending Approval State +// ============================================================================ + +interface PendingApproval { + resolve: (result: ApprovalResult) => void; + reject: (error: Error) => void; + request: ApprovalRequest; + timeoutId?: ReturnType; +} + +// ============================================================================ +// GuardrailMiddleware +// ============================================================================ + +/** + * GuardrailMiddleware - Single entry point for all guardrail operations. + * Combines evaluation and approval gate into one async flow. + */ +export class GuardrailMiddleware extends Common.ObjectWrapper.ObjectWrapper<{ + [GuardrailEvents.APPROVAL_REQUESTED]: ApprovalRequestedEvent, + [GuardrailEvents.APPROVAL_RESOLVED]: ApprovalResolvedEvent, + [GuardrailEvents.APPROVAL_TIMEOUT]: { approvalId: string }, +}> { + private static instance: GuardrailMiddleware | null = null; + + private config: GuardrailConfig; + private evaluator: PolicyEvaluator; + private pendingApprovals = new Map(); + + private constructor(config: Partial = {}) { + super(); + this.config = { ...DEFAULT_GUARDRAIL_CONFIG, ...config }; + this.evaluator = new PolicyEvaluator(this.config); + } + + /** + * Get singleton instance + */ + static getInstance(): GuardrailMiddleware { + if (!GuardrailMiddleware.instance) { + GuardrailMiddleware.instance = new GuardrailMiddleware(); + } + return GuardrailMiddleware.instance; + } + + /** + * Reset the singleton (for testing) + */ + static resetInstance(): void { + if (GuardrailMiddleware.instance) { + GuardrailMiddleware.instance.cancelAllPending(); + } + GuardrailMiddleware.instance = null; + } + + /** + * Update configuration + */ + updateConfig(config: Partial): void { + this.config = { ...this.config, ...config }; + this.evaluator.updateConfig(this.config); + logger.info('GuardrailMiddleware config updated', { enabled: this.config.enabled }); + } + + /** + * Get current configuration + */ + getConfig(): GuardrailConfig { + return { ...this.config }; + } + + /** + * Get the policy evaluator (for prompt compilation) + */ + getEvaluator(): PolicyEvaluator { + return this.evaluator; + } + + // ========================================================================== + // Main Gate API + // ========================================================================== + + /** + * Main entry point - evaluate tool call and gate execution if needed. + * + * This is the single integration point that replaces both: + * - AgentRunner guardrail check + * - AgentNodes guardrail check + * + * @param toolCall - The tool call to evaluate + * @param context - Execution context (current URL, domain, user goal) + * @param onApprovalNeeded - Callback when approval is needed (for UI updates) + * @param toolApprovalConfig - Tool-level approval configuration (optional) + * @returns GateResult with proceed flag and optional feedback + */ + async gate( + toolCall: ToolCall, + context: ExecutionContext, + onApprovalNeeded?: (request: ApprovalRequest) => void, + toolApprovalConfig?: ToolApprovalConfig + ): Promise { + // Step 1: Evaluate tool call against policies + const decision = await this.evaluator.evaluate(toolCall, context, toolApprovalConfig); + + // Step 2: If safe, proceed immediately + if (!decision.requiresApproval) { + return { proceed: true, decision }; + } + + // Step 3: Create approval request + const request = this.createApprovalRequest(toolCall, decision); + + // Step 4: Notify UI (if callback provided) + if (onApprovalNeeded) { + onApprovalNeeded(request); + } + + // Step 5: Dispatch event for other listeners + this.dispatchEventToListeners(GuardrailEvents.APPROVAL_REQUESTED, { request }); + + // Step 6: Wait for user response + const result = await this.waitForApproval(request.id); + + // Step 7: Return result with feedback for agent + return { + proceed: result.approved, + feedback: result.feedback, + decision, + }; + } + + /** + * Simplified gate that evaluates and returns decision without blocking for approval. + * Useful for checking if a tool would require approval without actually waiting. + */ + async evaluate( + toolCall: ToolCall, + context: ExecutionContext, + toolApprovalConfig?: ToolApprovalConfig + ): Promise { + return this.evaluator.evaluate(toolCall, context, toolApprovalConfig); + } + + // ========================================================================== + // Approval Management + // ========================================================================== + + /** + * Wait for approval with timeout + */ + private waitForApproval(approvalId: string): Promise { + return new Promise((resolve, reject) => { + const pending = this.pendingApprovals.get(approvalId); + if (!pending) { + // Approval was already resolved or doesn't exist + reject(new Error(`Approval ${approvalId} not found`)); + return; + } + + // Update resolve/reject functions + pending.resolve = resolve; + pending.reject = reject; + + // Set up timeout + const timeoutId = setTimeout(() => { + this.handleTimeout(approvalId); + }, this.config.approvalTimeoutMs); + + pending.timeoutId = timeoutId; + }); + } + + /** + * Create an approval request + */ + private createApprovalRequest(toolCall: ToolCall, decision: GuardrailDecision): ApprovalRequest { + const id = this.generateApprovalId(); + const request: ApprovalRequest = { + id, + toolCall, + decision, + status: 'pending', + timestamp: Date.now(), + }; + + // Store pending approval (without resolve/reject yet - set in waitForApproval) + this.pendingApprovals.set(id, { + resolve: () => {}, + reject: () => {}, + request, + }); + + return request; + } + + /** + * Resolve an approval with user's decision. + * Called by UI when user clicks approve/reject. + */ + resolveApproval(approvalId: string, approved: boolean, feedback?: string): void { + const pending = this.pendingApprovals.get(approvalId); + if (!pending) { + logger.warn('Approval not found', { approvalId }); + return; + } + + // Clear timeout + if (pending.timeoutId) { + clearTimeout(pending.timeoutId); + } + + // Calculate response time + const responseTimeMs = Date.now() - pending.request.timestamp; + + const result: ApprovalResult = { + approved, + feedback, + responseTimeMs, + }; + + logger.info('Approval resolved', { approvalId, approved, responseTimeMs }); + + // Update request status + pending.request.status = approved ? 'approved' : 'rejected'; + pending.request.feedback = feedback; + + // Remove from pending + this.pendingApprovals.delete(approvalId); + + // Resolve the promise + pending.resolve(result); + + // Dispatch event + this.dispatchEventToListeners(GuardrailEvents.APPROVAL_RESOLVED, { + approvalId, + result, + }); + } + + /** + * Approve an action + */ + approve(approvalId: string): void { + this.resolveApproval(approvalId, true); + } + + /** + * Reject an action with optional feedback + */ + reject(approvalId: string, feedback?: string): void { + this.resolveApproval(approvalId, false, feedback); + } + + /** + * Cancel a pending approval + */ + cancelApproval(approvalId: string, reason?: string): void { + const pending = this.pendingApprovals.get(approvalId); + if (!pending) { + return; + } + + if (pending.timeoutId) { + clearTimeout(pending.timeoutId); + } + + this.pendingApprovals.delete(approvalId); + pending.reject(new Error(reason || 'Approval cancelled')); + + logger.info('Approval cancelled', { approvalId, reason }); + } + + /** + * Cancel all pending approvals + */ + cancelAllPending(): void { + for (const approvalId of this.pendingApprovals.keys()) { + this.cancelApproval(approvalId, 'All approvals cancelled'); + } + } + + /** + * Handle timeout + */ + private handleTimeout(approvalId: string): void { + const pending = this.pendingApprovals.get(approvalId); + if (!pending) { + return; + } + + logger.warn('Approval timed out', { approvalId, toolName: pending.request.toolCall.name }); + + // Update request status + pending.request.status = 'rejected'; + + // Remove from pending + this.pendingApprovals.delete(approvalId); + + // Reject with timeout error + pending.reject(new Error('Approval request timed out')); + + // Dispatch timeout event + this.dispatchEventToListeners(GuardrailEvents.APPROVAL_TIMEOUT, { approvalId }); + } + + // ========================================================================== + // Utility Methods + // ========================================================================== + + /** + * Check if an approval is pending + */ + hasPendingApproval(approvalId: string): boolean { + return this.pendingApprovals.has(approvalId); + } + + /** + * Get pending approval details + */ + getPendingApproval(approvalId: string): ApprovalRequest | undefined { + return this.pendingApprovals.get(approvalId)?.request; + } + + /** + * Get all pending approval IDs + */ + getPendingApprovalIds(): string[] { + return Array.from(this.pendingApprovals.keys()); + } + + /** + * Get count of pending approvals + */ + getPendingCount(): number { + return this.pendingApprovals.size; + } + + /** + * Generate a unique approval ID + */ + private generateApprovalId(): string { + return `approval-${Date.now()}-${Math.random().toString(36).substring(2, 9)}`; + } + + /** + * Generate approval ID (static method for external use) + */ + static generateApprovalId(): string { + return `approval-${Date.now()}-${Math.random().toString(36).substring(2, 9)}`; + } +} + +// ============================================================================ +// Convenience Exports +// ============================================================================ + +/** + * Get the singleton GuardrailMiddleware instance + */ +export function getGuardrailMiddleware(): GuardrailMiddleware { + return GuardrailMiddleware.getInstance(); +} + +/** + * Reset the singleton (for testing) + */ +export function resetGuardrailMiddleware(): void { + GuardrailMiddleware.resetInstance(); +} diff --git a/front_end/panels/ai_chat/guardrails/PolicyEvaluator.ts b/front_end/panels/ai_chat/guardrails/PolicyEvaluator.ts new file mode 100644 index 0000000000..66cce0fdfb --- /dev/null +++ b/front_end/panels/ai_chat/guardrails/PolicyEvaluator.ts @@ -0,0 +1,379 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/** + * PolicyEvaluator - Consolidated policy evaluation engine. + * Merges rule-based and LLM-based evaluation from GuardrailEvaluator. + */ + +import { createLogger } from '../core/Logger.js'; +import { LLMClient } from '../LLM/LLMClient.js'; +import type { + ToolCall, + ExecutionContext, + GuardrailDecision, + GuardrailConfig, + ToolApprovalConfig, + RiskLevel, + Policy, +} from './types.js'; +import { DEFAULT_GUARDRAIL_CONFIG, RISK_LEVEL_ORDER } from './types.js'; +import { + POLICIES, + getPoliciesForTool, + evaluateNavigation, + evaluateDataEntry, + evaluateClick, + evaluateScript, +} from './policies.js'; + +const logger = createLogger('PolicyEvaluator'); + +/** + * PolicyEvaluator - Evaluates tool calls against policies + */ +export class PolicyEvaluator { + private config: GuardrailConfig; + private llmClient: LLMClient; + + constructor(config: Partial = {}) { + this.config = { ...DEFAULT_GUARDRAIL_CONFIG, ...config }; + this.llmClient = LLMClient.getInstance(); + } + + /** + * Update configuration + */ + updateConfig(config: Partial): void { + this.config = { ...this.config, ...config }; + logger.info('PolicyEvaluator config updated', { enabled: this.config.enabled }); + } + + /** + * Get current configuration + */ + getConfig(): GuardrailConfig { + return { ...this.config }; + } + + /** + * Get active policies + */ + getActivePolicies(): Policy[] { + return POLICIES; + } + + /** + * Main evaluation entry point + */ + async evaluate( + toolCall: ToolCall, + context: ExecutionContext, + toolApprovalConfig?: ToolApprovalConfig + ): Promise { + // Check if guardrails are disabled + if (!this.config.enabled) { + return this.createSafeDecision('Guardrails disabled'); + } + + // Check tool-level approval config first + if (toolApprovalConfig?.requiresApproval) { + const riskLevel = toolApprovalConfig.riskLevel || 'medium'; + const message = toolApprovalConfig.approvalMessage || + `The tool "${toolCall.name}" requires human approval by default.`; + return this.createViolationDecision( + riskLevel, + 'Tool requires approval by default', + message, + 'tool_approval_config' + ); + } + + // Check explicit allow list + if (this.config.alwaysApprove.includes(toolCall.name)) { + return this.createSafeDecision('Tool in allow list'); + } + + // Check explicit require list + if (this.config.alwaysRequire.includes(toolCall.name)) { + return this.createViolationDecision( + 'high', + 'Tool requires approval', + `The tool "${toolCall.name}" is configured to always require human approval.`, + 'always_require_list' + ); + } + + // Run rule-based evaluation (fast path) + const ruleDecision = this.evaluateByRules(toolCall, context); + + // If definitive decision or LLM disabled, apply threshold and return + if (ruleDecision.isDefinitive || !this.config.llm.enabled) { + return this.applyThreshold(ruleDecision); + } + + // For escalate cases, use LLM evaluation if enabled and configured + if (this.config.llm.enabled && this.config.llm.apiKey && this.config.llm.provider) { + try { + const llmDecision = await this.evaluateWithLLM(toolCall, context); + return this.applyThreshold(llmDecision); + } catch (error) { + logger.warn('LLM guardrail evaluation failed, falling back to rule-based', { error }); + return this.applyThreshold(ruleDecision); + } + } + + return this.applyThreshold(ruleDecision); + } + + /** + * Rule-based evaluation (fast path) + */ + private evaluateByRules(toolCall: ToolCall, context: ExecutionContext): GuardrailDecision { + const { name, args } = toolCall; + + // Navigation checks + if (name === 'navigate_url' || name === 'navigate') { + const url = (args.url as string) || ''; + return evaluateNavigation(url, context); + } + + // Type/input checks + if (name === 'type' || name === 'fill_form') { + return evaluateDataEntry(args, context); + } + + // Click checks + if (name === 'click') { + return evaluateClick(args); + } + + // JavaScript execution + if (name === 'execute_javascript' || name === 'evaluate') { + return evaluateScript(args); + } + + // Default: safe + return this.createSafeDecision('Standard action - no specific risks identified'); + } + + /** + * LLM-based evaluation for nuanced cases + */ + private async evaluateWithLLM( + toolCall: ToolCall, + context: ExecutionContext + ): Promise { + const applicablePolicies = getPoliciesForTool(toolCall.name); + const policyPrompt = this.buildPolicyPrompt(applicablePolicies); + + const prompt = `You are a safety guardrail evaluator for a browser automation agent. Use the provided policies to classify this action. + +## POLICIES +${policyPrompt} + +## ACTION TO EVALUATE +Tool: ${toolCall.name} +Arguments: ${JSON.stringify(toolCall.args, null, 2)} +Current URL: ${context.currentUrl || 'unknown'} +User's stated goal: ${context.userGoal || 'unknown'} + +## INSTRUCTIONS +1. Identify which policy applies to this action +2. Reason through the policy criteria step by step +3. Classify as: safe (auto-approve), violation (require approval), or escalate (human judgment needed) + +## OUTPUT FORMAT (JSON only, no markdown) +{ + "reasoning": "Step-by-step analysis of policy criteria...", + "policyMatched": "policy_name", + "decision": "safe|violation|escalate", + "riskLevel": "none|low|medium|high|critical", + "suggestedMessage": "Human-readable explanation for the user" +}`; + + // Reasoning effort controls response detail level + // (maxTokens could be used with providers that support it) + const _reasoningEffort = this.config.llm.reasoningEffort; + + const response = await this.llmClient.call({ + provider: this.config.llm.provider!, + model: this.config.llm.model || 'gpt-4o-mini', + messages: [{ role: 'user', content: prompt }], + systemPrompt: 'You are a safety classifier. Respond only with valid JSON.', + temperature: 0.1, + }); + + // Extract JSON from response + const content = response.text || ''; + const jsonMatch = content.match(/\{[\s\S]*\}/); + if (!jsonMatch) { + throw new Error('Could not parse LLM response as JSON'); + } + + const parsed = JSON.parse(jsonMatch[0]); + return { + requiresApproval: parsed.decision !== 'safe', + riskLevel: parsed.riskLevel || 'medium', + decision: parsed.decision || 'escalate', + reasoning: parsed.reasoning || 'LLM evaluation', + policyMatched: parsed.policyMatched, + suggestedMessage: parsed.suggestedMessage || 'Please review this action.', + isDefinitive: true, + }; + } + + /** + * Build policy prompt for LLM evaluation + */ + private buildPolicyPrompt(policies: Policy[]): string { + return policies.map(p => ` +### ${p.name}: ${p.description} +**Instructions**: ${p.instructions} +**Definitions**: ${Object.entries(p.definitions).map(([k, v]) => `- ${k}: ${v}`).join('\n')} +**Violations (require approval)**: +${p.violations.map(v => `- ${v}`).join('\n')} +**Safe content (auto-approve)**: +${p.safeContent.map(s => `- ${s}`).join('\n')} +**Escalate to human**: +${p.escalateCriteria.map(e => `- ${e}`).join('\n')} +`).join('\n---\n'); + } + + /** + * Apply threshold to decision + */ + private applyThreshold(decision: GuardrailDecision): GuardrailDecision { + const decisionRiskOrder = RISK_LEVEL_ORDER[decision.riskLevel]; + const thresholdOrder = RISK_LEVEL_ORDER[this.config.approvalThreshold]; + + // If decision is below threshold and not a violation, don't require approval + if (decisionRiskOrder < thresholdOrder && decision.decision !== 'violation') { + return { + ...decision, + requiresApproval: false, + }; + } + + return decision; + } + + /** + * Create a safe decision + */ + private createSafeDecision(reasoning: string): GuardrailDecision { + return { + requiresApproval: false, + riskLevel: 'none', + decision: 'safe', + reasoning, + suggestedMessage: 'Action approved automatically.', + isDefinitive: true, + }; + } + + /** + * Create a violation decision + */ + private createViolationDecision( + riskLevel: RiskLevel, + reasoning: string, + suggestedMessage: string, + policyMatched?: string + ): GuardrailDecision { + return { + requiresApproval: true, + riskLevel, + decision: 'violation', + reasoning, + policyMatched, + suggestedMessage, + isDefinitive: true, + }; + } +} + +// ============================================================================ +// Prompt Compiler (merged from GuardrailPromptCompiler) +// ============================================================================ + +export interface ToolWithApprovalConfig { + name: string; + approvalConfig?: ToolApprovalConfig; +} + +/** + * Compile guardrail policies and tool approval requirements into a system prompt block. + * This makes the LLM aware of constraints and helps it proactively avoid triggering approvals. + */ +export function compileGuardrailPrompt( + activePolicies: Policy[] = POLICIES, + tools: ToolWithApprovalConfig[] = [] +): string { + const sections: string[] = []; + + // Section 1: Tools requiring approval + const approvalTools = tools.filter(t => t.approvalConfig?.requiresApproval); + if (approvalTools.length > 0) { + const toolList = approvalTools.map(t => { + const message = t.approvalConfig?.approvalMessage || 'Requires approval'; + const risk = t.approvalConfig?.riskLevel || 'medium'; + return `- **${t.name}**: ${message} (Risk: ${risk})`; + }).join('\n'); + + sections.push(`## Tool Approval Requirements + +The following tools require human approval before execution: +${toolList} + +When using these tools, the user will be prompted to approve or reject the action. Consider alternatives if possible, or explain your reasoning clearly before invoking them.`); + } + + // Section 2: Active policies (summarized for LLM awareness) + if (activePolicies.length > 0) { + const policySummaries = activePolicies.map(p => { + const violations = p.violations.slice(0, 2).join('; ') || 'Various high-risk actions'; + const safe = p.safeContent.slice(0, 2).join('; ') || 'Standard actions'; + return `### ${p.name} +${p.description} +- **Actions that may require approval**: ${violations} +- **Safe actions**: ${safe}`; + }).join('\n\n'); + + sections.push(`## Safety Policies + +The following safety policies are active and may trigger approval requests: + +${policySummaries}`); + } + + // Section 3: General guidance + sections.push(`## Approval Guidance + +To minimize approval interruptions and provide a smoother user experience: +1. Prefer actions within the current domain over external navigation +2. Avoid entering sensitive data (passwords, payment info) unless explicitly requested by the user +3. Use read-only operations when possible before requesting write operations +4. Explain your reasoning before taking high-risk actions so users understand the context +5. If an action requires approval, clearly state what you're about to do and why`); + + // Only return content if there's something meaningful + if (sections.length === 1) { + return ''; + } + + return sections.join('\n\n'); +} + +/** + * Check if any guardrail content should be injected + */ +export function hasGuardrailContent( + policies: Policy[] = POLICIES, + tools: ToolWithApprovalConfig[] = [] +): boolean { + const hasApprovalTools = tools.some(t => t.approvalConfig?.requiresApproval); + const hasPolicies = policies.length > 0; + return hasApprovalTools || hasPolicies; +} diff --git a/front_end/panels/ai_chat/guardrails/index.ts b/front_end/panels/ai_chat/guardrails/index.ts new file mode 100644 index 0000000000..36ac54803e --- /dev/null +++ b/front_end/panels/ai_chat/guardrails/index.ts @@ -0,0 +1,55 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/** + * Guardrails Module - Public API + * + * Provides a unified interface for: + * - Policy-based tool evaluation + * - Human-in-the-loop approval gating + * - LLM-enhanced safety assessment + */ + +// Types +export type { + RiskLevel, + DecisionType, + ReasoningEffort, + ToolCall, + ToolApprovalConfig, + ExecutionContext, + GuardrailDecision, + ApprovalRequest, + ApprovalResult, + GateResult, + LLMConfig, + GuardrailConfig, + PolicyExample, + Policy, +} from './types.js'; + +export { + DEFAULT_GUARDRAIL_CONFIG, + RISK_LEVEL_ORDER, + isRiskAtOrAbove, +} from './types.js'; + +// Policy definitions +export { POLICIES, getPoliciesForTool, getPolicy } from './policies.js'; + +// Policy evaluator +export { PolicyEvaluator, compileGuardrailPrompt, hasGuardrailContent } from './PolicyEvaluator.js'; + +// Main middleware +export { + GuardrailMiddleware, + GuardrailEvents, + getGuardrailMiddleware, + resetGuardrailMiddleware, +} from './GuardrailMiddleware.js'; + +export type { + ApprovalRequestedEvent, + ApprovalResolvedEvent, +} from './GuardrailMiddleware.js'; diff --git a/front_end/panels/ai_chat/guardrails/policies.ts b/front_end/panels/ai_chat/guardrails/policies.ts new file mode 100644 index 0000000000..f578bfa8ba --- /dev/null +++ b/front_end/panels/ai_chat/guardrails/policies.ts @@ -0,0 +1,619 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/** + * Guardrail Policies - Policy definitions for tool execution safety. + * Consolidated from GuardrailPolicies.ts. + */ + +import type { Policy, ExecutionContext, GuardrailDecision, RiskLevel } from './types.js'; + +// ============================================================================ +// Policy Definitions +// ============================================================================ + +export const POLICIES: Policy[] = [ + { + name: 'external_navigation', + description: 'Controls navigation to external websites', + instructions: 'Evaluate if the navigation target is safe and expected based on the current context.', + definitions: { + 'external_domain': 'A domain different from the current page domain', + 'known_safe': 'Domains explicitly allowlisted by the user or commonly trusted (google.com, github.com, etc.)', + 'suspicious_domain': 'Domains with unusual TLDs, misspellings of known brands, or unknown reputation', + }, + violations: [ + 'Navigating to a domain that appears to be a phishing attempt (misspelled brand names)', + 'Navigating to domains with suspicious TLDs (.xyz, .top, .click) unless explicitly requested', + 'Following redirect chains that unexpectedly change domains', + ], + safeContent: [ + 'Navigating within the same domain', + 'Navigating to well-known, trusted domains (google.com, github.com, stackoverflow.com)', + 'Navigating to URLs explicitly mentioned by the user in their request', + ], + escalateCriteria: [ + 'Unknown external domain that may be legitimate but cannot be verified', + 'Redirect chains that cross domain boundaries', + 'Navigation to download pages or file hosting sites', + ], + examples: [ + { + input: { + toolName: 'navigate_url', + args: { url: 'https://google.com/search?q=test' }, + context: 'User asked to search for something', + }, + decision: 'safe', + reasoning: 'Google.com is a well-known trusted domain and the user requested a search.', + }, + { + input: { + toolName: 'navigate_url', + args: { url: 'https://amaz0n-deals.xyz/login' }, + context: 'Agent found a link claiming to offer Amazon deals', + }, + decision: 'violation', + reasoning: 'This appears to be a phishing domain - misspelled brand name (amaz0n) with suspicious TLD (.xyz).', + }, + { + input: { + toolName: 'navigate_url', + args: { url: 'https://acme-corp.com/products' }, + context: 'User asked to check a company website', + }, + decision: 'escalate', + reasoning: 'Unknown domain - cannot verify if this is the legitimate company website without user confirmation.', + }, + ], + applicableTools: ['navigate_url', 'click'], + }, + + { + name: 'sensitive_data_entry', + description: 'Controls typing sensitive information into form fields', + instructions: 'Evaluate if the data being entered is sensitive and if the target field/site is appropriate.', + definitions: { + 'sensitive_data': 'Passwords, credit card numbers, SSN, API keys, or other credentials', + 'pii': 'Personally identifiable information like full name, address, phone, email', + 'secure_context': 'HTTPS connection to a verified, legitimate domain', + }, + violations: [ + 'Entering passwords or credentials on non-HTTPS sites', + 'Typing credit card information on unverified e-commerce sites', + 'Entering API keys or secrets into any form field', + 'Typing SSN or government ID numbers', + ], + safeContent: [ + 'Typing search queries into search boxes', + 'Entering non-sensitive form data (comments, messages, etc.)', + 'Typing into text editors or note-taking applications', + ], + escalateCriteria: [ + 'Entering email addresses (PII but often required)', + 'Filling login forms on legitimate but unfamiliar sites', + 'Auto-filling saved credentials on recognized sites', + ], + examples: [ + { + input: { + toolName: 'type', + args: { selector: '#search', text: 'best restaurants near me' }, + context: 'User asked to search for restaurants', + }, + decision: 'safe', + reasoning: 'Typing a search query is not sensitive data.', + }, + { + input: { + toolName: 'type', + args: { selector: '#password', text: 'mySecretPassword123' }, + context: 'Agent attempting to log into a website', + }, + decision: 'violation', + reasoning: 'Password entry requires explicit user approval to prevent credential theft.', + }, + { + input: { + toolName: 'type', + args: { selector: '#email', text: 'user@example.com' }, + context: 'Filling out a newsletter signup form', + }, + decision: 'escalate', + reasoning: 'Email is PII - user should confirm they want to share this information.', + }, + ], + applicableTools: ['type', 'fill_form'], + }, + + { + name: 'form_submission', + description: 'Controls form submissions that may have side effects', + instructions: 'Evaluate if the form submission could have significant consequences.', + definitions: { + 'transactional_form': 'Forms that trigger purchases, transfers, or irreversible actions', + 'data_collection_form': 'Forms that collect and submit personal information', + 'safe_form': 'Search forms, filters, or navigation controls', + }, + violations: [ + 'Submitting payment or checkout forms', + 'Submitting forms that delete data or accounts', + 'Submitting legal agreements or contracts', + ], + safeContent: [ + 'Submitting search queries', + 'Applying filters or sort options', + 'Navigating via form-based menus', + ], + escalateCriteria: [ + 'Submitting contact or registration forms', + 'Submitting feedback or review forms', + 'Submitting any form with filled personal information', + ], + examples: [ + { + input: { + toolName: 'click', + args: { selector: 'button[type="submit"]', elementInfo: 'Search button' }, + context: 'Submitting a search form', + }, + decision: 'safe', + reasoning: 'Search form submission has no side effects beyond navigation.', + }, + { + input: { + toolName: 'click', + args: { selector: '#place-order', elementInfo: 'Place Order - $99.99' }, + context: 'Checkout page with items in cart', + }, + decision: 'violation', + reasoning: 'Order placement is a financial transaction requiring explicit user approval.', + }, + ], + applicableTools: ['click', 'submit_form'], + }, + + { + name: 'script_execution', + description: 'Controls execution of JavaScript code', + instructions: 'Evaluate if the JavaScript code could cause harm or unexpected behavior.', + definitions: { + 'read_only_script': 'Scripts that only query data without modifying state', + 'mutating_script': 'Scripts that modify DOM, storage, or make network requests', + 'dangerous_script': 'Scripts that access credentials, make purchases, or delete data', + }, + violations: [ + 'Scripts that access localStorage/sessionStorage credentials', + 'Scripts that make fetch/XHR requests to external domains', + 'Scripts that modify cookie values', + 'Scripts that trigger downloads or file access', + ], + safeContent: [ + 'Scripts that query DOM elements for information', + 'Scripts that read page content or metadata', + 'Scripts that check element visibility or state', + ], + escalateCriteria: [ + 'Scripts that modify DOM elements', + 'Scripts that scroll or interact with the page', + 'Scripts with complex logic that is hard to verify', + ], + examples: [ + { + input: { + toolName: 'execute_javascript', + args: { code: 'document.querySelectorAll("a").length' }, + context: 'Counting links on a page', + }, + decision: 'safe', + reasoning: 'Read-only DOM query with no side effects.', + }, + { + input: { + toolName: 'execute_javascript', + args: { code: 'localStorage.getItem("auth_token")' }, + context: 'Agent trying to get authentication info', + }, + decision: 'violation', + reasoning: 'Accessing stored credentials is a security-sensitive operation.', + }, + ], + applicableTools: ['execute_javascript', 'evaluate'], + }, + + { + name: 'file_operations', + description: 'Controls file downloads and uploads', + instructions: 'Evaluate if the file operation is safe and expected.', + definitions: { + 'safe_download': 'Downloads from trusted sources explicitly requested by user', + 'risky_download': 'Executable files, scripts, or downloads from unknown sources', + 'upload': 'Any file upload operation', + }, + violations: [ + 'Downloading executable files (.exe, .dmg, .sh, .bat)', + 'Uploading files without explicit user consent', + 'Downloading from suspicious or unknown domains', + ], + safeContent: [ + 'Downloading documents explicitly requested by user (PDF, images)', + 'Downloading from verified, trusted sources', + ], + escalateCriteria: [ + 'Any file download not explicitly requested', + 'Downloads from unfamiliar but potentially legitimate sources', + ], + examples: [ + { + input: { + toolName: 'click', + args: { selector: 'a[href$=".exe"]', elementInfo: 'Download installer' }, + context: 'Software download page', + }, + decision: 'violation', + reasoning: 'Executable downloads require explicit user approval due to security risks.', + }, + ], + applicableTools: ['click', 'download_file'], + }, +]; + +// ============================================================================ +// Policy Lookup Utilities +// ============================================================================ + +/** + * Get policies applicable to a specific tool + */ +export function getPoliciesForTool(toolName: string): Policy[] { + return POLICIES.filter(policy => { + if (!policy.applicableTools || policy.applicableTools.length === 0) { + return true; + } + return policy.applicableTools.includes(toolName); + }); +} + +/** + * Get a specific policy by name + */ +export function getPolicy(name: string): Policy | undefined { + return POLICIES.find(p => p.name === name); +} + +// ============================================================================ +// Rule-Based Evaluation Functions +// ============================================================================ + +/** Trusted domains that don't require approval */ +const TRUSTED_DOMAINS = [ + 'google.com', 'www.google.com', + 'github.com', 'www.github.com', + 'stackoverflow.com', 'www.stackoverflow.com', + 'wikipedia.org', 'en.wikipedia.org', + 'youtube.com', 'www.youtube.com', + 'amazon.com', 'www.amazon.com', +]; + +/** Suspicious TLDs often associated with spam/phishing */ +const SUSPICIOUS_TLDS = ['.xyz', '.top', '.click', '.work', '.tk', '.ml']; + +/** Brand misspellings commonly used in phishing */ +const BRAND_MISSPELLINGS = ['amaz0n', 'g00gle', 'faceb00k', 'paypa1', 'micros0ft']; + +/** + * Evaluate navigation actions + */ +export function evaluateNavigation(url: string, context: ExecutionContext): GuardrailDecision { + try { + const targetUrl = new URL(url); + const targetDomain = targetUrl.hostname; + + // Same domain navigation is safe + if (context.currentDomain && targetDomain === context.currentDomain) { + return createSafeDecision('Navigation within same domain'); + } + + // Trusted domains are safe with low risk + if (TRUSTED_DOMAINS.some(d => targetDomain.endsWith(d))) { + return { + requiresApproval: false, + riskLevel: 'low', + decision: 'safe', + reasoning: 'Navigation to well-known trusted domain', + policyMatched: 'external_navigation', + suggestedMessage: `Navigating to trusted domain: ${targetDomain}`, + isDefinitive: true, + }; + } + + // Suspicious TLDs require approval + if (SUSPICIOUS_TLDS.some(tld => targetDomain.endsWith(tld))) { + return createViolationDecision( + 'high', + 'Navigation to domain with suspicious TLD', + `The agent wants to navigate to ${url}. This domain uses a suspicious TLD commonly associated with spam or phishing.`, + 'external_navigation' + ); + } + + // Brand misspellings are critical violations + if (BRAND_MISSPELLINGS.some(m => targetDomain.includes(m))) { + return createViolationDecision( + 'critical', + 'Possible phishing domain detected', + `The agent wants to navigate to ${url}. This domain appears to be a phishing attempt with a misspelled brand name.`, + 'external_navigation' + ); + } + + // Unknown external domain - escalate for human judgment + return { + requiresApproval: true, + riskLevel: 'medium', + decision: 'escalate', + reasoning: 'Unknown external domain - cannot verify legitimacy without user input', + policyMatched: 'external_navigation', + suggestedMessage: `The agent wants to navigate to an external domain: ${targetDomain}. Please confirm this is the intended destination.`, + isDefinitive: false, + }; + + } catch { + // Invalid URL - escalate + return { + requiresApproval: true, + riskLevel: 'medium', + decision: 'escalate', + reasoning: 'Could not parse URL for safety evaluation', + policyMatched: 'external_navigation', + suggestedMessage: `The agent wants to navigate to: ${url}. Please verify this URL is correct.`, + isDefinitive: false, + }; + } +} + +/** + * Evaluate text input actions + */ +export function evaluateDataEntry( + args: Record, + _context: ExecutionContext +): GuardrailDecision { + const text = (args.text as string) || ''; + const selector = (args.selector as string) || ''; + const elementInfo = (args.elementInfo as string) || ''; + + // Password fields + if ( + selector.includes('password') || + elementInfo.toLowerCase().includes('password') || + selector.includes('type="password"') + ) { + return createViolationDecision( + 'high', + 'Typing into password field', + 'The agent wants to enter text into a password field. Please confirm you want to allow credential entry.', + 'sensitive_data_entry' + ); + } + + // Credit card patterns + const ccPattern = /\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b/; + if (ccPattern.test(text)) { + return createViolationDecision( + 'critical', + 'Credit card number detected', + 'The agent wants to enter what appears to be a credit card number. This requires explicit approval.', + 'sensitive_data_entry' + ); + } + + // SSN pattern + const ssnPattern = /\b\d{3}[-]?\d{2}[-]?\d{4}\b/; + if (ssnPattern.test(text)) { + return createViolationDecision( + 'critical', + 'SSN pattern detected', + 'The agent wants to enter what appears to be a Social Security Number. This requires explicit approval.', + 'sensitive_data_entry' + ); + } + + // API key patterns + const apiKeyPatterns = [ + /sk-[a-zA-Z0-9]{32,}/, // OpenAI + /ghp_[a-zA-Z0-9]{36}/, // GitHub + /AKIA[0-9A-Z]{16}/, // AWS + ]; + if (apiKeyPatterns.some(p => p.test(text))) { + return createViolationDecision( + 'critical', + 'API key or secret detected', + 'The agent wants to enter what appears to be an API key or secret. This is highly sensitive.', + 'sensitive_data_entry' + ); + } + + // Email addresses - escalate (PII but often necessary) + const emailPattern = /\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b/; + if (emailPattern.test(text)) { + return { + requiresApproval: true, + riskLevel: 'low', + decision: 'escalate', + reasoning: 'Email address is PII - user should confirm sharing', + policyMatched: 'sensitive_data_entry', + suggestedMessage: `The agent wants to enter an email address: ${text.match(emailPattern)?.[0]}. Is this correct?`, + isDefinitive: false, + }; + } + + // Standard text input is safe + return createSafeDecision('Standard text input - no sensitive data detected'); +} + +/** + * Evaluate click actions + */ +export function evaluateClick(args: Record): GuardrailDecision { + const selector = (args.selector as string) || ''; + const elementInfo = (args.elementInfo as string) || ''; + const text = elementInfo.toLowerCase(); + + // Purchase/checkout buttons + const purchaseKeywords = ['place order', 'buy now', 'purchase', 'checkout', 'pay now', 'submit order', 'confirm purchase']; + if (purchaseKeywords.some(k => text.includes(k))) { + return createViolationDecision( + 'critical', + 'Purchase/checkout action detected', + `The agent wants to click: "${elementInfo}". This appears to be a purchase action that requires your explicit approval.`, + 'form_submission' + ); + } + + // Delete/destructive actions + const deleteKeywords = ['delete', 'remove', 'cancel account', 'deactivate', 'unsubscribe']; + if (deleteKeywords.some(k => text.includes(k))) { + return createViolationDecision( + 'high', + 'Destructive action detected', + `The agent wants to click: "${elementInfo}". This appears to be a destructive action.`, + 'form_submission' + ); + } + + // Downloads + if ( + selector.includes('.exe') || + selector.includes('.dmg') || + selector.includes('.msi') || + selector.includes('download') || + text.includes('download') + ) { + return { + requiresApproval: true, + riskLevel: 'medium', + decision: 'escalate', + reasoning: 'Download action detected', + policyMatched: 'file_operations', + suggestedMessage: 'The agent wants to initiate a download. Please confirm this is intended.', + isDefinitive: false, + }; + } + + // Form submit buttons (generally safe) + if ( + selector.includes('type="submit"') || + text.includes('submit') || + text.includes('send') + ) { + return { + requiresApproval: false, + riskLevel: 'low', + decision: 'safe', + reasoning: 'Form submission - not a high-risk transaction', + policyMatched: 'form_submission', + suggestedMessage: 'Standard form submission', + isDefinitive: true, + }; + } + + return createSafeDecision('Standard click action'); +} + +/** + * Evaluate JavaScript execution + */ +export function evaluateScript(args: Record): GuardrailDecision { + const code = (args.code as string) || (args.expression as string) || ''; + + // Credential access patterns + const credentialPatterns = [ + /localStorage\.getItem\s*\(\s*['"`].*(?:token|auth|session|key|password)/i, + /sessionStorage\.getItem/i, + /document\.cookie/i, + ]; + if (credentialPatterns.some(p => p.test(code))) { + return createViolationDecision( + 'high', + 'Credential access in script', + 'The script attempts to access stored credentials or sensitive data.', + 'script_execution' + ); + } + + // Network requests + if (/fetch\s*\(|XMLHttpRequest|\.ajax\(/.test(code)) { + return { + requiresApproval: true, + riskLevel: 'medium', + decision: 'escalate', + reasoning: 'Script makes network requests', + policyMatched: 'script_execution', + suggestedMessage: 'The script makes network requests. Please review before allowing.', + isDefinitive: false, + }; + } + + // DOM modification (generally safe) + if (/\.innerHTML|\.outerHTML|document\.write|\.insertAdjacentHTML/.test(code)) { + return { + requiresApproval: false, + riskLevel: 'low', + decision: 'safe', + reasoning: 'Script modifies DOM - generally safe', + policyMatched: 'script_execution', + suggestedMessage: 'Script modifies page content', + isDefinitive: true, + }; + } + + // Read-only queries are safe + if (/querySelector|querySelectorAll|getElementById|getElementsBy|\.textContent|\.innerText/.test(code)) { + return createSafeDecision('Read-only DOM query'); + } + + // Unknown script - escalate + return { + requiresApproval: true, + riskLevel: 'medium', + decision: 'escalate', + reasoning: 'Script with unknown effects', + policyMatched: 'script_execution', + suggestedMessage: 'Please review this script before execution.', + isDefinitive: false, + }; +} + +// ============================================================================ +// Helper Functions +// ============================================================================ + +function createSafeDecision(reasoning: string): GuardrailDecision { + return { + requiresApproval: false, + riskLevel: 'none', + decision: 'safe', + reasoning, + suggestedMessage: 'Action approved automatically.', + isDefinitive: true, + }; +} + +function createViolationDecision( + riskLevel: RiskLevel, + reasoning: string, + suggestedMessage: string, + policyMatched?: string +): GuardrailDecision { + return { + requiresApproval: true, + riskLevel, + decision: 'violation', + reasoning, + policyMatched, + suggestedMessage, + isDefinitive: true, + }; +} diff --git a/front_end/panels/ai_chat/guardrails/types.ts b/front_end/panels/ai_chat/guardrails/types.ts new file mode 100644 index 0000000000..0e3f8eb3c9 --- /dev/null +++ b/front_end/panels/ai_chat/guardrails/types.ts @@ -0,0 +1,267 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/** + * Unified types for the guardrail system. + * Single source of truth - replaces fragmented types across multiple files. + */ + +import type { LLMProvider } from '../LLM/LLMTypes.js'; + +// ============================================================================ +// Core Types +// ============================================================================ + +/** + * Risk levels from lowest to highest + */ +export type RiskLevel = 'none' | 'low' | 'medium' | 'high' | 'critical'; + +/** + * Decision types following GPT-OSS Safeguard patterns + */ +export type DecisionType = 'safe' | 'violation' | 'escalate'; + +/** + * Reasoning effort levels - controls LLM token budget and depth + */ +export type ReasoningEffort = 'low' | 'medium' | 'high'; + +// ============================================================================ +// Tool Call Types +// ============================================================================ + +/** + * Unified tool call representation + */ +export interface ToolCall { + /** Tool name */ + name: string; + /** Tool arguments */ + args: Record; + /** Unique call ID */ + callId: string; +} + +/** + * Tool-level approval configuration (from tool definitions) + */ +export interface ToolApprovalConfig { + /** Whether this tool always requires approval */ + requiresApproval?: boolean; + /** Default risk level for this tool */ + riskLevel?: RiskLevel; + /** Custom approval message */ + approvalMessage?: string; +} + +// ============================================================================ +// Execution Context +// ============================================================================ + +/** + * Context provided during evaluation - always complete, no undefined values + */ +export interface ExecutionContext { + /** Current page URL (required) */ + currentUrl: string; + /** Current page domain (extracted from URL) */ + currentDomain: string; + /** User's stated goal/task (optional) */ + userGoal?: string; + /** Previous actions in this session */ + recentActions?: string[]; +} + +// ============================================================================ +// Guardrail Decision +// ============================================================================ + +/** + * Result of guardrail evaluation + */ +export interface GuardrailDecision { + /** Whether human approval is required */ + requiresApproval: boolean; + /** Risk level assessment */ + riskLevel: RiskLevel; + /** Decision type */ + decision: DecisionType; + /** Chain-of-thought reasoning (transparent audit trail) */ + reasoning: string; + /** Which policy triggered this decision */ + policyMatched?: string; + /** Human-readable message for the approval UI */ + suggestedMessage: string; + /** Whether this is a definitive decision (no need for LLM escalation) */ + isDefinitive?: boolean; +} + +// ============================================================================ +// Approval Types +// ============================================================================ + +/** + * Unified approval request (replaces dual message types) + */ +export interface ApprovalRequest { + /** Unique approval ID */ + id: string; + /** The tool call that requires approval */ + toolCall: ToolCall; + /** Guardrail decision details */ + decision: GuardrailDecision; + /** Current status */ + status: 'pending' | 'approved' | 'rejected'; + /** User feedback (especially on rejection) */ + feedback?: string; + /** When the request was created */ + timestamp: number; +} + +/** + * Result of an approval decision + */ +export interface ApprovalResult { + /** Whether the action was approved */ + approved: boolean; + /** Optional feedback from user (especially on rejection) */ + feedback?: string; + /** Time taken for user to respond (ms) */ + responseTimeMs?: number; +} + +/** + * Result of the guardrail gate + */ +export interface GateResult { + /** Whether to proceed with tool execution */ + proceed: boolean; + /** User feedback if rejected (passed to agent) */ + feedback?: string; + /** The guardrail decision that was made */ + decision?: GuardrailDecision; +} + +// ============================================================================ +// Configuration +// ============================================================================ + +/** + * LLM configuration for guardrail evaluation + */ +export interface LLMConfig { + /** Enable LLM-based evaluation for nuanced decisions */ + enabled: boolean; + /** LLM provider to use */ + provider?: LLMProvider; + /** Model to use (defaults to gpt-4o-mini) */ + model?: string; + /** API key for LLM calls */ + apiKey?: string; + /** Reasoning effort level */ + reasoningEffort: ReasoningEffort; +} + +/** + * Main guardrail configuration + */ +export interface GuardrailConfig { + /** Enable guardrail evaluation */ + enabled: boolean; + /** Approval threshold - require approval for this risk level and above */ + approvalThreshold: RiskLevel; + /** Tools that never require approval (bypass guardrails) */ + alwaysApprove: string[]; + /** Tools that always require approval */ + alwaysRequire: string[]; + /** LLM evaluation settings */ + llm: LLMConfig; + /** Default timeout for approvals in ms */ + approvalTimeoutMs: number; +} + +/** + * Default configuration + */ +export const DEFAULT_GUARDRAIL_CONFIG: GuardrailConfig = { + enabled: true, + approvalThreshold: 'medium', + alwaysApprove: [ + 'get_page_content', + 'get_element_info', + 'get_accessibility_tree', + 'wait', + 'screenshot', + ], + alwaysRequire: [], + llm: { + enabled: false, + reasoningEffort: 'medium', + }, + approvalTimeoutMs: 5 * 60 * 1000, // 5 minutes +}; + +// ============================================================================ +// Policy Types +// ============================================================================ + +/** + * Policy example for training the LLM guardrail + */ +export interface PolicyExample { + input: { + toolName: string; + args: Record; + context?: string; + }; + decision: DecisionType; + reasoning: string; +} + +/** + * A guardrail policy definition + */ +export interface Policy { + /** Unique identifier for the policy */ + name: string; + /** Human-readable description */ + description: string; + /** Instructions for evaluating this policy */ + instructions: string; + /** Key terms and their definitions */ + definitions: Record; + /** Criteria that trigger approval requirement */ + violations: string[]; + /** Criteria that allow auto-approval */ + safeContent: string[]; + /** Ambiguous cases requiring human judgment */ + escalateCriteria: string[]; + /** Few-shot examples for LLM evaluation */ + examples: PolicyExample[]; + /** Tools this policy applies to (empty = all tools) */ + applicableTools?: string[]; +} + +// ============================================================================ +// Risk Level Utilities +// ============================================================================ + +/** + * Risk level ordering for comparison + */ +export const RISK_LEVEL_ORDER: Record = { + 'none': 0, + 'low': 1, + 'medium': 2, + 'high': 3, + 'critical': 4, +}; + +/** + * Compare risk levels + */ +export function isRiskAtOrAbove(level: RiskLevel, threshold: RiskLevel): boolean { + return RISK_LEVEL_ORDER[level] >= RISK_LEVEL_ORDER[threshold]; +} diff --git a/front_end/panels/ai_chat/memory/ListMemoryBlocksTool.ts b/front_end/panels/ai_chat/memory/ListMemoryBlocksTool.ts new file mode 100644 index 0000000000..964aea41e5 --- /dev/null +++ b/front_end/panels/ai_chat/memory/ListMemoryBlocksTool.ts @@ -0,0 +1,93 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { createLogger } from '../core/Logger.js'; +import type { Tool, LLMContext } from '../tools/Tools.js'; +import { MemoryBlockManager } from './MemoryBlockManager.js'; + +const logger = createLogger('Tool:ListMemoryBlocks'); + +export interface ListMemoryBlocksArgs { + // No arguments needed +} + +export interface ListMemoryBlocksResult { + success: boolean; + blocks: Array<{ + type: string; + label: string; + content: string; + charCount: number; + charLimit: number; + updatedAt: string; + }>; + summary: { + totalBlocks: number; + totalChars: number; + maxChars: number; + }; + error?: string; +} + +/** + * Tool for listing all memory blocks with their content and metadata. + * Useful for the MemoryAgent to see current memory state before making updates. + */ +export class ListMemoryBlocksTool implements Tool { + name = 'list_memory_blocks'; + description = 'List all memory blocks with their current content and metadata (size, limits, last updated). Use this to see the current state of memory before making updates.'; + + schema = { + type: 'object', + properties: {}, + required: [] + }; + + async execute(_args: ListMemoryBlocksArgs, _ctx?: LLMContext): Promise { + logger.info('Executing list memory blocks'); + + try { + const manager = new MemoryBlockManager(); + const blocks = await manager.getAllBlocks(); + + const formattedBlocks = blocks.map(b => ({ + type: b.type, + label: b.label, + content: b.content, + charCount: b.content.length, + charLimit: b.charLimit, + updatedAt: new Date(b.updatedAt).toISOString() + })); + + // Calculate max capacity: 20000 (user) + 20000 (facts) + 4*20000 (projects) + const maxChars = 120000; + + const summary = { + totalBlocks: blocks.length, + totalChars: blocks.reduce((sum, b) => sum + b.content.length, 0), + maxChars + }; + + logger.info('Listed memory blocks', { blockCount: blocks.length, totalChars: summary.totalChars }); + + return { + success: true, + blocks: formattedBlocks, + summary + }; + } catch (error: any) { + logger.error('Failed to list memory blocks', { error: error?.message }); + return { + success: false, + blocks: [], + summary: { + totalBlocks: 0, + totalChars: 0, + maxChars: 9500 + }, + error: error?.message || 'Failed to list memory blocks.' + }; + } + } +} diff --git a/front_end/panels/ai_chat/memory/MemoryAgentConfig.ts b/front_end/panels/ai_chat/memory/MemoryAgentConfig.ts new file mode 100644 index 0000000000..e449545c38 --- /dev/null +++ b/front_end/panels/ai_chat/memory/MemoryAgentConfig.ts @@ -0,0 +1,284 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import type { AgentToolConfig, ConfigurableAgentArgs } from '../agent_framework/ConfigurableAgentTool.js'; +import { ChatMessageEntity } from '../models/ChatTypes.js'; +import type { ChatMessage } from '../models/ChatTypes.js'; +import { MODEL_SENTINELS } from '../core/Constants.js'; +import { AGENT_VERSION } from '../agent_framework/implementation/agents/AgentVersion.js'; + +/** + * Memory agent mode determines behavior and configuration. + */ +export type MemoryAgentMode = 'extraction' | 'search'; + +// Extraction mode prompt - runs after conversations to consolidate facts +const EXTRACTION_PROMPT = `You are a Memory Consolidation Agent that runs in the background after conversations end. + +## Your Purpose +Extract and organize important information from completed conversations into persistent memory blocks that will help the assistant in future conversations. + +## Memory Block Types + +| Block | Purpose | Max Size | +|-------|---------|----------| +| user | User identity, preferences, communication style | 20000 chars | +| facts | Factual information learned from conversations | 20000 chars | +| project_ | Project-specific context (up to 4 projects) | 20000 chars each | + +## Workflow + +1. **List current memory** using list_memory_blocks +2. **Analyze** the conversation for extractable information +3. **Check for duplicates** before adding new facts +4. **Update blocks** with consolidated, organized content +5. **Verify** changes are correct and within limits + +## What to Extract + +### High Priority (Always Extract) +- User's name, role, job title +- Explicit preferences ("I prefer...", "I like...", "Always use...") +- Project names, tech stacks, goals +- Recurring patterns in requests + +### Medium Priority (Extract if Relevant) +- Problem-solving approaches that worked +- Tools/libraries the user uses frequently +- Team members or collaborators mentioned + +### Skip (Do Not Extract) +- One-time troubleshooting details +- Temporary debugging information +- Generic conversation pleasantries +- Information already in memory + +## Writing Guidelines + +### Be Specific with Dates +❌ "Recently discussed migration" +✅ "2025-01-15: Discussed database migration to PostgreSQL" + +### Be Concise +❌ "The user mentioned that they have a strong preference for using TypeScript in their projects because they find it helps catch errors" +✅ "Prefers TypeScript for type safety" + +### Use Bullet Points +\`\`\` +- Name: Alex Chen +- Role: Senior Frontend Engineer +- Prefers: TypeScript, React, Tailwind CSS +- Dislikes: Inline styles, any types +\`\`\` + +### Consolidate Related Info +If user block has: +\`\`\` +- Likes dark mode +- Uses VS Code +- Prefers dark themes +\`\`\` + +Consolidate to: +\`\`\` +- Prefers dark mode/themes +- Uses VS Code +\`\`\` + +## Examples + +### Example 1: User Preferences +**Conversation excerpt:** +> User: "Hey, I'm Sarah. Can you help me debug this React component? I always use functional components with hooks, never class components." + +**Memory update (user block):** +\`\`\` +- Name: Sarah +- React: Functional components + hooks only, no class components +\`\`\` + +### Example 2: Project Context +**Conversation excerpt:** +> User: "Working on our e-commerce platform. We're using Next.js 14 with App Router, Prisma for the database, and Stripe for payments." + +**Memory update (project_ecommerce block):** +\`\`\` +Project: E-commerce Platform +Stack: Next.js 14 (App Router), Prisma, Stripe +\`\`\` + +### Example 3: Skip Extraction +**Conversation excerpt:** +> User: "Getting a 404 error on /api/users endpoint" +> Assistant: "The route file is missing, create app/api/users/route.ts" +> User: "Fixed, thanks!" + +**Action:** No extraction needed - one-time debugging, no lasting value. + +## Output +After processing, briefly state what was updated or why nothing was updated. +`; + +// Search mode prompt - read-only queries for orchestrators +const SEARCH_PROMPT = `You are a Memory Retrieval Agent. Your job is to find and summarize relevant information from stored memory to help the assistant respond to the user. + +## Memory Structure + +| Block | Contains | +|-------|----------| +| user | User identity, preferences, communication style | +| facts | Factual information from past conversations | +| project_* | Project-specific context (tech stack, goals, current work) | + +## Workflow + +1. Use list_memory_blocks to retrieve all stored memory +2. Scan each block for information relevant to the query +3. Return a concise summary of relevant findings + +## Response Format + +### When Memory Exists +Return relevant information organized by category: + +\`\`\` +**User Context:** +- Name: Sarah, Senior Frontend Engineer +- Prefers TypeScript, functional React components + +**Relevant Project:** +- E-commerce Platform: Next.js 14, Prisma, Stripe + +**Related Facts:** +- 2025-01-10: Migrated auth to NextAuth.js +\`\`\` + +### When No Memory Exists +Simply respond: "No relevant memory found." + +### When Memory is Empty +Simply respond: "No memory stored yet." + +## Guidelines + +- Only include information relevant to the query +- Don't dump entire blocks - summarize what's useful +- Prioritize recent information over old +- If query is vague, return user preferences + active project context +`; + +/** + * Create a memory agent configuration with the specified mode. + * + * @param mode - 'extraction' for background consolidation, 'search' for read-only queries + * @returns AgentToolConfig for the memory agent + */ +export function createMemoryAgentConfig(mode: MemoryAgentMode): AgentToolConfig { + if (mode === 'extraction') { + return createExtractionConfig(); + } + return createSearchConfig(); +} + +function createExtractionConfig(): AgentToolConfig { + return { + name: 'memory_agent', + version: AGENT_VERSION, + description: 'Background memory consolidation agent that extracts facts from conversations and maintains organized memory blocks.', + + ui: { + displayName: 'Memory Agent', + avatar: '🧠', + color: '#8b5cf6', + backgroundColor: '#f5f3ff' + }, + + systemPrompt: EXTRACTION_PROMPT, + + tools: ['search_memory', 'update_memory', 'list_memory_blocks'], + + schema: { + type: 'object', + properties: { + conversation_summary: { + type: 'string', + description: 'Summary of the conversation to analyze for memory extraction' + }, + reasoning: { + type: 'string', + description: 'Why this extraction is being run' + } + }, + required: ['conversation_summary', 'reasoning'] + }, + + prepareMessages: (args: ConfigurableAgentArgs): ChatMessage[] => { + return [{ + entity: ChatMessageEntity.USER, + text: `## Conversation to Analyze + +${args.conversation_summary || ''} + +## Reason for Extraction +${args.reasoning || 'Automatic extraction after session completion'} + +Please analyze this conversation and update memory blocks as appropriate.`, + }]; + }, + + maxIterations: 5, + modelName: MODEL_SENTINELS.USE_MINI, // Cost-effective for background task + temperature: 0.1, + handoffs: [], + }; +} + +function createSearchConfig(): AgentToolConfig { + return { + name: 'search_memory_agent', + version: AGENT_VERSION, + description: 'Search user memory for relevant information. Use when you need to recall user preferences, past facts, or project context.', + + ui: { + displayName: 'Search Memory', + avatar: '🔍', + color: '#10b981', + backgroundColor: '#ecfdf5' + }, + + systemPrompt: SEARCH_PROMPT, + + tools: ['list_memory_blocks'], + + schema: { + type: 'object', + properties: { + query: { + type: 'string', + description: 'What to search for in memory (user preferences, facts, project info)' + }, + context: { + type: 'string', + description: 'Why this search is needed (helps with relevance)' + } + }, + required: ['query'] + }, + + prepareMessages: (args: ConfigurableAgentArgs): ChatMessage[] => { + return [{ + entity: ChatMessageEntity.USER, + text: `Search memory for: ${args.query || ''} +${args.context ? `\nContext: ${args.context}` : ''} + +Please search memory and return any relevant information.`, + }]; + }, + + maxIterations: 2, // Just need to list and respond + modelName: MODEL_SENTINELS.USE_NANO, // Fast, cheap model for simple searches + temperature: 0, + handoffs: [], + }; +} diff --git a/front_end/panels/ai_chat/memory/MemoryBlockManager.ts b/front_end/panels/ai_chat/memory/MemoryBlockManager.ts new file mode 100644 index 0000000000..a46ab68f66 --- /dev/null +++ b/front_end/panels/ai_chat/memory/MemoryBlockManager.ts @@ -0,0 +1,248 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { FileStorageManager } from '../tools/FileStorageManager.js'; +import { createLogger } from '../core/Logger.js'; +import { MemoryModule } from './MemoryModule.js'; +import type { BlockType, MemoryBlock, MemorySearchResult } from './types.js'; + +const logger = createLogger('MemoryBlockManager'); + +/** + * Manages memory blocks stored as files via FileStorageManager. + * Memory is global (shared across all conversations) using a reserved session ID. + * + * Block types: + * - user: User preferences, name, coding style (20000 chars) + * - facts: Recent extracted facts (20000 chars) + * - project: Project-specific context (20000 chars each, max 4) + */ +export class MemoryBlockManager { + private fileManager: FileStorageManager; + private memoryModule: MemoryModule; + + constructor() { + this.fileManager = FileStorageManager.getInstance(); + this.memoryModule = MemoryModule.getInstance(); + } + + /** + * Execute a function with the global memory session, restoring the previous session after. + */ + private async withGlobalSession(fn: () => Promise): Promise { + const prevSession = this.fileManager.getSessionId(); + this.fileManager.setSessionId(this.memoryModule.getSessionId()); + try { + return await fn(); + } finally { + this.fileManager.setSessionId(prevSession); + } + } + + // --- Block CRUD --- + + /** + * Get a memory block by type and optional project name. + */ + async getBlock(type: BlockType, projectName?: string): Promise { + return this.withGlobalSession(async () => { + const filename = this.getFilename(type, projectName); + const file = await this.fileManager.readFile(filename); + if (!file) { + return null; + } + + return { + filename, + type, + label: this.getLabel(type, projectName), + description: this.getDescription(type), + content: file.content, + charLimit: this.memoryModule.getBlockLimit(type), + updatedAt: file.updatedAt, + }; + }); + } + + /** + * Update or create a memory block. + */ + async updateBlock(type: BlockType, content: string, projectName?: string): Promise { + const limit = this.memoryModule.getBlockLimit(type); + if (content.length > limit) { + throw new Error(`Content exceeds ${limit} char limit (got ${content.length})`); + } + + return this.withGlobalSession(async () => { + const filename = this.getFilename(type, projectName); + const exists = await this.fileManager.readFile(filename); + + if (exists) { + await this.fileManager.updateFile(filename, content, false); + logger.info('Updated memory block', { type, filename }); + } else { + // Check project limit before creating new project block + if (type === 'project') { + const projects = await this.listProjectBlocks(); + if (projects.length >= this.memoryModule.getMaxProjectBlocks()) { + throw new Error(`Max ${this.memoryModule.getMaxProjectBlocks()} project blocks allowed`); + } + } + await this.fileManager.createFile(filename, content, 'text/markdown'); + logger.info('Created memory block', { type, filename }); + } + }); + } + + /** + * Delete a memory block. + */ + async deleteBlock(type: BlockType, projectName?: string): Promise { + return this.withGlobalSession(async () => { + const filename = this.getFilename(type, projectName); + try { + await this.fileManager.deleteFile(filename); + logger.info('Deleted memory block', { type, filename }); + } catch (error) { + // Ignore if file doesn't exist + logger.debug('Block not found for deletion', { type, filename }); + } + }); + } + + // --- Queries --- + + /** + * Get all memory blocks. + */ + async getAllBlocks(): Promise { + return this.withGlobalSession(async () => { + const files = await this.fileManager.listFiles(); + const blocks: MemoryBlock[] = []; + + for (const file of files) { + if (!file.fileName.startsWith('memory_')) { + continue; + } + + const fullFile = await this.fileManager.readFile(file.fileName); + if (!fullFile) { + continue; + } + + const { type, projectName } = this.parseFilename(file.fileName); + blocks.push({ + filename: file.fileName, + type, + label: this.getLabel(type, projectName), + description: this.getDescription(type), + content: fullFile.content, + charLimit: this.memoryModule.getBlockLimit(type), + updatedAt: file.updatedAt, + }); + } + + return blocks; + }); + } + + /** + * List only project blocks. + */ + async listProjectBlocks(): Promise { + const all = await this.getAllBlocks(); + return all.filter(b => b.type === 'project'); + } + + /** + * Search across all blocks for matching lines. + */ + async searchBlocks(query: string): Promise { + const blocks = await this.getAllBlocks(); + const results: MemorySearchResult[] = []; + const queryLower = query.toLowerCase(); + + for (const block of blocks) { + const lines = block.content.split('\n'); + const matches = lines.filter(line => + line.toLowerCase().includes(queryLower) + ); + if (matches.length > 0) { + results.push({ block, matches }); + } + } + + return results; + } + + // --- Helpers --- + + private getFilename(type: BlockType, projectName?: string): string { + if (type === 'project' && projectName) { + const safeName = projectName.toLowerCase().replace(/[^a-z0-9]/g, '_'); + return `memory_project_${safeName}.md`; + } + return `memory_${type}.md`; + } + + private parseFilename(filename: string): { type: BlockType; projectName?: string } { + if (filename === 'memory_user.md') { + return { type: 'user' }; + } + if (filename === 'memory_facts.md') { + return { type: 'facts' }; + } + if (filename.startsWith('memory_project_')) { + const projectName = filename.replace('memory_project_', '').replace('.md', ''); + return { type: 'project', projectName }; + } + return { type: 'facts' }; // fallback + } + + private getLabel(type: BlockType, projectName?: string): string { + if (type === 'project') { + return `project:${projectName}`; + } + return type; + } + + private getDescription(type: BlockType): string { + switch (type) { + case 'user': + return 'User preferences, name, coding style, and personal context'; + case 'facts': + return 'Recent facts extracted from conversations'; + case 'project': + return 'Project-specific context, tech stack, and goals'; + } + } + + // --- Memory Compilation (for prompt injection) --- + + /** + * Compile all memory blocks into XML context for prompt injection. + */ + async compileMemoryContext(): Promise { + const blocks = await this.getAllBlocks(); + if (blocks.length === 0) { + return ''; + } + + let context = '\n'; + + for (const block of blocks) { + if (!block.content.trim()) { + continue; + } + + context += `<${block.label}>\n`; + context += `${block.description}\n`; + context += `\n${block.content}\n\n`; + context += `\n`; + } + + context += ''; + return context; + } +} diff --git a/front_end/panels/ai_chat/memory/MemoryModule.ts b/front_end/panels/ai_chat/memory/MemoryModule.ts new file mode 100644 index 0000000000..97de7d3623 --- /dev/null +++ b/front_end/panels/ai_chat/memory/MemoryModule.ts @@ -0,0 +1,127 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import type { MemoryConfig } from './types.js'; + +/** + * Memory Module - Central facade for the memory system. + * + * Provides: + * - Configuration constants + * - Settings management (enable/disable) + * - Memory instructions for prompts + * - Tool availability checks + */ + +// Memory instructions prepended to orchestrator prompts when memory is enabled +const MEMORY_INSTRUCTIONS_TEXT = ` +You have a persistent memory system that remembers information across conversations. + +Memory is organized into blocks: +- **user**: Information about the user (name, preferences, working style) +- **facts**: Important facts learned from past conversations +- **project_***: Project-specific context (tech stack, goals, current work) + +**To access memory, use the 'search_memory_agent' tool.** Call it when: +- The user asks about something you might have discussed before +- You need to recall user preferences or past context +- The conversation involves a project you may have worked on + +Memory is updated automatically after conversations end. + + +`; + +// Default configuration +const DEFAULT_CONFIG: MemoryConfig = { + blockLimits: { + user: 20000, + facts: 20000, + project: 20000, + }, + maxProjectBlocks: 4, + sessionId: '__global_memory__', + enabledKey: 'ai_chat_memory_enabled', +}; + +/** + * Singleton class for memory system configuration and settings. + */ +export class MemoryModule { + private static instance: MemoryModule | null = null; + private config: MemoryConfig; + + private constructor() { + this.config = { ...DEFAULT_CONFIG }; + } + + /** + * Get the singleton instance. + */ + static getInstance(): MemoryModule { + if (!MemoryModule.instance) { + MemoryModule.instance = new MemoryModule(); + } + return MemoryModule.instance; + } + + /** + * Get the memory configuration. + */ + getConfig(): MemoryConfig { + return this.config; + } + + /** + * Check if memory is enabled in settings. + * Memory is enabled by default (returns true if not explicitly set to 'false'). + */ + isEnabled(): boolean { + return localStorage.getItem(this.config.enabledKey) !== 'false'; + } + + /** + * Enable or disable memory. + */ + setEnabled(enabled: boolean): void { + localStorage.setItem(this.config.enabledKey, enabled.toString()); + } + + /** + * Get memory instructions for prompt injection. + * Returns empty string if memory is disabled. + */ + getInstructions(): string { + return this.isEnabled() ? MEMORY_INSTRUCTIONS_TEXT : ''; + } + + /** + * Check if memory tool should be included in agent tools. + * Shorthand for isEnabled() - useful for tool filtering. + */ + shouldIncludeMemoryTool(): boolean { + return this.isEnabled(); + } + + /** + * Get the character limit for a specific block type. + */ + getBlockLimit(type: 'user' | 'facts' | 'project'): number { + return this.config.blockLimits[type]; + } + + /** + * Get the maximum number of project blocks allowed. + */ + getMaxProjectBlocks(): number { + return this.config.maxProjectBlocks; + } + + /** + * Get the session ID used for global memory storage. + */ + getSessionId(): string { + return this.config.sessionId; + } +} diff --git a/front_end/panels/ai_chat/memory/SearchMemoryTool.ts b/front_end/panels/ai_chat/memory/SearchMemoryTool.ts new file mode 100644 index 0000000000..c293087207 --- /dev/null +++ b/front_end/panels/ai_chat/memory/SearchMemoryTool.ts @@ -0,0 +1,72 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { createLogger } from '../core/Logger.js'; +import type { Tool, LLMContext } from '../tools/Tools.js'; +import { MemoryBlockManager } from './MemoryBlockManager.js'; + +const logger = createLogger('Tool:SearchMemory'); + +export interface SearchMemoryArgs { + query: string; +} + +export interface SearchMemoryResult { + success: boolean; + results: Array<{ + block: string; + matches: string[]; + }>; + count: number; + error?: string; +} + +/** + * Tool for searching across all memory blocks. + */ +export class SearchMemoryTool implements Tool { + name = 'search_memory'; + description = 'Search across all memory blocks (user preferences, facts, projects) for relevant information. Returns matching lines from each block.'; + + schema = { + type: 'object', + properties: { + query: { + type: 'string', + description: 'Search query to find in memory blocks' + } + }, + required: ['query'] + }; + + async execute(args: SearchMemoryArgs, _ctx?: LLMContext): Promise { + logger.info('Executing search memory', { query: args.query }); + + try { + const manager = new MemoryBlockManager(); + const searchResults = await manager.searchBlocks(args.query); + + const results = searchResults.map(r => ({ + block: r.block.label, + matches: r.matches.slice(0, 5) // Limit to 5 matches per block + })); + + logger.info('Search completed', { resultCount: results.length }); + + return { + success: true, + results, + count: results.length + }; + } catch (error: any) { + logger.error('Failed to search memory', { error: error?.message }); + return { + success: false, + results: [], + count: 0, + error: error?.message || 'Failed to search memory.' + }; + } + } +} diff --git a/front_end/panels/ai_chat/memory/UpdateMemoryTool.ts b/front_end/panels/ai_chat/memory/UpdateMemoryTool.ts new file mode 100644 index 0000000000..8b029724f5 --- /dev/null +++ b/front_end/panels/ai_chat/memory/UpdateMemoryTool.ts @@ -0,0 +1,95 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { createLogger } from '../core/Logger.js'; +import type { Tool, LLMContext } from '../tools/Tools.js'; +import { MemoryBlockManager } from './MemoryBlockManager.js'; +import type { BlockType } from './types.js'; + +const logger = createLogger('Tool:UpdateMemory'); + +export interface UpdateMemoryArgs { + blockType: BlockType; + content: string; + projectName?: string; +} + +export interface UpdateMemoryResult { + success: boolean; + message: string; + error?: string; +} + +/** + * Tool for updating memory blocks. + */ +export class UpdateMemoryTool implements Tool { + name = 'update_memory'; + description = `Update a memory block with new content. Block types: +- "user": User preferences, name, coding style (max 20000 chars) +- "facts": Recent facts extracted from conversations (max 20000 chars) +- "project": Project-specific context (max 20000 chars each, max 4 projects) + +For project blocks, you must also provide projectName.`; + + schema = { + type: 'object', + properties: { + blockType: { + type: 'string', + enum: ['user', 'facts', 'project'], + description: 'Type of memory block to update' + }, + content: { + type: 'string', + description: 'New content for the block (replaces existing content)' + }, + projectName: { + type: 'string', + description: 'Project name (required when blockType is "project")' + } + }, + required: ['blockType', 'content'] + }; + + async execute(args: UpdateMemoryArgs, _ctx?: LLMContext): Promise { + logger.info('Executing update memory', { + blockType: args.blockType, + contentLength: args.content.length, + projectName: args.projectName + }); + + try { + // Validate project name for project blocks + if (args.blockType === 'project' && !args.projectName) { + return { + success: false, + message: 'projectName is required for project blocks', + error: 'projectName is required for project blocks' + }; + } + + const manager = new MemoryBlockManager(); + await manager.updateBlock(args.blockType, args.content, args.projectName); + + const label = args.blockType === 'project' + ? `project:${args.projectName}` + : args.blockType; + + logger.info('Memory block updated', { label }); + + return { + success: true, + message: `Updated ${label} block (${args.content.length} chars)` + }; + } catch (error: any) { + logger.error('Failed to update memory block', { error: error?.message }); + return { + success: false, + message: error?.message || 'Failed to update memory block', + error: error?.message || 'Failed to update memory block' + }; + } + } +} diff --git a/front_end/panels/ai_chat/memory/__tests__/ListMemoryBlocksTool.test.ts b/front_end/panels/ai_chat/memory/__tests__/ListMemoryBlocksTool.test.ts new file mode 100644 index 0000000000..d8c833e355 --- /dev/null +++ b/front_end/panels/ai_chat/memory/__tests__/ListMemoryBlocksTool.test.ts @@ -0,0 +1,246 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { ListMemoryBlocksTool } from '../ListMemoryBlocksTool.js'; +import { FileStorageManager } from '../../tools/FileStorageManager.js'; +import { MemoryModule } from '../MemoryModule.js'; +import type { StoredFile, FileSummary } from '../../tools/FileStorageManager.js'; + +// Mock FileStorageManager +class MockFileStorageManager { + private files: Map = new Map(); + private currentSessionId = 'test-session'; + + getSessionId(): string { + return this.currentSessionId; + } + + setSessionId(sessionId: string): void { + this.currentSessionId = sessionId; + } + + async readFile(fileName: string): Promise { + const key = `${this.currentSessionId}:${fileName}`; + return this.files.get(key) || null; + } + + async createFile(fileName: string, content: string, mimeType: string): Promise { + const key = `${this.currentSessionId}:${fileName}`; + const now = Date.now(); + const file: StoredFile = { + id: `id-${fileName}`, + sessionId: this.currentSessionId, + fileName, + content, + mimeType, + createdAt: now, + updatedAt: now, + size: content.length, + }; + this.files.set(key, file); + return file; + } + + async updateFile(fileName: string, content: string): Promise { + const key = `${this.currentSessionId}:${fileName}`; + const existing = this.files.get(key); + if (!existing) { + throw new Error(`File "${fileName}" not found.`); + } + const updated: StoredFile = { ...existing, content, updatedAt: Date.now(), size: content.length }; + this.files.set(key, updated); + return updated; + } + + async deleteFile(fileName: string): Promise { + const key = `${this.currentSessionId}:${fileName}`; + this.files.delete(key); + } + + async listFiles(): Promise { + const summaries: FileSummary[] = []; + for (const [key, file] of this.files.entries()) { + if (key.startsWith(`${this.currentSessionId}:`)) { + summaries.push({ + fileName: file.fileName, + size: file.size, + mimeType: file.mimeType, + createdAt: file.createdAt, + updatedAt: file.updatedAt, + }); + } + } + return summaries; + } + + clearAllFiles(): void { + this.files.clear(); + } +} + +// Mock localStorage +const createLocalStorageMock = () => { + const store: Record = {}; + return { + getItem: (key: string): string | null => store[key] ?? null, + setItem: (key: string, value: string): void => { store[key] = value; }, + removeItem: (key: string): void => { delete store[key]; }, + clear: (): void => { Object.keys(store).forEach(k => delete store[k]); }, + get length(): number { return Object.keys(store).length; }, + key: (index: number): string | null => Object.keys(store)[index] ?? null, + }; +}; + +describe('ListMemoryBlocksTool', () => { + let mockFileStorageManager: MockFileStorageManager; + let originalFileStorageGetInstance: typeof FileStorageManager.getInstance; + let originalLocalStorage: Storage; + let mockLocalStorage: ReturnType; + let tool: ListMemoryBlocksTool; + + beforeEach(() => { + // Reset MemoryModule singleton + (MemoryModule as any).instance = null; + + // Mock localStorage + originalLocalStorage = globalThis.localStorage; + mockLocalStorage = createLocalStorageMock(); + Object.defineProperty(globalThis, 'localStorage', { + value: mockLocalStorage, + writable: true, + configurable: true, + }); + + // Mock FileStorageManager.getInstance + mockFileStorageManager = new MockFileStorageManager(); + originalFileStorageGetInstance = FileStorageManager.getInstance; + FileStorageManager.getInstance = () => mockFileStorageManager as unknown as FileStorageManager; + + tool = new ListMemoryBlocksTool(); + }); + + afterEach(() => { + FileStorageManager.getInstance = originalFileStorageGetInstance; + Object.defineProperty(globalThis, 'localStorage', { + value: originalLocalStorage, + writable: true, + configurable: true, + }); + mockFileStorageManager.clearAllFiles(); + }); + + describe('tool metadata', () => { + it('has correct name', () => { + assert.strictEqual(tool.name, 'list_memory_blocks'); + }); + + it('has description', () => { + assert.isString(tool.description); + assert.isTrue(tool.description.length > 0); + }); + + it('has empty required array in schema', () => { + assert.deepEqual(tool.schema.type, 'object'); + assert.deepEqual(tool.schema.required, []); + }); + }); + + describe('execute', () => { + it('returns empty blocks array when none exist', async () => { + const result = await tool.execute({}); + + assert.isTrue(result.success); + assert.isArray(result.blocks); + assert.lengthOf(result.blocks, 0); + assert.strictEqual(result.summary.totalBlocks, 0); + assert.strictEqual(result.summary.totalChars, 0); + }); + + it('returns formatted blocks with metadata', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_user.md', 'User preferences', 'text/markdown'); + + const result = await tool.execute({}); + + assert.isTrue(result.success); + assert.lengthOf(result.blocks, 1); + + const block = result.blocks[0]; + assert.strictEqual(block.type, 'user'); + assert.strictEqual(block.label, 'user'); + assert.strictEqual(block.content, 'User preferences'); + assert.strictEqual(block.charCount, 16); + assert.strictEqual(block.charLimit, 20000); + assert.isString(block.updatedAt); + }); + + it('returns all blocks with correct types', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_user.md', 'User data', 'text/markdown'); + await mockFileStorageManager.createFile('memory_facts.md', 'Some facts', 'text/markdown'); + await mockFileStorageManager.createFile('memory_project_app.md', 'App context', 'text/markdown'); + + const result = await tool.execute({}); + + assert.isTrue(result.success); + assert.lengthOf(result.blocks, 3); + + const types = result.blocks.map(b => b.type); + assert.isTrue(types.includes('user')); + assert.isTrue(types.includes('facts')); + assert.isTrue(types.includes('project')); + }); + + it('includes summary with correct totals', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_user.md', '12345', 'text/markdown'); + await mockFileStorageManager.createFile('memory_facts.md', '67890', 'text/markdown'); + + const result = await tool.execute({}); + + assert.strictEqual(result.summary.totalBlocks, 2); + assert.strictEqual(result.summary.totalChars, 10); + }); + + it('calculates maxChars as 120000', async () => { + const result = await tool.execute({}); + + assert.strictEqual(result.summary.maxChars, 120000); + }); + + it('returns error result on failure', async () => { + // Force an error + const originalListFiles = mockFileStorageManager.listFiles.bind(mockFileStorageManager); + mockFileStorageManager.listFiles = async () => { + throw new Error('Storage failure'); + }; + + const result = await tool.execute({}); + + assert.isFalse(result.success); + assert.lengthOf(result.blocks, 0); + assert.isString(result.error); + assert.isTrue(result.error!.includes('Storage failure')); + + // Note: error case has different maxChars in implementation (9500) + assert.strictEqual(result.summary.totalBlocks, 0); + assert.strictEqual(result.summary.totalChars, 0); + + mockFileStorageManager.listFiles = originalListFiles; + }); + + it('formats updatedAt as ISO string', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_user.md', 'Content', 'text/markdown'); + + const result = await tool.execute({}); + + const block = result.blocks[0]; + // Should be valid ISO date string + const date = new Date(block.updatedAt); + assert.isFalse(isNaN(date.getTime())); + assert.isTrue(block.updatedAt.includes('T')); + }); + }); +}); diff --git a/front_end/panels/ai_chat/memory/__tests__/MemoryBlockManager.test.ts b/front_end/panels/ai_chat/memory/__tests__/MemoryBlockManager.test.ts new file mode 100644 index 0000000000..66af99e59e --- /dev/null +++ b/front_end/panels/ai_chat/memory/__tests__/MemoryBlockManager.test.ts @@ -0,0 +1,451 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { MemoryBlockManager } from '../MemoryBlockManager.js'; +import { FileStorageManager } from '../../tools/FileStorageManager.js'; +import { MemoryModule } from '../MemoryModule.js'; +import type { StoredFile, FileSummary } from '../../tools/FileStorageManager.js'; + +// Mock FileStorageManager +class MockFileStorageManager { + private files: Map = new Map(); + private currentSessionId = 'test-session'; + + getSessionId(): string { + return this.currentSessionId; + } + + setSessionId(sessionId: string): void { + this.currentSessionId = sessionId; + } + + async readFile(fileName: string): Promise { + const key = `${this.currentSessionId}:${fileName}`; + return this.files.get(key) || null; + } + + async createFile(fileName: string, content: string, mimeType: string): Promise { + const key = `${this.currentSessionId}:${fileName}`; + if (this.files.has(key)) { + throw new Error(`File "${fileName}" already exists.`); + } + const now = Date.now(); + const file: StoredFile = { + id: `id-${fileName}`, + sessionId: this.currentSessionId, + fileName, + content, + mimeType, + createdAt: now, + updatedAt: now, + size: content.length, + }; + this.files.set(key, file); + return file; + } + + async updateFile(fileName: string, content: string, _append = false): Promise { + const key = `${this.currentSessionId}:${fileName}`; + const existing = this.files.get(key); + if (!existing) { + throw new Error(`File "${fileName}" not found.`); + } + const updated: StoredFile = { + ...existing, + content, + updatedAt: Date.now(), + size: content.length, + }; + this.files.set(key, updated); + return updated; + } + + async deleteFile(fileName: string): Promise { + const key = `${this.currentSessionId}:${fileName}`; + if (!this.files.has(key)) { + throw new Error(`File "${fileName}" not found.`); + } + this.files.delete(key); + } + + async listFiles(): Promise { + const summaries: FileSummary[] = []; + for (const [key, file] of this.files.entries()) { + if (key.startsWith(`${this.currentSessionId}:`)) { + summaries.push({ + fileName: file.fileName, + size: file.size, + mimeType: file.mimeType, + createdAt: file.createdAt, + updatedAt: file.updatedAt, + }); + } + } + return summaries; + } + + // Test helper to clear all files + clearAllFiles(): void { + this.files.clear(); + } +} + +// Mock localStorage +const createLocalStorageMock = () => { + const store: Record = {}; + return { + getItem: (key: string): string | null => store[key] ?? null, + setItem: (key: string, value: string): void => { store[key] = value; }, + removeItem: (key: string): void => { delete store[key]; }, + clear: (): void => { Object.keys(store).forEach(k => delete store[k]); }, + get length(): number { return Object.keys(store).length; }, + key: (index: number): string | null => Object.keys(store)[index] ?? null, + }; +}; + +describe('MemoryBlockManager', () => { + let mockFileStorageManager: MockFileStorageManager; + let originalFileStorageGetInstance: typeof FileStorageManager.getInstance; + let originalLocalStorage: Storage; + let mockLocalStorage: ReturnType; + + beforeEach(() => { + // Reset MemoryModule singleton + (MemoryModule as any).instance = null; + + // Mock localStorage + originalLocalStorage = globalThis.localStorage; + mockLocalStorage = createLocalStorageMock(); + Object.defineProperty(globalThis, 'localStorage', { + value: mockLocalStorage, + writable: true, + configurable: true, + }); + + // Mock FileStorageManager.getInstance + mockFileStorageManager = new MockFileStorageManager(); + originalFileStorageGetInstance = FileStorageManager.getInstance; + FileStorageManager.getInstance = () => mockFileStorageManager as unknown as FileStorageManager; + }); + + afterEach(() => { + // Restore FileStorageManager.getInstance + FileStorageManager.getInstance = originalFileStorageGetInstance; + + // Restore localStorage + Object.defineProperty(globalThis, 'localStorage', { + value: originalLocalStorage, + writable: true, + configurable: true, + }); + + // Clear files + mockFileStorageManager.clearAllFiles(); + }); + + describe('getBlock', () => { + it('returns null when block does not exist', async () => { + const manager = new MemoryBlockManager(); + const block = await manager.getBlock('user'); + assert.isNull(block); + }); + + it('returns MemoryBlock when user file exists', async () => { + // Create file in mock storage with global session + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_user.md', 'User preferences here', 'text/markdown'); + + const manager = new MemoryBlockManager(); + const block = await manager.getBlock('user'); + + assert.isNotNull(block); + assert.strictEqual(block!.type, 'user'); + assert.strictEqual(block!.content, 'User preferences here'); + assert.strictEqual(block!.label, 'user'); + assert.strictEqual(block!.filename, 'memory_user.md'); + assert.strictEqual(block!.charLimit, 20000); + }); + + it('returns project block with projectName', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_project_my_app.md', 'Project context', 'text/markdown'); + + const manager = new MemoryBlockManager(); + const block = await manager.getBlock('project', 'my_app'); + + assert.isNotNull(block); + assert.strictEqual(block!.type, 'project'); + assert.strictEqual(block!.content, 'Project context'); + assert.strictEqual(block!.label, 'project:my_app'); + }); + }); + + describe('updateBlock', () => { + it('creates new block when it does not exist', async () => { + const manager = new MemoryBlockManager(); + await manager.updateBlock('user', 'New user content'); + + mockFileStorageManager.setSessionId('__global_memory__'); + const file = await mockFileStorageManager.readFile('memory_user.md'); + assert.isNotNull(file); + assert.strictEqual(file!.content, 'New user content'); + }); + + it('updates existing block', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_facts.md', 'Old facts', 'text/markdown'); + + const manager = new MemoryBlockManager(); + await manager.updateBlock('facts', 'Updated facts'); + + const file = await mockFileStorageManager.readFile('memory_facts.md'); + assert.strictEqual(file!.content, 'Updated facts'); + }); + + it('throws when content exceeds limit', async () => { + const manager = new MemoryBlockManager(); + const oversizedContent = 'x'.repeat(20001); + + try { + await manager.updateBlock('user', oversizedContent); + assert.fail('Expected error to be thrown'); + } catch (error: any) { + assert.isTrue(error.message.includes('exceeds')); + assert.isTrue(error.message.includes('20000')); + } + }); + + it('throws when max project blocks reached', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + // Create 4 project blocks (the max) + await mockFileStorageManager.createFile('memory_project_one.md', 'Project 1', 'text/markdown'); + await mockFileStorageManager.createFile('memory_project_two.md', 'Project 2', 'text/markdown'); + await mockFileStorageManager.createFile('memory_project_three.md', 'Project 3', 'text/markdown'); + await mockFileStorageManager.createFile('memory_project_four.md', 'Project 4', 'text/markdown'); + + const manager = new MemoryBlockManager(); + + try { + await manager.updateBlock('project', 'Project 5 content', 'five'); + assert.fail('Expected error to be thrown'); + } catch (error: any) { + assert.isTrue(error.message.includes('Max')); + assert.isTrue(error.message.includes('4')); + } + }); + + it('allows updating existing project block when at max', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_project_one.md', 'Project 1', 'text/markdown'); + await mockFileStorageManager.createFile('memory_project_two.md', 'Project 2', 'text/markdown'); + await mockFileStorageManager.createFile('memory_project_three.md', 'Project 3', 'text/markdown'); + await mockFileStorageManager.createFile('memory_project_four.md', 'Project 4', 'text/markdown'); + + const manager = new MemoryBlockManager(); + // Should succeed since we're updating existing block + await manager.updateBlock('project', 'Updated Project 1', 'one'); + + const file = await mockFileStorageManager.readFile('memory_project_one.md'); + assert.strictEqual(file!.content, 'Updated Project 1'); + }); + }); + + describe('deleteBlock', () => { + it('removes existing block', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_user.md', 'User data', 'text/markdown'); + + const manager = new MemoryBlockManager(); + await manager.deleteBlock('user'); + + const file = await mockFileStorageManager.readFile('memory_user.md'); + assert.isNull(file); + }); + + it('silently handles non-existent block', async () => { + const manager = new MemoryBlockManager(); + // Should not throw + await manager.deleteBlock('facts'); + }); + }); + + describe('getAllBlocks', () => { + it('returns empty array when no blocks exist', async () => { + const manager = new MemoryBlockManager(); + const blocks = await manager.getAllBlocks(); + assert.isArray(blocks); + assert.lengthOf(blocks, 0); + }); + + it('returns all memory blocks', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_user.md', 'User content', 'text/markdown'); + await mockFileStorageManager.createFile('memory_facts.md', 'Facts content', 'text/markdown'); + + const manager = new MemoryBlockManager(); + const blocks = await manager.getAllBlocks(); + + assert.lengthOf(blocks, 2); + const types = blocks.map(b => b.type); + assert.isTrue(types.includes('user')); + assert.isTrue(types.includes('facts')); + }); + + it('filters non-memory files', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_user.md', 'User content', 'text/markdown'); + await mockFileStorageManager.createFile('other_file.txt', 'Other content', 'text/plain'); + + const manager = new MemoryBlockManager(); + const blocks = await manager.getAllBlocks(); + + assert.lengthOf(blocks, 1); + assert.strictEqual(blocks[0].type, 'user'); + }); + }); + + describe('listProjectBlocks', () => { + it('returns only project blocks', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_user.md', 'User content', 'text/markdown'); + await mockFileStorageManager.createFile('memory_facts.md', 'Facts content', 'text/markdown'); + await mockFileStorageManager.createFile('memory_project_app.md', 'App content', 'text/markdown'); + await mockFileStorageManager.createFile('memory_project_web.md', 'Web content', 'text/markdown'); + + const manager = new MemoryBlockManager(); + const projectBlocks = await manager.listProjectBlocks(); + + assert.lengthOf(projectBlocks, 2); + assert.isTrue(projectBlocks.every(b => b.type === 'project')); + }); + + it('returns empty array when no project blocks', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_user.md', 'User content', 'text/markdown'); + + const manager = new MemoryBlockManager(); + const projectBlocks = await manager.listProjectBlocks(); + + assert.lengthOf(projectBlocks, 0); + }); + }); + + describe('searchBlocks', () => { + it('returns empty when no matches', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_user.md', 'User preferences', 'text/markdown'); + + const manager = new MemoryBlockManager(); + const results = await manager.searchBlocks('nonexistent'); + + assert.isArray(results); + assert.lengthOf(results, 0); + }); + + it('finds case-insensitive matches', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_user.md', 'User likes TypeScript', 'text/markdown'); + + const manager = new MemoryBlockManager(); + const results = await manager.searchBlocks('typescript'); + + assert.lengthOf(results, 1); + assert.strictEqual(results[0].block.type, 'user'); + assert.lengthOf(results[0].matches, 1); + assert.isTrue(results[0].matches[0].includes('TypeScript')); + }); + + it('returns matches from multiple blocks', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_user.md', 'Prefers React framework', 'text/markdown'); + await mockFileStorageManager.createFile('memory_facts.md', 'Uses React and Vue', 'text/markdown'); + + const manager = new MemoryBlockManager(); + const results = await manager.searchBlocks('react'); + + assert.lengthOf(results, 2); + }); + + it('returns matching lines only', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_user.md', 'Line 1: hello\nLine 2: world\nLine 3: hello again', 'text/markdown'); + + const manager = new MemoryBlockManager(); + const results = await manager.searchBlocks('hello'); + + assert.lengthOf(results, 1); + assert.lengthOf(results[0].matches, 2); + }); + }); + + describe('compileMemoryContext', () => { + it('returns empty string when no blocks', async () => { + const manager = new MemoryBlockManager(); + const context = await manager.compileMemoryContext(); + assert.strictEqual(context, ''); + }); + + it('returns XML with all blocks', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_user.md', 'User data', 'text/markdown'); + await mockFileStorageManager.createFile('memory_facts.md', 'Some facts', 'text/markdown'); + + const manager = new MemoryBlockManager(); + const context = await manager.compileMemoryContext(); + + assert.isTrue(context.startsWith('')); + assert.isTrue(context.endsWith('')); + assert.isTrue(context.includes('')); + assert.isTrue(context.includes('')); + assert.isTrue(context.includes('')); + assert.isTrue(context.includes('')); + assert.isTrue(context.includes('User data')); + assert.isTrue(context.includes('Some facts')); + }); + + it('skips empty blocks', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_user.md', 'User data', 'text/markdown'); + await mockFileStorageManager.createFile('memory_facts.md', ' ', 'text/markdown'); + + const manager = new MemoryBlockManager(); + const context = await manager.compileMemoryContext(); + + assert.isTrue(context.includes('')); + assert.isFalse(context.includes('')); + }); + }); + + describe('session management', () => { + it('uses global session for all operations', async () => { + // Start with a different session + mockFileStorageManager.setSessionId('other-session'); + await mockFileStorageManager.createFile('memory_user.md', 'Other session data', 'text/markdown'); + + // Create block via manager (should use global session) + const manager = new MemoryBlockManager(); + await manager.updateBlock('user', 'Global session data'); + + // Verify data is in global session + mockFileStorageManager.setSessionId('__global_memory__'); + const globalFile = await mockFileStorageManager.readFile('memory_user.md'); + assert.strictEqual(globalFile!.content, 'Global session data'); + + // Verify other session data is unchanged + mockFileStorageManager.setSessionId('other-session'); + const otherFile = await mockFileStorageManager.readFile('memory_user.md'); + assert.strictEqual(otherFile!.content, 'Other session data'); + }); + + it('restores original session after operation', async () => { + mockFileStorageManager.setSessionId('original-session'); + + const manager = new MemoryBlockManager(); + await manager.updateBlock('user', 'Some content'); + + // Session should be restored + assert.strictEqual(mockFileStorageManager.getSessionId(), 'original-session'); + }); + }); +}); diff --git a/front_end/panels/ai_chat/memory/__tests__/MemoryIntegration.test.ts b/front_end/panels/ai_chat/memory/__tests__/MemoryIntegration.test.ts new file mode 100644 index 0000000000..87f4247f26 --- /dev/null +++ b/front_end/panels/ai_chat/memory/__tests__/MemoryIntegration.test.ts @@ -0,0 +1,442 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/** + * Integration tests for the memory module. + * These tests verify that components work together correctly. + */ + +import { MemoryModule } from '../MemoryModule.js'; +import { MemoryBlockManager } from '../MemoryBlockManager.js'; +import { SearchMemoryTool } from '../SearchMemoryTool.js'; +import { UpdateMemoryTool } from '../UpdateMemoryTool.js'; +import { ListMemoryBlocksTool } from '../ListMemoryBlocksTool.js'; +import { FileStorageManager } from '../../tools/FileStorageManager.js'; +import type { StoredFile, FileSummary } from '../../tools/FileStorageManager.js'; + +// Mock FileStorageManager with full functionality +class MockFileStorageManager { + private files: Map = new Map(); + private currentSessionId = 'test-session'; + + getSessionId(): string { + return this.currentSessionId; + } + + setSessionId(sessionId: string): void { + this.currentSessionId = sessionId; + } + + async readFile(fileName: string): Promise { + const key = `${this.currentSessionId}:${fileName}`; + return this.files.get(key) || null; + } + + async createFile(fileName: string, content: string, mimeType: string): Promise { + const key = `${this.currentSessionId}:${fileName}`; + if (this.files.has(key)) { + throw new Error(`File "${fileName}" already exists.`); + } + const now = Date.now(); + const file: StoredFile = { + id: `id-${Math.random().toString(36).substr(2, 9)}`, + sessionId: this.currentSessionId, + fileName, + content, + mimeType, + createdAt: now, + updatedAt: now, + size: content.length, + }; + this.files.set(key, file); + return file; + } + + async updateFile(fileName: string, content: string, _append = false): Promise { + const key = `${this.currentSessionId}:${fileName}`; + const existing = this.files.get(key); + if (!existing) { + throw new Error(`File "${fileName}" not found.`); + } + const updated: StoredFile = { + ...existing, + content, + updatedAt: Date.now(), + size: content.length, + }; + this.files.set(key, updated); + return updated; + } + + async deleteFile(fileName: string): Promise { + const key = `${this.currentSessionId}:${fileName}`; + if (!this.files.has(key)) { + throw new Error(`File "${fileName}" not found.`); + } + this.files.delete(key); + } + + async listFiles(): Promise { + const summaries: FileSummary[] = []; + for (const [key, file] of this.files.entries()) { + if (key.startsWith(`${this.currentSessionId}:`)) { + summaries.push({ + fileName: file.fileName, + size: file.size, + mimeType: file.mimeType, + createdAt: file.createdAt, + updatedAt: file.updatedAt, + }); + } + } + return summaries; + } + + clearAllFiles(): void { + this.files.clear(); + } +} + +// Mock localStorage +const createLocalStorageMock = () => { + const store: Record = {}; + return { + getItem: (key: string): string | null => store[key] ?? null, + setItem: (key: string, value: string): void => { store[key] = value; }, + removeItem: (key: string): void => { delete store[key]; }, + clear: (): void => { Object.keys(store).forEach(k => delete store[k]); }, + get length(): number { return Object.keys(store).length; }, + key: (index: number): string | null => Object.keys(store)[index] ?? null, + }; +}; + +describe('Memory Module Integration', () => { + let mockFileStorageManager: MockFileStorageManager; + let originalFileStorageGetInstance: typeof FileStorageManager.getInstance; + let originalLocalStorage: Storage; + let mockLocalStorage: ReturnType; + + beforeEach(() => { + // Reset MemoryModule singleton + (MemoryModule as any).instance = null; + + // Mock localStorage + originalLocalStorage = globalThis.localStorage; + mockLocalStorage = createLocalStorageMock(); + Object.defineProperty(globalThis, 'localStorage', { + value: mockLocalStorage, + writable: true, + configurable: true, + }); + + // Mock FileStorageManager.getInstance + mockFileStorageManager = new MockFileStorageManager(); + originalFileStorageGetInstance = FileStorageManager.getInstance; + FileStorageManager.getInstance = () => mockFileStorageManager as unknown as FileStorageManager; + }); + + afterEach(() => { + FileStorageManager.getInstance = originalFileStorageGetInstance; + Object.defineProperty(globalThis, 'localStorage', { + value: originalLocalStorage, + writable: true, + configurable: true, + }); + mockFileStorageManager.clearAllFiles(); + }); + + describe('End-to-End Workflows', () => { + it('creates user block and searches for content', async () => { + const updateTool = new UpdateMemoryTool(); + const searchTool = new SearchMemoryTool(); + + // Create block + const updateResult = await updateTool.execute({ + blockType: 'user', + content: 'Tyson prefers TypeScript and React', + }); + assert.isTrue(updateResult.success); + + // Search for content + const searchResult = await searchTool.execute({ query: 'TypeScript' }); + assert.isTrue(searchResult.success); + assert.strictEqual(searchResult.count, 1); + assert.isTrue(searchResult.results[0].matches[0].includes('TypeScript')); + }); + + it('creates multiple project blocks and lists them', async () => { + const updateTool = new UpdateMemoryTool(); + const listTool = new ListMemoryBlocksTool(); + + // Create project blocks + await updateTool.execute({ + blockType: 'project', + content: 'Browser extension project', + projectName: 'extension', + }); + await updateTool.execute({ + blockType: 'project', + content: 'Mobile app project', + projectName: 'mobile', + }); + + // List all blocks + const listResult = await listTool.execute({}); + assert.isTrue(listResult.success); + assert.strictEqual(listResult.summary.totalBlocks, 2); + + const labels = listResult.blocks.map(b => b.label); + assert.isTrue(labels.includes('project:extension')); + assert.isTrue(labels.includes('project:mobile')); + }); + + it('updates existing block and verifies content replaced', async () => { + const updateTool = new UpdateMemoryTool(); + const listTool = new ListMemoryBlocksTool(); + + // Create initial block + await updateTool.execute({ + blockType: 'facts', + content: 'Original facts here', + }); + + // Update block + await updateTool.execute({ + blockType: 'facts', + content: 'Updated facts content', + }); + + // Verify updated content + const listResult = await listTool.execute({}); + const factsBlock = listResult.blocks.find(b => b.type === 'facts'); + assert.strictEqual(factsBlock!.content, 'Updated facts content'); + }); + + it('enforces max project blocks limit', async () => { + const updateTool = new UpdateMemoryTool(); + + // Create 4 project blocks (the max) + for (let i = 1; i <= 4; i++) { + const result = await updateTool.execute({ + blockType: 'project', + content: `Project ${i} content`, + projectName: `project${i}`, + }); + assert.isTrue(result.success); + } + + // Try to create 5th project - should fail + const fifthResult = await updateTool.execute({ + blockType: 'project', + content: 'Project 5 content', + projectName: 'project5', + }); + assert.isFalse(fifthResult.success); + assert.isTrue(fifthResult.error!.includes('Max')); + }); + + it('enforces character limit', async () => { + const updateTool = new UpdateMemoryTool(); + const oversizedContent = 'x'.repeat(20001); + + const result = await updateTool.execute({ + blockType: 'user', + content: oversizedContent, + }); + + assert.isFalse(result.success); + assert.isTrue(result.error!.includes('exceeds')); + }); + + it('compiles memory context with multiple block types', async () => { + const updateTool = new UpdateMemoryTool(); + const manager = new MemoryBlockManager(); + + // Create blocks of different types + await updateTool.execute({ + blockType: 'user', + content: 'User: Tyson', + }); + await updateTool.execute({ + blockType: 'facts', + content: 'Fact: Uses VSCode', + }); + await updateTool.execute({ + blockType: 'project', + content: 'Project: Browser', + projectName: 'browser', + }); + + // Compile context + const context = await manager.compileMemoryContext(); + + assert.isTrue(context.includes('')); + assert.isTrue(context.includes('')); + assert.isTrue(context.includes('')); + assert.isTrue(context.includes('User: Tyson')); + assert.isTrue(context.includes('')); + assert.isTrue(context.includes('Fact: Uses VSCode')); + assert.isTrue(context.includes('')); + assert.isTrue(context.includes('Project: Browser')); + }); + }); + + describe('Tool Interactions', () => { + it('UpdateMemoryTool creates block that ListMemoryBlocksTool shows', async () => { + const updateTool = new UpdateMemoryTool(); + const listTool = new ListMemoryBlocksTool(); + + // Initially no blocks + const initialList = await listTool.execute({}); + assert.lengthOf(initialList.blocks, 0); + + // Create block + await updateTool.execute({ + blockType: 'user', + content: 'New user data', + }); + + // Now shows in list + const afterList = await listTool.execute({}); + assert.lengthOf(afterList.blocks, 1); + assert.strictEqual(afterList.blocks[0].content, 'New user data'); + }); + + it('UpdateMemoryTool creates block that SearchMemoryTool finds', async () => { + const updateTool = new UpdateMemoryTool(); + const searchTool = new SearchMemoryTool(); + + // Create block with specific content + await updateTool.execute({ + blockType: 'facts', + content: 'The quick brown fox jumps over the lazy dog', + }); + + // Search finds it + const searchResult = await searchTool.execute({ query: 'brown fox' }); + assert.isTrue(searchResult.success); + assert.strictEqual(searchResult.count, 1); + }); + + it('multiple tools operating on same block sequentially', async () => { + const updateTool = new UpdateMemoryTool(); + const searchTool = new SearchMemoryTool(); + const listTool = new ListMemoryBlocksTool(); + + // Create + await updateTool.execute({ + blockType: 'user', + content: 'Version 1', + }); + + // Search + let searchResult = await searchTool.execute({ query: 'Version 1' }); + assert.strictEqual(searchResult.count, 1); + + // Update + await updateTool.execute({ + blockType: 'user', + content: 'Version 2', + }); + + // Old search fails + searchResult = await searchTool.execute({ query: 'Version 1' }); + assert.strictEqual(searchResult.count, 0); + + // New search succeeds + searchResult = await searchTool.execute({ query: 'Version 2' }); + assert.strictEqual(searchResult.count, 1); + + // List shows updated content + const listResult = await listTool.execute({}); + assert.strictEqual(listResult.blocks[0].content, 'Version 2'); + }); + }); + + describe('Session Isolation', () => { + it('global memory session is isolated from other sessions', async () => { + const manager = new MemoryBlockManager(); + + // Create block via manager (uses global session) + await manager.updateBlock('user', 'Global memory data'); + + // Switch to different session and create file directly + mockFileStorageManager.setSessionId('conversation-123'); + await mockFileStorageManager.createFile('memory_user.md', 'Conversation data', 'text/markdown'); + + // Manager should only see global memory + const blocks = await manager.getAllBlocks(); + assert.lengthOf(blocks, 1); + assert.strictEqual(blocks[0].content, 'Global memory data'); + }); + }); + + describe('MemoryModule Integration', () => { + it('getInstructions returns content when enabled', () => { + const module = MemoryModule.getInstance(); + module.setEnabled(true); + + const instructions = module.getInstructions(); + assert.isTrue(instructions.includes('')); + assert.isTrue(instructions.includes('persistent memory system')); + }); + + it('getInstructions returns empty when disabled', () => { + const module = MemoryModule.getInstance(); + module.setEnabled(false); + + const instructions = module.getInstructions(); + assert.strictEqual(instructions, ''); + }); + + it('block limits are correctly enforced', async () => { + const updateTool = new UpdateMemoryTool(); + const module = MemoryModule.getInstance(); + + // Get the actual limit + const limit = module.getBlockLimit('user'); + assert.strictEqual(limit, 20000); + + // Content at limit should succeed + const atLimitResult = await updateTool.execute({ + blockType: 'user', + content: 'x'.repeat(20000), + }); + assert.isTrue(atLimitResult.success); + + // Content over limit should fail + const overLimitResult = await updateTool.execute({ + blockType: 'facts', + content: 'x'.repeat(20001), + }); + assert.isFalse(overLimitResult.success); + }); + }); + + describe('Error Propagation', () => { + it('MemoryBlockManager errors propagate through tools', async () => { + const updateTool = new UpdateMemoryTool(); + + // Create 4 project blocks + for (let i = 0; i < 4; i++) { + await updateTool.execute({ + blockType: 'project', + content: `Project ${i}`, + projectName: `proj${i}`, + }); + } + + // 5th should fail with proper error message + const result = await updateTool.execute({ + blockType: 'project', + content: 'Too many', + projectName: 'toomany', + }); + + assert.isFalse(result.success); + assert.isString(result.error); + assert.isTrue(result.message.includes('Max') || result.message.includes('4')); + }); + }); +}); diff --git a/front_end/panels/ai_chat/memory/__tests__/MemoryModule.test.ts b/front_end/panels/ai_chat/memory/__tests__/MemoryModule.test.ts new file mode 100644 index 0000000000..fa97bf4f3f --- /dev/null +++ b/front_end/panels/ai_chat/memory/__tests__/MemoryModule.test.ts @@ -0,0 +1,191 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { MemoryModule } from '../MemoryModule.js'; + +// Mock localStorage +const createLocalStorageMock = () => { + const store: Record = {}; + return { + getItem: (key: string): string | null => store[key] ?? null, + setItem: (key: string, value: string): void => { store[key] = value; }, + removeItem: (key: string): void => { delete store[key]; }, + clear: (): void => { Object.keys(store).forEach(k => delete store[k]); }, + get length(): number { return Object.keys(store).length; }, + key: (index: number): string | null => Object.keys(store)[index] ?? null, + }; +}; + +describe('MemoryModule', () => { + let originalLocalStorage: Storage; + let mockLocalStorage: ReturnType; + + beforeEach(() => { + // Save original localStorage and replace with mock + originalLocalStorage = globalThis.localStorage; + mockLocalStorage = createLocalStorageMock(); + Object.defineProperty(globalThis, 'localStorage', { + value: mockLocalStorage, + writable: true, + configurable: true, + }); + + // Reset singleton instance by accessing private static property + (MemoryModule as any).instance = null; + }); + + afterEach(() => { + // Restore original localStorage + Object.defineProperty(globalThis, 'localStorage', { + value: originalLocalStorage, + writable: true, + configurable: true, + }); + }); + + describe('getInstance', () => { + it('returns a MemoryModule instance', () => { + const instance = MemoryModule.getInstance(); + assert.isNotNull(instance); + assert.isFunction(instance.isEnabled); + assert.isFunction(instance.getConfig); + }); + + it('returns the same instance on multiple calls', () => { + const instance1 = MemoryModule.getInstance(); + const instance2 = MemoryModule.getInstance(); + assert.strictEqual(instance1, instance2); + }); + }); + + describe('isEnabled', () => { + it('returns true by default when localStorage has no value', () => { + const module = MemoryModule.getInstance(); + assert.isTrue(module.isEnabled()); + }); + + it('returns true when localStorage value is not "false"', () => { + mockLocalStorage.setItem('ai_chat_memory_enabled', 'true'); + const module = MemoryModule.getInstance(); + assert.isTrue(module.isEnabled()); + }); + + it('returns false when localStorage has "false"', () => { + mockLocalStorage.setItem('ai_chat_memory_enabled', 'false'); + const module = MemoryModule.getInstance(); + assert.isFalse(module.isEnabled()); + }); + }); + + describe('setEnabled', () => { + it('persists true to localStorage', () => { + const module = MemoryModule.getInstance(); + module.setEnabled(true); + assert.strictEqual(mockLocalStorage.getItem('ai_chat_memory_enabled'), 'true'); + }); + + it('persists false to localStorage', () => { + const module = MemoryModule.getInstance(); + module.setEnabled(false); + assert.strictEqual(mockLocalStorage.getItem('ai_chat_memory_enabled'), 'false'); + }); + + it('updates isEnabled result after setEnabled', () => { + const module = MemoryModule.getInstance(); + module.setEnabled(false); + assert.isFalse(module.isEnabled()); + module.setEnabled(true); + assert.isTrue(module.isEnabled()); + }); + }); + + describe('getInstructions', () => { + it('returns memory instructions when enabled', () => { + const module = MemoryModule.getInstance(); + module.setEnabled(true); + const instructions = module.getInstructions(); + assert.isString(instructions); + assert.isTrue(instructions.length > 0); + assert.isTrue(instructions.includes('')); + assert.isTrue(instructions.includes('persistent memory system')); + }); + + it('returns empty string when disabled', () => { + const module = MemoryModule.getInstance(); + module.setEnabled(false); + const instructions = module.getInstructions(); + assert.strictEqual(instructions, ''); + }); + }); + + describe('shouldIncludeMemoryTool', () => { + it('returns true when enabled', () => { + const module = MemoryModule.getInstance(); + module.setEnabled(true); + assert.isTrue(module.shouldIncludeMemoryTool()); + }); + + it('returns false when disabled', () => { + const module = MemoryModule.getInstance(); + module.setEnabled(false); + assert.isFalse(module.shouldIncludeMemoryTool()); + }); + + it('matches isEnabled behavior', () => { + const module = MemoryModule.getInstance(); + module.setEnabled(true); + assert.strictEqual(module.shouldIncludeMemoryTool(), module.isEnabled()); + module.setEnabled(false); + assert.strictEqual(module.shouldIncludeMemoryTool(), module.isEnabled()); + }); + }); + + describe('getBlockLimit', () => { + it('returns 20000 for user block type', () => { + const module = MemoryModule.getInstance(); + assert.strictEqual(module.getBlockLimit('user'), 20000); + }); + + it('returns 20000 for facts block type', () => { + const module = MemoryModule.getInstance(); + assert.strictEqual(module.getBlockLimit('facts'), 20000); + }); + + it('returns 20000 for project block type', () => { + const module = MemoryModule.getInstance(); + assert.strictEqual(module.getBlockLimit('project'), 20000); + }); + }); + + describe('getMaxProjectBlocks', () => { + it('returns 4', () => { + const module = MemoryModule.getInstance(); + assert.strictEqual(module.getMaxProjectBlocks(), 4); + }); + }); + + describe('getSessionId', () => { + it('returns __global_memory__', () => { + const module = MemoryModule.getInstance(); + assert.strictEqual(module.getSessionId(), '__global_memory__'); + }); + }); + + describe('getConfig', () => { + it('returns a complete config object', () => { + const module = MemoryModule.getInstance(); + const config = module.getConfig(); + + assert.isObject(config); + assert.deepEqual(config.blockLimits, { + user: 20000, + facts: 20000, + project: 20000, + }); + assert.strictEqual(config.maxProjectBlocks, 4); + assert.strictEqual(config.sessionId, '__global_memory__'); + assert.strictEqual(config.enabledKey, 'ai_chat_memory_enabled'); + }); + }); +}); diff --git a/front_end/panels/ai_chat/memory/__tests__/SearchMemoryTool.test.ts b/front_end/panels/ai_chat/memory/__tests__/SearchMemoryTool.test.ts new file mode 100644 index 0000000000..164bca91b7 --- /dev/null +++ b/front_end/panels/ai_chat/memory/__tests__/SearchMemoryTool.test.ts @@ -0,0 +1,223 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { SearchMemoryTool } from '../SearchMemoryTool.js'; +import { MemoryBlockManager } from '../MemoryBlockManager.js'; +import { FileStorageManager } from '../../tools/FileStorageManager.js'; +import { MemoryModule } from '../MemoryModule.js'; +import type { StoredFile, FileSummary } from '../../tools/FileStorageManager.js'; + +// Mock FileStorageManager +class MockFileStorageManager { + private files: Map = new Map(); + private currentSessionId = 'test-session'; + + getSessionId(): string { + return this.currentSessionId; + } + + setSessionId(sessionId: string): void { + this.currentSessionId = sessionId; + } + + async readFile(fileName: string): Promise { + const key = `${this.currentSessionId}:${fileName}`; + return this.files.get(key) || null; + } + + async createFile(fileName: string, content: string, mimeType: string): Promise { + const key = `${this.currentSessionId}:${fileName}`; + const now = Date.now(); + const file: StoredFile = { + id: `id-${fileName}`, + sessionId: this.currentSessionId, + fileName, + content, + mimeType, + createdAt: now, + updatedAt: now, + size: content.length, + }; + this.files.set(key, file); + return file; + } + + async updateFile(fileName: string, content: string): Promise { + const key = `${this.currentSessionId}:${fileName}`; + const existing = this.files.get(key); + if (!existing) { + throw new Error(`File "${fileName}" not found.`); + } + const updated: StoredFile = { ...existing, content, updatedAt: Date.now(), size: content.length }; + this.files.set(key, updated); + return updated; + } + + async deleteFile(fileName: string): Promise { + const key = `${this.currentSessionId}:${fileName}`; + this.files.delete(key); + } + + async listFiles(): Promise { + const summaries: FileSummary[] = []; + for (const [key, file] of this.files.entries()) { + if (key.startsWith(`${this.currentSessionId}:`)) { + summaries.push({ + fileName: file.fileName, + size: file.size, + mimeType: file.mimeType, + createdAt: file.createdAt, + updatedAt: file.updatedAt, + }); + } + } + return summaries; + } + + clearAllFiles(): void { + this.files.clear(); + } +} + +// Mock localStorage +const createLocalStorageMock = () => { + const store: Record = {}; + return { + getItem: (key: string): string | null => store[key] ?? null, + setItem: (key: string, value: string): void => { store[key] = value; }, + removeItem: (key: string): void => { delete store[key]; }, + clear: (): void => { Object.keys(store).forEach(k => delete store[k]); }, + get length(): number { return Object.keys(store).length; }, + key: (index: number): string | null => Object.keys(store)[index] ?? null, + }; +}; + +describe('SearchMemoryTool', () => { + let mockFileStorageManager: MockFileStorageManager; + let originalFileStorageGetInstance: typeof FileStorageManager.getInstance; + let originalLocalStorage: Storage; + let mockLocalStorage: ReturnType; + let tool: SearchMemoryTool; + + beforeEach(() => { + // Reset MemoryModule singleton + (MemoryModule as any).instance = null; + + // Mock localStorage + originalLocalStorage = globalThis.localStorage; + mockLocalStorage = createLocalStorageMock(); + Object.defineProperty(globalThis, 'localStorage', { + value: mockLocalStorage, + writable: true, + configurable: true, + }); + + // Mock FileStorageManager.getInstance + mockFileStorageManager = new MockFileStorageManager(); + originalFileStorageGetInstance = FileStorageManager.getInstance; + FileStorageManager.getInstance = () => mockFileStorageManager as unknown as FileStorageManager; + + tool = new SearchMemoryTool(); + }); + + afterEach(() => { + FileStorageManager.getInstance = originalFileStorageGetInstance; + Object.defineProperty(globalThis, 'localStorage', { + value: originalLocalStorage, + writable: true, + configurable: true, + }); + mockFileStorageManager.clearAllFiles(); + }); + + describe('tool metadata', () => { + it('has correct name', () => { + assert.strictEqual(tool.name, 'search_memory'); + }); + + it('has description', () => { + assert.isString(tool.description); + assert.isTrue(tool.description.length > 0); + }); + + it('has correct schema with required query field', () => { + assert.deepEqual(tool.schema.type, 'object'); + assert.isObject(tool.schema.properties); + assert.isObject((tool.schema.properties as any).query); + assert.deepEqual(tool.schema.required, ['query']); + }); + }); + + describe('execute', () => { + it('returns success with matching results', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_user.md', 'User likes TypeScript', 'text/markdown'); + + const result = await tool.execute({ query: 'typescript' }); + + assert.isTrue(result.success); + assert.strictEqual(result.count, 1); + assert.lengthOf(result.results, 1); + assert.strictEqual(result.results[0].block, 'user'); + assert.isTrue(result.results[0].matches[0].includes('TypeScript')); + }); + + it('returns empty results when no matches', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_user.md', 'User preferences', 'text/markdown'); + + const result = await tool.execute({ query: 'nonexistent' }); + + assert.isTrue(result.success); + assert.strictEqual(result.count, 0); + assert.lengthOf(result.results, 0); + }); + + it('limits to 5 matches per block', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + // Create content with more than 5 matching lines + const content = Array(10).fill('match line').join('\n'); + await mockFileStorageManager.createFile('memory_user.md', content, 'text/markdown'); + + const result = await tool.execute({ query: 'match' }); + + assert.isTrue(result.success); + assert.strictEqual(result.count, 1); + assert.lengthOf(result.results[0].matches, 5); + }); + + it('returns results from multiple blocks', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_user.md', 'Uses React', 'text/markdown'); + await mockFileStorageManager.createFile('memory_facts.md', 'React is popular', 'text/markdown'); + + const result = await tool.execute({ query: 'react' }); + + assert.isTrue(result.success); + assert.strictEqual(result.count, 2); + const blocks = result.results.map(r => r.block); + assert.isTrue(blocks.includes('user')); + assert.isTrue(blocks.includes('facts')); + }); + + it('returns error result on failure', async () => { + // Force an error by making FileStorageManager throw + const originalListFiles = mockFileStorageManager.listFiles.bind(mockFileStorageManager); + mockFileStorageManager.listFiles = async () => { + throw new Error('Database error'); + }; + + const result = await tool.execute({ query: 'test' }); + + assert.isFalse(result.success); + assert.strictEqual(result.count, 0); + assert.lengthOf(result.results, 0); + assert.isString(result.error); + assert.isTrue(result.error!.includes('Database error')); + + // Restore + mockFileStorageManager.listFiles = originalListFiles; + }); + }); +}); diff --git a/front_end/panels/ai_chat/memory/__tests__/UpdateMemoryTool.test.ts b/front_end/panels/ai_chat/memory/__tests__/UpdateMemoryTool.test.ts new file mode 100644 index 0000000000..14ab8ca574 --- /dev/null +++ b/front_end/panels/ai_chat/memory/__tests__/UpdateMemoryTool.test.ts @@ -0,0 +1,270 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { UpdateMemoryTool } from '../UpdateMemoryTool.js'; +import { FileStorageManager } from '../../tools/FileStorageManager.js'; +import { MemoryModule } from '../MemoryModule.js'; +import type { StoredFile, FileSummary } from '../../tools/FileStorageManager.js'; + +// Mock FileStorageManager +class MockFileStorageManager { + private files: Map = new Map(); + private currentSessionId = 'test-session'; + + getSessionId(): string { + return this.currentSessionId; + } + + setSessionId(sessionId: string): void { + this.currentSessionId = sessionId; + } + + async readFile(fileName: string): Promise { + const key = `${this.currentSessionId}:${fileName}`; + return this.files.get(key) || null; + } + + async createFile(fileName: string, content: string, mimeType: string): Promise { + const key = `${this.currentSessionId}:${fileName}`; + if (this.files.has(key)) { + throw new Error(`File "${fileName}" already exists.`); + } + const now = Date.now(); + const file: StoredFile = { + id: `id-${fileName}`, + sessionId: this.currentSessionId, + fileName, + content, + mimeType, + createdAt: now, + updatedAt: now, + size: content.length, + }; + this.files.set(key, file); + return file; + } + + async updateFile(fileName: string, content: string): Promise { + const key = `${this.currentSessionId}:${fileName}`; + const existing = this.files.get(key); + if (!existing) { + throw new Error(`File "${fileName}" not found.`); + } + const updated: StoredFile = { ...existing, content, updatedAt: Date.now(), size: content.length }; + this.files.set(key, updated); + return updated; + } + + async deleteFile(fileName: string): Promise { + const key = `${this.currentSessionId}:${fileName}`; + if (!this.files.has(key)) { + throw new Error(`File "${fileName}" not found.`); + } + this.files.delete(key); + } + + async listFiles(): Promise { + const summaries: FileSummary[] = []; + for (const [key, file] of this.files.entries()) { + if (key.startsWith(`${this.currentSessionId}:`)) { + summaries.push({ + fileName: file.fileName, + size: file.size, + mimeType: file.mimeType, + createdAt: file.createdAt, + updatedAt: file.updatedAt, + }); + } + } + return summaries; + } + + clearAllFiles(): void { + this.files.clear(); + } +} + +// Mock localStorage +const createLocalStorageMock = () => { + const store: Record = {}; + return { + getItem: (key: string): string | null => store[key] ?? null, + setItem: (key: string, value: string): void => { store[key] = value; }, + removeItem: (key: string): void => { delete store[key]; }, + clear: (): void => { Object.keys(store).forEach(k => delete store[k]); }, + get length(): number { return Object.keys(store).length; }, + key: (index: number): string | null => Object.keys(store)[index] ?? null, + }; +}; + +describe('UpdateMemoryTool', () => { + let mockFileStorageManager: MockFileStorageManager; + let originalFileStorageGetInstance: typeof FileStorageManager.getInstance; + let originalLocalStorage: Storage; + let mockLocalStorage: ReturnType; + let tool: UpdateMemoryTool; + + beforeEach(() => { + // Reset MemoryModule singleton + (MemoryModule as any).instance = null; + + // Mock localStorage + originalLocalStorage = globalThis.localStorage; + mockLocalStorage = createLocalStorageMock(); + Object.defineProperty(globalThis, 'localStorage', { + value: mockLocalStorage, + writable: true, + configurable: true, + }); + + // Mock FileStorageManager.getInstance + mockFileStorageManager = new MockFileStorageManager(); + originalFileStorageGetInstance = FileStorageManager.getInstance; + FileStorageManager.getInstance = () => mockFileStorageManager as unknown as FileStorageManager; + + tool = new UpdateMemoryTool(); + }); + + afterEach(() => { + FileStorageManager.getInstance = originalFileStorageGetInstance; + Object.defineProperty(globalThis, 'localStorage', { + value: originalLocalStorage, + writable: true, + configurable: true, + }); + mockFileStorageManager.clearAllFiles(); + }); + + describe('tool metadata', () => { + it('has correct name', () => { + assert.strictEqual(tool.name, 'update_memory'); + }); + + it('has description', () => { + assert.isString(tool.description); + assert.isTrue(tool.description.length > 0); + assert.isTrue(tool.description.includes('user')); + assert.isTrue(tool.description.includes('facts')); + assert.isTrue(tool.description.includes('project')); + }); + + it('has correct schema', () => { + assert.deepEqual(tool.schema.type, 'object'); + assert.isObject(tool.schema.properties); + assert.isObject((tool.schema.properties as any).blockType); + assert.isObject((tool.schema.properties as any).content); + assert.isObject((tool.schema.properties as any).projectName); + assert.deepEqual(tool.schema.required, ['blockType', 'content']); + }); + + it('schema has blockType enum', () => { + const blockType = (tool.schema.properties as any).blockType; + assert.deepEqual(blockType.enum, ['user', 'facts', 'project']); + }); + }); + + describe('execute', () => { + it('returns error when project block missing projectName', async () => { + const result = await tool.execute({ + blockType: 'project', + content: 'Project content', + }); + + assert.isFalse(result.success); + assert.isTrue(result.message.includes('projectName')); + assert.isTrue(result.error!.includes('projectName')); + }); + + it('updates user block successfully', async () => { + const result = await tool.execute({ + blockType: 'user', + content: 'User preferences here', + }); + + assert.isTrue(result.success); + assert.isTrue(result.message.includes('user')); + assert.isTrue(result.message.includes('22 chars')); + + // Verify file was created + mockFileStorageManager.setSessionId('__global_memory__'); + const file = await mockFileStorageManager.readFile('memory_user.md'); + assert.strictEqual(file!.content, 'User preferences here'); + }); + + it('updates facts block successfully', async () => { + const result = await tool.execute({ + blockType: 'facts', + content: 'Some important facts', + }); + + assert.isTrue(result.success); + assert.isTrue(result.message.includes('facts')); + + mockFileStorageManager.setSessionId('__global_memory__'); + const file = await mockFileStorageManager.readFile('memory_facts.md'); + assert.strictEqual(file!.content, 'Some important facts'); + }); + + it('updates project block with projectName', async () => { + const result = await tool.execute({ + blockType: 'project', + content: 'My app project context', + projectName: 'my-app', + }); + + assert.isTrue(result.success); + assert.isTrue(result.message.includes('project:my-app')); + + mockFileStorageManager.setSessionId('__global_memory__'); + const file = await mockFileStorageManager.readFile('memory_project_my_app.md'); + assert.strictEqual(file!.content, 'My app project context'); + }); + + it('returns error when content exceeds limit', async () => { + const oversizedContent = 'x'.repeat(20001); + + const result = await tool.execute({ + blockType: 'user', + content: oversizedContent, + }); + + assert.isFalse(result.success); + assert.isTrue(result.message.includes('exceeds')); + assert.isString(result.error); + }); + + it('returns error when max project blocks reached', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + // Create 4 project blocks + await mockFileStorageManager.createFile('memory_project_one.md', 'Project 1', 'text/markdown'); + await mockFileStorageManager.createFile('memory_project_two.md', 'Project 2', 'text/markdown'); + await mockFileStorageManager.createFile('memory_project_three.md', 'Project 3', 'text/markdown'); + await mockFileStorageManager.createFile('memory_project_four.md', 'Project 4', 'text/markdown'); + + const result = await tool.execute({ + blockType: 'project', + content: 'Fifth project', + projectName: 'five', + }); + + assert.isFalse(result.success); + assert.isTrue(result.message.includes('Max')); + }); + + it('updates existing block rather than creating new', async () => { + mockFileStorageManager.setSessionId('__global_memory__'); + await mockFileStorageManager.createFile('memory_user.md', 'Old content', 'text/markdown'); + + const result = await tool.execute({ + blockType: 'user', + content: 'New content', + }); + + assert.isTrue(result.success); + + const file = await mockFileStorageManager.readFile('memory_user.md'); + assert.strictEqual(file!.content, 'New content'); + }); + }); +}); diff --git a/front_end/panels/ai_chat/memory/index.ts b/front_end/panels/ai_chat/memory/index.ts new file mode 100644 index 0000000000..e275cd312e --- /dev/null +++ b/front_end/panels/ai_chat/memory/index.ts @@ -0,0 +1,28 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/** + * Memory Module - Public API + * + * This module provides a consolidated memory system for the AI Chat panel. + * All memory-related functionality should be accessed through this index. + */ + +// Core exports +export { MemoryModule } from './MemoryModule.js'; +export { MemoryBlockManager } from './MemoryBlockManager.js'; +export { createMemoryAgentConfig } from './MemoryAgentConfig.js'; + +// Tool exports +export { SearchMemoryTool } from './SearchMemoryTool.js'; +export { UpdateMemoryTool } from './UpdateMemoryTool.js'; +export { ListMemoryBlocksTool } from './ListMemoryBlocksTool.js'; + +// Type exports +export type { + BlockType, + MemoryBlock, + MemorySearchResult, + MemoryConfig, +} from './types.js'; diff --git a/front_end/panels/ai_chat/memory/types.ts b/front_end/panels/ai_chat/memory/types.ts new file mode 100644 index 0000000000..337c2bdce1 --- /dev/null +++ b/front_end/panels/ai_chat/memory/types.ts @@ -0,0 +1,87 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/** + * Memory System Types + * + * Shared type definitions for the memory module. + */ + +/** + * Block types for memory storage. + * - user: User preferences, name, coding style + * - facts: Recent extracted facts from conversations + * - project: Project-specific context (max 4 blocks) + */ +export type BlockType = 'user' | 'facts' | 'project'; + +/** + * A memory block stored in the system. + */ +export interface MemoryBlock { + filename: string; + type: BlockType; + label: string; + description: string; + content: string; + charLimit: number; + updatedAt: number; +} + +/** + * Result from searching memory blocks. + */ +export interface MemorySearchResult { + block: MemoryBlock; + matches: string[]; +} + +/** + * Configuration constants for the memory system. + */ +export interface MemoryConfig { + /** Character limit per block type */ + blockLimits: { + user: number; + facts: number; + project: number; + }; + /** Maximum number of project blocks allowed */ + maxProjectBlocks: number; + /** Session ID used for global memory storage */ + sessionId: string; + /** LocalStorage key for memory enabled setting */ + enabledKey: string; +} + +/** + * Operations supported by the unified memory tool. + */ +export type MemoryOperation = 'search' | 'update' | 'list' | 'delete'; + +/** + * Arguments for the unified memory tool. + */ +export interface MemoryToolArgs { + /** The operation to perform */ + operation: MemoryOperation; + /** Search query (for search operation) */ + query?: string; + /** Block type (for update/delete operations) */ + blockType?: BlockType; + /** Content to write (for update operation) */ + content?: string; + /** Project name (required for project block operations) */ + projectName?: string; +} + +/** + * Result from the unified memory tool. + */ +export interface MemoryToolResult { + success: boolean; + operation: MemoryOperation; + data?: unknown; + error?: string; +} diff --git a/front_end/panels/ai_chat/models/ChatTypes.ts b/front_end/panels/ai_chat/models/ChatTypes.ts index eab00a0ede..72f6c2ea95 100644 --- a/front_end/panels/ai_chat/models/ChatTypes.ts +++ b/front_end/panels/ai_chat/models/ChatTypes.ts @@ -10,6 +10,7 @@ export enum ChatMessageEntity { MODEL = 'model', TOOL_RESULT = 'tool_result', AGENT_SESSION = 'agent_session', + APPROVAL_REQUEST = 'approval_request', } // Base structure for all chat messages @@ -71,8 +72,39 @@ export interface AgentSessionMessage extends BaseChatMessage { summary?: string; } +// Approval status for human-in-the-loop decisions +export type ApprovalStatus = 'pending' | 'approved' | 'rejected'; + +// Risk levels for guardrail evaluation +export type RiskLevel = 'none' | 'low' | 'medium' | 'high' | 'critical'; + +// Approval request message for human-in-the-loop approval flow +export interface ApprovalRequestMessage extends BaseChatMessage { + entity: ChatMessageEntity.APPROVAL_REQUEST; + /** Unique identifier for this approval request */ + approvalId: string; + /** Name of the tool requesting approval */ + toolName: string; + /** Arguments passed to the tool */ + toolArgs: Record; + /** Human-readable description of what the tool wants to do */ + description: string; + /** Current approval status */ + status: ApprovalStatus; + /** Risk level from guardrail evaluation */ + riskLevel: RiskLevel; + /** Chain-of-thought reasoning from guardrail */ + reasoning?: string; + /** Policy that triggered this approval request */ + policyMatched?: string; + /** User feedback on rejection */ + feedback?: string; + /** Links to original tool call */ + toolCallId?: string; +} + export type ChatMessage = - UserChatMessage|ModelChatMessage|ToolResultMessage|AgentSessionMessage; + UserChatMessage|ModelChatMessage|ToolResultMessage|AgentSessionMessage|ApprovalRequestMessage; // View state for the chat container export enum State { diff --git a/front_end/panels/ai_chat/persistence/ConversationManager.ts b/front_end/panels/ai_chat/persistence/ConversationManager.ts index 50ff286f1a..4a81903a27 100644 --- a/front_end/panels/ai_chat/persistence/ConversationManager.ts +++ b/front_end/panels/ai_chat/persistence/ConversationManager.ts @@ -27,6 +27,9 @@ export class ConversationManager { private static instance: ConversationManager|null = null; private storageManager: ConversationStorageManager; + // 30 minutes timeout for stale 'processing' status + private static readonly PROCESSING_TIMEOUT_MS = 30 * 60 * 1000; + private constructor() { this.storageManager = ConversationStorageManager.getInstance(); logger.info('Initialized ConversationManager'); @@ -181,4 +184,112 @@ export class ConversationManager { await this.storageManager.clearAllConversations(); logger.info('Cleared all conversations'); } + + // ==================== Memory Processing Methods ==================== + + /** + * Attempts to claim a conversation for memory processing. + * Returns true if claimed successfully, false if already processing. + * Uses 'processing' status as a lock to prevent concurrent processing. + * + * Will re-claim if: + * - Status is 'failed' (retry) + * - Status is 'processing' but started > 30 min ago (stale/crashed) + */ + async tryClaimForMemoryProcessing(conversationId: string): Promise { + const conversation = await this.storageManager.loadConversation(conversationId); + if (!conversation) { + logger.info('[Memory] Claim failed - conversation not found', { conversationId }); + return false; + } + + logger.info('[Memory] Current memory status', { + conversationId, + memoryStatus: conversation.memoryStatus, + memoryProcessedAt: conversation.memoryProcessedAt, + memoryProcessingStartedAt: conversation.memoryProcessingStartedAt, + }); + + // Already completed - don't reprocess + if (conversation.memoryStatus === 'completed') { + logger.info('[Memory] Claim failed - already completed', { conversationId }); + return false; + } + + // Currently processing - check if stale (> 30 min) + if (conversation.memoryStatus === 'processing') { + const startedAt = conversation.memoryProcessingStartedAt || 0; + const elapsed = Date.now() - startedAt; + if (elapsed < ConversationManager.PROCESSING_TIMEOUT_MS) { + // Still within timeout, don't re-claim + return false; + } + // Stale processing - allow re-claim + logger.warn('Re-claiming stale processing conversation', { + conversationId, + elapsedMs: elapsed, + }); + } + + // Claim it by setting to 'processing' with timestamp + conversation.memoryStatus = 'processing'; + conversation.memoryProcessingStartedAt = Date.now(); + await this.storageManager.saveConversation(conversation); + logger.info('Claimed conversation for memory processing', {conversationId}); + return true; + } + + /** + * Marks memory processing as completed. + */ + async markMemoryCompleted(conversationId: string): Promise { + const conversation = await this.storageManager.loadConversation(conversationId); + if (conversation) { + conversation.memoryStatus = 'completed'; + conversation.memoryProcessedAt = Date.now(); + await this.storageManager.saveConversation(conversation); + logger.info('Marked memory as completed', {conversationId}); + } + } + + /** + * Marks memory processing as failed (can be retried later). + */ + async markMemoryFailed(conversationId: string): Promise { + const conversation = await this.storageManager.loadConversation(conversationId); + if (conversation) { + conversation.memoryStatus = 'failed'; + await this.storageManager.saveConversation(conversation); + logger.warn('Marked memory as failed', {conversationId}); + } + } + + /** + * Returns conversations that need memory processing. + * Includes: + * - pending, failed, or undefined status (old conversations) + * - 'processing' that started > 30 min ago (stale/crashed) + */ + async getConversationsNeedingMemoryProcessing(): Promise { + const all = await this.listConversations(); + const now = Date.now(); + + return all.filter(c => { + // Not started, pending, or failed - needs processing + if (!c.memoryStatus || + c.memoryStatus === 'pending' || + c.memoryStatus === 'failed') { + return true; + } + + // Stale processing (> 30 min) - needs retry + if (c.memoryStatus === 'processing') { + const startedAt = c.memoryProcessingStartedAt || 0; + const elapsed = now - startedAt; + return elapsed >= ConversationManager.PROCESSING_TIMEOUT_MS; + } + + return false; + }); + } } diff --git a/front_end/panels/ai_chat/persistence/ConversationTypes.ts b/front_end/panels/ai_chat/persistence/ConversationTypes.ts index 71092ef63c..82058cc8de 100644 --- a/front_end/panels/ai_chat/persistence/ConversationTypes.ts +++ b/front_end/panels/ai_chat/persistence/ConversationTypes.ts @@ -7,6 +7,15 @@ import type {ChatMessage} from '../models/ChatTypes.js'; import {ChatMessageEntity} from '../models/ChatTypes.js'; import type {AgentSession} from '../agent_framework/AgentSessionTypes.js'; +/** + * Memory processing status for conversation + */ +export type MemoryProcessingStatus = + | 'pending' // Not yet processed + | 'processing' // Currently being processed (prevents concurrent runs) + | 'completed' // Successfully processed + | 'failed'; // Failed (can retry) + /** * Represents a fully stored conversation with all state and metadata */ @@ -32,6 +41,11 @@ export interface StoredConversation { // Total number of messages in the conversation messageCount: number; + + // Memory extraction status + memoryStatus?: MemoryProcessingStatus; + memoryProcessedAt?: number; // Unix timestamp when completed + memoryProcessingStartedAt?: number; // Unix timestamp when processing started (for timeout detection) } /** @@ -44,6 +58,8 @@ export interface ConversationMetadata { updatedAt: number; preview?: string; messageCount: number; + memoryStatus?: MemoryProcessingStatus; + memoryProcessingStartedAt?: number; // Needed to detect stale processing } /** @@ -117,7 +133,7 @@ export interface SerializableAgentSession { export interface SerializableAgentMessage { id: string; timestamp: number; // Unix timestamp - type: 'reasoning' | 'tool_call' | 'tool_result' | 'handoff' | 'final_answer'; + type: 'reasoning' | 'tool_call' | 'tool_result' | 'handoff' | 'final_answer' | 'approval_request'; content: any; // Keep the content structure as-is } @@ -284,5 +300,7 @@ export function extractMetadata(conversation: StoredConversation): ConversationM updatedAt: conversation.updatedAt, preview: conversation.preview, messageCount: conversation.messageCount, + memoryStatus: conversation.memoryStatus, + memoryProcessingStartedAt: conversation.memoryProcessingStartedAt, }; } diff --git a/front_end/panels/ai_chat/persistence/MemoryBlockManager.ts b/front_end/panels/ai_chat/persistence/MemoryBlockManager.ts new file mode 100644 index 0000000000..f292eac7b4 --- /dev/null +++ b/front_end/panels/ai_chat/persistence/MemoryBlockManager.ts @@ -0,0 +1,284 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { FileStorageManager } from '../tools/FileStorageManager.js'; +import { createLogger } from '../core/Logger.js'; + +const logger = createLogger('MemoryBlockManager'); + +// Constants +const MEMORY_SESSION_ID = '__global_memory__'; +const BLOCK_LIMITS = { + user: 20000, + facts: 20000, + project: 20000, +} as const; +const MAX_PROJECT_BLOCKS = 4; + +export type BlockType = 'user' | 'facts' | 'project'; + +export interface MemoryBlock { + filename: string; + type: BlockType; + label: string; + description: string; + content: string; + charLimit: number; + updatedAt: number; +} + +export interface MemorySearchResult { + block: MemoryBlock; + matches: string[]; +} + +/** + * Manages memory blocks stored as files via FileStorageManager. + * Memory is global (shared across all conversations) using a reserved session ID. + * + * Block types: + * - user: User preferences, name, coding style (20000 chars) + * - facts: Recent extracted facts (20000 chars) + * - project: Project-specific context (20000 chars each, max 4) + */ +export class MemoryBlockManager { + private fileManager: FileStorageManager; + + constructor() { + this.fileManager = FileStorageManager.getInstance(); + } + + /** + * Execute a function with the global memory session, restoring the previous session after. + */ + private async withGlobalSession(fn: () => Promise): Promise { + const prevSession = this.fileManager.getSessionId(); + this.fileManager.setSessionId(MEMORY_SESSION_ID); + try { + return await fn(); + } finally { + this.fileManager.setSessionId(prevSession); + } + } + + // --- Block CRUD --- + + /** + * Get a memory block by type and optional project name. + */ + async getBlock(type: BlockType, projectName?: string): Promise { + return this.withGlobalSession(async () => { + const filename = this.getFilename(type, projectName); + const file = await this.fileManager.readFile(filename); + if (!file) { + return null; + } + + return { + filename, + type, + label: this.getLabel(type, projectName), + description: this.getDescription(type), + content: file.content, + charLimit: BLOCK_LIMITS[type], + updatedAt: file.updatedAt, + }; + }); + } + + /** + * Update or create a memory block. + */ + async updateBlock(type: BlockType, content: string, projectName?: string): Promise { + const limit = BLOCK_LIMITS[type]; + if (content.length > limit) { + throw new Error(`Content exceeds ${limit} char limit (got ${content.length})`); + } + + return this.withGlobalSession(async () => { + const filename = this.getFilename(type, projectName); + const exists = await this.fileManager.readFile(filename); + + if (exists) { + await this.fileManager.updateFile(filename, content, false); + logger.info('Updated memory block', { type, filename }); + } else { + // Check project limit before creating new project block + if (type === 'project') { + const projects = await this.listProjectBlocks(); + if (projects.length >= MAX_PROJECT_BLOCKS) { + throw new Error(`Max ${MAX_PROJECT_BLOCKS} project blocks allowed`); + } + } + await this.fileManager.createFile(filename, content, 'text/markdown'); + logger.info('Created memory block', { type, filename }); + } + }); + } + + /** + * Delete a memory block. + */ + async deleteBlock(type: BlockType, projectName?: string): Promise { + return this.withGlobalSession(async () => { + const filename = this.getFilename(type, projectName); + try { + await this.fileManager.deleteFile(filename); + logger.info('Deleted memory block', { type, filename }); + } catch (error) { + // Ignore if file doesn't exist + logger.debug('Block not found for deletion', { type, filename }); + } + }); + } + + // --- Queries --- + + /** + * Get all memory blocks. + */ + async getAllBlocks(): Promise { + return this.withGlobalSession(async () => { + const files = await this.fileManager.listFiles(); + const blocks: MemoryBlock[] = []; + + for (const file of files) { + if (!file.fileName.startsWith('memory_')) { + continue; + } + + const fullFile = await this.fileManager.readFile(file.fileName); + if (!fullFile) { + continue; + } + + const { type, projectName } = this.parseFilename(file.fileName); + blocks.push({ + filename: file.fileName, + type, + label: this.getLabel(type, projectName), + description: this.getDescription(type), + content: fullFile.content, + charLimit: BLOCK_LIMITS[type], + updatedAt: file.updatedAt, + }); + } + + return blocks; + }); + } + + /** + * List only project blocks. + */ + async listProjectBlocks(): Promise { + const all = await this.getAllBlocks(); + return all.filter(b => b.type === 'project'); + } + + /** + * Search across all blocks for matching lines. + */ + async searchBlocks(query: string): Promise { + const blocks = await this.getAllBlocks(); + const results: MemorySearchResult[] = []; + const queryLower = query.toLowerCase(); + + for (const block of blocks) { + const lines = block.content.split('\n'); + const matches = lines.filter(line => + line.toLowerCase().includes(queryLower) + ); + if (matches.length > 0) { + results.push({ block, matches }); + } + } + + return results; + } + + // --- Helpers --- + + private getFilename(type: BlockType, projectName?: string): string { + if (type === 'project' && projectName) { + const safeName = projectName.toLowerCase().replace(/[^a-z0-9]/g, '_'); + return `memory_project_${safeName}.md`; + } + return `memory_${type}.md`; + } + + private parseFilename(filename: string): { type: BlockType; projectName?: string } { + if (filename === 'memory_user.md') { + return { type: 'user' }; + } + if (filename === 'memory_facts.md') { + return { type: 'facts' }; + } + if (filename.startsWith('memory_project_')) { + const projectName = filename.replace('memory_project_', '').replace('.md', ''); + return { type: 'project', projectName }; + } + return { type: 'facts' }; // fallback + } + + private getLabel(type: BlockType, projectName?: string): string { + if (type === 'project') { + return `project:${projectName}`; + } + return type; + } + + private getDescription(type: BlockType): string { + switch (type) { + case 'user': + return 'User preferences, name, coding style, and personal context'; + case 'facts': + return 'Recent facts extracted from conversations'; + case 'project': + return 'Project-specific context, tech stack, and goals'; + } + } + + // --- Memory Compilation (for prompt injection) --- + + /** + * Compile all memory blocks into XML context for prompt injection. + */ + async compileMemoryContext(): Promise { + const blocks = await this.getAllBlocks(); + if (blocks.length === 0) { + return ''; + } + + let context = '\n'; + + for (const block of blocks) { + if (!block.content.trim()) { + continue; + } + + context += `<${block.label}>\n`; + context += `${block.description}\n`; + context += `\n${block.content}\n\n`; + context += `\n`; + } + + context += ''; + return context; + } + + /** + * Get the character limit for a block type. + */ + static getBlockLimit(type: BlockType): number { + return BLOCK_LIMITS[type]; + } + + /** + * Get the maximum number of project blocks allowed. + */ + static getMaxProjectBlocks(): number { + return MAX_PROJECT_BLOCKS; + } +} diff --git a/front_end/panels/ai_chat/tools/DeleteFileTool.ts b/front_end/panels/ai_chat/tools/DeleteFileTool.ts index 6743497758..895cb8f3b7 100644 --- a/front_end/panels/ai_chat/tools/DeleteFileTool.ts +++ b/front_end/panels/ai_chat/tools/DeleteFileTool.ts @@ -38,6 +38,13 @@ export class DeleteFileTool implements Tool { required: ['fileName', 'reasoning'] }; + // Destructive action - requires human approval by default + approvalConfig = { + requiresApproval: true, + riskLevel: 'high' as const, + approvalMessage: 'This tool will permanently delete a file. Please confirm before proceeding.', + }; + async execute(args: DeleteFileArgs, _ctx?: LLMContext): Promise { logger.info('Executing delete file', { fileName: args.fileName }); const manager = FileStorageManager.getInstance(); diff --git a/front_end/panels/ai_chat/tools/ExecuteCodeTool.ts b/front_end/panels/ai_chat/tools/ExecuteCodeTool.ts index ff55982f6f..4bc302b5a9 100644 --- a/front_end/panels/ai_chat/tools/ExecuteCodeTool.ts +++ b/front_end/panels/ai_chat/tools/ExecuteCodeTool.ts @@ -67,6 +67,13 @@ Examples: required: ['code', 'reasoning'] }; + // High-risk tool - requires human approval by default + approvalConfig = { + requiresApproval: true, + riskLevel: 'high' as const, + approvalMessage: 'This tool will execute JavaScript code on the page. Please review the code before approving.', + }; + async execute(args: ExecuteCodeArgs, _ctx?: LLMContext): Promise { const { code, reasoning } = args; diff --git a/front_end/panels/ai_chat/tools/FileStorageManager.ts b/front_end/panels/ai_chat/tools/FileStorageManager.ts index 3e431cbb20..fea32ff4ac 100644 --- a/front_end/panels/ai_chat/tools/FileStorageManager.ts +++ b/front_end/panels/ai_chat/tools/FileStorageManager.ts @@ -49,8 +49,8 @@ export class FileStorageManager { private dbInitializationPromise: Promise | null = null; private constructor() { - this.sessionId = 'default'; // Will be set to conversation ID when conversation is created/loaded - logger.info('Initialized FileStorageManager with default session'); + this.sessionId = `temp-${this.generateUUID()}`; // Unique per session, will be set to conversation ID when conversation is created/loaded + logger.info('Initialized FileStorageManager with session', { sessionId: this.sessionId }); } static getInstance(): FileStorageManager { diff --git a/front_end/panels/ai_chat/tools/Tools.ts b/front_end/panels/ai_chat/tools/Tools.ts index e282b068a8..0fef74f6ab 100644 --- a/front_end/panels/ai_chat/tools/Tools.ts +++ b/front_end/panels/ai_chat/tools/Tools.ts @@ -37,6 +37,23 @@ import { RenderWebAppTool, type RenderWebAppArgs, type RenderWebAppResult } from import { GetWebAppDataTool, type GetWebAppDataArgs, type GetWebAppDataResult } from './GetWebAppDataTool.js'; import { RemoveWebAppTool, type RemoveWebAppArgs, type RemoveWebAppResult } from './RemoveWebAppTool.js'; +/** + * Risk levels for tool approval (shared with GuardrailEvaluator) + */ +export type ToolRiskLevel = 'none' | 'low' | 'medium' | 'high' | 'critical'; + +/** + * Approval configuration for tools + */ +export interface ToolApprovalConfig { + /** Whether this tool requires human approval by default */ + requiresApproval?: boolean; + /** Default risk level for this tool */ + riskLevel?: ToolRiskLevel; + /** Custom approval message shown to user */ + approvalMessage?: string; +} + /** * Base interface for all tools */ @@ -49,6 +66,8 @@ export interface Tool, TResult = unknown> { properties: Record, required?: string[], }; + /** Optional approval configuration - if set, tool may require human approval */ + approvalConfig?: ToolApprovalConfig; } /** @@ -62,6 +81,8 @@ export interface LLMContext { miniModel?: string; nanoModel?: string; abortSignal?: AbortSignal; + /** If true, don't emit UI progress events (for background tools/agents) */ + background?: boolean; } /** diff --git a/front_end/panels/ai_chat/ui/AIChatPanel.ts b/front_end/panels/ai_chat/ui/AIChatPanel.ts index c2311e360b..23c0f5d99f 100644 --- a/front_end/panels/ai_chat/ui/AIChatPanel.ts +++ b/front_end/panels/ai_chat/ui/AIChatPanel.ts @@ -9,7 +9,13 @@ import * as SDK from '../../../core/sdk/sdk.js'; import * as UI from '../../../ui/legacy/legacy.js'; import {AgentService, Events as AgentEvents} from '../core/AgentService.js'; import { LLMClient } from '../LLM/LLMClient.js'; -import { LLMConfigurationManager } from '../core/LLMConfigurationManager.js'; +import { + LLMConfigurationManager, + type ModelOption, + DEFAULT_PROVIDER_MODELS, + DEFAULT_OPENAI_MODELS, + MODEL_PLACEHOLDERS as CONFIG_MODEL_PLACEHOLDERS, +} from '../core/LLMConfigurationManager.js'; import { LLMProviderRegistry } from '../LLM/LLMProviderRegistry.js'; import { createLogger } from '../core/Logger.js'; import { CustomProviderManager } from '../core/CustomProviderManager.js'; @@ -18,6 +24,7 @@ import type { ProviderType } from './settings/types.js'; import { isEvaluationEnabled, getEvaluationConfig } from '../common/EvaluationConfig.js'; import { EvaluationAgent } from '../evaluation/remote/EvaluationAgent.js'; import { BUILD_CONFIG } from '../core/BuildConfig.js'; +import { OnboardingDialog, createSetupRequiredBanner } from './OnboardingDialog.js'; // Import of LiveAgentSessionComponent is not required here; the element is // registered by ChatView where it is used. @@ -92,73 +99,12 @@ import { MCPConnectorsCatalogDialog } from './mcp/MCPConnectorsCatalogDialog.js' import { ConversationHistoryList } from './ConversationHistoryList.js'; -// Model type definition -export interface ModelOption { - value: string; - label: string; - type: string; // Supports standard providers and custom providers (e.g., 'custom:my-provider') -} - -// Add model options constant - these are the default OpenAI models -const DEFAULT_OPENAI_MODELS: ModelOption[] = [ - {value: 'o4-mini-2025-04-16', label: 'O4 Mini', type: 'openai'}, - {value: 'o3-mini-2025-01-31', label: 'O3 Mini', type: 'openai'}, - {value: 'gpt-5-2025-08-07', label: 'GPT-5', type: 'openai'}, - {value: 'gpt-5-mini-2025-08-07', label: 'GPT-5 Mini', type: 'openai'}, - {value: 'gpt-5-nano-2025-08-07', label: 'GPT-5 Nano', type: 'openai'}, - {value: 'gpt-4.1-2025-04-14', label: 'GPT-4.1', type: 'openai'}, - {value: 'gpt-4.1-mini-2025-04-14', label: 'GPT-4.1 Mini', type: 'openai'}, - {value: 'gpt-4.1-nano-2025-04-14', label: 'GPT-4.1 Nano', type: 'openai'}, -]; - -// Default model selections for each provider -export const DEFAULT_PROVIDER_MODELS: Record = { - openai: { - main: 'gpt-4.1-2025-04-14', - mini: 'gpt-4.1-mini-2025-04-14', - nano: 'gpt-4.1-nano-2025-04-14' - }, - litellm: { - main: '', // Will use first available model - mini: '', - nano: '' - }, - groq: { - main: 'meta-llama/llama-4-scout-17b-16e-instruct', - mini: 'qwen/qwen3-32b', - nano: 'llama-3.1-8b-instant' - }, - openrouter: { - main: 'anthropic/claude-sonnet-4', - mini: 'google/gemini-2.5-flash', - nano: 'google/gemini-2.5-flash-lite-preview-06-17' - }, - browseroperator: { - main: 'main', - mini: 'mini', - nano: 'nano' - }, - cerebras: { - main: 'llama-3.3-70b', - mini: 'llama-3.3-8b', - nano: 'llama-3.3-8b' - }, - anthropic: { - main: 'claude-sonnet-4-20250514', - mini: 'claude-haiku-3-5-20241022', - nano: 'claude-haiku-3-5-20241022' - }, - googleai: { - main: 'gemini-2.0-flash-exp', - mini: 'gemini-2.0-flash-thinking-exp-01-21', - nano: 'gemini-2.0-flash-thinking-exp-01-21' - } -}; - -// This will hold the current active model options -let MODEL_OPTIONS: ModelOption[] = [...DEFAULT_OPENAI_MODELS]; +// Re-export ModelOption type for backward compatibility +export type { ModelOption }; +// Re-export DEFAULT_PROVIDER_MODELS for backward compatibility +export { DEFAULT_PROVIDER_MODELS }; -// Model selector localStorage keys +// Model selector localStorage keys (kept for local usage) const MODEL_SELECTION_KEY = 'ai_chat_model_selection'; const MINI_MODEL_STORAGE_KEY = 'ai_chat_mini_model'; const NANO_MODEL_STORAGE_KEY = 'ai_chat_nano_model'; @@ -168,6 +114,25 @@ const PROVIDER_SELECTION_KEY = 'ai_chat_provider'; const LITELLM_ENDPOINT_KEY = 'ai_chat_litellm_endpoint'; const LITELLM_API_KEY_STORAGE_KEY = 'ai_chat_litellm_api_key'; +// Local MODEL_OPTIONS reference that syncs with LLMConfigurationManager +// This maintains backward compatibility while delegating to the centralized manager +let MODEL_OPTIONS: ModelOption[] = [...DEFAULT_OPENAI_MODELS]; + +// Helper to get MODEL_OPTIONS from LLMConfigurationManager +function getModelOptions(): ModelOption[] { + return LLMConfigurationManager.getInstance().getModelOptionsForCurrentProvider(); +} + +// Helper to get all model options across all providers +function getAllModelOptions(): ModelOption[] { + return LLMConfigurationManager.getInstance().getAllModelOptions(); +} + +// Sync local MODEL_OPTIONS with LLMConfigurationManager +function syncModelOptions(): void { + MODEL_OPTIONS = LLMConfigurationManager.getInstance().getModelOptionsForCurrentProvider(); +} + const UIStrings = { /** *@description Text for the AI welcome message @@ -377,263 +342,82 @@ export class AIChatPanel extends UI.Panel.Panel { * @returns Array of model options */ static getModelOptions(provider?: ProviderType): ModelOption[] { - // Try to get from all_model_options first (comprehensive list) - const allModelOptionsStr = localStorage.getItem('ai_chat_all_model_options'); - if (allModelOptionsStr) { - try { - const allModelOptions = JSON.parse(allModelOptionsStr); - // If provider is specified, filter by it - return provider ? allModelOptions.filter((opt: ModelOption) => opt.type === provider) : allModelOptions; - } catch (error) { - console.warn('Failed to parse ai_chat_all_model_options from localStorage, removing corrupted data:', error); - localStorage.removeItem('ai_chat_all_model_options'); - } - } - - // Fallback to legacy model_options if all_model_options doesn't exist - const modelOptionsStr = localStorage.getItem('ai_chat_model_options'); - if (modelOptionsStr) { - try { - const modelOptions = JSON.parse(modelOptionsStr); - // If we got legacy options, migrate them to all_model_options for future use - localStorage.setItem('ai_chat_all_model_options', modelOptionsStr); - // Apply provider filter if needed - return provider ? modelOptions.filter((opt: ModelOption) => opt.type === provider) : modelOptions; - } catch (error) { - console.warn('Failed to parse ai_chat_model_options from localStorage, removing corrupted data:', error); - localStorage.removeItem('ai_chat_model_options'); - } + const configManager = LLMConfigurationManager.getInstance(); + if (provider) { + return configManager.getModelOptions(provider); } - - // If nothing is found, return default OpenAI models - return provider === 'litellm' ? [] : DEFAULT_OPENAI_MODELS; + return configManager.getAllModelOptions(); } /** * Updates model options with new provider models + * Delegates to centralized LLMConfigurationManager * @param providerModels Models fetched from any provider (LiteLLM, Groq, etc.) - * @param hadWildcard Whether LiteLLM returned a wildcard model + * @param _hadWildcard Whether LiteLLM returned a wildcard model (kept for backward compatibility) * @returns Updated model options */ - static updateModelOptions(providerModels: ModelOption[] = [], hadWildcard = false): ModelOption[] { - // Get the selected provider (for context, but we store all models regardless) - const selectedProvider = localStorage.getItem(PROVIDER_SELECTION_KEY) || 'openai'; - - // Get existing models from localStorage - let existingAllModels: ModelOption[] = []; - try { - existingAllModels = JSON.parse(localStorage.getItem('ai_chat_all_model_options') || '[]'); - } catch (error) { - console.warn('Failed to parse ai_chat_all_model_options from localStorage, using empty array:', error); - localStorage.removeItem('ai_chat_all_model_options'); - } - - // Get existing custom models (if any) - these are for LiteLLM only - let savedCustomModels: string[] = []; - try { - savedCustomModels = JSON.parse(localStorage.getItem('ai_chat_custom_models') || '[]'); - } catch (error) { - console.warn('Failed to parse ai_chat_custom_models from localStorage, using empty array:', error); - localStorage.removeItem('ai_chat_custom_models'); - } - const customModels = savedCustomModels.map((model: string) => ({ - value: model, - label: `LiteLLM: ${model}`, - type: 'litellm' as const - })); - - // Define standard provider types - const STANDARD_PROVIDER_TYPES: ProviderType[] = [ - 'openai', 'litellm', 'groq', 'openrouter', 'browseroperator', - 'cerebras', 'anthropic', 'googleai' - ]; - - // Get custom providers dynamically - const customProviders = CustomProviderManager.listEnabledProviders().map(p => p.id); - - // Combine standard and custom providers - const ALL_PROVIDER_TYPES = [...STANDARD_PROVIDER_TYPES, ...customProviders]; - - // Build a map of provider type -> models for generic handling - const modelsByProvider = new Map(); - - // Initialize with existing models for each provider - for (const providerType of ALL_PROVIDER_TYPES) { - const existingModels = existingAllModels.filter((m: ModelOption) => m.type === providerType); - modelsByProvider.set(providerType, existingModels); - } - - // Special case: OpenAI always uses DEFAULT_OPENAI_MODELS to ensure latest hardcoded list - modelsByProvider.set('openai', DEFAULT_OPENAI_MODELS); - - // Load models from custom providers - for (const customProviderId of customProviders) { - const customProvider = CustomProviderManager.getProvider(customProviderId); - if (customProvider && customProvider.models && customProvider.models.length > 0) { - const customProviderModels = customProvider.models.map(modelId => ({ - value: modelId, - label: `${customProvider.name}: ${modelId}`, - type: customProviderId as ProviderType - })); - modelsByProvider.set(customProviderId as ProviderType, customProviderModels); - } - } + static updateModelOptions(providerModels: ModelOption[] = [], _hadWildcard = false): ModelOption[] { + const configManager = LLMConfigurationManager.getInstance(); - // Update models for the provider type we're adding (if any) + // Determine provider from the models if (providerModels.length > 0) { - const firstModelType = providerModels[0].type; - - if (firstModelType === 'litellm') { - // Special case: LiteLLM includes custom models - modelsByProvider.set('litellm', [...customModels, ...providerModels]); - } else { - // For all other providers, just replace with new models - modelsByProvider.set(firstModelType, providerModels); - } - } - - // Create comprehensive model list from all providers - const allModels: ModelOption[] = []; - for (const providerType of ALL_PROVIDER_TYPES) { - const models = modelsByProvider.get(providerType) || []; - allModels.push(...models); - } - - // Save comprehensive list to localStorage - localStorage.setItem('ai_chat_all_model_options', JSON.stringify(allModels)); - - // Set MODEL_OPTIONS based on currently selected provider - MODEL_OPTIONS = modelsByProvider.get(selectedProvider as ProviderType) || []; - - // Add placeholder if no models available for the selected provider - if (MODEL_OPTIONS.length === 0) { - // Special case for LiteLLM with wildcard - if (selectedProvider === 'litellm' && hadWildcard) { - MODEL_OPTIONS.push({ - value: MODEL_PLACEHOLDERS.ADD_CUSTOM, - label: 'LiteLLM: Please add custom models in settings', - type: 'litellm' as const - }); - } else { - // Generic placeholder for all other providers - const providerLabel = selectedProvider.charAt(0).toUpperCase() + selectedProvider.slice(1); - MODEL_OPTIONS.push({ - value: MODEL_PLACEHOLDERS.NO_MODELS, - label: `${providerLabel}: Please configure in settings`, - type: selectedProvider as ProviderType - }); - } + const provider = providerModels[0].type; + configManager.setModelOptions(provider, providerModels); } - // Save MODEL_OPTIONS to localStorage for backwards compatibility - localStorage.setItem('ai_chat_model_options', JSON.stringify(MODEL_OPTIONS)); + // Sync local MODEL_OPTIONS + syncModelOptions(); - // Build log info dynamically for all providers - const logInfo: Record = { - provider: selectedProvider, - totalModelOptions: MODEL_OPTIONS.length, - allModelsLength: allModels.length - }; - for (const providerType of ALL_PROVIDER_TYPES) { - const models = modelsByProvider.get(providerType) || []; - logInfo[`${providerType}Models`] = models.length; - } + logger.info('Updated model options via configManager:', { + provider: configManager.getProvider(), + modelCount: MODEL_OPTIONS.length + }); - logger.info('Updated model options:', logInfo); - - return allModels; + return configManager.getAllModelOptions(); } /** * Adds a custom model to the options + * Delegates to centralized LLMConfigurationManager * @param modelName Name of the model to add * @param modelType Type of the model ('openai' or 'litellm') * @returns Updated model options */ static addCustomModelOption(modelName: string, modelType?: ProviderType): ModelOption[] { - // Default to litellm if not specified - const finalModelType = modelType || 'litellm'; + const configManager = LLMConfigurationManager.getInstance(); + configManager.addCustomModelOption(modelName, modelType); - // Get existing custom models - const savedCustomModels = JSON.parse(localStorage.getItem('ai_chat_custom_models') || '[]'); - - // Check if the model already exists - if (savedCustomModels.includes(modelName)) { - logger.info(`Custom model ${modelName} already exists, not adding again`); - return AIChatPanel.getModelOptions(); - } - - // Add the new model to custom models - savedCustomModels.push(modelName); - localStorage.setItem('ai_chat_custom_models', JSON.stringify(savedCustomModels)); - - // Create the model option object - const newOption: ModelOption = { - value: modelName, - label: finalModelType === 'litellm' ? `LiteLLM: ${modelName}` : - finalModelType === 'groq' ? `Groq: ${modelName}` : - finalModelType === 'openrouter' ? `OpenRouter: ${modelName}` : - `OpenAI: ${modelName}`, - type: finalModelType - }; - - // Get all existing model options - const allModelOptions = AIChatPanel.getModelOptions(); - - // Add the new option - const updatedOptions = [...allModelOptions, newOption]; - localStorage.setItem('ai_chat_all_model_options', JSON.stringify(updatedOptions)); - - // Update MODEL_OPTIONS for backwards compatibility if provider matches - const currentProvider = localStorage.getItem(PROVIDER_SELECTION_KEY) || 'openai'; - if ((currentProvider === 'openai' && modelType === 'openai') || - (currentProvider === 'litellm' && modelType === 'litellm') || - (currentProvider === 'groq' && modelType === 'groq')) { - MODEL_OPTIONS = [...MODEL_OPTIONS, newOption]; - localStorage.setItem('ai_chat_model_options', JSON.stringify(MODEL_OPTIONS)); - } - - return updatedOptions; + // Sync local MODEL_OPTIONS + syncModelOptions(); + + return configManager.getAllModelOptions(); } /** * Clears cached model data to force refresh from defaults + * Delegates to centralized LLMConfigurationManager */ static clearModelCache(): void { - localStorage.removeItem('ai_chat_all_model_options'); - localStorage.removeItem('ai_chat_model_options'); - logger.info('Cleared model cache - will use DEFAULT_OPENAI_MODELS on next refresh'); + const configManager = LLMConfigurationManager.getInstance(); + configManager.clearModelOptions(); + syncModelOptions(); + logger.info('Cleared model cache via configManager'); } /** * Removes a custom model from the options + * Delegates to centralized LLMConfigurationManager * @param modelName Name of the model to remove * @returns Updated model options */ static removeCustomModelOption(modelName: string): ModelOption[] { - // Get existing custom models - const savedCustomModels = JSON.parse(localStorage.getItem('ai_chat_custom_models') || '[]'); - - // Check if the model exists - if (!savedCustomModels.includes(modelName)) { - logger.info(`Custom model ${modelName} not found, nothing to remove`); - return AIChatPanel.getModelOptions(); - } - - // Remove the model from custom models - const updatedCustomModels = savedCustomModels.filter((model: string) => model !== modelName); - localStorage.setItem('ai_chat_custom_models', JSON.stringify(updatedCustomModels)); - - // Get all existing model options and remove the specified one - const allModelOptions = AIChatPanel.getModelOptions(); - const updatedOptions = allModelOptions.filter(option => option.value !== modelName); - localStorage.setItem('ai_chat_all_model_options', JSON.stringify(updatedOptions)); - - // Update MODEL_OPTIONS for backwards compatibility - MODEL_OPTIONS = MODEL_OPTIONS.filter(option => option.value !== modelName); - localStorage.setItem('ai_chat_model_options', JSON.stringify(MODEL_OPTIONS)); - - return updatedOptions; + const configManager = LLMConfigurationManager.getInstance(); + configManager.removeCustomModelOption(modelName); + + // Sync local MODEL_OPTIONS + syncModelOptions(); + + return configManager.getAllModelOptions(); } static readonly panelName = 'ai-chat'; @@ -847,94 +631,21 @@ export class AIChatPanel extends UI.Panel.Panel { * Sets up model options based on provider and stored preferences */ #setupModelOptions(): void { - // Get the selected provider - const selectedProvider = localStorage.getItem(PROVIDER_SELECTION_KEY) || 'openai'; - - // Initialize MODEL_OPTIONS based on the selected provider - this.#updateModelOptions([], false); - - // Load custom models - const savedCustomModels = JSON.parse(localStorage.getItem('ai_chat_custom_models') || '[]'); - - // If we have custom models and using LiteLLM, add them - if (savedCustomModels.length > 0 && selectedProvider === 'litellm') { - // Add custom models to MODEL_OPTIONS - const customOptions = savedCustomModels.map((model: string) => ({ - value: model, - label: `LiteLLM: ${model}`, - type: 'litellm' as const - })); - MODEL_OPTIONS = [...MODEL_OPTIONS, ...customOptions]; + const configManager = LLMConfigurationManager.getInstance(); - // Save MODEL_OPTIONS to localStorage - localStorage.setItem('ai_chat_model_options', JSON.stringify(MODEL_OPTIONS)); - } - - this.#loadModelSelections(); - - // Validate models after loading - this.#validateAndFixModelSelections(); - } + // Sync local MODEL_OPTIONS from the centralized manager + syncModelOptions(); - /** - * Loads model selections from localStorage - */ - #loadModelSelections(): void { - // Get the current provider - const currentProvider = localStorage.getItem(PROVIDER_SELECTION_KEY) || 'openai'; - const providerDefaults = DEFAULT_PROVIDER_MODELS[currentProvider] || DEFAULT_PROVIDER_MODELS.openai; - - // Load the selected model - const storedModel = localStorage.getItem(MODEL_SELECTION_KEY); - - if (MODEL_OPTIONS.length === 0) { - logger.warn('No model options available when loading model selections'); - return; - } - - if (storedModel && MODEL_OPTIONS.some(option => option.value === storedModel)) { - this.#selectedModel = storedModel; - } else if (MODEL_OPTIONS.length > 0) { - // Check if provider default main model is available - if (providerDefaults.main && MODEL_OPTIONS.some(option => option.value === providerDefaults.main)) { - this.#selectedModel = providerDefaults.main; - } else { - // Otherwise, use the first available model - this.#selectedModel = MODEL_OPTIONS[0].value; - } - localStorage.setItem(MODEL_SELECTION_KEY, this.#selectedModel); - } - - // Load mini model - check that it belongs to current provider - const storedMiniModel = localStorage.getItem(MINI_MODEL_STORAGE_KEY); - const storedMiniModelOption = storedMiniModel ? MODEL_OPTIONS.find(option => option.value === storedMiniModel) : null; - if (storedMiniModelOption && storedMiniModelOption.type === currentProvider && storedMiniModel) { - this.#miniModel = storedMiniModel; - } else if (providerDefaults.mini && MODEL_OPTIONS.some(option => option.value === providerDefaults.mini)) { - // Use provider default mini model if available - this.#miniModel = providerDefaults.mini; - localStorage.setItem(MINI_MODEL_STORAGE_KEY, this.#miniModel); - } else { - this.#miniModel = ''; - localStorage.removeItem(MINI_MODEL_STORAGE_KEY); - } + // Validate and fix model selections using centralized manager + const corrected = configManager.validateAndFixModelSelections(); - // Load nano model - check that it belongs to current provider - const storedNanoModel = localStorage.getItem(NANO_MODEL_STORAGE_KEY); - const storedNanoModelOption = storedNanoModel ? MODEL_OPTIONS.find(option => option.value === storedNanoModel) : null; - if (storedNanoModelOption && storedNanoModelOption.type === currentProvider && storedNanoModel) { - this.#nanoModel = storedNanoModel; - } else if (providerDefaults.nano && MODEL_OPTIONS.some(option => option.value === providerDefaults.nano)) { - // Use provider default nano model if available - this.#nanoModel = providerDefaults.nano; - localStorage.setItem(NANO_MODEL_STORAGE_KEY, this.#nanoModel); - } else { - this.#nanoModel = ''; - localStorage.removeItem(NANO_MODEL_STORAGE_KEY); - } - - logger.info('Loaded model selections:', { - provider: currentProvider, + // Apply the corrected values to instance state + this.#selectedModel = corrected.main; + this.#miniModel = corrected.mini; + this.#nanoModel = corrected.nano; + + logger.info('Setup model options:', { + provider: configManager.getProvider(), selectedModel: this.#selectedModel, miniModel: this.#miniModel, nanoModel: this.#nanoModel @@ -952,99 +663,37 @@ export class AIChatPanel extends UI.Panel.Panel { /** * Validates and fixes model selections to ensure they exist in the current provider * Returns true if all models are valid, false if any needed to be fixed + * Delegates to centralized LLMConfigurationManager */ #validateAndFixModelSelections(): boolean { - logger.info('=== VALIDATING MODEL SELECTIONS ==='); - - const currentProvider = localStorage.getItem(PROVIDER_SELECTION_KEY) || 'openai'; - const providerDefaults = DEFAULT_PROVIDER_MODELS[currentProvider] || DEFAULT_PROVIDER_MODELS.openai; - const availableModels = AIChatPanel.getModelOptions(currentProvider as 'openai' | 'litellm' | 'groq' | 'openrouter' | 'browseroperator'); - - let allValid = true; - - // Log current state - logger.info('Current state:', { - provider: currentProvider, - selectedModel: this.#selectedModel, - miniModel: this.#miniModel, - nanoModel: this.#nanoModel, - availableModelsCount: availableModels.length - }); - - // If no models available for provider, we have a problem - if (availableModels.length === 0) { - logger.error(`No models available for provider ${currentProvider}`); - return false; - } - - // Validate main model - const mainModelValid = availableModels.some(m => m.value === this.#selectedModel); - if (!mainModelValid) { - logger.warn(`Main model ${this.#selectedModel} not valid for ${currentProvider}, resetting...`); - allValid = false; - - // Try provider default first - if (providerDefaults.main && availableModels.some(m => m.value === providerDefaults.main)) { - this.#selectedModel = providerDefaults.main; - logger.info(`Reset main model to provider default: ${providerDefaults.main}`); - } else { - // Fall back to first available model - this.#selectedModel = availableModels[0].value; - logger.info(`Reset main model to first available: ${this.#selectedModel}`); - } - localStorage.setItem(MODEL_SELECTION_KEY, this.#selectedModel); - } - - // Validate mini model - if (this.#miniModel) { - const miniModelValid = availableModels.some(m => m.value === this.#miniModel); - if (!miniModelValid) { - logger.warn(`Mini model ${this.#miniModel} not valid for ${currentProvider}, resetting...`); - allValid = false; - - // Try provider default first - if (providerDefaults.mini && availableModels.some(m => m.value === providerDefaults.mini)) { - this.#miniModel = providerDefaults.mini; - logger.info(`Reset mini model to provider default: ${providerDefaults.mini}`); - localStorage.setItem(MINI_MODEL_STORAGE_KEY, this.#miniModel); - } else { - // Clear mini model to fall back to main model - this.#miniModel = ''; - logger.info('Cleared mini model to fall back to main model'); - localStorage.removeItem(MINI_MODEL_STORAGE_KEY); - } - } - } - - // Validate nano model - if (this.#nanoModel) { - const nanoModelValid = availableModels.some(m => m.value === this.#nanoModel); - if (!nanoModelValid) { - logger.warn(`Nano model ${this.#nanoModel} not valid for ${currentProvider}, resetting...`); - allValid = false; - - // Try provider default first - if (providerDefaults.nano && availableModels.some(m => m.value === providerDefaults.nano)) { - this.#nanoModel = providerDefaults.nano; - logger.info(`Reset nano model to provider default: ${providerDefaults.nano}`); - localStorage.setItem(NANO_MODEL_STORAGE_KEY, this.#nanoModel); - } else { - // Clear nano model to fall back to mini/main model - this.#nanoModel = ''; - logger.info('Cleared nano model to fall back to mini/main model'); - localStorage.removeItem(NANO_MODEL_STORAGE_KEY); - } - } + const configManager = LLMConfigurationManager.getInstance(); + + // Track previous values to determine if changes were made + const prevMain = this.#selectedModel; + const prevMini = this.#miniModel; + const prevNano = this.#nanoModel; + + // Delegate to centralized validation + const corrected = configManager.validateAndFixModelSelections(); + + // Apply corrected values to instance state + this.#selectedModel = corrected.main; + this.#miniModel = corrected.mini; + this.#nanoModel = corrected.nano; + + // Return true if no changes were needed + const allValid = prevMain === corrected.main && + prevMini === corrected.mini && + prevNano === corrected.nano; + + if (!allValid) { + logger.info('Model selections were fixed:', { + main: { from: prevMain, to: corrected.main }, + mini: { from: prevMini, to: corrected.mini }, + nano: { from: prevNano, to: corrected.nano } + }); } - - // Log final state - logger.info('Validation complete:', { - allValid, - finalSelectedModel: this.#selectedModel, - finalMiniModel: this.#miniModel, - finalNanoModel: this.#nanoModel - }); - + return allValid; } @@ -1230,57 +879,6 @@ export class AIChatPanel extends UI.Panel.Panel { AIChatPanel.updateModelOptions(litellmModels, hadWildcard); } - /** - * Refreshes Groq models from the API - */ - async #refreshGroqModels(): Promise { - try { - const groqApiKey = localStorage.getItem('ai_chat_groq_api_key'); - - if (!groqApiKey) { - logger.info('No Groq API key configured, skipping model refresh'); - return; - } - - const { models: groqModels } = await this.#fetchGroqModels(groqApiKey); - this.#updateModelOptions(groqModels, false); - - // Update MODEL_OPTIONS to reflect the fetched models - this.performUpdate(); - } catch (error) { - logger.error('Failed to refresh Groq models:', error); - // Clear Groq models on error - AIChatPanel.updateModelOptions([], false); - this.performUpdate(); - } - } - - /** - * Fetches Groq models from the API - * @param apiKey API key to use for the request - * @returns Object containing models - */ - async #fetchGroqModels(apiKey: string): Promise<{models: ModelOption[]}> { - try { - // Fetch models from Groq - const models = await LLMClient.fetchGroqModels(apiKey); - - // Transform the models to the format we need - const groqModels = models.map(model => ({ - value: model.id, - label: `Groq: ${model.id}`, - type: 'groq' as const - })); - - logger.info(`Fetched ${groqModels.length} Groq models`); - return { models: groqModels }; - } catch (error) { - logger.error('Failed to fetch Groq models:', error); - // Return empty array on error - return { models: [] }; - } - } - /** * Determines the status of the selected model * @param modelValue The model value to check @@ -1958,6 +1556,57 @@ export class AIChatPanel extends UI.Panel.Panel { override wasShown(): void { this.performUpdate(); this.#chatView?.focus(); + + // Show onboarding for first-time users + if (OnboardingDialog.shouldShowOnboarding()) { + OnboardingDialog.show(async () => { + // Fetch models for the newly selected provider + await this.#refreshModelsForCurrentProvider(); + // Sync MODEL_OPTIONS and validate model selections + this.#setupModelOptions(); + // Re-initialize agent service with newly selected provider + this.#initializeAgentService(); + // Refresh UI after onboarding completes + this.performUpdate(); + }); + return; + } + + // Refresh models when panel is shown to ensure we have the latest available models + void this.#refreshModelsForCurrentProvider(); + } + + /** + * Fetches and caches models for the current provider + * Uses LLMProviderRegistry directly (doesn't require LLMClient initialization) + */ + async #refreshModelsForCurrentProvider(): Promise { + const configManager = LLMConfigurationManager.getInstance(); + const provider = configManager.getProvider(); + + try { + const apiKey = LLMProviderRegistry.getProviderApiKey(provider as LLMProvider); + if (!apiKey) { + logger.debug(`No API key for provider ${provider}, skipping model refresh`); + return; + } + + const models = await LLMProviderRegistry.fetchProviderModels(provider as LLMProvider, apiKey); + + // Convert ModelInfo[] to ModelOption[] for UI caching + const modelOptions = models.map(m => ({ + value: m.id, + label: m.name || m.id, + type: provider + })); + + // Store in the configuration manager's cache + configManager.setModelOptions(provider, modelOptions); + logger.info(`Fetched and cached ${modelOptions.length} models for provider ${provider}`); + } catch (error) { + logger.error(`Failed to refresh models for provider ${provider}:`, error); + // Don't clear cache on error - keep existing cached models if available + } } /** @@ -2188,10 +1837,11 @@ export class AIChatPanel extends UI.Panel.Panel { this.#isProcessing = false; this.#selectedAgentType = null; // Reset selected agent type - // Reset file storage session ID to default for new chat + // Reset file storage session ID to a new unique ID for new chat const {FileStorageManager} = await import('../tools/FileStorageManager.js'); - FileStorageManager.getInstance().setSessionId('default'); - logger.info('Reset file storage sessionId to default for new chat'); + const newSessionId = `temp-${crypto.randomUUID()}`; + FileStorageManager.getInstance().setSessionId(newSessionId); + logger.info('Set file storage sessionId for new chat', { sessionId: newSessionId }); // Create new EvaluationAgent for new chat session this.#createEvaluationAgentIfNeeded(); @@ -2503,43 +2153,34 @@ export class AIChatPanel extends UI.Panel.Panel { * Handles changes made in the settings dialog */ async #handleSettingsChanged(): Promise { - // Get the selected provider - const prevProvider = localStorage.getItem(PROVIDER_SELECTION_KEY) || 'openai'; - const newProvider = localStorage.getItem(PROVIDER_SELECTION_KEY) || 'openai'; - - logger.info(`Provider changing from ${prevProvider} to ${newProvider}`); - - // Load saved settings + const configManager = LLMConfigurationManager.getInstance(); + const newProvider = configManager.getProvider(); + + logger.info(`Settings changed, current provider: ${newProvider}`); + + // Load saved settings (for instance properties) this.#apiKey = localStorage.getItem('ai_chat_api_key'); this.#liteLLMApiKey = localStorage.getItem(LITELLM_API_KEY_STORAGE_KEY); this.#liteLLMEndpoint = localStorage.getItem(LITELLM_ENDPOINT_KEY); - - // Reset model options based on the new provider - if (newProvider === 'litellm') { - // First update model options with empty models - this.#updateModelOptions([], false); - - // Then refresh LiteLLM models - await this.#refreshLiteLLMModels(); - } else if (newProvider === 'groq') { - // For Groq, update model options and refresh models if API key exists - this.#updateModelOptions([], false); - - const groqApiKey = localStorage.getItem('ai_chat_groq_api_key'); - if (groqApiKey) { - await this.#refreshGroqModels(); - } - } else { - // For OpenAI, just update model options with empty LiteLLM models - this.#updateModelOptions([], false); - } - - this.#updateModelSelections(); - - // Validate models after updating selections - this.#validateAndFixModelSelections(); - + + // Fetch models for the current provider + await this.#refreshModelsForCurrentProvider(); + + // Sync local MODEL_OPTIONS with the centralized manager + syncModelOptions(); + + // Use the centralized validation method (single source of truth) + const corrected = configManager.validateAndFixModelSelections(); + + // Update instance properties with corrected values + this.#selectedModel = corrected.main; + this.#miniModel = corrected.mini; + this.#nanoModel = corrected.nano; + + logger.info('Model selections after validation:', corrected); + this.#initializeAgentService(); + // Re-initialize MCP based on latest settings try { await MCPRegistry.init(); @@ -2588,15 +2229,21 @@ export class AIChatPanel extends UI.Panel.Panel { // Check if the current selected model is valid for the new provider const selectedModelOption = MODEL_OPTIONS.find(opt => opt.value === this.#selectedModel); - if (!selectedModelOption || selectedModelOption.type !== currentProvider) { - logger.info(`Selected model ${this.#selectedModel} is not valid for provider ${currentProvider}`); - + if (!this.#selectedModel || !selectedModelOption || selectedModelOption.type !== currentProvider) { + logger.info(`Selected model ${this.#selectedModel} is not valid for provider ${currentProvider}, selecting default`); + // Try to use provider default main model first if (providerDefaults.main && MODEL_OPTIONS.some(option => option.value === providerDefaults.main)) { this.#selectedModel = providerDefaults.main; + logger.info(`Set main model to provider default: ${providerDefaults.main}`); } else if (MODEL_OPTIONS.length > 0) { // Otherwise, use the first available model this.#selectedModel = MODEL_OPTIONS[0].value; + logger.info(`Set main model to first available: ${this.#selectedModel}`); + } else { + // No models available + this.#selectedModel = ''; + logger.warn(`No models available for provider ${currentProvider}`); } localStorage.setItem(MODEL_SELECTION_KEY, this.#selectedModel); } diff --git a/front_end/panels/ai_chat/ui/ChatView.ts b/front_end/panels/ai_chat/ui/ChatView.ts index a06d980e13..87c051f549 100644 --- a/front_end/panels/ai_chat/ui/ChatView.ts +++ b/front_end/panels/ai_chat/ui/ChatView.ts @@ -21,6 +21,7 @@ import './message/MessageList.js'; import { renderUserMessage } from './message/UserMessage.js'; import { renderModelMessage } from './message/ModelMessage.js'; import { renderToolResultMessage } from './message/ToolResultMessage.js'; +import { renderApprovalRequestMessage } from './message/ApprovalRequestMessage.js'; import './version/VersionBanner.js'; import { renderGlobalActionsRow } from './message/GlobalActionsRow.js'; import { renderStructuredResponse as renderStructuredResponseUI } from './message/StructuredResponseRender.js'; @@ -34,7 +35,7 @@ import './TodoListDisplay.js'; import './FileListDisplay.js'; // Shared chat types -import type { ChatMessage, ModelChatMessage, ToolResultMessage, AgentSessionMessage, ImageInputData } from '../models/ChatTypes.js'; +import type { ChatMessage, ModelChatMessage, ToolResultMessage, AgentSessionMessage, ImageInputData, ApprovalRequestMessage as ApprovalRequestMessageType } from '../models/ChatTypes.js'; import { ChatMessageEntity, State } from '../models/ChatTypes.js'; const logger = createLogger('ChatView'); @@ -747,6 +748,11 @@ export class ChatView extends HTMLElement { `; } + case ChatMessageEntity.APPROVAL_REQUEST: + { + const approvalMessage = message as ApprovalRequestMessageType; + return renderApprovalRequestMessage(approvalMessage); + } default: // Should not happen, but render a fallback return html`
Unknown message type
`; diff --git a/front_end/panels/ai_chat/ui/FileContentViewer.ts b/front_end/panels/ai_chat/ui/FileContentViewer.ts index 613a805bbb..fcfe00e80d 100644 --- a/front_end/panels/ai_chat/ui/FileContentViewer.ts +++ b/front_end/panels/ai_chat/ui/FileContentViewer.ts @@ -5,11 +5,30 @@ import { createLogger } from '../core/Logger.js'; import type { FileSummary } from '../tools/FileStorageManager.js'; import * as Marked from '../../../third_party/marked/marked.js'; +import * as SDK from '../../../core/sdk/sdk.js'; const logger = createLogger('FileContentViewer'); type FileType = 'code' | 'json' | 'markdown' | 'text' | 'html' | 'css'; +/** + * Options for FileContentViewer + */ +export interface FileContentViewerOptions { + /** Whether the content can be edited */ + editable?: boolean; + /** Callback when content is saved (only used if editable is true) */ + onSave?: (newContent: string) => Promise; +} + +/** + * Result returned from FileContentViewer.show() + */ +export interface FileContentViewerResult { + /** The unique ID of the webapp iframe for CDP operations */ + webappId: string; +} + /** * FileContentViewer - Full-screen file viewer using RenderWebAppTool * @@ -17,20 +36,24 @@ type FileType = 'code' | 'json' | 'markdown' | 'text' | 'html' | 'css'; * - Syntax-aware formatting * - Copy and download functionality * - Clean, modern design + * - Optional edit mode with save callback */ export class FileContentViewer { /** * Display file content in full-screen view + * @returns The webappId for CDP operations (e.g., polling for saves) */ - static async show(file: FileSummary, content: string): Promise { + static async show(file: FileSummary, content: string, options?: FileContentViewerOptions): Promise { try { // Import RenderWebAppTool const { RenderWebAppTool } = await import('../tools/RenderWebAppTool.js'); + const editable = options?.editable ?? false; + // Build viewer components - const viewerHTML = await FileContentViewer.buildHTML(file, content); - const viewerCSS = FileContentViewer.buildCSS(); - const viewerJS = FileContentViewer.buildJS(file.fileName, content); + const viewerHTML = await FileContentViewer.buildHTML(file, content, editable); + const viewerCSS = FileContentViewer.buildCSS(editable); + const viewerJS = FileContentViewer.buildJS(file.fileName, content, editable); // Use RenderWebAppTool to display full-screen viewer const tool = new RenderWebAppTool(); @@ -43,15 +66,64 @@ export class FileContentViewer { if ('error' in result) { logger.error('Failed to open file viewer:', result.error); - } else { - logger.info('File viewer opened successfully', { fileName: file.fileName }); + return null; } + + logger.info('File viewer opened successfully', { fileName: file.fileName, editable, webappId: result.webappId }); + return { webappId: result.webappId }; } catch (error) { logger.error('Error opening file viewer:', error); throw error; } } + /** + * Poll for saved content from an editable viewer iframe + * Uses CDP to read data attributes from the iframe's body + * @param webappId The ID of the webapp iframe + * @returns The saved content if available, null otherwise + */ + static async checkForSavedContent(webappId: string): Promise { + try { + const target = SDK.TargetManager.TargetManager.instance().primaryPageTarget(); + if (!target) { + return null; + } + + const runtimeAgent = target.runtimeAgent(); + + // Check for data-memory-saved attribute and retrieve content + const result = await runtimeAgent.invoke_evaluate({ + expression: ` + (() => { + const iframe = document.getElementById('${webappId}'); + if (!iframe) return null; + try { + const doc = iframe.contentDocument || iframe.contentWindow.document; + if (doc.body.getAttribute('data-memory-saved') === 'true') { + const encoded = doc.body.getAttribute('data-memory-content'); + if (encoded) { + // Decode base64 with Unicode support + return decodeURIComponent(escape(atob(encoded))); + } + } + } catch (e) { + // Cross-origin access denied + return null; + } + return null; + })() + `, + returnByValue: true, + }); + + return result.result?.value || null; + } catch (error) { + logger.error('Error checking for saved content:', error); + return null; + } + } + /** * Detect file type based on extension */ @@ -240,7 +312,7 @@ export class FileContentViewer { /** * Build HTML structure */ - private static async buildHTML(file: FileSummary, content: string): Promise { + private static async buildHTML(file: FileSummary, content: string, editable: boolean = false): Promise { const fileType = FileContentViewer.detectFileType(file.fileName); const icon = FileContentViewer.getFileIcon(fileType); const typeLabel = FileContentViewer.getFileTypeLabel(fileType); @@ -257,7 +329,7 @@ export class FileContentViewer { // For markdown: hidden div with original source + visible rendered HTML contentHTML = ` -
${sanitizedHTML}
+
${sanitizedHTML}
`; } else { // For code files: use escapeHTML helper and add id @@ -265,8 +337,12 @@ export class FileContentViewer { contentHTML = `
${safeContent}
`; } + // Content click handler for editable mode + const contentClickHandler = editable ? 'onclick="enterEditMode()"' : ''; + const contentCursor = editable ? 'cursor: pointer;' : ''; + return ` -
+
@@ -298,11 +374,18 @@ export class FileContentViewer { Download +
-
+
${contentHTML}
@@ -312,7 +395,76 @@ export class FileContentViewer { /** * Build CSS styles */ - private static buildCSS(): string { + private static buildCSS(editable: boolean = false): string { + const editStyles = editable ? ` + /* Click-to-edit styles */ + .content-container { + transition: background-color 0.2s ease; + } + + .content-container:hover .markdown-content, + .content-container:hover .file-content { + background: rgba(25, 118, 210, 0.04); + } + + /* Saved indicator */ + .save-indicator { + display: flex; + align-items: center; + gap: 4px; + padding: 8px 12px; + font-size: 13px; + font-weight: 500; + color: #4caf50; + opacity: 0; + transition: opacity 0.3s ease; + } + + .save-indicator.visible { + opacity: 1; + } + + .edit-textarea { + width: calc(100% - 48px); + height: calc(100vh - 140px); + padding: 32px; + font-family: 'SF Mono', 'Monaco', 'Menlo', 'Consolas', 'Courier New', monospace; + font-size: 14px; + line-height: 1.6; + color: #202124; + border: 2px solid #1976d2; + resize: none; + background: rgba(255, 255, 255, 0.98); + margin: 24px; + border-radius: 16px; + box-sizing: border-box; + outline: none; + } + + .edit-textarea:focus { + border-color: #1565c0; + box-shadow: 0 0 0 3px rgba(25, 118, 210, 0.2); + } + + @media (prefers-color-scheme: dark) { + .content-container:hover .markdown-content, + .content-container:hover .file-content { + background: rgba(100, 181, 246, 0.08); + } + + .edit-textarea { + background: rgba(41, 42, 45, 0.98); + color: #e8eaed; + border-color: #1976d2; + } + + .edit-textarea:focus { + border-color: #64b5f6; + box-shadow: 0 0 0 3px rgba(100, 181, 246, 0.2); + } + } + ` : ''; + return ` * { margin: 0; @@ -451,6 +603,19 @@ export class FileContentViewer { box-shadow: 0 6px 16px rgba(25, 118, 210, 0.4); } + .close-btn { + background: rgba(244, 67, 54, 0.1); + border-color: rgba(244, 67, 54, 0.3); + color: #f44336; + } + + .close-btn:hover { + background: #f44336; + border-color: #f44336; + color: white; + box-shadow: 0 4px 12px rgba(244, 67, 54, 0.3); + } + /* Content */ .content-container { overflow: auto; @@ -681,6 +846,18 @@ export class FileContentViewer { background: linear-gradient(135deg, #1565c0, #0d47a1); } + .close-btn { + background: rgba(244, 67, 54, 0.15); + border-color: rgba(244, 67, 54, 0.4); + color: #ef5350; + } + + .close-btn:hover { + background: #f44336; + border-color: #f44336; + color: white; + } + .content-container { background: transparent; } @@ -776,13 +953,15 @@ export class FileContentViewer { background: rgba(255, 255, 255, 0.3); } } + + ${editStyles} `; } /** * Build JavaScript functionality */ - private static buildJS(fileName: string, content: string): string { + private static buildJS(fileName: string, content: string, editable: boolean = false): string { return ` const FILE_NAME = ${JSON.stringify(fileName)}; @@ -794,7 +973,9 @@ export class FileContentViewer { const originalText = textSpan.textContent; try { - const content = document.getElementById('file-content').textContent; + // If in edit mode, use textarea value; otherwise use file-content + const editTextarea = document.getElementById('edit-textarea'); + const content = editTextarea ? editTextarea.value : document.getElementById('file-content').textContent; // Try modern Clipboard API first if (navigator.clipboard && navigator.clipboard.writeText) { @@ -838,7 +1019,9 @@ export class FileContentViewer { window.downloadFile = function(event) { event.preventDefault(); try { - const content = document.getElementById('file-content').textContent; + // If in edit mode, use textarea value; otherwise use file-content + const editTextarea = document.getElementById('edit-textarea'); + const content = editTextarea ? editTextarea.value : document.getElementById('file-content').textContent; const blob = new Blob([content], { type: 'text/plain;charset=utf-8' }); const url = URL.createObjectURL(blob); const a = document.createElement('a'); @@ -854,9 +1037,83 @@ export class FileContentViewer { } }; + window.closeViewer = function(event) { + event.preventDefault(); + // Remove the iframe from the page + const iframe = window.frameElement; + if (iframe) { + iframe.remove(); + } + }; + // Prevent default drag and drop document.addEventListener('dragover', (e) => e.preventDefault()); document.addEventListener('drop', (e) => e.preventDefault()); - `; + ` + (editable ? ` + // Click-to-edit functionality with auto-save + let isEditing = false; + + window.enterEditMode = function() { + if (isEditing) return; // Already in edit mode + isEditing = true; + + const contentMain = document.getElementById('content-main'); + const markdownView = document.getElementById('markdown-view'); + const fileContent = document.getElementById('file-content'); + const originalContent = fileContent ? fileContent.textContent : ''; + + // Hide rendered markdown/code view + if (markdownView) { + markdownView.style.display = 'none'; + } + if (fileContent && fileContent.tagName === 'PRE') { + fileContent.style.display = 'none'; + } + + // Create and show textarea + const textarea = document.createElement('textarea'); + textarea.id = 'edit-textarea'; + textarea.className = 'edit-textarea'; + textarea.value = originalContent; + contentMain.appendChild(textarea); + textarea.focus(); + + // Remove click handler and cursor style + contentMain.onclick = null; + contentMain.style.cursor = 'default'; + + // Auto-save on input (debounced 500ms) + let saveTimeout = null; + textarea.addEventListener('input', function() { + if (saveTimeout) clearTimeout(saveTimeout); + saveTimeout = setTimeout(() => { + try { + // Store in data attribute for CDP retrieval + document.body.setAttribute('data-memory-saved', 'true'); + document.body.setAttribute('data-memory-content', btoa(unescape(encodeURIComponent(textarea.value)))); + // Show saved indicator + showSavedIndicator(); + } catch (error) { + console.error('Auto-save failed:', error); + } + }, 500); + }); + }; + + // Show "Saved" indicator in header + function showSavedIndicator() { + let indicator = document.getElementById('save-indicator'); + if (!indicator) { + indicator = document.createElement('div'); + indicator.id = 'save-indicator'; + indicator.className = 'save-indicator'; + indicator.innerHTML = '✓ Saved'; + const headerActions = document.querySelector('.header-actions'); + if (headerActions) headerActions.prepend(indicator); + } + indicator.classList.add('visible'); + setTimeout(() => indicator.classList.remove('visible'), 1500); + } + ` : ''); } } diff --git a/front_end/panels/ai_chat/ui/LiveAgentSessionComponent.ts b/front_end/panels/ai_chat/ui/LiveAgentSessionComponent.ts index 92b5ecc1d5..3902d59747 100644 --- a/front_end/panels/ai_chat/ui/LiveAgentSessionComponent.ts +++ b/front_end/panels/ai_chat/ui/LiveAgentSessionComponent.ts @@ -4,8 +4,9 @@ import * as Lit from '../../../ui/lit/lit.js'; import { createLogger } from '../core/Logger.js'; -import type { AgentSession, AgentMessage } from '../agent_framework/AgentSessionTypes.js'; +import type { AgentSession, AgentMessage, ApprovalRequestContent } from '../agent_framework/AgentSessionTypes.js'; import { getAgentUIConfig } from '../agent_framework/AgentSessionTypes.js'; +import { getGuardrailMiddleware } from '../guardrails/index.js'; import { ToolCallComponent } from './ToolCallComponent.js'; import { AgentSessionHeaderComponent } from './AgentSessionHeaderComponent.js'; import { ToolDescriptionFormatter } from './ToolDescriptionFormatter.js'; @@ -385,6 +386,107 @@ export class LiveAgentSessionComponent extends HTMLElement { margin-bottom: 8px; color: var(--sys-color-on-surface); } + + /* Approval request styles */ + .tool-status-marker.pending { + color: #f5a623; + animation: dotPulse 1.5s ease-in-out infinite; + } + .tool-status-marker.approved { + color: var(--sys-color-green-bright); + } + .tool-status-marker.rejected { + color: var(--sys-color-error); + } + + .risk-badge { + padding: 1px 4px; + border-radius: 3px; + font-size: 9px; + font-weight: 600; + text-transform: uppercase; + margin-right: 4px; + } + .risk-none { background: #e8f5e9; color: #2e7d32; } + .risk-low { background: #e3f2fd; color: #1565c0; } + .risk-medium { background: #fff3e0; color: #e65100; } + .risk-high { background: #ffebee; color: #c62828; } + .risk-critical { background: #f3e5f5; color: #6a1b9a; } + + .approval-details { + margin-top: 4px; + padding-left: 16px; + } + + .approval-description { + color: var(--sys-color-on-surface-variant); + font-size: 12px; + margin-bottom: 4px; + } + + .approval-tool-info { + margin: 4px 0; + padding: 4px 6px; + background: var(--sys-color-surface); + border-radius: 4px; + border: 1px solid var(--sys-color-divider); + } + + .tool-args-preview { + font-family: var(--monospace-font-family); + font-size: 10px; + max-height: 60px; + overflow-y: auto; + white-space: pre-wrap; + word-break: break-all; + color: var(--sys-color-on-surface-variant); + margin: 0; + } + + .approval-actions { + display: flex; + gap: 8px; + margin-top: 6px; + align-items: center; + } + + .approve-btn, .reject-btn { + padding: 3px 10px; + border-radius: 4px; + font-size: 11px; + font-weight: 500; + cursor: pointer; + transition: all 0.15s ease; + } + + .approve-btn { + background: var(--sys-color-green-bright); + color: white; + border: none; + } + + .approve-btn:hover { + filter: brightness(0.9); + } + + .reject-btn { + background: transparent; + color: var(--sys-color-error); + border: 1px solid var(--sys-color-error); + } + + .reject-btn:hover { + background: var(--sys-color-error); + color: white; + } + + .feedback-display { + margin-top: 2px; + padding-left: 16px; + font-size: 11px; + font-style: italic; + color: var(--sys-color-on-surface-variant); + }
${reasoningHtml} @@ -413,6 +515,28 @@ export class LiveAgentSessionComponent extends HTMLElement { }); } + // Wire approval button clicks to GuardrailMiddleware + const approveButtons = this.shadow.querySelectorAll('.approve-btn[data-approval-id]'); + const rejectButtons = this.shadow.querySelectorAll('.reject-btn[data-approval-id]'); + + approveButtons.forEach(btn => { + btn.addEventListener('click', () => { + const approvalId = btn.getAttribute('data-approval-id'); + if (approvalId) { + getGuardrailMiddleware().approve(approvalId); + } + }); + }); + + rejectButtons.forEach(btn => { + btn.addEventListener('click', () => { + const approvalId = btn.getAttribute('data-approval-id'); + if (approvalId) { + getGuardrailMiddleware().reject(approvalId); + } + }); + }); + // No expand/collapse listeners for reasoning (simplified) // Render nested sessions as real elements with live interactivity @@ -480,9 +604,54 @@ export class LiveAgentSessionComponent extends HTMLElement { if (rc?.toolCallId) resultMap.set(rc.toolCallId, rc); } - // Walk messages in order and render tool calls and inline handoff anchors + // Walk messages in order and render tool calls, approval requests, and inline handoff anchors for (const m of (this._session?.messages || [])) { - if (m.type === 'tool_call') { + if (m.type === 'approval_request') { + // Render approval request inline in timeline + const content = m.content as ApprovalRequestContent; + const status = content.status; + const statusMarkerClass = status === 'pending' ? 'pending' : + status === 'approved' ? 'approved' : 'rejected'; + const statusLabel = status === 'pending' ? 'Approval Required' : + status === 'approved' ? 'Approved' : 'Rejected'; + const riskClass = `risk-${content.riskLevel}`; + const riskLabel = content.riskLevel === 'none' ? 'Safe' : + content.riskLevel === 'low' ? 'Low' : + content.riskLevel === 'medium' ? 'Med' : + content.riskLevel === 'high' ? 'High' : 'Crit'; + const icon = ToolDescriptionFormatter.getToolIcon(content.toolName); + const toolNameDisplay = ToolDescriptionFormatter.formatToolName(content.toolName); + + html += ` +
+
+
+ ${statusLabel} + + ${status === 'pending' ? `${riskLabel}` : ''} + ${icon} ${toolNameDisplay} + +
+ +
+ ${status === 'pending' ? ` +
+
${content.description}
+
+
${JSON.stringify(content.toolArgs, null, 2)}
+
+
+ + +
+
+ ` : ''} + ${status === 'rejected' && content.feedback ? ` + + ` : ''} +
+ `; + } else if (m.type === 'tool_call') { const toolContent = m.content as any; const toolName = toolContent.toolName; const toolArgs = toolContent.toolArgs || {}; diff --git a/front_end/panels/ai_chat/ui/OnboardingDialog.ts b/front_end/panels/ai_chat/ui/OnboardingDialog.ts new file mode 100644 index 0000000000..ad919c9624 --- /dev/null +++ b/front_end/panels/ai_chat/ui/OnboardingDialog.ts @@ -0,0 +1,849 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import * as UI from '../../../ui/legacy/legacy.js'; +import * as Geometry from '../../../models/geometry/geometry.js'; +import { applyOnboardingStyles } from './onboardingStyles.js'; +import { LLMProviderRegistry } from '../LLM/LLMProviderRegistry.js'; +import { OpenRouterOAuth } from '../auth/OpenRouterOAuth.js'; +import { createLogger } from '../core/Logger.js'; + +const logger = createLogger('OnboardingDialog'); + +const ONBOARDING_COMPLETE_KEY = 'ai_chat_onboarding_complete'; +const SETUP_SKIPPED_KEY = 'ai_chat_setup_skipped'; + +/** + * Provider information for the onboarding wizard + */ +interface ProviderInfo { + id: string; + name: string; + description: string; + getKeyUrl: string; +} + +const PROVIDERS: ProviderInfo[] = [ + { + id: 'openai', + name: 'OpenAI', + description: 'GPT-4.1 and latest OpenAI models', + getKeyUrl: 'https://platform.openai.com/api-keys', + }, + { + id: 'anthropic', + name: 'Anthropic', + description: 'Claude models with extended thinking', + getKeyUrl: 'https://console.anthropic.com/settings/keys', + }, + { + id: 'googleai', + name: 'Google AI', + description: 'Gemini models for multimodal tasks', + getKeyUrl: 'https://aistudio.google.com/apikey', + }, + { + id: 'groq', + name: 'Groq', + description: 'Ultra-fast inference with Groq hardware', + getKeyUrl: 'https://console.groq.com/keys', + }, + { + id: 'cerebras', + name: 'Cerebras', + description: 'High-performance AI inference', + getKeyUrl: 'https://cloud.cerebras.ai/', + }, + { + id: 'openrouter', + name: 'OpenRouter', + description: 'Access 100+ models with unified API', + getKeyUrl: 'https://openrouter.ai/keys', + }, +]; + +/** + * Feature information for the overview step + */ +interface FeatureInfo { + title: string; + description: string; + icon: string; +} + +const FEATURES: FeatureInfo[] = [ + { + title: 'Multi-Agent Framework', + description: 'Specialized agents automatically handle different tasks like browsing, extraction, and analysis', + icon: '🤝', + }, + { + title: 'Web Automation', + description: 'Click, type, navigate, and interact with any webpage through natural language', + icon: '🌐', + }, + { + title: 'Data Extraction', + description: 'Extract structured data from websites using schemas or natural language descriptions', + icon: '📊', + }, + { + title: 'External Tools (MCP)', + description: 'Connect external tools and data sources via Model Context Protocol', + icon: '🔌', + }, + { + title: 'Conversation History', + description: 'Your conversations persist between sessions for easy reference', + icon: '💬', + }, +]; + +type OnboardingStep = 'welcome' | 'provider' | 'apikey' | 'features' | 'ready'; + +const STEPS: OnboardingStep[] = ['welcome', 'provider', 'apikey', 'features', 'ready']; + +/** + * Onboarding wizard dialog for first-time users + */ +export class OnboardingDialog { + private dialog: UI.Dialog.Dialog | null = null; + private currentStep: OnboardingStep = 'welcome'; + private selectedProvider: ProviderInfo | null = null; + private apiKey: string = ''; + private onComplete: (() => void) | null = null; + + // DOM elements + private contentElement: HTMLElement | null = null; + private stepIndicators: HTMLElement[] = []; + private apiKeyStatusDiv: HTMLElement | null = null; + + // OAuth event handler for cleanup + private handleOAuthSuccess: (() => void) | null = null; + + /** + * Check if onboarding should be shown + */ + static shouldShowOnboarding(): boolean { + return localStorage.getItem(ONBOARDING_COMPLETE_KEY) !== 'true'; + } + + /** + * Check if user skipped setup (for showing banner) + */ + static wasSetupSkipped(): boolean { + return localStorage.getItem(SETUP_SKIPPED_KEY) === 'true'; + } + + /** + * Clear the skipped flag (when user completes setup) + */ + static clearSkippedFlag(): void { + localStorage.removeItem(SETUP_SKIPPED_KEY); + } + + /** + * Show the onboarding dialog + */ + static show(onComplete?: () => void): void { + const instance = new OnboardingDialog(); + instance.onComplete = onComplete || null; + instance.showDialog(); + } + + private showDialog(): void { + this.dialog = new UI.Dialog.Dialog(); + this.dialog.setDimmed(true); + this.dialog.setSizeBehavior(UI.GlassPane.SizeBehavior.SET_EXACT_SIZE); + this.dialog.setMaxContentSize(new Geometry.Size(window.innerWidth, window.innerHeight)); + + const container = document.createElement('div'); + container.className = 'onboarding-dialog'; + this.dialog.contentElement.appendChild(container); + + applyOnboardingStyles(container); + this.buildDialog(container); + + // Setup OAuth success listener for OpenRouter + this.handleOAuthSuccess = () => { + logger.info('OAuth success received, completing onboarding'); + this.complete(); + }; + window.addEventListener('openrouter-oauth-success', this.handleOAuthSuccess); + + this.dialog.show(); + } + + private buildDialog(container: HTMLElement): void { + const dialogContainer = document.createElement('div'); + dialogContainer.className = 'onboarding-container'; + container.appendChild(dialogContainer); + + // Step indicators + const indicatorsContainer = document.createElement('div'); + indicatorsContainer.className = 'step-indicators'; + dialogContainer.appendChild(indicatorsContainer); + + this.stepIndicators = []; + for (let i = 0; i < STEPS.length; i++) { + const indicator = document.createElement('div'); + indicator.className = 'step-indicator'; + indicatorsContainer.appendChild(indicator); + this.stepIndicators.push(indicator); + } + + // Content area + this.contentElement = document.createElement('div'); + this.contentElement.className = 'onboarding-content'; + dialogContainer.appendChild(this.contentElement); + + // Footer + const footer = document.createElement('div'); + footer.className = 'onboarding-footer'; + dialogContainer.appendChild(footer); + + const footerLeft = document.createElement('div'); + footerLeft.className = 'footer-left'; + footer.appendChild(footerLeft); + + const footerRight = document.createElement('div'); + footerRight.className = 'footer-right'; + footer.appendChild(footerRight); + + // Back button + const backButton = document.createElement('button'); + backButton.className = 'btn btn-secondary'; + backButton.textContent = 'Back'; + backButton.style.display = 'none'; + backButton.addEventListener('click', () => this.goBack()); + footerLeft.appendChild(backButton); + + // Next/Done button + const nextButton = document.createElement('button'); + nextButton.className = 'btn btn-primary'; + nextButton.textContent = 'Get Started'; + nextButton.addEventListener('click', () => this.goNext()); + footerRight.appendChild(nextButton); + + // Skip button (in footer, shown on all steps except ready) + const skipButton = document.createElement('button'); + skipButton.className = 'btn btn-text'; + skipButton.textContent = 'Skip'; + skipButton.addEventListener('click', () => this.skipSetup()); + footerLeft.appendChild(skipButton); + + // Store references + (this as any).backButton = backButton; + (this as any).nextButton = nextButton; + (this as any).skipButton = skipButton; + + this.renderCurrentStep(); + } + + private updateStepIndicators(): void { + const currentIndex = STEPS.indexOf(this.currentStep); + this.stepIndicators.forEach((indicator, index) => { + indicator.classList.remove('active', 'completed'); + if (index < currentIndex) { + indicator.classList.add('completed'); + } else if (index === currentIndex) { + indicator.classList.add('active'); + } + }); + } + + private updateButtons(): void { + const backButton = (this as any).backButton as HTMLButtonElement; + const nextButton = (this as any).nextButton as HTMLButtonElement; + const skipButton = (this as any).skipButton as HTMLButtonElement; + + // Show/hide back button + backButton.style.display = this.currentStep === 'welcome' ? 'none' : 'block'; + + // Show/hide skip button (hide on ready step) + skipButton.style.display = this.currentStep === 'ready' ? 'none' : 'block'; + + // Update next button text + switch (this.currentStep) { + case 'welcome': + nextButton.textContent = 'Get Started'; + nextButton.disabled = false; + break; + case 'provider': + nextButton.textContent = 'Next'; + nextButton.disabled = !this.selectedProvider; + break; + case 'apikey': + nextButton.textContent = 'Next'; + nextButton.disabled = false; // Can skip + break; + case 'features': + nextButton.textContent = 'Next'; + nextButton.disabled = false; + break; + case 'ready': + nextButton.textContent = 'Start Chatting'; + nextButton.disabled = false; + break; + } + } + + private renderCurrentStep(): void { + if (!this.contentElement) return; + this.contentElement.innerHTML = ''; + this.updateStepIndicators(); + this.updateButtons(); + + switch (this.currentStep) { + case 'welcome': + this.renderWelcomeStep(); + break; + case 'provider': + this.renderProviderStep(); + break; + case 'apikey': + this.renderApiKeyStep(); + break; + case 'features': + this.renderFeaturesStep(); + break; + case 'ready': + this.renderReadyStep(); + break; + } + } + + private renderWelcomeStep(): void { + const content = this.contentElement!; + + // Browser Operator logo + const logoContainer = document.createElement('div'); + logoContainer.className = 'welcome-icon'; + const logo = document.createElement('img'); + logo.src = '/bundled/Images/browser-operator-logo.png'; + logo.alt = 'Browser Operator'; + logo.style.cssText = 'width: 64px; height: 64px; border-radius: 12px;'; + logoContainer.appendChild(logo); + content.appendChild(logoContainer); + + const title = document.createElement('h2'); + title.className = 'step-title'; + title.textContent = 'Welcome to Browser Operator'; + content.appendChild(title); + + const description = document.createElement('p'); + description.className = 'step-description'; + description.textContent = 'Your intelligent partner for research, analysis, and automation.'; + content.appendChild(description); + + // Demo gif with link to docs + const demoContainer = document.createElement('div'); + demoContainer.className = 'video-placeholder'; + demoContainer.style.cssText = 'border: none; background: transparent; display: flex; flex-direction: column; align-items: center;'; + + const demoGif = document.createElement('img'); + demoGif.src = '/bundled/Images/demo.gif'; + demoGif.alt = 'Browser Operator Demo'; + demoGif.style.cssText = 'width: 100%; max-width: 400px; border-radius: 8px;'; + demoContainer.appendChild(demoGif); + + const clickLink = document.createElement('a'); + clickLink.href = 'https://docs.browseroperator.io/getting-started/'; + clickLink.target = '_top'; + clickLink.textContent = 'View the getting started guide →'; + clickLink.style.cssText = 'display: block; margin-top: 12px; color: var(--color-primary); text-decoration: none; font-size: 13px;'; + clickLink.addEventListener('mouseenter', () => { clickLink.style.textDecoration = 'underline'; }); + clickLink.addEventListener('mouseleave', () => { clickLink.style.textDecoration = 'none'; }); + demoContainer.appendChild(clickLink); + + content.appendChild(demoContainer); + } + + private renderProviderStep(): void { + const content = this.contentElement!; + + const title = document.createElement('h2'); + title.className = 'step-title'; + title.textContent = 'Choose Your AI Provider'; + content.appendChild(title); + + const description = document.createElement('p'); + description.className = 'step-description'; + description.textContent = 'Select the AI provider you\'d like to use. You can change this later in Settings.'; + content.appendChild(description); + + const grid = document.createElement('div'); + grid.className = 'provider-grid'; + content.appendChild(grid); + + for (const provider of PROVIDERS) { + const card = document.createElement('div'); + card.className = 'provider-card'; + if (this.selectedProvider?.id === provider.id) { + card.classList.add('selected'); + } + + card.addEventListener('click', () => { + this.selectedProvider = provider; + // Update selection visually + grid.querySelectorAll('.provider-card').forEach(c => c.classList.remove('selected')); + card.classList.add('selected'); + // Auto-advance to next step + this.goNext(); + }); + + const header = document.createElement('div'); + header.className = 'provider-card-header'; + + const name = document.createElement('span'); + name.className = 'provider-name'; + name.textContent = provider.name; + header.appendChild(name); + + card.appendChild(header); + + const desc = document.createElement('div'); + desc.className = 'provider-description'; + desc.textContent = provider.description; + card.appendChild(desc); + + grid.appendChild(card); + } + } + + private renderApiKeyStep(): void { + const content = this.contentElement!; + const provider = this.selectedProvider!; + + const title = document.createElement('h2'); + title.className = 'step-title'; + title.textContent = `Set Up ${provider.name}`; + content.appendChild(title); + + const description = document.createElement('p'); + description.className = 'step-description'; + + // OpenRouter: OAuth only (no API key input) + if (provider.id === 'openrouter') { + description.textContent = 'Sign in with your OpenRouter account to get started.'; + content.appendChild(description); + + const oauthButton = document.createElement('button'); + oauthButton.className = 'btn btn-primary'; + oauthButton.textContent = 'Sign in with OpenRouter'; + oauthButton.style.cssText = 'background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); width: 100%; max-width: 300px; margin-top: 16px;'; + oauthButton.addEventListener('click', async () => { + oauthButton.disabled = true; + oauthButton.textContent = 'Redirecting to OpenRouter...'; + try { + await OpenRouterOAuth.startAuthFlow(); + // Success is handled by the event listener in showDialog + } catch (error) { + oauthButton.disabled = false; + oauthButton.textContent = 'Sign in with OpenRouter'; + logger.error('OpenRouter OAuth failed:', error); + } + }); + content.appendChild(oauthButton); + + // Link to getting started guide (for OpenRouter too) + const guideLink = document.createElement('a'); + guideLink.href = 'https://docs.browseroperator.io/getting-started/'; + guideLink.target = '_top'; + guideLink.textContent = 'View the getting started guide →'; + guideLink.style.cssText = 'display: block; margin-top: 24px; color: var(--color-primary); text-decoration: none; font-size: 13px; text-align: center;'; + guideLink.addEventListener('mouseenter', () => { guideLink.style.textDecoration = 'underline'; }); + guideLink.addEventListener('mouseleave', () => { guideLink.style.textDecoration = 'none'; }); + content.appendChild(guideLink); + + return; // Don't show API key form for OpenRouter + } + + // Other providers: show API key form + description.textContent = `Enter your ${provider.name} API key to get started. Your key is stored locally and never sent to our servers.`; + content.appendChild(description); + + const form = document.createElement('div'); + form.className = 'api-key-form'; + content.appendChild(form); + + // API Key input + const formGroup = document.createElement('div'); + formGroup.className = 'form-group'; + form.appendChild(formGroup); + + const label = document.createElement('label'); + label.className = 'form-label'; + label.textContent = 'API Key'; + formGroup.appendChild(label); + + const input = document.createElement('input'); + input.className = 'form-input'; + input.type = 'password'; + input.placeholder = 'Enter your API key'; + input.value = this.apiKey; + input.addEventListener('input', (e) => { + this.apiKey = (e.target as HTMLInputElement).value; + }); + formGroup.appendChild(input); + + const hint = document.createElement('div'); + hint.className = 'form-hint'; + hint.innerHTML = `Don't have an API key? Get one here`; + formGroup.appendChild(hint); + + // Test button + const testButton = document.createElement('button'); + testButton.className = 'test-button'; + testButton.textContent = 'Test Connection'; + testButton.addEventListener('click', () => this.testConnection(testButton, statusDiv)); + form.appendChild(testButton); + + // Status message (also used for inline errors from goNext) + const statusDiv = document.createElement('div'); + statusDiv.className = 'test-status'; + form.appendChild(statusDiv); + this.apiKeyStatusDiv = statusDiv; + + // Link to getting started guide + const guideLink = document.createElement('a'); + guideLink.href = 'https://docs.browseroperator.io/getting-started/'; + guideLink.target = '_top'; + guideLink.textContent = 'View the getting started guide →'; + guideLink.style.cssText = 'display: block; margin-top: 24px; color: var(--color-primary); text-decoration: none; font-size: 13px; text-align: center;'; + guideLink.addEventListener('mouseenter', () => { guideLink.style.textDecoration = 'underline'; }); + guideLink.addEventListener('mouseleave', () => { guideLink.style.textDecoration = 'none'; }); + content.appendChild(guideLink); + } + + private async testConnection(button: HTMLButtonElement, statusDiv: HTMLElement): Promise { + if (!this.apiKey.trim()) { + statusDiv.className = 'test-status visible error'; + statusDiv.textContent = 'Please enter an API key'; + return; + } + + button.disabled = true; + button.textContent = 'Testing...'; + statusDiv.className = 'test-status'; + + try { + // Save API key temporarily for testing + const provider = this.selectedProvider!; + LLMProviderRegistry.saveProviderApiKey(provider.id as any, this.apiKey); + + // Try to fetch models to verify the key works + const providerInstance = LLMProviderRegistry.getProvider(provider.id as any); + if (providerInstance && typeof (providerInstance as any).fetchModels === 'function') { + await (providerInstance as any).fetchModels(this.apiKey); + } + + statusDiv.className = 'test-status visible success'; + statusDiv.textContent = 'Connection successful! Your API key is valid.'; + logger.info(`API key validated for ${provider.id}`); + } catch (error) { + statusDiv.className = 'test-status visible error'; + statusDiv.textContent = `Connection failed: ${error instanceof Error ? error.message : 'Unknown error'}`; + logger.error('API key validation failed:', error); + } finally { + button.disabled = false; + button.textContent = 'Test Connection'; + } + } + + private renderFeaturesStep(): void { + const content = this.contentElement!; + + const title = document.createElement('h2'); + title.className = 'step-title'; + title.textContent = 'What You Can Do'; + content.appendChild(title); + + const description = document.createElement('p'); + description.className = 'step-description'; + description.textContent = 'Browser Operator Agent comes packed with powerful features to help you with daily tasks from the web.'; + content.appendChild(description); + + const list = document.createElement('div'); + list.className = 'feature-list'; + content.appendChild(list); + + for (const feature of FEATURES) { + const item = document.createElement('div'); + item.className = 'feature-item'; + + const icon = document.createElement('span'); + icon.className = 'feature-icon'; + icon.textContent = feature.icon; + item.appendChild(icon); + + const contentDiv = document.createElement('div'); + contentDiv.className = 'feature-content'; + + const featureTitle = document.createElement('h4'); + featureTitle.className = 'feature-title'; + featureTitle.textContent = feature.title; + contentDiv.appendChild(featureTitle); + + const featureDesc = document.createElement('p'); + featureDesc.className = 'feature-description'; + featureDesc.textContent = feature.description; + contentDiv.appendChild(featureDesc); + + item.appendChild(contentDiv); + list.appendChild(item); + } + } + + private renderReadyStep(): void { + const content = this.contentElement!; + + const icon = document.createElement('div'); + icon.className = 'ready-icon'; + icon.textContent = '🎉'; + content.appendChild(icon); + + const title = document.createElement('h2'); + title.className = 'step-title'; + title.textContent = 'You\'re All Set!'; + content.appendChild(title); + + const description = document.createElement('p'); + description.className = 'step-description'; + description.textContent = 'You\'re ready to start using Browser Operator. Here are some quick tips to get you started:'; + content.appendChild(description); + + const tipsList = document.createElement('div'); + tipsList.className = 'tips-list'; + content.appendChild(tipsList); + + const tips = [ + { icon: '💡', text: 'Type naturally - describe what you want to do' }, + { icon: '⚙️', text: 'Click the gear icon for advanced settings' }, + { icon: '🔧', text: 'Configure MCP to add external tools' }, + { icon: '📜', text: 'Your conversation history is saved automatically' }, + ]; + + for (const tip of tips) { + const item = document.createElement('div'); + item.className = 'tip-item'; + + const tipIcon = document.createElement('span'); + tipIcon.className = 'tip-icon'; + tipIcon.textContent = tip.icon; + item.appendChild(tipIcon); + + const tipText = document.createElement('span'); + tipText.textContent = tip.text; + item.appendChild(tipText); + + tipsList.appendChild(item); + } + } + + private goBack(): void { + const currentIndex = STEPS.indexOf(this.currentStep); + if (currentIndex > 0) { + this.currentStep = STEPS[currentIndex - 1]; + this.renderCurrentStep(); + } + } + + private async goNext(): Promise { + const currentIndex = STEPS.indexOf(this.currentStep); + + // Validation for provider step + if (this.currentStep === 'provider' && !this.selectedProvider) { + return; + } + + // Require and test API key before advancing from apikey step (non-OpenRouter) + if (this.currentStep === 'apikey' && this.selectedProvider && this.selectedProvider.id !== 'openrouter') { + const statusDiv = this.apiKeyStatusDiv; + + // Require API key + if (!this.apiKey.trim()) { + if (statusDiv) { + statusDiv.className = 'test-status visible error'; + statusDiv.textContent = 'Please enter an API key'; + setTimeout(() => { + statusDiv.className = 'test-status'; + statusDiv.textContent = ''; + }, 5000); + } + return; + } + + const nextButton = (this as any).nextButton as HTMLButtonElement; + nextButton.disabled = true; + nextButton.textContent = 'Testing...'; + + // Clear any previous error + if (statusDiv) { + statusDiv.className = 'test-status'; + statusDiv.textContent = ''; + } + + try { + const provider = this.selectedProvider; + + // Use LLMProviderRegistry.testProviderConnection which works for all providers + const result = await LLMProviderRegistry.testProviderConnection( + provider.id as any, + this.apiKey + ); + + if (!result.success) { + throw new Error(result.message); + } + + // Success - save API key and configuration + LLMProviderRegistry.saveProviderApiKey(provider.id as any, this.apiKey); + this.saveConfiguration(); + logger.info(`API key validated for ${provider.id}`); + + // Show success screen and auto-close after 5 seconds + this.showSuccessAndClose(); + return; + } catch (error) { + // Failed - show inline error, don't advance + nextButton.disabled = false; + nextButton.textContent = 'Next'; + const errorMsg = error instanceof Error ? error.message : 'Connection failed'; + logger.error('API key validation failed:', error); + if (statusDiv) { + statusDiv.className = 'test-status visible error'; + statusDiv.textContent = `Invalid API key: ${errorMsg}`; + setTimeout(() => { + statusDiv.className = 'test-status'; + statusDiv.textContent = ''; + }, 5000); + } + return; + } + + nextButton.disabled = false; + nextButton.textContent = 'Next'; + } + + if (currentIndex < STEPS.length - 1) { + this.currentStep = STEPS[currentIndex + 1]; + this.renderCurrentStep(); + } else { + // Complete onboarding + this.complete(); + } + } + + private skipSetup(): void { + // Mark as skipped + localStorage.setItem(SETUP_SKIPPED_KEY, 'true'); + // Complete onboarding immediately + this.complete(); + } + + private showSuccessAndClose(): void { + // Jump to ready step to show the existing UI + this.currentStep = 'ready'; + this.renderCurrentStep(); + + // Hide footer buttons + const footer = this.contentElement?.parentElement?.querySelector('.onboarding-footer') as HTMLElement; + if (footer) { + footer.style.display = 'none'; + } + + // Add loading indicator to content + const content = this.contentElement!; + const loadingDiv = document.createElement('div'); + loadingDiv.style.cssText = 'margin-top: 24px; text-align: center; color: var(--color-text-secondary);'; + + const loadingText = document.createElement('div'); + loadingText.style.marginBottom = '8px'; + loadingText.textContent = 'Starting Browser Operator Agent...'; + loadingDiv.appendChild(loadingText); + + const spinner = document.createElement('div'); + spinner.className = 'loading-spinner'; + // Inline styles as fallback in case CSS class doesn't load + spinner.style.cssText = 'width: 24px; height: 24px; border: 3px solid rgba(128, 128, 128, 0.3); border-top-color: var(--color-primary, #00a4fe); border-radius: 50%; animation: spin 1s linear infinite; margin: 0 auto;'; + loadingDiv.appendChild(spinner); + + // Add keyframes animation inline + const styleEl = document.createElement('style'); + styleEl.textContent = '@keyframes spin { to { transform: rotate(360deg); } }'; + loadingDiv.appendChild(styleEl); + + content.appendChild(loadingDiv); + + // Auto-close after 5 seconds + setTimeout(() => { + this.complete(); + }, 5000); + } + + private saveConfiguration(): void { + if (!this.selectedProvider || !this.apiKey.trim()) return; + + const provider = this.selectedProvider; + + // Save provider selection + localStorage.setItem('ai_chat_provider', provider.id); + + // Save API key + LLMProviderRegistry.saveProviderApiKey(provider.id as any, this.apiKey); + + logger.info(`Saved configuration for provider: ${provider.id}`); + } + + private complete(): void { + // Mark onboarding as complete + localStorage.setItem(ONBOARDING_COMPLETE_KEY, 'true'); + + // Clear skipped flag if they completed with an API key + if (this.apiKey.trim()) { + localStorage.removeItem(SETUP_SKIPPED_KEY); + } + + // Clean up OAuth listener + if (this.handleOAuthSuccess) { + window.removeEventListener('openrouter-oauth-success', this.handleOAuthSuccess); + this.handleOAuthSuccess = null; + } + + // Close dialog + if (this.dialog) { + this.dialog.hide(); + this.dialog = null; + } + + // Call completion callback + if (this.onComplete) { + this.onComplete(); + } + + logger.info('Onboarding completed'); + } +} + +/** + * Create and return a setup required banner element + */ +export function createSetupRequiredBanner(onSettingsClick: () => void): HTMLElement { + const banner = document.createElement('div'); + banner.className = 'setup-required-banner'; + + const text = document.createElement('div'); + text.className = 'setup-banner-text'; + text.innerHTML = '⚠️ API key not configured. Set up a provider to start chatting.'; + banner.appendChild(text); + + const button = document.createElement('button'); + button.className = 'setup-banner-button'; + button.textContent = 'Open Settings'; + button.addEventListener('click', onSettingsClick); + banner.appendChild(button); + + return banner; +} diff --git a/front_end/panels/ai_chat/ui/SettingsDialog.ts b/front_end/panels/ai_chat/ui/SettingsDialog.ts index d593bcb1ff..3cd9f40cfb 100644 --- a/front_end/panels/ai_chat/ui/SettingsDialog.ts +++ b/front_end/panels/ai_chat/ui/SettingsDialog.ts @@ -32,6 +32,7 @@ import { BrowsingHistorySettings } from './settings/advanced/BrowsingHistorySett import { VectorDBSettings } from './settings/advanced/VectorDBSettings.js'; import { TracingSettings } from './settings/advanced/TracingSettings.js'; import { EvaluationSettings } from './settings/advanced/EvaluationSettings.js'; +import { MemorySettings } from './settings/advanced/MemorySettings.js'; import './model_selector/ModelSelector.js'; @@ -463,6 +464,10 @@ export class SettingsDialog { advancedToggleContainer.appendChild(advancedToggleLabel); // Create advanced feature sections + const memorySection = document.createElement('div'); + memorySection.className = 'settings-section memory-section'; + contentDiv.appendChild(memorySection); + const historySection = document.createElement('div'); historySection.className = 'settings-section history-section'; contentDiv.appendChild(historySection); @@ -480,12 +485,14 @@ export class SettingsDialog { contentDiv.appendChild(evaluationSection); // Instantiate advanced feature settings classes + const memorySettings = new MemorySettings(memorySection); const browsingHistorySettings = new BrowsingHistorySettings(historySection); const vectorDBSettings = new VectorDBSettings(vectorDBSection); const tracingSettings = new TracingSettings(tracingSection); const evaluationSettings = new EvaluationSettings(evaluationSection); // Render advanced features + memorySettings.render(); browsingHistorySettings.render(); vectorDBSettings.render(); tracingSettings.render(); @@ -493,6 +500,7 @@ export class SettingsDialog { // Store advanced features for cleanup const advancedFeatures = [ + memorySettings, browsingHistorySettings, vectorDBSettings, tracingSettings, @@ -503,6 +511,7 @@ export class SettingsDialog { // Advanced Settings Toggle Logic function toggleAdvancedSections(show: boolean): void { const display = show ? 'block' : 'none'; + memorySection.style.display = display; historySection.style.display = display; vectorDBSection.style.display = display; tracingSection.style.display = display; diff --git a/front_end/panels/ai_chat/ui/__tests__/ChatViewAgentSessions.test.ts b/front_end/panels/ai_chat/ui/__tests__/ChatViewAgentSessions.test.ts index a36544a64c..563aae5ef2 100644 --- a/front_end/panels/ai_chat/ui/__tests__/ChatViewAgentSessions.test.ts +++ b/front_end/panels/ai_chat/ui/__tests__/ChatViewAgentSessions.test.ts @@ -3,6 +3,18 @@ import '../ChatView.js'; import {raf, doubleRaf} from '../../../../testing/DOMHelpers.js'; +// Use global sinon provided by Karma framework +declare const sinon: typeof import('sinon'); + +// Helper to stub fetch for VersionChecker network calls +function stubFetch(): sinon.SinonStub { + return sinon.stub(globalThis, 'fetch').resolves(new Response(JSON.stringify({ + tag_name: 'v1.0.0', + html_url: 'https://example.com/release', + body: 'Test release', + }), { status: 200 })); +} + // Local enums/types to avoid TS enum imports in strip mode const ChatMessageEntity = { USER: 'user', @@ -55,6 +67,10 @@ function queryLive(view: HTMLElement): HTMLElement[] { } describe('ChatView Agent Sessions: nesting & handoffs', () => { + let fetchStub: sinon.SinonStub; + beforeEach(() => { fetchStub = stubFetch(); }); + afterEach(() => { fetchStub.restore(); }); + it('renders nested child session inside parent timeline', async () => { const parent = makeSession('p1', {nestedSessions: [makeSession('c1')]}); const view = document.createElement('devtools-chat-view') as any; @@ -282,6 +298,10 @@ describe('ChatView Agent Sessions: nesting & handoffs', () => { }); describe('ChatView Agent Sessions: pruning and resilience', () => { + let fetchStub: sinon.SinonStub; + beforeEach(() => { fetchStub = stubFetch(); }); + afterEach(() => { fetchStub.restore(); }); + it('reorder does not recreate or prune', async () => { const s1 = makeSession('s1'); const s2 = makeSession('s2'); @@ -309,6 +329,10 @@ describe('ChatView Agent Sessions: pruning and resilience', () => { }); describe('ChatView visibility rules: agent-managed tool calls/results are hidden', () => { + let fetchStub: sinon.SinonStub; + beforeEach(() => { fetchStub = stubFetch(); }); + afterEach(() => { fetchStub.restore(); }); + it('hides model tool call + result for configurable agent; live timeline shows instead', async () => { const view = document.createElement('devtools-chat-view') as any; document.body.appendChild(view); @@ -383,6 +407,10 @@ describe('ChatView visibility rules: agent-managed tool calls/results are hidden }); describe('LiveAgentSessionComponent timeline rendering and interactions', () => { + let fetchStub: sinon.SinonStub; + beforeEach(() => { fetchStub = stubFetch(); }); + afterEach(() => { fetchStub.restore(); }); + it('single tool session shows single-tool mode and hides spine', async () => { const session = makeSession('s1', {messages: [makeToolCall('tc1', 'fetch', {url: 'x'}), makeToolResult('tc1', 'fetch', true, {ok: true})]}); const view = document.createElement('devtools-chat-view') as any; diff --git a/front_end/panels/ai_chat/ui/__tests__/ChatViewAgentSessionsOrder.test.ts b/front_end/panels/ai_chat/ui/__tests__/ChatViewAgentSessionsOrder.test.ts index c1b8b3eb7a..307c1f35f1 100644 --- a/front_end/panels/ai_chat/ui/__tests__/ChatViewAgentSessionsOrder.test.ts +++ b/front_end/panels/ai_chat/ui/__tests__/ChatViewAgentSessionsOrder.test.ts @@ -3,6 +3,18 @@ import '../ChatView.js'; import {raf} from '../../../../testing/DOMHelpers.js'; +// Use global sinon provided by Karma framework +declare const sinon: typeof import('sinon'); + +// Helper to stub fetch for VersionChecker network calls +function stubFetch(): sinon.SinonStub { + return sinon.stub(globalThis, 'fetch').resolves(new Response(JSON.stringify({ + tag_name: 'v1.0.0', + html_url: 'https://example.com/release', + body: 'Test release', + }), { status: 200 })); +} + // Minimal local constants to avoid importing enums in strip mode const ChatMessageEntity = { USER: 'user', @@ -38,6 +50,10 @@ function queryLive(view: HTMLElement): HTMLElement[] { } describe('ChatView Agent Sessions: sequential top-level sessions', () => { + let fetchStub: sinon.SinonStub; + beforeEach(() => { fetchStub = stubFetch(); }); + afterEach(() => { fetchStub.restore(); }); + it('renders two top-level agent sessions in order with first completed and second running', async () => { // First session has a completed tool (call + result) const s1 = makeSession('s1', { diff --git a/front_end/panels/ai_chat/ui/__tests__/ChatViewInputClear.test.ts b/front_end/panels/ai_chat/ui/__tests__/ChatViewInputClear.test.ts index cc32f1ff8f..4829aa761e 100644 --- a/front_end/panels/ai_chat/ui/__tests__/ChatViewInputClear.test.ts +++ b/front_end/panels/ai_chat/ui/__tests__/ChatViewInputClear.test.ts @@ -3,12 +3,28 @@ import '../ChatView.js'; import {raf} from '../../../../testing/DOMHelpers.js'; +// Use global sinon provided by Karma framework +declare const sinon: typeof import('sinon'); + +// Helper to stub fetch for VersionChecker network calls +function stubFetch(): sinon.SinonStub { + return sinon.stub(globalThis, 'fetch').resolves(new Response(JSON.stringify({ + tag_name: 'v1.0.0', + html_url: 'https://example.com/release', + body: 'Test release', + }), { status: 200 })); +} + // Minimal enums const ChatMessageEntity = { USER: 'user' } as const; function makeUser(text: string): any { return { entity: ChatMessageEntity.USER, text } as any; } describe('ChatView input clearing (expanded view)', () => { + let fetchStub: sinon.SinonStub; + beforeEach(() => { fetchStub = stubFetch(); }); + afterEach(() => { fetchStub.restore(); }); + function getTextarea(view: HTMLElement): HTMLTextAreaElement { const shadow = view.shadowRoot!; const bar = shadow.querySelector('ai-input-bar') as HTMLElement; diff --git a/front_end/panels/ai_chat/ui/__tests__/ChatViewPrune.test.ts b/front_end/panels/ai_chat/ui/__tests__/ChatViewPrune.test.ts index 1426792783..759173f6aa 100644 --- a/front_end/panels/ai_chat/ui/__tests__/ChatViewPrune.test.ts +++ b/front_end/panels/ai_chat/ui/__tests__/ChatViewPrune.test.ts @@ -3,6 +3,18 @@ import '../ChatView.js'; import {raf} from '../../../../testing/DOMHelpers.js'; +// Use global sinon provided by Karma framework +declare const sinon: typeof import('sinon'); + +// Helper to stub fetch for VersionChecker network calls +function stubFetch(): sinon.SinonStub { + return sinon.stub(globalThis, 'fetch').resolves(new Response(JSON.stringify({ + tag_name: 'v1.0.0', + html_url: 'https://example.com/release', + body: 'Test release', + }), { status: 200 })); +} + // Minimal local enum constants to avoid importing TS enums in tests const ChatMessageEntity = { USER: 'user', @@ -22,6 +34,10 @@ function makeAgentSessionMessage(sessionId: string): any { } describe('ChatView pruneLiveAgentSessions', () => { + let fetchStub: sinon.SinonStub; + beforeEach(() => { fetchStub = stubFetch(); }); + afterEach(() => { fetchStub.restore(); }); + it('prunes cached sessions when they disappear from messages', async () => { const view = document.createElement('devtools-chat-view') as any; document.body.appendChild(view); diff --git a/front_end/panels/ai_chat/ui/__tests__/ChatViewSequentialSessionsTransition.test.ts b/front_end/panels/ai_chat/ui/__tests__/ChatViewSequentialSessionsTransition.test.ts index a1bd67b2a4..cba6e88a30 100644 --- a/front_end/panels/ai_chat/ui/__tests__/ChatViewSequentialSessionsTransition.test.ts +++ b/front_end/panels/ai_chat/ui/__tests__/ChatViewSequentialSessionsTransition.test.ts @@ -3,6 +3,18 @@ import '../ChatView.js'; import {raf} from '../../../../testing/DOMHelpers.js'; +// Use global sinon provided by Karma framework +declare const sinon: typeof import('sinon'); + +// Helper to stub fetch for VersionChecker network calls +function stubFetch(): sinon.SinonStub { + return sinon.stub(globalThis, 'fetch').resolves(new Response(JSON.stringify({ + tag_name: 'v1.0.0', + html_url: 'https://example.com/release', + body: 'Test release', + }), { status: 200 })); +} + // Minimal local constants to avoid importing enums in strip mode const ChatMessageEntity = { USER: 'user', @@ -38,6 +50,10 @@ function queryLive(view: HTMLElement): HTMLElement[] { } describe('ChatView Agent Sessions: transition from completed to new running session', () => { + let fetchStub: sinon.SinonStub; + beforeEach(() => { fetchStub = stubFetch(); }); + afterEach(() => { fetchStub.restore(); }); + it('renders first completed session then second running session when added to messages', async () => { const s1 = makeSession('s1', { agentReasoning: 'First agent session', diff --git a/front_end/panels/ai_chat/ui/message/ApprovalRequestMessage.ts b/front_end/panels/ai_chat/ui/message/ApprovalRequestMessage.ts new file mode 100644 index 0000000000..a21d442f2a --- /dev/null +++ b/front_end/panels/ai_chat/ui/message/ApprovalRequestMessage.ts @@ -0,0 +1,367 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/** + * ApprovalRequestMessage - Inline UI component for human-in-the-loop approval requests + * + * Uses the same styling as agent-execution-timeline: + * - .timeline-item, .tool-line, .tool-left, .tool-name-badge, .tool-status-marker + * - ● bullet markers for status (pending=orange, approved=green, rejected=red) + * - Compact layout matching tool execution entries + */ + +import * as Lit from '../../../../ui/lit/lit.js'; +import type { ApprovalRequestMessage as ApprovalRequestMessageType, RiskLevel } from '../../models/ChatTypes.js'; +import { getGuardrailMiddleware } from '../../guardrails/index.js'; + +const {html, nothing} = Lit; + +/** + * Get CSS class for risk level badge + */ +function getRiskLevelClass(riskLevel: RiskLevel): string { + switch (riskLevel) { + case 'none': + return 'risk-none'; + case 'low': + return 'risk-low'; + case 'medium': + return 'risk-medium'; + case 'high': + return 'risk-high'; + case 'critical': + return 'risk-critical'; + default: + return 'risk-medium'; + } +} + +/** + * Get human-readable risk level label (short form) + */ +function getRiskLevelLabel(riskLevel: RiskLevel): string { + switch (riskLevel) { + case 'none': + return 'Safe'; + case 'low': + return 'Low'; + case 'medium': + return 'Med'; + case 'high': + return 'High'; + case 'critical': + return 'Crit'; + default: + return '?'; + } +} + +/** + * Render an approval request message (timeline style) + */ +export function renderApprovalRequestMessage( + msg: ApprovalRequestMessageType, + onApprove?: (approvalId: string) => void, + onReject?: (approvalId: string, feedback?: string) => void +): Lit.TemplateResult { + const isPending = msg.status === 'pending'; + const isApproved = msg.status === 'approved'; + const isRejected = msg.status === 'rejected'; + + // Handle approve click + const handleApprove = () => { + if (onApprove) { + onApprove(msg.approvalId); + } else { + getGuardrailMiddleware().approve(msg.approvalId); + } + }; + + // Handle reject click + const handleReject = () => { + if (onReject) { + onReject(msg.approvalId); + } else { + getGuardrailMiddleware().reject(msg.approvalId); + } + }; + + // Handle reject with feedback + const handleRejectWithFeedback = (e: Event) => { + const container = (e.target as HTMLElement).closest('.timeline-item'); + const textarea = container?.querySelector('.feedback-textarea') as HTMLTextAreaElement; + const feedback = textarea?.value || ''; + + if (onReject) { + onReject(msg.approvalId, feedback); + } else { + getGuardrailMiddleware().reject(msg.approvalId, feedback); + } + }; + + // Toggle feedback section visibility + const toggleFeedback = (e: Event) => { + const container = (e.target as HTMLElement).closest('.timeline-item'); + const feedbackSection = container?.querySelector('.feedback-section') as HTMLElement; + if (feedbackSection) { + feedbackSection.style.display = feedbackSection.style.display === 'none' ? 'block' : 'none'; + } + }; + + return html` + + +
+
+
+ + ${isPending ? 'Approval Required' : isApproved ? 'Approved' : 'Rejected'} + + + ${isPending ? html`${getRiskLevelLabel(msg.riskLevel)}` : nothing} + ${msg.toolName} + +
+ +
+ + ${isPending ? html` +
+
${msg.description}
+ +
+
${JSON.stringify(msg.toolArgs, null, 2)}
+
+ +
+ + + +
+ + +
+ ` : nothing} + + ${isRejected && msg.feedback ? html` + + ` : nothing} +
+ `; +} diff --git a/front_end/panels/ai_chat/ui/onboardingStyles.ts b/front_end/panels/ai_chat/ui/onboardingStyles.ts new file mode 100644 index 0000000000..270146b800 --- /dev/null +++ b/front_end/panels/ai_chat/ui/onboardingStyles.ts @@ -0,0 +1,478 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/** + * CSS styles for the onboarding wizard dialog + */ +export function getOnboardingStyles(): string { + return ` + .onboarding-dialog { + width: 100%; + height: 100%; + display: flex; + align-items: center; + justify-content: center; + background: var(--color-background); + color: var(--color-text-primary); + } + + .onboarding-container { + max-width: 600px; + width: 90%; + max-height: 85vh; + display: flex; + flex-direction: column; + background: var(--color-background-elevation-1); + border-radius: 12px; + box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2); + overflow: hidden; + } + + /* Step Indicators */ + .step-indicators { + display: flex; + justify-content: center; + gap: 12px; + padding: 20px; + background: var(--color-background-elevation-2); + border-bottom: 1px solid var(--color-details-hairline); + } + + .step-indicator { + width: 10px; + height: 10px; + border-radius: 50%; + background: var(--color-details-hairline); + transition: all 0.3s ease; + } + + .step-indicator.active { + background: var(--color-primary); + transform: scale(1.2); + } + + .step-indicator.completed { + background: var(--sys-color-accent-green, #4caf50); + } + + /* Content Area */ + .onboarding-content { + flex: 1; + padding: 32px; + overflow-y: auto; + display: flex; + flex-direction: column; + align-items: center; + } + + .step-title { + font-size: 24px; + font-weight: 600; + margin: 0 0 12px 0; + text-align: center; + color: var(--color-text-primary); + } + + .step-description { + font-size: 14px; + color: var(--color-text-secondary); + text-align: center; + margin: 0 0 24px 0; + max-width: 450px; + line-height: 1.5; + } + + /* Welcome Step */ + .welcome-icon { + font-size: 64px; + margin-bottom: 16px; + } + + .video-placeholder { + width: 100%; + max-width: 400px; + aspect-ratio: 16/9; + background: var(--color-background-elevation-0); + border-radius: 8px; + display: flex; + align-items: center; + justify-content: center; + margin-top: 16px; + border: 2px dashed var(--color-details-hairline); + } + + .video-placeholder-text { + color: var(--color-text-secondary); + font-size: 14px; + } + + /* Provider Grid */ + .provider-grid { + display: grid; + grid-template-columns: repeat(2, 1fr); + gap: 12px; + width: 100%; + max-width: 500px; + } + + .provider-card { + padding: 16px; + border: 2px solid var(--color-details-hairline); + border-radius: 8px; + cursor: pointer; + transition: all 0.2s ease; + background: var(--color-background); + text-align: left; + } + + .provider-card:hover { + border-color: var(--color-primary); + transform: translateY(-2px); + box-shadow: 0 4px 12px rgba(0, 164, 254, 0.15); + } + + .provider-card.selected { + border-color: var(--color-primary); + background: var(--color-primary-container, rgba(0, 164, 254, 0.1)); + } + + .provider-card-header { + display: flex; + align-items: center; + gap: 10px; + margin-bottom: 8px; + } + + .provider-icon { + font-size: 24px; + } + + .provider-name { + font-size: 16px; + font-weight: 600; + color: var(--color-text-primary); + } + + .provider-description { + font-size: 12px; + color: var(--color-text-secondary); + line-height: 1.4; + } + + /* API Key Form */ + .api-key-form { + width: 100%; + max-width: 450px; + } + + .form-group { + margin-bottom: 20px; + } + + .form-label { + display: block; + font-size: 14px; + font-weight: 500; + margin-bottom: 8px; + color: var(--color-text-primary); + } + + .form-input { + width: 100%; + padding: 12px 16px; + border: 1px solid var(--color-details-hairline); + border-radius: 6px; + font-size: 14px; + background: var(--color-background); + color: var(--color-text-primary); + box-sizing: border-box; + transition: border-color 0.2s ease; + } + + .form-input:focus { + outline: none; + border-color: var(--color-primary); + box-shadow: 0 0 0 2px rgba(0, 164, 254, 0.2); + } + + .form-hint { + display: flex; + align-items: center; + gap: 8px; + margin-top: 8px; + font-size: 12px; + color: var(--color-text-secondary); + } + + .form-hint a { + color: var(--color-primary); + text-decoration: none; + } + + .form-hint a:hover { + text-decoration: underline; + } + + .test-button { + display: inline-flex; + align-items: center; + gap: 8px; + padding: 10px 20px; + background: var(--color-background-elevation-2); + border: 1px solid var(--color-details-hairline); + border-radius: 6px; + color: var(--color-text-primary); + font-size: 14px; + cursor: pointer; + transition: all 0.2s ease; + } + + .test-button:hover { + background: var(--color-background-elevation-0); + } + + .test-button:disabled { + opacity: 0.5; + cursor: not-allowed; + } + + .test-status { + margin-top: 16px; + padding: 12px 16px; + border-radius: 6px; + font-size: 13px; + display: none; + } + + .test-status.visible { + display: block; + } + + .test-status.success { + background: var(--sys-color-green-container, rgba(76, 175, 80, 0.1)); + color: var(--sys-color-on-green-container, #2e7d32); + border: 1px solid var(--sys-color-accent-green, #4caf50); + } + + .test-status.error { + background: var(--sys-color-error-container, rgba(244, 67, 54, 0.1)); + color: var(--sys-color-on-error-container, #c62828); + border: 1px solid var(--sys-color-error, #f44336); + } + + .skip-link { + margin-top: 24px; + color: var(--color-text-secondary); + font-size: 13px; + cursor: pointer; + text-decoration: underline; + background: none; + border: none; + } + + .skip-link:hover { + color: var(--color-text-primary); + } + + /* Features List */ + .feature-list { + width: 100%; + max-width: 450px; + display: flex; + flex-direction: column; + gap: 12px; + } + + .feature-item { + display: flex; + align-items: flex-start; + gap: 16px; + padding: 16px; + background: var(--color-background); + border-radius: 8px; + border: 1px solid var(--color-details-hairline); + } + + .feature-icon { + font-size: 28px; + flex-shrink: 0; + } + + .feature-content { + flex: 1; + } + + .feature-title { + font-size: 15px; + font-weight: 600; + margin: 0 0 4px 0; + color: var(--color-text-primary); + } + + .feature-description { + font-size: 13px; + color: var(--color-text-secondary); + margin: 0; + line-height: 1.4; + } + + /* Ready Step */ + .ready-icon { + font-size: 72px; + margin-bottom: 16px; + } + + .tips-list { + width: 100%; + max-width: 400px; + margin-top: 24px; + text-align: left; + } + + .tip-item { + display: flex; + align-items: center; + gap: 12px; + padding: 12px 0; + border-bottom: 1px solid var(--color-details-hairline); + font-size: 14px; + color: var(--color-text-secondary); + } + + .tip-item:last-child { + border-bottom: none; + } + + .tip-icon { + font-size: 18px; + } + + /* Footer */ + .onboarding-footer { + display: flex; + justify-content: space-between; + align-items: center; + padding: 16px 32px; + background: var(--color-background-elevation-2); + border-top: 1px solid var(--color-details-hairline); + } + + .footer-left { + display: flex; + gap: 12px; + } + + .footer-right { + display: flex; + gap: 12px; + } + + .btn { + padding: 10px 24px; + border-radius: 6px; + font-size: 14px; + font-weight: 500; + cursor: pointer; + transition: all 0.2s ease; + border: none; + } + + .btn-secondary { + background: var(--color-background); + border: 1px solid var(--color-details-hairline); + color: var(--color-text-primary); + } + + .btn-secondary:hover { + background: var(--color-background-elevation-0); + } + + .btn-primary { + background: var(--color-primary); + color: white; + } + + .btn-primary:hover { + background: var(--color-primary-variant, #0093e0); + transform: translateY(-1px); + box-shadow: 0 4px 12px rgba(0, 164, 254, 0.3); + } + + .btn-primary:disabled { + opacity: 0.5; + cursor: not-allowed; + transform: none; + box-shadow: none; + } + + .btn-text { + background: none; + border: none; + color: var(--color-text-secondary); + padding: 10px 16px; + } + + .btn-text:hover { + color: var(--color-text-primary); + } + + /* Setup Required Banner (for main UI) */ + .setup-required-banner { + display: flex; + align-items: center; + justify-content: space-between; + padding: 12px 16px; + background: var(--sys-color-yellow-container, rgba(255, 193, 7, 0.1)); + border: 1px solid var(--sys-color-accent-yellow, #ffc107); + border-radius: 8px; + margin: 12px; + } + + .setup-banner-text { + display: flex; + align-items: center; + gap: 8px; + font-size: 14px; + color: var(--color-text-primary); + } + + .setup-banner-button { + padding: 6px 16px; + background: var(--color-primary); + color: white; + border: none; + border-radius: 4px; + font-size: 13px; + cursor: pointer; + } + + .setup-banner-button:hover { + background: var(--color-primary-variant, #0093e0); + } + + /* Loading spinner */ + .loading-spinner { + width: 24px; + height: 24px; + border: 3px solid var(--color-details-hairline); + border-top-color: var(--color-primary); + border-radius: 50%; + animation: spin 1s linear infinite; + margin: 0 auto; + } + + @keyframes spin { + to { transform: rotate(360deg); } + } + `; +} + +/** + * Apply onboarding styles to a dialog element + */ +export function applyOnboardingStyles(dialogElement: HTMLElement): void { + const styleElement = document.createElement('style'); + styleElement.textContent = getOnboardingStyles(); + dialogElement.appendChild(styleElement); +} diff --git a/front_end/panels/ai_chat/ui/settings/advanced/MemorySettings.ts b/front_end/panels/ai_chat/ui/settings/advanced/MemorySettings.ts new file mode 100644 index 0000000000..ccbe83f3ec --- /dev/null +++ b/front_end/panels/ai_chat/ui/settings/advanced/MemorySettings.ts @@ -0,0 +1,354 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { i18nString, UIStrings } from '../i18n-strings.js'; +import { MEMORY_ENABLED_KEY } from '../constants.js'; +import { MemoryBlockManager } from '../../../memory/MemoryBlockManager.js'; +import { FileContentViewer } from '../../FileContentViewer.js'; +import type { MemoryBlock } from '../../../memory/types.js'; +import type { FileSummary } from '../../../tools/FileStorageManager.js'; + +/** + * Memory System Settings + * + * Allows enabling/disabling the memory system that extracts and stores + * facts from conversations for use in future sessions. + * Also displays stored memory blocks with the ability to view their contents. + */ +export class MemorySettings { + private container: HTMLElement; + private memoryEnabledCheckbox: HTMLInputElement | null = null; + private blockListContainer: HTMLElement | null = null; + private blockManager: MemoryBlockManager; + + constructor(container: HTMLElement) { + this.container = container; + this.blockManager = new MemoryBlockManager(); + } + + render(): void { + // Clear any existing content + this.container.innerHTML = ''; + this.container.className = 'settings-section memory-section'; + + // Title + const memoryTitle = document.createElement('h3'); + memoryTitle.textContent = i18nString(UIStrings.memoryLabel); + memoryTitle.classList.add('settings-subtitle'); + this.container.appendChild(memoryTitle); + + // Memory enabled checkbox + const memoryEnabledContainer = document.createElement('div'); + memoryEnabledContainer.className = 'tracing-enabled-container'; + this.container.appendChild(memoryEnabledContainer); + + this.memoryEnabledCheckbox = document.createElement('input'); + this.memoryEnabledCheckbox.type = 'checkbox'; + this.memoryEnabledCheckbox.id = 'memory-enabled'; + this.memoryEnabledCheckbox.className = 'tracing-checkbox'; + // Default to enabled (true) if not set + const storedValue = localStorage.getItem(MEMORY_ENABLED_KEY); + this.memoryEnabledCheckbox.checked = storedValue !== 'false'; + memoryEnabledContainer.appendChild(this.memoryEnabledCheckbox); + + const memoryEnabledLabel = document.createElement('label'); + memoryEnabledLabel.htmlFor = 'memory-enabled'; + memoryEnabledLabel.className = 'tracing-label'; + memoryEnabledLabel.textContent = i18nString(UIStrings.memoryEnabled); + memoryEnabledContainer.appendChild(memoryEnabledLabel); + + const memoryEnabledHint = document.createElement('div'); + memoryEnabledHint.className = 'settings-hint'; + memoryEnabledHint.textContent = i18nString(UIStrings.memoryEnabledHint); + this.container.appendChild(memoryEnabledHint); + + // Toggle memory and save to localStorage + this.memoryEnabledCheckbox.addEventListener('change', () => { + localStorage.setItem(MEMORY_ENABLED_KEY, this.memoryEnabledCheckbox!.checked.toString()); + this.updateBlockListVisibility(); + }); + + // Memory blocks list container + this.blockListContainer = document.createElement('div'); + this.blockListContainer.className = 'memory-blocks-container'; + this.container.appendChild(this.blockListContainer); + + // Initial render of block list + this.updateBlockListVisibility(); + this.renderMemoryBlocks(); + } + + /** + * Update visibility of block list based on memory enabled state + */ + private updateBlockListVisibility(): void { + if (this.blockListContainer && this.memoryEnabledCheckbox) { + this.blockListContainer.style.display = this.memoryEnabledCheckbox.checked ? 'block' : 'none'; + } + } + + /** + * Render the list of memory blocks + */ + private async renderMemoryBlocks(): Promise { + if (!this.blockListContainer) { + return; + } + + this.blockListContainer.innerHTML = ''; + + // Add a subtitle for the blocks section + const blocksTitle = document.createElement('div'); + blocksTitle.className = 'memory-blocks-title'; + blocksTitle.textContent = 'Stored Memory'; + this.blockListContainer.appendChild(blocksTitle); + + try { + const blocks = await this.blockManager.getAllBlocks(); + + if (blocks.length === 0) { + const emptyMessage = document.createElement('div'); + emptyMessage.className = 'memory-blocks-empty'; + emptyMessage.textContent = 'No memory blocks stored yet. Memory will be extracted from conversations automatically.'; + this.blockListContainer.appendChild(emptyMessage); + return; + } + + // Create block list + const blockList = document.createElement('div'); + blockList.className = 'memory-blocks-list'; + + for (const block of blocks) { + const blockItem = this.createBlockItem(block); + blockList.appendChild(blockItem); + } + + this.blockListContainer.appendChild(blockList); + } catch (error) { + const errorMessage = document.createElement('div'); + errorMessage.className = 'memory-blocks-error'; + errorMessage.textContent = 'Failed to load memory blocks.'; + this.blockListContainer.appendChild(errorMessage); + } + } + + /** + * Create a clickable block item + */ + private createBlockItem(block: MemoryBlock): HTMLElement { + const item = document.createElement('div'); + item.className = 'memory-block-item'; + item.setAttribute('role', 'button'); + item.setAttribute('tabindex', '0'); + + // Icon based on block type + const icon = document.createElement('span'); + icon.className = 'memory-block-icon'; + icon.textContent = this.getBlockIcon(block.type); + item.appendChild(icon); + + // Block info + const info = document.createElement('div'); + info.className = 'memory-block-info'; + + const label = document.createElement('div'); + label.className = 'memory-block-label'; + label.textContent = block.label; + info.appendChild(label); + + const meta = document.createElement('div'); + meta.className = 'memory-block-meta'; + meta.textContent = this.formatBlockMeta(block); + info.appendChild(meta); + + item.appendChild(info); + + // Delete button + const deleteBtn = document.createElement('button'); + deleteBtn.className = 'memory-block-delete'; + deleteBtn.textContent = '🗑️'; + deleteBtn.title = 'Delete this memory block'; + deleteBtn.addEventListener('click', (e: MouseEvent) => { + e.stopPropagation(); // Don't trigger view + this.confirmAndDeleteBlock(block); + }); + item.appendChild(deleteBtn); + + // Click handler to view content + const handleClick = (): void => { + this.viewBlock(block); + }; + + item.addEventListener('click', handleClick); + item.addEventListener('keydown', (e: KeyboardEvent) => { + if (e.key === 'Enter' || e.key === ' ') { + e.preventDefault(); + handleClick(); + } + }); + + return item; + } + + /** + * Confirm and delete a memory block + */ + private async confirmAndDeleteBlock(block: MemoryBlock): Promise { + const confirmed = confirm(`Delete "${block.label}" memory block? This cannot be undone.`); + if (!confirmed) { + return; + } + + // Extract projectName from filename for project blocks + let projectName: string | undefined; + if (block.type === 'project') { + projectName = block.filename.replace('memory_project_', '').replace('.md', ''); + } + + try { + await this.blockManager.deleteBlock(block.type, projectName); + await this.renderMemoryBlocks(); // Refresh list + } catch (error) { + console.error('Failed to delete memory block:', error); + alert('Failed to delete memory block.'); + } + } + + /** + * Get icon for block type + */ + private getBlockIcon(type: string): string { + switch (type) { + case 'user': + return '👤'; + case 'facts': + return '💡'; + case 'project': + return '📁'; + default: + return '📄'; + } + } + + /** + * Format block metadata string + */ + private formatBlockMeta(block: MemoryBlock): string { + const charCount = this.formatCharCount(block.content.length); + const updated = this.formatDate(block.updatedAt); + return `${charCount} • Updated ${updated}`; + } + + /** + * Format character count + */ + private formatCharCount(chars: number): string { + if (chars < 1000) { + return `${chars} chars`; + } + return `${(chars / 1000).toFixed(1)}K chars`; + } + + /** + * Format date for display + */ + private formatDate(timestamp: number): string { + const date = new Date(timestamp); + const now = new Date(); + const diffMs = now.getTime() - date.getTime(); + const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24)); + + if (diffDays === 0) { + return 'today'; + } else if (diffDays === 1) { + return 'yesterday'; + } else if (diffDays < 7) { + return `${diffDays} days ago`; + } else { + return date.toLocaleDateString(); + } + } + + /** + * Open FileContentViewer to display block content with edit capability + */ + private async viewBlock(block: MemoryBlock): Promise { + // Create a FileSummary-compatible object for FileContentViewer + const fileSummary: FileSummary = { + fileName: block.filename, + size: block.content.length, + mimeType: 'text/markdown', + createdAt: block.updatedAt, + updatedAt: block.updatedAt, + }; + + const result = await FileContentViewer.show(fileSummary, block.content, { + editable: true, + }); + + // If we got a webappId, start polling for saved content + if (result?.webappId) { + this.startPollingForSave(result.webappId, block); + } + } + + /** Active polling interval ID for save detection */ + private pollIntervalId: ReturnType | null = null; + + /** + * Start polling for saved content from the file viewer iframe + */ + private startPollingForSave(webappId: string, block: MemoryBlock): void { + // Clear any existing polling + this.stopPollingForSave(); + + // Poll every 500ms for saved content + this.pollIntervalId = setInterval(async () => { + const savedContent = await FileContentViewer.checkForSavedContent(webappId); + if (savedContent !== null) { + this.stopPollingForSave(); + await this.saveBlock(block, savedContent); + } + }, 500); + } + + /** + * Stop polling for saved content + */ + private stopPollingForSave(): void { + if (this.pollIntervalId !== null) { + clearInterval(this.pollIntervalId); + this.pollIntervalId = null; + } + } + + /** + * Save edited block content + */ + private async saveBlock(block: MemoryBlock, newContent: string): Promise { + // Extract projectName from filename for project blocks + let projectName: string | undefined; + if (block.type === 'project') { + projectName = block.filename.replace('memory_project_', '').replace('.md', ''); + } + + try { + await this.blockManager.updateBlock(block.type, newContent, projectName); + // Refresh the block list to show updated content + await this.renderMemoryBlocks(); + } catch (error) { + console.error('Failed to save memory block:', error); + alert('Failed to save memory block. Content may exceed the character limit.'); + } + } + + save(): void { + // Memory settings are auto-saved on checkbox change + } + + cleanup(): void { + // Stop any active polling + this.stopPollingForSave(); + } +} diff --git a/front_end/panels/ai_chat/ui/settings/constants.ts b/front_end/panels/ai_chat/ui/settings/constants.ts index 915c4f412e..f1618434b1 100644 --- a/front_end/panels/ai_chat/ui/settings/constants.ts +++ b/front_end/panels/ai_chat/ui/settings/constants.ts @@ -45,3 +45,8 @@ export const MILVUS_OPENAI_KEY = 'ai_chat_milvus_openai_key'; * Advanced settings toggle key */ export const ADVANCED_SETTINGS_ENABLED_KEY = 'ai_chat_advanced_settings_enabled'; + +/** + * Memory system toggle key + */ +export const MEMORY_ENABLED_KEY = 'ai_chat_memory_enabled'; diff --git a/front_end/panels/ai_chat/ui/settings/i18n-strings.ts b/front_end/panels/ai_chat/ui/settings/i18n-strings.ts index 4693bac547..21e000a994 100644 --- a/front_end/panels/ai_chat/ui/settings/i18n-strings.ts +++ b/front_end/panels/ai_chat/ui/settings/i18n-strings.ts @@ -509,6 +509,18 @@ export const UIStrings = { *@description Models count hint text */ modelsCountHint: '{n} model(s)', + /** + *@description Memory section label + */ + memoryLabel: 'Memory', + /** + *@description Memory enabled label + */ + memoryEnabled: 'Enable Memory', + /** + *@description Memory enabled hint + */ + memoryEnabledHint: 'Automatically extract and remember facts from conversations for future sessions', }; /** diff --git a/front_end/panels/ai_chat/ui/settings/providerConfigs.ts b/front_end/panels/ai_chat/ui/settings/providerConfigs.ts index 0ad4c532d1..93294eb396 100644 --- a/front_end/panels/ai_chat/ui/settings/providerConfigs.ts +++ b/front_end/panels/ai_chat/ui/settings/providerConfigs.ts @@ -13,6 +13,17 @@ import { GOOGLEAI_API_KEY_STORAGE_KEY, } from './constants.js'; +// Lazy i18n helpers to avoid calling i18nString at module load time +const lazyString = (getter: () => string): string => { + // Return a proxy string that evaluates lazily + // For now, use fallback strings - i18n will be available when UI renders + try { + return getter(); + } catch { + return ''; + } +}; + /** * OpenAI provider configuration * - API key only @@ -23,8 +34,8 @@ export const OpenAIConfig: ProviderConfig = { id: 'openai', displayName: 'OpenAI', apiKeyStorageKey: OPENAI_API_KEY_STORAGE_KEY, - apiKeyLabel: i18nString(UIStrings.apiKeyLabel), - apiKeyHint: i18nString(UIStrings.apiKeyHint), + get apiKeyLabel() { return lazyString(() => i18nString(UIStrings.apiKeyLabel)); }, + get apiKeyHint() { return lazyString(() => i18nString(UIStrings.apiKeyHint)); }, apiKeyPlaceholder: 'Enter your OpenAI API key', hasModelSelectors: true, hasFetchButton: false, @@ -40,8 +51,8 @@ export const BrowserOperatorConfig: ProviderConfig = { id: 'browseroperator', displayName: 'BrowserOperator', apiKeyStorageKey: BROWSEROPERATOR_API_KEY_STORAGE_KEY, - apiKeyLabel: i18nString(UIStrings.browseroperatorApiKeyLabel), - apiKeyHint: i18nString(UIStrings.browseroperatorApiKeyHint), + get apiKeyLabel() { return lazyString(() => i18nString(UIStrings.browseroperatorApiKeyLabel)); }, + get apiKeyHint() { return lazyString(() => i18nString(UIStrings.browseroperatorApiKeyHint)); }, apiKeyPlaceholder: 'Enter your BrowserOperator API key (optional)', hasModelSelectors: false, hasFetchButton: false, @@ -59,12 +70,12 @@ export const GroqConfig: ProviderConfig = { id: 'groq', displayName: 'Groq', apiKeyStorageKey: GROQ_API_KEY_STORAGE_KEY, - apiKeyLabel: i18nString(UIStrings.groqApiKeyLabel), - apiKeyHint: i18nString(UIStrings.groqApiKeyHint), + get apiKeyLabel() { return lazyString(() => i18nString(UIStrings.groqApiKeyLabel)); }, + get apiKeyHint() { return lazyString(() => i18nString(UIStrings.groqApiKeyHint)); }, apiKeyPlaceholder: 'Enter your Groq API key', hasModelSelectors: true, hasFetchButton: true, - fetchButtonLabel: i18nString(UIStrings.fetchGroqModelsButton), + get fetchButtonLabel() { return lazyString(() => i18nString(UIStrings.fetchGroqModelsButton)); }, fetchMethodName: 'fetchGroqModels', useNameAsLabel: false, }; diff --git a/front_end/panels/ai_chat/ui/settings/utils/styles.ts b/front_end/panels/ai_chat/ui/settings/utils/styles.ts index 756582bf78..b80fe97af5 100644 --- a/front_end/panels/ai_chat/ui/settings/utils/styles.ts +++ b/front_end/panels/ai_chat/ui/settings/utils/styles.ts @@ -501,6 +501,104 @@ export function getSettingsStyles(): string { from { transform: rotate(0deg); } to { transform: rotate(360deg); } } + + /* Memory blocks styles */ + .memory-blocks-container { + margin-top: 16px; + padding-top: 12px; + border-top: 1px solid var(--color-details-hairline); + } + + .memory-blocks-title { + font-size: 13px; + font-weight: 500; + color: var(--color-text-secondary); + margin-bottom: 8px; + text-transform: uppercase; + letter-spacing: 0.5px; + } + + .memory-blocks-list { + display: flex; + flex-direction: column; + gap: 4px; + } + + .memory-block-item { + display: flex; + align-items: center; + gap: 10px; + padding: 10px 12px; + border-radius: 6px; + background-color: var(--color-background-elevation-1); + cursor: pointer; + transition: background-color 0.15s ease; + } + + .memory-block-delete { + background: none; + border: none; + cursor: pointer; + padding: 4px 8px; + font-size: 16px; + opacity: 0.4; + transition: opacity 0.15s ease; + flex-shrink: 0; + border-radius: 4px; + } + + .memory-block-delete:hover { + opacity: 1; + background-color: var(--color-background-elevation-2); + } + + .memory-block-item:hover { + background-color: var(--color-background-elevation-2); + } + + .memory-block-item:focus { + outline: 2px solid var(--color-primary); + outline-offset: -2px; + } + + .memory-block-icon { + font-size: 18px; + line-height: 1; + flex-shrink: 0; + } + + .memory-block-info { + flex: 1; + min-width: 0; + } + + .memory-block-label { + font-size: 14px; + font-weight: 500; + color: var(--color-text-primary); + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + } + + .memory-block-meta { + font-size: 12px; + color: var(--color-text-secondary); + margin-top: 2px; + } + + .memory-blocks-empty { + font-size: 13px; + color: var(--color-text-secondary); + font-style: italic; + padding: 8px 0; + } + + .memory-blocks-error { + font-size: 13px; + color: var(--color-accent-red); + padding: 8px 0; + } `; } diff --git a/front_end/panels/ai_chat/ui/settings/utils/validation.ts b/front_end/panels/ai_chat/ui/settings/utils/validation.ts index 9c37e555ce..e048b1ecd6 100644 --- a/front_end/panels/ai_chat/ui/settings/utils/validation.ts +++ b/front_end/panels/ai_chat/ui/settings/utils/validation.ts @@ -4,6 +4,7 @@ import { DEFAULT_PROVIDER_MODELS } from '../../AIChatPanel.js'; import type { ModelOption, ProviderType, ModelTier } from '../types.js'; +import { findClosestModel } from '../../../LLM/FuzzyModelMatcher.js'; /** * Get a valid model for a specific provider, falling back to defaults if needed @@ -14,21 +15,31 @@ export function getValidModelForProvider( provider: ProviderType, modelType: ModelTier, ): string { - // Check if current model is valid for this provider + const availableValues = providerModels.map(m => m.value); + + // 1. Check if current model is valid (exact match only) if (providerModels.some(model => model.value === currentModel)) { return currentModel; } - // Get defaults from AIChatPanel's DEFAULT_PROVIDER_MODELS + // 2. Get defaults from AIChatPanel's DEFAULT_PROVIDER_MODELS const defaults = DEFAULT_PROVIDER_MODELS[provider] || DEFAULT_PROVIDER_MODELS.openai; const defaultModel = modelType === 'mini' ? defaults.mini : defaults.nano; - // Return default if it exists in provider models + // 3. Check exact match for default if (defaultModel && providerModels.some(model => model.value === defaultModel)) { return defaultModel; } - // If no valid model found, return empty string to indicate no selection + // 4. Try fuzzy match for default only + if (defaultModel) { + const fuzzyDefault = findClosestModel(defaultModel, availableValues); + if (fuzzyDefault) { + return fuzzyDefault; + } + } + + // 5. If no valid model found, return empty string to indicate no selection // The UI should handle this by showing a placeholder or the first available option return ''; }