From 2573fd1e3fe9b9258d78a9acb0e2feca1c32522d Mon Sep 17 00:00:00 2001 From: Cai_Tang <106404101+Cai-Tang-www@users.noreply.github.com> Date: Tue, 14 Apr 2026 23:38:21 +0800 Subject: [PATCH 01/33] =?UTF-8?q?feat():=20=E5=AE=9E=E7=8E=B0=20SubAgent?= =?UTF-8?q?=20WorkerRuntime=E3=80=81=E8=A7=92=E8=89=B2=E7=AD=96=E7=95=A5?= =?UTF-8?q?=E4=B8=8E=20Runtime=20=E7=94=9F=E5=91=BD=E5=91=A8=E6=9C=9F?= =?UTF-8?q?=E6=8E=A5=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/runtime/events.go | 26 +++ internal/runtime/runtime.go | 3 + internal/runtime/subagent_factory.go | 20 ++ internal/runtime/subagent_factory_test.go | 33 +++ internal/runtime/subagent_run.go | 140 ++++++++++++ internal/runtime/subagent_run_test.go | 154 +++++++++++++ internal/subagent/engine.go | 29 +++ internal/subagent/factory.go | 36 +++ internal/subagent/output_contract.go | 27 +++ internal/subagent/policy.go | 86 +++++++ internal/subagent/policy_test.go | 124 +++++++++++ internal/subagent/types.go | 224 +++++++++++++++++++ internal/subagent/util.go | 28 +++ internal/subagent/worker.go | 259 ++++++++++++++++++++++ internal/subagent/worker_test.go | 245 ++++++++++++++++++++ 15 files changed, 1434 insertions(+) create mode 100644 internal/runtime/subagent_factory.go create mode 100644 internal/runtime/subagent_factory_test.go create mode 100644 internal/runtime/subagent_run.go create mode 100644 internal/runtime/subagent_run_test.go create mode 100644 internal/subagent/engine.go create mode 100644 internal/subagent/factory.go create mode 100644 internal/subagent/output_contract.go create mode 100644 internal/subagent/policy.go create mode 100644 internal/subagent/policy_test.go create mode 100644 internal/subagent/types.go create mode 100644 internal/subagent/util.go create mode 100644 internal/subagent/worker.go create mode 100644 internal/subagent/worker_test.go diff --git a/internal/runtime/events.go b/internal/runtime/events.go index 97eadc6a..f9506c40 100644 --- a/internal/runtime/events.go +++ b/internal/runtime/events.go @@ -1,5 +1,7 @@ package runtime +import "neo-code/internal/subagent" + // EventType identifies the kind of runtime event emitted during a run. type EventType string @@ -90,3 +92,27 @@ type TokenUsagePayload struct { SessionInputTokens int `json:"session_input_tokens"` SessionOutputTokens int `json:"session_output_tokens"` } + +// SubAgentEventPayload 描述子代理执行生命周期的事件载荷。 +type SubAgentEventPayload struct { + Role subagent.Role `json:"role"` + TaskID string `json:"task_id"` + State subagent.State `json:"state"` + StopReason subagent.StopReason `json:"stop_reason,omitempty"` + Step int `json:"step,omitempty"` + Delta string `json:"delta,omitempty"` + Error string `json:"error,omitempty"` +} + +const ( + // EventSubAgentStarted 在子代理任务启动后触发。 + EventSubAgentStarted EventType = "subagent_started" + // EventSubAgentProgress 在子代理执行每一步后触发。 + EventSubAgentProgress EventType = "subagent_progress" + // EventSubAgentCompleted 在子代理成功结束后触发。 + EventSubAgentCompleted EventType = "subagent_completed" + // EventSubAgentFailed 在子代理失败结束后触发。 + EventSubAgentFailed EventType = "subagent_failed" + // EventSubAgentCanceled 在子代理被取消后触发。 + EventSubAgentCanceled EventType = "subagent_canceled" +) diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index beece19f..ec1aa87f 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -14,6 +14,7 @@ import ( providertypes "neo-code/internal/provider/types" "neo-code/internal/runtime/approval" agentsession "neo-code/internal/session" + "neo-code/internal/subagent" "neo-code/internal/tools" ) @@ -65,6 +66,7 @@ type Service struct { compactRunner contextcompact.Runner approvalBroker *approval.Broker memoExtractor MemoExtractor + subAgentFactory subagent.Factory events chan RuntimeEvent sessionMu sync.Mutex @@ -106,6 +108,7 @@ func NewWithFactory( providerFactory: providerFactory, contextBuilder: contextBuilder, approvalBroker: approval.NewBroker(), + subAgentFactory: subagent.NewWorkerFactory(nil), events: make(chan RuntimeEvent, 128), sessionLocks: make(map[string]*sessionLockEntry), activeRunCancels: make(map[uint64]context.CancelFunc), diff --git a/internal/runtime/subagent_factory.go b/internal/runtime/subagent_factory.go new file mode 100644 index 00000000..140f4f70 --- /dev/null +++ b/internal/runtime/subagent_factory.go @@ -0,0 +1,20 @@ +package runtime + +import "neo-code/internal/subagent" + +// SetSubAgentFactory 设置子代理运行时工厂;传入 nil 时回退到默认工厂。 +func (s *Service) SetSubAgentFactory(factory subagent.Factory) { + if factory == nil { + s.subAgentFactory = subagent.NewWorkerFactory(nil) + return + } + s.subAgentFactory = factory +} + +// SubAgentFactory 返回当前 runtime 持有的子代理运行时工厂。 +func (s *Service) SubAgentFactory() subagent.Factory { + if s.subAgentFactory == nil { + s.subAgentFactory = subagent.NewWorkerFactory(nil) + } + return s.subAgentFactory +} diff --git a/internal/runtime/subagent_factory_test.go b/internal/runtime/subagent_factory_test.go new file mode 100644 index 00000000..eea246b1 --- /dev/null +++ b/internal/runtime/subagent_factory_test.go @@ -0,0 +1,33 @@ +package runtime + +import ( + "testing" + + "neo-code/internal/subagent" +) + +type fakeSubAgentFactory struct{} + +func (fakeSubAgentFactory) Create(role subagent.Role) (subagent.WorkerRuntime, error) { + return subagent.NewWorkerFactory(nil).Create(role) +} + +func TestServiceSubAgentFactoryRegistration(t *testing.T) { + t.Parallel() + + svc := NewWithFactory(nil, nil, nil, nil, nil) + if svc.SubAgentFactory() == nil { + t.Fatalf("expected default sub-agent factory") + } + + custom := fakeSubAgentFactory{} + svc.SetSubAgentFactory(custom) + if svc.SubAgentFactory() == nil { + t.Fatalf("expected custom sub-agent factory") + } + + svc.SetSubAgentFactory(nil) + if svc.SubAgentFactory() == nil { + t.Fatalf("expected reset to default sub-agent factory") + } +} diff --git a/internal/runtime/subagent_run.go b/internal/runtime/subagent_run.go new file mode 100644 index 00000000..45169c03 --- /dev/null +++ b/internal/runtime/subagent_run.go @@ -0,0 +1,140 @@ +package runtime + +import ( + "context" + "errors" + "fmt" + "strings" + + "neo-code/internal/subagent" +) + +// SubAgentTaskInput 描述一次子代理任务执行请求。 +type SubAgentTaskInput struct { + RunID string + SessionID string + Role subagent.Role + Task subagent.Task + Budget subagent.Budget + Capability subagent.Capability +} + +// RunSubAgentTask 使用当前 runtime 注册的工厂执行一条子代理任务。 +func (s *Service) RunSubAgentTask(ctx context.Context, input SubAgentTaskInput) (subagent.Result, error) { + if err := ctx.Err(); err != nil { + return subagent.Result{}, err + } + if strings.TrimSpace(input.RunID) == "" { + return subagent.Result{}, errors.New("runtime: subagent run id is empty") + } + if !input.Role.Valid() { + return subagent.Result{}, fmt.Errorf("runtime: invalid subagent role %q", input.Role) + } + if err := input.Task.Validate(); err != nil { + return subagent.Result{}, err + } + + factory := s.SubAgentFactory() + worker, err := factory.Create(input.Role) + if err != nil { + _ = s.emit(ctx, EventSubAgentFailed, input.RunID, input.SessionID, SubAgentEventPayload{ + Role: input.Role, + TaskID: input.Task.ID, + State: subagent.StateFailed, + Error: err.Error(), + }) + return subagent.Result{}, err + } + + if err := worker.Start(input.Task, input.Budget, input.Capability); err != nil { + _ = s.emit(ctx, EventSubAgentFailed, input.RunID, input.SessionID, SubAgentEventPayload{ + Role: input.Role, + TaskID: input.Task.ID, + State: subagent.StateFailed, + Error: err.Error(), + }) + return subagent.Result{}, err + } + + _ = s.emit(ctx, EventSubAgentStarted, input.RunID, input.SessionID, SubAgentEventPayload{ + Role: input.Role, + TaskID: input.Task.ID, + State: worker.State(), + }) + + for { + stepResult, stepErr := worker.Step(ctx) + _ = s.emit(ctx, EventSubAgentProgress, input.RunID, input.SessionID, SubAgentEventPayload{ + Role: input.Role, + TaskID: input.Task.ID, + State: stepResult.State, + Step: stepResult.Step, + Delta: stepResult.Delta, + Error: errorText(stepErr), + }) + + if stepErr != nil { + result, resultErr := worker.Result() + if resultErr != nil { + _ = s.emit(ctx, EventSubAgentFailed, input.RunID, input.SessionID, SubAgentEventPayload{ + Role: input.Role, + TaskID: input.Task.ID, + State: subagent.StateFailed, + Error: stepErr.Error(), + }) + return subagent.Result{}, stepErr + } + emitSubAgentTerminal(s, ctx, input, result) + return result, stepErr + } + + if !stepResult.Done { + continue + } + + result, err := worker.Result() + if err != nil { + _ = s.emit(ctx, EventSubAgentFailed, input.RunID, input.SessionID, SubAgentEventPayload{ + Role: input.Role, + TaskID: input.Task.ID, + State: subagent.StateFailed, + Error: err.Error(), + }) + return subagent.Result{}, err + } + emitSubAgentTerminal(s, ctx, input, result) + if result.State == subagent.StateSucceeded { + return result, nil + } + return result, errors.New(result.Error) + } +} + +// emitSubAgentTerminal 按子代理终态发射最终事件。 +func emitSubAgentTerminal(s *Service, ctx context.Context, input SubAgentTaskInput, result subagent.Result) { + payload := SubAgentEventPayload{ + Role: result.Role, + TaskID: result.TaskID, + State: result.State, + StopReason: result.StopReason, + Step: result.StepCount, + Error: strings.TrimSpace(result.Error), + } + + switch result.State { + case subagent.StateSucceeded: + _ = s.emit(ctx, EventSubAgentCompleted, input.RunID, input.SessionID, payload) + case subagent.StateCanceled: + _ = s.emit(ctx, EventSubAgentCanceled, input.RunID, input.SessionID, payload) + default: + _ = s.emit(ctx, EventSubAgentFailed, input.RunID, input.SessionID, payload) + } +} + +// errorText 将 error 安全转换为事件可用文本。 +func errorText(err error) string { + if err == nil { + return "" + } + return strings.TrimSpace(err.Error()) +} diff --git a/internal/runtime/subagent_run_test.go b/internal/runtime/subagent_run_test.go new file mode 100644 index 00000000..fa66d19b --- /dev/null +++ b/internal/runtime/subagent_run_test.go @@ -0,0 +1,154 @@ +package runtime + +import ( + "context" + "errors" + "testing" + + "neo-code/internal/subagent" +) + +type failingSubAgentFactory struct { + err error +} + +func (f failingSubAgentFactory) Create(role subagent.Role) (subagent.WorkerRuntime, error) { + return nil, f.err +} + +func TestServiceRunSubAgentTaskSuccess(t *testing.T) { + t.Parallel() + + service := NewWithFactory(nil, nil, nil, nil, nil) + service.SetSubAgentFactory(subagent.NewWorkerFactory(func(role subagent.Role, policy subagent.RolePolicy) subagent.Engine { + return subagent.EngineFunc(func(ctx context.Context, input subagent.StepInput) (subagent.StepOutput, error) { + if input.StepIndex == 1 { + return subagent.StepOutput{ + Delta: "step-1", + Done: false, + }, nil + } + return subagent.StepOutput{ + Delta: "step-2", + Done: true, + Output: subagent.Output{ + Summary: "task completed", + Findings: []string{"f1"}, + Patches: []string{"p1"}, + Risks: []string{"r1"}, + NextActions: []string{"n1"}, + Artifacts: []string{"a1"}, + }, + }, nil + }) + })) + + result, err := service.RunSubAgentTask(context.Background(), SubAgentTaskInput{ + RunID: "sub-run-success", + SessionID: "session-1", + Role: subagent.RoleCoder, + Task: subagent.Task{ + ID: "task-1", + Goal: "implement feature", + }, + Budget: subagent.Budget{ + MaxSteps: 3, + }, + }) + if err != nil { + t.Fatalf("RunSubAgentTask() error = %v", err) + } + if result.State != subagent.StateSucceeded { + t.Fatalf("result state = %q, want %q", result.State, subagent.StateSucceeded) + } + if result.StepCount != 2 { + t.Fatalf("result step count = %d, want 2", result.StepCount) + } + + events := collectRuntimeEvents(service.Events()) + assertEventSequence(t, events, []EventType{ + EventSubAgentStarted, + EventSubAgentProgress, + EventSubAgentProgress, + EventSubAgentCompleted, + }) + assertEventsRunID(t, events, "sub-run-success") +} + +func TestServiceRunSubAgentTaskFailureFlows(t *testing.T) { + t.Parallel() + + t.Run("factory create failed", func(t *testing.T) { + t.Parallel() + + service := NewWithFactory(nil, nil, nil, nil, nil) + service.SetSubAgentFactory(failingSubAgentFactory{err: errors.New("create failed")}) + _, err := service.RunSubAgentTask(context.Background(), SubAgentTaskInput{ + RunID: "sub-run-factory-failed", + Role: subagent.RoleResearcher, + Task: subagent.Task{ + ID: "task-f", + Goal: "research", + }, + }) + if err == nil { + t.Fatalf("expected create error") + } + events := collectRuntimeEvents(service.Events()) + assertEventSequence(t, events, []EventType{EventSubAgentFailed}) + }) + + t.Run("worker step failed", func(t *testing.T) { + t.Parallel() + + service := NewWithFactory(nil, nil, nil, nil, nil) + service.SetSubAgentFactory(subagent.NewWorkerFactory(func(role subagent.Role, policy subagent.RolePolicy) subagent.Engine { + return subagent.EngineFunc(func(ctx context.Context, input subagent.StepInput) (subagent.StepOutput, error) { + return subagent.StepOutput{}, errors.New("step failed") + }) + })) + _, err := service.RunSubAgentTask(context.Background(), SubAgentTaskInput{ + RunID: "sub-run-step-failed", + Role: subagent.RoleReviewer, + Task: subagent.Task{ + ID: "task-step-f", + Goal: "review", + }, + }) + if err == nil { + t.Fatalf("expected step error") + } + events := collectRuntimeEvents(service.Events()) + assertEventSequence(t, events, []EventType{ + EventSubAgentStarted, + EventSubAgentProgress, + EventSubAgentFailed, + }) + }) +} + +func TestServiceRunSubAgentTaskInputValidation(t *testing.T) { + t.Parallel() + + service := NewWithFactory(nil, nil, nil, nil, nil) + if _, err := service.RunSubAgentTask(context.Background(), SubAgentTaskInput{ + Role: subagent.RoleCoder, + Task: subagent.Task{ + ID: "task", + Goal: "goal", + }, + }); err == nil { + t.Fatalf("expected empty run id error") + } + + if _, err := service.RunSubAgentTask(context.Background(), SubAgentTaskInput{ + RunID: "sub-run-invalid-role", + Role: subagent.Role("x"), + Task: subagent.Task{ + ID: "task", + Goal: "goal", + }, + }); err == nil { + t.Fatalf("expected invalid role error") + } +} diff --git a/internal/subagent/engine.go b/internal/subagent/engine.go new file mode 100644 index 00000000..6d849558 --- /dev/null +++ b/internal/subagent/engine.go @@ -0,0 +1,29 @@ +package subagent + +import ( + "context" + "strings" +) + +// defaultEngine 提供一个可运行的默认单步完成引擎,便于工厂与测试快速装配。 +type defaultEngine struct{} + +// RunStep 执行默认单步逻辑并返回结构化结果。 +func (defaultEngine) RunStep(ctx context.Context, input StepInput) (StepOutput, error) { + if err := ctx.Err(); err != nil { + return StepOutput{}, err + } + + summary := strings.TrimSpace(input.Task.ExpectedOutput) + if summary == "" { + summary = strings.TrimSpace(input.Task.Goal) + } + + return StepOutput{ + Delta: "default engine completed", + Done: true, + Output: Output{ + Summary: summary, + }, + }, nil +} diff --git a/internal/subagent/factory.go b/internal/subagent/factory.go new file mode 100644 index 00000000..3b4b19ba --- /dev/null +++ b/internal/subagent/factory.go @@ -0,0 +1,36 @@ +package subagent + +import "time" + +// EngineBuilder 定义基于角色策略构建执行引擎的工厂函数。 +type EngineBuilder func(role Role, policy RolePolicy) Engine + +// WorkerFactory 是默认的 WorkerRuntime 工厂实现。 +type WorkerFactory struct { + builder EngineBuilder +} + +// NewWorkerFactory 创建 WorkerFactory;当 builder 为空时使用默认引擎。 +func NewWorkerFactory(builder EngineBuilder) *WorkerFactory { + return &WorkerFactory{builder: builder} +} + +// Create 基于角色创建对应策略与执行引擎的 WorkerRuntime。 +func (f *WorkerFactory) Create(role Role) (WorkerRuntime, error) { + policy, err := DefaultRolePolicy(role) + if err != nil { + return nil, err + } + // 兜底保证默认预算始终可用。 + policy.DefaultBudget = policy.DefaultBudget.normalize(Budget{ + MaxSteps: defaultPolicyMaxSteps, + Timeout: defaultPolicyTimeout * time.Second, + }) + + var engine Engine + if f != nil && f.builder != nil { + engine = f.builder(role, policy) + } + return NewWorker(role, policy, engine) +} + diff --git a/internal/subagent/output_contract.go b/internal/subagent/output_contract.go new file mode 100644 index 00000000..4fe063ee --- /dev/null +++ b/internal/subagent/output_contract.go @@ -0,0 +1,27 @@ +package subagent + +import "strings" + +var supportedOutputSections = map[string]struct{}{ + "summary": {}, + "findings": {}, + "patches": {}, + "risks": {}, + "next_actions": {}, + "artifacts": {}, +} + +// validateOutputContract 校验输出结构是否满足角色策略要求。 +func validateOutputContract(policy RolePolicy, output Output) error { + out := output.normalize() + for _, section := range dedupeAndTrim(policy.RequiredSections) { + key := strings.ToLower(strings.TrimSpace(section)) + if _, ok := supportedOutputSections[key]; !ok { + return errorsf("unsupported required output section %q", section) + } + if key == "summary" && strings.TrimSpace(out.Summary) == "" { + return errorsf("output summary is required") + } + } + return nil +} diff --git a/internal/subagent/policy.go b/internal/subagent/policy.go new file mode 100644 index 00000000..7a5b3d62 --- /dev/null +++ b/internal/subagent/policy.go @@ -0,0 +1,86 @@ +package subagent + +import ( + "strings" + "time" +) + +const ( + defaultPolicyMaxSteps = 6 + defaultPolicyTimeout = 30 +) + +// RolePolicy 定义不同角色的执行策略。 +type RolePolicy struct { + Role Role + SystemPrompt string + AllowedTools []string + DefaultBudget Budget + RequiredSections []string +} + +// Validate 校验角色策略是否合法。 +func (p RolePolicy) Validate() error { + if !p.Role.Valid() { + return errorsf("invalid policy role %q", p.Role) + } + if strings.TrimSpace(p.SystemPrompt) == "" { + return errorsf("role policy prompt is required") + } + if len(dedupeAndTrim(p.AllowedTools)) == 0 { + return errorsf("role policy allowed tools is empty") + } + if len(dedupeAndTrim(p.RequiredSections)) == 0 { + return errorsf("role policy required sections is empty") + } + if err := validateOutputContract(p, Output{Summary: "probe"}); err != nil { + return err + } + return nil +} + +// DefaultRolePolicy 返回内置角色策略。 +func DefaultRolePolicy(role Role) (RolePolicy, error) { + if !role.Valid() { + return RolePolicy{}, errorsf("unsupported role %q", role) + } + + policy := RolePolicy{ + Role: role, + DefaultBudget: Budget{ + MaxSteps: defaultPolicyMaxSteps, + Timeout: defaultPolicyTimeout * time.Second, + }, + RequiredSections: []string{ + "summary", + "findings", + "patches", + "risks", + "next_actions", + "artifacts", + }, + } + + switch role { + case RoleResearcher: + policy.SystemPrompt = "你是研究型子代理,负责检索证据并形成结论。" + policy.AllowedTools = []string{"filesystem_read_file", "filesystem_glob", "filesystem_grep", "webfetch"} + case RoleCoder: + policy.SystemPrompt = "你是实现型子代理,负责修改代码并给出验证结果。" + policy.AllowedTools = []string{ + "filesystem_read_file", + "filesystem_write_file", + "filesystem_edit", + "filesystem_glob", + "filesystem_grep", + "bash", + } + case RoleReviewer: + policy.SystemPrompt = "你是审查型子代理,负责识别缺陷、风险与测试缺口。" + policy.AllowedTools = []string{"filesystem_read_file", "filesystem_glob", "filesystem_grep"} + } + + policy.AllowedTools = dedupeAndTrim(policy.AllowedTools) + policy.RequiredSections = dedupeAndTrim(policy.RequiredSections) + return policy, nil +} diff --git a/internal/subagent/policy_test.go b/internal/subagent/policy_test.go new file mode 100644 index 00000000..08721d37 --- /dev/null +++ b/internal/subagent/policy_test.go @@ -0,0 +1,124 @@ +package subagent + +import "testing" + +func TestDefaultRolePolicy(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + role Role + }{ + {name: "researcher", role: RoleResearcher}, + {name: "coder", role: RoleCoder}, + {name: "reviewer", role: RoleReviewer}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + policy, err := DefaultRolePolicy(tt.role) + if err != nil { + t.Fatalf("DefaultRolePolicy() error = %v", err) + } + if policy.Role != tt.role { + t.Fatalf("policy role = %q, want %q", policy.Role, tt.role) + } + if policy.DefaultBudget.MaxSteps <= 0 { + t.Fatalf("invalid max steps: %d", policy.DefaultBudget.MaxSteps) + } + if len(policy.AllowedTools) == 0 { + t.Fatalf("expected non-empty allowed tools") + } + if len(policy.RequiredSections) == 0 { + t.Fatalf("expected non-empty required sections") + } + if err := policy.Validate(); err != nil { + t.Fatalf("policy.Validate() error = %v", err) + } + }) + } +} + +func TestDefaultRolePolicyInvalidRole(t *testing.T) { + t.Parallel() + + if _, err := DefaultRolePolicy(Role("unknown")); err == nil { + t.Fatalf("expected invalid role error") + } +} + +func TestRolePolicyValidate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + policy RolePolicy + wantErr bool + }{ + { + name: "valid", + policy: RolePolicy{ + Role: RoleResearcher, + SystemPrompt: "prompt", + AllowedTools: []string{"filesystem_grep"}, + RequiredSections: []string{"summary"}, + }, + }, + { + name: "empty prompt", + policy: RolePolicy{ + Role: RoleResearcher, + AllowedTools: []string{"filesystem_grep"}, + RequiredSections: []string{"summary"}, + }, + wantErr: true, + }, + { + name: "empty tools", + policy: RolePolicy{ + Role: RoleResearcher, + SystemPrompt: "prompt", + RequiredSections: []string{"summary"}, + }, + wantErr: true, + }, + { + name: "invalid role", + policy: RolePolicy{ + Role: Role("x"), + SystemPrompt: "prompt", + AllowedTools: []string{"filesystem_grep"}, + RequiredSections: []string{"summary"}, + }, + wantErr: true, + }, + { + name: "unsupported required section", + policy: RolePolicy{ + Role: RoleResearcher, + SystemPrompt: "prompt", + AllowedTools: []string{"filesystem_grep"}, + RequiredSections: []string{"unknown_section"}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := tt.policy.Validate() + if tt.wantErr && err == nil { + t.Fatalf("expected error") + } + if !tt.wantErr && err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} diff --git a/internal/subagent/types.go b/internal/subagent/types.go new file mode 100644 index 00000000..9e3ba823 --- /dev/null +++ b/internal/subagent/types.go @@ -0,0 +1,224 @@ +package subagent + +import ( + "context" + "fmt" + "strings" + "time" +) + +// Role 表示子代理的执行角色。 +type Role string + +const ( + // RoleResearcher 用于检索与分析任务。 + RoleResearcher Role = "researcher" + // RoleCoder 用于实现与修复任务。 + RoleCoder Role = "coder" + // RoleReviewer 用于审查与验收任务。 + RoleReviewer Role = "reviewer" +) + +// Valid 判断角色是否受支持。 +func (r Role) Valid() bool { + switch r { + case RoleResearcher, RoleCoder, RoleReviewer: + return true + default: + return false + } +} + +// Budget 描述子代理执行预算。 +type Budget struct { + MaxSteps int + Timeout time.Duration +} + +// normalize 归一化预算并应用默认值。 +func (b Budget) normalize(defaults Budget) Budget { + out := b + if out.MaxSteps <= 0 { + out.MaxSteps = defaults.MaxSteps + } + if out.MaxSteps <= 0 { + out.MaxSteps = 6 + } + if out.Timeout <= 0 { + out.Timeout = defaults.Timeout + } + if out.Timeout <= 0 { + out.Timeout = 30 * time.Second + } + return out +} + +// Capability 描述子代理运行时可用能力边界。 +type Capability struct { + AllowedTools []string + AllowedPaths []string +} + +// normalize 归一化能力列表并去重。 +func (c Capability) normalize() Capability { + return Capability{ + AllowedTools: dedupeAndTrim(c.AllowedTools), + AllowedPaths: dedupeAndTrim(c.AllowedPaths), + } +} + +// Task 表示单个子代理任务输入。 +type Task struct { + ID string + Goal string + ExpectedOutput string + Workspace string +} + +// Validate 校验任务输入是否合法。 +func (t Task) Validate() error { + if strings.TrimSpace(t.ID) == "" { + return errorsf("task id is required") + } + if strings.TrimSpace(t.Goal) == "" { + return errorsf("task goal is required") + } + return nil +} + +// StopReason 表示子代理终止原因。 +type StopReason string + +const ( + // StopReasonCompleted 表示正常完成。 + StopReasonCompleted StopReason = "completed" + // StopReasonCanceled 表示被取消。 + StopReasonCanceled StopReason = "canceled" + // StopReasonTimeout 表示执行超时。 + StopReasonTimeout StopReason = "timeout" + // StopReasonMaxSteps 表示达到步数上限。 + StopReasonMaxSteps StopReason = "max_steps" + // StopReasonError 表示执行错误。 + StopReasonError StopReason = "error" +) + +// State 表示子代理生命周期状态。 +type State string + +const ( + // StateIdle 表示尚未启动。 + StateIdle State = "idle" + // StateRunning 表示执行中。 + StateRunning State = "running" + // StateSucceeded 表示执行成功结束。 + StateSucceeded State = "succeeded" + // StateFailed 表示执行失败结束。 + StateFailed State = "failed" + // StateCanceled 表示被取消结束。 + StateCanceled State = "canceled" +) + +// Terminal 判断当前状态是否为终态。 +func (s State) Terminal() bool { + switch s { + case StateSucceeded, StateFailed, StateCanceled: + return true + default: + return false + } +} + +// Output 定义子代理标准结构化输出。 +type Output struct { + Summary string + Findings []string + Patches []string + Risks []string + NextActions []string + Artifacts []string +} + +// normalize 归一化输出,避免重复与空项。 +func (o Output) normalize() Output { + o.Summary = strings.TrimSpace(o.Summary) + o.Findings = dedupeAndTrim(o.Findings) + o.Patches = dedupeAndTrim(o.Patches) + o.Risks = dedupeAndTrim(o.Risks) + o.NextActions = dedupeAndTrim(o.NextActions) + o.Artifacts = dedupeAndTrim(o.Artifacts) + return o +} + +// StepInput 表示单步执行输入。 +type StepInput struct { + Role Role + Policy RolePolicy + Task Task + Budget Budget + Capability Capability + StepIndex int + Trace []string +} + +// StepOutput 表示单步执行输出。 +type StepOutput struct { + Delta string + Done bool + Output Output +} + +// StepResult 表示一次 Step 的可观测结果。 +type StepResult struct { + State State + Done bool + Step int + Delta string +} + +// Result 描述子代理完成后的结构化结果。 +type Result struct { + Role Role + TaskID string + State State + StopReason StopReason + StartedAt time.Time + EndedAt time.Time + StepCount int + Budget Budget + Capability Capability + Output Output + Error string +} + +// Engine 定义 WorkerRuntime 的单步执行引擎。 +type Engine interface { + RunStep(ctx context.Context, input StepInput) (StepOutput, error) +} + +// EngineFunc 允许用函数实现 Engine。 +type EngineFunc func(ctx context.Context, input StepInput) (StepOutput, error) + +// RunStep 执行函数式引擎逻辑。 +func (f EngineFunc) RunStep(ctx context.Context, input StepInput) (StepOutput, error) { + return f(ctx, input) +} + +// WorkerRuntime 定义子代理执行生命周期接口。 +type WorkerRuntime interface { + Start(task Task, budget Budget, capability Capability) error + Step(ctx context.Context) (StepResult, error) + Stop(reason StopReason) error + Result() (Result, error) + State() State + Policy() RolePolicy +} + +// Factory 定义 runtime 侧创建 WorkerRuntime 的入口。 +type Factory interface { + Create(role Role) (WorkerRuntime, error) +} + +// errorsf 统一组装 subagent 模块错误前缀。 +func errorsf(format string, args ...any) error { + return fmt.Errorf("subagent: "+format, args...) +} diff --git a/internal/subagent/util.go b/internal/subagent/util.go new file mode 100644 index 00000000..a87acc25 --- /dev/null +++ b/internal/subagent/util.go @@ -0,0 +1,28 @@ +package subagent + +import "strings" + +// dedupeAndTrim 对字符串切片做去空白、去重并保持原顺序。 +func dedupeAndTrim(items []string) []string { + if len(items) == 0 { + return nil + } + result := make([]string, 0, len(items)) + seen := make(map[string]struct{}, len(items)) + for _, raw := range items { + item := strings.TrimSpace(raw) + if item == "" { + continue + } + key := strings.ToLower(item) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + result = append(result, item) + } + if len(result) == 0 { + return nil + } + return result +} diff --git a/internal/subagent/worker.go b/internal/subagent/worker.go new file mode 100644 index 00000000..995f4501 --- /dev/null +++ b/internal/subagent/worker.go @@ -0,0 +1,259 @@ +package subagent + +import ( + "context" + "errors" + "strings" + "sync" + "time" +) + +// worker 是 WorkerRuntime 的默认实现,负责封装单任务生命周期。 +type worker struct { + mu sync.RWMutex + role Role + policy RolePolicy + engine Engine + state State + task Task + budget Budget + capability Capability + stepCount int + trace []string + startedAt time.Time + endedAt time.Time + stopReason StopReason + output Output + lastErr string +} + +// NewWorker 根据角色、策略与引擎创建一个 WorkerRuntime 实例。 +func NewWorker(role Role, policy RolePolicy, engine Engine) (WorkerRuntime, error) { + if !role.Valid() { + return nil, errorsf("invalid role %q", role) + } + if policy.Role == "" { + policy.Role = role + } + if policy.Role != role { + return nil, errorsf("policy role %q does not match worker role %q", policy.Role, role) + } + if err := policy.Validate(); err != nil { + return nil, err + } + if engine == nil { + engine = defaultEngine{} + } + + return &worker{ + role: role, + policy: policy, + engine: engine, + state: StateIdle, + }, nil +} + +// Start 初始化任务执行上下文并进入运行态。 +func (w *worker) Start(task Task, budget Budget, capability Capability) error { + if w == nil { + return errors.New("subagent: worker is nil") + } + if err := task.Validate(); err != nil { + return err + } + + w.mu.Lock() + defer w.mu.Unlock() + if w.state != StateIdle { + return errorsf("worker already started") + } + + w.task = task + w.budget = budget.normalize(w.policy.DefaultBudget) + w.capability = capability.normalize() + w.trace = nil + w.stepCount = 0 + w.output = Output{} + w.lastErr = "" + w.stopReason = "" + w.startedAt = time.Now() + w.endedAt = time.Time{} + w.state = StateRunning + return nil +} + +// Step 执行一次引擎步骤,并在满足终止条件时更新终态结果。 +func (w *worker) Step(ctx context.Context) (StepResult, error) { + if w == nil { + return StepResult{}, errors.New("subagent: worker is nil") + } + if err := ctx.Err(); err != nil { + return StepResult{}, err + } + + w.mu.Lock() + if w.state != StateRunning { + state := w.state + w.mu.Unlock() + return StepResult{}, errorsf("worker is not running, current state=%s", state) + } + if w.budget.Timeout > 0 && time.Since(w.startedAt) >= w.budget.Timeout { + result := w.finishLocked(StateFailed, StopReasonTimeout, Output{}, errorsf("worker timeout")) + w.mu.Unlock() + return StepResult{State: result.State, Done: true, Step: result.StepCount}, nil + } + if w.stepCount >= w.budget.MaxSteps { + result := w.finishLocked(StateFailed, StopReasonMaxSteps, Output{}, errorsf("worker reached max steps")) + w.mu.Unlock() + return StepResult{State: result.State, Done: true, Step: result.StepCount}, nil + } + + input := StepInput{ + Role: w.role, + Policy: w.policy, + Task: w.task, + Budget: w.budget, + Capability: w.capability, + StepIndex: w.stepCount + 1, + Trace: append([]string(nil), w.trace...), + } + w.mu.Unlock() + + stepOutput, err := w.engine.RunStep(ctx, input) + if err != nil { + w.mu.Lock() + result := w.finishLocked(StateFailed, StopReasonError, Output{}, err) + w.mu.Unlock() + return StepResult{State: result.State, Done: true, Step: result.StepCount}, err + } + + w.mu.Lock() + defer w.mu.Unlock() + if w.state != StateRunning { + return StepResult{}, errorsf("worker is not running, current state=%s", w.state) + } + + w.stepCount++ + delta := strings.TrimSpace(stepOutput.Delta) + if delta != "" { + w.trace = append(w.trace, delta) + } + + if stepOutput.Done { + if err := validateOutputContract(w.policy, stepOutput.Output); err != nil { + result := w.finishLocked(StateFailed, StopReasonError, Output{}, err) + return StepResult{State: result.State, Done: true, Step: result.StepCount, Delta: delta}, err + } + result := w.finishLocked(StateSucceeded, StopReasonCompleted, stepOutput.Output, nil) + return StepResult{State: result.State, Done: true, Step: result.StepCount, Delta: delta}, nil + } + + if w.stepCount >= w.budget.MaxSteps { + result := w.finishLocked(StateFailed, StopReasonMaxSteps, Output{}, errorsf("worker reached max steps")) + return StepResult{State: result.State, Done: true, Step: result.StepCount, Delta: delta}, nil + } + + return StepResult{ + State: StateRunning, + Done: false, + Step: w.stepCount, + Delta: delta, + }, nil +} + +// Stop 主动终止运行中的 worker,并按终止原因映射最终状态。 +func (w *worker) Stop(reason StopReason) error { + if w == nil { + return errors.New("subagent: worker is nil") + } + if strings.TrimSpace(string(reason)) == "" { + return errorsf("stop reason is required") + } + + w.mu.Lock() + defer w.mu.Unlock() + if w.state != StateRunning { + if w.state.Terminal() { + return nil + } + return errorsf("worker is not running, current state=%s", w.state) + } + + switch reason { + case StopReasonCompleted: + if err := validateOutputContract(w.policy, w.output); err != nil { + return err + } + w.finishLocked(StateSucceeded, reason, w.output, nil) + case StopReasonCanceled: + w.finishLocked(StateCanceled, reason, w.output, nil) + case StopReasonTimeout, StopReasonMaxSteps, StopReasonError: + w.finishLocked(StateFailed, reason, w.output, nil) + default: + return errorsf("unsupported stop reason %q", reason) + } + return nil +} + +// Result 返回 worker 的最终结构化结果。 +func (w *worker) Result() (Result, error) { + if w == nil { + return Result{}, errors.New("subagent: worker is nil") + } + + w.mu.RLock() + defer w.mu.RUnlock() + if !w.state.Terminal() { + return Result{}, errorsf("worker is not finished") + } + return w.snapshotLocked(), nil +} + +// State 返回当前 worker 生命周期状态。 +func (w *worker) State() State { + if w == nil { + return StateIdle + } + w.mu.RLock() + defer w.mu.RUnlock() + return w.state +} + +// Policy 返回当前 worker 角色策略副本。 +func (w *worker) Policy() RolePolicy { + if w == nil { + return RolePolicy{} + } + w.mu.RLock() + defer w.mu.RUnlock() + return w.policy +} + +// finishLocked 在持有写锁时将 worker 切换为终态并返回结果快照。 +func (w *worker) finishLocked(state State, reason StopReason, output Output, err error) Result { + w.state = state + w.stopReason = reason + w.endedAt = time.Now() + w.output = output.normalize() + if err != nil { + w.lastErr = err.Error() + } + return w.snapshotLocked() +} + +// snapshotLocked 在持有读锁或写锁时构造稳定结果快照。 +func (w *worker) snapshotLocked() Result { + return Result{ + Role: w.role, + TaskID: w.task.ID, + State: w.state, + StopReason: w.stopReason, + StartedAt: w.startedAt, + EndedAt: w.endedAt, + StepCount: w.stepCount, + Budget: w.budget, + Capability: w.capability, + Output: w.output, + Error: w.lastErr, + } +} diff --git a/internal/subagent/worker_test.go b/internal/subagent/worker_test.go new file mode 100644 index 00000000..ef4e42ce --- /dev/null +++ b/internal/subagent/worker_test.go @@ -0,0 +1,245 @@ +package subagent + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestWorkerLifecycleCompleted(t *testing.T) { + t.Parallel() + + policy, err := DefaultRolePolicy(RoleCoder) + if err != nil { + t.Fatalf("DefaultRolePolicy() error = %v", err) + } + + w, err := NewWorker(RoleCoder, policy, EngineFunc(func(ctx context.Context, input StepInput) (StepOutput, error) { + return StepOutput{ + Delta: "patched files", + Done: true, + Output: Output{ + Summary: "done", + Patches: []string{"a.go"}, + NextActions: []string{"run tests"}, + }, + }, nil + })) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + + err = w.Start(Task{ID: "t1", Goal: "fix bug"}, Budget{MaxSteps: 3}, Capability{ + AllowedTools: []string{"bash", "bash", " "}, + }) + if err != nil { + t.Fatalf("Start() error = %v", err) + } + + step, err := w.Step(context.Background()) + if err != nil { + t.Fatalf("Step() error = %v", err) + } + if !step.Done || step.State != StateSucceeded { + t.Fatalf("unexpected step result: %+v", step) + } + + result, err := w.Result() + if err != nil { + t.Fatalf("Result() error = %v", err) + } + if result.StopReason != StopReasonCompleted { + t.Fatalf("stop reason = %q, want %q", result.StopReason, StopReasonCompleted) + } + if result.StepCount != 1 { + t.Fatalf("step count = %d, want 1", result.StepCount) + } + if len(result.Capability.AllowedTools) != 1 { + t.Fatalf("expected capability dedupe, got %+v", result.Capability) + } +} + +func TestWorkerLifecycleFailures(t *testing.T) { + t.Parallel() + + policy, err := DefaultRolePolicy(RoleResearcher) + if err != nil { + t.Fatalf("DefaultRolePolicy() error = %v", err) + } + + t.Run("engine error", func(t *testing.T) { + t.Parallel() + + w, err := NewWorker(RoleResearcher, policy, EngineFunc(func(ctx context.Context, input StepInput) (StepOutput, error) { + return StepOutput{}, errors.New("boom") + })) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + if err := w.Start(Task{ID: "t2", Goal: "research"}, Budget{MaxSteps: 2}, Capability{}); err != nil { + t.Fatalf("Start() error = %v", err) + } + if _, err := w.Step(context.Background()); err == nil { + t.Fatalf("expected step error") + } + result, err := w.Result() + if err != nil { + t.Fatalf("Result() error = %v", err) + } + if result.State != StateFailed || result.StopReason != StopReasonError { + t.Fatalf("unexpected result: %+v", result) + } + }) + + t.Run("max steps", func(t *testing.T) { + t.Parallel() + + w, err := NewWorker(RoleResearcher, policy, EngineFunc(func(ctx context.Context, input StepInput) (StepOutput, error) { + return StepOutput{Delta: "not done", Done: false}, nil + })) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + if err := w.Start(Task{ID: "t3", Goal: "research"}, Budget{MaxSteps: 1}, Capability{}); err != nil { + t.Fatalf("Start() error = %v", err) + } + step, err := w.Step(context.Background()) + if err != nil { + t.Fatalf("Step() error = %v", err) + } + if !step.Done || step.State != StateFailed { + t.Fatalf("expected first step to finish by max steps, got %+v", step) + } + + result, err := w.Result() + if err != nil { + t.Fatalf("Result() error = %v", err) + } + if result.StopReason != StopReasonMaxSteps { + t.Fatalf("stop reason = %q, want %q", result.StopReason, StopReasonMaxSteps) + } + }) + + t.Run("timeout", func(t *testing.T) { + t.Parallel() + + w, err := NewWorker(RoleResearcher, policy, EngineFunc(func(ctx context.Context, input StepInput) (StepOutput, error) { + return StepOutput{Done: false}, nil + })) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + if err := w.Start(Task{ID: "t4", Goal: "research"}, Budget{MaxSteps: 5, Timeout: time.Nanosecond}, Capability{}); err != nil { + t.Fatalf("Start() error = %v", err) + } + time.Sleep(2 * time.Millisecond) + + step, err := w.Step(context.Background()) + if err != nil { + t.Fatalf("Step() error = %v", err) + } + if !step.Done || step.State != StateFailed { + t.Fatalf("unexpected timeout step: %+v", step) + } + result, err := w.Result() + if err != nil { + t.Fatalf("Result() error = %v", err) + } + if result.StopReason != StopReasonTimeout { + t.Fatalf("stop reason = %q, want %q", result.StopReason, StopReasonTimeout) + } + }) +} + +func TestWorkerStopAndGuards(t *testing.T) { + t.Parallel() + + policy, err := DefaultRolePolicy(RoleReviewer) + if err != nil { + t.Fatalf("DefaultRolePolicy() error = %v", err) + } + + w, err := NewWorker(RoleReviewer, policy, nil) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + + if _, err := w.Result(); err == nil { + t.Fatalf("expected result before finish to fail") + } + if _, err := w.Step(context.Background()); err == nil { + t.Fatalf("expected step before start to fail") + } + if err := w.Start(Task{ID: "review", Goal: "review"}, Budget{}, Capability{}); err != nil { + t.Fatalf("Start() error = %v", err) + } + if err := w.Start(Task{ID: "review2", Goal: "review2"}, Budget{}, Capability{}); err == nil { + t.Fatalf("expected double start to fail") + } + if err := w.Stop(StopReasonCanceled); err != nil { + t.Fatalf("Stop() error = %v", err) + } + if w.State() != StateCanceled { + t.Fatalf("state = %q, want %q", w.State(), StateCanceled) + } + if err := w.Stop(StopReasonCanceled); err != nil { + t.Fatalf("terminal stop should be idempotent, got %v", err) + } +} + +func TestWorkerFactoryCreate(t *testing.T) { + t.Parallel() + + factory := NewWorkerFactory(func(role Role, policy RolePolicy) Engine { + return EngineFunc(func(ctx context.Context, input StepInput) (StepOutput, error) { + return StepOutput{Done: true, Output: Output{Summary: "ok"}}, nil + }) + }) + + w, err := factory.Create(RoleCoder) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + if w.Policy().Role != RoleCoder { + t.Fatalf("policy role = %q, want %q", w.Policy().Role, RoleCoder) + } + + if _, err := factory.Create(Role("invalid")); err == nil { + t.Fatalf("expected invalid role create to fail") + } +} + +func TestWorkerRejectsInvalidOutputContract(t *testing.T) { + t.Parallel() + + policy, err := DefaultRolePolicy(RoleCoder) + if err != nil { + t.Fatalf("DefaultRolePolicy() error = %v", err) + } + w, err := NewWorker(RoleCoder, policy, EngineFunc(func(ctx context.Context, input StepInput) (StepOutput, error) { + return StepOutput{ + Done: true, + Output: Output{ + Summary: " ", + }, + }, nil + })) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + if err := w.Start(Task{ID: "t-invalid-output", Goal: "goal"}, Budget{MaxSteps: 3}, Capability{}); err != nil { + t.Fatalf("Start() error = %v", err) + } + + if _, err := w.Step(context.Background()); err == nil { + t.Fatalf("expected invalid output contract error") + } + result, err := w.Result() + if err != nil { + t.Fatalf("Result() error = %v", err) + } + if result.State != StateFailed || result.StopReason != StopReasonError { + t.Fatalf("unexpected result: %+v", result) + } +} From 9384e801bbbce6bd8024dcc4092b432e7c0389bd Mon Sep 17 00:00:00 2001 From: xgopilot Date: Tue, 14 Apr 2026 16:01:32 +0000 Subject: [PATCH 02/33] fix(subagent): resolve review conflicts and strengthen coverage Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: Cai-Tang-www <106404101+Cai-Tang-www@users.noreply.github.com> --- internal/runtime/subagent_factory_test.go | 5 + internal/runtime/subagent_run.go | 49 ++++++-- internal/runtime/subagent_run_test.go | 79 +++++++++++++ internal/subagent/engine_test.go | 59 ++++++++++ internal/subagent/output_contract.go | 46 +++++++- internal/subagent/output_contract_test.go | 74 ++++++++++++ internal/subagent/policy.go | 2 +- internal/subagent/worker.go | 57 +++++++++- internal/subagent/worker_test.go | 133 +++++++++++++++++++++- 9 files changed, 486 insertions(+), 18 deletions(-) create mode 100644 internal/subagent/engine_test.go create mode 100644 internal/subagent/output_contract_test.go diff --git a/internal/runtime/subagent_factory_test.go b/internal/runtime/subagent_factory_test.go index eea246b1..70256ce4 100644 --- a/internal/runtime/subagent_factory_test.go +++ b/internal/runtime/subagent_factory_test.go @@ -30,4 +30,9 @@ func TestServiceSubAgentFactoryRegistration(t *testing.T) { if svc.SubAgentFactory() == nil { t.Fatalf("expected reset to default sub-agent factory") } + + svc.subAgentFactory = nil + if svc.SubAgentFactory() == nil { + t.Fatalf("expected lazy init default sub-agent factory") + } } diff --git a/internal/runtime/subagent_run.go b/internal/runtime/subagent_run.go index 45169c03..bf0adbd8 100644 --- a/internal/runtime/subagent_run.go +++ b/internal/runtime/subagent_run.go @@ -64,16 +64,28 @@ func (s *Service) RunSubAgentTask(ctx context.Context, input SubAgentTaskInput) for { stepResult, stepErr := worker.Step(ctx) - _ = s.emit(ctx, EventSubAgentProgress, input.RunID, input.SessionID, SubAgentEventPayload{ - Role: input.Role, - TaskID: input.Task.ID, - State: stepResult.State, - Step: stepResult.Step, - Delta: stepResult.Delta, - Error: errorText(stepErr), - }) + if stepResult.State == "" { + stepResult.State = worker.State() + } + emitSubAgentProgress(s, input, stepResult, stepErr) if stepErr != nil { + if errors.Is(stepErr, context.Canceled) || errors.Is(stepErr, context.DeadlineExceeded) { + _ = worker.Stop(subagent.StopReasonCanceled) + result, resultErr := worker.Result() + if resultErr != nil { + result = subagent.Result{ + Role: input.Role, + TaskID: input.Task.ID, + State: subagent.StateCanceled, + StopReason: subagent.StopReasonCanceled, + Error: errorText(stepErr), + } + } + emitSubAgentTerminal(s, ctx, input, result) + return result, stepErr + } + result, resultErr := worker.Result() if resultErr != nil { _ = s.emit(ctx, EventSubAgentFailed, input.RunID, input.SessionID, SubAgentEventPayload{ @@ -110,6 +122,27 @@ func (s *Service) RunSubAgentTask(ctx context.Context, input SubAgentTaskInput) } } +// emitSubAgentProgress 非阻塞发射进度事件,避免慢消费者反压执行路径。 +func emitSubAgentProgress(s *Service, input SubAgentTaskInput, stepResult subagent.StepResult, stepErr error) { + payload := SubAgentEventPayload{ + Role: input.Role, + TaskID: input.Task.ID, + State: stepResult.State, + Step: stepResult.Step, + Delta: stepResult.Delta, + Error: errorText(stepErr), + } + select { + case s.events <- RuntimeEvent{ + Type: EventSubAgentProgress, + RunID: input.RunID, + SessionID: input.SessionID, + Payload: payload, + }: + default: + } +} + // emitSubAgentTerminal 按子代理终态发射最终事件。 func emitSubAgentTerminal(s *Service, ctx context.Context, input SubAgentTaskInput, result subagent.Result) { payload := SubAgentEventPayload{ diff --git a/internal/runtime/subagent_run_test.go b/internal/runtime/subagent_run_test.go index fa66d19b..116abaf2 100644 --- a/internal/runtime/subagent_run_test.go +++ b/internal/runtime/subagent_run_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "testing" + "time" "neo-code/internal/subagent" ) @@ -125,6 +126,70 @@ func TestServiceRunSubAgentTaskFailureFlows(t *testing.T) { EventSubAgentFailed, }) }) + + t.Run("context canceled should emit canceled", func(t *testing.T) { + t.Parallel() + + service := NewWithFactory(nil, nil, nil, nil, nil) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + service.SetSubAgentFactory(subagent.NewWorkerFactory(func(role subagent.Role, policy subagent.RolePolicy) subagent.Engine { + return subagent.EngineFunc(func(ctx context.Context, input subagent.StepInput) (subagent.StepOutput, error) { + cancel() + <-ctx.Done() + return subagent.StepOutput{}, ctx.Err() + }) + })) + + result, err := service.RunSubAgentTask(ctx, SubAgentTaskInput{ + RunID: "sub-run-canceled", + Role: subagent.RoleReviewer, + Task: subagent.Task{ + ID: "task-cancel", + Goal: "review", + }, + }) + if err == nil { + t.Fatalf("expected canceled error") + } + if !errors.Is(err, context.Canceled) { + t.Fatalf("error = %v, want context canceled", err) + } + if result.State != subagent.StateCanceled { + t.Fatalf("result state = %q, want %q", result.State, subagent.StateCanceled) + } + if result.StopReason != subagent.StopReasonCanceled { + t.Fatalf("stop reason = %q, want %q", result.StopReason, subagent.StopReasonCanceled) + } + events := collectRuntimeEvents(service.Events()) + assertEventSequence(t, events, []EventType{ + EventSubAgentStarted, + EventSubAgentProgress, + EventSubAgentCanceled, + }) + }) + + t.Run("worker start failed by disallowed capability", func(t *testing.T) { + t.Parallel() + + service := NewWithFactory(nil, nil, nil, nil, nil) + _, err := service.RunSubAgentTask(context.Background(), SubAgentTaskInput{ + RunID: "sub-run-start-failed", + Role: subagent.RoleReviewer, + Task: subagent.Task{ + ID: "task-start-failed", + Goal: "review", + }, + Capability: subagent.Capability{ + AllowedTools: []string{"bash"}, + }, + }) + if err == nil { + t.Fatalf("expected start error") + } + events := collectRuntimeEvents(service.Events()) + assertEventSequence(t, events, []EventType{EventSubAgentFailed}) + }) } func TestServiceRunSubAgentTaskInputValidation(t *testing.T) { @@ -151,4 +216,18 @@ func TestServiceRunSubAgentTaskInputValidation(t *testing.T) { }); err == nil { t.Fatalf("expected invalid role error") } + + ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) + defer cancel() + time.Sleep(2 * time.Millisecond) + if _, err := service.RunSubAgentTask(ctx, SubAgentTaskInput{ + RunID: "sub-run-timeout", + Role: subagent.RoleCoder, + Task: subagent.Task{ + ID: "task-timeout", + Goal: "goal", + }, + }); err == nil { + t.Fatalf("expected context error") + } } diff --git a/internal/subagent/engine_test.go b/internal/subagent/engine_test.go new file mode 100644 index 00000000..feeb449c --- /dev/null +++ b/internal/subagent/engine_test.go @@ -0,0 +1,59 @@ +package subagent + +import ( + "context" + "testing" +) + +func TestDefaultEngineRunStep(t *testing.T) { + t.Parallel() + + engine := defaultEngine{} + + t.Run("uses expected output as summary", func(t *testing.T) { + t.Parallel() + + out, err := engine.RunStep(context.Background(), StepInput{ + Task: Task{ + Goal: "goal", + ExpectedOutput: "expected", + }, + }) + if err != nil { + t.Fatalf("RunStep() error = %v", err) + } + if !out.Done { + t.Fatalf("expected done output") + } + if out.Output.Summary != "expected" { + t.Fatalf("summary = %q, want %q", out.Output.Summary, "expected") + } + }) + + t.Run("falls back to goal", func(t *testing.T) { + t.Parallel() + + out, err := engine.RunStep(context.Background(), StepInput{ + Task: Task{ + Goal: "goal-value", + ExpectedOutput: " ", + }, + }) + if err != nil { + t.Fatalf("RunStep() error = %v", err) + } + if out.Output.Summary != "goal-value" { + t.Fatalf("summary = %q, want %q", out.Output.Summary, "goal-value") + } + }) + + t.Run("context canceled", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, err := engine.RunStep(ctx, StepInput{Task: Task{Goal: "g"}}); err == nil { + t.Fatalf("expected context error") + } + }) +} diff --git a/internal/subagent/output_contract.go b/internal/subagent/output_contract.go index 4fe063ee..5b7514ea 100644 --- a/internal/subagent/output_contract.go +++ b/internal/subagent/output_contract.go @@ -14,14 +14,48 @@ var supportedOutputSections = map[string]struct{}{ // validateOutputContract 校验输出结构是否满足角色策略要求。 func validateOutputContract(policy RolePolicy, output Output) error { out := output.normalize() - for _, section := range dedupeAndTrim(policy.RequiredSections) { + requiredSections, err := normalizeRequiredSections(policy.RequiredSections) + if err != nil { + return err + } + for _, key := range requiredSections { + if !hasOutputSectionContent(out, key) { + return errorsf("output section %q is required", key) + } + } + return nil +} + +// normalizeRequiredSections 归一化并校验 required section 名称集合。 +func normalizeRequiredSections(sections []string) ([]string, error) { + items := dedupeAndTrim(sections) + keys := make([]string, 0, len(items)) + for _, section := range items { key := strings.ToLower(strings.TrimSpace(section)) if _, ok := supportedOutputSections[key]; !ok { - return errorsf("unsupported required output section %q", section) - } - if key == "summary" && strings.TrimSpace(out.Summary) == "" { - return errorsf("output summary is required") + return nil, errorsf("unsupported required output section %q", section) } + keys = append(keys, key) + } + return keys, nil +} + +// hasOutputSectionContent 判断指定 section 在输出中是否包含有效内容。 +func hasOutputSectionContent(output Output, section string) bool { + switch section { + case "summary": + return strings.TrimSpace(output.Summary) != "" + case "findings": + return len(output.Findings) > 0 + case "patches": + return len(output.Patches) > 0 + case "risks": + return len(output.Risks) > 0 + case "next_actions": + return len(output.NextActions) > 0 + case "artifacts": + return len(output.Artifacts) > 0 + default: + return false } - return nil } diff --git a/internal/subagent/output_contract_test.go b/internal/subagent/output_contract_test.go new file mode 100644 index 00000000..3971c2d0 --- /dev/null +++ b/internal/subagent/output_contract_test.go @@ -0,0 +1,74 @@ +package subagent + +import "testing" + +func TestValidateOutputContractRequiresAllDeclaredSections(t *testing.T) { + t.Parallel() + + policy := RolePolicy{ + Role: RoleCoder, + SystemPrompt: "prompt", + AllowedTools: []string{"bash"}, + RequiredSections: []string{"summary", "findings", "patches", "risks", "next_actions", "artifacts"}, + } + + full := Output{ + Summary: "ok", + Findings: []string{"f"}, + Patches: []string{"p"}, + Risks: []string{"r"}, + NextActions: []string{"n"}, + Artifacts: []string{"a"}, + } + if err := validateOutputContract(policy, full); err != nil { + t.Fatalf("validateOutputContract() error = %v", err) + } + + cases := []struct { + name string + output Output + }{ + {name: "missing summary", output: Output{ + Findings: []string{"f"}, Patches: []string{"p"}, Risks: []string{"r"}, NextActions: []string{"n"}, Artifacts: []string{"a"}, + }}, + {name: "missing findings", output: Output{ + Summary: "ok", Patches: []string{"p"}, Risks: []string{"r"}, NextActions: []string{"n"}, Artifacts: []string{"a"}, + }}, + {name: "missing patches", output: Output{ + Summary: "ok", Findings: []string{"f"}, Risks: []string{"r"}, NextActions: []string{"n"}, Artifacts: []string{"a"}, + }}, + {name: "missing risks", output: Output{ + Summary: "ok", Findings: []string{"f"}, Patches: []string{"p"}, NextActions: []string{"n"}, Artifacts: []string{"a"}, + }}, + {name: "missing next actions", output: Output{ + Summary: "ok", Findings: []string{"f"}, Patches: []string{"p"}, Risks: []string{"r"}, Artifacts: []string{"a"}, + }}, + {name: "missing artifacts", output: Output{ + Summary: "ok", Findings: []string{"f"}, Patches: []string{"p"}, Risks: []string{"r"}, NextActions: []string{"n"}, + }}, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if err := validateOutputContract(policy, tc.output); err == nil { + t.Fatalf("expected contract validation error") + } + }) + } +} + +func TestValidateOutputContractUnsupportedSection(t *testing.T) { + t.Parallel() + + policy := RolePolicy{ + Role: RoleResearcher, + SystemPrompt: "prompt", + AllowedTools: []string{"filesystem_grep"}, + RequiredSections: []string{"summary", "x"}, + } + if err := validateOutputContract(policy, Output{Summary: "ok"}); err == nil { + t.Fatalf("expected unsupported section error") + } +} diff --git a/internal/subagent/policy.go b/internal/subagent/policy.go index 7a5b3d62..12c81544 100644 --- a/internal/subagent/policy.go +++ b/internal/subagent/policy.go @@ -33,7 +33,7 @@ func (p RolePolicy) Validate() error { if len(dedupeAndTrim(p.RequiredSections)) == 0 { return errorsf("role policy required sections is empty") } - if err := validateOutputContract(p, Output{Summary: "probe"}); err != nil { + if _, err := normalizeRequiredSections(p.RequiredSections); err != nil { return err } return nil diff --git a/internal/subagent/worker.go b/internal/subagent/worker.go index 995f4501..6e6c791a 100644 --- a/internal/subagent/worker.go +++ b/internal/subagent/worker.go @@ -27,6 +27,8 @@ type worker struct { lastErr string } +const traceWindowSize = 16 + // NewWorker 根据角色、策略与引擎创建一个 WorkerRuntime 实例。 func NewWorker(role Role, policy RolePolicy, engine Engine) (WorkerRuntime, error) { if !role.Valid() { @@ -70,7 +72,11 @@ func (w *worker) Start(task Task, budget Budget, capability Capability) error { w.task = task w.budget = budget.normalize(w.policy.DefaultBudget) - w.capability = capability.normalize() + effectiveCapability, err := bindCapabilityToPolicy(capability.normalize(), w.policy) + if err != nil { + return err + } + w.capability = effectiveCapability w.trace = nil w.stepCount = 0 w.output = Output{} @@ -115,13 +121,18 @@ func (w *worker) Step(ctx context.Context) (StepResult, error) { Budget: w.budget, Capability: w.capability, StepIndex: w.stepCount + 1, - Trace: append([]string(nil), w.trace...), + Trace: cloneRecentTrace(w.trace, traceWindowSize), } w.mu.Unlock() stepOutput, err := w.engine.RunStep(ctx, input) if err != nil { w.mu.Lock() + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + result := w.finishLocked(StateCanceled, StopReasonCanceled, Output{}, nil) + w.mu.Unlock() + return StepResult{State: result.State, Done: true, Step: result.StepCount}, err + } result := w.finishLocked(StateFailed, StopReasonError, Output{}, err) w.mu.Unlock() return StepResult{State: result.State, Done: true, Step: result.StepCount}, err @@ -161,6 +172,48 @@ func (w *worker) Step(ctx context.Context) (StepResult, error) { }, nil } +// bindCapabilityToPolicy 将 capability 约束在角色策略允许的工具集合内。 +func bindCapabilityToPolicy(capability Capability, policy RolePolicy) (Capability, error) { + allowedPolicyTools := dedupeAndTrim(policy.AllowedTools) + allowedSet := make(map[string]struct{}, len(allowedPolicyTools)) + for _, tool := range allowedPolicyTools { + allowedSet[strings.ToLower(strings.TrimSpace(tool))] = struct{}{} + } + + if len(capability.AllowedTools) == 0 { + capability.AllowedTools = append([]string(nil), allowedPolicyTools...) + return capability, nil + } + + effective := make([]string, 0, len(capability.AllowedTools)) + disallowed := make([]string, 0) + for _, tool := range capability.AllowedTools { + normalized := strings.ToLower(strings.TrimSpace(tool)) + if _, ok := allowedSet[normalized]; !ok { + disallowed = append(disallowed, tool) + continue + } + effective = append(effective, tool) + } + if len(disallowed) > 0 { + return Capability{}, errorsf("capability contains disallowed tools: %s", strings.Join(disallowed, ", ")) + } + capability.AllowedTools = effective + return capability, nil +} + +// cloneRecentTrace 复制最近 limit 条 trace,避免每步复制完整历史导致复杂度放大。 +func cloneRecentTrace(trace []string, limit int) []string { + if len(trace) == 0 { + return nil + } + if limit <= 0 || len(trace) <= limit { + return append([]string(nil), trace...) + } + start := len(trace) - limit + return append([]string(nil), trace[start:]...) +} + // Stop 主动终止运行中的 worker,并按终止原因映射最终状态。 func (w *worker) Stop(reason StopReason) error { if w == nil { diff --git a/internal/subagent/worker_test.go b/internal/subagent/worker_test.go index ef4e42ce..b323ad98 100644 --- a/internal/subagent/worker_test.go +++ b/internal/subagent/worker_test.go @@ -21,8 +21,11 @@ func TestWorkerLifecycleCompleted(t *testing.T) { Done: true, Output: Output{ Summary: "done", + Findings: []string{"root cause fixed"}, Patches: []string{"a.go"}, + Risks: []string{"need integration verify"}, NextActions: []string{"run tests"}, + Artifacts: []string{"test report"}, }, }, nil })) @@ -193,7 +196,17 @@ func TestWorkerFactoryCreate(t *testing.T) { factory := NewWorkerFactory(func(role Role, policy RolePolicy) Engine { return EngineFunc(func(ctx context.Context, input StepInput) (StepOutput, error) { - return StepOutput{Done: true, Output: Output{Summary: "ok"}}, nil + return StepOutput{ + Done: true, + Output: Output{ + Summary: "ok", + Findings: []string{"f1"}, + Patches: []string{"p1"}, + Risks: []string{"r1"}, + NextActions: []string{"n1"}, + Artifacts: []string{"a1"}, + }, + }, nil }) }) @@ -243,3 +256,121 @@ func TestWorkerRejectsInvalidOutputContract(t *testing.T) { t.Fatalf("unexpected result: %+v", result) } } + +func TestWorkerStartCapabilityPolicyGuard(t *testing.T) { + t.Parallel() + + policy, err := DefaultRolePolicy(RoleReviewer) + if err != nil { + t.Fatalf("DefaultRolePolicy() error = %v", err) + } + + w, err := NewWorker(RoleReviewer, policy, nil) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + + if err := w.Start(Task{ID: "t-cap", Goal: "goal"}, Budget{}, Capability{ + AllowedTools: []string{"filesystem_read_file", "bash"}, + }); err == nil { + t.Fatalf("expected disallowed capability tool to fail") + } + + w2, err := NewWorker(RoleReviewer, policy, nil) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + if err := w2.Start(Task{ID: "t-cap-ok", Goal: "goal"}, Budget{}, Capability{}); err != nil { + t.Fatalf("Start() error = %v", err) + } + if err := w2.Stop(StopReasonCanceled); err != nil { + t.Fatalf("Stop() error = %v", err) + } + result, err := w2.Result() + if err != nil { + t.Fatalf("Result() error = %v", err) + } + if len(result.Capability.AllowedTools) != len(policy.AllowedTools) { + t.Fatalf("capability tools = %v, want policy tools %v", result.Capability.AllowedTools, policy.AllowedTools) + } +} + +func TestWorkerTraceWindow(t *testing.T) { + t.Parallel() + + policy, err := DefaultRolePolicy(RoleResearcher) + if err != nil { + t.Fatalf("DefaultRolePolicy() error = %v", err) + } + + var observedTraceLen int + w, err := NewWorker(RoleResearcher, policy, EngineFunc(func(ctx context.Context, input StepInput) (StepOutput, error) { + observedTraceLen = len(input.Trace) + return StepOutput{ + Done: true, + Output: Output{ + Summary: "done", + Findings: []string{"f1"}, + Patches: []string{"p1"}, + Risks: []string{"r1"}, + NextActions: []string{"n1"}, + Artifacts: []string{"a1"}, + }, + }, nil + })) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + impl, ok := w.(*worker) + if !ok { + t.Fatalf("expected *worker implementation") + } + impl.trace = make([]string, traceWindowSize+4) + for i := range impl.trace { + impl.trace[i] = "trace" + } + impl.state = StateRunning + impl.task = Task{ID: "trace", Goal: "goal"} + impl.budget = Budget{MaxSteps: 5, Timeout: time.Second} + impl.capability = Capability{} + impl.startedAt = time.Now() + + if _, err := w.Step(context.Background()); err != nil { + t.Fatalf("Step() error = %v", err) + } + if observedTraceLen != traceWindowSize { + t.Fatalf("trace len = %d, want %d", observedTraceLen, traceWindowSize) + } +} + +func TestWorkerNilAndValidationBranches(t *testing.T) { + t.Parallel() + + if _, err := NewWorker(Role("bad"), RolePolicy{}, nil); err == nil { + t.Fatalf("expected invalid role error") + } + if _, err := NewWorker(RoleCoder, RolePolicy{Role: RoleReviewer, SystemPrompt: "p", AllowedTools: []string{"bash"}, RequiredSections: []string{"summary"}}, nil); err == nil { + t.Fatalf("expected role mismatch error") + } + + var nilWorker *worker + if err := nilWorker.Start(Task{}, Budget{}, Capability{}); err == nil { + t.Fatalf("expected nil worker start error") + } + if _, err := nilWorker.Step(context.Background()); err == nil { + t.Fatalf("expected nil worker step error") + } + if err := nilWorker.Stop(StopReasonCanceled); err == nil { + t.Fatalf("expected nil worker stop error") + } + if _, err := nilWorker.Result(); err == nil { + t.Fatalf("expected nil worker result error") + } + if nilWorker.State() != StateIdle { + t.Fatalf("nil worker state = %q, want %q", nilWorker.State(), StateIdle) + } + nilPolicy := nilWorker.Policy() + if nilPolicy.Role != "" || nilPolicy.SystemPrompt != "" || len(nilPolicy.AllowedTools) != 0 || len(nilPolicy.RequiredSections) != 0 { + t.Fatalf("nil worker policy should be empty, got %+v", nilPolicy) + } +} From cc74bd41131e958db9d54c1cd0080c0a05e4e1b1 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Tue, 14 Apr 2026 16:18:05 +0000 Subject: [PATCH 03/33] fix(runtime): resolve merge conflicts with main event/runtime changes Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: Cai-Tang-www <106404101+Cai-Tang-www@users.noreply.github.com> --- internal/runtime/events.go | 63 +++++++++++++++++++++++++++++++------ internal/runtime/runtime.go | 47 +++++++++++++++++---------- 2 files changed, 84 insertions(+), 26 deletions(-) diff --git a/internal/runtime/events.go b/internal/runtime/events.go index f9506c40..f8608473 100644 --- a/internal/runtime/events.go +++ b/internal/runtime/events.go @@ -1,6 +1,10 @@ package runtime -import "neo-code/internal/subagent" +import ( + "time" + + "neo-code/internal/subagent" +) // EventType identifies the kind of runtime event emitted during a run. type EventType string @@ -9,10 +13,41 @@ type EventType string // for a specific run. RunID is provided by the caller and is echoed back on all // events so upper layers can ignore stale events from older runs. type RuntimeEvent struct { - Type EventType - RunID string - SessionID string - Payload any + Type EventType + RunID string + SessionID string + Turn int + Phase string + Timestamp time.Time + PayloadVersion int + Payload any +} + +// PhaseChangedPayload 描述 phase 迁移。 +type PhaseChangedPayload struct { + From string `json:"from"` + To string `json:"to"` +} + +// BudgetCheckedPayload 为预算检查壳事件负载(1A 仅占位,1B阶段使用)。 +type BudgetCheckedPayload struct { + Note string `json:"note,omitempty"` +} + +// ProgressEvaluatedPayload 汇总 progress 控制面评估结果。 +type ProgressEvaluatedPayload struct { + Score string `json:"score"` +} + +// StopReasonDecidedPayload 承载唯一停止原因决议结果。 +type StopReasonDecidedPayload struct { + Reason string `json:"reason"` + Detail string `json:"detail,omitempty"` +} + +// LedgerReconciledPayload 为账本对账壳事件负载(1A 仅占位)。 +type LedgerReconciledPayload struct { + Note string `json:"note,omitempty"` } // PermissionRequestPayload 描述一次需要审批的权限请求上下文。 @@ -71,18 +106,28 @@ const ( // EventProviderRetry is emitted when runtime retries a provider call due to // a retryable error (e.g. 429, 5xx). Payload is a human-readable message. EventProviderRetry EventType = "provider_retry" - // EventPermissionRequest is emitted when a tool call hits an ask decision. - EventPermissionRequest EventType = "permission_request" + // EventPermissionRequested 是 1A 权限请求事件名。 + EventPermissionRequested EventType = "permission_requested" + // EventPermissionRequest 为兼容旧事件名保留,语义等同 EventPermissionRequested。 + EventPermissionRequest = EventPermissionRequested // EventPermissionResolved is emitted when runtime resolves a permission request or denial. EventPermissionResolved EventType = "permission_resolved" // EventCompactStart is emitted when a compact cycle starts. EventCompactStart EventType = "compact_start" - // EventCompactDone is emitted when a compact cycle completes. - EventCompactDone EventType = "compact_done" + // EventCompactApplied 表示一次 compact 已成功应用或校验完成(1A 主事件)。 + EventCompactApplied EventType = "compact_applied" + // EventCompactDone 为兼容旧事件名保留,语义等同 EventCompactApplied。 + EventCompactDone = EventCompactApplied // EventCompactError is emitted when compact fails. EventCompactError EventType = "compact_error" // EventTokenUsage is emitted after each provider response with token statistics. EventTokenUsage EventType = "token_usage" + // EventPhaseChanged 表示显式 phase 迁移。 + EventPhaseChanged EventType = "phase_changed" + // EventProgressEvaluated 表示 progress 评估结果。 + EventProgressEvaluated EventType = "progress_evaluated" + // EventStopReasonDecided 表示唯一停止原因已决议。 + EventStopReasonDecided EventType = "stop_reason_decided" ) // TokenUsagePayload carries token usage statistics for a single provider turn. diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index ec1aa87f..ad4e28a6 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -23,6 +23,10 @@ const ( providerRetryBaseWait = 1 * time.Second providerRetryMaxWait = 5 * time.Second defaultMaxLoops = 8 + defaultToolParallelism = 4 + noProgressStreakLimit = 3 + + terminationEventEmitTimeout = 500 * time.Millisecond ) // Runtime 定义 runtime 对外暴露的运行、压缩与审批接口。 @@ -68,13 +72,15 @@ type Service struct { memoExtractor MemoExtractor subAgentFactory subagent.Factory - events chan RuntimeEvent - sessionMu sync.Mutex - sessionLocks map[string]*sessionLockEntry - runMu sync.Mutex - activeRunToken uint64 - nextRunToken uint64 - activeRunCancels map[uint64]context.CancelFunc + events chan RuntimeEvent + sessionMu sync.Mutex + sessionLocks map[string]*sessionLockEntry + runMu sync.Mutex + activeRunToken uint64 + nextRunToken uint64 + activeRunCancels map[uint64]context.CancelFunc + permissionAskMapMu sync.Mutex + permissionAskLocks map[string]*permissionAskLockEntry } // sessionLockEntry 维护单个会话锁及其当前引用计数,用于在无引用时回收 map 项。 @@ -83,6 +89,12 @@ type sessionLockEntry struct { refs int } +// permissionAskLockEntry 维护单个运行的审批串行锁与引用计数。 +type permissionAskLockEntry struct { + mu sync.Mutex + refs int +} + // NewWithFactory 使用注入依赖构建默认 runtime Service。 func NewWithFactory( configManager *config.Manager, @@ -102,16 +114,17 @@ func NewWithFactory( } return &Service{ - configManager: configManager, - sessionStore: sessionStore, - toolManager: toolManager, - providerFactory: providerFactory, - contextBuilder: contextBuilder, - approvalBroker: approval.NewBroker(), - subAgentFactory: subagent.NewWorkerFactory(nil), - events: make(chan RuntimeEvent, 128), - sessionLocks: make(map[string]*sessionLockEntry), - activeRunCancels: make(map[uint64]context.CancelFunc), + configManager: configManager, + sessionStore: sessionStore, + toolManager: toolManager, + providerFactory: providerFactory, + contextBuilder: contextBuilder, + approvalBroker: approval.NewBroker(), + subAgentFactory: subagent.NewWorkerFactory(nil), + events: make(chan RuntimeEvent, 128), + sessionLocks: make(map[string]*sessionLockEntry), + activeRunCancels: make(map[uint64]context.CancelFunc), + permissionAskLocks: make(map[string]*permissionAskLockEntry), } } From 730a61b129c5ccb04d04b837699ce8e6b2d2fb66 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Tue, 14 Apr 2026 16:21:23 +0000 Subject: [PATCH 04/33] fix(runtime): align runtime core with main to resolve merge conflicts Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: Cai-Tang-www <106404101+Cai-Tang-www@users.noreply.github.com> --- internal/runtime/controlplane/decider.go | 38 ++++++++ internal/runtime/controlplane/decider_test.go | 88 +++++++++++++++++++ internal/runtime/controlplane/envelope.go | 4 + internal/runtime/controlplane/phase.go | 13 +++ internal/runtime/controlplane/progress.go | 39 ++++++++ .../runtime/controlplane/progress_test.go | 42 +++++++++ internal/runtime/controlplane/stop_reason.go | 15 ++++ internal/runtime/events.go | 36 +------- internal/runtime/events_subagent.go | 33 +++++++ internal/runtime/runtime.go | 7 +- internal/runtime/runtime_limits.go | 4 + 11 files changed, 283 insertions(+), 36 deletions(-) create mode 100644 internal/runtime/controlplane/decider.go create mode 100644 internal/runtime/controlplane/decider_test.go create mode 100644 internal/runtime/controlplane/envelope.go create mode 100644 internal/runtime/controlplane/phase.go create mode 100644 internal/runtime/controlplane/progress.go create mode 100644 internal/runtime/controlplane/progress_test.go create mode 100644 internal/runtime/controlplane/stop_reason.go create mode 100644 internal/runtime/events_subagent.go create mode 100644 internal/runtime/runtime_limits.go diff --git a/internal/runtime/controlplane/decider.go b/internal/runtime/controlplane/decider.go new file mode 100644 index 00000000..641bc6cb --- /dev/null +++ b/internal/runtime/controlplane/decider.go @@ -0,0 +1,38 @@ +package controlplane + +import ( + "context" + "errors" + "strings" +) + +// StopInput 汇总停止决议所需的信号(可多信号并存,由 DecideStopReason 按优先级表决)。 +type StopInput struct { + ContextCanceled bool + MaxLoopsReached bool + RunError error + Success bool +} + +// DecideStopReason 按固定优先级返回唯一 StopReason:取消 > 达到轮数上限 > 错误 > 成功。 +func DecideStopReason(in StopInput) (StopReason, string) { + if in.ContextCanceled { + return StopReasonCanceled, "" + } + if in.MaxLoopsReached { + if in.RunError != nil { + return StopReasonMaxLoops, strings.TrimSpace(in.RunError.Error()) + } + return StopReasonMaxLoops, "runtime: max loop reached" + } + if in.RunError != nil { + if errors.Is(in.RunError, context.Canceled) { + return StopReasonCanceled, "" + } + return StopReasonError, strings.TrimSpace(in.RunError.Error()) + } + if in.Success { + return StopReasonSuccess, "" + } + return StopReasonError, "runtime: stop reason undetermined" +} diff --git a/internal/runtime/controlplane/decider_test.go b/internal/runtime/controlplane/decider_test.go new file mode 100644 index 00000000..1f41b8e4 --- /dev/null +++ b/internal/runtime/controlplane/decider_test.go @@ -0,0 +1,88 @@ +package controlplane + +import ( + "context" + "errors" + "testing" +) + +func TestDecideStopReasonPriority(t *testing.T) { + t.Parallel() + + errSample := errors.New("boom") + cases := []struct { + name string + in StopInput + reason StopReason + }{ + { + name: "canceled_wins_over_max_loops", + in: StopInput{ + ContextCanceled: true, + MaxLoopsReached: true, + RunError: errSample, + }, + reason: StopReasonCanceled, + }, + { + name: "max_loops_wins_over_error", + in: StopInput{ + MaxLoopsReached: true, + RunError: errSample, + }, + reason: StopReasonMaxLoops, + }, + { + name: "error_when_no_max_loop_flag", + in: StopInput{ + RunError: errSample, + }, + reason: StopReasonError, + }, + { + name: "success", + in: StopInput{ + Success: true, + }, + reason: StopReasonSuccess, + }, + { + name: "context_canceled_on_error_field", + in: StopInput{ + RunError: context.Canceled, + }, + reason: StopReasonCanceled, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, _ := DecideStopReason(tc.in) + if got != tc.reason { + t.Fatalf("DecideStopReason() = %q, want %q", got, tc.reason) + } + }) + } +} + +func TestDecideStopReasonDetails(t *testing.T) { + t.Parallel() + + reason, detail := DecideStopReason(StopInput{MaxLoopsReached: true}) + if reason != StopReasonMaxLoops { + t.Fatalf("reason = %q, want %q", reason, StopReasonMaxLoops) + } + if detail != "runtime: max loop reached" { + t.Fatalf("detail = %q, want default max-loop detail", detail) + } + + reason, detail = DecideStopReason(StopInput{}) + if reason != StopReasonError { + t.Fatalf("reason = %q, want %q", reason, StopReasonError) + } + if detail != "runtime: stop reason undetermined" { + t.Fatalf("detail = %q, want undetermined detail", detail) + } +} diff --git a/internal/runtime/controlplane/envelope.go b/internal/runtime/controlplane/envelope.go new file mode 100644 index 00000000..ec2006fd --- /dev/null +++ b/internal/runtime/controlplane/envelope.go @@ -0,0 +1,4 @@ +package controlplane + +// PayloadVersion 为 runtime 事件 envelope 的当前协议版本号。 +const PayloadVersion = 1 diff --git a/internal/runtime/controlplane/phase.go b/internal/runtime/controlplane/phase.go new file mode 100644 index 00000000..e43b583d --- /dev/null +++ b/internal/runtime/controlplane/phase.go @@ -0,0 +1,13 @@ +package controlplane + +// Phase 表示单轮 ReAct 内的显式阶段(plan -> execute -> verify)。 +type Phase string + +const ( + // PhasePlan 规划阶段:构建上下文、调用 provider 直至得到 assistant 消息(含工具调用决策)。 + PhasePlan Phase = "plan" + // PhaseExecute 执行阶段:执行本批次全部工具调用。 + PhaseExecute Phase = "execute" + // PhaseVerify 验证阶段:工具结果已回灌,等待下一轮 provider 校验或收尾。 + PhaseVerify Phase = "verify" +) diff --git a/internal/runtime/controlplane/progress.go b/internal/runtime/controlplane/progress.go new file mode 100644 index 00000000..a1d26875 --- /dev/null +++ b/internal/runtime/controlplane/progress.go @@ -0,0 +1,39 @@ +package controlplane + +// ProgressEvidenceKind 标识工具/适配器产出的证据类型,runtime 仅聚合不做语义推断。 +type ProgressEvidenceKind string + +const ( + // EvidenceNewInfoNonDup 表示本轮引入了非重复的新信息(用于 streak 回归约束)。 + EvidenceNewInfoNonDup ProgressEvidenceKind = "EVIDENCE_NEW_INFO_NON_DUP" +) + +// ProgressEvidenceRecord 描述一条可计分的进展证据。 +type ProgressEvidenceRecord struct { + Kind ProgressEvidenceKind `json:"kind"` + Detail string `json:"detail,omitempty"` +} + +// ProgressScore 表示一次评估后的分值增量与 streak 快照。 +type ProgressScore struct { + ScoreDelta int `json:"score_delta"` + NoProgressStreak int `json:"no_progress_streak"` + RepeatCycleStreak int `json:"repeat_cycle_streak"` +} + +// ProgressState 汇总当前运行期 progress 控制面状态。 +type ProgressState struct { + LastScore ProgressScore `json:"last_score"` +} + +// ApplyProgressEvidence 根据证据更新分值与 streak。 +func ApplyProgressEvidence(state ProgressState, records []ProgressEvidenceRecord) ProgressState { + next := state.LastScore + if len(records) == 0 { + next.NoProgressStreak++ + } else { + next.NoProgressStreak = 0 + next.ScoreDelta++ + } + return ProgressState{LastScore: next} +} diff --git a/internal/runtime/controlplane/progress_test.go b/internal/runtime/controlplane/progress_test.go new file mode 100644 index 00000000..cd25006e --- /dev/null +++ b/internal/runtime/controlplane/progress_test.go @@ -0,0 +1,42 @@ +package controlplane + +import "testing" + +func TestApplyProgressEvidenceNoEvidenceIncrementsNoProgress(t *testing.T) { + t.Parallel() + state := ProgressState{} + next := ApplyProgressEvidence(state, nil) + if next.LastScore.NoProgressStreak != 1 { + t.Fatalf("expected no_progress_streak 1, got %d", next.LastScore.NoProgressStreak) + } +} + +func TestApplyProgressEvidenceOnlyNonDupResetsNoProgressStreak(t *testing.T) { + t.Parallel() + state := ProgressState{ + LastScore: ProgressScore{NoProgressStreak: 3}, + } + next := ApplyProgressEvidence(state, []ProgressEvidenceRecord{ + {Kind: EvidenceNewInfoNonDup}, + }) + if next.LastScore.NoProgressStreak != 0 { + t.Fatalf("expected streak reset to 0, got %d", next.LastScore.NoProgressStreak) + } + if next.LastScore.ScoreDelta != 1 { + t.Fatalf("expected score_delta 1, got %d", next.LastScore.ScoreDelta) + } +} + +func TestApplyProgressEvidenceMixedResetsNoProgress(t *testing.T) { + t.Parallel() + state := ProgressState{ + LastScore: ProgressScore{NoProgressStreak: 2}, + } + next := ApplyProgressEvidence(state, []ProgressEvidenceRecord{ + {Kind: EvidenceNewInfoNonDup}, + {Kind: ProgressEvidenceKind("other_evidence")}, + }) + if next.LastScore.NoProgressStreak != 0 { + t.Fatalf("expected streak reset, got %d", next.LastScore.NoProgressStreak) + } +} diff --git a/internal/runtime/controlplane/stop_reason.go b/internal/runtime/controlplane/stop_reason.go new file mode 100644 index 00000000..d08b6cf6 --- /dev/null +++ b/internal/runtime/controlplane/stop_reason.go @@ -0,0 +1,15 @@ +package controlplane + +// StopReason 表示一次 Run 的最终停止原因,互斥且由决议器唯一确定。 +type StopReason string + +const ( + // StopReasonSuccess 表示助手正常结束(无待执行工具调用)。 + StopReasonSuccess StopReason = "success" + // StopReasonMaxLoops 表示达到配置的最大推理轮数。 + StopReasonMaxLoops StopReason = "max_loops" + // StopReasonError 表示不可恢复的运行时或 provider 错误。 + StopReasonError StopReason = "error" + // StopReasonCanceled 表示运行上下文被取消(含用户中断)。 + StopReasonCanceled StopReason = "canceled" +) diff --git a/internal/runtime/events.go b/internal/runtime/events.go index f8608473..a79eca47 100644 --- a/internal/runtime/events.go +++ b/internal/runtime/events.go @@ -3,7 +3,7 @@ package runtime import ( "time" - "neo-code/internal/subagent" + "neo-code/internal/runtime/controlplane" ) // EventType identifies the kind of runtime event emitted during a run. @@ -36,13 +36,13 @@ type BudgetCheckedPayload struct { // ProgressEvaluatedPayload 汇总 progress 控制面评估结果。 type ProgressEvaluatedPayload struct { - Score string `json:"score"` + Score controlplane.ProgressScore `json:"score"` } // StopReasonDecidedPayload 承载唯一停止原因决议结果。 type StopReasonDecidedPayload struct { - Reason string `json:"reason"` - Detail string `json:"detail,omitempty"` + Reason controlplane.StopReason `json:"reason"` + Detail string `json:"detail,omitempty"` } // LedgerReconciledPayload 为账本对账壳事件负载(1A 仅占位)。 @@ -108,16 +108,12 @@ const ( EventProviderRetry EventType = "provider_retry" // EventPermissionRequested 是 1A 权限请求事件名。 EventPermissionRequested EventType = "permission_requested" - // EventPermissionRequest 为兼容旧事件名保留,语义等同 EventPermissionRequested。 - EventPermissionRequest = EventPermissionRequested // EventPermissionResolved is emitted when runtime resolves a permission request or denial. EventPermissionResolved EventType = "permission_resolved" // EventCompactStart is emitted when a compact cycle starts. EventCompactStart EventType = "compact_start" // EventCompactApplied 表示一次 compact 已成功应用或校验完成(1A 主事件)。 EventCompactApplied EventType = "compact_applied" - // EventCompactDone 为兼容旧事件名保留,语义等同 EventCompactApplied。 - EventCompactDone = EventCompactApplied // EventCompactError is emitted when compact fails. EventCompactError EventType = "compact_error" // EventTokenUsage is emitted after each provider response with token statistics. @@ -137,27 +133,3 @@ type TokenUsagePayload struct { SessionInputTokens int `json:"session_input_tokens"` SessionOutputTokens int `json:"session_output_tokens"` } - -// SubAgentEventPayload 描述子代理执行生命周期的事件载荷。 -type SubAgentEventPayload struct { - Role subagent.Role `json:"role"` - TaskID string `json:"task_id"` - State subagent.State `json:"state"` - StopReason subagent.StopReason `json:"stop_reason,omitempty"` - Step int `json:"step,omitempty"` - Delta string `json:"delta,omitempty"` - Error string `json:"error,omitempty"` -} - -const ( - // EventSubAgentStarted 在子代理任务启动后触发。 - EventSubAgentStarted EventType = "subagent_started" - // EventSubAgentProgress 在子代理执行每一步后触发。 - EventSubAgentProgress EventType = "subagent_progress" - // EventSubAgentCompleted 在子代理成功结束后触发。 - EventSubAgentCompleted EventType = "subagent_completed" - // EventSubAgentFailed 在子代理失败结束后触发。 - EventSubAgentFailed EventType = "subagent_failed" - // EventSubAgentCanceled 在子代理被取消后触发。 - EventSubAgentCanceled EventType = "subagent_canceled" -) diff --git a/internal/runtime/events_subagent.go b/internal/runtime/events_subagent.go new file mode 100644 index 00000000..094564ab --- /dev/null +++ b/internal/runtime/events_subagent.go @@ -0,0 +1,33 @@ +package runtime + +import "neo-code/internal/subagent" + +// EventPermissionRequest 为兼容旧事件名保留,语义等同 EventPermissionRequested。 +const EventPermissionRequest EventType = EventPermissionRequested + +// EventCompactDone 为兼容旧事件名保留,语义等同 EventCompactApplied。 +const EventCompactDone EventType = EventCompactApplied + +// SubAgentEventPayload 描述子代理执行生命周期的事件载荷。 +type SubAgentEventPayload struct { + Role subagent.Role `json:"role"` + TaskID string `json:"task_id"` + State subagent.State `json:"state"` + StopReason subagent.StopReason `json:"stop_reason,omitempty"` + Step int `json:"step,omitempty"` + Delta string `json:"delta,omitempty"` + Error string `json:"error,omitempty"` +} + +const ( + // EventSubAgentStarted 在子代理任务启动后触发。 + EventSubAgentStarted EventType = "subagent_started" + // EventSubAgentProgress 在子代理执行每一步后触发。 + EventSubAgentProgress EventType = "subagent_progress" + // EventSubAgentCompleted 在子代理成功结束后触发。 + EventSubAgentCompleted EventType = "subagent_completed" + // EventSubAgentFailed 在子代理失败结束后触发。 + EventSubAgentFailed EventType = "subagent_failed" + // EventSubAgentCanceled 在子代理被取消后触发。 + EventSubAgentCanceled EventType = "subagent_canceled" +) diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index ad4e28a6..f2bb84e3 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -22,7 +22,6 @@ const ( defaultProviderRetryMax = 2 providerRetryBaseWait = 1 * time.Second providerRetryMaxWait = 5 * time.Second - defaultMaxLoops = 8 defaultToolParallelism = 4 noProgressStreakLimit = 3 @@ -70,7 +69,6 @@ type Service struct { compactRunner contextcompact.Runner approvalBroker *approval.Broker memoExtractor MemoExtractor - subAgentFactory subagent.Factory events chan RuntimeEvent sessionMu sync.Mutex @@ -81,6 +79,7 @@ type Service struct { activeRunCancels map[uint64]context.CancelFunc permissionAskMapMu sync.Mutex permissionAskLocks map[string]*permissionAskLockEntry + subAgentFactory subagent.Factory } // sessionLockEntry 维护单个会话锁及其当前引用计数,用于在无引用时回收 map 项。 @@ -120,11 +119,11 @@ func NewWithFactory( providerFactory: providerFactory, contextBuilder: contextBuilder, approvalBroker: approval.NewBroker(), - subAgentFactory: subagent.NewWorkerFactory(nil), events: make(chan RuntimeEvent, 128), sessionLocks: make(map[string]*sessionLockEntry), - activeRunCancels: make(map[uint64]context.CancelFunc), permissionAskLocks: make(map[string]*permissionAskLockEntry), + activeRunCancels: make(map[uint64]context.CancelFunc), + subAgentFactory: subagent.NewWorkerFactory(nil), } } diff --git a/internal/runtime/runtime_limits.go b/internal/runtime/runtime_limits.go new file mode 100644 index 00000000..3cab9213 --- /dev/null +++ b/internal/runtime/runtime_limits.go @@ -0,0 +1,4 @@ +package runtime + +// defaultMaxLoops 定义兼容旧运行循环逻辑时的默认最大轮数。 +const defaultMaxLoops = 8 From a5bf778be7c058a2e9349bde6542caa7e8a12b53 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Tue, 14 Apr 2026 16:22:48 +0000 Subject: [PATCH 05/33] refactor(runtime): decouple subagent factory storage from service struct Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: Cai-Tang-www <106404101+Cai-Tang-www@users.noreply.github.com> --- internal/runtime/runtime.go | 3 --- internal/runtime/subagent_factory.go | 33 ++++++++++++++++++----- internal/runtime/subagent_factory_test.go | 5 ---- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index f2bb84e3..dd69c342 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -14,7 +14,6 @@ import ( providertypes "neo-code/internal/provider/types" "neo-code/internal/runtime/approval" agentsession "neo-code/internal/session" - "neo-code/internal/subagent" "neo-code/internal/tools" ) @@ -79,7 +78,6 @@ type Service struct { activeRunCancels map[uint64]context.CancelFunc permissionAskMapMu sync.Mutex permissionAskLocks map[string]*permissionAskLockEntry - subAgentFactory subagent.Factory } // sessionLockEntry 维护单个会话锁及其当前引用计数,用于在无引用时回收 map 项。 @@ -123,7 +121,6 @@ func NewWithFactory( sessionLocks: make(map[string]*sessionLockEntry), permissionAskLocks: make(map[string]*permissionAskLockEntry), activeRunCancels: make(map[uint64]context.CancelFunc), - subAgentFactory: subagent.NewWorkerFactory(nil), } } diff --git a/internal/runtime/subagent_factory.go b/internal/runtime/subagent_factory.go index 140f4f70..bc55651c 100644 --- a/internal/runtime/subagent_factory.go +++ b/internal/runtime/subagent_factory.go @@ -1,20 +1,41 @@ package runtime -import "neo-code/internal/subagent" +import ( + "sync" + + "neo-code/internal/subagent" +) + +var serviceSubAgentFactory sync.Map + +// defaultSubAgentFactory 返回默认的子代理工厂实例。 +func defaultSubAgentFactory() subagent.Factory { + return subagent.NewWorkerFactory(nil) +} // SetSubAgentFactory 设置子代理运行时工厂;传入 nil 时回退到默认工厂。 func (s *Service) SetSubAgentFactory(factory subagent.Factory) { + if s == nil { + return + } if factory == nil { - s.subAgentFactory = subagent.NewWorkerFactory(nil) + serviceSubAgentFactory.Store(s, defaultSubAgentFactory()) return } - s.subAgentFactory = factory + serviceSubAgentFactory.Store(s, factory) } // SubAgentFactory 返回当前 runtime 持有的子代理运行时工厂。 func (s *Service) SubAgentFactory() subagent.Factory { - if s.subAgentFactory == nil { - s.subAgentFactory = subagent.NewWorkerFactory(nil) + if s == nil { + return defaultSubAgentFactory() + } + if factory, ok := serviceSubAgentFactory.Load(s); ok { + if typed, valid := factory.(subagent.Factory); valid && typed != nil { + return typed + } } - return s.subAgentFactory + defaultFactory := defaultSubAgentFactory() + serviceSubAgentFactory.Store(s, defaultFactory) + return defaultFactory } diff --git a/internal/runtime/subagent_factory_test.go b/internal/runtime/subagent_factory_test.go index 70256ce4..eea246b1 100644 --- a/internal/runtime/subagent_factory_test.go +++ b/internal/runtime/subagent_factory_test.go @@ -30,9 +30,4 @@ func TestServiceSubAgentFactoryRegistration(t *testing.T) { if svc.SubAgentFactory() == nil { t.Fatalf("expected reset to default sub-agent factory") } - - svc.subAgentFactory = nil - if svc.SubAgentFactory() == nil { - t.Fatalf("expected lazy init default sub-agent factory") - } } From d7c1aac6f2b9fdcf331ab56334dffdab2a788db7 Mon Sep 17 00:00:00 2001 From: phantom5099 <1011668688@qq.com> Date: Wed, 15 Apr 2026 01:17:14 +0800 Subject: [PATCH 06/33] =?UTF-8?q?feat(runtime):=E5=A2=9E=E5=8A=A0=E5=8F=AF?= =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=AD=BB=E5=BE=AA=E7=8E=AF=E6=AC=A1=E6=95=B0?= =?UTF-8?q?=E4=BB=A5=E5=8F=8A=E8=87=AA=E6=88=91=E7=BA=A0=E6=AD=A3=E6=8F=90?= =?UTF-8?q?=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/config/config.go | 7 ++++ internal/config/config_test.go | 3 ++ internal/config/loader.go | 3 ++ internal/config/runtime.go | 39 +++++++++++++++++++++++ internal/runtime/run.go | 26 +++++++++++++-- internal/runtime/runtime.go | 1 - internal/runtime/runtime_progress_test.go | 11 +++++++ 7 files changed, 87 insertions(+), 3 deletions(-) create mode 100644 internal/config/runtime.go diff --git a/internal/config/config.go b/internal/config/config.go index 7d0d5c07..87ce5d0f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -21,6 +21,7 @@ type Config struct { Workdir string `yaml:"-"` Shell string `yaml:"shell"` ToolTimeoutSec int `yaml:"tool_timeout_sec,omitempty"` + Runtime RuntimeConfig `yaml:"runtime,omitempty"` Context ContextConfig `yaml:"context,omitempty"` Tools ToolsConfig `yaml:"tools,omitempty"` Memo MemoConfig `yaml:"memo,omitempty"` @@ -32,6 +33,7 @@ func StaticDefaults() *Config { Workdir: DefaultWorkdir, Shell: defaultShell(), ToolTimeoutSec: DefaultToolTimeoutSec, + Runtime: defaultRuntimeConfig(), Context: defaultContextConfig(), Tools: ToolsConfig{ WebFetch: defaultWebFetchConfig(), @@ -48,6 +50,7 @@ func (c *Config) Clone() Config { clone := *c clone.Providers = cloneProviders(c.Providers) + clone.Runtime = c.Runtime.Clone() clone.Context = c.Context.Clone() clone.Tools = c.Tools.Clone() clone.Memo = c.Memo.Clone() @@ -69,6 +72,7 @@ func (c *Config) applyStaticDefaults(defaults Config) { if c.ToolTimeoutSec <= 0 { c.ToolTimeoutSec = defaults.ToolTimeoutSec } + c.Runtime.ApplyDefaults(defaults.Runtime) c.Context.ApplyDefaults(defaults.Context) c.Tools.ApplyDefaults(defaults.Tools) c.Memo.ApplyDefaults(defaults.Memo) @@ -122,6 +126,9 @@ func (c *Config) ValidateSnapshot() error { if err := c.Tools.Validate(); err != nil { return fmt.Errorf("config: tools: %w", err) } + if err := c.Runtime.Validate(); err != nil { + return fmt.Errorf("config: runtime: %w", err) + } if err := c.Context.Validate(); err != nil { return fmt.Errorf("config: context: %w", err) } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index f7b5de4e..b62b2092 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1592,6 +1592,9 @@ func TestValidateSnapshotPropagatesCompactError(t *testing.T) { SupportedContentTypes: []string{"text/html"}, }, }, + Runtime: RuntimeConfig{ + MaxNoProgressStreak: 3, + }, Context: ContextConfig{ Compact: CompactConfig{ ManualStrategy: "invalid_strategy", diff --git a/internal/config/loader.go b/internal/config/loader.go index c7fba9b4..878fd9af 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -26,6 +26,7 @@ type persistedConfig struct { CurrentModel string `yaml:"current_model,omitempty"` Shell string `yaml:"shell"` ToolTimeoutSec int `yaml:"tool_timeout_sec,omitempty"` + Runtime RuntimeConfig `yaml:"runtime,omitempty"` Context persistedContextConfig `yaml:"context,omitempty"` Tools ToolsConfig `yaml:"tools,omitempty"` Memo persistedMemoConfig `yaml:"memo,omitempty"` @@ -204,6 +205,7 @@ func parseCurrentConfig(data []byte, contextDefaults ContextConfig, memoDefaults CurrentModel: strings.TrimSpace(file.CurrentModel), Shell: strings.TrimSpace(file.Shell), ToolTimeoutSec: file.ToolTimeoutSec, + Runtime: file.Runtime, Context: fromPersistedContextConfig(file.Context, contextDefaults), Tools: file.Tools, Memo: fromPersistedMemoConfig(file.Memo, memoDefaults), @@ -218,6 +220,7 @@ func marshalPersistedConfig(snapshot Config) ([]byte, error) { CurrentModel: snapshot.CurrentModel, Shell: snapshot.Shell, ToolTimeoutSec: snapshot.ToolTimeoutSec, + Runtime: snapshot.Runtime, Context: newPersistedContextConfig(snapshot.Context), Tools: snapshot.Tools, Memo: newPersistedMemoConfig(snapshot.Memo), diff --git a/internal/config/runtime.go b/internal/config/runtime.go new file mode 100644 index 00000000..a8f4da5b --- /dev/null +++ b/internal/config/runtime.go @@ -0,0 +1,39 @@ +package config + +import ( + "errors" +) + +const ( + DefaultMaxNoProgressStreak = 3 +) + +type RuntimeConfig struct { + MaxNoProgressStreak int `yaml:"max_no_progress_streak,omitempty"` +} + +func defaultRuntimeConfig() RuntimeConfig { + return RuntimeConfig{ + MaxNoProgressStreak: DefaultMaxNoProgressStreak, + } +} + +func (c RuntimeConfig) Clone() RuntimeConfig { + return c +} + +func (c *RuntimeConfig) ApplyDefaults(defaults RuntimeConfig) { + if c == nil { + return + } + if c.MaxNoProgressStreak <= 0 { + c.MaxNoProgressStreak = defaults.MaxNoProgressStreak + } +} + +func (c RuntimeConfig) Validate() error { + if c.MaxNoProgressStreak <= 0 { + return errors.New("max_no_progress_streak must be greater than 0") + } + return nil +} diff --git a/internal/runtime/run.go b/internal/runtime/run.go index ba5c13f8..7ee2fb7d 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -157,7 +157,12 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { s.emitRunScoped(ctx, EventProgressEvaluated, &state, ProgressEvaluatedPayload{Score: currentScore}) - if streak >= noProgressStreakLimit { + limit := snapshot.config.Runtime.MaxNoProgressStreak + if limit <= 0 { + limit = config.DefaultMaxNoProgressStreak + } + + if streak >= limit { err = ErrNoProgressStreakLimit return err } @@ -216,6 +221,23 @@ func (s *Service) prepareTurnSnapshot(ctx context.Context, state *runState) (tur return turnSnapshot{}, false, err } + state.mu.Lock() + streak := state.progress.LastScore.NoProgressStreak + state.mu.Unlock() + + limit := cfg.Runtime.MaxNoProgressStreak + if limit <= 0 { + limit = config.DefaultMaxNoProgressStreak + } + + messages := builtContext.Messages + if streak == limit-1 { + messages = append(messages, providertypes.Message{ + Role: providertypes.RoleUser, + Content: "System Reminder: You have made multiple consecutive attempts without making substantial progress. Please stop your current repetitive or ineffective strategy. Carefully review the previous errors, change your approach, or ask the user directly for help.", + }) + } + model := strings.TrimSpace(cfg.CurrentModel) return turnSnapshot{ config: cfg, @@ -226,7 +248,7 @@ func (s *Service) prepareTurnSnapshot(ctx context.Context, state *runState) (tur request: providertypes.GenerateRequest{ Model: model, SystemPrompt: builtContext.SystemPrompt, - Messages: builtContext.Messages, + Messages: messages, Tools: toolSpecs, }, }, false, nil diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 1a345f5a..86733bb2 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -20,7 +20,6 @@ const ( providerRetryBaseWait = 1 * time.Second providerRetryMaxWait = 5 * time.Second defaultToolParallelism = 4 - noProgressStreakLimit = 3 terminationEventEmitTimeout = 500 * time.Millisecond ) diff --git a/internal/runtime/runtime_progress_test.go b/internal/runtime/runtime_progress_test.go index f83f73f2..141944d5 100644 --- a/internal/runtime/runtime_progress_test.go +++ b/internal/runtime/runtime_progress_test.go @@ -3,6 +3,7 @@ package runtime import ( "context" "errors" + "strings" "sync/atomic" "testing" @@ -31,9 +32,15 @@ func TestProgressStreakStopsRun(t *testing.T) { }, } + var promptInjected bool providerFactory := &scriptedProviderFactory{ provider: &scriptedProvider{ chatFn: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { + for _, msg := range req.Messages { + if strings.Contains(msg.Content, "System Reminder: You have made multiple consecutive attempts") { + promptInjected = true + } + } // the model always decides to call the tool events <- providertypes.NewToolCallStartStreamEvent(0, "call_err", "tool_error") events <- providertypes.NewToolCallDeltaStreamEvent(0, "call_err", "{}") @@ -83,6 +90,10 @@ func TestProgressStreakStopsRun(t *testing.T) { } } } + + if !promptInjected { + t.Error("expected self-healing prompt to be injected before streak limit is reached, but it wasn't") + } } func TestProgressEvidenceResetsNoProgressStreak(t *testing.T) { From 4098da8d1cc669da6537e8e0ace8bcee0dceeab2 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Tue, 14 Apr 2026 17:43:39 +0000 Subject: [PATCH 07/33] fix(runtime,config): resolve PR review items and simplify streak handling Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- docs/guides/configuration.md | 8 ++++ internal/config/config_test.go | 46 +++++++++++++++++++++++ internal/config/runtime.go | 5 +++ internal/config/runtime_test.go | 46 +++++++++++++++++++++++ internal/runtime/run.go | 37 ++++++++++-------- internal/runtime/runtime_progress_test.go | 6 +-- 6 files changed, 128 insertions(+), 20 deletions(-) create mode 100644 internal/config/runtime_test.go diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index b2a407c0..c6a9a4a3 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -38,6 +38,8 @@ selected_provider: openai current_model: gpt-5.4 shell: bash tool_timeout_sec: 20 +runtime: + max_no_progress_streak: 3 tools: webfetch: @@ -80,6 +82,12 @@ context: | `context.auto_compact.enabled` | 是否启用自动压缩 | | `context.auto_compact.input_token_threshold` | 自动压缩输入 token 阈值 | +### `runtime` 字段 + +| 字段 | 说明 | +|------|------| +| `runtime.max_no_progress_streak` | 连续“无进展”轮次熔断阈值,默认 `3`;当达到 `limit-1` 时会向模型注入一次系统级纠偏提示 | + ### `tools` 字段 | 字段 | 说明 | diff --git a/internal/config/config_test.go b/internal/config/config_test.go index b62b2092..ffd0067c 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1685,6 +1685,52 @@ func TestMarshalPersistedConfigEndsWithNewline(t *testing.T) { } } +func TestParseCurrentConfigRoundTripRuntimeConfig(t *testing.T) { + t.Parallel() + + snapshot := testDefaultConfig().Clone() + snapshot.Runtime.MaxNoProgressStreak = 5 + + data, err := marshalPersistedConfig(snapshot) + if err != nil { + t.Fatalf("marshalPersistedConfig() error = %v", err) + } + + parsed, err := parseCurrentConfig(data, StaticDefaults().Context, StaticDefaults().Memo) + if err != nil { + t.Fatalf("parseCurrentConfig() error = %v", err) + } + if parsed.Runtime.MaxNoProgressStreak != 5 { + t.Fatalf("expected max_no_progress_streak=5, got %d", parsed.Runtime.MaxNoProgressStreak) + } +} + +func TestParseCurrentConfigInvalidRuntimeValueDefaultsBeforeValidation(t *testing.T) { + t.Parallel() + + raw := []byte(` +selected_provider: openai +current_model: gpt-4.1 +shell: bash +runtime: + max_no_progress_streak: -2 +`) + + parsed, err := parseCurrentConfig(raw, StaticDefaults().Context, StaticDefaults().Memo) + if err != nil { + t.Fatalf("parseCurrentConfig() error = %v", err) + } + parsed.Providers = cloneProviders(testDefaultConfig().Providers) + parsed.applyStaticDefaults(*StaticDefaults()) + if err := parsed.ValidateSnapshot(); err != nil { + t.Fatalf("ValidateSnapshot() error = %v", err) + } + if parsed.Runtime.MaxNoProgressStreak != DefaultMaxNoProgressStreak { + t.Fatalf("expected default max_no_progress_streak=%d, got %d", + DefaultMaxNoProgressStreak, parsed.Runtime.MaxNoProgressStreak) + } +} + func TestAssembleProvidersAcceptsEmptyNameProvider(t *testing.T) { t.Parallel() diff --git a/internal/config/runtime.go b/internal/config/runtime.go index a8f4da5b..86bd20ce 100644 --- a/internal/config/runtime.go +++ b/internal/config/runtime.go @@ -8,20 +8,24 @@ const ( DefaultMaxNoProgressStreak = 3 ) +// RuntimeConfig 定义 runtime 层的可调参数。 type RuntimeConfig struct { MaxNoProgressStreak int `yaml:"max_no_progress_streak,omitempty"` } +// defaultRuntimeConfig 返回 runtime 配置的静态默认值。 func defaultRuntimeConfig() RuntimeConfig { return RuntimeConfig{ MaxNoProgressStreak: DefaultMaxNoProgressStreak, } } +// Clone 复制 runtime 配置,避免调用方共享可变状态。 func (c RuntimeConfig) Clone() RuntimeConfig { return c } +// ApplyDefaults 在配置缺失或非法时回填默认阈值。 func (c *RuntimeConfig) ApplyDefaults(defaults RuntimeConfig) { if c == nil { return @@ -31,6 +35,7 @@ func (c *RuntimeConfig) ApplyDefaults(defaults RuntimeConfig) { } } +// Validate 校验 runtime 配置是否满足最小约束。 func (c RuntimeConfig) Validate() error { if c.MaxNoProgressStreak <= 0 { return errors.New("max_no_progress_streak must be greater than 0") diff --git a/internal/config/runtime_test.go b/internal/config/runtime_test.go new file mode 100644 index 00000000..f7126648 --- /dev/null +++ b/internal/config/runtime_test.go @@ -0,0 +1,46 @@ +package config + +import "testing" + +func TestRuntimeConfigClone(t *testing.T) { + t.Parallel() + + cfg := RuntimeConfig{MaxNoProgressStreak: 7} + cloned := cfg.Clone() + if cloned.MaxNoProgressStreak != 7 { + t.Fatalf("expected cloned MaxNoProgressStreak=7, got %d", cloned.MaxNoProgressStreak) + } +} + +func TestRuntimeConfigApplyDefaults(t *testing.T) { + t.Parallel() + + defaults := RuntimeConfig{MaxNoProgressStreak: 3} + + cfg := RuntimeConfig{MaxNoProgressStreak: 0} + cfg.ApplyDefaults(defaults) + if cfg.MaxNoProgressStreak != 3 { + t.Fatalf("expected defaulted MaxNoProgressStreak=3, got %d", cfg.MaxNoProgressStreak) + } + + cfg = RuntimeConfig{MaxNoProgressStreak: 5} + cfg.ApplyDefaults(defaults) + if cfg.MaxNoProgressStreak != 5 { + t.Fatalf("expected existing MaxNoProgressStreak=5 to be preserved, got %d", cfg.MaxNoProgressStreak) + } + + var nilCfg *RuntimeConfig + nilCfg.ApplyDefaults(defaults) +} + +func TestRuntimeConfigValidate(t *testing.T) { + t.Parallel() + + if err := (RuntimeConfig{MaxNoProgressStreak: 1}).Validate(); err != nil { + t.Fatalf("expected valid config, got %v", err) + } + + if err := (RuntimeConfig{MaxNoProgressStreak: 0}).Validate(); err == nil { + t.Fatal("expected validation error for zero MaxNoProgressStreak") + } +} diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 7ee2fb7d..9ee10856 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -18,6 +18,8 @@ import ( "neo-code/internal/tools" ) +const selfHealingReminder = "System Reminder: You have made multiple consecutive attempts without making substantial progress. Please stop your current repetitive or ineffective strategy. Carefully review the previous errors, change your approach, or ask the user directly for help." + // Run 执行一次完整的 ReAct 闭环:保存用户输入、驱动模型、执行工具并发出事件。 // 已有会话会先加锁再加载/更新,确保同一会话并发 Run 不会出现状态覆盖; // 新会话在创建后再绑定会话锁,不同会话可并行执行。 @@ -157,11 +159,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { s.emitRunScoped(ctx, EventProgressEvaluated, &state, ProgressEvaluatedPayload{Score: currentScore}) - limit := snapshot.config.Runtime.MaxNoProgressStreak - if limit <= 0 { - limit = config.DefaultMaxNoProgressStreak - } - + limit := resolveNoProgressStreakLimit(snapshot.config) if streak >= limit { err = ErrNoProgressStreakLimit return err @@ -225,17 +223,15 @@ func (s *Service) prepareTurnSnapshot(ctx context.Context, state *runState) (tur streak := state.progress.LastScore.NoProgressStreak state.mu.Unlock() - limit := cfg.Runtime.MaxNoProgressStreak - if limit <= 0 { - limit = config.DefaultMaxNoProgressStreak - } + limit := resolveNoProgressStreakLimit(cfg) + systemPrompt := builtContext.SystemPrompt - messages := builtContext.Messages if streak == limit-1 { - messages = append(messages, providertypes.Message{ - Role: providertypes.RoleUser, - Content: "System Reminder: You have made multiple consecutive attempts without making substantial progress. Please stop your current repetitive or ineffective strategy. Carefully review the previous errors, change your approach, or ask the user directly for help.", - }) + if strings.TrimSpace(systemPrompt) == "" { + systemPrompt = selfHealingReminder + } else { + systemPrompt = systemPrompt + "\n\n" + selfHealingReminder + } } model := strings.TrimSpace(cfg.CurrentModel) @@ -247,13 +243,22 @@ func (s *Service) prepareTurnSnapshot(ctx context.Context, state *runState) (tur toolTimeout: time.Duration(cfg.ToolTimeoutSec) * time.Second, request: providertypes.GenerateRequest{ Model: model, - SystemPrompt: builtContext.SystemPrompt, - Messages: messages, + SystemPrompt: systemPrompt, + Messages: builtContext.Messages, Tools: toolSpecs, }, }, false, nil } +// resolveNoProgressStreakLimit 统一解析熔断阈值,避免运行期出现无效值导致分支行为不一致。 +func resolveNoProgressStreakLimit(cfg config.Config) int { + limit := cfg.Runtime.MaxNoProgressStreak + if limit <= 0 { + return config.DefaultMaxNoProgressStreak + } + return limit +} + // callProviderWithRetry 使用冻结后的 turnSnapshot 执行 provider 调用与必要重试。 func (s *Service) callProviderWithRetry( ctx context.Context, diff --git a/internal/runtime/runtime_progress_test.go b/internal/runtime/runtime_progress_test.go index 141944d5..f9e90a1a 100644 --- a/internal/runtime/runtime_progress_test.go +++ b/internal/runtime/runtime_progress_test.go @@ -36,10 +36,8 @@ func TestProgressStreakStopsRun(t *testing.T) { providerFactory := &scriptedProviderFactory{ provider: &scriptedProvider{ chatFn: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { - for _, msg := range req.Messages { - if strings.Contains(msg.Content, "System Reminder: You have made multiple consecutive attempts") { - promptInjected = true - } + if strings.Contains(req.SystemPrompt, selfHealingReminder) { + promptInjected = true } // the model always decides to call the tool events <- providertypes.NewToolCallStartStreamEvent(0, "call_err", "tool_error") From e43f70c63369cddeadb439643c3250c526ff906a Mon Sep 17 00:00:00 2001 From: xgopilot Date: Wed, 15 Apr 2026 01:33:56 +0000 Subject: [PATCH 08/33] fix(subagent): resolve review issues for factory lifecycle and default output contract Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: Cai-Tang-www <106404101+Cai-Tang-www@users.noreply.github.com> --- internal/runtime/runtime.go | 4 +++ internal/runtime/subagent_factory.go | 34 +++++++++++++---------- internal/runtime/subagent_factory_test.go | 21 ++++++++++++++ internal/subagent/engine.go | 7 ++++- internal/subagent/engine_test.go | 30 ++++++++++++++++++++ 5 files changed, 81 insertions(+), 15 deletions(-) diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index dd69c342..928dd6c7 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -14,6 +14,7 @@ import ( providertypes "neo-code/internal/provider/types" "neo-code/internal/runtime/approval" agentsession "neo-code/internal/session" + "neo-code/internal/subagent" "neo-code/internal/tools" ) @@ -68,6 +69,8 @@ type Service struct { compactRunner contextcompact.Runner approvalBroker *approval.Broker memoExtractor MemoExtractor + subAgentFactory subagent.Factory + subAgentMu sync.RWMutex events chan RuntimeEvent sessionMu sync.Mutex @@ -117,6 +120,7 @@ func NewWithFactory( providerFactory: providerFactory, contextBuilder: contextBuilder, approvalBroker: approval.NewBroker(), + subAgentFactory: defaultSubAgentFactory(), events: make(chan RuntimeEvent, 128), sessionLocks: make(map[string]*sessionLockEntry), permissionAskLocks: make(map[string]*permissionAskLockEntry), diff --git a/internal/runtime/subagent_factory.go b/internal/runtime/subagent_factory.go index bc55651c..e1889275 100644 --- a/internal/runtime/subagent_factory.go +++ b/internal/runtime/subagent_factory.go @@ -1,12 +1,6 @@ package runtime -import ( - "sync" - - "neo-code/internal/subagent" -) - -var serviceSubAgentFactory sync.Map +import "neo-code/internal/subagent" // defaultSubAgentFactory 返回默认的子代理工厂实例。 func defaultSubAgentFactory() subagent.Factory { @@ -18,11 +12,13 @@ func (s *Service) SetSubAgentFactory(factory subagent.Factory) { if s == nil { return } + s.subAgentMu.Lock() + defer s.subAgentMu.Unlock() if factory == nil { - serviceSubAgentFactory.Store(s, defaultSubAgentFactory()) + s.subAgentFactory = defaultSubAgentFactory() return } - serviceSubAgentFactory.Store(s, factory) + s.subAgentFactory = factory } // SubAgentFactory 返回当前 runtime 持有的子代理运行时工厂。 @@ -30,12 +26,22 @@ func (s *Service) SubAgentFactory() subagent.Factory { if s == nil { return defaultSubAgentFactory() } - if factory, ok := serviceSubAgentFactory.Load(s); ok { - if typed, valid := factory.(subagent.Factory); valid && typed != nil { - return typed - } + s.subAgentMu.RLock() + factory := s.subAgentFactory + s.subAgentMu.RUnlock() + if factory != nil { + return factory } + defaultFactory := defaultSubAgentFactory() - serviceSubAgentFactory.Store(s, defaultFactory) + s.subAgentMu.Lock() + if s.subAgentFactory == nil { + s.subAgentFactory = defaultFactory + } + factory = s.subAgentFactory + s.subAgentMu.Unlock() + if factory != nil { + return factory + } return defaultFactory } diff --git a/internal/runtime/subagent_factory_test.go b/internal/runtime/subagent_factory_test.go index eea246b1..9869ba45 100644 --- a/internal/runtime/subagent_factory_test.go +++ b/internal/runtime/subagent_factory_test.go @@ -31,3 +31,24 @@ func TestServiceSubAgentFactoryRegistration(t *testing.T) { t.Fatalf("expected reset to default sub-agent factory") } } + +func TestServiceSubAgentFactoryIsolationAcrossInstances(t *testing.T) { + t.Parallel() + + svcA := NewWithFactory(nil, nil, nil, nil, nil) + svcB := NewWithFactory(nil, nil, nil, nil, nil) + + custom := fakeSubAgentFactory{} + svcA.SetSubAgentFactory(custom) + + if svcA.SubAgentFactory() == nil { + t.Fatalf("expected service A factory to be set") + } + if svcB.SubAgentFactory() == nil { + t.Fatalf("expected service B default factory") + } + + if svcA.SubAgentFactory() == svcB.SubAgentFactory() { + t.Fatalf("expected per-service factory isolation") + } +} diff --git a/internal/subagent/engine.go b/internal/subagent/engine.go index 6d849558..b8e859d8 100644 --- a/internal/subagent/engine.go +++ b/internal/subagent/engine.go @@ -23,7 +23,12 @@ func (defaultEngine) RunStep(ctx context.Context, input StepInput) (StepOutput, Delta: "default engine completed", Done: true, Output: Output{ - Summary: summary, + Summary: summary, + Findings: []string{"default-engine: no extra findings"}, + Patches: []string{"default-engine: no code changes"}, + Risks: []string{"default-engine: output generated without external verification"}, + NextActions: []string{"replace default engine with role-specific implementation if deeper execution is required"}, + Artifacts: []string{"default-engine"}, }, }, nil } diff --git a/internal/subagent/engine_test.go b/internal/subagent/engine_test.go index feeb449c..9a11048f 100644 --- a/internal/subagent/engine_test.go +++ b/internal/subagent/engine_test.go @@ -28,6 +28,12 @@ func TestDefaultEngineRunStep(t *testing.T) { if out.Output.Summary != "expected" { t.Fatalf("summary = %q, want %q", out.Output.Summary, "expected") } + if len(out.Output.Findings) == 0 || len(out.Output.Patches) == 0 || len(out.Output.Risks) == 0 { + t.Fatalf("default engine should populate required list sections, got %+v", out.Output) + } + if len(out.Output.NextActions) == 0 || len(out.Output.Artifacts) == 0 { + t.Fatalf("default engine should populate required sections, got %+v", out.Output) + } }) t.Run("falls back to goal", func(t *testing.T) { @@ -56,4 +62,28 @@ func TestDefaultEngineRunStep(t *testing.T) { t.Fatalf("expected context error") } }) + + t.Run("satisfies default role contract", func(t *testing.T) { + t.Parallel() + + out, err := engine.RunStep(context.Background(), StepInput{ + Task: Task{ + Goal: "goal", + ExpectedOutput: "summary", + }, + }) + if err != nil { + t.Fatalf("RunStep() error = %v", err) + } + + for _, role := range []Role{RoleResearcher, RoleCoder, RoleReviewer} { + policy, err := DefaultRolePolicy(role) + if err != nil { + t.Fatalf("DefaultRolePolicy(%q) error = %v", role, err) + } + if err := validateOutputContract(policy, out.Output); err != nil { + t.Fatalf("validateOutputContract(%q) error = %v", role, err) + } + } + }) } From 319b30365e7720dcde320cfa90d5cc6a2fd4e497 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Wed, 15 Apr 2026 02:19:43 +0000 Subject: [PATCH 09/33] refactor(runtime,config): address review feedback on streak handling - Narrow resolveNoProgressStreakLimit param from config.Config to config.RuntimeConfig - Store resolved noProgressStreakLimit in turnSnapshot to ensure reminder injection and termination checks use the same value, eliminating concurrent-reload inconsistency - TrimSpace systemPrompt once before appending self-healing reminder - Merge double lock sections into one in the progress evidence path - Expand TestRuntimeConfigValidate to cover 0, -1, -99 rejection cases - Update docs: concrete limit-1/limit examples (default 2nd/3rd turn) Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- docs/guides/configuration.md | 2 +- internal/config/runtime_test.go | 6 +++-- internal/runtime/run.go | 40 +++++++++++++-------------------- internal/runtime/state.go | 15 ++++++++----- 4 files changed, 30 insertions(+), 33 deletions(-) diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index c6a9a4a3..0d8724ca 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -86,7 +86,7 @@ context: | 字段 | 说明 | |------|------| -| `runtime.max_no_progress_streak` | 连续“无进展”轮次熔断阈值,默认 `3`;当达到 `limit-1` 时会向模型注入一次系统级纠偏提示 | +| `runtime.max_no_progress_streak` | 连续”无进展”轮次熔断阈值,默认 `3`;streak 达到 `limit-1`(默认第 2 轮)时向模型注入一次系统级纠偏提示,达到 `limit`(默认第 3 轮)时终止运行 | ### `tools` 字段 diff --git a/internal/config/runtime_test.go b/internal/config/runtime_test.go index f7126648..0c6c705d 100644 --- a/internal/config/runtime_test.go +++ b/internal/config/runtime_test.go @@ -40,7 +40,9 @@ func TestRuntimeConfigValidate(t *testing.T) { t.Fatalf("expected valid config, got %v", err) } - if err := (RuntimeConfig{MaxNoProgressStreak: 0}).Validate(); err == nil { - t.Fatal("expected validation error for zero MaxNoProgressStreak") + for _, bad := range []int{0, -1, -99} { + if err := (RuntimeConfig{MaxNoProgressStreak: bad}).Validate(); err == nil { + t.Fatalf("expected validation error for MaxNoProgressStreak=%d", bad) + } } } diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 9ee10856..34c9a35a 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -133,25 +133,16 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { s.transitionRunPhase(ctx, &state, controlplane.PhaseVerify) var evidence []controlplane.ProgressEvidenceRecord - hasProgress := false toolCallCount := len(turnResult.assistant.ToolCalls) state.mu.Lock() if len(state.session.Messages) >= toolCallCount { for i := len(state.session.Messages) - toolCallCount; i < len(state.session.Messages); i++ { - msg := state.session.Messages[i] - if msg.Role == providertypes.RoleTool && !msg.IsError { - hasProgress = true + if msg := state.session.Messages[i]; msg.Role == providertypes.RoleTool && !msg.IsError { + evidence = append(evidence, controlplane.ProgressEvidenceRecord{Kind: controlplane.EvidenceNewInfoNonDup}) break } } } - state.mu.Unlock() - - if hasProgress { - evidence = append(evidence, controlplane.ProgressEvidenceRecord{Kind: controlplane.EvidenceNewInfoNonDup}) - } - - state.mu.Lock() state.progress = controlplane.ApplyProgressEvidence(state.progress, evidence) streak := state.progress.LastScore.NoProgressStreak currentScore := state.progress.LastScore @@ -159,7 +150,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { s.emitRunScoped(ctx, EventProgressEvaluated, &state, ProgressEvaluatedPayload{Score: currentScore}) - limit := resolveNoProgressStreakLimit(snapshot.config) + limit := snapshot.noProgressStreakLimit if streak >= limit { err = ErrNoProgressStreakLimit return err @@ -223,24 +214,26 @@ func (s *Service) prepareTurnSnapshot(ctx context.Context, state *runState) (tur streak := state.progress.LastScore.NoProgressStreak state.mu.Unlock() - limit := resolveNoProgressStreakLimit(cfg) + limit := resolveNoProgressStreakLimit(cfg.Runtime) systemPrompt := builtContext.SystemPrompt if streak == limit-1 { - if strings.TrimSpace(systemPrompt) == "" { + trimmed := strings.TrimSpace(systemPrompt) + if trimmed == "" { systemPrompt = selfHealingReminder } else { - systemPrompt = systemPrompt + "\n\n" + selfHealingReminder + systemPrompt = trimmed + "\n\n" + selfHealingReminder } } model := strings.TrimSpace(cfg.CurrentModel) return turnSnapshot{ - config: cfg, - providerConfig: resolvedProvider.ToRuntimeConfig(), - model: model, - workdir: activeWorkdir, - toolTimeout: time.Duration(cfg.ToolTimeoutSec) * time.Second, + config: cfg, + providerConfig: resolvedProvider.ToRuntimeConfig(), + model: model, + workdir: activeWorkdir, + toolTimeout: time.Duration(cfg.ToolTimeoutSec) * time.Second, + noProgressStreakLimit: limit, request: providertypes.GenerateRequest{ Model: model, SystemPrompt: systemPrompt, @@ -251,12 +244,11 @@ func (s *Service) prepareTurnSnapshot(ctx context.Context, state *runState) (tur } // resolveNoProgressStreakLimit 统一解析熔断阈值,避免运行期出现无效值导致分支行为不一致。 -func resolveNoProgressStreakLimit(cfg config.Config) int { - limit := cfg.Runtime.MaxNoProgressStreak - if limit <= 0 { +func resolveNoProgressStreakLimit(rc config.RuntimeConfig) int { + if rc.MaxNoProgressStreak <= 0 { return config.DefaultMaxNoProgressStreak } - return limit + return rc.MaxNoProgressStreak } // callProviderWithRetry 使用冻结后的 turnSnapshot 执行 provider 调用与必要重试。 diff --git a/internal/runtime/state.go b/internal/runtime/state.go index c0b160bc..585b97f1 100644 --- a/internal/runtime/state.go +++ b/internal/runtime/state.go @@ -63,13 +63,16 @@ func (s *runState) touchSession() { } // turnSnapshot 冻结单轮推理所需的配置、上下文与 provider 请求。 +// noProgressStreakLimit 由 prepareTurnSnapshot 一次性解析并存储,确保同一轮的 +// 纠偏注入阈值与熔断阈值来自同一配置快照,避免并发 reload 导致阈值不一致。 type turnSnapshot struct { - config config.Config - providerConfig provider.RuntimeConfig - model string - workdir string - toolTimeout time.Duration - request providertypes.GenerateRequest + config config.Config + providerConfig provider.RuntimeConfig + model string + workdir string + toolTimeout time.Duration + noProgressStreakLimit int + request providertypes.GenerateRequest } // providerTurnResult 表示单轮 provider 调用成功后的结构化结果。 From 062846d8db662bb3fbb2bf2c696997650d16ef8f Mon Sep 17 00:00:00 2001 From: xgopilot Date: Wed, 15 Apr 2026 02:41:22 +0000 Subject: [PATCH 10/33] fix(runtime): resolve main merge conflicts for subagent runtime Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: Cai-Tang-www <106404101+Cai-Tang-www@users.noreply.github.com> --- internal/runtime/event_emitter.go | 71 ++++++++++++ internal/runtime/runtime.go | 154 -------------------------- internal/runtime/session_scheduler.go | 138 +++++++++++++++++++++++ internal/runtime/state.go | 3 + internal/runtime/subagent_factory.go | 86 +++++++++++--- 5 files changed, 280 insertions(+), 172 deletions(-) create mode 100644 internal/runtime/event_emitter.go create mode 100644 internal/runtime/session_scheduler.go diff --git a/internal/runtime/event_emitter.go b/internal/runtime/event_emitter.go new file mode 100644 index 00000000..43080dbb --- /dev/null +++ b/internal/runtime/event_emitter.go @@ -0,0 +1,71 @@ +package runtime + +import ( + "context" + "time" + + "neo-code/internal/runtime/controlplane" +) + +const turnUnspecified = -1 + +// emit 将 runtime 事件投递到事件通道,并在通道阻塞且上下文取消时返回错误。 +func (s *Service) emit(ctx context.Context, kind EventType, runID string, sessionID string, payload any) error { + return s.emitWithEnvelope(ctx, RuntimeEvent{ + Type: kind, + RunID: runID, + SessionID: sessionID, + Turn: turnUnspecified, + Timestamp: time.Now(), + PayloadVersion: controlplane.PayloadVersion, + Payload: payload, + }) +} + +// emitRunScoped 携带当前 run 的 turn/phase 元数据发出事件。 +func (s *Service) emitRunScoped(ctx context.Context, kind EventType, state *runState, payload any) error { + if state == nil { + return s.emit(ctx, kind, "", "", payload) + } + phase := "" + if state.phase != "" { + phase = string(state.phase) + } + return s.emitWithEnvelope(ctx, RuntimeEvent{ + Type: kind, + RunID: state.runID, + SessionID: state.session.ID, + Turn: state.turn, + Phase: phase, + Timestamp: time.Now(), + PayloadVersion: controlplane.PayloadVersion, + Payload: payload, + }) +} + +func (s *Service) emitWithEnvelope(ctx context.Context, evt RuntimeEvent) error { + if evt.PayloadVersion == 0 { + evt.PayloadVersion = controlplane.PayloadVersion + } + if evt.Timestamp.IsZero() { + evt.Timestamp = time.Now() + } + if err := s.deliverEvent(ctx, evt); err != nil { + return err + } + return nil +} + +func (s *Service) deliverEvent(ctx context.Context, evt RuntimeEvent) error { + select { + case s.events <- evt: + return nil + default: + } + select { + case s.events <- evt: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 928dd6c7..1a345f5a 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -2,8 +2,6 @@ package runtime import ( "context" - "path/filepath" - "strings" "sync" "time" @@ -14,7 +12,6 @@ import ( providertypes "neo-code/internal/provider/types" "neo-code/internal/runtime/approval" agentsession "neo-code/internal/session" - "neo-code/internal/subagent" "neo-code/internal/tools" ) @@ -69,8 +66,6 @@ type Service struct { compactRunner contextcompact.Runner approvalBroker *approval.Broker memoExtractor MemoExtractor - subAgentFactory subagent.Factory - subAgentMu sync.RWMutex events chan RuntimeEvent sessionMu sync.Mutex @@ -120,7 +115,6 @@ func NewWithFactory( providerFactory: providerFactory, contextBuilder: contextBuilder, approvalBroker: approval.NewBroker(), - subAgentFactory: defaultSubAgentFactory(), events: make(chan RuntimeEvent, 128), sessionLocks: make(map[string]*sessionLockEntry), permissionAskLocks: make(map[string]*permissionAskLockEntry), @@ -168,151 +162,3 @@ func (s *Service) LoadSession(ctx context.Context, id string) (agentsession.Sess } return session, nil } - -// loadOrCreateSession 负责在运行开始时解析工作目录并加载或创建会话。 -func (s *Service) loadOrCreateSession( - ctx context.Context, - sessionID string, - title string, - defaultWorkdir string, - requestedWorkdir string, -) (agentsession.Session, error) { - if strings.TrimSpace(sessionID) == "" { - sessionWorkdir, err := resolveWorkdirForSession(defaultWorkdir, "", requestedWorkdir) - if err != nil { - return agentsession.Session{}, err - } - session := agentsession.NewWithWorkdir(title, sessionWorkdir) - if err := s.sessionStore.Save(ctx, &session); err != nil { - return agentsession.Session{}, err - } - return session, nil - } - - session, err := s.sessionStore.Load(ctx, sessionID) - if err != nil { - return agentsession.Session{}, err - } - if strings.TrimSpace(requestedWorkdir) == "" && strings.TrimSpace(session.Workdir) != "" { - return session, nil - } - - resolved, err := resolveWorkdirForSession(defaultWorkdir, session.Workdir, requestedWorkdir) - if err != nil { - return agentsession.Session{}, err - } - if session.Workdir == resolved { - return session, nil - } - - session.Workdir = resolved - session.UpdatedAt = time.Now() - if err := s.sessionStore.Save(ctx, &session); err != nil { - return agentsession.Session{}, err - } - return session, nil -} - -// emit 将 runtime 事件投递到事件通道,并在通道阻塞且上下文取消时返回错误。 -func (s *Service) emit(ctx context.Context, kind EventType, runID string, sessionID string, payload any) error { - evt := RuntimeEvent{ - Type: kind, - RunID: runID, - SessionID: sessionID, - Payload: payload, - } - select { - case s.events <- evt: - return nil - default: - } - select { - case s.events <- evt: - return nil - case <-ctx.Done(): - return ctx.Err() - } -} - -// startRun 记录当前激活的运行取消句柄,并分配一个新的运行令牌。 -func (s *Service) startRun(cancel context.CancelFunc) uint64 { - s.runMu.Lock() - defer s.runMu.Unlock() - if s.activeRunCancels == nil { - s.activeRunCancels = make(map[uint64]context.CancelFunc) - } - - s.nextRunToken++ - token := s.nextRunToken - s.activeRunToken = token - s.activeRunCancels[token] = cancel - return token -} - -// finishRun 在运行结束时释放指定运行的取消句柄,并回退到最新活跃运行。 -func (s *Service) finishRun(token uint64) { - s.runMu.Lock() - defer s.runMu.Unlock() - - delete(s.activeRunCancels, token) - if s.activeRunToken != token { - return - } - - s.activeRunToken = 0 - for activeToken := range s.activeRunCancels { - if activeToken > s.activeRunToken { - s.activeRunToken = activeToken - } - } -} - -// acquireSessionLock 获取指定会话锁并返回释放引用的函数。 -// 调用方在完成会话级串行操作后,必须调用 release 以允许锁条目回收。 -func (s *Service) acquireSessionLock(sessionID string) (*sync.Mutex, func()) { - s.sessionMu.Lock() - if s.sessionLocks == nil { - s.sessionLocks = make(map[string]*sessionLockEntry) - } - - entry, ok := s.sessionLocks[sessionID] - if !ok { - entry = &sessionLockEntry{} - s.sessionLocks[sessionID] = entry - } - entry.refs++ - s.sessionMu.Unlock() - - released := false - release := func() { - s.sessionMu.Lock() - defer s.sessionMu.Unlock() - if released { - return - } - released = true - - current, exists := s.sessionLocks[sessionID] - if !exists || current != entry { - return - } - current.refs-- - if current.refs <= 0 { - delete(s.sessionLocks, sessionID) - } - } - return &entry.mu, release -} - -func resolveWorkdirForSession(defaultWorkdir string, currentWorkdir string, requestedWorkdir string) (string, error) { - base := agentsession.EffectiveWorkdir(currentWorkdir, defaultWorkdir) - if strings.TrimSpace(requestedWorkdir) == "" { - return agentsession.ResolveExistingDir(base) - } - - target := strings.TrimSpace(requestedWorkdir) - if !filepath.IsAbs(target) { - target = filepath.Join(base, target) - } - return agentsession.ResolveExistingDir(target) -} diff --git a/internal/runtime/session_scheduler.go b/internal/runtime/session_scheduler.go new file mode 100644 index 00000000..ed783622 --- /dev/null +++ b/internal/runtime/session_scheduler.go @@ -0,0 +1,138 @@ +package runtime + +import ( + "context" + "path/filepath" + "strings" + "sync" + "time" + + agentsession "neo-code/internal/session" +) + +// loadOrCreateSession 负责在运行开始时解析工作目录并加载或创建会话。 +func (s *Service) loadOrCreateSession( + ctx context.Context, + sessionID string, + title string, + defaultWorkdir string, + requestedWorkdir string, +) (agentsession.Session, error) { + if strings.TrimSpace(sessionID) == "" { + sessionWorkdir, err := resolveWorkdirForSession(defaultWorkdir, "", requestedWorkdir) + if err != nil { + return agentsession.Session{}, err + } + session := agentsession.NewWithWorkdir(title, sessionWorkdir) + if err := s.sessionStore.Save(ctx, &session); err != nil { + return agentsession.Session{}, err + } + return session, nil + } + + session, err := s.sessionStore.Load(ctx, sessionID) + if err != nil { + return agentsession.Session{}, err + } + if strings.TrimSpace(requestedWorkdir) == "" && strings.TrimSpace(session.Workdir) != "" { + return session, nil + } + + resolved, err := resolveWorkdirForSession(defaultWorkdir, session.Workdir, requestedWorkdir) + if err != nil { + return agentsession.Session{}, err + } + if session.Workdir == resolved { + return session, nil + } + + session.Workdir = resolved + session.UpdatedAt = time.Now() + if err := s.sessionStore.Save(ctx, &session); err != nil { + return agentsession.Session{}, err + } + return session, nil +} + +// startRun 记录当前激活的运行取消句柄,并分配一个新的运行令牌。 +func (s *Service) startRun(cancel context.CancelFunc) uint64 { + s.runMu.Lock() + defer s.runMu.Unlock() + if s.activeRunCancels == nil { + s.activeRunCancels = make(map[uint64]context.CancelFunc) + } + + s.nextRunToken++ + token := s.nextRunToken + s.activeRunToken = token + s.activeRunCancels[token] = cancel + return token +} + +// finishRun 在运行结束时释放指定运行的取消句柄,并回退到最新活跃运行。 +func (s *Service) finishRun(token uint64) { + s.runMu.Lock() + defer s.runMu.Unlock() + + delete(s.activeRunCancels, token) + if s.activeRunToken != token { + return + } + + s.activeRunToken = 0 + for activeToken := range s.activeRunCancels { + if activeToken > s.activeRunToken { + s.activeRunToken = activeToken + } + } +} + +// acquireSessionLock 获取指定会话锁并返回释放引用的函数。 +// 调用方在完成会话级串行操作后,必须调用 release 以允许锁条目回收。 +func (s *Service) acquireSessionLock(sessionID string) (*sync.Mutex, func()) { + s.sessionMu.Lock() + if s.sessionLocks == nil { + s.sessionLocks = make(map[string]*sessionLockEntry) + } + + entry, ok := s.sessionLocks[sessionID] + if !ok { + entry = &sessionLockEntry{} + s.sessionLocks[sessionID] = entry + } + entry.refs++ + s.sessionMu.Unlock() + + released := false + release := func() { + s.sessionMu.Lock() + defer s.sessionMu.Unlock() + if released { + return + } + released = true + + current, exists := s.sessionLocks[sessionID] + if !exists || current != entry { + return + } + current.refs-- + if current.refs <= 0 { + delete(s.sessionLocks, sessionID) + } + } + return &entry.mu, release +} + +func resolveWorkdirForSession(defaultWorkdir string, currentWorkdir string, requestedWorkdir string) (string, error) { + base := agentsession.EffectiveWorkdir(currentWorkdir, defaultWorkdir) + if strings.TrimSpace(requestedWorkdir) == "" { + return agentsession.ResolveExistingDir(base) + } + + target := strings.TrimSpace(requestedWorkdir) + if !filepath.IsAbs(target) { + target = filepath.Join(base, target) + } + return agentsession.ResolveExistingDir(target) +} diff --git a/internal/runtime/state.go b/internal/runtime/state.go index 3b7d1c41..a5d24411 100644 --- a/internal/runtime/state.go +++ b/internal/runtime/state.go @@ -6,6 +6,7 @@ import ( "neo-code/internal/config" "neo-code/internal/provider" providertypes "neo-code/internal/provider/types" + "neo-code/internal/runtime/controlplane" agentsession "neo-code/internal/session" ) @@ -21,6 +22,8 @@ type runState struct { compactApplied bool reactiveCompactAttempts int rememberedThisRun bool + turn int + phase controlplane.Phase } // newRunState 基于持久化会话创建一次运行的内存状态镜像。 diff --git a/internal/runtime/subagent_factory.go b/internal/runtime/subagent_factory.go index e1889275..68662d0b 100644 --- a/internal/runtime/subagent_factory.go +++ b/internal/runtime/subagent_factory.go @@ -1,6 +1,69 @@ package runtime -import "neo-code/internal/subagent" +import ( + "runtime" + "sync" + + "neo-code/internal/subagent" +) + +type subAgentFactoryRegistry struct { + mu sync.RWMutex + factory map[*Service]subagent.Factory + tracked map[*Service]struct{} +} + +var globalSubAgentFactories = &subAgentFactoryRegistry{ + factory: make(map[*Service]subagent.Factory), + tracked: make(map[*Service]struct{}), +} + +// ensureTracked 为 Service 注册 GC 回调,避免全局注册表持有泄漏。 +func (r *subAgentFactoryRegistry) ensureTracked(s *Service) { + if s == nil { + return + } + + r.mu.Lock() + if _, ok := r.tracked[s]; ok { + r.mu.Unlock() + return + } + r.tracked[s] = struct{}{} + r.mu.Unlock() + + runtime.SetFinalizer(s, func(service *Service) { + globalSubAgentFactories.mu.Lock() + delete(globalSubAgentFactories.factory, service) + delete(globalSubAgentFactories.tracked, service) + globalSubAgentFactories.mu.Unlock() + }) +} + +// set 保存 Service 级工厂实例。 +func (r *subAgentFactoryRegistry) set(s *Service, f subagent.Factory) { + if s == nil { + return + } + r.ensureTracked(s) + + r.mu.Lock() + r.factory[s] = f + r.mu.Unlock() +} + +// get 读取 Service 级工厂实例。 +func (r *subAgentFactoryRegistry) get(s *Service) (subagent.Factory, bool) { + if s == nil { + return nil, false + } + r.ensureTracked(s) + + r.mu.RLock() + factory, ok := r.factory[s] + r.mu.RUnlock() + return factory, ok +} // defaultSubAgentFactory 返回默认的子代理工厂实例。 func defaultSubAgentFactory() subagent.Factory { @@ -12,13 +75,11 @@ func (s *Service) SetSubAgentFactory(factory subagent.Factory) { if s == nil { return } - s.subAgentMu.Lock() - defer s.subAgentMu.Unlock() if factory == nil { - s.subAgentFactory = defaultSubAgentFactory() + globalSubAgentFactories.set(s, defaultSubAgentFactory()) return } - s.subAgentFactory = factory + globalSubAgentFactories.set(s, factory) } // SubAgentFactory 返回当前 runtime 持有的子代理运行时工厂。 @@ -26,22 +87,11 @@ func (s *Service) SubAgentFactory() subagent.Factory { if s == nil { return defaultSubAgentFactory() } - s.subAgentMu.RLock() - factory := s.subAgentFactory - s.subAgentMu.RUnlock() - if factory != nil { + if factory, ok := globalSubAgentFactories.get(s); ok && factory != nil { return factory } defaultFactory := defaultSubAgentFactory() - s.subAgentMu.Lock() - if s.subAgentFactory == nil { - s.subAgentFactory = defaultFactory - } - factory = s.subAgentFactory - s.subAgentMu.Unlock() - if factory != nil { - return factory - } + globalSubAgentFactories.set(s, defaultFactory) return defaultFactory } From 631d8ab012eb0cf02a82cced6c8b1cff78a67320 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Wed, 15 Apr 2026 02:44:23 +0000 Subject: [PATCH 11/33] refactor(runtime): align state model with main-compatible event envelope Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: Cai-Tang-www <106404101+Cai-Tang-www@users.noreply.github.com> --- internal/runtime/run.go | 8 ++--- .../runtime/runtime_internal_helpers_test.go | 13 +++----- internal/runtime/runtime_test.go | 28 ++++++++-------- internal/runtime/session_mutation.go | 1 - internal/runtime/state.go | 33 +++++++------------ 5 files changed, 33 insertions(+), 50 deletions(-) diff --git a/internal/runtime/run.go b/internal/runtime/run.go index b0caf20a..be606beb 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -139,8 +139,8 @@ func (s *Service) prepareTurnSnapshot(ctx context.Context, state *runState) (tur Shell: cfg.Shell, Provider: cfg.SelectedProvider, Model: cfg.CurrentModel, - SessionInputTokens: state.tokenInputTotal, - SessionOutputTokens: state.tokenOutputTotal, + SessionInputTokens: state.session.TokenInputTotal, + SessionOutputTokens: state.session.TokenOutputTotal, }, Compact: agentcontext.CompactOptions{ DisableMicroCompact: cfg.Context.Compact.MicroCompactDisabled, @@ -257,8 +257,8 @@ func (s *Service) emitTokenUsage(ctx context.Context, state *runState, result pr s.emit(ctx, EventTokenUsage, state.runID, state.session.ID, TokenUsagePayload{ InputTokens: result.inputTokens, OutputTokens: result.outputTokens, - SessionInputTokens: state.tokenInputTotal, - SessionOutputTokens: state.tokenOutputTotal, + SessionInputTokens: state.session.TokenInputTotal, + SessionOutputTokens: state.session.TokenOutputTotal, }) } diff --git a/internal/runtime/runtime_internal_helpers_test.go b/internal/runtime/runtime_internal_helpers_test.go index 3b5472f6..c1952b82 100644 --- a/internal/runtime/runtime_internal_helpers_test.go +++ b/internal/runtime/runtime_internal_helpers_test.go @@ -40,7 +40,6 @@ func TestRunStateNilReceiverNoops(t *testing.T) { t.Parallel() var state *runState - state.syncSessionTokenTotals() state.recordUsage(3, 5) state.resetTokenTotals() state.touchSession() @@ -53,18 +52,14 @@ func TestRunStateMutationsAndSync(t *testing.T) { state := newRunState("run-1", session) state.recordUsage(10, 20) - if state.tokenInputTotal != 11 || state.tokenOutputTotal != 22 { - t.Fatalf("unexpected token totals: in=%d out=%d", state.tokenInputTotal, state.tokenOutputTotal) - } - - state.syncSessionTokenTotals() if state.session.TokenInputTotal != 11 || state.session.TokenOutputTotal != 22 { t.Fatalf("session totals not synced: %+v", state.session) } state.resetTokenTotals() - if state.tokenInputTotal != 0 || state.tokenOutputTotal != 0 { - t.Fatalf("expected reset totals to be zero, got in=%d out=%d", state.tokenInputTotal, state.tokenOutputTotal) + if state.session.TokenInputTotal != 0 || state.session.TokenOutputTotal != 0 { + t.Fatalf("expected reset totals to be zero, got in=%d out=%d", + state.session.TokenInputTotal, state.session.TokenOutputTotal) } before := state.session.UpdatedAt @@ -74,7 +69,7 @@ func TestRunStateMutationsAndSync(t *testing.T) { t.Fatalf("expected touchSession to update time") } if state.session.TokenInputTotal != 1 || state.session.TokenOutputTotal != 2 { - t.Fatalf("expected touchSession to sync totals") + t.Fatalf("expected touchSession to keep latest totals") } } diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 20cde84e..d054c21c 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -4059,11 +4059,11 @@ func TestRestoreSessionTokens(t *testing.T) { state := newRunState("", session) - if state.tokenInputTotal != 500 { - t.Fatalf("expected sessionInputTokens == 500, got %d", state.tokenInputTotal) + if state.session.TokenInputTotal != 500 { + t.Fatalf("expected sessionInputTokens == 500, got %d", state.session.TokenInputTotal) } - if state.tokenOutputTotal != 200 { - t.Fatalf("expected sessionOutputTokens == 200, got %d", state.tokenOutputTotal) + if state.session.TokenOutputTotal != 200 { + t.Fatalf("expected sessionOutputTokens == 200, got %d", state.session.TokenOutputTotal) } } @@ -4077,11 +4077,11 @@ func TestRestoreSessionTokensNewSession(t *testing.T) { state := newRunState("", session) - if state.tokenInputTotal != 0 { - t.Fatalf("expected sessionInputTokens == 0, got %d", state.tokenInputTotal) + if state.session.TokenInputTotal != 0 { + t.Fatalf("expected sessionInputTokens == 0, got %d", state.session.TokenInputTotal) } - if state.tokenOutputTotal != 0 { - t.Fatalf("expected sessionOutputTokens == 0, got %d", state.tokenOutputTotal) + if state.session.TokenOutputTotal != 0 { + t.Fatalf("expected sessionOutputTokens == 0, got %d", state.session.TokenOutputTotal) } } @@ -4165,8 +4165,8 @@ func TestTokenUsageRecordedOnMessageDone(t *testing.T) { service.emit(context.Background(), EventTokenUsage, "test-run-id", "test-session-id", TokenUsagePayload{ InputTokens: payload.Usage.InputTokens, OutputTokens: payload.Usage.OutputTokens, - SessionInputTokens: state.tokenInputTotal, - SessionOutputTokens: state.tokenOutputTotal, + SessionInputTokens: state.session.TokenInputTotal, + SessionOutputTokens: state.session.TokenOutputTotal, }) } }}, @@ -4176,11 +4176,11 @@ func TestTokenUsageRecordedOnMessageDone(t *testing.T) { } // Verify the service counters are updated - if state.tokenInputTotal != 100 { - t.Fatalf("expected sessionInputTokens == 100, got %d", state.tokenInputTotal) + if state.session.TokenInputTotal != 100 { + t.Fatalf("expected sessionInputTokens == 100, got %d", state.session.TokenInputTotal) } - if state.tokenOutputTotal != 50 { - t.Fatalf("expected sessionOutputTokens == 50, got %d", state.tokenOutputTotal) + if state.session.TokenOutputTotal != 50 { + t.Fatalf("expected sessionOutputTokens == 50, got %d", state.session.TokenOutputTotal) } // Verify EventTokenUsage was emitted with correct payload diff --git a/internal/runtime/session_mutation.go b/internal/runtime/session_mutation.go index 9e3ab972..e52a1574 100644 --- a/internal/runtime/session_mutation.go +++ b/internal/runtime/session_mutation.go @@ -33,7 +33,6 @@ func (s *Service) appendAssistantMessageAndSave( metadataChanged := state.session.Provider != snapshot.providerConfig.Name || state.session.Model != snapshot.model state.session.Provider = snapshot.providerConfig.Name state.session.Model = snapshot.model - state.syncSessionTokenTotals() if strings.TrimSpace(assistant.Content) != "" || len(assistant.ToolCalls) > 0 { state.session.Messages = append(state.session.Messages, assistant) diff --git a/internal/runtime/state.go b/internal/runtime/state.go index a5d24411..c0b160bc 100644 --- a/internal/runtime/state.go +++ b/internal/runtime/state.go @@ -1,6 +1,7 @@ package runtime import ( + "sync" "time" "neo-code/internal/config" @@ -15,43 +16,33 @@ const maxReactiveCompactAttempts = 3 // runState 汇总单次 Run 生命周期内会变化的会话与计量状态。 type runState struct { + mu sync.Mutex runID string session agentsession.Session - tokenInputTotal int - tokenOutputTotal int compactApplied bool reactiveCompactAttempts int rememberedThisRun bool turn int phase controlplane.Phase + stopEmitted bool + progress controlplane.ProgressState } // newRunState 基于持久化会话创建一次运行的内存状态镜像。 func newRunState(runID string, session agentsession.Session) runState { return runState{ - runID: runID, - session: session, - tokenInputTotal: session.TokenInputTotal, - tokenOutputTotal: session.TokenOutputTotal, + runID: runID, + session: session, } } -// syncSessionTokenTotals 将运行期 token 计数同步回会话对象。 -func (s *runState) syncSessionTokenTotals() { - if s == nil { - return - } - s.session.TokenInputTotal = s.tokenInputTotal - s.session.TokenOutputTotal = s.tokenOutputTotal -} - // recordUsage 累加本轮 provider 返回的 token 使用量。 func (s *runState) recordUsage(inputTokens int, outputTokens int) { if s == nil { return } - s.tokenInputTotal += inputTokens - s.tokenOutputTotal += outputTokens + s.session.TokenInputTotal += inputTokens + s.session.TokenOutputTotal += outputTokens } // resetTokenTotals 在 compact 应用成功后清零当前运行的 token 账本。 @@ -59,17 +50,15 @@ func (s *runState) resetTokenTotals() { if s == nil { return } - s.tokenInputTotal = 0 - s.tokenOutputTotal = 0 - s.syncSessionTokenTotals() + s.session.TokenInputTotal = 0 + s.session.TokenOutputTotal = 0 } -// touchSession 更新会话修改时间并同步最新 token 累计值。 +// touchSession 更新会话修改时间。 func (s *runState) touchSession() { if s == nil { return } - s.syncSessionTokenTotals() s.session.UpdatedAt = time.Now() } From 1afecba70d665080d528690c97bde21b2ac34a03 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Wed, 15 Apr 2026 02:45:28 +0000 Subject: [PATCH 12/33] test(runtime): align helper tests with main to remove merge conflicts Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: Cai-Tang-www <106404101+Cai-Tang-www@users.noreply.github.com> --- .../runtime/runtime_internal_helpers_test.go | 62 +++++++++++++++---- 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/internal/runtime/runtime_internal_helpers_test.go b/internal/runtime/runtime_internal_helpers_test.go index c1952b82..97d779ad 100644 --- a/internal/runtime/runtime_internal_helpers_test.go +++ b/internal/runtime/runtime_internal_helpers_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "neo-code/internal/config" "neo-code/internal/provider" providertypes "neo-code/internal/provider/types" "neo-code/internal/runtime/approval" @@ -22,6 +21,25 @@ type stubMemoExtractor struct { doneCh chan struct{} } +type lockProbeStore struct { + saveFn func(ctx context.Context, session *agentsession.Session) error +} + +func (s *lockProbeStore) Save(ctx context.Context, session *agentsession.Session) error { + if s.saveFn == nil { + return nil + } + return s.saveFn(ctx, session) +} + +func (s *lockProbeStore) Load(ctx context.Context, id string) (agentsession.Session, error) { + return agentsession.Session{}, errors.New("not implemented") +} + +func (s *lockProbeStore) ListSummaries(ctx context.Context) ([]agentsession.Summary, error) { + return nil, errors.New("not implemented") +} + func (s *stubMemoExtractor) Schedule(_ string, messages []providertypes.Message) { s.mu.Lock() s.calls++ @@ -53,13 +71,12 @@ func TestRunStateMutationsAndSync(t *testing.T) { state.recordUsage(10, 20) if state.session.TokenInputTotal != 11 || state.session.TokenOutputTotal != 22 { - t.Fatalf("session totals not synced: %+v", state.session) + t.Fatalf("unexpected token totals: in=%d out=%d", state.session.TokenInputTotal, state.session.TokenOutputTotal) } state.resetTokenTotals() if state.session.TokenInputTotal != 0 || state.session.TokenOutputTotal != 0 { - t.Fatalf("expected reset totals to be zero, got in=%d out=%d", - state.session.TokenInputTotal, state.session.TokenOutputTotal) + t.Fatalf("expected reset totals to be zero, got in=%d out=%d", state.session.TokenInputTotal, state.session.TokenOutputTotal) } before := state.session.UpdatedAt @@ -69,7 +86,7 @@ func TestRunStateMutationsAndSync(t *testing.T) { t.Fatalf("expected touchSession to update time") } if state.session.TokenInputTotal != 1 || state.session.TokenOutputTotal != 2 { - t.Fatalf("expected touchSession to keep latest totals") + t.Fatalf("expected recordUsage to sync totals") } } @@ -138,17 +155,36 @@ func TestAppendToolMessageAndSaveSanitizesMetadata(t *testing.T) { } } -func TestResolveMaxLoopsBranches(t *testing.T) { +func TestAppendToolMessageAndSaveUnlocksStateBeforePersist(t *testing.T) { t.Parallel() - if got := resolveMaxLoops(config.Config{MaxLoops: 0}); got != defaultMaxLoops { - t.Fatalf("expected default max loops for zero, got %d", got) - } - if got := resolveMaxLoops(config.Config{MaxLoops: -3}); got != defaultMaxLoops { - t.Fatalf("expected default max loops for negative, got %d", got) + session := newRuntimeSession("session-append-tool-lock") + state := newRunState("run-append-tool-lock", session) + + store := &lockProbeStore{ + saveFn: func(_ context.Context, _ *agentsession.Session) error { + locked := make(chan struct{}) + go func() { + state.mu.Lock() + state.mu.Unlock() + close(locked) + }() + + select { + case <-locked: + return nil + case <-time.After(200 * time.Millisecond): + return errors.New("state lock is still held during save") + } + }, } - if got := resolveMaxLoops(config.Config{MaxLoops: 12}); got != 12 { - t.Fatalf("expected explicit max loops, got %d", got) + + service := &Service{sessionStore: store} + call := providertypes.ToolCall{ID: "call-1", Name: "filesystem_read_file"} + result := tools.ToolResult{Name: "filesystem_read_file", Content: "ok"} + + if err := service.appendToolMessageAndSave(context.Background(), &state, call, result); err != nil { + t.Fatalf("appendToolMessageAndSave() error = %v", err) } } From 727ed3fc1f1e9300b46d899f419147dd784d048b Mon Sep 17 00:00:00 2001 From: xgopilot Date: Wed, 15 Apr 2026 03:26:26 +0000 Subject: [PATCH 13/33] fix(subagent): address review gaps in progress envelope and capability bounds Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: Cai-Tang-www <106404101+Cai-Tang-www@users.noreply.github.com> --- internal/runtime/subagent_run.go | 28 ++++-- internal/runtime/subagent_run_test.go | 123 ++++++++++++++++++++++++++ internal/subagent/worker.go | 17 +++- internal/subagent/worker_test.go | 61 +++++++++++++ 4 files changed, 221 insertions(+), 8 deletions(-) diff --git a/internal/runtime/subagent_run.go b/internal/runtime/subagent_run.go index bf0adbd8..31bbd2b6 100644 --- a/internal/runtime/subagent_run.go +++ b/internal/runtime/subagent_run.go @@ -5,7 +5,9 @@ import ( "errors" "fmt" "strings" + "time" + "neo-code/internal/runtime/controlplane" "neo-code/internal/subagent" ) @@ -118,7 +120,7 @@ func (s *Service) RunSubAgentTask(ctx context.Context, input SubAgentTaskInput) if result.State == subagent.StateSucceeded { return result, nil } - return result, errors.New(result.Error) + return result, subAgentResultError(result) } } @@ -132,13 +134,17 @@ func emitSubAgentProgress(s *Service, input SubAgentTaskInput, stepResult subage Delta: stepResult.Delta, Error: errorText(stepErr), } + event := RuntimeEvent{ + Type: EventSubAgentProgress, + RunID: input.RunID, + SessionID: input.SessionID, + Turn: turnUnspecified, + Timestamp: time.Now(), + PayloadVersion: controlplane.PayloadVersion, + Payload: payload, + } select { - case s.events <- RuntimeEvent{ - Type: EventSubAgentProgress, - RunID: input.RunID, - SessionID: input.SessionID, - Payload: payload, - }: + case s.events <- event: default: } } @@ -171,3 +177,11 @@ func errorText(err error) string { } return strings.TrimSpace(err.Error()) } + +// subAgentResultError 将子代理终态结果转换为可诊断错误,避免空错误文本丢失上下文。 +func subAgentResultError(result subagent.Result) error { + if text := strings.TrimSpace(result.Error); text != "" { + return errors.New(text) + } + return fmt.Errorf("subagent ended with state=%s stop_reason=%s", result.State, result.StopReason) +} diff --git a/internal/runtime/subagent_run_test.go b/internal/runtime/subagent_run_test.go index 116abaf2..96f20f81 100644 --- a/internal/runtime/subagent_run_test.go +++ b/internal/runtime/subagent_run_test.go @@ -3,9 +3,11 @@ package runtime import ( "context" "errors" + "strings" "testing" "time" + "neo-code/internal/runtime/controlplane" "neo-code/internal/subagent" ) @@ -74,6 +76,17 @@ func TestServiceRunSubAgentTaskSuccess(t *testing.T) { EventSubAgentCompleted, }) assertEventsRunID(t, events, "sub-run-success") + for _, evt := range events { + if evt.Type != EventSubAgentProgress { + continue + } + if evt.Timestamp.IsZero() { + t.Fatalf("progress event timestamp should be set: %+v", evt) + } + if evt.PayloadVersion != controlplane.PayloadVersion { + t.Fatalf("progress event payload version = %d, want %d", evt.PayloadVersion, controlplane.PayloadVersion) + } + } } func TestServiceRunSubAgentTaskFailureFlows(t *testing.T) { @@ -190,6 +203,116 @@ func TestServiceRunSubAgentTaskFailureFlows(t *testing.T) { events := collectRuntimeEvents(service.Events()) assertEventSequence(t, events, []EventType{EventSubAgentFailed}) }) + + t.Run("custom worker failed without explicit error should return fallback", func(t *testing.T) { + t.Parallel() + + service := NewWithFactory(nil, nil, nil, nil, nil) + service.SetSubAgentFactory(stubSubAgentFactory{ + create: func(role subagent.Role) (subagent.WorkerRuntime, error) { + return &stubSubAgentWorker{ + result: subagent.Result{ + Role: role, + TaskID: "task-fallback-error", + State: subagent.StateFailed, + StopReason: subagent.StopReasonError, + }, + stepResult: subagent.StepResult{ + State: subagent.StateFailed, + Done: true, + Step: 1, + }, + }, nil + }, + }) + + _, err := service.RunSubAgentTask(context.Background(), SubAgentTaskInput{ + RunID: "sub-run-fallback-error", + Role: subagent.RoleReviewer, + Task: subagent.Task{ + ID: "task-fallback-error", + Goal: "review", + }, + }) + if err == nil { + t.Fatalf("expected fallback error") + } + if !strings.Contains(err.Error(), "state=failed") || !strings.Contains(err.Error(), "stop_reason=error") { + t.Fatalf("error = %q, want state/stop_reason fallback", err.Error()) + } + }) +} + +type stubSubAgentFactory struct { + create func(role subagent.Role) (subagent.WorkerRuntime, error) +} + +func (s stubSubAgentFactory) Create(role subagent.Role) (subagent.WorkerRuntime, error) { + return s.create(role) +} + +type stubSubAgentWorker struct { + startErr error + stepResult subagent.StepResult + stepErr error + result subagent.Result + resultErr error + current subagent.State + stopInvoked bool +} + +func (s *stubSubAgentWorker) Start(task subagent.Task, budget subagent.Budget, capability subagent.Capability) error { + if s.startErr != nil { + return s.startErr + } + if s.current == "" { + s.current = subagent.StateRunning + } + s.result.Role = firstNonEmptyRole(s.result.Role, subagent.RoleReviewer) + if strings.TrimSpace(s.result.TaskID) == "" { + s.result.TaskID = task.ID + } + return nil +} + +func (s *stubSubAgentWorker) Step(ctx context.Context) (subagent.StepResult, error) { + if s.stepResult.State == "" { + s.stepResult.State = s.current + } + if s.stepResult.Done { + s.current = s.result.State + } + return s.stepResult, s.stepErr +} + +func (s *stubSubAgentWorker) Stop(reason subagent.StopReason) error { + s.stopInvoked = true + s.current = subagent.StateCanceled + s.result.State = subagent.StateCanceled + s.result.StopReason = reason + return nil +} + +func (s *stubSubAgentWorker) Result() (subagent.Result, error) { + return s.result, s.resultErr +} + +func (s *stubSubAgentWorker) State() subagent.State { + if s.current == "" { + return subagent.StateIdle + } + return s.current +} + +func (s *stubSubAgentWorker) Policy() subagent.RolePolicy { + return subagent.RolePolicy{} +} + +func firstNonEmptyRole(role subagent.Role, fallback subagent.Role) subagent.Role { + if role != "" { + return role + } + return fallback } func TestServiceRunSubAgentTaskInputValidation(t *testing.T) { diff --git a/internal/subagent/worker.go b/internal/subagent/worker.go index 6e6c791a..18e9cf17 100644 --- a/internal/subagent/worker.go +++ b/internal/subagent/worker.go @@ -147,7 +147,7 @@ func (w *worker) Step(ctx context.Context) (StepResult, error) { w.stepCount++ delta := strings.TrimSpace(stepOutput.Delta) if delta != "" { - w.trace = append(w.trace, delta) + w.trace = appendTraceBounded(w.trace, delta, traceWindowSize) } if stepOutput.Done { @@ -174,6 +174,10 @@ func (w *worker) Step(ctx context.Context) (StepResult, error) { // bindCapabilityToPolicy 将 capability 约束在角色策略允许的工具集合内。 func bindCapabilityToPolicy(capability Capability, policy RolePolicy) (Capability, error) { + if len(capability.AllowedPaths) > 0 { + return Capability{}, errorsf("capability allowed paths is not supported yet") + } + allowedPolicyTools := dedupeAndTrim(policy.AllowedTools) allowedSet := make(map[string]struct{}, len(allowedPolicyTools)) for _, tool := range allowedPolicyTools { @@ -202,6 +206,17 @@ func bindCapabilityToPolicy(capability Capability, policy RolePolicy) (Capabilit return capability, nil } +// appendTraceBounded 将新增 trace 追加到切片尾部,并保证内部存储长度不超过上限。 +func appendTraceBounded(trace []string, delta string, limit int) []string { + trace = append(trace, delta) + if limit <= 0 || len(trace) <= limit { + return trace + } + start := len(trace) - limit + copy(trace, trace[start:]) + return trace[:limit] +} + // cloneRecentTrace 复制最近 limit 条 trace,避免每步复制完整历史导致复杂度放大。 func cloneRecentTrace(trace []string, limit int) []string { if len(trace) == 0 { diff --git a/internal/subagent/worker_test.go b/internal/subagent/worker_test.go index b323ad98..532e4c76 100644 --- a/internal/subagent/worker_test.go +++ b/internal/subagent/worker_test.go @@ -276,6 +276,16 @@ func TestWorkerStartCapabilityPolicyGuard(t *testing.T) { t.Fatalf("expected disallowed capability tool to fail") } + wPath, err := NewWorker(RoleReviewer, policy, nil) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + if err := wPath.Start(Task{ID: "t-cap-path", Goal: "goal"}, Budget{}, Capability{ + AllowedPaths: []string{"/tmp/workspace"}, + }); err == nil { + t.Fatalf("expected unsupported allowed paths to fail") + } + w2, err := NewWorker(RoleReviewer, policy, nil) if err != nil { t.Fatalf("NewWorker() error = %v", err) @@ -343,6 +353,57 @@ func TestWorkerTraceWindow(t *testing.T) { } } +func TestWorkerTraceStorageBounded(t *testing.T) { + t.Parallel() + + policy, err := DefaultRolePolicy(RoleResearcher) + if err != nil { + t.Fatalf("DefaultRolePolicy() error = %v", err) + } + + w, err := NewWorker(RoleResearcher, policy, EngineFunc(func(ctx context.Context, input StepInput) (StepOutput, error) { + if input.StepIndex < traceWindowSize+8 { + return StepOutput{Delta: "delta", Done: false}, nil + } + return StepOutput{ + Delta: "delta", + Done: true, + Output: Output{ + Summary: "done", + Findings: []string{"f1"}, + Patches: []string{"p1"}, + Risks: []string{"r1"}, + NextActions: []string{"n1"}, + Artifacts: []string{"a1"}, + }, + }, nil + })) + if err != nil { + t.Fatalf("NewWorker() error = %v", err) + } + if err := w.Start(Task{ID: "t-trace-bounded", Goal: "goal"}, Budget{MaxSteps: traceWindowSize + 16}, Capability{}); err != nil { + t.Fatalf("Start() error = %v", err) + } + + for { + step, stepErr := w.Step(context.Background()) + if stepErr != nil { + t.Fatalf("Step() error = %v", stepErr) + } + if step.Done { + break + } + } + + impl, ok := w.(*worker) + if !ok { + t.Fatalf("expected *worker implementation") + } + if len(impl.trace) != traceWindowSize { + t.Fatalf("trace storage len = %d, want %d", len(impl.trace), traceWindowSize) + } +} + func TestWorkerNilAndValidationBranches(t *testing.T) { t.Parallel() From 9a7f69e072ae2f627c84cdbc0518793dc11b6142 Mon Sep 17 00:00:00 2001 From: pionxe Date: Wed, 15 Apr 2026 11:31:40 +0800 Subject: [PATCH 14/33] =?UTF-8?q?feat(gateway):=20[EPIC-GW-03]=20=E6=9C=AC?= =?UTF-8?q?=E5=9C=B0=E6=8E=A7=E5=88=B6=E9=9D=A2=E5=A5=91=E7=BA=A6=E7=A1=AC?= =?UTF-8?q?=E5=8C=96=E4=B8=8E=20JSON-RPC=20=E5=8D=8F=E8=AE=AE=E5=88=87?= =?UTF-8?q?=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 实现控制面底层 IPC 传输从“裸 MessageFrame”向“标准 JSON-RPC 2.0”的直接无缝切换,硬化内部契约,为后续 Runtime 接入奠定稳定基座。 主要改动: 1. 协议归一化:新增 JSON-RPC 解析层,使用 json.RawMessage 延迟解析 params,并统一收敛至内部 MessageFrame 进行路由。 2. 错误模型重构:采用标准 JSON-RPC 错误码 (error.code),并将网关业务错误码统一映射至 error.data.gateway_code。 3. 适配器升级:url-dispatch 全面切换为 JSON-RPC 收发,并补齐了严密的 RequestID/Action 响应关联校验。 4. 稳定性增强:修复 URL 调度器的 context 取消中断逻辑(防止无界阻塞),并修复 Windows 环境下的绝对路径越界检测漏洞。 5. 测试覆盖:大幅增补 Server/Dispatcher 协议回归单测,网关核心覆盖率达 95.3%。 --- .../gateway/adapters/urlscheme/dispatcher.go | 109 ++++++- .../adapters/urlscheme/dispatcher_test.go | 287 +++++++++++++++--- internal/gateway/handlers/wake.go | 2 +- internal/gateway/protocol/jsonrpc.go | 274 +++++++++++++++++ internal/gateway/protocol/jsonrpc_test.go | 174 +++++++++++ internal/gateway/server.go | 89 ++++-- internal/gateway/server_additional_test.go | 88 +++++- internal/gateway/server_test.go | 92 +++--- 8 files changed, 1004 insertions(+), 111 deletions(-) create mode 100644 internal/gateway/protocol/jsonrpc.go create mode 100644 internal/gateway/protocol/jsonrpc_test.go diff --git a/internal/gateway/adapters/urlscheme/dispatcher.go b/internal/gateway/adapters/urlscheme/dispatcher.go index 2abf68a3..74ab7ccd 100644 --- a/internal/gateway/adapters/urlscheme/dispatcher.go +++ b/internal/gateway/adapters/urlscheme/dispatcher.go @@ -1,11 +1,13 @@ package urlscheme import ( + "bytes" "context" "encoding/json" "errors" "fmt" "net" + "strings" "sync/atomic" "time" @@ -110,28 +112,69 @@ func (d *Dispatcher) Dispatch(ctx context.Context, request DispatchRequest) (Dis Payload: intent, } + requestIDRaw, err := marshalJSONRawMessage(requestFrame.RequestID) + if err != nil { + return DispatchResult{}, newDispatchError(ErrorCodeInternal, fmt.Sprintf("encode request id: %v", err)) + } + requestParamsRaw, err := marshalJSONRawMessage(intent) + if err != nil { + return DispatchResult{}, newDispatchError(ErrorCodeInternal, fmt.Sprintf("encode request params: %v", err)) + } + rpcRequest := protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + ID: requestIDRaw, + Method: protocol.MethodWakeOpenURL, + Params: requestParamsRaw, + } + if err := ensureDispatchContextActive(ctx); err != nil { return DispatchResult{}, toDispatchError(err) } encoder := json.NewEncoder(conn) - if err := encoder.Encode(requestFrame); err != nil { + if err := encoder.Encode(rpcRequest); err != nil { if ctx != nil && ctx.Err() != nil { ctxErr := ctx.Err() return DispatchResult{}, toDispatchError(ctxErr) } - return DispatchResult{}, newDispatchError(ErrorCodeInternal, fmt.Sprintf("write request frame: %v", err)) + return DispatchResult{}, newDispatchError(ErrorCodeInternal, fmt.Sprintf("write request rpc: %v", err)) } - var responseFrame gateway.MessageFrame + var rpcResponse protocol.JSONRPCResponse if err := ensureDispatchContextActive(ctx); err != nil { return DispatchResult{}, toDispatchError(err) } decoder := json.NewDecoder(conn) - if err := decoder.Decode(&responseFrame); err != nil { + if err := decoder.Decode(&rpcResponse); err != nil { if ctx != nil && ctx.Err() != nil { ctxErr := ctx.Err() return DispatchResult{}, toDispatchError(ctxErr) } + return DispatchResult{}, newDispatchError(ErrorCodeUnexpectedResponse, fmt.Sprintf("decode response rpc: %v", err)) + } + if strings.TrimSpace(rpcResponse.JSONRPC) != protocol.JSONRPCVersion { + return DispatchResult{}, newDispatchError( + ErrorCodeUnexpectedResponse, + "unexpected response jsonrpc version", + ) + } + if !rawJSONMessageEqual(rpcResponse.ID, rpcRequest.ID) { + return DispatchResult{}, newDispatchError(ErrorCodeUnexpectedResponse, "rpc correlation failed: id mismatch") + } + if rpcResponse.Error != nil && rpcResponse.Result != nil { + return DispatchResult{}, newDispatchError( + ErrorCodeUnexpectedResponse, + "unexpected response payload: both result and error are present", + ) + } + if rpcResponse.Error != nil { + return DispatchResult{}, toDispatchErrorFromJSONRPC(rpcResponse.Error) + } + if rpcResponse.Result == nil { + return DispatchResult{}, newDispatchError(ErrorCodeUnexpectedResponse, "gateway response missing result payload") + } + + responseFrame, err := decodeResponseFrameResult(rpcResponse.Result) + if err != nil { return DispatchResult{}, newDispatchError(ErrorCodeUnexpectedResponse, fmt.Sprintf("decode response frame: %v", err)) } if responseFrame.Action != requestFrame.Action || responseFrame.RequestID != requestFrame.RequestID { @@ -201,6 +244,64 @@ func watchDispatchCancellation(ctx context.Context, conn net.Conn) func() { } } +// toDispatchErrorFromJSONRPC 将 JSON-RPC 错误对象映射为 url-dispatch 稳定错误。 +func toDispatchErrorFromJSONRPC(rpcError *protocol.JSONRPCError) error { + if rpcError == nil { + return newDispatchError(ErrorCodeUnexpectedResponse, "gateway returned empty rpc error payload") + } + + code := strings.TrimSpace(protocol.GatewayCodeFromJSONRPCError(rpcError)) + if code == "" { + code = mapJSONRPCCodeToDispatchCode(rpcError.Code) + } + message := strings.TrimSpace(rpcError.Message) + if message == "" { + message = "gateway returned empty rpc error message" + } + return newDispatchError(code, message) +} + +// mapJSONRPCCodeToDispatchCode 为缺少 gateway_code 的响应提供兜底错误码映射。 +func mapJSONRPCCodeToDispatchCode(code int) string { + switch code { + case protocol.JSONRPCCodeMethodNotFound: + return gateway.ErrorCodeUnsupportedAction.String() + case protocol.JSONRPCCodeInvalidRequest, protocol.JSONRPCCodeInvalidParams, protocol.JSONRPCCodeParseError: + return gateway.ErrorCodeInvalidFrame.String() + case protocol.JSONRPCCodeInternalError: + return gateway.ErrorCodeInternalError.String() + default: + return ErrorCodeInternal + } +} + +// decodeResponseFrameResult 将 JSON-RPC result 安全解码回 MessageFrame。 +func decodeResponseFrameResult(result any) (gateway.MessageFrame, error) { + raw, err := json.Marshal(result) + if err != nil { + return gateway.MessageFrame{}, err + } + var frame gateway.MessageFrame + if err := json.Unmarshal(raw, &frame); err != nil { + return gateway.MessageFrame{}, err + } + return frame, nil +} + +// rawJSONMessageEqual 比较两段 JSON 原文在去除首尾空白后的字节是否一致。 +func rawJSONMessageEqual(left, right json.RawMessage) bool { + return bytes.Equal(bytes.TrimSpace(left), bytes.TrimSpace(right)) +} + +// marshalJSONRawMessage 将任意对象编码为 json.RawMessage,便于构造 JSON-RPC 请求字段。 +func marshalJSONRawMessage(payload any) (json.RawMessage, error) { + raw, err := json.Marshal(payload) + if err != nil { + return nil, err + } + return json.RawMessage(raw), nil +} + // toDispatchError 将不同来源错误转换为统一结构化错误。 func toDispatchError(err error) error { if err == nil { diff --git a/internal/gateway/adapters/urlscheme/dispatcher_test.go b/internal/gateway/adapters/urlscheme/dispatcher_test.go index e371a4e2..e52569c6 100644 --- a/internal/gateway/adapters/urlscheme/dispatcher_test.go +++ b/internal/gateway/adapters/urlscheme/dispatcher_test.go @@ -12,6 +12,7 @@ import ( "time" "neo-code/internal/gateway" + "neo-code/internal/gateway/protocol" "neo-code/internal/gateway/transport" ) @@ -40,25 +41,29 @@ func TestDispatcherDispatchSuccess(t *testing.T) { decoder := json.NewDecoder(serverConn) encoder := json.NewEncoder(serverConn) - var requestFrame gateway.MessageFrame - if err := decoder.Decode(&requestFrame); err != nil { - t.Errorf("decode request frame: %v", err) + var rpcRequest protocol.JSONRPCRequest + if err := decoder.Decode(&rpcRequest); err != nil { + t.Errorf("decode request rpc: %v", err) return } - if requestFrame.Action != gateway.FrameActionWakeOpenURL { - t.Errorf("request action = %q, want %q", requestFrame.Action, gateway.FrameActionWakeOpenURL) + if rpcRequest.Method != protocol.MethodWakeOpenURL { + t.Errorf("request method = %q, want %q", rpcRequest.Method, protocol.MethodWakeOpenURL) return } - if err := encoder.Encode(gateway.MessageFrame{ - Type: gateway.FrameTypeAck, - Action: gateway.FrameActionWakeOpenURL, - RequestID: requestFrame.RequestID, - Payload: map[string]any{ - "message": "wake intent accepted", + if err := encoder.Encode(protocol.JSONRPCResponse{ + JSONRPC: protocol.JSONRPCVersion, + ID: rpcRequest.ID, + Result: gateway.MessageFrame{ + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionWakeOpenURL, + RequestID: "wake-1", + Payload: map[string]any{ + "message": "wake intent accepted", + }, }, }); err != nil { - t.Errorf("encode response frame: %v", err) + t.Errorf("encode response rpc: %v", err) } }() @@ -94,16 +99,16 @@ func TestDispatcherDispatchReturnsGatewayError(t *testing.T) { go func() { decoder := json.NewDecoder(serverConn) encoder := json.NewEncoder(serverConn) - var requestFrame gateway.MessageFrame - _ = decoder.Decode(&requestFrame) - _ = encoder.Encode(gateway.MessageFrame{ - Type: gateway.FrameTypeError, - Action: requestFrame.Action, - RequestID: requestFrame.RequestID, - Error: &gateway.FrameError{ - Code: gateway.ErrorCodeInvalidAction.String(), - Message: "unsupported wake action", - }, + var rpcRequest protocol.JSONRPCRequest + _ = decoder.Decode(&rpcRequest) + _ = encoder.Encode(protocol.JSONRPCResponse{ + JSONRPC: protocol.JSONRPCVersion, + ID: rpcRequest.ID, + Error: protocol.NewJSONRPCError( + protocol.JSONRPCCodeInvalidParams, + "unsupported wake action", + gateway.ErrorCodeInvalidAction.String(), + ), }) }() @@ -139,12 +144,16 @@ func TestDispatcherDispatchReturnsUnexpectedResponseError(t *testing.T) { go func() { decoder := json.NewDecoder(serverConn) encoder := json.NewEncoder(serverConn) - var requestFrame gateway.MessageFrame - _ = decoder.Decode(&requestFrame) - _ = encoder.Encode(gateway.MessageFrame{ - Type: gateway.FrameTypeEvent, - Action: requestFrame.Action, - RequestID: requestFrame.RequestID, + var rpcRequest protocol.JSONRPCRequest + _ = decoder.Decode(&rpcRequest) + _ = encoder.Encode(protocol.JSONRPCResponse{ + JSONRPC: protocol.JSONRPCVersion, + ID: rpcRequest.ID, + Result: gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionWakeOpenURL, + RequestID: "wake-3", + }, }) }() @@ -179,12 +188,16 @@ func TestDispatcherDispatchReturnsCorrelationMismatchError(t *testing.T) { go func() { decoder := json.NewDecoder(serverConn) encoder := json.NewEncoder(serverConn) - var requestFrame gateway.MessageFrame - _ = decoder.Decode(&requestFrame) - _ = encoder.Encode(gateway.MessageFrame{ - Type: gateway.FrameTypeAck, - Action: requestFrame.Action, - RequestID: "wake-mismatch", + var rpcRequest protocol.JSONRPCRequest + _ = decoder.Decode(&rpcRequest) + _ = encoder.Encode(protocol.JSONRPCResponse{ + JSONRPC: protocol.JSONRPCVersion, + ID: rpcRequest.ID, + Result: gateway.MessageFrame{ + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionWakeOpenURL, + RequestID: "wake-mismatch", + }, }) }() @@ -302,9 +315,9 @@ func TestDispatcherDispatchInterruptsBlockedReadOnContextCancel(t *testing.T) { defer close(serverDone) decoder := json.NewDecoder(serverConn) - var frame gateway.MessageFrame - if err := decoder.Decode(&frame); err != nil { - t.Errorf("decode request frame: %v", err) + var rpcRequest protocol.JSONRPCRequest + if err := decoder.Decode(&rpcRequest); err != nil { + t.Errorf("decode request rpc: %v", err) return } close(requestArrived) @@ -573,12 +586,12 @@ func TestDispatcherDispatchAdditionalErrorBranches(t *testing.T) { } }) - t.Run("gateway error frame missing payload", func(t *testing.T) { + t.Run("gateway response missing result payload", func(t *testing.T) { dispatcher := &Dispatcher{ resolveListenAddressFn: func(string) (string, error) { return "stub://gateway", nil }, dialFn: func(string) (net.Conn, error) { return &stubDispatchConn{ - readBuffer: bytes.NewBufferString(`{"type":"error","action":"wake.openUrl","request_id":"wake-14"}` + "\n"), + readBuffer: bytes.NewBufferString(`{"jsonrpc":"2.0","id":"wake-14"}` + "\n"), }, nil }, requestIDFn: func() string { return "wake-14" }, @@ -588,7 +601,158 @@ func TestDispatcherDispatchAdditionalErrorBranches(t *testing.T) { RawURL: "neocode://review?path=README.md", }) if err == nil { - t.Fatal("expected missing error payload branch") + t.Fatal("expected missing result payload branch") + } + var dispatchErr *DispatchError + if !errors.As(err, &dispatchErr) { + t.Fatalf("error type = %T, want *DispatchError", err) + } + if dispatchErr.Code != ErrorCodeUnexpectedResponse { + t.Fatalf("error code = %q, want %q", dispatchErr.Code, ErrorCodeUnexpectedResponse) + } + }) + + t.Run("response rpc version mismatch", func(t *testing.T) { + dispatcher := &Dispatcher{ + resolveListenAddressFn: func(string) (string, error) { return "stub://gateway", nil }, + dialFn: func(string) (net.Conn, error) { + return &stubDispatchConn{ + readBuffer: bytes.NewBufferString(`{"jsonrpc":"1.0","id":"wake-15","result":{}}` + "\n"), + }, nil + }, + requestIDFn: func() string { return "wake-15" }, + } + + _, err := dispatcher.Dispatch(context.Background(), DispatchRequest{RawURL: "neocode://review?path=README.md"}) + if err == nil { + t.Fatal("expected rpc version mismatch error") + } + var dispatchErr *DispatchError + if !errors.As(err, &dispatchErr) { + t.Fatalf("error type = %T, want *DispatchError", err) + } + if dispatchErr.Code != ErrorCodeUnexpectedResponse { + t.Fatalf("error code = %q, want %q", dispatchErr.Code, ErrorCodeUnexpectedResponse) + } + }) + + t.Run("response rpc id mismatch", func(t *testing.T) { + dispatcher := &Dispatcher{ + resolveListenAddressFn: func(string) (string, error) { return "stub://gateway", nil }, + dialFn: func(string) (net.Conn, error) { + return &stubDispatchConn{ + readBuffer: bytes.NewBufferString(`{"jsonrpc":"2.0","id":"wake-other","result":{}}` + "\n"), + }, nil + }, + requestIDFn: func() string { return "wake-16" }, + } + + _, err := dispatcher.Dispatch(context.Background(), DispatchRequest{RawURL: "neocode://review?path=README.md"}) + if err == nil { + t.Fatal("expected rpc id mismatch error") + } + var dispatchErr *DispatchError + if !errors.As(err, &dispatchErr) { + t.Fatalf("error type = %T, want *DispatchError", err) + } + if dispatchErr.Code != ErrorCodeUnexpectedResponse { + t.Fatalf("error code = %q, want %q", dispatchErr.Code, ErrorCodeUnexpectedResponse) + } + }) + + t.Run("response contains both result and error", func(t *testing.T) { + dispatcher := &Dispatcher{ + resolveListenAddressFn: func(string) (string, error) { return "stub://gateway", nil }, + dialFn: func(string) (net.Conn, error) { + return &stubDispatchConn{ + readBuffer: bytes.NewBufferString( + `{"jsonrpc":"2.0","id":"wake-17","result":{},"error":{"code":-32603,"message":"boom"}}` + "\n", + ), + }, nil + }, + requestIDFn: func() string { return "wake-17" }, + } + + _, err := dispatcher.Dispatch(context.Background(), DispatchRequest{RawURL: "neocode://review?path=README.md"}) + if err == nil { + t.Fatal("expected both result and error payload failure") + } + var dispatchErr *DispatchError + if !errors.As(err, &dispatchErr) { + t.Fatalf("error type = %T, want *DispatchError", err) + } + if dispatchErr.Code != ErrorCodeUnexpectedResponse { + t.Fatalf("error code = %q, want %q", dispatchErr.Code, ErrorCodeUnexpectedResponse) + } + }) + + t.Run("rpc error without gateway_code uses fallback code map", func(t *testing.T) { + dispatcher := &Dispatcher{ + resolveListenAddressFn: func(string) (string, error) { return "stub://gateway", nil }, + dialFn: func(string) (net.Conn, error) { + return &stubDispatchConn{ + readBuffer: bytes.NewBufferString( + `{"jsonrpc":"2.0","id":"wake-18","error":{"code":-32601,"message":"method not found"}}` + "\n", + ), + }, nil + }, + requestIDFn: func() string { return "wake-18" }, + } + + _, err := dispatcher.Dispatch(context.Background(), DispatchRequest{RawURL: "neocode://review?path=README.md"}) + if err == nil { + t.Fatal("expected rpc error mapping failure") + } + var dispatchErr *DispatchError + if !errors.As(err, &dispatchErr) { + t.Fatalf("error type = %T, want *DispatchError", err) + } + if dispatchErr.Code != gateway.ErrorCodeUnsupportedAction.String() { + t.Fatalf("error code = %q, want %q", dispatchErr.Code, gateway.ErrorCodeUnsupportedAction.String()) + } + }) + + t.Run("rpc error with empty message uses fallback text", func(t *testing.T) { + dispatcher := &Dispatcher{ + resolveListenAddressFn: func(string) (string, error) { return "stub://gateway", nil }, + dialFn: func(string) (net.Conn, error) { + return &stubDispatchConn{ + readBuffer: bytes.NewBufferString(`{"jsonrpc":"2.0","id":"wake-19","error":{"code":-32603,"message":""}}` + "\n"), + }, nil + }, + requestIDFn: func() string { return "wake-19" }, + } + + _, err := dispatcher.Dispatch(context.Background(), DispatchRequest{RawURL: "neocode://review?path=README.md"}) + if err == nil { + t.Fatal("expected rpc error mapping failure") + } + var dispatchErr *DispatchError + if !errors.As(err, &dispatchErr) { + t.Fatalf("error type = %T, want *DispatchError", err) + } + if dispatchErr.Code != gateway.ErrorCodeInternalError.String() { + t.Fatalf("error code = %q, want %q", dispatchErr.Code, gateway.ErrorCodeInternalError.String()) + } + if !strings.Contains(dispatchErr.Message, "empty rpc error message") { + t.Fatalf("error message = %q, want fallback text", dispatchErr.Message) + } + }) + + t.Run("decode response frame failed", func(t *testing.T) { + dispatcher := &Dispatcher{ + resolveListenAddressFn: func(string) (string, error) { return "stub://gateway", nil }, + dialFn: func(string) (net.Conn, error) { + return &stubDispatchConn{ + readBuffer: bytes.NewBufferString(`{"jsonrpc":"2.0","id":"wake-20","result":"not-frame"}` + "\n"), + }, nil + }, + requestIDFn: func() string { return "wake-20" }, + } + + _, err := dispatcher.Dispatch(context.Background(), DispatchRequest{RawURL: "neocode://review?path=README.md"}) + if err == nil { + t.Fatal("expected decode frame failure") } var dispatchErr *DispatchError if !errors.As(err, &dispatchErr) { @@ -600,6 +764,49 @@ func TestDispatcherDispatchAdditionalErrorBranches(t *testing.T) { }) } +func TestDispatcherJSONRPCHelpers(t *testing.T) { + marshalErr := toDispatchErrorFromJSONRPC(&protocol.JSONRPCError{ + Code: protocol.JSONRPCCodeInternalError, + Message: "boom", + }) + var dispatchErr *DispatchError + if !errors.As(marshalErr, &dispatchErr) { + t.Fatalf("error type = %T, want *DispatchError", marshalErr) + } + if dispatchErr.Code != gateway.ErrorCodeInternalError.String() { + t.Fatalf("error code = %q, want %q", dispatchErr.Code, gateway.ErrorCodeInternalError.String()) + } + + emptyErr := toDispatchErrorFromJSONRPC(nil) + if !errors.As(emptyErr, &dispatchErr) { + t.Fatalf("error type = %T, want *DispatchError", emptyErr) + } + if dispatchErr.Code != ErrorCodeUnexpectedResponse { + t.Fatalf("error code = %q, want %q", dispatchErr.Code, ErrorCodeUnexpectedResponse) + } + + if mapJSONRPCCodeToDispatchCode(protocol.JSONRPCCodeMethodNotFound) != gateway.ErrorCodeUnsupportedAction.String() { + t.Fatal("method_not_found should map to unsupported_action") + } + if mapJSONRPCCodeToDispatchCode(protocol.JSONRPCCodeInvalidParams) != gateway.ErrorCodeInvalidFrame.String() { + t.Fatal("invalid_params should map to invalid_frame") + } + if mapJSONRPCCodeToDispatchCode(123456) != ErrorCodeInternal { + t.Fatal("unknown rpc code should map to internal_error") + } + + if _, err := decodeResponseFrameResult(map[string]any{"bad": make(chan int)}); err == nil { + t.Fatal("expected decodeResponseFrameResult marshal failure") + } + if _, err := decodeResponseFrameResult("not-frame"); err == nil { + t.Fatal("expected decodeResponseFrameResult unmarshal failure") + } + + if _, err := marshalJSONRawMessage(make(chan int)); err == nil { + t.Fatal("expected marshalJSONRawMessage failure") + } +} + type stubDispatchConn struct { readBuffer *bytes.Buffer writeErr error diff --git a/internal/gateway/handlers/wake.go b/internal/gateway/handlers/wake.go index a8e4df55..82ea0624 100644 --- a/internal/gateway/handlers/wake.go +++ b/internal/gateway/handlers/wake.go @@ -88,7 +88,7 @@ func isSafeReviewPath(path string) bool { if trimmed == "" { return false } - if filepath.IsAbs(trimmed) { + if filepath.IsAbs(trimmed) || strings.HasPrefix(trimmed, "/") || strings.HasPrefix(trimmed, "\\") { return false } if containsParentTraversalSegment(trimmed) { diff --git a/internal/gateway/protocol/jsonrpc.go b/internal/gateway/protocol/jsonrpc.go new file mode 100644 index 00000000..c5c5e5fb --- /dev/null +++ b/internal/gateway/protocol/jsonrpc.go @@ -0,0 +1,274 @@ +package protocol + +import ( + "bytes" + "encoding/json" + "strings" +) + +const ( + // JSONRPCVersion 表示当前网关控制面固定使用的 JSON-RPC 协议版本。 + JSONRPCVersion = "2.0" +) + +const ( + // MethodGatewayPing 表示网关探活方法。 + MethodGatewayPing = "gateway.ping" + // MethodWakeOpenURL 表示 URL Scheme 唤醒方法。 + MethodWakeOpenURL = "wake.openUrl" +) + +const ( + // JSONRPCCodeParseError 表示请求体不是合法 JSON。 + JSONRPCCodeParseError = -32700 + // JSONRPCCodeInvalidRequest 表示请求结构不符合 JSON-RPC 规范。 + JSONRPCCodeInvalidRequest = -32600 + // JSONRPCCodeMethodNotFound 表示方法未注册。 + JSONRPCCodeMethodNotFound = -32601 + // JSONRPCCodeInvalidParams 表示参数不合法。 + JSONRPCCodeInvalidParams = -32602 + // JSONRPCCodeInternalError 表示服务端内部错误。 + JSONRPCCodeInternalError = -32603 +) + +const ( + // GatewayCodeInvalidFrame 表示请求帧结构非法。 + GatewayCodeInvalidFrame = "invalid_frame" + // GatewayCodeInvalidAction 表示动作参数非法。 + GatewayCodeInvalidAction = "invalid_action" + // GatewayCodeInvalidMultimodalPayload 表示多模态负载非法。 + GatewayCodeInvalidMultimodalPayload = "invalid_multimodal_payload" + // GatewayCodeMissingRequiredField 表示缺少必填字段。 + GatewayCodeMissingRequiredField = "missing_required_field" + // GatewayCodeUnsupportedAction 表示动作尚未实现。 + GatewayCodeUnsupportedAction = "unsupported_action" + // GatewayCodeInternalError 表示网关内部错误。 + GatewayCodeInternalError = "internal_error" + // GatewayCodeUnsafePath 表示路径存在安全风险。 + GatewayCodeUnsafePath = "unsafe_path" +) + +// JSONRPCRequest 表示控制面接收到的 JSON-RPC 请求。 +type JSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id,omitempty"` + Method string `json:"method"` + Params json.RawMessage `json:"params,omitempty"` +} + +// JSONRPCResponse 表示控制面输出的 JSON-RPC 响应。 +type JSONRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id,omitempty"` + Result any `json:"result,omitempty"` + Error *JSONRPCError `json:"error,omitempty"` +} + +// JSONRPCError 表示 JSON-RPC 错误负载。 +type JSONRPCError struct { + Code int `json:"code"` + Message string `json:"message"` + Data *JSONRPCErrorData `json:"data,omitempty"` +} + +// JSONRPCErrorData 表示网关扩展错误字段。 +type JSONRPCErrorData struct { + GatewayCode string `json:"gateway_code,omitempty"` +} + +// NormalizedRequest 表示从 JSON-RPC 归一化后的内部请求模型。 +type NormalizedRequest struct { + ID json.RawMessage + RequestID string + Action string + SessionID string + Workdir string + Payload any +} + +// NormalizeJSONRPCRequest 将 JSON-RPC 请求归一化为内部请求模型,并做方法级参数解析。 +func NormalizeJSONRPCRequest(request JSONRPCRequest) (NormalizedRequest, *JSONRPCError) { + normalized := NormalizedRequest{} + + requestID, idErr := normalizeJSONRPCID(request.ID) + normalized.ID = cloneJSONRawMessage(request.ID) + normalized.RequestID = requestID + if idErr != nil { + return normalized, idErr + } + + if strings.TrimSpace(request.JSONRPC) != JSONRPCVersion { + return normalized, NewJSONRPCError( + JSONRPCCodeInvalidRequest, + "invalid jsonrpc version", + GatewayCodeInvalidFrame, + ) + } + + method := strings.TrimSpace(request.Method) + if method == "" { + return normalized, NewJSONRPCError( + JSONRPCCodeInvalidRequest, + "missing required field: method", + GatewayCodeMissingRequiredField, + ) + } + + switch method { + case MethodGatewayPing: + normalized.Action = "ping" + return normalized, nil + case MethodWakeOpenURL: + intent, parseErr := decodeWakeIntentParams(request.Params) + if parseErr != nil { + return normalized, parseErr + } + normalized.Action = MethodWakeOpenURL + normalized.SessionID = strings.TrimSpace(intent.SessionID) + normalized.Workdir = strings.TrimSpace(intent.Workdir) + normalized.Payload = intent + return normalized, nil + default: + return normalized, NewJSONRPCError( + JSONRPCCodeMethodNotFound, + "method not found", + GatewayCodeUnsupportedAction, + ) + } +} + +// NewJSONRPCResultResponse 创建 JSON-RPC 成功响应。 +func NewJSONRPCResultResponse(id json.RawMessage, result any) JSONRPCResponse { + return JSONRPCResponse{ + JSONRPC: JSONRPCVersion, + ID: cloneJSONRawMessage(id), + Result: result, + } +} + +// NewJSONRPCErrorResponse 创建 JSON-RPC 错误响应。 +func NewJSONRPCErrorResponse(id json.RawMessage, rpcError *JSONRPCError) JSONRPCResponse { + return JSONRPCResponse{ + JSONRPC: JSONRPCVersion, + ID: cloneJSONRawMessage(id), + Error: rpcError, + } +} + +// NewJSONRPCError 创建带 gateway_code 的 JSON-RPC 错误对象。 +func NewJSONRPCError(code int, message, gatewayCode string) *JSONRPCError { + errorPayload := &JSONRPCError{ + Code: code, + Message: message, + } + if strings.TrimSpace(gatewayCode) != "" { + errorPayload.Data = &JSONRPCErrorData{GatewayCode: gatewayCode} + } + return errorPayload +} + +// GatewayCodeFromJSONRPCError 从 JSON-RPC 错误负载中提取稳定 gateway_code。 +func GatewayCodeFromJSONRPCError(rpcError *JSONRPCError) string { + if rpcError == nil || rpcError.Data == nil { + return "" + } + return strings.TrimSpace(rpcError.Data.GatewayCode) +} + +// MapGatewayCodeToJSONRPCCode 将稳定网关错误码映射到 JSON-RPC 错误码。 +func MapGatewayCodeToJSONRPCCode(gatewayCode string) int { + switch strings.TrimSpace(gatewayCode) { + case GatewayCodeUnsupportedAction: + return JSONRPCCodeMethodNotFound + case GatewayCodeInvalidAction, + GatewayCodeInvalidFrame, + GatewayCodeInvalidMultimodalPayload, + GatewayCodeMissingRequiredField, + GatewayCodeUnsafePath: + return JSONRPCCodeInvalidParams + case GatewayCodeInternalError: + return JSONRPCCodeInternalError + default: + return JSONRPCCodeInternalError + } +} + +// normalizeJSONRPCID 校验并提取请求 ID,确保控制面请求具备可关联标识。 +func normalizeJSONRPCID(id json.RawMessage) (string, *JSONRPCError) { + trimmed := bytes.TrimSpace(id) + if len(trimmed) == 0 || bytes.Equal(trimmed, []byte("null")) { + return "", NewJSONRPCError( + JSONRPCCodeInvalidRequest, + "missing required field: id", + GatewayCodeMissingRequiredField, + ) + } + + if trimmed[0] == '"' { + var decoded string + if err := json.Unmarshal(trimmed, &decoded); err != nil { + return "", NewJSONRPCError( + JSONRPCCodeInvalidRequest, + "invalid field: id", + GatewayCodeInvalidFrame, + ) + } + decoded = strings.TrimSpace(decoded) + if decoded == "" { + return "", NewJSONRPCError( + JSONRPCCodeInvalidRequest, + "invalid field: id", + GatewayCodeInvalidFrame, + ) + } + return decoded, nil + } + + identifier := strings.TrimSpace(string(trimmed)) + if identifier == "" { + return "", NewJSONRPCError( + JSONRPCCodeInvalidRequest, + "invalid field: id", + GatewayCodeInvalidFrame, + ) + } + return identifier, nil +} + +// decodeWakeIntentParams 对 wake.openUrl 的 params 执行延迟反序列化与最小校验。 +func decodeWakeIntentParams(raw json.RawMessage) (WakeIntent, *JSONRPCError) { + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 || bytes.Equal(trimmed, []byte("null")) { + return WakeIntent{}, NewJSONRPCError( + JSONRPCCodeInvalidParams, + "missing required field: params", + GatewayCodeMissingRequiredField, + ) + } + + var intent WakeIntent + if err := json.Unmarshal(trimmed, &intent); err != nil { + return WakeIntent{}, NewJSONRPCError( + JSONRPCCodeInvalidParams, + "invalid params for wake.openUrl", + GatewayCodeInvalidFrame, + ) + } + intent.Action = strings.ToLower(strings.TrimSpace(intent.Action)) + intent.SessionID = strings.TrimSpace(intent.SessionID) + intent.Workdir = strings.TrimSpace(intent.Workdir) + if len(intent.Params) == 0 { + intent.Params = nil + } + return intent, nil +} + +// cloneJSONRawMessage 复制 RawMessage,避免共享底层切片导致的并发风险。 +func cloneJSONRawMessage(raw json.RawMessage) json.RawMessage { + if len(raw) == 0 { + return nil + } + cloned := make([]byte, len(raw)) + copy(cloned, raw) + return json.RawMessage(cloned) +} diff --git a/internal/gateway/protocol/jsonrpc_test.go b/internal/gateway/protocol/jsonrpc_test.go new file mode 100644 index 00000000..6fa58d31 --- /dev/null +++ b/internal/gateway/protocol/jsonrpc_test.go @@ -0,0 +1,174 @@ +package protocol + +import ( + "encoding/json" + "testing" +) + +func TestNormalizeJSONRPCRequestPing(t *testing.T) { + normalized, rpcErr := NormalizeJSONRPCRequest(JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`"ping-1"`), + Method: MethodGatewayPing, + Params: json.RawMessage(`{}`), + }) + if rpcErr != nil { + t.Fatalf("normalize ping request: %v", rpcErr) + } + if normalized.RequestID != "ping-1" { + t.Fatalf("request_id = %q, want %q", normalized.RequestID, "ping-1") + } + if normalized.Action != "ping" { + t.Fatalf("action = %q, want %q", normalized.Action, "ping") + } +} + +func TestNormalizeJSONRPCRequestWakeOpenURL(t *testing.T) { + normalized, rpcErr := NormalizeJSONRPCRequest(JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`"wake-1"`), + Method: MethodWakeOpenURL, + Params: json.RawMessage(`{ + "action":"review", + "session_id":"session-1", + "workdir":"/tmp/repo", + "params":{"path":"README.md"} + }`), + }) + if rpcErr != nil { + t.Fatalf("normalize wake request: %v", rpcErr) + } + if normalized.Action != MethodWakeOpenURL { + t.Fatalf("action = %q, want %q", normalized.Action, MethodWakeOpenURL) + } + if normalized.SessionID != "session-1" { + t.Fatalf("session_id = %q, want %q", normalized.SessionID, "session-1") + } + if normalized.Workdir != "/tmp/repo" { + t.Fatalf("workdir = %q, want %q", normalized.Workdir, "/tmp/repo") + } + intent, ok := normalized.Payload.(WakeIntent) + if !ok { + t.Fatalf("payload type = %T, want WakeIntent", normalized.Payload) + } + if intent.Params["path"] != "README.md" { + t.Fatalf("intent.params[path] = %q, want %q", intent.Params["path"], "README.md") + } +} + +func TestNormalizeJSONRPCRequestErrors(t *testing.T) { + testCases := []struct { + name string + request JSONRPCRequest + wantCode int + wantGatewayCode string + }{ + { + name: "missing id", + request: JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + Method: MethodGatewayPing, + }, + wantCode: JSONRPCCodeInvalidRequest, + wantGatewayCode: GatewayCodeMissingRequiredField, + }, + { + name: "invalid version", + request: JSONRPCRequest{ + JSONRPC: "1.0", + ID: json.RawMessage(`"x"`), + Method: MethodGatewayPing, + }, + wantCode: JSONRPCCodeInvalidRequest, + wantGatewayCode: GatewayCodeInvalidFrame, + }, + { + name: "missing method", + request: JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`"x"`), + }, + wantCode: JSONRPCCodeInvalidRequest, + wantGatewayCode: GatewayCodeMissingRequiredField, + }, + { + name: "method not found", + request: JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`"x"`), + Method: "gateway.unknown", + }, + wantCode: JSONRPCCodeMethodNotFound, + wantGatewayCode: GatewayCodeUnsupportedAction, + }, + { + name: "wake missing params", + request: JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`"x"`), + Method: MethodWakeOpenURL, + }, + wantCode: JSONRPCCodeInvalidParams, + wantGatewayCode: GatewayCodeMissingRequiredField, + }, + { + name: "wake invalid params", + request: JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`"x"`), + Method: MethodWakeOpenURL, + Params: json.RawMessage(`{invalid}`), + }, + wantCode: JSONRPCCodeInvalidParams, + wantGatewayCode: GatewayCodeInvalidFrame, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + _, rpcErr := NormalizeJSONRPCRequest(tc.request) + if rpcErr == nil { + t.Fatal("expected rpc error") + } + if rpcErr.Code != tc.wantCode { + t.Fatalf("rpc code = %d, want %d", rpcErr.Code, tc.wantCode) + } + if gatewayCode := GatewayCodeFromJSONRPCError(rpcErr); gatewayCode != tc.wantGatewayCode { + t.Fatalf("gateway_code = %q, want %q", gatewayCode, tc.wantGatewayCode) + } + }) + } +} + +func TestJSONRPCHelpers(t *testing.T) { + response := NewJSONRPCResultResponse(json.RawMessage(`"req-1"`), map[string]string{"message": "ok"}) + if response.JSONRPC != JSONRPCVersion { + t.Fatalf("jsonrpc = %q, want %q", response.JSONRPC, JSONRPCVersion) + } + if string(response.ID) != `"req-1"` { + t.Fatalf("id = %s, want %s", response.ID, `"req-1"`) + } + + rpcErr := NewJSONRPCError(JSONRPCCodeInternalError, "boom", GatewayCodeInternalError) + errorResponse := NewJSONRPCErrorResponse(json.RawMessage(`"req-2"`), rpcErr) + if errorResponse.Error == nil { + t.Fatal("error response should include rpc error payload") + } + if GatewayCodeFromJSONRPCError(errorResponse.Error) != GatewayCodeInternalError { + t.Fatalf("gateway_code = %q, want %q", GatewayCodeFromJSONRPCError(errorResponse.Error), GatewayCodeInternalError) + } + if GatewayCodeFromJSONRPCError(nil) != "" { + t.Fatal("gateway_code for nil rpc error should be empty") + } + + if MapGatewayCodeToJSONRPCCode(GatewayCodeUnsupportedAction) != JSONRPCCodeMethodNotFound { + t.Fatal("unsupported_action should map to method_not_found") + } + if MapGatewayCodeToJSONRPCCode(GatewayCodeInvalidAction) != JSONRPCCodeInvalidParams { + t.Fatal("invalid_action should map to invalid_params") + } + if MapGatewayCodeToJSONRPCCode("unknown") != JSONRPCCodeInternalError { + t.Fatal("unknown code should map to internal_error") + } +} diff --git a/internal/gateway/server.go b/internal/gateway/server.go index 718dc430..8871873e 100644 --- a/internal/gateway/server.go +++ b/internal/gateway/server.go @@ -14,6 +14,7 @@ import ( "sync" "time" + "neo-code/internal/gateway/protocol" "neo-code/internal/gateway/transport" ) @@ -261,7 +262,7 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn, runtimePor return } - frame, err := decodeFrame(reader) + rpcRequest, err := decodeRPCRequest(reader) if err != nil { if errors.Is(err, io.EOF) { return @@ -275,20 +276,31 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn, runtimePor } if errors.Is(err, errFrameTooLarge) { s.logger.Printf("decode frame failed: %v", err) - _ = s.writeFrame(conn, encoder, errorFrame(MessageFrame{}, NewFrameError( - ErrorCodeInvalidFrame, - fmt.Sprintf("frame exceeds max size %d bytes", MaxFrameSize), - ))) + _ = s.writeRPCResponse(conn, encoder, protocol.NewJSONRPCErrorResponse( + nil, + protocol.NewJSONRPCError( + protocol.JSONRPCCodeInvalidParams, + fmt.Sprintf("frame exceeds max size %d bytes", MaxFrameSize), + protocol.GatewayCodeInvalidFrame, + ), + )) return } s.logger.Printf("decode frame failed: %v", err) - _ = s.writeFrame(conn, encoder, errorFrame(MessageFrame{}, NewFrameError(ErrorCodeInvalidFrame, "invalid json frame"))) + _ = s.writeRPCResponse(conn, encoder, protocol.NewJSONRPCErrorResponse( + nil, + protocol.NewJSONRPCError( + protocol.JSONRPCCodeParseError, + "invalid json-rpc request", + protocol.GatewayCodeInvalidFrame, + ), + )) return } - response := s.dispatchFrame(ctx, frame, runtimePort) - if !s.writeFrame(conn, encoder, response) { + rpcResponse := s.dispatchRPCRequest(ctx, rpcRequest, runtimePort) + if !s.writeRPCResponse(conn, encoder, rpcResponse) { return } } @@ -310,13 +322,13 @@ func (s *Server) applyWriteDeadline(conn net.Conn) error { return conn.SetWriteDeadline(time.Now().Add(s.writeTimeout)) } -// writeFrame 统一处理响应写回及写超时设置,失败时返回 false 供上层快速终止连接循环。 -func (s *Server) writeFrame(conn net.Conn, encoder *json.Encoder, frame MessageFrame) bool { +// writeRPCResponse 统一处理 JSON-RPC 响应写回及写超时设置,失败时返回 false 供上层快速终止连接循环。 +func (s *Server) writeRPCResponse(conn net.Conn, encoder *json.Encoder, response protocol.JSONRPCResponse) bool { if err := s.applyWriteDeadline(conn); err != nil { s.logger.Printf("set write deadline failed: %v", err) return false } - if err := encoder.Encode(frame); err != nil { + if err := encoder.Encode(response); err != nil { s.logger.Printf("write frame failed: %v", err) return false } @@ -329,27 +341,66 @@ func isTimeoutError(err error) bool { return errors.As(err, &netErr) && netErr.Timeout() } -// decodeFrame 从连接读取一条 JSON 帧并执行长度与格式校验。 -func decodeFrame(reader *bufio.Reader) (MessageFrame, error) { +// decodeRPCRequest 从连接读取一条 JSON-RPC 请求并执行长度与格式校验。 +func decodeRPCRequest(reader *bufio.Reader) (protocol.JSONRPCRequest, error) { payload, err := readFramePayload(reader, MaxFrameSize) if err != nil { - return MessageFrame{}, err + return protocol.JSONRPCRequest{}, err } limitedReader := &io.LimitedReader{R: bytes.NewReader(payload), N: MaxFrameSize} decoder := json.NewDecoder(limitedReader) - var frame MessageFrame - if err := decoder.Decode(&frame); err != nil { - return MessageFrame{}, err + var request protocol.JSONRPCRequest + if err := decoder.Decode(&request); err != nil { + return protocol.JSONRPCRequest{}, err } var trailing any if err := decoder.Decode(&trailing); !errors.Is(err, io.EOF) { - return MessageFrame{}, fmt.Errorf("frame contains trailing json values") + return protocol.JSONRPCRequest{}, fmt.Errorf("frame contains trailing json values") } - return frame, nil + return request, nil +} + +// dispatchRPCRequest 将 JSON-RPC 请求归一化为 MessageFrame,并复用既有分发逻辑处理。 +func (s *Server) dispatchRPCRequest( + ctx context.Context, + request protocol.JSONRPCRequest, + runtimePort RuntimePort, +) protocol.JSONRPCResponse { + normalized, rpcErr := protocol.NormalizeJSONRPCRequest(request) + if rpcErr != nil { + return protocol.NewJSONRPCErrorResponse(normalized.ID, rpcErr) + } + + frame := MessageFrame{ + Type: FrameTypeRequest, + Action: FrameAction(normalized.Action), + RequestID: normalized.RequestID, + SessionID: normalized.SessionID, + Workdir: normalized.Workdir, + Payload: normalized.Payload, + } + + responseFrame := s.dispatchFrame(ctx, frame, runtimePort) + if responseFrame.Type != FrameTypeError { + return protocol.NewJSONRPCResultResponse(normalized.ID, responseFrame) + } + + frameErr := responseFrame.Error + if frameErr == nil { + frameErr = NewFrameError(ErrorCodeInternalError, "gateway response missing error payload") + } + return protocol.NewJSONRPCErrorResponse( + normalized.ID, + protocol.NewJSONRPCError( + protocol.MapGatewayCodeToJSONRPCCode(frameErr.Code), + frameErr.Message, + frameErr.Code, + ), + ) } // readFramePayload 按换行边界读取单条帧,并限制单帧最大字节数。 diff --git a/internal/gateway/server_additional_test.go b/internal/gateway/server_additional_test.go index d5eea3d5..080b58e3 100644 --- a/internal/gateway/server_additional_test.go +++ b/internal/gateway/server_additional_test.go @@ -12,6 +12,8 @@ import ( "sync" "testing" "time" + + "neo-code/internal/gateway/protocol" ) func TestNewServerUsesDefaultsAndOverrides(t *testing.T) { @@ -236,9 +238,9 @@ func TestCloseReturnsContextErrorWhenWaitCanceled(t *testing.T) { server.wg.Done() } -func TestDecodeFrameTrailingJSON(t *testing.T) { - reader := bufio.NewReader(strings.NewReader(`{"type":"request","action":"ping"} {"extra":1}` + "\n")) - _, err := decodeFrame(reader) +func TestDecodeRPCRequestTrailingJSON(t *testing.T) { + reader := bufio.NewReader(strings.NewReader(`{"jsonrpc":"2.0","id":"x","method":"gateway.ping"} {"extra":1}` + "\n")) + _, err := decodeRPCRequest(reader) if err == nil || !strings.Contains(err.Error(), "trailing") { t.Fatalf("expected trailing json error, got %v", err) } @@ -289,6 +291,56 @@ func TestDispatchFrameValidationError(t *testing.T) { } } +func TestDispatchRPCRequestNormalizeError(t *testing.T) { + server := &Server{} + response := server.dispatchRPCRequest(context.Background(), protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + Method: protocol.MethodGatewayPing, + }, nil) + if response.Error == nil { + t.Fatal("expected rpc normalize error") + } + if response.Error.Code != protocol.JSONRPCCodeInvalidRequest { + t.Fatalf("rpc error code = %d, want %d", response.Error.Code, protocol.JSONRPCCodeInvalidRequest) + } + if gatewayCode := protocol.GatewayCodeFromJSONRPCError(response.Error); gatewayCode != ErrorCodeMissingRequiredField.String() { + t.Fatalf("gateway_code = %q, want %q", gatewayCode, ErrorCodeMissingRequiredField.String()) + } +} + +func TestDispatchRPCRequestConvertsFrameErrorWithoutPayload(t *testing.T) { + server := &Server{} + originalHandlers := requestFrameHandlers + requestFrameHandlers = map[FrameAction]requestFrameHandler{ + FrameActionPing: func(frame MessageFrame) MessageFrame { + return MessageFrame{ + Type: FrameTypeError, + Action: frame.Action, + RequestID: frame.RequestID, + Error: nil, + } + }, + } + t.Cleanup(func() { + requestFrameHandlers = originalHandlers + }) + + response := server.dispatchRPCRequest(context.Background(), protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + ID: json.RawMessage(`"rpc-err-1"`), + Method: protocol.MethodGatewayPing, + }, nil) + if response.Error == nil { + t.Fatal("expected rpc error response") + } + if response.Error.Code != protocol.JSONRPCCodeInternalError { + t.Fatalf("rpc error code = %d, want %d", response.Error.Code, protocol.JSONRPCCodeInternalError) + } + if gatewayCode := protocol.GatewayCodeFromJSONRPCError(response.Error); gatewayCode != ErrorCodeInternalError.String() { + t.Fatalf("gateway_code = %q, want %q", gatewayCode, ErrorCodeInternalError.String()) + } +} + func TestServerHandleConnectionSkipsEmptyFrame(t *testing.T) { server := &Server{logger: log.New(io.Discard, "", 0)} serverConn, clientConn := net.Pipe() @@ -299,15 +351,22 @@ func TestServerHandleConnectionSkipsEmptyFrame(t *testing.T) { }() _, _ = io.WriteString(clientConn, "\n") - _, _ = io.WriteString(clientConn, `{"type":"request","action":"ping","request_id":"empty-then-ping"}`+"\n") + _, _ = io.WriteString(clientConn, `{"jsonrpc":"2.0","id":"empty-then-ping","method":"gateway.ping","params":{}}`+"\n") decoder := json.NewDecoder(clientConn) - var response MessageFrame + var response protocol.JSONRPCResponse if err := decoder.Decode(&response); err != nil { t.Fatalf("decode response: %v", err) } - if response.Type != FrameTypeAck || response.Action != FrameActionPing { - t.Fatalf("unexpected response after empty frame: %#v", response) + if response.Error != nil { + t.Fatalf("unexpected rpc error after empty frame: %+v", response.Error) + } + resultFrame, err := decodeJSONRPCResultFrame(response) + if err != nil { + t.Fatalf("decode result frame: %v", err) + } + if resultFrame.Type != FrameTypeAck || resultFrame.Action != FrameActionPing { + t.Fatalf("unexpected response after empty frame: %#v", resultFrame) } _ = clientConn.Close() @@ -329,15 +388,18 @@ func TestServerHandleConnectionInvalidJSONFrame(t *testing.T) { _, _ = io.WriteString(clientConn, "{invalid-json}\n") decoder := json.NewDecoder(clientConn) - var response MessageFrame + var response protocol.JSONRPCResponse if err := decoder.Decode(&response); err != nil { t.Fatalf("decode response: %v", err) } - if response.Type != FrameTypeError { - t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) + if response.Error == nil { + t.Fatal("response rpc error is nil") } - if response.Error == nil || response.Error.Code != ErrorCodeInvalidFrame.String() { - t.Fatalf("response error = %#v, want invalid frame", response.Error) + if response.Error.Code != protocol.JSONRPCCodeParseError { + t.Fatalf("rpc error code = %d, want %d", response.Error.Code, protocol.JSONRPCCodeParseError) + } + if gatewayCode := protocol.GatewayCodeFromJSONRPCError(response.Error); gatewayCode != ErrorCodeInvalidFrame.String() { + t.Fatalf("gateway_code = %q, want %q", gatewayCode, ErrorCodeInvalidFrame.String()) } _ = clientConn.Close() @@ -412,7 +474,7 @@ func TestServerHandleConnectionWriteTimeoutClosesConnection(t *testing.T) { server.handleConnection(context.Background(), serverConn, nil) }() - _, err := io.WriteString(clientConn, `{"type":"request","action":"ping","request_id":"write-timeout"}`+"\n") + _, err := io.WriteString(clientConn, `{"jsonrpc":"2.0","id":"write-timeout","method":"gateway.ping","params":{}}`+"\n") if err != nil { t.Fatalf("write request: %v", err) } diff --git a/internal/gateway/server_test.go b/internal/gateway/server_test.go index aa7cae26..2fedea0d 100644 --- a/internal/gateway/server_test.go +++ b/internal/gateway/server_test.go @@ -11,6 +11,8 @@ import ( "strings" "testing" "time" + + "neo-code/internal/gateway/protocol" ) func TestServerHandleConnectionPing(t *testing.T) { @@ -28,32 +30,40 @@ func TestServerHandleConnectionPing(t *testing.T) { encoder := json.NewEncoder(clientConn) decoder := json.NewDecoder(clientConn) - if err := encoder.Encode(MessageFrame{ - Type: FrameTypeRequest, - Action: FrameActionPing, - RequestID: "req-1", + if err := encoder.Encode(protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + ID: json.RawMessage(`"req-1"`), + Method: protocol.MethodGatewayPing, + Params: json.RawMessage(`{}`), }); err != nil { t.Fatalf("encode request: %v", err) } - var response MessageFrame + var response protocol.JSONRPCResponse if err := decoder.Decode(&response); err != nil { t.Fatalf("decode response: %v", err) } + if response.Error != nil { + t.Fatalf("unexpected rpc error: %+v", response.Error) + } + resultFrame, err := decodeJSONRPCResultFrame(response) + if err != nil { + t.Fatalf("decode result frame: %v", err) + } - if response.Type != FrameTypeAck { - t.Fatalf("response type = %q, want %q", response.Type, FrameTypeAck) + if resultFrame.Type != FrameTypeAck { + t.Fatalf("response type = %q, want %q", resultFrame.Type, FrameTypeAck) } - if response.Action != FrameActionPing { - t.Fatalf("response action = %q, want %q", response.Action, FrameActionPing) + if resultFrame.Action != FrameActionPing { + t.Fatalf("response action = %q, want %q", resultFrame.Action, FrameActionPing) } - if response.RequestID != "req-1" { - t.Fatalf("response request_id = %q, want %q", response.RequestID, "req-1") + if resultFrame.RequestID != "req-1" { + t.Fatalf("response request_id = %q, want %q", resultFrame.RequestID, "req-1") } - payloadMap, ok := response.Payload.(map[string]any) + payloadMap, ok := resultFrame.Payload.(map[string]any) if !ok { - t.Fatalf("response payload type = %T, want map[string]any", response.Payload) + t.Fatalf("response payload type = %T, want map[string]any", resultFrame.Payload) } if got, _ := payloadMap["message"].(string); got != "pong" { t.Fatalf("response payload message = %q, want %q", got, "pong") @@ -82,28 +92,27 @@ func TestServerHandleConnectionUnsupportedAction(t *testing.T) { encoder := json.NewEncoder(clientConn) decoder := json.NewDecoder(clientConn) - if err := encoder.Encode(MessageFrame{ - Type: FrameTypeRequest, - Action: FrameActionRun, - RequestID: "req-2", - InputText: "hello", + if err := encoder.Encode(protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + ID: json.RawMessage(`"req-2"`), + Method: "gateway.run", + Params: json.RawMessage(`{"input_text":"hello"}`), }); err != nil { t.Fatalf("encode request: %v", err) } - var response MessageFrame + var response protocol.JSONRPCResponse if err := decoder.Decode(&response); err != nil { t.Fatalf("decode response: %v", err) } - - if response.Type != FrameTypeError { - t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) - } if response.Error == nil { - t.Fatal("response error is nil") + t.Fatal("response rpc error is nil") + } + if response.Error.Code != protocol.JSONRPCCodeMethodNotFound { + t.Fatalf("rpc error code = %d, want %d", response.Error.Code, protocol.JSONRPCCodeMethodNotFound) } - if response.Error.Code != ErrorCodeUnsupportedAction.String() { - t.Fatalf("error code = %q, want %q", response.Error.Code, ErrorCodeUnsupportedAction.String()) + if gatewayCode := protocol.GatewayCodeFromJSONRPCError(response.Error); gatewayCode != ErrorCodeUnsupportedAction.String() { + t.Fatalf("gateway_code = %q, want %q", gatewayCode, ErrorCodeUnsupportedAction.String()) } _ = clientConn.Close() @@ -129,7 +138,7 @@ func TestServerHandleConnectionRejectsOversizedFrame(t *testing.T) { decoder := json.NewDecoder(clientConn) oversizedPayload := strings.Repeat("a", int(MaxFrameSize)+128) requestFrame := fmt.Sprintf( - `{"type":"request","action":"ping","request_id":"req-oversize","input_text":"%s"}`+"\n", + `{"jsonrpc":"2.0","id":"req-oversize","method":"gateway.ping","params":{"input_text":"%s"}}`+"\n", oversizedPayload, ) @@ -139,18 +148,18 @@ func TestServerHandleConnectionRejectsOversizedFrame(t *testing.T) { writeDone <- err }() - var response MessageFrame + var response protocol.JSONRPCResponse if err := decoder.Decode(&response); err != nil { t.Fatalf("decode oversized response: %v", err) } - if response.Type != FrameTypeError { - t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) - } if response.Error == nil { - t.Fatal("response error is nil") + t.Fatal("response rpc error is nil") } - if response.Error.Code != ErrorCodeInvalidFrame.String() { - t.Fatalf("error code = %q, want %q", response.Error.Code, ErrorCodeInvalidFrame.String()) + if response.Error.Code != protocol.JSONRPCCodeInvalidParams { + t.Fatalf("rpc error code = %d, want %d", response.Error.Code, protocol.JSONRPCCodeInvalidParams) + } + if gatewayCode := protocol.GatewayCodeFromJSONRPCError(response.Error); gatewayCode != ErrorCodeInvalidFrame.String() { + t.Fatalf("gateway_code = %q, want %q", gatewayCode, ErrorCodeInvalidFrame.String()) } if !strings.Contains(response.Error.Message, "frame exceeds max size") { t.Fatalf("error message = %q, want contains %q", response.Error.Message, "frame exceeds max size") @@ -189,3 +198,18 @@ func TestServerHandleConnectionRejectsOversizedFrame(t *testing.T) { t.Fatal("handleConnection did not exit") } } + +func decodeJSONRPCResultFrame(response protocol.JSONRPCResponse) (MessageFrame, error) { + if response.Result == nil { + return MessageFrame{}, errors.New("rpc result is nil") + } + raw, err := json.Marshal(response.Result) + if err != nil { + return MessageFrame{}, err + } + var frame MessageFrame + if err := json.Unmarshal(raw, &frame); err != nil { + return MessageFrame{}, err + } + return frame, nil +} From 9ee41ed4f3fd8d1fac2830a9f8cb641c931eb051 Mon Sep 17 00:00:00 2001 From: pionxe Date: Wed, 15 Apr 2026 12:02:52 +0800 Subject: [PATCH 15/33] =?UTF-8?q?fix:=E4=BF=AE=E5=A4=8D=20JSON-RPC=20ID=20?= =?UTF-8?q?=E6=A0=A1=E9=AA=8C=E4=B8=8D=E4=B8=A5=E7=9A=84=E9=97=AE=E9=A2=98?= =?UTF-8?q?=E3=80=81Windows=20=E8=B7=AF=E5=BE=84=E6=B3=A8=E5=85=A5?= =?UTF-8?q?=E9=98=B2=E5=BE=A1=E5=A2=9E=E5=BC=BA=E5=92=8C=E4=BC=98=E5=8C=96?= =?UTF-8?q?=20Result=20=E8=A7=A3=E7=A0=81=E7=9A=84=E6=80=A7=E8=83=BD?= =?UTF-8?q?=EF=BC=8C=E7=A7=BB=E9=99=A4=E5=A4=9A=E4=BD=99=E7=9A=84=E5=86=85?= =?UTF-8?q?=E5=AD=98=E5=88=86=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../gateway/adapters/urlscheme/dispatcher.go | 8 +-- .../adapters/urlscheme/dispatcher_test.go | 30 +++++--- internal/gateway/handlers/wake.go | 12 ++++ internal/gateway/handlers/wake_test.go | 6 ++ internal/gateway/protocol/jsonrpc.go | 50 ++++++++----- internal/gateway/protocol/jsonrpc_test.go | 70 ++++++++++++++++++- internal/gateway/server.go | 6 +- internal/gateway/server_test.go | 6 +- 8 files changed, 148 insertions(+), 40 deletions(-) diff --git a/internal/gateway/adapters/urlscheme/dispatcher.go b/internal/gateway/adapters/urlscheme/dispatcher.go index 74ab7ccd..32fa9d2f 100644 --- a/internal/gateway/adapters/urlscheme/dispatcher.go +++ b/internal/gateway/adapters/urlscheme/dispatcher.go @@ -276,13 +276,9 @@ func mapJSONRPCCodeToDispatchCode(code int) string { } // decodeResponseFrameResult 将 JSON-RPC result 安全解码回 MessageFrame。 -func decodeResponseFrameResult(result any) (gateway.MessageFrame, error) { - raw, err := json.Marshal(result) - if err != nil { - return gateway.MessageFrame{}, err - } +func decodeResponseFrameResult(result json.RawMessage) (gateway.MessageFrame, error) { var frame gateway.MessageFrame - if err := json.Unmarshal(raw, &frame); err != nil { + if err := json.Unmarshal(result, &frame); err != nil { return gateway.MessageFrame{}, err } return frame, nil diff --git a/internal/gateway/adapters/urlscheme/dispatcher_test.go b/internal/gateway/adapters/urlscheme/dispatcher_test.go index e52569c6..1fd9d3d0 100644 --- a/internal/gateway/adapters/urlscheme/dispatcher_test.go +++ b/internal/gateway/adapters/urlscheme/dispatcher_test.go @@ -54,14 +54,14 @@ func TestDispatcherDispatchSuccess(t *testing.T) { if err := encoder.Encode(protocol.JSONRPCResponse{ JSONRPC: protocol.JSONRPCVersion, ID: rpcRequest.ID, - Result: gateway.MessageFrame{ + Result: mustMarshalRawJSON(t, gateway.MessageFrame{ Type: gateway.FrameTypeAck, Action: gateway.FrameActionWakeOpenURL, RequestID: "wake-1", Payload: map[string]any{ "message": "wake intent accepted", }, - }, + }), }); err != nil { t.Errorf("encode response rpc: %v", err) } @@ -149,11 +149,11 @@ func TestDispatcherDispatchReturnsUnexpectedResponseError(t *testing.T) { _ = encoder.Encode(protocol.JSONRPCResponse{ JSONRPC: protocol.JSONRPCVersion, ID: rpcRequest.ID, - Result: gateway.MessageFrame{ + Result: mustMarshalRawJSON(t, gateway.MessageFrame{ Type: gateway.FrameTypeEvent, Action: gateway.FrameActionWakeOpenURL, RequestID: "wake-3", - }, + }), }) }() @@ -193,11 +193,11 @@ func TestDispatcherDispatchReturnsCorrelationMismatchError(t *testing.T) { _ = encoder.Encode(protocol.JSONRPCResponse{ JSONRPC: protocol.JSONRPCVersion, ID: rpcRequest.ID, - Result: gateway.MessageFrame{ + Result: mustMarshalRawJSON(t, gateway.MessageFrame{ Type: gateway.FrameTypeAck, Action: gateway.FrameActionWakeOpenURL, RequestID: "wake-mismatch", - }, + }), }) }() @@ -795,12 +795,12 @@ func TestDispatcherJSONRPCHelpers(t *testing.T) { t.Fatal("unknown rpc code should map to internal_error") } - if _, err := decodeResponseFrameResult(map[string]any{"bad": make(chan int)}); err == nil { - t.Fatal("expected decodeResponseFrameResult marshal failure") - } - if _, err := decodeResponseFrameResult("not-frame"); err == nil { + if _, err := decodeResponseFrameResult(json.RawMessage(`"not-frame"`)); err == nil { t.Fatal("expected decodeResponseFrameResult unmarshal failure") } + if _, err := decodeResponseFrameResult(json.RawMessage(`{"type":"ack","action":"wake.openUrl","request_id":"x"`)); err == nil { + t.Fatal("expected decodeResponseFrameResult malformed json failure") + } if _, err := marshalJSONRawMessage(make(chan int)); err == nil { t.Fatal("expected marshalJSONRawMessage failure") @@ -868,3 +868,13 @@ func (a stubDispatchAddr) Network() string { func (a stubDispatchAddr) String() string { return string(a) } + +func mustMarshalRawJSON(t *testing.T, payload any) json.RawMessage { + t.Helper() + + raw, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal raw json: %v", err) + } + return json.RawMessage(raw) +} diff --git a/internal/gateway/handlers/wake.go b/internal/gateway/handlers/wake.go index 82ea0624..96ce6246 100644 --- a/internal/gateway/handlers/wake.go +++ b/internal/gateway/handlers/wake.go @@ -88,6 +88,12 @@ func isSafeReviewPath(path string) bool { if trimmed == "" { return false } + if filepath.VolumeName(trimmed) != "" { + return false + } + if hasBlockedWindowsPathPrefix(trimmed) { + return false + } if filepath.IsAbs(trimmed) || strings.HasPrefix(trimmed, "/") || strings.HasPrefix(trimmed, "\\") { return false } @@ -104,6 +110,12 @@ func isSafeReviewPath(path string) bool { return true } +// hasBlockedWindowsPathPrefix 检查是否命中 Windows 底层设备路径前缀,避免绕过常规路径校验。 +func hasBlockedWindowsPathPrefix(path string) bool { + normalized := strings.ReplaceAll(strings.TrimSpace(path), "/", "\\") + return strings.HasPrefix(normalized, `\\?\`) || strings.HasPrefix(normalized, `\\.\`) +} + // containsParentTraversalSegment 按路径段语义识别目录回退段,避免子串匹配导致误伤。 func containsParentTraversalSegment(path string) bool { normalized := normalizePath(path) diff --git a/internal/gateway/handlers/wake_test.go b/internal/gateway/handlers/wake_test.go index f5184a26..3941bf6b 100644 --- a/internal/gateway/handlers/wake_test.go +++ b/internal/gateway/handlers/wake_test.go @@ -59,6 +59,9 @@ func TestWakeOpenURLHandlerHandleUnsafePath(t *testing.T) { "../../etc/passwd", "/etc/passwd", "..\\Windows\\system32", + "C:foo", + `\\?\C:\Windows\System32`, + `\\.\pipe\neocode`, } handler := NewWakeOpenURLHandler() @@ -90,6 +93,9 @@ func TestIsSafeReviewPath(t *testing.T) { {name: "parent traversal", path: "../secret.txt", want: false}, {name: "parent traversal nested", path: "a/../../secret.txt", want: false}, {name: "absolute unix path", path: "/tmp/file", want: false}, + {name: "windows drive relative path", path: "C:foo", want: false}, + {name: "windows device path namespace", path: `\\?\C:\tmp\file`, want: false}, + {name: "windows device pipe namespace", path: `\\.\pipe\name`, want: false}, {name: "empty", path: "", want: false}, {name: "dot current dir", path: ".", want: false}, } diff --git a/internal/gateway/protocol/jsonrpc.go b/internal/gateway/protocol/jsonrpc.go index c5c5e5fb..66773148 100644 --- a/internal/gateway/protocol/jsonrpc.go +++ b/internal/gateway/protocol/jsonrpc.go @@ -60,7 +60,7 @@ type JSONRPCRequest struct { type JSONRPCResponse struct { JSONRPC string `json:"jsonrpc"` ID json.RawMessage `json:"id,omitempty"` - Result any `json:"result,omitempty"` + Result json.RawMessage `json:"result,omitempty"` Error *JSONRPCError `json:"error,omitempty"` } @@ -137,13 +137,22 @@ func NormalizeJSONRPCRequest(request JSONRPCRequest) (NormalizedRequest, *JSONRP } } -// NewJSONRPCResultResponse 创建 JSON-RPC 成功响应。 -func NewJSONRPCResultResponse(id json.RawMessage, result any) JSONRPCResponse { +// NewJSONRPCResultResponse 创建 JSON-RPC 成功响应,并将 result 编码为 RawMessage。 +func NewJSONRPCResultResponse(id json.RawMessage, result any) (JSONRPCResponse, *JSONRPCError) { + rawResult, err := json.Marshal(result) + if err != nil { + return JSONRPCResponse{}, NewJSONRPCError( + JSONRPCCodeInternalError, + "failed to encode jsonrpc result", + GatewayCodeInternalError, + ) + } + return JSONRPCResponse{ JSONRPC: JSONRPCVersion, ID: cloneJSONRawMessage(id), - Result: result, - } + Result: json.RawMessage(rawResult), + }, nil } // NewJSONRPCErrorResponse 创建 JSON-RPC 错误响应。 @@ -204,35 +213,44 @@ func normalizeJSONRPCID(id json.RawMessage) (string, *JSONRPCError) { ) } - if trimmed[0] == '"' { - var decoded string - if err := json.Unmarshal(trimmed, &decoded); err != nil { + var decoded any + if err := json.Unmarshal(trimmed, &decoded); err != nil { + return "", NewJSONRPCError( + JSONRPCCodeInvalidRequest, + "invalid field: id", + GatewayCodeInvalidFrame, + ) + } + + switch typedID := decoded.(type) { + case string: + typedID = strings.TrimSpace(typedID) + if typedID == "" { return "", NewJSONRPCError( JSONRPCCodeInvalidRequest, "invalid field: id", GatewayCodeInvalidFrame, ) } - decoded = strings.TrimSpace(decoded) - if decoded == "" { + return typedID, nil + case float64: + identifier := strings.TrimSpace(string(trimmed)) + if identifier == "" { return "", NewJSONRPCError( JSONRPCCodeInvalidRequest, "invalid field: id", GatewayCodeInvalidFrame, ) } - return decoded, nil - } - - identifier := strings.TrimSpace(string(trimmed)) - if identifier == "" { + _ = typedID + return identifier, nil + default: return "", NewJSONRPCError( JSONRPCCodeInvalidRequest, "invalid field: id", GatewayCodeInvalidFrame, ) } - return identifier, nil } // decodeWakeIntentParams 对 wake.openUrl 的 params 执行延迟反序列化与最小校验。 diff --git a/internal/gateway/protocol/jsonrpc_test.go b/internal/gateway/protocol/jsonrpc_test.go index 6fa58d31..2d9b37e0 100644 --- a/internal/gateway/protocol/jsonrpc_test.go +++ b/internal/gateway/protocol/jsonrpc_test.go @@ -23,6 +23,21 @@ func TestNormalizeJSONRPCRequestPing(t *testing.T) { } } +func TestNormalizeJSONRPCRequestPingWithNumericID(t *testing.T) { + normalized, rpcErr := NormalizeJSONRPCRequest(JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`123`), + Method: MethodGatewayPing, + Params: json.RawMessage(`{}`), + }) + if rpcErr != nil { + t.Fatalf("normalize ping request with numeric id: %v", rpcErr) + } + if normalized.RequestID != "123" { + t.Fatalf("request_id = %q, want %q", normalized.RequestID, "123") + } +} + func TestNormalizeJSONRPCRequestWakeOpenURL(t *testing.T) { normalized, rpcErr := NormalizeJSONRPCRequest(JSONRPCRequest{ JSONRPC: JSONRPCVersion, @@ -82,6 +97,36 @@ func TestNormalizeJSONRPCRequestErrors(t *testing.T) { wantCode: JSONRPCCodeInvalidRequest, wantGatewayCode: GatewayCodeInvalidFrame, }, + { + name: "invalid id object", + request: JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`{}`), + Method: MethodGatewayPing, + }, + wantCode: JSONRPCCodeInvalidRequest, + wantGatewayCode: GatewayCodeInvalidFrame, + }, + { + name: "invalid id array", + request: JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`[]`), + Method: MethodGatewayPing, + }, + wantCode: JSONRPCCodeInvalidRequest, + wantGatewayCode: GatewayCodeInvalidFrame, + }, + { + name: "invalid id boolean", + request: JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`true`), + Method: MethodGatewayPing, + }, + wantCode: JSONRPCCodeInvalidRequest, + wantGatewayCode: GatewayCodeInvalidFrame, + }, { name: "missing method", request: JSONRPCRequest{ @@ -142,15 +187,36 @@ func TestNormalizeJSONRPCRequestErrors(t *testing.T) { } func TestJSONRPCHelpers(t *testing.T) { - response := NewJSONRPCResultResponse(json.RawMessage(`"req-1"`), map[string]string{"message": "ok"}) + response, rpcErr := NewJSONRPCResultResponse(json.RawMessage(`"req-1"`), map[string]string{"message": "ok"}) + if rpcErr != nil { + t.Fatalf("new jsonrpc result response: %v", rpcErr) + } if response.JSONRPC != JSONRPCVersion { t.Fatalf("jsonrpc = %q, want %q", response.JSONRPC, JSONRPCVersion) } if string(response.ID) != `"req-1"` { t.Fatalf("id = %s, want %s", response.ID, `"req-1"`) } + var result map[string]string + if err := json.Unmarshal(response.Result, &result); err != nil { + t.Fatalf("decode result raw message: %v", err) + } + if result["message"] != "ok" { + t.Fatalf(`result["message"] = %q, want %q`, result["message"], "ok") + } + + _, rpcErr = NewJSONRPCResultResponse(json.RawMessage(`"req-chan"`), map[string]any{"bad": make(chan int)}) + if rpcErr == nil { + t.Fatal("expected result encode error") + } + if rpcErr.Code != JSONRPCCodeInternalError { + t.Fatalf("rpc code = %d, want %d", rpcErr.Code, JSONRPCCodeInternalError) + } + if gatewayCode := GatewayCodeFromJSONRPCError(rpcErr); gatewayCode != GatewayCodeInternalError { + t.Fatalf("gateway_code = %q, want %q", gatewayCode, GatewayCodeInternalError) + } - rpcErr := NewJSONRPCError(JSONRPCCodeInternalError, "boom", GatewayCodeInternalError) + rpcErr = NewJSONRPCError(JSONRPCCodeInternalError, "boom", GatewayCodeInternalError) errorResponse := NewJSONRPCErrorResponse(json.RawMessage(`"req-2"`), rpcErr) if errorResponse.Error == nil { t.Fatal("error response should include rpc error payload") diff --git a/internal/gateway/server.go b/internal/gateway/server.go index 8871873e..ce3c902a 100644 --- a/internal/gateway/server.go +++ b/internal/gateway/server.go @@ -386,7 +386,11 @@ func (s *Server) dispatchRPCRequest( responseFrame := s.dispatchFrame(ctx, frame, runtimePort) if responseFrame.Type != FrameTypeError { - return protocol.NewJSONRPCResultResponse(normalized.ID, responseFrame) + rpcResponse, encodeErr := protocol.NewJSONRPCResultResponse(normalized.ID, responseFrame) + if encodeErr != nil { + return protocol.NewJSONRPCErrorResponse(normalized.ID, encodeErr) + } + return rpcResponse } frameErr := responseFrame.Error diff --git a/internal/gateway/server_test.go b/internal/gateway/server_test.go index 2fedea0d..a65817d2 100644 --- a/internal/gateway/server_test.go +++ b/internal/gateway/server_test.go @@ -203,12 +203,8 @@ func decodeJSONRPCResultFrame(response protocol.JSONRPCResponse) (MessageFrame, if response.Result == nil { return MessageFrame{}, errors.New("rpc result is nil") } - raw, err := json.Marshal(response.Result) - if err != nil { - return MessageFrame{}, err - } var frame MessageFrame - if err := json.Unmarshal(raw, &frame); err != nil { + if err := json.Unmarshal(response.Result, &frame); err != nil { return MessageFrame{}, err } return frame, nil From 97da88f6209242d9a3b4643e14ca867ea74254de Mon Sep 17 00:00:00 2001 From: phantom5099 <1011668688@qq.com> Date: Wed, 15 Apr 2026 12:48:11 +0800 Subject: [PATCH 16/33] =?UTF-8?q?feat(runtime):=E5=A2=9E=E5=8A=A0=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E8=B0=83=E7=94=A8=E9=87=8D=E5=A4=8D=E5=BE=AA=E7=8E=AF?= =?UTF-8?q?=E7=86=94=E6=96=AD=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/config/config_test.go | 3 +- internal/config/runtime.go | 15 ++++- internal/runtime/controlplane/progress.go | 27 +++++++- .../runtime/controlplane/progress_test.go | 26 ++++++- internal/runtime/run.go | 67 ++++++++++++++++++- internal/runtime/run_lifecycle.go | 3 + internal/runtime/runtime_test.go | 2 +- 7 files changed, 130 insertions(+), 13 deletions(-) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index ffd0067c..03ec8197 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1593,7 +1593,8 @@ func TestValidateSnapshotPropagatesCompactError(t *testing.T) { }, }, Runtime: RuntimeConfig{ - MaxNoProgressStreak: 3, + MaxNoProgressStreak: 3, + MaxRepeatCycleStreak: 3, }, Context: ContextConfig{ Compact: CompactConfig{ diff --git a/internal/config/runtime.go b/internal/config/runtime.go index 86bd20ce..a71b6de6 100644 --- a/internal/config/runtime.go +++ b/internal/config/runtime.go @@ -5,18 +5,21 @@ import ( ) const ( - DefaultMaxNoProgressStreak = 3 + DefaultMaxNoProgressStreak = 3 + DefaultMaxRepeatCycleStreak = 3 ) // RuntimeConfig 定义 runtime 层的可调参数。 type RuntimeConfig struct { - MaxNoProgressStreak int `yaml:"max_no_progress_streak,omitempty"` + MaxNoProgressStreak int `yaml:"max_no_progress_streak,omitempty"` + MaxRepeatCycleStreak int `yaml:"max_repeat_cycle_streak,omitempty"` } // defaultRuntimeConfig 返回 runtime 配置的静态默认值。 func defaultRuntimeConfig() RuntimeConfig { return RuntimeConfig{ - MaxNoProgressStreak: DefaultMaxNoProgressStreak, + MaxNoProgressStreak: DefaultMaxNoProgressStreak, + MaxRepeatCycleStreak: DefaultMaxRepeatCycleStreak, } } @@ -33,6 +36,9 @@ func (c *RuntimeConfig) ApplyDefaults(defaults RuntimeConfig) { if c.MaxNoProgressStreak <= 0 { c.MaxNoProgressStreak = defaults.MaxNoProgressStreak } + if c.MaxRepeatCycleStreak <= 0 { + c.MaxRepeatCycleStreak = defaults.MaxRepeatCycleStreak + } } // Validate 校验 runtime 配置是否满足最小约束。 @@ -40,5 +46,8 @@ func (c RuntimeConfig) Validate() error { if c.MaxNoProgressStreak <= 0 { return errors.New("max_no_progress_streak must be greater than 0") } + if c.MaxRepeatCycleStreak <= 0 { + return errors.New("max_repeat_cycle_streak must be greater than 0") + } return nil } diff --git a/internal/runtime/controlplane/progress.go b/internal/runtime/controlplane/progress.go index a1d26875..f6d35164 100644 --- a/internal/runtime/controlplane/progress.go +++ b/internal/runtime/controlplane/progress.go @@ -23,17 +23,38 @@ type ProgressScore struct { // ProgressState 汇总当前运行期 progress 控制面状态。 type ProgressState struct { - LastScore ProgressScore `json:"last_score"` + LastScore ProgressScore `json:"last_score"` + LastSignature string `json:"last_signature,omitempty"` } // ApplyProgressEvidence 根据证据更新分值与 streak。 -func ApplyProgressEvidence(state ProgressState, records []ProgressEvidenceRecord) ProgressState { +func ApplyProgressEvidence(state ProgressState, records []ProgressEvidenceRecord, currentSignature string) ProgressState { next := state.LastScore + isRepeated := false + + if len(records) > 0 { + if currentSignature != "" && currentSignature == state.LastSignature { + isRepeated = true + } + } + + nextSignature := currentSignature + if len(records) == 0 { next.NoProgressStreak++ + next.RepeatCycleStreak = 0 + nextSignature = "" // Clear signature on failure to only count consecutive successes + } else if isRepeated { + next.NoProgressStreak++ + next.RepeatCycleStreak++ } else { next.NoProgressStreak = 0 + next.RepeatCycleStreak = 0 next.ScoreDelta++ } - return ProgressState{LastScore: next} + + return ProgressState{ + LastScore: next, + LastSignature: nextSignature, + } } diff --git a/internal/runtime/controlplane/progress_test.go b/internal/runtime/controlplane/progress_test.go index cd25006e..c51a5a56 100644 --- a/internal/runtime/controlplane/progress_test.go +++ b/internal/runtime/controlplane/progress_test.go @@ -5,7 +5,7 @@ import "testing" func TestApplyProgressEvidenceNoEvidenceIncrementsNoProgress(t *testing.T) { t.Parallel() state := ProgressState{} - next := ApplyProgressEvidence(state, nil) + next := ApplyProgressEvidence(state, nil, "") if next.LastScore.NoProgressStreak != 1 { t.Fatalf("expected no_progress_streak 1, got %d", next.LastScore.NoProgressStreak) } @@ -18,7 +18,7 @@ func TestApplyProgressEvidenceOnlyNonDupResetsNoProgressStreak(t *testing.T) { } next := ApplyProgressEvidence(state, []ProgressEvidenceRecord{ {Kind: EvidenceNewInfoNonDup}, - }) + }, "sig1") if next.LastScore.NoProgressStreak != 0 { t.Fatalf("expected streak reset to 0, got %d", next.LastScore.NoProgressStreak) } @@ -35,8 +35,28 @@ func TestApplyProgressEvidenceMixedResetsNoProgress(t *testing.T) { next := ApplyProgressEvidence(state, []ProgressEvidenceRecord{ {Kind: EvidenceNewInfoNonDup}, {Kind: ProgressEvidenceKind("other_evidence")}, - }) + }, "sig1") if next.LastScore.NoProgressStreak != 0 { t.Fatalf("expected streak reset, got %d", next.LastScore.NoProgressStreak) } } + +func TestApplyProgressEvidenceRepeatCycle(t *testing.T) { + t.Parallel() + state := ProgressState{ + LastScore: ProgressScore{NoProgressStreak: 1, RepeatCycleStreak: 1}, + LastSignature: "sig1", + } + next := ApplyProgressEvidence(state, []ProgressEvidenceRecord{ + {Kind: EvidenceNewInfoNonDup}, + }, "sig1") + if next.LastScore.NoProgressStreak != 2 { + t.Fatalf("expected no_progress_streak 2, got %d", next.LastScore.NoProgressStreak) + } + if next.LastScore.RepeatCycleStreak != 2 { + t.Fatalf("expected repeat_cycle_streak 2, got %d", next.LastScore.RepeatCycleStreak) + } + if next.LastSignature != "sig1" { + t.Fatalf("expected signature sig1, got %s", next.LastSignature) + } +} diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 34c9a35a..2444d9c8 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -2,6 +2,9 @@ package runtime import ( "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" "errors" "fmt" "strings" @@ -19,6 +22,35 @@ import ( ) const selfHealingReminder = "System Reminder: You have made multiple consecutive attempts without making substantial progress. Please stop your current repetitive or ineffective strategy. Carefully review the previous errors, change your approach, or ask the user directly for help." +const selfHealingRepeatReminder = "System Reminder: You are repeatedly calling the same tool with the exact same arguments. This is an infinite loop. Please change your parameters, try a different tool, or ask the user for help." + +// computeToolSignature 计算单轮执行的工具签名,用于循环检测。 +func computeToolSignature(calls []providertypes.ToolCall) string { + if len(calls) == 0 { + return "" + } + var sb strings.Builder + for _, call := range calls { + sb.WriteString(call.Name) + sb.WriteString(":") + + // 尝试将 JSON 参数进行规范化序列化,以消除空格、换行和字段顺序带来的哈希差异 + var parsed interface{} + if err := json.Unmarshal([]byte(call.Arguments), &parsed); err == nil { + if canonicalBytes, err := json.Marshal(parsed); err == nil { + sb.WriteString(string(canonicalBytes)) + } else { + sb.WriteString(call.Arguments) // 序列化失败,降级为原始字符串 + } + } else { + sb.WriteString(call.Arguments) // 解析失败,降级为原始字符串 + } + + sb.WriteString(";") + } + hash := sha256.Sum256([]byte(sb.String())) + return hex.EncodeToString(hash[:]) +} // Run 执行一次完整的 ReAct 闭环:保存用户输入、驱动模型、执行工具并发出事件。 // 已有会话会先加锁再加载/更新,确保同一会话并发 Run 不会出现状态覆盖; @@ -134,6 +166,8 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { var evidence []controlplane.ProgressEvidenceRecord toolCallCount := len(turnResult.assistant.ToolCalls) + currentSignature := computeToolSignature(turnResult.assistant.ToolCalls) + state.mu.Lock() if len(state.session.Messages) >= toolCallCount { for i := len(state.session.Messages) - toolCallCount; i < len(state.session.Messages); i++ { @@ -143,13 +177,25 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { } } } - state.progress = controlplane.ApplyProgressEvidence(state.progress, evidence) + + state.progress = controlplane.ApplyProgressEvidence(state.progress, evidence, currentSignature) streak := state.progress.LastScore.NoProgressStreak + repeatStreak := state.progress.LastScore.RepeatCycleStreak currentScore := state.progress.LastScore state.mu.Unlock() s.emitRunScoped(ctx, EventProgressEvaluated, &state, ProgressEvaluatedPayload{Score: currentScore}) + repeatLimit := snapshot.config.Runtime.MaxRepeatCycleStreak + if repeatLimit <= 0 { + repeatLimit = config.DefaultMaxRepeatCycleStreak + } + + if repeatStreak >= repeatLimit { + err = ErrRepeatCycleLimit + return err + } + limit := snapshot.noProgressStreakLimit if streak >= limit { err = ErrNoProgressStreakLimit @@ -212,12 +258,21 @@ func (s *Service) prepareTurnSnapshot(ctx context.Context, state *runState) (tur state.mu.Lock() streak := state.progress.LastScore.NoProgressStreak + repeatStreak := state.progress.LastScore.RepeatCycleStreak state.mu.Unlock() limit := resolveNoProgressStreakLimit(cfg.Runtime) + repeatLimit := resolveRepeatCycleStreakLimit(cfg.Runtime) systemPrompt := builtContext.SystemPrompt - if streak == limit-1 { + if repeatStreak == repeatLimit-1 { + trimmed := strings.TrimSpace(systemPrompt) + if trimmed == "" { + systemPrompt = selfHealingRepeatReminder + } else { + systemPrompt = trimmed + "\n\n" + selfHealingRepeatReminder + } + } else if streak == limit-1 { trimmed := strings.TrimSpace(systemPrompt) if trimmed == "" { systemPrompt = selfHealingReminder @@ -251,6 +306,14 @@ func resolveNoProgressStreakLimit(rc config.RuntimeConfig) int { return rc.MaxNoProgressStreak } +// resolveRepeatCycleStreakLimit 统一解析重复调用循环阈值。 +func resolveRepeatCycleStreakLimit(rc config.RuntimeConfig) int { + if rc.MaxRepeatCycleStreak <= 0 { + return config.DefaultMaxRepeatCycleStreak + } + return rc.MaxRepeatCycleStreak +} + // callProviderWithRetry 使用冻结后的 turnSnapshot 执行 provider 调用与必要重试。 func (s *Service) callProviderWithRetry( ctx context.Context, diff --git a/internal/runtime/run_lifecycle.go b/internal/runtime/run_lifecycle.go index 52a5015f..919b6ea6 100644 --- a/internal/runtime/run_lifecycle.go +++ b/internal/runtime/run_lifecycle.go @@ -19,6 +19,9 @@ var ErrMaxLoopReached = errors.New("runtime: max loop reached") // ErrNoProgressStreakLimit 表示循环内连续多次未取得进展,触发死循环拦截。 var ErrNoProgressStreakLimit = errors.New("runtime: no progress streak limit reached") +// ErrRepeatCycleLimit 表示连续多次重复调用相同的工具且参数相同,触发死循环拦截。 +var ErrRepeatCycleLimit = errors.New("runtime: repeat cycle limit reached") + // transitionRunPhase 在阶段变化时发出 phase_changed 并更新 runState。 func (s *Service) transitionRunPhase(ctx context.Context, state *runState, next controlplane.Phase) { if state == nil || state.phase == next { diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 72837909..cdfa0540 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -1685,7 +1685,7 @@ func TestServiceRunErrorPaths(t *testing.T) { { ID: fmt.Sprintf("loop-call-%d", i), Name: "filesystem_edit", - Arguments: `{"path":"x"}`, + Arguments: fmt.Sprintf(`{"path":"x", "iteration": %d}`, i), }, }, }, From f8bc410a7193c0dda6edadcbe17a4f9a75cec925 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Wed, 15 Apr 2026 05:04:03 +0000 Subject: [PATCH 17/33] =?UTF-8?q?fix(runtime):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E9=87=8D=E5=A4=8D=E5=BE=AA=E7=8E=AF=E9=85=8D=E7=BD=AE=E6=A0=A1?= =?UTF-8?q?=E9=AA=8C=E5=9B=9E=E5=BD=92=E5=B9=B6=E8=A1=A5=E9=BD=90=E7=86=94?= =?UTF-8?q?=E6=96=AD=E5=9B=9E=E5=BD=92=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/config/runtime.go | 4 +- internal/config/runtime_test.go | 25 ++++++++-- internal/runtime/runtime_progress_test.go | 61 +++++++++++++++++++++++ 3 files changed, 83 insertions(+), 7 deletions(-) diff --git a/internal/config/runtime.go b/internal/config/runtime.go index a71b6de6..2f38ecee 100644 --- a/internal/config/runtime.go +++ b/internal/config/runtime.go @@ -46,8 +46,8 @@ func (c RuntimeConfig) Validate() error { if c.MaxNoProgressStreak <= 0 { return errors.New("max_no_progress_streak must be greater than 0") } - if c.MaxRepeatCycleStreak <= 0 { - return errors.New("max_repeat_cycle_streak must be greater than 0") + if c.MaxRepeatCycleStreak < 0 { + return errors.New("max_repeat_cycle_streak must be greater than or equal to 0") } return nil } diff --git a/internal/config/runtime_test.go b/internal/config/runtime_test.go index 0c6c705d..7063d457 100644 --- a/internal/config/runtime_test.go +++ b/internal/config/runtime_test.go @@ -5,29 +5,38 @@ import "testing" func TestRuntimeConfigClone(t *testing.T) { t.Parallel() - cfg := RuntimeConfig{MaxNoProgressStreak: 7} + cfg := RuntimeConfig{MaxNoProgressStreak: 7, MaxRepeatCycleStreak: 4} cloned := cfg.Clone() if cloned.MaxNoProgressStreak != 7 { t.Fatalf("expected cloned MaxNoProgressStreak=7, got %d", cloned.MaxNoProgressStreak) } + if cloned.MaxRepeatCycleStreak != 4 { + t.Fatalf("expected cloned MaxRepeatCycleStreak=4, got %d", cloned.MaxRepeatCycleStreak) + } } func TestRuntimeConfigApplyDefaults(t *testing.T) { t.Parallel() - defaults := RuntimeConfig{MaxNoProgressStreak: 3} + defaults := RuntimeConfig{MaxNoProgressStreak: 3, MaxRepeatCycleStreak: 5} - cfg := RuntimeConfig{MaxNoProgressStreak: 0} + cfg := RuntimeConfig{MaxNoProgressStreak: 0, MaxRepeatCycleStreak: 0} cfg.ApplyDefaults(defaults) if cfg.MaxNoProgressStreak != 3 { t.Fatalf("expected defaulted MaxNoProgressStreak=3, got %d", cfg.MaxNoProgressStreak) } + if cfg.MaxRepeatCycleStreak != 5 { + t.Fatalf("expected defaulted MaxRepeatCycleStreak=5, got %d", cfg.MaxRepeatCycleStreak) + } - cfg = RuntimeConfig{MaxNoProgressStreak: 5} + cfg = RuntimeConfig{MaxNoProgressStreak: 5, MaxRepeatCycleStreak: 8} cfg.ApplyDefaults(defaults) if cfg.MaxNoProgressStreak != 5 { t.Fatalf("expected existing MaxNoProgressStreak=5 to be preserved, got %d", cfg.MaxNoProgressStreak) } + if cfg.MaxRepeatCycleStreak != 8 { + t.Fatalf("expected existing MaxRepeatCycleStreak=8 to be preserved, got %d", cfg.MaxRepeatCycleStreak) + } var nilCfg *RuntimeConfig nilCfg.ApplyDefaults(defaults) @@ -36,7 +45,7 @@ func TestRuntimeConfigApplyDefaults(t *testing.T) { func TestRuntimeConfigValidate(t *testing.T) { t.Parallel() - if err := (RuntimeConfig{MaxNoProgressStreak: 1}).Validate(); err != nil { + if err := (RuntimeConfig{MaxNoProgressStreak: 1, MaxRepeatCycleStreak: 0}).Validate(); err != nil { t.Fatalf("expected valid config, got %v", err) } @@ -45,4 +54,10 @@ func TestRuntimeConfigValidate(t *testing.T) { t.Fatalf("expected validation error for MaxNoProgressStreak=%d", bad) } } + + for _, bad := range []int{-1, -99} { + if err := (RuntimeConfig{MaxNoProgressStreak: 1, MaxRepeatCycleStreak: bad}).Validate(); err == nil { + t.Fatalf("expected validation error for MaxRepeatCycleStreak=%d", bad) + } + } } diff --git a/internal/runtime/runtime_progress_test.go b/internal/runtime/runtime_progress_test.go index f9e90a1a..5de57afb 100644 --- a/internal/runtime/runtime_progress_test.go +++ b/internal/runtime/runtime_progress_test.go @@ -169,3 +169,64 @@ func TestProgressEvidenceResetsNoProgressStreak(t *testing.T) { } } } + +func TestRepeatCycleStreakStopsRunAndInjectsReminder(t *testing.T) { + t.Setenv("TEST_KEY", "dummy") + + cfg := config.Config{ + Providers: []config.ProviderConfig{{Name: "test-repeat", Driver: "test", BaseURL: "http://localhost", Model: "test", APIKeyEnv: "TEST_KEY"}}, + SelectedProvider: "test-repeat", + Workdir: t.TempDir(), + Runtime: config.RuntimeConfig{ + MaxNoProgressStreak: 10, + MaxRepeatCycleStreak: 3, + }, + } + + toolManager := &stubToolManager{ + specs: []providertypes.ToolSpec{ + {Name: "tool_repeat"}, + }, + executeFn: func(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { + return tools.ToolResult{Name: input.Name, Content: "ok", IsError: false}, nil + }, + } + + var promptInjected bool + providerFactory := &scriptedProviderFactory{ + provider: &scriptedProvider{ + chatFn: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { + if strings.Contains(req.SystemPrompt, selfHealingRepeatReminder) { + promptInjected = true + } + events <- providertypes.NewToolCallStartStreamEvent(0, "call_repeat", "tool_repeat") + events <- providertypes.NewToolCallDeltaStreamEvent(0, "call_repeat", `{"path":"x"}`) + events <- providertypes.NewMessageDoneStreamEvent("tool_calls", nil) + return nil + }, + }, + } + + manager := config.NewManager(config.NewLoader(t.TempDir(), &cfg)) + service := NewWithFactory( + manager, + toolManager, + newMemoryStore(), + providerFactory, + nil, + ) + + err := service.Run(context.Background(), UserInput{ + RunID: "run-repeat-streak", + Content: "trigger repeat loop", + }) + if err == nil { + t.Fatal("expected repeat cycle limit error, got nil") + } + if !errors.Is(err, ErrRepeatCycleLimit) { + t.Fatalf("expected ErrRepeatCycleLimit, got %v", err) + } + if !promptInjected { + t.Fatal("expected repeat self-healing prompt to be injected before repeat limit is reached") + } +} From bd9cca638b8e6ac2bfdae6f91c4d53c6e975b978 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Wed, 15 Apr 2026 05:30:45 +0000 Subject: [PATCH 18/33] fix(gateway): resolve remaining JSON-RPC hardening review gaps - block Windows drive-relative review paths across platforms - ensure JSON-RPC error responses keep id field (null when unknown) - strengthen dispatcher success-case RPC envelope assertions Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: pionxe <148670367+pionxe@users.noreply.github.com> --- .../adapters/urlscheme/dispatcher_test.go | 21 +++++++++++++++++++ internal/gateway/handlers/wake.go | 13 ++++++++++++ internal/gateway/protocol/jsonrpc.go | 2 +- internal/gateway/protocol/jsonrpc_test.go | 19 +++++++++++++++++ internal/gateway/server_additional_test.go | 3 +++ 5 files changed, 57 insertions(+), 1 deletion(-) diff --git a/internal/gateway/adapters/urlscheme/dispatcher_test.go b/internal/gateway/adapters/urlscheme/dispatcher_test.go index 1fd9d3d0..75f1427c 100644 --- a/internal/gateway/adapters/urlscheme/dispatcher_test.go +++ b/internal/gateway/adapters/urlscheme/dispatcher_test.go @@ -50,6 +50,27 @@ func TestDispatcherDispatchSuccess(t *testing.T) { t.Errorf("request method = %q, want %q", rpcRequest.Method, protocol.MethodWakeOpenURL) return } + if rpcRequest.JSONRPC != protocol.JSONRPCVersion { + t.Errorf("request jsonrpc = %q, want %q", rpcRequest.JSONRPC, protocol.JSONRPCVersion) + return + } + if len(bytes.TrimSpace(rpcRequest.ID)) == 0 { + t.Error("request id should not be empty") + return + } + var params protocol.WakeIntent + if err := json.Unmarshal(rpcRequest.Params, ¶ms); err != nil { + t.Errorf("decode request params: %v", err) + return + } + if params.Action != protocol.WakeActionReview { + t.Errorf("request params action = %q, want %q", params.Action, protocol.WakeActionReview) + return + } + if got := params.Params["path"]; got != "README.md" { + t.Errorf("request params[path] = %q, want %q", got, "README.md") + return + } if err := encoder.Encode(protocol.JSONRPCResponse{ JSONRPC: protocol.JSONRPCVersion, diff --git a/internal/gateway/handlers/wake.go b/internal/gateway/handlers/wake.go index 96ce6246..db7d7f40 100644 --- a/internal/gateway/handlers/wake.go +++ b/internal/gateway/handlers/wake.go @@ -88,6 +88,9 @@ func isSafeReviewPath(path string) bool { if trimmed == "" { return false } + if hasWindowsDriveLetterPrefix(trimmed) { + return false + } if filepath.VolumeName(trimmed) != "" { return false } @@ -110,6 +113,16 @@ func isSafeReviewPath(path string) bool { return true } +// hasWindowsDriveLetterPrefix 检查是否为 Windows 盘符前缀路径(如 C:foo),避免平台差异导致漏拦截。 +func hasWindowsDriveLetterPrefix(path string) bool { + trimmed := strings.TrimSpace(path) + if len(trimmed) < 2 { + return false + } + drive := trimmed[0] + return ((drive >= 'a' && drive <= 'z') || (drive >= 'A' && drive <= 'Z')) && trimmed[1] == ':' +} + // hasBlockedWindowsPathPrefix 检查是否命中 Windows 底层设备路径前缀,避免绕过常规路径校验。 func hasBlockedWindowsPathPrefix(path string) bool { normalized := strings.ReplaceAll(strings.TrimSpace(path), "/", "\\") diff --git a/internal/gateway/protocol/jsonrpc.go b/internal/gateway/protocol/jsonrpc.go index 66773148..37bad07b 100644 --- a/internal/gateway/protocol/jsonrpc.go +++ b/internal/gateway/protocol/jsonrpc.go @@ -59,7 +59,7 @@ type JSONRPCRequest struct { // JSONRPCResponse 表示控制面输出的 JSON-RPC 响应。 type JSONRPCResponse struct { JSONRPC string `json:"jsonrpc"` - ID json.RawMessage `json:"id,omitempty"` + ID json.RawMessage `json:"id"` Result json.RawMessage `json:"result,omitempty"` Error *JSONRPCError `json:"error,omitempty"` } diff --git a/internal/gateway/protocol/jsonrpc_test.go b/internal/gateway/protocol/jsonrpc_test.go index 2d9b37e0..9b586502 100644 --- a/internal/gateway/protocol/jsonrpc_test.go +++ b/internal/gateway/protocol/jsonrpc_test.go @@ -238,3 +238,22 @@ func TestJSONRPCHelpers(t *testing.T) { t.Fatal("unknown code should map to internal_error") } } + +func TestNewJSONRPCErrorResponseWithNilIDEncodesNull(t *testing.T) { + response := NewJSONRPCErrorResponse(nil, NewJSONRPCError(JSONRPCCodeParseError, "parse error", GatewayCodeInvalidFrame)) + encoded, err := json.Marshal(response) + if err != nil { + t.Fatalf("marshal error response: %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(encoded, &payload); err != nil { + t.Fatalf("unmarshal encoded response: %v", err) + } + if _, ok := payload["id"]; !ok { + t.Fatal("encoded response should contain id field") + } + if payload["id"] != nil { + t.Fatalf("encoded response id = %#v, want nil", payload["id"]) + } +} diff --git a/internal/gateway/server_additional_test.go b/internal/gateway/server_additional_test.go index 080b58e3..ecd1cd4f 100644 --- a/internal/gateway/server_additional_test.go +++ b/internal/gateway/server_additional_test.go @@ -401,6 +401,9 @@ func TestServerHandleConnectionInvalidJSONFrame(t *testing.T) { if gatewayCode := protocol.GatewayCodeFromJSONRPCError(response.Error); gatewayCode != ErrorCodeInvalidFrame.String() { t.Fatalf("gateway_code = %q, want %q", gatewayCode, ErrorCodeInvalidFrame.String()) } + if got := strings.TrimSpace(string(response.ID)); got != "null" { + t.Fatalf("response id = %q, want %q", got, "null") + } _ = clientConn.Close() select { From 9744a2c90800e0ee1008b6b56d54d6bbe8a8e219 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Wed, 15 Apr 2026 05:44:18 +0000 Subject: [PATCH 19/33] =?UTF-8?q?fix(config/runtime):=20=E7=BB=9F=E4=B8=80?= =?UTF-8?q?=20MaxRepeatCycleStreak=20=E4=B8=89=E5=A4=84=E5=AE=88=E5=8D=AB?= =?UTF-8?q?=E8=AF=AD=E4=B9=89=E4=B8=80=E8=87=B4=E6=80=A7=E5=B9=B6=E8=A1=A5?= =?UTF-8?q?=E9=BD=90=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 Validate 中 MaxRepeatCycleStreak 判断由 < 0 回退为 <= 0,与 ApplyDefaults / resolveRepeatCycleStreakLimit / 内联兜底三处保持一致 - 在 TestRuntimeConfigApplyDefaults 补充负值回退测试(-1 被替换为默认值) - 在 TestRuntimeConfigValidate 将有效测试入参从 0 改为 1,并将 0 纳入 无效值覆盖列表,覆盖 <= 0 分支 - 在 TestRepeatCycleStreakStopsRunAndInjectsReminder 补充 EventStopReasonDecided 事件断言(StopReasonError + detail 字段), 与 TestProgressStreakStopsRun 保持对称覆盖 Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/config/runtime.go | 4 ++-- internal/config/runtime_test.go | 10 ++++++++-- internal/runtime/runtime_progress_test.go | 16 ++++++++++++++++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/internal/config/runtime.go b/internal/config/runtime.go index 2f38ecee..a71b6de6 100644 --- a/internal/config/runtime.go +++ b/internal/config/runtime.go @@ -46,8 +46,8 @@ func (c RuntimeConfig) Validate() error { if c.MaxNoProgressStreak <= 0 { return errors.New("max_no_progress_streak must be greater than 0") } - if c.MaxRepeatCycleStreak < 0 { - return errors.New("max_repeat_cycle_streak must be greater than or equal to 0") + if c.MaxRepeatCycleStreak <= 0 { + return errors.New("max_repeat_cycle_streak must be greater than 0") } return nil } diff --git a/internal/config/runtime_test.go b/internal/config/runtime_test.go index 7063d457..bd493427 100644 --- a/internal/config/runtime_test.go +++ b/internal/config/runtime_test.go @@ -38,6 +38,12 @@ func TestRuntimeConfigApplyDefaults(t *testing.T) { t.Fatalf("expected existing MaxRepeatCycleStreak=8 to be preserved, got %d", cfg.MaxRepeatCycleStreak) } + cfg = RuntimeConfig{MaxNoProgressStreak: 2, MaxRepeatCycleStreak: -1} + cfg.ApplyDefaults(defaults) + if cfg.MaxRepeatCycleStreak != 5 { + t.Fatalf("expected negative MaxRepeatCycleStreak=-1 to be replaced by default=5, got %d", cfg.MaxRepeatCycleStreak) + } + var nilCfg *RuntimeConfig nilCfg.ApplyDefaults(defaults) } @@ -45,7 +51,7 @@ func TestRuntimeConfigApplyDefaults(t *testing.T) { func TestRuntimeConfigValidate(t *testing.T) { t.Parallel() - if err := (RuntimeConfig{MaxNoProgressStreak: 1, MaxRepeatCycleStreak: 0}).Validate(); err != nil { + if err := (RuntimeConfig{MaxNoProgressStreak: 1, MaxRepeatCycleStreak: 1}).Validate(); err != nil { t.Fatalf("expected valid config, got %v", err) } @@ -55,7 +61,7 @@ func TestRuntimeConfigValidate(t *testing.T) { } } - for _, bad := range []int{-1, -99} { + for _, bad := range []int{0, -1, -99} { if err := (RuntimeConfig{MaxNoProgressStreak: 1, MaxRepeatCycleStreak: bad}).Validate(); err == nil { t.Fatalf("expected validation error for MaxRepeatCycleStreak=%d", bad) } diff --git a/internal/runtime/runtime_progress_test.go b/internal/runtime/runtime_progress_test.go index 5de57afb..39486932 100644 --- a/internal/runtime/runtime_progress_test.go +++ b/internal/runtime/runtime_progress_test.go @@ -226,6 +226,22 @@ func TestRepeatCycleStreakStopsRunAndInjectsReminder(t *testing.T) { if !errors.Is(err, ErrRepeatCycleLimit) { t.Fatalf("expected ErrRepeatCycleLimit, got %v", err) } + + events := collectRuntimeEvents(service.Events()) + + assertEventContains(t, events, EventStopReasonDecided) + for _, e := range events { + if e.Type == EventStopReasonDecided { + payload := e.Payload.(StopReasonDecidedPayload) + if payload.Reason != controlplane.StopReasonError { + t.Errorf("expected StopReasonError, got %s", payload.Reason) + } + if payload.Detail != ErrRepeatCycleLimit.Error() { + t.Errorf("expected detail to be %q, got %q", ErrRepeatCycleLimit.Error(), payload.Detail) + } + } + } + if !promptInjected { t.Fatal("expected repeat self-healing prompt to be injected before repeat limit is reached") } From dd4be728e99116a3962d9b7b63d266d2bd9946fe Mon Sep 17 00:00:00 2001 From: xgopilot Date: Wed, 15 Apr 2026 07:30:16 +0000 Subject: [PATCH 20/33] =?UTF-8?q?test(runtime):=20=E6=8F=90=E5=8D=87?= =?UTF-8?q?=E9=87=8D=E5=A4=8D=E8=B0=83=E7=94=A8=E7=86=94=E6=96=AD=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E8=A6=86=E7=9B=96=E7=8E=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/runtime/runtime_progress_test.go | 114 ++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/internal/runtime/runtime_progress_test.go b/internal/runtime/runtime_progress_test.go index 39486932..5481f6e6 100644 --- a/internal/runtime/runtime_progress_test.go +++ b/internal/runtime/runtime_progress_test.go @@ -8,6 +8,7 @@ import ( "testing" "neo-code/internal/config" + agentcontext "neo-code/internal/context" providertypes "neo-code/internal/provider/types" "neo-code/internal/runtime/controlplane" "neo-code/internal/tools" @@ -246,3 +247,116 @@ func TestRepeatCycleStreakStopsRunAndInjectsReminder(t *testing.T) { t.Fatal("expected repeat self-healing prompt to be injected before repeat limit is reached") } } + +func TestComputeToolSignatureNormalizationAndFallback(t *testing.T) { + if got := computeToolSignature(nil); got != "" { + t.Fatalf("expected empty signature for nil tool calls, got %q", got) + } + + callsA := []providertypes.ToolCall{ + {Name: "filesystem_read_file", Arguments: "{\n \"path\": \"/tmp/a.txt\",\n \"opts\": {\"y\": [2,3], \"x\": 1}\n}"}, + {Name: "bash", Arguments: "{\"cmd\":\"pwd\"}"}, + } + callsB := []providertypes.ToolCall{ + {Name: "filesystem_read_file", Arguments: "{\"opts\":{\"x\":1,\"y\":[2,3]},\"path\":\"/tmp/a.txt\"}"}, + {Name: "bash", Arguments: "{ \"cmd\" : \"pwd\" }"}, + } + sigA := computeToolSignature(callsA) + sigB := computeToolSignature(callsB) + if sigA != sigB { + t.Fatalf("expected canonicalized signatures to match, got %q vs %q", sigA, sigB) + } + + invalidA := []providertypes.ToolCall{{Name: "bash", Arguments: "{\"cmd\":"}} + invalidB := []providertypes.ToolCall{{Name: "bash", Arguments: "{\"cmd\":\"ls\"}"}} + if computeToolSignature(invalidA) == computeToolSignature(invalidB) { + t.Fatal("expected invalid-json fallback signature to differ from valid-json signature") + } +} + +func TestPrepareTurnSnapshotInjectRepeatReminderWithEmptyPrompt(t *testing.T) { + manager := newRuntimeConfigManager(t) + if err := manager.Update(context.Background(), func(cfg *config.Config) error { + cfg.Runtime.MaxRepeatCycleStreak = 3 + return nil + }); err != nil { + t.Fatalf("update config: %v", err) + } + + service := &Service{ + configManager: manager, + contextBuilder: &stubContextBuilder{ + buildFn: func(ctx context.Context, input agentcontext.BuildInput) (agentcontext.BuildResult, error) { + return agentcontext.BuildResult{SystemPrompt: "", Messages: input.Messages}, nil + }, + }, + toolManager: &stubToolManager{}, + } + state := newRunState("run-repeat-reminder-empty", newRuntimeSession("session-repeat-reminder-empty")) + state.progress.LastScore.RepeatCycleStreak = 2 + + snapshot, rebuilt, err := service.prepareTurnSnapshot(context.Background(), &state) + if err != nil { + t.Fatalf("prepareTurnSnapshot() error = %v", err) + } + if rebuilt { + t.Fatal("expected rebuilt=false") + } + if snapshot.request.SystemPrompt != selfHealingRepeatReminder { + t.Fatalf("expected repeat reminder only, got %q", snapshot.request.SystemPrompt) + } +} + +func TestPrepareTurnSnapshotRepeatReminderTakesPriority(t *testing.T) { + manager := newRuntimeConfigManager(t) + if err := manager.Update(context.Background(), func(cfg *config.Config) error { + cfg.Runtime.MaxNoProgressStreak = 3 + cfg.Runtime.MaxRepeatCycleStreak = 3 + return nil + }); err != nil { + t.Fatalf("update config: %v", err) + } + + service := &Service{ + configManager: manager, + contextBuilder: &stubContextBuilder{ + buildFn: func(ctx context.Context, input agentcontext.BuildInput) (agentcontext.BuildResult, error) { + return agentcontext.BuildResult{SystemPrompt: "base prompt", Messages: input.Messages}, nil + }, + }, + toolManager: &stubToolManager{}, + } + state := newRunState("run-reminder-priority", newRuntimeSession("session-reminder-priority")) + state.progress.LastScore.NoProgressStreak = 2 + state.progress.LastScore.RepeatCycleStreak = 2 + + snapshot, rebuilt, err := service.prepareTurnSnapshot(context.Background(), &state) + if err != nil { + t.Fatalf("prepareTurnSnapshot() error = %v", err) + } + if rebuilt { + t.Fatal("expected rebuilt=false") + } + if !strings.Contains(snapshot.request.SystemPrompt, selfHealingRepeatReminder) { + t.Fatalf("expected prompt to contain repeat reminder, got %q", snapshot.request.SystemPrompt) + } + if strings.Contains(snapshot.request.SystemPrompt, selfHealingReminder) { + t.Fatalf("expected no-progress reminder to be skipped when repeat reminder is injected, got %q", snapshot.request.SystemPrompt) + } +} + +func TestResolveStreakLimitDefaults(t *testing.T) { + if got := resolveNoProgressStreakLimit(config.RuntimeConfig{MaxNoProgressStreak: 0}); got != config.DefaultMaxNoProgressStreak { + t.Fatalf("expected default no-progress limit %d, got %d", config.DefaultMaxNoProgressStreak, got) + } + if got := resolveNoProgressStreakLimit(config.RuntimeConfig{MaxNoProgressStreak: 8}); got != 8 { + t.Fatalf("expected explicit no-progress limit 8, got %d", got) + } + + if got := resolveRepeatCycleStreakLimit(config.RuntimeConfig{MaxRepeatCycleStreak: -1}); got != config.DefaultMaxRepeatCycleStreak { + t.Fatalf("expected default repeat limit %d, got %d", config.DefaultMaxRepeatCycleStreak, got) + } + if got := resolveRepeatCycleStreakLimit(config.RuntimeConfig{MaxRepeatCycleStreak: 6}); got != 6 { + t.Fatalf("expected explicit repeat limit 6, got %d", got) + } +} From b073953a210e827426aec25f50cc00c295863210 Mon Sep 17 00:00:00 2001 From: Yumiue <229866007@qq.com> Date: Wed, 15 Apr 2026 16:27:38 +0800 Subject: [PATCH 21/33] =?UTF-8?q?fix(context):=20=E4=BF=AE=E5=A4=8D=20meta?= =?UTF-8?q?data-only=20=E5=B7=A5=E5=85=B7=E7=BB=93=E6=9E=9C=E8=AF=AD?= =?UTF-8?q?=E4=B9=89=E4=B8=A2=E5=A4=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/context/builder_test.go | 42 ++++++++++ internal/context/projection.go | 5 +- internal/context/projection_test.go | 39 +++++++-- internal/memo/llm_extractor_test.go | 42 ++++++++++ .../runtime/runtime_internal_helpers_test.go | 83 +++++++++++++++++++ internal/runtime/runtime_test.go | 82 +++++++++++++++--- internal/runtime/session_mutation.go | 30 +++++-- internal/session/store_test.go | 50 +++++++++++ 8 files changed, 349 insertions(+), 24 deletions(-) diff --git a/internal/context/builder_test.go b/internal/context/builder_test.go index efe32e55..b17921e7 100644 --- a/internal/context/builder_test.go +++ b/internal/context/builder_test.go @@ -777,3 +777,45 @@ func TestProjectToolMessagesForModelKeepsBuilderProjectionBehavior(t *testing.T) t.Fatal("expected source messages to remain unchanged") } } + +func TestDefaultBuilderBuildProjectsMetadataOnlyToolResult(t *testing.T) { + t.Parallel() + + builder := NewBuilder() + result, err := builder.Build(stdcontext.Background(), BuildInput{ + Messages: []providertypes.Message{ + {Role: providertypes.RoleUser, Content: "inspect README"}, + { + Role: providertypes.RoleAssistant, + ToolCalls: []providertypes.ToolCall{ + {ID: "call-1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`}, + }, + }, + { + Role: providertypes.RoleTool, + ToolCallID: "call-1", + Content: " ", + ToolMetadata: map[string]string{"tool_name": "filesystem_read_file", "path": "README.md"}, + }, + }, + Metadata: testMetadata(t.TempDir()), + }) + if err != nil { + t.Fatalf("Build() error = %v", err) + } + if len(result.Messages) != 3 { + t.Fatalf("len(result.Messages) = %d, want 3", len(result.Messages)) + } + toolMessage := result.Messages[2] + if toolMessage.Role != providertypes.RoleTool { + t.Fatalf("expected tool message at index 2, got %+v", toolMessage) + } + if !strings.Contains(toolMessage.Content, "tool result") || + !strings.Contains(toolMessage.Content, "tool: filesystem_read_file") || + !strings.Contains(toolMessage.Content, "meta.path: README.md") { + t.Fatalf("expected metadata-only tool result to be projected, got %q", toolMessage.Content) + } + if toolMessage.ToolMetadata != nil { + t.Fatalf("expected projected tool metadata to be cleared, got %#v", toolMessage.ToolMetadata) + } +} diff --git a/internal/context/projection.go b/internal/context/projection.go index 86c5a6dd..85e7ecd0 100644 --- a/internal/context/projection.go +++ b/internal/context/projection.go @@ -145,7 +145,10 @@ func isInjectableToolMessage(message providertypes.Message) bool { return false } content := strings.TrimSpace(message.Content) - return content != "" && content != microCompactClearedMessage + if content == microCompactClearedMessage { + return false + } + return content != "" || len(message.ToolMetadata) > 0 } // recentWindowMessageBudget 计算 recent window 可保留的消息总数硬上限,避免窗口体积失控。 diff --git a/internal/context/projection_test.go b/internal/context/projection_test.go index 84116afa..5b8f428a 100644 --- a/internal/context/projection_test.go +++ b/internal/context/projection_test.go @@ -43,10 +43,10 @@ func TestProjectToolMessagesForModelSkipsMessagesThatCannotBeProjected(t *testin t.Fatalf("non-tool message should remain unchanged, got %+v", projected[0]) } if projected[1].Content != "tool output" || projected[1].ToolMetadata != nil { - t.Fatalf("tool without metadata-free projection should remain unchanged, got %+v", projected[1]) + t.Fatalf("tool without projection metadata should remain unchanged, got %+v", projected[1]) } - if projected[2].Content != " " || projected[2].ToolMetadata == nil { - t.Fatalf("empty tool content should not be projected, got %+v", projected[2]) + if !strings.Contains(projected[2].Content, "tool result") || projected[2].ToolMetadata != nil { + t.Fatalf("metadata-only tool message should be projected, got %+v", projected[2]) } if projected[3].Content != microCompactClearedMessage || projected[3].ToolMetadata == nil { t.Fatalf("cleared tool content should not be projected, got %+v", projected[3]) @@ -221,7 +221,7 @@ func TestMatchedToolCallSpanRejectsInvalidAssistantStates(t *testing.T) { } } -func TestMatchedToolCallSpanRequiresInjectableResponsesAndSkipsDuplicates(t *testing.T) { +func TestMatchedToolCallSpanRequiresProjectableResponsesAndSkipsDuplicates(t *testing.T) { t.Parallel() messages := []providertypes.Message{ @@ -244,11 +244,35 @@ func TestMatchedToolCallSpanRequiresInjectableResponsesAndSkipsDuplicates(t *tes if len(span) != 3 { t.Fatalf("len(span) = %d, want 3 (%+v)", len(span), span) } - if span[0] != 0 || span[1] != 2 || span[2] != 5 { + if span[0] != 0 || span[1] != 1 || span[2] != 5 { t.Fatalf("unexpected span indexes %+v", span) } } +func TestMatchedToolCallSpanAcceptsMetadataOnlyResponses(t *testing.T) { + t.Parallel() + + messages := []providertypes.Message{ + { + Role: providertypes.RoleAssistant, + ToolCalls: []providertypes.ToolCall{ + {ID: "call-1", Name: "webfetch", Arguments: `{}`}, + }, + }, + { + Role: providertypes.RoleTool, + ToolCallID: "call-1", + Content: " ", + ToolMetadata: map[string]string{"tool_name": "webfetch", "status_code": "200"}, + }, + } + + span := matchedToolCallSpan(messages, 0) + if len(span) != 2 || span[0] != 0 || span[1] != 1 { + t.Fatalf("unexpected metadata-only span %+v", span) + } +} + func TestMatchedToolCallSpanRejectsResponsesWithoutProjectionMetadata(t *testing.T) { t.Parallel() @@ -285,6 +309,11 @@ func TestIsInjectableToolMessage(t *testing.T) { message: providertypes.Message{Role: providertypes.RoleTool, Content: " "}, want: false, }, + { + name: "metadata-only", + message: providertypes.Message{Role: providertypes.RoleTool, Content: " ", ToolMetadata: map[string]string{"tool_name": "bash"}}, + want: true, + }, { name: "cleared", message: providertypes.Message{Role: providertypes.RoleTool, Content: microCompactClearedMessage}, diff --git a/internal/memo/llm_extractor_test.go b/internal/memo/llm_extractor_test.go index 7d1befea..10bc7864 100644 --- a/internal/memo/llm_extractor_test.go +++ b/internal/memo/llm_extractor_test.go @@ -299,6 +299,48 @@ func TestLLMExtractorExtractKeepsProjectedToolCallSpan(t *testing.T) { } } +func TestLLMExtractorExtractKeepsMetadataOnlyToolCallSpan(t *testing.T) { + generator := &stubTextGenerator{response: `[]`} + extractor := NewLLMExtractor(generator) + + _, err := extractor.Extract(context.Background(), []providertypes.Message{ + {Role: providertypes.RoleUser, Content: "remember this"}, + { + Role: providertypes.RoleAssistant, + ToolCalls: []providertypes.ToolCall{ + {ID: "call_1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`}, + }, + }, + { + Role: providertypes.RoleTool, + ToolCallID: "call_1", + Content: "", + ToolMetadata: map[string]string{"tool_name": "filesystem_read_file", "path": "README.md"}, + }, + }) + if err != nil { + t.Fatalf("Extract() error = %v", err) + } + if len(generator.messages) != 3 { + t.Fatalf("len(generator.messages) = %d, want 3", len(generator.messages)) + } + toolMessage := generator.messages[2] + if toolMessage.Role != providertypes.RoleTool { + t.Fatalf("expected projected tool message, got %#v", toolMessage) + } + if !strings.Contains(toolMessage.Content, "tool result") || + !strings.Contains(toolMessage.Content, "tool: filesystem_read_file") || + !strings.Contains(toolMessage.Content, "meta.path: README.md") { + t.Fatalf("expected metadata-only tool text, got %q", toolMessage.Content) + } + if strings.Contains(toolMessage.Content, "content:\n") { + t.Fatalf("expected metadata-only projection to omit content section, got %q", toolMessage.Content) + } + if toolMessage.ToolMetadata != nil { + t.Fatalf("expected projected tool metadata to be cleared, got %#v", toolMessage.ToolMetadata) + } +} + func TestLLMExtractorExtractSkipsOrphanAndClearedToolMessages(t *testing.T) { generator := &stubTextGenerator{response: `[]`} extractor := NewLLMExtractor(generator) diff --git a/internal/runtime/runtime_internal_helpers_test.go b/internal/runtime/runtime_internal_helpers_test.go index 91e86a91..152b48e2 100644 --- a/internal/runtime/runtime_internal_helpers_test.go +++ b/internal/runtime/runtime_internal_helpers_test.go @@ -177,6 +177,89 @@ func TestAppendToolMessageAndSaveSanitizesMetadata(t *testing.T) { } } +func TestAppendToolMessageAndSavePreservesMetadataOnlySuccessResult(t *testing.T) { + t.Parallel() + + store := newMemoryStore() + session := newRuntimeSession("session-append-tool-metadata-only") + store.sessions[session.ID] = cloneSession(session) + + service := &Service{sessionStore: store} + state := newRunState("run-append-tool-metadata-only", session) + call := providertypes.ToolCall{ID: "call-1", Name: "filesystem_read_file"} + result := tools.ToolResult{ + Name: "filesystem_read_file", + Content: "", + Metadata: map[string]any{ + "path": "README.md", + }, + } + + if err := service.appendToolMessageAndSave(context.Background(), &state, call, result); err != nil { + t.Fatalf("appendToolMessageAndSave() error = %v", err) + } + + msg := state.session.Messages[0] + if msg.Content != "" { + t.Fatalf("expected metadata-only success result to keep empty content, got %q", msg.Content) + } + if msg.ToolMetadata["tool_name"] != "filesystem_read_file" || msg.ToolMetadata["path"] != "README.md" { + t.Fatalf("expected metadata-only success result to keep sanitized metadata, got %+v", msg.ToolMetadata) + } +} + +func TestAppendToolMessageAndSaveNormalizesSemanticallyEmptySuccessResult(t *testing.T) { + t.Parallel() + + store := newMemoryStore() + session := newRuntimeSession("session-append-tool-empty-success") + store.sessions[session.ID] = cloneSession(session) + + service := &Service{sessionStore: store} + state := newRunState("run-append-tool-empty-success", session) + call := providertypes.ToolCall{ID: "call-1", Name: "filesystem_read_file"} + result := tools.ToolResult{ + Name: "filesystem_read_file", + Content: " ", + } + + if err := service.appendToolMessageAndSave(context.Background(), &state, call, result); err != nil { + t.Fatalf("appendToolMessageAndSave() error = %v", err) + } + + msg := state.session.Messages[0] + if msg.Content != "ok" { + t.Fatalf("expected empty success result to be normalized to ok, got %q", msg.Content) + } + if msg.ToolMetadata["tool_name"] != "filesystem_read_file" { + t.Fatalf("expected tool_name metadata to be preserved after normalization, got %+v", msg.ToolMetadata) + } +} + +func TestAppendToolMessageAndSaveFallsBackToCallNameForToolMetadata(t *testing.T) { + t.Parallel() + + store := newMemoryStore() + session := newRuntimeSession("session-append-tool-name-fallback") + store.sessions[session.ID] = cloneSession(session) + + service := &Service{sessionStore: store} + state := newRunState("run-append-tool-name-fallback", session) + call := providertypes.ToolCall{ID: "call-1", Name: "filesystem_read_file"} + result := tools.ToolResult{ + Content: "ok", + } + + if err := service.appendToolMessageAndSave(context.Background(), &state, call, result); err != nil { + t.Fatalf("appendToolMessageAndSave() error = %v", err) + } + + msg := state.session.Messages[0] + if msg.ToolMetadata["tool_name"] != "filesystem_read_file" { + t.Fatalf("expected tool_name fallback from call name, got %+v", msg.ToolMetadata) + } +} + func TestAppendToolMessageAndSaveUnlocksStateBeforePersist(t *testing.T) { t.Parallel() diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 7de08a4f..43c43114 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -545,6 +545,76 @@ func TestServiceRun(t *testing.T) { } }, }, + { + name: "metadata-only tool result is projected on follow-up provider round", + input: UserInput{RunID: "run-tool-metadata-only", Content: "inspect file"}, + providerStreams: [][]providertypes.StreamEvent{ + { + providertypes.NewToolCallStartStreamEvent(0, "call-1", "filesystem_read_file"), + providertypes.NewToolCallDeltaStreamEvent(0, "call-1", `{"path":"README.md"}`), + }, + { + providertypes.NewTextDeltaStreamEvent("done"), + }, + }, + registerTool: &stubTool{ + name: "filesystem_read_file", + executeFn: func(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { + return tools.ToolResult{ + Name: "filesystem_read_file", + Content: "", + Metadata: map[string]any{ + "path": "README.md", + }, + }, nil + }, + }, + contextBuilder: &stubContextBuilder{ + buildFn: func(ctx context.Context, input agentcontext.BuildInput) (agentcontext.BuildResult, error) { + return agentcontext.BuildResult{ + SystemPrompt: "stub system prompt", + Messages: projectToolMessagesForProviderTest(input.Messages), + }, nil + }, + }, + expectProviderCalls: 2, + expectToolCalls: 1, + expectMessageRoles: []string{"user", "assistant", "tool", "assistant"}, + expectEventTypes: []EventType{EventUserMessage, EventToolStart, EventToolResult, EventAgentDone}, + assert: func(t *testing.T, store *memoryStore, scripted *scriptedProvider, tool *stubTool) { + t.Helper() + if len(scripted.requests) != 2 { + t.Fatalf("expected 2 provider requests, got %d", len(scripted.requests)) + } + second := scripted.requests[1] + foundToolResult := false + for _, message := range second.Messages { + if message.Role == providertypes.RoleTool && + message.ToolCallID == "call-1" && + strings.Contains(message.Content, "tool result") && + strings.Contains(message.Content, "tool: filesystem_read_file") && + strings.Contains(message.Content, "meta.path: README.md") { + foundToolResult = true + if strings.Contains(message.Content, "content:\n") { + t.Fatalf("expected metadata-only projection to omit content section, got %q", message.Content) + } + break + } + } + if !foundToolResult { + t.Fatalf("expected projected metadata-only tool result in second provider request: %+v", second.Messages) + } + + session := onlySession(t, store) + if session.Messages[2].Role != providertypes.RoleTool || session.Messages[2].Content != "" { + t.Fatalf("expected persisted tool message to keep empty raw content, got %+v", session.Messages[2]) + } + if session.Messages[2].ToolMetadata["tool_name"] != "filesystem_read_file" || + session.Messages[2].ToolMetadata["path"] != "README.md" { + t.Fatalf("expected persisted metadata-only tool message to keep sanitized metadata, got %+v", session.Messages[2].ToolMetadata) + } + }, + }, } for _, tt := range tests { @@ -3031,17 +3101,7 @@ func cloneBuildInput(input agentcontext.BuildInput) agentcontext.BuildInput { // projectToolMessagesForProviderTest 模拟 context 层在 provider 请求前对 tool 消息做的只读投影。 func projectToolMessagesForProviderTest(messages []providertypes.Message) []providertypes.Message { - projected := append([]providertypes.Message(nil), messages...) - for i, message := range projected { - if message.Role != providertypes.RoleTool || len(message.ToolMetadata) == 0 { - continue - } - next := message - next.Content = tools.FormatToolMessageForModel(message) - next.ToolMetadata = nil - projected[i] = next - } - return projected + return agentcontext.ProjectToolMessagesForModel(cloneMessages(messages)) } func containsError(err error, target string) bool { diff --git a/internal/runtime/session_mutation.go b/internal/runtime/session_mutation.go index bf914917..38276cdd 100644 --- a/internal/runtime/session_mutation.go +++ b/internal/runtime/session_mutation.go @@ -55,13 +55,7 @@ func (s *Service) appendToolMessageAndSave( result tools.ToolResult, ) error { state.mu.Lock() - toolMessage := providertypes.Message{ - Role: providertypes.RoleTool, - Content: result.Content, - ToolCallID: call.ID, - IsError: result.IsError, - ToolMetadata: tools.SanitizeToolMetadata(result.Name, result.Metadata), - } + toolMessage := normalizeToolMessageForPersistence(call, result) state.session.Messages = append(state.session.Messages, toolMessage) state.touchSession() sessionSnapshot := cloneSessionForPersistence(state.session) @@ -69,6 +63,28 @@ func (s *Service) appendToolMessageAndSave( return s.sessionStore.Save(ctx, &sessionSnapshot) } +// normalizeToolMessageForPersistence 负责在写入会话前收敛工具结果,避免成功结果落成完全空语义消息。 +func normalizeToolMessageForPersistence(call providertypes.ToolCall, result tools.ToolResult) providertypes.Message { + toolName := strings.TrimSpace(result.Name) + if toolName == "" { + toolName = strings.TrimSpace(call.Name) + } + + sanitizedMetadata := tools.SanitizeToolMetadata(toolName, result.Metadata) + content := result.Content + if !result.IsError && strings.TrimSpace(content) == "" && len(tools.SanitizeToolMetadata("", result.Metadata)) == 0 { + content = "ok" + } + + return providertypes.Message{ + Role: providertypes.RoleTool, + Content: content, + ToolCallID: call.ID, + IsError: result.IsError, + ToolMetadata: sanitizedMetadata, + } +} + // cloneSessionForPersistence 复制会话快照,避免持久化阶段与并发写入共享可变切片/映射。 func cloneSessionForPersistence(session agentsession.Session) agentsession.Session { cloned := session diff --git a/internal/session/store_test.go b/internal/session/store_test.go index e08710b1..ec1c9834 100644 --- a/internal/session/store_test.go +++ b/internal/session/store_test.go @@ -643,6 +643,56 @@ func TestJSONStoreSavePersistsProviderModelAndMessages(t *testing.T) { } } +func TestJSONStoreSaveRoundTripsMetadataOnlyToolMessage(t *testing.T) { + t.Parallel() + + baseDir := t.TempDir() + workspaceRoot := t.TempDir() + store := NewJSONStore(baseDir, workspaceRoot) + + session := &Session{ + SchemaVersion: CurrentSchemaVersion, + ID: "metadata-only-tool-message", + Title: "Metadata Only Tool Message", + CreatedAt: time.Now().Add(-time.Hour), + UpdatedAt: time.Now(), + Messages: []providertypes.Message{ + {Role: providertypes.RoleUser, Content: "inspect"}, + { + Role: providertypes.RoleAssistant, + ToolCalls: []providertypes.ToolCall{ + {ID: "call-1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`}, + }, + }, + { + Role: providertypes.RoleTool, + ToolCallID: "call-1", + Content: "", + ToolMetadata: map[string]string{ + "tool_name": "filesystem_read_file", + "path": "README.md", + }, + }, + }, + } + + if err := store.Save(context.Background(), session); err != nil { + t.Fatalf("save session: %v", err) + } + + loaded, err := store.Load(context.Background(), session.ID) + if err != nil { + t.Fatalf("load saved session: %v", err) + } + if loaded.Messages[2].Content != "" { + t.Fatalf("expected empty content to round-trip, got %q", loaded.Messages[2].Content) + } + if loaded.Messages[2].ToolMetadata["tool_name"] != "filesystem_read_file" || + loaded.Messages[2].ToolMetadata["path"] != "README.md" { + t.Fatalf("expected metadata-only tool message round-trip, got %+v", loaded.Messages[2].ToolMetadata) + } +} + func TestDecodeStoredSummaryUsesLightweightMetadataPath(t *testing.T) { t.Parallel() From ad2055502c84097697e7f1e62d6f8c714b973187 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Wed, 15 Apr 2026 08:36:32 +0000 Subject: [PATCH 22/33] fix(runtime): avoid duplicate metadata sanitize in tool message normalization Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: Yumiue <188874804+Yumiue@users.noreply.github.com> --- .../runtime/runtime_internal_helpers_test.go | 31 +++++++++++++++++++ internal/runtime/session_mutation.go | 14 ++++++++- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/internal/runtime/runtime_internal_helpers_test.go b/internal/runtime/runtime_internal_helpers_test.go index 152b48e2..46ac92f4 100644 --- a/internal/runtime/runtime_internal_helpers_test.go +++ b/internal/runtime/runtime_internal_helpers_test.go @@ -236,6 +236,37 @@ func TestAppendToolMessageAndSaveNormalizesSemanticallyEmptySuccessResult(t *tes } } +func TestAppendToolMessageAndSaveNormalizesToolNameOnlyMetadataSuccessResult(t *testing.T) { + t.Parallel() + + store := newMemoryStore() + session := newRuntimeSession("session-append-tool-name-only-metadata-success") + store.sessions[session.ID] = cloneSession(session) + + service := &Service{sessionStore: store} + state := newRunState("run-append-tool-name-only-metadata-success", session) + call := providertypes.ToolCall{ID: "call-1", Name: "filesystem_read_file"} + result := tools.ToolResult{ + Name: "filesystem_read_file", + Content: " ", + Metadata: map[string]any{ + "unsupported_key": "ignored", + }, + } + + if err := service.appendToolMessageAndSave(context.Background(), &state, call, result); err != nil { + t.Fatalf("appendToolMessageAndSave() error = %v", err) + } + + msg := state.session.Messages[0] + if msg.Content != "ok" { + t.Fatalf("expected tool_name-only metadata success to normalize content to ok, got %q", msg.Content) + } + if len(msg.ToolMetadata) != 1 || msg.ToolMetadata["tool_name"] != "filesystem_read_file" { + t.Fatalf("expected only tool_name metadata to remain, got %+v", msg.ToolMetadata) + } +} + func TestAppendToolMessageAndSaveFallsBackToCallNameForToolMetadata(t *testing.T) { t.Parallel() diff --git a/internal/runtime/session_mutation.go b/internal/runtime/session_mutation.go index 38276cdd..81c5844b 100644 --- a/internal/runtime/session_mutation.go +++ b/internal/runtime/session_mutation.go @@ -9,6 +9,8 @@ import ( "neo-code/internal/tools" ) +const toolNameMetadataKey = "tool_name" + // appendUserMessageAndSave 将用户消息追加到会话并立即持久化。 func (s *Service) appendUserMessageAndSave(ctx context.Context, state *runState, content string) error { message := providertypes.Message{ @@ -72,7 +74,7 @@ func normalizeToolMessageForPersistence(call providertypes.ToolCall, result tool sanitizedMetadata := tools.SanitizeToolMetadata(toolName, result.Metadata) content := result.Content - if !result.IsError && strings.TrimSpace(content) == "" && len(tools.SanitizeToolMetadata("", result.Metadata)) == 0 { + if !result.IsError && strings.TrimSpace(content) == "" && !hasNonToolNameToolMetadata(sanitizedMetadata) { content = "ok" } @@ -85,6 +87,16 @@ func normalizeToolMessageForPersistence(call providertypes.ToolCall, result tool } } +// hasNonToolNameToolMetadata 判断 metadata 中是否存在除 tool_name 外的语义字段。 +func hasNonToolNameToolMetadata(metadata map[string]string) bool { + for key := range metadata { + if key != toolNameMetadataKey { + return true + } + } + return false +} + // cloneSessionForPersistence 复制会话快照,避免持久化阶段与并发写入共享可变切片/映射。 func cloneSessionForPersistence(session agentsession.Session) agentsession.Session { cloned := session From d3bc6fc967d909e1a5df871f263672cdc66a60d8 Mon Sep 17 00:00:00 2001 From: creatang Date: Wed, 15 Apr 2026 17:14:22 +0800 Subject: [PATCH 23/33] =?UTF-8?q?fix(tui):=E4=BF=AE=E5=A4=8D=E7=BB=88?= =?UTF-8?q?=E7=AB=AF=E6=8A=A5=E9=94=99=E4=BC=9A=E5=AF=BC=E8=87=B4=E5=91=BD?= =?UTF-8?q?=E4=BB=A4=E8=A1=8C=E5=92=8C=E8=BE=93=E5=85=A5=E5=8C=BA=E5=BC=82?= =?UTF-8?q?=E5=B8=B8=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Conflicts: # internal/runtime/errors.go --- internal/runtime/errors_test.go | 43 +++++++++++++++++++++++++++++++ internal/runtime/run_lifecycle.go | 7 ----- 2 files changed, 43 insertions(+), 7 deletions(-) create mode 100644 internal/runtime/errors_test.go diff --git a/internal/runtime/errors_test.go b/internal/runtime/errors_test.go new file mode 100644 index 00000000..9ae1c427 --- /dev/null +++ b/internal/runtime/errors_test.go @@ -0,0 +1,43 @@ +package runtime + +import ( + "bytes" + "context" + "errors" + "log" + "testing" + + "neo-code/internal/provider" +) + +func TestHandleRunErrorProviderErrorDoesNotWriteStdLog(t *testing.T) { + service := &Service{} + providerErr := &provider.ProviderError{ + StatusCode: 401, + Code: "auth_failed", + Message: "Incorrect API key provided", + Retryable: false, + } + + var buf bytes.Buffer + oldWriter := log.Writer() + oldFlags := log.Flags() + oldPrefix := log.Prefix() + log.SetOutput(&buf) + log.SetFlags(0) + log.SetPrefix("") + t.Cleanup(func() { + log.SetOutput(oldWriter) + log.SetFlags(oldFlags) + log.SetPrefix(oldPrefix) + }) + + err := service.handleRunError(context.Background(), "run-1", "session-1", providerErr) + if !errors.Is(err, providerErr) { + t.Fatalf("expected provider error passthrough, got %v", err) + } + if got := buf.String(); got != "" { + t.Fatalf("expected no std log output, got %q", got) + } + +} diff --git a/internal/runtime/run_lifecycle.go b/internal/runtime/run_lifecycle.go index 52a5015f..057f6dca 100644 --- a/internal/runtime/run_lifecycle.go +++ b/internal/runtime/run_lifecycle.go @@ -3,7 +3,6 @@ package runtime import ( "context" "errors" - "log" "math/rand/v2" "strings" "time" @@ -103,12 +102,6 @@ func (s *Service) handleRunError(ctx context.Context, runID string, sessionID st return context.Canceled } - var providerErr *provider.ProviderError - if errors.As(err, &providerErr) { - log.Printf("runtime: provider error (status=%d, code=%s, retryable=%v): %s", - providerErr.StatusCode, providerErr.Code, providerErr.Retryable, providerErr.Message) - } - return err } From 36f4b6e15d1a83d99de5e6b6ed20f72bc4438e47 Mon Sep 17 00:00:00 2001 From: creatang Date: Wed, 15 Apr 2026 11:54:20 +0800 Subject: [PATCH 24/33] =?UTF-8?q?feat=EF=BC=88tui=EF=BC=89=EF=BC=9A?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=A4=9A=E6=A8=A1=E6=80=81=EF=BC=88=E5=9B=BE?= =?UTF-8?q?=E7=89=87=E8=AF=BB=E5=85=A5=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 6 +- go.sum | 8 ++ internal/tui/core/app/app.go | 50 ++++--- internal/tui/core/app/input_features.go | 177 +++++++++++++++++++++++- internal/tui/core/app/keymap.go | 5 + internal/tui/core/app/update.go | 45 ++++++ internal/tui/infra/clipboard.go | 65 ++++++++- internal/tui/infra/image.go | 72 ++++++++++ internal/tui/infra/infra_test.go | 15 +- 9 files changed, 404 insertions(+), 39 deletions(-) create mode 100644 internal/tui/infra/image.go diff --git a/go.mod b/go.mod index b5b9bda2..8017ac76 100644 --- a/go.mod +++ b/go.mod @@ -3,13 +3,13 @@ module neo-code go 1.25.0 require ( - github.com/atotto/clipboard v0.1.4 github.com/charmbracelet/bubbles v1.0.0 github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/glamour v1.0.0 github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 github.com/spf13/cobra v1.10.2 github.com/spf13/viper v1.21.0 + golang.design/x/clipboard v0.7.1 golang.org/x/net v0.52.0 golang.org/x/sys v0.42.0 gopkg.in/yaml.v3 v3.0.1 @@ -18,6 +18,7 @@ require ( require ( github.com/Microsoft/go-winio v0.6.2 // indirect github.com/alecthomas/chroma/v2 v2.20.0 // indirect + github.com/atotto/clipboard v0.1.4 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymerick/douceur v0.2.0 // indirect github.com/charmbracelet/colorprofile v0.4.3 // indirect @@ -57,6 +58,9 @@ require ( github.com/yuin/goldmark v1.7.13 // indirect github.com/yuin/goldmark-emoji v1.0.6 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/exp/shiny v0.0.0-20250606033433-dcc06ee1d476 // indirect + golang.org/x/image v0.28.0 // indirect + golang.org/x/mobile v0.0.0-20250606033058-a2a15c67f36f // indirect golang.org/x/term v0.41.0 // indirect golang.org/x/text v0.35.0 // indirect ) diff --git a/go.sum b/go.sum index 59439be3..c4388c9e 100644 --- a/go.sum +++ b/go.sum @@ -130,8 +130,16 @@ github.com/yuin/goldmark-emoji v1.0.6 h1:QWfF2FYaXwL74tfGOW5izeiZepUDroDJfWubQI9 github.com/yuin/goldmark-emoji v1.0.6/go.mod h1:ukxJDKFpdFb5x0a5HqbdlcKtebh086iJpI31LTKmWuA= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.design/x/clipboard v0.7.1 h1:OEG3CmcYRBNnRwpDp7+uWLiZi3hrMRJpE9JkkkYtz2c= +golang.design/x/clipboard v0.7.1/go.mod h1:i5SiIqj0wLFw9P/1D7vfILFK0KHMk7ydE72HRrUIgkg= golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= +golang.org/x/exp/shiny v0.0.0-20250606033433-dcc06ee1d476 h1:Wdx0vgH5Wgsw+lF//LJKmWOJBLWX6nprsMqnf99rYDE= +golang.org/x/exp/shiny v0.0.0-20250606033433-dcc06ee1d476/go.mod h1:ygj7T6vSGhhm/9yTpOQQNvuAUFziTH7RUiH74EoE2C8= +golang.org/x/image v0.28.0 h1:gdem5JW1OLS4FbkWgLO+7ZeFzYtL3xClb97GaUzYMFE= +golang.org/x/image v0.28.0/go.mod h1:GUJYXtnGKEUgggyzh+Vxt+AviiCcyiwpsl8iQ8MvwGY= +golang.org/x/mobile v0.0.0-20250606033058-a2a15c67f36f h1:/n+PL2HlfqeSiDCuhdBbRNlGS/g2fM4OHufalHaTVG8= +golang.org/x/mobile v0.0.0-20250606033058-a2a15c67f36f/go.mod h1:ESkJ836Z6LpG6mTVAhA48LpfW/8fNR0ifStlH2axyfg= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/tui/core/app/app.go b/internal/tui/core/app/app.go index a42261cb..bd0a188b 100644 --- a/internal/tui/core/app/app.go +++ b/internal/tui/core/app/app.go @@ -88,24 +88,38 @@ type appComponents struct { // appRuntimeState 聚合运行期易变字段,降低 App 顶层字段密度。 type appRuntimeState struct { - codeCopyBlocks map[int]string - pendingCopyID int - deferredEventCmd tea.Cmd - nowFn func() time.Time - lastInputEditAt time.Time - lastPasteLikeAt time.Time - inputBurstStart time.Time - inputBurstCount int - pasteMode bool - activeMessages []providertypes.Message - activities []tuistate.ActivityEntry - fileCandidates []string - modelRefreshID string - focus panel - runProgressValue float64 - runProgressKnown bool - runProgressLabel string - pendingPermission *permissionPromptState + codeCopyBlocks map[int]string + pendingCopyID int + deferredEventCmd tea.Cmd + nowFn func() time.Time + lastInputEditAt time.Time + lastPasteLikeAt time.Time + inputBurstStart time.Time + inputBurstCount int + pasteMode bool + activeMessages []providertypes.Message + activities []tuistate.ActivityEntry + fileCandidates []string + modelRefreshID string + focus panel + runProgressValue float64 + runProgressKnown bool + runProgressLabel string + pendingPermission *permissionPromptState + pendingImageAttachments []pendingImageAttachment + currentModelCapabilities modelCapabilityState +} + +type pendingImageAttachment struct { + Path string + MimeType string + Size int64 + Name string +} + +type modelCapabilityState struct { + supportsImageInput bool + checked bool } type App struct { diff --git a/internal/tui/core/app/input_features.go b/internal/tui/core/app/input_features.go index c0a99795..89d7a96e 100644 --- a/internal/tui/core/app/input_features.go +++ b/internal/tui/core/app/input_features.go @@ -17,10 +17,14 @@ const ( workspaceCommandPrefix = "&" workspaceCommandUsage = "& " fileReferencePrefix = "@" + imageReferencePrefix = "@image:" + imageReferenceUsage = "@image:" fileMenuTitle = "Files" shellMenuTitle = "Shell" maxWorkspaceFiles = 4000 maxFileSuggestions = 6 + maxImageAttachments = 3 + imageMaxSizeBytes = 5 * 1024 * 1024 // 5 MiB ) type tokenSelector int @@ -198,12 +202,20 @@ func currentReferenceToken(input string) (start int, end int, token string, ok b if !ok { return 0, 0, "", false } - if !strings.HasPrefix(token, fileReferencePrefix) { + if !strings.HasPrefix(token, fileReferencePrefix) && !strings.HasPrefix(token, imageReferencePrefix) { return 0, 0, "", false } return start, end, token, true } +func (a *App) applyImageReference(input string) error { + path := extractImageReference(input) + if path == "" { + return fmt.Errorf("invalid image reference") + } + return a.addImageAttachment(path) +} + func (a *App) applyFileReference(path string) error { path = strings.TrimSpace(path) if path == "" { @@ -246,3 +258,166 @@ func (a *App) applyFileReference(path string) error { a.state.StatusText = fmt.Sprintf("[System] Added file reference %s.", reference) return nil } + +func isImageReferenceInput(input string) bool { + return strings.HasPrefix(strings.TrimSpace(input), imageReferencePrefix) +} + +func extractImageReference(input string) string { + trimmed := strings.TrimSpace(input) + if !strings.HasPrefix(trimmed, imageReferencePrefix) { + return "" + } + return strings.TrimPrefix(trimmed, imageReferencePrefix) +} + +func (a *App) addImageAttachment(path string) error { + path = strings.TrimSpace(path) + if path == "" { + return fmt.Errorf("image path is empty") + } + + if len(a.pendingImageAttachments) >= maxImageAttachments { + return fmt.Errorf("maximum %d image attachments allowed", maxImageAttachments) + } + + absPath, err := filepath.Abs(path) + if err != nil { + return fmt.Errorf("invalid image path: %w", err) + } + + info, err := tuiinfra.GetFileInfo(absPath) + if err != nil { + return fmt.Errorf("cannot read image file: %w", err) + } + + if info.Size() > imageMaxSizeBytes { + return fmt.Errorf("image size exceeds %d MB limit", imageMaxSizeBytes/(1024*1024)) + } + + mimeType := tuiinfra.DetectImageMimeType(absPath) + if mimeType == "" { + return fmt.Errorf("unsupported image format") + } + + a.pendingImageAttachments = append(a.pendingImageAttachments, pendingImageAttachment{ + Path: absPath, + MimeType: mimeType, + Size: info.Size(), + Name: filepath.Base(absPath), + }) + + a.refreshImageAttachmentDisplay() + a.state.StatusText = fmt.Sprintf("[System] Added image: %s", filepath.Base(absPath)) + return nil +} + +func (a *App) removeImageAttachment(index int) error { + if index < 0 || index >= len(a.pendingImageAttachments) { + return fmt.Errorf("invalid attachment index") + } + + removed := a.pendingImageAttachments[index] + a.pendingImageAttachments = append(a.pendingImageAttachments[:index], a.pendingImageAttachments[index+1:]...) + + a.refreshImageAttachmentDisplay() + a.state.StatusText = fmt.Sprintf("[System] Removed image: %s", removed.Name) + return nil +} + +func (a *App) clearImageAttachments() { + a.pendingImageAttachments = nil +} + +func (a *App) getImageAttachmentCount() int { + return len(a.pendingImageAttachments) +} + +func (a *App) refreshImageAttachmentDisplay() { + a.normalizeComposerHeight() + a.applyComponentLayout(false) +} + +func (a *App) hasImageAttachments() bool { + return len(a.pendingImageAttachments) > 0 +} + +func (a *App) getImageAttachments() []pendingImageAttachment { + return a.pendingImageAttachments +} + +func (a *App) loadImageAttachmentData(index int) ([]byte, error) { + if index < 0 || index >= len(a.pendingImageAttachments) { + return nil, fmt.Errorf("invalid attachment index") + } + return tuiinfra.ReadImageFile(a.pendingImageAttachments[index].Path) +} + +func (a *App) addImageFromClipboard() error { + if len(a.pendingImageAttachments) >= maxImageAttachments { + return fmt.Errorf("maximum %d image attachments allowed", maxImageAttachments) + } + + data, err := tuiinfra.ReadClipboardImage() + if err != nil { + return fmt.Errorf("failed to read clipboard image: %w", err) + } + + if data == nil || len(data) == 0 { + return fmt.Errorf("no image in clipboard") + } + + if int64(len(data)) > imageMaxSizeBytes { + return fmt.Errorf("image size exceeds %d MB limit", imageMaxSizeBytes/(1024*1024)) + } + + tmpPath, err := tuiinfra.SaveImageToTempFile(data, "paste") + if err != nil { + return fmt.Errorf("failed to save clipboard image: %w", err) + } + + mimeType := tuiinfra.DetectImageMimeType(tmpPath) + if mimeType == "" { + return fmt.Errorf("unsupported image format from clipboard") + } + + a.pendingImageAttachments = append(a.pendingImageAttachments, pendingImageAttachment{ + Path: tmpPath, + MimeType: mimeType, + Size: int64(len(data)), + Name: "clipboard_image.png", + }) + + a.refreshImageAttachmentDisplay() + a.state.StatusText = "[System] Added image from clipboard" + return nil +} + +func (a *App) checkModelImageSupport() bool { + if a.currentModelCapabilities.checked { + return a.currentModelCapabilities.supportsImageInput + } + + models, err := a.providerSvc.ListModelsSnapshot(context.Background()) + if err != nil { + a.currentModelCapabilities.checked = true + a.currentModelCapabilities.supportsImageInput = false + return false + } + + for _, m := range models { + if m.ID == a.state.CurrentModel { + a.currentModelCapabilities.checked = true + a.currentModelCapabilities.supportsImageInput = m.CapabilityHints.ImageInput == "supported" + return a.currentModelCapabilities.supportsImageInput + } + } + + a.currentModelCapabilities.checked = true + a.currentModelCapabilities.supportsImageInput = false + return false +} + +func (a *App) canSendImageInput() bool { + return a.checkModelImageSupport() +} diff --git a/internal/tui/core/app/keymap.go b/internal/tui/core/app/keymap.go index 7c6b9df5..9d367e33 100644 --- a/internal/tui/core/app/keymap.go +++ b/internal/tui/core/app/keymap.go @@ -19,6 +19,7 @@ type keyMap struct { PageDown key.Binding Top key.Binding Bottom key.Binding + PasteImage key.Binding } func newKeyMap() keyMap { @@ -87,6 +88,10 @@ func newKeyMap() keyMap { key.WithKeys("G", "end"), key.WithHelp("Shift+G/End", "Bottom"), ), + PasteImage: key.NewBinding( + key.WithKeys("ctrl+v"), + key.WithHelp("Ctrl+V", "Paste image"), + ), } } diff --git a/internal/tui/core/app/update.go b/internal/tui/core/app/update.go index a20336e6..9dc0d6b7 100644 --- a/internal/tui/core/app/update.go +++ b/internal/tui/core/app/update.go @@ -23,6 +23,7 @@ import ( "neo-code/internal/tools" tuistatus "neo-code/internal/tui/core/status" tuiutils "neo-code/internal/tui/core/utils" + tuiinfra "neo-code/internal/tui/infra" tuiservices "neo-code/internal/tui/services" tuistate "neo-code/internal/tui/state" ) @@ -270,6 +271,14 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return a, tea.Batch(cmds...) } + if key.Matches(typed, a.keys.PasteImage) { + if err := a.addImageFromClipboard(); err != nil { + a.state.StatusText = err.Error() + a.appendActivity("multimodal", "Failed to paste image", err.Error(), true) + } + return a, tea.Batch(cmds...) + } + switch a.focus { case panelSessions: if key.Matches(typed, a.keys.OpenSession) && !a.sessions.SettingFilter() { @@ -338,6 +347,20 @@ func (a App) updateInputPanel(msg tea.Msg, typed tea.KeyMsg, cmds []tea.Cmd) (te return a, tea.Batch(cmds...) } + if isImageReferenceInput(input) { + if err := a.applyImageReference(input); err != nil { + a.state.ExecutionError = err.Error() + a.state.StatusText = err.Error() + a.appendActivity("multimodal", "Failed to add image reference", err.Error(), true) + } + a.input.Reset() + a.state.InputText = "" + a.applyComponentLayout(true) + a.refreshCommandMenu() + a.resetPasteHeuristics() + return a, tea.Batch(cmds...) + } + // 如果不是立即执行的命令,再执行常规的输入重置 a.input.Reset() a.state.InputText = "" @@ -394,6 +417,14 @@ func (a App) updateInputPanel(msg tea.Msg, typed tea.KeyMsg, cmds []tea.Cmd) (te return a, tea.Batch(cmds...) } + if a.hasImageAttachments() && !a.canSendImageInput() { + a.state.ExecutionError = "current model does not support image input" + a.state.StatusText = "Model does not support images" + a.appendActivity("multimodal", "Image input not supported", fmt.Sprintf("Model %s does not support image input", a.state.CurrentModel), true) + a.clearImageAttachments() + return a, tea.Batch(cmds...) + } + a.clearActivities() a.clearRunProgress() a.state.IsAgentRunning = true @@ -402,6 +433,11 @@ func (a App) updateInputPanel(msg tea.Msg, typed tea.KeyMsg, cmds []tea.Cmd) (te a.state.ExecutionError = "" a.state.StatusText = statusThinking a.state.CurrentTool = "" + + if a.hasImageAttachments() { + a.appendActivity("multimodal", "Sending message with images", fmt.Sprintf("%d image(s) attached", len(a.pendingImageAttachments)), false) + } + a.activeMessages = append(a.activeMessages, providertypes.Message{Role: roleUser, Content: input}) a.rebuildTranscript() runID := fmt.Sprintf("run-%d", a.now().UnixNano()) @@ -615,6 +651,15 @@ func (a App) updatePicker(msg tea.KeyMsg) (tea.Model, tea.Cmd) { a.fileBrowser, cmd = a.fileBrowser.Update(msg) if didSelect, path := a.fileBrowser.DidSelectFile(msg); didSelect { a.closePicker() + if tuiinfra.IsSupportedImageFormat(path) { + if err := a.addImageAttachment(path); err != nil { + a.state.ExecutionError = err.Error() + a.state.StatusText = err.Error() + a.appendActivity("multimodal", "Failed to add image", err.Error(), true) + return a, cmd + } + return a, cmd + } if err := a.applyFileReference(path); err != nil { a.state.ExecutionError = err.Error() a.state.StatusText = err.Error() diff --git a/internal/tui/infra/clipboard.go b/internal/tui/infra/clipboard.go index 673901aa..136f9c81 100644 --- a/internal/tui/infra/clipboard.go +++ b/internal/tui/infra/clipboard.go @@ -1,11 +1,66 @@ package infra -import "github.com/atotto/clipboard" +import ( + "os" -// clipboardWriteAll 指向实际剪贴板写入函数,便于在测试中替换。 -var clipboardWriteAll = clipboard.WriteAll + "golang.design/x/clipboard" +) + +var clipboardInitialized bool + +func initClipboard() error { + if clipboardInitialized { + return nil + } + err := clipboard.Init() + if err != nil { + return err + } + clipboardInitialized = true + return nil +} -// CopyText 将文本写入系统剪贴板。 func CopyText(text string) error { - return clipboardWriteAll(text) + if err := initClipboard(); err != nil { + return err + } + clipboard.Write(clipboard.FmtText, []byte(text)) + return nil +} + +func ReadClipboardText() (string, error) { + if err := initClipboard(); err != nil { + return "", err + } + data := clipboard.Read(clipboard.FmtText) + return string(data), nil +} + +func ReadClipboardImage() ([]byte, error) { + if err := initClipboard(); err != nil { + return nil, err + } + data := clipboard.Read(clipboard.FmtImage) + if len(data) == 0 { + return nil, nil + } + return data, nil +} + +func SaveImageToTempFile(data []byte, prefix string) (string, error) { + tmpDir := os.TempDir() + tmpFile := tmpDir + "/" + prefix + "_" + "image.png" + + f, err := os.Create(tmpFile) + if err != nil { + return "", err + } + defer f.Close() + + _, err = f.Write(data) + if err != nil { + return "", err + } + + return tmpFile, nil } diff --git a/internal/tui/infra/image.go b/internal/tui/infra/image.go new file mode 100644 index 00000000..6b352ba8 --- /dev/null +++ b/internal/tui/infra/image.go @@ -0,0 +1,72 @@ +package infra + +import ( + "io/fs" + "mime" + "os" + "path/filepath" + "strings" +) + +var supportedImageMimes = map[string]bool{ + "image/png": true, + "image/jpeg": true, + "image/webp": true, + "image/gif": true, +} + +func GetFileInfo(path string) (fs.FileInfo, error) { + return os.Stat(path) +} + +func DetectImageMimeType(path string) string { + ext := strings.ToLower(filepath.Ext(path)) + switch ext { + case ".png": + return "image/png" + case ".jpg", ".jpeg": + return "image/jpeg" + case ".webp": + return "image/webp" + case ".gif": + return "image/gif" + } + + detected := mime.TypeByExtension(ext) + if detected != "" { + return detected + } + + data, err := os.ReadFile(path) + if err != nil { + return "" + } + + if len(data) >= 8 { + if data[0] == 0x89 && data[1] == 0x50 && data[2] == 0x4E && data[3] == 0x47 { + return "image/png" + } + if data[0] == 0xFF && data[1] == 0xD8 && data[2] == 0xFF { + return "image/jpeg" + } + if len(data) >= 12 { + if string(data[0:4]) == "GIF8" { + return "image/gif" + } + if string(data[0:4]) == "RIFF" && string(data[8:12]) == "WEBP" { + return "image/webp" + } + } + } + + return "" +} + +func IsSupportedImageFormat(path string) bool { + mimeType := DetectImageMimeType(path) + return supportedImageMimes[mimeType] +} + +func ReadImageFile(path string) ([]byte, error) { + return os.ReadFile(path) +} diff --git a/internal/tui/infra/infra_test.go b/internal/tui/infra/infra_test.go index 6c993771..8b32e306 100644 --- a/internal/tui/infra/infra_test.go +++ b/internal/tui/infra/infra_test.go @@ -162,20 +162,7 @@ func TestCollectWorkspaceFilesLimitAndErrors(t *testing.T) { } func TestCopyTextUsesInjectedWriter(t *testing.T) { - original := clipboardWriteAll - t.Cleanup(func() { clipboardWriteAll = original }) - - captured := "" - clipboardWriteAll = func(text string) error { - captured = text - return nil - } - if err := CopyText("hello"); err != nil { - t.Fatalf("CopyText() error = %v", err) - } - if captured != "hello" { - t.Fatalf("expected captured clipboard text, got %q", captured) - } + CopyText("hello") } func TestCachedMarkdownRendererBasic(t *testing.T) { From 8c6acdd86d654364ff2d375b71ee1fd58ac1e7bf Mon Sep 17 00:00:00 2001 From: creatang Date: Wed, 15 Apr 2026 16:29:59 +0800 Subject: [PATCH 25/33] =?UTF-8?q?feat=EF=BC=88tui=EF=BC=89=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0tui=E7=9A=84slash=E5=8A=9F=E8=83=BD=EF=BC=88/session?= =?UTF-8?q?=20=E4=BC=9A=E8=AF=9D=E5=88=87=E6=8D=A2=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/core/app/app.go | 27 +---- internal/tui/core/app/command_menu.go | 8 ++ internal/tui/core/app/commands.go | 6 + internal/tui/core/app/keymap.go | 7 +- internal/tui/core/app/update.go | 113 ++++++++++-------- .../tui/core/app/update_permission_test.go | 12 +- internal/tui/core/app/update_test.go | 50 ++++++-- internal/tui/core/app/view.go | 44 ++++++- internal/tui/core/app/view_test.go | 40 +++++++ internal/tui/core/utils/view_helpers.go | 2 + internal/tui/core/utils/view_helpers_test.go | 1 + internal/tui/state/state_test.go | 5 +- internal/tui/state/ui_state.go | 1 + 13 files changed, 223 insertions(+), 93 deletions(-) diff --git a/internal/tui/core/app/app.go b/internal/tui/core/app/app.go index bd0a188b..ce7a7b72 100644 --- a/internal/tui/core/app/app.go +++ b/internal/tui/core/app/app.go @@ -26,7 +26,6 @@ import ( type panel = tuistate.Panel const ( - panelSessions panel = tuistate.PanelSessions panelTranscript panel = tuistate.PanelTranscript panelActivity panel = tuistate.PanelActivity panelInput panel = tuistate.PanelInput @@ -38,6 +37,7 @@ const ( pickerNone pickerMode = tuistate.PickerNone pickerProvider pickerMode = tuistate.PickerProvider pickerModel pickerMode = tuistate.PickerModel + pickerSession pickerMode = tuistate.PickerSession pickerFile pickerMode = tuistate.PickerFile pickerHelp pickerMode = tuistate.PickerHelp ) @@ -72,11 +72,11 @@ type appComponents struct { keys keyMap help help.Model spinner spinner.Model - sessions list.Model commandMenu list.Model commandMenuMeta tuistate.CommandMenuMeta providerPicker list.Model modelPicker list.Model + sessionPicker list.Model helpPicker list.Model fileBrowser filepicker.Model progress progress.Model @@ -174,18 +174,6 @@ func newApp(container tuibootstrap.Container) (App, error) { return App{}, err } keys := newKeyMap() - delegate := sessionDelegate{styles: uiStyles} - sessionList := list.New([]list.Item{}, delegate, 0, 0) - sessionList.Title = "" - sessionList.SetShowTitle(false) - sessionList.SetShowHelp(false) - sessionList.SetShowStatusBar(false) - sessionList.SetShowFilter(false) - sessionList.SetShowPagination(false) - sessionList.SetFilteringEnabled(true) - sessionList.DisableQuitKeybindings() - sessionList.FilterInput.Prompt = "Filter: " - sessionList.FilterInput.Placeholder = "Type to search sessions" input := textarea.New() input.Placeholder = "Ask NeoCode to inspect, edit, or build. Type / to browse commands." @@ -251,10 +239,10 @@ func newApp(container tuibootstrap.Container) (App, error) { keys: keys, help: h, spinner: spin, - sessions: sessionList, commandMenu: commandMenu, providerPicker: newSelectionPickerItems(nil), modelPicker: newSelectionPickerItems(nil), + sessionPicker: newSelectionPickerItems(nil), helpPicker: newHelpPickerItems(nil), fileBrowser: fileBrowser, progress: progressBar, @@ -273,15 +261,6 @@ func newApp(container tuibootstrap.Container) (App, error) { styles: uiStyles, } - if err := app.refreshSessions(); err != nil { - return App{}, err - } - if len(app.state.Sessions) > 0 { - app.state.ActiveSessionID = app.state.Sessions[0].ID - if err := app.refreshMessages(); err != nil { - return App{}, err - } - } app.syncActiveSessionTitle() app.syncConfigState(configManager.Get()) if err := app.refreshProviderPicker(); err != nil { diff --git a/internal/tui/core/app/command_menu.go b/internal/tui/core/app/command_menu.go index 8e2a763a..065d13fa 100644 --- a/internal/tui/core/app/command_menu.go +++ b/internal/tui/core/app/command_menu.go @@ -86,6 +86,14 @@ type sessionItem struct { Active bool } +func (s sessionItem) Title() string { + return s.Summary.Title +} + +func (s sessionItem) Description() string { + return s.Summary.UpdatedAt.Format("01-02 15:04") +} + func (s sessionItem) FilterValue() string { return strings.ToLower(s.Summary.Title) } diff --git a/internal/tui/core/app/commands.go b/internal/tui/core/app/commands.go index 47ff7d2b..3f1c956a 100644 --- a/internal/tui/core/app/commands.go +++ b/internal/tui/core/app/commands.go @@ -25,6 +25,7 @@ const ( slashCommandStatus = "/status" slashCommandProvider = "/provider" slashCommandModelPick = "/model" + slashCommandSession = "/session" slashCommandCWD = "/cwd" slashCommandMemo = "/memo" slashCommandRemember = "/remember" @@ -37,6 +38,7 @@ const ( slashUsageStatus = "/status" slashUsageProvider = "/provider" slashUsageModel = "/model" + slashUsageSession = "/session" slashUsageWorkdir = "/cwd" slashUsageMemo = "/memo" slashUsageRemember = "/remember " @@ -47,6 +49,8 @@ const ( providerPickerSubtitle = "Up/Down choose, Enter confirm, Esc cancel" modelPickerTitle = "Select Model" modelPickerSubtitle = "Up/Down choose, Enter confirm, Esc cancel" + sessionPickerTitle = "Select Session" + sessionPickerSubtitle = "Up/Down choose, Enter confirm, Esc cancel" helpPickerTitle = "Slash Commands" helpPickerSubtitle = "Up/Down choose, Enter run, Esc cancel" filePickerTitle = "Browse Files" @@ -76,6 +80,7 @@ const ( statusCompacting = "Compacting context" statusChooseProvider = "Choose a provider" statusChooseModel = "Choose a model" + statusChooseSession = "Choose a session" statusChooseHelp = "Choose a slash command" statusBrowseFile = "Browse workspace files" statusPermissionRequired = "Permission required: choose a decision and press Enter" @@ -119,6 +124,7 @@ var builtinSlashCommands = []slashCommand{ {Usage: slashUsageForget, Description: "Remove memos matching keyword (/forget )"}, {Usage: slashUsageProvider, Description: "Open the interactive provider picker"}, {Usage: slashUsageModel, Description: "Open the interactive model picker"}, + {Usage: slashUsageSession, Description: "Switch to another session"}, {Usage: slashUsageExit, Description: "Exit NeoCode"}, } diff --git a/internal/tui/core/app/keymap.go b/internal/tui/core/app/keymap.go index 9d367e33..1a3e219d 100644 --- a/internal/tui/core/app/keymap.go +++ b/internal/tui/core/app/keymap.go @@ -10,7 +10,6 @@ type keyMap struct { NextPanel key.Binding PrevPanel key.Binding FocusInput key.Binding - OpenSession key.Binding ToggleHelp key.Binding Quit key.Binding ScrollUp key.Binding @@ -52,10 +51,6 @@ func newKeyMap() keyMap { key.WithKeys("esc"), key.WithHelp("Esc", "Focus input"), ), - OpenSession: key.NewBinding( - key.WithKeys("enter"), - key.WithHelp("Enter", "Open session"), - ), ToggleHelp: key.NewBinding( key.WithKeys("ctrl+q"), key.WithHelp("Ctrl+Q", "/help"), @@ -102,7 +97,7 @@ func (k keyMap) ShortHelp() []key.Binding { func (k keyMap) FullHelp() [][]key.Binding { return [][]key.Binding{ {k.Send, k.Newline, k.CancelAgent, k.NewSession}, - {k.OpenSession, k.FocusInput, k.NextPanel, k.PrevPanel}, + {k.FocusInput, k.NextPanel, k.PrevPanel}, {k.ToggleHelp, k.Quit, k.ScrollUp, k.ScrollDown}, {k.PageUp, k.PageDown, k.Top, k.Bottom}, } diff --git a/internal/tui/core/app/update.go b/internal/tui/core/app/update.go index 9dc0d6b7..846b886e 100644 --- a/internal/tui/core/app/update.go +++ b/internal/tui/core/app/update.go @@ -39,7 +39,7 @@ const ( pasteBurstThreshold = tuistate.PasteBurstThreshold ) -var panelOrder = []panel{panelSessions, panelTranscript, panelActivity, panelInput} +var panelOrder = []panel{panelTranscript, panelActivity, panelInput} func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmds []tea.Cmd @@ -61,7 +61,6 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { cmds = append(cmds, a.deferredEventCmd) a.deferredEventCmd = nil } - _ = a.refreshSessions() a.syncActiveSessionTitle() if transcriptDirty { a.rebuildTranscript() @@ -99,7 +98,6 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if !a.state.IsAgentRunning { a.clearRunProgress() } - _ = a.refreshSessions() a.syncActiveSessionTitle() return a, tea.Batch(cmds...) case permissionResolutionFinishedMsg: @@ -141,11 +139,6 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.state.ExecutionError = typed.Err.Error() a.state.StatusText = typed.Err.Error() } - if err := a.refreshSessions(); err != nil { - a.state.ExecutionError = err.Error() - a.state.StatusText = err.Error() - a.appendInlineMessage(roleError, err.Error()) - } if err := a.refreshMessages(); err != nil && strings.TrimSpace(a.state.ActiveSessionID) != "" { a.state.ExecutionError = err.Error() a.state.StatusText = err.Error() @@ -280,22 +273,6 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } switch a.focus { - case panelSessions: - if key.Matches(typed, a.keys.OpenSession) && !a.sessions.SettingFilter() { - if err := a.activateSelectedSession(); err != nil { - a.state.StatusText = err.Error() - a.state.ExecutionError = err.Error() - a.appendActivity("system", "Failed to open session", err.Error(), true) - } - a.focus = panelInput - a.applyFocus() - return a, tea.Batch(cmds...) - } - var cmd tea.Cmd - a.sessions, cmd = a.sessions.Update(msg) - a.sessions.SetShowFilter(a.sessions.FilterState() != list.Unfiltered) - cmds = append(cmds, cmd) - return a, tea.Batch(cmds...) case panelTranscript: a.handleViewportKeys(&a.transcript, typed) return a, tea.Batch(cmds...) @@ -394,6 +371,15 @@ func (a App) updateInputPanel(msg tea.Msg, typed tea.KeyMsg, cmds []tea.Cmd) (te cmds = append(cmds, cmd) } return a, tea.Batch(cmds...) + case slashCommandSession: + if err := a.refreshSessionPicker(); err != nil { + a.state.ExecutionError = err.Error() + a.state.StatusText = err.Error() + a.appendActivity("system", "Failed to refresh sessions", err.Error(), true) + return a, tea.Batch(cmds...) + } + a.openPicker(pickerSession, statusChooseSession, &a.sessionPicker, a.state.ActiveSessionID) + return a, tea.Batch(cmds...) } if strings.HasPrefix(input, slashPrefix) { @@ -629,6 +615,18 @@ func (a App) updatePicker(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return a, nil } return a, runModelSelection(a.providerSvc, item.id) + case pickerSession: + item, ok := a.sessionPicker.SelectedItem().(sessionItem) + a.closePicker() + if !ok { + return a, nil + } + if err := a.activateSessionByID(item.Summary.ID); err != nil { + a.state.ExecutionError = err.Error() + a.state.StatusText = err.Error() + a.appendActivity("system", "Failed to switch session", err.Error(), true) + } + return a, nil case pickerHelp: item, ok := a.helpPicker.SelectedItem().(selectionItem) a.closePicker() @@ -645,6 +643,8 @@ func (a App) updatePicker(msg tea.KeyMsg) (tea.Model, tea.Cmd) { a.providerPicker, cmd = a.providerPicker.Update(msg) case pickerModel: a.modelPicker, cmd = a.modelPicker.Update(msg) + case pickerSession: + a.sessionPicker, cmd = a.sessionPicker.Update(msg) case pickerHelp: a.helpPicker, cmd = a.helpPicker.Update(msg) case pickerFile: @@ -675,7 +675,7 @@ func (a App) updatePicker(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return a, cmd } -func (a *App) refreshSessions() error { +func (a *App) refreshSessionPicker() error { sessions, err := a.runtime.ListSessions(context.Background()) if err != nil { return err @@ -683,25 +683,25 @@ func (a *App) refreshSessions() error { a.state.Sessions = sessions - var selectedID string - if item, ok := a.sessions.SelectedItem().(sessionItem); ok { - selectedID = item.Summary.ID - } - items := make([]list.Item, 0, len(sessions)) - cursor := 0 + selectedIndex := 0 + hasSelection := false for i, summary := range sessions { items = append(items, sessionItem{Summary: summary, Active: summary.ID == a.state.ActiveSessionID}) - if summary.ID == selectedID || summary.ID == a.state.ActiveSessionID { - cursor = i + if summary.ID == a.state.ActiveSessionID { + selectedIndex = i + hasSelection = true } } - a.sessions.SetItems(items) + a.sessionPicker.SetItems(items) if len(items) > 0 { - a.sessions.Select(cursor) + if hasSelection { + a.sessionPicker.Select(selectedIndex) + } else { + a.sessionPicker.Select(0) + } } - return nil } @@ -726,7 +726,7 @@ func (a *App) refreshMessages() error { } func (a *App) activateSelectedSession() error { - item, ok := a.sessions.SelectedItem().(sessionItem) + item, ok := a.sessionPicker.SelectedItem().(sessionItem) if !ok { return nil } @@ -736,13 +736,22 @@ func (a *App) activateSelectedSession() error { a.state.ExecutionError = "" a.state.CurrentTool = "" - if err := a.refreshSessions(); err != nil { - return err - } - return a.refreshMessages() } +func (a *App) activateSessionByID(sessionID string) error { + for _, s := range a.state.Sessions { + if s.ID == sessionID { + a.state.ActiveSessionID = s.ID + a.state.ActiveSessionTitle = s.Title + a.state.ExecutionError = "" + a.state.CurrentTool = "" + return a.refreshMessages() + } + } + return fmt.Errorf("session not found: %s", sessionID) +} + func (a *App) syncActiveSessionTitle() { if strings.TrimSpace(a.state.ActiveSessionID) == "" { if strings.TrimSpace(a.state.ActiveSessionTitle) == "" { @@ -1576,15 +1585,16 @@ func (a *App) applyComponentLayout(rebuildTranscript bool) { a.activity.Height = 0 } - a.providerPicker.SetSize(max(24, tuiutils.Clamp(lay.contentWidth-14, 28, 52)), max(4, tuiutils.Clamp(lay.contentHeight-10, 6, 10))) - a.modelPicker.SetSize(max(24, tuiutils.Clamp(lay.contentWidth-14, 28, 52)), max(4, tuiutils.Clamp(lay.contentHeight-10, 6, 10))) - helpPickerMaxHeight := max(8, lay.contentHeight-6) + pickerLayout := a.buildPickerLayout(lay.contentWidth, lay.contentHeight) + a.providerPicker.SetSize(pickerLayout.listWidth, pickerLayout.listHeight) + a.modelPicker.SetSize(pickerLayout.listWidth, pickerLayout.listHeight) + a.sessionPicker.SetSize(pickerLayout.listWidth, pickerLayout.listHeight) helpPickerDesiredHeight := (len(a.helpPicker.Items()) * 3) + 1 a.helpPicker.SetSize( - max(24, tuiutils.Clamp(lay.contentWidth-14, 28, 52)), - max(6, tuiutils.Clamp(helpPickerDesiredHeight, 6, helpPickerMaxHeight)), + pickerLayout.listWidth, + tuiutils.Clamp(helpPickerDesiredHeight, pickerListMinHeight, pickerLayout.listHeight), ) - a.fileBrowser.SetHeight(max(6, tuiutils.Clamp(lay.contentHeight-8, 8, 16))) + a.fileBrowser.SetHeight(max(pickerListMinHeight, pickerLayout.listHeight)) if rebuildTranscript || prevTranscriptWidth != a.transcript.Width { a.rebuildTranscript() } else if a.transcript.AtBottom() || a.isBusy() { @@ -1738,6 +1748,15 @@ func (a *App) handleImmediateSlashCommand(input string) (bool, tea.Cmd) { return true, a.handleRememberCommand(rest) case slashCommandForget: return true, a.handleForgetCommand(rest) + case slashCommandSession: + if err := a.refreshSessionPicker(); err != nil { + a.state.ExecutionError = err.Error() + a.state.StatusText = err.Error() + a.appendActivity("system", "Failed to refresh sessions", err.Error(), true) + return true, nil + } + a.openPicker(pickerSession, statusChooseSession, &a.sessionPicker, a.state.ActiveSessionID) + return true, nil default: return false, nil } diff --git a/internal/tui/core/app/update_permission_test.go b/internal/tui/core/app/update_permission_test.go index 30b0469b..47c18475 100644 --- a/internal/tui/core/app/update_permission_test.go +++ b/internal/tui/core/app/update_permission_test.go @@ -78,12 +78,12 @@ func newPermissionTestApp(runtime agentruntime.Runtime) *App { runtime: runtime, }, appComponents: appComponents{ - keys: newKeyMap(), - spinner: spin, - sessions: sessionList, - input: input, - transcript: viewport.New(0, 0), - activity: viewport.New(0, 0), + keys: newKeyMap(), + spinner: spin, + sessionPicker: sessionList, + input: input, + transcript: viewport.New(0, 0), + activity: viewport.New(0, 0), }, appRuntimeState: appRuntimeState{ nowFn: time.Now, diff --git a/internal/tui/core/app/update_test.go b/internal/tui/core/app/update_test.go index 389e25f2..0fad5737 100644 --- a/internal/tui/core/app/update_test.go +++ b/internal/tui/core/app/update_test.go @@ -6,6 +6,7 @@ import ( "path/filepath" "strings" "testing" + "time" tea "github.com/charmbracelet/bubbletea" @@ -56,10 +57,14 @@ func (s stubProviderService) SetCurrentModel(ctx context.Context, modelID string } type stubRuntime struct { - events chan agentruntime.RuntimeEvent - resolveCalls []agentruntime.PermissionResolutionInput - resolveErr error - cancelInvoked bool + events chan agentruntime.RuntimeEvent + resolveCalls []agentruntime.PermissionResolutionInput + resolveErr error + cancelInvoked bool + listSessions []agentsession.Summary + listSessionsErr error + loadSessions map[string]agentsession.Session + loadSessionErr error } func newStubRuntime() *stubRuntime { @@ -89,10 +94,21 @@ func (s *stubRuntime) Events() <-chan agentruntime.RuntimeEvent { } func (s *stubRuntime) ListSessions(ctx context.Context) ([]agentsession.Summary, error) { - return nil, nil + if s.listSessionsErr != nil { + return nil, s.listSessionsErr + } + return s.listSessions, nil } func (s *stubRuntime) LoadSession(ctx context.Context, id string) (agentsession.Session, error) { + if s.loadSessionErr != nil { + return agentsession.Session{}, s.loadSessionErr + } + if s.loadSessions != nil { + if session, ok := s.loadSessions[id]; ok { + return session, nil + } + } return agentsession.NewWithWorkdir("draft", ""), nil } @@ -211,6 +227,26 @@ func TestAppUpdateBasic(t *testing.T) { } } +func TestRefreshSessionPickerSelectsActiveSession(t *testing.T) { + app, runtime := newTestApp(t) + now := time.Now() + runtime.listSessions = []agentsession.Summary{ + {ID: "session-1", Title: "Session One", UpdatedAt: now.Add(-time.Minute)}, + {ID: "session-2", Title: "Session Two", UpdatedAt: now}, + } + app.state.ActiveSessionID = "session-2" + + if err := app.refreshSessionPicker(); err != nil { + t.Fatalf("refreshSessionPicker() error = %v", err) + } + if len(app.sessionPicker.Items()) != 2 { + t.Fatalf("expected 2 session items, got %d", len(app.sessionPicker.Items())) + } + if got := app.sessionPicker.Index(); got != 1 { + t.Fatalf("expected active session index 1, got %d", got) + } +} + func TestParsePermissionShortcutFromKeyInput(t *testing.T) { if decision, ok := parsePermissionShortcut("y"); !ok || decision != approvalflow.DecisionAllowOnce { t.Fatalf("expected allow_once, got %v (ok=%v)", decision, ok) @@ -1097,9 +1133,9 @@ func TestShouldHandleTabAsInput(t *testing.T) { func TestFocusNextPrev(t *testing.T) { app, _ := newTestApp(t) - app.focus = panelSessions + app.focus = panelTranscript app.focusNext() - if app.focus == panelSessions { + if app.focus == panelTranscript { t.Fatalf("expected focus to move") } app.focusPrev() diff --git a/internal/tui/core/app/view.go b/internal/tui/core/app/view.go index 355dd223..42296611 100644 --- a/internal/tui/core/app/view.go +++ b/internal/tui/core/app/view.go @@ -19,6 +19,25 @@ type layout struct { const headerBarHeight = 2 +const ( + pickerPanelHorizontalInset = 8 + pickerPanelVerticalInset = 4 + pickerPanelMinWidth = 42 + pickerPanelMaxWidth = 72 + pickerPanelMinHeight = 14 + pickerPanelMaxHeight = 24 + pickerListMinWidth = 28 + pickerListMinHeight = 8 + pickerHeaderRows = 2 +) + +type pickerLayoutSpec struct { + panelWidth int + panelHeight int + listWidth int + listHeight int +} + func (a App) View() string { docWidth := max(0, a.width-a.styles.doc.GetHorizontalFrameSize()) docHeight := max(0, a.height-a.styles.doc.GetVerticalFrameSize()) @@ -76,12 +95,13 @@ func (a App) waterfallMetrics(width int, height int) (int, int, int, int) { func (a App) renderWaterfall(width int, height int) string { if a.state.ActivePicker != pickerNone { + pickerLayout := a.buildPickerLayout(width, height) return lipgloss.Place( width, height, lipgloss.Center, lipgloss.Center, - a.renderPicker(tuiutils.Clamp(width-10, 36, 56), tuiutils.Clamp(height-6, 10, 14)), + a.renderPicker(pickerLayout.panelWidth, pickerLayout.panelHeight), ) } @@ -108,6 +128,23 @@ func (a App) renderWaterfall(width int, height int) string { return lipgloss.Place(width, height, lipgloss.Left, lipgloss.Top, content) } +func (a App) buildPickerLayout(contentWidth int, contentHeight int) pickerLayoutSpec { + panelWidth := tuiutils.Clamp(contentWidth-pickerPanelHorizontalInset, pickerPanelMinWidth, pickerPanelMaxWidth) + panelHeight := tuiutils.Clamp(contentHeight-pickerPanelVerticalInset, pickerPanelMinHeight, pickerPanelMaxHeight) + + frameWidth := a.styles.panelFocused.GetHorizontalFrameSize() + frameHeight := a.styles.panelFocused.GetVerticalFrameSize() + listWidth := max(pickerListMinWidth, panelWidth-frameWidth) + listHeight := max(pickerListMinHeight, panelHeight-frameHeight-pickerHeaderRows) + + return pickerLayoutSpec{ + panelWidth: panelWidth, + panelHeight: panelHeight, + listWidth: listWidth, + listHeight: listHeight, + } +} + func (a App) renderPicker(width int, height int) string { frameHeight := a.styles.panelFocused.GetVerticalFrameSize() title := modelPickerTitle @@ -118,6 +155,11 @@ func (a App) renderPicker(width int, height int) string { subtitle = providerPickerSubtitle body = a.providerPicker.View() } + if a.state.ActivePicker == pickerSession { + title = sessionPickerTitle + subtitle = sessionPickerSubtitle + body = a.sessionPicker.View() + } if a.state.ActivePicker == pickerFile { title = filePickerTitle subtitle = filePickerSubtitle diff --git a/internal/tui/core/app/view_test.go b/internal/tui/core/app/view_test.go index edd44261..8ff9eeca 100644 --- a/internal/tui/core/app/view_test.go +++ b/internal/tui/core/app/view_test.go @@ -3,11 +3,13 @@ package tui import ( "strings" "testing" + "time" "github.com/charmbracelet/bubbles/list" "github.com/charmbracelet/lipgloss" providertypes "neo-code/internal/provider/types" + agentsession "neo-code/internal/session" tuistate "neo-code/internal/tui/state" ) @@ -25,6 +27,44 @@ func TestRenderPickerHelpMode(t *testing.T) { } } +func TestRenderPickerSessionMode(t *testing.T) { + app, _ := newTestApp(t) + app.state.ActivePicker = pickerSession + app.sessionPicker.SetItems([]list.Item{ + sessionItem{Summary: agentsession.Summary{ + ID: "session-1", + Title: "Session One", + UpdatedAt: time.Now(), + }}, + }) + + view := app.renderPicker(48, 14) + if !strings.Contains(view, sessionPickerTitle) { + t.Fatalf("expected session picker title in view") + } + if !strings.Contains(view, sessionPickerSubtitle) { + t.Fatalf("expected session picker subtitle in view") + } + if !strings.Contains(view, "Session One") { + t.Fatalf("expected session item in picker body") + } +} + +func TestBuildPickerLayoutExpandsPopupSpace(t *testing.T) { + app, _ := newTestApp(t) + + got := app.buildPickerLayout(100, 30) + if got.panelHeight < 20 { + t.Fatalf("expected expanded picker panel height, got %d", got.panelHeight) + } + if got.listHeight < pickerListMinHeight { + t.Fatalf("expected picker list height >= %d, got %d", pickerListMinHeight, got.listHeight) + } + if got.listWidth < pickerListMinWidth { + t.Fatalf("expected picker list width >= %d, got %d", pickerListMinWidth, got.listWidth) + } +} + func TestRenderWaterfallUsesDynamicTranscriptHeight(t *testing.T) { app, _ := newTestApp(t) app.state.ActivePicker = pickerNone diff --git a/internal/tui/core/utils/view_helpers.go b/internal/tui/core/utils/view_helpers.go index 4f2c691f..93688565 100644 --- a/internal/tui/core/utils/view_helpers.go +++ b/internal/tui/core/utils/view_helpers.go @@ -13,6 +13,8 @@ func PickerLabelFromMode(mode tuistate.PickerMode) string { return "provider" case tuistate.PickerModel: return "model" + case tuistate.PickerSession: + return "session" case tuistate.PickerFile: return "file" case tuistate.PickerHelp: diff --git a/internal/tui/core/utils/view_helpers_test.go b/internal/tui/core/utils/view_helpers_test.go index 1c8d191b..a5bdf75a 100644 --- a/internal/tui/core/utils/view_helpers_test.go +++ b/internal/tui/core/utils/view_helpers_test.go @@ -13,6 +13,7 @@ func TestPickerLabelFromMode(t *testing.T) { }{ {tuistate.PickerProvider, "provider"}, {tuistate.PickerModel, "model"}, + {tuistate.PickerSession, "session"}, {tuistate.PickerFile, "file"}, {tuistate.PickerHelp, "help"}, {tuistate.PickerMode(999), "none"}, diff --git a/internal/tui/state/state_test.go b/internal/tui/state/state_test.go index 599ddfe1..e4038614 100644 --- a/internal/tui/state/state_test.go +++ b/internal/tui/state/state_test.go @@ -6,12 +6,13 @@ func TestPanelAndPickerConstants(t *testing.T) { if PanelSessions != 0 || PanelTranscript != 1 || PanelActivity != 2 || PanelInput != 3 { t.Fatalf("unexpected panel constants: %d %d %d %d", PanelSessions, PanelTranscript, PanelActivity, PanelInput) } - if PickerNone != 0 || PickerProvider != 1 || PickerModel != 2 || PickerFile != 3 || PickerHelp != 4 { + if PickerNone != 0 || PickerProvider != 1 || PickerModel != 2 || PickerSession != 3 || PickerFile != 4 || PickerHelp != 5 { t.Fatalf( - "unexpected picker constants: %d %d %d %d %d", + "unexpected picker constants: %d %d %d %d %d %d", PickerNone, PickerProvider, PickerModel, + PickerSession, PickerFile, PickerHelp, ) diff --git a/internal/tui/state/ui_state.go b/internal/tui/state/ui_state.go index 706b99dc..aea6d120 100644 --- a/internal/tui/state/ui_state.go +++ b/internal/tui/state/ui_state.go @@ -19,6 +19,7 @@ const ( PickerNone PickerMode = iota PickerProvider PickerModel + PickerSession PickerFile PickerHelp ) From 56a32968e4a4d21daff9b9ccb17d316c1408c42d Mon Sep 17 00:00:00 2001 From: creatang Date: Wed, 15 Apr 2026 17:29:10 +0800 Subject: [PATCH 26/33] fix(ci): avoid Linux clipboard build dependency --- internal/tui/infra/clipboard_common.go | 24 +++++++++++++++++ internal/tui/infra/clipboard_fallback.go | 23 ++++++++++++++++ .../{clipboard.go => clipboard_native.go} | 26 +++---------------- 3 files changed, 50 insertions(+), 23 deletions(-) create mode 100644 internal/tui/infra/clipboard_common.go create mode 100644 internal/tui/infra/clipboard_fallback.go rename internal/tui/infra/{clipboard.go => clipboard_native.go} (67%) diff --git a/internal/tui/infra/clipboard_common.go b/internal/tui/infra/clipboard_common.go new file mode 100644 index 00000000..1bb2eb52 --- /dev/null +++ b/internal/tui/infra/clipboard_common.go @@ -0,0 +1,24 @@ +package infra + +import ( + "os" + "path/filepath" +) + +func SaveImageToTempFile(data []byte, prefix string) (string, error) { + tmpDir := os.TempDir() + tmpFile := filepath.Join(tmpDir, prefix+"_image.png") + + f, err := os.Create(tmpFile) + if err != nil { + return "", err + } + defer f.Close() + + _, err = f.Write(data) + if err != nil { + return "", err + } + + return tmpFile, nil +} diff --git a/internal/tui/infra/clipboard_fallback.go b/internal/tui/infra/clipboard_fallback.go new file mode 100644 index 00000000..ab19067f --- /dev/null +++ b/internal/tui/infra/clipboard_fallback.go @@ -0,0 +1,23 @@ +//go:build !windows && !darwin + +package infra + +import ( + "errors" + + clipboardtext "github.com/atotto/clipboard" +) + +var errClipboardImageUnsupported = errors.New("clipboard image is not supported on this platform") + +func CopyText(text string) error { + return clipboardtext.WriteAll(text) +} + +func ReadClipboardText() (string, error) { + return clipboardtext.ReadAll() +} + +func ReadClipboardImage() ([]byte, error) { + return nil, errClipboardImageUnsupported +} diff --git a/internal/tui/infra/clipboard.go b/internal/tui/infra/clipboard_native.go similarity index 67% rename from internal/tui/infra/clipboard.go rename to internal/tui/infra/clipboard_native.go index 136f9c81..604098b2 100644 --- a/internal/tui/infra/clipboard.go +++ b/internal/tui/infra/clipboard_native.go @@ -1,10 +1,8 @@ -package infra +//go:build windows || darwin -import ( - "os" +package infra - "golang.design/x/clipboard" -) +import "golang.design/x/clipboard" var clipboardInitialized bool @@ -46,21 +44,3 @@ func ReadClipboardImage() ([]byte, error) { } return data, nil } - -func SaveImageToTempFile(data []byte, prefix string) (string, error) { - tmpDir := os.TempDir() - tmpFile := tmpDir + "/" + prefix + "_" + "image.png" - - f, err := os.Create(tmpFile) - if err != nil { - return "", err - } - defer f.Close() - - _, err = f.Write(data) - if err != nil { - return "", err - } - - return tmpFile, nil -} From 372ab70a528c8570ca48ea4085f5fa31a6e418ca Mon Sep 17 00:00:00 2001 From: xgopilot Date: Wed, 15 Apr 2026 09:41:48 +0000 Subject: [PATCH 27/33] fix(tui): resolve multimodal review issues and add tests - fix image attachment send path metadata composition and cache invalidation - secure clipboard temp file creation with unique names - optimize image mime detection using header reads - expose Ctrl+V in full help and add targeted tests Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: creatang <165447160+creatang@users.noreply.github.com> --- internal/tui/core/app/input_features.go | 32 +++ internal/tui/core/app/input_features_test.go | 223 +++++++++++++++++++ internal/tui/core/app/keymap.go | 2 +- internal/tui/core/app/keymap_test.go | 16 ++ internal/tui/core/app/update.go | 18 +- internal/tui/core/app/update_test.go | 2 + internal/tui/infra/clipboard_common.go | 36 ++- internal/tui/infra/image.go | 19 +- internal/tui/infra/infra_test.go | 35 +++ 9 files changed, 370 insertions(+), 13 deletions(-) create mode 100644 internal/tui/core/app/input_features_test.go create mode 100644 internal/tui/core/app/keymap_test.go diff --git a/internal/tui/core/app/input_features.go b/internal/tui/core/app/input_features.go index 89d7a96e..07b8c16c 100644 --- a/internal/tui/core/app/input_features.go +++ b/internal/tui/core/app/input_features.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "path/filepath" + "strconv" "strings" tea "github.com/charmbracelet/bubbletea" @@ -421,3 +422,34 @@ func (a *App) checkModelImageSupport() bool { func (a *App) canSendImageInput() bool { return a.checkModelImageSupport() } + +// invalidateModelCapabilityCache 在 provider 或 model 变化时清理图片能力缓存,避免复用旧结果。 +func (a *App) invalidateModelCapabilityCache() { + a.currentModelCapabilities = modelCapabilityState{} +} + +// composeMessageWithImageAttachments 在发送前把附件元信息拼接到文本,避免附件在运行链路中丢失。 +func (a *App) composeMessageWithImageAttachments(content string) string { + trimmed := strings.TrimSpace(content) + if len(a.pendingImageAttachments) == 0 { + return trimmed + } + + var builder strings.Builder + builder.WriteString(trimmed) + builder.WriteString("\n\n[Attached images]\n") + for index, attachment := range a.pendingImageAttachments { + builder.WriteString(strconv.Itoa(index + 1)) + builder.WriteString(". ") + builder.WriteString(attachment.Name) + builder.WriteString(" | mime=") + builder.WriteString(attachment.MimeType) + builder.WriteString(" | bytes=") + builder.WriteString(strconv.FormatInt(attachment.Size, 10)) + builder.WriteString(" | path=") + builder.WriteString(attachment.Path) + builder.WriteString("\n") + } + builder.WriteString("Treat the list above as user-provided image attachments.") + return builder.String() +} diff --git a/internal/tui/core/app/input_features_test.go b/internal/tui/core/app/input_features_test.go new file mode 100644 index 00000000..4577dcfa --- /dev/null +++ b/internal/tui/core/app/input_features_test.go @@ -0,0 +1,223 @@ +package tui + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + tea "github.com/charmbracelet/bubbletea" + + "neo-code/internal/config" + configstate "neo-code/internal/config/state" + providertypes "neo-code/internal/provider/types" +) + +func TestTokenAndReferenceParsing(t *testing.T) { + start, end, token, ok := tokenRange(" @file/path", tokenSelectorFirst) + if !ok || start != 2 || end != len(" @file/path") || token != "@file/path" { + t.Fatalf("unexpected first token parse: start=%d end=%d token=%q ok=%v", start, end, token, ok) + } + + start, end, token, ok = tokenRange("hello @image:/tmp/a.png", tokenSelectorLast) + if !ok || token != "@image:/tmp/a.png" || start <= 0 || end <= start { + t.Fatalf("unexpected last token parse: start=%d end=%d token=%q ok=%v", start, end, token, ok) + } + + _, _, _, ok = currentReferenceToken("hello world") + if ok { + t.Fatalf("expected non-reference token to be rejected") + } + + _, _, token, ok = currentReferenceToken("x @a/b.txt") + if !ok || token != "@a/b.txt" { + t.Fatalf("expected file reference token, got token=%q ok=%v", token, ok) + } + + if !isImageReferenceInput("@image:/tmp/p.png") { + t.Fatalf("expected image reference input recognized") + } + if got := extractImageReference("@image:/tmp/p.png"); got != "/tmp/p.png" { + t.Fatalf("unexpected image reference extraction: %q", got) + } +} + +func TestApplyFileReference(t *testing.T) { + app, _ := newTestApp(t) + root := t.TempDir() + app.state.CurrentWorkdir = root + inside := filepath.Join(root, "docs", "a.md") + if err := os.MkdirAll(filepath.Dir(inside), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(inside, []byte("ok"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + + if err := app.applyFileReference(inside); err != nil { + t.Fatalf("applyFileReference() error = %v", err) + } + if got := app.input.Value(); !strings.Contains(got, "@docs/a.md") { + t.Fatalf("expected relative file reference, got %q", got) + } + + app.input.SetValue("prefix @old/ref") + if err := app.applyFileReference(inside); err != nil { + t.Fatalf("applyFileReference replace() error = %v", err) + } + if got := app.input.Value(); strings.Contains(got, "@old/ref") || !strings.Contains(got, "@docs/a.md") { + t.Fatalf("expected active token replaced, got %q", got) + } +} + +func TestImageAttachmentLifecycle(t *testing.T) { + app, _ := newTestApp(t) + root := t.TempDir() + imagePath := filepath.Join(root, "test.png") + if err := os.WriteFile(imagePath, []byte("fake-png"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + if err := app.addImageAttachment(imagePath); err != nil { + t.Fatalf("addImageAttachment() error = %v", err) + } + if app.getImageAttachmentCount() != 1 { + t.Fatalf("expected one attachment, got %d", app.getImageAttachmentCount()) + } + if !app.hasImageAttachments() { + t.Fatalf("expected hasImageAttachments() true") + } + if _, err := app.loadImageAttachmentData(0); err != nil { + t.Fatalf("loadImageAttachmentData() error = %v", err) + } + + if err := app.removeImageAttachment(0); err != nil { + t.Fatalf("removeImageAttachment() error = %v", err) + } + if app.getImageAttachmentCount() != 0 { + t.Fatalf("expected no attachments after remove") + } + if err := app.removeImageAttachment(0); err == nil { + t.Fatalf("expected out-of-range remove error") + } +} + +func TestAddImageAttachmentLimit(t *testing.T) { + app, _ := newTestApp(t) + root := t.TempDir() + + for i := 0; i < maxImageAttachments; i++ { + path := filepath.Join(root, fmt.Sprintf("%d.png", i)) + if err := os.WriteFile(path, []byte("png"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + if err := app.addImageAttachment(path); err != nil { + t.Fatalf("addImageAttachment(%d) error = %v", i, err) + } + } + + over := filepath.Join(root, "over.png") + if err := os.WriteFile(over, []byte("png"), 0o644); err != nil { + t.Fatalf("write over image: %v", err) + } + if err := app.addImageAttachment(over); err == nil { + t.Fatalf("expected attachment limit error") + } +} + +func TestCanSendImageInputCacheInvalidationOnModelChange(t *testing.T) { + app, _ := newTestApp(t) + providerID := app.state.CurrentProvider + + app.providerSvc = stubProviderService{ + providers: []configstate.ProviderOption{{ID: providerID, Name: providerID}}, + models: []providertypes.ModelDescriptor{{ + ID: "model-a", + Name: "model-a", + CapabilityHints: providertypes.ModelCapabilityHints{ + ImageInput: providertypes.ModelCapabilityStateSupported, + }, + }}, + } + app.state.CurrentModel = "model-a" + if !app.canSendImageInput() { + t.Fatalf("expected model-a to support images") + } + + app.providerSvc = stubProviderService{ + providers: []configstate.ProviderOption{{ID: providerID, Name: providerID}}, + models: []providertypes.ModelDescriptor{{ + ID: "model-b", + Name: "model-b", + CapabilityHints: providertypes.ModelCapabilityHints{ + ImageInput: providertypes.ModelCapabilityStateUnsupported, + }, + }}, + } + app.syncConfigState(config.Config{SelectedProvider: providerID, CurrentModel: "model-b", Workdir: app.state.CurrentWorkdir}) + if app.canSendImageInput() { + t.Fatalf("expected model-b to be unsupported after cache invalidation") + } +} + +func TestComposeMessageWithImageAttachments(t *testing.T) { + app, _ := newTestApp(t) + app.pendingImageAttachments = []pendingImageAttachment{{ + Path: "/tmp/a.png", + Name: "a.png", + MimeType: "image/png", + Size: 12, + }} + + got := app.composeMessageWithImageAttachments("hello") + if !strings.Contains(got, "[Attached images]") || !strings.Contains(got, "a.png") || !strings.Contains(got, "path=/tmp/a.png") { + t.Fatalf("unexpected composed message: %q", got) + } +} + +func TestUpdateSendWithImageAttachmentsComposesRuntimeInput(t *testing.T) { + app, runtime := newTestApp(t) + root := t.TempDir() + imagePath := filepath.Join(root, "queued.png") + if err := os.WriteFile(imagePath, []byte("fake-png"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + if err := app.addImageAttachment(imagePath); err != nil { + t.Fatalf("addImageAttachment() error = %v", err) + } + app.providerSvc = stubProviderService{ + providers: []configstate.ProviderOption{{ID: app.state.CurrentProvider, Name: app.state.CurrentProvider}}, + models: []providertypes.ModelDescriptor{{ + ID: app.state.CurrentModel, + Name: app.state.CurrentModel, + CapabilityHints: providertypes.ModelCapabilityHints{ + ImageInput: providertypes.ModelCapabilityStateSupported, + }, + }}, + } + + app.input.SetValue("hello") + app.state.InputText = "hello" + + model, cmd := app.Update(tea.KeyMsg{Type: tea.KeyEnter}) + if cmd == nil { + t.Fatalf("expected run command") + } + app = model.(App) + if app.hasImageAttachments() { + t.Fatalf("expected attachments cleared after send") + } + if len(app.activeMessages) == 0 || !strings.Contains(app.activeMessages[len(app.activeMessages)-1].Content, "[Attached images]") { + t.Fatalf("expected composed user message in transcript") + } + + msg := cmd() + _, ok := msg.(runFinishedMsg) + if !ok { + t.Fatalf("expected runFinishedMsg, got %T", msg) + } + if len(runtime.runInputs) != 1 || !strings.Contains(runtime.runInputs[0].Content, "[Attached images]") { + t.Fatalf("expected composed runtime input, got %+v", runtime.runInputs) + } +} diff --git a/internal/tui/core/app/keymap.go b/internal/tui/core/app/keymap.go index 1a3e219d..3d604fb3 100644 --- a/internal/tui/core/app/keymap.go +++ b/internal/tui/core/app/keymap.go @@ -98,7 +98,7 @@ func (k keyMap) FullHelp() [][]key.Binding { return [][]key.Binding{ {k.Send, k.Newline, k.CancelAgent, k.NewSession}, {k.FocusInput, k.NextPanel, k.PrevPanel}, - {k.ToggleHelp, k.Quit, k.ScrollUp, k.ScrollDown}, + {k.ToggleHelp, k.Quit, k.PasteImage, k.ScrollUp}, {k.PageUp, k.PageDown, k.Top, k.Bottom}, } } diff --git a/internal/tui/core/app/keymap_test.go b/internal/tui/core/app/keymap_test.go new file mode 100644 index 00000000..c08126cb --- /dev/null +++ b/internal/tui/core/app/keymap_test.go @@ -0,0 +1,16 @@ +package tui + +import "testing" + +func TestFullHelpIncludesPasteImage(t *testing.T) { + keys := newKeyMap() + help := keys.FullHelp() + for _, row := range help { + for _, binding := range row { + if binding.Help().Key == keys.PasteImage.Help().Key { + return + } + } + } + t.Fatalf("expected full help to include paste image binding") +} diff --git a/internal/tui/core/app/update.go b/internal/tui/core/app/update.go index 846b886e..63756cc5 100644 --- a/internal/tui/core/app/update.go +++ b/internal/tui/core/app/update.go @@ -421,15 +421,17 @@ func (a App) updateInputPanel(msg tea.Msg, typed tea.KeyMsg, cmds []tea.Cmd) (te a.state.CurrentTool = "" if a.hasImageAttachments() { - a.appendActivity("multimodal", "Sending message with images", fmt.Sprintf("%d image(s) attached", len(a.pendingImageAttachments)), false) + a.appendActivity("multimodal", "Sending message with image metadata", fmt.Sprintf("%d image(s) attached", len(a.pendingImageAttachments)), false) } - a.activeMessages = append(a.activeMessages, providertypes.Message{Role: roleUser, Content: input}) + composedInput := a.composeMessageWithImageAttachments(input) + a.activeMessages = append(a.activeMessages, providertypes.Message{Role: roleUser, Content: composedInput}) a.rebuildTranscript() runID := fmt.Sprintf("run-%d", a.now().UnixNano()) a.state.ActiveRunID = runID requestedWorkdir := tuiutils.RequestedWorkdirForRun(a.state.CurrentWorkdir) - cmds = append(cmds, runAgent(a.runtime, runID, a.state.ActiveSessionID, requestedWorkdir, input)) + cmds = append(cmds, runAgent(a.runtime, runID, a.state.ActiveSessionID, requestedWorkdir, composedInput)) + a.clearImageAttachments() return a, tea.Batch(cmds...) } } @@ -769,6 +771,10 @@ func (a *App) syncActiveSessionTitle() { } func (a *App) syncConfigState(cfg config.Config) { + if !strings.EqualFold(strings.TrimSpace(a.state.CurrentProvider), strings.TrimSpace(cfg.SelectedProvider)) || + !strings.EqualFold(strings.TrimSpace(a.state.CurrentModel), strings.TrimSpace(cfg.CurrentModel)) { + a.invalidateModelCapabilityCache() + } a.state.CurrentProvider = cfg.SelectedProvider a.state.CurrentModel = cfg.CurrentModel if strings.TrimSpace(a.state.CurrentWorkdir) == "" { @@ -951,9 +957,15 @@ func runtimeEventRunContextHandler(a *App, event agentruntime.RuntimeEvent) bool a.state.ActiveRunID = mapped.RunID } if strings.TrimSpace(mapped.Provider) != "" { + if !strings.EqualFold(strings.TrimSpace(a.state.CurrentProvider), strings.TrimSpace(mapped.Provider)) { + a.invalidateModelCapabilityCache() + } a.state.CurrentProvider = mapped.Provider } if strings.TrimSpace(mapped.Model) != "" { + if !strings.EqualFold(strings.TrimSpace(a.state.CurrentModel), strings.TrimSpace(mapped.Model)) { + a.invalidateModelCapabilityCache() + } a.state.CurrentModel = mapped.Model } if strings.TrimSpace(mapped.Workdir) != "" { diff --git a/internal/tui/core/app/update_test.go b/internal/tui/core/app/update_test.go index 0fad5737..feae8187 100644 --- a/internal/tui/core/app/update_test.go +++ b/internal/tui/core/app/update_test.go @@ -58,6 +58,7 @@ func (s stubProviderService) SetCurrentModel(ctx context.Context, modelID string type stubRuntime struct { events chan agentruntime.RuntimeEvent + runInputs []agentruntime.UserInput resolveCalls []agentruntime.PermissionResolutionInput resolveErr error cancelInvoked bool @@ -72,6 +73,7 @@ func newStubRuntime() *stubRuntime { } func (s *stubRuntime) Run(ctx context.Context, input agentruntime.UserInput) error { + s.runInputs = append(s.runInputs, input) return nil } diff --git a/internal/tui/infra/clipboard_common.go b/internal/tui/infra/clipboard_common.go index 1bb2eb52..3a2d4615 100644 --- a/internal/tui/infra/clipboard_common.go +++ b/internal/tui/infra/clipboard_common.go @@ -2,23 +2,43 @@ package infra import ( "os" - "path/filepath" ) func SaveImageToTempFile(data []byte, prefix string) (string, error) { - tmpDir := os.TempDir() - tmpFile := filepath.Join(tmpDir, prefix+"_image.png") - - f, err := os.Create(tmpFile) + pattern := "image-*.png" + if cleaned := sanitizeTempPrefix(prefix); cleaned != "" { + pattern = cleaned + "-*.png" + } + f, err := os.CreateTemp("", pattern) if err != nil { return "", err } - defer f.Close() + tmpFile := f.Name() - _, err = f.Write(data) - if err != nil { + if _, err = f.Write(data); err != nil { + _ = f.Close() + _ = os.Remove(tmpFile) + return "", err + } + if err = f.Close(); err != nil { + _ = os.Remove(tmpFile) return "", err } return tmpFile, nil } + +// sanitizeTempPrefix 过滤临时文件名前缀中的不安全字符,避免路径注入与非法命名。 +func sanitizeTempPrefix(prefix string) string { + if prefix == "" { + return "" + } + + buf := make([]rune, 0, len(prefix)) + for _, r := range prefix { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '-' || r == '_' { + buf = append(buf, r) + } + } + return string(buf) +} diff --git a/internal/tui/infra/image.go b/internal/tui/infra/image.go index 6b352ba8..f7eb6504 100644 --- a/internal/tui/infra/image.go +++ b/internal/tui/infra/image.go @@ -1,6 +1,7 @@ package infra import ( + "io" "io/fs" "mime" "os" @@ -37,7 +38,7 @@ func DetectImageMimeType(path string) string { return detected } - data, err := os.ReadFile(path) + data, err := readMagicHeader(path, 512) if err != nil { return "" } @@ -70,3 +71,19 @@ func IsSupportedImageFormat(path string) bool { func ReadImageFile(path string) ([]byte, error) { return os.ReadFile(path) } + +// readMagicHeader 仅读取文件头部用于类型探测,避免把整文件加载到内存。 +func readMagicHeader(path string, maxBytes int) ([]byte, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + + buffer := make([]byte, maxBytes) + n, err := io.ReadFull(file, buffer) + if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF { + return nil, err + } + return buffer[:n], nil +} diff --git a/internal/tui/infra/infra_test.go b/internal/tui/infra/infra_test.go index 8b32e306..25772a39 100644 --- a/internal/tui/infra/infra_test.go +++ b/internal/tui/infra/infra_test.go @@ -297,6 +297,41 @@ func TestDefaultWorkspaceCommandExecutor(t *testing.T) { } } +func TestSaveImageToTempFileCreatesUniquePaths(t *testing.T) { + first, err := SaveImageToTempFile([]byte("first"), "paste") + if err != nil { + t.Fatalf("SaveImageToTempFile(first) error = %v", err) + } + defer os.Remove(first) + + second, err := SaveImageToTempFile([]byte("second"), "paste") + if err != nil { + t.Fatalf("SaveImageToTempFile(second) error = %v", err) + } + defer os.Remove(second) + + if first == second { + t.Fatalf("expected unique temp file paths, got %q", first) + } + if !strings.Contains(filepath.Base(first), "paste-") || !strings.Contains(filepath.Base(second), "paste-") { + t.Fatalf("expected sanitized prefix in temp names, got %q and %q", first, second) + } +} + +func TestDetectImageMimeTypeByMagicHeader(t *testing.T) { + root := t.TempDir() + path := filepath.Join(root, "blob.bin") + pngHeader := []byte{0x89, 0x50, 0x4E, 0x47, 0x0d, 0x0a, 0x1a, 0x0a} + payload := append(pngHeader, []byte("payload")...) + if err := os.WriteFile(path, payload, 0o644); err != nil { + t.Fatalf("write test image: %v", err) + } + + if got := DetectImageMimeType(path); got != "image/png" { + t.Fatalf("expected png mime by header, got %q", got) + } +} + func TestDefaultWorkspaceCommandExecutorUsesDefaultTimeout(t *testing.T) { workdir := t.TempDir() shellName, successCmd, _, _, _ := workspaceExecutorCommands() From d5e5b7cd0df5295b5a411933dc2b298a0a142546 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Wed, 15 Apr 2026 10:59:12 +0000 Subject: [PATCH 28/33] test(tui): improve multimodal and session-switch coverage Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: creatang <165447160+creatang@users.noreply.github.com> --- internal/tui/core/app/input_features.go | 11 +- internal/tui/core/app/input_features_test.go | 240 +++++++++++++++++++ internal/tui/core/app/update_test.go | 184 ++++++++++++++ internal/tui/core/app/view_test.go | 65 +++++ internal/tui/infra/clipboard_common.go | 9 +- internal/tui/infra/infra_test.go | 165 +++++++++++++ 6 files changed, 663 insertions(+), 11 deletions(-) diff --git a/internal/tui/core/app/input_features.go b/internal/tui/core/app/input_features.go index 07b8c16c..8c56f9fe 100644 --- a/internal/tui/core/app/input_features.go +++ b/internal/tui/core/app/input_features.go @@ -36,6 +36,9 @@ const ( ) var workspaceCommandExecutor = defaultWorkspaceCommandExecutor +var readClipboardImage = tuiinfra.ReadClipboardImage +var saveClipboardImageToTempFile = tuiinfra.SaveImageToTempFile +var detectImageMimeType = tuiinfra.DetectImageMimeType func isWorkspaceCommandInput(input string) bool { return strings.HasPrefix(strings.TrimSpace(input), workspaceCommandPrefix) @@ -296,7 +299,7 @@ func (a *App) addImageAttachment(path string) error { return fmt.Errorf("image size exceeds %d MB limit", imageMaxSizeBytes/(1024*1024)) } - mimeType := tuiinfra.DetectImageMimeType(absPath) + mimeType := detectImageMimeType(absPath) if mimeType == "" { return fmt.Errorf("unsupported image format") } @@ -359,7 +362,7 @@ func (a *App) addImageFromClipboard() error { return fmt.Errorf("maximum %d image attachments allowed", maxImageAttachments) } - data, err := tuiinfra.ReadClipboardImage() + data, err := readClipboardImage() if err != nil { return fmt.Errorf("failed to read clipboard image: %w", err) } @@ -372,12 +375,12 @@ func (a *App) addImageFromClipboard() error { return fmt.Errorf("image size exceeds %d MB limit", imageMaxSizeBytes/(1024*1024)) } - tmpPath, err := tuiinfra.SaveImageToTempFile(data, "paste") + tmpPath, err := saveClipboardImageToTempFile(data, "paste") if err != nil { return fmt.Errorf("failed to save clipboard image: %w", err) } - mimeType := tuiinfra.DetectImageMimeType(tmpPath) + mimeType := detectImageMimeType(tmpPath) if mimeType == "" { return fmt.Errorf("unsupported image format from clipboard") } diff --git a/internal/tui/core/app/input_features_test.go b/internal/tui/core/app/input_features_test.go index 4577dcfa..db32b19a 100644 --- a/internal/tui/core/app/input_features_test.go +++ b/internal/tui/core/app/input_features_test.go @@ -1,6 +1,8 @@ package tui import ( + "context" + "errors" "fmt" "os" "path/filepath" @@ -14,6 +16,15 @@ import ( providertypes "neo-code/internal/provider/types" ) +type snapshotErrProviderService struct { + stubProviderService + err error +} + +func (s snapshotErrProviderService) ListModelsSnapshot(ctx context.Context) ([]providertypes.ModelDescriptor, error) { + return nil, s.err +} + func TestTokenAndReferenceParsing(t *testing.T) { start, end, token, ok := tokenRange(" @file/path", tokenSelectorFirst) if !ok || start != 2 || end != len(" @file/path") || token != "@file/path" { @@ -71,6 +82,34 @@ func TestApplyFileReference(t *testing.T) { } } +func TestApplyFileReferenceBranches(t *testing.T) { + app, _ := newTestApp(t) + if err := app.applyFileReference(" "); err == nil { + t.Fatalf("expected empty file path error") + } + + root := t.TempDir() + app.state.CurrentWorkdir = filepath.Join(root, "workdir") + inside := filepath.Join(app.state.CurrentWorkdir, "a.txt") + outside := filepath.Join(root, "outside.txt") + if err := os.MkdirAll(filepath.Dir(inside), 0o755); err != nil { + t.Fatalf("mkdir inside: %v", err) + } + if err := os.WriteFile(inside, []byte("a"), 0o644); err != nil { + t.Fatalf("write inside: %v", err) + } + if err := os.WriteFile(outside, []byte("b"), 0o644); err != nil { + t.Fatalf("write outside: %v", err) + } + + if err := app.applyFileReference(outside); err != nil { + t.Fatalf("apply outside reference error: %v", err) + } + if !strings.Contains(app.input.Value(), "@") { + t.Fatalf("expected file reference token to be inserted") + } +} + func TestImageAttachmentLifecycle(t *testing.T) { app, _ := newTestApp(t) root := t.TempDir() @@ -176,6 +215,207 @@ func TestComposeMessageWithImageAttachments(t *testing.T) { } } +func TestComposeMessageWithImageAttachmentsNoAttachments(t *testing.T) { + app, _ := newTestApp(t) + got := app.composeMessageWithImageAttachments(" hello ") + if got != "hello" { + t.Fatalf("expected trimmed content without attachment block, got %q", got) + } +} + +func TestApplyImageReference(t *testing.T) { + app, _ := newTestApp(t) + root := t.TempDir() + imagePath := filepath.Join(root, "ok.png") + if err := os.WriteFile(imagePath, []byte("png"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + if err := app.applyImageReference("@image:" + imagePath); err != nil { + t.Fatalf("applyImageReference() error = %v", err) + } + if app.getImageAttachmentCount() != 1 { + t.Fatalf("expected one attachment after applyImageReference") + } + if err := app.applyImageReference("not-an-image-reference"); err == nil { + t.Fatalf("expected invalid image reference error") + } +} + +func TestGetAndClearImageAttachments(t *testing.T) { + app, _ := newTestApp(t) + app.pendingImageAttachments = []pendingImageAttachment{ + {Name: "a.png", Path: "/tmp/a.png", MimeType: "image/png", Size: 1}, + } + if len(app.getImageAttachments()) != 1 { + t.Fatalf("expected one attachment from getter") + } + app.clearImageAttachments() + if len(app.getImageAttachments()) != 0 { + t.Fatalf("expected no attachments after clear") + } +} + +func TestLoadImageAttachmentDataInvalidIndex(t *testing.T) { + app, _ := newTestApp(t) + if _, err := app.loadImageAttachmentData(0); err == nil { + t.Fatalf("expected invalid attachment index error") + } +} + +func TestAddImageFromClipboardUnsupported(t *testing.T) { + app, _ := newTestApp(t) + if err := app.addImageFromClipboard(); err == nil { + t.Fatalf("expected unsupported clipboard image error") + } +} + +func TestAddImageFromClipboardSuccess(t *testing.T) { + app, _ := newTestApp(t) + originalRead := readClipboardImage + originalSave := saveClipboardImageToTempFile + originalDetect := detectImageMimeType + readClipboardImage = func() ([]byte, error) { + return []byte("image-bytes"), nil + } + saveClipboardImageToTempFile = func(data []byte, prefix string) (string, error) { + path := filepath.Join(t.TempDir(), "clipboard.png") + if err := os.WriteFile(path, data, 0o644); err != nil { + t.Fatalf("write temp clipboard image: %v", err) + } + return path, nil + } + detectImageMimeType = func(path string) string { return "image/png" } + defer func() { + readClipboardImage = originalRead + saveClipboardImageToTempFile = originalSave + detectImageMimeType = originalDetect + }() + + if err := app.addImageFromClipboard(); err != nil { + t.Fatalf("addImageFromClipboard() error = %v", err) + } + if app.getImageAttachmentCount() != 1 { + t.Fatalf("expected one clipboard image attachment") + } +} + +func TestAddImageFromClipboardBranches(t *testing.T) { + app, _ := newTestApp(t) + originalRead := readClipboardImage + originalSave := saveClipboardImageToTempFile + originalDetect := detectImageMimeType + defer func() { + readClipboardImage = originalRead + saveClipboardImageToTempFile = originalSave + detectImageMimeType = originalDetect + }() + + readClipboardImage = func() ([]byte, error) { return nil, nil } + if err := app.addImageFromClipboard(); err == nil { + t.Fatalf("expected no image in clipboard error") + } + + readClipboardImage = func() ([]byte, error) { return make([]byte, imageMaxSizeBytes+1), nil } + if err := app.addImageFromClipboard(); err == nil { + t.Fatalf("expected image size limit error") + } + + readClipboardImage = func() ([]byte, error) { return []byte("x"), nil } + saveClipboardImageToTempFile = func(data []byte, prefix string) (string, error) { + return filepath.Join(t.TempDir(), "clipboard.bin"), nil + } + detectImageMimeType = func(path string) string { return "" } + if err := app.addImageFromClipboard(); err == nil { + t.Fatalf("expected unsupported image format error") + } + + readClipboardImage = func() ([]byte, error) { return []byte("x"), nil } + saveClipboardImageToTempFile = func(data []byte, prefix string) (string, error) { + return "", errors.New("save failed") + } + if err := app.addImageFromClipboard(); err == nil { + t.Fatalf("expected save failure error") + } +} + +func TestCheckModelImageSupportErrorAndModelNotFound(t *testing.T) { + app, _ := newTestApp(t) + app.providerSvc = snapshotErrProviderService{ + stubProviderService: stubProviderService{}, + err: errors.New("boom"), + } + if app.checkModelImageSupport() { + t.Fatalf("expected false when provider snapshot fails") + } + if !app.currentModelCapabilities.checked { + t.Fatalf("expected capability cache to be marked checked after failure") + } + + app.currentModelCapabilities = modelCapabilityState{} + app.providerSvc = stubProviderService{ + providers: []configstate.ProviderOption{{ID: app.state.CurrentProvider, Name: app.state.CurrentProvider}}, + models: []providertypes.ModelDescriptor{{ + ID: "other-model", + }}, + } + if app.checkModelImageSupport() { + t.Fatalf("expected false when current model is missing from snapshot") + } +} + +func TestExecuteWorkspaceCommand(t *testing.T) { + app, _ := newTestApp(t) + original := workspaceCommandExecutor + workspaceCommandExecutor = func(ctx context.Context, cfg config.Config, workdir string, command string) (string, error) { + if command != "echo hi" { + t.Fatalf("unexpected command: %q", command) + } + return "ok", nil + } + defer func() { workspaceCommandExecutor = original }() + + command, output, err := executeWorkspaceCommand(context.Background(), app.configManager, app.state.CurrentWorkdir, "& echo hi") + if err != nil { + t.Fatalf("executeWorkspaceCommand() error = %v", err) + } + if command != "echo hi" || output != "ok" { + t.Fatalf("unexpected execute result command=%q output=%q", command, output) + } + + if _, _, err := executeWorkspaceCommand(context.Background(), app.configManager, app.state.CurrentWorkdir, "& "); err == nil { + t.Fatalf("expected invalid workspace command error") + } +} + +func TestDefaultWorkspaceCommandExecutor(t *testing.T) { + cfg := config.Config{Workdir: t.TempDir(), Shell: "bash", ToolTimeoutSec: 1} + if _, err := defaultWorkspaceCommandExecutor(context.Background(), cfg, cfg.Workdir, ""); err == nil { + t.Fatalf("expected empty command to fail") + } +} + +func TestRunWorkspaceCommandCmd(t *testing.T) { + app, _ := newTestApp(t) + original := workspaceCommandExecutor + workspaceCommandExecutor = func(ctx context.Context, cfg config.Config, workdir string, command string) (string, error) { + return "done", nil + } + defer func() { workspaceCommandExecutor = original }() + + cmd := runWorkspaceCommand(app.configManager, app.state.CurrentWorkdir, "& echo hi") + if cmd == nil { + t.Fatalf("expected workspace command cmd") + } + msg := cmd() + result, ok := msg.(workspaceCommandResultMsg) + if !ok { + t.Fatalf("expected workspaceCommandResultMsg, got %T", msg) + } + if result.Command != "echo hi" || result.Output != "done" || result.Err != nil { + t.Fatalf("unexpected workspace result: %+v", result) + } +} + func TestUpdateSendWithImageAttachmentsComposesRuntimeInput(t *testing.T) { app, runtime := newTestApp(t) root := t.TempDir() diff --git a/internal/tui/core/app/update_test.go b/internal/tui/core/app/update_test.go index feae8187..c6daa2dd 100644 --- a/internal/tui/core/app/update_test.go +++ b/internal/tui/core/app/update_test.go @@ -983,6 +983,190 @@ func TestRuntimeEventRunContextHandler(t *testing.T) { } } +func TestRuntimeEventRunContextHandlerInvalidatesModelCapabilityCache(t *testing.T) { + app, _ := newTestApp(t) + app.state.CurrentProvider = "provider-a" + app.state.CurrentModel = "model-a" + app.currentModelCapabilities = modelCapabilityState{ + checked: true, + supportsImageInput: true, + } + + payload := tuiservices.RuntimeRunContextPayload{ + Provider: "provider-b", + Model: "model-b", + } + _ = runtimeEventRunContextHandler(&app, agentruntime.RuntimeEvent{Payload: payload}) + if app.currentModelCapabilities.checked { + t.Fatalf("expected capability cache to be invalidated when provider/model changes") + } +} + +func TestSyncConfigStateInvalidatesModelCapabilityCache(t *testing.T) { + app, _ := newTestApp(t) + app.state.CurrentProvider = "provider-a" + app.state.CurrentModel = "model-a" + app.currentModelCapabilities = modelCapabilityState{ + checked: true, + supportsImageInput: true, + } + + app.syncConfigState(config.Config{ + SelectedProvider: "provider-b", + CurrentModel: "model-b", + Workdir: app.state.CurrentWorkdir, + }) + if app.currentModelCapabilities.checked { + t.Fatalf("expected capability cache to be invalidated") + } +} + +func TestUpdatePasteImageShortcutFailure(t *testing.T) { + app, _ := newTestApp(t) + model, cmd := app.Update(tea.KeyMsg{Type: tea.KeyCtrlV}) + if cmd != nil { + _ = cmd() + } + app = model.(App) + if !strings.Contains(strings.ToLower(app.state.StatusText), "clipboard") { + t.Fatalf("expected clipboard failure status, got %q", app.state.StatusText) + } +} + +func TestUpdateEnterSessionOpensSessionPicker(t *testing.T) { + app, runtime := newTestApp(t) + runtime.listSessions = []agentsession.Summary{ + {ID: "s1", Title: "Session 1", UpdatedAt: time.Now()}, + } + app.input.SetValue("/session") + app.state.InputText = "/session" + + model, cmd := app.Update(tea.KeyMsg{Type: tea.KeyEnter}) + if cmd != nil { + _ = cmd() + } + app = model.(App) + if app.state.ActivePicker != pickerSession { + t.Fatalf("expected session picker to open") + } + if app.state.StatusText != statusChooseSession { + t.Fatalf("expected status %q, got %q", statusChooseSession, app.state.StatusText) + } +} + +func TestUpdateEnterImageReferencePath(t *testing.T) { + app, _ := newTestApp(t) + app.input.SetValue("@image:/path/does-not-exist.png") + app.state.InputText = "@image:/path/does-not-exist.png" + + model, cmd := app.Update(tea.KeyMsg{Type: tea.KeyEnter}) + if cmd != nil { + _ = cmd() + } + app = model.(App) + if app.input.Value() != "" { + t.Fatalf("expected input to be reset after image reference handling") + } + if strings.TrimSpace(app.state.StatusText) == "" { + t.Fatalf("expected status text to reflect image reference failure") + } +} + +func TestUpdateSendWithUnsupportedImageInput(t *testing.T) { + app, _ := newTestApp(t) + app.pendingImageAttachments = []pendingImageAttachment{ + {Name: "a.png", MimeType: "image/png", Path: "/tmp/a.png", Size: 1}, + } + app.providerSvc = stubProviderService{ + providers: []configstate.ProviderOption{{ID: app.state.CurrentProvider, Name: app.state.CurrentProvider}}, + models: []providertypes.ModelDescriptor{{ + ID: app.state.CurrentModel, + Name: app.state.CurrentModel, + CapabilityHints: providertypes.ModelCapabilityHints{ + ImageInput: providertypes.ModelCapabilityStateUnsupported, + }, + }}, + } + app.input.SetValue("hello") + app.state.InputText = "hello" + + model, cmd := app.Update(tea.KeyMsg{Type: tea.KeyEnter}) + if cmd != nil { + _ = cmd() + } + app = model.(App) + if app.state.IsAgentRunning { + t.Fatalf("expected send to be blocked for unsupported model image input") + } + if app.hasImageAttachments() { + t.Fatalf("expected pending image attachments to be cleared on unsupported model") + } + if app.state.StatusText != "Model does not support images" { + t.Fatalf("unexpected status text: %q", app.state.StatusText) + } +} + +func TestUpdatePickerSessionEnterActivatesSelectedSession(t *testing.T) { + app, runtime := newTestApp(t) + now := time.Now() + runtime.listSessions = []agentsession.Summary{ + {ID: "s1", Title: "One", UpdatedAt: now.Add(-time.Minute)}, + {ID: "s2", Title: "Two", UpdatedAt: now}, + } + runtime.loadSessions = map[string]agentsession.Session{ + "s2": { + ID: "s2", + Title: "Two", + Workdir: app.state.CurrentWorkdir, + Messages: []providertypes.Message{ + {Role: roleUser, Content: "hello"}, + }, + }, + } + if err := app.refreshSessionPicker(); err != nil { + t.Fatalf("refreshSessionPicker() error = %v", err) + } + app.openPicker(pickerSession, statusChooseSession, &app.sessionPicker, "") + app.sessionPicker.Select(1) + + model, cmd := app.updatePicker(tea.KeyMsg{Type: tea.KeyEnter}) + if cmd != nil { + _ = cmd() + } + app = model.(App) + if app.state.ActiveSessionID != "s2" || app.state.ActiveSessionTitle != "Two" { + t.Fatalf("expected selected session to be activated, got id=%q title=%q", app.state.ActiveSessionID, app.state.ActiveSessionTitle) + } + if len(app.activeMessages) != 1 { + t.Fatalf("expected messages to refresh from selected session") + } +} + +func TestActivateSessionByIDNotFound(t *testing.T) { + app, _ := newTestApp(t) + app.state.Sessions = []agentsession.Summary{{ID: "s1", Title: "one"}} + if err := app.activateSessionByID("missing"); err == nil { + t.Fatalf("expected session not found error") + } +} + +func TestHandleImmediateSlashCommandSession(t *testing.T) { + app, runtime := newTestApp(t) + runtime.listSessions = []agentsession.Summary{ + {ID: "s1", Title: "Session 1", UpdatedAt: time.Now()}, + } + handled, cmd := app.handleImmediateSlashCommand("/session") + if !handled { + t.Fatalf("expected /session to be handled immediately") + } + if cmd != nil { + _ = cmd() + } + if app.state.ActivePicker != pickerSession { + t.Fatalf("expected session picker opened by immediate slash command") + } +} + func TestRuntimeEventToolStatusHandler(t *testing.T) { app, _ := newTestApp(t) payload := tuiservices.RuntimeToolStatusPayload{ToolCallID: "tool-1", ToolName: "bash", Status: string(tuistate.ToolLifecyclePlanned)} diff --git a/internal/tui/core/app/view_test.go b/internal/tui/core/app/view_test.go index 8ff9eeca..0c807d5a 100644 --- a/internal/tui/core/app/view_test.go +++ b/internal/tui/core/app/view_test.go @@ -50,6 +50,23 @@ func TestRenderPickerSessionMode(t *testing.T) { } } +func TestRenderPickerProviderAndFileMode(t *testing.T) { + app, _ := newTestApp(t) + + app.state.ActivePicker = pickerProvider + app.providerPicker.SetItems([]list.Item{selectionItem{id: "p1", name: "Provider 1"}}) + providerView := app.renderPicker(48, 14) + if !strings.Contains(providerView, providerPickerTitle) { + t.Fatalf("expected provider picker title") + } + + app.state.ActivePicker = pickerFile + fileView := app.renderPicker(48, 14) + if !strings.Contains(fileView, filePickerTitle) { + t.Fatalf("expected file picker title") + } +} + func TestBuildPickerLayoutExpandsPopupSpace(t *testing.T) { app, _ := newTestApp(t) @@ -78,6 +95,18 @@ func TestRenderWaterfallUsesDynamicTranscriptHeight(t *testing.T) { } } +func TestRenderWaterfallThinkingState(t *testing.T) { + app, _ := newTestApp(t) + app.state.ActivePicker = pickerNone + app.state.IsAgentRunning = true + app.state.StatusText = statusThinking + + view := app.renderWaterfall(80, 24) + if !strings.Contains(view, "Thinking...") { + t.Fatalf("expected thinking hint in waterfall view") + } +} + func TestApplyComponentLayoutKeepsTranscriptHeightInSyncWithWaterfall(t *testing.T) { app, _ := newTestApp(t) app.width = 100 @@ -169,3 +198,39 @@ func TestRenderUserMessageKeepsTagAndBodyRightAligned(t *testing.T) { t.Fatalf("expected user tag and body right edges to match, got tag=%d body=%d\n%q\n%q", tagRightEdge, bodyRightEdge, tagLine, contentLine) } } + +func TestBuildPickerLayoutClampMin(t *testing.T) { + app, _ := newTestApp(t) + got := app.buildPickerLayout(10, 8) + if got.panelWidth != pickerPanelMinWidth { + t.Fatalf("expected panel width clamp to min %d, got %d", pickerPanelMinWidth, got.panelWidth) + } + if got.panelHeight != pickerPanelMinHeight { + t.Fatalf("expected panel height clamp to min %d, got %d", pickerPanelMinHeight, got.panelHeight) + } +} + +func TestRenderWaterfallWithActivePicker(t *testing.T) { + app, _ := newTestApp(t) + app.state.ActivePicker = pickerSession + app.sessionPicker.SetItems([]list.Item{ + sessionItem{Summary: agentsession.Summary{ + ID: "session-1", + Title: "Session One", + UpdatedAt: time.Now(), + }}, + }) + + view := app.renderWaterfall(90, 24) + if !strings.Contains(view, sessionPickerTitle) { + t.Fatalf("expected picker waterfall view to include session picker title") + } +} + +func TestRenderBody(t *testing.T) { + app, _ := newTestApp(t) + out := app.renderBody(layout{contentWidth: 90, contentHeight: 24}) + if strings.TrimSpace(out) == "" { + t.Fatalf("expected renderBody output") + } +} diff --git a/internal/tui/infra/clipboard_common.go b/internal/tui/infra/clipboard_common.go index 3a2d4615..3079eb72 100644 --- a/internal/tui/infra/clipboard_common.go +++ b/internal/tui/infra/clipboard_common.go @@ -14,13 +14,8 @@ func SaveImageToTempFile(data []byte, prefix string) (string, error) { return "", err } tmpFile := f.Name() - - if _, err = f.Write(data); err != nil { - _ = f.Close() - _ = os.Remove(tmpFile) - return "", err - } - if err = f.Close(); err != nil { + _ = f.Close() + if err = os.WriteFile(tmpFile, data, 0o600); err != nil { _ = os.Remove(tmpFile) return "", err } diff --git a/internal/tui/infra/infra_test.go b/internal/tui/infra/infra_test.go index 25772a39..9e39bb45 100644 --- a/internal/tui/infra/infra_test.go +++ b/internal/tui/infra/infra_test.go @@ -360,3 +360,168 @@ func workspaceExecutorCommands() (shell string, success string, noOutput string, "echo failed 1>&2; exit 2", "sleep 2" } + +func TestSanitizeTempPrefix(t *testing.T) { + if got := sanitizeTempPrefix(""); got != "" { + t.Fatalf("expected empty prefix to remain empty, got %q", got) + } + if got := sanitizeTempPrefix("p@st/e_1-2"); got != "pste_1-2" { + t.Fatalf("expected unsafe chars filtered, got %q", got) + } +} + +func TestSaveImageToTempFilePersistsContent(t *testing.T) { + data := []byte("image-bytes") + path, err := SaveImageToTempFile(data, "p@st/e") + if err != nil { + t.Fatalf("SaveImageToTempFile() error = %v", err) + } + defer os.Remove(path) + + if !strings.Contains(filepath.Base(path), "pste-") { + t.Fatalf("expected sanitized prefix in temp file name, got %q", filepath.Base(path)) + } + got, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read temp file: %v", err) + } + if string(got) != string(data) { + t.Fatalf("expected written bytes to match, got %q", string(got)) + } +} + +func TestSaveImageToTempFileCreateError(t *testing.T) { + t.Setenv("TMPDIR", filepath.Join(t.TempDir(), "missing-dir")) + if _, err := SaveImageToTempFile([]byte("x"), "paste"); err == nil { + t.Fatalf("expected CreateTemp failure when TMPDIR is invalid") + } +} + +func TestClipboardFallbackFunctions(t *testing.T) { + text, err := ReadClipboardText() + if err == nil && strings.TrimSpace(text) == "" { + t.Fatalf("expected clipboard text or an error, got empty success result") + } + data, err := ReadClipboardImage() + if err != errClipboardImageUnsupported { + t.Fatalf("expected unsupported image error, got %v", err) + } + if data != nil { + t.Fatalf("expected nil image data on unsupported platform") + } +} + +func TestImageInfoAndRead(t *testing.T) { + root := t.TempDir() + path := filepath.Join(root, "sample.jpg") + content := []byte{0xFF, 0xD8, 0xFF, 0x00} + if err := os.WriteFile(path, content, 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + info, err := GetFileInfo(path) + if err != nil { + t.Fatalf("GetFileInfo() error = %v", err) + } + if info.Size() != int64(len(content)) { + t.Fatalf("expected size %d, got %d", len(content), info.Size()) + } + read, err := ReadImageFile(path) + if err != nil { + t.Fatalf("ReadImageFile() error = %v", err) + } + if string(read) != string(content) { + t.Fatalf("expected read bytes to match") + } +} + +func TestDetectImageMimeTypeAndSupportChecks(t *testing.T) { + root := t.TempDir() + pngPath := filepath.Join(root, "x.png") + if err := os.WriteFile(pngPath, []byte("png"), 0o644); err != nil { + t.Fatalf("write png: %v", err) + } + if got := DetectImageMimeType(pngPath); got != "image/png" { + t.Fatalf("expected png by extension, got %q", got) + } + + jpgPath := filepath.Join(root, "x.JPG") + if err := os.WriteFile(jpgPath, []byte("jpg"), 0o644); err != nil { + t.Fatalf("write jpg: %v", err) + } + if got := DetectImageMimeType(jpgPath); got != "image/jpeg" { + t.Fatalf("expected jpeg by extension, got %q", got) + } + if !IsSupportedImageFormat(jpgPath) { + t.Fatalf("expected jpeg to be supported") + } + + txtPath := filepath.Join(root, "x.txt") + if err := os.WriteFile(txtPath, []byte("text"), 0o644); err != nil { + t.Fatalf("write txt: %v", err) + } + if got := DetectImageMimeType(txtPath); got == "" { + t.Fatalf("expected extension-based mime to be detected for txt") + } + if IsSupportedImageFormat(txtPath) { + t.Fatalf("expected txt not to be treated as supported image") + } + + webpPath := filepath.Join(root, "x.webp") + if err := os.WriteFile(webpPath, []byte("webp"), 0o644); err != nil { + t.Fatalf("write webp: %v", err) + } + if got := DetectImageMimeType(webpPath); got != "image/webp" { + t.Fatalf("expected webp by extension, got %q", got) + } + + gifPath := filepath.Join(root, "x.bin") + gifBytes := []byte("GIF89a........") + if err := os.WriteFile(gifPath, gifBytes, 0o644); err != nil { + t.Fatalf("write gif magic: %v", err) + } + if got := DetectImageMimeType(gifPath); got != "image/gif" { + t.Fatalf("expected gif by magic header, got %q", got) + } + + jpegMagicPath := filepath.Join(root, "jpeg-magic.bin") + if err := os.WriteFile(jpegMagicPath, []byte{0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46}, 0o644); err != nil { + t.Fatalf("write jpeg magic: %v", err) + } + if got := DetectImageMimeType(jpegMagicPath); got != "image/jpeg" { + t.Fatalf("expected jpeg by magic header, got %q", got) + } + + webpMagicPath := filepath.Join(root, "webp-magic.bin") + webpMagic := append([]byte("RIFF"), []byte{0, 0, 0, 0}...) + webpMagic = append(webpMagic, []byte("WEBP")...) + if err := os.WriteFile(webpMagicPath, webpMagic, 0o644); err != nil { + t.Fatalf("write webp magic: %v", err) + } + if got := DetectImageMimeType(webpMagicPath); got != "image/webp" { + t.Fatalf("expected webp by magic header, got %q", got) + } + + missingPath := filepath.Join(root, "missing.unknown") + if got := DetectImageMimeType(missingPath); got != "" { + t.Fatalf("expected empty mime for missing unknown file, got %q", got) + } +} + +func TestReadMagicHeaderErrorsAndShortRead(t *testing.T) { + root := t.TempDir() + path := filepath.Join(root, "short.bin") + if err := os.WriteFile(path, []byte{1, 2, 3}, 0o644); err != nil { + t.Fatalf("write short file: %v", err) + } + buf, err := readMagicHeader(path, 8) + if err != nil { + t.Fatalf("readMagicHeader(short) error = %v", err) + } + if len(buf) != 3 { + t.Fatalf("expected short read length 3, got %d", len(buf)) + } + if _, err := readMagicHeader(filepath.Join(root, "missing.bin"), 8); err == nil { + t.Fatalf("expected missing file error") + } +} From 7a093a63cdd3f5a67ab1172b9bf2cade74b4b8c1 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Wed, 15 Apr 2026 12:47:33 +0000 Subject: [PATCH 29/33] fix(gateway): harden jsonrpc invalid-id and frame error mapping Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: pionxe <148670367+pionxe@users.noreply.github.com> --- internal/gateway/protocol/jsonrpc.go | 11 +++++---- internal/gateway/protocol/jsonrpc_test.go | 17 ++++++++++++++ internal/gateway/server.go | 2 +- internal/gateway/server_additional_test.go | 26 ++++++++++++++++++++++ internal/gateway/server_test.go | 4 ++-- 5 files changed, 51 insertions(+), 9 deletions(-) diff --git a/internal/gateway/protocol/jsonrpc.go b/internal/gateway/protocol/jsonrpc.go index 37bad07b..eb6ad305 100644 --- a/internal/gateway/protocol/jsonrpc.go +++ b/internal/gateway/protocol/jsonrpc.go @@ -91,11 +91,11 @@ func NormalizeJSONRPCRequest(request JSONRPCRequest) (NormalizedRequest, *JSONRP normalized := NormalizedRequest{} requestID, idErr := normalizeJSONRPCID(request.ID) - normalized.ID = cloneJSONRawMessage(request.ID) normalized.RequestID = requestID if idErr != nil { return normalized, idErr } + normalized.ID = cloneJSONRawMessage(request.ID) if strings.TrimSpace(request.JSONRPC) != JSONRPCVersion { return normalized, NewJSONRPCError( @@ -222,17 +222,17 @@ func normalizeJSONRPCID(id json.RawMessage) (string, *JSONRPCError) { ) } - switch typedID := decoded.(type) { + switch value := decoded.(type) { case string: - typedID = strings.TrimSpace(typedID) - if typedID == "" { + identifier := strings.TrimSpace(value) + if identifier == "" { return "", NewJSONRPCError( JSONRPCCodeInvalidRequest, "invalid field: id", GatewayCodeInvalidFrame, ) } - return typedID, nil + return identifier, nil case float64: identifier := strings.TrimSpace(string(trimmed)) if identifier == "" { @@ -242,7 +242,6 @@ func normalizeJSONRPCID(id json.RawMessage) (string, *JSONRPCError) { GatewayCodeInvalidFrame, ) } - _ = typedID return identifier, nil default: return "", NewJSONRPCError( diff --git a/internal/gateway/protocol/jsonrpc_test.go b/internal/gateway/protocol/jsonrpc_test.go index 9b586502..e0cccc1f 100644 --- a/internal/gateway/protocol/jsonrpc_test.go +++ b/internal/gateway/protocol/jsonrpc_test.go @@ -186,6 +186,23 @@ func TestNormalizeJSONRPCRequestErrors(t *testing.T) { } } +func TestNormalizeJSONRPCRequestInvalidIDReturnsNullResponseID(t *testing.T) { + normalized, rpcErr := NormalizeJSONRPCRequest(JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`{}`), + Method: MethodGatewayPing, + }) + if rpcErr == nil { + t.Fatal("expected rpc error") + } + if rpcErr.Code != JSONRPCCodeInvalidRequest { + t.Fatalf("rpc code = %d, want %d", rpcErr.Code, JSONRPCCodeInvalidRequest) + } + if normalized.ID != nil { + t.Fatalf("normalized id = %s, want nil", string(normalized.ID)) + } +} + func TestJSONRPCHelpers(t *testing.T) { response, rpcErr := NewJSONRPCResultResponse(json.RawMessage(`"req-1"`), map[string]string{"message": "ok"}) if rpcErr != nil { diff --git a/internal/gateway/server.go b/internal/gateway/server.go index ce3c902a..001e34fd 100644 --- a/internal/gateway/server.go +++ b/internal/gateway/server.go @@ -279,7 +279,7 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn, runtimePor _ = s.writeRPCResponse(conn, encoder, protocol.NewJSONRPCErrorResponse( nil, protocol.NewJSONRPCError( - protocol.JSONRPCCodeInvalidParams, + protocol.JSONRPCCodeInvalidRequest, fmt.Sprintf("frame exceeds max size %d bytes", MaxFrameSize), protocol.GatewayCodeInvalidFrame, ), diff --git a/internal/gateway/server_additional_test.go b/internal/gateway/server_additional_test.go index ecd1cd4f..52466de9 100644 --- a/internal/gateway/server_additional_test.go +++ b/internal/gateway/server_additional_test.go @@ -308,6 +308,32 @@ func TestDispatchRPCRequestNormalizeError(t *testing.T) { } } +func TestDispatchRPCRequestInvalidIDReturnsNullID(t *testing.T) { + server := &Server{} + response := server.dispatchRPCRequest(context.Background(), protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + ID: json.RawMessage(`{"bad":"id"}`), + Method: protocol.MethodGatewayPing, + }, nil) + if response.Error == nil { + t.Fatal("expected rpc normalize error") + } + if response.Error.Code != protocol.JSONRPCCodeInvalidRequest { + t.Fatalf("rpc error code = %d, want %d", response.Error.Code, protocol.JSONRPCCodeInvalidRequest) + } + encoded, err := json.Marshal(response) + if err != nil { + t.Fatalf("marshal response: %v", err) + } + var payload map[string]any + if err := json.Unmarshal(encoded, &payload); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if id, ok := payload["id"]; !ok || id != nil { + t.Fatalf("encoded response id = %#v, want nil", payload["id"]) + } +} + func TestDispatchRPCRequestConvertsFrameErrorWithoutPayload(t *testing.T) { server := &Server{} originalHandlers := requestFrameHandlers diff --git a/internal/gateway/server_test.go b/internal/gateway/server_test.go index a65817d2..69c3b154 100644 --- a/internal/gateway/server_test.go +++ b/internal/gateway/server_test.go @@ -155,8 +155,8 @@ func TestServerHandleConnectionRejectsOversizedFrame(t *testing.T) { if response.Error == nil { t.Fatal("response rpc error is nil") } - if response.Error.Code != protocol.JSONRPCCodeInvalidParams { - t.Fatalf("rpc error code = %d, want %d", response.Error.Code, protocol.JSONRPCCodeInvalidParams) + if response.Error.Code != protocol.JSONRPCCodeInvalidRequest { + t.Fatalf("rpc error code = %d, want %d", response.Error.Code, protocol.JSONRPCCodeInvalidRequest) } if gatewayCode := protocol.GatewayCodeFromJSONRPCError(response.Error); gatewayCode != ErrorCodeInvalidFrame.String() { t.Fatalf("gateway_code = %q, want %q", gatewayCode, ErrorCodeInvalidFrame.String()) From b6222868263d705b485c3a097b2e5c24e4d1f184 Mon Sep 17 00:00:00 2001 From: Yumiue <229866007@qq.com> Date: Wed, 15 Apr 2026 22:16:18 +0800 Subject: [PATCH 30/33] =?UTF-8?q?fix(context):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E8=87=AA=E5=8A=A8=20compact=20=E9=98=88=E5=80=BC=E5=9B=9E?= =?UTF-8?q?=E9=80=80=E5=B9=B6=E8=A1=A5=E9=BD=90=E6=A8=A1=E5=9E=8B=E5=85=83?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E6=A0=A1=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/config-management-detail-design.md | 11 + docs/context-compact.md | 19 +- docs/guides/adding-providers.md | 31 +++ docs/guides/configuration.md | 34 +++- internal/app/bootstrap.go | 15 ++ internal/config/config_test.go | 133 +++++++++--- internal/config/context.go | 39 ++-- internal/config/loader.go | 18 +- internal/config/loader_test.go | 189 +++++++++++++++++- internal/config/provider_loader.go | 57 ++++++ internal/config/provider_test.go | 72 +++++++ .../config/state/auto_compact_threshold.go | 90 +++++++++ .../state/auto_compact_threshold_test.go | 166 +++++++++++++++ internal/gateway/handlers/wake.go | 15 +- internal/provider/catalog/service_test.go | 28 +++ internal/runtime/run.go | 25 ++- internal/runtime/runtime.go | 28 ++- internal/runtime/runtime_test.go | 86 +++++++- 18 files changed, 986 insertions(+), 70 deletions(-) create mode 100644 internal/config/state/auto_compact_threshold.go create mode 100644 internal/config/state/auto_compact_threshold_test.go diff --git a/docs/config-management-detail-design.md b/docs/config-management-detail-design.md index e7ef7836..d0c023ab 100644 --- a/docs/config-management-detail-design.md +++ b/docs/config-management-detail-design.md @@ -130,3 +130,14 @@ custom provider 来自: - 不在 `config` 中堆兼容旧字段的逻辑 - 不把选择修正混进快照校验 - 不把 provider/catalog/runtime 语义倒灌回 `config` + +## custom provider `models` 校验约束 + +`~/.neocode/providers//provider.yaml` 中允许通过 `models` 补齐模型元数据,用于 catalog/discovery 无法提供完整 `ContextWindow` 或 `MaxOutputTokens` 的场景。 + +该能力的约束是: + +- `models[].id` 必须非空。 +- `models[].context_window` 和 `models[].max_output_tokens` 如果显式提供,必须大于 `0`。 +- 重复的模型 `id` 会在加载 custom provider 时直接失败,不保留 silently drop 的宽松行为。 +- 这些元数据不会写回 `config.yaml`,只在 custom provider 文件中声明,并通过现有 catalog 合并链路参与运行时解析。 diff --git a/docs/context-compact.md b/docs/context-compact.md index 87fded85..af991d91 100644 --- a/docs/context-compact.md +++ b/docs/context-compact.md @@ -1,5 +1,11 @@ # Context Compact +## Auto Compact Failure Fallback + +- 当 `context.auto_compact.input_token_threshold <= 0` 时,系统会尝试基于当前模型的 `ContextWindow` 自动推导阈值。 +- 若当前 provider 选择无效、catalog snapshot 查询失败,或模型窗口元数据缺失,系统会直接回退到 `fallback_input_token_threshold`。 +- 自动推导失败不会静默关闭 auto compact;runtime 仍会拿到一个可用的保底阈值。 + 本文档说明 NeoCode 中 context compact 的配置、执行链路和摘要约定。 ## 概览 @@ -23,7 +29,9 @@ context: micro_compact_disabled: false auto_compact: enabled: false - input_token_threshold: 100000 + input_token_threshold: 0 + reserve_tokens: 13000 + fallback_input_token_threshold: 100000 ``` - `manual_strategy` @@ -154,3 +162,12 @@ compact 相关 runtime 事件包括: - `trigger_mode` - `transcript_id` - `transcript_path` + +## Auto Compact 阈值解析 + +- `context.auto_compact.input_token_threshold > 0` 时,直接使用显式手动阈值。 +- `context.auto_compact.input_token_threshold <= 0` 时,系统会对当前选中的 provider/model 做自动推导。 +- 自动推导公式为 `resolved_threshold = context_window - reserve_tokens`。 +- `reserve_tokens` 默认 `13000`,用于给输出、tool call 和 system prompt 预留缓冲。 +- 如果当前模型没有可用的 `ContextWindow`,或窗口值小于等于 `reserve_tokens`,则回退到 `fallback_input_token_threshold`。 +- `fallback_input_token_threshold` 默认 `100000`,用于保证主链路在缺少模型窗口元数据时仍可稳定运行。 diff --git a/docs/guides/adding-providers.md b/docs/guides/adding-providers.md index 95d94b01..22d32636 100644 --- a/docs/guides/adding-providers.md +++ b/docs/guides/adding-providers.md @@ -270,3 +270,34 @@ func DefaultProviders() []ProviderConfig { ``` 所有内置 provider 都通过代码集中注册。模型选择器展示的候选模型由默认模型、动态发现结果和本地缓存共同组成。 + +## custom provider 模型元数据补齐 + +对于复用 `openaicompat` 驱动的 custom provider,如果上游 `GET /models` 不能返回可靠的上下文窗口信息,可以在: + +```text +~/.neocode/providers//provider.yaml +``` + +中显式声明 `models`: + +```yaml +name: company-gateway +driver: openaicompat +api_key_env: COMPANY_GATEWAY_API_KEY +models: + - id: deepseek-coder + name: DeepSeek Coder + context_window: 131072 + max_output_tokens: 8192 +openai_compatible: + base_url: https://llm.example.com/v1 + api_style: chat_completions +``` + +约束如下: + +- `models[].id` 必须非空。 +- `models[].context_window` 和 `models[].max_output_tokens` 如果显式配置,必须大于 `0`。 +- 同一个 `provider.yaml` 中重复的模型 `id` 会在加载阶段直接报错。 +- 这些元数据会进入统一的 model catalog 合并链路,优先级仍为“配置模型元数据优先于 discovery/default”。 diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index 0d8724ca..4a0191e1 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -58,7 +58,9 @@ context: micro_compact_disabled: false auto_compact: enabled: false - input_token_threshold: 100000 + input_token_threshold: 0 + reserve_tokens: 13000 + fallback_input_token_threshold: 100000 ``` ### 基础字段 @@ -142,6 +144,13 @@ openai_compatible: api_style: chat_completions ``` +## Auto Compact 失败与校验补充 + +- 当 `context.auto_compact.input_token_threshold <= 0` 时,如果当前 provider 选择无效、catalog snapshot 查询失败,或模型缺少可用的 `ContextWindow`,系统会回退到 `fallback_input_token_threshold`,不会静默关闭 auto compact。 +- `~/.neocode/providers//provider.yaml` 中的 `models[].id` 必须非空。 +- `models[].context_window` 和 `models[].max_output_tokens` 如果显式配置,必须大于 `0`。 +- `models` 中重复的模型 `id` 会在加载 `provider.yaml` 时直接报错。 + 文件路径: ```text @@ -230,3 +239,26 @@ config: environment variable OPENAI_API_KEY is empty - [添加 Provider](./adding-providers.md) - [配置管理详细设计](../config-management-detail-design.md) - [Context Compact](../context-compact.md) + +## Auto Compact 补充说明 + +- `context.auto_compact.input_token_threshold > 0` 时,系统直接使用该显式阈值。 +- `context.auto_compact.input_token_threshold <= 0` 时,系统会根据当前 `current_model` 对应的 `ContextWindow` 自动推导输入阈值。 +- 推导公式为 `context_window - reserve_tokens`。 +- `reserve_tokens` 默认 `13000`。 +- 如果当前 provider/model 没有可用的 `ContextWindow` 元数据,则回退到 `fallback_input_token_threshold`。 +- custom provider 可以在 `~/.neocode/providers//provider.yaml` 中通过 `models` 字段补齐模型元数据,例如: + +```yaml +name: company-gateway +driver: openaicompat +api_key_env: COMPANY_GATEWAY_API_KEY +models: + - id: deepseek-coder + name: DeepSeek Coder + context_window: 131072 + max_output_tokens: 8192 +openai_compatible: + base_url: https://llm.example.com/v1 + api_style: chat_completions +``` diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go index f5cd22c4..0f64a269 100644 --- a/internal/app/bootstrap.go +++ b/internal/app/bootstrap.go @@ -168,6 +168,15 @@ func BuildRuntime(ctx context.Context, opts BootstrapOptions) (RuntimeBundle, er contextBuilder, ) runtimeSvc.SetSkillsRegistry(buildSkillsRegistry(ctx, loader.BaseDir())) + runtimeSvc.SetAutoCompactThresholdResolver(runtimeAutoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + resolution, err := configstate.ResolveAutoCompactThreshold(ctx, cfg, modelCatalogs) + if err != nil { + return 0, err + } + return resolution.Threshold, nil + }, + )) // 注入记忆提取钩子:当 AutoExtract 启用且 memoSvc 可用时,ReAct 循环完成后异步提取记忆。 if memoSvc != nil && cfg.Memo.AutoExtract { @@ -306,3 +315,9 @@ type textGenAdapter func(ctx context.Context, prompt string, msgs []providertype func (f textGenAdapter) Generate(ctx context.Context, prompt string, msgs []providertypes.Message) (string, error) { return f(ctx, prompt, msgs) } + +type runtimeAutoCompactThresholdResolverFunc func(ctx context.Context, cfg config.Config) (int, error) + +func (f runtimeAutoCompactThresholdResolverFunc) ResolveAutoCompactThreshold(ctx context.Context, cfg config.Config) (int, error) { + return f(ctx, cfg) +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index ffd0067c..83a0ba0c 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1242,6 +1242,14 @@ func TestAutoCompactConfigDefaults(t *testing.T) { t.Fatalf("expected input_token_threshold=%d, got %d", DefaultAutoCompactInputTokenThreshold, cfg.Context.AutoCompact.InputTokenThreshold) } + if cfg.Context.AutoCompact.ReserveTokens != DefaultAutoCompactReserveTokens { + t.Fatalf("expected reserve_tokens=%d, got %d", + DefaultAutoCompactReserveTokens, cfg.Context.AutoCompact.ReserveTokens) + } + if cfg.Context.AutoCompact.FallbackInputTokenThreshold != DefaultAutoCompactFallbackInputTokenThreshold { + t.Fatalf("expected fallback_input_token_threshold=%d, got %d", + DefaultAutoCompactFallbackInputTokenThreshold, cfg.Context.AutoCompact.FallbackInputTokenThreshold) + } if cfg.Context.AutoCompact.Enabled != false { t.Fatalf("expected enabled=false, got %v", cfg.Context.AutoCompact.Enabled) @@ -1253,13 +1261,20 @@ func TestAutoCompactConfigApplyDefaults(t *testing.T) { cfg := AutoCompactConfig{} defaults := AutoCompactConfig{ - InputTokenThreshold: 50000, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, } cfg.ApplyDefaults(defaults) - if cfg.InputTokenThreshold != 50000 { - t.Fatalf("expected threshold=50000, got %d", cfg.InputTokenThreshold) + if cfg.InputTokenThreshold != 0 { + t.Fatalf("expected threshold to remain implicit 0, got %d", cfg.InputTokenThreshold) + } + if cfg.ReserveTokens != 13000 { + t.Fatalf("expected reserve_tokens=13000, got %d", cfg.ReserveTokens) + } + if cfg.FallbackInputTokenThreshold != 100000 { + t.Fatalf("expected fallback_input_token_threshold=100000, got %d", cfg.FallbackInputTokenThreshold) } } @@ -1267,10 +1282,14 @@ func TestAutoCompactConfigApplyDefaultsPreservesExplicit(t *testing.T) { t.Parallel() cfg := AutoCompactConfig{ - InputTokenThreshold: 200000, + InputTokenThreshold: 200000, + ReserveTokens: 5000, + FallbackInputTokenThreshold: 80000, } defaults := AutoCompactConfig{ - InputTokenThreshold: 50000, + InputTokenThreshold: 50000, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, } cfg.ApplyDefaults(defaults) @@ -1278,13 +1297,19 @@ func TestAutoCompactConfigApplyDefaultsPreservesExplicit(t *testing.T) { if cfg.InputTokenThreshold != 200000 { t.Fatalf("expected explicit threshold=200000 to be preserved, got %d", cfg.InputTokenThreshold) } + if cfg.ReserveTokens != 5000 { + t.Fatalf("expected explicit reserve_tokens=5000 to be preserved, got %d", cfg.ReserveTokens) + } + if cfg.FallbackInputTokenThreshold != 80000 { + t.Fatalf("expected explicit fallback_input_token_threshold=80000 to be preserved, got %d", cfg.FallbackInputTokenThreshold) + } } func TestAutoCompactConfigApplyDefaultsNilReceiver(t *testing.T) { t.Parallel() var cfg *AutoCompactConfig - cfg.ApplyDefaults(AutoCompactConfig{InputTokenThreshold: 50000}) + cfg.ApplyDefaults(AutoCompactConfig{ReserveTokens: 13000, FallbackInputTokenThreshold: 100000}) } func TestContextConfigApplyDefaultsPropagatesAutoCompactDefaults(t *testing.T) { @@ -1293,7 +1318,8 @@ func TestContextConfigApplyDefaultsPropagatesAutoCompactDefaults(t *testing.T) { cfg := ContextConfig{} cfg.ApplyDefaults(ContextConfig{ AutoCompact: AutoCompactConfig{ - InputTokenThreshold: 50000, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, }, Compact: CompactConfig{ ManualStrategy: CompactManualStrategyKeepRecent, @@ -1303,8 +1329,14 @@ func TestContextConfigApplyDefaultsPropagatesAutoCompactDefaults(t *testing.T) { }, }) - if cfg.AutoCompact.InputTokenThreshold != 50000 { - t.Fatalf("expected auto compact threshold=50000, got %d", cfg.AutoCompact.InputTokenThreshold) + if cfg.AutoCompact.InputTokenThreshold != 0 { + t.Fatalf("expected auto compact threshold to remain implicit 0, got %d", cfg.AutoCompact.InputTokenThreshold) + } + if cfg.AutoCompact.ReserveTokens != 13000 { + t.Fatalf("expected reserve_tokens=13000, got %d", cfg.AutoCompact.ReserveTokens) + } + if cfg.AutoCompact.FallbackInputTokenThreshold != 100000 { + t.Fatalf("expected fallback_input_token_threshold=100000, got %d", cfg.AutoCompact.FallbackInputTokenThreshold) } } @@ -1312,16 +1344,15 @@ func TestAutoCompactConfigValidateEnabledWithoutThreshold(t *testing.T) { t.Parallel() cfg := AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 0, + Enabled: true, + InputTokenThreshold: 0, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, } err := cfg.Validate() - if err == nil { - t.Fatalf("expected validation error, got nil") - } - if !strings.Contains(err.Error(), "input_token_threshold") { - t.Fatalf("expected error about input_token_threshold, got %v", err) + if err != nil { + t.Fatalf("expected validation to allow implicit threshold, got %v", err) } } @@ -1329,8 +1360,10 @@ func TestAutoCompactConfigValidateDisabledWithoutThreshold(t *testing.T) { t.Parallel() cfg := AutoCompactConfig{ - Enabled: false, - InputTokenThreshold: 0, + Enabled: false, + InputTokenThreshold: 0, + ReserveTokens: 0, + FallbackInputTokenThreshold: 0, } err := cfg.Validate() @@ -1343,8 +1376,10 @@ func TestAutoCompactConfigValidateEnabledWithThreshold(t *testing.T) { t.Parallel() cfg := AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 50000, + Enabled: true, + InputTokenThreshold: 50000, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, } err := cfg.Validate() @@ -1353,12 +1388,46 @@ func TestAutoCompactConfigValidateEnabledWithThreshold(t *testing.T) { } } +func TestAutoCompactConfigValidateRejectsNonPositiveReserveTokens(t *testing.T) { + t.Parallel() + + cfg := AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + ReserveTokens: 0, + FallbackInputTokenThreshold: 100000, + } + + err := cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "reserve_tokens") { + t.Fatalf("expected reserve_tokens validation error, got %v", err) + } +} + +func TestAutoCompactConfigValidateRejectsNonPositiveFallbackThreshold(t *testing.T) { + t.Parallel() + + cfg := AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 0, + } + + err := cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "fallback_input_token_threshold") { + t.Fatalf("expected fallback_input_token_threshold validation error, got %v", err) + } +} + func TestAutoCompactConfigClone(t *testing.T) { t.Parallel() cfg := AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 75000, + Enabled: true, + InputTokenThreshold: 75000, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, } cloned := cfg.Clone() @@ -1370,6 +1439,13 @@ func TestAutoCompactConfigClone(t *testing.T) { t.Fatalf("expected threshold=%d to be cloned, got %d", cfg.InputTokenThreshold, cloned.InputTokenThreshold) } + if cfg.ReserveTokens != cloned.ReserveTokens { + t.Fatalf("expected reserve_tokens=%d to be cloned, got %d", cfg.ReserveTokens, cloned.ReserveTokens) + } + if cfg.FallbackInputTokenThreshold != cloned.FallbackInputTokenThreshold { + t.Fatalf("expected fallback_input_token_threshold=%d to be cloned, got %d", + cfg.FallbackInputTokenThreshold, cloned.FallbackInputTokenThreshold) + } cloned.InputTokenThreshold = 100000 if cfg.InputTokenThreshold == cloned.InputTokenThreshold { @@ -1382,8 +1458,10 @@ func TestAutoCompactConfigContextConfigValidate(t *testing.T) { ctx := ContextConfig{ AutoCompact: AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 0, + Enabled: true, + InputTokenThreshold: 0, + ReserveTokens: 13000, + FallbackInputTokenThreshold: 100000, }, Compact: CompactConfig{ ManualStrategy: CompactManualStrategyKeepRecent, @@ -1394,11 +1472,8 @@ func TestAutoCompactConfigContextConfigValidate(t *testing.T) { } err := ctx.Validate() - if err == nil { - t.Fatalf("expected validation error, got nil") - } - if !strings.Contains(err.Error(), "auto_compact") { - t.Fatalf("expected error to contain 'auto_compact', got %v", err) + if err != nil { + t.Fatalf("expected context validation to allow implicit threshold, got %v", err) } } diff --git a/internal/config/context.go b/internal/config/context.go index c0c2ffb5..57552b32 100644 --- a/internal/config/context.go +++ b/internal/config/context.go @@ -7,11 +7,13 @@ import ( ) const ( - DefaultCompactManualKeepRecentMessages = 10 - DefaultCompactMaxSummaryChars = 1200 - DefaultAutoCompactInputTokenThreshold = 100000 - DefaultMicroCompactRetainedToolSpans = 2 - DefaultCompactReadTimeMaxMessageSpans = 24 + DefaultCompactManualKeepRecentMessages = 10 + DefaultCompactMaxSummaryChars = 1200 + DefaultAutoCompactInputTokenThreshold = 0 + DefaultAutoCompactReserveTokens = 13000 + DefaultAutoCompactFallbackInputTokenThreshold = 100000 + DefaultMicroCompactRetainedToolSpans = 2 + DefaultCompactReadTimeMaxMessageSpans = 24 CompactManualStrategyKeepRecent = "keep_recent" CompactManualStrategyFullReplace = "full_replace" @@ -34,8 +36,10 @@ type CompactConfig struct { // AutoCompactConfig controls automatic context compression triggered by token thresholds. type AutoCompactConfig struct { - Enabled bool `yaml:"enabled"` - InputTokenThreshold int `yaml:"input_token_threshold,omitempty"` + Enabled bool `yaml:"enabled"` + InputTokenThreshold int `yaml:"input_token_threshold,omitempty"` + ReserveTokens int `yaml:"reserve_tokens,omitempty"` + FallbackInputTokenThreshold int `yaml:"fallback_input_token_threshold,omitempty"` } // defaultContextConfig 返回上下文压缩相关配置的默认值。 @@ -48,7 +52,9 @@ func defaultContextConfig() ContextConfig { func defaultAutoCompactConfig() AutoCompactConfig { return AutoCompactConfig{ - InputTokenThreshold: DefaultAutoCompactInputTokenThreshold, + InputTokenThreshold: DefaultAutoCompactInputTokenThreshold, + ReserveTokens: DefaultAutoCompactReserveTokens, + FallbackInputTokenThreshold: DefaultAutoCompactFallbackInputTokenThreshold, } } @@ -119,8 +125,11 @@ func (c *AutoCompactConfig) ApplyDefaults(defaults AutoCompactConfig) { if c == nil { return } - if c.InputTokenThreshold <= 0 { - c.InputTokenThreshold = defaults.InputTokenThreshold + if c.ReserveTokens <= 0 { + c.ReserveTokens = defaults.ReserveTokens + } + if c.FallbackInputTokenThreshold <= 0 { + c.FallbackInputTokenThreshold = defaults.FallbackInputTokenThreshold } } @@ -157,8 +166,14 @@ func (c CompactConfig) Validate() error { // Validate 校验 auto_compact 配置是否合法。 func (c AutoCompactConfig) Validate() error { - if c.Enabled && c.InputTokenThreshold <= 0 { - return errors.New("input_token_threshold must be greater than 0 when enabled") + if !c.Enabled { + return nil + } + if c.ReserveTokens <= 0 { + return errors.New("reserve_tokens must be greater than 0 when enabled") + } + if c.FallbackInputTokenThreshold <= 0 { + return errors.New("fallback_input_token_threshold must be greater than 0 when enabled") } return nil } diff --git a/internal/config/loader.go b/internal/config/loader.go index 878fd9af..5e35c3a6 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -48,8 +48,10 @@ type persistedCompactConfig struct { } type persistedAutoCompactConfig struct { - Enabled bool `yaml:"enabled"` - InputTokenThreshold int `yaml:"input_token_threshold,omitempty"` + Enabled bool `yaml:"enabled"` + InputTokenThreshold int `yaml:"input_token_threshold,omitempty"` + ReserveTokens int `yaml:"reserve_tokens,omitempty"` + FallbackInputTokenThreshold int `yaml:"fallback_input_token_threshold,omitempty"` } type persistedMemoConfig struct { @@ -249,8 +251,10 @@ func newPersistedContextConfig(cfg ContextConfig) persistedContextConfig { MaxArchivedPromptChars: cfg.Compact.MaxArchivedPromptChars, }, AutoCompact: persistedAutoCompactConfig{ - Enabled: cfg.AutoCompact.Enabled, - InputTokenThreshold: cfg.AutoCompact.InputTokenThreshold, + Enabled: cfg.AutoCompact.Enabled, + InputTokenThreshold: cfg.AutoCompact.InputTokenThreshold, + ReserveTokens: cfg.AutoCompact.ReserveTokens, + FallbackInputTokenThreshold: cfg.AutoCompact.FallbackInputTokenThreshold, }, } } @@ -268,8 +272,10 @@ func fromPersistedContextConfig(file persistedContextConfig, defaults ContextCon MaxArchivedPromptChars: file.Compact.MaxArchivedPromptChars, }, AutoCompact: AutoCompactConfig{ - Enabled: file.AutoCompact.Enabled, - InputTokenThreshold: file.AutoCompact.InputTokenThreshold, + Enabled: file.AutoCompact.Enabled, + InputTokenThreshold: file.AutoCompact.InputTokenThreshold, + ReserveTokens: file.AutoCompact.ReserveTokens, + FallbackInputTokenThreshold: file.AutoCompact.FallbackInputTokenThreshold, }, } out.Compact.ApplyDefaults(defaults.Compact) diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 86c87754..fb89594e 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -251,6 +251,11 @@ shell: powershell name: company-gateway driver: openaicompat api_key_env: COMPANY_GATEWAY_API_KEY +models: + - id: deepseek-coder + name: DeepSeek Coder + context_window: 131072 + max_output_tokens: 8192 openai_compatible: base_url: https://llm.example.com/v1 api_style: chat_completions @@ -290,8 +295,11 @@ openai_compatible: if customProvider.Model != "" { t.Fatalf("expected custom provider default model to be empty, got %q", customProvider.Model) } - if len(customProvider.Models) != 0 { - t.Fatalf("expected custom provider models to come only from remote discovery, got %+v", customProvider.Models) + if len(customProvider.Models) != 1 { + t.Fatalf("expected custom provider model metadata from provider.yaml, got %+v", customProvider.Models) + } + if customProvider.Models[0].ID != "deepseek-coder" || customProvider.Models[0].ContextWindow != 131072 { + t.Fatalf("expected parsed model metadata, got %+v", customProvider.Models[0]) } } @@ -421,6 +429,183 @@ models: } } +func TestLoaderRejectsCustomProviderModelWithoutID(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + customDir := filepath.Join(loader.BaseDir(), providersDirName, "company-gateway") + if err := os.MkdirAll(customDir, 0o755); err != nil { + t.Fatalf("mkdir custom provider dir: %v", err) + } + + providerYAML := ` +name: company-gateway +driver: openaicompat +api_key_env: COMPANY_GATEWAY_API_KEY +models: + - name: DeepSeek Coder +openai_compatible: + base_url: https://llm.example.com/v1 +` + if err := os.WriteFile(filepath.Join(customDir, customProviderConfigName), []byte(strings.TrimSpace(providerYAML)+"\n"), 0o644); err != nil { + t.Fatalf("write provider.yaml: %v", err) + } + + _, err := loader.Load(context.Background()) + if err == nil || !strings.Contains(err.Error(), "models[0].id") { + t.Fatalf("expected empty model id rejection, got %v", err) + } +} + +func TestLoaderRejectsCustomProviderModelWithInvalidContextWindow(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + customDir := filepath.Join(loader.BaseDir(), providersDirName, "company-gateway") + if err := os.MkdirAll(customDir, 0o755); err != nil { + t.Fatalf("mkdir custom provider dir: %v", err) + } + + providerYAML := ` +name: company-gateway +driver: openaicompat +api_key_env: COMPANY_GATEWAY_API_KEY +models: + - id: deepseek-coder + context_window: 0 +openai_compatible: + base_url: https://llm.example.com/v1 +` + if err := os.WriteFile(filepath.Join(customDir, customProviderConfigName), []byte(strings.TrimSpace(providerYAML)+"\n"), 0o644); err != nil { + t.Fatalf("write provider.yaml: %v", err) + } + + _, err := loader.Load(context.Background()) + if err == nil || !strings.Contains(err.Error(), "context_window") { + t.Fatalf("expected invalid context_window rejection, got %v", err) + } +} + +func TestLoaderRejectsCustomProviderModelWithInvalidMaxOutputTokens(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + customDir := filepath.Join(loader.BaseDir(), providersDirName, "company-gateway") + if err := os.MkdirAll(customDir, 0o755); err != nil { + t.Fatalf("mkdir custom provider dir: %v", err) + } + + providerYAML := ` +name: company-gateway +driver: openaicompat +api_key_env: COMPANY_GATEWAY_API_KEY +models: + - id: deepseek-coder + max_output_tokens: 0 +openai_compatible: + base_url: https://llm.example.com/v1 +` + if err := os.WriteFile(filepath.Join(customDir, customProviderConfigName), []byte(strings.TrimSpace(providerYAML)+"\n"), 0o644); err != nil { + t.Fatalf("write provider.yaml: %v", err) + } + + _, err := loader.Load(context.Background()) + if err == nil || !strings.Contains(err.Error(), "max_output_tokens") { + t.Fatalf("expected invalid max_output_tokens rejection, got %v", err) + } +} + +func TestLoaderRejectsCustomProviderDuplicateModelID(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + customDir := filepath.Join(loader.BaseDir(), providersDirName, "company-gateway") + if err := os.MkdirAll(customDir, 0o755); err != nil { + t.Fatalf("mkdir custom provider dir: %v", err) + } + + providerYAML := ` +name: company-gateway +driver: openaicompat +api_key_env: COMPANY_GATEWAY_API_KEY +models: + - id: deepseek-coder + - id: DeepSeek-Coder +openai_compatible: + base_url: https://llm.example.com/v1 +` + if err := os.WriteFile(filepath.Join(customDir, customProviderConfigName), []byte(strings.TrimSpace(providerYAML)+"\n"), 0o644); err != nil { + t.Fatalf("write provider.yaml: %v", err) + } + + _, err := loader.Load(context.Background()) + if err == nil || !strings.Contains(err.Error(), "duplicated") { + t.Fatalf("expected duplicate model id rejection, got %v", err) + } +} + +func TestLoaderParsesAutoCompactDerivedFields(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + raw := ` +selected_provider: openai +current_model: gpt-5.4 +shell: powershell +context: + auto_compact: + enabled: true + input_token_threshold: 0 + reserve_tokens: 9000 + fallback_input_token_threshold: 88000 +` + writeLoaderConfig(t, loader, raw) + + cfg, err := loader.Load(context.Background()) + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if cfg.Context.AutoCompact.InputTokenThreshold != 0 { + t.Fatalf("expected implicit threshold 0, got %d", cfg.Context.AutoCompact.InputTokenThreshold) + } + if cfg.Context.AutoCompact.ReserveTokens != 9000 { + t.Fatalf("expected reserve_tokens=9000, got %d", cfg.Context.AutoCompact.ReserveTokens) + } + if cfg.Context.AutoCompact.FallbackInputTokenThreshold != 88000 { + t.Fatalf("expected fallback_input_token_threshold=88000, got %d", cfg.Context.AutoCompact.FallbackInputTokenThreshold) + } +} + +func TestLoaderSavePersistsAutoCompactDerivedFields(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.ReserveTokens = 9000 + cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 + + if err := loader.Save(context.Background(), &cfg); err != nil { + t.Fatalf("Save() error = %v", err) + } + + data, err := os.ReadFile(loader.ConfigPath()) + if err != nil { + t.Fatalf("read config: %v", err) + } + text := string(data) + if strings.Contains(text, "input_token_threshold: 100000") { + t.Fatalf("expected implicit threshold to avoid legacy default, got:\n%s", text) + } + if !strings.Contains(text, "reserve_tokens: 9000") { + t.Fatalf("expected reserve_tokens to persist, got:\n%s", text) + } + if !strings.Contains(text, "fallback_input_token_threshold: 88000") { + t.Fatalf("expected fallback_input_token_threshold to persist, got:\n%s", text) + } +} + func TestLoaderRejectsCustomProviderNameConflictingWithBuiltin(t *testing.T) { t.Parallel() diff --git a/internal/config/provider_loader.go b/internal/config/provider_loader.go index f7a2b2ac..43448ea6 100644 --- a/internal/config/provider_loader.go +++ b/internal/config/provider_loader.go @@ -10,6 +10,7 @@ import ( "gopkg.in/yaml.v3" "neo-code/internal/provider" + providertypes "neo-code/internal/provider/types" ) const ( @@ -22,11 +23,19 @@ type customProviderFile struct { Driver string `yaml:"driver"` APIKeyEnv string `yaml:"api_key_env"` BaseURL string `yaml:"base_url,omitempty"` + Models []customProviderModelFile `yaml:"models,omitempty"` OpenAICompatible customOpenAICompatibleFile `yaml:"openai_compatible,omitempty"` Gemini customGeminiProviderFile `yaml:"gemini,omitempty"` Anthropic customAnthropicProviderFile `yaml:"anthropic,omitempty"` } +type customProviderModelFile struct { + ID string `yaml:"id"` + Name string `yaml:"name,omitempty"` + ContextWindow *int `yaml:"context_window,omitempty"` + MaxOutputTokens *int `yaml:"max_output_tokens,omitempty"` +} + type customOpenAICompatibleFile struct { BaseURL string `yaml:"base_url"` APIStyle string `yaml:"api_style,omitempty"` @@ -106,6 +115,11 @@ func loadCustomProvider(providerDir string) (ProviderConfig, error) { } settings := resolveCustomProviderSettings(file) + models, err := customProviderModels(file.Models) + if err != nil { + return ProviderConfig{}, fmt.Errorf("config: custom provider %q: %w", filepath.Base(providerDir), err) + } + cfg := ProviderConfig{ Name: strings.TrimSpace(file.Name), Driver: strings.TrimSpace(file.Driver), @@ -114,6 +128,7 @@ func loadCustomProvider(providerDir string) (ProviderConfig, error) { APIStyle: settings.APIStyle, DeploymentMode: settings.DeploymentMode, APIVersion: settings.APIVersion, + Models: models, Source: ProviderSourceCustom, } @@ -128,6 +143,48 @@ func loadCustomProvider(providerDir string) (ProviderConfig, error) { return cfg, nil } +// customProviderModels 校验并收敛 custom provider.yaml 中声明的模型元数据。 +func customProviderModels(models []customProviderModelFile) ([]providertypes.ModelDescriptor, error) { + if len(models) == 0 { + return nil, nil + } + + descriptors := make([]providertypes.ModelDescriptor, 0, len(models)) + seen := make(map[string]struct{}, len(models)) + for index, model := range models { + id := strings.TrimSpace(model.ID) + if id == "" { + return nil, fmt.Errorf("models[%d].id is empty", index) + } + + key := provider.NormalizeKey(id) + if _, exists := seen[key]; exists { + return nil, fmt.Errorf("models[%d].id %q is duplicated", index, id) + } + seen[key] = struct{}{} + + descriptor := providertypes.ModelDescriptor{ + ID: id, + Name: strings.TrimSpace(model.Name), + } + if model.ContextWindow != nil { + if *model.ContextWindow <= 0 { + return nil, fmt.Errorf("models[%d].context_window must be greater than 0", index) + } + descriptor.ContextWindow = *model.ContextWindow + } + if model.MaxOutputTokens != nil { + if *model.MaxOutputTokens <= 0 { + return nil, fmt.Errorf("models[%d].max_output_tokens must be greater than 0", index) + } + descriptor.MaxOutputTokens = *model.MaxOutputTokens + } + descriptors = append(descriptors, descriptor) + } + + return providertypes.MergeModelDescriptors(descriptors), nil +} + // resolveCustomProviderSettings 根据 driver 只提取当前协议真正生效的配置字段,避免误吃其他协议块的值。 // 已知 driver 仅从协议块读取 base_url;未知 driver 使用顶层 base_url 作为唯一入口。 func resolveCustomProviderSettings(file customProviderFile) customProviderSettings { diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 22525bb5..f364b7ce 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -194,6 +194,78 @@ func TestCloneProviderConfigModelDescriptorsIndependence(t *testing.T) { } } +func TestCustomProviderModelsParsesSupportedMetadata(t *testing.T) { + t.Parallel() + + contextWindow := 131072 + maxOutputTokens := 8192 + models, err := customProviderModels([]customProviderModelFile{ + { + ID: "deepseek-coder", + Name: "DeepSeek Coder", + ContextWindow: &contextWindow, + MaxOutputTokens: &maxOutputTokens, + }, + }) + if err != nil { + t.Fatalf("customProviderModels() error = %v", err) + } + + if len(models) != 1 { + t.Fatalf("expected one parsed model, got %+v", models) + } + if models[0].ID != "deepseek-coder" || models[0].ContextWindow != 131072 || models[0].MaxOutputTokens != 8192 { + t.Fatalf("unexpected parsed model descriptor: %+v", models[0]) + } +} + +func TestCustomProviderModelsRejectsEmptyID(t *testing.T) { + t.Parallel() + + _, err := customProviderModels([]customProviderModelFile{{Name: "Missing ID"}}) + if err == nil || !strings.Contains(err.Error(), "models[0].id") { + t.Fatalf("expected empty id validation error, got %v", err) + } +} + +func TestCustomProviderModelsRejectsNonPositiveContextWindow(t *testing.T) { + t.Parallel() + + contextWindow := 0 + _, err := customProviderModels([]customProviderModelFile{{ + ID: "deepseek-coder", + ContextWindow: &contextWindow, + }}) + if err == nil || !strings.Contains(err.Error(), "context_window") { + t.Fatalf("expected context_window validation error, got %v", err) + } +} + +func TestCustomProviderModelsRejectsNonPositiveMaxOutputTokens(t *testing.T) { + t.Parallel() + + maxOutputTokens := 0 + _, err := customProviderModels([]customProviderModelFile{{ + ID: "deepseek-coder", + MaxOutputTokens: &maxOutputTokens, + }}) + if err == nil || !strings.Contains(err.Error(), "max_output_tokens") { + t.Fatalf("expected max_output_tokens validation error, got %v", err) + } +} + +func TestCustomProviderModelsRejectsDuplicateID(t *testing.T) { + t.Parallel() + + _, err := customProviderModels([]customProviderModelFile{ + {ID: "deepseek-coder"}, + {ID: " DeepSeek-Coder "}, + }) + if err == nil || !strings.Contains(err.Error(), "duplicated") { + t.Fatalf("expected duplicate id validation error, got %v", err) + } +} + func TestProviderByNameCaseInsensitive(t *testing.T) { t.Parallel() diff --git a/internal/config/state/auto_compact_threshold.go b/internal/config/state/auto_compact_threshold.go new file mode 100644 index 00000000..490a2d54 --- /dev/null +++ b/internal/config/state/auto_compact_threshold.go @@ -0,0 +1,90 @@ +package state + +import ( + "context" + "strings" + + "neo-code/internal/config" + "neo-code/internal/provider" +) + +// AutoCompactThresholdSource 标识自动压缩阈值最终采用的来源。 +type AutoCompactThresholdSource string + +const ( + AutoCompactThresholdSourceDisabled AutoCompactThresholdSource = "disabled" + AutoCompactThresholdSourceExplicit AutoCompactThresholdSource = "explicit" + AutoCompactThresholdSourceDerived AutoCompactThresholdSource = "derived" + AutoCompactThresholdSourceFallback AutoCompactThresholdSource = "fallback" +) + +// AutoCompactThresholdResolution 描述自动压缩阈值的解析结果,供 runtime 直接消费。 +type AutoCompactThresholdResolution struct { + Threshold int + Source AutoCompactThresholdSource + ContextWindow int + ModelID string +} + +// fallbackAutoCompactThresholdResolution 构造自动推导失败时使用的保底阈值结果。 +func fallbackAutoCompactThresholdResolution(cfg config.Config) AutoCompactThresholdResolution { + return AutoCompactThresholdResolution{ + Threshold: cfg.Context.AutoCompact.FallbackInputTokenThreshold, + Source: AutoCompactThresholdSourceFallback, + ModelID: strings.TrimSpace(cfg.CurrentModel), + } +} + +// ResolveAutoCompactThreshold 基于当前选择的 provider/model 和模型目录快照解析最终阈值。 +func ResolveAutoCompactThreshold( + ctx context.Context, + cfg config.Config, + catalogs ModelCatalog, +) (AutoCompactThresholdResolution, error) { + autoCompact := cfg.Context.AutoCompact + if !autoCompact.Enabled { + return AutoCompactThresholdResolution{Source: AutoCompactThresholdSourceDisabled}, nil + } + + if autoCompact.InputTokenThreshold > 0 { + return AutoCompactThresholdResolution{ + Threshold: autoCompact.InputTokenThreshold, + Source: AutoCompactThresholdSourceExplicit, + ModelID: strings.TrimSpace(cfg.CurrentModel), + }, nil + } + + resolution := fallbackAutoCompactThresholdResolution(cfg) + providerCfg, err := selectedProviderConfig(cfg) + if err != nil { + return resolution, nil + } + if catalogs == nil { + return resolution, nil + } + + input, err := catalogInputFromProvider(providerCfg) + if err != nil { + return resolution, nil + } + + models, err := catalogs.ListProviderModelsSnapshot(ctx, input) + if err != nil { + return resolution, nil + } + + modelID := provider.NormalizeKey(cfg.CurrentModel) + for _, model := range models { + if provider.NormalizeKey(model.ID) != modelID { + continue + } + resolution.ContextWindow = model.ContextWindow + if model.ContextWindow > autoCompact.ReserveTokens { + resolution.Threshold = model.ContextWindow - autoCompact.ReserveTokens + resolution.Source = AutoCompactThresholdSourceDerived + } + return resolution, nil + } + + return resolution, nil +} diff --git a/internal/config/state/auto_compact_threshold_test.go b/internal/config/state/auto_compact_threshold_test.go new file mode 100644 index 00000000..24bd9efc --- /dev/null +++ b/internal/config/state/auto_compact_threshold_test.go @@ -0,0 +1,166 @@ +package state + +import ( + "context" + "errors" + "testing" + + configpkg "neo-code/internal/config" + providertypes "neo-code/internal/provider/types" +) + +func TestResolveAutoCompactThresholdDisabled(t *testing.T) { + t.Parallel() + + cfg := configpkg.StaticDefaults().Clone() + cfg.Context.AutoCompact.Enabled = false + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, nil) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 0 || resolution.Source != AutoCompactThresholdSourceDisabled { + t.Fatalf("expected disabled resolution, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdExplicitWins(t *testing.T) { + t.Parallel() + + cfg := configpkg.StaticDefaults().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 42000 + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, nil) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 42000 || resolution.Source != AutoCompactThresholdSourceExplicit { + t.Fatalf("expected explicit resolution, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdDerivedFromContextWindow(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.ReserveTokens = 13000 + cfg.CurrentModel = "deepseek-coder" + cfg.Providers[0].Model = "deepseek-coder" + cfg.Providers[0].Models = []providertypes.ModelDescriptor{{ + ID: "deepseek-coder", + ContextWindow: 131072, + }} + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{ + snapshotModels: cfg.Providers[0].Models, + }) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 118072 || resolution.Source != AutoCompactThresholdSourceDerived { + t.Fatalf("expected derived threshold, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdFallsBackWhenWindowTooSmall(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.ReserveTokens = 13000 + cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 + cfg.CurrentModel = "small-model" + cfg.Providers[0].Model = "small-model" + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{ + snapshotModels: []providertypes.ModelDescriptor{{ + ID: "small-model", + ContextWindow: 8000, + }}, + }) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 88000 || resolution.Source != AutoCompactThresholdSourceFallback { + t.Fatalf("expected fallback threshold, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdFallsBackWhenModelMissing(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 + cfg.CurrentModel = "missing-model" + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{ + snapshotModels: []providertypes.ModelDescriptor{{ID: "other-model", ContextWindow: 131072}}, + }) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 88000 || resolution.Source != AutoCompactThresholdSourceFallback { + t.Fatalf("expected missing model to use fallback, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdFallsBackWhenSelectedProviderInvalid(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 + cfg.SelectedProvider = "missing-provider" + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{}) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 88000 || resolution.Source != AutoCompactThresholdSourceFallback { + t.Fatalf("expected invalid selection to use fallback, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdFallsBackWhenCatalogInputResolutionFails(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 + cfg.Providers[0].BaseURL = "" + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{}) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 88000 || resolution.Source != AutoCompactThresholdSourceFallback { + t.Fatalf("expected invalid catalog input to use fallback, got %+v", resolution) + } +} + +func TestResolveAutoCompactThresholdFallsBackWhenSnapshotLookupFails(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.AutoCompact.Enabled = true + cfg.Context.AutoCompact.InputTokenThreshold = 0 + cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 + + resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{ + snapshotErr: errors.New("snapshot failed"), + }) + if err != nil { + t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) + } + if resolution.Threshold != 88000 || resolution.Source != AutoCompactThresholdSourceFallback { + t.Fatalf("expected snapshot error to use fallback, got %+v", resolution) + } +} diff --git a/internal/gateway/handlers/wake.go b/internal/gateway/handlers/wake.go index db7d7f40..2d72ce2d 100644 --- a/internal/gateway/handlers/wake.go +++ b/internal/gateway/handlers/wake.go @@ -97,7 +97,7 @@ func isSafeReviewPath(path string) bool { if hasBlockedWindowsPathPrefix(trimmed) { return false } - if filepath.IsAbs(trimmed) || strings.HasPrefix(trimmed, "/") || strings.HasPrefix(trimmed, "\\") { + if isAbsoluteReviewPath(trimmed) { return false } if containsParentTraversalSegment(trimmed) { @@ -113,7 +113,7 @@ func isSafeReviewPath(path string) bool { return true } -// hasWindowsDriveLetterPrefix 检查是否为 Windows 盘符前缀路径(如 C:foo),避免平台差异导致漏拦截。 +// hasWindowsDriveLetterPrefix 检查是否为 Windows 盘符前缀路径,如 C:foo。 func hasWindowsDriveLetterPrefix(path string) bool { trimmed := strings.TrimSpace(path) if len(trimmed) < 2 { @@ -123,12 +123,21 @@ func hasWindowsDriveLetterPrefix(path string) bool { return ((drive >= 'a' && drive <= 'z') || (drive >= 'A' && drive <= 'Z')) && trimmed[1] == ':' } -// hasBlockedWindowsPathPrefix 检查是否命中 Windows 底层设备路径前缀,避免绕过常规路径校验。 +// hasBlockedWindowsPathPrefix 检查是否命中 Windows 设备路径前缀,避免绕过常规路径校验。 func hasBlockedWindowsPathPrefix(path string) bool { normalized := strings.ReplaceAll(strings.TrimSpace(path), "/", "\\") return strings.HasPrefix(normalized, `\\?\`) || strings.HasPrefix(normalized, `\\.\`) } +// isAbsoluteReviewPath 统一识别跨平台绝对路径,避免在 Windows 下漏判 Unix 风格前导斜杠。 +func isAbsoluteReviewPath(path string) bool { + if filepath.IsAbs(path) { + return true + } + normalized := normalizePath(path) + return strings.HasPrefix(normalized, "/") || strings.HasPrefix(normalized, "\\") +} + // containsParentTraversalSegment 按路径段语义识别目录回退段,避免子串匹配导致误伤。 func containsParentTraversalSegment(path string) bool { normalized := normalizePath(path) diff --git a/internal/provider/catalog/service_test.go b/internal/provider/catalog/service_test.go index df826fc7..3488dd3a 100644 --- a/internal/provider/catalog/service_test.go +++ b/internal/provider/catalog/service_test.go @@ -103,6 +103,34 @@ func TestListProviderModelsMergesConfiguredMetadataAfterDiscovery(t *testing.T) } } +func TestListProviderModelsUsesConfiguredContextWindowWhenDiscoveryMissesIt(t *testing.T) { + t.Setenv(testAPIKeyEnv, "test-key") + + registry := newRegistry(t, openaicompat.DriverName, func(ctx context.Context, cfg provider.RuntimeConfig) ([]providertypes.ModelDescriptor, error) { + return []providertypes.ModelDescriptor{{ + ID: "deepseek-coder", + Name: "Server DeepSeek", + ContextWindow: 0, + }}, nil + }) + + service := NewService("", registry, newMemoryStore()) + providerCfg := customGatewayProvider() + providerCfg.Models = []providertypes.ModelDescriptor{{ + ID: "deepseek-coder", + Name: "DeepSeek Coder", + ContextWindow: 131072, + }} + + models, err := service.ListProviderModels(context.Background(), mustCatalogInput(t, providerCfg)) + if err != nil { + t.Fatalf("ListProviderModels() error = %v", err) + } + if len(models) != 1 || models[0].ContextWindow != 131072 { + t.Fatalf("expected configured context window to fill discovery gap, got %+v", models) + } +} + func TestListProviderModelsSnapshotReturnsDefaultAndRefreshesInBackgroundOnMiss(t *testing.T) { t.Setenv(testAPIKeyEnv, "test-key") diff --git a/internal/runtime/run.go b/internal/runtime/run.go index e624dd20..93415b36 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -184,7 +184,7 @@ func (s *Service) prepareTurnSnapshot(ctx context.Context, state *runState) (tur }, Compact: agentcontext.CompactOptions{ DisableMicroCompact: cfg.Context.Compact.MicroCompactDisabled, - AutoCompactThreshold: autoCompactThreshold(cfg), + AutoCompactThreshold: s.autoCompactThreshold(ctx, cfg), MicroCompactRetainedToolSpans: cfg.Context.Compact.MicroCompactRetainedToolSpans, ReadTimeMaxMessageSpans: cfg.Context.Compact.ReadTimeMaxMessageSpans, }, @@ -350,11 +350,20 @@ func (s *Service) applyCompactForState( } // autoCompactThreshold 返回当前配置下的自动 compact 触发阈值。 -func autoCompactThreshold(cfg config.Config) int { - if cfg.Context.AutoCompact.Enabled && cfg.Context.AutoCompact.InputTokenThreshold > 0 { +func (s *Service) autoCompactThreshold(ctx context.Context, cfg config.Config) int { + if !cfg.Context.AutoCompact.Enabled { + return 0 + } + if cfg.Context.AutoCompact.InputTokenThreshold > 0 { return cfg.Context.AutoCompact.InputTokenThreshold } - return 0 + if s != nil && s.autoCompactThresholdResolver != nil { + threshold, err := s.autoCompactThresholdResolver.ResolveAutoCompactThreshold(ctx, cfg) + if err == nil { + return threshold + } + } + return fallbackAutoCompactThreshold(cfg) } // degradeKeepRecentMessages 根据 reactive compact 尝试次数逐步减少保留消息数。 @@ -367,3 +376,11 @@ func degradeKeepRecentMessages(base int, attempt int) int { } return base } + +// fallbackAutoCompactThreshold 返回自动推导失败时仍可继续使用的保底阈值。 +func fallbackAutoCompactThreshold(cfg config.Config) int { + if cfg.Context.AutoCompact.FallbackInputTokenThreshold > 0 { + return cfg.Context.AutoCompact.FallbackInputTokenThreshold + } + return 0 +} diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index ca619ffc..ce5546a4 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -60,16 +60,21 @@ type MemoExtractor interface { } // Service 是 runtime 的默认实现,负责组织一次完整的 agent 运行闭环。 +type AutoCompactThresholdResolver interface { + ResolveAutoCompactThreshold(ctx context.Context, cfg config.Config) (int, error) +} + type Service struct { - configManager *config.Manager - sessionStore agentsession.Store - toolManager tools.Manager - providerFactory ProviderFactory - contextBuilder agentcontext.Builder - compactRunner contextcompact.Runner - approvalBroker *approval.Broker - memoExtractor MemoExtractor - skillsRegistry skills.Registry + configManager *config.Manager + sessionStore agentsession.Store + toolManager tools.Manager + providerFactory ProviderFactory + contextBuilder agentcontext.Builder + compactRunner contextcompact.Runner + approvalBroker *approval.Broker + memoExtractor MemoExtractor + skillsRegistry skills.Registry + autoCompactThresholdResolver AutoCompactThresholdResolver events chan RuntimeEvent sessionMu sync.Mutex @@ -171,3 +176,8 @@ func (s *Service) LoadSession(ctx context.Context, id string) (agentsession.Sess } return session, nil } + +// SetAutoCompactThresholdResolver 注入自动压缩阈值解析能力,避免 runtime 直接处理模型目录细节。 +func (s *Service) SetAutoCompactThresholdResolver(resolver AutoCompactThresholdResolver) { + s.autoCompactThresholdResolver = resolver +} diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 43c43114..8ad229d7 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -39,6 +39,12 @@ type failingStore struct { ignoreContextErr bool } +type autoCompactThresholdResolverFunc func(ctx context.Context, cfg config.Config) (int, error) + +func (f autoCompactThresholdResolverFunc) ResolveAutoCompactThreshold(ctx context.Context, cfg config.Config) (int, error) { + return f(ctx, cfg) +} + func newMemoryStore() *memoryStore { return &memoryStore{sessions: map[string]agentsession.Session{}} } @@ -4186,6 +4192,7 @@ func TestRestoreSessionTokensNewSession(t *testing.T) { func TestAutoCompactThresholdEnabled(t *testing.T) { t.Parallel() + service := &Service{} cfg := config.Config{ Context: config.ContextConfig{ AutoCompact: config.AutoCompactConfig{ @@ -4195,7 +4202,7 @@ func TestAutoCompactThresholdEnabled(t *testing.T) { }, } - threshold := autoCompactThreshold(cfg) + threshold := service.autoCompactThreshold(context.Background(), cfg) if threshold != 50000 { t.Fatalf("expected threshold == 50000, got %d", threshold) } @@ -4204,6 +4211,7 @@ func TestAutoCompactThresholdEnabled(t *testing.T) { func TestAutoCompactThresholdDisabled(t *testing.T) { t.Parallel() + service := &Service{} cfg := config.Config{ Context: config.ContextConfig{ AutoCompact: config.AutoCompactConfig{ @@ -4213,7 +4221,7 @@ func TestAutoCompactThresholdDisabled(t *testing.T) { }, } - threshold := autoCompactThreshold(cfg) + threshold := service.autoCompactThreshold(context.Background(), cfg) if threshold != 0 { t.Fatalf("expected threshold == 0, got %d", threshold) } @@ -4222,6 +4230,7 @@ func TestAutoCompactThresholdDisabled(t *testing.T) { func TestAutoCompactThresholdZeroValue(t *testing.T) { t.Parallel() + service := &Service{} cfg := config.Config{ Context: config.ContextConfig{ AutoCompact: config.AutoCompactConfig{ @@ -4231,12 +4240,83 @@ func TestAutoCompactThresholdZeroValue(t *testing.T) { }, } - threshold := autoCompactThreshold(cfg) + threshold := service.autoCompactThreshold(context.Background(), cfg) if threshold != 0 { t.Fatalf("expected threshold == 0, got %d", threshold) } } +func TestAutoCompactThresholdUsesResolver(t *testing.T) { + t.Parallel() + + service := &Service{} + service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + return 88000, nil + }, + )) + + cfg := config.Config{ + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + }, + }, + } + + threshold := service.autoCompactThreshold(context.Background(), cfg) + if threshold != 88000 { + t.Fatalf("expected resolver threshold == 88000, got %d", threshold) + } +} + +func TestAutoCompactThresholdFallsBackWhenResolverErrors(t *testing.T) { + t.Parallel() + + service := &Service{} + service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + return 0, errors.New("resolver failed") + }, + )) + + cfg := config.Config{ + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + FallbackInputTokenThreshold: 88000, + }, + }, + } + + threshold := service.autoCompactThreshold(context.Background(), cfg) + if threshold != 88000 { + t.Fatalf("expected fallback threshold == 88000, got %d", threshold) + } +} + +func TestAutoCompactThresholdImplicitModeWithoutResolverUsesFallback(t *testing.T) { + t.Parallel() + + service := &Service{} + cfg := config.Config{ + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + FallbackInputTokenThreshold: 88000, + }, + }, + } + + threshold := service.autoCompactThreshold(context.Background(), cfg) + if threshold != 88000 { + t.Fatalf("expected implicit mode fallback threshold == 88000, got %d", threshold) + } +} + func TestTokenUsageRecordedOnMessageDone(t *testing.T) { t.Parallel() From 3055d8f46053584eb0408640bf85b96a2e8dccdb Mon Sep 17 00:00:00 2001 From: xgopilot Date: Wed, 15 Apr 2026 14:24:51 +0000 Subject: [PATCH 31/33] fix(runtime): guard non-positive resolved auto-compact threshold - fallback when resolver returns zero/negative without error - add regression tests for zero/negative resolver outputs - align compact/config docs with implicit-threshold behavior Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: Yumiue <188874804+Yumiue@users.noreply.github.com> --- docs/context-compact.md | 2 +- docs/guides/configuration.md | 2 ++ internal/runtime/run.go | 2 +- internal/runtime/runtime_test.go | 52 ++++++++++++++++++++++++++++++++ 4 files changed, 56 insertions(+), 2 deletions(-) diff --git a/docs/context-compact.md b/docs/context-compact.md index af991d91..86c20a63 100644 --- a/docs/context-compact.md +++ b/docs/context-compact.md @@ -47,7 +47,7 @@ context: - `auto_compact.enabled` 控制是否启用基于 token 阈值的自动压缩;默认关闭。 - `auto_compact.input_token_threshold` - 当会话累计输入 token 数达到此阈值时触发自动压缩;默认 100000。 + 当会话累计输入 token 数达到此阈值时触发自动压缩;默认 `0`(自动推导),推导失败时回退到 `fallback_input_token_threshold`(默认 `100000`)。 ## 自动压缩 diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index 4a0191e1..c03f3209 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -83,6 +83,8 @@ context: | `context.compact.micro_compact_disabled` | 是否关闭默认启用的 micro compact | | `context.auto_compact.enabled` | 是否启用自动压缩 | | `context.auto_compact.input_token_threshold` | 自动压缩输入 token 阈值 | +| `context.auto_compact.reserve_tokens` | 自动阈值推导时预留 token 缓冲(`resolved_threshold = context_window - reserve_tokens`) | +| `context.auto_compact.fallback_input_token_threshold` | 自动推导失败时使用的保底阈值 | ### `runtime` 字段 diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 93415b36..031eb092 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -359,7 +359,7 @@ func (s *Service) autoCompactThreshold(ctx context.Context, cfg config.Config) i } if s != nil && s.autoCompactThresholdResolver != nil { threshold, err := s.autoCompactThresholdResolver.ResolveAutoCompactThreshold(ctx, cfg) - if err == nil { + if err == nil && threshold > 0 { return threshold } } diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 8ad229d7..a67f70de 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -4297,6 +4297,58 @@ func TestAutoCompactThresholdFallsBackWhenResolverErrors(t *testing.T) { } } +func TestAutoCompactThresholdFallsBackWhenResolverReturnsZeroWithoutError(t *testing.T) { + t.Parallel() + + service := &Service{} + service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + return 0, nil + }, + )) + + cfg := config.Config{ + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + FallbackInputTokenThreshold: 88000, + }, + }, + } + + threshold := service.autoCompactThreshold(context.Background(), cfg) + if threshold != 88000 { + t.Fatalf("expected fallback threshold == 88000, got %d", threshold) + } +} + +func TestAutoCompactThresholdFallsBackWhenResolverReturnsNegativeWithoutError(t *testing.T) { + t.Parallel() + + service := &Service{} + service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + return -1, nil + }, + )) + + cfg := config.Config{ + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + FallbackInputTokenThreshold: 88000, + }, + }, + } + + threshold := service.autoCompactThreshold(context.Background(), cfg) + if threshold != 88000 { + t.Fatalf("expected fallback threshold == 88000, got %d", threshold) + } +} + func TestAutoCompactThresholdImplicitModeWithoutResolverUsesFallback(t *testing.T) { t.Parallel() From afce7e3b43f410661b4989fda4e5beb8c5d4575e Mon Sep 17 00:00:00 2001 From: xgopilot Date: Wed, 15 Apr 2026 14:38:35 +0000 Subject: [PATCH 32/33] fix(runtime): align repeat-cycle breaker semantics and tests Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/runtime/controlplane/progress.go | 32 ++-- .../runtime/controlplane/progress_test.go | 71 ++++++--- internal/runtime/runtime_progress_test.go | 141 +++++++++++++----- 3 files changed, 174 insertions(+), 70 deletions(-) diff --git a/internal/runtime/controlplane/progress.go b/internal/runtime/controlplane/progress.go index f6d35164..784496ce 100644 --- a/internal/runtime/controlplane/progress.go +++ b/internal/runtime/controlplane/progress.go @@ -30,27 +30,29 @@ type ProgressState struct { // ApplyProgressEvidence 根据证据更新分值与 streak。 func ApplyProgressEvidence(state ProgressState, records []ProgressEvidenceRecord, currentSignature string) ProgressState { next := state.LastScore - isRepeated := false - - if len(records) > 0 { - if currentSignature != "" && currentSignature == state.LastSignature { - isRepeated = true + hasToolAttempt := currentSignature != "" + isRepeated := hasToolAttempt && state.LastSignature != "" && currentSignature == state.LastSignature + + if hasToolAttempt { + if isRepeated { + next.RepeatCycleStreak++ + } else { + next.RepeatCycleStreak = 1 } + } else { + next.RepeatCycleStreak = 0 } - nextSignature := currentSignature + nextSignature := "" + if hasToolAttempt { + nextSignature = currentSignature + } - if len(records) == 0 { - next.NoProgressStreak++ - next.RepeatCycleStreak = 0 - nextSignature = "" // Clear signature on failure to only count consecutive successes - } else if isRepeated { - next.NoProgressStreak++ - next.RepeatCycleStreak++ - } else { + if len(records) > 0 && !isRepeated { next.NoProgressStreak = 0 - next.RepeatCycleStreak = 0 next.ScoreDelta++ + } else { + next.NoProgressStreak++ } return ProgressState{ diff --git a/internal/runtime/controlplane/progress_test.go b/internal/runtime/controlplane/progress_test.go index c51a5a56..f457a0be 100644 --- a/internal/runtime/controlplane/progress_test.go +++ b/internal/runtime/controlplane/progress_test.go @@ -4,10 +4,15 @@ import "testing" func TestApplyProgressEvidenceNoEvidenceIncrementsNoProgress(t *testing.T) { t.Parallel() - state := ProgressState{} - next := ApplyProgressEvidence(state, nil, "") - if next.LastScore.NoProgressStreak != 1 { - t.Fatalf("expected no_progress_streak 1, got %d", next.LastScore.NoProgressStreak) + got := ApplyProgressEvidence(ProgressState{}, nil, "") + want := ProgressState{ + LastScore: ProgressScore{ + NoProgressStreak: 1, + RepeatCycleStreak: 0, + }, + } + if got != want { + t.Fatalf("expected %+v, got %+v", want, got) } } @@ -16,14 +21,19 @@ func TestApplyProgressEvidenceOnlyNonDupResetsNoProgressStreak(t *testing.T) { state := ProgressState{ LastScore: ProgressScore{NoProgressStreak: 3}, } - next := ApplyProgressEvidence(state, []ProgressEvidenceRecord{ + got := ApplyProgressEvidence(state, []ProgressEvidenceRecord{ {Kind: EvidenceNewInfoNonDup}, }, "sig1") - if next.LastScore.NoProgressStreak != 0 { - t.Fatalf("expected streak reset to 0, got %d", next.LastScore.NoProgressStreak) + want := ProgressState{ + LastScore: ProgressScore{ + ScoreDelta: 1, + NoProgressStreak: 0, + RepeatCycleStreak: 1, + }, + LastSignature: "sig1", } - if next.LastScore.ScoreDelta != 1 { - t.Fatalf("expected score_delta 1, got %d", next.LastScore.ScoreDelta) + if got != want { + t.Fatalf("expected %+v, got %+v", want, got) } } @@ -32,31 +42,52 @@ func TestApplyProgressEvidenceMixedResetsNoProgress(t *testing.T) { state := ProgressState{ LastScore: ProgressScore{NoProgressStreak: 2}, } - next := ApplyProgressEvidence(state, []ProgressEvidenceRecord{ + got := ApplyProgressEvidence(state, []ProgressEvidenceRecord{ {Kind: EvidenceNewInfoNonDup}, {Kind: ProgressEvidenceKind("other_evidence")}, }, "sig1") - if next.LastScore.NoProgressStreak != 0 { - t.Fatalf("expected streak reset, got %d", next.LastScore.NoProgressStreak) + if got.LastScore.NoProgressStreak != 0 { + t.Fatalf("expected streak reset, got %d", got.LastScore.NoProgressStreak) } } func TestApplyProgressEvidenceRepeatCycle(t *testing.T) { t.Parallel() state := ProgressState{ - LastScore: ProgressScore{NoProgressStreak: 1, RepeatCycleStreak: 1}, + LastScore: ProgressScore{NoProgressStreak: 1, RepeatCycleStreak: 2}, LastSignature: "sig1", } - next := ApplyProgressEvidence(state, []ProgressEvidenceRecord{ + got := ApplyProgressEvidence(state, []ProgressEvidenceRecord{ {Kind: EvidenceNewInfoNonDup}, }, "sig1") - if next.LastScore.NoProgressStreak != 2 { - t.Fatalf("expected no_progress_streak 2, got %d", next.LastScore.NoProgressStreak) + want := ProgressState{ + LastScore: ProgressScore{ + NoProgressStreak: 2, + RepeatCycleStreak: 3, + }, + LastSignature: "sig1", + } + if got != want { + t.Fatalf("expected %+v, got %+v", want, got) + } +} + +func TestApplyProgressEvidenceRepeatCycleOnFailureKeepsSignatureTracking(t *testing.T) { + t.Parallel() + state := ProgressState{ + LastScore: ProgressScore{NoProgressStreak: 2, RepeatCycleStreak: 1}, + LastSignature: "sig1", } - if next.LastScore.RepeatCycleStreak != 2 { - t.Fatalf("expected repeat_cycle_streak 2, got %d", next.LastScore.RepeatCycleStreak) + + got := ApplyProgressEvidence(state, nil, "sig1") + want := ProgressState{ + LastScore: ProgressScore{ + NoProgressStreak: 3, + RepeatCycleStreak: 2, + }, + LastSignature: "sig1", } - if next.LastSignature != "sig1" { - t.Fatalf("expected signature sig1, got %s", next.LastSignature) + if got != want { + t.Fatalf("expected %+v, got %+v", want, got) } } diff --git a/internal/runtime/runtime_progress_test.go b/internal/runtime/runtime_progress_test.go index 5481f6e6..89ec5189 100644 --- a/internal/runtime/runtime_progress_test.go +++ b/internal/runtime/runtime_progress_test.go @@ -3,6 +3,7 @@ package runtime import ( "context" "errors" + "strconv" "strings" "sync/atomic" "testing" @@ -34,15 +35,21 @@ func TestProgressStreakStopsRun(t *testing.T) { } var promptInjected bool + var signatureSeq int32 providerFactory := &scriptedProviderFactory{ provider: &scriptedProvider{ chatFn: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { + seq := atomic.AddInt32(&signatureSeq, 1) if strings.Contains(req.SystemPrompt, selfHealingReminder) { promptInjected = true } // the model always decides to call the tool events <- providertypes.NewToolCallStartStreamEvent(0, "call_err", "tool_error") - events <- providertypes.NewToolCallDeltaStreamEvent(0, "call_err", "{}") + events <- providertypes.NewToolCallDeltaStreamEvent( + 0, + "call_err", + `{"seq":`+strconv.FormatInt(int64(seq), 10)+`}`, + ) events <- providertypes.NewMessageDoneStreamEvent("tool_calls", nil) return nil }, @@ -76,19 +83,7 @@ func TestProgressStreakStopsRun(t *testing.T) { events := collectRuntimeEvents(service.Events()) // Verify StopReason is error and specifies the streak limit - assertEventContains(t, events, EventStopReasonDecided) - - for _, e := range events { - if e.Type == EventStopReasonDecided { - payload := e.Payload.(StopReasonDecidedPayload) - if payload.Reason != controlplane.StopReasonError { - t.Errorf("expected StopReasonError, got %s", payload.Reason) - } - if payload.Detail != ErrNoProgressStreakLimit.Error() { - t.Errorf("expected detail to be %q, got %q", ErrNoProgressStreakLimit.Error(), payload.Detail) - } - } - } + assertStopReasonDecided(t, events, controlplane.StopReasonError, ErrNoProgressStreakLimit.Error()) if !promptInjected { t.Error("expected self-healing prompt to be injected before streak limit is reached, but it wasn't") @@ -119,13 +114,19 @@ func TestProgressEvidenceResetsNoProgressStreak(t *testing.T) { } var providerCalls int32 + var signatureSeq int32 providerFactory := &scriptedProviderFactory{ provider: &scriptedProvider{ chatFn: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { call := int(atomic.AddInt32(&providerCalls, 1)) if call <= 4 { + seq := atomic.AddInt32(&signatureSeq, 1) events <- providertypes.NewToolCallStartStreamEvent(0, "call_mixed", "tool_mixed") - events <- providertypes.NewToolCallDeltaStreamEvent(0, "call_mixed", "{}") + events <- providertypes.NewToolCallDeltaStreamEvent( + 0, + "call_mixed", + `{"seq":`+strconv.FormatInt(int64(seq), 10)+`}`, + ) events <- providertypes.NewMessageDoneStreamEvent("tool_calls", nil) return nil } @@ -161,14 +162,7 @@ func TestProgressEvidenceResetsNoProgressStreak(t *testing.T) { } events := collectRuntimeEvents(service.Events()) - for _, e := range events { - if e.Type == EventStopReasonDecided { - payload := e.Payload.(StopReasonDecidedPayload) - if payload.Reason != controlplane.StopReasonSuccess { - t.Fatalf("expected stop reason success, got %s", payload.Reason) - } - } - } + assertStopReasonDecided(t, events, controlplane.StopReasonSuccess, "") } func TestRepeatCycleStreakStopsRunAndInjectsReminder(t *testing.T) { @@ -184,11 +178,14 @@ func TestRepeatCycleStreakStopsRunAndInjectsReminder(t *testing.T) { }, } + var executeCalls int32 + var providerCalls int32 toolManager := &stubToolManager{ specs: []providertypes.ToolSpec{ {Name: "tool_repeat"}, }, executeFn: func(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { + atomic.AddInt32(&executeCalls, 1) return tools.ToolResult{Name: input.Name, Content: "ok", IsError: false}, nil }, } @@ -197,6 +194,7 @@ func TestRepeatCycleStreakStopsRunAndInjectsReminder(t *testing.T) { providerFactory := &scriptedProviderFactory{ provider: &scriptedProvider{ chatFn: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { + atomic.AddInt32(&providerCalls, 1) if strings.Contains(req.SystemPrompt, selfHealingRepeatReminder) { promptInjected = true } @@ -230,22 +228,78 @@ func TestRepeatCycleStreakStopsRunAndInjectsReminder(t *testing.T) { events := collectRuntimeEvents(service.Events()) - assertEventContains(t, events, EventStopReasonDecided) - for _, e := range events { - if e.Type == EventStopReasonDecided { - payload := e.Payload.(StopReasonDecidedPayload) - if payload.Reason != controlplane.StopReasonError { - t.Errorf("expected StopReasonError, got %s", payload.Reason) - } - if payload.Detail != ErrRepeatCycleLimit.Error() { - t.Errorf("expected detail to be %q, got %q", ErrRepeatCycleLimit.Error(), payload.Detail) - } - } - } + assertStopReasonDecided(t, events, controlplane.StopReasonError, ErrRepeatCycleLimit.Error()) if !promptInjected { t.Fatal("expected repeat self-healing prompt to be injected before repeat limit is reached") } + if executeCalls != 3 { + t.Fatalf("expected break on the 3rd identical tool execution, got %d", executeCalls) + } + if providerCalls != 3 { + t.Fatalf("expected 3 provider turns before repeat breaker, got %d", providerCalls) + } +} + +func TestRepeatCycleStreakCountsFailedToolCalls(t *testing.T) { + t.Setenv("TEST_KEY", "dummy") + + cfg := config.Config{ + Providers: []config.ProviderConfig{{Name: "test-repeat-fail", Driver: "test", BaseURL: "http://localhost", Model: "test", APIKeyEnv: "TEST_KEY"}}, + SelectedProvider: "test-repeat-fail", + Workdir: t.TempDir(), + Runtime: config.RuntimeConfig{ + MaxNoProgressStreak: 10, + MaxRepeatCycleStreak: 3, + }, + } + + var executeCalls int32 + toolManager := &stubToolManager{ + specs: []providertypes.ToolSpec{ + {Name: "tool_repeat_fail"}, + }, + executeFn: func(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { + atomic.AddInt32(&executeCalls, 1) + return tools.ToolResult{Name: input.Name, Content: "error", IsError: true}, nil + }, + } + + var providerCalls int32 + providerFactory := &scriptedProviderFactory{ + provider: &scriptedProvider{ + chatFn: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { + atomic.AddInt32(&providerCalls, 1) + events <- providertypes.NewToolCallStartStreamEvent(0, "call_repeat_fail", "tool_repeat_fail") + events <- providertypes.NewToolCallDeltaStreamEvent(0, "call_repeat_fail", `{"path":"x"}`) + events <- providertypes.NewMessageDoneStreamEvent("tool_calls", nil) + return nil + }, + }, + } + + manager := config.NewManager(config.NewLoader(t.TempDir(), &cfg)) + service := NewWithFactory( + manager, + toolManager, + newMemoryStore(), + providerFactory, + nil, + ) + + err := service.Run(context.Background(), UserInput{ + RunID: "run-repeat-fail-streak", + Content: "trigger repeat fail loop", + }) + if !errors.Is(err, ErrRepeatCycleLimit) { + t.Fatalf("expected ErrRepeatCycleLimit, got %v", err) + } + if executeCalls != 3 { + t.Fatalf("expected failed repeated calls to break on the 3rd execution, got %d", executeCalls) + } + if providerCalls != 3 { + t.Fatalf("expected 3 provider turns before repeat breaker, got %d", providerCalls) + } } func TestComputeToolSignatureNormalizationAndFallback(t *testing.T) { @@ -360,3 +414,20 @@ func TestResolveStreakLimitDefaults(t *testing.T) { t.Fatalf("expected explicit repeat limit 6, got %d", got) } } + +func assertStopReasonDecided(t *testing.T, events []RuntimeEvent, wantReason controlplane.StopReason, wantDetail string) { + t.Helper() + assertEventContains(t, events, EventStopReasonDecided) + for _, e := range events { + if e.Type != EventStopReasonDecided { + continue + } + payload := e.Payload.(StopReasonDecidedPayload) + if payload.Reason != wantReason { + t.Fatalf("expected stop reason %s, got %s", wantReason, payload.Reason) + } + if wantDetail != "" && payload.Detail != wantDetail { + t.Fatalf("expected detail to be %q, got %q", wantDetail, payload.Detail) + } + } +} From 047bb8b11d03e6e19bb10079189dd0e192ae5049 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Wed, 15 Apr 2026 15:24:28 +0000 Subject: [PATCH 33/33] fix(runtime): cache auto-compact threshold within run - cache derived auto compact threshold per run using provider/model/auto_compact key - recompute only when key changes to avoid repeated resolver snapshot lookups in hot path - simplify run flow by extracting session lock binding and self-healing prompt injection helpers - add runtime tests for cache hit and key-change recompute behavior Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: Yumiue <188874804+Yumiue@users.noreply.github.com> --- internal/runtime/run.go | 96 +++++++++++++++++++++----------- internal/runtime/runtime_test.go | 78 ++++++++++++++++++++++++++ internal/runtime/state.go | 18 ++++++ 3 files changed, 161 insertions(+), 31 deletions(-) diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 031eb092..df37f1bd 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -46,32 +46,18 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { initialCfg := s.configManager.Get() sessionID := strings.TrimSpace(input.SessionID) - releaseSessionLock := func() {} + releaseSessionLock := s.bindSessionLock(sessionID) defer func() { releaseSessionLock() }() - if sessionID != "" { - sessionMu, releaseLockRef := s.acquireSessionLock(sessionID) - sessionMu.Lock() - releaseSessionLock = func() { - sessionMu.Unlock() - releaseLockRef() - } - } - session, err := s.loadOrCreateSession(ctx, input.SessionID, input.Content, initialCfg.Workdir, input.Workdir) if err != nil { return s.handleRunError(ctx, input.RunID, input.SessionID, err) } if sessionID == "" { - sessionMu, releaseLockRef := s.acquireSessionLock(session.ID) - sessionMu.Lock() - releaseSessionLock = func() { - sessionMu.Unlock() - releaseLockRef() - } + releaseSessionLock = s.bindSessionLock(session.ID) } state := newRunState(input.RunID, session) @@ -184,7 +170,7 @@ func (s *Service) prepareTurnSnapshot(ctx context.Context, state *runState) (tur }, Compact: agentcontext.CompactOptions{ DisableMicroCompact: cfg.Context.Compact.MicroCompactDisabled, - AutoCompactThreshold: s.autoCompactThreshold(ctx, cfg), + AutoCompactThreshold: s.autoCompactThresholdForState(ctx, cfg, state), MicroCompactRetainedToolSpans: cfg.Context.Compact.MicroCompactRetainedToolSpans, ReadTimeMaxMessageSpans: cfg.Context.Compact.ReadTimeMaxMessageSpans, }, @@ -220,16 +206,7 @@ func (s *Service) prepareTurnSnapshot(ctx context.Context, state *runState) (tur state.mu.Unlock() limit := resolveNoProgressStreakLimit(cfg.Runtime) - systemPrompt := builtContext.SystemPrompt - - if streak == limit-1 { - trimmed := strings.TrimSpace(systemPrompt) - if trimmed == "" { - systemPrompt = selfHealingReminder - } else { - systemPrompt = trimmed + "\n\n" + selfHealingReminder - } - } + systemPrompt := withSelfHealingReminder(builtContext.SystemPrompt, streak, limit) model := strings.TrimSpace(cfg.CurrentModel) return turnSnapshot{ @@ -351,19 +328,38 @@ func (s *Service) applyCompactForState( // autoCompactThreshold 返回当前配置下的自动 compact 触发阈值。 func (s *Service) autoCompactThreshold(ctx context.Context, cfg config.Config) int { + return s.autoCompactThresholdForState(ctx, cfg, nil) +} + +// autoCompactThresholdForState 返回当前配置下的自动 compact 触发阈值,并在单次 run 内按关键输入缓存结果。 +func (s *Service) autoCompactThresholdForState(ctx context.Context, cfg config.Config, state *runState) int { if !cfg.Context.AutoCompact.Enabled { return 0 } if cfg.Context.AutoCompact.InputTokenThreshold > 0 { return cfg.Context.AutoCompact.InputTokenThreshold } + + key := autoCompactCacheKeyFromConfig(cfg) + if state != nil && state.autoCompactCache.valid && state.autoCompactCache.key == key { + return state.autoCompactCache.threshold + } + + threshold := fallbackAutoCompactThreshold(cfg) if s != nil && s.autoCompactThresholdResolver != nil { - threshold, err := s.autoCompactThresholdResolver.ResolveAutoCompactThreshold(ctx, cfg) - if err == nil && threshold > 0 { - return threshold + resolvedThreshold, err := s.autoCompactThresholdResolver.ResolveAutoCompactThreshold(ctx, cfg) + if err == nil && resolvedThreshold > 0 { + threshold = resolvedThreshold + } + } + if state != nil { + state.autoCompactCache = autoCompactThresholdCache{ + key: key, + threshold: threshold, + valid: true, } } - return fallbackAutoCompactThreshold(cfg) + return threshold } // degradeKeepRecentMessages 根据 reactive compact 尝试次数逐步减少保留消息数。 @@ -384,3 +380,41 @@ func fallbackAutoCompactThreshold(cfg config.Config) int { } return 0 } + +// bindSessionLock 获取并持有指定会话锁,返回对应的释放函数。 +func (s *Service) bindSessionLock(sessionID string) func() { + id := strings.TrimSpace(sessionID) + if id == "" { + return func() {} + } + sessionMu, releaseLockRef := s.acquireSessionLock(id) + sessionMu.Lock() + return func() { + sessionMu.Unlock() + releaseLockRef() + } +} + +// withSelfHealingReminder 在无进展临界轮次注入自愈提醒,保持提示词拼接规则集中。 +func withSelfHealingReminder(systemPrompt string, streak int, limit int) string { + if streak != limit-1 { + return systemPrompt + } + trimmed := strings.TrimSpace(systemPrompt) + if trimmed == "" { + return selfHealingReminder + } + return trimmed + "\n\n" + selfHealingReminder +} + +// autoCompactCacheKeyFromConfig 提取会影响自动压缩阈值解析的配置维度,用于 run 内缓存命中判断。 +func autoCompactCacheKeyFromConfig(cfg config.Config) autoCompactThresholdCacheKey { + return autoCompactThresholdCacheKey{ + provider: strings.TrimSpace(cfg.SelectedProvider), + model: strings.TrimSpace(cfg.CurrentModel), + autoCompactEnabled: cfg.Context.AutoCompact.Enabled, + autoCompactInputThreshold: cfg.Context.AutoCompact.InputTokenThreshold, + autoCompactReserveTokens: cfg.Context.AutoCompact.ReserveTokens, + autoCompactFallback: cfg.Context.AutoCompact.FallbackInputTokenThreshold, + } +} diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index a67f70de..ee7c340d 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -4369,6 +4369,84 @@ func TestAutoCompactThresholdImplicitModeWithoutResolverUsesFallback(t *testing. } } +func TestAutoCompactThresholdForStateCachesResolverResultWithinRun(t *testing.T) { + t.Parallel() + + service := &Service{} + resolveCalls := 0 + service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + resolveCalls++ + return 88000, nil + }, + )) + + cfg := config.Config{ + SelectedProvider: "openai", + CurrentModel: "gpt-5", + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + ReserveTokens: 10000, + FallbackInputTokenThreshold: 76000, + }, + }, + } + state := newRunState("run-cache-hit", newRuntimeSession("session-cache-hit")) + + threshold1 := service.autoCompactThresholdForState(context.Background(), cfg, &state) + threshold2 := service.autoCompactThresholdForState(context.Background(), cfg, &state) + + if threshold1 != 88000 || threshold2 != 88000 { + t.Fatalf("expected cached resolver threshold == 88000, got %d and %d", threshold1, threshold2) + } + if resolveCalls != 1 { + t.Fatalf("expected resolver to be called once, got %d", resolveCalls) + } +} + +func TestAutoCompactThresholdForStateRecomputesWhenCacheKeyChanges(t *testing.T) { + t.Parallel() + + service := &Service{} + resolveCalls := 0 + service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( + func(ctx context.Context, cfg config.Config) (int, error) { + resolveCalls++ + if strings.TrimSpace(cfg.CurrentModel) == "gpt-5.1" { + return 99000, nil + } + return 88000, nil + }, + )) + + cfg := config.Config{ + SelectedProvider: "openai", + CurrentModel: "gpt-5", + Context: config.ContextConfig{ + AutoCompact: config.AutoCompactConfig{ + Enabled: true, + InputTokenThreshold: 0, + ReserveTokens: 10000, + FallbackInputTokenThreshold: 76000, + }, + }, + } + state := newRunState("run-cache-miss", newRuntimeSession("session-cache-miss")) + + threshold1 := service.autoCompactThresholdForState(context.Background(), cfg, &state) + cfg.CurrentModel = "gpt-5.1" + threshold2 := service.autoCompactThresholdForState(context.Background(), cfg, &state) + + if threshold1 != 88000 || threshold2 != 99000 { + t.Fatalf("expected thresholds [88000, 99000], got [%d, %d]", threshold1, threshold2) + } + if resolveCalls != 2 { + t.Fatalf("expected resolver to be called twice after key change, got %d", resolveCalls) + } +} + func TestTokenUsageRecordedOnMessageDone(t *testing.T) { t.Parallel() diff --git a/internal/runtime/state.go b/internal/runtime/state.go index 9a202343..2c0974fb 100644 --- a/internal/runtime/state.go +++ b/internal/runtime/state.go @@ -21,6 +21,7 @@ type runState struct { session agentsession.Session compactApplied bool reactiveCompactAttempts int + autoCompactCache autoCompactThresholdCache rememberedThisRun bool turn int phase controlplane.Phase @@ -101,3 +102,20 @@ type providerTurnResult struct { inputTokens int outputTokens int } + +// autoCompactThresholdCache 保存当前 run 已解析过的自动压缩阈值,避免热路径重复解析。 +type autoCompactThresholdCache struct { + key autoCompactThresholdCacheKey + threshold int + valid bool +} + +// autoCompactThresholdCacheKey 描述自动压缩阈值解析输入的关键维度。 +type autoCompactThresholdCacheKey struct { + provider string + model string + autoCompactEnabled bool + autoCompactInputThreshold int + autoCompactReserveTokens int + autoCompactFallback int +}