Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 23 additions & 29 deletions internal/runtime/compact_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,35 +151,13 @@ func cloneStringSlice(items []string) []string {

// extractJSONObject 从模型响应中提取首个满足 compact 协议的 JSON 对象,容忍前后噪音。
func extractJSONObject(text string) (string, error) {
start := strings.IndexByte(text, '{')
if start < 0 {
return "", errors.New("runtime: compact summary response does not contain a JSON object")
}

for {
candidate, err := extractJSONObjectCandidate(text, start)
if err == nil {
if _, decodeErr := decodeCompactSummaryResponse(candidate); decodeErr == nil {
return candidate, nil
}
}

next := strings.IndexByte(text[start+1:], '{')
if next < 0 {
break
}
start += next + 1
}

return "", errors.New("runtime: compact summary response does not contain a valid compact JSON object")
}

// extractJSONObjectCandidate 从给定起点抽取平衡的 JSON 对象片段。
func extractJSONObjectCandidate(text string, start int) (string, error) {
depth := 0
inString := false
escaped := false
for index := start; index < len(text); index++ {
start := -1
seenObjectStart := false

for index := 0; index < len(text); index++ {
ch := text[index]
if inString {
if escaped {
Expand All @@ -199,14 +177,30 @@ func extractJSONObjectCandidate(text string, start int) (string, error) {
case '"':
inString = true
case '{':
seenObjectStart = true
if depth == 0 {
start = index
}
depth++
case '}':
depth--
if depth == 0 {
return strings.TrimSpace(text[start : index+1]), nil
continue
}
depth--
if depth == 0 && start >= 0 {
candidate := strings.TrimSpace(text[start : index+1])
if _, err := decodeCompactSummaryResponse(candidate); err == nil {
return candidate, nil
}
}
}
}

return "", errors.New("runtime: compact summary response contains an incomplete JSON object")
if !seenObjectStart {
return "", errors.New("runtime: compact summary response does not contain a JSON object")
}
if depth > 0 {
return "", errors.New("runtime: compact summary response contains an incomplete JSON object")
}
return "", errors.New("runtime: compact summary response does not contain a valid compact JSON object")
}
102 changes: 51 additions & 51 deletions internal/runtime/compact_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,12 @@ func validCompactSummaryJSON() string {
func TestCompactSummaryGeneratorBuildsProviderRequestWithoutTools(t *testing.T) {
t.Parallel()

manager := newRuntimeConfigManager(t)
resolvedProvider, err := resolvedProviderForTests(manager.Get(), config.OpenAIName)
if err != nil {
t.Fatalf("resolve provider: %v", err)
}

scripted := &scriptedProvider{
streams: [][]providertypes.StreamEvent{
{providertypes.NewTextDeltaStreamEvent(validCompactSummaryJSON())},
},
}
factory := &scriptedProviderFactory{provider: scripted}
generator := newCompactSummaryGenerator(factory, resolvedProvider.ToRuntimeConfig(), "session-model")
generator, factory, manager := newCompactGeneratorTestSetup(t, scripted)

summary, err := generator.Generate(context.Background(), contextcompact.SummaryInput{
Mode: contextcompact.ModeManual,
Expand Down Expand Up @@ -116,12 +109,6 @@ func TestCompactSummaryGeneratorBuildsProviderRequestWithoutTools(t *testing.T)
func TestCompactSummaryGeneratorRejectsToolCalls(t *testing.T) {
t.Parallel()

manager := newRuntimeConfigManager(t)
resolvedProvider, err := resolvedProviderForTests(manager.Get(), config.OpenAIName)
if err != nil {
t.Fatalf("resolve provider: %v", err)
}

scripted := &scriptedProvider{
streams: [][]providertypes.StreamEvent{
{
Expand All @@ -130,56 +117,40 @@ func TestCompactSummaryGeneratorRejectsToolCalls(t *testing.T) {
},
},
}
generator := newCompactSummaryGenerator(&scriptedProviderFactory{provider: scripted}, resolvedProvider.ToRuntimeConfig(), "session-model")
generator, _, manager := newCompactGeneratorTestSetup(t, scripted)

_, err = generator.Generate(context.Background(), contextcompact.SummaryInput{
_, err := generator.Generate(context.Background(), contextcompact.SummaryInput{
Mode: contextcompact.ModeManual,
ArchivedMessages: []providertypes.Message{
{Role: providertypes.RoleUser, Content: "legacy request"},
},
Config: manager.Get().Context.Compact,
})
if err == nil || !strings.Contains(err.Error(), "must not contain tool calls") {
t.Fatalf("expected tool call rejection, got %v", err)
}
requireErrorContains(t, err, "must not contain tool calls")
}

func TestCompactSummaryGeneratorRejectsMalformedStreamEvent(t *testing.T) {
t.Parallel()

manager := newRuntimeConfigManager(t)
resolvedProvider, err := resolvedProviderForTests(manager.Get(), config.OpenAIName)
if err != nil {
t.Fatalf("resolve provider: %v", err)
}

scripted := &scriptedProvider{
streams: [][]providertypes.StreamEvent{
{
{Type: providertypes.StreamEventTextDelta},
},
},
}
generator := newCompactSummaryGenerator(&scriptedProviderFactory{provider: scripted}, resolvedProvider.ToRuntimeConfig(), "session-model")
generator, _, manager := newCompactGeneratorTestSetup(t, scripted)

_, err = generator.Generate(context.Background(), contextcompact.SummaryInput{
_, err := generator.Generate(context.Background(), contextcompact.SummaryInput{
Mode: contextcompact.ModeManual,
Config: manager.Get().Context.Compact,
})
if err == nil || !strings.Contains(err.Error(), "text_delta event payload is nil") {
t.Fatalf("expected malformed stream event rejection, got %v", err)
}
requireErrorContains(t, err, "text_delta event payload is nil")
}

func TestCompactSummaryGeneratorRejectsCompletionWithoutMessageDone(t *testing.T) {
t.Parallel()

manager := newRuntimeConfigManager(t)
resolvedProvider, err := resolvedProviderForTests(manager.Get(), config.OpenAIName)
if err != nil {
t.Fatalf("resolve provider: %v", err)
}

scripted := &scriptedProvider{
chatFn: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error {
select {
Expand All @@ -190,37 +161,29 @@ func TestCompactSummaryGeneratorRejectsCompletionWithoutMessageDone(t *testing.T
return nil
},
}
generator := newCompactSummaryGenerator(&scriptedProviderFactory{provider: scripted}, resolvedProvider.ToRuntimeConfig(), "session-model")
generator, _, manager := newCompactGeneratorTestSetup(t, scripted)

_, err = generator.Generate(context.Background(), contextcompact.SummaryInput{
_, err := generator.Generate(context.Background(), contextcompact.SummaryInput{
Mode: contextcompact.ModeManual,
Config: manager.Get().Context.Compact,
})
if !errors.Is(err, provider.ErrStreamInterrupted) {
t.Fatalf("expected ErrStreamInterrupted, got %v", err)
}
if !strings.Contains(err.Error(), "without message_done") {
t.Fatalf("expected missing message_done error, got %v", err)
}
requireErrorContains(t, err, "without message_done")
}

func TestCompactSummaryGeneratorMalformedStreamEventDoesNotDeadlock(t *testing.T) {
t.Parallel()

manager := newRuntimeConfigManager(t)
resolvedProvider, err := resolvedProviderForTests(manager.Get(), config.OpenAIName)
if err != nil {
t.Fatalf("resolve provider: %v", err)
}

stream := []providertypes.StreamEvent{{Type: providertypes.StreamEventTextDelta}}
for i := 0; i < 40; i++ {
stream = append(stream, providertypes.NewTextDeltaStreamEvent("ignored"))
}
scripted := &scriptedProvider{
streams: [][]providertypes.StreamEvent{stream},
}
generator := newCompactSummaryGenerator(&scriptedProviderFactory{provider: scripted}, resolvedProvider.ToRuntimeConfig(), "session-model")
generator, _, manager := newCompactGeneratorTestSetup(t, scripted)

errCh := make(chan error, 1)
go func() {
Expand All @@ -233,9 +196,7 @@ func TestCompactSummaryGeneratorMalformedStreamEventDoesNotDeadlock(t *testing.T

select {
case genErr := <-errCh:
if genErr == nil || !strings.Contains(genErr.Error(), "text_delta event payload is nil") {
t.Fatalf("expected malformed stream event rejection, got %v", genErr)
}
requireErrorContains(t, genErr, "text_delta event payload is nil")
case <-time.After(2 * time.Second):
t.Fatal("expected compact generation to fail instead of deadlocking on malformed stream event")
}
Expand Down Expand Up @@ -275,3 +236,42 @@ func TestParseCompactSummaryOutputSkipsNonCompactJSONPreface(t *testing.T) {
t.Fatalf("expected parsed goal, got %+v", output.TaskState)
}
}

func TestParseCompactSummaryOutputSkipsBraceNoiseAndFindsFirstValidObject(t *testing.T) {
t.Parallel()

noise := strings.Repeat("{not-json}", 80)
content := strings.Join([]string{
noise,
`{"task_state":{"goal":"g","progress":[],"open_items":[],"next_step":"","blockers":[],"key_artifacts":[],"decisions":[],"user_constraints":[]},"display_summary":"[compact_summary]\nok"}`,
}, "\n")

output, err := parseCompactSummaryOutput(content)
if err != nil {
t.Fatalf("expected parser to recover valid compact payload after brace noise, got %v", err)
}
if output.TaskState.Goal != "g" {
t.Fatalf("expected parsed goal, got %+v", output.TaskState)
}
}

func newCompactGeneratorTestSetup(t *testing.T, scripted *scriptedProvider) (contextcompact.SummaryGenerator, *scriptedProviderFactory, *config.Manager) {
t.Helper()

manager := newRuntimeConfigManager(t)
resolvedProvider, err := resolvedProviderForTests(manager.Get(), config.OpenAIName)
if err != nil {
t.Fatalf("resolve provider: %v", err)
}

factory := &scriptedProviderFactory{provider: scripted}
generator := newCompactSummaryGenerator(factory, resolvedProvider.ToRuntimeConfig(), "session-model")
return generator, factory, manager
}

func requireErrorContains(t *testing.T, err error, substring string) {
t.Helper()
if err == nil || !strings.Contains(err.Error(), substring) {
t.Fatalf("expected error containing %q, got %v", substring, err)
}
}
20 changes: 16 additions & 4 deletions internal/session/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ func (s *JSONStore) Save(ctx context.Context, session *Session) error {
s.mu.Lock()
defer s.mu.Unlock()

if err := os.MkdirAll(s.baseDir, 0o755); err != nil {
return fmt.Errorf("session: create sessions dir: %w", err)
if err := s.ensureSessionsDir(); err != nil {
return err
}

payload, err := json.MarshalIndent(session, "", " ")
Expand Down Expand Up @@ -141,8 +141,8 @@ func (s *JSONStore) ListSummaries(ctx context.Context) ([]Summary, error) {
s.mu.RLock()
defer s.mu.RUnlock()

if err := os.MkdirAll(s.baseDir, 0o755); err != nil {
return nil, fmt.Errorf("session: create sessions dir: %w", err)
if err := s.ensureSessionsDir(); err != nil {
return nil, err
}

entries, err := os.ReadDir(s.baseDir)
Expand Down Expand Up @@ -189,6 +189,14 @@ func (s *JSONStore) filePath(id string) string {
return filepath.Join(s.baseDir, id+".json")
}

// ensureSessionsDir 确保会话目录存在,避免重复的 mkdir 原子逻辑。
func (s *JSONStore) ensureSessionsDir() error {
if err := os.MkdirAll(s.baseDir, 0o755); err != nil {
return fmt.Errorf("session: create sessions dir: %w", err)
}
return nil
}

// New 创建一个默认标题策略的新会话对象。
func New(title string) Session {
return NewWithWorkdir(title, "")
Expand Down Expand Up @@ -308,6 +316,10 @@ func decodeStoredSummary(data []byte) (Summary, error) {
if len(stored.TaskState) == 0 {
return Summary{}, errors.New("missing required field task_state")
}
var taskState *TaskState
if err := json.Unmarshal(stored.TaskState, &taskState); err != nil || taskState == nil {
return Summary{}, errors.New("invalid field task_state")
}
if err := validateSessionSchema(Session{SchemaVersion: *stored.SchemaVersion}); err != nil {
return Summary{}, err
}
Expand Down
Loading
Loading