diff --git a/internal/config/loader.go b/internal/config/loader.go index d3e40a41..14cf26cb 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -38,10 +38,12 @@ type persistedContextConfig struct { } type persistedCompactConfig struct { - ManualStrategy string `yaml:"manual_strategy,omitempty"` - ManualKeepRecentMessages int `yaml:"manual_keep_recent_messages,omitempty"` - MaxSummaryChars int `yaml:"max_summary_chars,omitempty"` - MicroCompactDisabled bool `yaml:"micro_compact_disabled,omitempty"` + ManualStrategy string `yaml:"manual_strategy,omitempty"` + ManualKeepRecentMessages int `yaml:"manual_keep_recent_messages,omitempty"` + MaxSummaryChars int `yaml:"max_summary_chars,omitempty"` + MicroCompactDisabled bool `yaml:"micro_compact_disabled,omitempty"` + MicroCompactRetainedToolSpans int `yaml:"micro_compact_retained_tool_spans,omitempty"` + MaxArchivedPromptChars int `yaml:"max_archived_prompt_chars,omitempty"` } type persistedAutoCompactConfig struct { @@ -237,10 +239,12 @@ func marshalPersistedConfig(snapshot Config) ([]byte, error) { func newPersistedContextConfig(cfg ContextConfig) persistedContextConfig { return persistedContextConfig{ Compact: persistedCompactConfig{ - ManualStrategy: cfg.Compact.ManualStrategy, - ManualKeepRecentMessages: cfg.Compact.ManualKeepRecentMessages, - MaxSummaryChars: cfg.Compact.MaxSummaryChars, - MicroCompactDisabled: cfg.Compact.MicroCompactDisabled, + ManualStrategy: cfg.Compact.ManualStrategy, + ManualKeepRecentMessages: cfg.Compact.ManualKeepRecentMessages, + MaxSummaryChars: cfg.Compact.MaxSummaryChars, + MicroCompactDisabled: cfg.Compact.MicroCompactDisabled, + MicroCompactRetainedToolSpans: cfg.Compact.MicroCompactRetainedToolSpans, + MaxArchivedPromptChars: cfg.Compact.MaxArchivedPromptChars, }, AutoCompact: persistedAutoCompactConfig{ Enabled: cfg.AutoCompact.Enabled, @@ -253,10 +257,12 @@ func newPersistedContextConfig(cfg ContextConfig) persistedContextConfig { func fromPersistedContextConfig(file persistedContextConfig, defaults ContextConfig) ContextConfig { out := ContextConfig{ Compact: CompactConfig{ - ManualStrategy: strings.TrimSpace(file.Compact.ManualStrategy), - ManualKeepRecentMessages: file.Compact.ManualKeepRecentMessages, - MaxSummaryChars: file.Compact.MaxSummaryChars, - MicroCompactDisabled: file.Compact.MicroCompactDisabled, + ManualStrategy: strings.TrimSpace(file.Compact.ManualStrategy), + ManualKeepRecentMessages: file.Compact.ManualKeepRecentMessages, + MaxSummaryChars: file.Compact.MaxSummaryChars, + MicroCompactDisabled: file.Compact.MicroCompactDisabled, + MicroCompactRetainedToolSpans: file.Compact.MicroCompactRetainedToolSpans, + MaxArchivedPromptChars: file.Compact.MaxArchivedPromptChars, }, AutoCompact: AutoCompactConfig{ Enabled: file.AutoCompact.Enabled, @@ -288,24 +294,26 @@ func assembleProviders(builtin []ProviderConfig, custom []ProviderConfig) ([]Pro return nil } - for _, provider := range builtin { - candidate := cloneProviderConfig(provider) - if candidate.Source == "" { - candidate.Source = ProviderSourceBuiltin - } - if err := appendProvider(candidate); err != nil { - return nil, err - } + sections := []struct { + providers []ProviderConfig + source ProviderSource + }{ + {providers: builtin, source: ProviderSourceBuiltin}, + {providers: custom, source: ProviderSourceCustom}, } - for _, provider := range custom { - candidate := cloneProviderConfig(provider) - if candidate.Source == "" { - candidate.Source = ProviderSourceCustom - } - if err := appendProvider(candidate); err != nil { - return nil, err + + for _, section := range sections { + for _, provider := range section.providers { + candidate := cloneProviderConfig(provider) + if candidate.Source == "" { + candidate.Source = section.source + } + if err := appendProvider(candidate); err != nil { + return nil, err + } } } + return assembled, nil } diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 7d0224e1..86c87754 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -8,6 +8,20 @@ import ( "testing" ) +func writeLoaderConfig(t *testing.T, loader *Loader, raw string) { + t.Helper() + if err := os.MkdirAll(loader.BaseDir(), 0o755); err != nil { + t.Fatalf("mkdir base dir: %v", err) + } + content := raw + if strings.Contains(raw, "\n") { + content = strings.TrimSpace(raw) + "\n" + } + if err := os.WriteFile(loader.ConfigPath(), []byte(content), 0o644); err != nil { + t.Fatalf("write config: %v", err) + } +} + func TestLoaderLoadMissingConfigCreatesDefault(t *testing.T) { t.Parallel() @@ -32,12 +46,7 @@ func TestLoaderLoadMalformedYAML(t *testing.T) { t.Parallel() loader := NewLoader(t.TempDir(), testDefaultConfig()) - if err := os.MkdirAll(loader.BaseDir(), 0o755); err != nil { - t.Fatalf("mkdir base dir: %v", err) - } - if err := os.WriteFile(loader.ConfigPath(), []byte("providers:\n - name: [\n"), 0o644); err != nil { - t.Fatalf("write malformed config: %v", err) - } + writeLoaderConfig(t, loader, "providers:\n - name: [\n") _, err := loader.Load(context.Background()) if err == nil || !strings.Contains(err.Error(), "parse config file") { @@ -49,18 +58,13 @@ func TestLoaderRejectsLegacyWorkdirKey(t *testing.T) { t.Parallel() loader := NewLoader(t.TempDir(), testDefaultConfig()) - if err := os.MkdirAll(loader.BaseDir(), 0o755); err != nil { - t.Fatalf("mkdir base dir: %v", err) - } raw := ` selected_provider: openai current_model: gpt-4.1 workdir: . shell: powershell ` - if err := os.WriteFile(loader.ConfigPath(), []byte(strings.TrimSpace(raw)+"\n"), 0o644); err != nil { - t.Fatalf("write legacy config: %v", err) - } + writeLoaderConfig(t, loader, raw) _, err := loader.Load(context.Background()) if err == nil || !strings.Contains(err.Error(), "field workdir not found") { @@ -72,18 +76,13 @@ func TestLoaderRejectsLegacyDefaultWorkdirKey(t *testing.T) { t.Parallel() loader := NewLoader(t.TempDir(), testDefaultConfig()) - if err := os.MkdirAll(loader.BaseDir(), 0o755); err != nil { - t.Fatalf("mkdir base dir: %v", err) - } raw := ` selected_provider: openai current_model: gpt-4.1 default_workdir: . shell: powershell ` - if err := os.WriteFile(loader.ConfigPath(), []byte(strings.TrimSpace(raw)+"\n"), 0o644); err != nil { - t.Fatalf("write legacy config: %v", err) - } + writeLoaderConfig(t, loader, raw) _, err := loader.Load(context.Background()) if err == nil || !strings.Contains(err.Error(), "field default_workdir not found") { @@ -111,9 +110,6 @@ func TestLoaderRejectsLegacyProvidersFormatOnLoad(t *testing.T) { t.Parallel() loader := NewLoader(t.TempDir(), testDefaultConfig()) - if err := os.MkdirAll(loader.BaseDir(), 0o755); err != nil { - t.Fatalf("mkdir base dir: %v", err) - } legacy := ` selected_provider: openai @@ -126,9 +122,7 @@ providers: model: gpt-5.4 api_key_env: OPENAI_API_KEY ` - if err := os.WriteFile(loader.ConfigPath(), []byte(strings.TrimSpace(legacy)+"\n"), 0o644); err != nil { - t.Fatalf("write legacy config: %v", err) - } + writeLoaderConfig(t, loader, legacy) _, err := loader.Load(context.Background()) if err == nil || !strings.Contains(err.Error(), "field providers not found") { @@ -140,17 +134,12 @@ func TestLoaderPreservesSelectionStateOnLoad(t *testing.T) { t.Parallel() loader := NewLoader(t.TempDir(), testDefaultConfig()) - if err := os.MkdirAll(loader.BaseDir(), 0o755); err != nil { - t.Fatalf("mkdir base dir: %v", err) - } raw := ` selected_provider: missing-provider shell: powershell ` - if err := os.WriteFile(loader.ConfigPath(), []byte(strings.TrimSpace(raw)+"\n"), 0o644); err != nil { - t.Fatalf("write config: %v", err) - } + writeLoaderConfig(t, loader, raw) cfg, err := loader.Load(context.Background()) if err != nil { @@ -177,17 +166,12 @@ func TestLoaderPreservesMissingCurrentModelOnLoad(t *testing.T) { t.Parallel() loader := NewLoader(t.TempDir(), testDefaultConfig()) - if err := os.MkdirAll(loader.BaseDir(), 0o755); err != nil { - t.Fatalf("mkdir base dir: %v", err) - } raw := ` selected_provider: openai shell: powershell ` - if err := os.WriteFile(loader.ConfigPath(), []byte(strings.TrimSpace(raw)+"\n"), 0o644); err != nil { - t.Fatalf("write config: %v", err) - } + writeLoaderConfig(t, loader, raw) cfg, err := loader.Load(context.Background()) if err != nil { @@ -223,9 +207,7 @@ func TestLoaderAllowsSelectedCustomProviderWithEmptyCurrentModel(t *testing.T) { selected_provider: company-gateway shell: powershell ` - if err := os.WriteFile(loader.ConfigPath(), []byte(strings.TrimSpace(rawConfig)+"\n"), 0o644); err != nil { - t.Fatalf("write config: %v", err) - } + writeLoaderConfig(t, loader, rawConfig) providerYAML := ` name: company-gateway @@ -263,9 +245,7 @@ selected_provider: company-gateway current_model: deepseek-coder shell: powershell ` - if err := os.WriteFile(loader.ConfigPath(), []byte(strings.TrimSpace(rawConfig)+"\n"), 0o644); err != nil { - t.Fatalf("write config: %v", err) - } + writeLoaderConfig(t, loader, rawConfig) providerYAML := ` name: company-gateway @@ -523,9 +503,7 @@ selected_provider: company-gateway current_model: server-model shell: powershell ` - if err := os.WriteFile(loader.ConfigPath(), []byte(strings.TrimSpace(rawConfig)+"\n"), 0o644); err != nil { - t.Fatalf("write config: %v", err) - } + writeLoaderConfig(t, loader, rawConfig) providerYAML := ` name: company-gateway @@ -582,9 +560,7 @@ selected_provider: company-gateway current_model: server-model shell: powershell ` - if err := os.WriteFile(loader.ConfigPath(), []byte(strings.TrimSpace(rawConfig)+"\n"), 0o644); err != nil { - t.Fatalf("write config: %v", err) - } + writeLoaderConfig(t, loader, rawConfig) providerYAML := ` name: company-gateway @@ -764,9 +740,6 @@ func TestLoaderMemoConfigPreservesExplicitFalse(t *testing.T) { t.Parallel() loader := NewLoader(t.TempDir(), testDefaultConfig()) - if err := os.MkdirAll(loader.BaseDir(), 0o755); err != nil { - t.Fatalf("mkdir base dir: %v", err) - } raw := ` selected_provider: openai @@ -777,9 +750,7 @@ memo: auto_extract: false max_index_lines: 123 ` - if err := os.WriteFile(loader.ConfigPath(), []byte(strings.TrimSpace(raw)+"\n"), 0o644); err != nil { - t.Fatalf("write config: %v", err) - } + writeLoaderConfig(t, loader, raw) cfg, err := loader.Load(context.Background()) if err != nil { @@ -812,18 +783,13 @@ func TestLoaderMemoConfigAppliesDefaultsWhenSectionMissing(t *testing.T) { t.Parallel() loader := NewLoader(t.TempDir(), testDefaultConfig()) - if err := os.MkdirAll(loader.BaseDir(), 0o755); err != nil { - t.Fatalf("mkdir base dir: %v", err) - } raw := ` selected_provider: openai current_model: gpt-4.1 shell: powershell ` - if err := os.WriteFile(loader.ConfigPath(), []byte(strings.TrimSpace(raw)+"\n"), 0o644); err != nil { - t.Fatalf("write config: %v", err) - } + writeLoaderConfig(t, loader, raw) cfg, err := loader.Load(context.Background()) if err != nil { @@ -839,3 +805,70 @@ shell: powershell t.Fatalf("expected memo.max_index_lines to be defaulted, got %d", cfg.Memo.MaxIndexLines) } } + +func TestLoaderLoadsCompactExtendedFields(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + + raw := ` +selected_provider: openai +current_model: gpt-4.1 +shell: powershell +context: + compact: + manual_strategy: keep_recent + manual_keep_recent_messages: 9 + max_summary_chars: 900 + micro_compact_retained_tool_spans: 4 + max_archived_prompt_chars: 4096 +` + writeLoaderConfig(t, loader, raw) + + cfg, err := loader.Load(context.Background()) + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if cfg.Context.Compact.MicroCompactRetainedToolSpans != 4 { + t.Fatalf("expected micro_compact_retained_tool_spans=4, got %d", cfg.Context.Compact.MicroCompactRetainedToolSpans) + } + if cfg.Context.Compact.MaxArchivedPromptChars != 4096 { + t.Fatalf("expected max_archived_prompt_chars=4096, got %d", cfg.Context.Compact.MaxArchivedPromptChars) + } +} + +func TestLoaderSaveRoundTripsCompactExtendedFields(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + cfg := loader.DefaultConfig() + cfg.Context.Compact.MicroCompactRetainedToolSpans = 5 + cfg.Context.Compact.MaxArchivedPromptChars = 3072 + + if err := loader.Save(context.Background(), &cfg); err != nil { + t.Fatalf("Save() error = %v", err) + } + + data, err := os.ReadFile(loader.ConfigPath()) + if err != nil { + t.Fatalf("read config: %v", err) + } + text := string(data) + if !strings.Contains(text, "micro_compact_retained_tool_spans: 5") { + t.Fatalf("expected persisted micro_compact_retained_tool_spans, got:\n%s", text) + } + if !strings.Contains(text, "max_archived_prompt_chars: 3072") { + t.Fatalf("expected persisted max_archived_prompt_chars, got:\n%s", text) + } + + loaded, err := loader.Load(context.Background()) + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if loaded.Context.Compact.MicroCompactRetainedToolSpans != 5 { + t.Fatalf("expected round-trip micro_compact_retained_tool_spans=5, got %d", loaded.Context.Compact.MicroCompactRetainedToolSpans) + } + if loaded.Context.Compact.MaxArchivedPromptChars != 3072 { + t.Fatalf("expected round-trip max_archived_prompt_chars=3072, got %d", loaded.Context.Compact.MaxArchivedPromptChars) + } +} diff --git a/internal/runtime/compact_generator.go b/internal/runtime/compact_generator.go index 8d63ea04..570dd276 100644 --- a/internal/runtime/compact_generator.go +++ b/internal/runtime/compact_generator.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "io" "strings" @@ -115,17 +116,42 @@ func parseCompactSummaryOutput(content string) (contextcompact.SummaryOutput, er } task := raw.TaskState + progress, err := coerceStringArray("progress", task.Progress) + if err != nil { + return contextcompact.SummaryOutput{}, err + } + openItems, err := coerceStringArray("open_items", task.OpenItems) + if err != nil { + return contextcompact.SummaryOutput{}, err + } + blockers, err := coerceStringArray("blockers", task.Blockers) + if err != nil { + return contextcompact.SummaryOutput{}, err + } + keyArtifacts, err := coerceStringArray("key_artifacts", task.KeyArtifacts) + if err != nil { + return contextcompact.SummaryOutput{}, err + } + decisions, err := coerceStringArray("decisions", task.Decisions) + if err != nil { + return contextcompact.SummaryOutput{}, err + } + userConstraints, err := coerceStringArray("user_constraints", task.UserConstraints) + if err != nil { + return contextcompact.SummaryOutput{}, err + } + output := contextcompact.SummaryOutput{ DisplaySummary: strings.TrimSpace(raw.DisplaySummary), TaskState: agentsession.TaskState{ Goal: task.Goal, - Progress: coerceStringArray(task.Progress), - OpenItems: coerceStringArray(task.OpenItems), + Progress: progress, + OpenItems: openItems, NextStep: task.NextStep, - Blockers: coerceStringArray(task.Blockers), - KeyArtifacts: coerceStringArray(task.KeyArtifacts), - Decisions: coerceStringArray(task.Decisions), - UserConstraints: coerceStringArray(task.UserConstraints), + Blockers: blockers, + KeyArtifacts: keyArtifacts, + Decisions: decisions, + UserConstraints: userConstraints, }, } @@ -151,29 +177,33 @@ func decodeCompactSummaryResponse(jsonText string) (tolerantSummaryResponse, err } // coerceStringArray 尝试将 json.RawMessage 解析为 []string,容忍单个 string 值。 -func coerceStringArray(raw json.RawMessage) []string { +func coerceStringArray(fieldName string, raw json.RawMessage) ([]string, error) { if len(raw) == 0 { - return nil + return nil, nil } // 根据首字节判断 JSON 类型,避免双重 Unmarshal switch raw[0] { case '[': var arr []string - if err := json.Unmarshal(raw, &arr); err == nil { - return arr + if err := json.Unmarshal(raw, &arr); err != nil { + return nil, fmt.Errorf("runtime: compact summary task_state.%s must be string array: %w", fieldName, err) } + return arr, nil case '"': var s string - if err := json.Unmarshal(raw, &s); err == nil { - trimmed := strings.TrimSpace(s) - if trimmed != "" { - return []string{trimmed} - } + if err := json.Unmarshal(raw, &s); err != nil { + return nil, fmt.Errorf("runtime: compact summary task_state.%s must be string: %w", fieldName, err) + } + trimmed := strings.TrimSpace(s) + if trimmed != "" { + return []string{trimmed}, nil } + return nil, nil + case 'n': + return nil, nil } - // null、数字、布尔、对象等均返回 nil - return nil + return nil, fmt.Errorf("runtime: compact summary task_state.%s must be string or string array", fieldName) } // extractJSONObject 从模型响应中提取首个满足 compact 协议的 JSON 对象,容忍前后噪音。 diff --git a/internal/runtime/compact_generator_test.go b/internal/runtime/compact_generator_test.go index b59c9e0e..2472dccf 100644 --- a/internal/runtime/compact_generator_test.go +++ b/internal/runtime/compact_generator_test.go @@ -268,16 +268,16 @@ func TestParseCompactSummaryOutputToleratesStringInsteadOfArray(t *testing.T) { wantOK: true, }, { - name: "数字代替数组产生nil", + name: "数字代替数组报错", json: `{"task_state":{"goal":"g","progress":42,"open_items":[],"next_step":"n","blockers":[],"key_artifacts":[],"decisions":[],"user_constraints":[]},"display_summary":"summary"}`, want: nil, - wantOK: true, + wantOK: false, }, { - name: "嵌套对象代替数组产生nil", + name: "嵌套对象代替数组报错", json: `{"task_state":{"goal":"g","progress":{"nested":true},"open_items":[],"next_step":"n","blockers":[],"key_artifacts":[],"decisions":[],"user_constraints":[]},"display_summary":"summary"}`, want: nil, - wantOK: true, + wantOK: false, }, } @@ -307,9 +307,10 @@ func TestCoerceStringArray(t *testing.T) { t.Parallel() tests := []struct { - name string - raw string - want []string + name string + raw string + want []string + wantErr bool }{ { name: "正常字符串数组", @@ -332,19 +333,22 @@ func TestCoerceStringArray(t *testing.T) { want: nil, }, { - name: "数字返回nil", - raw: `42`, - want: nil, + name: "数字返回nil", + raw: `42`, + want: nil, + wantErr: true, }, { - name: "布尔返回nil", - raw: `true`, - want: nil, + name: "布尔返回nil", + raw: `true`, + want: nil, + wantErr: true, }, { - name: "嵌套对象返回nil", - raw: `{"key":"val"}`, - want: nil, + name: "嵌套对象返回nil", + raw: `{"key":"val"}`, + want: nil, + wantErr: true, }, { name: "空RawMessage返回nil", @@ -356,7 +360,13 @@ func TestCoerceStringArray(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - got := coerceStringArray(json.RawMessage(tt.raw)) + got, err := coerceStringArray("progress", json.RawMessage(tt.raw)) + if (err != nil) != tt.wantErr { + t.Fatalf("coerceStringArray(%q) error = %v, wantErr %v", tt.raw, err, tt.wantErr) + } + if tt.wantErr { + return + } if len(got) != len(tt.want) { t.Fatalf("coerceStringArray(%q) = %v, want %v", tt.raw, got, tt.want) }