diff --git a/docs/config-management-detail-design.md b/docs/config-management-detail-design.md index e7ef7836..d0c023ab 100644 --- a/docs/config-management-detail-design.md +++ b/docs/config-management-detail-design.md @@ -130,3 +130,14 @@ custom provider 来自: - 不在 `config` 中堆兼容旧字段的逻辑 - 不把选择修正混进快照校验 - 不把 provider/catalog/runtime 语义倒灌回 `config` + +## custom provider `models` 校验约束 + +`~/.neocode/providers//provider.yaml` 中允许通过 `models` 补齐模型元数据,用于 catalog/discovery 无法提供完整 `ContextWindow` 或 `MaxOutputTokens` 的场景。 + +该能力的约束是: + +- `models[].id` 必须非空。 +- `models[].context_window` 和 `models[].max_output_tokens` 如果显式提供,必须大于 `0`。 +- 重复的模型 `id` 会在加载 custom provider 时直接失败,不保留 silently drop 的宽松行为。 +- 这些元数据不会写回 `config.yaml`,只在 custom provider 文件中声明,并通过现有 catalog 合并链路参与运行时解析。 diff --git a/docs/context-compact.md b/docs/context-compact.md index 87fded85..86c20a63 100644 --- a/docs/context-compact.md +++ b/docs/context-compact.md @@ -1,5 +1,11 @@ # Context Compact +## Auto Compact Failure Fallback + +- 当 `context.auto_compact.input_token_threshold <= 0` 时,系统会尝试基于当前模型的 `ContextWindow` 自动推导阈值。 +- 若当前 provider 选择无效、catalog snapshot 查询失败,或模型窗口元数据缺失,系统会直接回退到 `fallback_input_token_threshold`。 +- 自动推导失败不会静默关闭 auto compact;runtime 仍会拿到一个可用的保底阈值。 + 本文档说明 NeoCode 中 context compact 的配置、执行链路和摘要约定。 ## 概览 @@ -23,7 +29,9 @@ context: micro_compact_disabled: false auto_compact: enabled: false - input_token_threshold: 100000 + input_token_threshold: 0 + reserve_tokens: 13000 + fallback_input_token_threshold: 100000 ``` - `manual_strategy` @@ -39,7 +47,7 @@ context: - `auto_compact.enabled` 控制是否启用基于 token 阈值的自动压缩;默认关闭。 - `auto_compact.input_token_threshold` - 当会话累计输入 token 数达到此阈值时触发自动压缩;默认 100000。 + 当会话累计输入 token 数达到此阈值时触发自动压缩;默认 `0`(自动推导),推导失败时回退到 `fallback_input_token_threshold`(默认 `100000`)。 ## 自动压缩 @@ -154,3 +162,12 @@ compact 相关 runtime 事件包括: - `trigger_mode` - `transcript_id` - `transcript_path` + +## Auto Compact 阈值解析 + +- `context.auto_compact.input_token_threshold > 0` 时,直接使用显式手动阈值。 +- `context.auto_compact.input_token_threshold <= 0` 时,系统会对当前选中的 provider/model 做自动推导。 +- 自动推导公式为 `resolved_threshold = context_window - reserve_tokens`。 +- `reserve_tokens` 默认 `13000`,用于给输出、tool call 和 system prompt 预留缓冲。 +- 如果当前模型没有可用的 `ContextWindow`,或窗口值小于等于 `reserve_tokens`,则回退到 `fallback_input_token_threshold`。 +- `fallback_input_token_threshold` 默认 `100000`,用于保证主链路在缺少模型窗口元数据时仍可稳定运行。 diff --git a/docs/guides/adding-providers.md b/docs/guides/adding-providers.md index 95d94b01..22d32636 100644 --- a/docs/guides/adding-providers.md +++ b/docs/guides/adding-providers.md @@ -270,3 +270,34 @@ func DefaultProviders() []ProviderConfig { ``` 所有内置 provider 都通过代码集中注册。模型选择器展示的候选模型由默认模型、动态发现结果和本地缓存共同组成。 + +## custom provider 模型元数据补齐 + +对于复用 `openaicompat` 驱动的 custom provider,如果上游 `GET /models` 不能返回可靠的上下文窗口信息,可以在: + +```text +~/.neocode/providers//provider.yaml +``` + +中显式声明 `models`: + +```yaml +name: company-gateway +driver: openaicompat +api_key_env: COMPANY_GATEWAY_API_KEY +models: + - id: deepseek-coder + name: DeepSeek Coder + context_window: 131072 + max_output_tokens: 8192 +openai_compatible: + base_url: https://llm.example.com/v1 + api_style: chat_completions +``` + +约束如下: + +- `models[].id` 必须非空。 +- `models[].context_window` 和 `models[].max_output_tokens` 如果显式配置,必须大于 `0`。 +- 同一个 `provider.yaml` 中重复的模型 `id` 会在加载阶段直接报错。 +- 这些元数据会进入统一的 model catalog 合并链路,优先级仍为“配置模型元数据优先于 discovery/default”。 diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index 0d8724ca..c03f3209 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -58,7 +58,9 @@ context: micro_compact_disabled: false auto_compact: enabled: false - input_token_threshold: 100000 + input_token_threshold: 0 + reserve_tokens: 13000 + fallback_input_token_threshold: 100000 ``` ### 基础字段 @@ -81,6 +83,8 @@ context: | `context.compact.micro_compact_disabled` | 是否关闭默认启用的 micro compact | | `context.auto_compact.enabled` | 是否启用自动压缩 | | `context.auto_compact.input_token_threshold` | 自动压缩输入 token 阈值 | +| `context.auto_compact.reserve_tokens` | 自动阈值推导时预留 token 缓冲(`resolved_threshold = context_window - reserve_tokens`) | +| `context.auto_compact.fallback_input_token_threshold` | 自动推导失败时使用的保底阈值 | ### `runtime` 字段 @@ -142,6 +146,13 @@ openai_compatible: api_style: chat_completions ``` +## Auto Compact 失败与校验补充 + +- 当 `context.auto_compact.input_token_threshold <= 0` 时,如果当前 provider 选择无效、catalog snapshot 查询失败,或模型缺少可用的 `ContextWindow`,系统会回退到 `fallback_input_token_threshold`,不会静默关闭 auto compact。 +- `~/.neocode/providers//provider.yaml` 中的 `models[].id` 必须非空。 +- `models[].context_window` 和 `models[].max_output_tokens` 如果显式配置,必须大于 `0`。 +- `models` 中重复的模型 `id` 会在加载 `provider.yaml` 时直接报错。 + 文件路径: ```text @@ -230,3 +241,26 @@ config: environment variable OPENAI_API_KEY is empty - [添加 Provider](./adding-providers.md) - [配置管理详细设计](../config-management-detail-design.md) - [Context Compact](../context-compact.md) + +## Auto Compact 补充说明 + +- `context.auto_compact.input_token_threshold > 0` 时,系统直接使用该显式阈值。 +- `context.auto_compact.input_token_threshold <= 0` 时,系统会根据当前 `current_model` 对应的 `ContextWindow` 自动推导输入阈值。 +- 推导公式为 `context_window - reserve_tokens`。 +- `reserve_tokens` 默认 `13000`。 +- 如果当前 provider/model 没有可用的 `ContextWindow` 元数据,则回退到 `fallback_input_token_threshold`。 +- custom provider 可以在 `~/.neocode/providers//provider.yaml` 中通过 `models` 字段补齐模型元数据,例如: + +```yaml +name: company-gateway +driver: openaicompat +api_key_env: COMPANY_GATEWAY_API_KEY +models: + - id: deepseek-coder + name: DeepSeek Coder + context_window: 131072 + max_output_tokens: 8192 +openai_compatible: + base_url: https://llm.example.com/v1 + api_style: chat_completions +``` diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go index f5cd22c4..0f64a269 100644 --- a/internal/app/bootstrap.go +++ b/internal/app/bootstrap.go @@ -168,6 +168,15 @@ func BuildRuntime(ctx context.Context, opts BootstrapOptions) (RuntimeBundle, er contextBuilder, ) runtimeSvc.SetSkillsRegistry(buildSkillsRegistry(ctx, loader.BaseDir())) + runtimeSvc.SetAutoCompactThresholdResolver(runtimeAutoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + resolution, err := configstate.ResolveAutoCompactThreshold(ctx, cfg, modelCatalogs) + if err != nil { + return 0, err + } + return resolution.Threshold, nil + }, + )) // 注入记忆提取钩子:当 AutoExtract 启用且 memoSvc 可用时,ReAct 循环完成后异步提取记忆。 if memoSvc != nil && cfg.Memo.AutoExtract { @@ -306,3 +315,9 @@ type textGenAdapter func(ctx context.Context, prompt string, msgs []providertype func (f textGenAdapter) Generate(ctx context.Context, prompt string, msgs []providertypes.Message) (string, error) { return f(ctx, prompt, msgs) } + +type runtimeAutoCompactThresholdResolverFunc func(ctx context.Context, cfg config.Config) (int, error) + +func (f runtimeAutoCompactThresholdResolverFunc) ResolveAutoCompactThreshold(ctx context.Context, cfg config.Config) (int, error) { + return f(ctx, cfg) +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 03ec8197..0a55a580 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1242,6 +1242,14 @@ func TestAutoCompactConfigDefaults(t *testing.T) { t.Fatalf("expected input_token_threshold=%d, got %d", DefaultAutoCompactInputTokenThreshold, cfg.Context.AutoCompact.InputTokenThreshold) } + if cfg.Context.AutoCompact.ReserveTokens != DefaultAutoCompactReserveTokens { + t.Fatalf("expected reserve_tokens=%d, got %d", + DefaultAutoCompactReserveTokens, cfg.Context.AutoCompact.ReserveTokens) + } + if cfg.Context.AutoCompact.FallbackInputTokenThreshold != DefaultAutoCompactFallbackInputTokenThreshold { + t.Fatalf("expected fallback_input_token_threshold=%d, got %d", + DefaultAutoCompactFallbackInputTokenThreshold, cfg.Context.AutoCompact.FallbackInputTokenThreshold) + } if cfg.Context.AutoCompact.Enabled != false { t.Fatalf("expected enabled=false, got %v", cfg.Context.AutoCompact.Enabled) @@ -1253,13 +1261,20 @@ func TestAutoCompactConfigApplyDefaults(t *testing.T) { cfg := AutoCompactConfig{} defaults := AutoCompactConfig{ - InputTokenThreshold: 50000, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, } cfg.ApplyDefaults(defaults) - if cfg.InputTokenThreshold != 50000 { - t.Fatalf("expected threshold=50000, got %d", cfg.InputTokenThreshold) + if cfg.InputTokenThreshold != 0 { + t.Fatalf("expected threshold to remain implicit 0, got %d", cfg.InputTokenThreshold) + } + if cfg.ReserveTokens != 13000 { + t.Fatalf("expected reserve_tokens=13000, got %d", cfg.ReserveTokens) + } + if cfg.FallbackInputTokenThreshold != 100000 { + t.Fatalf("expected fallback_input_token_threshold=100000, got %d", cfg.FallbackInputTokenThreshold) } } @@ -1267,10 +1282,14 @@ func TestAutoCompactConfigApplyDefaultsPreservesExplicit(t *testing.T) { t.Parallel() cfg := AutoCompactConfig{ - InputTokenThreshold: 200000, + InputTokenThreshold: 200000, + ReserveTokens: 5000, + FallbackInputTokenThreshold: 80000, } defaults := AutoCompactConfig{ - InputTokenThreshold: 50000, + InputTokenThreshold: 50000, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, } cfg.ApplyDefaults(defaults) @@ -1278,13 +1297,19 @@ func TestAutoCompactConfigApplyDefaultsPreservesExplicit(t *testing.T) { if cfg.InputTokenThreshold != 200000 { t.Fatalf("expected explicit threshold=200000 to be preserved, got %d", cfg.InputTokenThreshold) } + if cfg.ReserveTokens != 5000 { + t.Fatalf("expected explicit reserve_tokens=5000 to be preserved, got %d", cfg.ReserveTokens) + } + if cfg.FallbackInputTokenThreshold != 80000 { + t.Fatalf("expected explicit fallback_input_token_threshold=80000 to be preserved, got %d", cfg.FallbackInputTokenThreshold) + } } func TestAutoCompactConfigApplyDefaultsNilReceiver(t *testing.T) { t.Parallel() var cfg *AutoCompactConfig - cfg.ApplyDefaults(AutoCompactConfig{InputTokenThreshold: 50000}) + cfg.ApplyDefaults(AutoCompactConfig{ReserveTokens: 13000, FallbackInputTokenThreshold: 100000}) } func TestContextConfigApplyDefaultsPropagatesAutoCompactDefaults(t *testing.T) { @@ -1293,7 +1318,8 @@ func TestContextConfigApplyDefaultsPropagatesAutoCompactDefaults(t *testing.T) { cfg := ContextConfig{} cfg.ApplyDefaults(ContextConfig{ AutoCompact: AutoCompactConfig{ - InputTokenThreshold: 50000, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, }, Compact: CompactConfig{ ManualStrategy: CompactManualStrategyKeepRecent, @@ -1303,8 +1329,14 @@ func TestContextConfigApplyDefaultsPropagatesAutoCompactDefaults(t *testing.T) { }, }) - if cfg.AutoCompact.InputTokenThreshold != 50000 { - t.Fatalf("expected auto compact threshold=50000, got %d", cfg.AutoCompact.InputTokenThreshold) + if cfg.AutoCompact.InputTokenThreshold != 0 { + t.Fatalf("expected auto compact threshold to remain implicit 0, got %d", cfg.AutoCompact.InputTokenThreshold) + } + if cfg.AutoCompact.ReserveTokens != 13000 { + t.Fatalf("expected reserve_tokens=13000, got %d", cfg.AutoCompact.ReserveTokens) + } + if cfg.AutoCompact.FallbackInputTokenThreshold != 100000 { + t.Fatalf("expected fallback_input_token_threshold=100000, got %d", cfg.AutoCompact.FallbackInputTokenThreshold) } } @@ -1312,16 +1344,15 @@ func TestAutoCompactConfigValidateEnabledWithoutThreshold(t *testing.T) { t.Parallel() cfg := AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 0, + Enabled: true, + InputTokenThreshold: 0, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, } err := cfg.Validate() - if err == nil { - t.Fatalf("expected validation error, got nil") - } - if !strings.Contains(err.Error(), "input_token_threshold") { - t.Fatalf("expected error about input_token_threshold, got %v", err) + if err != nil { + t.Fatalf("expected validation to allow implicit threshold, got %v", err) } } @@ -1329,8 +1360,10 @@ func TestAutoCompactConfigValidateDisabledWithoutThreshold(t *testing.T) { t.Parallel() cfg := AutoCompactConfig{ - Enabled: false, - InputTokenThreshold: 0, + Enabled: false, + InputTokenThreshold: 0, + ReserveTokens: 0, + FallbackInputTokenThreshold: 0, } err := cfg.Validate() @@ -1343,8 +1376,10 @@ func TestAutoCompactConfigValidateEnabledWithThreshold(t *testing.T) { t.Parallel() cfg := AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 50000, + Enabled: true, + InputTokenThreshold: 50000, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, } err := cfg.Validate() @@ -1353,12 +1388,46 @@ func TestAutoCompactConfigValidateEnabledWithThreshold(t *testing.T) { } } +func TestAutoCompactConfigValidateRejectsNonPositiveReserveTokens(t *testing.T) { + t.Parallel() + + cfg := AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + ReserveTokens: 0, + FallbackInputTokenThreshold: 100000, + } + + err := cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "reserve_tokens") { + t.Fatalf("expected reserve_tokens validation error, got %v", err) + } +} + +func TestAutoCompactConfigValidateRejectsNonPositiveFallbackThreshold(t *testing.T) { + t.Parallel() + + cfg := AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 0, + } + + err := cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "fallback_input_token_threshold") { + t.Fatalf("expected fallback_input_token_threshold validation error, got %v", err) + } +} + func TestAutoCompactConfigClone(t *testing.T) { t.Parallel() cfg := AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 75000, + Enabled: true, + InputTokenThreshold: 75000, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, } cloned := cfg.Clone() @@ -1370,6 +1439,13 @@ func TestAutoCompactConfigClone(t *testing.T) { t.Fatalf("expected threshold=%d to be cloned, got %d", cfg.InputTokenThreshold, cloned.InputTokenThreshold) } + if cfg.ReserveTokens != cloned.ReserveTokens { + t.Fatalf("expected reserve_tokens=%d to be cloned, got %d", cfg.ReserveTokens, cloned.ReserveTokens) + } + if cfg.FallbackInputTokenThreshold != cloned.FallbackInputTokenThreshold { + t.Fatalf("expected fallback_input_token_threshold=%d to be cloned, got %d", + cfg.FallbackInputTokenThreshold, cloned.FallbackInputTokenThreshold) + } cloned.InputTokenThreshold = 100000 if cfg.InputTokenThreshold == cloned.InputTokenThreshold { @@ -1382,8 +1458,10 @@ func TestAutoCompactConfigContextConfigValidate(t *testing.T) { ctx := ContextConfig{ AutoCompact: AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 0, + Enabled: true, + InputTokenThreshold: 0, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, }, Compact: CompactConfig{ ManualStrategy: CompactManualStrategyKeepRecent, @@ -1394,11 +1472,8 @@ func TestAutoCompactConfigContextConfigValidate(t *testing.T) { } err := ctx.Validate() - if err == nil { - t.Fatalf("expected validation error, got nil") - } - if !strings.Contains(err.Error(), "auto_compact") { - t.Fatalf("expected error to contain 'auto_compact', got %v", err) + if err != nil { + t.Fatalf("expected context validation to allow implicit threshold, got %v", err) } } diff --git a/internal/config/context.go b/internal/config/context.go index c0c2ffb5..57552b32 100644 --- a/internal/config/context.go +++ b/internal/config/context.go @@ -7,11 +7,13 @@ import ( ) const ( - DefaultCompactManualKeepRecentMessages = 10 - DefaultCompactMaxSummaryChars = 1200 - DefaultAutoCompactInputTokenThreshold = 100000 - DefaultMicroCompactRetainedToolSpans = 2 - DefaultCompactReadTimeMaxMessageSpans = 24 + DefaultCompactManualKeepRecentMessages = 10 + DefaultCompactMaxSummaryChars = 1200 + DefaultAutoCompactInputTokenThreshold = 0 + DefaultAutoCompactReserveTokens = 13000 + DefaultAutoCompactFallbackInputTokenThreshold = 100000 + DefaultMicroCompactRetainedToolSpans = 2 + DefaultCompactReadTimeMaxMessageSpans = 24 CompactManualStrategyKeepRecent = "keep_recent" CompactManualStrategyFullReplace = "full_replace" @@ -34,8 +36,10 @@ type CompactConfig struct { // AutoCompactConfig controls automatic context compression triggered by token thresholds. type AutoCompactConfig struct { - Enabled bool `yaml:"enabled"` - InputTokenThreshold int `yaml:"input_token_threshold,omitempty"` + Enabled bool `yaml:"enabled"` + InputTokenThreshold int `yaml:"input_token_threshold,omitempty"` + ReserveTokens int `yaml:"reserve_tokens,omitempty"` + FallbackInputTokenThreshold int `yaml:"fallback_input_token_threshold,omitempty"` } // defaultContextConfig 返回上下文压缩相关配置的默认值。 @@ -48,7 +52,9 @@ func defaultContextConfig() ContextConfig { func defaultAutoCompactConfig() AutoCompactConfig { return AutoCompactConfig{ - InputTokenThreshold: DefaultAutoCompactInputTokenThreshold, + InputTokenThreshold: DefaultAutoCompactInputTokenThreshold, + ReserveTokens: DefaultAutoCompactReserveTokens, + FallbackInputTokenThreshold: DefaultAutoCompactFallbackInputTokenThreshold, } } @@ -119,8 +125,11 @@ func (c *AutoCompactConfig) ApplyDefaults(defaults AutoCompactConfig) { if c == nil { return } - if c.InputTokenThreshold <= 0 { - c.InputTokenThreshold = defaults.InputTokenThreshold + if c.ReserveTokens <= 0 { + c.ReserveTokens = defaults.ReserveTokens + } + if c.FallbackInputTokenThreshold <= 0 { + c.FallbackInputTokenThreshold = defaults.FallbackInputTokenThreshold } } @@ -157,8 +166,14 @@ func (c CompactConfig) Validate() error { // Validate 校验 auto_compact 配置是否合法。 func (c AutoCompactConfig) Validate() error { - if c.Enabled && c.InputTokenThreshold <= 0 { - return errors.New("input_token_threshold must be greater than 0 when enabled") + if !c.Enabled { + return nil + } + if c.ReserveTokens <= 0 { + return errors.New("reserve_tokens must be greater than 0 when enabled") + } + if c.FallbackInputTokenThreshold <= 0 { + return errors.New("fallback_input_token_threshold must be greater than 0 when enabled") } return nil } diff --git a/internal/config/loader.go b/internal/config/loader.go index 878fd9af..5e35c3a6 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -48,8 +48,10 @@ type persistedCompactConfig struct { } type persistedAutoCompactConfig struct { - Enabled bool `yaml:"enabled"` - InputTokenThreshold int `yaml:"input_token_threshold,omitempty"` + Enabled bool `yaml:"enabled"` + InputTokenThreshold int `yaml:"input_token_threshold,omitempty"` + ReserveTokens int `yaml:"reserve_tokens,omitempty"` + FallbackInputTokenThreshold int `yaml:"fallback_input_token_threshold,omitempty"` } type persistedMemoConfig struct { @@ -249,8 +251,10 @@ func newPersistedContextConfig(cfg ContextConfig) persistedContextConfig { MaxArchivedPromptChars: cfg.Compact.MaxArchivedPromptChars, }, AutoCompact: persistedAutoCompactConfig{ - Enabled: cfg.AutoCompact.Enabled, - InputTokenThreshold: cfg.AutoCompact.InputTokenThreshold, + Enabled: cfg.AutoCompact.Enabled, + InputTokenThreshold: cfg.AutoCompact.InputTokenThreshold, + ReserveTokens: cfg.AutoCompact.ReserveTokens, + FallbackInputTokenThreshold: cfg.AutoCompact.FallbackInputTokenThreshold, }, } } @@ -268,8 +272,10 @@ func fromPersistedContextConfig(file persistedContextConfig, defaults ContextCon MaxArchivedPromptChars: file.Compact.MaxArchivedPromptChars, }, AutoCompact: AutoCompactConfig{ - Enabled: file.AutoCompact.Enabled, - InputTokenThreshold: file.AutoCompact.InputTokenThreshold, + Enabled: file.AutoCompact.Enabled, + InputTokenThreshold: file.AutoCompact.InputTokenThreshold, + ReserveTokens: file.AutoCompact.ReserveTokens, + FallbackInputTokenThreshold: file.AutoCompact.FallbackInputTokenThreshold, }, } out.Compact.ApplyDefaults(defaults.Compact) diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 86c87754..fb89594e 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -251,6 +251,11 @@ shell: powershell name: company-gateway driver: openaicompat api_key_env: COMPANY_GATEWAY_API_KEY +models: + - id: deepseek-coder + name: DeepSeek Coder + context_window: 131072 + max_output_tokens: 8192 openai_compatible: base_url: https://llm.example.com/v1 api_style: chat_completions @@ -290,8 +295,11 @@ openai_compatible: if customProvider.Model != "" { t.Fatalf("expected custom provider default model to be empty, got %q", customProvider.Model) } - if len(customProvider.Models) != 0 { - t.Fatalf("expected custom provider models to come only from remote discovery, got %+v", customProvider.Models) + if len(customProvider.Models) != 1 { + t.Fatalf("expected custom provider model metadata from provider.yaml, got %+v", customProvider.Models) + } + if customProvider.Models[0].ID != "deepseek-coder" || customProvider.Models[0].ContextWindow != 131072 { + t.Fatalf("expected parsed model metadata, got %+v", customProvider.Models[0]) } } @@ -421,6 +429,183 @@ models: } } +func TestLoaderRejectsCustomProviderModelWithoutID(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + customDir := filepath.Join(loader.BaseDir(), providersDirName, "company-gateway") + if err := os.MkdirAll(customDir, 0o755); err != nil { + t.Fatalf("mkdir custom provider dir: %v", err) + } + + providerYAML := ` +name: company-gateway +driver: openaicompat +api_key_env: COMPANY_GATEWAY_API_KEY +models: + - name: DeepSeek Coder +openai_compatible: + base_url: https://llm.example.com/v1 +` + if err := os.WriteFile(filepath.Join(customDir, customProviderConfigName), []byte(strings.TrimSpace(providerYAML)+"\n"), 0o644); err != nil { + t.Fatalf("write provider.yaml: %v", err) + } + + _, err := loader.Load(context.Background()) + if err == nil || !strings.Contains(err.Error(), "models[0].id") { + t.Fatalf("expected empty model id rejection, got %v", err) + } +} + +func TestLoaderRejectsCustomProviderModelWithInvalidContextWindow(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + customDir := filepath.Join(loader.BaseDir(), providersDirName, "company-gateway") + if err := os.MkdirAll(customDir, 0o755); err != nil { + t.Fatalf("mkdir custom provider dir: %v", err) + } + + providerYAML := ` +name: company-gateway +driver: openaicompat +api_key_env: COMPANY_GATEWAY_API_KEY +models: + - id: deepseek-coder + context_window: 0 +openai_compatible: + base_url: https://llm.example.com/v1 +` + if err := os.WriteFile(filepath.Join(customDir, customProviderConfigName), []byte(strings.TrimSpace(providerYAML)+"\n"), 0o644); err != nil { + t.Fatalf("write provider.yaml: %v", err) + } + + _, err := loader.Load(context.Background()) + if err == nil || !strings.Contains(err.Error(), "context_window") { + t.Fatalf("expected invalid context_window rejection, got %v", err) + } +} + +func TestLoaderRejectsCustomProviderModelWithInvalidMaxOutputTokens(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + customDir := filepath.Join(loader.BaseDir(), providersDirName, "company-gateway") + if err := os.MkdirAll(customDir, 0o755); err != nil { + t.Fatalf("mkdir custom provider dir: %v", err) + } + + providerYAML := ` +name: company-gateway +driver: openaicompat +api_key_env: COMPANY_GATEWAY_API_KEY +models: + - id: deepseek-coder + max_output_tokens: 0 +openai_compatible: + base_url: https://llm.example.com/v1 +` + if err := os.WriteFile(filepath.Join(customDir, customProviderConfigName), []byte(strings.TrimSpace(providerYAML)+"\n"), 0o644); err != nil { + t.Fatalf("write provider.yaml: %v", err) + } + + _, err := loader.Load(context.Background()) + if err == nil || !strings.Contains(err.Error(), "max_output_tokens") { + t.Fatalf("expected invalid max_output_tokens rejection, got %v", err) + } +} + +func TestLoaderRejectsCustomProviderDuplicateModelID(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + customDir := filepath.Join(loader.BaseDir(), providersDirName, "company-gateway") + if err := os.MkdirAll(customDir, 0o755); err != nil { + t.Fatalf("mkdir custom provider dir: %v", err) + } + + providerYAML := ` +name: company-gateway +driver: openaicompat +api_key_env: COMPANY_GATEWAY_API_KEY +models: + - id: deepseek-coder + - id: DeepSeek-Coder +openai_compatible: + base_url: https://llm.example.com/v1 +` + if err := os.WriteFile(filepath.Join(customDir, customProviderConfigName), []byte(strings.TrimSpace(providerYAML)+"\n"), 0o644); err != nil { + t.Fatalf("write provider.yaml: %v", err) + } + + _, err := loader.Load(context.Background()) + if err == nil || !strings.Contains(err.Error(), "duplicated") { + t.Fatalf("expected duplicate model id rejection, got %v", err) + } +} + +func TestLoaderParsesAutoCompactDerivedFields(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + raw := ` +selected_provider: openai +current_model: gpt-5.4 +shell: powershell +context: + auto_compact: + enabled: true + input_token_threshold: 0 + reserve_tokens: 9000 + fallback_input_token_threshold: 88000 +` + writeLoaderConfig(t, loader, raw) + + cfg, err := loader.Load(context.Background()) + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if cfg.Context.AutoCompact.InputTokenThreshold != 0 { + t.Fatalf("expected implicit threshold 0, got %d", cfg.Context.AutoCompact.InputTokenThreshold) + } + if cfg.Context.AutoCompact.ReserveTokens != 9000 { + t.Fatalf("expected reserve_tokens=9000, got %d", cfg.Context.AutoCompact.ReserveTokens) + } + if cfg.Context.AutoCompact.FallbackInputTokenThreshold != 88000 { + t.Fatalf("expected fallback_input_token_threshold=88000, got %d", cfg.Context.AutoCompact.FallbackInputTokenThreshold) + } +} + +func TestLoaderSavePersistsAutoCompactDerivedFields(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.ReserveTokens = 9000 + cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 + + if err := loader.Save(context.Background(), &cfg); err != nil { + t.Fatalf("Save() error = %v", err) + } + + data, err := os.ReadFile(loader.ConfigPath()) + if err != nil { + t.Fatalf("read config: %v", err) + } + text := string(data) + if strings.Contains(text, "input_token_threshold: 100000") { + t.Fatalf("expected implicit threshold to avoid legacy default, got:\n%s", text) + } + if !strings.Contains(text, "reserve_tokens: 9000") { + t.Fatalf("expected reserve_tokens to persist, got:\n%s", text) + } + if !strings.Contains(text, "fallback_input_token_threshold: 88000") { + t.Fatalf("expected fallback_input_token_threshold to persist, got:\n%s", text) + } +} + func TestLoaderRejectsCustomProviderNameConflictingWithBuiltin(t *testing.T) { t.Parallel() diff --git a/internal/config/provider_loader.go b/internal/config/provider_loader.go index f7a2b2ac..43448ea6 100644 --- a/internal/config/provider_loader.go +++ b/internal/config/provider_loader.go @@ -10,6 +10,7 @@ import ( "gopkg.in/yaml.v3" "neo-code/internal/provider" + providertypes "neo-code/internal/provider/types" ) const ( @@ -22,11 +23,19 @@ type customProviderFile struct { Driver string `yaml:"driver"` APIKeyEnv string `yaml:"api_key_env"` BaseURL string `yaml:"base_url,omitempty"` + Models []customProviderModelFile `yaml:"models,omitempty"` OpenAICompatible customOpenAICompatibleFile `yaml:"openai_compatible,omitempty"` Gemini customGeminiProviderFile `yaml:"gemini,omitempty"` Anthropic customAnthropicProviderFile `yaml:"anthropic,omitempty"` } +type customProviderModelFile struct { + ID string `yaml:"id"` + Name string `yaml:"name,omitempty"` + ContextWindow *int `yaml:"context_window,omitempty"` + MaxOutputTokens *int `yaml:"max_output_tokens,omitempty"` +} + type customOpenAICompatibleFile struct { BaseURL string `yaml:"base_url"` APIStyle string `yaml:"api_style,omitempty"` @@ -106,6 +115,11 @@ func loadCustomProvider(providerDir string) (ProviderConfig, error) { } settings := resolveCustomProviderSettings(file) + models, err := customProviderModels(file.Models) + if err != nil { + return ProviderConfig{}, fmt.Errorf("config: custom provider %q: %w", filepath.Base(providerDir), err) + } + cfg := ProviderConfig{ Name: strings.TrimSpace(file.Name), Driver: strings.TrimSpace(file.Driver), @@ -114,6 +128,7 @@ func loadCustomProvider(providerDir string) (ProviderConfig, error) { APIStyle: settings.APIStyle, DeploymentMode: settings.DeploymentMode, APIVersion: settings.APIVersion, + Models: models, Source: ProviderSourceCustom, } @@ -128,6 +143,48 @@ func loadCustomProvider(providerDir string) (ProviderConfig, error) { return cfg, nil } +// customProviderModels 校验并收敛 custom provider.yaml 中声明的模型元数据。 +func customProviderModels(models []customProviderModelFile) ([]providertypes.ModelDescriptor, error) { + if len(models) == 0 { + return nil, nil + } + + descriptors := make([]providertypes.ModelDescriptor, 0, len(models)) + seen := make(map[string]struct{}, len(models)) + for index, model := range models { + id := strings.TrimSpace(model.ID) + if id == "" { + return nil, fmt.Errorf("models[%d].id is empty", index) + } + + key := provider.NormalizeKey(id) + if _, exists := seen[key]; exists { + return nil, fmt.Errorf("models[%d].id %q is duplicated", index, id) + } + seen[key] = struct{}{} + + descriptor := providertypes.ModelDescriptor{ + ID: id, + Name: strings.TrimSpace(model.Name), + } + if model.ContextWindow != nil { + if *model.ContextWindow <= 0 { + return nil, fmt.Errorf("models[%d].context_window must be greater than 0", index) + } + descriptor.ContextWindow = *model.ContextWindow + } + if model.MaxOutputTokens != nil { + if *model.MaxOutputTokens <= 0 { + return nil, fmt.Errorf("models[%d].max_output_tokens must be greater than 0", index) + } + descriptor.MaxOutputTokens = *model.MaxOutputTokens + } + descriptors = append(descriptors, descriptor) + } + + return providertypes.MergeModelDescriptors(descriptors), nil +} + // resolveCustomProviderSettings 根据 driver 只提取当前协议真正生效的配置字段,避免误吃其他协议块的值。 // 已知 driver 仅从协议块读取 base_url;未知 driver 使用顶层 base_url 作为唯一入口。 func resolveCustomProviderSettings(file customProviderFile) customProviderSettings { diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 22525bb5..f364b7ce 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -194,6 +194,78 @@ func TestCloneProviderConfigModelDescriptorsIndependence(t *testing.T) { } } +func TestCustomProviderModelsParsesSupportedMetadata(t *testing.T) { + t.Parallel() + + contextWindow := 131072 + maxOutputTokens := 8192 + models, err := customProviderModels([]customProviderModelFile{ + { + ID: "deepseek-coder", + Name: "DeepSeek Coder", + ContextWindow: &contextWindow, + MaxOutputTokens: &maxOutputTokens, + }, + }) + if err != nil { + t.Fatalf("customProviderModels() error = %v", err) + } + + if len(models) != 1 { + t.Fatalf("expected one parsed model, got %+v", models) + } + if models[0].ID != "deepseek-coder" || models[0].ContextWindow != 131072 || models[0].MaxOutputTokens != 8192 { + t.Fatalf("unexpected parsed model descriptor: %+v", models[0]) + } +} + +func TestCustomProviderModelsRejectsEmptyID(t *testing.T) { + t.Parallel() + + _, err := customProviderModels([]customProviderModelFile{{Name: "Missing ID"}}) + if err == nil || !strings.Contains(err.Error(), "models[0].id") { + t.Fatalf("expected empty id validation error, got %v", err) + } +} + +func TestCustomProviderModelsRejectsNonPositiveContextWindow(t *testing.T) { + t.Parallel() + + contextWindow := 0 + _, err := customProviderModels([]customProviderModelFile{{ + ID: "deepseek-coder", + ContextWindow: &contextWindow, + }}) + if err == nil || !strings.Contains(err.Error(), "context_window") { + t.Fatalf("expected context_window validation error, got %v", err) + } +} + +func TestCustomProviderModelsRejectsNonPositiveMaxOutputTokens(t *testing.T) { + t.Parallel() + + maxOutputTokens := 0 + _, err := customProviderModels([]customProviderModelFile{{ + ID: "deepseek-coder", + MaxOutputTokens: &maxOutputTokens, + }}) + if err == nil || !strings.Contains(err.Error(), "max_output_tokens") { + t.Fatalf("expected max_output_tokens validation error, got %v", err) + } +} + +func TestCustomProviderModelsRejectsDuplicateID(t *testing.T) { + t.Parallel() + + _, err := customProviderModels([]customProviderModelFile{ + {ID: "deepseek-coder"}, + {ID: " DeepSeek-Coder "}, + }) + if err == nil || !strings.Contains(err.Error(), "duplicated") { + t.Fatalf("expected duplicate id validation error, got %v", err) + } +} + func TestProviderByNameCaseInsensitive(t *testing.T) { t.Parallel() diff --git a/internal/config/state/auto_compact_threshold.go b/internal/config/state/auto_compact_threshold.go new file mode 100644 index 00000000..490a2d54 --- /dev/null +++ b/internal/config/state/auto_compact_threshold.go @@ -0,0 +1,90 @@ +package state + +import ( + "context" + "strings" + + "neo-code/internal/config" + "neo-code/internal/provider" +) + +// AutoCompactThresholdSource 标识自动压缩阈值最终采用的来源。 +type AutoCompactThresholdSource string + +const ( + AutoCompactThresholdSourceDisabled AutoCompactThresholdSource = "disabled" + AutoCompactThresholdSourceExplicit AutoCompactThresholdSource = "explicit" + AutoCompactThresholdSourceDerived AutoCompactThresholdSource = "derived" + AutoCompactThresholdSourceFallback AutoCompactThresholdSource = "fallback" +) + +// AutoCompactThresholdResolution 描述自动压缩阈值的解析结果,供 runtime 直接消费。 +type AutoCompactThresholdResolution struct { + Threshold int + Source AutoCompactThresholdSource + ContextWindow int + ModelID string +} + +// fallbackAutoCompactThresholdResolution 构造自动推导失败时使用的保底阈值结果。 +func fallbackAutoCompactThresholdResolution(cfg config.Config) AutoCompactThresholdResolution { + return AutoCompactThresholdResolution{ + Threshold: cfg.Context.AutoCompact.FallbackInputTokenThreshold, + Source: AutoCompactThresholdSourceFallback, + ModelID: strings.TrimSpace(cfg.CurrentModel), + } +} + +// ResolveAutoCompactThreshold 基于当前选择的 provider/model 和模型目录快照解析最终阈值。 +func ResolveAutoCompactThreshold( + ctx context.Context, + cfg config.Config, + catalogs ModelCatalog, +) (AutoCompactThresholdResolution, error) { + autoCompact := cfg.Context.AutoCompact + if !autoCompact.Enabled { + return AutoCompactThresholdResolution{Source: AutoCompactThresholdSourceDisabled}, nil + } + + if autoCompact.InputTokenThreshold > 0 { + return AutoCompactThresholdResolution{ + Threshold: autoCompact.InputTokenThreshold, + Source: AutoCompactThresholdSourceExplicit, + ModelID: strings.TrimSpace(cfg.CurrentModel), + }, nil + } + + resolution := fallbackAutoCompactThresholdResolution(cfg) + providerCfg, err := selectedProviderConfig(cfg) + if err != nil { + return resolution, nil + } + if catalogs == nil { + return resolution, nil + } + + input, err := catalogInputFromProvider(providerCfg) + if err != nil { + return resolution, nil + } + + models, err := catalogs.ListProviderModelsSnapshot(ctx, input) + if err != nil { + return resolution, nil + } + + modelID := provider.NormalizeKey(cfg.CurrentModel) + for _, model := range models { + if provider.NormalizeKey(model.ID) != modelID { + continue + } + resolution.ContextWindow = model.ContextWindow + if model.ContextWindow > autoCompact.ReserveTokens { + resolution.Threshold = model.ContextWindow - autoCompact.ReserveTokens + resolution.Source = AutoCompactThresholdSourceDerived + } + return resolution, nil + } + + return resolution, nil +} diff --git a/internal/config/state/auto_compact_threshold_test.go b/internal/config/state/auto_compact_threshold_test.go new file mode 100644 index 00000000..24bd9efc --- /dev/null +++ b/internal/config/state/auto_compact_threshold_test.go @@ -0,0 +1,166 @@ +package state + +import ( + "context" + "errors" + "testing" + + configpkg "neo-code/internal/config" + providertypes "neo-code/internal/provider/types" +) + +func TestResolveAutoCompactThresholdDisabled(t *testing.T) { + t.Parallel() + + cfg := configpkg.StaticDefaults().Clone() + cfg.Context.AutoCompact.Enabled = false + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, nil) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 0 || resolution.Source != AutoCompactThresholdSourceDisabled { + t.Fatalf("expected disabled resolution, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdExplicitWins(t *testing.T) { + t.Parallel() + + cfg := configpkg.StaticDefaults().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 42000 + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, nil) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 42000 || resolution.Source != AutoCompactThresholdSourceExplicit { + t.Fatalf("expected explicit resolution, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdDerivedFromContextWindow(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.ReserveTokens = 13000 + cfg.CurrentModel = "deepseek-coder" + cfg.Providers[0].Model = "deepseek-coder" + cfg.Providers[0].Models = []providertypes.ModelDescriptor{{ + ID: "deepseek-coder", + ContextWindow: 131072, + }} + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{ + snapshotModels: cfg.Providers[0].Models, + }) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 118072 || resolution.Source != AutoCompactThresholdSourceDerived { + t.Fatalf("expected derived threshold, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdFallsBackWhenWindowTooSmall(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.ReserveTokens = 13000 + cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 + cfg.CurrentModel = "small-model" + cfg.Providers[0].Model = "small-model" + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{ + snapshotModels: []providertypes.ModelDescriptor{{ + ID: "small-model", + ContextWindow: 8000, + }}, + }) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 88000 || resolution.Source != AutoCompactThresholdSourceFallback { + t.Fatalf("expected fallback threshold, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdFallsBackWhenModelMissing(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 + cfg.CurrentModel = "missing-model" + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{ + snapshotModels: []providertypes.ModelDescriptor{{ID: "other-model", ContextWindow: 131072}}, + }) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 88000 || resolution.Source != AutoCompactThresholdSourceFallback { + t.Fatalf("expected missing model to use fallback, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdFallsBackWhenSelectedProviderInvalid(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 + cfg.SelectedProvider = "missing-provider" + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{}) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 88000 || resolution.Source != AutoCompactThresholdSourceFallback { + t.Fatalf("expected invalid selection to use fallback, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdFallsBackWhenCatalogInputResolutionFails(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 + cfg.Providers[0].BaseURL = "" + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{}) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 88000 || resolution.Source != AutoCompactThresholdSourceFallback { + t.Fatalf("expected invalid catalog input to use fallback, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdFallsBackWhenSnapshotLookupFails(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{ + snapshotErr: errors.New("snapshot failed"), + }) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 88000 || resolution.Source != AutoCompactThresholdSourceFallback { + t.Fatalf("expected snapshot error to use fallback, got %+v", resolution) + } +} diff --git a/internal/gateway/handlers/wake.go b/internal/gateway/handlers/wake.go index db7d7f40..2d72ce2d 100644 --- a/internal/gateway/handlers/wake.go +++ b/internal/gateway/handlers/wake.go @@ -97,7 +97,7 @@ func isSafeReviewPath(path string) bool { if hasBlockedWindowsPathPrefix(trimmed) { return false } - if filepath.IsAbs(trimmed) || strings.HasPrefix(trimmed, "/") || strings.HasPrefix(trimmed, "\\") { + if isAbsoluteReviewPath(trimmed) { return false } if containsParentTraversalSegment(trimmed) { @@ -113,7 +113,7 @@ func isSafeReviewPath(path string) bool { return true } -// hasWindowsDriveLetterPrefix 检查是否为 Windows 盘符前缀路径(如 C:foo),避免平台差异导致漏拦截。 +// hasWindowsDriveLetterPrefix 检查是否为 Windows 盘符前缀路径,如 C:foo。 func hasWindowsDriveLetterPrefix(path string) bool { trimmed := strings.TrimSpace(path) if len(trimmed) < 2 { @@ -123,12 +123,21 @@ func hasWindowsDriveLetterPrefix(path string) bool { return ((drive >= 'a' && drive <= 'z') || (drive >= 'A' && drive <= 'Z')) && trimmed[1] == ':' } -// hasBlockedWindowsPathPrefix 检查是否命中 Windows 底层设备路径前缀,避免绕过常规路径校验。 +// hasBlockedWindowsPathPrefix 检查是否命中 Windows 设备路径前缀,避免绕过常规路径校验。 func hasBlockedWindowsPathPrefix(path string) bool { normalized := strings.ReplaceAll(strings.TrimSpace(path), "/", "\\") return strings.HasPrefix(normalized, `\\?\`) || strings.HasPrefix(normalized, `\\.\`) } +// isAbsoluteReviewPath 统一识别跨平台绝对路径,避免在 Windows 下漏判 Unix 风格前导斜杠。 +func isAbsoluteReviewPath(path string) bool { + if filepath.IsAbs(path) { + return true + } + normalized := normalizePath(path) + return strings.HasPrefix(normalized, "/") || strings.HasPrefix(normalized, "\\") +} + // containsParentTraversalSegment 按路径段语义识别目录回退段,避免子串匹配导致误伤。 func containsParentTraversalSegment(path string) bool { normalized := normalizePath(path) diff --git a/internal/provider/catalog/service_test.go b/internal/provider/catalog/service_test.go index df826fc7..3488dd3a 100644 --- a/internal/provider/catalog/service_test.go +++ b/internal/provider/catalog/service_test.go @@ -103,6 +103,34 @@ func TestListProviderModelsMergesConfiguredMetadataAfterDiscovery(t *testing.T) } } +func TestListProviderModelsUsesConfiguredContextWindowWhenDiscoveryMissesIt(t *testing.T) { + t.Setenv(testAPIKeyEnv, "test-key") + + registry := newRegistry(t, openaicompat.DriverName, func(ctx context.Context, cfg provider.RuntimeConfig) ([]providertypes.ModelDescriptor, error) { + return []providertypes.ModelDescriptor{{ + ID: "deepseek-coder", + Name: "Server DeepSeek", + ContextWindow: 0, + }}, nil + }) + + service := NewService("", registry, newMemoryStore()) + providerCfg := customGatewayProvider() + providerCfg.Models = []providertypes.ModelDescriptor{{ + ID: "deepseek-coder", + Name: "DeepSeek Coder", + ContextWindow: 131072, + }} + + models, err := service.ListProviderModels(context.Background(), mustCatalogInput(t, providerCfg)) + if err != nil { + t.Fatalf("ListProviderModels() error = %v", err) + } + if len(models) != 1 || models[0].ContextWindow != 131072 { + t.Fatalf("expected configured context window to fill discovery gap, got %+v", models) + } +} + func TestListProviderModelsSnapshotReturnsDefaultAndRefreshesInBackgroundOnMiss(t *testing.T) { t.Setenv(testAPIKeyEnv, "test-key") diff --git a/internal/runtime/run.go b/internal/runtime/run.go index d698b420..ea3930c5 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -78,32 +78,18 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { initialCfg := s.configManager.Get() sessionID := strings.TrimSpace(input.SessionID) - releaseSessionLock := func() {} + releaseSessionLock := s.bindSessionLock(sessionID) defer func() { releaseSessionLock() }() - if sessionID != "" { - sessionMu, releaseLockRef := s.acquireSessionLock(sessionID) - sessionMu.Lock() - releaseSessionLock = func() { - sessionMu.Unlock() - releaseLockRef() - } - } - session, err := s.loadOrCreateSession(ctx, input.SessionID, input.Content, initialCfg.Workdir, input.Workdir) if err != nil { return s.handleRunError(ctx, input.RunID, input.SessionID, err) } if sessionID == "" { - sessionMu, releaseLockRef := s.acquireSessionLock(session.ID) - sessionMu.Lock() - releaseSessionLock = func() { - sessionMu.Unlock() - releaseLockRef() - } + releaseSessionLock = s.bindSessionLock(session.ID) } state := newRunState(input.RunID, session) @@ -230,7 +216,7 @@ func (s *Service) prepareTurnSnapshot(ctx context.Context, state *runState) (tur }, Compact: agentcontext.CompactOptions{ DisableMicroCompact: cfg.Context.Compact.MicroCompactDisabled, - AutoCompactThreshold: autoCompactThreshold(cfg), + AutoCompactThreshold: s.autoCompactThresholdForState(ctx, cfg, state), MicroCompactRetainedToolSpans: cfg.Context.Compact.MicroCompactRetainedToolSpans, ReadTimeMaxMessageSpans: cfg.Context.Compact.ReadTimeMaxMessageSpans, }, @@ -267,6 +253,9 @@ func (s *Service) prepareTurnSnapshot(ctx context.Context, state *runState) (tur state.mu.Unlock() limit := resolveNoProgressStreakLimit(cfg.Runtime) +<<<<<<< codex/issue-294-auto-compact-threshold + systemPrompt := withSelfHealingReminder(builtContext.SystemPrompt, streak, limit) +======= repeatLimit := resolveRepeatCycleStreakLimit(cfg.Runtime) systemPrompt := builtContext.SystemPrompt @@ -285,6 +274,7 @@ func (s *Service) prepareTurnSnapshot(ctx context.Context, state *runState) (tur systemPrompt = trimmed + "\n\n" + selfHealingReminder } } +>>>>>>> main model := strings.TrimSpace(cfg.CurrentModel) return turnSnapshot{ @@ -413,11 +403,39 @@ func (s *Service) applyCompactForState( } // autoCompactThreshold 返回当前配置下的自动 compact 触发阈值。 -func autoCompactThreshold(cfg config.Config) int { - if cfg.Context.AutoCompact.Enabled && cfg.Context.AutoCompact.InputTokenThreshold > 0 { +func (s *Service) autoCompactThreshold(ctx context.Context, cfg config.Config) int { + return s.autoCompactThresholdForState(ctx, cfg, nil) +} + +// autoCompactThresholdForState 返回当前配置下的自动 compact 触发阈值,并在单次 run 内按关键输入缓存结果。 +func (s *Service) autoCompactThresholdForState(ctx context.Context, cfg config.Config, state *runState) int { + if !cfg.Context.AutoCompact.Enabled { + return 0 + } + if cfg.Context.AutoCompact.InputTokenThreshold > 0 { return cfg.Context.AutoCompact.InputTokenThreshold } - return 0 + + key := autoCompactCacheKeyFromConfig(cfg) + if state != nil && state.autoCompactCache.valid && state.autoCompactCache.key == key { + return state.autoCompactCache.threshold + } + + threshold := fallbackAutoCompactThreshold(cfg) + if s != nil && s.autoCompactThresholdResolver != nil { + resolvedThreshold, err := s.autoCompactThresholdResolver.ResolveAutoCompactThreshold(ctx, cfg) + if err == nil && resolvedThreshold > 0 { + threshold = resolvedThreshold + } + } + if state != nil { + state.autoCompactCache = autoCompactThresholdCache{ + key: key, + threshold: threshold, + valid: true, + } + } + return threshold } // degradeKeepRecentMessages 根据 reactive compact 尝试次数逐步减少保留消息数。 @@ -430,3 +448,49 @@ func degradeKeepRecentMessages(base int, attempt int) int { } return base } + +// fallbackAutoCompactThreshold 返回自动推导失败时仍可继续使用的保底阈值。 +func fallbackAutoCompactThreshold(cfg config.Config) int { + if cfg.Context.AutoCompact.FallbackInputTokenThreshold > 0 { + return cfg.Context.AutoCompact.FallbackInputTokenThreshold + } + return 0 +} + +// bindSessionLock 获取并持有指定会话锁,返回对应的释放函数。 +func (s *Service) bindSessionLock(sessionID string) func() { + id := strings.TrimSpace(sessionID) + if id == "" { + return func() {} + } + sessionMu, releaseLockRef := s.acquireSessionLock(id) + sessionMu.Lock() + return func() { + sessionMu.Unlock() + releaseLockRef() + } +} + +// withSelfHealingReminder 在无进展临界轮次注入自愈提醒,保持提示词拼接规则集中。 +func withSelfHealingReminder(systemPrompt string, streak int, limit int) string { + if streak != limit-1 { + return systemPrompt + } + trimmed := strings.TrimSpace(systemPrompt) + if trimmed == "" { + return selfHealingReminder + } + return trimmed + "\n\n" + selfHealingReminder +} + +// autoCompactCacheKeyFromConfig 提取会影响自动压缩阈值解析的配置维度,用于 run 内缓存命中判断。 +func autoCompactCacheKeyFromConfig(cfg config.Config) autoCompactThresholdCacheKey { + return autoCompactThresholdCacheKey{ + provider: strings.TrimSpace(cfg.SelectedProvider), + model: strings.TrimSpace(cfg.CurrentModel), + autoCompactEnabled: cfg.Context.AutoCompact.Enabled, + autoCompactInputThreshold: cfg.Context.AutoCompact.InputTokenThreshold, + autoCompactReserveTokens: cfg.Context.AutoCompact.ReserveTokens, + autoCompactFallback: cfg.Context.AutoCompact.FallbackInputTokenThreshold, + } +} diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index ca619ffc..ce5546a4 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -60,16 +60,21 @@ type MemoExtractor interface { } // Service 是 runtime 的默认实现,负责组织一次完整的 agent 运行闭环。 +type AutoCompactThresholdResolver interface { + ResolveAutoCompactThreshold(ctx context.Context, cfg config.Config) (int, error) +} + type Service struct { - configManager *config.Manager - sessionStore agentsession.Store - toolManager tools.Manager - providerFactory ProviderFactory - contextBuilder agentcontext.Builder - compactRunner contextcompact.Runner - approvalBroker *approval.Broker - memoExtractor MemoExtractor - skillsRegistry skills.Registry + configManager *config.Manager + sessionStore agentsession.Store + toolManager tools.Manager + providerFactory ProviderFactory + contextBuilder agentcontext.Builder + compactRunner contextcompact.Runner + approvalBroker *approval.Broker + memoExtractor MemoExtractor + skillsRegistry skills.Registry + autoCompactThresholdResolver AutoCompactThresholdResolver events chan RuntimeEvent sessionMu sync.Mutex @@ -171,3 +176,8 @@ func (s *Service) LoadSession(ctx context.Context, id string) (agentsession.Sess } return session, nil } + +// SetAutoCompactThresholdResolver 注入自动压缩阈值解析能力,避免 runtime 直接处理模型目录细节。 +func (s *Service) SetAutoCompactThresholdResolver(resolver AutoCompactThresholdResolver) { + s.autoCompactThresholdResolver = resolver +} diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 39b145ae..c03b8a2a 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -39,6 +39,12 @@ type failingStore struct { ignoreContextErr bool } +type autoCompactThresholdResolverFunc func(ctx context.Context, cfg config.Config) (int, error) + +func (f autoCompactThresholdResolverFunc) ResolveAutoCompactThreshold(ctx context.Context, cfg config.Config) (int, error) { + return f(ctx, cfg) +} + func newMemoryStore() *memoryStore { return &memoryStore{sessions: map[string]agentsession.Session{}} } @@ -4186,6 +4192,7 @@ func TestRestoreSessionTokensNewSession(t *testing.T) { func TestAutoCompactThresholdEnabled(t *testing.T) { t.Parallel() + service := &Service{} cfg := config.Config{ Context: config.ContextConfig{ AutoCompact: config.AutoCompactConfig{ @@ -4195,7 +4202,7 @@ func TestAutoCompactThresholdEnabled(t *testing.T) { }, } - threshold := autoCompactThreshold(cfg) + threshold := service.autoCompactThreshold(context.Background(), cfg) if threshold != 50000 { t.Fatalf("expected threshold == 50000, got %d", threshold) } @@ -4204,6 +4211,7 @@ func TestAutoCompactThresholdEnabled(t *testing.T) { func TestAutoCompactThresholdDisabled(t *testing.T) { t.Parallel() + service := &Service{} cfg := config.Config{ Context: config.ContextConfig{ AutoCompact: config.AutoCompactConfig{ @@ -4213,7 +4221,7 @@ func TestAutoCompactThresholdDisabled(t *testing.T) { }, } - threshold := autoCompactThreshold(cfg) + threshold := service.autoCompactThreshold(context.Background(), cfg) if threshold != 0 { t.Fatalf("expected threshold == 0, got %d", threshold) } @@ -4222,6 +4230,7 @@ func TestAutoCompactThresholdDisabled(t *testing.T) { func TestAutoCompactThresholdZeroValue(t *testing.T) { t.Parallel() + service := &Service{} cfg := config.Config{ Context: config.ContextConfig{ AutoCompact: config.AutoCompactConfig{ @@ -4231,12 +4240,213 @@ func TestAutoCompactThresholdZeroValue(t *testing.T) { }, } - threshold := autoCompactThreshold(cfg) + threshold := service.autoCompactThreshold(context.Background(), cfg) if threshold != 0 { t.Fatalf("expected threshold == 0, got %d", threshold) } } +func TestAutoCompactThresholdUsesResolver(t *testing.T) { + t.Parallel() + + service := &Service{} + service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + return 88000, nil + }, + )) + + cfg := config.Config{ + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + }, + }, + } + + threshold := service.autoCompactThreshold(context.Background(), cfg) + if threshold != 88000 { + t.Fatalf("expected resolver threshold == 88000, got %d", threshold) + } +} + +func TestAutoCompactThresholdFallsBackWhenResolverErrors(t *testing.T) { + t.Parallel() + + service := &Service{} + service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + return 0, errors.New("resolver failed") + }, + )) + + cfg := config.Config{ + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + FallbackInputTokenThreshold: 88000, + }, + }, + } + + threshold := service.autoCompactThreshold(context.Background(), cfg) + if threshold != 88000 { + t.Fatalf("expected fallback threshold == 88000, got %d", threshold) + } +} + +func TestAutoCompactThresholdFallsBackWhenResolverReturnsZeroWithoutError(t *testing.T) { + t.Parallel() + + service := &Service{} + service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + return 0, nil + }, + )) + + cfg := config.Config{ + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + FallbackInputTokenThreshold: 88000, + }, + }, + } + + threshold := service.autoCompactThreshold(context.Background(), cfg) + if threshold != 88000 { + t.Fatalf("expected fallback threshold == 88000, got %d", threshold) + } +} + +func TestAutoCompactThresholdFallsBackWhenResolverReturnsNegativeWithoutError(t *testing.T) { + t.Parallel() + + service := &Service{} + service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + return -1, nil + }, + )) + + cfg := config.Config{ + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + FallbackInputTokenThreshold: 88000, + }, + }, + } + + threshold := service.autoCompactThreshold(context.Background(), cfg) + if threshold != 88000 { + t.Fatalf("expected fallback threshold == 88000, got %d", threshold) + } +} + +func TestAutoCompactThresholdImplicitModeWithoutResolverUsesFallback(t *testing.T) { + t.Parallel() + + service := &Service{} + cfg := config.Config{ + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + FallbackInputTokenThreshold: 88000, + }, + }, + } + + threshold := service.autoCompactThreshold(context.Background(), cfg) + if threshold != 88000 { + t.Fatalf("expected implicit mode fallback threshold == 88000, got %d", threshold) + } +} + +func TestAutoCompactThresholdForStateCachesResolverResultWithinRun(t *testing.T) { + t.Parallel() + + service := &Service{} + resolveCalls := 0 + service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + resolveCalls++ + return 88000, nil + }, + )) + + cfg := config.Config{ + SelectedProvider: "openai", + CurrentModel: "gpt-5", + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + ReserveTokens: 10000, + FallbackInputTokenThreshold: 76000, + }, + }, + } + state := newRunState("run-cache-hit", newRuntimeSession("session-cache-hit")) + + threshold1 := service.autoCompactThresholdForState(context.Background(), cfg, &state) + threshold2 := service.autoCompactThresholdForState(context.Background(), cfg, &state) + + if threshold1 != 88000 || threshold2 != 88000 { + t.Fatalf("expected cached resolver threshold == 88000, got %d and %d", threshold1, threshold2) + } + if resolveCalls != 1 { + t.Fatalf("expected resolver to be called once, got %d", resolveCalls) + } +} + +func TestAutoCompactThresholdForStateRecomputesWhenCacheKeyChanges(t *testing.T) { + t.Parallel() + + service := &Service{} + resolveCalls := 0 + service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + resolveCalls++ + if strings.TrimSpace(cfg.CurrentModel) == "gpt-5.1" { + return 99000, nil + } + return 88000, nil + }, + )) + + cfg := config.Config{ + SelectedProvider: "openai", + CurrentModel: "gpt-5", + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + ReserveTokens: 10000, + FallbackInputTokenThreshold: 76000, + }, + }, + } + state := newRunState("run-cache-miss", newRuntimeSession("session-cache-miss")) + + threshold1 := service.autoCompactThresholdForState(context.Background(), cfg, &state) + cfg.CurrentModel = "gpt-5.1" + threshold2 := service.autoCompactThresholdForState(context.Background(), cfg, &state) + + if threshold1 != 88000 || threshold2 != 99000 { + t.Fatalf("expected thresholds [88000, 99000], got [%d, %d]", threshold1, threshold2) + } + if resolveCalls != 2 { + t.Fatalf("expected resolver to be called twice after key change, got %d", resolveCalls) + } +} + func TestTokenUsageRecordedOnMessageDone(t *testing.T) { t.Parallel() diff --git a/internal/runtime/state.go b/internal/runtime/state.go index 9a202343..2c0974fb 100644 --- a/internal/runtime/state.go +++ b/internal/runtime/state.go @@ -21,6 +21,7 @@ type runState struct { session agentsession.Session compactApplied bool reactiveCompactAttempts int + autoCompactCache autoCompactThresholdCache rememberedThisRun bool turn int phase controlplane.Phase @@ -101,3 +102,20 @@ type providerTurnResult struct { inputTokens int outputTokens int } + +// autoCompactThresholdCache 保存当前 run 已解析过的自动压缩阈值,避免热路径重复解析。 +type autoCompactThresholdCache struct { + key autoCompactThresholdCacheKey + threshold int + valid bool +} + +// autoCompactThresholdCacheKey 描述自动压缩阈值解析输入的关键维度。 +type autoCompactThresholdCacheKey struct { + provider string + model string + autoCompactEnabled bool + autoCompactInputThreshold int + autoCompactReserveTokens int + autoCompactFallback int +}