From b2f30b478f3d2655d3e95579ca585d26629f063c Mon Sep 17 00:00:00 2001 From: xgopilot Date: Mon, 13 Apr 2026 14:54:34 +0000 Subject: [PATCH] fix(session,runtime): tighten summary task_state validation and linearize compact JSON extraction Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: Yumiue <188874804+Yumiue@users.noreply.github.com> --- internal/runtime/compact_generator.go | 52 +++++------ internal/runtime/compact_generator_test.go | 102 ++++++++++----------- internal/session/store.go | 20 +++- internal/session/store_test.go | 66 ++++++++++--- 4 files changed, 144 insertions(+), 96 deletions(-) diff --git a/internal/runtime/compact_generator.go b/internal/runtime/compact_generator.go index 321cc7b0..24db3b24 100644 --- a/internal/runtime/compact_generator.go +++ b/internal/runtime/compact_generator.go @@ -151,35 +151,13 @@ func cloneStringSlice(items []string) []string { // extractJSONObject 从模型响应中提取首个满足 compact 协议的 JSON 对象,容忍前后噪音。 func extractJSONObject(text string) (string, error) { - start := strings.IndexByte(text, '{') - if start < 0 { - return "", errors.New("runtime: compact summary response does not contain a JSON object") - } - - for { - candidate, err := extractJSONObjectCandidate(text, start) - if err == nil { - if _, decodeErr := decodeCompactSummaryResponse(candidate); decodeErr == nil { - return candidate, nil - } - } - - next := strings.IndexByte(text[start+1:], '{') - if next < 0 { - break - } - start += next + 1 - } - - return "", errors.New("runtime: compact summary response does not contain a valid compact JSON object") -} - -// extractJSONObjectCandidate 从给定起点抽取平衡的 JSON 对象片段。 -func extractJSONObjectCandidate(text string, start int) (string, error) { depth := 0 inString := false escaped := false - for index := start; index < len(text); index++ { + start := -1 + seenObjectStart := false + + for index := 0; index < len(text); index++ { ch := text[index] if inString { if escaped { @@ -199,14 +177,30 @@ func extractJSONObjectCandidate(text string, start int) (string, error) { case '"': inString = true case '{': + seenObjectStart = true + if depth == 0 { + start = index + } depth++ case '}': - depth-- if depth == 0 { - return strings.TrimSpace(text[start : index+1]), nil + continue + } + depth-- + if depth == 0 && start >= 0 { + candidate := strings.TrimSpace(text[start : index+1]) + if _, err := decodeCompactSummaryResponse(candidate); err == nil { + return candidate, nil + } } } } - return "", errors.New("runtime: compact summary response contains an incomplete JSON object") + if !seenObjectStart { + return "", errors.New("runtime: compact summary response does not contain a JSON object") + } + if depth > 0 { + return "", errors.New("runtime: compact summary response contains an incomplete JSON object") + } + return "", errors.New("runtime: compact summary response does not contain a valid compact JSON object") } diff --git a/internal/runtime/compact_generator_test.go b/internal/runtime/compact_generator_test.go index 6c389467..ce48b937 100644 --- a/internal/runtime/compact_generator_test.go +++ b/internal/runtime/compact_generator_test.go @@ -23,19 +23,12 @@ func validCompactSummaryJSON() string { func TestCompactSummaryGeneratorBuildsProviderRequestWithoutTools(t *testing.T) { t.Parallel() - manager := newRuntimeConfigManager(t) - resolvedProvider, err := resolvedProviderForTests(manager.Get(), config.OpenAIName) - if err != nil { - t.Fatalf("resolve provider: %v", err) - } - scripted := &scriptedProvider{ streams: [][]providertypes.StreamEvent{ {providertypes.NewTextDeltaStreamEvent(validCompactSummaryJSON())}, }, } - factory := &scriptedProviderFactory{provider: scripted} - generator := newCompactSummaryGenerator(factory, resolvedProvider.ToRuntimeConfig(), "session-model") + generator, factory, manager := newCompactGeneratorTestSetup(t, scripted) summary, err := generator.Generate(context.Background(), contextcompact.SummaryInput{ Mode: contextcompact.ModeManual, @@ -116,12 +109,6 @@ func TestCompactSummaryGeneratorBuildsProviderRequestWithoutTools(t *testing.T) func TestCompactSummaryGeneratorRejectsToolCalls(t *testing.T) { t.Parallel() - manager := newRuntimeConfigManager(t) - resolvedProvider, err := resolvedProviderForTests(manager.Get(), config.OpenAIName) - if err != nil { - t.Fatalf("resolve provider: %v", err) - } - scripted := &scriptedProvider{ streams: [][]providertypes.StreamEvent{ { @@ -130,29 +117,21 @@ func TestCompactSummaryGeneratorRejectsToolCalls(t *testing.T) { }, }, } - generator := newCompactSummaryGenerator(&scriptedProviderFactory{provider: scripted}, resolvedProvider.ToRuntimeConfig(), "session-model") + generator, _, manager := newCompactGeneratorTestSetup(t, scripted) - _, err = generator.Generate(context.Background(), contextcompact.SummaryInput{ + _, err := generator.Generate(context.Background(), contextcompact.SummaryInput{ Mode: contextcompact.ModeManual, ArchivedMessages: []providertypes.Message{ {Role: providertypes.RoleUser, Content: "legacy request"}, }, Config: manager.Get().Context.Compact, }) - if err == nil || !strings.Contains(err.Error(), "must not contain tool calls") { - t.Fatalf("expected tool call rejection, got %v", err) - } + requireErrorContains(t, err, "must not contain tool calls") } func TestCompactSummaryGeneratorRejectsMalformedStreamEvent(t *testing.T) { t.Parallel() - manager := newRuntimeConfigManager(t) - resolvedProvider, err := resolvedProviderForTests(manager.Get(), config.OpenAIName) - if err != nil { - t.Fatalf("resolve provider: %v", err) - } - scripted := &scriptedProvider{ streams: [][]providertypes.StreamEvent{ { @@ -160,26 +139,18 @@ func TestCompactSummaryGeneratorRejectsMalformedStreamEvent(t *testing.T) { }, }, } - generator := newCompactSummaryGenerator(&scriptedProviderFactory{provider: scripted}, resolvedProvider.ToRuntimeConfig(), "session-model") + generator, _, manager := newCompactGeneratorTestSetup(t, scripted) - _, err = generator.Generate(context.Background(), contextcompact.SummaryInput{ + _, err := generator.Generate(context.Background(), contextcompact.SummaryInput{ Mode: contextcompact.ModeManual, Config: manager.Get().Context.Compact, }) - if err == nil || !strings.Contains(err.Error(), "text_delta event payload is nil") { - t.Fatalf("expected malformed stream event rejection, got %v", err) - } + requireErrorContains(t, err, "text_delta event payload is nil") } func TestCompactSummaryGeneratorRejectsCompletionWithoutMessageDone(t *testing.T) { t.Parallel() - manager := newRuntimeConfigManager(t) - resolvedProvider, err := resolvedProviderForTests(manager.Get(), config.OpenAIName) - if err != nil { - t.Fatalf("resolve provider: %v", err) - } - scripted := &scriptedProvider{ chatFn: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { select { @@ -190,29 +161,21 @@ func TestCompactSummaryGeneratorRejectsCompletionWithoutMessageDone(t *testing.T return nil }, } - generator := newCompactSummaryGenerator(&scriptedProviderFactory{provider: scripted}, resolvedProvider.ToRuntimeConfig(), "session-model") + generator, _, manager := newCompactGeneratorTestSetup(t, scripted) - _, err = generator.Generate(context.Background(), contextcompact.SummaryInput{ + _, err := generator.Generate(context.Background(), contextcompact.SummaryInput{ Mode: contextcompact.ModeManual, Config: manager.Get().Context.Compact, }) if !errors.Is(err, provider.ErrStreamInterrupted) { t.Fatalf("expected ErrStreamInterrupted, got %v", err) } - if !strings.Contains(err.Error(), "without message_done") { - t.Fatalf("expected missing message_done error, got %v", err) - } + requireErrorContains(t, err, "without message_done") } func TestCompactSummaryGeneratorMalformedStreamEventDoesNotDeadlock(t *testing.T) { t.Parallel() - manager := newRuntimeConfigManager(t) - resolvedProvider, err := resolvedProviderForTests(manager.Get(), config.OpenAIName) - if err != nil { - t.Fatalf("resolve provider: %v", err) - } - stream := []providertypes.StreamEvent{{Type: providertypes.StreamEventTextDelta}} for i := 0; i < 40; i++ { stream = append(stream, providertypes.NewTextDeltaStreamEvent("ignored")) @@ -220,7 +183,7 @@ func TestCompactSummaryGeneratorMalformedStreamEventDoesNotDeadlock(t *testing.T scripted := &scriptedProvider{ streams: [][]providertypes.StreamEvent{stream}, } - generator := newCompactSummaryGenerator(&scriptedProviderFactory{provider: scripted}, resolvedProvider.ToRuntimeConfig(), "session-model") + generator, _, manager := newCompactGeneratorTestSetup(t, scripted) errCh := make(chan error, 1) go func() { @@ -233,9 +196,7 @@ func TestCompactSummaryGeneratorMalformedStreamEventDoesNotDeadlock(t *testing.T select { case genErr := <-errCh: - if genErr == nil || !strings.Contains(genErr.Error(), "text_delta event payload is nil") { - t.Fatalf("expected malformed stream event rejection, got %v", genErr) - } + requireErrorContains(t, genErr, "text_delta event payload is nil") case <-time.After(2 * time.Second): t.Fatal("expected compact generation to fail instead of deadlocking on malformed stream event") } @@ -275,3 +236,42 @@ func TestParseCompactSummaryOutputSkipsNonCompactJSONPreface(t *testing.T) { t.Fatalf("expected parsed goal, got %+v", output.TaskState) } } + +func TestParseCompactSummaryOutputSkipsBraceNoiseAndFindsFirstValidObject(t *testing.T) { + t.Parallel() + + noise := strings.Repeat("{not-json}", 80) + content := strings.Join([]string{ + noise, + `{"task_state":{"goal":"g","progress":[],"open_items":[],"next_step":"","blockers":[],"key_artifacts":[],"decisions":[],"user_constraints":[]},"display_summary":"[compact_summary]\nok"}`, + }, "\n") + + output, err := parseCompactSummaryOutput(content) + if err != nil { + t.Fatalf("expected parser to recover valid compact payload after brace noise, got %v", err) + } + if output.TaskState.Goal != "g" { + t.Fatalf("expected parsed goal, got %+v", output.TaskState) + } +} + +func newCompactGeneratorTestSetup(t *testing.T, scripted *scriptedProvider) (contextcompact.SummaryGenerator, *scriptedProviderFactory, *config.Manager) { + t.Helper() + + manager := newRuntimeConfigManager(t) + resolvedProvider, err := resolvedProviderForTests(manager.Get(), config.OpenAIName) + if err != nil { + t.Fatalf("resolve provider: %v", err) + } + + factory := &scriptedProviderFactory{provider: scripted} + generator := newCompactSummaryGenerator(factory, resolvedProvider.ToRuntimeConfig(), "session-model") + return generator, factory, manager +} + +func requireErrorContains(t *testing.T, err error, substring string) { + t.Helper() + if err == nil || !strings.Contains(err.Error(), substring) { + t.Fatalf("expected error containing %q, got %v", substring, err) + } +} diff --git a/internal/session/store.go b/internal/session/store.go index a5018040..3ef30fa7 100644 --- a/internal/session/store.go +++ b/internal/session/store.go @@ -86,8 +86,8 @@ func (s *JSONStore) Save(ctx context.Context, session *Session) error { s.mu.Lock() defer s.mu.Unlock() - if err := os.MkdirAll(s.baseDir, 0o755); err != nil { - return fmt.Errorf("session: create sessions dir: %w", err) + if err := s.ensureSessionsDir(); err != nil { + return err } payload, err := json.MarshalIndent(session, "", " ") @@ -141,8 +141,8 @@ func (s *JSONStore) ListSummaries(ctx context.Context) ([]Summary, error) { s.mu.RLock() defer s.mu.RUnlock() - if err := os.MkdirAll(s.baseDir, 0o755); err != nil { - return nil, fmt.Errorf("session: create sessions dir: %w", err) + if err := s.ensureSessionsDir(); err != nil { + return nil, err } entries, err := os.ReadDir(s.baseDir) @@ -189,6 +189,14 @@ func (s *JSONStore) filePath(id string) string { return filepath.Join(s.baseDir, id+".json") } +// ensureSessionsDir 确保会话目录存在,避免重复的 mkdir 原子逻辑。 +func (s *JSONStore) ensureSessionsDir() error { + if err := os.MkdirAll(s.baseDir, 0o755); err != nil { + return fmt.Errorf("session: create sessions dir: %w", err) + } + return nil +} + // New 创建一个默认标题策略的新会话对象。 func New(title string) Session { return NewWithWorkdir(title, "") @@ -308,6 +316,10 @@ func decodeStoredSummary(data []byte) (Summary, error) { if len(stored.TaskState) == 0 { return Summary{}, errors.New("missing required field task_state") } + var taskState *TaskState + if err := json.Unmarshal(stored.TaskState, &taskState); err != nil || taskState == nil { + return Summary{}, errors.New("invalid field task_state") + } if err := validateSessionSchema(Session{SchemaVersion: *stored.SchemaVersion}); err != nil { return Summary{}, err } diff --git a/internal/session/store_test.go b/internal/session/store_test.go index 5a0dc1e6..6f00e92d 100644 --- a/internal/session/store_test.go +++ b/internal/session/store_test.go @@ -17,12 +17,7 @@ import ( func TestJSONStoreSaveLoadAndListSummaries(t *testing.T) { t.Parallel() - baseDir := t.TempDir() - workspaceRoot := filepath.Join(t.TempDir(), "workspace") - if err := os.MkdirAll(workspaceRoot, 0o755); err != nil { - t.Fatalf("mkdir workspace root: %v", err) - } - store := NewJSONStore(baseDir, workspaceRoot) + store, baseDir, workspaceRoot := newTempWorkspaceStore(t) older := &Session{ SchemaVersion: CurrentSchemaVersion, @@ -47,12 +42,8 @@ func TestJSONStoreSaveLoadAndListSummaries(t *testing.T) { }, } - if err := store.Save(context.Background(), older); err != nil { - t.Fatalf("Save older session: %v", err) - } - if err := store.Save(context.Background(), newer); err != nil { - t.Fatalf("Save newer session: %v", err) - } + mustSaveSession(t, store, older) + mustSaveSession(t, store, newer) loaded, err := store.Load(context.Background(), older.ID) if err != nil { @@ -679,6 +670,38 @@ func TestDecodeStoredSummaryUsesLightweightMetadataPath(t *testing.T) { } } +func TestDecodeStoredSummaryRejectsNullTaskState(t *testing.T) { + t.Parallel() + + _, err := decodeStoredSummary([]byte(`{ + "schema_version": 1, + "id": "summary-null-task-state", + "title": "Summary Null Task State", + "created_at": "2026-04-13T08:00:00Z", + "updated_at": "2026-04-13T09:00:00Z", + "task_state": null +}`)) + if err == nil || !strings.Contains(err.Error(), "invalid field task_state") { + t.Fatalf("expected null task_state rejection, got %v", err) + } +} + +func TestDecodeStoredSummaryRejectsNonObjectTaskState(t *testing.T) { + t.Parallel() + + _, err := decodeStoredSummary([]byte(`{ + "schema_version": 1, + "id": "summary-invalid-task-state", + "title": "Summary Invalid Task State", + "created_at": "2026-04-13T08:00:00Z", + "updated_at": "2026-04-13T09:00:00Z", + "task_state": [] +}`)) + if err == nil || !strings.Contains(err.Error(), "invalid field task_state") { + t.Fatalf("expected non-object task_state rejection, got %v", err) + } +} + func TestJSONStoreSaveClampsOversizedTaskState(t *testing.T) { t.Parallel() @@ -784,6 +807,25 @@ func buildIndexedSuffix(index int) string { return string([]rune{hi, lo, 'x', 'x'}) } +func newTempWorkspaceStore(t testing.TB) (*JSONStore, string, string) { + t.Helper() + + baseDir := t.TempDir() + workspaceRoot := filepath.Join(t.TempDir(), "workspace") + if err := os.MkdirAll(workspaceRoot, 0o755); err != nil { + t.Fatalf("mkdir workspace root: %v", err) + } + + return NewJSONStore(baseDir, workspaceRoot), baseDir, workspaceRoot +} + +func mustSaveSession(t testing.TB, store *JSONStore, session *Session) { + t.Helper() + if err := store.Save(context.Background(), session); err != nil { + t.Fatalf("save session %s: %v", session.ID, err) + } +} + func mustWriteSessionFile(t *testing.T, path string, content string) { t.Helper() if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {