diff --git a/internal/runtime/compact.go b/internal/runtime/compact.go index a5ba64db..78f24857 100644 --- a/internal/runtime/compact.go +++ b/internal/runtime/compact.go @@ -96,23 +96,31 @@ func (s *Service) runCompactForSession( mode contextcompact.Mode, errorPolicy compactErrorPolicy, ) (agentsession.Session, contextcompact.Result, error) { + failCompact := func(err error) (agentsession.Session, contextcompact.Result, error) { + s.emit(ctx, EventCompactError, runID, session.ID, CompactErrorPayload{ + TriggerMode: string(mode), + Message: err.Error(), + }) + if errorPolicy == compactErrorStrict { + return session, contextcompact.Result{}, err + } + return session, contextcompact.Result{}, nil + } + runner := s.compactRunner if runner == nil { var err error runner, err = s.defaultCompactRunner(session, cfg) if err != nil { - s.emit(ctx, EventCompactError, runID, session.ID, CompactErrorPayload{ - TriggerMode: string(mode), - Message: err.Error(), - }) - if errorPolicy == compactErrorStrict { - return session, contextcompact.Result{}, err - } - return session, contextcompact.Result{}, nil + return failCompact(err) } } originalMessages := append([]providertypes.Message(nil), session.Messages...) + originalTaskState := session.TaskState.Clone() + originalTokenInputTotal := session.TokenInputTotal + originalTokenOutputTotal := session.TokenOutputTotal + originalUpdatedAt := session.UpdatedAt s.emit(ctx, EventCompactStart, runID, session.ID, string(mode)) result, err := runner.Run(ctx, contextcompact.Input{ @@ -124,14 +132,7 @@ func (s *Service) runCompactForSession( Config: cfg.Context.Compact, }) if err != nil { - s.emit(ctx, EventCompactError, runID, session.ID, CompactErrorPayload{ - TriggerMode: string(mode), - Message: err.Error(), - }) - if errorPolicy == compactErrorStrict { - return session, contextcompact.Result{}, err - } - return session, contextcompact.Result{}, nil + return failCompact(err) } if result.Applied { @@ -141,15 +142,12 @@ func (s *Service) runCompactForSession( session.TokenOutputTotal = 0 session.UpdatedAt = time.Now() if err := s.sessionStore.Save(ctx, &session); err != nil { - s.emit(ctx, EventCompactError, runID, session.ID, CompactErrorPayload{ - TriggerMode: string(mode), - Message: err.Error(), - }) session.Messages = originalMessages - if errorPolicy == compactErrorStrict { - return session, contextcompact.Result{}, err - } - return session, contextcompact.Result{}, nil + session.TaskState = originalTaskState + session.TokenInputTotal = originalTokenInputTotal + session.TokenOutputTotal = originalTokenOutputTotal + session.UpdatedAt = originalUpdatedAt + return failCompact(err) } } diff --git a/internal/runtime/compact_generator.go b/internal/runtime/compact_generator.go index f002aa38..6ef580ed 100644 --- a/internal/runtime/compact_generator.go +++ b/internal/runtime/compact_generator.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "io" "strings" agentcontext "neo-code/internal/context" @@ -105,8 +106,8 @@ func parseCompactSummaryOutput(content string) (contextcompact.SummaryOutput, er return contextcompact.SummaryOutput{}, err } - var response compactSummaryResponse - if err := json.Unmarshal([]byte(jsonText), &response); err != nil { + response, err := decodeCompactSummaryResponse(jsonText) + if err != nil { return contextcompact.SummaryOutput{}, err } @@ -114,13 +115,13 @@ func parseCompactSummaryOutput(content string) (contextcompact.SummaryOutput, er DisplaySummary: strings.TrimSpace(response.DisplaySummary), } output.TaskState.Goal = response.TaskState.Goal - output.TaskState.Progress = append([]string(nil), response.TaskState.Progress...) - output.TaskState.OpenItems = append([]string(nil), response.TaskState.OpenItems...) + output.TaskState.Progress = cloneStringSlice(response.TaskState.Progress) + output.TaskState.OpenItems = cloneStringSlice(response.TaskState.OpenItems) output.TaskState.NextStep = response.TaskState.NextStep - output.TaskState.Blockers = append([]string(nil), response.TaskState.Blockers...) - output.TaskState.KeyArtifacts = append([]string(nil), response.TaskState.KeyArtifacts...) - output.TaskState.Decisions = append([]string(nil), response.TaskState.Decisions...) - output.TaskState.UserConstraints = append([]string(nil), response.TaskState.UserConstraints...) + output.TaskState.Blockers = cloneStringSlice(response.TaskState.Blockers) + output.TaskState.KeyArtifacts = cloneStringSlice(response.TaskState.KeyArtifacts) + output.TaskState.Decisions = cloneStringSlice(response.TaskState.Decisions) + output.TaskState.UserConstraints = cloneStringSlice(response.TaskState.UserConstraints) if output.DisplaySummary == "" { return contextcompact.SummaryOutput{}, errors.New("runtime: compact summary response is empty") @@ -128,6 +129,26 @@ func parseCompactSummaryOutput(content string) (contextcompact.SummaryOutput, er return output, nil } +// decodeCompactSummaryResponse 对 compact JSON 响应执行严格解码,拒绝未知字段与尾随垃圾内容。 +func decodeCompactSummaryResponse(jsonText string) (compactSummaryResponse, error) { + decoder := json.NewDecoder(strings.NewReader(jsonText)) + decoder.DisallowUnknownFields() + + var response compactSummaryResponse + if err := decoder.Decode(&response); err != nil { + return compactSummaryResponse{}, err + } + if err := decoder.Decode(&struct{}{}); err != nil && !errors.Is(err, io.EOF) { + return compactSummaryResponse{}, errors.New("runtime: compact summary response contains trailing JSON content") + } + return response, nil +} + +// cloneStringSlice 复制字符串切片,避免结果复用解析对象的底层数组。 +func cloneStringSlice(items []string) []string { + return append([]string(nil), items...) +} + // extractJSONObject 从模型响应中提取最外层 JSON 对象,容忍前后噪音。 func extractJSONObject(text string) (string, error) { start := strings.Index(text, "{") diff --git a/internal/runtime/compact_generator_test.go b/internal/runtime/compact_generator_test.go index ee048646..ac2c908a 100644 --- a/internal/runtime/compact_generator_test.go +++ b/internal/runtime/compact_generator_test.go @@ -240,3 +240,21 @@ func TestCompactSummaryGeneratorMalformedStreamEventDoesNotDeadlock(t *testing.T t.Fatal("expected compact generation to fail instead of deadlocking on malformed stream event") } } + +func TestParseCompactSummaryOutputRejectsUnknownTopLevelField(t *testing.T) { + t.Parallel() + + content := `{"task_state":{"goal":"g","progress":[],"open_items":[],"next_step":"","blockers":[],"key_artifacts":[],"decisions":[],"user_constraints":[]},"display_summary":"[compact_summary]\nok","unexpected":"value"}` + if _, err := parseCompactSummaryOutput(content); err == nil { + t.Fatal("expected unknown top-level field to be rejected") + } +} + +func TestParseCompactSummaryOutputRejectsUnknownTaskStateField(t *testing.T) { + t.Parallel() + + content := `{"task_state":{"goal":"g","progress":[],"open_items":[],"next_step":"","blockers":[],"key_artifacts":[],"decisions":[],"user_constraints":[],"extra":"x"},"display_summary":"[compact_summary]\nok"}` + if _, err := parseCompactSummaryOutput(content); err == nil { + t.Fatal("expected unknown task_state field to be rejected") + } +} diff --git a/internal/runtime/runtime_gap_coverage_test.go b/internal/runtime/runtime_gap_coverage_test.go index 692a0f88..cd6d6685 100644 --- a/internal/runtime/runtime_gap_coverage_test.go +++ b/internal/runtime/runtime_gap_coverage_test.go @@ -78,6 +78,11 @@ func TestRunCompactForSessionSaveErrorPolicyBranches(t *testing.T) { baseStore := newMemoryStore() session := newRuntimeSession("session-compact-save-error") session.Messages = []providertypes.Message{{Role: providertypes.RoleUser, Content: "before"}} + session.TaskState.Goal = "before-goal" + session.TokenInputTotal = 11 + session.TokenOutputTotal = 22 + originalUpdatedAt := time.Unix(1700000000, 0) + session.UpdatedAt = originalUpdatedAt baseStore.sessions[session.ID] = cloneSession(session) store := &failingStore{Store: baseStore, saveErr: errors.New("save failed"), failOnSave: 1, ignoreContextErr: true} @@ -87,6 +92,9 @@ func TestRunCompactForSessionSaveErrorPolicyBranches(t *testing.T) { compactRunner: &stubCompactRunner{result: contextcompact.Result{ Applied: true, Messages: []providertypes.Message{{Role: providertypes.RoleAssistant, Content: "after"}}, + TaskState: agentsession.TaskState{ + Goal: "after-goal", + }, }}, } @@ -97,6 +105,15 @@ func TestRunCompactForSessionSaveErrorPolicyBranches(t *testing.T) { if strictSession.Messages[0].Content != "before" { t.Fatalf("expected strict mode to rollback messages, got %+v", strictSession.Messages) } + if strictSession.TaskState.Goal != "before-goal" { + t.Fatalf("expected strict mode to rollback task state, got %+v", strictSession.TaskState) + } + if strictSession.TokenInputTotal != 11 || strictSession.TokenOutputTotal != 22 { + t.Fatalf("expected strict mode to rollback token totals, got input=%d output=%d", strictSession.TokenInputTotal, strictSession.TokenOutputTotal) + } + if !strictSession.UpdatedAt.Equal(originalUpdatedAt) { + t.Fatalf("expected strict mode to rollback updated_at, got %s", strictSession.UpdatedAt) + } store.saveCalls = 0 bestEffortSession, bestEffortResult, err := service.runCompactForSession(context.Background(), "run-compact-save", session, config.Config{}, contextcompact.ModeManual, compactErrorBestEffort) @@ -109,6 +126,15 @@ func TestRunCompactForSessionSaveErrorPolicyBranches(t *testing.T) { if bestEffortSession.Messages[0].Content != "before" { t.Fatalf("expected best effort rollback messages, got %+v", bestEffortSession.Messages) } + if bestEffortSession.TaskState.Goal != "before-goal" { + t.Fatalf("expected best effort rollback task state, got %+v", bestEffortSession.TaskState) + } + if bestEffortSession.TokenInputTotal != 11 || bestEffortSession.TokenOutputTotal != 22 { + t.Fatalf("expected best effort rollback token totals, got input=%d output=%d", bestEffortSession.TokenInputTotal, bestEffortSession.TokenOutputTotal) + } + if !bestEffortSession.UpdatedAt.Equal(originalUpdatedAt) { + t.Fatalf("expected best effort rollback updated_at, got %s", bestEffortSession.UpdatedAt) + } } func TestCompactProviderSelectionErrorBranches(t *testing.T) {