Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 22 additions & 24 deletions internal/runtime/compact.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,23 +96,31 @@ func (s *Service) runCompactForSession(
mode contextcompact.Mode,
errorPolicy compactErrorPolicy,
) (agentsession.Session, contextcompact.Result, error) {
failCompact := func(err error) (agentsession.Session, contextcompact.Result, error) {
s.emit(ctx, EventCompactError, runID, session.ID, CompactErrorPayload{
TriggerMode: string(mode),
Message: err.Error(),
})
if errorPolicy == compactErrorStrict {
return session, contextcompact.Result{}, err
}
return session, contextcompact.Result{}, nil
}

runner := s.compactRunner
if runner == nil {
var err error
runner, err = s.defaultCompactRunner(session, cfg)
if err != nil {
s.emit(ctx, EventCompactError, runID, session.ID, CompactErrorPayload{
TriggerMode: string(mode),
Message: err.Error(),
})
if errorPolicy == compactErrorStrict {
return session, contextcompact.Result{}, err
}
return session, contextcompact.Result{}, nil
return failCompact(err)
}
}

originalMessages := append([]providertypes.Message(nil), session.Messages...)
originalTaskState := session.TaskState.Clone()
originalTokenInputTotal := session.TokenInputTotal
originalTokenOutputTotal := session.TokenOutputTotal
originalUpdatedAt := session.UpdatedAt
s.emit(ctx, EventCompactStart, runID, session.ID, string(mode))

result, err := runner.Run(ctx, contextcompact.Input{
Expand All @@ -124,14 +132,7 @@ func (s *Service) runCompactForSession(
Config: cfg.Context.Compact,
})
if err != nil {
s.emit(ctx, EventCompactError, runID, session.ID, CompactErrorPayload{
TriggerMode: string(mode),
Message: err.Error(),
})
if errorPolicy == compactErrorStrict {
return session, contextcompact.Result{}, err
}
return session, contextcompact.Result{}, nil
return failCompact(err)
}

if result.Applied {
Expand All @@ -141,15 +142,12 @@ func (s *Service) runCompactForSession(
session.TokenOutputTotal = 0
session.UpdatedAt = time.Now()
if err := s.sessionStore.Save(ctx, &session); err != nil {
s.emit(ctx, EventCompactError, runID, session.ID, CompactErrorPayload{
TriggerMode: string(mode),
Message: err.Error(),
})
session.Messages = originalMessages
if errorPolicy == compactErrorStrict {
return session, contextcompact.Result{}, err
}
return session, contextcompact.Result{}, nil
session.TaskState = originalTaskState
session.TokenInputTotal = originalTokenInputTotal
session.TokenOutputTotal = originalTokenOutputTotal
session.UpdatedAt = originalUpdatedAt
return failCompact(err)
}
}

Expand Down
37 changes: 29 additions & 8 deletions internal/runtime/compact_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"io"
"strings"

agentcontext "neo-code/internal/context"
Expand Down Expand Up @@ -105,29 +106,49 @@ func parseCompactSummaryOutput(content string) (contextcompact.SummaryOutput, er
return contextcompact.SummaryOutput{}, err
}

var response compactSummaryResponse
if err := json.Unmarshal([]byte(jsonText), &response); err != nil {
response, err := decodeCompactSummaryResponse(jsonText)
if err != nil {
return contextcompact.SummaryOutput{}, err
}

output := contextcompact.SummaryOutput{
DisplaySummary: strings.TrimSpace(response.DisplaySummary),
}
output.TaskState.Goal = response.TaskState.Goal
output.TaskState.Progress = append([]string(nil), response.TaskState.Progress...)
output.TaskState.OpenItems = append([]string(nil), response.TaskState.OpenItems...)
output.TaskState.Progress = cloneStringSlice(response.TaskState.Progress)
output.TaskState.OpenItems = cloneStringSlice(response.TaskState.OpenItems)
output.TaskState.NextStep = response.TaskState.NextStep
output.TaskState.Blockers = append([]string(nil), response.TaskState.Blockers...)
output.TaskState.KeyArtifacts = append([]string(nil), response.TaskState.KeyArtifacts...)
output.TaskState.Decisions = append([]string(nil), response.TaskState.Decisions...)
output.TaskState.UserConstraints = append([]string(nil), response.TaskState.UserConstraints...)
output.TaskState.Blockers = cloneStringSlice(response.TaskState.Blockers)
output.TaskState.KeyArtifacts = cloneStringSlice(response.TaskState.KeyArtifacts)
output.TaskState.Decisions = cloneStringSlice(response.TaskState.Decisions)
output.TaskState.UserConstraints = cloneStringSlice(response.TaskState.UserConstraints)

if output.DisplaySummary == "" {
return contextcompact.SummaryOutput{}, errors.New("runtime: compact summary response is empty")
}
return output, nil
}

// decodeCompactSummaryResponse 对 compact JSON 响应执行严格解码,拒绝未知字段与尾随垃圾内容。
func decodeCompactSummaryResponse(jsonText string) (compactSummaryResponse, error) {
decoder := json.NewDecoder(strings.NewReader(jsonText))
decoder.DisallowUnknownFields()

var response compactSummaryResponse
if err := decoder.Decode(&response); err != nil {
return compactSummaryResponse{}, err
}
if err := decoder.Decode(&struct{}{}); err != nil && !errors.Is(err, io.EOF) {
return compactSummaryResponse{}, errors.New("runtime: compact summary response contains trailing JSON content")
}
return response, nil
}

// cloneStringSlice 复制字符串切片,避免结果复用解析对象的底层数组。
func cloneStringSlice(items []string) []string {
return append([]string(nil), items...)
}

// extractJSONObject 从模型响应中提取最外层 JSON 对象,容忍前后噪音。
func extractJSONObject(text string) (string, error) {
start := strings.Index(text, "{")
Expand Down
18 changes: 18 additions & 0 deletions internal/runtime/compact_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,21 @@ func TestCompactSummaryGeneratorMalformedStreamEventDoesNotDeadlock(t *testing.T
t.Fatal("expected compact generation to fail instead of deadlocking on malformed stream event")
}
}

func TestParseCompactSummaryOutputRejectsUnknownTopLevelField(t *testing.T) {
t.Parallel()

content := `{"task_state":{"goal":"g","progress":[],"open_items":[],"next_step":"","blockers":[],"key_artifacts":[],"decisions":[],"user_constraints":[]},"display_summary":"[compact_summary]\nok","unexpected":"value"}`
if _, err := parseCompactSummaryOutput(content); err == nil {
t.Fatal("expected unknown top-level field to be rejected")
}
}

func TestParseCompactSummaryOutputRejectsUnknownTaskStateField(t *testing.T) {
t.Parallel()

content := `{"task_state":{"goal":"g","progress":[],"open_items":[],"next_step":"","blockers":[],"key_artifacts":[],"decisions":[],"user_constraints":[],"extra":"x"},"display_summary":"[compact_summary]\nok"}`
if _, err := parseCompactSummaryOutput(content); err == nil {
t.Fatal("expected unknown task_state field to be rejected")
}
}
26 changes: 26 additions & 0 deletions internal/runtime/runtime_gap_coverage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ func TestRunCompactForSessionSaveErrorPolicyBranches(t *testing.T) {
baseStore := newMemoryStore()
session := newRuntimeSession("session-compact-save-error")
session.Messages = []providertypes.Message{{Role: providertypes.RoleUser, Content: "before"}}
session.TaskState.Goal = "before-goal"
session.TokenInputTotal = 11
session.TokenOutputTotal = 22
originalUpdatedAt := time.Unix(1700000000, 0)
session.UpdatedAt = originalUpdatedAt
baseStore.sessions[session.ID] = cloneSession(session)

store := &failingStore{Store: baseStore, saveErr: errors.New("save failed"), failOnSave: 1, ignoreContextErr: true}
Expand All @@ -87,6 +92,9 @@ func TestRunCompactForSessionSaveErrorPolicyBranches(t *testing.T) {
compactRunner: &stubCompactRunner{result: contextcompact.Result{
Applied: true,
Messages: []providertypes.Message{{Role: providertypes.RoleAssistant, Content: "after"}},
TaskState: agentsession.TaskState{
Goal: "after-goal",
},
}},
}

Expand All @@ -97,6 +105,15 @@ func TestRunCompactForSessionSaveErrorPolicyBranches(t *testing.T) {
if strictSession.Messages[0].Content != "before" {
t.Fatalf("expected strict mode to rollback messages, got %+v", strictSession.Messages)
}
if strictSession.TaskState.Goal != "before-goal" {
t.Fatalf("expected strict mode to rollback task state, got %+v", strictSession.TaskState)
}
if strictSession.TokenInputTotal != 11 || strictSession.TokenOutputTotal != 22 {
t.Fatalf("expected strict mode to rollback token totals, got input=%d output=%d", strictSession.TokenInputTotal, strictSession.TokenOutputTotal)
}
if !strictSession.UpdatedAt.Equal(originalUpdatedAt) {
t.Fatalf("expected strict mode to rollback updated_at, got %s", strictSession.UpdatedAt)
}

store.saveCalls = 0
bestEffortSession, bestEffortResult, err := service.runCompactForSession(context.Background(), "run-compact-save", session, config.Config{}, contextcompact.ModeManual, compactErrorBestEffort)
Expand All @@ -109,6 +126,15 @@ func TestRunCompactForSessionSaveErrorPolicyBranches(t *testing.T) {
if bestEffortSession.Messages[0].Content != "before" {
t.Fatalf("expected best effort rollback messages, got %+v", bestEffortSession.Messages)
}
if bestEffortSession.TaskState.Goal != "before-goal" {
t.Fatalf("expected best effort rollback task state, got %+v", bestEffortSession.TaskState)
}
if bestEffortSession.TokenInputTotal != 11 || bestEffortSession.TokenOutputTotal != 22 {
t.Fatalf("expected best effort rollback token totals, got input=%d output=%d", bestEffortSession.TokenInputTotal, bestEffortSession.TokenOutputTotal)
}
if !bestEffortSession.UpdatedAt.Equal(originalUpdatedAt) {
t.Fatalf("expected best effort rollback updated_at, got %s", bestEffortSession.UpdatedAt)
}
}

func TestCompactProviderSelectionErrorBranches(t *testing.T) {
Expand Down
Loading