diff --git a/docs/config-management-detail-design.md b/docs/config-management-detail-design.md index e7ef7836..d0c023ab 100644 --- a/docs/config-management-detail-design.md +++ b/docs/config-management-detail-design.md @@ -130,3 +130,14 @@ custom provider 来自: - 不在 `config` 中堆兼容旧字段的逻辑 - 不把选择修正混进快照校验 - 不把 provider/catalog/runtime 语义倒灌回 `config` + +## custom provider `models` 校验约束 + +`~/.neocode/providers//provider.yaml` 中允许通过 `models` 补齐模型元数据,用于 catalog/discovery 无法提供完整 `ContextWindow` 或 `MaxOutputTokens` 的场景。 + +该能力的约束是: + +- `models[].id` 必须非空。 +- `models[].context_window` 和 `models[].max_output_tokens` 如果显式提供,必须大于 `0`。 +- 重复的模型 `id` 会在加载 custom provider 时直接失败,不保留 silently drop 的宽松行为。 +- 这些元数据不会写回 `config.yaml`,只在 custom provider 文件中声明,并通过现有 catalog 合并链路参与运行时解析。 diff --git a/docs/context-compact.md b/docs/context-compact.md index 87fded85..86c20a63 100644 --- a/docs/context-compact.md +++ b/docs/context-compact.md @@ -1,5 +1,11 @@ # Context Compact +## Auto Compact Failure Fallback + +- 当 `context.auto_compact.input_token_threshold <= 0` 时,系统会尝试基于当前模型的 `ContextWindow` 自动推导阈值。 +- 若当前 provider 选择无效、catalog snapshot 查询失败,或模型窗口元数据缺失,系统会直接回退到 `fallback_input_token_threshold`。 +- 自动推导失败不会静默关闭 auto compact;runtime 仍会拿到一个可用的保底阈值。 + 本文档说明 NeoCode 中 context compact 的配置、执行链路和摘要约定。 ## 概览 @@ -23,7 +29,9 @@ context: micro_compact_disabled: false auto_compact: enabled: false - input_token_threshold: 100000 + input_token_threshold: 0 + reserve_tokens: 13000 + fallback_input_token_threshold: 100000 ``` - `manual_strategy` @@ -39,7 +47,7 @@ context: - `auto_compact.enabled` 控制是否启用基于 token 阈值的自动压缩;默认关闭。 - `auto_compact.input_token_threshold` - 当会话累计输入 token 数达到此阈值时触发自动压缩;默认 100000。 + 当会话累计输入 token 数达到此阈值时触发自动压缩;默认 `0`(自动推导),推导失败时回退到 `fallback_input_token_threshold`(默认 `100000`)。 ## 自动压缩 @@ -154,3 +162,12 @@ compact 相关 runtime 事件包括: - `trigger_mode` - `transcript_id` - `transcript_path` + +## Auto Compact 阈值解析 + +- `context.auto_compact.input_token_threshold > 0` 时,直接使用显式手动阈值。 +- `context.auto_compact.input_token_threshold <= 0` 时,系统会对当前选中的 provider/model 做自动推导。 +- 自动推导公式为 `resolved_threshold = context_window - reserve_tokens`。 +- `reserve_tokens` 默认 `13000`,用于给输出、tool call 和 system prompt 预留缓冲。 +- 如果当前模型没有可用的 `ContextWindow`,或窗口值小于等于 `reserve_tokens`,则回退到 `fallback_input_token_threshold`。 +- `fallback_input_token_threshold` 默认 `100000`,用于保证主链路在缺少模型窗口元数据时仍可稳定运行。 diff --git a/docs/guides/adding-providers.md b/docs/guides/adding-providers.md index 95d94b01..22d32636 100644 --- a/docs/guides/adding-providers.md +++ b/docs/guides/adding-providers.md @@ -270,3 +270,34 @@ func DefaultProviders() []ProviderConfig { ``` 所有内置 provider 都通过代码集中注册。模型选择器展示的候选模型由默认模型、动态发现结果和本地缓存共同组成。 + +## custom provider 模型元数据补齐 + +对于复用 `openaicompat` 驱动的 custom provider,如果上游 `GET /models` 不能返回可靠的上下文窗口信息,可以在: + +```text +~/.neocode/providers//provider.yaml +``` + +中显式声明 `models`: + +```yaml +name: company-gateway +driver: openaicompat +api_key_env: COMPANY_GATEWAY_API_KEY +models: + - id: deepseek-coder + name: DeepSeek Coder + context_window: 131072 + max_output_tokens: 8192 +openai_compatible: + base_url: https://llm.example.com/v1 + api_style: chat_completions +``` + +约束如下: + +- `models[].id` 必须非空。 +- `models[].context_window` 和 `models[].max_output_tokens` 如果显式配置,必须大于 `0`。 +- 同一个 `provider.yaml` 中重复的模型 `id` 会在加载阶段直接报错。 +- 这些元数据会进入统一的 model catalog 合并链路,优先级仍为“配置模型元数据优先于 discovery/default”。 diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index b2a407c0..c03f3209 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -38,6 +38,8 @@ selected_provider: openai current_model: gpt-5.4 shell: bash tool_timeout_sec: 20 +runtime: + max_no_progress_streak: 3 tools: webfetch: @@ -56,7 +58,9 @@ context: micro_compact_disabled: false auto_compact: enabled: false - input_token_threshold: 100000 + input_token_threshold: 0 + reserve_tokens: 13000 + fallback_input_token_threshold: 100000 ``` ### 基础字段 @@ -79,6 +83,14 @@ context: | `context.compact.micro_compact_disabled` | 是否关闭默认启用的 micro compact | | `context.auto_compact.enabled` | 是否启用自动压缩 | | `context.auto_compact.input_token_threshold` | 自动压缩输入 token 阈值 | +| `context.auto_compact.reserve_tokens` | 自动阈值推导时预留 token 缓冲(`resolved_threshold = context_window - reserve_tokens`) | +| `context.auto_compact.fallback_input_token_threshold` | 自动推导失败时使用的保底阈值 | + +### `runtime` 字段 + +| 字段 | 说明 | +|------|------| +| `runtime.max_no_progress_streak` | 连续”无进展”轮次熔断阈值,默认 `3`;streak 达到 `limit-1`(默认第 2 轮)时向模型注入一次系统级纠偏提示,达到 `limit`(默认第 3 轮)时终止运行 | ### `tools` 字段 @@ -134,6 +146,13 @@ openai_compatible: api_style: chat_completions ``` +## Auto Compact 失败与校验补充 + +- 当 `context.auto_compact.input_token_threshold <= 0` 时,如果当前 provider 选择无效、catalog snapshot 查询失败,或模型缺少可用的 `ContextWindow`,系统会回退到 `fallback_input_token_threshold`,不会静默关闭 auto compact。 +- `~/.neocode/providers//provider.yaml` 中的 `models[].id` 必须非空。 +- `models[].context_window` 和 `models[].max_output_tokens` 如果显式配置,必须大于 `0`。 +- `models` 中重复的模型 `id` 会在加载 `provider.yaml` 时直接报错。 + 文件路径: ```text @@ -222,3 +241,26 @@ config: environment variable OPENAI_API_KEY is empty - [添加 Provider](./adding-providers.md) - [配置管理详细设计](../config-management-detail-design.md) - [Context Compact](../context-compact.md) + +## Auto Compact 补充说明 + +- `context.auto_compact.input_token_threshold > 0` 时,系统直接使用该显式阈值。 +- `context.auto_compact.input_token_threshold <= 0` 时,系统会根据当前 `current_model` 对应的 `ContextWindow` 自动推导输入阈值。 +- 推导公式为 `context_window - reserve_tokens`。 +- `reserve_tokens` 默认 `13000`。 +- 如果当前 provider/model 没有可用的 `ContextWindow` 元数据,则回退到 `fallback_input_token_threshold`。 +- custom provider 可以在 `~/.neocode/providers//provider.yaml` 中通过 `models` 字段补齐模型元数据,例如: + +```yaml +name: company-gateway +driver: openaicompat +api_key_env: COMPANY_GATEWAY_API_KEY +models: + - id: deepseek-coder + name: DeepSeek Coder + context_window: 131072 + max_output_tokens: 8192 +openai_compatible: + base_url: https://llm.example.com/v1 + api_style: chat_completions +``` diff --git a/go.mod b/go.mod index b5b9bda2..8017ac76 100644 --- a/go.mod +++ b/go.mod @@ -3,13 +3,13 @@ module neo-code go 1.25.0 require ( - github.com/atotto/clipboard v0.1.4 github.com/charmbracelet/bubbles v1.0.0 github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/glamour v1.0.0 github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 github.com/spf13/cobra v1.10.2 github.com/spf13/viper v1.21.0 + golang.design/x/clipboard v0.7.1 golang.org/x/net v0.52.0 golang.org/x/sys v0.42.0 gopkg.in/yaml.v3 v3.0.1 @@ -18,6 +18,7 @@ require ( require ( github.com/Microsoft/go-winio v0.6.2 // indirect github.com/alecthomas/chroma/v2 v2.20.0 // indirect + github.com/atotto/clipboard v0.1.4 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymerick/douceur v0.2.0 // indirect github.com/charmbracelet/colorprofile v0.4.3 // indirect @@ -57,6 +58,9 @@ require ( github.com/yuin/goldmark v1.7.13 // indirect github.com/yuin/goldmark-emoji v1.0.6 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/exp/shiny v0.0.0-20250606033433-dcc06ee1d476 // indirect + golang.org/x/image v0.28.0 // indirect + golang.org/x/mobile v0.0.0-20250606033058-a2a15c67f36f // indirect golang.org/x/term v0.41.0 // indirect golang.org/x/text v0.35.0 // indirect ) diff --git a/go.sum b/go.sum index 59439be3..c4388c9e 100644 --- a/go.sum +++ b/go.sum @@ -130,8 +130,16 @@ github.com/yuin/goldmark-emoji v1.0.6 h1:QWfF2FYaXwL74tfGOW5izeiZepUDroDJfWubQI9 github.com/yuin/goldmark-emoji v1.0.6/go.mod h1:ukxJDKFpdFb5x0a5HqbdlcKtebh086iJpI31LTKmWuA= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.design/x/clipboard v0.7.1 h1:OEG3CmcYRBNnRwpDp7+uWLiZi3hrMRJpE9JkkkYtz2c= +golang.design/x/clipboard v0.7.1/go.mod h1:i5SiIqj0wLFw9P/1D7vfILFK0KHMk7ydE72HRrUIgkg= golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= +golang.org/x/exp/shiny v0.0.0-20250606033433-dcc06ee1d476 h1:Wdx0vgH5Wgsw+lF//LJKmWOJBLWX6nprsMqnf99rYDE= +golang.org/x/exp/shiny v0.0.0-20250606033433-dcc06ee1d476/go.mod h1:ygj7T6vSGhhm/9yTpOQQNvuAUFziTH7RUiH74EoE2C8= +golang.org/x/image v0.28.0 h1:gdem5JW1OLS4FbkWgLO+7ZeFzYtL3xClb97GaUzYMFE= +golang.org/x/image v0.28.0/go.mod h1:GUJYXtnGKEUgggyzh+Vxt+AviiCcyiwpsl8iQ8MvwGY= +golang.org/x/mobile v0.0.0-20250606033058-a2a15c67f36f h1:/n+PL2HlfqeSiDCuhdBbRNlGS/g2fM4OHufalHaTVG8= +golang.org/x/mobile v0.0.0-20250606033058-a2a15c67f36f/go.mod h1:ESkJ836Z6LpG6mTVAhA48LpfW/8fNR0ifStlH2axyfg= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go index f5cd22c4..0f64a269 100644 --- a/internal/app/bootstrap.go +++ b/internal/app/bootstrap.go @@ -168,6 +168,15 @@ func BuildRuntime(ctx context.Context, opts BootstrapOptions) (RuntimeBundle, er contextBuilder, ) runtimeSvc.SetSkillsRegistry(buildSkillsRegistry(ctx, loader.BaseDir())) + runtimeSvc.SetAutoCompactThresholdResolver(runtimeAutoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + resolution, err := configstate.ResolveAutoCompactThreshold(ctx, cfg, modelCatalogs) + if err != nil { + return 0, err + } + return resolution.Threshold, nil + }, + )) // 注入记忆提取钩子:当 AutoExtract 启用且 memoSvc 可用时,ReAct 循环完成后异步提取记忆。 if memoSvc != nil && cfg.Memo.AutoExtract { @@ -306,3 +315,9 @@ type textGenAdapter func(ctx context.Context, prompt string, msgs []providertype func (f textGenAdapter) Generate(ctx context.Context, prompt string, msgs []providertypes.Message) (string, error) { return f(ctx, prompt, msgs) } + +type runtimeAutoCompactThresholdResolverFunc func(ctx context.Context, cfg config.Config) (int, error) + +func (f runtimeAutoCompactThresholdResolverFunc) ResolveAutoCompactThreshold(ctx context.Context, cfg config.Config) (int, error) { + return f(ctx, cfg) +} diff --git a/internal/config/config.go b/internal/config/config.go index 7d0d5c07..87ce5d0f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -21,6 +21,7 @@ type Config struct { Workdir string `yaml:"-"` Shell string `yaml:"shell"` ToolTimeoutSec int `yaml:"tool_timeout_sec,omitempty"` + Runtime RuntimeConfig `yaml:"runtime,omitempty"` Context ContextConfig `yaml:"context,omitempty"` Tools ToolsConfig `yaml:"tools,omitempty"` Memo MemoConfig `yaml:"memo,omitempty"` @@ -32,6 +33,7 @@ func StaticDefaults() *Config { Workdir: DefaultWorkdir, Shell: defaultShell(), ToolTimeoutSec: DefaultToolTimeoutSec, + Runtime: defaultRuntimeConfig(), Context: defaultContextConfig(), Tools: ToolsConfig{ WebFetch: defaultWebFetchConfig(), @@ -48,6 +50,7 @@ func (c *Config) Clone() Config { clone := *c clone.Providers = cloneProviders(c.Providers) + clone.Runtime = c.Runtime.Clone() clone.Context = c.Context.Clone() clone.Tools = c.Tools.Clone() clone.Memo = c.Memo.Clone() @@ -69,6 +72,7 @@ func (c *Config) applyStaticDefaults(defaults Config) { if c.ToolTimeoutSec <= 0 { c.ToolTimeoutSec = defaults.ToolTimeoutSec } + c.Runtime.ApplyDefaults(defaults.Runtime) c.Context.ApplyDefaults(defaults.Context) c.Tools.ApplyDefaults(defaults.Tools) c.Memo.ApplyDefaults(defaults.Memo) @@ -122,6 +126,9 @@ func (c *Config) ValidateSnapshot() error { if err := c.Tools.Validate(); err != nil { return fmt.Errorf("config: tools: %w", err) } + if err := c.Runtime.Validate(); err != nil { + return fmt.Errorf("config: runtime: %w", err) + } if err := c.Context.Validate(); err != nil { return fmt.Errorf("config: context: %w", err) } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index f7b5de4e..0a55a580 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1242,6 +1242,14 @@ func TestAutoCompactConfigDefaults(t *testing.T) { t.Fatalf("expected input_token_threshold=%d, got %d", DefaultAutoCompactInputTokenThreshold, cfg.Context.AutoCompact.InputTokenThreshold) } + if cfg.Context.AutoCompact.ReserveTokens != DefaultAutoCompactReserveTokens { + t.Fatalf("expected reserve_tokens=%d, got %d", + DefaultAutoCompactReserveTokens, cfg.Context.AutoCompact.ReserveTokens) + } + if cfg.Context.AutoCompact.FallbackInputTokenThreshold != DefaultAutoCompactFallbackInputTokenThreshold { + t.Fatalf("expected fallback_input_token_threshold=%d, got %d", + DefaultAutoCompactFallbackInputTokenThreshold, cfg.Context.AutoCompact.FallbackInputTokenThreshold) + } if cfg.Context.AutoCompact.Enabled != false { t.Fatalf("expected enabled=false, got %v", cfg.Context.AutoCompact.Enabled) @@ -1253,13 +1261,20 @@ func TestAutoCompactConfigApplyDefaults(t *testing.T) { cfg := AutoCompactConfig{} defaults := AutoCompactConfig{ - InputTokenThreshold: 50000, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, } cfg.ApplyDefaults(defaults) - if cfg.InputTokenThreshold != 50000 { - t.Fatalf("expected threshold=50000, got %d", cfg.InputTokenThreshold) + if cfg.InputTokenThreshold != 0 { + t.Fatalf("expected threshold to remain implicit 0, got %d", cfg.InputTokenThreshold) + } + if cfg.ReserveTokens != 13000 { + t.Fatalf("expected reserve_tokens=13000, got %d", cfg.ReserveTokens) + } + if cfg.FallbackInputTokenThreshold != 100000 { + t.Fatalf("expected fallback_input_token_threshold=100000, got %d", cfg.FallbackInputTokenThreshold) } } @@ -1267,10 +1282,14 @@ func TestAutoCompactConfigApplyDefaultsPreservesExplicit(t *testing.T) { t.Parallel() cfg := AutoCompactConfig{ - InputTokenThreshold: 200000, + InputTokenThreshold: 200000, + ReserveTokens: 5000, + FallbackInputTokenThreshold: 80000, } defaults := AutoCompactConfig{ - InputTokenThreshold: 50000, + InputTokenThreshold: 50000, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, } cfg.ApplyDefaults(defaults) @@ -1278,13 +1297,19 @@ func TestAutoCompactConfigApplyDefaultsPreservesExplicit(t *testing.T) { if cfg.InputTokenThreshold != 200000 { t.Fatalf("expected explicit threshold=200000 to be preserved, got %d", cfg.InputTokenThreshold) } + if cfg.ReserveTokens != 5000 { + t.Fatalf("expected explicit reserve_tokens=5000 to be preserved, got %d", cfg.ReserveTokens) + } + if cfg.FallbackInputTokenThreshold != 80000 { + t.Fatalf("expected explicit fallback_input_token_threshold=80000 to be preserved, got %d", cfg.FallbackInputTokenThreshold) + } } func TestAutoCompactConfigApplyDefaultsNilReceiver(t *testing.T) { t.Parallel() var cfg *AutoCompactConfig - cfg.ApplyDefaults(AutoCompactConfig{InputTokenThreshold: 50000}) + cfg.ApplyDefaults(AutoCompactConfig{ReserveTokens: 13000, FallbackInputTokenThreshold: 100000}) } func TestContextConfigApplyDefaultsPropagatesAutoCompactDefaults(t *testing.T) { @@ -1293,7 +1318,8 @@ func TestContextConfigApplyDefaultsPropagatesAutoCompactDefaults(t *testing.T) { cfg := ContextConfig{} cfg.ApplyDefaults(ContextConfig{ AutoCompact: AutoCompactConfig{ - InputTokenThreshold: 50000, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, }, Compact: CompactConfig{ ManualStrategy: CompactManualStrategyKeepRecent, @@ -1303,8 +1329,14 @@ func TestContextConfigApplyDefaultsPropagatesAutoCompactDefaults(t *testing.T) { }, }) - if cfg.AutoCompact.InputTokenThreshold != 50000 { - t.Fatalf("expected auto compact threshold=50000, got %d", cfg.AutoCompact.InputTokenThreshold) + if cfg.AutoCompact.InputTokenThreshold != 0 { + t.Fatalf("expected auto compact threshold to remain implicit 0, got %d", cfg.AutoCompact.InputTokenThreshold) + } + if cfg.AutoCompact.ReserveTokens != 13000 { + t.Fatalf("expected reserve_tokens=13000, got %d", cfg.AutoCompact.ReserveTokens) + } + if cfg.AutoCompact.FallbackInputTokenThreshold != 100000 { + t.Fatalf("expected fallback_input_token_threshold=100000, got %d", cfg.AutoCompact.FallbackInputTokenThreshold) } } @@ -1312,16 +1344,15 @@ func TestAutoCompactConfigValidateEnabledWithoutThreshold(t *testing.T) { t.Parallel() cfg := AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 0, + Enabled: true, + InputTokenThreshold: 0, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, } err := cfg.Validate() - if err == nil { - t.Fatalf("expected validation error, got nil") - } - if !strings.Contains(err.Error(), "input_token_threshold") { - t.Fatalf("expected error about input_token_threshold, got %v", err) + if err != nil { + t.Fatalf("expected validation to allow implicit threshold, got %v", err) } } @@ -1329,8 +1360,10 @@ func TestAutoCompactConfigValidateDisabledWithoutThreshold(t *testing.T) { t.Parallel() cfg := AutoCompactConfig{ - Enabled: false, - InputTokenThreshold: 0, + Enabled: false, + InputTokenThreshold: 0, + ReserveTokens: 0, + FallbackInputTokenThreshold: 0, } err := cfg.Validate() @@ -1343,8 +1376,10 @@ func TestAutoCompactConfigValidateEnabledWithThreshold(t *testing.T) { t.Parallel() cfg := AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 50000, + Enabled: true, + InputTokenThreshold: 50000, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, } err := cfg.Validate() @@ -1353,12 +1388,46 @@ func TestAutoCompactConfigValidateEnabledWithThreshold(t *testing.T) { } } +func TestAutoCompactConfigValidateRejectsNonPositiveReserveTokens(t *testing.T) { + t.Parallel() + + cfg := AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + ReserveTokens: 0, + FallbackInputTokenThreshold: 100000, + } + + err := cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "reserve_tokens") { + t.Fatalf("expected reserve_tokens validation error, got %v", err) + } +} + +func TestAutoCompactConfigValidateRejectsNonPositiveFallbackThreshold(t *testing.T) { + t.Parallel() + + cfg := AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 0, + } + + err := cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "fallback_input_token_threshold") { + t.Fatalf("expected fallback_input_token_threshold validation error, got %v", err) + } +} + func TestAutoCompactConfigClone(t *testing.T) { t.Parallel() cfg := AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 75000, + Enabled: true, + InputTokenThreshold: 75000, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, } cloned := cfg.Clone() @@ -1370,6 +1439,13 @@ func TestAutoCompactConfigClone(t *testing.T) { t.Fatalf("expected threshold=%d to be cloned, got %d", cfg.InputTokenThreshold, cloned.InputTokenThreshold) } + if cfg.ReserveTokens != cloned.ReserveTokens { + t.Fatalf("expected reserve_tokens=%d to be cloned, got %d", cfg.ReserveTokens, cloned.ReserveTokens) + } + if cfg.FallbackInputTokenThreshold != cloned.FallbackInputTokenThreshold { + t.Fatalf("expected fallback_input_token_threshold=%d to be cloned, got %d", + cfg.FallbackInputTokenThreshold, cloned.FallbackInputTokenThreshold) + } cloned.InputTokenThreshold = 100000 if cfg.InputTokenThreshold == cloned.InputTokenThreshold { @@ -1382,8 +1458,10 @@ func TestAutoCompactConfigContextConfigValidate(t *testing.T) { ctx := ContextConfig{ AutoCompact: AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 0, + Enabled: true, + InputTokenThreshold: 0, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, }, Compact: CompactConfig{ ManualStrategy: CompactManualStrategyKeepRecent, @@ -1394,11 +1472,8 @@ func TestAutoCompactConfigContextConfigValidate(t *testing.T) { } err := ctx.Validate() - if err == nil { - t.Fatalf("expected validation error, got nil") - } - if !strings.Contains(err.Error(), "auto_compact") { - t.Fatalf("expected error to contain 'auto_compact', got %v", err) + if err != nil { + t.Fatalf("expected context validation to allow implicit threshold, got %v", err) } } @@ -1592,6 +1667,10 @@ func TestValidateSnapshotPropagatesCompactError(t *testing.T) { SupportedContentTypes: []string{"text/html"}, }, }, + Runtime: RuntimeConfig{ + MaxNoProgressStreak: 3, + MaxRepeatCycleStreak: 3, + }, Context: ContextConfig{ Compact: CompactConfig{ ManualStrategy: "invalid_strategy", @@ -1682,6 +1761,52 @@ func TestMarshalPersistedConfigEndsWithNewline(t *testing.T) { } } +func TestParseCurrentConfigRoundTripRuntimeConfig(t *testing.T) { + t.Parallel() + + snapshot := testDefaultConfig().Clone() + snapshot.Runtime.MaxNoProgressStreak = 5 + + data, err := marshalPersistedConfig(snapshot) + if err != nil { + t.Fatalf("marshalPersistedConfig() error = %v", err) + } + + parsed, err := parseCurrentConfig(data, StaticDefaults().Context, StaticDefaults().Memo) + if err != nil { + t.Fatalf("parseCurrentConfig() error = %v", err) + } + if parsed.Runtime.MaxNoProgressStreak != 5 { + t.Fatalf("expected max_no_progress_streak=5, got %d", parsed.Runtime.MaxNoProgressStreak) + } +} + +func TestParseCurrentConfigInvalidRuntimeValueDefaultsBeforeValidation(t *testing.T) { + t.Parallel() + + raw := []byte(` +selected_provider: openai +current_model: gpt-4.1 +shell: bash +runtime: + max_no_progress_streak: -2 +`) + + parsed, err := parseCurrentConfig(raw, StaticDefaults().Context, StaticDefaults().Memo) + if err != nil { + t.Fatalf("parseCurrentConfig() error = %v", err) + } + parsed.Providers = cloneProviders(testDefaultConfig().Providers) + parsed.applyStaticDefaults(*StaticDefaults()) + if err := parsed.ValidateSnapshot(); err != nil { + t.Fatalf("ValidateSnapshot() error = %v", err) + } + if parsed.Runtime.MaxNoProgressStreak != DefaultMaxNoProgressStreak { + t.Fatalf("expected default max_no_progress_streak=%d, got %d", + DefaultMaxNoProgressStreak, parsed.Runtime.MaxNoProgressStreak) + } +} + func TestAssembleProvidersAcceptsEmptyNameProvider(t *testing.T) { t.Parallel() diff --git a/internal/config/context.go b/internal/config/context.go index c0c2ffb5..57552b32 100644 --- a/internal/config/context.go +++ b/internal/config/context.go @@ -7,11 +7,13 @@ import ( ) const ( - DefaultCompactManualKeepRecentMessages = 10 - DefaultCompactMaxSummaryChars = 1200 - DefaultAutoCompactInputTokenThreshold = 100000 - DefaultMicroCompactRetainedToolSpans = 2 - DefaultCompactReadTimeMaxMessageSpans = 24 + DefaultCompactManualKeepRecentMessages = 10 + DefaultCompactMaxSummaryChars = 1200 + DefaultAutoCompactInputTokenThreshold = 0 + DefaultAutoCompactReserveTokens = 13000 + DefaultAutoCompactFallbackInputTokenThreshold = 100000 + DefaultMicroCompactRetainedToolSpans = 2 + DefaultCompactReadTimeMaxMessageSpans = 24 CompactManualStrategyKeepRecent = "keep_recent" CompactManualStrategyFullReplace = "full_replace" @@ -34,8 +36,10 @@ type CompactConfig struct { // AutoCompactConfig controls automatic context compression triggered by token thresholds. type AutoCompactConfig struct { - Enabled bool `yaml:"enabled"` - InputTokenThreshold int `yaml:"input_token_threshold,omitempty"` + Enabled bool `yaml:"enabled"` + InputTokenThreshold int `yaml:"input_token_threshold,omitempty"` + ReserveTokens int `yaml:"reserve_tokens,omitempty"` + FallbackInputTokenThreshold int `yaml:"fallback_input_token_threshold,omitempty"` } // defaultContextConfig 返回上下文压缩相关配置的默认值。 @@ -48,7 +52,9 @@ func defaultContextConfig() ContextConfig { func defaultAutoCompactConfig() AutoCompactConfig { return AutoCompactConfig{ - InputTokenThreshold: DefaultAutoCompactInputTokenThreshold, + InputTokenThreshold: DefaultAutoCompactInputTokenThreshold, + ReserveTokens: DefaultAutoCompactReserveTokens, + FallbackInputTokenThreshold: DefaultAutoCompactFallbackInputTokenThreshold, } } @@ -119,8 +125,11 @@ func (c *AutoCompactConfig) ApplyDefaults(defaults AutoCompactConfig) { if c == nil { return } - if c.InputTokenThreshold <= 0 { - c.InputTokenThreshold = defaults.InputTokenThreshold + if c.ReserveTokens <= 0 { + c.ReserveTokens = defaults.ReserveTokens + } + if c.FallbackInputTokenThreshold <= 0 { + c.FallbackInputTokenThreshold = defaults.FallbackInputTokenThreshold } } @@ -157,8 +166,14 @@ func (c CompactConfig) Validate() error { // Validate 校验 auto_compact 配置是否合法。 func (c AutoCompactConfig) Validate() error { - if c.Enabled && c.InputTokenThreshold <= 0 { - return errors.New("input_token_threshold must be greater than 0 when enabled") + if !c.Enabled { + return nil + } + if c.ReserveTokens <= 0 { + return errors.New("reserve_tokens must be greater than 0 when enabled") + } + if c.FallbackInputTokenThreshold <= 0 { + return errors.New("fallback_input_token_threshold must be greater than 0 when enabled") } return nil } diff --git a/internal/config/loader.go b/internal/config/loader.go index c7fba9b4..5e35c3a6 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -26,6 +26,7 @@ type persistedConfig struct { CurrentModel string `yaml:"current_model,omitempty"` Shell string `yaml:"shell"` ToolTimeoutSec int `yaml:"tool_timeout_sec,omitempty"` + Runtime RuntimeConfig `yaml:"runtime,omitempty"` Context persistedContextConfig `yaml:"context,omitempty"` Tools ToolsConfig `yaml:"tools,omitempty"` Memo persistedMemoConfig `yaml:"memo,omitempty"` @@ -47,8 +48,10 @@ type persistedCompactConfig struct { } type persistedAutoCompactConfig struct { - Enabled bool `yaml:"enabled"` - InputTokenThreshold int `yaml:"input_token_threshold,omitempty"` + Enabled bool `yaml:"enabled"` + InputTokenThreshold int `yaml:"input_token_threshold,omitempty"` + ReserveTokens int `yaml:"reserve_tokens,omitempty"` + FallbackInputTokenThreshold int `yaml:"fallback_input_token_threshold,omitempty"` } type persistedMemoConfig struct { @@ -204,6 +207,7 @@ func parseCurrentConfig(data []byte, contextDefaults ContextConfig, memoDefaults CurrentModel: strings.TrimSpace(file.CurrentModel), Shell: strings.TrimSpace(file.Shell), ToolTimeoutSec: file.ToolTimeoutSec, + Runtime: file.Runtime, Context: fromPersistedContextConfig(file.Context, contextDefaults), Tools: file.Tools, Memo: fromPersistedMemoConfig(file.Memo, memoDefaults), @@ -218,6 +222,7 @@ func marshalPersistedConfig(snapshot Config) ([]byte, error) { CurrentModel: snapshot.CurrentModel, Shell: snapshot.Shell, ToolTimeoutSec: snapshot.ToolTimeoutSec, + Runtime: snapshot.Runtime, Context: newPersistedContextConfig(snapshot.Context), Tools: snapshot.Tools, Memo: newPersistedMemoConfig(snapshot.Memo), @@ -246,8 +251,10 @@ func newPersistedContextConfig(cfg ContextConfig) persistedContextConfig { MaxArchivedPromptChars: cfg.Compact.MaxArchivedPromptChars, }, AutoCompact: persistedAutoCompactConfig{ - Enabled: cfg.AutoCompact.Enabled, - InputTokenThreshold: cfg.AutoCompact.InputTokenThreshold, + Enabled: cfg.AutoCompact.Enabled, + InputTokenThreshold: cfg.AutoCompact.InputTokenThreshold, + ReserveTokens: cfg.AutoCompact.ReserveTokens, + FallbackInputTokenThreshold: cfg.AutoCompact.FallbackInputTokenThreshold, }, } } @@ -265,8 +272,10 @@ func fromPersistedContextConfig(file persistedContextConfig, defaults ContextCon MaxArchivedPromptChars: file.Compact.MaxArchivedPromptChars, }, AutoCompact: AutoCompactConfig{ - Enabled: file.AutoCompact.Enabled, - InputTokenThreshold: file.AutoCompact.InputTokenThreshold, + Enabled: file.AutoCompact.Enabled, + InputTokenThreshold: file.AutoCompact.InputTokenThreshold, + ReserveTokens: file.AutoCompact.ReserveTokens, + FallbackInputTokenThreshold: file.AutoCompact.FallbackInputTokenThreshold, }, } out.Compact.ApplyDefaults(defaults.Compact) diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 86c87754..fb89594e 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -251,6 +251,11 @@ shell: powershell name: company-gateway driver: openaicompat api_key_env: COMPANY_GATEWAY_API_KEY +models: + - id: deepseek-coder + name: DeepSeek Coder + context_window: 131072 + max_output_tokens: 8192 openai_compatible: base_url: https://llm.example.com/v1 api_style: chat_completions @@ -290,8 +295,11 @@ openai_compatible: if customProvider.Model != "" { t.Fatalf("expected custom provider default model to be empty, got %q", customProvider.Model) } - if len(customProvider.Models) != 0 { - t.Fatalf("expected custom provider models to come only from remote discovery, got %+v", customProvider.Models) + if len(customProvider.Models) != 1 { + t.Fatalf("expected custom provider model metadata from provider.yaml, got %+v", customProvider.Models) + } + if customProvider.Models[0].ID != "deepseek-coder" || customProvider.Models[0].ContextWindow != 131072 { + t.Fatalf("expected parsed model metadata, got %+v", customProvider.Models[0]) } } @@ -421,6 +429,183 @@ models: } } +func TestLoaderRejectsCustomProviderModelWithoutID(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + customDir := filepath.Join(loader.BaseDir(), providersDirName, "company-gateway") + if err := os.MkdirAll(customDir, 0o755); err != nil { + t.Fatalf("mkdir custom provider dir: %v", err) + } + + providerYAML := ` +name: company-gateway +driver: openaicompat +api_key_env: COMPANY_GATEWAY_API_KEY +models: + - name: DeepSeek Coder +openai_compatible: + base_url: https://llm.example.com/v1 +` + if err := os.WriteFile(filepath.Join(customDir, customProviderConfigName), []byte(strings.TrimSpace(providerYAML)+"\n"), 0o644); err != nil { + t.Fatalf("write provider.yaml: %v", err) + } + + _, err := loader.Load(context.Background()) + if err == nil || !strings.Contains(err.Error(), "models[0].id") { + t.Fatalf("expected empty model id rejection, got %v", err) + } +} + +func TestLoaderRejectsCustomProviderModelWithInvalidContextWindow(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + customDir := filepath.Join(loader.BaseDir(), providersDirName, "company-gateway") + if err := os.MkdirAll(customDir, 0o755); err != nil { + t.Fatalf("mkdir custom provider dir: %v", err) + } + + providerYAML := ` +name: company-gateway +driver: openaicompat +api_key_env: COMPANY_GATEWAY_API_KEY +models: + - id: deepseek-coder + context_window: 0 +openai_compatible: + base_url: https://llm.example.com/v1 +` + if err := os.WriteFile(filepath.Join(customDir, customProviderConfigName), []byte(strings.TrimSpace(providerYAML)+"\n"), 0o644); err != nil { + t.Fatalf("write provider.yaml: %v", err) + } + + _, err := loader.Load(context.Background()) + if err == nil || !strings.Contains(err.Error(), "context_window") { + t.Fatalf("expected invalid context_window rejection, got %v", err) + } +} + +func TestLoaderRejectsCustomProviderModelWithInvalidMaxOutputTokens(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + customDir := filepath.Join(loader.BaseDir(), providersDirName, "company-gateway") + if err := os.MkdirAll(customDir, 0o755); err != nil { + t.Fatalf("mkdir custom provider dir: %v", err) + } + + providerYAML := ` +name: company-gateway +driver: openaicompat +api_key_env: COMPANY_GATEWAY_API_KEY +models: + - id: deepseek-coder + max_output_tokens: 0 +openai_compatible: + base_url: https://llm.example.com/v1 +` + if err := os.WriteFile(filepath.Join(customDir, customProviderConfigName), []byte(strings.TrimSpace(providerYAML)+"\n"), 0o644); err != nil { + t.Fatalf("write provider.yaml: %v", err) + } + + _, err := loader.Load(context.Background()) + if err == nil || !strings.Contains(err.Error(), "max_output_tokens") { + t.Fatalf("expected invalid max_output_tokens rejection, got %v", err) + } +} + +func TestLoaderRejectsCustomProviderDuplicateModelID(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + customDir := filepath.Join(loader.BaseDir(), providersDirName, "company-gateway") + if err := os.MkdirAll(customDir, 0o755); err != nil { + t.Fatalf("mkdir custom provider dir: %v", err) + } + + providerYAML := ` +name: company-gateway +driver: openaicompat +api_key_env: COMPANY_GATEWAY_API_KEY +models: + - id: deepseek-coder + - id: DeepSeek-Coder +openai_compatible: + base_url: https://llm.example.com/v1 +` + if err := os.WriteFile(filepath.Join(customDir, customProviderConfigName), []byte(strings.TrimSpace(providerYAML)+"\n"), 0o644); err != nil { + t.Fatalf("write provider.yaml: %v", err) + } + + _, err := loader.Load(context.Background()) + if err == nil || !strings.Contains(err.Error(), "duplicated") { + t.Fatalf("expected duplicate model id rejection, got %v", err) + } +} + +func TestLoaderParsesAutoCompactDerivedFields(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + raw := ` +selected_provider: openai +current_model: gpt-5.4 +shell: powershell +context: + auto_compact: + enabled: true + input_token_threshold: 0 + reserve_tokens: 9000 + fallback_input_token_threshold: 88000 +` + writeLoaderConfig(t, loader, raw) + + cfg, err := loader.Load(context.Background()) + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if cfg.Context.AutoCompact.InputTokenThreshold != 0 { + t.Fatalf("expected implicit threshold 0, got %d", cfg.Context.AutoCompact.InputTokenThreshold) + } + if cfg.Context.AutoCompact.ReserveTokens != 9000 { + t.Fatalf("expected reserve_tokens=9000, got %d", cfg.Context.AutoCompact.ReserveTokens) + } + if cfg.Context.AutoCompact.FallbackInputTokenThreshold != 88000 { + t.Fatalf("expected fallback_input_token_threshold=88000, got %d", cfg.Context.AutoCompact.FallbackInputTokenThreshold) + } +} + +func TestLoaderSavePersistsAutoCompactDerivedFields(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.ReserveTokens = 9000 + cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 + + if err := loader.Save(context.Background(), &cfg); err != nil { + t.Fatalf("Save() error = %v", err) + } + + data, err := os.ReadFile(loader.ConfigPath()) + if err != nil { + t.Fatalf("read config: %v", err) + } + text := string(data) + if strings.Contains(text, "input_token_threshold: 100000") { + t.Fatalf("expected implicit threshold to avoid legacy default, got:\n%s", text) + } + if !strings.Contains(text, "reserve_tokens: 9000") { + t.Fatalf("expected reserve_tokens to persist, got:\n%s", text) + } + if !strings.Contains(text, "fallback_input_token_threshold: 88000") { + t.Fatalf("expected fallback_input_token_threshold to persist, got:\n%s", text) + } +} + func TestLoaderRejectsCustomProviderNameConflictingWithBuiltin(t *testing.T) { t.Parallel() diff --git a/internal/config/provider_loader.go b/internal/config/provider_loader.go index f7a2b2ac..43448ea6 100644 --- a/internal/config/provider_loader.go +++ b/internal/config/provider_loader.go @@ -10,6 +10,7 @@ import ( "gopkg.in/yaml.v3" "neo-code/internal/provider" + providertypes "neo-code/internal/provider/types" ) const ( @@ -22,11 +23,19 @@ type customProviderFile struct { Driver string `yaml:"driver"` APIKeyEnv string `yaml:"api_key_env"` BaseURL string `yaml:"base_url,omitempty"` + Models []customProviderModelFile `yaml:"models,omitempty"` OpenAICompatible customOpenAICompatibleFile `yaml:"openai_compatible,omitempty"` Gemini customGeminiProviderFile `yaml:"gemini,omitempty"` Anthropic customAnthropicProviderFile `yaml:"anthropic,omitempty"` } +type customProviderModelFile struct { + ID string `yaml:"id"` + Name string `yaml:"name,omitempty"` + ContextWindow *int `yaml:"context_window,omitempty"` + MaxOutputTokens *int `yaml:"max_output_tokens,omitempty"` +} + type customOpenAICompatibleFile struct { BaseURL string `yaml:"base_url"` APIStyle string `yaml:"api_style,omitempty"` @@ -106,6 +115,11 @@ func loadCustomProvider(providerDir string) (ProviderConfig, error) { } settings := resolveCustomProviderSettings(file) + models, err := customProviderModels(file.Models) + if err != nil { + return ProviderConfig{}, fmt.Errorf("config: custom provider %q: %w", filepath.Base(providerDir), err) + } + cfg := ProviderConfig{ Name: strings.TrimSpace(file.Name), Driver: strings.TrimSpace(file.Driver), @@ -114,6 +128,7 @@ func loadCustomProvider(providerDir string) (ProviderConfig, error) { APIStyle: settings.APIStyle, DeploymentMode: settings.DeploymentMode, APIVersion: settings.APIVersion, + Models: models, Source: ProviderSourceCustom, } @@ -128,6 +143,48 @@ func loadCustomProvider(providerDir string) (ProviderConfig, error) { return cfg, nil } +// customProviderModels 校验并收敛 custom provider.yaml 中声明的模型元数据。 +func customProviderModels(models []customProviderModelFile) ([]providertypes.ModelDescriptor, error) { + if len(models) == 0 { + return nil, nil + } + + descriptors := make([]providertypes.ModelDescriptor, 0, len(models)) + seen := make(map[string]struct{}, len(models)) + for index, model := range models { + id := strings.TrimSpace(model.ID) + if id == "" { + return nil, fmt.Errorf("models[%d].id is empty", index) + } + + key := provider.NormalizeKey(id) + if _, exists := seen[key]; exists { + return nil, fmt.Errorf("models[%d].id %q is duplicated", index, id) + } + seen[key] = struct{}{} + + descriptor := providertypes.ModelDescriptor{ + ID: id, + Name: strings.TrimSpace(model.Name), + } + if model.ContextWindow != nil { + if *model.ContextWindow <= 0 { + return nil, fmt.Errorf("models[%d].context_window must be greater than 0", index) + } + descriptor.ContextWindow = *model.ContextWindow + } + if model.MaxOutputTokens != nil { + if *model.MaxOutputTokens <= 0 { + return nil, fmt.Errorf("models[%d].max_output_tokens must be greater than 0", index) + } + descriptor.MaxOutputTokens = *model.MaxOutputTokens + } + descriptors = append(descriptors, descriptor) + } + + return providertypes.MergeModelDescriptors(descriptors), nil +} + // resolveCustomProviderSettings 根据 driver 只提取当前协议真正生效的配置字段,避免误吃其他协议块的值。 // 已知 driver 仅从协议块读取 base_url;未知 driver 使用顶层 base_url 作为唯一入口。 func resolveCustomProviderSettings(file customProviderFile) customProviderSettings { diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 22525bb5..f364b7ce 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -194,6 +194,78 @@ func TestCloneProviderConfigModelDescriptorsIndependence(t *testing.T) { } } +func TestCustomProviderModelsParsesSupportedMetadata(t *testing.T) { + t.Parallel() + + contextWindow := 131072 + maxOutputTokens := 8192 + models, err := customProviderModels([]customProviderModelFile{ + { + ID: "deepseek-coder", + Name: "DeepSeek Coder", + ContextWindow: &contextWindow, + MaxOutputTokens: &maxOutputTokens, + }, + }) + if err != nil { + t.Fatalf("customProviderModels() error = %v", err) + } + + if len(models) != 1 { + t.Fatalf("expected one parsed model, got %+v", models) + } + if models[0].ID != "deepseek-coder" || models[0].ContextWindow != 131072 || models[0].MaxOutputTokens != 8192 { + t.Fatalf("unexpected parsed model descriptor: %+v", models[0]) + } +} + +func TestCustomProviderModelsRejectsEmptyID(t *testing.T) { + t.Parallel() + + _, err := customProviderModels([]customProviderModelFile{{Name: "Missing ID"}}) + if err == nil || !strings.Contains(err.Error(), "models[0].id") { + t.Fatalf("expected empty id validation error, got %v", err) + } +} + +func TestCustomProviderModelsRejectsNonPositiveContextWindow(t *testing.T) { + t.Parallel() + + contextWindow := 0 + _, err := customProviderModels([]customProviderModelFile{{ + ID: "deepseek-coder", + ContextWindow: &contextWindow, + }}) + if err == nil || !strings.Contains(err.Error(), "context_window") { + t.Fatalf("expected context_window validation error, got %v", err) + } +} + +func TestCustomProviderModelsRejectsNonPositiveMaxOutputTokens(t *testing.T) { + t.Parallel() + + maxOutputTokens := 0 + _, err := customProviderModels([]customProviderModelFile{{ + ID: "deepseek-coder", + MaxOutputTokens: &maxOutputTokens, + }}) + if err == nil || !strings.Contains(err.Error(), "max_output_tokens") { + t.Fatalf("expected max_output_tokens validation error, got %v", err) + } +} + +func TestCustomProviderModelsRejectsDuplicateID(t *testing.T) { + t.Parallel() + + _, err := customProviderModels([]customProviderModelFile{ + {ID: "deepseek-coder"}, + {ID: " DeepSeek-Coder "}, + }) + if err == nil || !strings.Contains(err.Error(), "duplicated") { + t.Fatalf("expected duplicate id validation error, got %v", err) + } +} + func TestProviderByNameCaseInsensitive(t *testing.T) { t.Parallel() diff --git a/internal/config/runtime.go b/internal/config/runtime.go new file mode 100644 index 00000000..a71b6de6 --- /dev/null +++ b/internal/config/runtime.go @@ -0,0 +1,53 @@ +package config + +import ( + "errors" +) + +const ( + DefaultMaxNoProgressStreak = 3 + DefaultMaxRepeatCycleStreak = 3 +) + +// RuntimeConfig 定义 runtime 层的可调参数。 +type RuntimeConfig struct { + MaxNoProgressStreak int `yaml:"max_no_progress_streak,omitempty"` + MaxRepeatCycleStreak int `yaml:"max_repeat_cycle_streak,omitempty"` +} + +// defaultRuntimeConfig 返回 runtime 配置的静态默认值。 +func defaultRuntimeConfig() RuntimeConfig { + return RuntimeConfig{ + MaxNoProgressStreak: DefaultMaxNoProgressStreak, + MaxRepeatCycleStreak: DefaultMaxRepeatCycleStreak, + } +} + +// Clone 复制 runtime 配置,避免调用方共享可变状态。 +func (c RuntimeConfig) Clone() RuntimeConfig { + return c +} + +// ApplyDefaults 在配置缺失或非法时回填默认阈值。 +func (c *RuntimeConfig) ApplyDefaults(defaults RuntimeConfig) { + if c == nil { + return + } + if c.MaxNoProgressStreak <= 0 { + c.MaxNoProgressStreak = defaults.MaxNoProgressStreak + } + if c.MaxRepeatCycleStreak <= 0 { + c.MaxRepeatCycleStreak = defaults.MaxRepeatCycleStreak + } +} + +// Validate 校验 runtime 配置是否满足最小约束。 +func (c RuntimeConfig) Validate() error { + if c.MaxNoProgressStreak <= 0 { + return errors.New("max_no_progress_streak must be greater than 0") + } + if c.MaxRepeatCycleStreak <= 0 { + return errors.New("max_repeat_cycle_streak must be greater than 0") + } + return nil +} diff --git a/internal/config/runtime_test.go b/internal/config/runtime_test.go new file mode 100644 index 00000000..bd493427 --- /dev/null +++ b/internal/config/runtime_test.go @@ -0,0 +1,69 @@ +package config + +import "testing" + +func TestRuntimeConfigClone(t *testing.T) { + t.Parallel() + + cfg := RuntimeConfig{MaxNoProgressStreak: 7, MaxRepeatCycleStreak: 4} + cloned := cfg.Clone() + if cloned.MaxNoProgressStreak != 7 { + t.Fatalf("expected cloned MaxNoProgressStreak=7, got %d", cloned.MaxNoProgressStreak) + } + if cloned.MaxRepeatCycleStreak != 4 { + t.Fatalf("expected cloned MaxRepeatCycleStreak=4, got %d", cloned.MaxRepeatCycleStreak) + } +} + +func TestRuntimeConfigApplyDefaults(t *testing.T) { + t.Parallel() + + defaults := RuntimeConfig{MaxNoProgressStreak: 3, MaxRepeatCycleStreak: 5} + + cfg := RuntimeConfig{MaxNoProgressStreak: 0, MaxRepeatCycleStreak: 0} + cfg.ApplyDefaults(defaults) + if cfg.MaxNoProgressStreak != 3 { + t.Fatalf("expected defaulted MaxNoProgressStreak=3, got %d", cfg.MaxNoProgressStreak) + } + if cfg.MaxRepeatCycleStreak != 5 { + t.Fatalf("expected defaulted MaxRepeatCycleStreak=5, got %d", cfg.MaxRepeatCycleStreak) + } + + cfg = RuntimeConfig{MaxNoProgressStreak: 5, MaxRepeatCycleStreak: 8} + cfg.ApplyDefaults(defaults) + if cfg.MaxNoProgressStreak != 5 { + t.Fatalf("expected existing MaxNoProgressStreak=5 to be preserved, got %d", cfg.MaxNoProgressStreak) + } + if cfg.MaxRepeatCycleStreak != 8 { + t.Fatalf("expected existing MaxRepeatCycleStreak=8 to be preserved, got %d", cfg.MaxRepeatCycleStreak) + } + + cfg = RuntimeConfig{MaxNoProgressStreak: 2, MaxRepeatCycleStreak: -1} + cfg.ApplyDefaults(defaults) + if cfg.MaxRepeatCycleStreak != 5 { + t.Fatalf("expected negative MaxRepeatCycleStreak=-1 to be replaced by default=5, got %d", cfg.MaxRepeatCycleStreak) + } + + var nilCfg *RuntimeConfig + nilCfg.ApplyDefaults(defaults) +} + +func TestRuntimeConfigValidate(t *testing.T) { + t.Parallel() + + if err := (RuntimeConfig{MaxNoProgressStreak: 1, MaxRepeatCycleStreak: 1}).Validate(); err != nil { + t.Fatalf("expected valid config, got %v", err) + } + + for _, bad := range []int{0, -1, -99} { + if err := (RuntimeConfig{MaxNoProgressStreak: bad}).Validate(); err == nil { + t.Fatalf("expected validation error for MaxNoProgressStreak=%d", bad) + } + } + + for _, bad := range []int{0, -1, -99} { + if err := (RuntimeConfig{MaxNoProgressStreak: 1, MaxRepeatCycleStreak: bad}).Validate(); err == nil { + t.Fatalf("expected validation error for MaxRepeatCycleStreak=%d", bad) + } + } +} diff --git a/internal/config/state/auto_compact_threshold.go b/internal/config/state/auto_compact_threshold.go new file mode 100644 index 00000000..490a2d54 --- /dev/null +++ b/internal/config/state/auto_compact_threshold.go @@ -0,0 +1,90 @@ +package state + +import ( + "context" + "strings" + + "neo-code/internal/config" + "neo-code/internal/provider" +) + +// AutoCompactThresholdSource 标识自动压缩阈值最终采用的来源。 +type AutoCompactThresholdSource string + +const ( + AutoCompactThresholdSourceDisabled AutoCompactThresholdSource = "disabled" + AutoCompactThresholdSourceExplicit AutoCompactThresholdSource = "explicit" + AutoCompactThresholdSourceDerived AutoCompactThresholdSource = "derived" + AutoCompactThresholdSourceFallback AutoCompactThresholdSource = "fallback" +) + +// AutoCompactThresholdResolution 描述自动压缩阈值的解析结果,供 runtime 直接消费。 +type AutoCompactThresholdResolution struct { + Threshold int + Source AutoCompactThresholdSource + ContextWindow int + ModelID string +} + +// fallbackAutoCompactThresholdResolution 构造自动推导失败时使用的保底阈值结果。 +func fallbackAutoCompactThresholdResolution(cfg config.Config) AutoCompactThresholdResolution { + return AutoCompactThresholdResolution{ + Threshold: cfg.Context.AutoCompact.FallbackInputTokenThreshold, + Source: AutoCompactThresholdSourceFallback, + ModelID: strings.TrimSpace(cfg.CurrentModel), + } +} + +// ResolveAutoCompactThreshold 基于当前选择的 provider/model 和模型目录快照解析最终阈值。 +func ResolveAutoCompactThreshold( + ctx context.Context, + cfg config.Config, + catalogs ModelCatalog, +) (AutoCompactThresholdResolution, error) { + autoCompact := cfg.Context.AutoCompact + if !autoCompact.Enabled { + return AutoCompactThresholdResolution{Source: AutoCompactThresholdSourceDisabled}, nil + } + + if autoCompact.InputTokenThreshold > 0 { + return AutoCompactThresholdResolution{ + Threshold: autoCompact.InputTokenThreshold, + Source: AutoCompactThresholdSourceExplicit, + ModelID: strings.TrimSpace(cfg.CurrentModel), + }, nil + } + + resolution := fallbackAutoCompactThresholdResolution(cfg) + providerCfg, err := selectedProviderConfig(cfg) + if err != nil { + return resolution, nil + } + if catalogs == nil { + return resolution, nil + } + + input, err := catalogInputFromProvider(providerCfg) + if err != nil { + return resolution, nil + } + + models, err := catalogs.ListProviderModelsSnapshot(ctx, input) + if err != nil { + return resolution, nil + } + + modelID := provider.NormalizeKey(cfg.CurrentModel) + for _, model := range models { + if provider.NormalizeKey(model.ID) != modelID { + continue + } + resolution.ContextWindow = model.ContextWindow + if model.ContextWindow > autoCompact.ReserveTokens { + resolution.Threshold = model.ContextWindow - autoCompact.ReserveTokens + resolution.Source = AutoCompactThresholdSourceDerived + } + return resolution, nil + } + + return resolution, nil +} diff --git a/internal/config/state/auto_compact_threshold_test.go b/internal/config/state/auto_compact_threshold_test.go new file mode 100644 index 00000000..24bd9efc --- /dev/null +++ b/internal/config/state/auto_compact_threshold_test.go @@ -0,0 +1,166 @@ +package state + +import ( + "context" + "errors" + "testing" + + configpkg "neo-code/internal/config" + providertypes "neo-code/internal/provider/types" +) + +func TestResolveAutoCompactThresholdDisabled(t *testing.T) { + t.Parallel() + + cfg := configpkg.StaticDefaults().Clone() + cfg.Context.AutoCompact.Enabled = false + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, nil) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 0 || resolution.Source != AutoCompactThresholdSourceDisabled { + t.Fatalf("expected disabled resolution, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdExplicitWins(t *testing.T) { + t.Parallel() + + cfg := configpkg.StaticDefaults().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 42000 + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, nil) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 42000 || resolution.Source != AutoCompactThresholdSourceExplicit { + t.Fatalf("expected explicit resolution, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdDerivedFromContextWindow(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.ReserveTokens = 13000 + cfg.CurrentModel = "deepseek-coder" + cfg.Providers[0].Model = "deepseek-coder" + cfg.Providers[0].Models = []providertypes.ModelDescriptor{{ + ID: "deepseek-coder", + ContextWindow: 131072, + }} + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{ + snapshotModels: cfg.Providers[0].Models, + }) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 118072 || resolution.Source != AutoCompactThresholdSourceDerived { + t.Fatalf("expected derived threshold, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdFallsBackWhenWindowTooSmall(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.ReserveTokens = 13000 + cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 + cfg.CurrentModel = "small-model" + cfg.Providers[0].Model = "small-model" + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{ + snapshotModels: []providertypes.ModelDescriptor{{ + ID: "small-model", + ContextWindow: 8000, + }}, + }) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 88000 || resolution.Source != AutoCompactThresholdSourceFallback { + t.Fatalf("expected fallback threshold, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdFallsBackWhenModelMissing(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 + cfg.CurrentModel = "missing-model" + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{ + snapshotModels: []providertypes.ModelDescriptor{{ID: "other-model", ContextWindow: 131072}}, + }) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 88000 || resolution.Source != AutoCompactThresholdSourceFallback { + t.Fatalf("expected missing model to use fallback, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdFallsBackWhenSelectedProviderInvalid(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 + cfg.SelectedProvider = "missing-provider" + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{}) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 88000 || resolution.Source != AutoCompactThresholdSourceFallback { + t.Fatalf("expected invalid selection to use fallback, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdFallsBackWhenCatalogInputResolutionFails(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 + cfg.Providers[0].BaseURL = "" + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{}) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 88000 || resolution.Source != AutoCompactThresholdSourceFallback { + t.Fatalf("expected invalid catalog input to use fallback, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdFallsBackWhenSnapshotLookupFails(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{ + snapshotErr: errors.New("snapshot failed"), + }) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 88000 || resolution.Source != AutoCompactThresholdSourceFallback { + t.Fatalf("expected snapshot error to use fallback, got %+v", resolution) + } +} diff --git a/internal/context/builder_test.go b/internal/context/builder_test.go index efe32e55..b17921e7 100644 --- a/internal/context/builder_test.go +++ b/internal/context/builder_test.go @@ -777,3 +777,45 @@ func TestProjectToolMessagesForModelKeepsBuilderProjectionBehavior(t *testing.T) t.Fatal("expected source messages to remain unchanged") } } + +func TestDefaultBuilderBuildProjectsMetadataOnlyToolResult(t *testing.T) { + t.Parallel() + + builder := NewBuilder() + result, err := builder.Build(stdcontext.Background(), BuildInput{ + Messages: []providertypes.Message{ + {Role: providertypes.RoleUser, Content: "inspect README"}, + { + Role: providertypes.RoleAssistant, + ToolCalls: []providertypes.ToolCall{ + {ID: "call-1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`}, + }, + }, + { + Role: providertypes.RoleTool, + ToolCallID: "call-1", + Content: " ", + ToolMetadata: map[string]string{"tool_name": "filesystem_read_file", "path": "README.md"}, + }, + }, + Metadata: testMetadata(t.TempDir()), + }) + if err != nil { + t.Fatalf("Build() error = %v", err) + } + if len(result.Messages) != 3 { + t.Fatalf("len(result.Messages) = %d, want 3", len(result.Messages)) + } + toolMessage := result.Messages[2] + if toolMessage.Role != providertypes.RoleTool { + t.Fatalf("expected tool message at index 2, got %+v", toolMessage) + } + if !strings.Contains(toolMessage.Content, "tool result") || + !strings.Contains(toolMessage.Content, "tool: filesystem_read_file") || + !strings.Contains(toolMessage.Content, "meta.path: README.md") { + t.Fatalf("expected metadata-only tool result to be projected, got %q", toolMessage.Content) + } + if toolMessage.ToolMetadata != nil { + t.Fatalf("expected projected tool metadata to be cleared, got %#v", toolMessage.ToolMetadata) + } +} diff --git a/internal/context/projection.go b/internal/context/projection.go index 86c5a6dd..85e7ecd0 100644 --- a/internal/context/projection.go +++ b/internal/context/projection.go @@ -145,7 +145,10 @@ func isInjectableToolMessage(message providertypes.Message) bool { return false } content := strings.TrimSpace(message.Content) - return content != "" && content != microCompactClearedMessage + if content == microCompactClearedMessage { + return false + } + return content != "" || len(message.ToolMetadata) > 0 } // recentWindowMessageBudget 计算 recent window 可保留的消息总数硬上限,避免窗口体积失控。 diff --git a/internal/context/projection_test.go b/internal/context/projection_test.go index 84116afa..5b8f428a 100644 --- a/internal/context/projection_test.go +++ b/internal/context/projection_test.go @@ -43,10 +43,10 @@ func TestProjectToolMessagesForModelSkipsMessagesThatCannotBeProjected(t *testin t.Fatalf("non-tool message should remain unchanged, got %+v", projected[0]) } if projected[1].Content != "tool output" || projected[1].ToolMetadata != nil { - t.Fatalf("tool without metadata-free projection should remain unchanged, got %+v", projected[1]) + t.Fatalf("tool without projection metadata should remain unchanged, got %+v", projected[1]) } - if projected[2].Content != " " || projected[2].ToolMetadata == nil { - t.Fatalf("empty tool content should not be projected, got %+v", projected[2]) + if !strings.Contains(projected[2].Content, "tool result") || projected[2].ToolMetadata != nil { + t.Fatalf("metadata-only tool message should be projected, got %+v", projected[2]) } if projected[3].Content != microCompactClearedMessage || projected[3].ToolMetadata == nil { t.Fatalf("cleared tool content should not be projected, got %+v", projected[3]) @@ -221,7 +221,7 @@ func TestMatchedToolCallSpanRejectsInvalidAssistantStates(t *testing.T) { } } -func TestMatchedToolCallSpanRequiresInjectableResponsesAndSkipsDuplicates(t *testing.T) { +func TestMatchedToolCallSpanRequiresProjectableResponsesAndSkipsDuplicates(t *testing.T) { t.Parallel() messages := []providertypes.Message{ @@ -244,11 +244,35 @@ func TestMatchedToolCallSpanRequiresInjectableResponsesAndSkipsDuplicates(t *tes if len(span) != 3 { t.Fatalf("len(span) = %d, want 3 (%+v)", len(span), span) } - if span[0] != 0 || span[1] != 2 || span[2] != 5 { + if span[0] != 0 || span[1] != 1 || span[2] != 5 { t.Fatalf("unexpected span indexes %+v", span) } } +func TestMatchedToolCallSpanAcceptsMetadataOnlyResponses(t *testing.T) { + t.Parallel() + + messages := []providertypes.Message{ + { + Role: providertypes.RoleAssistant, + ToolCalls: []providertypes.ToolCall{ + {ID: "call-1", Name: "webfetch", Arguments: `{}`}, + }, + }, + { + Role: providertypes.RoleTool, + ToolCallID: "call-1", + Content: " ", + ToolMetadata: map[string]string{"tool_name": "webfetch", "status_code": "200"}, + }, + } + + span := matchedToolCallSpan(messages, 0) + if len(span) != 2 || span[0] != 0 || span[1] != 1 { + t.Fatalf("unexpected metadata-only span %+v", span) + } +} + func TestMatchedToolCallSpanRejectsResponsesWithoutProjectionMetadata(t *testing.T) { t.Parallel() @@ -285,6 +309,11 @@ func TestIsInjectableToolMessage(t *testing.T) { message: providertypes.Message{Role: providertypes.RoleTool, Content: " "}, want: false, }, + { + name: "metadata-only", + message: providertypes.Message{Role: providertypes.RoleTool, Content: " ", ToolMetadata: map[string]string{"tool_name": "bash"}}, + want: true, + }, { name: "cleared", message: providertypes.Message{Role: providertypes.RoleTool, Content: microCompactClearedMessage}, diff --git a/internal/gateway/adapters/urlscheme/dispatcher.go b/internal/gateway/adapters/urlscheme/dispatcher.go index 2abf68a3..32fa9d2f 100644 --- a/internal/gateway/adapters/urlscheme/dispatcher.go +++ b/internal/gateway/adapters/urlscheme/dispatcher.go @@ -1,11 +1,13 @@ package urlscheme import ( + "bytes" "context" "encoding/json" "errors" "fmt" "net" + "strings" "sync/atomic" "time" @@ -110,28 +112,69 @@ func (d *Dispatcher) Dispatch(ctx context.Context, request DispatchRequest) (Dis Payload: intent, } + requestIDRaw, err := marshalJSONRawMessage(requestFrame.RequestID) + if err != nil { + return DispatchResult{}, newDispatchError(ErrorCodeInternal, fmt.Sprintf("encode request id: %v", err)) + } + requestParamsRaw, err := marshalJSONRawMessage(intent) + if err != nil { + return DispatchResult{}, newDispatchError(ErrorCodeInternal, fmt.Sprintf("encode request params: %v", err)) + } + rpcRequest := protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + ID: requestIDRaw, + Method: protocol.MethodWakeOpenURL, + Params: requestParamsRaw, + } + if err := ensureDispatchContextActive(ctx); err != nil { return DispatchResult{}, toDispatchError(err) } encoder := json.NewEncoder(conn) - if err := encoder.Encode(requestFrame); err != nil { + if err := encoder.Encode(rpcRequest); err != nil { if ctx != nil && ctx.Err() != nil { ctxErr := ctx.Err() return DispatchResult{}, toDispatchError(ctxErr) } - return DispatchResult{}, newDispatchError(ErrorCodeInternal, fmt.Sprintf("write request frame: %v", err)) + return DispatchResult{}, newDispatchError(ErrorCodeInternal, fmt.Sprintf("write request rpc: %v", err)) } - var responseFrame gateway.MessageFrame + var rpcResponse protocol.JSONRPCResponse if err := ensureDispatchContextActive(ctx); err != nil { return DispatchResult{}, toDispatchError(err) } decoder := json.NewDecoder(conn) - if err := decoder.Decode(&responseFrame); err != nil { + if err := decoder.Decode(&rpcResponse); err != nil { if ctx != nil && ctx.Err() != nil { ctxErr := ctx.Err() return DispatchResult{}, toDispatchError(ctxErr) } + return DispatchResult{}, newDispatchError(ErrorCodeUnexpectedResponse, fmt.Sprintf("decode response rpc: %v", err)) + } + if strings.TrimSpace(rpcResponse.JSONRPC) != protocol.JSONRPCVersion { + return DispatchResult{}, newDispatchError( + ErrorCodeUnexpectedResponse, + "unexpected response jsonrpc version", + ) + } + if !rawJSONMessageEqual(rpcResponse.ID, rpcRequest.ID) { + return DispatchResult{}, newDispatchError(ErrorCodeUnexpectedResponse, "rpc correlation failed: id mismatch") + } + if rpcResponse.Error != nil && rpcResponse.Result != nil { + return DispatchResult{}, newDispatchError( + ErrorCodeUnexpectedResponse, + "unexpected response payload: both result and error are present", + ) + } + if rpcResponse.Error != nil { + return DispatchResult{}, toDispatchErrorFromJSONRPC(rpcResponse.Error) + } + if rpcResponse.Result == nil { + return DispatchResult{}, newDispatchError(ErrorCodeUnexpectedResponse, "gateway response missing result payload") + } + + responseFrame, err := decodeResponseFrameResult(rpcResponse.Result) + if err != nil { return DispatchResult{}, newDispatchError(ErrorCodeUnexpectedResponse, fmt.Sprintf("decode response frame: %v", err)) } if responseFrame.Action != requestFrame.Action || responseFrame.RequestID != requestFrame.RequestID { @@ -201,6 +244,60 @@ func watchDispatchCancellation(ctx context.Context, conn net.Conn) func() { } } +// toDispatchErrorFromJSONRPC 将 JSON-RPC 错误对象映射为 url-dispatch 稳定错误。 +func toDispatchErrorFromJSONRPC(rpcError *protocol.JSONRPCError) error { + if rpcError == nil { + return newDispatchError(ErrorCodeUnexpectedResponse, "gateway returned empty rpc error payload") + } + + code := strings.TrimSpace(protocol.GatewayCodeFromJSONRPCError(rpcError)) + if code == "" { + code = mapJSONRPCCodeToDispatchCode(rpcError.Code) + } + message := strings.TrimSpace(rpcError.Message) + if message == "" { + message = "gateway returned empty rpc error message" + } + return newDispatchError(code, message) +} + +// mapJSONRPCCodeToDispatchCode 为缺少 gateway_code 的响应提供兜底错误码映射。 +func mapJSONRPCCodeToDispatchCode(code int) string { + switch code { + case protocol.JSONRPCCodeMethodNotFound: + return gateway.ErrorCodeUnsupportedAction.String() + case protocol.JSONRPCCodeInvalidRequest, protocol.JSONRPCCodeInvalidParams, protocol.JSONRPCCodeParseError: + return gateway.ErrorCodeInvalidFrame.String() + case protocol.JSONRPCCodeInternalError: + return gateway.ErrorCodeInternalError.String() + default: + return ErrorCodeInternal + } +} + +// decodeResponseFrameResult 将 JSON-RPC result 安全解码回 MessageFrame。 +func decodeResponseFrameResult(result json.RawMessage) (gateway.MessageFrame, error) { + var frame gateway.MessageFrame + if err := json.Unmarshal(result, &frame); err != nil { + return gateway.MessageFrame{}, err + } + return frame, nil +} + +// rawJSONMessageEqual 比较两段 JSON 原文在去除首尾空白后的字节是否一致。 +func rawJSONMessageEqual(left, right json.RawMessage) bool { + return bytes.Equal(bytes.TrimSpace(left), bytes.TrimSpace(right)) +} + +// marshalJSONRawMessage 将任意对象编码为 json.RawMessage,便于构造 JSON-RPC 请求字段。 +func marshalJSONRawMessage(payload any) (json.RawMessage, error) { + raw, err := json.Marshal(payload) + if err != nil { + return nil, err + } + return json.RawMessage(raw), nil +} + // toDispatchError 将不同来源错误转换为统一结构化错误。 func toDispatchError(err error) error { if err == nil { diff --git a/internal/gateway/adapters/urlscheme/dispatcher_test.go b/internal/gateway/adapters/urlscheme/dispatcher_test.go index e371a4e2..75f1427c 100644 --- a/internal/gateway/adapters/urlscheme/dispatcher_test.go +++ b/internal/gateway/adapters/urlscheme/dispatcher_test.go @@ -12,6 +12,7 @@ import ( "time" "neo-code/internal/gateway" + "neo-code/internal/gateway/protocol" "neo-code/internal/gateway/transport" ) @@ -40,25 +41,50 @@ func TestDispatcherDispatchSuccess(t *testing.T) { decoder := json.NewDecoder(serverConn) encoder := json.NewEncoder(serverConn) - var requestFrame gateway.MessageFrame - if err := decoder.Decode(&requestFrame); err != nil { - t.Errorf("decode request frame: %v", err) + var rpcRequest protocol.JSONRPCRequest + if err := decoder.Decode(&rpcRequest); err != nil { + t.Errorf("decode request rpc: %v", err) return } - if requestFrame.Action != gateway.FrameActionWakeOpenURL { - t.Errorf("request action = %q, want %q", requestFrame.Action, gateway.FrameActionWakeOpenURL) + if rpcRequest.Method != protocol.MethodWakeOpenURL { + t.Errorf("request method = %q, want %q", rpcRequest.Method, protocol.MethodWakeOpenURL) + return + } + if rpcRequest.JSONRPC != protocol.JSONRPCVersion { + t.Errorf("request jsonrpc = %q, want %q", rpcRequest.JSONRPC, protocol.JSONRPCVersion) + return + } + if len(bytes.TrimSpace(rpcRequest.ID)) == 0 { + t.Error("request id should not be empty") + return + } + var params protocol.WakeIntent + if err := json.Unmarshal(rpcRequest.Params, ¶ms); err != nil { + t.Errorf("decode request params: %v", err) + return + } + if params.Action != protocol.WakeActionReview { + t.Errorf("request params action = %q, want %q", params.Action, protocol.WakeActionReview) + return + } + if got := params.Params["path"]; got != "README.md" { + t.Errorf("request params[path] = %q, want %q", got, "README.md") return } - if err := encoder.Encode(gateway.MessageFrame{ - Type: gateway.FrameTypeAck, - Action: gateway.FrameActionWakeOpenURL, - RequestID: requestFrame.RequestID, - Payload: map[string]any{ - "message": "wake intent accepted", - }, + if err := encoder.Encode(protocol.JSONRPCResponse{ + JSONRPC: protocol.JSONRPCVersion, + ID: rpcRequest.ID, + Result: mustMarshalRawJSON(t, gateway.MessageFrame{ + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionWakeOpenURL, + RequestID: "wake-1", + Payload: map[string]any{ + "message": "wake intent accepted", + }, + }), }); err != nil { - t.Errorf("encode response frame: %v", err) + t.Errorf("encode response rpc: %v", err) } }() @@ -94,16 +120,16 @@ func TestDispatcherDispatchReturnsGatewayError(t *testing.T) { go func() { decoder := json.NewDecoder(serverConn) encoder := json.NewEncoder(serverConn) - var requestFrame gateway.MessageFrame - _ = decoder.Decode(&requestFrame) - _ = encoder.Encode(gateway.MessageFrame{ - Type: gateway.FrameTypeError, - Action: requestFrame.Action, - RequestID: requestFrame.RequestID, - Error: &gateway.FrameError{ - Code: gateway.ErrorCodeInvalidAction.String(), - Message: "unsupported wake action", - }, + var rpcRequest protocol.JSONRPCRequest + _ = decoder.Decode(&rpcRequest) + _ = encoder.Encode(protocol.JSONRPCResponse{ + JSONRPC: protocol.JSONRPCVersion, + ID: rpcRequest.ID, + Error: protocol.NewJSONRPCError( + protocol.JSONRPCCodeInvalidParams, + "unsupported wake action", + gateway.ErrorCodeInvalidAction.String(), + ), }) }() @@ -139,12 +165,16 @@ func TestDispatcherDispatchReturnsUnexpectedResponseError(t *testing.T) { go func() { decoder := json.NewDecoder(serverConn) encoder := json.NewEncoder(serverConn) - var requestFrame gateway.MessageFrame - _ = decoder.Decode(&requestFrame) - _ = encoder.Encode(gateway.MessageFrame{ - Type: gateway.FrameTypeEvent, - Action: requestFrame.Action, - RequestID: requestFrame.RequestID, + var rpcRequest protocol.JSONRPCRequest + _ = decoder.Decode(&rpcRequest) + _ = encoder.Encode(protocol.JSONRPCResponse{ + JSONRPC: protocol.JSONRPCVersion, + ID: rpcRequest.ID, + Result: mustMarshalRawJSON(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionWakeOpenURL, + RequestID: "wake-3", + }), }) }() @@ -179,12 +209,16 @@ func TestDispatcherDispatchReturnsCorrelationMismatchError(t *testing.T) { go func() { decoder := json.NewDecoder(serverConn) encoder := json.NewEncoder(serverConn) - var requestFrame gateway.MessageFrame - _ = decoder.Decode(&requestFrame) - _ = encoder.Encode(gateway.MessageFrame{ - Type: gateway.FrameTypeAck, - Action: requestFrame.Action, - RequestID: "wake-mismatch", + var rpcRequest protocol.JSONRPCRequest + _ = decoder.Decode(&rpcRequest) + _ = encoder.Encode(protocol.JSONRPCResponse{ + JSONRPC: protocol.JSONRPCVersion, + ID: rpcRequest.ID, + Result: mustMarshalRawJSON(t, gateway.MessageFrame{ + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionWakeOpenURL, + RequestID: "wake-mismatch", + }), }) }() @@ -302,9 +336,9 @@ func TestDispatcherDispatchInterruptsBlockedReadOnContextCancel(t *testing.T) { defer close(serverDone) decoder := json.NewDecoder(serverConn) - var frame gateway.MessageFrame - if err := decoder.Decode(&frame); err != nil { - t.Errorf("decode request frame: %v", err) + var rpcRequest protocol.JSONRPCRequest + if err := decoder.Decode(&rpcRequest); err != nil { + t.Errorf("decode request rpc: %v", err) return } close(requestArrived) @@ -573,12 +607,12 @@ func TestDispatcherDispatchAdditionalErrorBranches(t *testing.T) { } }) - t.Run("gateway error frame missing payload", func(t *testing.T) { + t.Run("gateway response missing result payload", func(t *testing.T) { dispatcher := &Dispatcher{ resolveListenAddressFn: func(string) (string, error) { return "stub://gateway", nil }, dialFn: func(string) (net.Conn, error) { return &stubDispatchConn{ - readBuffer: bytes.NewBufferString(`{"type":"error","action":"wake.openUrl","request_id":"wake-14"}` + "\n"), + readBuffer: bytes.NewBufferString(`{"jsonrpc":"2.0","id":"wake-14"}` + "\n"), }, nil }, requestIDFn: func() string { return "wake-14" }, @@ -588,7 +622,158 @@ func TestDispatcherDispatchAdditionalErrorBranches(t *testing.T) { RawURL: "neocode://review?path=README.md", }) if err == nil { - t.Fatal("expected missing error payload branch") + t.Fatal("expected missing result payload branch") + } + var dispatchErr *DispatchError + if !errors.As(err, &dispatchErr) { + t.Fatalf("error type = %T, want *DispatchError", err) + } + if dispatchErr.Code != ErrorCodeUnexpectedResponse { + t.Fatalf("error code = %q, want %q", dispatchErr.Code, ErrorCodeUnexpectedResponse) + } + }) + + t.Run("response rpc version mismatch", func(t *testing.T) { + dispatcher := &Dispatcher{ + resolveListenAddressFn: func(string) (string, error) { return "stub://gateway", nil }, + dialFn: func(string) (net.Conn, error) { + return &stubDispatchConn{ + readBuffer: bytes.NewBufferString(`{"jsonrpc":"1.0","id":"wake-15","result":{}}` + "\n"), + }, nil + }, + requestIDFn: func() string { return "wake-15" }, + } + + _, err := dispatcher.Dispatch(context.Background(), DispatchRequest{RawURL: "neocode://review?path=README.md"}) + if err == nil { + t.Fatal("expected rpc version mismatch error") + } + var dispatchErr *DispatchError + if !errors.As(err, &dispatchErr) { + t.Fatalf("error type = %T, want *DispatchError", err) + } + if dispatchErr.Code != ErrorCodeUnexpectedResponse { + t.Fatalf("error code = %q, want %q", dispatchErr.Code, ErrorCodeUnexpectedResponse) + } + }) + + t.Run("response rpc id mismatch", func(t *testing.T) { + dispatcher := &Dispatcher{ + resolveListenAddressFn: func(string) (string, error) { return "stub://gateway", nil }, + dialFn: func(string) (net.Conn, error) { + return &stubDispatchConn{ + readBuffer: bytes.NewBufferString(`{"jsonrpc":"2.0","id":"wake-other","result":{}}` + "\n"), + }, nil + }, + requestIDFn: func() string { return "wake-16" }, + } + + _, err := dispatcher.Dispatch(context.Background(), DispatchRequest{RawURL: "neocode://review?path=README.md"}) + if err == nil { + t.Fatal("expected rpc id mismatch error") + } + var dispatchErr *DispatchError + if !errors.As(err, &dispatchErr) { + t.Fatalf("error type = %T, want *DispatchError", err) + } + if dispatchErr.Code != ErrorCodeUnexpectedResponse { + t.Fatalf("error code = %q, want %q", dispatchErr.Code, ErrorCodeUnexpectedResponse) + } + }) + + t.Run("response contains both result and error", func(t *testing.T) { + dispatcher := &Dispatcher{ + resolveListenAddressFn: func(string) (string, error) { return "stub://gateway", nil }, + dialFn: func(string) (net.Conn, error) { + return &stubDispatchConn{ + readBuffer: bytes.NewBufferString( + `{"jsonrpc":"2.0","id":"wake-17","result":{},"error":{"code":-32603,"message":"boom"}}` + "\n", + ), + }, nil + }, + requestIDFn: func() string { return "wake-17" }, + } + + _, err := dispatcher.Dispatch(context.Background(), DispatchRequest{RawURL: "neocode://review?path=README.md"}) + if err == nil { + t.Fatal("expected both result and error payload failure") + } + var dispatchErr *DispatchError + if !errors.As(err, &dispatchErr) { + t.Fatalf("error type = %T, want *DispatchError", err) + } + if dispatchErr.Code != ErrorCodeUnexpectedResponse { + t.Fatalf("error code = %q, want %q", dispatchErr.Code, ErrorCodeUnexpectedResponse) + } + }) + + t.Run("rpc error without gateway_code uses fallback code map", func(t *testing.T) { + dispatcher := &Dispatcher{ + resolveListenAddressFn: func(string) (string, error) { return "stub://gateway", nil }, + dialFn: func(string) (net.Conn, error) { + return &stubDispatchConn{ + readBuffer: bytes.NewBufferString( + `{"jsonrpc":"2.0","id":"wake-18","error":{"code":-32601,"message":"method not found"}}` + "\n", + ), + }, nil + }, + requestIDFn: func() string { return "wake-18" }, + } + + _, err := dispatcher.Dispatch(context.Background(), DispatchRequest{RawURL: "neocode://review?path=README.md"}) + if err == nil { + t.Fatal("expected rpc error mapping failure") + } + var dispatchErr *DispatchError + if !errors.As(err, &dispatchErr) { + t.Fatalf("error type = %T, want *DispatchError", err) + } + if dispatchErr.Code != gateway.ErrorCodeUnsupportedAction.String() { + t.Fatalf("error code = %q, want %q", dispatchErr.Code, gateway.ErrorCodeUnsupportedAction.String()) + } + }) + + t.Run("rpc error with empty message uses fallback text", func(t *testing.T) { + dispatcher := &Dispatcher{ + resolveListenAddressFn: func(string) (string, error) { return "stub://gateway", nil }, + dialFn: func(string) (net.Conn, error) { + return &stubDispatchConn{ + readBuffer: bytes.NewBufferString(`{"jsonrpc":"2.0","id":"wake-19","error":{"code":-32603,"message":""}}` + "\n"), + }, nil + }, + requestIDFn: func() string { return "wake-19" }, + } + + _, err := dispatcher.Dispatch(context.Background(), DispatchRequest{RawURL: "neocode://review?path=README.md"}) + if err == nil { + t.Fatal("expected rpc error mapping failure") + } + var dispatchErr *DispatchError + if !errors.As(err, &dispatchErr) { + t.Fatalf("error type = %T, want *DispatchError", err) + } + if dispatchErr.Code != gateway.ErrorCodeInternalError.String() { + t.Fatalf("error code = %q, want %q", dispatchErr.Code, gateway.ErrorCodeInternalError.String()) + } + if !strings.Contains(dispatchErr.Message, "empty rpc error message") { + t.Fatalf("error message = %q, want fallback text", dispatchErr.Message) + } + }) + + t.Run("decode response frame failed", func(t *testing.T) { + dispatcher := &Dispatcher{ + resolveListenAddressFn: func(string) (string, error) { return "stub://gateway", nil }, + dialFn: func(string) (net.Conn, error) { + return &stubDispatchConn{ + readBuffer: bytes.NewBufferString(`{"jsonrpc":"2.0","id":"wake-20","result":"not-frame"}` + "\n"), + }, nil + }, + requestIDFn: func() string { return "wake-20" }, + } + + _, err := dispatcher.Dispatch(context.Background(), DispatchRequest{RawURL: "neocode://review?path=README.md"}) + if err == nil { + t.Fatal("expected decode frame failure") } var dispatchErr *DispatchError if !errors.As(err, &dispatchErr) { @@ -600,6 +785,49 @@ func TestDispatcherDispatchAdditionalErrorBranches(t *testing.T) { }) } +func TestDispatcherJSONRPCHelpers(t *testing.T) { + marshalErr := toDispatchErrorFromJSONRPC(&protocol.JSONRPCError{ + Code: protocol.JSONRPCCodeInternalError, + Message: "boom", + }) + var dispatchErr *DispatchError + if !errors.As(marshalErr, &dispatchErr) { + t.Fatalf("error type = %T, want *DispatchError", marshalErr) + } + if dispatchErr.Code != gateway.ErrorCodeInternalError.String() { + t.Fatalf("error code = %q, want %q", dispatchErr.Code, gateway.ErrorCodeInternalError.String()) + } + + emptyErr := toDispatchErrorFromJSONRPC(nil) + if !errors.As(emptyErr, &dispatchErr) { + t.Fatalf("error type = %T, want *DispatchError", emptyErr) + } + if dispatchErr.Code != ErrorCodeUnexpectedResponse { + t.Fatalf("error code = %q, want %q", dispatchErr.Code, ErrorCodeUnexpectedResponse) + } + + if mapJSONRPCCodeToDispatchCode(protocol.JSONRPCCodeMethodNotFound) != gateway.ErrorCodeUnsupportedAction.String() { + t.Fatal("method_not_found should map to unsupported_action") + } + if mapJSONRPCCodeToDispatchCode(protocol.JSONRPCCodeInvalidParams) != gateway.ErrorCodeInvalidFrame.String() { + t.Fatal("invalid_params should map to invalid_frame") + } + if mapJSONRPCCodeToDispatchCode(123456) != ErrorCodeInternal { + t.Fatal("unknown rpc code should map to internal_error") + } + + if _, err := decodeResponseFrameResult(json.RawMessage(`"not-frame"`)); err == nil { + t.Fatal("expected decodeResponseFrameResult unmarshal failure") + } + if _, err := decodeResponseFrameResult(json.RawMessage(`{"type":"ack","action":"wake.openUrl","request_id":"x"`)); err == nil { + t.Fatal("expected decodeResponseFrameResult malformed json failure") + } + + if _, err := marshalJSONRawMessage(make(chan int)); err == nil { + t.Fatal("expected marshalJSONRawMessage failure") + } +} + type stubDispatchConn struct { readBuffer *bytes.Buffer writeErr error @@ -661,3 +889,13 @@ func (a stubDispatchAddr) Network() string { func (a stubDispatchAddr) String() string { return string(a) } + +func mustMarshalRawJSON(t *testing.T, payload any) json.RawMessage { + t.Helper() + + raw, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal raw json: %v", err) + } + return json.RawMessage(raw) +} diff --git a/internal/gateway/handlers/wake.go b/internal/gateway/handlers/wake.go index a8e4df55..2d72ce2d 100644 --- a/internal/gateway/handlers/wake.go +++ b/internal/gateway/handlers/wake.go @@ -88,7 +88,16 @@ func isSafeReviewPath(path string) bool { if trimmed == "" { return false } - if filepath.IsAbs(trimmed) { + if hasWindowsDriveLetterPrefix(trimmed) { + return false + } + if filepath.VolumeName(trimmed) != "" { + return false + } + if hasBlockedWindowsPathPrefix(trimmed) { + return false + } + if isAbsoluteReviewPath(trimmed) { return false } if containsParentTraversalSegment(trimmed) { @@ -104,6 +113,31 @@ func isSafeReviewPath(path string) bool { return true } +// hasWindowsDriveLetterPrefix 检查是否为 Windows 盘符前缀路径,如 C:foo。 +func hasWindowsDriveLetterPrefix(path string) bool { + trimmed := strings.TrimSpace(path) + if len(trimmed) < 2 { + return false + } + drive := trimmed[0] + return ((drive >= 'a' && drive <= 'z') || (drive >= 'A' && drive <= 'Z')) && trimmed[1] == ':' +} + +// hasBlockedWindowsPathPrefix 检查是否命中 Windows 设备路径前缀,避免绕过常规路径校验。 +func hasBlockedWindowsPathPrefix(path string) bool { + normalized := strings.ReplaceAll(strings.TrimSpace(path), "/", "\\") + return strings.HasPrefix(normalized, `\\?\`) || strings.HasPrefix(normalized, `\\.\`) +} + +// isAbsoluteReviewPath 统一识别跨平台绝对路径,避免在 Windows 下漏判 Unix 风格前导斜杠。 +func isAbsoluteReviewPath(path string) bool { + if filepath.IsAbs(path) { + return true + } + normalized := normalizePath(path) + return strings.HasPrefix(normalized, "/") || strings.HasPrefix(normalized, "\\") +} + // containsParentTraversalSegment 按路径段语义识别目录回退段,避免子串匹配导致误伤。 func containsParentTraversalSegment(path string) bool { normalized := normalizePath(path) diff --git a/internal/gateway/handlers/wake_test.go b/internal/gateway/handlers/wake_test.go index f5184a26..3941bf6b 100644 --- a/internal/gateway/handlers/wake_test.go +++ b/internal/gateway/handlers/wake_test.go @@ -59,6 +59,9 @@ func TestWakeOpenURLHandlerHandleUnsafePath(t *testing.T) { "../../etc/passwd", "/etc/passwd", "..\\Windows\\system32", + "C:foo", + `\\?\C:\Windows\System32`, + `\\.\pipe\neocode`, } handler := NewWakeOpenURLHandler() @@ -90,6 +93,9 @@ func TestIsSafeReviewPath(t *testing.T) { {name: "parent traversal", path: "../secret.txt", want: false}, {name: "parent traversal nested", path: "a/../../secret.txt", want: false}, {name: "absolute unix path", path: "/tmp/file", want: false}, + {name: "windows drive relative path", path: "C:foo", want: false}, + {name: "windows device path namespace", path: `\\?\C:\tmp\file`, want: false}, + {name: "windows device pipe namespace", path: `\\.\pipe\name`, want: false}, {name: "empty", path: "", want: false}, {name: "dot current dir", path: ".", want: false}, } diff --git a/internal/gateway/protocol/jsonrpc.go b/internal/gateway/protocol/jsonrpc.go new file mode 100644 index 00000000..eb6ad305 --- /dev/null +++ b/internal/gateway/protocol/jsonrpc.go @@ -0,0 +1,291 @@ +package protocol + +import ( + "bytes" + "encoding/json" + "strings" +) + +const ( + // JSONRPCVersion 表示当前网关控制面固定使用的 JSON-RPC 协议版本。 + JSONRPCVersion = "2.0" +) + +const ( + // MethodGatewayPing 表示网关探活方法。 + MethodGatewayPing = "gateway.ping" + // MethodWakeOpenURL 表示 URL Scheme 唤醒方法。 + MethodWakeOpenURL = "wake.openUrl" +) + +const ( + // JSONRPCCodeParseError 表示请求体不是合法 JSON。 + JSONRPCCodeParseError = -32700 + // JSONRPCCodeInvalidRequest 表示请求结构不符合 JSON-RPC 规范。 + JSONRPCCodeInvalidRequest = -32600 + // JSONRPCCodeMethodNotFound 表示方法未注册。 + JSONRPCCodeMethodNotFound = -32601 + // JSONRPCCodeInvalidParams 表示参数不合法。 + JSONRPCCodeInvalidParams = -32602 + // JSONRPCCodeInternalError 表示服务端内部错误。 + JSONRPCCodeInternalError = -32603 +) + +const ( + // GatewayCodeInvalidFrame 表示请求帧结构非法。 + GatewayCodeInvalidFrame = "invalid_frame" + // GatewayCodeInvalidAction 表示动作参数非法。 + GatewayCodeInvalidAction = "invalid_action" + // GatewayCodeInvalidMultimodalPayload 表示多模态负载非法。 + GatewayCodeInvalidMultimodalPayload = "invalid_multimodal_payload" + // GatewayCodeMissingRequiredField 表示缺少必填字段。 + GatewayCodeMissingRequiredField = "missing_required_field" + // GatewayCodeUnsupportedAction 表示动作尚未实现。 + GatewayCodeUnsupportedAction = "unsupported_action" + // GatewayCodeInternalError 表示网关内部错误。 + GatewayCodeInternalError = "internal_error" + // GatewayCodeUnsafePath 表示路径存在安全风险。 + GatewayCodeUnsafePath = "unsafe_path" +) + +// JSONRPCRequest 表示控制面接收到的 JSON-RPC 请求。 +type JSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id,omitempty"` + Method string `json:"method"` + Params json.RawMessage `json:"params,omitempty"` +} + +// JSONRPCResponse 表示控制面输出的 JSON-RPC 响应。 +type JSONRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *JSONRPCError `json:"error,omitempty"` +} + +// JSONRPCError 表示 JSON-RPC 错误负载。 +type JSONRPCError struct { + Code int `json:"code"` + Message string `json:"message"` + Data *JSONRPCErrorData `json:"data,omitempty"` +} + +// JSONRPCErrorData 表示网关扩展错误字段。 +type JSONRPCErrorData struct { + GatewayCode string `json:"gateway_code,omitempty"` +} + +// NormalizedRequest 表示从 JSON-RPC 归一化后的内部请求模型。 +type NormalizedRequest struct { + ID json.RawMessage + RequestID string + Action string + SessionID string + Workdir string + Payload any +} + +// NormalizeJSONRPCRequest 将 JSON-RPC 请求归一化为内部请求模型,并做方法级参数解析。 +func NormalizeJSONRPCRequest(request JSONRPCRequest) (NormalizedRequest, *JSONRPCError) { + normalized := NormalizedRequest{} + + requestID, idErr := normalizeJSONRPCID(request.ID) + normalized.RequestID = requestID + if idErr != nil { + return normalized, idErr + } + normalized.ID = cloneJSONRawMessage(request.ID) + + if strings.TrimSpace(request.JSONRPC) != JSONRPCVersion { + return normalized, NewJSONRPCError( + JSONRPCCodeInvalidRequest, + "invalid jsonrpc version", + GatewayCodeInvalidFrame, + ) + } + + method := strings.TrimSpace(request.Method) + if method == "" { + return normalized, NewJSONRPCError( + JSONRPCCodeInvalidRequest, + "missing required field: method", + GatewayCodeMissingRequiredField, + ) + } + + switch method { + case MethodGatewayPing: + normalized.Action = "ping" + return normalized, nil + case MethodWakeOpenURL: + intent, parseErr := decodeWakeIntentParams(request.Params) + if parseErr != nil { + return normalized, parseErr + } + normalized.Action = MethodWakeOpenURL + normalized.SessionID = strings.TrimSpace(intent.SessionID) + normalized.Workdir = strings.TrimSpace(intent.Workdir) + normalized.Payload = intent + return normalized, nil + default: + return normalized, NewJSONRPCError( + JSONRPCCodeMethodNotFound, + "method not found", + GatewayCodeUnsupportedAction, + ) + } +} + +// NewJSONRPCResultResponse 创建 JSON-RPC 成功响应,并将 result 编码为 RawMessage。 +func NewJSONRPCResultResponse(id json.RawMessage, result any) (JSONRPCResponse, *JSONRPCError) { + rawResult, err := json.Marshal(result) + if err != nil { + return JSONRPCResponse{}, NewJSONRPCError( + JSONRPCCodeInternalError, + "failed to encode jsonrpc result", + GatewayCodeInternalError, + ) + } + + return JSONRPCResponse{ + JSONRPC: JSONRPCVersion, + ID: cloneJSONRawMessage(id), + Result: json.RawMessage(rawResult), + }, nil +} + +// NewJSONRPCErrorResponse 创建 JSON-RPC 错误响应。 +func NewJSONRPCErrorResponse(id json.RawMessage, rpcError *JSONRPCError) JSONRPCResponse { + return JSONRPCResponse{ + JSONRPC: JSONRPCVersion, + ID: cloneJSONRawMessage(id), + Error: rpcError, + } +} + +// NewJSONRPCError 创建带 gateway_code 的 JSON-RPC 错误对象。 +func NewJSONRPCError(code int, message, gatewayCode string) *JSONRPCError { + errorPayload := &JSONRPCError{ + Code: code, + Message: message, + } + if strings.TrimSpace(gatewayCode) != "" { + errorPayload.Data = &JSONRPCErrorData{GatewayCode: gatewayCode} + } + return errorPayload +} + +// GatewayCodeFromJSONRPCError 从 JSON-RPC 错误负载中提取稳定 gateway_code。 +func GatewayCodeFromJSONRPCError(rpcError *JSONRPCError) string { + if rpcError == nil || rpcError.Data == nil { + return "" + } + return strings.TrimSpace(rpcError.Data.GatewayCode) +} + +// MapGatewayCodeToJSONRPCCode 将稳定网关错误码映射到 JSON-RPC 错误码。 +func MapGatewayCodeToJSONRPCCode(gatewayCode string) int { + switch strings.TrimSpace(gatewayCode) { + case GatewayCodeUnsupportedAction: + return JSONRPCCodeMethodNotFound + case GatewayCodeInvalidAction, + GatewayCodeInvalidFrame, + GatewayCodeInvalidMultimodalPayload, + GatewayCodeMissingRequiredField, + GatewayCodeUnsafePath: + return JSONRPCCodeInvalidParams + case GatewayCodeInternalError: + return JSONRPCCodeInternalError + default: + return JSONRPCCodeInternalError + } +} + +// normalizeJSONRPCID 校验并提取请求 ID,确保控制面请求具备可关联标识。 +func normalizeJSONRPCID(id json.RawMessage) (string, *JSONRPCError) { + trimmed := bytes.TrimSpace(id) + if len(trimmed) == 0 || bytes.Equal(trimmed, []byte("null")) { + return "", NewJSONRPCError( + JSONRPCCodeInvalidRequest, + "missing required field: id", + GatewayCodeMissingRequiredField, + ) + } + + var decoded any + if err := json.Unmarshal(trimmed, &decoded); err != nil { + return "", NewJSONRPCError( + JSONRPCCodeInvalidRequest, + "invalid field: id", + GatewayCodeInvalidFrame, + ) + } + + switch value := decoded.(type) { + case string: + identifier := strings.TrimSpace(value) + if identifier == "" { + return "", NewJSONRPCError( + JSONRPCCodeInvalidRequest, + "invalid field: id", + GatewayCodeInvalidFrame, + ) + } + return identifier, nil + case float64: + identifier := strings.TrimSpace(string(trimmed)) + if identifier == "" { + return "", NewJSONRPCError( + JSONRPCCodeInvalidRequest, + "invalid field: id", + GatewayCodeInvalidFrame, + ) + } + return identifier, nil + default: + return "", NewJSONRPCError( + JSONRPCCodeInvalidRequest, + "invalid field: id", + GatewayCodeInvalidFrame, + ) + } +} + +// decodeWakeIntentParams 对 wake.openUrl 的 params 执行延迟反序列化与最小校验。 +func decodeWakeIntentParams(raw json.RawMessage) (WakeIntent, *JSONRPCError) { + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 || bytes.Equal(trimmed, []byte("null")) { + return WakeIntent{}, NewJSONRPCError( + JSONRPCCodeInvalidParams, + "missing required field: params", + GatewayCodeMissingRequiredField, + ) + } + + var intent WakeIntent + if err := json.Unmarshal(trimmed, &intent); err != nil { + return WakeIntent{}, NewJSONRPCError( + JSONRPCCodeInvalidParams, + "invalid params for wake.openUrl", + GatewayCodeInvalidFrame, + ) + } + intent.Action = strings.ToLower(strings.TrimSpace(intent.Action)) + intent.SessionID = strings.TrimSpace(intent.SessionID) + intent.Workdir = strings.TrimSpace(intent.Workdir) + if len(intent.Params) == 0 { + intent.Params = nil + } + return intent, nil +} + +// cloneJSONRawMessage 复制 RawMessage,避免共享底层切片导致的并发风险。 +func cloneJSONRawMessage(raw json.RawMessage) json.RawMessage { + if len(raw) == 0 { + return nil + } + cloned := make([]byte, len(raw)) + copy(cloned, raw) + return json.RawMessage(cloned) +} diff --git a/internal/gateway/protocol/jsonrpc_test.go b/internal/gateway/protocol/jsonrpc_test.go new file mode 100644 index 00000000..e0cccc1f --- /dev/null +++ b/internal/gateway/protocol/jsonrpc_test.go @@ -0,0 +1,276 @@ +package protocol + +import ( + "encoding/json" + "testing" +) + +func TestNormalizeJSONRPCRequestPing(t *testing.T) { + normalized, rpcErr := NormalizeJSONRPCRequest(JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`"ping-1"`), + Method: MethodGatewayPing, + Params: json.RawMessage(`{}`), + }) + if rpcErr != nil { + t.Fatalf("normalize ping request: %v", rpcErr) + } + if normalized.RequestID != "ping-1" { + t.Fatalf("request_id = %q, want %q", normalized.RequestID, "ping-1") + } + if normalized.Action != "ping" { + t.Fatalf("action = %q, want %q", normalized.Action, "ping") + } +} + +func TestNormalizeJSONRPCRequestPingWithNumericID(t *testing.T) { + normalized, rpcErr := NormalizeJSONRPCRequest(JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`123`), + Method: MethodGatewayPing, + Params: json.RawMessage(`{}`), + }) + if rpcErr != nil { + t.Fatalf("normalize ping request with numeric id: %v", rpcErr) + } + if normalized.RequestID != "123" { + t.Fatalf("request_id = %q, want %q", normalized.RequestID, "123") + } +} + +func TestNormalizeJSONRPCRequestWakeOpenURL(t *testing.T) { + normalized, rpcErr := NormalizeJSONRPCRequest(JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`"wake-1"`), + Method: MethodWakeOpenURL, + Params: json.RawMessage(`{ + "action":"review", + "session_id":"session-1", + "workdir":"/tmp/repo", + "params":{"path":"README.md"} + }`), + }) + if rpcErr != nil { + t.Fatalf("normalize wake request: %v", rpcErr) + } + if normalized.Action != MethodWakeOpenURL { + t.Fatalf("action = %q, want %q", normalized.Action, MethodWakeOpenURL) + } + if normalized.SessionID != "session-1" { + t.Fatalf("session_id = %q, want %q", normalized.SessionID, "session-1") + } + if normalized.Workdir != "/tmp/repo" { + t.Fatalf("workdir = %q, want %q", normalized.Workdir, "/tmp/repo") + } + intent, ok := normalized.Payload.(WakeIntent) + if !ok { + t.Fatalf("payload type = %T, want WakeIntent", normalized.Payload) + } + if intent.Params["path"] != "README.md" { + t.Fatalf("intent.params[path] = %q, want %q", intent.Params["path"], "README.md") + } +} + +func TestNormalizeJSONRPCRequestErrors(t *testing.T) { + testCases := []struct { + name string + request JSONRPCRequest + wantCode int + wantGatewayCode string + }{ + { + name: "missing id", + request: JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + Method: MethodGatewayPing, + }, + wantCode: JSONRPCCodeInvalidRequest, + wantGatewayCode: GatewayCodeMissingRequiredField, + }, + { + name: "invalid version", + request: JSONRPCRequest{ + JSONRPC: "1.0", + ID: json.RawMessage(`"x"`), + Method: MethodGatewayPing, + }, + wantCode: JSONRPCCodeInvalidRequest, + wantGatewayCode: GatewayCodeInvalidFrame, + }, + { + name: "invalid id object", + request: JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`{}`), + Method: MethodGatewayPing, + }, + wantCode: JSONRPCCodeInvalidRequest, + wantGatewayCode: GatewayCodeInvalidFrame, + }, + { + name: "invalid id array", + request: JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`[]`), + Method: MethodGatewayPing, + }, + wantCode: JSONRPCCodeInvalidRequest, + wantGatewayCode: GatewayCodeInvalidFrame, + }, + { + name: "invalid id boolean", + request: JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`true`), + Method: MethodGatewayPing, + }, + wantCode: JSONRPCCodeInvalidRequest, + wantGatewayCode: GatewayCodeInvalidFrame, + }, + { + name: "missing method", + request: JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`"x"`), + }, + wantCode: JSONRPCCodeInvalidRequest, + wantGatewayCode: GatewayCodeMissingRequiredField, + }, + { + name: "method not found", + request: JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`"x"`), + Method: "gateway.unknown", + }, + wantCode: JSONRPCCodeMethodNotFound, + wantGatewayCode: GatewayCodeUnsupportedAction, + }, + { + name: "wake missing params", + request: JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`"x"`), + Method: MethodWakeOpenURL, + }, + wantCode: JSONRPCCodeInvalidParams, + wantGatewayCode: GatewayCodeMissingRequiredField, + }, + { + name: "wake invalid params", + request: JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`"x"`), + Method: MethodWakeOpenURL, + Params: json.RawMessage(`{invalid}`), + }, + wantCode: JSONRPCCodeInvalidParams, + wantGatewayCode: GatewayCodeInvalidFrame, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + _, rpcErr := NormalizeJSONRPCRequest(tc.request) + if rpcErr == nil { + t.Fatal("expected rpc error") + } + if rpcErr.Code != tc.wantCode { + t.Fatalf("rpc code = %d, want %d", rpcErr.Code, tc.wantCode) + } + if gatewayCode := GatewayCodeFromJSONRPCError(rpcErr); gatewayCode != tc.wantGatewayCode { + t.Fatalf("gateway_code = %q, want %q", gatewayCode, tc.wantGatewayCode) + } + }) + } +} + +func TestNormalizeJSONRPCRequestInvalidIDReturnsNullResponseID(t *testing.T) { + normalized, rpcErr := NormalizeJSONRPCRequest(JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`{}`), + Method: MethodGatewayPing, + }) + if rpcErr == nil { + t.Fatal("expected rpc error") + } + if rpcErr.Code != JSONRPCCodeInvalidRequest { + t.Fatalf("rpc code = %d, want %d", rpcErr.Code, JSONRPCCodeInvalidRequest) + } + if normalized.ID != nil { + t.Fatalf("normalized id = %s, want nil", string(normalized.ID)) + } +} + +func TestJSONRPCHelpers(t *testing.T) { + response, rpcErr := NewJSONRPCResultResponse(json.RawMessage(`"req-1"`), map[string]string{"message": "ok"}) + if rpcErr != nil { + t.Fatalf("new jsonrpc result response: %v", rpcErr) + } + if response.JSONRPC != JSONRPCVersion { + t.Fatalf("jsonrpc = %q, want %q", response.JSONRPC, JSONRPCVersion) + } + if string(response.ID) != `"req-1"` { + t.Fatalf("id = %s, want %s", response.ID, `"req-1"`) + } + var result map[string]string + if err := json.Unmarshal(response.Result, &result); err != nil { + t.Fatalf("decode result raw message: %v", err) + } + if result["message"] != "ok" { + t.Fatalf(`result["message"] = %q, want %q`, result["message"], "ok") + } + + _, rpcErr = NewJSONRPCResultResponse(json.RawMessage(`"req-chan"`), map[string]any{"bad": make(chan int)}) + if rpcErr == nil { + t.Fatal("expected result encode error") + } + if rpcErr.Code != JSONRPCCodeInternalError { + t.Fatalf("rpc code = %d, want %d", rpcErr.Code, JSONRPCCodeInternalError) + } + if gatewayCode := GatewayCodeFromJSONRPCError(rpcErr); gatewayCode != GatewayCodeInternalError { + t.Fatalf("gateway_code = %q, want %q", gatewayCode, GatewayCodeInternalError) + } + + rpcErr = NewJSONRPCError(JSONRPCCodeInternalError, "boom", GatewayCodeInternalError) + errorResponse := NewJSONRPCErrorResponse(json.RawMessage(`"req-2"`), rpcErr) + if errorResponse.Error == nil { + t.Fatal("error response should include rpc error payload") + } + if GatewayCodeFromJSONRPCError(errorResponse.Error) != GatewayCodeInternalError { + t.Fatalf("gateway_code = %q, want %q", GatewayCodeFromJSONRPCError(errorResponse.Error), GatewayCodeInternalError) + } + if GatewayCodeFromJSONRPCError(nil) != "" { + t.Fatal("gateway_code for nil rpc error should be empty") + } + + if MapGatewayCodeToJSONRPCCode(GatewayCodeUnsupportedAction) != JSONRPCCodeMethodNotFound { + t.Fatal("unsupported_action should map to method_not_found") + } + if MapGatewayCodeToJSONRPCCode(GatewayCodeInvalidAction) != JSONRPCCodeInvalidParams { + t.Fatal("invalid_action should map to invalid_params") + } + if MapGatewayCodeToJSONRPCCode("unknown") != JSONRPCCodeInternalError { + t.Fatal("unknown code should map to internal_error") + } +} + +func TestNewJSONRPCErrorResponseWithNilIDEncodesNull(t *testing.T) { + response := NewJSONRPCErrorResponse(nil, NewJSONRPCError(JSONRPCCodeParseError, "parse error", GatewayCodeInvalidFrame)) + encoded, err := json.Marshal(response) + if err != nil { + t.Fatalf("marshal error response: %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(encoded, &payload); err != nil { + t.Fatalf("unmarshal encoded response: %v", err) + } + if _, ok := payload["id"]; !ok { + t.Fatal("encoded response should contain id field") + } + if payload["id"] != nil { + t.Fatalf("encoded response id = %#v, want nil", payload["id"]) + } +} diff --git a/internal/gateway/server.go b/internal/gateway/server.go index 718dc430..001e34fd 100644 --- a/internal/gateway/server.go +++ b/internal/gateway/server.go @@ -14,6 +14,7 @@ import ( "sync" "time" + "neo-code/internal/gateway/protocol" "neo-code/internal/gateway/transport" ) @@ -261,7 +262,7 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn, runtimePor return } - frame, err := decodeFrame(reader) + rpcRequest, err := decodeRPCRequest(reader) if err != nil { if errors.Is(err, io.EOF) { return @@ -275,20 +276,31 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn, runtimePor } if errors.Is(err, errFrameTooLarge) { s.logger.Printf("decode frame failed: %v", err) - _ = s.writeFrame(conn, encoder, errorFrame(MessageFrame{}, NewFrameError( - ErrorCodeInvalidFrame, - fmt.Sprintf("frame exceeds max size %d bytes", MaxFrameSize), - ))) + _ = s.writeRPCResponse(conn, encoder, protocol.NewJSONRPCErrorResponse( + nil, + protocol.NewJSONRPCError( + protocol.JSONRPCCodeInvalidRequest, + fmt.Sprintf("frame exceeds max size %d bytes", MaxFrameSize), + protocol.GatewayCodeInvalidFrame, + ), + )) return } s.logger.Printf("decode frame failed: %v", err) - _ = s.writeFrame(conn, encoder, errorFrame(MessageFrame{}, NewFrameError(ErrorCodeInvalidFrame, "invalid json frame"))) + _ = s.writeRPCResponse(conn, encoder, protocol.NewJSONRPCErrorResponse( + nil, + protocol.NewJSONRPCError( + protocol.JSONRPCCodeParseError, + "invalid json-rpc request", + protocol.GatewayCodeInvalidFrame, + ), + )) return } - response := s.dispatchFrame(ctx, frame, runtimePort) - if !s.writeFrame(conn, encoder, response) { + rpcResponse := s.dispatchRPCRequest(ctx, rpcRequest, runtimePort) + if !s.writeRPCResponse(conn, encoder, rpcResponse) { return } } @@ -310,13 +322,13 @@ func (s *Server) applyWriteDeadline(conn net.Conn) error { return conn.SetWriteDeadline(time.Now().Add(s.writeTimeout)) } -// writeFrame 统一处理响应写回及写超时设置,失败时返回 false 供上层快速终止连接循环。 -func (s *Server) writeFrame(conn net.Conn, encoder *json.Encoder, frame MessageFrame) bool { +// writeRPCResponse 统一处理 JSON-RPC 响应写回及写超时设置,失败时返回 false 供上层快速终止连接循环。 +func (s *Server) writeRPCResponse(conn net.Conn, encoder *json.Encoder, response protocol.JSONRPCResponse) bool { if err := s.applyWriteDeadline(conn); err != nil { s.logger.Printf("set write deadline failed: %v", err) return false } - if err := encoder.Encode(frame); err != nil { + if err := encoder.Encode(response); err != nil { s.logger.Printf("write frame failed: %v", err) return false } @@ -329,27 +341,70 @@ func isTimeoutError(err error) bool { return errors.As(err, &netErr) && netErr.Timeout() } -// decodeFrame 从连接读取一条 JSON 帧并执行长度与格式校验。 -func decodeFrame(reader *bufio.Reader) (MessageFrame, error) { +// decodeRPCRequest 从连接读取一条 JSON-RPC 请求并执行长度与格式校验。 +func decodeRPCRequest(reader *bufio.Reader) (protocol.JSONRPCRequest, error) { payload, err := readFramePayload(reader, MaxFrameSize) if err != nil { - return MessageFrame{}, err + return protocol.JSONRPCRequest{}, err } limitedReader := &io.LimitedReader{R: bytes.NewReader(payload), N: MaxFrameSize} decoder := json.NewDecoder(limitedReader) - var frame MessageFrame - if err := decoder.Decode(&frame); err != nil { - return MessageFrame{}, err + var request protocol.JSONRPCRequest + if err := decoder.Decode(&request); err != nil { + return protocol.JSONRPCRequest{}, err } var trailing any if err := decoder.Decode(&trailing); !errors.Is(err, io.EOF) { - return MessageFrame{}, fmt.Errorf("frame contains trailing json values") + return protocol.JSONRPCRequest{}, fmt.Errorf("frame contains trailing json values") } - return frame, nil + return request, nil +} + +// dispatchRPCRequest 将 JSON-RPC 请求归一化为 MessageFrame,并复用既有分发逻辑处理。 +func (s *Server) dispatchRPCRequest( + ctx context.Context, + request protocol.JSONRPCRequest, + runtimePort RuntimePort, +) protocol.JSONRPCResponse { + normalized, rpcErr := protocol.NormalizeJSONRPCRequest(request) + if rpcErr != nil { + return protocol.NewJSONRPCErrorResponse(normalized.ID, rpcErr) + } + + frame := MessageFrame{ + Type: FrameTypeRequest, + Action: FrameAction(normalized.Action), + RequestID: normalized.RequestID, + SessionID: normalized.SessionID, + Workdir: normalized.Workdir, + Payload: normalized.Payload, + } + + responseFrame := s.dispatchFrame(ctx, frame, runtimePort) + if responseFrame.Type != FrameTypeError { + rpcResponse, encodeErr := protocol.NewJSONRPCResultResponse(normalized.ID, responseFrame) + if encodeErr != nil { + return protocol.NewJSONRPCErrorResponse(normalized.ID, encodeErr) + } + return rpcResponse + } + + frameErr := responseFrame.Error + if frameErr == nil { + frameErr = NewFrameError(ErrorCodeInternalError, "gateway response missing error payload") + } + return protocol.NewJSONRPCErrorResponse( + normalized.ID, + protocol.NewJSONRPCError( + protocol.MapGatewayCodeToJSONRPCCode(frameErr.Code), + frameErr.Message, + frameErr.Code, + ), + ) } // readFramePayload 按换行边界读取单条帧,并限制单帧最大字节数。 diff --git a/internal/gateway/server_additional_test.go b/internal/gateway/server_additional_test.go index d5eea3d5..52466de9 100644 --- a/internal/gateway/server_additional_test.go +++ b/internal/gateway/server_additional_test.go @@ -12,6 +12,8 @@ import ( "sync" "testing" "time" + + "neo-code/internal/gateway/protocol" ) func TestNewServerUsesDefaultsAndOverrides(t *testing.T) { @@ -236,9 +238,9 @@ func TestCloseReturnsContextErrorWhenWaitCanceled(t *testing.T) { server.wg.Done() } -func TestDecodeFrameTrailingJSON(t *testing.T) { - reader := bufio.NewReader(strings.NewReader(`{"type":"request","action":"ping"} {"extra":1}` + "\n")) - _, err := decodeFrame(reader) +func TestDecodeRPCRequestTrailingJSON(t *testing.T) { + reader := bufio.NewReader(strings.NewReader(`{"jsonrpc":"2.0","id":"x","method":"gateway.ping"} {"extra":1}` + "\n")) + _, err := decodeRPCRequest(reader) if err == nil || !strings.Contains(err.Error(), "trailing") { t.Fatalf("expected trailing json error, got %v", err) } @@ -289,6 +291,82 @@ func TestDispatchFrameValidationError(t *testing.T) { } } +func TestDispatchRPCRequestNormalizeError(t *testing.T) { + server := &Server{} + response := server.dispatchRPCRequest(context.Background(), protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + Method: protocol.MethodGatewayPing, + }, nil) + if response.Error == nil { + t.Fatal("expected rpc normalize error") + } + if response.Error.Code != protocol.JSONRPCCodeInvalidRequest { + t.Fatalf("rpc error code = %d, want %d", response.Error.Code, protocol.JSONRPCCodeInvalidRequest) + } + if gatewayCode := protocol.GatewayCodeFromJSONRPCError(response.Error); gatewayCode != ErrorCodeMissingRequiredField.String() { + t.Fatalf("gateway_code = %q, want %q", gatewayCode, ErrorCodeMissingRequiredField.String()) + } +} + +func TestDispatchRPCRequestInvalidIDReturnsNullID(t *testing.T) { + server := &Server{} + response := server.dispatchRPCRequest(context.Background(), protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + ID: json.RawMessage(`{"bad":"id"}`), + Method: protocol.MethodGatewayPing, + }, nil) + if response.Error == nil { + t.Fatal("expected rpc normalize error") + } + if response.Error.Code != protocol.JSONRPCCodeInvalidRequest { + t.Fatalf("rpc error code = %d, want %d", response.Error.Code, protocol.JSONRPCCodeInvalidRequest) + } + encoded, err := json.Marshal(response) + if err != nil { + t.Fatalf("marshal response: %v", err) + } + var payload map[string]any + if err := json.Unmarshal(encoded, &payload); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if id, ok := payload["id"]; !ok || id != nil { + t.Fatalf("encoded response id = %#v, want nil", payload["id"]) + } +} + +func TestDispatchRPCRequestConvertsFrameErrorWithoutPayload(t *testing.T) { + server := &Server{} + originalHandlers := requestFrameHandlers + requestFrameHandlers = map[FrameAction]requestFrameHandler{ + FrameActionPing: func(frame MessageFrame) MessageFrame { + return MessageFrame{ + Type: FrameTypeError, + Action: frame.Action, + RequestID: frame.RequestID, + Error: nil, + } + }, + } + t.Cleanup(func() { + requestFrameHandlers = originalHandlers + }) + + response := server.dispatchRPCRequest(context.Background(), protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + ID: json.RawMessage(`"rpc-err-1"`), + Method: protocol.MethodGatewayPing, + }, nil) + if response.Error == nil { + t.Fatal("expected rpc error response") + } + if response.Error.Code != protocol.JSONRPCCodeInternalError { + t.Fatalf("rpc error code = %d, want %d", response.Error.Code, protocol.JSONRPCCodeInternalError) + } + if gatewayCode := protocol.GatewayCodeFromJSONRPCError(response.Error); gatewayCode != ErrorCodeInternalError.String() { + t.Fatalf("gateway_code = %q, want %q", gatewayCode, ErrorCodeInternalError.String()) + } +} + func TestServerHandleConnectionSkipsEmptyFrame(t *testing.T) { server := &Server{logger: log.New(io.Discard, "", 0)} serverConn, clientConn := net.Pipe() @@ -299,15 +377,22 @@ func TestServerHandleConnectionSkipsEmptyFrame(t *testing.T) { }() _, _ = io.WriteString(clientConn, "\n") - _, _ = io.WriteString(clientConn, `{"type":"request","action":"ping","request_id":"empty-then-ping"}`+"\n") + _, _ = io.WriteString(clientConn, `{"jsonrpc":"2.0","id":"empty-then-ping","method":"gateway.ping","params":{}}`+"\n") decoder := json.NewDecoder(clientConn) - var response MessageFrame + var response protocol.JSONRPCResponse if err := decoder.Decode(&response); err != nil { t.Fatalf("decode response: %v", err) } - if response.Type != FrameTypeAck || response.Action != FrameActionPing { - t.Fatalf("unexpected response after empty frame: %#v", response) + if response.Error != nil { + t.Fatalf("unexpected rpc error after empty frame: %+v", response.Error) + } + resultFrame, err := decodeJSONRPCResultFrame(response) + if err != nil { + t.Fatalf("decode result frame: %v", err) + } + if resultFrame.Type != FrameTypeAck || resultFrame.Action != FrameActionPing { + t.Fatalf("unexpected response after empty frame: %#v", resultFrame) } _ = clientConn.Close() @@ -329,15 +414,21 @@ func TestServerHandleConnectionInvalidJSONFrame(t *testing.T) { _, _ = io.WriteString(clientConn, "{invalid-json}\n") decoder := json.NewDecoder(clientConn) - var response MessageFrame + var response protocol.JSONRPCResponse if err := decoder.Decode(&response); err != nil { t.Fatalf("decode response: %v", err) } - if response.Type != FrameTypeError { - t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) + if response.Error == nil { + t.Fatal("response rpc error is nil") } - if response.Error == nil || response.Error.Code != ErrorCodeInvalidFrame.String() { - t.Fatalf("response error = %#v, want invalid frame", response.Error) + if response.Error.Code != protocol.JSONRPCCodeParseError { + t.Fatalf("rpc error code = %d, want %d", response.Error.Code, protocol.JSONRPCCodeParseError) + } + if gatewayCode := protocol.GatewayCodeFromJSONRPCError(response.Error); gatewayCode != ErrorCodeInvalidFrame.String() { + t.Fatalf("gateway_code = %q, want %q", gatewayCode, ErrorCodeInvalidFrame.String()) + } + if got := strings.TrimSpace(string(response.ID)); got != "null" { + t.Fatalf("response id = %q, want %q", got, "null") } _ = clientConn.Close() @@ -412,7 +503,7 @@ func TestServerHandleConnectionWriteTimeoutClosesConnection(t *testing.T) { server.handleConnection(context.Background(), serverConn, nil) }() - _, err := io.WriteString(clientConn, `{"type":"request","action":"ping","request_id":"write-timeout"}`+"\n") + _, err := io.WriteString(clientConn, `{"jsonrpc":"2.0","id":"write-timeout","method":"gateway.ping","params":{}}`+"\n") if err != nil { t.Fatalf("write request: %v", err) } diff --git a/internal/gateway/server_test.go b/internal/gateway/server_test.go index aa7cae26..69c3b154 100644 --- a/internal/gateway/server_test.go +++ b/internal/gateway/server_test.go @@ -11,6 +11,8 @@ import ( "strings" "testing" "time" + + "neo-code/internal/gateway/protocol" ) func TestServerHandleConnectionPing(t *testing.T) { @@ -28,32 +30,40 @@ func TestServerHandleConnectionPing(t *testing.T) { encoder := json.NewEncoder(clientConn) decoder := json.NewDecoder(clientConn) - if err := encoder.Encode(MessageFrame{ - Type: FrameTypeRequest, - Action: FrameActionPing, - RequestID: "req-1", + if err := encoder.Encode(protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + ID: json.RawMessage(`"req-1"`), + Method: protocol.MethodGatewayPing, + Params: json.RawMessage(`{}`), }); err != nil { t.Fatalf("encode request: %v", err) } - var response MessageFrame + var response protocol.JSONRPCResponse if err := decoder.Decode(&response); err != nil { t.Fatalf("decode response: %v", err) } + if response.Error != nil { + t.Fatalf("unexpected rpc error: %+v", response.Error) + } + resultFrame, err := decodeJSONRPCResultFrame(response) + if err != nil { + t.Fatalf("decode result frame: %v", err) + } - if response.Type != FrameTypeAck { - t.Fatalf("response type = %q, want %q", response.Type, FrameTypeAck) + if resultFrame.Type != FrameTypeAck { + t.Fatalf("response type = %q, want %q", resultFrame.Type, FrameTypeAck) } - if response.Action != FrameActionPing { - t.Fatalf("response action = %q, want %q", response.Action, FrameActionPing) + if resultFrame.Action != FrameActionPing { + t.Fatalf("response action = %q, want %q", resultFrame.Action, FrameActionPing) } - if response.RequestID != "req-1" { - t.Fatalf("response request_id = %q, want %q", response.RequestID, "req-1") + if resultFrame.RequestID != "req-1" { + t.Fatalf("response request_id = %q, want %q", resultFrame.RequestID, "req-1") } - payloadMap, ok := response.Payload.(map[string]any) + payloadMap, ok := resultFrame.Payload.(map[string]any) if !ok { - t.Fatalf("response payload type = %T, want map[string]any", response.Payload) + t.Fatalf("response payload type = %T, want map[string]any", resultFrame.Payload) } if got, _ := payloadMap["message"].(string); got != "pong" { t.Fatalf("response payload message = %q, want %q", got, "pong") @@ -82,28 +92,27 @@ func TestServerHandleConnectionUnsupportedAction(t *testing.T) { encoder := json.NewEncoder(clientConn) decoder := json.NewDecoder(clientConn) - if err := encoder.Encode(MessageFrame{ - Type: FrameTypeRequest, - Action: FrameActionRun, - RequestID: "req-2", - InputText: "hello", + if err := encoder.Encode(protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + ID: json.RawMessage(`"req-2"`), + Method: "gateway.run", + Params: json.RawMessage(`{"input_text":"hello"}`), }); err != nil { t.Fatalf("encode request: %v", err) } - var response MessageFrame + var response protocol.JSONRPCResponse if err := decoder.Decode(&response); err != nil { t.Fatalf("decode response: %v", err) } - - if response.Type != FrameTypeError { - t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) - } if response.Error == nil { - t.Fatal("response error is nil") + t.Fatal("response rpc error is nil") + } + if response.Error.Code != protocol.JSONRPCCodeMethodNotFound { + t.Fatalf("rpc error code = %d, want %d", response.Error.Code, protocol.JSONRPCCodeMethodNotFound) } - if response.Error.Code != ErrorCodeUnsupportedAction.String() { - t.Fatalf("error code = %q, want %q", response.Error.Code, ErrorCodeUnsupportedAction.String()) + if gatewayCode := protocol.GatewayCodeFromJSONRPCError(response.Error); gatewayCode != ErrorCodeUnsupportedAction.String() { + t.Fatalf("gateway_code = %q, want %q", gatewayCode, ErrorCodeUnsupportedAction.String()) } _ = clientConn.Close() @@ -129,7 +138,7 @@ func TestServerHandleConnectionRejectsOversizedFrame(t *testing.T) { decoder := json.NewDecoder(clientConn) oversizedPayload := strings.Repeat("a", int(MaxFrameSize)+128) requestFrame := fmt.Sprintf( - `{"type":"request","action":"ping","request_id":"req-oversize","input_text":"%s"}`+"\n", + `{"jsonrpc":"2.0","id":"req-oversize","method":"gateway.ping","params":{"input_text":"%s"}}`+"\n", oversizedPayload, ) @@ -139,18 +148,18 @@ func TestServerHandleConnectionRejectsOversizedFrame(t *testing.T) { writeDone <- err }() - var response MessageFrame + var response protocol.JSONRPCResponse if err := decoder.Decode(&response); err != nil { t.Fatalf("decode oversized response: %v", err) } - if response.Type != FrameTypeError { - t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) - } if response.Error == nil { - t.Fatal("response error is nil") + t.Fatal("response rpc error is nil") + } + if response.Error.Code != protocol.JSONRPCCodeInvalidRequest { + t.Fatalf("rpc error code = %d, want %d", response.Error.Code, protocol.JSONRPCCodeInvalidRequest) } - if response.Error.Code != ErrorCodeInvalidFrame.String() { - t.Fatalf("error code = %q, want %q", response.Error.Code, ErrorCodeInvalidFrame.String()) + if gatewayCode := protocol.GatewayCodeFromJSONRPCError(response.Error); gatewayCode != ErrorCodeInvalidFrame.String() { + t.Fatalf("gateway_code = %q, want %q", gatewayCode, ErrorCodeInvalidFrame.String()) } if !strings.Contains(response.Error.Message, "frame exceeds max size") { t.Fatalf("error message = %q, want contains %q", response.Error.Message, "frame exceeds max size") @@ -189,3 +198,14 @@ func TestServerHandleConnectionRejectsOversizedFrame(t *testing.T) { t.Fatal("handleConnection did not exit") } } + +func decodeJSONRPCResultFrame(response protocol.JSONRPCResponse) (MessageFrame, error) { + if response.Result == nil { + return MessageFrame{}, errors.New("rpc result is nil") + } + var frame MessageFrame + if err := json.Unmarshal(response.Result, &frame); err != nil { + return MessageFrame{}, err + } + return frame, nil +} diff --git a/internal/memo/llm_extractor_test.go b/internal/memo/llm_extractor_test.go index 7d1befea..10bc7864 100644 --- a/internal/memo/llm_extractor_test.go +++ b/internal/memo/llm_extractor_test.go @@ -299,6 +299,48 @@ func TestLLMExtractorExtractKeepsProjectedToolCallSpan(t *testing.T) { } } +func TestLLMExtractorExtractKeepsMetadataOnlyToolCallSpan(t *testing.T) { + generator := &stubTextGenerator{response: `[]`} + extractor := NewLLMExtractor(generator) + + _, err := extractor.Extract(context.Background(), []providertypes.Message{ + {Role: providertypes.RoleUser, Content: "remember this"}, + { + Role: providertypes.RoleAssistant, + ToolCalls: []providertypes.ToolCall{ + {ID: "call_1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`}, + }, + }, + { + Role: providertypes.RoleTool, + ToolCallID: "call_1", + Content: "", + ToolMetadata: map[string]string{"tool_name": "filesystem_read_file", "path": "README.md"}, + }, + }) + if err != nil { + t.Fatalf("Extract() error = %v", err) + } + if len(generator.messages) != 3 { + t.Fatalf("len(generator.messages) = %d, want 3", len(generator.messages)) + } + toolMessage := generator.messages[2] + if toolMessage.Role != providertypes.RoleTool { + t.Fatalf("expected projected tool message, got %#v", toolMessage) + } + if !strings.Contains(toolMessage.Content, "tool result") || + !strings.Contains(toolMessage.Content, "tool: filesystem_read_file") || + !strings.Contains(toolMessage.Content, "meta.path: README.md") { + t.Fatalf("expected metadata-only tool text, got %q", toolMessage.Content) + } + if strings.Contains(toolMessage.Content, "content:\n") { + t.Fatalf("expected metadata-only projection to omit content section, got %q", toolMessage.Content) + } + if toolMessage.ToolMetadata != nil { + t.Fatalf("expected projected tool metadata to be cleared, got %#v", toolMessage.ToolMetadata) + } +} + func TestLLMExtractorExtractSkipsOrphanAndClearedToolMessages(t *testing.T) { generator := &stubTextGenerator{response: `[]`} extractor := NewLLMExtractor(generator) diff --git a/internal/provider/catalog/service_test.go b/internal/provider/catalog/service_test.go index df826fc7..3488dd3a 100644 --- a/internal/provider/catalog/service_test.go +++ b/internal/provider/catalog/service_test.go @@ -103,6 +103,34 @@ func TestListProviderModelsMergesConfiguredMetadataAfterDiscovery(t *testing.T) } } +func TestListProviderModelsUsesConfiguredContextWindowWhenDiscoveryMissesIt(t *testing.T) { + t.Setenv(testAPIKeyEnv, "test-key") + + registry := newRegistry(t, openaicompat.DriverName, func(ctx context.Context, cfg provider.RuntimeConfig) ([]providertypes.ModelDescriptor, error) { + return []providertypes.ModelDescriptor{{ + ID: "deepseek-coder", + Name: "Server DeepSeek", + ContextWindow: 0, + }}, nil + }) + + service := NewService("", registry, newMemoryStore()) + providerCfg := customGatewayProvider() + providerCfg.Models = []providertypes.ModelDescriptor{{ + ID: "deepseek-coder", + Name: "DeepSeek Coder", + ContextWindow: 131072, + }} + + models, err := service.ListProviderModels(context.Background(), mustCatalogInput(t, providerCfg)) + if err != nil { + t.Fatalf("ListProviderModels() error = %v", err) + } + if len(models) != 1 || models[0].ContextWindow != 131072 { + t.Fatalf("expected configured context window to fill discovery gap, got %+v", models) + } +} + func TestListProviderModelsSnapshotReturnsDefaultAndRefreshesInBackgroundOnMiss(t *testing.T) { t.Setenv(testAPIKeyEnv, "test-key") diff --git a/internal/runtime/controlplane/progress.go b/internal/runtime/controlplane/progress.go index a1d26875..784496ce 100644 --- a/internal/runtime/controlplane/progress.go +++ b/internal/runtime/controlplane/progress.go @@ -23,17 +23,40 @@ type ProgressScore struct { // ProgressState 汇总当前运行期 progress 控制面状态。 type ProgressState struct { - LastScore ProgressScore `json:"last_score"` + LastScore ProgressScore `json:"last_score"` + LastSignature string `json:"last_signature,omitempty"` } // ApplyProgressEvidence 根据证据更新分值与 streak。 -func ApplyProgressEvidence(state ProgressState, records []ProgressEvidenceRecord) ProgressState { +func ApplyProgressEvidence(state ProgressState, records []ProgressEvidenceRecord, currentSignature string) ProgressState { next := state.LastScore - if len(records) == 0 { - next.NoProgressStreak++ + hasToolAttempt := currentSignature != "" + isRepeated := hasToolAttempt && state.LastSignature != "" && currentSignature == state.LastSignature + + if hasToolAttempt { + if isRepeated { + next.RepeatCycleStreak++ + } else { + next.RepeatCycleStreak = 1 + } } else { + next.RepeatCycleStreak = 0 + } + + nextSignature := "" + if hasToolAttempt { + nextSignature = currentSignature + } + + if len(records) > 0 && !isRepeated { next.NoProgressStreak = 0 next.ScoreDelta++ + } else { + next.NoProgressStreak++ + } + + return ProgressState{ + LastScore: next, + LastSignature: nextSignature, } - return ProgressState{LastScore: next} } diff --git a/internal/runtime/controlplane/progress_test.go b/internal/runtime/controlplane/progress_test.go index cd25006e..f457a0be 100644 --- a/internal/runtime/controlplane/progress_test.go +++ b/internal/runtime/controlplane/progress_test.go @@ -4,10 +4,15 @@ import "testing" func TestApplyProgressEvidenceNoEvidenceIncrementsNoProgress(t *testing.T) { t.Parallel() - state := ProgressState{} - next := ApplyProgressEvidence(state, nil) - if next.LastScore.NoProgressStreak != 1 { - t.Fatalf("expected no_progress_streak 1, got %d", next.LastScore.NoProgressStreak) + got := ApplyProgressEvidence(ProgressState{}, nil, "") + want := ProgressState{ + LastScore: ProgressScore{ + NoProgressStreak: 1, + RepeatCycleStreak: 0, + }, + } + if got != want { + t.Fatalf("expected %+v, got %+v", want, got) } } @@ -16,14 +21,19 @@ func TestApplyProgressEvidenceOnlyNonDupResetsNoProgressStreak(t *testing.T) { state := ProgressState{ LastScore: ProgressScore{NoProgressStreak: 3}, } - next := ApplyProgressEvidence(state, []ProgressEvidenceRecord{ + got := ApplyProgressEvidence(state, []ProgressEvidenceRecord{ {Kind: EvidenceNewInfoNonDup}, - }) - if next.LastScore.NoProgressStreak != 0 { - t.Fatalf("expected streak reset to 0, got %d", next.LastScore.NoProgressStreak) + }, "sig1") + want := ProgressState{ + LastScore: ProgressScore{ + ScoreDelta: 1, + NoProgressStreak: 0, + RepeatCycleStreak: 1, + }, + LastSignature: "sig1", } - if next.LastScore.ScoreDelta != 1 { - t.Fatalf("expected score_delta 1, got %d", next.LastScore.ScoreDelta) + if got != want { + t.Fatalf("expected %+v, got %+v", want, got) } } @@ -32,11 +42,52 @@ func TestApplyProgressEvidenceMixedResetsNoProgress(t *testing.T) { state := ProgressState{ LastScore: ProgressScore{NoProgressStreak: 2}, } - next := ApplyProgressEvidence(state, []ProgressEvidenceRecord{ + got := ApplyProgressEvidence(state, []ProgressEvidenceRecord{ {Kind: EvidenceNewInfoNonDup}, {Kind: ProgressEvidenceKind("other_evidence")}, - }) - if next.LastScore.NoProgressStreak != 0 { - t.Fatalf("expected streak reset, got %d", next.LastScore.NoProgressStreak) + }, "sig1") + if got.LastScore.NoProgressStreak != 0 { + t.Fatalf("expected streak reset, got %d", got.LastScore.NoProgressStreak) + } +} + +func TestApplyProgressEvidenceRepeatCycle(t *testing.T) { + t.Parallel() + state := ProgressState{ + LastScore: ProgressScore{NoProgressStreak: 1, RepeatCycleStreak: 2}, + LastSignature: "sig1", + } + got := ApplyProgressEvidence(state, []ProgressEvidenceRecord{ + {Kind: EvidenceNewInfoNonDup}, + }, "sig1") + want := ProgressState{ + LastScore: ProgressScore{ + NoProgressStreak: 2, + RepeatCycleStreak: 3, + }, + LastSignature: "sig1", + } + if got != want { + t.Fatalf("expected %+v, got %+v", want, got) + } +} + +func TestApplyProgressEvidenceRepeatCycleOnFailureKeepsSignatureTracking(t *testing.T) { + t.Parallel() + state := ProgressState{ + LastScore: ProgressScore{NoProgressStreak: 2, RepeatCycleStreak: 1}, + LastSignature: "sig1", + } + + got := ApplyProgressEvidence(state, nil, "sig1") + want := ProgressState{ + LastScore: ProgressScore{ + NoProgressStreak: 3, + RepeatCycleStreak: 2, + }, + LastSignature: "sig1", + } + if got != want { + t.Fatalf("expected %+v, got %+v", want, got) } } diff --git a/internal/runtime/errors_test.go b/internal/runtime/errors_test.go new file mode 100644 index 00000000..9ae1c427 --- /dev/null +++ b/internal/runtime/errors_test.go @@ -0,0 +1,43 @@ +package runtime + +import ( + "bytes" + "context" + "errors" + "log" + "testing" + + "neo-code/internal/provider" +) + +func TestHandleRunErrorProviderErrorDoesNotWriteStdLog(t *testing.T) { + service := &Service{} + providerErr := &provider.ProviderError{ + StatusCode: 401, + Code: "auth_failed", + Message: "Incorrect API key provided", + Retryable: false, + } + + var buf bytes.Buffer + oldWriter := log.Writer() + oldFlags := log.Flags() + oldPrefix := log.Prefix() + log.SetOutput(&buf) + log.SetFlags(0) + log.SetPrefix("") + t.Cleanup(func() { + log.SetOutput(oldWriter) + log.SetFlags(oldFlags) + log.SetPrefix(oldPrefix) + }) + + err := service.handleRunError(context.Background(), "run-1", "session-1", providerErr) + if !errors.Is(err, providerErr) { + t.Fatalf("expected provider error passthrough, got %v", err) + } + if got := buf.String(); got != "" { + t.Fatalf("expected no std log output, got %q", got) + } + +} diff --git a/internal/runtime/events_subagent.go b/internal/runtime/events_subagent.go new file mode 100644 index 00000000..094564ab --- /dev/null +++ b/internal/runtime/events_subagent.go @@ -0,0 +1,33 @@ +package runtime + +import "neo-code/internal/subagent" + +// EventPermissionRequest 为兼容旧事件名保留,语义等同 EventPermissionRequested。 +const EventPermissionRequest EventType = EventPermissionRequested + +// EventCompactDone 为兼容旧事件名保留,语义等同 EventCompactApplied。 +const EventCompactDone EventType = EventCompactApplied + +// SubAgentEventPayload 描述子代理执行生命周期的事件载荷。 +type SubAgentEventPayload struct { + Role subagent.Role `json:"role"` + TaskID string `json:"task_id"` + State subagent.State `json:"state"` + StopReason subagent.StopReason `json:"stop_reason,omitempty"` + Step int `json:"step,omitempty"` + Delta string `json:"delta,omitempty"` + Error string `json:"error,omitempty"` +} + +const ( + // EventSubAgentStarted 在子代理任务启动后触发。 + EventSubAgentStarted EventType = "subagent_started" + // EventSubAgentProgress 在子代理执行每一步后触发。 + EventSubAgentProgress EventType = "subagent_progress" + // EventSubAgentCompleted 在子代理成功结束后触发。 + EventSubAgentCompleted EventType = "subagent_completed" + // EventSubAgentFailed 在子代理失败结束后触发。 + EventSubAgentFailed EventType = "subagent_failed" + // EventSubAgentCanceled 在子代理被取消后触发。 + EventSubAgentCanceled EventType = "subagent_canceled" +) diff --git a/internal/runtime/run.go b/internal/runtime/run.go index a56cda08..ea3930c5 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -2,6 +2,9 @@ package runtime import ( "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" "errors" "fmt" "strings" @@ -18,6 +21,37 @@ import ( "neo-code/internal/tools" ) +const selfHealingReminder = "System Reminder: You have made multiple consecutive attempts without making substantial progress. Please stop your current repetitive or ineffective strategy. Carefully review the previous errors, change your approach, or ask the user directly for help." +const selfHealingRepeatReminder = "System Reminder: You are repeatedly calling the same tool with the exact same arguments. This is an infinite loop. Please change your parameters, try a different tool, or ask the user for help." + +// computeToolSignature 计算单轮执行的工具签名,用于循环检测。 +func computeToolSignature(calls []providertypes.ToolCall) string { + if len(calls) == 0 { + return "" + } + var sb strings.Builder + for _, call := range calls { + sb.WriteString(call.Name) + sb.WriteString(":") + + // 尝试将 JSON 参数进行规范化序列化,以消除空格、换行和字段顺序带来的哈希差异 + var parsed interface{} + if err := json.Unmarshal([]byte(call.Arguments), &parsed); err == nil { + if canonicalBytes, err := json.Marshal(parsed); err == nil { + sb.WriteString(string(canonicalBytes)) + } else { + sb.WriteString(call.Arguments) // 序列化失败,降级为原始字符串 + } + } else { + sb.WriteString(call.Arguments) // 解析失败,降级为原始字符串 + } + + sb.WriteString(";") + } + hash := sha256.Sum256([]byte(sb.String())) + return hex.EncodeToString(hash[:]) +} + // Run 执行一次完整的 ReAct 闭环:保存用户输入、驱动模型、执行工具并发出事件。 // 已有会话会先加锁再加载/更新,确保同一会话并发 Run 不会出现状态覆盖; // 新会话在创建后再绑定会话锁,不同会话可并行执行。 @@ -44,32 +78,18 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { initialCfg := s.configManager.Get() sessionID := strings.TrimSpace(input.SessionID) - releaseSessionLock := func() {} + releaseSessionLock := s.bindSessionLock(sessionID) defer func() { releaseSessionLock() }() - if sessionID != "" { - sessionMu, releaseLockRef := s.acquireSessionLock(sessionID) - sessionMu.Lock() - releaseSessionLock = func() { - sessionMu.Unlock() - releaseLockRef() - } - } - session, err := s.loadOrCreateSession(ctx, input.SessionID, input.Content, initialCfg.Workdir, input.Workdir) if err != nil { return s.handleRunError(ctx, input.RunID, input.SessionID, err) } if sessionID == "" { - sessionMu, releaseLockRef := s.acquireSessionLock(session.ID) - sessionMu.Lock() - releaseSessionLock = func() { - sessionMu.Unlock() - releaseLockRef() - } + releaseSessionLock = s.bindSessionLock(session.ID) } state := newRunState(input.RunID, session) @@ -131,33 +151,39 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { s.transitionRunPhase(ctx, &state, controlplane.PhaseVerify) var evidence []controlplane.ProgressEvidenceRecord - hasProgress := false toolCallCount := len(turnResult.assistant.ToolCalls) + currentSignature := computeToolSignature(turnResult.assistant.ToolCalls) + state.mu.Lock() if len(state.session.Messages) >= toolCallCount { for i := len(state.session.Messages) - toolCallCount; i < len(state.session.Messages); i++ { - msg := state.session.Messages[i] - if msg.Role == providertypes.RoleTool && !msg.IsError { - hasProgress = true + if msg := state.session.Messages[i]; msg.Role == providertypes.RoleTool && !msg.IsError { + evidence = append(evidence, controlplane.ProgressEvidenceRecord{Kind: controlplane.EvidenceNewInfoNonDup}) break } } } - state.mu.Unlock() - if hasProgress { - evidence = append(evidence, controlplane.ProgressEvidenceRecord{Kind: controlplane.EvidenceNewInfoNonDup}) - } - - state.mu.Lock() - state.progress = controlplane.ApplyProgressEvidence(state.progress, evidence) + state.progress = controlplane.ApplyProgressEvidence(state.progress, evidence, currentSignature) streak := state.progress.LastScore.NoProgressStreak + repeatStreak := state.progress.LastScore.RepeatCycleStreak currentScore := state.progress.LastScore state.mu.Unlock() s.emitRunScoped(ctx, EventProgressEvaluated, &state, ProgressEvaluatedPayload{Score: currentScore}) - if streak >= noProgressStreakLimit { + repeatLimit := snapshot.config.Runtime.MaxRepeatCycleStreak + if repeatLimit <= 0 { + repeatLimit = config.DefaultMaxRepeatCycleStreak + } + + if repeatStreak >= repeatLimit { + err = ErrRepeatCycleLimit + return err + } + + limit := snapshot.noProgressStreakLimit + if streak >= limit { err = ErrNoProgressStreakLimit return err } @@ -190,7 +216,7 @@ func (s *Service) prepareTurnSnapshot(ctx context.Context, state *runState) (tur }, Compact: agentcontext.CompactOptions{ DisableMicroCompact: cfg.Context.Compact.MicroCompactDisabled, - AutoCompactThreshold: autoCompactThreshold(cfg), + AutoCompactThreshold: s.autoCompactThresholdForState(ctx, cfg, state), MicroCompactRetainedToolSpans: cfg.Context.Compact.MicroCompactRetainedToolSpans, ReadTimeMaxMessageSpans: cfg.Context.Compact.ReadTimeMaxMessageSpans, }, @@ -221,22 +247,68 @@ func (s *Service) prepareTurnSnapshot(ctx context.Context, state *runState) (tur return turnSnapshot{}, false, err } + state.mu.Lock() + streak := state.progress.LastScore.NoProgressStreak + repeatStreak := state.progress.LastScore.RepeatCycleStreak + state.mu.Unlock() + + limit := resolveNoProgressStreakLimit(cfg.Runtime) +<<<<<<< codex/issue-294-auto-compact-threshold + systemPrompt := withSelfHealingReminder(builtContext.SystemPrompt, streak, limit) +======= + repeatLimit := resolveRepeatCycleStreakLimit(cfg.Runtime) + systemPrompt := builtContext.SystemPrompt + + if repeatStreak == repeatLimit-1 { + trimmed := strings.TrimSpace(systemPrompt) + if trimmed == "" { + systemPrompt = selfHealingRepeatReminder + } else { + systemPrompt = trimmed + "\n\n" + selfHealingRepeatReminder + } + } else if streak == limit-1 { + trimmed := strings.TrimSpace(systemPrompt) + if trimmed == "" { + systemPrompt = selfHealingReminder + } else { + systemPrompt = trimmed + "\n\n" + selfHealingReminder + } + } +>>>>>>> main + model := strings.TrimSpace(cfg.CurrentModel) return turnSnapshot{ - config: cfg, - providerConfig: resolvedProvider.ToRuntimeConfig(), - model: model, - workdir: activeWorkdir, - toolTimeout: time.Duration(cfg.ToolTimeoutSec) * time.Second, + config: cfg, + providerConfig: resolvedProvider.ToRuntimeConfig(), + model: model, + workdir: activeWorkdir, + toolTimeout: time.Duration(cfg.ToolTimeoutSec) * time.Second, + noProgressStreakLimit: limit, request: providertypes.GenerateRequest{ Model: model, - SystemPrompt: builtContext.SystemPrompt, + SystemPrompt: systemPrompt, Messages: builtContext.Messages, Tools: toolSpecs, }, }, false, nil } +// resolveNoProgressStreakLimit 统一解析熔断阈值,避免运行期出现无效值导致分支行为不一致。 +func resolveNoProgressStreakLimit(rc config.RuntimeConfig) int { + if rc.MaxNoProgressStreak <= 0 { + return config.DefaultMaxNoProgressStreak + } + return rc.MaxNoProgressStreak +} + +// resolveRepeatCycleStreakLimit 统一解析重复调用循环阈值。 +func resolveRepeatCycleStreakLimit(rc config.RuntimeConfig) int { + if rc.MaxRepeatCycleStreak <= 0 { + return config.DefaultMaxRepeatCycleStreak + } + return rc.MaxRepeatCycleStreak +} + // callProviderWithRetry 使用冻结后的 turnSnapshot 执行 provider 调用与必要重试。 func (s *Service) callProviderWithRetry( ctx context.Context, @@ -331,11 +403,39 @@ func (s *Service) applyCompactForState( } // autoCompactThreshold 返回当前配置下的自动 compact 触发阈值。 -func autoCompactThreshold(cfg config.Config) int { - if cfg.Context.AutoCompact.Enabled && cfg.Context.AutoCompact.InputTokenThreshold > 0 { +func (s *Service) autoCompactThreshold(ctx context.Context, cfg config.Config) int { + return s.autoCompactThresholdForState(ctx, cfg, nil) +} + +// autoCompactThresholdForState 返回当前配置下的自动 compact 触发阈值,并在单次 run 内按关键输入缓存结果。 +func (s *Service) autoCompactThresholdForState(ctx context.Context, cfg config.Config, state *runState) int { + if !cfg.Context.AutoCompact.Enabled { + return 0 + } + if cfg.Context.AutoCompact.InputTokenThreshold > 0 { return cfg.Context.AutoCompact.InputTokenThreshold } - return 0 + + key := autoCompactCacheKeyFromConfig(cfg) + if state != nil && state.autoCompactCache.valid && state.autoCompactCache.key == key { + return state.autoCompactCache.threshold + } + + threshold := fallbackAutoCompactThreshold(cfg) + if s != nil && s.autoCompactThresholdResolver != nil { + resolvedThreshold, err := s.autoCompactThresholdResolver.ResolveAutoCompactThreshold(ctx, cfg) + if err == nil && resolvedThreshold > 0 { + threshold = resolvedThreshold + } + } + if state != nil { + state.autoCompactCache = autoCompactThresholdCache{ + key: key, + threshold: threshold, + valid: true, + } + } + return threshold } // degradeKeepRecentMessages 根据 reactive compact 尝试次数逐步减少保留消息数。 @@ -348,3 +448,49 @@ func degradeKeepRecentMessages(base int, attempt int) int { } return base } + +// fallbackAutoCompactThreshold 返回自动推导失败时仍可继续使用的保底阈值。 +func fallbackAutoCompactThreshold(cfg config.Config) int { + if cfg.Context.AutoCompact.FallbackInputTokenThreshold > 0 { + return cfg.Context.AutoCompact.FallbackInputTokenThreshold + } + return 0 +} + +// bindSessionLock 获取并持有指定会话锁,返回对应的释放函数。 +func (s *Service) bindSessionLock(sessionID string) func() { + id := strings.TrimSpace(sessionID) + if id == "" { + return func() {} + } + sessionMu, releaseLockRef := s.acquireSessionLock(id) + sessionMu.Lock() + return func() { + sessionMu.Unlock() + releaseLockRef() + } +} + +// withSelfHealingReminder 在无进展临界轮次注入自愈提醒,保持提示词拼接规则集中。 +func withSelfHealingReminder(systemPrompt string, streak int, limit int) string { + if streak != limit-1 { + return systemPrompt + } + trimmed := strings.TrimSpace(systemPrompt) + if trimmed == "" { + return selfHealingReminder + } + return trimmed + "\n\n" + selfHealingReminder +} + +// autoCompactCacheKeyFromConfig 提取会影响自动压缩阈值解析的配置维度,用于 run 内缓存命中判断。 +func autoCompactCacheKeyFromConfig(cfg config.Config) autoCompactThresholdCacheKey { + return autoCompactThresholdCacheKey{ + provider: strings.TrimSpace(cfg.SelectedProvider), + model: strings.TrimSpace(cfg.CurrentModel), + autoCompactEnabled: cfg.Context.AutoCompact.Enabled, + autoCompactInputThreshold: cfg.Context.AutoCompact.InputTokenThreshold, + autoCompactReserveTokens: cfg.Context.AutoCompact.ReserveTokens, + autoCompactFallback: cfg.Context.AutoCompact.FallbackInputTokenThreshold, + } +} diff --git a/internal/runtime/run_lifecycle.go b/internal/runtime/run_lifecycle.go index 52a5015f..d82d71b1 100644 --- a/internal/runtime/run_lifecycle.go +++ b/internal/runtime/run_lifecycle.go @@ -3,7 +3,6 @@ package runtime import ( "context" "errors" - "log" "math/rand/v2" "strings" "time" @@ -19,6 +18,9 @@ var ErrMaxLoopReached = errors.New("runtime: max loop reached") // ErrNoProgressStreakLimit 表示循环内连续多次未取得进展,触发死循环拦截。 var ErrNoProgressStreakLimit = errors.New("runtime: no progress streak limit reached") +// ErrRepeatCycleLimit 表示连续多次重复调用相同的工具且参数相同,触发死循环拦截。 +var ErrRepeatCycleLimit = errors.New("runtime: repeat cycle limit reached") + // transitionRunPhase 在阶段变化时发出 phase_changed 并更新 runState。 func (s *Service) transitionRunPhase(ctx context.Context, state *runState, next controlplane.Phase) { if state == nil || state.phase == next { @@ -103,12 +105,6 @@ func (s *Service) handleRunError(ctx context.Context, runID string, sessionID st return context.Canceled } - var providerErr *provider.ProviderError - if errors.As(err, &providerErr) { - log.Printf("runtime: provider error (status=%d, code=%s, retryable=%v): %s", - providerErr.StatusCode, providerErr.Code, providerErr.Retryable, providerErr.Message) - } - return err } diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index d965f312..ce5546a4 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -21,7 +21,6 @@ const ( providerRetryBaseWait = 1 * time.Second providerRetryMaxWait = 5 * time.Second defaultToolParallelism = 4 - noProgressStreakLimit = 3 terminationEventEmitTimeout = 500 * time.Millisecond ) @@ -61,16 +60,21 @@ type MemoExtractor interface { } // Service 是 runtime 的默认实现,负责组织一次完整的 agent 运行闭环。 +type AutoCompactThresholdResolver interface { + ResolveAutoCompactThreshold(ctx context.Context, cfg config.Config) (int, error) +} + type Service struct { - configManager *config.Manager - sessionStore agentsession.Store - toolManager tools.Manager - providerFactory ProviderFactory - contextBuilder agentcontext.Builder - compactRunner contextcompact.Runner - approvalBroker *approval.Broker - memoExtractor MemoExtractor - skillsRegistry skills.Registry + configManager *config.Manager + sessionStore agentsession.Store + toolManager tools.Manager + providerFactory ProviderFactory + contextBuilder agentcontext.Builder + compactRunner contextcompact.Runner + approvalBroker *approval.Broker + memoExtractor MemoExtractor + skillsRegistry skills.Registry + autoCompactThresholdResolver AutoCompactThresholdResolver events chan RuntimeEvent sessionMu sync.Mutex @@ -172,3 +176,8 @@ func (s *Service) LoadSession(ctx context.Context, id string) (agentsession.Sess } return session, nil } + +// SetAutoCompactThresholdResolver 注入自动压缩阈值解析能力,避免 runtime 直接处理模型目录细节。 +func (s *Service) SetAutoCompactThresholdResolver(resolver AutoCompactThresholdResolver) { + s.autoCompactThresholdResolver = resolver +} diff --git a/internal/runtime/runtime_internal_helpers_test.go b/internal/runtime/runtime_internal_helpers_test.go index 91e86a91..46ac92f4 100644 --- a/internal/runtime/runtime_internal_helpers_test.go +++ b/internal/runtime/runtime_internal_helpers_test.go @@ -177,6 +177,120 @@ func TestAppendToolMessageAndSaveSanitizesMetadata(t *testing.T) { } } +func TestAppendToolMessageAndSavePreservesMetadataOnlySuccessResult(t *testing.T) { + t.Parallel() + + store := newMemoryStore() + session := newRuntimeSession("session-append-tool-metadata-only") + store.sessions[session.ID] = cloneSession(session) + + service := &Service{sessionStore: store} + state := newRunState("run-append-tool-metadata-only", session) + call := providertypes.ToolCall{ID: "call-1", Name: "filesystem_read_file"} + result := tools.ToolResult{ + Name: "filesystem_read_file", + Content: "", + Metadata: map[string]any{ + "path": "README.md", + }, + } + + if err := service.appendToolMessageAndSave(context.Background(), &state, call, result); err != nil { + t.Fatalf("appendToolMessageAndSave() error = %v", err) + } + + msg := state.session.Messages[0] + if msg.Content != "" { + t.Fatalf("expected metadata-only success result to keep empty content, got %q", msg.Content) + } + if msg.ToolMetadata["tool_name"] != "filesystem_read_file" || msg.ToolMetadata["path"] != "README.md" { + t.Fatalf("expected metadata-only success result to keep sanitized metadata, got %+v", msg.ToolMetadata) + } +} + +func TestAppendToolMessageAndSaveNormalizesSemanticallyEmptySuccessResult(t *testing.T) { + t.Parallel() + + store := newMemoryStore() + session := newRuntimeSession("session-append-tool-empty-success") + store.sessions[session.ID] = cloneSession(session) + + service := &Service{sessionStore: store} + state := newRunState("run-append-tool-empty-success", session) + call := providertypes.ToolCall{ID: "call-1", Name: "filesystem_read_file"} + result := tools.ToolResult{ + Name: "filesystem_read_file", + Content: " ", + } + + if err := service.appendToolMessageAndSave(context.Background(), &state, call, result); err != nil { + t.Fatalf("appendToolMessageAndSave() error = %v", err) + } + + msg := state.session.Messages[0] + if msg.Content != "ok" { + t.Fatalf("expected empty success result to be normalized to ok, got %q", msg.Content) + } + if msg.ToolMetadata["tool_name"] != "filesystem_read_file" { + t.Fatalf("expected tool_name metadata to be preserved after normalization, got %+v", msg.ToolMetadata) + } +} + +func TestAppendToolMessageAndSaveNormalizesToolNameOnlyMetadataSuccessResult(t *testing.T) { + t.Parallel() + + store := newMemoryStore() + session := newRuntimeSession("session-append-tool-name-only-metadata-success") + store.sessions[session.ID] = cloneSession(session) + + service := &Service{sessionStore: store} + state := newRunState("run-append-tool-name-only-metadata-success", session) + call := providertypes.ToolCall{ID: "call-1", Name: "filesystem_read_file"} + result := tools.ToolResult{ + Name: "filesystem_read_file", + Content: " ", + Metadata: map[string]any{ + "unsupported_key": "ignored", + }, + } + + if err := service.appendToolMessageAndSave(context.Background(), &state, call, result); err != nil { + t.Fatalf("appendToolMessageAndSave() error = %v", err) + } + + msg := state.session.Messages[0] + if msg.Content != "ok" { + t.Fatalf("expected tool_name-only metadata success to normalize content to ok, got %q", msg.Content) + } + if len(msg.ToolMetadata) != 1 || msg.ToolMetadata["tool_name"] != "filesystem_read_file" { + t.Fatalf("expected only tool_name metadata to remain, got %+v", msg.ToolMetadata) + } +} + +func TestAppendToolMessageAndSaveFallsBackToCallNameForToolMetadata(t *testing.T) { + t.Parallel() + + store := newMemoryStore() + session := newRuntimeSession("session-append-tool-name-fallback") + store.sessions[session.ID] = cloneSession(session) + + service := &Service{sessionStore: store} + state := newRunState("run-append-tool-name-fallback", session) + call := providertypes.ToolCall{ID: "call-1", Name: "filesystem_read_file"} + result := tools.ToolResult{ + Content: "ok", + } + + if err := service.appendToolMessageAndSave(context.Background(), &state, call, result); err != nil { + t.Fatalf("appendToolMessageAndSave() error = %v", err) + } + + msg := state.session.Messages[0] + if msg.ToolMetadata["tool_name"] != "filesystem_read_file" { + t.Fatalf("expected tool_name fallback from call name, got %+v", msg.ToolMetadata) + } +} + func TestAppendToolMessageAndSaveUnlocksStateBeforePersist(t *testing.T) { t.Parallel() diff --git a/internal/runtime/runtime_limits.go b/internal/runtime/runtime_limits.go new file mode 100644 index 00000000..3cab9213 --- /dev/null +++ b/internal/runtime/runtime_limits.go @@ -0,0 +1,4 @@ +package runtime + +// defaultMaxLoops 定义兼容旧运行循环逻辑时的默认最大轮数。 +const defaultMaxLoops = 8 diff --git a/internal/runtime/runtime_progress_test.go b/internal/runtime/runtime_progress_test.go index f83f73f2..89ec5189 100644 --- a/internal/runtime/runtime_progress_test.go +++ b/internal/runtime/runtime_progress_test.go @@ -3,10 +3,13 @@ package runtime import ( "context" "errors" + "strconv" + "strings" "sync/atomic" "testing" "neo-code/internal/config" + agentcontext "neo-code/internal/context" providertypes "neo-code/internal/provider/types" "neo-code/internal/runtime/controlplane" "neo-code/internal/tools" @@ -31,12 +34,22 @@ func TestProgressStreakStopsRun(t *testing.T) { }, } + var promptInjected bool + var signatureSeq int32 providerFactory := &scriptedProviderFactory{ provider: &scriptedProvider{ chatFn: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { + seq := atomic.AddInt32(&signatureSeq, 1) + if strings.Contains(req.SystemPrompt, selfHealingReminder) { + promptInjected = true + } // the model always decides to call the tool events <- providertypes.NewToolCallStartStreamEvent(0, "call_err", "tool_error") - events <- providertypes.NewToolCallDeltaStreamEvent(0, "call_err", "{}") + events <- providertypes.NewToolCallDeltaStreamEvent( + 0, + "call_err", + `{"seq":`+strconv.FormatInt(int64(seq), 10)+`}`, + ) events <- providertypes.NewMessageDoneStreamEvent("tool_calls", nil) return nil }, @@ -70,18 +83,10 @@ func TestProgressStreakStopsRun(t *testing.T) { events := collectRuntimeEvents(service.Events()) // Verify StopReason is error and specifies the streak limit - assertEventContains(t, events, EventStopReasonDecided) + assertStopReasonDecided(t, events, controlplane.StopReasonError, ErrNoProgressStreakLimit.Error()) - for _, e := range events { - if e.Type == EventStopReasonDecided { - payload := e.Payload.(StopReasonDecidedPayload) - if payload.Reason != controlplane.StopReasonError { - t.Errorf("expected StopReasonError, got %s", payload.Reason) - } - if payload.Detail != ErrNoProgressStreakLimit.Error() { - t.Errorf("expected detail to be %q, got %q", ErrNoProgressStreakLimit.Error(), payload.Detail) - } - } + if !promptInjected { + t.Error("expected self-healing prompt to be injected before streak limit is reached, but it wasn't") } } @@ -109,13 +114,19 @@ func TestProgressEvidenceResetsNoProgressStreak(t *testing.T) { } var providerCalls int32 + var signatureSeq int32 providerFactory := &scriptedProviderFactory{ provider: &scriptedProvider{ chatFn: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { call := int(atomic.AddInt32(&providerCalls, 1)) if call <= 4 { + seq := atomic.AddInt32(&signatureSeq, 1) events <- providertypes.NewToolCallStartStreamEvent(0, "call_mixed", "tool_mixed") - events <- providertypes.NewToolCallDeltaStreamEvent(0, "call_mixed", "{}") + events <- providertypes.NewToolCallDeltaStreamEvent( + 0, + "call_mixed", + `{"seq":`+strconv.FormatInt(int64(seq), 10)+`}`, + ) events <- providertypes.NewMessageDoneStreamEvent("tool_calls", nil) return nil } @@ -151,12 +162,272 @@ func TestProgressEvidenceResetsNoProgressStreak(t *testing.T) { } events := collectRuntimeEvents(service.Events()) + assertStopReasonDecided(t, events, controlplane.StopReasonSuccess, "") +} + +func TestRepeatCycleStreakStopsRunAndInjectsReminder(t *testing.T) { + t.Setenv("TEST_KEY", "dummy") + + cfg := config.Config{ + Providers: []config.ProviderConfig{{Name: "test-repeat", Driver: "test", BaseURL: "http://localhost", Model: "test", APIKeyEnv: "TEST_KEY"}}, + SelectedProvider: "test-repeat", + Workdir: t.TempDir(), + Runtime: config.RuntimeConfig{ + MaxNoProgressStreak: 10, + MaxRepeatCycleStreak: 3, + }, + } + + var executeCalls int32 + var providerCalls int32 + toolManager := &stubToolManager{ + specs: []providertypes.ToolSpec{ + {Name: "tool_repeat"}, + }, + executeFn: func(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { + atomic.AddInt32(&executeCalls, 1) + return tools.ToolResult{Name: input.Name, Content: "ok", IsError: false}, nil + }, + } + + var promptInjected bool + providerFactory := &scriptedProviderFactory{ + provider: &scriptedProvider{ + chatFn: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { + atomic.AddInt32(&providerCalls, 1) + if strings.Contains(req.SystemPrompt, selfHealingRepeatReminder) { + promptInjected = true + } + events <- providertypes.NewToolCallStartStreamEvent(0, "call_repeat", "tool_repeat") + events <- providertypes.NewToolCallDeltaStreamEvent(0, "call_repeat", `{"path":"x"}`) + events <- providertypes.NewMessageDoneStreamEvent("tool_calls", nil) + return nil + }, + }, + } + + manager := config.NewManager(config.NewLoader(t.TempDir(), &cfg)) + service := NewWithFactory( + manager, + toolManager, + newMemoryStore(), + providerFactory, + nil, + ) + + err := service.Run(context.Background(), UserInput{ + RunID: "run-repeat-streak", + Content: "trigger repeat loop", + }) + if err == nil { + t.Fatal("expected repeat cycle limit error, got nil") + } + if !errors.Is(err, ErrRepeatCycleLimit) { + t.Fatalf("expected ErrRepeatCycleLimit, got %v", err) + } + + events := collectRuntimeEvents(service.Events()) + + assertStopReasonDecided(t, events, controlplane.StopReasonError, ErrRepeatCycleLimit.Error()) + + if !promptInjected { + t.Fatal("expected repeat self-healing prompt to be injected before repeat limit is reached") + } + if executeCalls != 3 { + t.Fatalf("expected break on the 3rd identical tool execution, got %d", executeCalls) + } + if providerCalls != 3 { + t.Fatalf("expected 3 provider turns before repeat breaker, got %d", providerCalls) + } +} + +func TestRepeatCycleStreakCountsFailedToolCalls(t *testing.T) { + t.Setenv("TEST_KEY", "dummy") + + cfg := config.Config{ + Providers: []config.ProviderConfig{{Name: "test-repeat-fail", Driver: "test", BaseURL: "http://localhost", Model: "test", APIKeyEnv: "TEST_KEY"}}, + SelectedProvider: "test-repeat-fail", + Workdir: t.TempDir(), + Runtime: config.RuntimeConfig{ + MaxNoProgressStreak: 10, + MaxRepeatCycleStreak: 3, + }, + } + + var executeCalls int32 + toolManager := &stubToolManager{ + specs: []providertypes.ToolSpec{ + {Name: "tool_repeat_fail"}, + }, + executeFn: func(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { + atomic.AddInt32(&executeCalls, 1) + return tools.ToolResult{Name: input.Name, Content: "error", IsError: true}, nil + }, + } + + var providerCalls int32 + providerFactory := &scriptedProviderFactory{ + provider: &scriptedProvider{ + chatFn: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { + atomic.AddInt32(&providerCalls, 1) + events <- providertypes.NewToolCallStartStreamEvent(0, "call_repeat_fail", "tool_repeat_fail") + events <- providertypes.NewToolCallDeltaStreamEvent(0, "call_repeat_fail", `{"path":"x"}`) + events <- providertypes.NewMessageDoneStreamEvent("tool_calls", nil) + return nil + }, + }, + } + + manager := config.NewManager(config.NewLoader(t.TempDir(), &cfg)) + service := NewWithFactory( + manager, + toolManager, + newMemoryStore(), + providerFactory, + nil, + ) + + err := service.Run(context.Background(), UserInput{ + RunID: "run-repeat-fail-streak", + Content: "trigger repeat fail loop", + }) + if !errors.Is(err, ErrRepeatCycleLimit) { + t.Fatalf("expected ErrRepeatCycleLimit, got %v", err) + } + if executeCalls != 3 { + t.Fatalf("expected failed repeated calls to break on the 3rd execution, got %d", executeCalls) + } + if providerCalls != 3 { + t.Fatalf("expected 3 provider turns before repeat breaker, got %d", providerCalls) + } +} + +func TestComputeToolSignatureNormalizationAndFallback(t *testing.T) { + if got := computeToolSignature(nil); got != "" { + t.Fatalf("expected empty signature for nil tool calls, got %q", got) + } + + callsA := []providertypes.ToolCall{ + {Name: "filesystem_read_file", Arguments: "{\n \"path\": \"/tmp/a.txt\",\n \"opts\": {\"y\": [2,3], \"x\": 1}\n}"}, + {Name: "bash", Arguments: "{\"cmd\":\"pwd\"}"}, + } + callsB := []providertypes.ToolCall{ + {Name: "filesystem_read_file", Arguments: "{\"opts\":{\"x\":1,\"y\":[2,3]},\"path\":\"/tmp/a.txt\"}"}, + {Name: "bash", Arguments: "{ \"cmd\" : \"pwd\" }"}, + } + sigA := computeToolSignature(callsA) + sigB := computeToolSignature(callsB) + if sigA != sigB { + t.Fatalf("expected canonicalized signatures to match, got %q vs %q", sigA, sigB) + } + + invalidA := []providertypes.ToolCall{{Name: "bash", Arguments: "{\"cmd\":"}} + invalidB := []providertypes.ToolCall{{Name: "bash", Arguments: "{\"cmd\":\"ls\"}"}} + if computeToolSignature(invalidA) == computeToolSignature(invalidB) { + t.Fatal("expected invalid-json fallback signature to differ from valid-json signature") + } +} + +func TestPrepareTurnSnapshotInjectRepeatReminderWithEmptyPrompt(t *testing.T) { + manager := newRuntimeConfigManager(t) + if err := manager.Update(context.Background(), func(cfg *config.Config) error { + cfg.Runtime.MaxRepeatCycleStreak = 3 + return nil + }); err != nil { + t.Fatalf("update config: %v", err) + } + + service := &Service{ + configManager: manager, + contextBuilder: &stubContextBuilder{ + buildFn: func(ctx context.Context, input agentcontext.BuildInput) (agentcontext.BuildResult, error) { + return agentcontext.BuildResult{SystemPrompt: "", Messages: input.Messages}, nil + }, + }, + toolManager: &stubToolManager{}, + } + state := newRunState("run-repeat-reminder-empty", newRuntimeSession("session-repeat-reminder-empty")) + state.progress.LastScore.RepeatCycleStreak = 2 + + snapshot, rebuilt, err := service.prepareTurnSnapshot(context.Background(), &state) + if err != nil { + t.Fatalf("prepareTurnSnapshot() error = %v", err) + } + if rebuilt { + t.Fatal("expected rebuilt=false") + } + if snapshot.request.SystemPrompt != selfHealingRepeatReminder { + t.Fatalf("expected repeat reminder only, got %q", snapshot.request.SystemPrompt) + } +} + +func TestPrepareTurnSnapshotRepeatReminderTakesPriority(t *testing.T) { + manager := newRuntimeConfigManager(t) + if err := manager.Update(context.Background(), func(cfg *config.Config) error { + cfg.Runtime.MaxNoProgressStreak = 3 + cfg.Runtime.MaxRepeatCycleStreak = 3 + return nil + }); err != nil { + t.Fatalf("update config: %v", err) + } + + service := &Service{ + configManager: manager, + contextBuilder: &stubContextBuilder{ + buildFn: func(ctx context.Context, input agentcontext.BuildInput) (agentcontext.BuildResult, error) { + return agentcontext.BuildResult{SystemPrompt: "base prompt", Messages: input.Messages}, nil + }, + }, + toolManager: &stubToolManager{}, + } + state := newRunState("run-reminder-priority", newRuntimeSession("session-reminder-priority")) + state.progress.LastScore.NoProgressStreak = 2 + state.progress.LastScore.RepeatCycleStreak = 2 + + snapshot, rebuilt, err := service.prepareTurnSnapshot(context.Background(), &state) + if err != nil { + t.Fatalf("prepareTurnSnapshot() error = %v", err) + } + if rebuilt { + t.Fatal("expected rebuilt=false") + } + if !strings.Contains(snapshot.request.SystemPrompt, selfHealingRepeatReminder) { + t.Fatalf("expected prompt to contain repeat reminder, got %q", snapshot.request.SystemPrompt) + } + if strings.Contains(snapshot.request.SystemPrompt, selfHealingReminder) { + t.Fatalf("expected no-progress reminder to be skipped when repeat reminder is injected, got %q", snapshot.request.SystemPrompt) + } +} + +func TestResolveStreakLimitDefaults(t *testing.T) { + if got := resolveNoProgressStreakLimit(config.RuntimeConfig{MaxNoProgressStreak: 0}); got != config.DefaultMaxNoProgressStreak { + t.Fatalf("expected default no-progress limit %d, got %d", config.DefaultMaxNoProgressStreak, got) + } + if got := resolveNoProgressStreakLimit(config.RuntimeConfig{MaxNoProgressStreak: 8}); got != 8 { + t.Fatalf("expected explicit no-progress limit 8, got %d", got) + } + + if got := resolveRepeatCycleStreakLimit(config.RuntimeConfig{MaxRepeatCycleStreak: -1}); got != config.DefaultMaxRepeatCycleStreak { + t.Fatalf("expected default repeat limit %d, got %d", config.DefaultMaxRepeatCycleStreak, got) + } + if got := resolveRepeatCycleStreakLimit(config.RuntimeConfig{MaxRepeatCycleStreak: 6}); got != 6 { + t.Fatalf("expected explicit repeat limit 6, got %d", got) + } +} + +func assertStopReasonDecided(t *testing.T, events []RuntimeEvent, wantReason controlplane.StopReason, wantDetail string) { + t.Helper() + assertEventContains(t, events, EventStopReasonDecided) for _, e := range events { - if e.Type == EventStopReasonDecided { - payload := e.Payload.(StopReasonDecidedPayload) - if payload.Reason != controlplane.StopReasonSuccess { - t.Fatalf("expected stop reason success, got %s", payload.Reason) - } + if e.Type != EventStopReasonDecided { + continue + } + payload := e.Payload.(StopReasonDecidedPayload) + if payload.Reason != wantReason { + t.Fatalf("expected stop reason %s, got %s", wantReason, payload.Reason) + } + if wantDetail != "" && payload.Detail != wantDetail { + t.Fatalf("expected detail to be %q, got %q", wantDetail, payload.Detail) } } } diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 7de08a4f..c03b8a2a 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -39,6 +39,12 @@ type failingStore struct { ignoreContextErr bool } +type autoCompactThresholdResolverFunc func(ctx context.Context, cfg config.Config) (int, error) + +func (f autoCompactThresholdResolverFunc) ResolveAutoCompactThreshold(ctx context.Context, cfg config.Config) (int, error) { + return f(ctx, cfg) +} + func newMemoryStore() *memoryStore { return &memoryStore{sessions: map[string]agentsession.Session{}} } @@ -545,6 +551,76 @@ func TestServiceRun(t *testing.T) { } }, }, + { + name: "metadata-only tool result is projected on follow-up provider round", + input: UserInput{RunID: "run-tool-metadata-only", Content: "inspect file"}, + providerStreams: [][]providertypes.StreamEvent{ + { + providertypes.NewToolCallStartStreamEvent(0, "call-1", "filesystem_read_file"), + providertypes.NewToolCallDeltaStreamEvent(0, "call-1", `{"path":"README.md"}`), + }, + { + providertypes.NewTextDeltaStreamEvent("done"), + }, + }, + registerTool: &stubTool{ + name: "filesystem_read_file", + executeFn: func(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { + return tools.ToolResult{ + Name: "filesystem_read_file", + Content: "", + Metadata: map[string]any{ + "path": "README.md", + }, + }, nil + }, + }, + contextBuilder: &stubContextBuilder{ + buildFn: func(ctx context.Context, input agentcontext.BuildInput) (agentcontext.BuildResult, error) { + return agentcontext.BuildResult{ + SystemPrompt: "stub system prompt", + Messages: projectToolMessagesForProviderTest(input.Messages), + }, nil + }, + }, + expectProviderCalls: 2, + expectToolCalls: 1, + expectMessageRoles: []string{"user", "assistant", "tool", "assistant"}, + expectEventTypes: []EventType{EventUserMessage, EventToolStart, EventToolResult, EventAgentDone}, + assert: func(t *testing.T, store *memoryStore, scripted *scriptedProvider, tool *stubTool) { + t.Helper() + if len(scripted.requests) != 2 { + t.Fatalf("expected 2 provider requests, got %d", len(scripted.requests)) + } + second := scripted.requests[1] + foundToolResult := false + for _, message := range second.Messages { + if message.Role == providertypes.RoleTool && + message.ToolCallID == "call-1" && + strings.Contains(message.Content, "tool result") && + strings.Contains(message.Content, "tool: filesystem_read_file") && + strings.Contains(message.Content, "meta.path: README.md") { + foundToolResult = true + if strings.Contains(message.Content, "content:\n") { + t.Fatalf("expected metadata-only projection to omit content section, got %q", message.Content) + } + break + } + } + if !foundToolResult { + t.Fatalf("expected projected metadata-only tool result in second provider request: %+v", second.Messages) + } + + session := onlySession(t, store) + if session.Messages[2].Role != providertypes.RoleTool || session.Messages[2].Content != "" { + t.Fatalf("expected persisted tool message to keep empty raw content, got %+v", session.Messages[2]) + } + if session.Messages[2].ToolMetadata["tool_name"] != "filesystem_read_file" || + session.Messages[2].ToolMetadata["path"] != "README.md" { + t.Fatalf("expected persisted metadata-only tool message to keep sanitized metadata, got %+v", session.Messages[2].ToolMetadata) + } + }, + }, } for _, tt := range tests { @@ -1686,7 +1762,7 @@ func TestServiceRunErrorPaths(t *testing.T) { { ID: fmt.Sprintf("loop-call-%d", i), Name: "filesystem_edit", - Arguments: `{"path":"x"}`, + Arguments: fmt.Sprintf(`{"path":"x", "iteration": %d}`, i), }, }, }, @@ -3031,17 +3107,7 @@ func cloneBuildInput(input agentcontext.BuildInput) agentcontext.BuildInput { // projectToolMessagesForProviderTest 模拟 context 层在 provider 请求前对 tool 消息做的只读投影。 func projectToolMessagesForProviderTest(messages []providertypes.Message) []providertypes.Message { - projected := append([]providertypes.Message(nil), messages...) - for i, message := range projected { - if message.Role != providertypes.RoleTool || len(message.ToolMetadata) == 0 { - continue - } - next := message - next.Content = tools.FormatToolMessageForModel(message) - next.ToolMetadata = nil - projected[i] = next - } - return projected + return agentcontext.ProjectToolMessagesForModel(cloneMessages(messages)) } func containsError(err error, target string) bool { @@ -4126,6 +4192,7 @@ func TestRestoreSessionTokensNewSession(t *testing.T) { func TestAutoCompactThresholdEnabled(t *testing.T) { t.Parallel() + service := &Service{} cfg := config.Config{ Context: config.ContextConfig{ AutoCompact: config.AutoCompactConfig{ @@ -4135,7 +4202,7 @@ func TestAutoCompactThresholdEnabled(t *testing.T) { }, } - threshold := autoCompactThreshold(cfg) + threshold := service.autoCompactThreshold(context.Background(), cfg) if threshold != 50000 { t.Fatalf("expected threshold == 50000, got %d", threshold) } @@ -4144,6 +4211,7 @@ func TestAutoCompactThresholdEnabled(t *testing.T) { func TestAutoCompactThresholdDisabled(t *testing.T) { t.Parallel() + service := &Service{} cfg := config.Config{ Context: config.ContextConfig{ AutoCompact: config.AutoCompactConfig{ @@ -4153,7 +4221,7 @@ func TestAutoCompactThresholdDisabled(t *testing.T) { }, } - threshold := autoCompactThreshold(cfg) + threshold := service.autoCompactThreshold(context.Background(), cfg) if threshold != 0 { t.Fatalf("expected threshold == 0, got %d", threshold) } @@ -4162,6 +4230,7 @@ func TestAutoCompactThresholdDisabled(t *testing.T) { func TestAutoCompactThresholdZeroValue(t *testing.T) { t.Parallel() + service := &Service{} cfg := config.Config{ Context: config.ContextConfig{ AutoCompact: config.AutoCompactConfig{ @@ -4171,12 +4240,213 @@ func TestAutoCompactThresholdZeroValue(t *testing.T) { }, } - threshold := autoCompactThreshold(cfg) + threshold := service.autoCompactThreshold(context.Background(), cfg) if threshold != 0 { t.Fatalf("expected threshold == 0, got %d", threshold) } } +func TestAutoCompactThresholdUsesResolver(t *testing.T) { + t.Parallel() + + service := &Service{} + service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + return 88000, nil + }, + )) + + cfg := config.Config{ + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + }, + }, + } + + threshold := service.autoCompactThreshold(context.Background(), cfg) + if threshold != 88000 { + t.Fatalf("expected resolver threshold == 88000, got %d", threshold) + } +} + +func TestAutoCompactThresholdFallsBackWhenResolverErrors(t *testing.T) { + t.Parallel() + + service := &Service{} + service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + return 0, errors.New("resolver failed") + }, + )) + + cfg := config.Config{ + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + FallbackInputTokenThreshold: 88000, + }, + }, + } + + threshold := service.autoCompactThreshold(context.Background(), cfg) + if threshold != 88000 { + t.Fatalf("expected fallback threshold == 88000, got %d", threshold) + } +} + +func TestAutoCompactThresholdFallsBackWhenResolverReturnsZeroWithoutError(t *testing.T) { + t.Parallel() + + service := &Service{} + service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + return 0, nil + }, + )) + + cfg := config.Config{ + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + FallbackInputTokenThreshold: 88000, + }, + }, + } + + threshold := service.autoCompactThreshold(context.Background(), cfg) + if threshold != 88000 { + t.Fatalf("expected fallback threshold == 88000, got %d", threshold) + } +} + +func TestAutoCompactThresholdFallsBackWhenResolverReturnsNegativeWithoutError(t *testing.T) { + t.Parallel() + + service := &Service{} + service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + return -1, nil + }, + )) + + cfg := config.Config{ + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + FallbackInputTokenThreshold: 88000, + }, + }, + } + + threshold := service.autoCompactThreshold(context.Background(), cfg) + if threshold != 88000 { + t.Fatalf("expected fallback threshold == 88000, got %d", threshold) + } +} + +func TestAutoCompactThresholdImplicitModeWithoutResolverUsesFallback(t *testing.T) { + t.Parallel() + + service := &Service{} + cfg := config.Config{ + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + FallbackInputTokenThreshold: 88000, + }, + }, + } + + threshold := service.autoCompactThreshold(context.Background(), cfg) + if threshold != 88000 { + t.Fatalf("expected implicit mode fallback threshold == 88000, got %d", threshold) + } +} + +func TestAutoCompactThresholdForStateCachesResolverResultWithinRun(t *testing.T) { + t.Parallel() + + service := &Service{} + resolveCalls := 0 + service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + resolveCalls++ + return 88000, nil + }, + )) + + cfg := config.Config{ + SelectedProvider: "openai", + CurrentModel: "gpt-5", + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + ReserveTokens: 10000, + FallbackInputTokenThreshold: 76000, + }, + }, + } + state := newRunState("run-cache-hit", newRuntimeSession("session-cache-hit")) + + threshold1 := service.autoCompactThresholdForState(context.Background(), cfg, &state) + threshold2 := service.autoCompactThresholdForState(context.Background(), cfg, &state) + + if threshold1 != 88000 || threshold2 != 88000 { + t.Fatalf("expected cached resolver threshold == 88000, got %d and %d", threshold1, threshold2) + } + if resolveCalls != 1 { + t.Fatalf("expected resolver to be called once, got %d", resolveCalls) + } +} + +func TestAutoCompactThresholdForStateRecomputesWhenCacheKeyChanges(t *testing.T) { + t.Parallel() + + service := &Service{} + resolveCalls := 0 + service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + resolveCalls++ + if strings.TrimSpace(cfg.CurrentModel) == "gpt-5.1" { + return 99000, nil + } + return 88000, nil + }, + )) + + cfg := config.Config{ + SelectedProvider: "openai", + CurrentModel: "gpt-5", + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + ReserveTokens: 10000, + FallbackInputTokenThreshold: 76000, + }, + }, + } + state := newRunState("run-cache-miss", newRuntimeSession("session-cache-miss")) + + threshold1 := service.autoCompactThresholdForState(context.Background(), cfg, &state) + cfg.CurrentModel = "gpt-5.1" + threshold2 := service.autoCompactThresholdForState(context.Background(), cfg, &state) + + if threshold1 != 88000 || threshold2 != 99000 { + t.Fatalf("expected thresholds [88000, 99000], got [%d, %d]", threshold1, threshold2) + } + if resolveCalls != 2 { + t.Fatalf("expected resolver to be called twice after key change, got %d", resolveCalls) + } +} + func TestTokenUsageRecordedOnMessageDone(t *testing.T) { t.Parallel() diff --git a/internal/runtime/session_mutation.go b/internal/runtime/session_mutation.go index bf914917..81c5844b 100644 --- a/internal/runtime/session_mutation.go +++ b/internal/runtime/session_mutation.go @@ -9,6 +9,8 @@ import ( "neo-code/internal/tools" ) +const toolNameMetadataKey = "tool_name" + // appendUserMessageAndSave 将用户消息追加到会话并立即持久化。 func (s *Service) appendUserMessageAndSave(ctx context.Context, state *runState, content string) error { message := providertypes.Message{ @@ -55,13 +57,7 @@ func (s *Service) appendToolMessageAndSave( result tools.ToolResult, ) error { state.mu.Lock() - toolMessage := providertypes.Message{ - Role: providertypes.RoleTool, - Content: result.Content, - ToolCallID: call.ID, - IsError: result.IsError, - ToolMetadata: tools.SanitizeToolMetadata(result.Name, result.Metadata), - } + toolMessage := normalizeToolMessageForPersistence(call, result) state.session.Messages = append(state.session.Messages, toolMessage) state.touchSession() sessionSnapshot := cloneSessionForPersistence(state.session) @@ -69,6 +65,38 @@ func (s *Service) appendToolMessageAndSave( return s.sessionStore.Save(ctx, &sessionSnapshot) } +// normalizeToolMessageForPersistence 负责在写入会话前收敛工具结果,避免成功结果落成完全空语义消息。 +func normalizeToolMessageForPersistence(call providertypes.ToolCall, result tools.ToolResult) providertypes.Message { + toolName := strings.TrimSpace(result.Name) + if toolName == "" { + toolName = strings.TrimSpace(call.Name) + } + + sanitizedMetadata := tools.SanitizeToolMetadata(toolName, result.Metadata) + content := result.Content + if !result.IsError && strings.TrimSpace(content) == "" && !hasNonToolNameToolMetadata(sanitizedMetadata) { + content = "ok" + } + + return providertypes.Message{ + Role: providertypes.RoleTool, + Content: content, + ToolCallID: call.ID, + IsError: result.IsError, + ToolMetadata: sanitizedMetadata, + } +} + +// hasNonToolNameToolMetadata 判断 metadata 中是否存在除 tool_name 外的语义字段。 +func hasNonToolNameToolMetadata(metadata map[string]string) bool { + for key := range metadata { + if key != toolNameMetadataKey { + return true + } + } + return false +} + // cloneSessionForPersistence 复制会话快照,避免持久化阶段与并发写入共享可变切片/映射。 func cloneSessionForPersistence(session agentsession.Session) agentsession.Session { cloned := session diff --git a/internal/runtime/state.go b/internal/runtime/state.go index a8ae3c67..2c0974fb 100644 --- a/internal/runtime/state.go +++ b/internal/runtime/state.go @@ -21,6 +21,7 @@ type runState struct { session agentsession.Session compactApplied bool reactiveCompactAttempts int + autoCompactCache autoCompactThresholdCache rememberedThisRun bool turn int phase controlplane.Phase @@ -83,13 +84,16 @@ func (s *runState) markSkillMissingReported(skillID string) bool { } // turnSnapshot 冻结单轮推理所需的配置、上下文与 provider 请求。 +// noProgressStreakLimit 由 prepareTurnSnapshot 一次性解析并存储,确保同一轮的 +// 纠偏注入阈值与熔断阈值来自同一配置快照,避免并发 reload 导致阈值不一致。 type turnSnapshot struct { - config config.Config - providerConfig provider.RuntimeConfig - model string - workdir string - toolTimeout time.Duration - request providertypes.GenerateRequest + config config.Config + providerConfig provider.RuntimeConfig + model string + workdir string + toolTimeout time.Duration + noProgressStreakLimit int + request providertypes.GenerateRequest } // providerTurnResult 表示单轮 provider 调用成功后的结构化结果。 @@ -98,3 +102,20 @@ type providerTurnResult struct { inputTokens int outputTokens int } + +// autoCompactThresholdCache 保存当前 run 已解析过的自动压缩阈值,避免热路径重复解析。 +type autoCompactThresholdCache struct { + key autoCompactThresholdCacheKey + threshold int + valid bool +} + +// autoCompactThresholdCacheKey 描述自动压缩阈值解析输入的关键维度。 +type autoCompactThresholdCacheKey struct { + provider string + model string + autoCompactEnabled bool + autoCompactInputThreshold int + autoCompactReserveTokens int + autoCompactFallback int +} diff --git a/internal/runtime/subagent_factory.go b/internal/runtime/subagent_factory.go new file mode 100644 index 00000000..68662d0b --- /dev/null +++ b/internal/runtime/subagent_factory.go @@ -0,0 +1,97 @@ +package runtime + +import ( + "runtime" + "sync" + + "neo-code/internal/subagent" +) + +type subAgentFactoryRegistry struct { + mu sync.RWMutex + factory map[*Service]subagent.Factory + tracked map[*Service]struct{} +} + +var globalSubAgentFactories = &subAgentFactoryRegistry{ + factory: make(map[*Service]subagent.Factory), + tracked: make(map[*Service]struct{}), +} + +// ensureTracked 为 Service 注册 GC 回调,避免全局注册表持有泄漏。 +func (r *subAgentFactoryRegistry) ensureTracked(s *Service) { + if s == nil { + return + } + + r.mu.Lock() + if _, ok := r.tracked[s]; ok { + r.mu.Unlock() + return + } + r.tracked[s] = struct{}{} + r.mu.Unlock() + + runtime.SetFinalizer(s, func(service *Service) { + globalSubAgentFactories.mu.Lock() + delete(globalSubAgentFactories.factory, service) + delete(globalSubAgentFactories.tracked, service) + globalSubAgentFactories.mu.Unlock() + }) +} + +// set 保存 Service 级工厂实例。 +func (r *subAgentFactoryRegistry) set(s *Service, f subagent.Factory) { + if s == nil { + return + } + r.ensureTracked(s) + + r.mu.Lock() + r.factory[s] = f + r.mu.Unlock() +} + +// get 读取 Service 级工厂实例。 +func (r *subAgentFactoryRegistry) get(s *Service) (subagent.Factory, bool) { + if s == nil { + return nil, false + } + r.ensureTracked(s) + + r.mu.RLock() + factory, ok := r.factory[s] + r.mu.RUnlock() + return factory, ok +} + +// defaultSubAgentFactory 返回默认的子代理工厂实例。 +func defaultSubAgentFactory() subagent.Factory { + return subagent.NewWorkerFactory(nil) +} + +// SetSubAgentFactory 设置子代理运行时工厂;传入 nil 时回退到默认工厂。 +func (s *Service) SetSubAgentFactory(factory subagent.Factory) { + if s == nil { + return + } + if factory == nil { + globalSubAgentFactories.set(s, defaultSubAgentFactory()) + return + } + globalSubAgentFactories.set(s, factory) +} + +// SubAgentFactory 返回当前 runtime 持有的子代理运行时工厂。 +func (s *Service) SubAgentFactory() subagent.Factory { + if s == nil { + return defaultSubAgentFactory() + } + if factory, ok := globalSubAgentFactories.get(s); ok && factory != nil { + return factory + } + + defaultFactory := defaultSubAgentFactory() + globalSubAgentFactories.set(s, defaultFactory) + return defaultFactory +} diff --git a/internal/runtime/subagent_factory_test.go b/internal/runtime/subagent_factory_test.go new file mode 100644 index 00000000..9869ba45 --- /dev/null +++ b/internal/runtime/subagent_factory_test.go @@ -0,0 +1,54 @@ +package runtime + +import ( + "testing" + + "neo-code/internal/subagent" +) + +type fakeSubAgentFactory struct{} + +func (fakeSubAgentFactory) Create(role subagent.Role) (subagent.WorkerRuntime, error) { + return subagent.NewWorkerFactory(nil).Create(role) +} + +func TestServiceSubAgentFactoryRegistration(t *testing.T) { + t.Parallel() + + svc := NewWithFactory(nil, nil, nil, nil, nil) + if svc.SubAgentFactory() == nil { + t.Fatalf("expected default sub-agent factory") + } + + custom := fakeSubAgentFactory{} + svc.SetSubAgentFactory(custom) + if svc.SubAgentFactory() == nil { + t.Fatalf("expected custom sub-agent factory") + } + + svc.SetSubAgentFactory(nil) + if svc.SubAgentFactory() == nil { + t.Fatalf("expected reset to default sub-agent factory") + } +} + +func TestServiceSubAgentFactoryIsolationAcrossInstances(t *testing.T) { + t.Parallel() + + svcA := NewWithFactory(nil, nil, nil, nil, nil) + svcB := NewWithFactory(nil, nil, nil, nil, nil) + + custom := fakeSubAgentFactory{} + svcA.SetSubAgentFactory(custom) + + if svcA.SubAgentFactory() == nil { + t.Fatalf("expected service A factory to be set") + } + if svcB.SubAgentFactory() == nil { + t.Fatalf("expected service B default factory") + } + + if svcA.SubAgentFactory() == svcB.SubAgentFactory() { + t.Fatalf("expected per-service factory isolation") + } +} diff --git a/internal/runtime/subagent_run.go b/internal/runtime/subagent_run.go new file mode 100644 index 00000000..31bbd2b6 --- /dev/null +++ b/internal/runtime/subagent_run.go @@ -0,0 +1,187 @@ +package runtime + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "neo-code/internal/runtime/controlplane" + "neo-code/internal/subagent" +) + +// SubAgentTaskInput 描述一次子代理任务执行请求。 +type SubAgentTaskInput struct { + RunID string + SessionID string + Role subagent.Role + Task subagent.Task + Budget subagent.Budget + Capability subagent.Capability +} + +// RunSubAgentTask 使用当前 runtime 注册的工厂执行一条子代理任务。 +func (s *Service) RunSubAgentTask(ctx context.Context, input SubAgentTaskInput) (subagent.Result, error) { + if err := ctx.Err(); err != nil { + return subagent.Result{}, err + } + if strings.TrimSpace(input.RunID) == "" { + return subagent.Result{}, errors.New("runtime: subagent run id is empty") + } + if !input.Role.Valid() { + return subagent.Result{}, fmt.Errorf("runtime: invalid subagent role %q", input.Role) + } + if err := input.Task.Validate(); err != nil { + return subagent.Result{}, err + } + + factory := s.SubAgentFactory() + worker, err := factory.Create(input.Role) + if err != nil { + _ = s.emit(ctx, EventSubAgentFailed, input.RunID, input.SessionID, SubAgentEventPayload{ + Role: input.Role, + TaskID: input.Task.ID, + State: subagent.StateFailed, + Error: err.Error(), + }) + return subagent.Result{}, err + } + + if err := worker.Start(input.Task, input.Budget, input.Capability); err != nil { + _ = s.emit(ctx, EventSubAgentFailed, input.RunID, input.SessionID, SubAgentEventPayload{ + Role: input.Role, + TaskID: input.Task.ID, + State: subagent.StateFailed, + Error: err.Error(), + }) + return subagent.Result{}, err + } + + _ = s.emit(ctx, EventSubAgentStarted, input.RunID, input.SessionID, SubAgentEventPayload{ + Role: input.Role, + TaskID: input.Task.ID, + State: worker.State(), + }) + + for { + stepResult, stepErr := worker.Step(ctx) + if stepResult.State == "" { + stepResult.State = worker.State() + } + emitSubAgentProgress(s, input, stepResult, stepErr) + + if stepErr != nil { + if errors.Is(stepErr, context.Canceled) || errors.Is(stepErr, context.DeadlineExceeded) { + _ = worker.Stop(subagent.StopReasonCanceled) + result, resultErr := worker.Result() + if resultErr != nil { + result = subagent.Result{ + Role: input.Role, + TaskID: input.Task.ID, + State: subagent.StateCanceled, + StopReason: subagent.StopReasonCanceled, + Error: errorText(stepErr), + } + } + emitSubAgentTerminal(s, ctx, input, result) + return result, stepErr + } + + result, resultErr := worker.Result() + if resultErr != nil { + _ = s.emit(ctx, EventSubAgentFailed, input.RunID, input.SessionID, SubAgentEventPayload{ + Role: input.Role, + TaskID: input.Task.ID, + State: subagent.StateFailed, + Error: stepErr.Error(), + }) + return subagent.Result{}, stepErr + } + emitSubAgentTerminal(s, ctx, input, result) + return result, stepErr + } + + if !stepResult.Done { + continue + } + + result, err := worker.Result() + if err != nil { + _ = s.emit(ctx, EventSubAgentFailed, input.RunID, input.SessionID, SubAgentEventPayload{ + Role: input.Role, + TaskID: input.Task.ID, + State: subagent.StateFailed, + Error: err.Error(), + }) + return subagent.Result{}, err + } + emitSubAgentTerminal(s, ctx, input, result) + if result.State == subagent.StateSucceeded { + return result, nil + } + return result, subAgentResultError(result) + } +} + +// emitSubAgentProgress 非阻塞发射进度事件,避免慢消费者反压执行路径。 +func emitSubAgentProgress(s *Service, input SubAgentTaskInput, stepResult subagent.StepResult, stepErr error) { + payload := SubAgentEventPayload{ + Role: input.Role, + TaskID: input.Task.ID, + State: stepResult.State, + Step: stepResult.Step, + Delta: stepResult.Delta, + Error: errorText(stepErr), + } + event := RuntimeEvent{ + Type: EventSubAgentProgress, + RunID: input.RunID, + SessionID: input.SessionID, + Turn: turnUnspecified, + Timestamp: time.Now(), + PayloadVersion: controlplane.PayloadVersion, + Payload: payload, + } + select { + case s.events <- event: + default: + } +} + +// emitSubAgentTerminal 按子代理终态发射最终事件。 +func emitSubAgentTerminal(s *Service, ctx context.Context, input SubAgentTaskInput, result subagent.Result) { + payload := SubAgentEventPayload{ + Role: result.Role, + TaskID: result.TaskID, + State: result.State, + StopReason: result.StopReason, + Step: result.StepCount, + Error: strings.TrimSpace(result.Error), + } + + switch result.State { + case subagent.StateSucceeded: + _ = s.emit(ctx, EventSubAgentCompleted, input.RunID, input.SessionID, payload) + case subagent.StateCanceled: + _ = s.emit(ctx, EventSubAgentCanceled, input.RunID, input.SessionID, payload) + default: + _ = s.emit(ctx, EventSubAgentFailed, input.RunID, input.SessionID, payload) + } +} + +// errorText 将 error 安全转换为事件可用文本。 +func errorText(err error) string { + if err == nil { + return "" + } + return strings.TrimSpace(err.Error()) +} + +// subAgentResultError 将子代理终态结果转换为可诊断错误,避免空错误文本丢失上下文。 +func subAgentResultError(result subagent.Result) error { + if text := strings.TrimSpace(result.Error); text != "" { + return errors.New(text) + } + return fmt.Errorf("subagent ended with state=%s stop_reason=%s", result.State, result.StopReason) +} diff --git a/internal/runtime/subagent_run_test.go b/internal/runtime/subagent_run_test.go new file mode 100644 index 00000000..96f20f81 --- /dev/null +++ b/internal/runtime/subagent_run_test.go @@ -0,0 +1,356 @@ +package runtime + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "neo-code/internal/runtime/controlplane" + "neo-code/internal/subagent" +) + +type failingSubAgentFactory struct { + err error +} + +func (f failingSubAgentFactory) Create(role subagent.Role) (subagent.WorkerRuntime, error) { + return nil, f.err +} + +func TestServiceRunSubAgentTaskSuccess(t *testing.T) { + t.Parallel() + + service := NewWithFactory(nil, nil, nil, nil, nil) + service.SetSubAgentFactory(subagent.NewWorkerFactory(func(role subagent.Role, policy subagent.RolePolicy) subagent.Engine { + return subagent.EngineFunc(func(ctx context.Context, input subagent.StepInput) (subagent.StepOutput, error) { + if input.StepIndex == 1 { + return subagent.StepOutput{ + Delta: "step-1", + Done: false, + }, nil + } + return subagent.StepOutput{ + Delta: "step-2", + Done: true, + Output: subagent.Output{ + Summary: "task completed", + Findings: []string{"f1"}, + Patches: []string{"p1"}, + Risks: []string{"r1"}, + NextActions: []string{"n1"}, + Artifacts: []string{"a1"}, + }, + }, nil + }) + })) + + result, err := service.RunSubAgentTask(context.Background(), SubAgentTaskInput{ + RunID: "sub-run-success", + SessionID: "session-1", + Role: subagent.RoleCoder, + Task: subagent.Task{ + ID: "task-1", + Goal: "implement feature", + }, + Budget: subagent.Budget{ + MaxSteps: 3, + }, + }) + if err != nil { + t.Fatalf("RunSubAgentTask() error = %v", err) + } + if result.State != subagent.StateSucceeded { + t.Fatalf("result state = %q, want %q", result.State, subagent.StateSucceeded) + } + if result.StepCount != 2 { + t.Fatalf("result step count = %d, want 2", result.StepCount) + } + + events := collectRuntimeEvents(service.Events()) + assertEventSequence(t, events, []EventType{ + EventSubAgentStarted, + EventSubAgentProgress, + EventSubAgentProgress, + EventSubAgentCompleted, + }) + assertEventsRunID(t, events, "sub-run-success") + for _, evt := range events { + if evt.Type != EventSubAgentProgress { + continue + } + if evt.Timestamp.IsZero() { + t.Fatalf("progress event timestamp should be set: %+v", evt) + } + if evt.PayloadVersion != controlplane.PayloadVersion { + t.Fatalf("progress event payload version = %d, want %d", evt.PayloadVersion, controlplane.PayloadVersion) + } + } +} + +func TestServiceRunSubAgentTaskFailureFlows(t *testing.T) { + t.Parallel() + + t.Run("factory create failed", func(t *testing.T) { + t.Parallel() + + service := NewWithFactory(nil, nil, nil, nil, nil) + service.SetSubAgentFactory(failingSubAgentFactory{err: errors.New("create failed")}) + _, err := service.RunSubAgentTask(context.Background(), SubAgentTaskInput{ + RunID: "sub-run-factory-failed", + Role: subagent.RoleResearcher, + Task: subagent.Task{ + ID: "task-f", + Goal: "research", + }, + }) + if err == nil { + t.Fatalf("expected create error") + } + events := collectRuntimeEvents(service.Events()) + assertEventSequence(t, events, []EventType{EventSubAgentFailed}) + }) + + t.Run("worker step failed", func(t *testing.T) { + t.Parallel() + + service := NewWithFactory(nil, nil, nil, nil, nil) + service.SetSubAgentFactory(subagent.NewWorkerFactory(func(role subagent.Role, policy subagent.RolePolicy) subagent.Engine { + return subagent.EngineFunc(func(ctx context.Context, input subagent.StepInput) (subagent.StepOutput, error) { + return subagent.StepOutput{}, errors.New("step failed") + }) + })) + _, err := service.RunSubAgentTask(context.Background(), SubAgentTaskInput{ + RunID: "sub-run-step-failed", + Role: subagent.RoleReviewer, + Task: subagent.Task{ + ID: "task-step-f", + Goal: "review", + }, + }) + if err == nil { + t.Fatalf("expected step error") + } + events := collectRuntimeEvents(service.Events()) + assertEventSequence(t, events, []EventType{ + EventSubAgentStarted, + EventSubAgentProgress, + EventSubAgentFailed, + }) + }) + + t.Run("context canceled should emit canceled", func(t *testing.T) { + t.Parallel() + + service := NewWithFactory(nil, nil, nil, nil, nil) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + service.SetSubAgentFactory(subagent.NewWorkerFactory(func(role subagent.Role, policy subagent.RolePolicy) subagent.Engine { + return subagent.EngineFunc(func(ctx context.Context, input subagent.StepInput) (subagent.StepOutput, error) { + cancel() + <-ctx.Done() + return subagent.StepOutput{}, ctx.Err() + }) + })) + + result, err := service.RunSubAgentTask(ctx, SubAgentTaskInput{ + RunID: "sub-run-canceled", + Role: subagent.RoleReviewer, + Task: subagent.Task{ + ID: "task-cancel", + Goal: "review", + }, + }) + if err == nil { + t.Fatalf("expected canceled error") + } + if !errors.Is(err, context.Canceled) { + t.Fatalf("error = %v, want context canceled", err) + } + if result.State != subagent.StateCanceled { + t.Fatalf("result state = %q, want %q", result.State, subagent.StateCanceled) + } + if result.StopReason != subagent.StopReasonCanceled { + t.Fatalf("stop reason = %q, want %q", result.StopReason, subagent.StopReasonCanceled) + } + events := collectRuntimeEvents(service.Events()) + assertEventSequence(t, events, []EventType{ + EventSubAgentStarted, + EventSubAgentProgress, + EventSubAgentCanceled, + }) + }) + + t.Run("worker start failed by disallowed capability", func(t *testing.T) { + t.Parallel() + + service := NewWithFactory(nil, nil, nil, nil, nil) + _, err := service.RunSubAgentTask(context.Background(), SubAgentTaskInput{ + RunID: "sub-run-start-failed", + Role: subagent.RoleReviewer, + Task: subagent.Task{ + ID: "task-start-failed", + Goal: "review", + }, + Capability: subagent.Capability{ + AllowedTools: []string{"bash"}, + }, + }) + if err == nil { + t.Fatalf("expected start error") + } + events := collectRuntimeEvents(service.Events()) + assertEventSequence(t, events, []EventType{EventSubAgentFailed}) + }) + + t.Run("custom worker failed without explicit error should return fallback", func(t *testing.T) { + t.Parallel() + + service := NewWithFactory(nil, nil, nil, nil, nil) + service.SetSubAgentFactory(stubSubAgentFactory{ + create: func(role subagent.Role) (subagent.WorkerRuntime, error) { + return &stubSubAgentWorker{ + result: subagent.Result{ + Role: role, + TaskID: "task-fallback-error", + State: subagent.StateFailed, + StopReason: subagent.StopReasonError, + }, + stepResult: subagent.StepResult{ + State: subagent.StateFailed, + Done: true, + Step: 1, + }, + }, nil + }, + }) + + _, err := service.RunSubAgentTask(context.Background(), SubAgentTaskInput{ + RunID: "sub-run-fallback-error", + Role: subagent.RoleReviewer, + Task: subagent.Task{ + ID: "task-fallback-error", + Goal: "review", + }, + }) + if err == nil { + t.Fatalf("expected fallback error") + } + if !strings.Contains(err.Error(), "state=failed") || !strings.Contains(err.Error(), "stop_reason=error") { + t.Fatalf("error = %q, want state/stop_reason fallback", err.Error()) + } + }) +} + +type stubSubAgentFactory struct { + create func(role subagent.Role) (subagent.WorkerRuntime, error) +} + +func (s stubSubAgentFactory) Create(role subagent.Role) (subagent.WorkerRuntime, error) { + return s.create(role) +} + +type stubSubAgentWorker struct { + startErr error + stepResult subagent.StepResult + stepErr error + result subagent.Result + resultErr error + current subagent.State + stopInvoked bool +} + +func (s *stubSubAgentWorker) Start(task subagent.Task, budget subagent.Budget, capability subagent.Capability) error { + if s.startErr != nil { + return s.startErr + } + if s.current == "" { + s.current = subagent.StateRunning + } + s.result.Role = firstNonEmptyRole(s.result.Role, subagent.RoleReviewer) + if strings.TrimSpace(s.result.TaskID) == "" { + s.result.TaskID = task.ID + } + return nil +} + +func (s *stubSubAgentWorker) Step(ctx context.Context) (subagent.StepResult, error) { + if s.stepResult.State == "" { + s.stepResult.State = s.current + } + if s.stepResult.Done { + s.current = s.result.State + } + return s.stepResult, s.stepErr +} + +func (s *stubSubAgentWorker) Stop(reason subagent.StopReason) error { + s.stopInvoked = true + s.current = subagent.StateCanceled + s.result.State = subagent.StateCanceled + s.result.StopReason = reason + return nil +} + +func (s *stubSubAgentWorker) Result() (subagent.Result, error) { + return s.result, s.resultErr +} + +func (s *stubSubAgentWorker) State() subagent.State { + if s.current == "" { + return subagent.StateIdle + } + return s.current +} + +func (s *stubSubAgentWorker) Policy() subagent.RolePolicy { + return subagent.RolePolicy{} +} + +func firstNonEmptyRole(role subagent.Role, fallback subagent.Role) subagent.Role { + if role != "" { + return role + } + return fallback +} + +func TestServiceRunSubAgentTaskInputValidation(t *testing.T) { + t.Parallel() + + service := NewWithFactory(nil, nil, nil, nil, nil) + if _, err := service.RunSubAgentTask(context.Background(), SubAgentTaskInput{ + Role: subagent.RoleCoder, + Task: subagent.Task{ + ID: "task", + Goal: "goal", + }, + }); err == nil { + t.Fatalf("expected empty run id error") + } + + if _, err := service.RunSubAgentTask(context.Background(), SubAgentTaskInput{ + RunID: "sub-run-invalid-role", + Role: subagent.Role("x"), + Task: subagent.Task{ + ID: "task", + Goal: "goal", + }, + }); err == nil { + t.Fatalf("expected invalid role error") + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) + defer cancel() + time.Sleep(2 * time.Millisecond) + if _, err := service.RunSubAgentTask(ctx, SubAgentTaskInput{ + RunID: "sub-run-timeout", + Role: subagent.RoleCoder, + Task: subagent.Task{ + ID: "task-timeout", + Goal: "goal", + }, + }); err == nil { + t.Fatalf("expected context error") + } +} diff --git a/internal/session/store_test.go b/internal/session/store_test.go index e08710b1..ec1c9834 100644 --- a/internal/session/store_test.go +++ b/internal/session/store_test.go @@ -643,6 +643,56 @@ func TestJSONStoreSavePersistsProviderModelAndMessages(t *testing.T) { } } +func TestJSONStoreSaveRoundTripsMetadataOnlyToolMessage(t *testing.T) { + t.Parallel() + + baseDir := t.TempDir() + workspaceRoot := t.TempDir() + store := NewJSONStore(baseDir, workspaceRoot) + + session := &Session{ + SchemaVersion: CurrentSchemaVersion, + ID: "metadata-only-tool-message", + Title: "Metadata Only Tool Message", + CreatedAt: time.Now().Add(-time.Hour), + UpdatedAt: time.Now(), + Messages: []providertypes.Message{ + {Role: providertypes.RoleUser, Content: "inspect"}, + { + Role: providertypes.RoleAssistant, + ToolCalls: []providertypes.ToolCall{ + {ID: "call-1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`}, + }, + }, + { + Role: providertypes.RoleTool, + ToolCallID: "call-1", + Content: "", + ToolMetadata: map[string]string{ + "tool_name": "filesystem_read_file", + "path": "README.md", + }, + }, + }, + } + + if err := store.Save(context.Background(), session); err != nil { + t.Fatalf("save session: %v", err) + } + + loaded, err := store.Load(context.Background(), session.ID) + if err != nil { + t.Fatalf("load saved session: %v", err) + } + if loaded.Messages[2].Content != "" { + t.Fatalf("expected empty content to round-trip, got %q", loaded.Messages[2].Content) + } + if loaded.Messages[2].ToolMetadata["tool_name"] != "filesystem_read_file" || + loaded.Messages[2].ToolMetadata["path"] != "README.md" { + t.Fatalf("expected metadata-only tool message round-trip, got %+v", loaded.Messages[2].ToolMetadata) + } +} + func TestDecodeStoredSummaryUsesLightweightMetadataPath(t *testing.T) { t.Parallel() diff --git a/internal/subagent/engine.go b/internal/subagent/engine.go new file mode 100644 index 00000000..b8e859d8 --- /dev/null +++ b/internal/subagent/engine.go @@ -0,0 +1,34 @@ +package subagent + +import ( + "context" + "strings" +) + +// defaultEngine 提供一个可运行的默认单步完成引擎,便于工厂与测试快速装配。 +type defaultEngine struct{} + +// RunStep 执行默认单步逻辑并返回结构化结果。 +func (defaultEngine) RunStep(ctx context.Context, input StepInput) (StepOutput, error) { + if err := ctx.Err(); err != nil { + return StepOutput{}, err + } + + summary := strings.TrimSpace(input.Task.ExpectedOutput) + if summary == "" { + summary = strings.TrimSpace(input.Task.Goal) + } + + return StepOutput{ + Delta: "default engine completed", + Done: true, + Output: Output{ + Summary: summary, + Findings: []string{"default-engine: no extra findings"}, + Patches: []string{"default-engine: no code changes"}, + Risks: []string{"default-engine: output generated without external verification"}, + NextActions: []string{"replace default engine with role-specific implementation if deeper execution is required"}, + Artifacts: []string{"default-engine"}, + }, + }, nil +} diff --git a/internal/subagent/engine_test.go b/internal/subagent/engine_test.go new file mode 100644 index 00000000..9a11048f --- /dev/null +++ b/internal/subagent/engine_test.go @@ -0,0 +1,89 @@ +package subagent + +import ( + "context" + "testing" +) + +func TestDefaultEngineRunStep(t *testing.T) { + t.Parallel() + + engine := defaultEngine{} + + t.Run("uses expected output as summary", func(t *testing.T) { + t.Parallel() + + out, err := engine.RunStep(context.Background(), StepInput{ + Task: Task{ + Goal: "goal", + ExpectedOutput: "expected", + }, + }) + if err != nil { + t.Fatalf("RunStep() error = %v", err) + } + if !out.Done { + t.Fatalf("expected done output") + } + if out.Output.Summary != "expected" { + t.Fatalf("summary = %q, want %q", out.Output.Summary, "expected") + } + if len(out.Output.Findings) == 0 || len(out.Output.Patches) == 0 || len(out.Output.Risks) == 0 { + t.Fatalf("default engine should populate required list sections, got %+v", out.Output) + } + if len(out.Output.NextActions) == 0 || len(out.Output.Artifacts) == 0 { + t.Fatalf("default engine should populate required sections, got %+v", out.Output) + } + }) + + t.Run("falls back to goal", func(t *testing.T) { + t.Parallel() + + out, err := engine.RunStep(context.Background(), StepInput{ + Task: Task{ + Goal: "goal-value", + ExpectedOutput: " ", + }, + }) + if err != nil { + t.Fatalf("RunStep() error = %v", err) + } + if out.Output.Summary != "goal-value" { + t.Fatalf("summary = %q, want %q", out.Output.Summary, "goal-value") + } + }) + + t.Run("context canceled", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, err := engine.RunStep(ctx, StepInput{Task: Task{Goal: "g"}}); err == nil { + t.Fatalf("expected context error") + } + }) + + t.Run("satisfies default role contract", func(t *testing.T) { + t.Parallel() + + out, err := engine.RunStep(context.Background(), StepInput{ + Task: Task{ + Goal: "goal", + ExpectedOutput: "summary", + }, + }) + if err != nil { + t.Fatalf("RunStep() error = %v", err) + } + + for _, role := range []Role{RoleResearcher, RoleCoder, RoleReviewer} { + policy, err := DefaultRolePolicy(role) + if err != nil { + t.Fatalf("DefaultRolePolicy(%q) error = %v", role, err) + } + if err := validateOutputContract(policy, out.Output); err != nil { + t.Fatalf("validateOutputContract(%q) error = %v", role, err) + } + } + }) +} diff --git a/internal/subagent/factory.go b/internal/subagent/factory.go new file mode 100644 index 00000000..3b4b19ba --- /dev/null +++ b/internal/subagent/factory.go @@ -0,0 +1,36 @@ +package subagent + +import "time" + +// EngineBuilder 定义基于角色策略构建执行引擎的工厂函数。 +type EngineBuilder func(role Role, policy RolePolicy) Engine + +// WorkerFactory 是默认的 WorkerRuntime 工厂实现。 +type WorkerFactory struct { + builder EngineBuilder +} + +// NewWorkerFactory 创建 WorkerFactory;当 builder 为空时使用默认引擎。 +func NewWorkerFactory(builder EngineBuilder) *WorkerFactory { + return &WorkerFactory{builder: builder} +} + +// Create 基于角色创建对应策略与执行引擎的 WorkerRuntime。 +func (f *WorkerFactory) Create(role Role) (WorkerRuntime, error) { + policy, err := DefaultRolePolicy(role) + if err != nil { + return nil, err + } + // 兜底保证默认预算始终可用。 + policy.DefaultBudget = policy.DefaultBudget.normalize(Budget{ + MaxSteps: defaultPolicyMaxSteps, + Timeout: defaultPolicyTimeout * time.Second, + }) + + var engine Engine + if f != nil && f.builder != nil { + engine = f.builder(role, policy) + } + return NewWorker(role, policy, engine) +} + diff --git a/internal/subagent/output_contract.go b/internal/subagent/output_contract.go new file mode 100644 index 00000000..5b7514ea --- /dev/null +++ b/internal/subagent/output_contract.go @@ -0,0 +1,61 @@ +package subagent + +import "strings" + +var supportedOutputSections = map[string]struct{}{ + "summary": {}, + "findings": {}, + "patches": {}, + "risks": {}, + "next_actions": {}, + "artifacts": {}, +} + +// validateOutputContract 校验输出结构是否满足角色策略要求。 +func validateOutputContract(policy RolePolicy, output Output) error { + out := output.normalize() + requiredSections, err := normalizeRequiredSections(policy.RequiredSections) + if err != nil { + return err + } + for _, key := range requiredSections { + if !hasOutputSectionContent(out, key) { + return errorsf("output section %q is required", key) + } + } + return nil +} + +// normalizeRequiredSections 归一化并校验 required section 名称集合。 +func normalizeRequiredSections(sections []string) ([]string, error) { + items := dedupeAndTrim(sections) + keys := make([]string, 0, len(items)) + for _, section := range items { + key := strings.ToLower(strings.TrimSpace(section)) + if _, ok := supportedOutputSections[key]; !ok { + return nil, errorsf("unsupported required output section %q", section) + } + keys = append(keys, key) + } + return keys, nil +} + +// hasOutputSectionContent 判断指定 section 在输出中是否包含有效内容。 +func hasOutputSectionContent(output Output, section string) bool { + switch section { + case "summary": + return strings.TrimSpace(output.Summary) != "" + case "findings": + return len(output.Findings) > 0 + case "patches": + return len(output.Patches) > 0 + case "risks": + return len(output.Risks) > 0 + case "next_actions": + return len(output.NextActions) > 0 + case "artifacts": + return len(output.Artifacts) > 0 + default: + return false + } +} diff --git a/internal/subagent/output_contract_test.go b/internal/subagent/output_contract_test.go new file mode 100644 index 00000000..3971c2d0 --- /dev/null +++ b/internal/subagent/output_contract_test.go @@ -0,0 +1,74 @@ +package subagent + +import "testing" + +func TestValidateOutputContractRequiresAllDeclaredSections(t *testing.T) { + t.Parallel() + + policy := RolePolicy{ + Role: RoleCoder, + SystemPrompt: "prompt", + AllowedTools: []string{"bash"}, + RequiredSections: []string{"summary", "findings", "patches", "risks", "next_actions", "artifacts"}, + } + + full := Output{ + Summary: "ok", + Findings: []string{"f"}, + Patches: []string{"p"}, + Risks: []string{"r"}, + NextActions: []string{"n"}, + Artifacts: []string{"a"}, + } + if err := validateOutputContract(policy, full); err != nil { + t.Fatalf("validateOutputContract() error = %v", err) + } + + cases := []struct { + name string + output Output + }{ + {name: "missing summary", output: Output{ + Findings: []string{"f"}, Patches: []string{"p"}, Risks: []string{"r"}, NextActions: []string{"n"}, Artifacts: []string{"a"}, + }}, + {name: "missing findings", output: Output{ + Summary: "ok", Patches: []string{"p"}, Risks: []string{"r"}, NextActions: []string{"n"}, Artifacts: []string{"a"}, + }}, + {name: "missing patches", output: Output{ + Summary: "ok", Findings: []string{"f"}, Risks: []string{"r"}, NextActions: []string{"n"}, Artifacts: []string{"a"}, + }}, + {name: "missing risks", output: Output{ + Summary: "ok", Findings: []string{"f"}, Patches: []string{"p"}, NextActions: []string{"n"}, Artifacts: []string{"a"}, + }}, + {name: "missing next actions", output: Output{ + Summary: "ok", Findings: []string{"f"}, Patches: []string{"p"}, Risks: []string{"r"}, Artifacts: []string{"a"}, + }}, + {name: "missing artifacts", output: Output{ + Summary: "ok", Findings: []string{"f"}, Patches: []string{"p"}, Risks: []string{"r"}, NextActions: []string{"n"}, + }}, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if err := validateOutputContract(policy, tc.output); err == nil { + t.Fatalf("expected contract validation error") + } + }) + } +} + +func TestValidateOutputContractUnsupportedSection(t *testing.T) { + t.Parallel() + + policy := RolePolicy{ + Role: RoleResearcher, + SystemPrompt: "prompt", + AllowedTools: []string{"filesystem_grep"}, + RequiredSections: []string{"summary", "x"}, + } + if err := validateOutputContract(policy, Output{Summary: "ok"}); err == nil { + t.Fatalf("expected unsupported section error") + } +} diff --git a/internal/subagent/policy.go b/internal/subagent/policy.go new file mode 100644 index 00000000..12c81544 --- /dev/null +++ b/internal/subagent/policy.go @@ -0,0 +1,86 @@ +package subagent + +import ( + "strings" + "time" +) + +const ( + defaultPolicyMaxSteps = 6 + defaultPolicyTimeout = 30 +) + +// RolePolicy 定义不同角色的执行策略。 +type RolePolicy struct { + Role Role + SystemPrompt string + AllowedTools []string + DefaultBudget Budget + RequiredSections []string +} + +// Validate 校验角色策略是否合法。 +func (p RolePolicy) Validate() error { + if !p.Role.Valid() { + return errorsf("invalid policy role %q", p.Role) + } + if strings.TrimSpace(p.SystemPrompt) == "" { + return errorsf("role policy prompt is required") + } + if len(dedupeAndTrim(p.AllowedTools)) == 0 { + return errorsf("role policy allowed tools is empty") + } + if len(dedupeAndTrim(p.RequiredSections)) == 0 { + return errorsf("role policy required sections is empty") + } + if _, err := normalizeRequiredSections(p.RequiredSections); err != nil { + return err + } + return nil +} + +// DefaultRolePolicy 返回内置角色策略。 +func DefaultRolePolicy(role Role) (RolePolicy, error) { + if !role.Valid() { + return RolePolicy{}, errorsf("unsupported role %q", role) + } + + policy := RolePolicy{ + Role: role, + DefaultBudget: Budget{ + MaxSteps: defaultPolicyMaxSteps, + Timeout: defaultPolicyTimeout * time.Second, + }, + RequiredSections: []string{ + "summary", + "findings", + "patches", + "risks", + "next_actions", + "artifacts", + }, + } + + switch role { + case RoleResearcher: + policy.SystemPrompt = "你是研究型子代理,负责检索证据并形成结论。" + policy.AllowedTools = []string{"filesystem_read_file", "filesystem_glob", "filesystem_grep", "webfetch"} + case RoleCoder: + policy.SystemPrompt = "你是实现型子代理,负责修改代码并给出验证结果。" + policy.AllowedTools = []string{ + "filesystem_read_file", + "filesystem_write_file", + "filesystem_edit", + "filesystem_glob", + "filesystem_grep", + "bash", + } + case RoleReviewer: + policy.SystemPrompt = "你是审查型子代理,负责识别缺陷、风险与测试缺口。" + policy.AllowedTools = []string{"filesystem_read_file", "filesystem_glob", "filesystem_grep"} + } + + policy.AllowedTools = dedupeAndTrim(policy.AllowedTools) + policy.RequiredSections = dedupeAndTrim(policy.RequiredSections) + return policy, nil +} diff --git a/internal/subagent/policy_test.go b/internal/subagent/policy_test.go new file mode 100644 index 00000000..08721d37 --- /dev/null +++ b/internal/subagent/policy_test.go @@ -0,0 +1,124 @@ +package subagent + +import "testing" + +func TestDefaultRolePolicy(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + role Role + }{ + {name: "researcher", role: RoleResearcher}, + {name: "coder", role: RoleCoder}, + {name: "reviewer", role: RoleReviewer}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + policy, err := DefaultRolePolicy(tt.role) + if err != nil { + t.Fatalf("DefaultRolePolicy() error = %v", err) + } + if policy.Role != tt.role { + t.Fatalf("policy role = %q, want %q", policy.Role, tt.role) + } + if policy.DefaultBudget.MaxSteps <= 0 { + t.Fatalf("invalid max steps: %d", policy.DefaultBudget.MaxSteps) + } + if len(policy.AllowedTools) == 0 { + t.Fatalf("expected non-empty allowed tools") + } + if len(policy.RequiredSections) == 0 { + t.Fatalf("expected non-empty required sections") + } + if err := policy.Validate(); err != nil { + t.Fatalf("policy.Validate() error = %v", err) + } + }) + } +} + +func TestDefaultRolePolicyInvalidRole(t *testing.T) { + t.Parallel() + + if _, err := DefaultRolePolicy(Role("unknown")); err == nil { + t.Fatalf("expected invalid role error") + } +} + +func TestRolePolicyValidate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + policy RolePolicy + wantErr bool + }{ + { + name: "valid", + policy: RolePolicy{ + Role: RoleResearcher, + SystemPrompt: "prompt", + AllowedTools: []string{"filesystem_grep"}, + RequiredSections: []string{"summary"}, + }, + }, + { + name: "empty prompt", + policy: RolePolicy{ + Role: RoleResearcher, + AllowedTools: []string{"filesystem_grep"}, + RequiredSections: []string{"summary"}, + }, + wantErr: true, + }, + { + name: "empty tools", + policy: RolePolicy{ + Role: RoleResearcher, + SystemPrompt: "prompt", + RequiredSections: []string{"summary"}, + }, + wantErr: true, + }, + { + name: "invalid role", + policy: RolePolicy{ + Role: Role("x"), + SystemPrompt: "prompt", + AllowedTools: []string{"filesystem_grep"}, + RequiredSections: []string{"summary"}, + }, + wantErr: true, + }, + { + name: "unsupported required section", + policy: RolePolicy{ + Role: RoleResearcher, + SystemPrompt: "prompt", + AllowedTools: []string{"filesystem_grep"}, + RequiredSections: []string{"unknown_section"}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := tt.policy.Validate() + if tt.wantErr && err == nil { + t.Fatalf("expected error") + } + if !tt.wantErr && err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} diff --git a/internal/subagent/types.go b/internal/subagent/types.go new file mode 100644 index 00000000..9e3ba823 --- /dev/null +++ b/internal/subagent/types.go @@ -0,0 +1,224 @@ +package subagent + +import ( + "context" + "fmt" + "strings" + "time" +) + +// Role 表示子代理的执行角色。 +type Role string + +const ( + // RoleResearcher 用于检索与分析任务。 + RoleResearcher Role = "researcher" + // RoleCoder 用于实现与修复任务。 + RoleCoder Role = "coder" + // RoleReviewer 用于审查与验收任务。 + RoleReviewer Role = "reviewer" +) + +// Valid 判断角色是否受支持。 +func (r Role) Valid() bool { + switch r { + case RoleResearcher, RoleCoder, RoleReviewer: + return true + default: + return false + } +} + +// Budget 描述子代理执行预算。 +type Budget struct { + MaxSteps int + Timeout time.Duration +} + +// normalize 归一化预算并应用默认值。 +func (b Budget) normalize(defaults Budget) Budget { + out := b + if out.MaxSteps <= 0 { + out.MaxSteps = defaults.MaxSteps + } + if out.MaxSteps <= 0 { + out.MaxSteps = 6 + } + if out.Timeout <= 0 { + out.Timeout = defaults.Timeout + } + if out.Timeout <= 0 { + out.Timeout = 30 * time.Second + } + return out +} + +// Capability 描述子代理运行时可用能力边界。 +type Capability struct { + AllowedTools []string + AllowedPaths []string +} + +// normalize 归一化能力列表并去重。 +func (c Capability) normalize() Capability { + return Capability{ + AllowedTools: dedupeAndTrim(c.AllowedTools), + AllowedPaths: dedupeAndTrim(c.AllowedPaths), + } +} + +// Task 表示单个子代理任务输入。 +type Task struct { + ID string + Goal string + ExpectedOutput string + Workspace string +} + +// Validate 校验任务输入是否合法。 +func (t Task) Validate() error { + if strings.TrimSpace(t.ID) == "" { + return errorsf("task id is required") + } + if strings.TrimSpace(t.Goal) == "" { + return errorsf("task goal is required") + } + return nil +} + +// StopReason 表示子代理终止原因。 +type StopReason string + +const ( + // StopReasonCompleted 表示正常完成。 + StopReasonCompleted StopReason = "completed" + // StopReasonCanceled 表示被取消。 + StopReasonCanceled StopReason = "canceled" + // StopReasonTimeout 表示执行超时。 + StopReasonTimeout StopReason = "timeout" + // StopReasonMaxSteps 表示达到步数上限。 + StopReasonMaxSteps StopReason = "max_steps" + // StopReasonError 表示执行错误。 + StopReasonError StopReason = "error" +) + +// State 表示子代理生命周期状态。 +type State string + +const ( + // StateIdle 表示尚未启动。 + StateIdle State = "idle" + // StateRunning 表示执行中。 + StateRunning State = "running" + // StateSucceeded 表示执行成功结束。 + StateSucceeded State = "succeeded" + // StateFailed 表示执行失败结束。 + StateFailed State = "failed" + // StateCanceled 表示被取消结束。 + StateCanceled State = "canceled" +) + +// Terminal 判断当前状态是否为终态。 +func (s State) Terminal() bool { + switch s { + case StateSucceeded, StateFailed, StateCanceled: + return true + default: + return false + } +} + +// Output 定义子代理标准结构化输出。 +type Output struct { + Summary string + Findings []string + Patches []string + Risks []string + NextActions []string + Artifacts []string +} + +// normalize 归一化输出,避免重复与空项。 +func (o Output) normalize() Output { + o.Summary = strings.TrimSpace(o.Summary) + o.Findings = dedupeAndTrim(o.Findings) + o.Patches = dedupeAndTrim(o.Patches) + o.Risks = dedupeAndTrim(o.Risks) + o.NextActions = dedupeAndTrim(o.NextActions) + o.Artifacts = dedupeAndTrim(o.Artifacts) + return o +} + +// StepInput 表示单步执行输入。 +type StepInput struct { + Role Role + Policy RolePolicy + Task Task + Budget Budget + Capability Capability + StepIndex int + Trace []string +} + +// StepOutput 表示单步执行输出。 +type StepOutput struct { + Delta string + Done bool + Output Output +} + +// StepResult 表示一次 Step 的可观测结果。 +type StepResult struct { + State State + Done bool + Step int + Delta string +} + +// Result 描述子代理完成后的结构化结果。 +type Result struct { + Role Role + TaskID string + State State + StopReason StopReason + StartedAt time.Time + EndedAt time.Time + StepCount int + Budget Budget + Capability Capability + Output Output + Error string +} + +// Engine 定义 WorkerRuntime 的单步执行引擎。 +type Engine interface { + RunStep(ctx context.Context, input StepInput) (StepOutput, error) +} + +// EngineFunc 允许用函数实现 Engine。 +type EngineFunc func(ctx context.Context, input StepInput) (StepOutput, error) + +// RunStep 执行函数式引擎逻辑。 +func (f EngineFunc) RunStep(ctx context.Context, input StepInput) (StepOutput, error) { + return f(ctx, input) +} + +// WorkerRuntime 定义子代理执行生命周期接口。 +type WorkerRuntime interface { + Start(task Task, budget Budget, capability Capability) error + Step(ctx context.Context) (StepResult, error) + Stop(reason StopReason) error + Result() (Result, error) + State() State + Policy() RolePolicy +} + +// Factory 定义 runtime 侧创建 WorkerRuntime 的入口。 +type Factory interface { + Create(role Role) (WorkerRuntime, error) +} + +// errorsf 统一组装 subagent 模块错误前缀。 +func errorsf(format string, args ...any) error { + return fmt.Errorf("subagent: "+format, args...) +} diff --git a/internal/subagent/util.go b/internal/subagent/util.go new file mode 100644 index 00000000..a87acc25 --- /dev/null +++ b/internal/subagent/util.go @@ -0,0 +1,28 @@ +package subagent + +import "strings" + +// dedupeAndTrim 对字符串切片做去空白、去重并保持原顺序。 +func dedupeAndTrim(items []string) []string { + if len(items) == 0 { + return nil + } + result := make([]string, 0, len(items)) + seen := make(map[string]struct{}, len(items)) + for _, raw := range items { + item := strings.TrimSpace(raw) + if item == "" { + continue + } + key := strings.ToLower(item) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + result = append(result, item) + } + if len(result) == 0 { + return nil + } + return result +} diff --git a/internal/subagent/worker.go b/internal/subagent/worker.go new file mode 100644 index 00000000..18e9cf17 --- /dev/null +++ b/internal/subagent/worker.go @@ -0,0 +1,327 @@ +package subagent + +import ( + "context" + "errors" + "strings" + "sync" + "time" +) + +// worker 是 WorkerRuntime 的默认实现,负责封装单任务生命周期。 +type worker struct { + mu sync.RWMutex + role Role + policy RolePolicy + engine Engine + state State + task Task + budget Budget + capability Capability + stepCount int + trace []string + startedAt time.Time + endedAt time.Time + stopReason StopReason + output Output + lastErr string +} + +const traceWindowSize = 16 + +// NewWorker 根据角色、策略与引擎创建一个 WorkerRuntime 实例。 +func NewWorker(role Role, policy RolePolicy, engine Engine) (WorkerRuntime, error) { + if !role.Valid() { + return nil, errorsf("invalid role %q", role) + } + if policy.Role == "" { + policy.Role = role + } + if policy.Role != role { + return nil, errorsf("policy role %q does not match worker role %q", policy.Role, role) + } + if err := policy.Validate(); err != nil { + return nil, err + } + if engine == nil { + engine = defaultEngine{} + } + + return &worker{ + role: role, + policy: policy, + engine: engine, + state: StateIdle, + }, nil +} + +// Start 初始化任务执行上下文并进入运行态。 +func (w *worker) Start(task Task, budget Budget, capability Capability) error { + if w == nil { + return errors.New("subagent: worker is nil") + } + if err := task.Validate(); err != nil { + return err + } + + w.mu.Lock() + defer w.mu.Unlock() + if w.state != StateIdle { + return errorsf("worker already started") + } + + w.task = task + w.budget = budget.normalize(w.policy.DefaultBudget) + effectiveCapability, err := bindCapabilityToPolicy(capability.normalize(), w.policy) + if err != nil { + return err + } + w.capability = effectiveCapability + w.trace = nil + w.stepCount = 0 + w.output = Output{} + w.lastErr = "" + w.stopReason = "" + w.startedAt = time.Now() + w.endedAt = time.Time{} + w.state = StateRunning + return nil +} + +// Step 执行一次引擎步骤,并在满足终止条件时更新终态结果。 +func (w *worker) Step(ctx context.Context) (StepResult, error) { + if w == nil { + return StepResult{}, errors.New("subagent: worker is nil") + } + if err := ctx.Err(); err != nil { + return StepResult{}, err + } + + w.mu.Lock() + if w.state != StateRunning { + state := w.state + w.mu.Unlock() + return StepResult{}, errorsf("worker is not running, current state=%s", state) + } + if w.budget.Timeout > 0 && time.Since(w.startedAt) >= w.budget.Timeout { + result := w.finishLocked(StateFailed, StopReasonTimeout, Output{}, errorsf("worker timeout")) + w.mu.Unlock() + return StepResult{State: result.State, Done: true, Step: result.StepCount}, nil + } + if w.stepCount >= w.budget.MaxSteps { + result := w.finishLocked(StateFailed, StopReasonMaxSteps, Output{}, errorsf("worker reached max steps")) + w.mu.Unlock() + return StepResult{State: result.State, Done: true, Step: result.StepCount}, nil + } + + input := StepInput{ + Role: w.role, + Policy: w.policy, + Task: w.task, + Budget: w.budget, + Capability: w.capability, + StepIndex: w.stepCount + 1, + Trace: cloneRecentTrace(w.trace, traceWindowSize), + } + w.mu.Unlock() + + stepOutput, err := w.engine.RunStep(ctx, input) + if err != nil { + w.mu.Lock() + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + result := w.finishLocked(StateCanceled, StopReasonCanceled, Output{}, nil) + w.mu.Unlock() + return StepResult{State: result.State, Done: true, Step: result.StepCount}, err + } + result := w.finishLocked(StateFailed, StopReasonError, Output{}, err) + w.mu.Unlock() + return StepResult{State: result.State, Done: true, Step: result.StepCount}, err + } + + w.mu.Lock() + defer w.mu.Unlock() + if w.state != StateRunning { + return StepResult{}, errorsf("worker is not running, current state=%s", w.state) + } + + w.stepCount++ + delta := strings.TrimSpace(stepOutput.Delta) + if delta != "" { + w.trace = appendTraceBounded(w.trace, delta, traceWindowSize) + } + + if stepOutput.Done { + if err := validateOutputContract(w.policy, stepOutput.Output); err != nil { + result := w.finishLocked(StateFailed, StopReasonError, Output{}, err) + return StepResult{State: result.State, Done: true, Step: result.StepCount, Delta: delta}, err + } + result := w.finishLocked(StateSucceeded, StopReasonCompleted, stepOutput.Output, nil) + return StepResult{State: result.State, Done: true, Step: result.StepCount, Delta: delta}, nil + } + + if w.stepCount >= w.budget.MaxSteps { + result := w.finishLocked(StateFailed, StopReasonMaxSteps, Output{}, errorsf("worker reached max steps")) + return StepResult{State: result.State, Done: true, Step: result.StepCount, Delta: delta}, nil + } + + return StepResult{ + State: StateRunning, + Done: false, + Step: w.stepCount, + Delta: delta, + }, nil +} + +// bindCapabilityToPolicy 将 capability 约束在角色策略允许的工具集合内。 +func bindCapabilityToPolicy(capability Capability, policy RolePolicy) (Capability, error) { + if len(capability.AllowedPaths) > 0 { + return Capability{}, errorsf("capability allowed paths is not supported yet") + } + + allowedPolicyTools := dedupeAndTrim(policy.AllowedTools) + allowedSet := make(map[string]struct{}, len(allowedPolicyTools)) + for _, tool := range allowedPolicyTools { + allowedSet[strings.ToLower(strings.TrimSpace(tool))] = struct{}{} + } + + if len(capability.AllowedTools) == 0 { + capability.AllowedTools = append([]string(nil), allowedPolicyTools...) + return capability, nil + } + + effective := make([]string, 0, len(capability.AllowedTools)) + disallowed := make([]string, 0) + for _, tool := range capability.AllowedTools { + normalized := strings.ToLower(strings.TrimSpace(tool)) + if _, ok := allowedSet[normalized]; !ok { + disallowed = append(disallowed, tool) + continue + } + effective = append(effective, tool) + } + if len(disallowed) > 0 { + return Capability{}, errorsf("capability contains disallowed tools: %s", strings.Join(disallowed, ", ")) + } + capability.AllowedTools = effective + return capability, nil +} + +// appendTraceBounded 将新增 trace 追加到切片尾部,并保证内部存储长度不超过上限。 +func appendTraceBounded(trace []string, delta string, limit int) []string { + trace = append(trace, delta) + if limit <= 0 || len(trace) <= limit { + return trace + } + start := len(trace) - limit + copy(trace, trace[start:]) + return trace[:limit] +} + +// cloneRecentTrace 复制最近 limit 条 trace,避免每步复制完整历史导致复杂度放大。 +func cloneRecentTrace(trace []string, limit int) []string { + if len(trace) == 0 { + return nil + } + if limit <= 0 || len(trace) <= limit { + return append([]string(nil), trace...) + } + start := len(trace) - limit + return append([]string(nil), trace[start:]...) +} + +// Stop 主动终止运行中的 worker,并按终止原因映射最终状态。 +func (w *worker) Stop(reason StopReason) error { + if w == nil { + return errors.New("subagent: worker is nil") + } + if strings.TrimSpace(string(reason)) == "" { + return errorsf("stop reason is required") + } + + w.mu.Lock() + defer w.mu.Unlock() + if w.state != StateRunning { + if w.state.Terminal() { + return nil + } + return errorsf("worker is not running, current state=%s", w.state) + } + + switch reason { + case StopReasonCompleted: + if err := validateOutputContract(w.policy, w.output); err != nil { + return err + } + w.finishLocked(StateSucceeded, reason, w.output, nil) + case StopReasonCanceled: + w.finishLocked(StateCanceled, reason, w.output, nil) + case StopReasonTimeout, StopReasonMaxSteps, StopReasonError: + w.finishLocked(StateFailed, reason, w.output, nil) + default: + return errorsf("unsupported stop reason %q", reason) + } + return nil +} + +// Result 返回 worker 的最终结构化结果。 +func (w *worker) Result() (Result, error) { + if w == nil { + return Result{}, errors.New("subagent: worker is nil") + } + + w.mu.RLock() + defer w.mu.RUnlock() + if !w.state.Terminal() { + return Result{}, errorsf("worker is not finished") + } + return w.snapshotLocked(), nil +} + +// State 返回当前 worker 生命周期状态。 +func (w *worker) State() State { + if w == nil { + return StateIdle + } + w.mu.RLock() + defer w.mu.RUnlock() + return w.state +} + +// Policy 返回当前 worker 角色策略副本。 +func (w *worker) Policy() RolePolicy { + if w == nil { + return RolePolicy{} + } + w.mu.RLock() + defer w.mu.RUnlock() + return w.policy +} + +// finishLocked 在持有写锁时将 worker 切换为终态并返回结果快照。 +func (w *worker) finishLocked(state State, reason StopReason, output Output, err error) Result { + w.state = state + w.stopReason = reason + w.endedAt = time.Now() + w.output = output.normalize() + if err != nil { + w.lastErr = err.Error() + } + return w.snapshotLocked() +} + +// snapshotLocked 在持有读锁或写锁时构造稳定结果快照。 +func (w *worker) snapshotLocked() Result { + return Result{ + Role: w.role, + TaskID: w.task.ID, + State: w.state, + StopReason: w.stopReason, + StartedAt: w.startedAt, + EndedAt: w.endedAt, + StepCount: w.stepCount, + Budget: w.budget, + Capability: w.capability, + Output: w.output, + Error: w.lastErr, + } +} diff --git a/internal/subagent/worker_test.go b/internal/subagent/worker_test.go new file mode 100644 index 00000000..532e4c76 --- /dev/null +++ b/internal/subagent/worker_test.go @@ -0,0 +1,437 @@ +package subagent + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestWorkerLifecycleCompleted(t *testing.T) { + t.Parallel() + + policy, err := DefaultRolePolicy(RoleCoder) + if err != nil { + t.Fatalf("DefaultRolePolicy() error = %v", err) + } + + w, err := NewWorker(RoleCoder, policy, EngineFunc(func(ctx context.Context, input StepInput) (StepOutput, error) { + return StepOutput{ + Delta: "patched files", + Done: true, + Output: Output{ + Summary: "done", + Findings: []string{"root cause fixed"}, + Patches: []string{"a.go"}, + Risks: []string{"need integration verify"}, + NextActions: []string{"run tests"}, + Artifacts: []string{"test report"}, + }, + }, nil + })) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + + err = w.Start(Task{ID: "t1", Goal: "fix bug"}, Budget{MaxSteps: 3}, Capability{ + AllowedTools: []string{"bash", "bash", " "}, + }) + if err != nil { + t.Fatalf("Start() error = %v", err) + } + + step, err := w.Step(context.Background()) + if err != nil { + t.Fatalf("Step() error = %v", err) + } + if !step.Done || step.State != StateSucceeded { + t.Fatalf("unexpected step result: %+v", step) + } + + result, err := w.Result() + if err != nil { + t.Fatalf("Result() error = %v", err) + } + if result.StopReason != StopReasonCompleted { + t.Fatalf("stop reason = %q, want %q", result.StopReason, StopReasonCompleted) + } + if result.StepCount != 1 { + t.Fatalf("step count = %d, want 1", result.StepCount) + } + if len(result.Capability.AllowedTools) != 1 { + t.Fatalf("expected capability dedupe, got %+v", result.Capability) + } +} + +func TestWorkerLifecycleFailures(t *testing.T) { + t.Parallel() + + policy, err := DefaultRolePolicy(RoleResearcher) + if err != nil { + t.Fatalf("DefaultRolePolicy() error = %v", err) + } + + t.Run("engine error", func(t *testing.T) { + t.Parallel() + + w, err := NewWorker(RoleResearcher, policy, EngineFunc(func(ctx context.Context, input StepInput) (StepOutput, error) { + return StepOutput{}, errors.New("boom") + })) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + if err := w.Start(Task{ID: "t2", Goal: "research"}, Budget{MaxSteps: 2}, Capability{}); err != nil { + t.Fatalf("Start() error = %v", err) + } + if _, err := w.Step(context.Background()); err == nil { + t.Fatalf("expected step error") + } + result, err := w.Result() + if err != nil { + t.Fatalf("Result() error = %v", err) + } + if result.State != StateFailed || result.StopReason != StopReasonError { + t.Fatalf("unexpected result: %+v", result) + } + }) + + t.Run("max steps", func(t *testing.T) { + t.Parallel() + + w, err := NewWorker(RoleResearcher, policy, EngineFunc(func(ctx context.Context, input StepInput) (StepOutput, error) { + return StepOutput{Delta: "not done", Done: false}, nil + })) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + if err := w.Start(Task{ID: "t3", Goal: "research"}, Budget{MaxSteps: 1}, Capability{}); err != nil { + t.Fatalf("Start() error = %v", err) + } + step, err := w.Step(context.Background()) + if err != nil { + t.Fatalf("Step() error = %v", err) + } + if !step.Done || step.State != StateFailed { + t.Fatalf("expected first step to finish by max steps, got %+v", step) + } + + result, err := w.Result() + if err != nil { + t.Fatalf("Result() error = %v", err) + } + if result.StopReason != StopReasonMaxSteps { + t.Fatalf("stop reason = %q, want %q", result.StopReason, StopReasonMaxSteps) + } + }) + + t.Run("timeout", func(t *testing.T) { + t.Parallel() + + w, err := NewWorker(RoleResearcher, policy, EngineFunc(func(ctx context.Context, input StepInput) (StepOutput, error) { + return StepOutput{Done: false}, nil + })) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + if err := w.Start(Task{ID: "t4", Goal: "research"}, Budget{MaxSteps: 5, Timeout: time.Nanosecond}, Capability{}); err != nil { + t.Fatalf("Start() error = %v", err) + } + time.Sleep(2 * time.Millisecond) + + step, err := w.Step(context.Background()) + if err != nil { + t.Fatalf("Step() error = %v", err) + } + if !step.Done || step.State != StateFailed { + t.Fatalf("unexpected timeout step: %+v", step) + } + result, err := w.Result() + if err != nil { + t.Fatalf("Result() error = %v", err) + } + if result.StopReason != StopReasonTimeout { + t.Fatalf("stop reason = %q, want %q", result.StopReason, StopReasonTimeout) + } + }) +} + +func TestWorkerStopAndGuards(t *testing.T) { + t.Parallel() + + policy, err := DefaultRolePolicy(RoleReviewer) + if err != nil { + t.Fatalf("DefaultRolePolicy() error = %v", err) + } + + w, err := NewWorker(RoleReviewer, policy, nil) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + + if _, err := w.Result(); err == nil { + t.Fatalf("expected result before finish to fail") + } + if _, err := w.Step(context.Background()); err == nil { + t.Fatalf("expected step before start to fail") + } + if err := w.Start(Task{ID: "review", Goal: "review"}, Budget{}, Capability{}); err != nil { + t.Fatalf("Start() error = %v", err) + } + if err := w.Start(Task{ID: "review2", Goal: "review2"}, Budget{}, Capability{}); err == nil { + t.Fatalf("expected double start to fail") + } + if err := w.Stop(StopReasonCanceled); err != nil { + t.Fatalf("Stop() error = %v", err) + } + if w.State() != StateCanceled { + t.Fatalf("state = %q, want %q", w.State(), StateCanceled) + } + if err := w.Stop(StopReasonCanceled); err != nil { + t.Fatalf("terminal stop should be idempotent, got %v", err) + } +} + +func TestWorkerFactoryCreate(t *testing.T) { + t.Parallel() + + factory := NewWorkerFactory(func(role Role, policy RolePolicy) Engine { + return EngineFunc(func(ctx context.Context, input StepInput) (StepOutput, error) { + return StepOutput{ + Done: true, + Output: Output{ + Summary: "ok", + Findings: []string{"f1"}, + Patches: []string{"p1"}, + Risks: []string{"r1"}, + NextActions: []string{"n1"}, + Artifacts: []string{"a1"}, + }, + }, nil + }) + }) + + w, err := factory.Create(RoleCoder) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + if w.Policy().Role != RoleCoder { + t.Fatalf("policy role = %q, want %q", w.Policy().Role, RoleCoder) + } + + if _, err := factory.Create(Role("invalid")); err == nil { + t.Fatalf("expected invalid role create to fail") + } +} + +func TestWorkerRejectsInvalidOutputContract(t *testing.T) { + t.Parallel() + + policy, err := DefaultRolePolicy(RoleCoder) + if err != nil { + t.Fatalf("DefaultRolePolicy() error = %v", err) + } + w, err := NewWorker(RoleCoder, policy, EngineFunc(func(ctx context.Context, input StepInput) (StepOutput, error) { + return StepOutput{ + Done: true, + Output: Output{ + Summary: " ", + }, + }, nil + })) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + if err := w.Start(Task{ID: "t-invalid-output", Goal: "goal"}, Budget{MaxSteps: 3}, Capability{}); err != nil { + t.Fatalf("Start() error = %v", err) + } + + if _, err := w.Step(context.Background()); err == nil { + t.Fatalf("expected invalid output contract error") + } + result, err := w.Result() + if err != nil { + t.Fatalf("Result() error = %v", err) + } + if result.State != StateFailed || result.StopReason != StopReasonError { + t.Fatalf("unexpected result: %+v", result) + } +} + +func TestWorkerStartCapabilityPolicyGuard(t *testing.T) { + t.Parallel() + + policy, err := DefaultRolePolicy(RoleReviewer) + if err != nil { + t.Fatalf("DefaultRolePolicy() error = %v", err) + } + + w, err := NewWorker(RoleReviewer, policy, nil) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + + if err := w.Start(Task{ID: "t-cap", Goal: "goal"}, Budget{}, Capability{ + AllowedTools: []string{"filesystem_read_file", "bash"}, + }); err == nil { + t.Fatalf("expected disallowed capability tool to fail") + } + + wPath, err := NewWorker(RoleReviewer, policy, nil) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + if err := wPath.Start(Task{ID: "t-cap-path", Goal: "goal"}, Budget{}, Capability{ + AllowedPaths: []string{"/tmp/workspace"}, + }); err == nil { + t.Fatalf("expected unsupported allowed paths to fail") + } + + w2, err := NewWorker(RoleReviewer, policy, nil) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + if err := w2.Start(Task{ID: "t-cap-ok", Goal: "goal"}, Budget{}, Capability{}); err != nil { + t.Fatalf("Start() error = %v", err) + } + if err := w2.Stop(StopReasonCanceled); err != nil { + t.Fatalf("Stop() error = %v", err) + } + result, err := w2.Result() + if err != nil { + t.Fatalf("Result() error = %v", err) + } + if len(result.Capability.AllowedTools) != len(policy.AllowedTools) { + t.Fatalf("capability tools = %v, want policy tools %v", result.Capability.AllowedTools, policy.AllowedTools) + } +} + +func TestWorkerTraceWindow(t *testing.T) { + t.Parallel() + + policy, err := DefaultRolePolicy(RoleResearcher) + if err != nil { + t.Fatalf("DefaultRolePolicy() error = %v", err) + } + + var observedTraceLen int + w, err := NewWorker(RoleResearcher, policy, EngineFunc(func(ctx context.Context, input StepInput) (StepOutput, error) { + observedTraceLen = len(input.Trace) + return StepOutput{ + Done: true, + Output: Output{ + Summary: "done", + Findings: []string{"f1"}, + Patches: []string{"p1"}, + Risks: []string{"r1"}, + NextActions: []string{"n1"}, + Artifacts: []string{"a1"}, + }, + }, nil + })) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + impl, ok := w.(*worker) + if !ok { + t.Fatalf("expected *worker implementation") + } + impl.trace = make([]string, traceWindowSize+4) + for i := range impl.trace { + impl.trace[i] = "trace" + } + impl.state = StateRunning + impl.task = Task{ID: "trace", Goal: "goal"} + impl.budget = Budget{MaxSteps: 5, Timeout: time.Second} + impl.capability = Capability{} + impl.startedAt = time.Now() + + if _, err := w.Step(context.Background()); err != nil { + t.Fatalf("Step() error = %v", err) + } + if observedTraceLen != traceWindowSize { + t.Fatalf("trace len = %d, want %d", observedTraceLen, traceWindowSize) + } +} + +func TestWorkerTraceStorageBounded(t *testing.T) { + t.Parallel() + + policy, err := DefaultRolePolicy(RoleResearcher) + if err != nil { + t.Fatalf("DefaultRolePolicy() error = %v", err) + } + + w, err := NewWorker(RoleResearcher, policy, EngineFunc(func(ctx context.Context, input StepInput) (StepOutput, error) { + if input.StepIndex < traceWindowSize+8 { + return StepOutput{Delta: "delta", Done: false}, nil + } + return StepOutput{ + Delta: "delta", + Done: true, + Output: Output{ + Summary: "done", + Findings: []string{"f1"}, + Patches: []string{"p1"}, + Risks: []string{"r1"}, + NextActions: []string{"n1"}, + Artifacts: []string{"a1"}, + }, + }, nil + })) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + if err := w.Start(Task{ID: "t-trace-bounded", Goal: "goal"}, Budget{MaxSteps: traceWindowSize + 16}, Capability{}); err != nil { + t.Fatalf("Start() error = %v", err) + } + + for { + step, stepErr := w.Step(context.Background()) + if stepErr != nil { + t.Fatalf("Step() error = %v", stepErr) + } + if step.Done { + break + } + } + + impl, ok := w.(*worker) + if !ok { + t.Fatalf("expected *worker implementation") + } + if len(impl.trace) != traceWindowSize { + t.Fatalf("trace storage len = %d, want %d", len(impl.trace), traceWindowSize) + } +} + +func TestWorkerNilAndValidationBranches(t *testing.T) { + t.Parallel() + + if _, err := NewWorker(Role("bad"), RolePolicy{}, nil); err == nil { + t.Fatalf("expected invalid role error") + } + if _, err := NewWorker(RoleCoder, RolePolicy{Role: RoleReviewer, SystemPrompt: "p", AllowedTools: []string{"bash"}, RequiredSections: []string{"summary"}}, nil); err == nil { + t.Fatalf("expected role mismatch error") + } + + var nilWorker *worker + if err := nilWorker.Start(Task{}, Budget{}, Capability{}); err == nil { + t.Fatalf("expected nil worker start error") + } + if _, err := nilWorker.Step(context.Background()); err == nil { + t.Fatalf("expected nil worker step error") + } + if err := nilWorker.Stop(StopReasonCanceled); err == nil { + t.Fatalf("expected nil worker stop error") + } + if _, err := nilWorker.Result(); err == nil { + t.Fatalf("expected nil worker result error") + } + if nilWorker.State() != StateIdle { + t.Fatalf("nil worker state = %q, want %q", nilWorker.State(), StateIdle) + } + nilPolicy := nilWorker.Policy() + if nilPolicy.Role != "" || nilPolicy.SystemPrompt != "" || len(nilPolicy.AllowedTools) != 0 || len(nilPolicy.RequiredSections) != 0 { + t.Fatalf("nil worker policy should be empty, got %+v", nilPolicy) + } +} diff --git a/internal/tui/core/app/app.go b/internal/tui/core/app/app.go index a42261cb..ce7a7b72 100644 --- a/internal/tui/core/app/app.go +++ b/internal/tui/core/app/app.go @@ -26,7 +26,6 @@ import ( type panel = tuistate.Panel const ( - panelSessions panel = tuistate.PanelSessions panelTranscript panel = tuistate.PanelTranscript panelActivity panel = tuistate.PanelActivity panelInput panel = tuistate.PanelInput @@ -38,6 +37,7 @@ const ( pickerNone pickerMode = tuistate.PickerNone pickerProvider pickerMode = tuistate.PickerProvider pickerModel pickerMode = tuistate.PickerModel + pickerSession pickerMode = tuistate.PickerSession pickerFile pickerMode = tuistate.PickerFile pickerHelp pickerMode = tuistate.PickerHelp ) @@ -72,11 +72,11 @@ type appComponents struct { keys keyMap help help.Model spinner spinner.Model - sessions list.Model commandMenu list.Model commandMenuMeta tuistate.CommandMenuMeta providerPicker list.Model modelPicker list.Model + sessionPicker list.Model helpPicker list.Model fileBrowser filepicker.Model progress progress.Model @@ -88,24 +88,38 @@ type appComponents struct { // appRuntimeState 聚合运行期易变字段,降低 App 顶层字段密度。 type appRuntimeState struct { - codeCopyBlocks map[int]string - pendingCopyID int - deferredEventCmd tea.Cmd - nowFn func() time.Time - lastInputEditAt time.Time - lastPasteLikeAt time.Time - inputBurstStart time.Time - inputBurstCount int - pasteMode bool - activeMessages []providertypes.Message - activities []tuistate.ActivityEntry - fileCandidates []string - modelRefreshID string - focus panel - runProgressValue float64 - runProgressKnown bool - runProgressLabel string - pendingPermission *permissionPromptState + codeCopyBlocks map[int]string + pendingCopyID int + deferredEventCmd tea.Cmd + nowFn func() time.Time + lastInputEditAt time.Time + lastPasteLikeAt time.Time + inputBurstStart time.Time + inputBurstCount int + pasteMode bool + activeMessages []providertypes.Message + activities []tuistate.ActivityEntry + fileCandidates []string + modelRefreshID string + focus panel + runProgressValue float64 + runProgressKnown bool + runProgressLabel string + pendingPermission *permissionPromptState + pendingImageAttachments []pendingImageAttachment + currentModelCapabilities modelCapabilityState +} + +type pendingImageAttachment struct { + Path string + MimeType string + Size int64 + Name string +} + +type modelCapabilityState struct { + supportsImageInput bool + checked bool } type App struct { @@ -160,18 +174,6 @@ func newApp(container tuibootstrap.Container) (App, error) { return App{}, err } keys := newKeyMap() - delegate := sessionDelegate{styles: uiStyles} - sessionList := list.New([]list.Item{}, delegate, 0, 0) - sessionList.Title = "" - sessionList.SetShowTitle(false) - sessionList.SetShowHelp(false) - sessionList.SetShowStatusBar(false) - sessionList.SetShowFilter(false) - sessionList.SetShowPagination(false) - sessionList.SetFilteringEnabled(true) - sessionList.DisableQuitKeybindings() - sessionList.FilterInput.Prompt = "Filter: " - sessionList.FilterInput.Placeholder = "Type to search sessions" input := textarea.New() input.Placeholder = "Ask NeoCode to inspect, edit, or build. Type / to browse commands." @@ -237,10 +239,10 @@ func newApp(container tuibootstrap.Container) (App, error) { keys: keys, help: h, spinner: spin, - sessions: sessionList, commandMenu: commandMenu, providerPicker: newSelectionPickerItems(nil), modelPicker: newSelectionPickerItems(nil), + sessionPicker: newSelectionPickerItems(nil), helpPicker: newHelpPickerItems(nil), fileBrowser: fileBrowser, progress: progressBar, @@ -259,15 +261,6 @@ func newApp(container tuibootstrap.Container) (App, error) { styles: uiStyles, } - if err := app.refreshSessions(); err != nil { - return App{}, err - } - if len(app.state.Sessions) > 0 { - app.state.ActiveSessionID = app.state.Sessions[0].ID - if err := app.refreshMessages(); err != nil { - return App{}, err - } - } app.syncActiveSessionTitle() app.syncConfigState(configManager.Get()) if err := app.refreshProviderPicker(); err != nil { diff --git a/internal/tui/core/app/command_menu.go b/internal/tui/core/app/command_menu.go index 8e2a763a..065d13fa 100644 --- a/internal/tui/core/app/command_menu.go +++ b/internal/tui/core/app/command_menu.go @@ -86,6 +86,14 @@ type sessionItem struct { Active bool } +func (s sessionItem) Title() string { + return s.Summary.Title +} + +func (s sessionItem) Description() string { + return s.Summary.UpdatedAt.Format("01-02 15:04") +} + func (s sessionItem) FilterValue() string { return strings.ToLower(s.Summary.Title) } diff --git a/internal/tui/core/app/commands.go b/internal/tui/core/app/commands.go index 47ff7d2b..3f1c956a 100644 --- a/internal/tui/core/app/commands.go +++ b/internal/tui/core/app/commands.go @@ -25,6 +25,7 @@ const ( slashCommandStatus = "/status" slashCommandProvider = "/provider" slashCommandModelPick = "/model" + slashCommandSession = "/session" slashCommandCWD = "/cwd" slashCommandMemo = "/memo" slashCommandRemember = "/remember" @@ -37,6 +38,7 @@ const ( slashUsageStatus = "/status" slashUsageProvider = "/provider" slashUsageModel = "/model" + slashUsageSession = "/session" slashUsageWorkdir = "/cwd" slashUsageMemo = "/memo" slashUsageRemember = "/remember " @@ -47,6 +49,8 @@ const ( providerPickerSubtitle = "Up/Down choose, Enter confirm, Esc cancel" modelPickerTitle = "Select Model" modelPickerSubtitle = "Up/Down choose, Enter confirm, Esc cancel" + sessionPickerTitle = "Select Session" + sessionPickerSubtitle = "Up/Down choose, Enter confirm, Esc cancel" helpPickerTitle = "Slash Commands" helpPickerSubtitle = "Up/Down choose, Enter run, Esc cancel" filePickerTitle = "Browse Files" @@ -76,6 +80,7 @@ const ( statusCompacting = "Compacting context" statusChooseProvider = "Choose a provider" statusChooseModel = "Choose a model" + statusChooseSession = "Choose a session" statusChooseHelp = "Choose a slash command" statusBrowseFile = "Browse workspace files" statusPermissionRequired = "Permission required: choose a decision and press Enter" @@ -119,6 +124,7 @@ var builtinSlashCommands = []slashCommand{ {Usage: slashUsageForget, Description: "Remove memos matching keyword (/forget )"}, {Usage: slashUsageProvider, Description: "Open the interactive provider picker"}, {Usage: slashUsageModel, Description: "Open the interactive model picker"}, + {Usage: slashUsageSession, Description: "Switch to another session"}, {Usage: slashUsageExit, Description: "Exit NeoCode"}, } diff --git a/internal/tui/core/app/input_features.go b/internal/tui/core/app/input_features.go index c0a99795..8c56f9fe 100644 --- a/internal/tui/core/app/input_features.go +++ b/internal/tui/core/app/input_features.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "path/filepath" + "strconv" "strings" tea "github.com/charmbracelet/bubbletea" @@ -17,10 +18,14 @@ const ( workspaceCommandPrefix = "&" workspaceCommandUsage = "& " fileReferencePrefix = "@" + imageReferencePrefix = "@image:" + imageReferenceUsage = "@image:" fileMenuTitle = "Files" shellMenuTitle = "Shell" maxWorkspaceFiles = 4000 maxFileSuggestions = 6 + maxImageAttachments = 3 + imageMaxSizeBytes = 5 * 1024 * 1024 // 5 MiB ) type tokenSelector int @@ -31,6 +36,9 @@ const ( ) var workspaceCommandExecutor = defaultWorkspaceCommandExecutor +var readClipboardImage = tuiinfra.ReadClipboardImage +var saveClipboardImageToTempFile = tuiinfra.SaveImageToTempFile +var detectImageMimeType = tuiinfra.DetectImageMimeType func isWorkspaceCommandInput(input string) bool { return strings.HasPrefix(strings.TrimSpace(input), workspaceCommandPrefix) @@ -198,12 +206,20 @@ func currentReferenceToken(input string) (start int, end int, token string, ok b if !ok { return 0, 0, "", false } - if !strings.HasPrefix(token, fileReferencePrefix) { + if !strings.HasPrefix(token, fileReferencePrefix) && !strings.HasPrefix(token, imageReferencePrefix) { return 0, 0, "", false } return start, end, token, true } +func (a *App) applyImageReference(input string) error { + path := extractImageReference(input) + if path == "" { + return fmt.Errorf("invalid image reference") + } + return a.addImageAttachment(path) +} + func (a *App) applyFileReference(path string) error { path = strings.TrimSpace(path) if path == "" { @@ -246,3 +262,197 @@ func (a *App) applyFileReference(path string) error { a.state.StatusText = fmt.Sprintf("[System] Added file reference %s.", reference) return nil } + +func isImageReferenceInput(input string) bool { + return strings.HasPrefix(strings.TrimSpace(input), imageReferencePrefix) +} + +func extractImageReference(input string) string { + trimmed := strings.TrimSpace(input) + if !strings.HasPrefix(trimmed, imageReferencePrefix) { + return "" + } + return strings.TrimPrefix(trimmed, imageReferencePrefix) +} + +func (a *App) addImageAttachment(path string) error { + path = strings.TrimSpace(path) + if path == "" { + return fmt.Errorf("image path is empty") + } + + if len(a.pendingImageAttachments) >= maxImageAttachments { + return fmt.Errorf("maximum %d image attachments allowed", maxImageAttachments) + } + + absPath, err := filepath.Abs(path) + if err != nil { + return fmt.Errorf("invalid image path: %w", err) + } + + info, err := tuiinfra.GetFileInfo(absPath) + if err != nil { + return fmt.Errorf("cannot read image file: %w", err) + } + + if info.Size() > imageMaxSizeBytes { + return fmt.Errorf("image size exceeds %d MB limit", imageMaxSizeBytes/(1024*1024)) + } + + mimeType := detectImageMimeType(absPath) + if mimeType == "" { + return fmt.Errorf("unsupported image format") + } + + a.pendingImageAttachments = append(a.pendingImageAttachments, pendingImageAttachment{ + Path: absPath, + MimeType: mimeType, + Size: info.Size(), + Name: filepath.Base(absPath), + }) + + a.refreshImageAttachmentDisplay() + a.state.StatusText = fmt.Sprintf("[System] Added image: %s", filepath.Base(absPath)) + return nil +} + +func (a *App) removeImageAttachment(index int) error { + if index < 0 || index >= len(a.pendingImageAttachments) { + return fmt.Errorf("invalid attachment index") + } + + removed := a.pendingImageAttachments[index] + a.pendingImageAttachments = append(a.pendingImageAttachments[:index], a.pendingImageAttachments[index+1:]...) + + a.refreshImageAttachmentDisplay() + a.state.StatusText = fmt.Sprintf("[System] Removed image: %s", removed.Name) + return nil +} + +func (a *App) clearImageAttachments() { + a.pendingImageAttachments = nil +} + +func (a *App) getImageAttachmentCount() int { + return len(a.pendingImageAttachments) +} + +func (a *App) refreshImageAttachmentDisplay() { + a.normalizeComposerHeight() + a.applyComponentLayout(false) +} + +func (a *App) hasImageAttachments() bool { + return len(a.pendingImageAttachments) > 0 +} + +func (a *App) getImageAttachments() []pendingImageAttachment { + return a.pendingImageAttachments +} + +func (a *App) loadImageAttachmentData(index int) ([]byte, error) { + if index < 0 || index >= len(a.pendingImageAttachments) { + return nil, fmt.Errorf("invalid attachment index") + } + return tuiinfra.ReadImageFile(a.pendingImageAttachments[index].Path) +} + +func (a *App) addImageFromClipboard() error { + if len(a.pendingImageAttachments) >= maxImageAttachments { + return fmt.Errorf("maximum %d image attachments allowed", maxImageAttachments) + } + + data, err := readClipboardImage() + if err != nil { + return fmt.Errorf("failed to read clipboard image: %w", err) + } + + if data == nil || len(data) == 0 { + return fmt.Errorf("no image in clipboard") + } + + if int64(len(data)) > imageMaxSizeBytes { + return fmt.Errorf("image size exceeds %d MB limit", imageMaxSizeBytes/(1024*1024)) + } + + tmpPath, err := saveClipboardImageToTempFile(data, "paste") + if err != nil { + return fmt.Errorf("failed to save clipboard image: %w", err) + } + + mimeType := detectImageMimeType(tmpPath) + if mimeType == "" { + return fmt.Errorf("unsupported image format from clipboard") + } + + a.pendingImageAttachments = append(a.pendingImageAttachments, pendingImageAttachment{ + Path: tmpPath, + MimeType: mimeType, + Size: int64(len(data)), + Name: "clipboard_image.png", + }) + + a.refreshImageAttachmentDisplay() + a.state.StatusText = "[System] Added image from clipboard" + return nil +} + +func (a *App) checkModelImageSupport() bool { + if a.currentModelCapabilities.checked { + return a.currentModelCapabilities.supportsImageInput + } + + models, err := a.providerSvc.ListModelsSnapshot(context.Background()) + if err != nil { + a.currentModelCapabilities.checked = true + a.currentModelCapabilities.supportsImageInput = false + return false + } + + for _, m := range models { + if m.ID == a.state.CurrentModel { + a.currentModelCapabilities.checked = true + a.currentModelCapabilities.supportsImageInput = m.CapabilityHints.ImageInput == "supported" + return a.currentModelCapabilities.supportsImageInput + } + } + + a.currentModelCapabilities.checked = true + a.currentModelCapabilities.supportsImageInput = false + return false +} + +func (a *App) canSendImageInput() bool { + return a.checkModelImageSupport() +} + +// invalidateModelCapabilityCache 在 provider 或 model 变化时清理图片能力缓存,避免复用旧结果。 +func (a *App) invalidateModelCapabilityCache() { + a.currentModelCapabilities = modelCapabilityState{} +} + +// composeMessageWithImageAttachments 在发送前把附件元信息拼接到文本,避免附件在运行链路中丢失。 +func (a *App) composeMessageWithImageAttachments(content string) string { + trimmed := strings.TrimSpace(content) + if len(a.pendingImageAttachments) == 0 { + return trimmed + } + + var builder strings.Builder + builder.WriteString(trimmed) + builder.WriteString("\n\n[Attached images]\n") + for index, attachment := range a.pendingImageAttachments { + builder.WriteString(strconv.Itoa(index + 1)) + builder.WriteString(". ") + builder.WriteString(attachment.Name) + builder.WriteString(" | mime=") + builder.WriteString(attachment.MimeType) + builder.WriteString(" | bytes=") + builder.WriteString(strconv.FormatInt(attachment.Size, 10)) + builder.WriteString(" | path=") + builder.WriteString(attachment.Path) + builder.WriteString("\n") + } + builder.WriteString("Treat the list above as user-provided image attachments.") + return builder.String() +} diff --git a/internal/tui/core/app/input_features_test.go b/internal/tui/core/app/input_features_test.go new file mode 100644 index 00000000..db32b19a --- /dev/null +++ b/internal/tui/core/app/input_features_test.go @@ -0,0 +1,463 @@ +package tui + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + tea "github.com/charmbracelet/bubbletea" + + "neo-code/internal/config" + configstate "neo-code/internal/config/state" + providertypes "neo-code/internal/provider/types" +) + +type snapshotErrProviderService struct { + stubProviderService + err error +} + +func (s snapshotErrProviderService) ListModelsSnapshot(ctx context.Context) ([]providertypes.ModelDescriptor, error) { + return nil, s.err +} + +func TestTokenAndReferenceParsing(t *testing.T) { + start, end, token, ok := tokenRange(" @file/path", tokenSelectorFirst) + if !ok || start != 2 || end != len(" @file/path") || token != "@file/path" { + t.Fatalf("unexpected first token parse: start=%d end=%d token=%q ok=%v", start, end, token, ok) + } + + start, end, token, ok = tokenRange("hello @image:/tmp/a.png", tokenSelectorLast) + if !ok || token != "@image:/tmp/a.png" || start <= 0 || end <= start { + t.Fatalf("unexpected last token parse: start=%d end=%d token=%q ok=%v", start, end, token, ok) + } + + _, _, _, ok = currentReferenceToken("hello world") + if ok { + t.Fatalf("expected non-reference token to be rejected") + } + + _, _, token, ok = currentReferenceToken("x @a/b.txt") + if !ok || token != "@a/b.txt" { + t.Fatalf("expected file reference token, got token=%q ok=%v", token, ok) + } + + if !isImageReferenceInput("@image:/tmp/p.png") { + t.Fatalf("expected image reference input recognized") + } + if got := extractImageReference("@image:/tmp/p.png"); got != "/tmp/p.png" { + t.Fatalf("unexpected image reference extraction: %q", got) + } +} + +func TestApplyFileReference(t *testing.T) { + app, _ := newTestApp(t) + root := t.TempDir() + app.state.CurrentWorkdir = root + inside := filepath.Join(root, "docs", "a.md") + if err := os.MkdirAll(filepath.Dir(inside), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(inside, []byte("ok"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + if err := app.applyFileReference(inside); err != nil { + t.Fatalf("applyFileReference() error = %v", err) + } + if got := app.input.Value(); !strings.Contains(got, "@docs/a.md") { + t.Fatalf("expected relative file reference, got %q", got) + } + + app.input.SetValue("prefix @old/ref") + if err := app.applyFileReference(inside); err != nil { + t.Fatalf("applyFileReference replace() error = %v", err) + } + if got := app.input.Value(); strings.Contains(got, "@old/ref") || !strings.Contains(got, "@docs/a.md") { + t.Fatalf("expected active token replaced, got %q", got) + } +} + +func TestApplyFileReferenceBranches(t *testing.T) { + app, _ := newTestApp(t) + if err := app.applyFileReference(" "); err == nil { + t.Fatalf("expected empty file path error") + } + + root := t.TempDir() + app.state.CurrentWorkdir = filepath.Join(root, "workdir") + inside := filepath.Join(app.state.CurrentWorkdir, "a.txt") + outside := filepath.Join(root, "outside.txt") + if err := os.MkdirAll(filepath.Dir(inside), 0o755); err != nil { + t.Fatalf("mkdir inside: %v", err) + } + if err := os.WriteFile(inside, []byte("a"), 0o644); err != nil { + t.Fatalf("write inside: %v", err) + } + if err := os.WriteFile(outside, []byte("b"), 0o644); err != nil { + t.Fatalf("write outside: %v", err) + } + + if err := app.applyFileReference(outside); err != nil { + t.Fatalf("apply outside reference error: %v", err) + } + if !strings.Contains(app.input.Value(), "@") { + t.Fatalf("expected file reference token to be inserted") + } +} + +func TestImageAttachmentLifecycle(t *testing.T) { + app, _ := newTestApp(t) + root := t.TempDir() + imagePath := filepath.Join(root, "test.png") + if err := os.WriteFile(imagePath, []byte("fake-png"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + if err := app.addImageAttachment(imagePath); err != nil { + t.Fatalf("addImageAttachment() error = %v", err) + } + if app.getImageAttachmentCount() != 1 { + t.Fatalf("expected one attachment, got %d", app.getImageAttachmentCount()) + } + if !app.hasImageAttachments() { + t.Fatalf("expected hasImageAttachments() true") + } + if _, err := app.loadImageAttachmentData(0); err != nil { + t.Fatalf("loadImageAttachmentData() error = %v", err) + } + + if err := app.removeImageAttachment(0); err != nil { + t.Fatalf("removeImageAttachment() error = %v", err) + } + if app.getImageAttachmentCount() != 0 { + t.Fatalf("expected no attachments after remove") + } + if err := app.removeImageAttachment(0); err == nil { + t.Fatalf("expected out-of-range remove error") + } +} + +func TestAddImageAttachmentLimit(t *testing.T) { + app, _ := newTestApp(t) + root := t.TempDir() + + for i := 0; i < maxImageAttachments; i++ { + path := filepath.Join(root, fmt.Sprintf("%d.png", i)) + if err := os.WriteFile(path, []byte("png"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + if err := app.addImageAttachment(path); err != nil { + t.Fatalf("addImageAttachment(%d) error = %v", i, err) + } + } + + over := filepath.Join(root, "over.png") + if err := os.WriteFile(over, []byte("png"), 0o644); err != nil { + t.Fatalf("write over image: %v", err) + } + if err := app.addImageAttachment(over); err == nil { + t.Fatalf("expected attachment limit error") + } +} + +func TestCanSendImageInputCacheInvalidationOnModelChange(t *testing.T) { + app, _ := newTestApp(t) + providerID := app.state.CurrentProvider + + app.providerSvc = stubProviderService{ + providers: []configstate.ProviderOption{{ID: providerID, Name: providerID}}, + models: []providertypes.ModelDescriptor{{ + ID: "model-a", + Name: "model-a", + CapabilityHints: providertypes.ModelCapabilityHints{ + ImageInput: providertypes.ModelCapabilityStateSupported, + }, + }}, + } + app.state.CurrentModel = "model-a" + if !app.canSendImageInput() { + t.Fatalf("expected model-a to support images") + } + + app.providerSvc = stubProviderService{ + providers: []configstate.ProviderOption{{ID: providerID, Name: providerID}}, + models: []providertypes.ModelDescriptor{{ + ID: "model-b", + Name: "model-b", + CapabilityHints: providertypes.ModelCapabilityHints{ + ImageInput: providertypes.ModelCapabilityStateUnsupported, + }, + }}, + } + app.syncConfigState(config.Config{SelectedProvider: providerID, CurrentModel: "model-b", Workdir: app.state.CurrentWorkdir}) + if app.canSendImageInput() { + t.Fatalf("expected model-b to be unsupported after cache invalidation") + } +} + +func TestComposeMessageWithImageAttachments(t *testing.T) { + app, _ := newTestApp(t) + app.pendingImageAttachments = []pendingImageAttachment{{ + Path: "/tmp/a.png", + Name: "a.png", + MimeType: "image/png", + Size: 12, + }} + + got := app.composeMessageWithImageAttachments("hello") + if !strings.Contains(got, "[Attached images]") || !strings.Contains(got, "a.png") || !strings.Contains(got, "path=/tmp/a.png") { + t.Fatalf("unexpected composed message: %q", got) + } +} + +func TestComposeMessageWithImageAttachmentsNoAttachments(t *testing.T) { + app, _ := newTestApp(t) + got := app.composeMessageWithImageAttachments(" hello ") + if got != "hello" { + t.Fatalf("expected trimmed content without attachment block, got %q", got) + } +} + +func TestApplyImageReference(t *testing.T) { + app, _ := newTestApp(t) + root := t.TempDir() + imagePath := filepath.Join(root, "ok.png") + if err := os.WriteFile(imagePath, []byte("png"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + if err := app.applyImageReference("@image:" + imagePath); err != nil { + t.Fatalf("applyImageReference() error = %v", err) + } + if app.getImageAttachmentCount() != 1 { + t.Fatalf("expected one attachment after applyImageReference") + } + if err := app.applyImageReference("not-an-image-reference"); err == nil { + t.Fatalf("expected invalid image reference error") + } +} + +func TestGetAndClearImageAttachments(t *testing.T) { + app, _ := newTestApp(t) + app.pendingImageAttachments = []pendingImageAttachment{ + {Name: "a.png", Path: "/tmp/a.png", MimeType: "image/png", Size: 1}, + } + if len(app.getImageAttachments()) != 1 { + t.Fatalf("expected one attachment from getter") + } + app.clearImageAttachments() + if len(app.getImageAttachments()) != 0 { + t.Fatalf("expected no attachments after clear") + } +} + +func TestLoadImageAttachmentDataInvalidIndex(t *testing.T) { + app, _ := newTestApp(t) + if _, err := app.loadImageAttachmentData(0); err == nil { + t.Fatalf("expected invalid attachment index error") + } +} + +func TestAddImageFromClipboardUnsupported(t *testing.T) { + app, _ := newTestApp(t) + if err := app.addImageFromClipboard(); err == nil { + t.Fatalf("expected unsupported clipboard image error") + } +} + +func TestAddImageFromClipboardSuccess(t *testing.T) { + app, _ := newTestApp(t) + originalRead := readClipboardImage + originalSave := saveClipboardImageToTempFile + originalDetect := detectImageMimeType + readClipboardImage = func() ([]byte, error) { + return []byte("image-bytes"), nil + } + saveClipboardImageToTempFile = func(data []byte, prefix string) (string, error) { + path := filepath.Join(t.TempDir(), "clipboard.png") + if err := os.WriteFile(path, data, 0o644); err != nil { + t.Fatalf("write temp clipboard image: %v", err) + } + return path, nil + } + detectImageMimeType = func(path string) string { return "image/png" } + defer func() { + readClipboardImage = originalRead + saveClipboardImageToTempFile = originalSave + detectImageMimeType = originalDetect + }() + + if err := app.addImageFromClipboard(); err != nil { + t.Fatalf("addImageFromClipboard() error = %v", err) + } + if app.getImageAttachmentCount() != 1 { + t.Fatalf("expected one clipboard image attachment") + } +} + +func TestAddImageFromClipboardBranches(t *testing.T) { + app, _ := newTestApp(t) + originalRead := readClipboardImage + originalSave := saveClipboardImageToTempFile + originalDetect := detectImageMimeType + defer func() { + readClipboardImage = originalRead + saveClipboardImageToTempFile = originalSave + detectImageMimeType = originalDetect + }() + + readClipboardImage = func() ([]byte, error) { return nil, nil } + if err := app.addImageFromClipboard(); err == nil { + t.Fatalf("expected no image in clipboard error") + } + + readClipboardImage = func() ([]byte, error) { return make([]byte, imageMaxSizeBytes+1), nil } + if err := app.addImageFromClipboard(); err == nil { + t.Fatalf("expected image size limit error") + } + + readClipboardImage = func() ([]byte, error) { return []byte("x"), nil } + saveClipboardImageToTempFile = func(data []byte, prefix string) (string, error) { + return filepath.Join(t.TempDir(), "clipboard.bin"), nil + } + detectImageMimeType = func(path string) string { return "" } + if err := app.addImageFromClipboard(); err == nil { + t.Fatalf("expected unsupported image format error") + } + + readClipboardImage = func() ([]byte, error) { return []byte("x"), nil } + saveClipboardImageToTempFile = func(data []byte, prefix string) (string, error) { + return "", errors.New("save failed") + } + if err := app.addImageFromClipboard(); err == nil { + t.Fatalf("expected save failure error") + } +} + +func TestCheckModelImageSupportErrorAndModelNotFound(t *testing.T) { + app, _ := newTestApp(t) + app.providerSvc = snapshotErrProviderService{ + stubProviderService: stubProviderService{}, + err: errors.New("boom"), + } + if app.checkModelImageSupport() { + t.Fatalf("expected false when provider snapshot fails") + } + if !app.currentModelCapabilities.checked { + t.Fatalf("expected capability cache to be marked checked after failure") + } + + app.currentModelCapabilities = modelCapabilityState{} + app.providerSvc = stubProviderService{ + providers: []configstate.ProviderOption{{ID: app.state.CurrentProvider, Name: app.state.CurrentProvider}}, + models: []providertypes.ModelDescriptor{{ + ID: "other-model", + }}, + } + if app.checkModelImageSupport() { + t.Fatalf("expected false when current model is missing from snapshot") + } +} + +func TestExecuteWorkspaceCommand(t *testing.T) { + app, _ := newTestApp(t) + original := workspaceCommandExecutor + workspaceCommandExecutor = func(ctx context.Context, cfg config.Config, workdir string, command string) (string, error) { + if command != "echo hi" { + t.Fatalf("unexpected command: %q", command) + } + return "ok", nil + } + defer func() { workspaceCommandExecutor = original }() + + command, output, err := executeWorkspaceCommand(context.Background(), app.configManager, app.state.CurrentWorkdir, "& echo hi") + if err != nil { + t.Fatalf("executeWorkspaceCommand() error = %v", err) + } + if command != "echo hi" || output != "ok" { + t.Fatalf("unexpected execute result command=%q output=%q", command, output) + } + + if _, _, err := executeWorkspaceCommand(context.Background(), app.configManager, app.state.CurrentWorkdir, "& "); err == nil { + t.Fatalf("expected invalid workspace command error") + } +} + +func TestDefaultWorkspaceCommandExecutor(t *testing.T) { + cfg := config.Config{Workdir: t.TempDir(), Shell: "bash", ToolTimeoutSec: 1} + if _, err := defaultWorkspaceCommandExecutor(context.Background(), cfg, cfg.Workdir, ""); err == nil { + t.Fatalf("expected empty command to fail") + } +} + +func TestRunWorkspaceCommandCmd(t *testing.T) { + app, _ := newTestApp(t) + original := workspaceCommandExecutor + workspaceCommandExecutor = func(ctx context.Context, cfg config.Config, workdir string, command string) (string, error) { + return "done", nil + } + defer func() { workspaceCommandExecutor = original }() + + cmd := runWorkspaceCommand(app.configManager, app.state.CurrentWorkdir, "& echo hi") + if cmd == nil { + t.Fatalf("expected workspace command cmd") + } + msg := cmd() + result, ok := msg.(workspaceCommandResultMsg) + if !ok { + t.Fatalf("expected workspaceCommandResultMsg, got %T", msg) + } + if result.Command != "echo hi" || result.Output != "done" || result.Err != nil { + t.Fatalf("unexpected workspace result: %+v", result) + } +} + +func TestUpdateSendWithImageAttachmentsComposesRuntimeInput(t *testing.T) { + app, runtime := newTestApp(t) + root := t.TempDir() + imagePath := filepath.Join(root, "queued.png") + if err := os.WriteFile(imagePath, []byte("fake-png"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + if err := app.addImageAttachment(imagePath); err != nil { + t.Fatalf("addImageAttachment() error = %v", err) + } + app.providerSvc = stubProviderService{ + providers: []configstate.ProviderOption{{ID: app.state.CurrentProvider, Name: app.state.CurrentProvider}}, + models: []providertypes.ModelDescriptor{{ + ID: app.state.CurrentModel, + Name: app.state.CurrentModel, + CapabilityHints: providertypes.ModelCapabilityHints{ + ImageInput: providertypes.ModelCapabilityStateSupported, + }, + }}, + } + + app.input.SetValue("hello") + app.state.InputText = "hello" + + model, cmd := app.Update(tea.KeyMsg{Type: tea.KeyEnter}) + if cmd == nil { + t.Fatalf("expected run command") + } + app = model.(App) + if app.hasImageAttachments() { + t.Fatalf("expected attachments cleared after send") + } + if len(app.activeMessages) == 0 || !strings.Contains(app.activeMessages[len(app.activeMessages)-1].Content, "[Attached images]") { + t.Fatalf("expected composed user message in transcript") + } + + msg := cmd() + _, ok := msg.(runFinishedMsg) + if !ok { + t.Fatalf("expected runFinishedMsg, got %T", msg) + } + if len(runtime.runInputs) != 1 || !strings.Contains(runtime.runInputs[0].Content, "[Attached images]") { + t.Fatalf("expected composed runtime input, got %+v", runtime.runInputs) + } +} diff --git a/internal/tui/core/app/keymap.go b/internal/tui/core/app/keymap.go index 7c6b9df5..3d604fb3 100644 --- a/internal/tui/core/app/keymap.go +++ b/internal/tui/core/app/keymap.go @@ -10,7 +10,6 @@ type keyMap struct { NextPanel key.Binding PrevPanel key.Binding FocusInput key.Binding - OpenSession key.Binding ToggleHelp key.Binding Quit key.Binding ScrollUp key.Binding @@ -19,6 +18,7 @@ type keyMap struct { PageDown key.Binding Top key.Binding Bottom key.Binding + PasteImage key.Binding } func newKeyMap() keyMap { @@ -51,10 +51,6 @@ func newKeyMap() keyMap { key.WithKeys("esc"), key.WithHelp("Esc", "Focus input"), ), - OpenSession: key.NewBinding( - key.WithKeys("enter"), - key.WithHelp("Enter", "Open session"), - ), ToggleHelp: key.NewBinding( key.WithKeys("ctrl+q"), key.WithHelp("Ctrl+Q", "/help"), @@ -87,6 +83,10 @@ func newKeyMap() keyMap { key.WithKeys("G", "end"), key.WithHelp("Shift+G/End", "Bottom"), ), + PasteImage: key.NewBinding( + key.WithKeys("ctrl+v"), + key.WithHelp("Ctrl+V", "Paste image"), + ), } } @@ -97,8 +97,8 @@ func (k keyMap) ShortHelp() []key.Binding { func (k keyMap) FullHelp() [][]key.Binding { return [][]key.Binding{ {k.Send, k.Newline, k.CancelAgent, k.NewSession}, - {k.OpenSession, k.FocusInput, k.NextPanel, k.PrevPanel}, - {k.ToggleHelp, k.Quit, k.ScrollUp, k.ScrollDown}, + {k.FocusInput, k.NextPanel, k.PrevPanel}, + {k.ToggleHelp, k.Quit, k.PasteImage, k.ScrollUp}, {k.PageUp, k.PageDown, k.Top, k.Bottom}, } } diff --git a/internal/tui/core/app/keymap_test.go b/internal/tui/core/app/keymap_test.go new file mode 100644 index 00000000..c08126cb --- /dev/null +++ b/internal/tui/core/app/keymap_test.go @@ -0,0 +1,16 @@ +package tui + +import "testing" + +func TestFullHelpIncludesPasteImage(t *testing.T) { + keys := newKeyMap() + help := keys.FullHelp() + for _, row := range help { + for _, binding := range row { + if binding.Help().Key == keys.PasteImage.Help().Key { + return + } + } + } + t.Fatalf("expected full help to include paste image binding") +} diff --git a/internal/tui/core/app/update.go b/internal/tui/core/app/update.go index a20336e6..63756cc5 100644 --- a/internal/tui/core/app/update.go +++ b/internal/tui/core/app/update.go @@ -23,6 +23,7 @@ import ( "neo-code/internal/tools" tuistatus "neo-code/internal/tui/core/status" tuiutils "neo-code/internal/tui/core/utils" + tuiinfra "neo-code/internal/tui/infra" tuiservices "neo-code/internal/tui/services" tuistate "neo-code/internal/tui/state" ) @@ -38,7 +39,7 @@ const ( pasteBurstThreshold = tuistate.PasteBurstThreshold ) -var panelOrder = []panel{panelSessions, panelTranscript, panelActivity, panelInput} +var panelOrder = []panel{panelTranscript, panelActivity, panelInput} func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmds []tea.Cmd @@ -60,7 +61,6 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { cmds = append(cmds, a.deferredEventCmd) a.deferredEventCmd = nil } - _ = a.refreshSessions() a.syncActiveSessionTitle() if transcriptDirty { a.rebuildTranscript() @@ -98,7 +98,6 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if !a.state.IsAgentRunning { a.clearRunProgress() } - _ = a.refreshSessions() a.syncActiveSessionTitle() return a, tea.Batch(cmds...) case permissionResolutionFinishedMsg: @@ -140,11 +139,6 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.state.ExecutionError = typed.Err.Error() a.state.StatusText = typed.Err.Error() } - if err := a.refreshSessions(); err != nil { - a.state.ExecutionError = err.Error() - a.state.StatusText = err.Error() - a.appendInlineMessage(roleError, err.Error()) - } if err := a.refreshMessages(); err != nil && strings.TrimSpace(a.state.ActiveSessionID) != "" { a.state.ExecutionError = err.Error() a.state.StatusText = err.Error() @@ -270,23 +264,15 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return a, tea.Batch(cmds...) } - switch a.focus { - case panelSessions: - if key.Matches(typed, a.keys.OpenSession) && !a.sessions.SettingFilter() { - if err := a.activateSelectedSession(); err != nil { - a.state.StatusText = err.Error() - a.state.ExecutionError = err.Error() - a.appendActivity("system", "Failed to open session", err.Error(), true) - } - a.focus = panelInput - a.applyFocus() - return a, tea.Batch(cmds...) + if key.Matches(typed, a.keys.PasteImage) { + if err := a.addImageFromClipboard(); err != nil { + a.state.StatusText = err.Error() + a.appendActivity("multimodal", "Failed to paste image", err.Error(), true) } - var cmd tea.Cmd - a.sessions, cmd = a.sessions.Update(msg) - a.sessions.SetShowFilter(a.sessions.FilterState() != list.Unfiltered) - cmds = append(cmds, cmd) return a, tea.Batch(cmds...) + } + + switch a.focus { case panelTranscript: a.handleViewportKeys(&a.transcript, typed) return a, tea.Batch(cmds...) @@ -338,6 +324,20 @@ func (a App) updateInputPanel(msg tea.Msg, typed tea.KeyMsg, cmds []tea.Cmd) (te return a, tea.Batch(cmds...) } + if isImageReferenceInput(input) { + if err := a.applyImageReference(input); err != nil { + a.state.ExecutionError = err.Error() + a.state.StatusText = err.Error() + a.appendActivity("multimodal", "Failed to add image reference", err.Error(), true) + } + a.input.Reset() + a.state.InputText = "" + a.applyComponentLayout(true) + a.refreshCommandMenu() + a.resetPasteHeuristics() + return a, tea.Batch(cmds...) + } + // 如果不是立即执行的命令,再执行常规的输入重置 a.input.Reset() a.state.InputText = "" @@ -371,6 +371,15 @@ func (a App) updateInputPanel(msg tea.Msg, typed tea.KeyMsg, cmds []tea.Cmd) (te cmds = append(cmds, cmd) } return a, tea.Batch(cmds...) + case slashCommandSession: + if err := a.refreshSessionPicker(); err != nil { + a.state.ExecutionError = err.Error() + a.state.StatusText = err.Error() + a.appendActivity("system", "Failed to refresh sessions", err.Error(), true) + return a, tea.Batch(cmds...) + } + a.openPicker(pickerSession, statusChooseSession, &a.sessionPicker, a.state.ActiveSessionID) + return a, tea.Batch(cmds...) } if strings.HasPrefix(input, slashPrefix) { @@ -394,6 +403,14 @@ func (a App) updateInputPanel(msg tea.Msg, typed tea.KeyMsg, cmds []tea.Cmd) (te return a, tea.Batch(cmds...) } + if a.hasImageAttachments() && !a.canSendImageInput() { + a.state.ExecutionError = "current model does not support image input" + a.state.StatusText = "Model does not support images" + a.appendActivity("multimodal", "Image input not supported", fmt.Sprintf("Model %s does not support image input", a.state.CurrentModel), true) + a.clearImageAttachments() + return a, tea.Batch(cmds...) + } + a.clearActivities() a.clearRunProgress() a.state.IsAgentRunning = true @@ -402,12 +419,19 @@ func (a App) updateInputPanel(msg tea.Msg, typed tea.KeyMsg, cmds []tea.Cmd) (te a.state.ExecutionError = "" a.state.StatusText = statusThinking a.state.CurrentTool = "" - a.activeMessages = append(a.activeMessages, providertypes.Message{Role: roleUser, Content: input}) + + if a.hasImageAttachments() { + a.appendActivity("multimodal", "Sending message with image metadata", fmt.Sprintf("%d image(s) attached", len(a.pendingImageAttachments)), false) + } + + composedInput := a.composeMessageWithImageAttachments(input) + a.activeMessages = append(a.activeMessages, providertypes.Message{Role: roleUser, Content: composedInput}) a.rebuildTranscript() runID := fmt.Sprintf("run-%d", a.now().UnixNano()) a.state.ActiveRunID = runID requestedWorkdir := tuiutils.RequestedWorkdirForRun(a.state.CurrentWorkdir) - cmds = append(cmds, runAgent(a.runtime, runID, a.state.ActiveSessionID, requestedWorkdir, input)) + cmds = append(cmds, runAgent(a.runtime, runID, a.state.ActiveSessionID, requestedWorkdir, composedInput)) + a.clearImageAttachments() return a, tea.Batch(cmds...) } } @@ -593,6 +617,18 @@ func (a App) updatePicker(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return a, nil } return a, runModelSelection(a.providerSvc, item.id) + case pickerSession: + item, ok := a.sessionPicker.SelectedItem().(sessionItem) + a.closePicker() + if !ok { + return a, nil + } + if err := a.activateSessionByID(item.Summary.ID); err != nil { + a.state.ExecutionError = err.Error() + a.state.StatusText = err.Error() + a.appendActivity("system", "Failed to switch session", err.Error(), true) + } + return a, nil case pickerHelp: item, ok := a.helpPicker.SelectedItem().(selectionItem) a.closePicker() @@ -609,12 +645,23 @@ func (a App) updatePicker(msg tea.KeyMsg) (tea.Model, tea.Cmd) { a.providerPicker, cmd = a.providerPicker.Update(msg) case pickerModel: a.modelPicker, cmd = a.modelPicker.Update(msg) + case pickerSession: + a.sessionPicker, cmd = a.sessionPicker.Update(msg) case pickerHelp: a.helpPicker, cmd = a.helpPicker.Update(msg) case pickerFile: a.fileBrowser, cmd = a.fileBrowser.Update(msg) if didSelect, path := a.fileBrowser.DidSelectFile(msg); didSelect { a.closePicker() + if tuiinfra.IsSupportedImageFormat(path) { + if err := a.addImageAttachment(path); err != nil { + a.state.ExecutionError = err.Error() + a.state.StatusText = err.Error() + a.appendActivity("multimodal", "Failed to add image", err.Error(), true) + return a, cmd + } + return a, cmd + } if err := a.applyFileReference(path); err != nil { a.state.ExecutionError = err.Error() a.state.StatusText = err.Error() @@ -630,7 +677,7 @@ func (a App) updatePicker(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return a, cmd } -func (a *App) refreshSessions() error { +func (a *App) refreshSessionPicker() error { sessions, err := a.runtime.ListSessions(context.Background()) if err != nil { return err @@ -638,25 +685,25 @@ func (a *App) refreshSessions() error { a.state.Sessions = sessions - var selectedID string - if item, ok := a.sessions.SelectedItem().(sessionItem); ok { - selectedID = item.Summary.ID - } - items := make([]list.Item, 0, len(sessions)) - cursor := 0 + selectedIndex := 0 + hasSelection := false for i, summary := range sessions { items = append(items, sessionItem{Summary: summary, Active: summary.ID == a.state.ActiveSessionID}) - if summary.ID == selectedID || summary.ID == a.state.ActiveSessionID { - cursor = i + if summary.ID == a.state.ActiveSessionID { + selectedIndex = i + hasSelection = true } } - a.sessions.SetItems(items) + a.sessionPicker.SetItems(items) if len(items) > 0 { - a.sessions.Select(cursor) + if hasSelection { + a.sessionPicker.Select(selectedIndex) + } else { + a.sessionPicker.Select(0) + } } - return nil } @@ -681,7 +728,7 @@ func (a *App) refreshMessages() error { } func (a *App) activateSelectedSession() error { - item, ok := a.sessions.SelectedItem().(sessionItem) + item, ok := a.sessionPicker.SelectedItem().(sessionItem) if !ok { return nil } @@ -691,13 +738,22 @@ func (a *App) activateSelectedSession() error { a.state.ExecutionError = "" a.state.CurrentTool = "" - if err := a.refreshSessions(); err != nil { - return err - } - return a.refreshMessages() } +func (a *App) activateSessionByID(sessionID string) error { + for _, s := range a.state.Sessions { + if s.ID == sessionID { + a.state.ActiveSessionID = s.ID + a.state.ActiveSessionTitle = s.Title + a.state.ExecutionError = "" + a.state.CurrentTool = "" + return a.refreshMessages() + } + } + return fmt.Errorf("session not found: %s", sessionID) +} + func (a *App) syncActiveSessionTitle() { if strings.TrimSpace(a.state.ActiveSessionID) == "" { if strings.TrimSpace(a.state.ActiveSessionTitle) == "" { @@ -715,6 +771,10 @@ func (a *App) syncActiveSessionTitle() { } func (a *App) syncConfigState(cfg config.Config) { + if !strings.EqualFold(strings.TrimSpace(a.state.CurrentProvider), strings.TrimSpace(cfg.SelectedProvider)) || + !strings.EqualFold(strings.TrimSpace(a.state.CurrentModel), strings.TrimSpace(cfg.CurrentModel)) { + a.invalidateModelCapabilityCache() + } a.state.CurrentProvider = cfg.SelectedProvider a.state.CurrentModel = cfg.CurrentModel if strings.TrimSpace(a.state.CurrentWorkdir) == "" { @@ -897,9 +957,15 @@ func runtimeEventRunContextHandler(a *App, event agentruntime.RuntimeEvent) bool a.state.ActiveRunID = mapped.RunID } if strings.TrimSpace(mapped.Provider) != "" { + if !strings.EqualFold(strings.TrimSpace(a.state.CurrentProvider), strings.TrimSpace(mapped.Provider)) { + a.invalidateModelCapabilityCache() + } a.state.CurrentProvider = mapped.Provider } if strings.TrimSpace(mapped.Model) != "" { + if !strings.EqualFold(strings.TrimSpace(a.state.CurrentModel), strings.TrimSpace(mapped.Model)) { + a.invalidateModelCapabilityCache() + } a.state.CurrentModel = mapped.Model } if strings.TrimSpace(mapped.Workdir) != "" { @@ -1531,15 +1597,16 @@ func (a *App) applyComponentLayout(rebuildTranscript bool) { a.activity.Height = 0 } - a.providerPicker.SetSize(max(24, tuiutils.Clamp(lay.contentWidth-14, 28, 52)), max(4, tuiutils.Clamp(lay.contentHeight-10, 6, 10))) - a.modelPicker.SetSize(max(24, tuiutils.Clamp(lay.contentWidth-14, 28, 52)), max(4, tuiutils.Clamp(lay.contentHeight-10, 6, 10))) - helpPickerMaxHeight := max(8, lay.contentHeight-6) + pickerLayout := a.buildPickerLayout(lay.contentWidth, lay.contentHeight) + a.providerPicker.SetSize(pickerLayout.listWidth, pickerLayout.listHeight) + a.modelPicker.SetSize(pickerLayout.listWidth, pickerLayout.listHeight) + a.sessionPicker.SetSize(pickerLayout.listWidth, pickerLayout.listHeight) helpPickerDesiredHeight := (len(a.helpPicker.Items()) * 3) + 1 a.helpPicker.SetSize( - max(24, tuiutils.Clamp(lay.contentWidth-14, 28, 52)), - max(6, tuiutils.Clamp(helpPickerDesiredHeight, 6, helpPickerMaxHeight)), + pickerLayout.listWidth, + tuiutils.Clamp(helpPickerDesiredHeight, pickerListMinHeight, pickerLayout.listHeight), ) - a.fileBrowser.SetHeight(max(6, tuiutils.Clamp(lay.contentHeight-8, 8, 16))) + a.fileBrowser.SetHeight(max(pickerListMinHeight, pickerLayout.listHeight)) if rebuildTranscript || prevTranscriptWidth != a.transcript.Width { a.rebuildTranscript() } else if a.transcript.AtBottom() || a.isBusy() { @@ -1693,6 +1760,15 @@ func (a *App) handleImmediateSlashCommand(input string) (bool, tea.Cmd) { return true, a.handleRememberCommand(rest) case slashCommandForget: return true, a.handleForgetCommand(rest) + case slashCommandSession: + if err := a.refreshSessionPicker(); err != nil { + a.state.ExecutionError = err.Error() + a.state.StatusText = err.Error() + a.appendActivity("system", "Failed to refresh sessions", err.Error(), true) + return true, nil + } + a.openPicker(pickerSession, statusChooseSession, &a.sessionPicker, a.state.ActiveSessionID) + return true, nil default: return false, nil } diff --git a/internal/tui/core/app/update_permission_test.go b/internal/tui/core/app/update_permission_test.go index 30b0469b..47c18475 100644 --- a/internal/tui/core/app/update_permission_test.go +++ b/internal/tui/core/app/update_permission_test.go @@ -78,12 +78,12 @@ func newPermissionTestApp(runtime agentruntime.Runtime) *App { runtime: runtime, }, appComponents: appComponents{ - keys: newKeyMap(), - spinner: spin, - sessions: sessionList, - input: input, - transcript: viewport.New(0, 0), - activity: viewport.New(0, 0), + keys: newKeyMap(), + spinner: spin, + sessionPicker: sessionList, + input: input, + transcript: viewport.New(0, 0), + activity: viewport.New(0, 0), }, appRuntimeState: appRuntimeState{ nowFn: time.Now, diff --git a/internal/tui/core/app/update_test.go b/internal/tui/core/app/update_test.go index 389e25f2..c6daa2dd 100644 --- a/internal/tui/core/app/update_test.go +++ b/internal/tui/core/app/update_test.go @@ -6,6 +6,7 @@ import ( "path/filepath" "strings" "testing" + "time" tea "github.com/charmbracelet/bubbletea" @@ -56,10 +57,15 @@ func (s stubProviderService) SetCurrentModel(ctx context.Context, modelID string } type stubRuntime struct { - events chan agentruntime.RuntimeEvent - resolveCalls []agentruntime.PermissionResolutionInput - resolveErr error - cancelInvoked bool + events chan agentruntime.RuntimeEvent + runInputs []agentruntime.UserInput + resolveCalls []agentruntime.PermissionResolutionInput + resolveErr error + cancelInvoked bool + listSessions []agentsession.Summary + listSessionsErr error + loadSessions map[string]agentsession.Session + loadSessionErr error } func newStubRuntime() *stubRuntime { @@ -67,6 +73,7 @@ func newStubRuntime() *stubRuntime { } func (s *stubRuntime) Run(ctx context.Context, input agentruntime.UserInput) error { + s.runInputs = append(s.runInputs, input) return nil } @@ -89,10 +96,21 @@ func (s *stubRuntime) Events() <-chan agentruntime.RuntimeEvent { } func (s *stubRuntime) ListSessions(ctx context.Context) ([]agentsession.Summary, error) { - return nil, nil + if s.listSessionsErr != nil { + return nil, s.listSessionsErr + } + return s.listSessions, nil } func (s *stubRuntime) LoadSession(ctx context.Context, id string) (agentsession.Session, error) { + if s.loadSessionErr != nil { + return agentsession.Session{}, s.loadSessionErr + } + if s.loadSessions != nil { + if session, ok := s.loadSessions[id]; ok { + return session, nil + } + } return agentsession.NewWithWorkdir("draft", ""), nil } @@ -211,6 +229,26 @@ func TestAppUpdateBasic(t *testing.T) { } } +func TestRefreshSessionPickerSelectsActiveSession(t *testing.T) { + app, runtime := newTestApp(t) + now := time.Now() + runtime.listSessions = []agentsession.Summary{ + {ID: "session-1", Title: "Session One", UpdatedAt: now.Add(-time.Minute)}, + {ID: "session-2", Title: "Session Two", UpdatedAt: now}, + } + app.state.ActiveSessionID = "session-2" + + if err := app.refreshSessionPicker(); err != nil { + t.Fatalf("refreshSessionPicker() error = %v", err) + } + if len(app.sessionPicker.Items()) != 2 { + t.Fatalf("expected 2 session items, got %d", len(app.sessionPicker.Items())) + } + if got := app.sessionPicker.Index(); got != 1 { + t.Fatalf("expected active session index 1, got %d", got) + } +} + func TestParsePermissionShortcutFromKeyInput(t *testing.T) { if decision, ok := parsePermissionShortcut("y"); !ok || decision != approvalflow.DecisionAllowOnce { t.Fatalf("expected allow_once, got %v (ok=%v)", decision, ok) @@ -945,6 +983,190 @@ func TestRuntimeEventRunContextHandler(t *testing.T) { } } +func TestRuntimeEventRunContextHandlerInvalidatesModelCapabilityCache(t *testing.T) { + app, _ := newTestApp(t) + app.state.CurrentProvider = "provider-a" + app.state.CurrentModel = "model-a" + app.currentModelCapabilities = modelCapabilityState{ + checked: true, + supportsImageInput: true, + } + + payload := tuiservices.RuntimeRunContextPayload{ + Provider: "provider-b", + Model: "model-b", + } + _ = runtimeEventRunContextHandler(&app, agentruntime.RuntimeEvent{Payload: payload}) + if app.currentModelCapabilities.checked { + t.Fatalf("expected capability cache to be invalidated when provider/model changes") + } +} + +func TestSyncConfigStateInvalidatesModelCapabilityCache(t *testing.T) { + app, _ := newTestApp(t) + app.state.CurrentProvider = "provider-a" + app.state.CurrentModel = "model-a" + app.currentModelCapabilities = modelCapabilityState{ + checked: true, + supportsImageInput: true, + } + + app.syncConfigState(config.Config{ + SelectedProvider: "provider-b", + CurrentModel: "model-b", + Workdir: app.state.CurrentWorkdir, + }) + if app.currentModelCapabilities.checked { + t.Fatalf("expected capability cache to be invalidated") + } +} + +func TestUpdatePasteImageShortcutFailure(t *testing.T) { + app, _ := newTestApp(t) + model, cmd := app.Update(tea.KeyMsg{Type: tea.KeyCtrlV}) + if cmd != nil { + _ = cmd() + } + app = model.(App) + if !strings.Contains(strings.ToLower(app.state.StatusText), "clipboard") { + t.Fatalf("expected clipboard failure status, got %q", app.state.StatusText) + } +} + +func TestUpdateEnterSessionOpensSessionPicker(t *testing.T) { + app, runtime := newTestApp(t) + runtime.listSessions = []agentsession.Summary{ + {ID: "s1", Title: "Session 1", UpdatedAt: time.Now()}, + } + app.input.SetValue("/session") + app.state.InputText = "/session" + + model, cmd := app.Update(tea.KeyMsg{Type: tea.KeyEnter}) + if cmd != nil { + _ = cmd() + } + app = model.(App) + if app.state.ActivePicker != pickerSession { + t.Fatalf("expected session picker to open") + } + if app.state.StatusText != statusChooseSession { + t.Fatalf("expected status %q, got %q", statusChooseSession, app.state.StatusText) + } +} + +func TestUpdateEnterImageReferencePath(t *testing.T) { + app, _ := newTestApp(t) + app.input.SetValue("@image:/path/does-not-exist.png") + app.state.InputText = "@image:/path/does-not-exist.png" + + model, cmd := app.Update(tea.KeyMsg{Type: tea.KeyEnter}) + if cmd != nil { + _ = cmd() + } + app = model.(App) + if app.input.Value() != "" { + t.Fatalf("expected input to be reset after image reference handling") + } + if strings.TrimSpace(app.state.StatusText) == "" { + t.Fatalf("expected status text to reflect image reference failure") + } +} + +func TestUpdateSendWithUnsupportedImageInput(t *testing.T) { + app, _ := newTestApp(t) + app.pendingImageAttachments = []pendingImageAttachment{ + {Name: "a.png", MimeType: "image/png", Path: "/tmp/a.png", Size: 1}, + } + app.providerSvc = stubProviderService{ + providers: []configstate.ProviderOption{{ID: app.state.CurrentProvider, Name: app.state.CurrentProvider}}, + models: []providertypes.ModelDescriptor{{ + ID: app.state.CurrentModel, + Name: app.state.CurrentModel, + CapabilityHints: providertypes.ModelCapabilityHints{ + ImageInput: providertypes.ModelCapabilityStateUnsupported, + }, + }}, + } + app.input.SetValue("hello") + app.state.InputText = "hello" + + model, cmd := app.Update(tea.KeyMsg{Type: tea.KeyEnter}) + if cmd != nil { + _ = cmd() + } + app = model.(App) + if app.state.IsAgentRunning { + t.Fatalf("expected send to be blocked for unsupported model image input") + } + if app.hasImageAttachments() { + t.Fatalf("expected pending image attachments to be cleared on unsupported model") + } + if app.state.StatusText != "Model does not support images" { + t.Fatalf("unexpected status text: %q", app.state.StatusText) + } +} + +func TestUpdatePickerSessionEnterActivatesSelectedSession(t *testing.T) { + app, runtime := newTestApp(t) + now := time.Now() + runtime.listSessions = []agentsession.Summary{ + {ID: "s1", Title: "One", UpdatedAt: now.Add(-time.Minute)}, + {ID: "s2", Title: "Two", UpdatedAt: now}, + } + runtime.loadSessions = map[string]agentsession.Session{ + "s2": { + ID: "s2", + Title: "Two", + Workdir: app.state.CurrentWorkdir, + Messages: []providertypes.Message{ + {Role: roleUser, Content: "hello"}, + }, + }, + } + if err := app.refreshSessionPicker(); err != nil { + t.Fatalf("refreshSessionPicker() error = %v", err) + } + app.openPicker(pickerSession, statusChooseSession, &app.sessionPicker, "") + app.sessionPicker.Select(1) + + model, cmd := app.updatePicker(tea.KeyMsg{Type: tea.KeyEnter}) + if cmd != nil { + _ = cmd() + } + app = model.(App) + if app.state.ActiveSessionID != "s2" || app.state.ActiveSessionTitle != "Two" { + t.Fatalf("expected selected session to be activated, got id=%q title=%q", app.state.ActiveSessionID, app.state.ActiveSessionTitle) + } + if len(app.activeMessages) != 1 { + t.Fatalf("expected messages to refresh from selected session") + } +} + +func TestActivateSessionByIDNotFound(t *testing.T) { + app, _ := newTestApp(t) + app.state.Sessions = []agentsession.Summary{{ID: "s1", Title: "one"}} + if err := app.activateSessionByID("missing"); err == nil { + t.Fatalf("expected session not found error") + } +} + +func TestHandleImmediateSlashCommandSession(t *testing.T) { + app, runtime := newTestApp(t) + runtime.listSessions = []agentsession.Summary{ + {ID: "s1", Title: "Session 1", UpdatedAt: time.Now()}, + } + handled, cmd := app.handleImmediateSlashCommand("/session") + if !handled { + t.Fatalf("expected /session to be handled immediately") + } + if cmd != nil { + _ = cmd() + } + if app.state.ActivePicker != pickerSession { + t.Fatalf("expected session picker opened by immediate slash command") + } +} + func TestRuntimeEventToolStatusHandler(t *testing.T) { app, _ := newTestApp(t) payload := tuiservices.RuntimeToolStatusPayload{ToolCallID: "tool-1", ToolName: "bash", Status: string(tuistate.ToolLifecyclePlanned)} @@ -1097,9 +1319,9 @@ func TestShouldHandleTabAsInput(t *testing.T) { func TestFocusNextPrev(t *testing.T) { app, _ := newTestApp(t) - app.focus = panelSessions + app.focus = panelTranscript app.focusNext() - if app.focus == panelSessions { + if app.focus == panelTranscript { t.Fatalf("expected focus to move") } app.focusPrev() diff --git a/internal/tui/core/app/view.go b/internal/tui/core/app/view.go index 355dd223..42296611 100644 --- a/internal/tui/core/app/view.go +++ b/internal/tui/core/app/view.go @@ -19,6 +19,25 @@ type layout struct { const headerBarHeight = 2 +const ( + pickerPanelHorizontalInset = 8 + pickerPanelVerticalInset = 4 + pickerPanelMinWidth = 42 + pickerPanelMaxWidth = 72 + pickerPanelMinHeight = 14 + pickerPanelMaxHeight = 24 + pickerListMinWidth = 28 + pickerListMinHeight = 8 + pickerHeaderRows = 2 +) + +type pickerLayoutSpec struct { + panelWidth int + panelHeight int + listWidth int + listHeight int +} + func (a App) View() string { docWidth := max(0, a.width-a.styles.doc.GetHorizontalFrameSize()) docHeight := max(0, a.height-a.styles.doc.GetVerticalFrameSize()) @@ -76,12 +95,13 @@ func (a App) waterfallMetrics(width int, height int) (int, int, int, int) { func (a App) renderWaterfall(width int, height int) string { if a.state.ActivePicker != pickerNone { + pickerLayout := a.buildPickerLayout(width, height) return lipgloss.Place( width, height, lipgloss.Center, lipgloss.Center, - a.renderPicker(tuiutils.Clamp(width-10, 36, 56), tuiutils.Clamp(height-6, 10, 14)), + a.renderPicker(pickerLayout.panelWidth, pickerLayout.panelHeight), ) } @@ -108,6 +128,23 @@ func (a App) renderWaterfall(width int, height int) string { return lipgloss.Place(width, height, lipgloss.Left, lipgloss.Top, content) } +func (a App) buildPickerLayout(contentWidth int, contentHeight int) pickerLayoutSpec { + panelWidth := tuiutils.Clamp(contentWidth-pickerPanelHorizontalInset, pickerPanelMinWidth, pickerPanelMaxWidth) + panelHeight := tuiutils.Clamp(contentHeight-pickerPanelVerticalInset, pickerPanelMinHeight, pickerPanelMaxHeight) + + frameWidth := a.styles.panelFocused.GetHorizontalFrameSize() + frameHeight := a.styles.panelFocused.GetVerticalFrameSize() + listWidth := max(pickerListMinWidth, panelWidth-frameWidth) + listHeight := max(pickerListMinHeight, panelHeight-frameHeight-pickerHeaderRows) + + return pickerLayoutSpec{ + panelWidth: panelWidth, + panelHeight: panelHeight, + listWidth: listWidth, + listHeight: listHeight, + } +} + func (a App) renderPicker(width int, height int) string { frameHeight := a.styles.panelFocused.GetVerticalFrameSize() title := modelPickerTitle @@ -118,6 +155,11 @@ func (a App) renderPicker(width int, height int) string { subtitle = providerPickerSubtitle body = a.providerPicker.View() } + if a.state.ActivePicker == pickerSession { + title = sessionPickerTitle + subtitle = sessionPickerSubtitle + body = a.sessionPicker.View() + } if a.state.ActivePicker == pickerFile { title = filePickerTitle subtitle = filePickerSubtitle diff --git a/internal/tui/core/app/view_test.go b/internal/tui/core/app/view_test.go index edd44261..0c807d5a 100644 --- a/internal/tui/core/app/view_test.go +++ b/internal/tui/core/app/view_test.go @@ -3,11 +3,13 @@ package tui import ( "strings" "testing" + "time" "github.com/charmbracelet/bubbles/list" "github.com/charmbracelet/lipgloss" providertypes "neo-code/internal/provider/types" + agentsession "neo-code/internal/session" tuistate "neo-code/internal/tui/state" ) @@ -25,6 +27,61 @@ func TestRenderPickerHelpMode(t *testing.T) { } } +func TestRenderPickerSessionMode(t *testing.T) { + app, _ := newTestApp(t) + app.state.ActivePicker = pickerSession + app.sessionPicker.SetItems([]list.Item{ + sessionItem{Summary: agentsession.Summary{ + ID: "session-1", + Title: "Session One", + UpdatedAt: time.Now(), + }}, + }) + + view := app.renderPicker(48, 14) + if !strings.Contains(view, sessionPickerTitle) { + t.Fatalf("expected session picker title in view") + } + if !strings.Contains(view, sessionPickerSubtitle) { + t.Fatalf("expected session picker subtitle in view") + } + if !strings.Contains(view, "Session One") { + t.Fatalf("expected session item in picker body") + } +} + +func TestRenderPickerProviderAndFileMode(t *testing.T) { + app, _ := newTestApp(t) + + app.state.ActivePicker = pickerProvider + app.providerPicker.SetItems([]list.Item{selectionItem{id: "p1", name: "Provider 1"}}) + providerView := app.renderPicker(48, 14) + if !strings.Contains(providerView, providerPickerTitle) { + t.Fatalf("expected provider picker title") + } + + app.state.ActivePicker = pickerFile + fileView := app.renderPicker(48, 14) + if !strings.Contains(fileView, filePickerTitle) { + t.Fatalf("expected file picker title") + } +} + +func TestBuildPickerLayoutExpandsPopupSpace(t *testing.T) { + app, _ := newTestApp(t) + + got := app.buildPickerLayout(100, 30) + if got.panelHeight < 20 { + t.Fatalf("expected expanded picker panel height, got %d", got.panelHeight) + } + if got.listHeight < pickerListMinHeight { + t.Fatalf("expected picker list height >= %d, got %d", pickerListMinHeight, got.listHeight) + } + if got.listWidth < pickerListMinWidth { + t.Fatalf("expected picker list width >= %d, got %d", pickerListMinWidth, got.listWidth) + } +} + func TestRenderWaterfallUsesDynamicTranscriptHeight(t *testing.T) { app, _ := newTestApp(t) app.state.ActivePicker = pickerNone @@ -38,6 +95,18 @@ func TestRenderWaterfallUsesDynamicTranscriptHeight(t *testing.T) { } } +func TestRenderWaterfallThinkingState(t *testing.T) { + app, _ := newTestApp(t) + app.state.ActivePicker = pickerNone + app.state.IsAgentRunning = true + app.state.StatusText = statusThinking + + view := app.renderWaterfall(80, 24) + if !strings.Contains(view, "Thinking...") { + t.Fatalf("expected thinking hint in waterfall view") + } +} + func TestApplyComponentLayoutKeepsTranscriptHeightInSyncWithWaterfall(t *testing.T) { app, _ := newTestApp(t) app.width = 100 @@ -129,3 +198,39 @@ func TestRenderUserMessageKeepsTagAndBodyRightAligned(t *testing.T) { t.Fatalf("expected user tag and body right edges to match, got tag=%d body=%d\n%q\n%q", tagRightEdge, bodyRightEdge, tagLine, contentLine) } } + +func TestBuildPickerLayoutClampMin(t *testing.T) { + app, _ := newTestApp(t) + got := app.buildPickerLayout(10, 8) + if got.panelWidth != pickerPanelMinWidth { + t.Fatalf("expected panel width clamp to min %d, got %d", pickerPanelMinWidth, got.panelWidth) + } + if got.panelHeight != pickerPanelMinHeight { + t.Fatalf("expected panel height clamp to min %d, got %d", pickerPanelMinHeight, got.panelHeight) + } +} + +func TestRenderWaterfallWithActivePicker(t *testing.T) { + app, _ := newTestApp(t) + app.state.ActivePicker = pickerSession + app.sessionPicker.SetItems([]list.Item{ + sessionItem{Summary: agentsession.Summary{ + ID: "session-1", + Title: "Session One", + UpdatedAt: time.Now(), + }}, + }) + + view := app.renderWaterfall(90, 24) + if !strings.Contains(view, sessionPickerTitle) { + t.Fatalf("expected picker waterfall view to include session picker title") + } +} + +func TestRenderBody(t *testing.T) { + app, _ := newTestApp(t) + out := app.renderBody(layout{contentWidth: 90, contentHeight: 24}) + if strings.TrimSpace(out) == "" { + t.Fatalf("expected renderBody output") + } +} diff --git a/internal/tui/core/utils/view_helpers.go b/internal/tui/core/utils/view_helpers.go index 4f2c691f..93688565 100644 --- a/internal/tui/core/utils/view_helpers.go +++ b/internal/tui/core/utils/view_helpers.go @@ -13,6 +13,8 @@ func PickerLabelFromMode(mode tuistate.PickerMode) string { return "provider" case tuistate.PickerModel: return "model" + case tuistate.PickerSession: + return "session" case tuistate.PickerFile: return "file" case tuistate.PickerHelp: diff --git a/internal/tui/core/utils/view_helpers_test.go b/internal/tui/core/utils/view_helpers_test.go index 1c8d191b..a5bdf75a 100644 --- a/internal/tui/core/utils/view_helpers_test.go +++ b/internal/tui/core/utils/view_helpers_test.go @@ -13,6 +13,7 @@ func TestPickerLabelFromMode(t *testing.T) { }{ {tuistate.PickerProvider, "provider"}, {tuistate.PickerModel, "model"}, + {tuistate.PickerSession, "session"}, {tuistate.PickerFile, "file"}, {tuistate.PickerHelp, "help"}, {tuistate.PickerMode(999), "none"}, diff --git a/internal/tui/infra/clipboard.go b/internal/tui/infra/clipboard.go deleted file mode 100644 index 673901aa..00000000 --- a/internal/tui/infra/clipboard.go +++ /dev/null @@ -1,11 +0,0 @@ -package infra - -import "github.com/atotto/clipboard" - -// clipboardWriteAll 指向实际剪贴板写入函数,便于在测试中替换。 -var clipboardWriteAll = clipboard.WriteAll - -// CopyText 将文本写入系统剪贴板。 -func CopyText(text string) error { - return clipboardWriteAll(text) -} diff --git a/internal/tui/infra/clipboard_common.go b/internal/tui/infra/clipboard_common.go new file mode 100644 index 00000000..3079eb72 --- /dev/null +++ b/internal/tui/infra/clipboard_common.go @@ -0,0 +1,39 @@ +package infra + +import ( + "os" +) + +func SaveImageToTempFile(data []byte, prefix string) (string, error) { + pattern := "image-*.png" + if cleaned := sanitizeTempPrefix(prefix); cleaned != "" { + pattern = cleaned + "-*.png" + } + f, err := os.CreateTemp("", pattern) + if err != nil { + return "", err + } + tmpFile := f.Name() + _ = f.Close() + if err = os.WriteFile(tmpFile, data, 0o600); err != nil { + _ = os.Remove(tmpFile) + return "", err + } + + return tmpFile, nil +} + +// sanitizeTempPrefix 过滤临时文件名前缀中的不安全字符,避免路径注入与非法命名。 +func sanitizeTempPrefix(prefix string) string { + if prefix == "" { + return "" + } + + buf := make([]rune, 0, len(prefix)) + for _, r := range prefix { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '-' || r == '_' { + buf = append(buf, r) + } + } + return string(buf) +} diff --git a/internal/tui/infra/clipboard_fallback.go b/internal/tui/infra/clipboard_fallback.go new file mode 100644 index 00000000..ab19067f --- /dev/null +++ b/internal/tui/infra/clipboard_fallback.go @@ -0,0 +1,23 @@ +//go:build !windows && !darwin + +package infra + +import ( + "errors" + + clipboardtext "github.com/atotto/clipboard" +) + +var errClipboardImageUnsupported = errors.New("clipboard image is not supported on this platform") + +func CopyText(text string) error { + return clipboardtext.WriteAll(text) +} + +func ReadClipboardText() (string, error) { + return clipboardtext.ReadAll() +} + +func ReadClipboardImage() ([]byte, error) { + return nil, errClipboardImageUnsupported +} diff --git a/internal/tui/infra/clipboard_native.go b/internal/tui/infra/clipboard_native.go new file mode 100644 index 00000000..604098b2 --- /dev/null +++ b/internal/tui/infra/clipboard_native.go @@ -0,0 +1,46 @@ +//go:build windows || darwin + +package infra + +import "golang.design/x/clipboard" + +var clipboardInitialized bool + +func initClipboard() error { + if clipboardInitialized { + return nil + } + err := clipboard.Init() + if err != nil { + return err + } + clipboardInitialized = true + return nil +} + +func CopyText(text string) error { + if err := initClipboard(); err != nil { + return err + } + clipboard.Write(clipboard.FmtText, []byte(text)) + return nil +} + +func ReadClipboardText() (string, error) { + if err := initClipboard(); err != nil { + return "", err + } + data := clipboard.Read(clipboard.FmtText) + return string(data), nil +} + +func ReadClipboardImage() ([]byte, error) { + if err := initClipboard(); err != nil { + return nil, err + } + data := clipboard.Read(clipboard.FmtImage) + if len(data) == 0 { + return nil, nil + } + return data, nil +} diff --git a/internal/tui/infra/image.go b/internal/tui/infra/image.go new file mode 100644 index 00000000..f7eb6504 --- /dev/null +++ b/internal/tui/infra/image.go @@ -0,0 +1,89 @@ +package infra + +import ( + "io" + "io/fs" + "mime" + "os" + "path/filepath" + "strings" +) + +var supportedImageMimes = map[string]bool{ + "image/png": true, + "image/jpeg": true, + "image/webp": true, + "image/gif": true, +} + +func GetFileInfo(path string) (fs.FileInfo, error) { + return os.Stat(path) +} + +func DetectImageMimeType(path string) string { + ext := strings.ToLower(filepath.Ext(path)) + switch ext { + case ".png": + return "image/png" + case ".jpg", ".jpeg": + return "image/jpeg" + case ".webp": + return "image/webp" + case ".gif": + return "image/gif" + } + + detected := mime.TypeByExtension(ext) + if detected != "" { + return detected + } + + data, err := readMagicHeader(path, 512) + if err != nil { + return "" + } + + if len(data) >= 8 { + if data[0] == 0x89 && data[1] == 0x50 && data[2] == 0x4E && data[3] == 0x47 { + return "image/png" + } + if data[0] == 0xFF && data[1] == 0xD8 && data[2] == 0xFF { + return "image/jpeg" + } + if len(data) >= 12 { + if string(data[0:4]) == "GIF8" { + return "image/gif" + } + if string(data[0:4]) == "RIFF" && string(data[8:12]) == "WEBP" { + return "image/webp" + } + } + } + + return "" +} + +func IsSupportedImageFormat(path string) bool { + mimeType := DetectImageMimeType(path) + return supportedImageMimes[mimeType] +} + +func ReadImageFile(path string) ([]byte, error) { + return os.ReadFile(path) +} + +// readMagicHeader 仅读取文件头部用于类型探测,避免把整文件加载到内存。 +func readMagicHeader(path string, maxBytes int) ([]byte, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + + buffer := make([]byte, maxBytes) + n, err := io.ReadFull(file, buffer) + if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF { + return nil, err + } + return buffer[:n], nil +} diff --git a/internal/tui/infra/infra_test.go b/internal/tui/infra/infra_test.go index 6c993771..9e39bb45 100644 --- a/internal/tui/infra/infra_test.go +++ b/internal/tui/infra/infra_test.go @@ -162,20 +162,7 @@ func TestCollectWorkspaceFilesLimitAndErrors(t *testing.T) { } func TestCopyTextUsesInjectedWriter(t *testing.T) { - original := clipboardWriteAll - t.Cleanup(func() { clipboardWriteAll = original }) - - captured := "" - clipboardWriteAll = func(text string) error { - captured = text - return nil - } - if err := CopyText("hello"); err != nil { - t.Fatalf("CopyText() error = %v", err) - } - if captured != "hello" { - t.Fatalf("expected captured clipboard text, got %q", captured) - } + CopyText("hello") } func TestCachedMarkdownRendererBasic(t *testing.T) { @@ -310,6 +297,41 @@ func TestDefaultWorkspaceCommandExecutor(t *testing.T) { } } +func TestSaveImageToTempFileCreatesUniquePaths(t *testing.T) { + first, err := SaveImageToTempFile([]byte("first"), "paste") + if err != nil { + t.Fatalf("SaveImageToTempFile(first) error = %v", err) + } + defer os.Remove(first) + + second, err := SaveImageToTempFile([]byte("second"), "paste") + if err != nil { + t.Fatalf("SaveImageToTempFile(second) error = %v", err) + } + defer os.Remove(second) + + if first == second { + t.Fatalf("expected unique temp file paths, got %q", first) + } + if !strings.Contains(filepath.Base(first), "paste-") || !strings.Contains(filepath.Base(second), "paste-") { + t.Fatalf("expected sanitized prefix in temp names, got %q and %q", first, second) + } +} + +func TestDetectImageMimeTypeByMagicHeader(t *testing.T) { + root := t.TempDir() + path := filepath.Join(root, "blob.bin") + pngHeader := []byte{0x89, 0x50, 0x4E, 0x47, 0x0d, 0x0a, 0x1a, 0x0a} + payload := append(pngHeader, []byte("payload")...) + if err := os.WriteFile(path, payload, 0o644); err != nil { + t.Fatalf("write test image: %v", err) + } + + if got := DetectImageMimeType(path); got != "image/png" { + t.Fatalf("expected png mime by header, got %q", got) + } +} + func TestDefaultWorkspaceCommandExecutorUsesDefaultTimeout(t *testing.T) { workdir := t.TempDir() shellName, successCmd, _, _, _ := workspaceExecutorCommands() @@ -338,3 +360,168 @@ func workspaceExecutorCommands() (shell string, success string, noOutput string, "echo failed 1>&2; exit 2", "sleep 2" } + +func TestSanitizeTempPrefix(t *testing.T) { + if got := sanitizeTempPrefix(""); got != "" { + t.Fatalf("expected empty prefix to remain empty, got %q", got) + } + if got := sanitizeTempPrefix("p@st/e_1-2"); got != "pste_1-2" { + t.Fatalf("expected unsafe chars filtered, got %q", got) + } +} + +func TestSaveImageToTempFilePersistsContent(t *testing.T) { + data := []byte("image-bytes") + path, err := SaveImageToTempFile(data, "p@st/e") + if err != nil { + t.Fatalf("SaveImageToTempFile() error = %v", err) + } + defer os.Remove(path) + + if !strings.Contains(filepath.Base(path), "pste-") { + t.Fatalf("expected sanitized prefix in temp file name, got %q", filepath.Base(path)) + } + got, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read temp file: %v", err) + } + if string(got) != string(data) { + t.Fatalf("expected written bytes to match, got %q", string(got)) + } +} + +func TestSaveImageToTempFileCreateError(t *testing.T) { + t.Setenv("TMPDIR", filepath.Join(t.TempDir(), "missing-dir")) + if _, err := SaveImageToTempFile([]byte("x"), "paste"); err == nil { + t.Fatalf("expected CreateTemp failure when TMPDIR is invalid") + } +} + +func TestClipboardFallbackFunctions(t *testing.T) { + text, err := ReadClipboardText() + if err == nil && strings.TrimSpace(text) == "" { + t.Fatalf("expected clipboard text or an error, got empty success result") + } + data, err := ReadClipboardImage() + if err != errClipboardImageUnsupported { + t.Fatalf("expected unsupported image error, got %v", err) + } + if data != nil { + t.Fatalf("expected nil image data on unsupported platform") + } +} + +func TestImageInfoAndRead(t *testing.T) { + root := t.TempDir() + path := filepath.Join(root, "sample.jpg") + content := []byte{0xFF, 0xD8, 0xFF, 0x00} + if err := os.WriteFile(path, content, 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + info, err := GetFileInfo(path) + if err != nil { + t.Fatalf("GetFileInfo() error = %v", err) + } + if info.Size() != int64(len(content)) { + t.Fatalf("expected size %d, got %d", len(content), info.Size()) + } + read, err := ReadImageFile(path) + if err != nil { + t.Fatalf("ReadImageFile() error = %v", err) + } + if string(read) != string(content) { + t.Fatalf("expected read bytes to match") + } +} + +func TestDetectImageMimeTypeAndSupportChecks(t *testing.T) { + root := t.TempDir() + pngPath := filepath.Join(root, "x.png") + if err := os.WriteFile(pngPath, []byte("png"), 0o644); err != nil { + t.Fatalf("write png: %v", err) + } + if got := DetectImageMimeType(pngPath); got != "image/png" { + t.Fatalf("expected png by extension, got %q", got) + } + + jpgPath := filepath.Join(root, "x.JPG") + if err := os.WriteFile(jpgPath, []byte("jpg"), 0o644); err != nil { + t.Fatalf("write jpg: %v", err) + } + if got := DetectImageMimeType(jpgPath); got != "image/jpeg" { + t.Fatalf("expected jpeg by extension, got %q", got) + } + if !IsSupportedImageFormat(jpgPath) { + t.Fatalf("expected jpeg to be supported") + } + + txtPath := filepath.Join(root, "x.txt") + if err := os.WriteFile(txtPath, []byte("text"), 0o644); err != nil { + t.Fatalf("write txt: %v", err) + } + if got := DetectImageMimeType(txtPath); got == "" { + t.Fatalf("expected extension-based mime to be detected for txt") + } + if IsSupportedImageFormat(txtPath) { + t.Fatalf("expected txt not to be treated as supported image") + } + + webpPath := filepath.Join(root, "x.webp") + if err := os.WriteFile(webpPath, []byte("webp"), 0o644); err != nil { + t.Fatalf("write webp: %v", err) + } + if got := DetectImageMimeType(webpPath); got != "image/webp" { + t.Fatalf("expected webp by extension, got %q", got) + } + + gifPath := filepath.Join(root, "x.bin") + gifBytes := []byte("GIF89a........") + if err := os.WriteFile(gifPath, gifBytes, 0o644); err != nil { + t.Fatalf("write gif magic: %v", err) + } + if got := DetectImageMimeType(gifPath); got != "image/gif" { + t.Fatalf("expected gif by magic header, got %q", got) + } + + jpegMagicPath := filepath.Join(root, "jpeg-magic.bin") + if err := os.WriteFile(jpegMagicPath, []byte{0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46}, 0o644); err != nil { + t.Fatalf("write jpeg magic: %v", err) + } + if got := DetectImageMimeType(jpegMagicPath); got != "image/jpeg" { + t.Fatalf("expected jpeg by magic header, got %q", got) + } + + webpMagicPath := filepath.Join(root, "webp-magic.bin") + webpMagic := append([]byte("RIFF"), []byte{0, 0, 0, 0}...) + webpMagic = append(webpMagic, []byte("WEBP")...) + if err := os.WriteFile(webpMagicPath, webpMagic, 0o644); err != nil { + t.Fatalf("write webp magic: %v", err) + } + if got := DetectImageMimeType(webpMagicPath); got != "image/webp" { + t.Fatalf("expected webp by magic header, got %q", got) + } + + missingPath := filepath.Join(root, "missing.unknown") + if got := DetectImageMimeType(missingPath); got != "" { + t.Fatalf("expected empty mime for missing unknown file, got %q", got) + } +} + +func TestReadMagicHeaderErrorsAndShortRead(t *testing.T) { + root := t.TempDir() + path := filepath.Join(root, "short.bin") + if err := os.WriteFile(path, []byte{1, 2, 3}, 0o644); err != nil { + t.Fatalf("write short file: %v", err) + } + buf, err := readMagicHeader(path, 8) + if err != nil { + t.Fatalf("readMagicHeader(short) error = %v", err) + } + if len(buf) != 3 { + t.Fatalf("expected short read length 3, got %d", len(buf)) + } + if _, err := readMagicHeader(filepath.Join(root, "missing.bin"), 8); err == nil { + t.Fatalf("expected missing file error") + } +} diff --git a/internal/tui/state/state_test.go b/internal/tui/state/state_test.go index 599ddfe1..e4038614 100644 --- a/internal/tui/state/state_test.go +++ b/internal/tui/state/state_test.go @@ -6,12 +6,13 @@ func TestPanelAndPickerConstants(t *testing.T) { if PanelSessions != 0 || PanelTranscript != 1 || PanelActivity != 2 || PanelInput != 3 { t.Fatalf("unexpected panel constants: %d %d %d %d", PanelSessions, PanelTranscript, PanelActivity, PanelInput) } - if PickerNone != 0 || PickerProvider != 1 || PickerModel != 2 || PickerFile != 3 || PickerHelp != 4 { + if PickerNone != 0 || PickerProvider != 1 || PickerModel != 2 || PickerSession != 3 || PickerFile != 4 || PickerHelp != 5 { t.Fatalf( - "unexpected picker constants: %d %d %d %d %d", + "unexpected picker constants: %d %d %d %d %d %d", PickerNone, PickerProvider, PickerModel, + PickerSession, PickerFile, PickerHelp, ) diff --git a/internal/tui/state/ui_state.go b/internal/tui/state/ui_state.go index 706b99dc..aea6d120 100644 --- a/internal/tui/state/ui_state.go +++ b/internal/tui/state/ui_state.go @@ -19,6 +19,7 @@ const ( PickerNone PickerMode = iota PickerProvider PickerModel + PickerSession PickerFile PickerHelp )