From 6b46bdfa522e0dec3c7c776be3d3b2adb63552b6 Mon Sep 17 00:00:00 2001 From: phantom5099 <1011668688@qq.com> Date: Fri, 17 Apr 2026 15:27:08 +0800 Subject: [PATCH 01/10] =?UTF-8?q?feat(TUI):=E5=A2=9E=E5=8A=A0=E5=9B=BE?= =?UTF-8?q?=E7=89=87=E8=BE=93=E5=85=A5=E5=85=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/app/bootstrap.go | 1 + internal/config/loader_test.go | 14 +- internal/config/provider_loader.go | 6 +- .../chatcompletions/provider_test.go | 33 +++ .../openaicompat/chatcompletions/request.go | 42 ++- internal/runtime/events.go | 28 ++ internal/runtime/input_prepare.go | 141 +++++++++ internal/runtime/input_prepare_test.go | 152 ++++++++++ internal/runtime/runtime.go | 34 +++ internal/session/input_preparer.go | 272 ++++++++++++++++++ internal/session/input_preparer_test.go | 242 ++++++++++++++++ internal/tools/todo/write.go | 3 + internal/tools/todo/write_test.go | 14 + internal/tui/bootstrap/builder_test.go | 40 +++ internal/tui/core/app/app.go | 46 ++- internal/tui/core/app/input_features.go | 187 ++++++------ internal/tui/core/app/input_features_test.go | 191 ++++++------ internal/tui/core/app/update.go | 155 +++++++--- .../tui/core/app/update_permission_test.go | 14 + internal/tui/core/app/update_test.go | 167 +++++++---- internal/tui/services/runtime_service.go | 15 + internal/tui/services/services_test.go | 28 ++ 22 files changed, 1497 insertions(+), 328 deletions(-) create mode 100644 internal/runtime/input_prepare.go create mode 100644 internal/runtime/input_prepare_test.go create mode 100644 internal/session/input_preparer.go create mode 100644 internal/session/input_preparer_test.go diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go index 4ab88b53..24a35b7e 100644 --- a/internal/app/bootstrap.go +++ b/internal/app/bootstrap.go @@ -169,6 +169,7 @@ func BuildRuntime(ctx context.Context, opts BootstrapOptions) (RuntimeBundle, er contextBuilder, ) runtimeSvc.SetSessionAssetStore(sessionStore) + runtimeSvc.SetUserInputPreparer(agentruntime.NewSessionInputPreparer(sessionStore, sessionStore)) runtimeSvc.SetSkillsRegistry(buildSkillsRegistry(ctx, loader.BaseDir())) runtimeSvc.SetAutoCompactThresholdResolver(runtimeAutoCompactThresholdResolverFunc( func(ctx context.Context, cfg config.Config) (int, error) { diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index a93c3d0f..211706ec 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -1052,10 +1052,10 @@ func TestLoadCustomProvidersReadDirAndStatErrors(t *testing.T) { _, err := loadCustomProviders(baseDir) if err == nil { - t.Fatal("expected read providers dir error") + t.Fatal("expected create providers dir error") } - if !strings.Contains(err.Error(), "read providers dir") { - t.Fatalf("expected read providers dir error, got %v", err) + if !strings.Contains(err.Error(), "create providers dir") { + t.Fatalf("expected create providers dir error, got %v", err) } }) @@ -1096,8 +1096,12 @@ func TestLoadCustomProvidersReturnsEmptyWhenProvidersDirMissing(t *testing.T) { if len(providers) != 0 { t.Fatalf("expected no custom providers, got %d", len(providers)) } - if _, err := os.Stat(providersPath); !os.IsNotExist(err) { - t.Fatalf("expected providers dir to remain missing, got %v", err) + info, err := os.Stat(providersPath) + if err != nil { + t.Fatalf("expected providers dir to be created, got %v", err) + } + if !info.IsDir() { + t.Fatalf("expected providers path to be directory") } } diff --git a/internal/config/provider_loader.go b/internal/config/provider_loader.go index dd51cb7d..d73ea428 100644 --- a/internal/config/provider_loader.go +++ b/internal/config/provider_loader.go @@ -62,11 +62,11 @@ type customProviderSettings struct { // loadCustomProviders 扫描 baseDir/providers 下的一层子目录,并将其中的 provider.yaml 解析为运行时配置。 func loadCustomProviders(baseDir string) ([]ProviderConfig, error) { providersDir := filepath.Join(strings.TrimSpace(baseDir), providersDirName) + if err := os.MkdirAll(providersDir, 0o755); err != nil { + return nil, fmt.Errorf("config: create providers dir: %w", err) + } entries, err := os.ReadDir(providersDir) if err != nil { - if os.IsNotExist(err) { - return nil, nil - } return nil, fmt.Errorf("config: read providers dir: %w", err) } diff --git a/internal/provider/openaicompat/chatcompletions/provider_test.go b/internal/provider/openaicompat/chatcompletions/provider_test.go index 6a379d89..13335aaf 100644 --- a/internal/provider/openaicompat/chatcompletions/provider_test.go +++ b/internal/provider/openaicompat/chatcompletions/provider_test.go @@ -71,6 +71,39 @@ func TestNewAndBuildRequest(t *testing.T) { t.Fatalf("unexpected tools: %+v", payload.Tools) } + toolSchemaWithTopLevelCombinator := map[string]any{ + "type": "object", + "properties": map[string]any{ + "action": map[string]any{"type": "string"}, + }, + "oneOf": []any{ + map[string]any{"required": []string{"action"}}, + }, + } + sanitizedPayload, err := BuildRequest(context.Background(), testCfg("https://api.example.com/v1", "gpt-4.1", "test-key"), providertypes.GenerateRequest{ + Messages: []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}, + }, + Tools: []providertypes.ToolSpec{{ + Name: "todo_write", + Description: "write todos", + Schema: toolSchemaWithTopLevelCombinator, + }}, + }) + if err != nil { + t.Fatalf("BuildRequest() sanitize schema error = %v", err) + } + gotSchema := sanitizedPayload.Tools[0].Function.Parameters + if gotSchema["type"] != "object" { + t.Fatalf("expected sanitized schema type object, got %+v", gotSchema["type"]) + } + if _, ok := gotSchema["oneOf"]; ok { + t.Fatalf("expected sanitized schema to drop top-level oneOf, got %+v", gotSchema) + } + if _, ok := toolSchemaWithTopLevelCombinator["oneOf"]; !ok { + t.Fatalf("expected original schema not to be mutated") + } + withSessionAsset, err := BuildRequest(context.Background(), testCfg("https://api.example.com/v1", "gpt-4.1", "test-key"), providertypes.GenerateRequest{ Messages: []providertypes.Message{ { diff --git a/internal/provider/openaicompat/chatcompletions/request.go b/internal/provider/openaicompat/chatcompletions/request.go index d37a658b..38a928a0 100644 --- a/internal/provider/openaicompat/chatcompletions/request.go +++ b/internal/provider/openaicompat/chatcompletions/request.go @@ -67,7 +67,7 @@ func BuildRequest(ctx context.Context, cfg provider.RuntimeConfig, req providert Function: FunctionDefinition{ Name: spec.Name, Description: spec.Description, - Parameters: spec.Schema, + Parameters: normalizeToolSchemaForOpenAI(spec.Schema), }, } payload.Tools = append(payload.Tools, def) @@ -77,6 +77,46 @@ func BuildRequest(ctx context.Context, cfg provider.RuntimeConfig, req providert return payload, nil } +// normalizeToolSchemaForOpenAI 归一化工具参数 schema,避免 OpenAI chat-completions 顶层关键字约束报错。 +// 仅收敛顶层结构:保证 type=object,并移除顶层 oneOf/anyOf/allOf/enum/not;嵌套语义保持原样。 +func normalizeToolSchemaForOpenAI(schema map[string]any) map[string]any { + normalized := cloneSchemaTopLevel(schema) + if len(normalized) == 0 { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + } + + typeName, _ := normalized["type"].(string) + if strings.TrimSpace(strings.ToLower(typeName)) != "object" { + normalized["type"] = "object" + } + + if _, ok := normalized["properties"].(map[string]any); !ok { + normalized["properties"] = map[string]any{} + } + + delete(normalized, "oneOf") + delete(normalized, "anyOf") + delete(normalized, "allOf") + delete(normalized, "enum") + delete(normalized, "not") + return normalized +} + +// cloneSchemaTopLevel 复制 schema 顶层 map,避免归一化阶段修改调用方原始结构。 +func cloneSchemaTopLevel(schema map[string]any) map[string]any { + if len(schema) == 0 { + return map[string]any{} + } + cloned := make(map[string]any, len(schema)) + for key, value := range schema { + cloned[key] = value + } + return cloned +} + // ToOpenAIMessage 将通用 Message 转换为 OpenAI 协议消息格式。 func ToOpenAIMessage(ctx context.Context, message providertypes.Message, assetReader providertypes.SessionAssetReader) (Message, error) { msg, _, err := toOpenAIMessageWithBudget(ctx, message, assetReader, maxSessionAssetsTotalBytes) diff --git a/internal/runtime/events.go b/internal/runtime/events.go index 8f72f814..eb41ad79 100644 --- a/internal/runtime/events.go +++ b/internal/runtime/events.go @@ -92,6 +92,28 @@ type TodoEventPayload struct { Reason string `json:"reason,omitempty"` } +// InputNormalizedPayload 描述输入归一化完成后的摘要信息。 +type InputNormalizedPayload struct { + TextLength int `json:"text_length"` + ImageCount int `json:"image_count"` +} + +// AssetSavedPayload 描述单个附件成功保存后的结果。 +type AssetSavedPayload struct { + Index int `json:"index"` + Path string `json:"path,omitempty"` + AssetID string `json:"asset_id"` + MimeType string `json:"mime_type,omitempty"` + Size int64 `json:"size,omitempty"` +} + +// AssetSaveFailedPayload 描述单个附件保存失败的结构化信息。 +type AssetSaveFailedPayload struct { + Index int `json:"index"` + Path string `json:"path,omitempty"` + Message string `json:"message"` +} + const ( // EventUserMessage 表示用户消息已写入会话。 EventUserMessage EventType = "user_message" @@ -143,6 +165,12 @@ const ( EventTodoConflict EventType = "todo_conflict" // EventTodoSummaryInjected 表示本轮上下文注入了 Todo 摘要。 EventTodoSummaryInjected EventType = "todo_summary_injected" + // EventInputNormalized 表示用户输入已完成归一化。 + EventInputNormalized EventType = "input_normalized" + // EventAssetSaved 表示本轮用户输入附件已完成持久化。 + EventAssetSaved EventType = "asset_saved" + // EventAssetSaveFailed 表示本轮用户输入附件持久化失败。 + EventAssetSaveFailed EventType = "asset_save_failed" ) // TokenUsagePayload 承载单轮 token 用量统计。 diff --git a/internal/runtime/input_prepare.go b/internal/runtime/input_prepare.go new file mode 100644 index 00000000..30e467f6 --- /dev/null +++ b/internal/runtime/input_prepare.go @@ -0,0 +1,141 @@ +package runtime + +import ( + "context" + "errors" + "fmt" + "strings" + + agentsession "neo-code/internal/session" +) + +// NewSessionInputPreparer 创建基于 session 子层实现的输入归一化适配器。 +func NewSessionInputPreparer(store agentsession.Store, assetStore agentsession.AssetStore) UserInputPreparer { + return sessionInputPreparer{ + preparer: agentsession.NewInputPreparer(store, assetStore), + } +} + +// PrepareUserInput 负责在运行前执行输入归一化编排,并发出最小可观测事件。 +// Submit 作为运行时提交入口,统一串联输入归一化与执行,避免上层手动编排两段调用。 +func (s *Service) Submit(ctx context.Context, input PrepareInput) error { + prepared, err := s.PrepareUserInput(ctx, input) + if err != nil { + return err + } + return s.Run(ctx, prepared) +} + +func (s *Service) PrepareUserInput(ctx context.Context, input PrepareInput) (UserInput, error) { + if err := ctx.Err(); err != nil { + return UserInput{}, err + } + if s == nil { + return UserInput{}, errors.New("runtime: service is nil") + } + if s.userInputPreparer == nil { + err := errors.New("runtime: user input preparer is not configured") + _ = s.emitPrepareFailure(ctx, input, err) + return UserInput{}, err + } + + defaultWorkdir := "" + if s.configManager != nil { + defaultWorkdir = strings.TrimSpace(s.configManager.Get().Workdir) + } + + prepared, err := s.userInputPreparer.Prepare(ctx, input, defaultWorkdir) + if err != nil { + _ = s.emitPrepareFailure(ctx, input, err) + return UserInput{}, err + } + + runID := strings.TrimSpace(input.RunID) + _ = s.emit(ctx, EventInputNormalized, runID, prepared.UserInput.SessionID, InputNormalizedPayload{ + TextLength: len([]rune(strings.TrimSpace(input.Text))), + ImageCount: len(input.Images), + }) + for index, asset := range prepared.SavedAssets { + path := "" + if index >= 0 && index < len(input.Images) { + path = strings.TrimSpace(input.Images[index].Path) + } + _ = s.emit(ctx, EventAssetSaved, runID, prepared.UserInput.SessionID, AssetSavedPayload{ + Index: index, + Path: path, + AssetID: asset.ID, + MimeType: asset.MimeType, + Size: asset.Size, + }) + } + + return prepared.UserInput, nil +} + +// emitPrepareFailure 统一发送输入归一化阶段的失败事件,避免前置副作用变成黑箱。 +func (s *Service) emitPrepareFailure(ctx context.Context, input PrepareInput, err error) error { + if s == nil { + return nil + } + + runID := strings.TrimSpace(input.RunID) + sessionID := strings.TrimSpace(input.SessionID) + + var saveErr *agentsession.AssetSaveError + if errors.As(err, &saveErr) { + return s.emit(ctx, EventAssetSaveFailed, runID, sessionID, AssetSaveFailedPayload{ + Index: saveErr.Index, + Path: strings.TrimSpace(saveErr.Path), + Message: strings.TrimSpace(saveErr.Error()), + }) + } + return s.emit(ctx, EventError, runID, sessionID, strings.TrimSpace(err.Error())) +} + +type sessionInputPreparer struct { + preparer *agentsession.InputPreparer +} + +// Prepare 将 runtime 输入 DTO 映射到 session 子层并返回标准 UserInput 结果。 +func (p sessionInputPreparer) Prepare( + ctx context.Context, + input PrepareInput, + defaultWorkdir string, +) (PreparedInputResult, error) { + if p.preparer == nil { + return PreparedInputResult{}, errors.New("runtime: session input preparer is nil") + } + + sessionImages := make([]agentsession.PrepareImageInput, 0, len(input.Images)) + for _, image := range input.Images { + sessionImages = append(sessionImages, agentsession.PrepareImageInput{ + Path: strings.TrimSpace(image.Path), + MimeType: strings.TrimSpace(image.MimeType), + }) + } + + prepared, err := p.preparer.Prepare(ctx, agentsession.PrepareInput{ + SessionID: strings.TrimSpace(input.SessionID), + Text: input.Text, + Images: sessionImages, + DefaultWorkdir: strings.TrimSpace(defaultWorkdir), + RequestedWorkdir: strings.TrimSpace(input.Workdir), + }) + if err != nil { + return PreparedInputResult{}, err + } + + if len(prepared.Parts) == 0 { + return PreparedInputResult{}, fmt.Errorf("runtime: prepared parts is empty") + } + + return PreparedInputResult{ + UserInput: UserInput{ + SessionID: strings.TrimSpace(prepared.SessionID), + RunID: strings.TrimSpace(input.RunID), + Parts: prepared.Parts, + Workdir: strings.TrimSpace(prepared.Workdir), + }, + SavedAssets: append([]agentsession.AssetMeta(nil), prepared.SavedAssets...), + }, nil +} diff --git a/internal/runtime/input_prepare_test.go b/internal/runtime/input_prepare_test.go new file mode 100644 index 00000000..d7a2e2c3 --- /dev/null +++ b/internal/runtime/input_prepare_test.go @@ -0,0 +1,152 @@ +package runtime + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "neo-code/internal/config" + providertypes "neo-code/internal/provider/types" + agentsession "neo-code/internal/session" +) + +func TestServicePrepareUserInputEmitsNormalizeAndAssetSaved(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + svc, _ := newPrepareTestService(t, workdir, true) + + imagePath := filepath.Join(workdir, "img.png") + if err := os.WriteFile(imagePath, []byte("fake-png"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + input, err := svc.PrepareUserInput(context.Background(), PrepareInput{ + RunID: "run-prepare-1", + Text: "hello", + Images: []UserImageInput{{Path: imagePath, MimeType: "image/png"}}, + }) + if err != nil { + t.Fatalf("PrepareUserInput() error = %v", err) + } + if input.SessionID == "" || input.RunID != "run-prepare-1" { + t.Fatalf("unexpected prepared user input: %+v", input) + } + if len(input.Parts) != 2 || input.Parts[0].Kind != providertypes.ContentPartText || input.Parts[1].Kind != providertypes.ContentPartImage { + t.Fatalf("unexpected prepared parts: %+v", input.Parts) + } + + normalizedEvent := mustReadRuntimeEvent(t, svc.Events()) + if normalizedEvent.Type != EventInputNormalized { + t.Fatalf("expected first event %q, got %q", EventInputNormalized, normalizedEvent.Type) + } + normalizedPayload, ok := normalizedEvent.Payload.(InputNormalizedPayload) + if !ok || normalizedPayload.ImageCount != 1 { + t.Fatalf("unexpected normalized payload: %#v", normalizedEvent.Payload) + } + + assetSavedEvent := mustReadRuntimeEvent(t, svc.Events()) + if assetSavedEvent.Type != EventAssetSaved { + t.Fatalf("expected second event %q, got %q", EventAssetSaved, assetSavedEvent.Type) + } + assetSavedPayload, ok := assetSavedEvent.Payload.(AssetSavedPayload) + if !ok || assetSavedPayload.AssetID == "" || assetSavedPayload.MimeType != "image/png" { + t.Fatalf("unexpected asset_saved payload: %#v", assetSavedEvent.Payload) + } +} + +func TestServicePrepareUserInputEmitsAssetSaveFailed(t *testing.T) { + t.Parallel() + + svc, _ := newPrepareTestService(t, t.TempDir(), true) + _, err := svc.PrepareUserInput(context.Background(), PrepareInput{ + RunID: "run-prepare-2", + Text: "hello", + Images: []UserImageInput{{Path: filepath.Join(t.TempDir(), "missing.png"), MimeType: "image/png"}}, + }) + if err == nil { + t.Fatalf("expected PrepareUserInput() to fail") + } + + failedEvent := mustReadRuntimeEvent(t, svc.Events()) + if failedEvent.Type != EventAssetSaveFailed { + t.Fatalf("expected event %q, got %q", EventAssetSaveFailed, failedEvent.Type) + } + payload, ok := failedEvent.Payload.(AssetSaveFailedPayload) + if !ok || payload.Index != 0 { + t.Fatalf("unexpected asset_save_failed payload: %#v", failedEvent.Payload) + } +} + +func TestServicePrepareUserInputWithoutPreparerEmitsErrorEvent(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + svc, _ := newPrepareTestService(t, workdir, false) + + _, err := svc.PrepareUserInput(context.Background(), PrepareInput{ + RunID: "run-prepare-3", + Text: "hello", + }) + if err == nil { + t.Fatalf("expected PrepareUserInput() to fail without preparer") + } + + errorEvent := mustReadRuntimeEvent(t, svc.Events()) + if errorEvent.Type != EventError { + t.Fatalf("expected event %q, got %q", EventError, errorEvent.Type) + } +} + +func TestServiceSubmitWithoutPreparerReturnsError(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + svc, _ := newPrepareTestService(t, workdir, false) + + err := svc.Submit(context.Background(), PrepareInput{ + RunID: "run-submit-1", + Text: "hello", + }) + if err == nil { + t.Fatalf("expected Submit() to fail without preparer") + } + + errorEvent := mustReadRuntimeEvent(t, svc.Events()) + if errorEvent.Type != EventError { + t.Fatalf("expected event %q, got %q", EventError, errorEvent.Type) + } +} + +func newPrepareTestService(t *testing.T, workdir string, withPreparer bool) (*Service, *agentsession.JSONStore) { + t.Helper() + + cfg := config.StaticDefaults() + cfg.Workdir = workdir + loader := config.NewLoader(t.TempDir(), cfg) + manager := config.NewManager(loader) + if _, err := manager.Load(context.Background()); err != nil { + t.Fatalf("load config: %v", err) + } + + store := agentsession.NewStore(t.TempDir(), workdir) + svc := NewWithFactory(manager, nil, store, nil, nil) + svc.SetSessionAssetStore(store) + if withPreparer { + svc.SetUserInputPreparer(NewSessionInputPreparer(store, store)) + } + return svc, store +} + +func mustReadRuntimeEvent(t *testing.T, events <-chan RuntimeEvent) RuntimeEvent { + t.Helper() + select { + case event := <-events: + return event + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for runtime event") + return RuntimeEvent{} + } +} diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 910507c4..ebe74ef6 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -28,6 +28,8 @@ const ( // Runtime 定义 runtime 对外暴露的运行、压缩与审批接口。 type Runtime interface { + Submit(ctx context.Context, input PrepareInput) error + PrepareUserInput(ctx context.Context, input PrepareInput) (UserInput, error) Run(ctx context.Context, input UserInput) error Compact(ctx context.Context, input CompactInput) (CompactResult, error) ResolvePermission(ctx context.Context, input PermissionResolutionInput) error @@ -51,6 +53,32 @@ type UserInput struct { CapabilityToken *security.CapabilityToken } +// UserImageInput 表示用户输入中附带的单个图片引用(路径 + MIME)。 +type UserImageInput struct { + Path string + MimeType string +} + +// PrepareInput 表示进入 runtime 归一化前的领域输入(仅包含文本/图片/会话上下文)。 +type PrepareInput struct { + SessionID string + RunID string + Workdir string + Text string + Images []UserImageInput +} + +// PreparedInputResult 描述输入归一化完成后的结果快照(标准 UserInput + 本轮保存附件元数据)。 +type PreparedInputResult struct { + UserInput UserInput + SavedAssets []agentsession.AssetMeta +} + +// UserInputPreparer 定义 runtime 输入归一化能力:会话绑定、附件持久化与 parts 组装。 +type UserInputPreparer interface { + Prepare(ctx context.Context, input PrepareInput, defaultWorkdir string) (PreparedInputResult, error) +} + // ProviderFactory 负责基于运行期配置创建 provider 实例。 type ProviderFactory interface { Build(ctx context.Context, cfg provider.RuntimeConfig) (provider.Provider, error) @@ -72,6 +100,7 @@ type Service struct { configManager *config.Manager sessionStore agentsession.Store sessionAssetStore agentsession.AssetStore + userInputPreparer UserInputPreparer toolManager tools.Manager providerFactory ProviderFactory contextBuilder agentcontext.Builder @@ -146,6 +175,11 @@ func (s *Service) SetSessionAssetStore(store agentsession.AssetStore) { s.sessionAssetStore = store } +// SetUserInputPreparer 设置输入归一化能力实现;runtime 仅做编排调用,不承载具体存储细节。 +func (s *Service) SetUserInputPreparer(preparer UserInputPreparer) { + s.userInputPreparer = preparer +} + // SetSkillsRegistry 设置运行时可选的 skills registry,用于激活校验与上下文注入。 func (s *Service) SetSkillsRegistry(registry skills.Registry) { s.skillsRegistry = registry diff --git a/internal/session/input_preparer.go b/internal/session/input_preparer.go new file mode 100644 index 00000000..28fe7910 --- /dev/null +++ b/internal/session/input_preparer.go @@ -0,0 +1,272 @@ +package session + +import ( + "context" + "fmt" + "io" + "mime" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + providertypes "neo-code/internal/provider/types" +) + +const imageOnlySessionTitle = "Image Message" + +// PrepareImageInput 表示一次用户输入中附带的本地图片引用。 +type PrepareImageInput struct { + Path string + MimeType string +} + +// PrepareInput 定义会话输入归一化的领域输入参数。 +type PrepareInput struct { + SessionID string + Text string + Images []PrepareImageInput + DefaultWorkdir string + RequestedWorkdir string +} + +// PreparedInput 表示归一化完成后可直接进入 runtime 的标准输入结果。 +type PreparedInput struct { + SessionID string + Workdir string + Parts []providertypes.ContentPart + SavedAssets []AssetMeta +} + +// AssetSaveError 描述图片落盘阶段的结构化失败信息,便于上层统一事件化处理。 +type AssetSaveError struct { + Index int + Path string + Err error +} + +func (e *AssetSaveError) Error() string { + if e == nil { + return "session: asset save failed" + } + if strings.TrimSpace(e.Path) == "" { + return fmt.Sprintf("session: save asset at index %d: %v", e.Index, e.Err) + } + return fmt.Sprintf("session: save asset %q at index %d: %v", e.Path, e.Index, e.Err) +} + +func (e *AssetSaveError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} + +// InputPreparer 负责把用户文本/图片输入归一化为会话级标准 parts。 +type InputPreparer struct { + store Store + assetStore AssetStore +} + +// NewInputPreparer 创建会话输入归一化组件。 +func NewInputPreparer(store Store, assetStore AssetStore) *InputPreparer { + return &InputPreparer{ + store: store, + assetStore: assetStore, + } +} + +// Prepare 负责会话解析/创建、附件落盘与 parts 组装。 +func (p *InputPreparer) Prepare(ctx context.Context, input PrepareInput) (PreparedInput, error) { + if err := ctx.Err(); err != nil { + return PreparedInput{}, err + } + if p == nil || p.store == nil { + return PreparedInput{}, fmt.Errorf("session: input preparer store is not configured") + } + if len(input.Images) > 0 && p.assetStore == nil { + return PreparedInput{}, fmt.Errorf("session: asset store is not configured") + } + + trimmedText := strings.TrimSpace(input.Text) + if trimmedText == "" && len(input.Images) == 0 { + return PreparedInput{}, fmt.Errorf("session: input content is empty") + } + + sessionTitle := buildSessionTitle(trimmedText, len(input.Images) > 0) + session, err := p.loadOrCreateSession( + ctx, + input.SessionID, + sessionTitle, + input.DefaultWorkdir, + input.RequestedWorkdir, + ) + if err != nil { + return PreparedInput{}, err + } + + parts := make([]providertypes.ContentPart, 0, 1+len(input.Images)) + if trimmedText != "" { + parts = append(parts, providertypes.NewTextPart(trimmedText)) + } + + savedAssets := make([]AssetMeta, 0, len(input.Images)) + for index, image := range input.Images { + path := strings.TrimSpace(image.Path) + if path == "" { + return PreparedInput{}, &AssetSaveError{ + Index: index, + Path: path, + Err: fmt.Errorf("image path is empty"), + } + } + mimeType := strings.TrimSpace(image.MimeType) + + meta, err := p.saveImageAsset(ctx, session.ID, path, mimeType) + if err != nil { + return PreparedInput{}, &AssetSaveError{ + Index: index, + Path: path, + Err: err, + } + } + savedAssets = append(savedAssets, meta) + parts = append(parts, providertypes.NewSessionAssetImagePart(meta.ID, meta.MimeType)) + } + + if err := providertypes.ValidateParts(parts); err != nil { + return PreparedInput{}, fmt.Errorf("session: normalize parts: %w", err) + } + + return PreparedInput{ + SessionID: session.ID, + Workdir: session.Workdir, + Parts: parts, + SavedAssets: savedAssets, + }, nil +} + +func (p *InputPreparer) saveImageAsset(ctx context.Context, sessionID string, path string, mimeType string) (AssetMeta, error) { + absolutePath, err := filepath.Abs(path) + if err != nil { + return AssetMeta{}, fmt.Errorf("resolve image path: %w", err) + } + + file, err := os.Open(absolutePath) + if err != nil { + return AssetMeta{}, fmt.Errorf("open image file: %w", err) + } + defer func() { + _ = file.Close() + }() + + resolvedMimeType, err := resolveImageMimeType(path, mimeType, file) + if err != nil { + return AssetMeta{}, err + } + + meta, err := p.assetStore.SaveAsset(ctx, sessionID, file, resolvedMimeType) + if err != nil { + return AssetMeta{}, err + } + return meta, nil +} + +// resolveImageMimeType 解析图片 MIME 类型,优先使用显式传入值,其次回退到扩展名与文件头探测。 +func resolveImageMimeType(path string, declared string, file *os.File) (string, error) { + if normalized := strings.ToLower(strings.TrimSpace(declared)); normalized != "" { + return normalized, nil + } + + extMime := strings.ToLower(strings.TrimSpace(mime.TypeByExtension(strings.ToLower(filepath.Ext(path))))) + if extMime != "" { + if idx := strings.Index(extMime, ";"); idx >= 0 { + extMime = strings.TrimSpace(extMime[:idx]) + } + if strings.HasPrefix(extMime, "image/") { + return extMime, nil + } + } + + buffer := make([]byte, 512) + n, readErr := file.Read(buffer) + if readErr != nil && readErr != io.EOF { + return "", fmt.Errorf("detect image mime type: %w", readErr) + } + if _, err := file.Seek(0, io.SeekStart); err != nil { + return "", fmt.Errorf("reset image reader: %w", err) + } + + detected := strings.ToLower(strings.TrimSpace(http.DetectContentType(buffer[:n]))) + if strings.HasPrefix(detected, "image/") { + return detected, nil + } + return "", fmt.Errorf("unsupported image format") +} + +func (p *InputPreparer) loadOrCreateSession( + ctx context.Context, + sessionID string, + title string, + defaultWorkdir string, + requestedWorkdir string, +) (Session, error) { + if strings.TrimSpace(sessionID) == "" { + sessionWorkdir, err := resolveWorkdirForInput(defaultWorkdir, "", requestedWorkdir) + if err != nil { + return Session{}, err + } + session := NewWithWorkdir(title, sessionWorkdir) + if err := p.store.Save(ctx, &session); err != nil { + return Session{}, err + } + return session, nil + } + + session, err := p.store.Load(ctx, sessionID) + if err != nil { + return Session{}, err + } + if strings.TrimSpace(requestedWorkdir) == "" && strings.TrimSpace(session.Workdir) != "" { + return session, nil + } + + resolved, err := resolveWorkdirForInput(defaultWorkdir, session.Workdir, requestedWorkdir) + if err != nil { + return Session{}, err + } + if session.Workdir == resolved { + return session, nil + } + + session.Workdir = resolved + session.UpdatedAt = time.Now() + if err := p.store.Save(ctx, &session); err != nil { + return Session{}, err + } + return session, nil +} + +func resolveWorkdirForInput(defaultWorkdir string, currentWorkdir string, requestedWorkdir string) (string, error) { + base := EffectiveWorkdir(currentWorkdir, defaultWorkdir) + if strings.TrimSpace(requestedWorkdir) == "" { + return ResolveExistingDir(base) + } + + target := strings.TrimSpace(requestedWorkdir) + if !filepath.IsAbs(target) { + target = filepath.Join(base, target) + } + return ResolveExistingDir(target) +} + +func buildSessionTitle(text string, hasImages bool) string { + if strings.TrimSpace(text) != "" { + return strings.TrimSpace(text) + } + if hasImages { + return imageOnlySessionTitle + } + return "New Session" +} diff --git a/internal/session/input_preparer_test.go b/internal/session/input_preparer_test.go new file mode 100644 index 00000000..8ef0d29d --- /dev/null +++ b/internal/session/input_preparer_test.go @@ -0,0 +1,242 @@ +package session + +import ( + "context" + "errors" + "io" + "os" + "path/filepath" + "testing" + + providertypes "neo-code/internal/provider/types" +) + +func TestInputPreparerPrepareTextOnly(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + store := NewStore(t.TempDir(), workdir) + preparer := NewInputPreparer(store, store) + + result, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "hello world", + DefaultWorkdir: workdir, + }) + if err != nil { + t.Fatalf("Prepare() error = %v", err) + } + if result.SessionID == "" { + t.Fatalf("expected non-empty session id") + } + if len(result.Parts) != 1 || result.Parts[0].Kind != providertypes.ContentPartText || result.Parts[0].Text != "hello world" { + t.Fatalf("unexpected prepared parts: %+v", result.Parts) + } + if len(result.SavedAssets) != 0 { + t.Fatalf("expected no saved assets, got %+v", result.SavedAssets) + } +} + +func TestInputPreparerPrepareTextAndImage(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + store := NewStore(t.TempDir(), workdir) + preparer := NewInputPreparer(store, store) + + imagePath := filepath.Join(workdir, "img.png") + payload := []byte("fake-png") + if err := os.WriteFile(imagePath, payload, 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + result, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "with image", + Images: []PrepareImageInput{{Path: imagePath, MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err != nil { + t.Fatalf("Prepare() error = %v", err) + } + if len(result.SavedAssets) != 1 { + t.Fatalf("expected one saved asset, got %+v", result.SavedAssets) + } + if len(result.Parts) != 2 { + t.Fatalf("expected 2 parts, got %+v", result.Parts) + } + imagePart := result.Parts[1] + if imagePart.Kind != providertypes.ContentPartImage || imagePart.Image == nil || imagePart.Image.Asset == nil { + t.Fatalf("expected session asset image part, got %+v", imagePart) + } + if imagePart.Image.Asset.ID != result.SavedAssets[0].ID { + t.Fatalf("expected image part asset id %q, got %+v", result.SavedAssets[0].ID, imagePart.Image.Asset) + } + + rc, meta, err := store.Open(context.Background(), result.SessionID, result.SavedAssets[0].ID) + if err != nil { + t.Fatalf("Open() error = %v", err) + } + defer func() { _ = rc.Close() }() + got, err := io.ReadAll(rc) + if err != nil { + t.Fatalf("ReadAll() error = %v", err) + } + if meta.MimeType != "image/png" || string(got) != string(payload) { + t.Fatalf("unexpected stored asset mime=%q payload=%q", meta.MimeType, string(got)) + } +} + +func TestInputPreparerPrepareImageInfersMimeWhenMissing(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + store := NewStore(t.TempDir(), workdir) + preparer := NewInputPreparer(store, store) + + imagePath := filepath.Join(workdir, "auto.png") + if err := os.WriteFile(imagePath, []byte("fake-png"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + result, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "infer mime", + Images: []PrepareImageInput{{Path: imagePath}}, + DefaultWorkdir: workdir, + }) + if err != nil { + t.Fatalf("Prepare() error = %v", err) + } + if len(result.SavedAssets) != 1 { + t.Fatalf("expected one saved asset, got %+v", result.SavedAssets) + } + if result.SavedAssets[0].MimeType != "image/png" { + t.Fatalf("expected inferred mime image/png, got %q", result.SavedAssets[0].MimeType) + } +} + +func TestInputPreparerPrepareImageOnlyUsesImageTitle(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + store := NewStore(t.TempDir(), workdir) + preparer := NewInputPreparer(store, store) + + imagePath := filepath.Join(workdir, "only.png") + if err := os.WriteFile(imagePath, []byte("img"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + result, err := preparer.Prepare(context.Background(), PrepareInput{ + Images: []PrepareImageInput{{Path: imagePath, MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err != nil { + t.Fatalf("Prepare() error = %v", err) + } + if len(result.Parts) != 1 || result.Parts[0].Kind != providertypes.ContentPartImage { + t.Fatalf("expected one image part, got %+v", result.Parts) + } + + session, err := store.Load(context.Background(), result.SessionID) + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if session.Title != imageOnlySessionTitle { + t.Fatalf("expected image-only title %q, got %q", imageOnlySessionTitle, session.Title) + } +} + +func TestInputPreparerPrepareErrors(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + store := NewStore(t.TempDir(), workdir) + + t.Run("missing store", func(t *testing.T) { + preparer := NewInputPreparer(nil, nil) + if _, err := preparer.Prepare(context.Background(), PrepareInput{Text: "x", DefaultWorkdir: workdir}); err == nil { + t.Fatalf("expected missing store error") + } + }) + + t.Run("missing asset store", func(t *testing.T) { + preparer := NewInputPreparer(store, nil) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Images: []PrepareImageInput{{Path: "x", MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected missing asset store error") + } + }) + + t.Run("empty content", func(t *testing.T) { + preparer := NewInputPreparer(store, store) + if _, err := preparer.Prepare(context.Background(), PrepareInput{DefaultWorkdir: workdir}); err == nil { + t.Fatalf("expected empty content error") + } + }) + + t.Run("asset save error is structured", func(t *testing.T) { + preparer := NewInputPreparer(store, store) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Images: []PrepareImageInput{{Path: "not-found.png", MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected asset save error") + } + var saveErr *AssetSaveError + if !errors.As(err, &saveErr) { + t.Fatalf("expected AssetSaveError, got %T %v", err, err) + } + if saveErr.Index != 0 { + t.Fatalf("expected save error index 0, got %d", saveErr.Index) + } + }) +} + +func TestInputPreparerPrepareUpdatesExistingSessionWorkdir(t *testing.T) { + t.Parallel() + + base := t.TempDir() + defaultWorkdir := filepath.Join(base, "workspace") + if err := os.MkdirAll(defaultWorkdir, 0o755); err != nil { + t.Fatalf("mkdir default workdir: %v", err) + } + currentWorkdir := filepath.Join(defaultWorkdir, "current") + if err := os.MkdirAll(currentWorkdir, 0o755); err != nil { + t.Fatalf("mkdir current workdir: %v", err) + } + targetWorkdir := filepath.Join(currentWorkdir, "nested") + if err := os.MkdirAll(targetWorkdir, 0o755); err != nil { + t.Fatalf("mkdir nested workdir: %v", err) + } + + store := NewStore(t.TempDir(), defaultWorkdir) + session := NewWithWorkdir("existing", currentWorkdir) + if err := store.Save(context.Background(), &session); err != nil { + t.Fatalf("Save() error = %v", err) + } + + preparer := NewInputPreparer(store, store) + result, err := preparer.Prepare(context.Background(), PrepareInput{ + SessionID: session.ID, + Text: "update workdir", + DefaultWorkdir: defaultWorkdir, + RequestedWorkdir: "nested", + }) + if err != nil { + t.Fatalf("Prepare() error = %v", err) + } + if result.Workdir != targetWorkdir { + t.Fatalf("expected target workdir %q, got %q", targetWorkdir, result.Workdir) + } + + loaded, err := store.Load(context.Background(), session.ID) + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if loaded.Workdir != targetWorkdir { + t.Fatalf("expected persisted workdir %q, got %q", targetWorkdir, loaded.Workdir) + } +} diff --git a/internal/tools/todo/write.go b/internal/tools/todo/write.go index c073d72c..8a94812e 100644 --- a/internal/tools/todo/write.go +++ b/internal/tools/todo/write.go @@ -138,6 +138,9 @@ func (t *Tool) Schema() map[string]any { }, "artifacts": map[string]any{ "type": "array", + "items": map[string]any{ + "type": "string", + }, }, "reason": map[string]any{ "type": "string", diff --git a/internal/tools/todo/write_test.go b/internal/tools/todo/write_test.go index 85d38a95..634335b7 100644 --- a/internal/tools/todo/write_test.go +++ b/internal/tools/todo/write_test.go @@ -202,6 +202,20 @@ func TestToolMetadataMethods(t *testing.T) { if _, ok := properties["items"]; !ok { t.Fatalf("Schema() should include items property") } + artifacts, ok := properties["artifacts"].(map[string]any) + if !ok { + t.Fatalf("Schema() artifacts should be object, got %T", properties["artifacts"]) + } + if artifacts["type"] != "array" { + t.Fatalf("Schema() artifacts.type = %+v, want array", artifacts["type"]) + } + items, ok := artifacts["items"].(map[string]any) + if !ok { + t.Fatalf("Schema() artifacts.items should be object, got %T", artifacts["items"]) + } + if items["type"] != "string" { + t.Fatalf("Schema() artifacts.items.type = %+v, want string", items["type"]) + } } func TestToolExecuteActionSequence(t *testing.T) { diff --git a/internal/tui/bootstrap/builder_test.go b/internal/tui/bootstrap/builder_test.go index 7a688b78..b6bf3677 100644 --- a/internal/tui/bootstrap/builder_test.go +++ b/internal/tui/bootstrap/builder_test.go @@ -15,6 +15,26 @@ import ( type testRuntime struct{} +func (r *testRuntime) PrepareUserInput(ctx context.Context, input agentruntime.PrepareInput) (agentruntime.UserInput, error) { + return agentruntime.UserInput{ + SessionID: input.SessionID, + RunID: input.RunID, + Workdir: input.Workdir, + }, nil +} + +func (r *testRuntime) Submit(ctx context.Context, input agentruntime.PrepareInput) error { + _, err := r.PrepareUserInput(ctx, input) + if err != nil { + return err + } + return r.Run(ctx, agentruntime.UserInput{ + SessionID: input.SessionID, + RunID: input.RunID, + Workdir: input.Workdir, + }) +} + func (r *testRuntime) Run(ctx context.Context, input agentruntime.UserInput) error { return nil } @@ -206,6 +226,26 @@ func (f errorFactory) BuildProvider(mode Mode, current ProviderService) (Provide type noopRuntime struct{} +func (r noopRuntime) PrepareUserInput(ctx context.Context, input agentruntime.PrepareInput) (agentruntime.UserInput, error) { + return agentruntime.UserInput{ + SessionID: input.SessionID, + RunID: input.RunID, + Workdir: input.Workdir, + }, nil +} + +func (r noopRuntime) Submit(ctx context.Context, input agentruntime.PrepareInput) error { + _, err := r.PrepareUserInput(ctx, input) + if err != nil { + return err + } + return r.Run(ctx, agentruntime.UserInput{ + SessionID: input.SessionID, + RunID: input.RunID, + Workdir: input.Workdir, + }) +} + func (r noopRuntime) Run(ctx context.Context, input agentruntime.UserInput) error { return nil } diff --git a/internal/tui/core/app/app.go b/internal/tui/core/app/app.go index 394eaa82..71953775 100644 --- a/internal/tui/core/app/app.go +++ b/internal/tui/core/app/app.go @@ -89,27 +89,26 @@ type appComponents struct { // appRuntimeState 聚合运行期易变字段,降低 App 顶层字段密度。 type appRuntimeState struct { - codeCopyBlocks map[int]string - pendingCopyID int - deferredEventCmd tea.Cmd - nowFn func() time.Time - lastInputEditAt time.Time - lastPasteLikeAt time.Time - inputBurstStart time.Time - inputBurstCount int - pasteMode bool - activeMessages []providertypes.Message - activities []tuistate.ActivityEntry - fileCandidates []string - modelRefreshID string - focus panel - runProgressValue float64 - runProgressKnown bool - runProgressLabel string - pendingPermission *permissionPromptState - pendingImageAttachments []pendingImageAttachment - currentModelCapabilities modelCapabilityState - providerAddForm *providerAddFormState + codeCopyBlocks map[int]string + pendingCopyID int + deferredEventCmd tea.Cmd + nowFn func() time.Time + lastInputEditAt time.Time + lastPasteLikeAt time.Time + inputBurstStart time.Time + inputBurstCount int + pasteMode bool + activeMessages []providertypes.Message + activities []tuistate.ActivityEntry + fileCandidates []string + modelRefreshID string + focus panel + runProgressValue float64 + runProgressKnown bool + runProgressLabel string + pendingPermission *permissionPromptState + pendingImageAttachments []pendingImageAttachment + providerAddForm *providerAddFormState } type pendingImageAttachment struct { @@ -119,11 +118,6 @@ type pendingImageAttachment struct { Name string } -type modelCapabilityState struct { - supportsImageInput bool - checked bool -} - // providerAddFormState 保存添加新 provider 表单的状态。 type providerAddFormState struct { Step int // 当前聚焦字段在“当前 driver 可见字段列表”中的索引 diff --git a/internal/tui/core/app/input_features.go b/internal/tui/core/app/input_features.go index c52d4a3b..892e5c01 100644 --- a/internal/tui/core/app/input_features.go +++ b/internal/tui/core/app/input_features.go @@ -24,7 +24,6 @@ const ( maxWorkspaceFiles = 4000 maxFileSuggestions = 6 maxImageAttachments = 3 - imageMaxSizeBytes = 5 * 1024 * 1024 // 5 MiB ) type tokenSelector int @@ -37,7 +36,6 @@ const ( var workspaceCommandExecutor = defaultWorkspaceCommandExecutor var readClipboardImage = tuiinfra.ReadClipboardImage var saveClipboardImageToTempFile = tuiinfra.SaveImageToTempFile -var detectImageMimeType = tuiinfra.DetectImageMimeType func isWorkspaceCommandInput(input string) bool { return strings.HasPrefix(strings.TrimSpace(input), workspaceCommandPrefix) @@ -219,6 +217,99 @@ func (a *App) applyImageReference(input string) error { return a.addImageAttachment(path) } +// absorbInlineImageReferences 会把输入文本中的 @ 令牌吸收到附件队列,并返回移除令牌后的文本。 +// 仅根据令牌语法与扩展名做轻量识别,避免把文件系统硬校验放到 TUI 层。 +func (a *App) absorbInlineImageReferences(input string) (string, int, error) { + tokens := strings.Fields(input) + if len(tokens) == 0 { + return strings.TrimSpace(input), 0, nil + } + + kept := make([]string, 0, len(tokens)) + absorbed := 0 + for _, token := range tokens { + imagePath, ok := a.parseInlineImagePathToken(token) + if !ok { + kept = append(kept, token) + continue + } + if err := a.queueImageAttachmentForPrepare(imagePath); err != nil { + return "", absorbed, err + } + absorbed++ + } + + return strings.TrimSpace(strings.Join(kept, " ")), absorbed, nil +} + +// parseInlineImagePathToken 识别 @ 形式的图片路径令牌,并映射为待发送路径。 +func (a *App) parseInlineImagePathToken(token string) (string, bool) { + trimmed := strings.TrimSpace(token) + if !strings.HasPrefix(trimmed, fileReferencePrefix) || strings.HasPrefix(trimmed, imageReferencePrefix) { + return "", false + } + + path := strings.TrimPrefix(trimmed, fileReferencePrefix) + path = strings.Trim(path, `"'`) + path = strings.TrimSpace(path) + if path == "" || !looksLikeImagePath(path) { + return "", false + } + + resolved := path + if !filepath.IsAbs(resolved) { + base := strings.TrimSpace(a.state.CurrentWorkdir) + if base == "" { + return "", false + } + resolved = filepath.Join(base, resolved) + } + return resolved, true +} + +// queueImageAttachmentForPrepare 将图片路径排队为待发送附件,不在 TUI 层做文件系统和 MIME 硬校验。 +// 真正的可用性校验与错误语义统一在 runtime/session 归一化阶段完成。 +func (a *App) queueImageAttachmentForPrepare(path string) error { + path = strings.TrimSpace(path) + if path == "" { + return fmt.Errorf("image path is empty") + } + if len(a.pendingImageAttachments) >= maxImageAttachments { + return fmt.Errorf("maximum %d image attachments allowed", maxImageAttachments) + } + + resolved := path + if !filepath.IsAbs(resolved) { + base := strings.TrimSpace(a.state.CurrentWorkdir) + if base != "" { + resolved = filepath.Join(base, resolved) + } + } + absPath, err := filepath.Abs(resolved) + if err != nil { + return fmt.Errorf("invalid image path: %w", err) + } + + a.pendingImageAttachments = append(a.pendingImageAttachments, pendingImageAttachment{ + Path: absPath, + MimeType: "", + Size: 0, + Name: filepath.Base(absPath), + }) + a.refreshImageAttachmentDisplay() + return nil +} + +// looksLikeImagePath 使用扩展名快速判断路径是否是常见图片文件。 +func looksLikeImagePath(path string) bool { + switch strings.ToLower(filepath.Ext(strings.TrimSpace(path))) { + case ".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp": + return true + default: + return false + } +} + func (a *App) applyFileReference(path string) error { path = strings.TrimSpace(path) if path == "" { @@ -275,43 +366,12 @@ func extractImageReference(input string) string { } func (a *App) addImageAttachment(path string) error { - path = strings.TrimSpace(path) - if path == "" { - return fmt.Errorf("image path is empty") - } - - if len(a.pendingImageAttachments) >= maxImageAttachments { - return fmt.Errorf("maximum %d image attachments allowed", maxImageAttachments) - } - - absPath, err := filepath.Abs(path) - if err != nil { - return fmt.Errorf("invalid image path: %w", err) - } - - info, err := tuiinfra.GetFileInfo(absPath) - if err != nil { - return fmt.Errorf("cannot read image file: %w", err) - } - - if info.Size() > imageMaxSizeBytes { - return fmt.Errorf("image size exceeds %d MB limit", imageMaxSizeBytes/(1024*1024)) + if err := a.queueImageAttachmentForPrepare(path); err != nil { + return err } - - mimeType := detectImageMimeType(absPath) - if mimeType == "" { - return fmt.Errorf("unsupported image format") + if count := len(a.pendingImageAttachments); count > 0 { + a.state.StatusText = fmt.Sprintf("[System] Added image: %s", a.pendingImageAttachments[count-1].Name) } - - a.pendingImageAttachments = append(a.pendingImageAttachments, pendingImageAttachment{ - Path: absPath, - MimeType: mimeType, - Size: info.Size(), - Name: filepath.Base(absPath), - }) - - a.refreshImageAttachmentDisplay() - a.state.StatusText = fmt.Sprintf("[System] Added image: %s", filepath.Base(absPath)) return nil } @@ -370,62 +430,13 @@ func (a *App) addImageFromClipboard() error { return fmt.Errorf("no image in clipboard") } - if int64(len(data)) > imageMaxSizeBytes { - return fmt.Errorf("image size exceeds %d MB limit", imageMaxSizeBytes/(1024*1024)) - } - tmpPath, err := saveClipboardImageToTempFile(data, "paste") if err != nil { return fmt.Errorf("failed to save clipboard image: %w", err) } - - mimeType := detectImageMimeType(tmpPath) - if mimeType == "" { - return fmt.Errorf("unsupported image format from clipboard") + if err := a.queueImageAttachmentForPrepare(tmpPath); err != nil { + return err } - - a.pendingImageAttachments = append(a.pendingImageAttachments, pendingImageAttachment{ - Path: tmpPath, - MimeType: mimeType, - Size: int64(len(data)), - Name: "clipboard_image.png", - }) - - a.refreshImageAttachmentDisplay() a.state.StatusText = "[System] Added image from clipboard" return nil } - -func (a *App) checkModelImageSupport() bool { - if a.currentModelCapabilities.checked { - return a.currentModelCapabilities.supportsImageInput - } - - models, err := a.providerSvc.ListModelsSnapshot(context.Background()) - if err != nil { - a.currentModelCapabilities.checked = true - a.currentModelCapabilities.supportsImageInput = false - return false - } - - for _, m := range models { - if m.ID == a.state.CurrentModel { - a.currentModelCapabilities.checked = true - a.currentModelCapabilities.supportsImageInput = m.CapabilityHints.ImageInput == "supported" - return a.currentModelCapabilities.supportsImageInput - } - } - - a.currentModelCapabilities.checked = true - a.currentModelCapabilities.supportsImageInput = false - return false -} - -func (a *App) canSendImageInput() bool { - return a.checkModelImageSupport() -} - -// invalidateModelCapabilityCache 在 provider 或 model 变化时清理图片能力缓存,避免复用旧结果。 -func (a *App) invalidateModelCapabilityCache() { - a.currentModelCapabilities = modelCapabilityState{} -} diff --git a/internal/tui/core/app/input_features_test.go b/internal/tui/core/app/input_features_test.go index b38e0edc..7e7e4aaf 100644 --- a/internal/tui/core/app/input_features_test.go +++ b/internal/tui/core/app/input_features_test.go @@ -12,19 +12,8 @@ import ( tea "github.com/charmbracelet/bubbletea" "neo-code/internal/config" - configstate "neo-code/internal/config/state" - providertypes "neo-code/internal/provider/types" ) -type snapshotErrProviderService struct { - stubProviderService - err error -} - -func (s snapshotErrProviderService) ListModelsSnapshot(ctx context.Context) ([]providertypes.ModelDescriptor, error) { - return nil, s.err -} - func TestTokenAndReferenceParsing(t *testing.T) { start, end, token, ok := tokenRange(" @file/path", tokenSelectorFirst) if !ok || start != 2 || end != len(" @file/path") || token != "@file/path" { @@ -165,41 +154,6 @@ func TestAddImageAttachmentLimit(t *testing.T) { } } -func TestCanSendImageInputCacheInvalidationOnModelChange(t *testing.T) { - app, _ := newTestApp(t) - providerID := app.state.CurrentProvider - - app.providerSvc = stubProviderService{ - providers: []configstate.ProviderOption{{ID: providerID, Name: providerID}}, - models: []providertypes.ModelDescriptor{{ - ID: "model-a", - Name: "model-a", - CapabilityHints: providertypes.ModelCapabilityHints{ - ImageInput: providertypes.ModelCapabilityStateSupported, - }, - }}, - } - app.state.CurrentModel = "model-a" - if !app.canSendImageInput() { - t.Fatalf("expected model-a to support images") - } - - app.providerSvc = stubProviderService{ - providers: []configstate.ProviderOption{{ID: providerID, Name: providerID}}, - models: []providertypes.ModelDescriptor{{ - ID: "model-b", - Name: "model-b", - CapabilityHints: providertypes.ModelCapabilityHints{ - ImageInput: providertypes.ModelCapabilityStateUnsupported, - }, - }}, - } - app.syncConfigState(config.Config{SelectedProvider: providerID, CurrentModel: "model-b", Workdir: app.state.CurrentWorkdir}) - if app.canSendImageInput() { - t.Fatalf("expected model-b to be unsupported after cache invalidation") - } -} - func TestApplyImageReference(t *testing.T) { app, _ := newTestApp(t) root := t.TempDir() @@ -218,6 +172,73 @@ func TestApplyImageReference(t *testing.T) { } } +func TestAbsorbInlineImageReferences(t *testing.T) { + app, _ := newTestApp(t) + root := t.TempDir() + app.state.CurrentWorkdir = root + + imagePath := filepath.Join(root, "chart.png") + if err := os.WriteFile(imagePath, []byte("png"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + normalized, absorbed, err := app.absorbInlineImageReferences("请分析 @chart.png 趋势") + if err != nil { + t.Fatalf("absorbInlineImageReferences() error = %v", err) + } + if absorbed != 1 { + t.Fatalf("expected one absorbed image, got %d", absorbed) + } + if normalized != "请分析 趋势" { + t.Fatalf("unexpected normalized text: %q", normalized) + } + if app.getImageAttachmentCount() != 1 { + t.Fatalf("expected one pending image attachment, got %d", app.getImageAttachmentCount()) + } +} + +func TestAbsorbInlineImageReferencesKeepsNonImageToken(t *testing.T) { + app, _ := newTestApp(t) + root := t.TempDir() + app.state.CurrentWorkdir = root + + normalized, absorbed, err := app.absorbInlineImageReferences("查看 @README.md 内容") + if err != nil { + t.Fatalf("absorbInlineImageReferences() error = %v", err) + } + if absorbed != 0 { + t.Fatalf("expected absorbed image count to be 0, got %d", absorbed) + } + if normalized != "查看 @README.md 内容" { + t.Fatalf("unexpected normalized text: %q", normalized) + } + if app.getImageAttachmentCount() != 0 { + t.Fatalf("expected no pending image attachments") + } +} + +func TestAbsorbInlineImageReferencesDoesNotRequireFileExistenceInTUI(t *testing.T) { + app, _ := newTestApp(t) + app.state.CurrentWorkdir = t.TempDir() + + normalized, absorbed, err := app.absorbInlineImageReferences("处理 @not-exist.png") + if err != nil { + t.Fatalf("absorbInlineImageReferences() error = %v", err) + } + if absorbed != 1 { + t.Fatalf("expected one absorbed image, got %d", absorbed) + } + if normalized != "处理" { + t.Fatalf("unexpected normalized text: %q", normalized) + } + if app.getImageAttachmentCount() != 1 { + t.Fatalf("expected one pending attachment") + } + if app.getImageAttachments()[0].MimeType != "" { + t.Fatalf("expected mime type to stay empty before runtime/session validation") + } +} + func TestGetAndClearImageAttachments(t *testing.T) { app, _ := newTestApp(t) app.pendingImageAttachments = []pendingImageAttachment{ @@ -250,7 +271,6 @@ func TestAddImageFromClipboardSuccess(t *testing.T) { app, _ := newTestApp(t) originalRead := readClipboardImage originalSave := saveClipboardImageToTempFile - originalDetect := detectImageMimeType readClipboardImage = func() ([]byte, error) { return []byte("image-bytes"), nil } @@ -261,11 +281,9 @@ func TestAddImageFromClipboardSuccess(t *testing.T) { } return path, nil } - detectImageMimeType = func(path string) string { return "image/png" } defer func() { readClipboardImage = originalRead saveClipboardImageToTempFile = originalSave - detectImageMimeType = originalDetect }() if err := app.addImageFromClipboard(); err != nil { @@ -280,11 +298,9 @@ func TestAddImageFromClipboardBranches(t *testing.T) { app, _ := newTestApp(t) originalRead := readClipboardImage originalSave := saveClipboardImageToTempFile - originalDetect := detectImageMimeType defer func() { readClipboardImage = originalRead saveClipboardImageToTempFile = originalSave - detectImageMimeType = originalDetect }() readClipboardImage = func() ([]byte, error) { return nil, nil } @@ -292,20 +308,6 @@ func TestAddImageFromClipboardBranches(t *testing.T) { t.Fatalf("expected no image in clipboard error") } - readClipboardImage = func() ([]byte, error) { return make([]byte, imageMaxSizeBytes+1), nil } - if err := app.addImageFromClipboard(); err == nil { - t.Fatalf("expected image size limit error") - } - - readClipboardImage = func() ([]byte, error) { return []byte("x"), nil } - saveClipboardImageToTempFile = func(data []byte, prefix string) (string, error) { - return filepath.Join(t.TempDir(), "clipboard.bin"), nil - } - detectImageMimeType = func(path string) string { return "" } - if err := app.addImageFromClipboard(); err == nil { - t.Fatalf("expected unsupported image format error") - } - readClipboardImage = func() ([]byte, error) { return []byte("x"), nil } saveClipboardImageToTempFile = func(data []byte, prefix string) (string, error) { return "", errors.New("save failed") @@ -315,31 +317,6 @@ func TestAddImageFromClipboardBranches(t *testing.T) { } } -func TestCheckModelImageSupportErrorAndModelNotFound(t *testing.T) { - app, _ := newTestApp(t) - app.providerSvc = snapshotErrProviderService{ - stubProviderService: stubProviderService{}, - err: errors.New("boom"), - } - if app.checkModelImageSupport() { - t.Fatalf("expected false when provider snapshot fails") - } - if !app.currentModelCapabilities.checked { - t.Fatalf("expected capability cache to be marked checked after failure") - } - - app.currentModelCapabilities = modelCapabilityState{} - app.providerSvc = stubProviderService{ - providers: []configstate.ProviderOption{{ID: app.state.CurrentProvider, Name: app.state.CurrentProvider}}, - models: []providertypes.ModelDescriptor{{ - ID: "other-model", - }}, - } - if app.checkModelImageSupport() { - t.Fatalf("expected false when current model is missing from snapshot") - } -} - func TestExecuteWorkspaceCommand(t *testing.T) { app, _ := newTestApp(t) original := workspaceCommandExecutor @@ -393,7 +370,7 @@ func TestRunWorkspaceCommandCmd(t *testing.T) { } } -func TestUpdateSendWithImageAttachmentsBlocksUntilSessionAssets(t *testing.T) { +func TestUpdateSendWithImageAttachmentsRunsThroughPreparePipeline(t *testing.T) { app, runtime := newTestApp(t) root := t.TempDir() imagePath := filepath.Join(root, "queued.png") @@ -403,17 +380,6 @@ func TestUpdateSendWithImageAttachmentsBlocksUntilSessionAssets(t *testing.T) { if err := app.addImageAttachment(imagePath); err != nil { t.Fatalf("addImageAttachment() error = %v", err) } - app.providerSvc = stubProviderService{ - providers: []configstate.ProviderOption{{ID: app.state.CurrentProvider, Name: app.state.CurrentProvider}}, - models: []providertypes.ModelDescriptor{{ - ID: app.state.CurrentModel, - Name: app.state.CurrentModel, - CapabilityHints: providertypes.ModelCapabilityHints{ - ImageInput: providertypes.ModelCapabilityStateSupported, - }, - }}, - } - app.input.SetValue("hello") app.state.InputText = "hello" @@ -423,18 +389,21 @@ func TestUpdateSendWithImageAttachmentsBlocksUntilSessionAssets(t *testing.T) { } app = model.(App) if app.hasImageAttachments() { - t.Fatalf("expected attachments cleared after unsupported send path") + t.Fatalf("expected attachments cleared after send") + } + if !app.state.IsAgentRunning { + t.Fatalf("expected image send to enter running state") } - if app.state.IsAgentRunning { - t.Fatalf("expected image send to be blocked until session assets are available") + if app.state.StatusText != statusThinking { + t.Fatalf("unexpected status text: %q", app.state.StatusText) } - if len(app.activeMessages) != 0 { - t.Fatalf("expected no text fallback message in transcript, got %+v", app.activeMessages) + if len(runtime.prepareInputs) != 1 { + t.Fatalf("expected one prepare input, got %+v", runtime.prepareInputs) } - if len(runtime.runInputs) != 0 { - t.Fatalf("expected no runtime input with image metadata fallback, got %+v", runtime.runInputs) + if len(runtime.prepareInputs[0].Images) != 1 || runtime.prepareInputs[0].Images[0].MimeType != "" { + t.Fatalf("expected one queued image in prepare input, got %+v", runtime.prepareInputs[0].Images) } - if app.state.StatusText != "Image attachments need session asset support" { - t.Fatalf("unexpected status text: %q", app.state.StatusText) + if len(runtime.runInputs) != 1 { + t.Fatalf("expected one runtime input after prepare, got %+v", runtime.runInputs) } } diff --git a/internal/tui/core/app/update.go b/internal/tui/core/app/update.go index 9fea342e..aa0d93f0 100644 --- a/internal/tui/core/app/update.go +++ b/internal/tui/core/app/update.go @@ -319,7 +319,8 @@ func (a App) updateInputPanel(msg tea.Msg, typed tea.KeyMsg, cmds []tea.Cmd) (te effectiveTyped = tea.KeyMsg{Type: tea.KeyEnter, Paste: true} } else { input := strings.TrimSpace(a.input.Value()) - if input == "" || a.isBusy() { + hasImages := a.hasImageAttachments() + if (input == "" && !hasImages) || a.isBusy() { return a, tea.Batch(cmds...) } @@ -411,21 +412,18 @@ func (a App) updateInputPanel(msg tea.Msg, typed tea.KeyMsg, cmds []tea.Cmd) (te return a, tea.Batch(cmds...) } - if a.hasImageAttachments() && !a.canSendImageInput() { - a.state.ExecutionError = "current model does not support image input" - a.state.StatusText = "Model does not support images" - a.appendActivity("multimodal", "Image input not supported", fmt.Sprintf("Model %s does not support image input", a.state.CurrentModel), true) - a.clearImageAttachments() + normalizedInput, absorbedImages, err := a.absorbInlineImageReferences(input) + if err != nil { + a.state.ExecutionError = err.Error() + a.state.StatusText = err.Error() + a.appendActivity("multimodal", "Failed to absorb inline image reference", err.Error(), true) return a, tea.Batch(cmds...) } - if a.hasImageAttachments() { - a.state.ExecutionError = "image attachments require session asset storage before sending" - a.state.StatusText = "Image attachments need session asset support" - a.appendActivity("multimodal", "Image attachments not sent", "Session asset storage is not available yet; images were not converted to text.", true) - a.clearImageAttachments() - return a, tea.Batch(cmds...) + if absorbedImages > 0 { + input = normalizedInput } + // image capability precheck is intentionally disabled. // 如果不是立即执行的命令,再执行常规的输入重置 a.input.Reset() a.state.InputText = "" @@ -442,12 +440,23 @@ func (a App) updateInputPanel(msg tea.Msg, typed tea.KeyMsg, cmds []tea.Cmd) (te a.state.StatusText = statusThinking a.state.CurrentTool = "" - a.activeMessages = append(a.activeMessages, providertypes.Message{Role: roleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart(input)}}) - a.rebuildTranscript() runID := fmt.Sprintf("run-%d", a.now().UnixNano()) a.state.ActiveRunID = runID requestedWorkdir := tuiutils.RequestedWorkdirForRun(a.state.CurrentWorkdir) - cmds = append(cmds, runAgent(a.runtime, runID, a.state.ActiveSessionID, requestedWorkdir, input)) + images := make([]agentruntime.UserImageInput, 0, len(a.pendingImageAttachments)) + for _, attachment := range a.pendingImageAttachments { + images = append(images, agentruntime.UserImageInput{ + Path: attachment.Path, + MimeType: attachment.MimeType, + }) + } + cmds = append(cmds, runAgent(a.runtime, agentruntime.PrepareInput{ + SessionID: a.state.ActiveSessionID, + RunID: runID, + Workdir: requestedWorkdir, + Text: input, + Images: images, + })) a.clearImageAttachments() return a, tea.Batch(cmds...) } @@ -726,6 +735,7 @@ func (a *App) refreshSessionPicker() error { } func (a *App) refreshMessages() error { + a.resetSessionRuntimeState() if strings.TrimSpace(a.state.ActiveSessionID) == "" { a.activeMessages = nil a.clearActivities() @@ -745,6 +755,19 @@ func (a *App) refreshMessages() error { return nil } +// resetSessionRuntimeState 在切换/刷新会话前清理运行态缓存,避免跨会话残留工具与用量展示。 +func (a *App) resetSessionRuntimeState() { + a.state.IsAgentRunning = false + a.state.StreamingReply = false + a.state.CurrentTool = "" + a.state.ActiveRunID = "" + a.state.ToolStates = nil + a.state.RunContext = tuistate.ContextWindowState{} + a.state.TokenUsage = tuistate.TokenUsageState{} + a.pendingPermission = nil + a.clearRunProgress() +} + func (a *App) activateSelectedSession() error { item, ok := a.sessionPicker.SelectedItem().(sessionItem) if !ok { @@ -789,10 +812,6 @@ func (a *App) syncActiveSessionTitle() { } func (a *App) syncConfigState(cfg config.Config) { - if !strings.EqualFold(strings.TrimSpace(a.state.CurrentProvider), strings.TrimSpace(cfg.SelectedProvider)) || - !strings.EqualFold(strings.TrimSpace(a.state.CurrentModel), strings.TrimSpace(cfg.CurrentModel)) { - a.invalidateModelCapabilityCache() - } a.state.CurrentProvider = cfg.SelectedProvider a.state.CurrentModel = cfg.CurrentModel if strings.TrimSpace(a.state.CurrentWorkdir) == "" { @@ -868,6 +887,9 @@ type runtimeRunSnapshotSource interface { var runtimeEventHandlerRegistry = map[agentruntime.EventType]func(*App, agentruntime.RuntimeEvent) bool{ agentruntime.EventUserMessage: runtimeEventUserMessageHandler, + agentruntime.EventInputNormalized: runtimeEventInputNormalizedHandler, + agentruntime.EventAssetSaved: runtimeEventAssetSavedHandler, + agentruntime.EventAssetSaveFailed: runtimeEventAssetSaveFailedHandler, agentruntime.EventType(tuiservices.RuntimeEventRunContext): runtimeEventRunContextHandler, agentruntime.EventType(tuiservices.RuntimeEventToolStatus): runtimeEventToolStatusHandler, agentruntime.EventType(tuiservices.RuntimeEventUsage): runtimeEventUsageHandler, @@ -951,6 +973,59 @@ func (a *App) handleRuntimeEvent(event agentruntime.RuntimeEvent) bool { } // runtimeEventUserMessageHandler 处理用户消息进入运行队列后的状态同步。 +// runtimeEventInputNormalizedHandler 处理输入归一化完成事件并更新运行态提示。 +func runtimeEventInputNormalizedHandler(a *App, event agentruntime.RuntimeEvent) bool { + if strings.TrimSpace(event.RunID) != "" { + a.state.ActiveRunID = strings.TrimSpace(event.RunID) + } + payload, ok := event.Payload.(agentruntime.InputNormalizedPayload) + if !ok { + return false + } + if payload.ImageCount > 0 { + a.appendActivity( + "multimodal", + "Input normalized", + fmt.Sprintf("text=%d chars, images=%d", payload.TextLength, payload.ImageCount), + false, + ) + } + return false +} + +// runtimeEventAssetSavedHandler 处理附件保存成功事件并写入活动面板。 +func runtimeEventAssetSavedHandler(a *App, event agentruntime.RuntimeEvent) bool { + payload, ok := event.Payload.(agentruntime.AssetSavedPayload) + if !ok { + return false + } + detail := strings.TrimSpace(payload.AssetID) + if detail == "" { + detail = "asset saved" + } + if strings.TrimSpace(payload.Path) != "" { + detail = fmt.Sprintf("%s (%s)", detail, filepath.Base(payload.Path)) + } + a.appendActivity("multimodal", "Saved attachment", detail, false) + return false +} + +// runtimeEventAssetSaveFailedHandler 处理附件保存失败事件并同步错误状态。 +func runtimeEventAssetSaveFailedHandler(a *App, event agentruntime.RuntimeEvent) bool { + payload, ok := event.Payload.(agentruntime.AssetSaveFailedPayload) + if !ok { + return false + } + message := strings.TrimSpace(payload.Message) + if message == "" { + message = "failed to save attachment" + } + a.state.ExecutionError = message + a.state.StatusText = message + a.appendActivity("multimodal", "Failed to save attachment", message, true) + return false +} + func runtimeEventUserMessageHandler(a *App, event agentruntime.RuntimeEvent) bool { if strings.TrimSpace(event.RunID) != "" { a.state.ActiveRunID = strings.TrimSpace(event.RunID) @@ -960,7 +1035,19 @@ func runtimeEventUserMessageHandler(a *App, event agentruntime.RuntimeEvent) boo a.state.CurrentTool = "" a.state.ExecutionError = "" a.setRunProgress(0.15, "Queued") - return false + payload, ok := event.Payload.(providertypes.Message) + if !ok { + return false + } + content := renderMessagePartsForDisplay(payload.Parts) + if strings.TrimSpace(content) == "" || a.lastUserMatches(content) { + return false + } + a.activeMessages = append(a.activeMessages, providertypes.Message{ + Role: roleUser, + Parts: providertypes.CloneParts(payload.Parts), + }) + return true } // runtimeEventRunContextHandler 处理 runtime 上下文事件并回填界面状态。 @@ -975,15 +1062,9 @@ func runtimeEventRunContextHandler(a *App, event agentruntime.RuntimeEvent) bool a.state.ActiveRunID = mapped.RunID } if strings.TrimSpace(mapped.Provider) != "" { - if !strings.EqualFold(strings.TrimSpace(a.state.CurrentProvider), strings.TrimSpace(mapped.Provider)) { - a.invalidateModelCapabilityCache() - } a.state.CurrentProvider = mapped.Provider } if strings.TrimSpace(mapped.Model) != "" { - if !strings.EqualFold(strings.TrimSpace(a.state.CurrentModel), strings.TrimSpace(mapped.Model)) { - a.invalidateModelCapabilityCache() - } a.state.CurrentModel = mapped.Model } if strings.TrimSpace(mapped.Workdir) != "" { @@ -1112,7 +1193,7 @@ func runtimeEventAgentDoneHandler(a *App, event agentruntime.RuntimeEvent) bool } // runtimeEventRunCanceledHandler 处理运行取消事件。 -func runtimeEventRunCanceledHandler(a *App, event agentruntime.RuntimeEvent) bool { +func runtimeEventRunCanceledHandler(a *App) bool { a.state.IsAgentRunning = false a.state.StreamingReply = false a.state.CurrentTool = "" @@ -1332,6 +1413,15 @@ func (a *App) lastAssistantMatches(content string) bool { return last.Role == roleAssistant && strings.TrimSpace(renderMessagePartsForDisplay(last.Parts)) == strings.TrimSpace(content) } +// lastUserMatches 判断末条用户消息是否与给定文本一致,避免重复渲染。 +func (a *App) lastUserMatches(content string) bool { + if len(a.activeMessages) == 0 { + return false + } + last := a.activeMessages[len(a.activeMessages)-1] + return last.Role == roleUser && strings.TrimSpace(renderMessagePartsForDisplay(last.Parts)) == strings.TrimSpace(content) +} + func (a *App) handleViewportKeys(vp *viewport.Model, msg tea.KeyMsg) { switch { case key.Matches(msg, a.keys.ScrollUp): @@ -1895,15 +1985,10 @@ func ListenForRuntimeEvent(sub <-chan agentruntime.RuntimeEvent) tea.Cmd { ) } -func runAgent(runtime agentruntime.Runtime, runID string, sessionID string, workdir string, content string) tea.Cmd { - return tuiservices.RunAgentCmd( +func runAgent(runtime agentruntime.Runtime, input agentruntime.PrepareInput) tea.Cmd { + return tuiservices.RunSubmitCmd( runtime, - agentruntime.UserInput{ - SessionID: sessionID, - RunID: strings.TrimSpace(runID), - Parts: []providertypes.ContentPart{providertypes.NewTextPart(content)}, - Workdir: workdir, - }, + input, func(err error) tea.Msg { return runFinishedMsg{Err: err} }, ) } diff --git a/internal/tui/core/app/update_permission_test.go b/internal/tui/core/app/update_permission_test.go index 47c18475..5b24ef62 100644 --- a/internal/tui/core/app/update_permission_test.go +++ b/internal/tui/core/app/update_permission_test.go @@ -23,6 +23,20 @@ type permissionTestRuntime struct { lastResolved agentruntime.PermissionResolutionInput } +func (r *permissionTestRuntime) PrepareUserInput(ctx context.Context, input agentruntime.PrepareInput) (agentruntime.UserInput, error) { + return agentruntime.UserInput{ + SessionID: input.SessionID, + RunID: input.RunID, + Parts: nil, + Workdir: input.Workdir, + }, nil +} + +func (r *permissionTestRuntime) Submit(ctx context.Context, input agentruntime.PrepareInput) error { + _, err := r.PrepareUserInput(ctx, input) + return err +} + func (r *permissionTestRuntime) Run(ctx context.Context, input agentruntime.UserInput) error { return nil } diff --git a/internal/tui/core/app/update_test.go b/internal/tui/core/app/update_test.go index 8ff0d3a5..ccf23b8a 100644 --- a/internal/tui/core/app/update_test.go +++ b/internal/tui/core/app/update_test.go @@ -80,6 +80,9 @@ func (s stubProviderService) SetCurrentModel(ctx context.Context, modelID string type stubRuntime struct { events chan agentruntime.RuntimeEvent + prepareInputs []agentruntime.PrepareInput + prepareErr error + preparedOutput agentruntime.UserInput runInputs []agentruntime.UserInput resolveCalls []agentruntime.PermissionResolutionInput resolveErr error @@ -101,6 +104,38 @@ func newStubRuntime() *stubRuntime { return &stubRuntime{events: make(chan agentruntime.RuntimeEvent)} } +func (s *stubRuntime) PrepareUserInput(ctx context.Context, input agentruntime.PrepareInput) (agentruntime.UserInput, error) { + s.prepareInputs = append(s.prepareInputs, input) + if s.prepareErr != nil { + return agentruntime.UserInput{}, s.prepareErr + } + if len(s.preparedOutput.Parts) > 0 { + return s.preparedOutput, nil + } + sessionID := strings.TrimSpace(input.SessionID) + if sessionID == "" { + sessionID = "session-prepared" + } + content := strings.TrimSpace(input.Text) + if content == "" { + content = "image input" + } + return agentruntime.UserInput{ + SessionID: sessionID, + RunID: strings.TrimSpace(input.RunID), + Parts: []providertypes.ContentPart{providertypes.NewTextPart(content)}, + Workdir: strings.TrimSpace(input.Workdir), + }, nil +} + +func (s *stubRuntime) Submit(ctx context.Context, input agentruntime.PrepareInput) error { + prepared, err := s.PrepareUserInput(ctx, input) + if err != nil { + return err + } + return s.Run(ctx, prepared) +} + func (s *stubRuntime) Run(ctx context.Context, input agentruntime.UserInput) error { s.runInputs = append(s.runInputs, input) return nil @@ -1474,44 +1509,6 @@ func TestRuntimeEventRunContextHandler(t *testing.T) { } } -func TestRuntimeEventRunContextHandlerInvalidatesModelCapabilityCache(t *testing.T) { - app, _ := newTestApp(t) - app.state.CurrentProvider = "provider-a" - app.state.CurrentModel = "model-a" - app.currentModelCapabilities = modelCapabilityState{ - checked: true, - supportsImageInput: true, - } - - payload := tuiservices.RuntimeRunContextPayload{ - Provider: "provider-b", - Model: "model-b", - } - _ = runtimeEventRunContextHandler(&app, agentruntime.RuntimeEvent{Payload: payload}) - if app.currentModelCapabilities.checked { - t.Fatalf("expected capability cache to be invalidated when provider/model changes") - } -} - -func TestSyncConfigStateInvalidatesModelCapabilityCache(t *testing.T) { - app, _ := newTestApp(t) - app.state.CurrentProvider = "provider-a" - app.state.CurrentModel = "model-a" - app.currentModelCapabilities = modelCapabilityState{ - checked: true, - supportsImageInput: true, - } - - app.syncConfigState(config.Config{ - SelectedProvider: "provider-b", - CurrentModel: "model-b", - Workdir: app.state.CurrentWorkdir, - }) - if app.currentModelCapabilities.checked { - t.Fatalf("expected capability cache to be invalidated") - } -} - func TestUpdatePasteImageShortcutFailure(t *testing.T) { app, _ := newTestApp(t) model, cmd := app.Update(tea.KeyMsg{Type: tea.KeyCtrlV}) @@ -1563,8 +1560,8 @@ func TestUpdateEnterImageReferencePath(t *testing.T) { } } -func TestUpdateSendWithUnsupportedImageInput(t *testing.T) { - app, _ := newTestApp(t) +func TestUpdateSendWithUnsupportedImageInputDoesNotPreBlock(t *testing.T) { + app, runtime := newTestApp(t) app.pendingImageAttachments = []pendingImageAttachment{ {Name: "a.png", MimeType: "image/png", Path: "/tmp/a.png", Size: 1}, } @@ -1586,25 +1583,25 @@ func TestUpdateSendWithUnsupportedImageInput(t *testing.T) { _ = cmd() } app = model.(App) - if app.state.IsAgentRunning { - t.Fatalf("expected send to be blocked for unsupported model image input") + if !app.state.IsAgentRunning { + t.Fatalf("expected send not to be pre-blocked by model capability hints") } if app.hasImageAttachments() { - t.Fatalf("expected pending image attachments to be cleared on unsupported model") + t.Fatalf("expected pending image attachments to be cleared after send") } - if app.state.StatusText != "Model does not support images" { + if app.state.StatusText != statusThinking { t.Fatalf("unexpected status text: %q", app.state.StatusText) } - if app.input.Value() != "hello" { - t.Fatalf("expected input to be preserved when send is blocked, got %q", app.input.Value()) + if app.input.Value() != "" || app.state.InputText != "" { + t.Fatalf("expected input to reset after send, got input=%q state=%q", app.input.Value(), app.state.InputText) } - if app.state.InputText != "hello" { - t.Fatalf("expected state input text to be preserved, got %q", app.state.InputText) + if len(runtime.prepareInputs) != 1 || len(runtime.prepareInputs[0].Images) != 1 { + t.Fatalf("expected image to flow into prepare pipeline, got %+v", runtime.prepareInputs) } } -func TestUpdateSendWithImageAttachmentsWithoutSessionAssets(t *testing.T) { - app, _ := newTestApp(t) +func TestUpdateSendWithImageAttachmentsUsesPreparePipeline(t *testing.T) { + app, runtime := newTestApp(t) app.pendingImageAttachments = []pendingImageAttachment{ {Name: "a.png", MimeType: "image/png", Path: "/tmp/a.png", Size: 1}, } @@ -1626,20 +1623,72 @@ func TestUpdateSendWithImageAttachmentsWithoutSessionAssets(t *testing.T) { _ = cmd() } app = model.(App) - if app.state.IsAgentRunning { - t.Fatalf("expected send to be blocked when session assets are unavailable") + if !app.state.IsAgentRunning { + t.Fatalf("expected send to enter running state") } if app.hasImageAttachments() { - t.Fatalf("expected pending image attachments to be cleared when storage is unavailable") + t.Fatalf("expected pending image attachments to be cleared after send") } - if app.state.StatusText != "Image attachments need session asset support" { + if app.state.StatusText != statusThinking { t.Fatalf("unexpected status text: %q", app.state.StatusText) } - if app.input.Value() != "hello" { - t.Fatalf("expected input to be preserved when send is blocked, got %q", app.input.Value()) + if app.input.Value() != "" { + t.Fatalf("expected input to be reset after send, got %q", app.input.Value()) + } + if app.state.InputText != "" { + t.Fatalf("expected state input text to reset after send, got %q", app.state.InputText) + } + if len(runtime.prepareInputs) != 1 { + t.Fatalf("expected one prepare input, got %+v", runtime.prepareInputs) + } + if len(runtime.prepareInputs[0].Images) != 1 || runtime.prepareInputs[0].Images[0].MimeType != "image/png" { + t.Fatalf("expected image metadata to flow through prepare input, got %+v", runtime.prepareInputs[0].Images) + } + if len(runtime.runInputs) != 1 { + t.Fatalf("expected one runtime run input, got %+v", runtime.runInputs) + } +} + +func TestUpdateSendWithInlineImageReferenceUsesPreparePipeline(t *testing.T) { + app, runtime := newTestApp(t) + root := t.TempDir() + app.state.CurrentWorkdir = root + + imagePath := filepath.Join(root, "burn.png") + if err := os.WriteFile(imagePath, []byte("png"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + app.providerSvc = stubProviderService{ + providers: []configstate.ProviderOption{{ID: app.state.CurrentProvider, Name: app.state.CurrentProvider}}, + models: []providertypes.ModelDescriptor{{ + ID: app.state.CurrentModel, + Name: app.state.CurrentModel, + CapabilityHints: providertypes.ModelCapabilityHints{ + ImageInput: providertypes.ModelCapabilityStateSupported, + }, + }}, + } + + app.input.SetValue("请分析 @burn.png") + app.state.InputText = "请分析 @burn.png" + + model, cmd := app.Update(tea.KeyMsg{Type: tea.KeyEnter}) + if cmd != nil { + _ = cmd() + } + app = model.(App) + + if len(runtime.prepareInputs) != 1 { + t.Fatalf("expected one prepare input, got %+v", runtime.prepareInputs) + } + if runtime.prepareInputs[0].Text != "请分析" { + t.Fatalf("expected inline image token removed from text, got %q", runtime.prepareInputs[0].Text) + } + if len(runtime.prepareInputs[0].Images) != 1 || runtime.prepareInputs[0].Images[0].MimeType != "" { + t.Fatalf("expected one promoted image in prepare input, got %+v", runtime.prepareInputs[0].Images) } - if app.state.InputText != "hello" { - t.Fatalf("expected state input text to be preserved, got %q", app.state.InputText) + if len(runtime.runInputs) != 1 { + t.Fatalf("expected one runtime run input, got %+v", runtime.runInputs) } } @@ -1778,7 +1827,7 @@ func TestRuntimeEventAgentChunkHandler(t *testing.T) { func TestRuntimeEventRunCanceledHandler(t *testing.T) { app, _ := newTestApp(t) app.state.ActiveRunID = "run-3" - runtimeEventRunCanceledHandler(&app, agentruntime.RuntimeEvent{}) + runtimeEventRunCanceledHandler(&app) if app.state.StatusText != statusCanceled { t.Fatalf("expected canceled status") } diff --git a/internal/tui/services/runtime_service.go b/internal/tui/services/runtime_service.go index d4982d4d..16d5caa1 100644 --- a/internal/tui/services/runtime_service.go +++ b/internal/tui/services/runtime_service.go @@ -16,6 +16,12 @@ type Runner interface { Run(ctx context.Context, input agentruntime.UserInput) error } +// PreparedRunner 定义“输入归一化 + run”链路所需最小能力。 +// Submitter 定义 runtime 单入口提交所需的最小能力。 +type Submitter interface { + Submit(ctx context.Context, input agentruntime.PrepareInput) error +} + // Compactor 定义执行 runtime compact 所需最小能力。 type Compactor interface { Compact(ctx context.Context, input agentruntime.CompactInput) (agentruntime.CompactResult, error) @@ -53,6 +59,15 @@ func RunAgentCmd( } } +// RunPreparedAgentCmd 先执行输入归一化,再执行 runtime run,并将结果映射为 UI 消息。 +// RunSubmitCmd 执行 runtime 单入口提交,并将结果映射为 UI 消息。 +func RunSubmitCmd(runtime Submitter, input agentruntime.PrepareInput, doneMsg func(error) tea.Msg) tea.Cmd { + return func() tea.Msg { + err := runtime.Submit(context.Background(), input) + return doneMsg(err) + } +} + // RunCompactCmd 执行 runtime compact,并将结果映射为 UI 消息。 func RunCompactCmd( runtime Compactor, diff --git a/internal/tui/services/services_test.go b/internal/tui/services/services_test.go index 42b5eec4..bb5eefcb 100644 --- a/internal/tui/services/services_test.go +++ b/internal/tui/services/services_test.go @@ -26,6 +26,16 @@ func (s *stubRunner) Run(ctx context.Context, input agentruntime.UserInput) erro return s.err } +type stubSubmitter struct { + lastInput agentruntime.PrepareInput + err error +} + +func (s *stubSubmitter) Submit(ctx context.Context, input agentruntime.PrepareInput) error { + s.lastInput = input + return s.err +} + type stubCompactor struct { lastInput agentruntime.CompactInput err error @@ -105,6 +115,24 @@ func TestRunAgentCmd(t *testing.T) { } } +func TestRunSubmitCmd(t *testing.T) { + runner := &stubSubmitter{err: errors.New("run failed")} + prepareInput := agentruntime.PrepareInput{ + SessionID: "s1", + RunID: "run-1", + Workdir: "D:/", + Text: "hello", + Images: []agentruntime.UserImageInput{{Path: "C:/a.png", MimeType: "image/png"}}, + } + msg := RunSubmitCmd(runner, prepareInput, func(err error) tea.Msg { return err })() + if runner.lastInput.RunID != "run-1" || len(runner.lastInput.Images) != 1 { + t.Fatalf("unexpected submit input: %+v", runner.lastInput) + } + if err, ok := msg.(error); !ok || err == nil || err.Error() != "run failed" { + t.Fatalf("expected forwarded run error, got %T %#v", msg, msg) + } +} + func TestRunCompactCmd(t *testing.T) { compactor := &stubCompactor{err: errors.New("compact failed")} input := agentruntime.CompactInput{SessionID: "s2"} From 20a53122a209b0aa0088f9534b90ab418ded9481 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Fri, 17 Apr 2026 07:43:53 +0000 Subject: [PATCH 02/10] fix(tui,session,config): address review findings and add regression coverage - require explicit @image: token and preserve whitespace when absorbing inline image refs - dedupe user-message rendering by run id instead of message content - rollback newly created session when asset save fails - make custom provider directory load read-only friendly when directory is missing - add/adjust tests for all updated behaviors Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/config/loader_test.go | 14 +++--- internal/config/provider_loader.go | 6 +-- internal/session/input_preparer.go | 44 ++++++++++++++----- internal/session/input_preparer_test.go | 40 +++++++++++++++++ internal/tui/core/app/app.go | 1 + internal/tui/core/app/input_features.go | 42 +++++++++++++----- internal/tui/core/app/input_features_test.go | 45 ++++++++++++++++++-- internal/tui/core/app/update.go | 24 +++++------ internal/tui/core/app/update_test.go | 37 +++++++++++++++- 9 files changed, 202 insertions(+), 51 deletions(-) diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 211706ec..a93c3d0f 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -1052,10 +1052,10 @@ func TestLoadCustomProvidersReadDirAndStatErrors(t *testing.T) { _, err := loadCustomProviders(baseDir) if err == nil { - t.Fatal("expected create providers dir error") + t.Fatal("expected read providers dir error") } - if !strings.Contains(err.Error(), "create providers dir") { - t.Fatalf("expected create providers dir error, got %v", err) + if !strings.Contains(err.Error(), "read providers dir") { + t.Fatalf("expected read providers dir error, got %v", err) } }) @@ -1096,12 +1096,8 @@ func TestLoadCustomProvidersReturnsEmptyWhenProvidersDirMissing(t *testing.T) { if len(providers) != 0 { t.Fatalf("expected no custom providers, got %d", len(providers)) } - info, err := os.Stat(providersPath) - if err != nil { - t.Fatalf("expected providers dir to be created, got %v", err) - } - if !info.IsDir() { - t.Fatalf("expected providers path to be directory") + if _, err := os.Stat(providersPath); !os.IsNotExist(err) { + t.Fatalf("expected providers dir to remain missing, got %v", err) } } diff --git a/internal/config/provider_loader.go b/internal/config/provider_loader.go index d73ea428..dd51cb7d 100644 --- a/internal/config/provider_loader.go +++ b/internal/config/provider_loader.go @@ -62,11 +62,11 @@ type customProviderSettings struct { // loadCustomProviders 扫描 baseDir/providers 下的一层子目录,并将其中的 provider.yaml 解析为运行时配置。 func loadCustomProviders(baseDir string) ([]ProviderConfig, error) { providersDir := filepath.Join(strings.TrimSpace(baseDir), providersDirName) - if err := os.MkdirAll(providersDir, 0o755); err != nil { - return nil, fmt.Errorf("config: create providers dir: %w", err) - } entries, err := os.ReadDir(providersDir) if err != nil { + if os.IsNotExist(err) { + return nil, nil + } return nil, fmt.Errorf("config: read providers dir: %w", err) } diff --git a/internal/session/input_preparer.go b/internal/session/input_preparer.go index 28fe7910..17e4a01d 100644 --- a/internal/session/input_preparer.go +++ b/internal/session/input_preparer.go @@ -95,7 +95,7 @@ func (p *InputPreparer) Prepare(ctx context.Context, input PrepareInput) (Prepar } sessionTitle := buildSessionTitle(trimmedText, len(input.Images) > 0) - session, err := p.loadOrCreateSession( + session, sessionCreated, err := p.loadOrCreateSession( ctx, input.SessionID, sessionTitle, @@ -115,6 +115,7 @@ func (p *InputPreparer) Prepare(ctx context.Context, input PrepareInput) (Prepar for index, image := range input.Images { path := strings.TrimSpace(image.Path) if path == "" { + p.rollbackCreatedSession(ctx, session.ID, sessionCreated) return PreparedInput{}, &AssetSaveError{ Index: index, Path: path, @@ -125,6 +126,7 @@ func (p *InputPreparer) Prepare(ctx context.Context, input PrepareInput) (Prepar meta, err := p.saveImageAsset(ctx, session.ID, path, mimeType) if err != nil { + p.rollbackCreatedSession(ctx, session.ID, sessionCreated) return PreparedInput{}, &AssetSaveError{ Index: index, Path: path, @@ -136,6 +138,7 @@ func (p *InputPreparer) Prepare(ctx context.Context, input PrepareInput) (Prepar } if err := providertypes.ValidateParts(parts); err != nil { + p.rollbackCreatedSession(ctx, session.ID, sessionCreated) return PreparedInput{}, fmt.Errorf("session: normalize parts: %w", err) } @@ -211,41 +214,60 @@ func (p *InputPreparer) loadOrCreateSession( title string, defaultWorkdir string, requestedWorkdir string, -) (Session, error) { +) (Session, bool, error) { if strings.TrimSpace(sessionID) == "" { sessionWorkdir, err := resolveWorkdirForInput(defaultWorkdir, "", requestedWorkdir) if err != nil { - return Session{}, err + return Session{}, false, err } session := NewWithWorkdir(title, sessionWorkdir) if err := p.store.Save(ctx, &session); err != nil { - return Session{}, err + return Session{}, false, err } - return session, nil + return session, true, nil } session, err := p.store.Load(ctx, sessionID) if err != nil { - return Session{}, err + return Session{}, false, err } if strings.TrimSpace(requestedWorkdir) == "" && strings.TrimSpace(session.Workdir) != "" { - return session, nil + return session, false, nil } resolved, err := resolveWorkdirForInput(defaultWorkdir, session.Workdir, requestedWorkdir) if err != nil { - return Session{}, err + return Session{}, false, err } if session.Workdir == resolved { - return session, nil + return session, false, nil } session.Workdir = resolved session.UpdatedAt = time.Now() if err := p.store.Save(ctx, &session); err != nil { - return Session{}, err + return Session{}, false, err } - return session, nil + return session, false, nil +} + +// rollbackCreatedSession 在本次 Prepare 新建会话后发生错误时回滚会话目录,避免残留孤儿会话。 +func (p *InputPreparer) rollbackCreatedSession(ctx context.Context, sessionID string, created bool) { + if !created { + return + } + store, ok := p.store.(*JSONStore) + if !ok { + return + } + if err := ctx.Err(); err != nil { + return + } + target := store.sessionDir(sessionID) + if err := ensurePathWithinBase(store.baseDir, target); err != nil { + return + } + _ = os.RemoveAll(target) } func resolveWorkdirForInput(defaultWorkdir string, currentWorkdir string, requestedWorkdir string) (string, error) { diff --git a/internal/session/input_preparer_test.go b/internal/session/input_preparer_test.go index 8ef0d29d..1dd033bd 100644 --- a/internal/session/input_preparer_test.go +++ b/internal/session/input_preparer_test.go @@ -193,6 +193,46 @@ func TestInputPreparerPrepareErrors(t *testing.T) { t.Fatalf("expected save error index 0, got %d", saveErr.Index) } }) + + t.Run("new session is rolled back when asset save fails", func(t *testing.T) { + preparer := NewInputPreparer(store, store) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Images: []PrepareImageInput{{Path: "not-found.png", MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected asset save error") + } + + summaries, listErr := store.ListSummaries(context.Background()) + if listErr != nil { + t.Fatalf("ListSummaries() error = %v", listErr) + } + if len(summaries) != 0 { + t.Fatalf("expected no persisted session after rollback, got %+v", summaries) + } + }) + + t.Run("existing session is kept when asset save fails", func(t *testing.T) { + existing := NewWithWorkdir("existing", workdir) + if err := store.Save(context.Background(), &existing); err != nil { + t.Fatalf("Save() error = %v", err) + } + + preparer := NewInputPreparer(store, store) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + SessionID: existing.ID, + Images: []PrepareImageInput{{Path: "not-found.png", MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected asset save error") + } + + if _, loadErr := store.Load(context.Background(), existing.ID); loadErr != nil { + t.Fatalf("expected existing session to remain, load error = %v", loadErr) + } + }) } func TestInputPreparerPrepareUpdatesExistingSessionWorkdir(t *testing.T) { diff --git a/internal/tui/core/app/app.go b/internal/tui/core/app/app.go index 71953775..37769e16 100644 --- a/internal/tui/core/app/app.go +++ b/internal/tui/core/app/app.go @@ -106,6 +106,7 @@ type appRuntimeState struct { runProgressValue float64 runProgressKnown bool runProgressLabel string + lastUserMessageRunID string pendingPermission *permissionPromptState pendingImageAttachments []pendingImageAttachment providerAddForm *providerAddFormState diff --git a/internal/tui/core/app/input_features.go b/internal/tui/core/app/input_features.go index 892e5c01..e0da0f8b 100644 --- a/internal/tui/core/app/input_features.go +++ b/internal/tui/core/app/input_features.go @@ -217,20 +217,30 @@ func (a *App) applyImageReference(input string) error { return a.addImageAttachment(path) } -// absorbInlineImageReferences 会把输入文本中的 @ 令牌吸收到附件队列,并返回移除令牌后的文本。 -// 仅根据令牌语法与扩展名做轻量识别,避免把文件系统硬校验放到 TUI 层。 +// absorbInlineImageReferences 会把输入文本中的 @image: 令牌吸收到附件队列,并返回移除令牌后的文本。 +// 该实现保留原始空白布局,仅移除命中的图片令牌,避免改变用户提示词语义。 func (a *App) absorbInlineImageReferences(input string) (string, int, error) { - tokens := strings.Fields(input) - if len(tokens) == 0 { + if strings.TrimSpace(input) == "" { return strings.TrimSpace(input), 0, nil } - kept := make([]string, 0, len(tokens)) + var builder strings.Builder absorbed := 0 - for _, token := range tokens { + for i := 0; i < len(input); { + if isInlineTokenSpace(input[i]) { + builder.WriteByte(input[i]) + i++ + continue + } + + start := i + for i < len(input) && !isInlineTokenSpace(input[i]) { + i++ + } + token := input[start:i] imagePath, ok := a.parseInlineImagePathToken(token) if !ok { - kept = append(kept, token) + builder.WriteString(token) continue } if err := a.queueImageAttachmentForPrepare(imagePath); err != nil { @@ -239,17 +249,27 @@ func (a *App) absorbInlineImageReferences(input string) (string, int, error) { absorbed++ } - return strings.TrimSpace(strings.Join(kept, " ")), absorbed, nil + return strings.TrimSpace(builder.String()), absorbed, nil +} + +// isInlineTokenSpace 判断字符是否属于输入令牌分隔空白字符。 +func isInlineTokenSpace(ch byte) bool { + switch ch { + case ' ', '\t', '\r', '\n': + return true + default: + return false + } } -// parseInlineImagePathToken 识别 @ 形式的图片路径令牌,并映射为待发送路径。 +// parseInlineImagePathToken 识别 @image: 形式的图片路径令牌,并映射为待发送路径。 func (a *App) parseInlineImagePathToken(token string) (string, bool) { trimmed := strings.TrimSpace(token) - if !strings.HasPrefix(trimmed, fileReferencePrefix) || strings.HasPrefix(trimmed, imageReferencePrefix) { + if !strings.HasPrefix(trimmed, imageReferencePrefix) { return "", false } - path := strings.TrimPrefix(trimmed, fileReferencePrefix) + path := strings.TrimPrefix(trimmed, imageReferencePrefix) path = strings.Trim(path, `"'`) path = strings.TrimSpace(path) if path == "" || !looksLikeImagePath(path) { diff --git a/internal/tui/core/app/input_features_test.go b/internal/tui/core/app/input_features_test.go index 7e7e4aaf..329e5023 100644 --- a/internal/tui/core/app/input_features_test.go +++ b/internal/tui/core/app/input_features_test.go @@ -182,14 +182,14 @@ func TestAbsorbInlineImageReferences(t *testing.T) { t.Fatalf("write image: %v", err) } - normalized, absorbed, err := app.absorbInlineImageReferences("请分析 @chart.png 趋势") + normalized, absorbed, err := app.absorbInlineImageReferences("请分析 @image:chart.png 趋势") if err != nil { t.Fatalf("absorbInlineImageReferences() error = %v", err) } if absorbed != 1 { t.Fatalf("expected one absorbed image, got %d", absorbed) } - if normalized != "请分析 趋势" { + if normalized != "请分析 趋势" { t.Fatalf("unexpected normalized text: %q", normalized) } if app.getImageAttachmentCount() != 1 { @@ -197,6 +197,26 @@ func TestAbsorbInlineImageReferences(t *testing.T) { } } +func TestAbsorbInlineImageReferencesRequiresExplicitPrefix(t *testing.T) { + app, _ := newTestApp(t) + root := t.TempDir() + app.state.CurrentWorkdir = root + + normalized, absorbed, err := app.absorbInlineImageReferences("请分析 @chart.png 趋势") + if err != nil { + t.Fatalf("absorbInlineImageReferences() error = %v", err) + } + if absorbed != 0 { + t.Fatalf("expected absorbed image count to be 0, got %d", absorbed) + } + if normalized != "请分析 @chart.png 趋势" { + t.Fatalf("unexpected normalized text: %q", normalized) + } + if app.getImageAttachmentCount() != 0 { + t.Fatalf("expected no pending image attachments") + } +} + func TestAbsorbInlineImageReferencesKeepsNonImageToken(t *testing.T) { app, _ := newTestApp(t) root := t.TempDir() @@ -221,7 +241,7 @@ func TestAbsorbInlineImageReferencesDoesNotRequireFileExistenceInTUI(t *testing. app, _ := newTestApp(t) app.state.CurrentWorkdir = t.TempDir() - normalized, absorbed, err := app.absorbInlineImageReferences("处理 @not-exist.png") + normalized, absorbed, err := app.absorbInlineImageReferences("处理 @image:not-exist.png") if err != nil { t.Fatalf("absorbInlineImageReferences() error = %v", err) } @@ -239,6 +259,25 @@ func TestAbsorbInlineImageReferencesDoesNotRequireFileExistenceInTUI(t *testing. } } +func TestAbsorbInlineImageReferencesPreservesWhitespaceLayout(t *testing.T) { + app, _ := newTestApp(t) + app.state.CurrentWorkdir = t.TempDir() + + normalized, absorbed, err := app.absorbInlineImageReferences("A @image:x.png\nB\t @image:y.jpg C") + if err != nil { + t.Fatalf("absorbInlineImageReferences() error = %v", err) + } + if absorbed != 2 { + t.Fatalf("expected absorbed image count to be 2, got %d", absorbed) + } + if normalized != "A \nB\t C" { + t.Fatalf("unexpected normalized text: %q", normalized) + } + if app.getImageAttachmentCount() != 2 { + t.Fatalf("expected two pending image attachments") + } +} + func TestGetAndClearImageAttachments(t *testing.T) { app, _ := newTestApp(t) app.pendingImageAttachments = []pendingImageAttachment{ diff --git a/internal/tui/core/app/update.go b/internal/tui/core/app/update.go index aa0d93f0..2c7291c0 100644 --- a/internal/tui/core/app/update.go +++ b/internal/tui/core/app/update.go @@ -761,6 +761,7 @@ func (a *App) resetSessionRuntimeState() { a.state.StreamingReply = false a.state.CurrentTool = "" a.state.ActiveRunID = "" + a.lastUserMessageRunID = "" a.state.ToolStates = nil a.state.RunContext = tuistate.ContextWindowState{} a.state.TokenUsage = tuistate.TokenUsageState{} @@ -1027,8 +1028,9 @@ func runtimeEventAssetSaveFailedHandler(a *App, event agentruntime.RuntimeEvent) } func runtimeEventUserMessageHandler(a *App, event agentruntime.RuntimeEvent) bool { - if strings.TrimSpace(event.RunID) != "" { - a.state.ActiveRunID = strings.TrimSpace(event.RunID) + runID := strings.TrimSpace(event.RunID) + if runID != "" { + a.state.ActiveRunID = runID } a.state.StatusText = statusThinking a.state.StreamingReply = false @@ -1040,13 +1042,19 @@ func runtimeEventUserMessageHandler(a *App, event agentruntime.RuntimeEvent) boo return false } content := renderMessagePartsForDisplay(payload.Parts) - if strings.TrimSpace(content) == "" || a.lastUserMatches(content) { + if strings.TrimSpace(content) == "" { + return false + } + if runID != "" && strings.EqualFold(a.lastUserMessageRunID, runID) { return false } a.activeMessages = append(a.activeMessages, providertypes.Message{ Role: roleUser, Parts: providertypes.CloneParts(payload.Parts), }) + if runID != "" { + a.lastUserMessageRunID = runID + } return true } @@ -1413,15 +1421,6 @@ func (a *App) lastAssistantMatches(content string) bool { return last.Role == roleAssistant && strings.TrimSpace(renderMessagePartsForDisplay(last.Parts)) == strings.TrimSpace(content) } -// lastUserMatches 判断末条用户消息是否与给定文本一致,避免重复渲染。 -func (a *App) lastUserMatches(content string) bool { - if len(a.activeMessages) == 0 { - return false - } - last := a.activeMessages[len(a.activeMessages)-1] - return last.Role == roleUser && strings.TrimSpace(renderMessagePartsForDisplay(last.Parts)) == strings.TrimSpace(content) -} - func (a *App) handleViewportKeys(vp *viewport.Model, msg tea.KeyMsg) { switch { case key.Matches(msg, a.keys.ScrollUp): @@ -1949,6 +1948,7 @@ func (a *App) startDraftSession() { a.state.ExecutionError = "" a.state.CurrentTool = "" a.state.ActiveRunID = "" + a.lastUserMessageRunID = "" a.state.ToolStates = nil a.state.RunContext = tuistate.ContextWindowState{} a.state.TokenUsage = tuistate.TokenUsageState{} diff --git a/internal/tui/core/app/update_test.go b/internal/tui/core/app/update_test.go index ccf23b8a..269f33f4 100644 --- a/internal/tui/core/app/update_test.go +++ b/internal/tui/core/app/update_test.go @@ -1492,6 +1492,39 @@ func TestRuntimeEventUserMessageHandler(t *testing.T) { } } +func TestRuntimeEventUserMessageHandlerDeduplicatesByRunID(t *testing.T) { + app, _ := newTestApp(t) + payload := providertypes.Message{ + Role: roleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("same content")}, + } + event := agentruntime.RuntimeEvent{RunID: "run-1", Payload: payload} + + handled := runtimeEventUserMessageHandler(&app, event) + if !handled { + t.Fatalf("expected first user message to be rendered") + } + if len(app.activeMessages) != 1 { + t.Fatalf("expected one user message, got %d", len(app.activeMessages)) + } + + handled = runtimeEventUserMessageHandler(&app, event) + if handled { + t.Fatalf("expected duplicate run id to be ignored") + } + if len(app.activeMessages) != 1 { + t.Fatalf("expected one user message after duplicate event, got %d", len(app.activeMessages)) + } + + handled = runtimeEventUserMessageHandler(&app, agentruntime.RuntimeEvent{RunID: "run-2", Payload: payload}) + if !handled { + t.Fatalf("expected same content with new run id to be rendered") + } + if len(app.activeMessages) != 2 { + t.Fatalf("expected two user messages after new run id, got %d", len(app.activeMessages)) + } +} + func TestRuntimeEventRunContextHandler(t *testing.T) { app, _ := newTestApp(t) payload := tuiservices.RuntimeRunContextPayload{ @@ -1669,8 +1702,8 @@ func TestUpdateSendWithInlineImageReferenceUsesPreparePipeline(t *testing.T) { }}, } - app.input.SetValue("请分析 @burn.png") - app.state.InputText = "请分析 @burn.png" + app.input.SetValue("请分析 @image:burn.png") + app.state.InputText = "请分析 @image:burn.png" model, cmd := app.Update(tea.KeyMsg{Type: tea.KeyEnter}) if cmd != nil { From 56fb8c9625555b516a0687365a6e13375704c56d Mon Sep 17 00:00:00 2001 From: xgopilot Date: Fri, 17 Apr 2026 08:35:54 +0000 Subject: [PATCH 03/10] fix(runtime/session/tui): harden multimodal input safety and event isolation Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- .../chatcompletions/provider_test.go | 27 +++- .../openaicompat/chatcompletions/request.go | 14 +- internal/runtime/input_prepare.go | 3 + internal/runtime/input_prepare_test.go | 19 ++- .../runtime/runtime_internal_helpers_test.go | 4 + .../runtime_remaining_branches_test.go | 8 ++ internal/runtime/runtime_test.go | 18 +++ internal/runtime/todo_mutator_test.go | 10 ++ internal/session/input_preparer.go | 124 +++++++++++++----- internal/session/input_preparer_test.go | 93 ++++++++++++- internal/session/store.go | 23 ++++ internal/tui/core/app/input_features.go | 118 +++++++++++++---- internal/tui/core/app/input_features_test.go | 23 ++++ internal/tui/core/app/update.go | 19 +++ .../core/app/update_runtime_events_test.go | 30 +++++ 15 files changed, 465 insertions(+), 68 deletions(-) diff --git a/internal/provider/openaicompat/chatcompletions/provider_test.go b/internal/provider/openaicompat/chatcompletions/provider_test.go index 13335aaf..def97a35 100644 --- a/internal/provider/openaicompat/chatcompletions/provider_test.go +++ b/internal/provider/openaicompat/chatcompletions/provider_test.go @@ -97,13 +97,36 @@ func TestNewAndBuildRequest(t *testing.T) { if gotSchema["type"] != "object" { t.Fatalf("expected sanitized schema type object, got %+v", gotSchema["type"]) } - if _, ok := gotSchema["oneOf"]; ok { - t.Fatalf("expected sanitized schema to drop top-level oneOf, got %+v", gotSchema) + if _, ok := gotSchema["oneOf"]; !ok { + t.Fatalf("expected top-level oneOf to be preserved, got %+v", gotSchema) } if _, ok := toolSchemaWithTopLevelCombinator["oneOf"]; !ok { t.Fatalf("expected original schema not to be mutated") } + downgradedPayload, err := BuildRequest(context.Background(), testCfg("https://api.example.com/v1", "gpt-4.1", "test-key"), providertypes.GenerateRequest{ + Messages: []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}, + }, + Tools: []providertypes.ToolSpec{{ + Name: "non_object_schema", + Description: "schema root is string", + Schema: map[string]any{ + "type": "string", + }, + }}, + }) + if err != nil { + t.Fatalf("BuildRequest() downgrade schema error = %v", err) + } + downgradedSchema := downgradedPayload.Tools[0].Function.Parameters + if downgradedSchema["type"] != "object" { + t.Fatalf("expected downgraded schema type object, got %+v", downgradedSchema["type"]) + } + if downgradedSchema["x-neocode-schema-downgraded"] != true { + t.Fatalf("expected downgrade marker, got %+v", downgradedSchema) + } + withSessionAsset, err := BuildRequest(context.Background(), testCfg("https://api.example.com/v1", "gpt-4.1", "test-key"), providertypes.GenerateRequest{ Messages: []providertypes.Message{ { diff --git a/internal/provider/openaicompat/chatcompletions/request.go b/internal/provider/openaicompat/chatcompletions/request.go index 38a928a0..6c17eb36 100644 --- a/internal/provider/openaicompat/chatcompletions/request.go +++ b/internal/provider/openaicompat/chatcompletions/request.go @@ -77,8 +77,8 @@ func BuildRequest(ctx context.Context, cfg provider.RuntimeConfig, req providert return payload, nil } -// normalizeToolSchemaForOpenAI 归一化工具参数 schema,避免 OpenAI chat-completions 顶层关键字约束报错。 -// 仅收敛顶层结构:保证 type=object,并移除顶层 oneOf/anyOf/allOf/enum/not;嵌套语义保持原样。 +// normalizeToolSchemaForOpenAI 归一化工具参数 schema,避免修改调用方原始结构并尽量保持语义。 +// 仅在缺失 schema 或明显非法(非 object 顶层)时做最小兼容降级,不再删除顶层组合约束关键字。 func normalizeToolSchemaForOpenAI(schema map[string]any) map[string]any { normalized := cloneSchemaTopLevel(schema) if len(normalized) == 0 { @@ -91,17 +91,15 @@ func normalizeToolSchemaForOpenAI(schema map[string]any) map[string]any { typeName, _ := normalized["type"].(string) if strings.TrimSpace(strings.ToLower(typeName)) != "object" { normalized["type"] = "object" + normalized["x-neocode-schema-downgraded"] = true } if _, ok := normalized["properties"].(map[string]any); !ok { normalized["properties"] = map[string]any{} + if strings.TrimSpace(strings.ToLower(typeName)) != "object" { + normalized["x-neocode-schema-downgraded"] = true + } } - - delete(normalized, "oneOf") - delete(normalized, "anyOf") - delete(normalized, "allOf") - delete(normalized, "enum") - delete(normalized, "not") return normalized } diff --git a/internal/runtime/input_prepare.go b/internal/runtime/input_prepare.go index 30e467f6..74f6e4ff 100644 --- a/internal/runtime/input_prepare.go +++ b/internal/runtime/input_prepare.go @@ -83,6 +83,9 @@ func (s *Service) emitPrepareFailure(ctx context.Context, input PrepareInput, er var saveErr *agentsession.AssetSaveError if errors.As(err, &saveErr) { + if session := strings.TrimSpace(saveErr.SessionID); session != "" { + sessionID = session + } return s.emit(ctx, EventAssetSaveFailed, runID, sessionID, AssetSaveFailedPayload{ Index: saveErr.Index, Path: strings.TrimSpace(saveErr.Path), diff --git a/internal/runtime/input_prepare_test.go b/internal/runtime/input_prepare_test.go index d7a2e2c3..d9f6ad2c 100644 --- a/internal/runtime/input_prepare_test.go +++ b/internal/runtime/input_prepare_test.go @@ -19,7 +19,7 @@ func TestServicePrepareUserInputEmitsNormalizeAndAssetSaved(t *testing.T) { svc, _ := newPrepareTestService(t, workdir, true) imagePath := filepath.Join(workdir, "img.png") - if err := os.WriteFile(imagePath, []byte("fake-png"), 0o644); err != nil { + if err := os.WriteFile(imagePath, minimalPNGBytesForRuntimeTest(), 0o644); err != nil { t.Fatalf("write image: %v", err) } @@ -74,6 +74,9 @@ func TestServicePrepareUserInputEmitsAssetSaveFailed(t *testing.T) { if failedEvent.Type != EventAssetSaveFailed { t.Fatalf("expected event %q, got %q", EventAssetSaveFailed, failedEvent.Type) } + if failedEvent.SessionID == "" { + t.Fatalf("expected asset_save_failed event to include session id") + } payload, ok := failedEvent.Payload.(AssetSaveFailedPayload) if !ok || payload.Index != 0 { t.Fatalf("unexpected asset_save_failed payload: %#v", failedEvent.Payload) @@ -150,3 +153,17 @@ func mustReadRuntimeEvent(t *testing.T, events <-chan RuntimeEvent) RuntimeEvent return RuntimeEvent{} } } + +func minimalPNGBytesForRuntimeTest() []byte { + return []byte{ + 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, + 0x00, 0x00, 0x00, 0x0d, 0x49, 0x48, 0x44, 0x52, + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, + 0x08, 0x06, 0x00, 0x00, 0x00, 0x1f, 0x15, 0xc4, + 0x89, 0x00, 0x00, 0x00, 0x0d, 0x49, 0x44, 0x41, + 0x54, 0x78, 0x9c, 0x63, 0xf8, 0xcf, 0xc0, 0x00, + 0x00, 0x03, 0x01, 0x01, 0x00, 0xc9, 0xfe, 0x92, + 0xef, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4e, + 0x44, 0xae, 0x42, 0x60, 0x82, + } +} diff --git a/internal/runtime/runtime_internal_helpers_test.go b/internal/runtime/runtime_internal_helpers_test.go index 2e136e67..235665a9 100644 --- a/internal/runtime/runtime_internal_helpers_test.go +++ b/internal/runtime/runtime_internal_helpers_test.go @@ -40,6 +40,10 @@ func (s *lockProbeStore) ListSummaries(ctx context.Context) ([]agentsession.Summ return nil, errors.New("not implemented") } +func (s *lockProbeStore) DeleteSession(ctx context.Context, id string) error { + return errors.New("not implemented") +} + func (s *stubMemoExtractor) Schedule(_ string, messages []providertypes.Message) { s.mu.Lock() s.calls++ diff --git a/internal/runtime/runtime_remaining_branches_test.go b/internal/runtime/runtime_remaining_branches_test.go index 5e1fd701..0673073f 100644 --- a/internal/runtime/runtime_remaining_branches_test.go +++ b/internal/runtime/runtime_remaining_branches_test.go @@ -80,6 +80,10 @@ func (s *saveHookStore) ListSummaries(ctx context.Context) ([]agentsession.Summa return s.base.ListSummaries(ctx) } +func (s *saveHookStore) DeleteSession(ctx context.Context, id string) error { + return s.base.DeleteSession(ctx, id) +} + func (s *postSaveHookStore) Load(ctx context.Context, id string) (agentsession.Session, error) { return s.base.Load(ctx, id) } @@ -88,6 +92,10 @@ func (s *postSaveHookStore) ListSummaries(ctx context.Context) ([]agentsession.S return s.base.ListSummaries(ctx) } +func (s *postSaveHookStore) DeleteSession(ctx context.Context, id string) error { + return s.base.DeleteSession(ctx, id) +} + func TestResolveCompactProviderSelectionResolveErrorBranch(t *testing.T) { t.Parallel() diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 2bc2f93e..eaa9e30b 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -102,6 +102,14 @@ func (s *memoryStore) ListSummaries(ctx context.Context) ([]agentsession.Summary return summaries, nil } +func (s *memoryStore) DeleteSession(ctx context.Context, id string) error { + if err := ctx.Err(); err != nil { + return err + } + delete(s.sessions, id) + return nil +} + // blockingLoadStore 用于并发测试:首次 Load 阻塞,以验证同 session 的锁时序。 type blockingLoadStore struct { mu sync.Mutex @@ -177,6 +185,16 @@ func (s *blockingLoadStore) ListSummaries(ctx context.Context) ([]agentsession.S return summaries, nil } +func (s *blockingLoadStore) DeleteSession(ctx context.Context, id string) error { + if err := ctx.Err(); err != nil { + return err + } + s.mu.Lock() + delete(s.sessions, id) + s.mu.Unlock() + return nil +} + type scriptedProvider struct { name string streams [][]providertypes.StreamEvent diff --git a/internal/runtime/todo_mutator_test.go b/internal/runtime/todo_mutator_test.go index cbcf002a..8c5e4cce 100644 --- a/internal/runtime/todo_mutator_test.go +++ b/internal/runtime/todo_mutator_test.go @@ -45,6 +45,16 @@ func (s *mutatorStore) ListSummaries(ctx context.Context) ([]agentsession.Summar return nil, nil } +func (s *mutatorStore) DeleteSession(ctx context.Context, id string) error { + if err := ctx.Err(); err != nil { + return err + } + if s.last.ID == id { + s.last = agentsession.Session{} + } + return nil +} + func TestRuntimeSessionMutatorMutateAndSave(t *testing.T) { t.Parallel() diff --git a/internal/session/input_preparer.go b/internal/session/input_preparer.go index 17e4a01d..5d65a6ec 100644 --- a/internal/session/input_preparer.go +++ b/internal/session/input_preparer.go @@ -41,9 +41,10 @@ type PreparedInput struct { // AssetSaveError 描述图片落盘阶段的结构化失败信息,便于上层统一事件化处理。 type AssetSaveError struct { - Index int - Path string - Err error + SessionID string + Index int + Path string + Err error } func (e *AssetSaveError) Error() string { @@ -117,20 +118,22 @@ func (p *InputPreparer) Prepare(ctx context.Context, input PrepareInput) (Prepar if path == "" { p.rollbackCreatedSession(ctx, session.ID, sessionCreated) return PreparedInput{}, &AssetSaveError{ - Index: index, - Path: path, - Err: fmt.Errorf("image path is empty"), + SessionID: session.ID, + Index: index, + Path: path, + Err: fmt.Errorf("image path is empty"), } } mimeType := strings.TrimSpace(image.MimeType) - meta, err := p.saveImageAsset(ctx, session.ID, path, mimeType) + meta, err := p.saveImageAsset(ctx, session.ID, session.Workdir, path, mimeType) if err != nil { p.rollbackCreatedSession(ctx, session.ID, sessionCreated) return PreparedInput{}, &AssetSaveError{ - Index: index, - Path: path, - Err: err, + SessionID: session.ID, + Index: index, + Path: path, + Err: err, } } savedAssets = append(savedAssets, meta) @@ -150,10 +153,17 @@ func (p *InputPreparer) Prepare(ctx context.Context, input PrepareInput) (Prepar }, nil } -func (p *InputPreparer) saveImageAsset(ctx context.Context, sessionID string, path string, mimeType string) (AssetMeta, error) { - absolutePath, err := filepath.Abs(path) +// saveImageAsset 按会话工作目录解析并校验图片路径后落盘,禁止越界访问工作目录外文件。 +func (p *InputPreparer) saveImageAsset( + ctx context.Context, + sessionID string, + workdir string, + path string, + mimeType string, +) (AssetMeta, error) { + absolutePath, err := resolveImagePath(workdir, path) if err != nil { - return AssetMeta{}, fmt.Errorf("resolve image path: %w", err) + return AssetMeta{}, err } file, err := os.Open(absolutePath) @@ -176,22 +186,33 @@ func (p *InputPreparer) saveImageAsset(ctx context.Context, sessionID string, pa return meta, nil } -// resolveImageMimeType 解析图片 MIME 类型,优先使用显式传入值,其次回退到扩展名与文件头探测。 +// resolveImageMimeType 解析图片 MIME 类型,仅允许 image/*,并要求声明值与文件头探测一致。 func resolveImageMimeType(path string, declared string, file *os.File) (string, error) { - if normalized := strings.ToLower(strings.TrimSpace(declared)); normalized != "" { - return normalized, nil + detected, err := detectImageMimeTypeFromFile(file) + if err != nil { + return "", err } - extMime := strings.ToLower(strings.TrimSpace(mime.TypeByExtension(strings.ToLower(filepath.Ext(path))))) - if extMime != "" { - if idx := strings.Index(extMime, ";"); idx >= 0 { - extMime = strings.TrimSpace(extMime[:idx]) + declaredMime := normalizeMimeType(declared) + if declaredMime != "" { + if !strings.HasPrefix(declaredMime, "image/") { + return "", fmt.Errorf("declared mime type %q is not an image", declared) } - if strings.HasPrefix(extMime, "image/") { - return extMime, nil + if declaredMime != detected { + return "", fmt.Errorf("declared mime type %q mismatches detected %q", declaredMime, detected) } + return detected, nil + } + + extMime := normalizeMimeType(mime.TypeByExtension(strings.ToLower(filepath.Ext(path)))) + if extMime != "" && strings.HasPrefix(extMime, "image/") && extMime != detected { + return "", fmt.Errorf("file extension mime %q mismatches detected %q", extMime, detected) } + return detected, nil +} +// detectImageMimeTypeFromFile 根据文件头探测 MIME,且要求结果为 image/*。 +func detectImageMimeTypeFromFile(file *os.File) (string, error) { buffer := make([]byte, 512) n, readErr := file.Read(buffer) if readErr != nil && readErr != io.EOF { @@ -208,6 +229,55 @@ func resolveImageMimeType(path string, declared string, file *os.File) (string, return "", fmt.Errorf("unsupported image format") } +// normalizeMimeType 清洗 MIME 字符串并移除参数段,返回小写标准形式。 +func normalizeMimeType(value string) string { + normalized := strings.ToLower(strings.TrimSpace(value)) + if normalized == "" { + return "" + } + if idx := strings.Index(normalized, ";"); idx >= 0 { + normalized = strings.TrimSpace(normalized[:idx]) + } + return normalized +} + +// resolveImagePath 以会话工作目录为基准解析图片路径并强制限制在工作目录内。 +func resolveImagePath(workdir string, path string) (string, error) { + base := strings.TrimSpace(workdir) + if base == "" { + return "", fmt.Errorf("resolve image path: workdir is empty") + } + baseAbs, err := filepath.Abs(base) + if err != nil { + return "", fmt.Errorf("resolve image path base: %w", err) + } + + target := strings.TrimSpace(path) + if target == "" { + return "", fmt.Errorf("resolve image path: path is empty") + } + if !filepath.IsAbs(target) { + target = filepath.Join(baseAbs, target) + } + + targetAbs, err := filepath.Abs(target) + if err != nil { + return "", fmt.Errorf("resolve image path: %w", err) + } + if err := ensurePathWithinBase(baseAbs, targetAbs); err != nil { + return "", fmt.Errorf("resolve image path: %w", err) + } + + resolved := targetAbs + if linkTarget, linkErr := filepath.EvalSymlinks(targetAbs); linkErr == nil { + if err := ensurePathWithinBase(baseAbs, linkTarget); err != nil { + return "", fmt.Errorf("resolve image path: %w", err) + } + resolved = linkTarget + } + return resolved, nil +} + func (p *InputPreparer) loadOrCreateSession( ctx context.Context, sessionID string, @@ -256,18 +326,10 @@ func (p *InputPreparer) rollbackCreatedSession(ctx context.Context, sessionID st if !created { return } - store, ok := p.store.(*JSONStore) - if !ok { - return - } if err := ctx.Err(); err != nil { return } - target := store.sessionDir(sessionID) - if err := ensurePathWithinBase(store.baseDir, target); err != nil { - return - } - _ = os.RemoveAll(target) + _ = p.store.DeleteSession(ctx, sessionID) } func resolveWorkdirForInput(defaultWorkdir string, currentWorkdir string, requestedWorkdir string) (string, error) { diff --git a/internal/session/input_preparer_test.go b/internal/session/input_preparer_test.go index 1dd033bd..0af31f2b 100644 --- a/internal/session/input_preparer_test.go +++ b/internal/session/input_preparer_test.go @@ -6,6 +6,7 @@ import ( "io" "os" "path/filepath" + "strings" "testing" providertypes "neo-code/internal/provider/types" @@ -44,7 +45,7 @@ func TestInputPreparerPrepareTextAndImage(t *testing.T) { preparer := NewInputPreparer(store, store) imagePath := filepath.Join(workdir, "img.png") - payload := []byte("fake-png") + payload := minimalPNGBytes() if err := os.WriteFile(imagePath, payload, 0o644); err != nil { t.Fatalf("write image: %v", err) } @@ -93,7 +94,7 @@ func TestInputPreparerPrepareImageInfersMimeWhenMissing(t *testing.T) { preparer := NewInputPreparer(store, store) imagePath := filepath.Join(workdir, "auto.png") - if err := os.WriteFile(imagePath, []byte("fake-png"), 0o644); err != nil { + if err := os.WriteFile(imagePath, minimalPNGBytes(), 0o644); err != nil { t.Fatalf("write image: %v", err) } @@ -121,7 +122,7 @@ func TestInputPreparerPrepareImageOnlyUsesImageTitle(t *testing.T) { preparer := NewInputPreparer(store, store) imagePath := filepath.Join(workdir, "only.png") - if err := os.WriteFile(imagePath, []byte("img"), 0o644); err != nil { + if err := os.WriteFile(imagePath, minimalPNGBytes(), 0o644); err != nil { t.Fatalf("write image: %v", err) } @@ -192,6 +193,9 @@ func TestInputPreparerPrepareErrors(t *testing.T) { if saveErr.Index != 0 { t.Fatalf("expected save error index 0, got %d", saveErr.Index) } + if saveErr.SessionID == "" { + t.Fatalf("expected save error session id") + } }) t.Run("new session is rolled back when asset save fails", func(t *testing.T) { @@ -235,6 +239,89 @@ func TestInputPreparerPrepareErrors(t *testing.T) { }) } +func TestInputPreparerPrepareImagePathAndMimeValidation(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + store := NewStore(t.TempDir(), workdir) + preparer := NewInputPreparer(store, store) + + t.Run("relative path is resolved by workdir", func(t *testing.T) { + relativeDir := filepath.Join(workdir, "images") + if err := os.MkdirAll(relativeDir, 0o755); err != nil { + t.Fatalf("mkdir images: %v", err) + } + imagePath := filepath.Join(relativeDir, "a.png") + if err := os.WriteFile(imagePath, minimalPNGBytes(), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + result, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "relative path", + Images: []PrepareImageInput{{Path: filepath.Join("images", "a.png")}}, + DefaultWorkdir: workdir, + }) + if err != nil { + t.Fatalf("Prepare() error = %v", err) + } + if len(result.SavedAssets) != 1 || result.SavedAssets[0].MimeType != "image/png" { + t.Fatalf("unexpected saved assets: %+v", result.SavedAssets) + } + }) + + t.Run("path outside workdir is rejected", func(t *testing.T) { + outside := filepath.Join(t.TempDir(), "outside.png") + if err := os.WriteFile(outside, minimalPNGBytes(), 0o644); err != nil { + t.Fatalf("write outside image: %v", err) + } + + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "outside", + Images: []PrepareImageInput{{Path: outside, MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected outside workdir error") + } + if !strings.Contains(err.Error(), "escapes base dir") { + t.Fatalf("expected escapes base dir error, got %v", err) + } + }) + + t.Run("declared mime mismatch with file header is rejected", func(t *testing.T) { + imagePath := filepath.Join(workdir, "declared-mismatch.png") + if err := os.WriteFile(imagePath, minimalPNGBytes(), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "declared mismatch", + Images: []PrepareImageInput{{Path: imagePath, MimeType: "image/jpeg"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected mime mismatch error") + } + if !strings.Contains(err.Error(), "mismatches detected") { + t.Fatalf("expected mismatch error, got %v", err) + } + }) +} + +func minimalPNGBytes() []byte { + return []byte{ + 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, + 0x00, 0x00, 0x00, 0x0d, 0x49, 0x48, 0x44, 0x52, + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, + 0x08, 0x06, 0x00, 0x00, 0x00, 0x1f, 0x15, 0xc4, + 0x89, 0x00, 0x00, 0x00, 0x0d, 0x49, 0x44, 0x41, + 0x54, 0x78, 0x9c, 0x63, 0xf8, 0xcf, 0xc0, 0x00, + 0x00, 0x03, 0x01, 0x01, 0x00, 0xc9, 0xfe, 0x92, + 0xef, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4e, + 0x44, 0xae, 0x42, 0x60, 0x82, + } +} + func TestInputPreparerPrepareUpdatesExistingSessionWorkdir(t *testing.T) { t.Parallel() diff --git a/internal/session/store.go b/internal/session/store.go index 147a3527..976318ac 100644 --- a/internal/session/store.go +++ b/internal/session/store.go @@ -61,6 +61,7 @@ type Store interface { Save(ctx context.Context, session *Session) error Load(ctx context.Context, id string) (Session, error) ListSummaries(ctx context.Context) ([]Summary, error) + DeleteSession(ctx context.Context, id string) error } // JSONStore 是基于 JSON 文件的会话存储实现。 @@ -221,6 +222,28 @@ func (s *JSONStore) ListSummaries(ctx context.Context) ([]Summary, error) { return summaries, nil } +// DeleteSession 删除指定会话目录及其附件,供创建后失败回滚等场景复用。 +func (s *JSONStore) DeleteSession(ctx context.Context, id string) error { + if err := ctx.Err(); err != nil { + return err + } + if err := validateStorageID("session id", id); err != nil { + return fmt.Errorf("session: %w", err) + } + + s.mu.Lock() + defer s.mu.Unlock() + + target := s.sessionDir(id) + if err := ensurePathWithinBase(s.baseDir, target); err != nil { + return fmt.Errorf("session: resolve session dir path: %w", err) + } + if err := os.RemoveAll(target); err != nil { + return fmt.Errorf("session: delete session dir: %w", err) + } + return nil +} + // filePath 生成会话 ID 对应的 JSON 文件路径。 func (s *JSONStore) filePath(id string) string { return filepath.Join(s.sessionDir(id), sessionFileName) diff --git a/internal/tui/core/app/input_features.go b/internal/tui/core/app/input_features.go index e0da0f8b..99e164b2 100644 --- a/internal/tui/core/app/input_features.go +++ b/internal/tui/core/app/input_features.go @@ -227,26 +227,17 @@ func (a *App) absorbInlineImageReferences(input string) (string, int, error) { var builder strings.Builder absorbed := 0 for i := 0; i < len(input); { - if isInlineTokenSpace(input[i]) { - builder.WriteByte(input[i]) - i++ - continue - } - - start := i - for i < len(input) && !isInlineTokenSpace(input[i]) { - i++ - } - token := input[start:i] - imagePath, ok := a.parseInlineImagePathToken(token) - if !ok { - builder.WriteString(token) + imagePath, end, ok := parseInlineImageReferenceAt(input, i) + if ok && looksLikeImagePath(imagePath) { + if err := a.queueImageAttachmentForPrepare(imagePath); err != nil { + return "", absorbed, err + } + absorbed++ + i = end continue } - if err := a.queueImageAttachmentForPrepare(imagePath); err != nil { - return "", absorbed, err - } - absorbed++ + builder.WriteByte(input[i]) + i++ } return strings.TrimSpace(builder.String()), absorbed, nil @@ -264,13 +255,10 @@ func isInlineTokenSpace(ch byte) bool { // parseInlineImagePathToken 识别 @image: 形式的图片路径令牌,并映射为待发送路径。 func (a *App) parseInlineImagePathToken(token string) (string, bool) { - trimmed := strings.TrimSpace(token) - if !strings.HasPrefix(trimmed, imageReferencePrefix) { + path, _, ok := parseInlineImageReferenceAt(strings.TrimSpace(token), 0) + if !ok { return "", false } - - path := strings.TrimPrefix(trimmed, imageReferencePrefix) - path = strings.Trim(path, `"'`) path = strings.TrimSpace(path) if path == "" || !looksLikeImagePath(path) { return "", false @@ -287,6 +275,90 @@ func (a *App) parseInlineImagePathToken(token string) (string, bool) { return resolved, true } +// parseInlineImageReferenceAt 从输入指定位置解析 @image:,支持引号与空格路径。 +func parseInlineImageReferenceAt(input string, start int) (path string, end int, ok bool) { + if start < 0 || start >= len(input) { + return "", 0, false + } + if start > 0 && !isInlineTokenSpace(input[start-1]) { + return "", 0, false + } + if !strings.HasPrefix(input[start:], imageReferencePrefix) { + return "", 0, false + } + + cursor := start + len(imageReferencePrefix) + if cursor >= len(input) { + return "", 0, false + } + + quotedPath, quotedEnd, quoted := readQuotedInlinePath(input, cursor) + if quoted { + if strings.TrimSpace(quotedPath) == "" { + return "", 0, false + } + return strings.TrimSpace(quotedPath), quotedEnd, true + } + + unquotedPath, unquotedEnd := readUnquotedInlinePath(input, cursor) + unquotedPath = strings.TrimSpace(unquotedPath) + if unquotedPath == "" { + return "", 0, false + } + return unquotedPath, unquotedEnd, true +} + +// readQuotedInlinePath 读取带引号路径,支持 \" 和 \' 转义。 +func readQuotedInlinePath(input string, start int) (string, int, bool) { + if start >= len(input) { + return "", 0, false + } + quote := input[start] + if quote != '"' && quote != '\'' { + return "", 0, false + } + var builder strings.Builder + for i := start + 1; i < len(input); i++ { + ch := input[i] + if ch == '\\' && i+1 < len(input) { + next := input[i+1] + if next == quote || next == '\\' { + builder.WriteByte(next) + i++ + continue + } + } + if ch == quote { + return builder.String(), i + 1, true + } + builder.WriteByte(ch) + } + return "", 0, false +} + +// readUnquotedInlinePath 读取非引号路径,遇到空白或换行结束,支持反斜杠转义空白字符。 +func readUnquotedInlinePath(input string, start int) (string, int) { + var builder strings.Builder + end := start + for end < len(input) { + ch := input[end] + if isInlineTokenSpace(ch) { + break + } + if ch == '\\' && end+1 < len(input) { + next := input[end+1] + if isInlineTokenSpace(next) { + builder.WriteByte(next) + end += 2 + continue + } + } + builder.WriteByte(ch) + end++ + } + return builder.String(), end +} + // queueImageAttachmentForPrepare 将图片路径排队为待发送附件,不在 TUI 层做文件系统和 MIME 硬校验。 // 真正的可用性校验与错误语义统一在 runtime/session 归一化阶段完成。 func (a *App) queueImageAttachmentForPrepare(path string) error { diff --git a/internal/tui/core/app/input_features_test.go b/internal/tui/core/app/input_features_test.go index 329e5023..539f32f9 100644 --- a/internal/tui/core/app/input_features_test.go +++ b/internal/tui/core/app/input_features_test.go @@ -278,6 +278,29 @@ func TestAbsorbInlineImageReferencesPreservesWhitespaceLayout(t *testing.T) { } } +func TestAbsorbInlineImageReferencesSupportsQuotedPathWithSpaces(t *testing.T) { + app, _ := newTestApp(t) + root := t.TempDir() + app.state.CurrentWorkdir = root + + normalized, absorbed, err := app.absorbInlineImageReferences(`请分析 @image:"charts/sales q1.png" 趋势`) + if err != nil { + t.Fatalf("absorbInlineImageReferences() error = %v", err) + } + if absorbed != 1 { + t.Fatalf("expected absorbed image count to be 1, got %d", absorbed) + } + if normalized != "请分析 趋势" { + t.Fatalf("unexpected normalized text: %q", normalized) + } + if app.getImageAttachmentCount() != 1 { + t.Fatalf("expected one pending image attachment") + } + if !strings.HasSuffix(app.getImageAttachments()[0].Path, filepath.FromSlash("charts/sales q1.png")) { + t.Fatalf("unexpected queued path: %q", app.getImageAttachments()[0].Path) + } +} + func TestGetAndClearImageAttachments(t *testing.T) { app, _ := newTestApp(t) app.pendingImageAttachments = []pendingImageAttachment{ diff --git a/internal/tui/core/app/update.go b/internal/tui/core/app/update.go index 2c7291c0..f18ebb5c 100644 --- a/internal/tui/core/app/update.go +++ b/internal/tui/core/app/update.go @@ -963,6 +963,9 @@ func runtimeEventStopReasonDecidedHandler(a *App, event agentruntime.RuntimeEven // handleRuntimeEvent 通过注册表分发 runtime 事件,避免巨型 switch 膨胀。 func (a *App) handleRuntimeEvent(event agentruntime.RuntimeEvent) bool { + if !a.shouldHandleRuntimeEvent(event) { + return false + } if a.state.ActiveSessionID == "" { a.state.ActiveSessionID = event.SessionID } @@ -973,6 +976,22 @@ func (a *App) handleRuntimeEvent(event agentruntime.RuntimeEvent) bool { return handler(a, event) } +// shouldHandleRuntimeEvent 校验事件与当前活跃会话/运行上下文的关联,避免跨会话污染 UI 状态。 +func (a *App) shouldHandleRuntimeEvent(event agentruntime.RuntimeEvent) bool { + activeSessionID := strings.TrimSpace(a.state.ActiveSessionID) + eventSessionID := strings.TrimSpace(event.SessionID) + if activeSessionID != "" && eventSessionID != "" && !strings.EqualFold(activeSessionID, eventSessionID) { + return false + } + + activeRunID := strings.TrimSpace(a.state.ActiveRunID) + eventRunID := strings.TrimSpace(event.RunID) + if activeRunID != "" && eventRunID != "" && !strings.EqualFold(activeRunID, eventRunID) { + return false + } + return true +} + // runtimeEventUserMessageHandler 处理用户消息进入运行队列后的状态同步。 // runtimeEventInputNormalizedHandler 处理输入归一化完成事件并更新运行态提示。 func runtimeEventInputNormalizedHandler(a *App, event agentruntime.RuntimeEvent) bool { diff --git a/internal/tui/core/app/update_runtime_events_test.go b/internal/tui/core/app/update_runtime_events_test.go index 8e571af0..78769d00 100644 --- a/internal/tui/core/app/update_runtime_events_test.go +++ b/internal/tui/core/app/update_runtime_events_test.go @@ -131,3 +131,33 @@ func TestRuntimeEventHandlerRegistryContainsRenamedEvents(t *testing.T) { t.Fatalf("expected compact_applied handler to be registered") } } + +func TestShouldHandleRuntimeEventFiltersBySessionAndRun(t *testing.T) { + t.Parallel() + + app, _ := newTestApp(t) + app.state.ActiveSessionID = "session-active" + app.state.ActiveRunID = "run-active" + + if app.shouldHandleRuntimeEvent(agentruntime.RuntimeEvent{ + Type: agentruntime.EventAgentChunk, + SessionID: "session-other", + RunID: "run-active", + }) { + t.Fatalf("expected mismatched session event to be ignored") + } + if app.shouldHandleRuntimeEvent(agentruntime.RuntimeEvent{ + Type: agentruntime.EventAgentChunk, + SessionID: "session-active", + RunID: "run-other", + }) { + t.Fatalf("expected mismatched run event to be ignored") + } + if !app.shouldHandleRuntimeEvent(agentruntime.RuntimeEvent{ + Type: agentruntime.EventAgentChunk, + SessionID: "session-active", + RunID: "run-active", + }) { + t.Fatalf("expected matched event to be handled") + } +} From dd7e3b90bb9c15345f161d361218875b1de8a0fa Mon Sep 17 00:00:00 2001 From: xgopilot Date: Fri, 17 Apr 2026 08:54:39 +0000 Subject: [PATCH 04/10] =?UTF-8?q?test(coverage):=20=E8=A1=A5=E5=85=85?= =?UTF-8?q?=E5=A4=9A=E6=A8=A1=E6=80=81=E8=BE=93=E5=85=A5=E4=B8=8E=E4=BA=8B?= =?UTF-8?q?=E4=BB=B6=E5=88=86=E6=94=AF=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/session/input_preparer_test.go | 25 ++++++ internal/tui/core/app/input_features_test.go | 47 ++++++++++ .../core/app/update_runtime_events_test.go | 90 +++++++++++++++++++ 3 files changed, 162 insertions(+) diff --git a/internal/session/input_preparer_test.go b/internal/session/input_preparer_test.go index 0af31f2b..096a45b8 100644 --- a/internal/session/input_preparer_test.go +++ b/internal/session/input_preparer_test.go @@ -308,6 +308,31 @@ func TestInputPreparerPrepareImagePathAndMimeValidation(t *testing.T) { }) } +func TestAssetSaveErrorMethods(t *testing.T) { + t.Parallel() + + if err := (*AssetSaveError)(nil).Unwrap(); err != nil { + t.Fatalf("expected nil asset save error unwrap to return nil, got %v", err) + } + if msg := (*AssetSaveError)(nil).Error(); msg != "session: asset save failed" { + t.Fatalf("unexpected nil asset save error message: %q", msg) + } + + inner := errors.New("boom") + assetErr := &AssetSaveError{ + SessionID: "session-1", + Index: 2, + Path: "/tmp/image.png", + Err: inner, + } + if !errors.Is(assetErr, inner) { + t.Fatalf("expected asset save error to unwrap inner error") + } + if !strings.Contains(assetErr.Error(), "image.png") || !strings.Contains(assetErr.Error(), "index 2") { + t.Fatalf("unexpected asset save error message: %q", assetErr.Error()) + } +} + func minimalPNGBytes() []byte { return []byte{ 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, diff --git a/internal/tui/core/app/input_features_test.go b/internal/tui/core/app/input_features_test.go index 539f32f9..5092c898 100644 --- a/internal/tui/core/app/input_features_test.go +++ b/internal/tui/core/app/input_features_test.go @@ -301,6 +301,53 @@ func TestAbsorbInlineImageReferencesSupportsQuotedPathWithSpaces(t *testing.T) { } } +func TestParseInlineImagePathToken(t *testing.T) { + app, _ := newTestApp(t) + root := t.TempDir() + app.state.CurrentWorkdir = root + + relative, ok := app.parseInlineImagePathToken(`@image:"charts/sales q1.png"`) + if !ok { + t.Fatalf("expected quoted relative token to parse") + } + if relative != filepath.Join(root, filepath.FromSlash("charts/sales q1.png")) { + t.Fatalf("unexpected resolved path: %q", relative) + } + + absolutePath := filepath.Join(root, "abs.png") + absolute, ok := app.parseInlineImagePathToken("@image:" + absolutePath) + if !ok || absolute != absolutePath { + t.Fatalf("expected absolute token to pass through, got %q ok=%v", absolute, ok) + } + + if _, ok := app.parseInlineImagePathToken("@image:notes.txt"); ok { + t.Fatalf("expected non-image token to be rejected") + } + app.state.CurrentWorkdir = "" + if _, ok := app.parseInlineImagePathToken("@image:relative.png"); ok { + t.Fatalf("expected missing workdir to reject relative token") + } + if _, ok := app.parseInlineImagePathToken("not-image-token"); ok { + t.Fatalf("expected invalid token to be rejected") + } +} + +func TestParseInlineImageReferenceAtBranches(t *testing.T) { + if _, _, ok := parseInlineImageReferenceAt("x@image:a.png", 1); ok { + t.Fatalf("expected token without boundary whitespace to be rejected") + } + path, end, ok := parseInlineImageReferenceAt(`@image:folder\ with\ space.png next`, 0) + if !ok { + t.Fatalf("expected escaped-space token to parse") + } + if path != "folder with space.png" || end <= 0 { + t.Fatalf("unexpected escaped path parse result path=%q end=%d", path, end) + } + if _, _, ok := parseInlineImageReferenceAt(`@image:""`, 0); ok { + t.Fatalf("expected empty quoted token to fail") + } +} + func TestGetAndClearImageAttachments(t *testing.T) { app, _ := newTestApp(t) app.pendingImageAttachments = []pendingImageAttachment{ diff --git a/internal/tui/core/app/update_runtime_events_test.go b/internal/tui/core/app/update_runtime_events_test.go index 78769d00..1da6c428 100644 --- a/internal/tui/core/app/update_runtime_events_test.go +++ b/internal/tui/core/app/update_runtime_events_test.go @@ -1,6 +1,7 @@ package tui import ( + "strings" "testing" agentruntime "neo-code/internal/runtime" @@ -161,3 +162,92 @@ func TestShouldHandleRuntimeEventFiltersBySessionAndRun(t *testing.T) { t.Fatalf("expected matched event to be handled") } } + +func TestRuntimeEventMultimodalHandlers(t *testing.T) { + t.Parallel() + + app, _ := newTestApp(t) + + if handled := runtimeEventInputNormalizedHandler(&app, agentruntime.RuntimeEvent{Payload: "bad"}); handled { + t.Fatalf("expected invalid normalized payload to return false") + } + runtimeEventInputNormalizedHandler(&app, agentruntime.RuntimeEvent{ + RunID: "run-1", + Payload: agentruntime.InputNormalizedPayload{ + TextLength: 12, + ImageCount: 2, + }, + }) + if app.state.ActiveRunID != "run-1" { + t.Fatalf("expected active run id to be updated, got %q", app.state.ActiveRunID) + } + if len(app.activities) == 0 { + t.Fatalf("expected input normalized activity to be appended") + } + last := app.activities[len(app.activities)-1] + if last.Title != "Input normalized" || !strings.Contains(last.Detail, "images=2") { + t.Fatalf("unexpected normalized activity: %+v", last) + } + + before := len(app.activities) + runtimeEventAssetSavedHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.AssetSavedPayload{ + AssetID: "asset-1", + Path: "/tmp/chart.png", + }, + }) + if len(app.activities) != before+1 { + t.Fatalf("expected saved attachment activity appended") + } + last = app.activities[len(app.activities)-1] + if last.Title != "Saved attachment" || !strings.Contains(last.Detail, "chart.png") { + t.Fatalf("unexpected asset saved activity: %+v", last) + } + if handled := runtimeEventAssetSavedHandler(&app, agentruntime.RuntimeEvent{Payload: 123}); handled { + t.Fatalf("expected invalid asset_saved payload to return false") + } + + runtimeEventAssetSaveFailedHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.AssetSaveFailedPayload{Message: " failed "}, + }) + if app.state.ExecutionError != "failed" || app.state.StatusText != "failed" { + t.Fatalf("expected failed status to be surfaced, got status=%q err=%q", app.state.StatusText, app.state.ExecutionError) + } + last = app.activities[len(app.activities)-1] + if !last.IsError || last.Title != "Failed to save attachment" { + t.Fatalf("unexpected asset save failed activity: %+v", last) + } + runtimeEventAssetSaveFailedHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.AssetSaveFailedPayload{}, + }) + if app.state.ExecutionError != "failed to save attachment" || app.state.StatusText != "failed to save attachment" { + t.Fatalf("expected default failed message, got status=%q err=%q", app.state.StatusText, app.state.ExecutionError) + } + if handled := runtimeEventAssetSaveFailedHandler(&app, agentruntime.RuntimeEvent{Payload: true}); handled { + t.Fatalf("expected invalid asset_save_failed payload to return false") + } +} + +func TestHandleRuntimeEventSetsSessionAndRoutesByRegistry(t *testing.T) { + t.Parallel() + + app, _ := newTestApp(t) + handled := app.handleRuntimeEvent(agentruntime.RuntimeEvent{ + Type: agentruntime.EventAssetSaved, + SessionID: "session-1", + Payload: agentruntime.AssetSavedPayload{AssetID: "asset-1"}, + }) + if handled { + t.Fatalf("expected asset_saved handler to return false") + } + if app.state.ActiveSessionID != "session-1" { + t.Fatalf("expected active session to be set from event, got %q", app.state.ActiveSessionID) + } + if len(app.activities) == 0 || app.activities[len(app.activities)-1].Title != "Saved attachment" { + t.Fatalf("expected saved attachment activity") + } + + if app.handleRuntimeEvent(agentruntime.RuntimeEvent{Type: "unknown_event", SessionID: "session-1"}) { + t.Fatalf("expected unknown event handler result to be false") + } +} From 07e26c04144057bbeb438c227a7b35d3f6ba5f5a Mon Sep 17 00:00:00 2001 From: pionxe Date: Fri, 17 Apr 2026 17:51:45 +0800 Subject: [PATCH 05/10] =?UTF-8?q?feat(cli):=20=E6=94=AF=E6=8C=81=E5=BA=94?= =?UTF-8?q?=E7=94=A8=E5=90=AF=E5=8A=A8=E9=9D=99=E9=BB=98=E6=A3=80=E6=B5=8B?= =?UTF-8?q?=E6=96=B0=E7=89=88=E6=9C=AC=E4=B8=8E=E5=B9=B3=E6=BB=91=E8=87=AA?= =?UTF-8?q?=E5=8A=A8=E5=8D=87=E7=BA=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 引入 `go-selfupdate`,支持检测 Github Releases 并自动替换二进制文件 - 新增 `internal/version` 包,配合 `.goreleaser.yaml` 注入构建版本号 - 在根命令 `PersistentPreRunE` 挂载异步静默版本检测(屏蔽 `url-dispatch`) - 优化 TUI 生命周期,利用提示缓冲实现 AltScreen 退出后的更新提醒 - 新增 `neocode update [--prerelease]` 手动升级命令 - fix(test): 修复 Windows 环境下 config 目录权限测试由于 `os.Chmod` 不兼容导致的误报 --- .goreleaser.yaml | 3 +- README.md | 1 + cmd/neocode/main.go | 3 + docs/guides/update.md | 26 +++ go.mod | 18 +- go.sum | 48 ++++ internal/cli/root.go | 41 +++- internal/cli/root_test.go | 65 ++++++ internal/cli/update_command.go | 60 +++++ internal/cli/update_command_test.go | 85 ++++++++ internal/cli/update_notice.go | 33 +++ internal/config/loader_test.go | 13 +- internal/updater/updater.go | 252 +++++++++++++++++++++ internal/updater/updater_test.go | 327 ++++++++++++++++++++++++++++ internal/version/version.go | 25 +++ internal/version/version_test.go | 41 ++++ 16 files changed, 1036 insertions(+), 5 deletions(-) create mode 100644 docs/guides/update.md create mode 100644 internal/cli/update_command.go create mode 100644 internal/cli/update_command_test.go create mode 100644 internal/cli/update_notice.go create mode 100644 internal/updater/updater.go create mode 100644 internal/updater/updater_test.go create mode 100644 internal/version/version.go create mode 100644 internal/version/version_test.go diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 9d95045a..a1919fba 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -11,6 +11,8 @@ before: builds: - env: - CGO_ENABLED=0 # 禁用 CGO,确保生成纯静态链接的二进制文件 + ldflags: + - -s -w -X 'neo-code/internal/version.Version={{.Version}}' goos: - linux - windows @@ -46,4 +48,3 @@ changelog: exclude: - '^docs:' - '^test:' - diff --git a/README.md b/README.md index cfc9ec7f..1e6bd26c 100644 --- a/README.md +++ b/README.md @@ -148,6 +148,7 @@ go run ./cmd/neocode --workdir /path/to/workspace - [Context Compact 说明](docs/context-compact.md) - [Tools 与 TUI 集成](docs/tools-and-tui-integration.md) - [MCP 配置指南](docs/guides/mcp-configuration.md) +- [更新与升级](docs/guides/update.md) ## 如何参与 diff --git a/cmd/neocode/main.go b/cmd/neocode/main.go index 1926ff8e..dfaea80c 100644 --- a/cmd/neocode/main.go +++ b/cmd/neocode/main.go @@ -13,4 +13,7 @@ func main() { fmt.Fprintf(os.Stderr, "neocode: %v\n", err) os.Exit(1) } + if notice := cli.ConsumeUpdateNotice(); notice != "" { + fmt.Fprintln(os.Stdout, notice) + } } diff --git a/docs/guides/update.md b/docs/guides/update.md new file mode 100644 index 00000000..7bc0a5d2 --- /dev/null +++ b/docs/guides/update.md @@ -0,0 +1,26 @@ +# 更新与升级 + +## 自动检测 + +- `neocode` 启动时会在后台静默检测最新版本(默认 3 秒超时)。 +- 为避免干扰 Bubble Tea TUI 交互,更新提示会在应用退出、终端屏幕恢复后输出。 +- `url-dispatch` 子命令会跳过该检测流程。 + +## 手动升级 + +使用以下命令升级到最新稳定版: + +```bash +neocode update +``` + +如需包含预发布版本: + +```bash +neocode update --prerelease +``` + +## 版本来源 + +- 发布构建会通过 `ldflags` 注入版本号到 `internal/version.Version`。 +- 本地开发构建默认版本为 `dev`。 diff --git a/go.mod b/go.mod index 8d4b20dd..d5cfe3dd 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/glamour v1.0.0 github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 + github.com/creativeprojects/go-selfupdate v1.5.2 github.com/spf13/cobra v1.10.2 github.com/spf13/viper v1.21.0 golang.design/x/clipboard v0.7.1 @@ -16,6 +17,9 @@ require ( ) require ( + code.gitea.io/sdk/gitea v0.22.1 // indirect + github.com/42wim/httpsig v1.2.3 // indirect + github.com/Masterminds/semver/v3 v3.4.0 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/alecthomas/chroma/v2 v2.20.0 // indirect github.com/atotto/clipboard v0.1.4 // indirect @@ -31,12 +35,19 @@ require ( github.com/charmbracelet/x/term v0.2.2 // indirect github.com/clipperhouse/displaywidth v0.11.0 // indirect github.com/clipperhouse/uax29/v2 v2.7.0 // indirect + github.com/davidmz/go-pageant v1.0.2 // indirect github.com/dlclark/regexp2 v1.11.5 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-fed/httpsig v1.1.0 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/google/go-github/v74 v74.0.0 // indirect + github.com/google/go-querystring v1.1.0 // indirect github.com/gorilla/css v1.0.1 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect + github.com/hashicorp/go-retryablehttp v0.7.8 // indirect + github.com/hashicorp/go-version v1.8.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -61,15 +72,20 @@ require ( github.com/spf13/cast v1.10.0 // indirect github.com/spf13/pflag v1.0.10 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/ulikunitz/xz v0.5.15 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/yuin/goldmark v1.7.13 // indirect github.com/yuin/goldmark-emoji v1.0.6 // indirect + gitlab.com/gitlab-org/api/client-go v1.9.1 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/crypto v0.49.0 // indirect golang.org/x/exp/shiny v0.0.0-20250606033433-dcc06ee1d476 // indirect golang.org/x/image v0.28.0 // indirect golang.org/x/mobile v0.0.0-20250606033058-a2a15c67f36f // indirect + golang.org/x/oauth2 v0.34.0 // indirect golang.org/x/term v0.41.0 // indirect golang.org/x/text v0.35.0 // indirect - google.golang.org/protobuf v1.36.8 // indirect + golang.org/x/time v0.14.0 // indirect + google.golang.org/protobuf v1.36.11 // indirect ) diff --git a/go.sum b/go.sum index 9feda321..1e41deec 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,11 @@ +code.gitea.io/sdk/gitea v0.22.1 h1:7K05KjRORyTcTYULQ/AwvlVS6pawLcWyXZcTr7gHFyA= +code.gitea.io/sdk/gitea v0.22.1/go.mod h1:yyF5+GhljqvA30sRDreoyHILruNiy4ASufugzYg0VHM= +github.com/42wim/httpsig v1.2.3 h1:xb0YyWhkYj57SPtfSttIobJUPJZB9as1nsfo7KWVcEs= +github.com/42wim/httpsig v1.2.3/go.mod h1:nZq9OlYKDrUBhptd77IHx4/sZZD+IxTBADvAPI9G/EM= github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= +github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0= +github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= @@ -47,8 +53,12 @@ github.com/clipperhouse/displaywidth v0.11.0/go.mod h1:bkrFNkf81G8HyVqmKGxsPufD3 github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk= github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/creativeprojects/go-selfupdate v1.5.2 h1:3KR3JLrq70oplb9yZzbmJ89qRP78D1AN/9u+l3k0LJ4= +github.com/creativeprojects/go-selfupdate v1.5.2/go.mod h1:BCOuwIl1dRRCmPNRPH0amULeZqayhKyY2mH/h4va7Dk= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davidmz/go-pageant v1.0.2 h1:bPblRCh5jGU+Uptpz6LgMZGD5hJoOt7otgT454WvHn0= +github.com/davidmz/go-pageant v1.0.2/go.mod h1:P2EDDnMqIwG5Rrp05dTRITj9z2zpGcD9efWSkTNKLIE= github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ= github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= @@ -59,13 +69,26 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-fed/httpsig v1.1.0 h1:9M+hb0jkEICD8/cAiNqEB66R87tTINszBRTjwjQzWcI= +github.com/go-fed/httpsig v1.1.0/go.mod h1:RCMrTZvN1bJYtofsG4rd5NaO5obxQ5xBkdiS7xsT7bM= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-github/v74 v74.0.0 h1:yZcddTUn8DPbj11GxnMrNiAnXH14gNs559AsUpNpPgM= +github.com/google/go-github/v74 v74.0.0/go.mod h1:ubn/YdyftV80VPSI26nSJvaEsTOnsjrxG3o9kJhcyak= +github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= +github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8= github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVUrx/c8Unxc48= +github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw= +github.com/hashicorp/go-version v1.8.0 h1:KAkNb1HAiZd1ukkxDFGmokVZe1Xy9HG6NUp+bPle2i4= +github.com/hashicorp/go-version v1.8.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= @@ -137,18 +160,27 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY= +github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA= github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= github.com/yuin/goldmark-emoji v1.0.6 h1:QWfF2FYaXwL74tfGOW5izeiZepUDroDJfWubQI9HTHs= github.com/yuin/goldmark-emoji v1.0.6/go.mod h1:ukxJDKFpdFb5x0a5HqbdlcKtebh086iJpI31LTKmWuA= +gitlab.com/gitlab-org/api/client-go v1.9.1 h1:tZm+URa36sVy8UCEHQyGGJ8COngV4YqMHpM6k9O5tK8= +gitlab.com/gitlab-org/api/client-go v1.9.1/go.mod h1:71yTJk1lnHCWcZLvM5kPAXzeJ2fn5GjaoV8gTOPd4ME= go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.design/x/clipboard v0.7.1 h1:OEG3CmcYRBNnRwpDp7+uWLiZi3hrMRJpE9JkkkYtz2c= golang.design/x/clipboard v0.7.1/go.mod h1:i5SiIqj0wLFw9P/1D7vfILFK0KHMk7ydE72HRrUIgkg= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= golang.org/x/exp/shiny v0.0.0-20250606033433-dcc06ee1d476 h1:Wdx0vgH5Wgsw+lF//LJKmWOJBLWX6nprsMqnf99rYDE= @@ -157,18 +189,34 @@ golang.org/x/image v0.28.0 h1:gdem5JW1OLS4FbkWgLO+7ZeFzYtL3xClb97GaUzYMFE= golang.org/x/image v0.28.0/go.mod h1:GUJYXtnGKEUgggyzh+Vxt+AviiCcyiwpsl8iQ8MvwGY= golang.org/x/mobile v0.0.0-20250606033058-a2a15c67f36f h1:/n+PL2HlfqeSiDCuhdBbRNlGS/g2fM4OHufalHaTVG8= golang.org/x/mobile v0.0.0-20250606033058-a2a15c67f36f/go.mod h1:ESkJ836Z6LpG6mTVAhA48LpfW/8fNR0ifStlH2axyfg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= +golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/cli/root.go b/internal/cli/root.go index 98af9e8a..390f0840 100644 --- a/internal/cli/root.go +++ b/internal/cli/root.go @@ -3,18 +3,25 @@ package cli import ( "context" "errors" + "fmt" "strings" + "time" "github.com/spf13/cobra" "github.com/spf13/viper" "neo-code/internal/app" "neo-code/internal/config" + "neo-code/internal/updater" + "neo-code/internal/version" ) var launchRootProgram = defaultRootProgramLauncher var newRootProgram = app.NewProgram var runGlobalPreload = defaultGlobalPreload +var runSilentUpdateCheck = defaultSilentUpdateCheck + +const silentUpdateCheckTimeout = 3 * time.Second // GlobalFlags 描述 CLI 根命令当前支持的全局参数。 type GlobalFlags struct { @@ -24,6 +31,7 @@ type GlobalFlags struct { // Execute 负责执行 NeoCode 的 CLI 根命令。 func Execute(ctx context.Context) error { app.EnsureConsoleUTF8() + _ = ConsumeUpdateNotice() return NewRootCommand().ExecuteContext(ctx) } @@ -41,7 +49,11 @@ func NewRootCommand() *cobra.Command { if shouldSkipGlobalPreload(cmd) { return nil } - return runGlobalPreload(cmd.Context()) + if err := runGlobalPreload(cmd.Context()); err != nil { + return err + } + runSilentUpdateCheck(cmd.Context()) + return nil }, RunE: func(cmd *cobra.Command, args []string) error { flags.Workdir = strings.TrimSpace(settings.GetString("workdir")) @@ -56,6 +68,7 @@ func NewRootCommand() *cobra.Command { cmd.AddCommand( newGatewayCommand(), newURLDispatchCommand(), + newUpdateCommand(), ) return cmd @@ -92,6 +105,32 @@ func defaultGlobalPreload(ctx context.Context) error { return config.LoadPersistedEnv("") } +// defaultSilentUpdateCheck 在后台异步检查新版本并缓存退出后提示文案。 +func defaultSilentUpdateCheck(ctx context.Context) { + currentVersion := version.Current() + if !version.IsSemverRelease(currentVersion) { + return + } + parentCtx := context.WithoutCancel(ctx) + + go func(parent context.Context, currentVersion string) { + checkCtx, cancel := context.WithTimeout(parent, silentUpdateCheckTimeout) + defer cancel() + + result, err := updater.CheckLatest(checkCtx, updater.CheckOptions{ + CurrentVersion: currentVersion, + IncludePrerelease: false, + }) + if err != nil || !result.HasUpdate { + return + } + if strings.TrimSpace(result.LatestVersion) == "" { + return + } + setUpdateNotice(fmt.Sprintf("🚀 发现新版本: %s,运行 neocode update 即可升级", result.LatestVersion)) + }(parentCtx, currentVersion) +} + // shouldSkipGlobalPreload 判断当前命令是否应跳过全局预加载逻辑。 func shouldSkipGlobalPreload(cmd *cobra.Command) bool { if cmd == nil { diff --git a/internal/cli/root_test.go b/internal/cli/root_test.go index d6c514e5..30342402 100644 --- a/internal/cli/root_test.go +++ b/internal/cli/root_test.go @@ -21,6 +21,10 @@ import ( gatewayauth "neo-code/internal/gateway/auth" ) +func init() { + runSilentUpdateCheck = func(context.Context) {} +} + func TestNewRootCommandPassesWorkdirFlagToLauncher(t *testing.T) { originalLauncher := launchRootProgram t.Cleanup(func() { launchRootProgram = originalLauncher }) @@ -1192,6 +1196,67 @@ func TestShouldSkipGlobalPreload(t *testing.T) { } } +func TestRootCommandRunsSilentUpdateCheckAfterPreload(t *testing.T) { + originalLauncher := launchRootProgram + originalPreload := runGlobalPreload + originalSilentCheck := runSilentUpdateCheck + t.Cleanup(func() { launchRootProgram = originalLauncher }) + t.Cleanup(func() { runGlobalPreload = originalPreload }) + t.Cleanup(func() { runSilentUpdateCheck = originalSilentCheck }) + + events := make([]string, 0, 3) + runGlobalPreload = func(context.Context) error { + events = append(events, "preload") + return nil + } + runSilentUpdateCheck = func(context.Context) { + events = append(events, "check") + } + launchRootProgram = func(context.Context, app.BootstrapOptions) error { + events = append(events, "run") + return nil + } + + command := NewRootCommand() + command.SetArgs([]string{}) + if err := command.ExecuteContext(context.Background()); err != nil { + t.Fatalf("ExecuteContext() error = %v", err) + } + want := []string{"preload", "check", "run"} + if len(events) != len(want) { + t.Fatalf("events = %v, want %v", events, want) + } + for i := range want { + if events[i] != want[i] { + t.Fatalf("events[%d] = %q, want %q", i, events[i], want[i]) + } + } +} + +func TestURLDispatchSkipsSilentUpdateCheck(t *testing.T) { + originalSilentCheck := runSilentUpdateCheck + originalRunner := runURLDispatchCommand + t.Cleanup(func() { runSilentUpdateCheck = originalSilentCheck }) + t.Cleanup(func() { runURLDispatchCommand = originalRunner }) + + var called bool + runSilentUpdateCheck = func(context.Context) { + called = true + } + runURLDispatchCommand = func(context.Context, urlDispatchCommandOptions) error { + return nil + } + + command := NewRootCommand() + command.SetArgs([]string{"url-dispatch", "--url", "neocode://review?path=README.md"}) + if err := command.ExecuteContext(context.Background()); err != nil { + t.Fatalf("ExecuteContext() error = %v", err) + } + if called { + t.Fatal("expected silent update check to be skipped for url-dispatch") + } +} + func TestDefaultGlobalPreloadLoadsPersistedEnv(t *testing.T) { home := t.TempDir() t.Setenv("HOME", home) diff --git a/internal/cli/update_command.go b/internal/cli/update_command.go new file mode 100644 index 00000000..09ee762f --- /dev/null +++ b/internal/cli/update_command.go @@ -0,0 +1,60 @@ +package cli + +import ( + "context" + "fmt" + "strings" + + "github.com/spf13/cobra" + + "neo-code/internal/updater" + "neo-code/internal/version" +) + +type updateCommandOptions struct { + IncludePrerelease bool +} + +var runUpdateCommand = defaultUpdateCommandRunner + +// newUpdateCommand 创建 update 子命令并绑定升级参数。 +func newUpdateCommand() *cobra.Command { + options := &updateCommandOptions{} + + cmd := &cobra.Command{ + Use: "update", + Short: "Update neocode to the latest release", + SilenceUsage: true, + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + result, err := runUpdateCommand(cmd.Context(), *options) + if err != nil { + return err + } + + out := cmd.OutOrStdout() + if !result.Updated { + latest := strings.TrimSpace(result.LatestVersion) + if latest == "" { + latest = "unknown" + } + _, _ = fmt.Fprintf(out, "Already up-to-date (latest: %s).\n", latest) + return nil + } + + _, _ = fmt.Fprintf(out, "Updated successfully: %s -> %s\n", result.CurrentVersion, result.LatestVersion) + return nil + }, + } + + cmd.Flags().BoolVar(&options.IncludePrerelease, "prerelease", false, "include prerelease versions") + return cmd +} + +// defaultUpdateCommandRunner 执行手动升级流程并返回升级结果。 +func defaultUpdateCommandRunner(ctx context.Context, options updateCommandOptions) (updater.UpdateResult, error) { + return updater.DoUpdate(ctx, updater.UpdateOptions{ + CurrentVersion: version.Current(), + IncludePrerelease: options.IncludePrerelease, + }) +} diff --git a/internal/cli/update_command_test.go b/internal/cli/update_command_test.go new file mode 100644 index 00000000..13699be7 --- /dev/null +++ b/internal/cli/update_command_test.go @@ -0,0 +1,85 @@ +package cli + +import ( + "bytes" + "context" + "testing" + + "neo-code/internal/updater" +) + +func TestUpdateCommandPassesPrereleaseFlag(t *testing.T) { + originalRunner := runUpdateCommand + originalPreload := runGlobalPreload + originalSilentCheck := runSilentUpdateCheck + t.Cleanup(func() { runUpdateCommand = originalRunner }) + t.Cleanup(func() { runGlobalPreload = originalPreload }) + t.Cleanup(func() { runSilentUpdateCheck = originalSilentCheck }) + + runGlobalPreload = func(context.Context) error { return nil } + runSilentUpdateCheck = func(context.Context) {} + + var received updateCommandOptions + runUpdateCommand = func(_ context.Context, options updateCommandOptions) (updater.UpdateResult, error) { + received = options + return updater.UpdateResult{Updated: false, LatestVersion: "v0.2.1"}, nil + } + + command := NewRootCommand() + var stdout bytes.Buffer + command.SetOut(&stdout) + command.SetArgs([]string{"update", "--prerelease"}) + if err := command.ExecuteContext(context.Background()); err != nil { + t.Fatalf("ExecuteContext() error = %v", err) + } + + if !received.IncludePrerelease { + t.Fatal("expected IncludePrerelease to be true") + } + if got := stdout.String(); got == "" { + t.Fatal("expected update command output") + } +} + +func TestUpdateCommandShowsSuccessMessage(t *testing.T) { + originalRunner := runUpdateCommand + originalPreload := runGlobalPreload + originalSilentCheck := runSilentUpdateCheck + t.Cleanup(func() { runUpdateCommand = originalRunner }) + t.Cleanup(func() { runGlobalPreload = originalPreload }) + t.Cleanup(func() { runSilentUpdateCheck = originalSilentCheck }) + + runGlobalPreload = func(context.Context) error { return nil } + runSilentUpdateCheck = func(context.Context) {} + runUpdateCommand = func(context.Context, updateCommandOptions) (updater.UpdateResult, error) { + return updater.UpdateResult{ + CurrentVersion: "v0.1.0", + LatestVersion: "v0.2.1", + Updated: true, + }, nil + } + + command := NewRootCommand() + var stdout bytes.Buffer + command.SetOut(&stdout) + command.SetArgs([]string{"update"}) + if err := command.ExecuteContext(context.Background()); err != nil { + t.Fatalf("ExecuteContext() error = %v", err) + } + + if got := stdout.String(); got == "" || !bytes.Contains(stdout.Bytes(), []byte("Updated successfully")) { + t.Fatalf("unexpected output: %q", got) + } +} + +func TestConsumeUpdateNoticeOnce(t *testing.T) { + _ = ConsumeUpdateNotice() + setUpdateNotice(" new version ") + + if got := ConsumeUpdateNotice(); got != "new version" { + t.Fatalf("ConsumeUpdateNotice() = %q, want %q", got, "new version") + } + if got := ConsumeUpdateNotice(); got != "" { + t.Fatalf("ConsumeUpdateNotice() second call = %q, want empty", got) + } +} diff --git a/internal/cli/update_notice.go b/internal/cli/update_notice.go new file mode 100644 index 00000000..8a645a33 --- /dev/null +++ b/internal/cli/update_notice.go @@ -0,0 +1,33 @@ +package cli + +import ( + "strings" + "sync" +) + +var ( + updateNoticeMu sync.Mutex + pendingUpdateNotice string +) + +// setUpdateNotice 保存待输出的更新提示,后写入会覆盖先前值。 +func setUpdateNotice(notice string) { + normalized := strings.TrimSpace(notice) + if normalized == "" { + return + } + + updateNoticeMu.Lock() + pendingUpdateNotice = normalized + updateNoticeMu.Unlock() +} + +// ConsumeUpdateNotice 读取并清空待输出的更新提示,确保只消费一次。 +func ConsumeUpdateNotice() string { + updateNoticeMu.Lock() + defer updateNoticeMu.Unlock() + + notice := pendingUpdateNotice + pendingUpdateNotice = "" + return notice +} diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index a93c3d0f..34efb42d 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -4,6 +4,7 @@ import ( "context" "os" "path/filepath" + "runtime" "strings" "testing" @@ -1044,11 +1045,19 @@ func TestDeleteCustomProviderRemovesProviderDir(t *testing.T) { func TestLoadCustomProvidersReadDirAndStatErrors(t *testing.T) { t.Run("providers dir read error", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Windows does not support chmod 000 for directories") + } + baseDir := t.TempDir() providersPath := filepath.Join(baseDir, providersDirName) - if err := os.WriteFile(providersPath, []byte("file"), 0o600); err != nil { - t.Fatalf("WriteFile() error = %v", err) + if err := os.MkdirAll(providersPath, 0o755); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + if err := os.Chmod(providersPath, 0o000); err != nil { + t.Fatalf("Chmod() error = %v", err) } + defer func() { _ = os.Chmod(providersPath, 0o755) }() _, err := loadCustomProviders(baseDir) if err == nil { diff --git a/internal/updater/updater.go b/internal/updater/updater.go new file mode 100644 index 00000000..4166b643 --- /dev/null +++ b/internal/updater/updater.go @@ -0,0 +1,252 @@ +package updater + +import ( + "context" + "errors" + "fmt" + "regexp" + "runtime" + "strings" + + selfupdate "github.com/creativeprojects/go-selfupdate" + + "neo-code/internal/version" +) + +const ( + repositoryOwner = "1024XEngineer" + repositoryName = "neo-code" + checksumFilename = "checksums.txt" +) + +var ( + runtimeGOOS = runtime.GOOS + runtimeGOARCH = runtime.GOARCH +) + +var ( + newClient = func(config selfupdate.Config) (updateClient, error) { + updater, err := selfupdate.NewUpdater(config) + if err != nil { + return nil, err + } + return selfupdateClient{updater: updater}, nil + } + resolveExecutablePath = selfupdate.ExecutablePath +) + +type assetTarget struct { + OSToken string + ArchToken string + Ext string + AssetName string +} + +type releaseView interface { + Version() string + GreaterThan(other string) bool +} + +type updateClient interface { + DetectLatest(ctx context.Context, repository selfupdate.Repository) (releaseView, bool, error) + UpdateTo(ctx context.Context, rel releaseView, cmdPath string) error +} + +type selfupdateClient struct { + updater *selfupdate.Updater +} + +type selfupdateRelease struct { + release *selfupdate.Release +} + +// CheckOptions 描述静默检测新版本时的输入参数。 +type CheckOptions struct { + CurrentVersion string + IncludePrerelease bool +} + +// CheckResult 表示静默检测流程返回的版本信息。 +type CheckResult struct { + CurrentVersion string + LatestVersion string + HasUpdate bool +} + +// UpdateOptions 描述手动更新命令的输入参数。 +type UpdateOptions struct { + CurrentVersion string + IncludePrerelease bool +} + +// UpdateResult 表示手动更新流程的最终结果。 +type UpdateResult struct { + CurrentVersion string + LatestVersion string + Updated bool +} + +// CheckLatest 按当前平台资产规则检测最新版本,不做本地文件替换。 +func CheckLatest(ctx context.Context, opts CheckOptions) (CheckResult, error) { + currentVersion := normalizeCurrentVersion(opts.CurrentVersion) + target, err := resolveAssetTarget(runtimeGOOS, runtimeGOARCH) + if err != nil { + return CheckResult{CurrentVersion: currentVersion}, err + } + + client, err := newClient(buildSelfupdateConfig(target, opts.IncludePrerelease)) + if err != nil { + return CheckResult{CurrentVersion: currentVersion}, err + } + + repository := selfupdate.NewRepositorySlug(repositoryOwner, repositoryName) + release, found, err := client.DetectLatest(ctx, repository) + if err != nil { + return CheckResult{CurrentVersion: currentVersion}, err + } + + result := CheckResult{CurrentVersion: currentVersion} + if !found || release == nil { + return result, nil + } + + result.LatestVersion = strings.TrimSpace(release.Version()) + if result.LatestVersion == "" { + return result, nil + } + + if version.IsSemverRelease(currentVersion) { + result.HasUpdate = release.GreaterThan(currentVersion) + } + return result, nil +} + +// DoUpdate 下载并校验最新版本后原地替换当前可执行文件。 +func DoUpdate(ctx context.Context, opts UpdateOptions) (UpdateResult, error) { + currentVersion := normalizeCurrentVersion(opts.CurrentVersion) + target, err := resolveAssetTarget(runtimeGOOS, runtimeGOARCH) + if err != nil { + return UpdateResult{CurrentVersion: currentVersion}, err + } + + client, err := newClient(buildSelfupdateConfig(target, opts.IncludePrerelease)) + if err != nil { + return UpdateResult{CurrentVersion: currentVersion}, err + } + + repository := selfupdate.NewRepositorySlug(repositoryOwner, repositoryName) + release, found, err := client.DetectLatest(ctx, repository) + if err != nil { + return UpdateResult{CurrentVersion: currentVersion}, err + } + if !found || release == nil { + return UpdateResult{CurrentVersion: currentVersion}, errors.New("updater: no release asset found for current platform") + } + + latestVersion := strings.TrimSpace(release.Version()) + result := UpdateResult{ + CurrentVersion: currentVersion, + LatestVersion: latestVersion, + } + + if version.IsSemverRelease(currentVersion) && !release.GreaterThan(currentVersion) { + return result, nil + } + + executablePath, err := resolveExecutablePath() + if err != nil { + return result, err + } + + if err := client.UpdateTo(ctx, release, executablePath); err != nil { + return result, err + } + + result.Updated = true + return result, nil +} + +// DetectLatest 调用底层 go-selfupdate 客户端获取最新版本信息。 +func (c selfupdateClient) DetectLatest(ctx context.Context, repository selfupdate.Repository) (releaseView, bool, error) { + release, found, err := c.updater.DetectLatest(ctx, repository) + if err != nil || !found || release == nil { + return nil, found, err + } + return selfupdateRelease{release: release}, true, nil +} + +// UpdateTo 委托 go-selfupdate 完成原地替换流程,不追加平台分支逻辑。 +func (c selfupdateClient) UpdateTo(ctx context.Context, rel releaseView, cmdPath string) error { + typed, ok := rel.(selfupdateRelease) + if !ok || typed.release == nil { + return errors.New("updater: unsupported release type") + } + return c.updater.UpdateTo(ctx, typed.release, cmdPath) +} + +// Version 返回底层 release 的语义化版本字符串。 +func (r selfupdateRelease) Version() string { + return strings.TrimSpace(r.release.Version()) +} + +// GreaterThan 判断底层 release 是否高于指定版本。 +func (r selfupdateRelease) GreaterThan(other string) bool { + return r.release.GreaterThan(other) +} + +// normalizeCurrentVersion 归一化当前版本输入并处理空值回退。 +func normalizeCurrentVersion(value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "dev" + } + return trimmed +} + +// buildSelfupdateConfig 构建严格资产匹配与 checksum 校验配置。 +func buildSelfupdateConfig(target assetTarget, includePrerelease bool) selfupdate.Config { + return selfupdate.Config{ + OS: target.OSToken, + Arch: target.ArchToken, + Filters: []string{"^" + regexp.QuoteMeta(target.AssetName) + "$"}, + Validator: &selfupdate.ChecksumValidator{UniqueFilename: checksumFilename}, + Prerelease: includePrerelease, + } +} + +// resolveAssetTarget 按 GoReleaser 产物命名约束生成当前平台目标资产名。 +func resolveAssetTarget(goos string, goarch string) (assetTarget, error) { + var osToken string + switch strings.ToLower(strings.TrimSpace(goos)) { + case "linux": + osToken = "Linux" + case "darwin": + osToken = "Darwin" + case "windows": + osToken = "Windows" + default: + return assetTarget{}, fmt.Errorf("updater: unsupported os %q", goos) + } + + var archToken string + switch strings.ToLower(strings.TrimSpace(goarch)) { + case "amd64": + archToken = "x86_64" + case "arm64": + archToken = "arm64" + default: + return assetTarget{}, fmt.Errorf("updater: unsupported arch %q", goarch) + } + + ext := "tar.gz" + if osToken == "Windows" { + ext = "zip" + } + + return assetTarget{ + OSToken: osToken, + ArchToken: archToken, + Ext: ext, + AssetName: fmt.Sprintf("neocode_%s_%s.%s", osToken, archToken, ext), + }, nil +} diff --git a/internal/updater/updater_test.go b/internal/updater/updater_test.go new file mode 100644 index 00000000..2c724777 --- /dev/null +++ b/internal/updater/updater_test.go @@ -0,0 +1,327 @@ +package updater + +import ( + "context" + "errors" + "regexp" + "testing" + + selfupdate "github.com/creativeprojects/go-selfupdate" +) + +type fakeRelease struct { + version string + greaterFn func(string) bool +} + +func (r fakeRelease) Version() string { + return r.version +} + +func (r fakeRelease) GreaterThan(other string) bool { + if r.greaterFn != nil { + return r.greaterFn(other) + } + return false +} + +type fakeClient struct { + release releaseView + found bool + detectErr error + updateErr error + updateCalls int + lastUpdatePath string +} + +func (c *fakeClient) DetectLatest(context.Context, selfupdate.Repository) (releaseView, bool, error) { + return c.release, c.found, c.detectErr +} + +func (c *fakeClient) UpdateTo(_ context.Context, rel releaseView, cmdPath string) error { + _ = rel + c.updateCalls++ + c.lastUpdatePath = cmdPath + return c.updateErr +} + +func TestResolveAssetTarget(t *testing.T) { + tests := []struct { + name string + goos string + goarch string + wantOS string + wantArch string + wantExt string + wantAsset string + expectErrMsg string + }{ + { + name: "linux amd64", + goos: "linux", + goarch: "amd64", + wantOS: "Linux", + wantArch: "x86_64", + wantExt: "tar.gz", + wantAsset: "neocode_Linux_x86_64.tar.gz", + }, + { + name: "darwin arm64", + goos: "darwin", + goarch: "arm64", + wantOS: "Darwin", + wantArch: "arm64", + wantExt: "tar.gz", + wantAsset: "neocode_Darwin_arm64.tar.gz", + }, + { + name: "windows amd64", + goos: "windows", + goarch: "amd64", + wantOS: "Windows", + wantArch: "x86_64", + wantExt: "zip", + wantAsset: "neocode_Windows_x86_64.zip", + }, + { + name: "unsupported os", + goos: "freebsd", + goarch: "amd64", + expectErrMsg: "unsupported os", + }, + { + name: "unsupported arch", + goos: "linux", + goarch: "386", + expectErrMsg: "unsupported arch", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + target, err := resolveAssetTarget(tt.goos, tt.goarch) + if tt.expectErrMsg != "" { + if err == nil || !regexp.MustCompile(tt.expectErrMsg).MatchString(err.Error()) { + t.Fatalf("resolveAssetTarget() error = %v, want contains %q", err, tt.expectErrMsg) + } + return + } + if err != nil { + t.Fatalf("resolveAssetTarget() error = %v", err) + } + if target.OSToken != tt.wantOS { + t.Fatalf("OSToken = %q, want %q", target.OSToken, tt.wantOS) + } + if target.ArchToken != tt.wantArch { + t.Fatalf("ArchToken = %q, want %q", target.ArchToken, tt.wantArch) + } + if target.Ext != tt.wantExt { + t.Fatalf("Ext = %q, want %q", target.Ext, tt.wantExt) + } + if target.AssetName != tt.wantAsset { + t.Fatalf("AssetName = %q, want %q", target.AssetName, tt.wantAsset) + } + }) + } +} + +func TestBuildSelfupdateConfigUsesExactFilterAndChecksum(t *testing.T) { + target := assetTarget{ + OSToken: "Darwin", + ArchToken: "x86_64", + Ext: "tar.gz", + AssetName: "neocode_Darwin_x86_64.tar.gz", + } + config := buildSelfupdateConfig(target, true) + if config.OS != "Darwin" || config.Arch != "x86_64" { + t.Fatalf("OS/Arch = %q/%q, want %q/%q", config.OS, config.Arch, "Darwin", "x86_64") + } + if !config.Prerelease { + t.Fatal("expected prerelease to be enabled") + } + if len(config.Filters) != 1 { + t.Fatalf("len(Filters) = %d, want 1", len(config.Filters)) + } + exactFilter := config.Filters[0] + re := regexp.MustCompile(exactFilter) + if !re.MatchString("neocode_Darwin_x86_64.tar.gz") { + t.Fatal("exact filter should match target asset") + } + if re.MatchString("neocode_Darwin_x86_64.tar.gz.sig") { + t.Fatal("exact filter should not match similar asset names") + } + validator, ok := config.Validator.(*selfupdate.ChecksumValidator) + if !ok { + t.Fatalf("validator type = %T, want *selfupdate.ChecksumValidator", config.Validator) + } + if validator.UniqueFilename != checksumFilename { + t.Fatalf("UniqueFilename = %q, want %q", validator.UniqueFilename, checksumFilename) + } +} + +func TestCheckLatest(t *testing.T) { + originalNewClient := newClient + originalGOOS := runtimeGOOS + originalGOARCH := runtimeGOARCH + t.Cleanup(func() { + newClient = originalNewClient + runtimeGOOS = originalGOOS + runtimeGOARCH = originalGOARCH + }) + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + + client := &fakeClient{ + release: fakeRelease{ + version: "v1.2.0", + greaterFn: func(other string) bool { + return other == "v1.1.0" + }, + }, + found: true, + } + newClient = func(config selfupdate.Config) (updateClient, error) { + return client, nil + } + + result, err := CheckLatest(context.Background(), CheckOptions{ + CurrentVersion: "v1.1.0", + IncludePrerelease: false, + }) + if err != nil { + t.Fatalf("CheckLatest() error = %v", err) + } + if !result.HasUpdate { + t.Fatal("expected HasUpdate to be true") + } + if result.LatestVersion != "v1.2.0" { + t.Fatalf("LatestVersion = %q, want %q", result.LatestVersion, "v1.2.0") + } +} + +func TestDoUpdateSkipsWhenAlreadyLatestForSemver(t *testing.T) { + originalNewClient := newClient + originalGOOS := runtimeGOOS + originalGOARCH := runtimeGOARCH + t.Cleanup(func() { + newClient = originalNewClient + runtimeGOOS = originalGOOS + runtimeGOARCH = originalGOARCH + }) + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + + client := &fakeClient{ + release: fakeRelease{ + version: "v1.2.0", + greaterFn: func(other string) bool { + return false + }, + }, + found: true, + } + newClient = func(config selfupdate.Config) (updateClient, error) { + return client, nil + } + + result, err := DoUpdate(context.Background(), UpdateOptions{CurrentVersion: "v1.2.0"}) + if err != nil { + t.Fatalf("DoUpdate() error = %v", err) + } + if result.Updated { + t.Fatal("expected Updated to be false") + } + if client.updateCalls != 0 { + t.Fatalf("update calls = %d, want 0", client.updateCalls) + } +} + +func TestDoUpdateUsesUpdaterLibraryPathForWindows(t *testing.T) { + originalNewClient := newClient + originalExePath := resolveExecutablePath + originalGOOS := runtimeGOOS + originalGOARCH := runtimeGOARCH + t.Cleanup(func() { + newClient = originalNewClient + resolveExecutablePath = originalExePath + runtimeGOOS = originalGOOS + runtimeGOARCH = originalGOARCH + }) + runtimeGOOS = "windows" + runtimeGOARCH = "amd64" + + client := &fakeClient{ + release: fakeRelease{ + version: "v1.3.0", + greaterFn: func(other string) bool { + return false + }, + }, + found: true, + } + + var capturedConfig selfupdate.Config + newClient = func(config selfupdate.Config) (updateClient, error) { + capturedConfig = config + return client, nil + } + resolveExecutablePath = func() (string, error) { + return `C:\Tools\neocode.exe`, nil + } + + result, err := DoUpdate(context.Background(), UpdateOptions{CurrentVersion: "dev"}) + if err != nil { + t.Fatalf("DoUpdate() error = %v", err) + } + if !result.Updated { + t.Fatal("expected Updated to be true") + } + if client.updateCalls != 1 { + t.Fatalf("update calls = %d, want 1", client.updateCalls) + } + if client.lastUpdatePath != `C:\Tools\neocode.exe` { + t.Fatalf("last update path = %q, want %q", client.lastUpdatePath, `C:\Tools\neocode.exe`) + } + if capturedConfig.OS != "Windows" || capturedConfig.Arch != "x86_64" { + t.Fatalf("config OS/Arch = %q/%q, want %q/%q", capturedConfig.OS, capturedConfig.Arch, "Windows", "x86_64") + } +} + +func TestDoUpdatePropagatesUpdateError(t *testing.T) { + originalNewClient := newClient + originalExePath := resolveExecutablePath + originalGOOS := runtimeGOOS + originalGOARCH := runtimeGOARCH + t.Cleanup(func() { + newClient = originalNewClient + resolveExecutablePath = originalExePath + runtimeGOOS = originalGOOS + runtimeGOARCH = originalGOARCH + }) + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + + expected := errors.New("apply update failed") + client := &fakeClient{ + release: fakeRelease{ + version: "v1.3.0", + greaterFn: func(other string) bool { + return true + }, + }, + found: true, + updateErr: expected, + } + + newClient = func(config selfupdate.Config) (updateClient, error) { + return client, nil + } + resolveExecutablePath = func() (string, error) { + return "/usr/local/bin/neocode", nil + } + + _, err := DoUpdate(context.Background(), UpdateOptions{CurrentVersion: "v1.2.0"}) + if !errors.Is(err, expected) { + t.Fatalf("DoUpdate() error = %v, want %v", err, expected) + } +} diff --git a/internal/version/version.go b/internal/version/version.go new file mode 100644 index 00000000..9f8ea7af --- /dev/null +++ b/internal/version/version.go @@ -0,0 +1,25 @@ +package version + +import ( + "regexp" + "strings" +) + +var semverPattern = regexp.MustCompile(`^v?\d+\.\d+\.\d+(?:-[0-9A-Za-z.-]+)?(?:\+[0-9A-Za-z.-]+)?$`) + +// Version 表示当前构建注入的版本号;默认值用于本地开发构建。 +var Version = "dev" + +// Current 返回归一化后的当前版本;空值会回退为 dev。 +func Current() string { + value := strings.TrimSpace(Version) + if value == "" { + return "dev" + } + return value +} + +// IsSemverRelease 判断给定版本字符串是否为可比较的语义化版本。 +func IsSemverRelease(value string) bool { + return semverPattern.MatchString(strings.TrimSpace(value)) +} diff --git a/internal/version/version_test.go b/internal/version/version_test.go new file mode 100644 index 00000000..befdd017 --- /dev/null +++ b/internal/version/version_test.go @@ -0,0 +1,41 @@ +package version + +import "testing" + +func TestCurrentFallsBackToDev(t *testing.T) { + original := Version + t.Cleanup(func() { Version = original }) + + Version = " " + if got := Current(); got != "dev" { + t.Fatalf("Current() = %q, want %q", got, "dev") + } + + Version = " v1.2.3 " + if got := Current(); got != "v1.2.3" { + t.Fatalf("Current() = %q, want %q", got, "v1.2.3") + } +} + +func TestIsSemverRelease(t *testing.T) { + tests := []struct { + name string + value string + matched bool + }{ + {name: "with v prefix", value: "v1.2.3", matched: true}, + {name: "without v prefix", value: "1.2.3", matched: true}, + {name: "prerelease", value: "v1.2.3-rc.1", matched: true}, + {name: "build metadata", value: "v1.2.3+meta", matched: true}, + {name: "dev", value: "dev", matched: false}, + {name: "missing patch", value: "v1.2", matched: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsSemverRelease(tt.value); got != tt.matched { + t.Fatalf("IsSemverRelease(%q) = %v, want %v", tt.value, got, tt.matched) + } + }) + } +} From 2b90cf9a2530d44fc318787a15cd7ff7dcdbb159 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Fri, 17 Apr 2026 10:18:06 +0000 Subject: [PATCH 06/10] fix: resolve residual multimodal/session/provider review risks Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- .../chatcompletions/provider_test.go | 4 +- .../openaicompat/chatcompletions/request.go | 4 -- internal/session/asset_store_test.go | 25 +++++++ internal/session/input_preparer.go | 72 +++++++++++++++---- internal/session/input_preparer_test.go | 70 ++++++++++++++++++ internal/session/store.go | 33 +++++++++ internal/tui/core/app/update.go | 9 ++- .../core/app/update_runtime_events_test.go | 40 ++++++++++- 8 files changed, 232 insertions(+), 25 deletions(-) diff --git a/internal/provider/openaicompat/chatcompletions/provider_test.go b/internal/provider/openaicompat/chatcompletions/provider_test.go index def97a35..1ab5eae8 100644 --- a/internal/provider/openaicompat/chatcompletions/provider_test.go +++ b/internal/provider/openaicompat/chatcompletions/provider_test.go @@ -123,8 +123,8 @@ func TestNewAndBuildRequest(t *testing.T) { if downgradedSchema["type"] != "object" { t.Fatalf("expected downgraded schema type object, got %+v", downgradedSchema["type"]) } - if downgradedSchema["x-neocode-schema-downgraded"] != true { - t.Fatalf("expected downgrade marker, got %+v", downgradedSchema) + if _, ok := downgradedSchema["x-neocode-schema-downgraded"]; ok { + t.Fatalf("expected no custom downgrade marker in outbound schema, got %+v", downgradedSchema) } withSessionAsset, err := BuildRequest(context.Background(), testCfg("https://api.example.com/v1", "gpt-4.1", "test-key"), providertypes.GenerateRequest{ diff --git a/internal/provider/openaicompat/chatcompletions/request.go b/internal/provider/openaicompat/chatcompletions/request.go index 6c17eb36..a1807b13 100644 --- a/internal/provider/openaicompat/chatcompletions/request.go +++ b/internal/provider/openaicompat/chatcompletions/request.go @@ -91,14 +91,10 @@ func normalizeToolSchemaForOpenAI(schema map[string]any) map[string]any { typeName, _ := normalized["type"].(string) if strings.TrimSpace(strings.ToLower(typeName)) != "object" { normalized["type"] = "object" - normalized["x-neocode-schema-downgraded"] = true } if _, ok := normalized["properties"].(map[string]any); !ok { normalized["properties"] = map[string]any{} - if strings.TrimSpace(strings.ToLower(typeName)) != "object" { - normalized["x-neocode-schema-downgraded"] = true - } } return normalized } diff --git a/internal/session/asset_store_test.go b/internal/session/asset_store_test.go index 7c2697a2..3820ee28 100644 --- a/internal/session/asset_store_test.go +++ b/internal/session/asset_store_test.go @@ -224,6 +224,31 @@ func TestJSONStoreOpenAndStatMissingStoredFiles(t *testing.T) { } } +func TestJSONStoreDeleteAsset(t *testing.T) { + t.Parallel() + + store := NewJSONStore(t.TempDir(), t.TempDir()) + sessionID := "session-delete-asset" + meta, err := store.SaveAsset(context.Background(), sessionID, strings.NewReader("img"), "image/png") + if err != nil { + t.Fatalf("save seed asset: %v", err) + } + + if err := store.DeleteAsset(context.Background(), sessionID, meta.ID); err != nil { + t.Fatalf("DeleteAsset() error = %v", err) + } + if _, statErr := os.Stat(store.assetPath(sessionID, meta.ID)); !errors.Is(statErr, os.ErrNotExist) { + t.Fatalf("expected removed asset file, got %v", statErr) + } + if _, statErr := os.Stat(store.assetMetaPath(sessionID, meta.ID)); !errors.Is(statErr, os.ErrNotExist) { + t.Fatalf("expected removed asset meta file, got %v", statErr) + } + + if err := store.DeleteAsset(context.Background(), sessionID, meta.ID); err != nil { + t.Fatalf("DeleteAsset() should ignore already deleted files, got %v", err) + } +} + type failingReader struct{} func (failingReader) Read(_ []byte) (int, error) { diff --git a/internal/session/input_preparer.go b/internal/session/input_preparer.go index 5d65a6ec..5e618c53 100644 --- a/internal/session/input_preparer.go +++ b/internal/session/input_preparer.go @@ -70,6 +70,10 @@ type InputPreparer struct { assetStore AssetStore } +type assetCleanupStore interface { + DeleteAsset(ctx context.Context, sessionID string, assetID string) error +} + // NewInputPreparer 创建会话输入归一化组件。 func NewInputPreparer(store Store, assetStore AssetStore) *InputPreparer { return &InputPreparer{ @@ -96,7 +100,7 @@ func (p *InputPreparer) Prepare(ctx context.Context, input PrepareInput) (Prepar } sessionTitle := buildSessionTitle(trimmedText, len(input.Images) > 0) - session, sessionCreated, err := p.loadOrCreateSession( + session, sessionCreated, pendingUpdate, err := p.loadOrCreateSession( ctx, input.SessionID, sessionTitle, @@ -117,6 +121,7 @@ func (p *InputPreparer) Prepare(ctx context.Context, input PrepareInput) (Prepar path := strings.TrimSpace(image.Path) if path == "" { p.rollbackCreatedSession(ctx, session.ID, sessionCreated) + p.cleanupSavedAssets(ctx, session.ID, savedAssets) return PreparedInput{}, &AssetSaveError{ SessionID: session.ID, Index: index, @@ -129,6 +134,7 @@ func (p *InputPreparer) Prepare(ctx context.Context, input PrepareInput) (Prepar meta, err := p.saveImageAsset(ctx, session.ID, session.Workdir, path, mimeType) if err != nil { p.rollbackCreatedSession(ctx, session.ID, sessionCreated) + p.cleanupSavedAssets(ctx, session.ID, savedAssets) return PreparedInput{}, &AssetSaveError{ SessionID: session.ID, Index: index, @@ -142,8 +148,14 @@ func (p *InputPreparer) Prepare(ctx context.Context, input PrepareInput) (Prepar if err := providertypes.ValidateParts(parts); err != nil { p.rollbackCreatedSession(ctx, session.ID, sessionCreated) + p.cleanupSavedAssets(ctx, session.ID, savedAssets) return PreparedInput{}, fmt.Errorf("session: normalize parts: %w", err) } + if err := p.persistSessionWorkdirUpdate(ctx, pendingUpdate); err != nil { + p.rollbackCreatedSession(ctx, session.ID, sessionCreated) + p.cleanupSavedAssets(ctx, session.ID, savedAssets) + return PreparedInput{}, err + } return PreparedInput{ SessionID: session.ID, @@ -278,47 +290,53 @@ func resolveImagePath(workdir string, path string) (string, error) { return resolved, nil } +// sessionWorkdirUpdate 描述已有会话 workdir 的待提交变更,确保 Prepare 成功后再落盘。 +type sessionWorkdirUpdate struct { + session Session + dirty bool +} + func (p *InputPreparer) loadOrCreateSession( ctx context.Context, sessionID string, title string, defaultWorkdir string, requestedWorkdir string, -) (Session, bool, error) { +) (Session, bool, sessionWorkdirUpdate, error) { if strings.TrimSpace(sessionID) == "" { sessionWorkdir, err := resolveWorkdirForInput(defaultWorkdir, "", requestedWorkdir) if err != nil { - return Session{}, false, err + return Session{}, false, sessionWorkdirUpdate{}, err } session := NewWithWorkdir(title, sessionWorkdir) if err := p.store.Save(ctx, &session); err != nil { - return Session{}, false, err + return Session{}, false, sessionWorkdirUpdate{}, err } - return session, true, nil + return session, true, sessionWorkdirUpdate{}, nil } session, err := p.store.Load(ctx, sessionID) if err != nil { - return Session{}, false, err + return Session{}, false, sessionWorkdirUpdate{}, err } if strings.TrimSpace(requestedWorkdir) == "" && strings.TrimSpace(session.Workdir) != "" { - return session, false, nil + return session, false, sessionWorkdirUpdate{}, nil } resolved, err := resolveWorkdirForInput(defaultWorkdir, session.Workdir, requestedWorkdir) if err != nil { - return Session{}, false, err + return Session{}, false, sessionWorkdirUpdate{}, err } if session.Workdir == resolved { - return session, false, nil + return session, false, sessionWorkdirUpdate{}, nil } session.Workdir = resolved session.UpdatedAt = time.Now() - if err := p.store.Save(ctx, &session); err != nil { - return Session{}, false, err - } - return session, false, nil + return session, false, sessionWorkdirUpdate{ + session: session, + dirty: true, + }, nil } // rollbackCreatedSession 在本次 Prepare 新建会话后发生错误时回滚会话目录,避免残留孤儿会话。 @@ -332,6 +350,34 @@ func (p *InputPreparer) rollbackCreatedSession(ctx context.Context, sessionID st _ = p.store.DeleteSession(ctx, sessionID) } +// persistSessionWorkdirUpdate 在 Prepare 其余步骤完成后统一提交会话 workdir 更新,避免失败时出现部分提交。 +func (p *InputPreparer) persistSessionWorkdirUpdate(ctx context.Context, pending sessionWorkdirUpdate) error { + if !pending.dirty { + return nil + } + if err := p.store.Save(ctx, &pending.session); err != nil { + return err + } + return nil +} + +// cleanupSavedAssets 在 Prepare 失败时尽力回收已落盘的附件,减少 existing session 残留垃圾文件。 +func (p *InputPreparer) cleanupSavedAssets(ctx context.Context, sessionID string, assets []AssetMeta) { + if len(assets) == 0 || ctx.Err() != nil { + return + } + cleanupStore, ok := p.assetStore.(assetCleanupStore) + if !ok { + return + } + for _, asset := range assets { + if strings.TrimSpace(asset.ID) == "" { + continue + } + _ = cleanupStore.DeleteAsset(ctx, sessionID, asset.ID) + } +} + func resolveWorkdirForInput(defaultWorkdir string, currentWorkdir string, requestedWorkdir string) (string, error) { base := EffectiveWorkdir(currentWorkdir, defaultWorkdir) if strings.TrimSpace(requestedWorkdir) == "" { diff --git a/internal/session/input_preparer_test.go b/internal/session/input_preparer_test.go index 096a45b8..d0b6c21f 100644 --- a/internal/session/input_preparer_test.go +++ b/internal/session/input_preparer_test.go @@ -237,6 +237,76 @@ func TestInputPreparerPrepareErrors(t *testing.T) { t.Fatalf("expected existing session to remain, load error = %v", loadErr) } }) + + t.Run("existing session cleanup removes previously saved assets on later failure", func(t *testing.T) { + existing := NewWithWorkdir("existing-cleanup", workdir) + if err := store.Save(context.Background(), &existing); err != nil { + t.Fatalf("Save() error = %v", err) + } + + okImage := filepath.Join(workdir, "ok.png") + if err := os.WriteFile(okImage, minimalPNGBytes(), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + preparer := NewInputPreparer(store, store) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + SessionID: existing.ID, + Text: "cleanup", + Images: []PrepareImageInput{ + {Path: okImage}, + {Path: "not-found.png", MimeType: "image/png"}, + }, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected prepare error") + } + + entries, readErr := os.ReadDir(store.assetsDir(existing.ID)) + if readErr != nil { + t.Fatalf("ReadDir() error = %v", readErr) + } + if len(entries) != 0 { + t.Fatalf("expected no leftover assets, got %d files", len(entries)) + } + }) + + t.Run("existing session workdir change is not persisted when prepare fails", func(t *testing.T) { + currentWorkdir := filepath.Join(workdir, "current") + if err := os.MkdirAll(currentWorkdir, 0o755); err != nil { + t.Fatalf("mkdir current workdir: %v", err) + } + targetWorkdir := filepath.Join(currentWorkdir, "nested") + if err := os.MkdirAll(targetWorkdir, 0o755); err != nil { + t.Fatalf("mkdir nested workdir: %v", err) + } + + existing := NewWithWorkdir("existing-workdir", currentWorkdir) + if err := store.Save(context.Background(), &existing); err != nil { + t.Fatalf("Save() error = %v", err) + } + + preparer := NewInputPreparer(store, store) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + SessionID: existing.ID, + Text: "will fail", + RequestedWorkdir: "nested", + Images: []PrepareImageInput{{Path: "not-found.png", MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected prepare error") + } + + loaded, loadErr := store.Load(context.Background(), existing.ID) + if loadErr != nil { + t.Fatalf("Load() error = %v", loadErr) + } + if loaded.Workdir != currentWorkdir { + t.Fatalf("expected workdir to stay %q, got %q", currentWorkdir, loaded.Workdir) + } + }) } func TestInputPreparerPrepareImagePathAndMimeValidation(t *testing.T) { diff --git a/internal/session/store.go b/internal/session/store.go index 976318ac..2000cde5 100644 --- a/internal/session/store.go +++ b/internal/session/store.go @@ -399,6 +399,39 @@ func (s *JSONStore) Stat(ctx context.Context, sessionID string, assetID string) return s.statUnlocked(sessionID, assetID) } +// DeleteAsset 删除指定会话附件的二进制与元数据文件,用于输入归一化失败后的清理。 +func (s *JSONStore) DeleteAsset(ctx context.Context, sessionID string, assetID string) error { + if err := ctx.Err(); err != nil { + return err + } + if err := validateStorageID("session id", sessionID); err != nil { + return fmt.Errorf("session: %w", err) + } + if err := validateStorageID("asset id", assetID); err != nil { + return fmt.Errorf("session: %w", err) + } + + s.mu.Lock() + defer s.mu.Unlock() + + target := s.assetPath(sessionID, assetID) + if err := ensurePathWithinBase(s.baseDir, target); err != nil { + return fmt.Errorf("session: resolve asset file path: %w", err) + } + if err := os.Remove(target); err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("session: delete asset file: %w", err) + } + + metaTarget := s.assetMetaPath(sessionID, assetID) + if err := ensurePathWithinBase(s.baseDir, metaTarget); err != nil { + return fmt.Errorf("session: resolve asset meta file path: %w", err) + } + if err := os.Remove(metaTarget); err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("session: delete asset meta file: %w", err) + } + return nil +} + // statUnlocked 在调用方已持有读锁时读取附件元数据,避免重复加锁导致死锁风险。 func (s *JSONStore) statUnlocked(sessionID string, assetID string) (AssetMeta, error) { target := s.assetMetaPath(sessionID, assetID) diff --git a/internal/tui/core/app/update.go b/internal/tui/core/app/update.go index f18ebb5c..836ad0f3 100644 --- a/internal/tui/core/app/update.go +++ b/internal/tui/core/app/update.go @@ -966,9 +966,6 @@ func (a *App) handleRuntimeEvent(event agentruntime.RuntimeEvent) bool { if !a.shouldHandleRuntimeEvent(event) { return false } - if a.state.ActiveSessionID == "" { - a.state.ActiveSessionID = event.SessionID - } handler, ok := runtimeEventHandlerRegistry[event.Type] if !ok { return false @@ -1051,6 +1048,9 @@ func runtimeEventUserMessageHandler(a *App, event agentruntime.RuntimeEvent) boo if runID != "" { a.state.ActiveRunID = runID } + if sessionID := strings.TrimSpace(event.SessionID); sessionID != "" { + a.state.ActiveSessionID = sessionID + } a.state.StatusText = statusThinking a.state.StreamingReply = false a.state.CurrentTool = "" @@ -1085,6 +1085,9 @@ func runtimeEventRunContextHandler(a *App, event agentruntime.RuntimeEvent) bool } mapped := tuiservices.MapRunContextPayload(event.RunID, event.SessionID, payload) a.state.RunContext = mapped + if strings.TrimSpace(mapped.SessionID) != "" { + a.state.ActiveSessionID = strings.TrimSpace(mapped.SessionID) + } if strings.TrimSpace(mapped.RunID) != "" { a.state.ActiveRunID = mapped.RunID } diff --git a/internal/tui/core/app/update_runtime_events_test.go b/internal/tui/core/app/update_runtime_events_test.go index 1da6c428..d6f7725d 100644 --- a/internal/tui/core/app/update_runtime_events_test.go +++ b/internal/tui/core/app/update_runtime_events_test.go @@ -4,8 +4,10 @@ import ( "strings" "testing" + providertypes "neo-code/internal/provider/types" agentruntime "neo-code/internal/runtime" "neo-code/internal/runtime/controlplane" + tuiservices "neo-code/internal/tui/services" ) func TestRuntimeEventPhaseChangedHandlerBranches(t *testing.T) { @@ -228,7 +230,7 @@ func TestRuntimeEventMultimodalHandlers(t *testing.T) { } } -func TestHandleRuntimeEventSetsSessionAndRoutesByRegistry(t *testing.T) { +func TestHandleRuntimeEventRoutesByRegistryWithoutBindingTransientSession(t *testing.T) { t.Parallel() app, _ := newTestApp(t) @@ -240,8 +242,8 @@ func TestHandleRuntimeEventSetsSessionAndRoutesByRegistry(t *testing.T) { if handled { t.Fatalf("expected asset_saved handler to return false") } - if app.state.ActiveSessionID != "session-1" { - t.Fatalf("expected active session to be set from event, got %q", app.state.ActiveSessionID) + if app.state.ActiveSessionID != "" { + t.Fatalf("expected active session to stay empty for non-stable event, got %q", app.state.ActiveSessionID) } if len(app.activities) == 0 || app.activities[len(app.activities)-1].Title != "Saved attachment" { t.Fatalf("expected saved attachment activity") @@ -251,3 +253,35 @@ func TestHandleRuntimeEventSetsSessionAndRoutesByRegistry(t *testing.T) { t.Fatalf("expected unknown event handler result to be false") } } + +func TestHandleRuntimeEventBindsSessionFromStableEvents(t *testing.T) { + t.Parallel() + + app, _ := newTestApp(t) + + app.handleRuntimeEvent(agentruntime.RuntimeEvent{ + Type: agentruntime.EventUserMessage, + SessionID: "session-user", + RunID: "run-1", + Payload: providertypes.Message{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")}, + }, + }) + if app.state.ActiveSessionID != "session-user" { + t.Fatalf("expected active session from user_message, got %q", app.state.ActiveSessionID) + } + + app.state.ActiveSessionID = "" + app.handleRuntimeEvent(agentruntime.RuntimeEvent{ + Type: agentruntime.EventType(tuiservices.RuntimeEventRunContext), + SessionID: "session-context", + Payload: tuiservices.RuntimeRunContextPayload{ + Provider: "openai", + Model: "gpt-5.4", + }, + }) + if app.state.ActiveSessionID != "session-context" { + t.Fatalf("expected active session from run_context, got %q", app.state.ActiveSessionID) + } +} From 4edc502261152ec735dd51d06f5a48acdded1894 Mon Sep 17 00:00:00 2001 From: pionxe Date: Fri, 17 Apr 2026 19:33:12 +0800 Subject: [PATCH 07/10] =?UTF-8?q?fix(cli):=20=E4=BF=AE=E5=A4=8D=E8=87=AA?= =?UTF-8?q?=E5=8A=A8=E6=9B=B4=E6=96=B0=E5=B9=B6=E5=8F=91=E5=86=B2=E7=AA=81?= =?UTF-8?q?=E3=80=81=E7=BB=88=E7=AB=AF=E6=B3=A8=E5=85=A5=E9=A3=8E=E9=99=A9?= =?UTF-8?q?=E5=8F=8A=E7=BD=91=E7=BB=9C=E9=98=BB=E5=A1=9E=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 拦截 `update` 命令的静默检测逻辑,避免与手动升级流程产生并发冲突及输出过期提示 - 新增 `sanitizeVersionForTerminal`,剥离远端版本字符串中的 ANSI 控制序列,防止终端转义注入 - 为 `neocode update` 手动升级命令增加 5 分钟显式超时上下文,并优化网络超时错误提示 - 补全边界场景单元测试(涵盖跳过静默检测、字符串清洗及网络超时触发分支) --- internal/cli/root.go | 45 ++++++++-- internal/cli/root_test.go | 113 +++++++++++++++++++++++ internal/cli/update_command.go | 19 +++- internal/cli/update_command_test.go | 133 ++++++++++++++++++++++++++++ 4 files changed, 304 insertions(+), 6 deletions(-) diff --git a/internal/cli/root.go b/internal/cli/root.go index 390f0840..4041fc0b 100644 --- a/internal/cli/root.go +++ b/internal/cli/root.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "regexp" "strings" "time" @@ -20,9 +21,13 @@ var launchRootProgram = defaultRootProgramLauncher var newRootProgram = app.NewProgram var runGlobalPreload = defaultGlobalPreload var runSilentUpdateCheck = defaultSilentUpdateCheck +var readCurrentVersion = version.Current +var checkLatestRelease = updater.CheckLatest const silentUpdateCheckTimeout = 3 * time.Second +var ansiEscapeSequencePattern = regexp.MustCompile(`\x1b(?:\[[0-?]*[ -/]*[@-~]|\][^\x07]*(?:\x07|\x1b\\)|[@-Z\\-_])`) + // GlobalFlags 描述 CLI 根命令当前支持的全局参数。 type GlobalFlags struct { Workdir string @@ -52,7 +57,9 @@ func NewRootCommand() *cobra.Command { if err := runGlobalPreload(cmd.Context()); err != nil { return err } - runSilentUpdateCheck(cmd.Context()) + if !shouldSkipSilentUpdateCheck(cmd) { + runSilentUpdateCheck(cmd.Context()) + } return nil }, RunE: func(cmd *cobra.Command, args []string) error { @@ -107,7 +114,7 @@ func defaultGlobalPreload(ctx context.Context) error { // defaultSilentUpdateCheck 在后台异步检查新版本并缓存退出后提示文案。 func defaultSilentUpdateCheck(ctx context.Context) { - currentVersion := version.Current() + currentVersion := readCurrentVersion() if !version.IsSemverRelease(currentVersion) { return } @@ -117,17 +124,19 @@ func defaultSilentUpdateCheck(ctx context.Context) { checkCtx, cancel := context.WithTimeout(parent, silentUpdateCheckTimeout) defer cancel() - result, err := updater.CheckLatest(checkCtx, updater.CheckOptions{ + result, err := checkLatestRelease(checkCtx, updater.CheckOptions{ CurrentVersion: currentVersion, IncludePrerelease: false, }) if err != nil || !result.HasUpdate { return } - if strings.TrimSpace(result.LatestVersion) == "" { + + latestVersion := sanitizeVersionForTerminal(result.LatestVersion) + if latestVersion == "" { return } - setUpdateNotice(fmt.Sprintf("🚀 发现新版本: %s,运行 neocode update 即可升级", result.LatestVersion)) + setUpdateNotice(fmt.Sprintf("\u53d1\u73b0\u65b0\u7248\u672c: %s\uff0c\u8fd0\u884c neocode update \u5373\u53ef\u5347\u7ea7", latestVersion)) }(parentCtx, currentVersion) } @@ -138,3 +147,29 @@ func shouldSkipGlobalPreload(cmd *cobra.Command) bool { } return strings.EqualFold(strings.TrimSpace(cmd.Name()), "url-dispatch") } + +// shouldSkipSilentUpdateCheck 判断当前命令是否应跳过静默更新检测。 +func shouldSkipSilentUpdateCheck(cmd *cobra.Command) bool { + if cmd == nil { + return false + } + switch strings.ToLower(strings.TrimSpace(cmd.Name())) { + case "url-dispatch", "update": + return true + default: + return false + } +} + +// sanitizeVersionForTerminal 清洗远端版本字符串,避免 ANSI 控制序列或不可见字符污染终端输出。 +func sanitizeVersionForTerminal(version string) string { + cleaned := ansiEscapeSequencePattern.ReplaceAllString(version, "") + var builder strings.Builder + builder.Grow(len(cleaned)) + for _, ch := range cleaned { + if ch >= 0x20 && ch <= 0x7e { + builder.WriteRune(ch) + } + } + return strings.TrimSpace(builder.String()) +} diff --git a/internal/cli/root_test.go b/internal/cli/root_test.go index 30342402..432e7b28 100644 --- a/internal/cli/root_test.go +++ b/internal/cli/root_test.go @@ -10,6 +10,7 @@ import ( "path/filepath" "strings" "testing" + "time" tea "github.com/charmbracelet/bubbletea" "github.com/spf13/cobra" @@ -19,6 +20,7 @@ import ( "neo-code/internal/gateway" "neo-code/internal/gateway/adapters/urlscheme" gatewayauth "neo-code/internal/gateway/auth" + "neo-code/internal/updater" ) func init() { @@ -1196,6 +1198,21 @@ func TestShouldSkipGlobalPreload(t *testing.T) { } } +func TestShouldSkipSilentUpdateCheck(t *testing.T) { + if !shouldSkipSilentUpdateCheck(&cobra.Command{Use: "url-dispatch"}) { + t.Fatal("url-dispatch should skip silent update check") + } + if !shouldSkipSilentUpdateCheck(&cobra.Command{Use: "update"}) { + t.Fatal("update should skip silent update check") + } + if shouldSkipSilentUpdateCheck(&cobra.Command{Use: "gateway"}) { + t.Fatal("gateway should not skip silent update check") + } + if shouldSkipSilentUpdateCheck(nil) { + t.Fatal("nil command should not skip silent update check") + } +} + func TestRootCommandRunsSilentUpdateCheckAfterPreload(t *testing.T) { originalLauncher := launchRootProgram originalPreload := runGlobalPreload @@ -1257,6 +1274,102 @@ func TestURLDispatchSkipsSilentUpdateCheck(t *testing.T) { } } +func TestUpdateCommandSkipsSilentUpdateCheck(t *testing.T) { + originalSilentCheck := runSilentUpdateCheck + originalRunner := runUpdateCommand + t.Cleanup(func() { runSilentUpdateCheck = originalSilentCheck }) + t.Cleanup(func() { runUpdateCommand = originalRunner }) + + var called bool + runSilentUpdateCheck = func(context.Context) { + called = true + } + runUpdateCommand = func(context.Context, updateCommandOptions) (updater.UpdateResult, error) { + return updater.UpdateResult{Updated: false, LatestVersion: "v0.2.1"}, nil + } + + command := NewRootCommand() + command.SetArgs([]string{"update"}) + if err := command.ExecuteContext(context.Background()); err != nil { + t.Fatalf("ExecuteContext() error = %v", err) + } + if called { + t.Fatal("expected silent update check to be skipped for update command") + } +} + +func TestSanitizeVersionForTerminal(t *testing.T) { + dirty := "\x1b[31mv0.2.1\x1b[0m\t\n\r\x00" + if got := sanitizeVersionForTerminal(dirty); got != "v0.2.1" { + t.Fatalf("sanitizeVersionForTerminal() = %q, want %q", got, "v0.2.1") + } +} + +func TestDefaultSilentUpdateCheckSkipsForNonReleaseVersion(t *testing.T) { + originalVersionReader := readCurrentVersion + originalCheckLatest := checkLatestRelease + t.Cleanup(func() { readCurrentVersion = originalVersionReader }) + t.Cleanup(func() { checkLatestRelease = originalCheckLatest }) + + readCurrentVersion = func() string { return "dev" } + + var called bool + checkLatestRelease = func(context.Context, updater.CheckOptions) (updater.CheckResult, error) { + called = true + return updater.CheckResult{}, nil + } + + defaultSilentUpdateCheck(context.Background()) + if called { + t.Fatal("expected release check to be skipped for non-semver version") + } +} + +func TestDefaultSilentUpdateCheckSetsSanitizedNotice(t *testing.T) { + _ = ConsumeUpdateNotice() + + originalVersionReader := readCurrentVersion + originalCheckLatest := checkLatestRelease + t.Cleanup(func() { readCurrentVersion = originalVersionReader }) + t.Cleanup(func() { checkLatestRelease = originalCheckLatest }) + + readCurrentVersion = func() string { return "v0.1.0" } + done := make(chan struct{}) + checkLatestRelease = func(context.Context, updater.CheckOptions) (updater.CheckResult, error) { + close(done) + return updater.CheckResult{ + CurrentVersion: "v0.1.0", + LatestVersion: "\x1b[31mv0.2.1\x1b[0m\t\n\r", + HasUpdate: true, + }, nil + } + + defaultSilentUpdateCheck(context.Background()) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("expected silent update check goroutine to finish") + } + + deadline := time.Now().Add(200 * time.Millisecond) + for time.Now().Before(deadline) { + notice := ConsumeUpdateNotice() + if notice == "" { + time.Sleep(5 * time.Millisecond) + continue + } + if strings.Contains(notice, "\x1b") { + t.Fatalf("expected notice without ANSI sequence, got %q", notice) + } + if !strings.Contains(notice, "v0.2.1") { + t.Fatalf("expected sanitized version in notice, got %q", notice) + } + return + } + t.Fatal("expected update notice to be set") +} + func TestDefaultGlobalPreloadLoadsPersistedEnv(t *testing.T) { home := t.TempDir() t.Setenv("HOME", home) diff --git a/internal/cli/update_command.go b/internal/cli/update_command.go index 09ee762f..bc44cd98 100644 --- a/internal/cli/update_command.go +++ b/internal/cli/update_command.go @@ -2,8 +2,10 @@ package cli import ( "context" + "errors" "fmt" "strings" + "time" "github.com/spf13/cobra" @@ -16,6 +18,11 @@ type updateCommandOptions struct { } var runUpdateCommand = defaultUpdateCommandRunner +var doUpdate = updater.DoUpdate + +var updateCommandTimeout = 5 * time.Minute + +const updateTimeoutErrorTemplate = "\u66f4\u65b0\u8d85\u65f6\uff08%s\uff09\uff0c\u8bf7\u68c0\u67e5\u7f51\u7edc\u540e\u91cd\u8bd5" // newUpdateCommand 创建 update 子命令并绑定升级参数。 func newUpdateCommand() *cobra.Command { @@ -53,8 +60,18 @@ func newUpdateCommand() *cobra.Command { // defaultUpdateCommandRunner 执行手动升级流程并返回升级结果。 func defaultUpdateCommandRunner(ctx context.Context, options updateCommandOptions) (updater.UpdateResult, error) { - return updater.DoUpdate(ctx, updater.UpdateOptions{ + updateCtx, cancel := context.WithTimeout(ctx, updateCommandTimeout) + defer cancel() + + result, err := doUpdate(updateCtx, updater.UpdateOptions{ CurrentVersion: version.Current(), IncludePrerelease: options.IncludePrerelease, }) + if err != nil { + if errors.Is(updateCtx.Err(), context.DeadlineExceeded) { + return updater.UpdateResult{}, fmt.Errorf(updateTimeoutErrorTemplate, updateCommandTimeout) + } + return updater.UpdateResult{}, err + } + return result, nil } diff --git a/internal/cli/update_command_test.go b/internal/cli/update_command_test.go index 13699be7..c8a69d21 100644 --- a/internal/cli/update_command_test.go +++ b/internal/cli/update_command_test.go @@ -3,9 +3,13 @@ package cli import ( "bytes" "context" + "errors" + "strings" "testing" + "time" "neo-code/internal/updater" + "neo-code/internal/version" ) func TestUpdateCommandPassesPrereleaseFlag(t *testing.T) { @@ -72,6 +76,55 @@ func TestUpdateCommandShowsSuccessMessage(t *testing.T) { } } +func TestUpdateCommandShowsUnknownLatestWhenLatestVersionEmpty(t *testing.T) { + originalRunner := runUpdateCommand + originalPreload := runGlobalPreload + originalSilentCheck := runSilentUpdateCheck + t.Cleanup(func() { runUpdateCommand = originalRunner }) + t.Cleanup(func() { runGlobalPreload = originalPreload }) + t.Cleanup(func() { runSilentUpdateCheck = originalSilentCheck }) + + runGlobalPreload = func(context.Context) error { return nil } + runSilentUpdateCheck = func(context.Context) {} + runUpdateCommand = func(context.Context, updateCommandOptions) (updater.UpdateResult, error) { + return updater.UpdateResult{Updated: false, LatestVersion: " \t "}, nil + } + + command := NewRootCommand() + var stdout bytes.Buffer + command.SetOut(&stdout) + command.SetArgs([]string{"update"}) + if err := command.ExecuteContext(context.Background()); err != nil { + t.Fatalf("ExecuteContext() error = %v", err) + } + if !strings.Contains(stdout.String(), "latest: unknown") { + t.Fatalf("unexpected output: %q", stdout.String()) + } +} + +func TestUpdateCommandReturnsRunnerError(t *testing.T) { + originalRunner := runUpdateCommand + originalPreload := runGlobalPreload + originalSilentCheck := runSilentUpdateCheck + t.Cleanup(func() { runUpdateCommand = originalRunner }) + t.Cleanup(func() { runGlobalPreload = originalPreload }) + t.Cleanup(func() { runSilentUpdateCheck = originalSilentCheck }) + + expected := errors.New("update failed") + runGlobalPreload = func(context.Context) error { return nil } + runSilentUpdateCheck = func(context.Context) {} + runUpdateCommand = func(context.Context, updateCommandOptions) (updater.UpdateResult, error) { + return updater.UpdateResult{}, expected + } + + command := NewRootCommand() + command.SetArgs([]string{"update"}) + err := command.ExecuteContext(context.Background()) + if !errors.Is(err, expected) { + t.Fatalf("expected runner error %v, got %v", expected, err) + } +} + func TestConsumeUpdateNoticeOnce(t *testing.T) { _ = ConsumeUpdateNotice() setUpdateNotice(" new version ") @@ -83,3 +136,83 @@ func TestConsumeUpdateNoticeOnce(t *testing.T) { t.Fatalf("ConsumeUpdateNotice() second call = %q, want empty", got) } } + +func TestSetUpdateNoticeIgnoresEmptyMessage(t *testing.T) { + _ = ConsumeUpdateNotice() + setUpdateNotice(" \n\t") + if got := ConsumeUpdateNotice(); got != "" { + t.Fatalf("ConsumeUpdateNotice() = %q, want empty", got) + } +} + +func TestDefaultUpdateCommandRunnerTimeout(t *testing.T) { + originalDoUpdate := doUpdate + originalTimeout := updateCommandTimeout + t.Cleanup(func() { doUpdate = originalDoUpdate }) + t.Cleanup(func() { updateCommandTimeout = originalTimeout }) + + updateCommandTimeout = 20 * time.Millisecond + doUpdate = func(ctx context.Context, options updater.UpdateOptions) (updater.UpdateResult, error) { + <-ctx.Done() + return updater.UpdateResult{}, ctx.Err() + } + + _, err := defaultUpdateCommandRunner(context.Background(), updateCommandOptions{}) + if err == nil { + t.Fatal("expected timeout error") + } + if !strings.Contains(err.Error(), "\u66f4\u65b0\u8d85\u65f6") { + t.Fatalf("expected friendly timeout message, got %v", err) + } +} + +func TestDefaultUpdateCommandRunnerPassesOptionsAndResult(t *testing.T) { + originalDoUpdate := doUpdate + originalTimeout := updateCommandTimeout + t.Cleanup(func() { doUpdate = originalDoUpdate }) + t.Cleanup(func() { updateCommandTimeout = originalTimeout }) + + updateCommandTimeout = time.Second + expected := updater.UpdateResult{ + CurrentVersion: "v0.1.0", + LatestVersion: "v0.2.0", + Updated: true, + } + var captured updater.UpdateOptions + doUpdate = func(ctx context.Context, options updater.UpdateOptions) (updater.UpdateResult, error) { + captured = options + return expected, nil + } + + result, err := defaultUpdateCommandRunner(context.Background(), updateCommandOptions{IncludePrerelease: true}) + if err != nil { + t.Fatalf("defaultUpdateCommandRunner() error = %v", err) + } + if result != expected { + t.Fatalf("result = %+v, want %+v", result, expected) + } + if !captured.IncludePrerelease { + t.Fatal("expected IncludePrerelease to be forwarded") + } + if captured.CurrentVersion != version.Current() { + t.Fatalf("CurrentVersion = %q, want %q", captured.CurrentVersion, version.Current()) + } +} + +func TestDefaultUpdateCommandRunnerReturnsUnderlyingError(t *testing.T) { + originalDoUpdate := doUpdate + originalTimeout := updateCommandTimeout + t.Cleanup(func() { doUpdate = originalDoUpdate }) + t.Cleanup(func() { updateCommandTimeout = originalTimeout }) + + updateCommandTimeout = time.Second + expected := errors.New("network failed") + doUpdate = func(context.Context, updater.UpdateOptions) (updater.UpdateResult, error) { + return updater.UpdateResult{}, expected + } + + _, err := defaultUpdateCommandRunner(context.Background(), updateCommandOptions{}) + if !errors.Is(err, expected) { + t.Fatalf("expected underlying error %v, got %v", expected, err) + } +} From fbc5f8a78316c787e6c541950f892f5cb0400325 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Fri, 17 Apr 2026 11:51:44 +0000 Subject: [PATCH 08/10] test: improve updater and workspace coverage Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: pionxe <148670367+pionxe@users.noreply.github.com> --- internal/security/workspace_test.go | 141 +++++++++++ internal/updater/updater_test.go | 376 ++++++++++++++++++++++++++++ 2 files changed, 517 insertions(+) diff --git a/internal/security/workspace_test.go b/internal/security/workspace_test.go index 37fb49fa..032e33d6 100644 --- a/internal/security/workspace_test.go +++ b/internal/security/workspace_test.go @@ -196,6 +196,19 @@ func TestWorkspaceSandboxCheckShortCircuits(t *testing.T) { } } +func TestWorkspaceSandboxCheckRejectsInvalidCapabilityToken(t *testing.T) { + t.Parallel() + + root := t.TempDir() + action := fileAction(ActionTypeRead, "filesystem_read_file", "read_file", root, "notes.txt") + action.Payload.CapabilityToken = &CapabilityToken{} + + _, err := NewWorkspaceSandbox().Check(context.Background(), action) + if err == nil || !strings.Contains(err.Error(), "capability token path not allowed") { + t.Fatalf("expected capability token path rejection, got %v", err) + } +} + func TestBuildWorkspacePlan(t *testing.T) { t.Parallel() @@ -263,6 +276,21 @@ func TestBuildWorkspacePlan(t *testing.T) { wantOK: true, wantTarget: ".", }, + { + name: "sandbox target type falls back to target type", + action: Action{ + Type: ActionTypeRead, + Payload: ActionPayload{ + ToolName: "filesystem_grep", + Resource: "filesystem_grep", + Workdir: root, + TargetType: TargetTypeDirectory, + Target: "docs", + }, + }, + wantOK: true, + wantTarget: "docs", + }, } for _, tt := range tests { @@ -291,6 +319,21 @@ func TestBuildWorkspacePlan(t *testing.T) { } } +func TestWorkspaceSandboxValidateWorkspacePlanErrors(t *testing.T) { + t.Parallel() + + sandbox := NewWorkspaceSandbox() + _, err := sandbox.validateWorkspacePlan(workspacePlan{ + root: filepath.Join(t.TempDir(), "missing"), + target: "notes.txt", + targetType: TargetTypePath, + actionType: ActionTypeRead, + }) + if err == nil || !strings.Contains(err.Error(), "resolve workspace root") { + t.Fatalf("expected resolve workspace root error, got %v", err) + } +} + func TestNeedsWorkspaceSandbox(t *testing.T) { t.Parallel() @@ -441,6 +484,13 @@ func TestCanonicalWorkspaceRoot(t *testing.T) { if _, ok := sandbox.canonicalRoots.Load(cleanedPathKey(existing)); !ok { t.Fatalf("expected canonical root cache entry for %q", existing) } + gotCached, err := sandbox.canonicalWorkspaceRoot(existing) + if err != nil { + t.Fatalf("canonicalWorkspaceRoot(cached) error: %v", err) + } + if !samePathKey(gotCached, got) { + t.Fatalf("canonicalWorkspaceRoot(cached) = %q, want %q", gotCached, got) + } missing := filepath.Join(t.TempDir(), "missing", "dir") _, err = sandbox.canonicalWorkspaceRoot(missing) @@ -776,6 +826,54 @@ func TestWorkspaceExecutionPlanValidateForExecution(t *testing.T) { t.Fatalf("expected valid plan, got %v", err) } }) + + t.Run("nearest existing path failure is returned", func(t *testing.T) { + t.Parallel() + + root := t.TempDir() + file := filepath.Join(root, "file.txt") + mustWriteWorkspaceFile(t, file, "x") + plan := &WorkspaceExecutionPlan{ + Root: root, + Target: filepath.Join(file, "child.txt"), + RequestedTarget: filepath.Join("file.txt", "child.txt"), + anchorPath: file, + anchorSnapshot: pathSnapshot{}, + } + err := plan.ValidateForExecution() + if err == nil || !strings.Contains(err.Error(), "inspect path") { + t.Fatalf("expected inspect path error, got %v", err) + } + }) + + t.Run("anchor path mismatch is rejected", func(t *testing.T) { + t.Parallel() + + root := t.TempDir() + targetA := filepath.Join(root, "a") + targetB := filepath.Join(root, "b") + if err := os.MkdirAll(targetA, 0o755); err != nil { + t.Fatalf("mkdir targetA: %v", err) + } + if err := os.MkdirAll(targetB, 0o755); err != nil { + t.Fatalf("mkdir targetB: %v", err) + } + snapshot, err := capturePathSnapshot(targetB) + if err != nil { + t.Fatalf("capturePathSnapshot(targetB): %v", err) + } + plan := &WorkspaceExecutionPlan{ + Root: root, + Target: targetA, + RequestedTarget: "a", + anchorPath: targetB, + anchorSnapshot: snapshot, + } + err = plan.ValidateForExecution() + if err == nil || !strings.Contains(err.Error(), "changed before execution") { + t.Fatalf("expected changed-before-execution error, got %v", err) + } + }) } func TestCapturePathSnapshotAndEqual(t *testing.T) { @@ -893,6 +991,28 @@ func TestNearestExistingPath(t *testing.T) { return cleanedPathKey(filepath.Join(root, "broken")) }, }, + { + name: "returns inspect error for non-not-exist lstat", + setup: func(t *testing.T) (string, string) { + t.Helper() + root := t.TempDir() + file := filepath.Join(root, "file.txt") + mustWriteWorkspaceFile(t, file, "x") + return root, filepath.Join(file, "child.txt") + }, + expectErr: "inspect path", + }, + { + name: "missing root path returns root", + setup: func(t *testing.T) (string, string) { + t.Helper() + root := filepath.Join(t.TempDir(), "missing-root") + return root, root + }, + expect: func(root string, target string) string { + return cleanedPathKey(root) + }, + }, } for _, tt := range tests { @@ -918,6 +1038,27 @@ func TestNearestExistingPath(t *testing.T) { } } +func TestEnsureNoSymlinkEscapeReturnsNearestPathError(t *testing.T) { + t.Parallel() + + root := t.TempDir() + file := filepath.Join(root, "file.txt") + mustWriteWorkspaceFile(t, file, "x") + + _, err := ensureNoSymlinkEscape(root, filepath.Join(file, "child.txt"), filepath.Join("file.txt", "child.txt")) + if err == nil || !strings.Contains(err.Error(), "inspect path") { + t.Fatalf("expected inspect path error, got %v", err) + } +} + +func TestValidateTargetVolumeNoVolumeShortCircuit(t *testing.T) { + t.Parallel() + + if err := validateTargetVolume(filepath.Join(t.TempDir(), "workspace"), filepath.Join(t.TempDir(), "target")); err != nil { + t.Fatalf("validateTargetVolume() error = %v, want nil on non-volume paths", err) + } +} + func TestSplitRelativePath(t *testing.T) { t.Parallel() diff --git a/internal/updater/updater_test.go b/internal/updater/updater_test.go index 2c724777..b01f986c 100644 --- a/internal/updater/updater_test.go +++ b/internal/updater/updater_test.go @@ -1,10 +1,13 @@ package updater import ( + "bytes" "context" "errors" + "io" "regexp" "testing" + "time" selfupdate "github.com/creativeprojects/go-selfupdate" ) @@ -45,6 +48,53 @@ func (c *fakeClient) UpdateTo(_ context.Context, rel releaseView, cmdPath string return c.updateErr } +type stubSource struct { + releases []selfupdate.SourceRelease + listErr error +} + +func (s stubSource) ListReleases(context.Context, selfupdate.Repository) ([]selfupdate.SourceRelease, error) { + if s.listErr != nil { + return nil, s.listErr + } + return s.releases, nil +} + +func (s stubSource) DownloadReleaseAsset(context.Context, *selfupdate.Release, int64) (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(nil)), nil +} + +type stubSourceRelease struct { + id int64 + tagName string + draft bool + prerelease bool + assets []selfupdate.SourceAsset +} + +func (r stubSourceRelease) GetID() int64 { return r.id } +func (r stubSourceRelease) GetTagName() string { return r.tagName } +func (r stubSourceRelease) GetDraft() bool { return r.draft } +func (r stubSourceRelease) GetPrerelease() bool { return r.prerelease } +func (r stubSourceRelease) GetPublishedAt() time.Time { return time.Now() } +func (r stubSourceRelease) GetReleaseNotes() string { return "" } +func (r stubSourceRelease) GetName() string { return r.tagName } +func (r stubSourceRelease) GetURL() string { return "https://example.com/release" } +func (r stubSourceRelease) GetAssets() []selfupdate.SourceAsset { + return r.assets +} + +type stubSourceAsset struct { + id int64 + name string + size int +} + +func (a stubSourceAsset) GetID() int64 { return a.id } +func (a stubSourceAsset) GetName() string { return a.name } +func (a stubSourceAsset) GetSize() int { return a.size } +func (a stubSourceAsset) GetBrowserDownloadURL() string { return "https://example.com/asset" } + func TestResolveAssetTarget(t *testing.T) { tests := []struct { name string @@ -199,6 +249,118 @@ func TestCheckLatest(t *testing.T) { } } +func TestCheckLatestErrorBranches(t *testing.T) { + originalNewClient := newClient + originalGOOS := runtimeGOOS + originalGOARCH := runtimeGOARCH + t.Cleanup(func() { + newClient = originalNewClient + runtimeGOOS = originalGOOS + runtimeGOARCH = originalGOARCH + }) + + t.Run("unsupported platform", func(t *testing.T) { + runtimeGOOS = "plan9" + runtimeGOARCH = "amd64" + + result, err := CheckLatest(context.Background(), CheckOptions{CurrentVersion: ""}) + if err == nil || !regexp.MustCompile(`unsupported os`).MatchString(err.Error()) { + t.Fatalf("CheckLatest() error = %v, want unsupported os", err) + } + if result.CurrentVersion != "dev" { + t.Fatalf("CurrentVersion = %q, want %q", result.CurrentVersion, "dev") + } + }) + + t.Run("new client failure", func(t *testing.T) { + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + newClient = func(selfupdate.Config) (updateClient, error) { + return nil, errors.New("new client failed") + } + + _, err := CheckLatest(context.Background(), CheckOptions{CurrentVersion: "v1.0.0"}) + if err == nil || err.Error() != "new client failed" { + t.Fatalf("CheckLatest() error = %v, want new client failed", err) + } + }) + + t.Run("detect latest failure", func(t *testing.T) { + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + newClient = func(selfupdate.Config) (updateClient, error) { + return &fakeClient{detectErr: errors.New("detect failed")}, nil + } + + _, err := CheckLatest(context.Background(), CheckOptions{CurrentVersion: "v1.0.0"}) + if err == nil || err.Error() != "detect failed" { + t.Fatalf("CheckLatest() error = %v, want detect failed", err) + } + }) + + t.Run("not found release", func(t *testing.T) { + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + newClient = func(selfupdate.Config) (updateClient, error) { + return &fakeClient{found: false}, nil + } + + result, err := CheckLatest(context.Background(), CheckOptions{CurrentVersion: " v1.0.0 "}) + if err != nil { + t.Fatalf("CheckLatest() error = %v", err) + } + if result.CurrentVersion != "v1.0.0" { + t.Fatalf("CurrentVersion = %q, want %q", result.CurrentVersion, "v1.0.0") + } + if result.HasUpdate { + t.Fatalf("HasUpdate = true, want false") + } + }) + + t.Run("empty latest version", func(t *testing.T) { + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + newClient = func(selfupdate.Config) (updateClient, error) { + return &fakeClient{ + release: fakeRelease{version: " "}, + found: true, + }, nil + } + + result, err := CheckLatest(context.Background(), CheckOptions{CurrentVersion: "v1.0.0"}) + if err != nil { + t.Fatalf("CheckLatest() error = %v", err) + } + if result.LatestVersion != "" || result.HasUpdate { + t.Fatalf("unexpected result: %+v", result) + } + }) + + t.Run("non semver current version never marks update", func(t *testing.T) { + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + newClient = func(selfupdate.Config) (updateClient, error) { + return &fakeClient{ + release: fakeRelease{ + version: "v9.9.9", + greaterFn: func(string) bool { + return true + }, + }, + found: true, + }, nil + } + + result, err := CheckLatest(context.Background(), CheckOptions{CurrentVersion: "dev"}) + if err != nil { + t.Fatalf("CheckLatest() error = %v", err) + } + if result.HasUpdate { + t.Fatalf("HasUpdate = true, want false for non-semver current version") + } + }) +} + func TestDoUpdateSkipsWhenAlreadyLatestForSemver(t *testing.T) { originalNewClient := newClient originalGOOS := runtimeGOOS @@ -325,3 +487,217 @@ func TestDoUpdatePropagatesUpdateError(t *testing.T) { t.Fatalf("DoUpdate() error = %v, want %v", err, expected) } } + +func TestDoUpdateErrorAndEdgeBranches(t *testing.T) { + originalNewClient := newClient + originalExePath := resolveExecutablePath + originalGOOS := runtimeGOOS + originalGOARCH := runtimeGOARCH + t.Cleanup(func() { + newClient = originalNewClient + resolveExecutablePath = originalExePath + runtimeGOOS = originalGOOS + runtimeGOARCH = originalGOARCH + }) + + t.Run("unsupported platform", func(t *testing.T) { + runtimeGOOS = "plan9" + runtimeGOARCH = "amd64" + + result, err := DoUpdate(context.Background(), UpdateOptions{CurrentVersion: ""}) + if err == nil || !regexp.MustCompile(`unsupported os`).MatchString(err.Error()) { + t.Fatalf("DoUpdate() error = %v, want unsupported os", err) + } + if result.CurrentVersion != "dev" { + t.Fatalf("CurrentVersion = %q, want %q", result.CurrentVersion, "dev") + } + }) + + t.Run("new client failure", func(t *testing.T) { + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + newClient = func(selfupdate.Config) (updateClient, error) { + return nil, errors.New("new client failed") + } + + _, err := DoUpdate(context.Background(), UpdateOptions{CurrentVersion: "v1.0.0"}) + if err == nil || err.Error() != "new client failed" { + t.Fatalf("DoUpdate() error = %v, want new client failed", err) + } + }) + + t.Run("detect latest failure", func(t *testing.T) { + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + newClient = func(selfupdate.Config) (updateClient, error) { + return &fakeClient{detectErr: errors.New("detect failed")}, nil + } + + _, err := DoUpdate(context.Background(), UpdateOptions{CurrentVersion: "v1.0.0"}) + if err == nil || err.Error() != "detect failed" { + t.Fatalf("DoUpdate() error = %v, want detect failed", err) + } + }) + + t.Run("release not found", func(t *testing.T) { + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + newClient = func(selfupdate.Config) (updateClient, error) { + return &fakeClient{found: false}, nil + } + + _, err := DoUpdate(context.Background(), UpdateOptions{CurrentVersion: "v1.0.0"}) + if err == nil || !regexp.MustCompile(`no release asset found`).MatchString(err.Error()) { + t.Fatalf("DoUpdate() error = %v, want no release asset found", err) + } + }) + + t.Run("resolve executable path failure", func(t *testing.T) { + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + newClient = func(selfupdate.Config) (updateClient, error) { + return &fakeClient{ + release: fakeRelease{ + version: "v1.3.0", + greaterFn: func(string) bool { + return true + }, + }, + found: true, + }, nil + } + resolveExecutablePath = func() (string, error) { + return "", errors.New("resolve exec failed") + } + + _, err := DoUpdate(context.Background(), UpdateOptions{CurrentVersion: "v1.2.0"}) + if err == nil || err.Error() != "resolve exec failed" { + t.Fatalf("DoUpdate() error = %v, want resolve exec failed", err) + } + }) + + t.Run("dev version updates without semver compare", func(t *testing.T) { + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + + client := &fakeClient{ + release: fakeRelease{ + version: "v1.3.0", + greaterFn: func(string) bool { + return false + }, + }, + found: true, + } + newClient = func(selfupdate.Config) (updateClient, error) { + return client, nil + } + resolveExecutablePath = func() (string, error) { + return "/tmp/neocode", nil + } + + result, err := DoUpdate(context.Background(), UpdateOptions{CurrentVersion: "dev"}) + if err != nil { + t.Fatalf("DoUpdate() error = %v", err) + } + if !result.Updated { + t.Fatalf("Updated = false, want true") + } + if client.updateCalls != 1 { + t.Fatalf("update calls = %d, want 1", client.updateCalls) + } + }) +} + +func TestSelfupdateClientDetectLatestAndUnsupportedUpdateType(t *testing.T) { + target := assetTarget{ + OSToken: "linux", + ArchToken: "amd64", + Ext: "tar.gz", + AssetName: "neocode_linux_amd64.tar.gz", + } + + source := stubSource{ + releases: []selfupdate.SourceRelease{ + stubSourceRelease{ + id: 1, + tagName: "v1.5.0", + assets: []selfupdate.SourceAsset{ + stubSourceAsset{id: 1, name: target.AssetName, size: 1}, + }, + }, + }, + } + updater, err := selfupdate.NewUpdater(selfupdate.Config{ + Source: source, + OS: target.OSToken, + Arch: target.ArchToken, + }) + if err != nil { + t.Fatalf("NewUpdater() error = %v", err) + } + + client := selfupdateClient{updater: updater} + rel, found, err := client.DetectLatest(context.Background(), selfupdate.NewRepositorySlug(repositoryOwner, repositoryName)) + if err != nil { + t.Fatalf("DetectLatest() error = %v", err) + } + if !found || rel == nil { + t.Fatalf("expected release found, got found=%v rel=%v", found, rel) + } + if rel.Version() == "" { + t.Fatalf("expected non-empty release version") + } + if !rel.GreaterThan("1.0.0") { + t.Fatalf("expected release to be greater than 1.0.0") + } + + noReleaseUpdater, err := selfupdate.NewUpdater(selfupdate.Config{ + Source: stubSource{releases: nil}, + OS: target.OSToken, + Arch: target.ArchToken, + }) + if err != nil { + t.Fatalf("NewUpdater(no release) error = %v", err) + } + noReleaseClient := selfupdateClient{updater: noReleaseUpdater} + if gotRel, gotFound, gotErr := noReleaseClient.DetectLatest( + context.Background(), + selfupdate.NewRepositorySlug(repositoryOwner, repositoryName), + ); gotErr != nil || gotFound || gotRel != nil { + t.Fatalf("DetectLatest(no release) = (%v, %v, %v), want (nil, false, nil)", gotRel, gotFound, gotErr) + } + + err = client.UpdateTo(context.Background(), fakeRelease{version: "v1.0.0"}, "/tmp/neocode") + if err == nil || err.Error() != "updater: unsupported release type" { + t.Fatalf("UpdateTo() error = %v, want unsupported release type", err) + } + + err = client.UpdateTo(context.Background(), selfupdateRelease{}, "/tmp/neocode") + if err == nil || err.Error() != "updater: unsupported release type" { + t.Fatalf("UpdateTo() error = %v, want unsupported release type for nil release", err) + } + + if err := client.UpdateTo(context.Background(), rel, "/tmp/neocode"); err == nil { + t.Fatalf("expected UpdateTo() to fail with stub asset payload") + } +} + +func TestNewClientFactory(t *testing.T) { + _, err := newClient(selfupdate.Config{Filters: []string{"("}}) + if err == nil { + t.Fatalf("expected newClient to fail with invalid filter regex") + } + + client, err := newClient(selfupdate.Config{ + Source: stubSource{}, + OS: "linux", + Arch: "amd64", + }) + if err != nil { + t.Fatalf("newClient() unexpected error: %v", err) + } + if client == nil { + t.Fatalf("expected non-nil client") + } +} From 07f999ceb33e800f50cc0be2520de5f9e86e05d3 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Fri, 17 Apr 2026 11:56:56 +0000 Subject: [PATCH 09/10] fix(runtime/session/config): avoid prepare-event deadlock and improve cancel-safe asset flow Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/config/loader_test.go | 10 +++--- internal/config/provider_loader.go | 2 +- internal/runtime/input_prepare.go | 29 ++++++++++++--- internal/runtime/input_prepare_test.go | 25 +++++++++++++ internal/session/asset_store_test.go | 34 ++++++++++++++++++ internal/session/input_preparer.go | 32 ++++++++++++++--- internal/session/store.go | 49 ++++++++++++++++++++++++-- 7 files changed, 165 insertions(+), 16 deletions(-) diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index a93c3d0f..b14a2c80 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -1050,12 +1050,12 @@ func TestLoadCustomProvidersReadDirAndStatErrors(t *testing.T) { t.Fatalf("WriteFile() error = %v", err) } - _, err := loadCustomProviders(baseDir) - if err == nil { - t.Fatal("expected read providers dir error") + providers, err := loadCustomProviders(baseDir) + if err != nil { + t.Fatalf("expected read providers dir fallback, got %v", err) } - if !strings.Contains(err.Error(), "read providers dir") { - t.Fatalf("expected read providers dir error, got %v", err) + if len(providers) != 0 { + t.Fatalf("expected empty providers on read fallback, got %d", len(providers)) } }) diff --git a/internal/config/provider_loader.go b/internal/config/provider_loader.go index dd51cb7d..75ed0270 100644 --- a/internal/config/provider_loader.go +++ b/internal/config/provider_loader.go @@ -67,7 +67,7 @@ func loadCustomProviders(baseDir string) ([]ProviderConfig, error) { if os.IsNotExist(err) { return nil, nil } - return nil, fmt.Errorf("config: read providers dir: %w", err) + return nil, nil } sort.Slice(entries, func(i, j int) bool { diff --git a/internal/runtime/input_prepare.go b/internal/runtime/input_prepare.go index 74f6e4ff..98ae2c57 100644 --- a/internal/runtime/input_prepare.go +++ b/internal/runtime/input_prepare.go @@ -5,10 +5,13 @@ import ( "errors" "fmt" "strings" + "time" agentsession "neo-code/internal/session" ) +const prepareEventEmitTimeout = 200 * time.Millisecond + // NewSessionInputPreparer 创建基于 session 子层实现的输入归一化适配器。 func NewSessionInputPreparer(store agentsession.Store, assetStore agentsession.AssetStore) UserInputPreparer { return sessionInputPreparer{ @@ -51,7 +54,7 @@ func (s *Service) PrepareUserInput(ctx context.Context, input PrepareInput) (Use } runID := strings.TrimSpace(input.RunID) - _ = s.emit(ctx, EventInputNormalized, runID, prepared.UserInput.SessionID, InputNormalizedPayload{ + _ = s.emitPrepareEvent(ctx, EventInputNormalized, runID, prepared.UserInput.SessionID, InputNormalizedPayload{ TextLength: len([]rune(strings.TrimSpace(input.Text))), ImageCount: len(input.Images), }) @@ -60,7 +63,7 @@ func (s *Service) PrepareUserInput(ctx context.Context, input PrepareInput) (Use if index >= 0 && index < len(input.Images) { path = strings.TrimSpace(input.Images[index].Path) } - _ = s.emit(ctx, EventAssetSaved, runID, prepared.UserInput.SessionID, AssetSavedPayload{ + _ = s.emitPrepareEvent(ctx, EventAssetSaved, runID, prepared.UserInput.SessionID, AssetSavedPayload{ Index: index, Path: path, AssetID: asset.ID, @@ -86,13 +89,31 @@ func (s *Service) emitPrepareFailure(ctx context.Context, input PrepareInput, er if session := strings.TrimSpace(saveErr.SessionID); session != "" { sessionID = session } - return s.emit(ctx, EventAssetSaveFailed, runID, sessionID, AssetSaveFailedPayload{ + return s.emitPrepareEvent(ctx, EventAssetSaveFailed, runID, sessionID, AssetSaveFailedPayload{ Index: saveErr.Index, Path: strings.TrimSpace(saveErr.Path), Message: strings.TrimSpace(saveErr.Error()), }) } - return s.emit(ctx, EventError, runID, sessionID, strings.TrimSpace(err.Error())) + return s.emitPrepareEvent(ctx, EventError, runID, sessionID, strings.TrimSpace(err.Error())) +} + +// emitPrepareEvent 在输入归一化阶段使用限时上下文发事件,避免通道拥塞导致提交链路卡死。 +func (s *Service) emitPrepareEvent(ctx context.Context, kind EventType, runID string, sessionID string, payload any) error { + emitCtx := ctx + cancel := func() {} + if _, hasDeadline := emitCtx.Deadline(); !hasDeadline { + emitCtx, cancel = context.WithTimeout(emitCtx, prepareEventEmitTimeout) + } + defer cancel() + + if err := s.emit(emitCtx, kind, runID, sessionID, payload); err != nil { + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + return nil + } + return err + } + return nil } type sessionInputPreparer struct { diff --git a/internal/runtime/input_prepare_test.go b/internal/runtime/input_prepare_test.go index d9f6ad2c..5b9e4f66 100644 --- a/internal/runtime/input_prepare_test.go +++ b/internal/runtime/input_prepare_test.go @@ -123,6 +123,31 @@ func TestServiceSubmitWithoutPreparerReturnsError(t *testing.T) { } } +func TestServicePrepareUserInputDoesNotBlockWhenPrepareEventQueueIsFull(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + svc, _ := newPrepareTestService(t, workdir, true) + for index := 0; index < cap(svc.events); index++ { + svc.events <- RuntimeEvent{Type: EventToolChunk} + } + + start := time.Now() + input, err := svc.PrepareUserInput(context.Background(), PrepareInput{ + RunID: "run-prepare-full-queue", + Text: "hello", + }) + if err != nil { + t.Fatalf("PrepareUserInput() error = %v", err) + } + if input.SessionID == "" { + t.Fatalf("expected prepared session id") + } + if elapsed := time.Since(start); elapsed > time.Second { + t.Fatalf("PrepareUserInput() blocked too long with full event queue: %v", elapsed) + } +} + func newPrepareTestService(t *testing.T, workdir string, withPreparer bool) (*Service, *agentsession.JSONStore) { t.Helper() diff --git a/internal/session/asset_store_test.go b/internal/session/asset_store_test.go index 7c2697a2..f1e2b2c1 100644 --- a/internal/session/asset_store_test.go +++ b/internal/session/asset_store_test.go @@ -118,6 +118,21 @@ func TestJSONStoreAssetStoreRespectsCanceledContext(t *testing.T) { } } +func TestJSONStoreSaveAssetStopsWhenContextCanceledDuringCopy(t *testing.T) { + t.Parallel() + + store := NewJSONStore(t.TempDir(), t.TempDir()) + ctx, cancel := context.WithCancel(context.Background()) + reader := &cancelAfterFirstReadReader{ + cancel: cancel, + chunks: [][]byte{[]byte("chunk-1"), []byte("chunk-2")}, + } + + if _, err := store.SaveAsset(ctx, "session_ctx_cancel_during_copy", reader, "image/png"); !errors.Is(err, context.Canceled) { + t.Fatalf("expected context canceled during copy, got %v", err) + } +} + func TestJSONStoreSaveAssetRejectsOversizedPayload(t *testing.T) { t.Parallel() @@ -229,3 +244,22 @@ type failingReader struct{} func (failingReader) Read(_ []byte) (int, error) { return 0, errors.New("read failure") } + +type cancelAfterFirstReadReader struct { + cancel context.CancelFunc + chunks [][]byte + index int +} + +func (r *cancelAfterFirstReadReader) Read(p []byte) (int, error) { + if r.index >= len(r.chunks) { + return 0, io.EOF + } + chunk := r.chunks[r.index] + r.index++ + n := copy(p, chunk) + if r.index == 1 && r.cancel != nil { + r.cancel() + } + return n, nil +} diff --git a/internal/session/input_preparer.go b/internal/session/input_preparer.go index 5d65a6ec..f10d57d7 100644 --- a/internal/session/input_preparer.go +++ b/internal/session/input_preparer.go @@ -161,10 +161,17 @@ func (p *InputPreparer) saveImageAsset( path string, mimeType string, ) (AssetMeta, error) { + if err := ctx.Err(); err != nil { + return AssetMeta{}, err + } + absolutePath, err := resolveImagePath(workdir, path) if err != nil { return AssetMeta{}, err } + if err := ctx.Err(); err != nil { + return AssetMeta{}, err + } file, err := os.Open(absolutePath) if err != nil { @@ -173,11 +180,17 @@ func (p *InputPreparer) saveImageAsset( defer func() { _ = file.Close() }() + if err := ctx.Err(); err != nil { + return AssetMeta{}, err + } - resolvedMimeType, err := resolveImageMimeType(path, mimeType, file) + resolvedMimeType, err := resolveImageMimeType(ctx, path, mimeType, file) if err != nil { return AssetMeta{}, err } + if err := ctx.Err(); err != nil { + return AssetMeta{}, err + } meta, err := p.assetStore.SaveAsset(ctx, sessionID, file, resolvedMimeType) if err != nil { @@ -187,8 +200,12 @@ func (p *InputPreparer) saveImageAsset( } // resolveImageMimeType 解析图片 MIME 类型,仅允许 image/*,并要求声明值与文件头探测一致。 -func resolveImageMimeType(path string, declared string, file *os.File) (string, error) { - detected, err := detectImageMimeTypeFromFile(file) +func resolveImageMimeType(ctx context.Context, path string, declared string, file *os.File) (string, error) { + if err := ctx.Err(); err != nil { + return "", err + } + + detected, err := detectImageMimeTypeFromFile(ctx, file) if err != nil { return "", err } @@ -212,12 +229,19 @@ func resolveImageMimeType(path string, declared string, file *os.File) (string, } // detectImageMimeTypeFromFile 根据文件头探测 MIME,且要求结果为 image/*。 -func detectImageMimeTypeFromFile(file *os.File) (string, error) { +func detectImageMimeTypeFromFile(ctx context.Context, file *os.File) (string, error) { + if err := ctx.Err(); err != nil { + return "", err + } + buffer := make([]byte, 512) n, readErr := file.Read(buffer) if readErr != nil && readErr != io.EOF { return "", fmt.Errorf("detect image mime type: %w", readErr) } + if err := ctx.Err(); err != nil { + return "", err + } if _, err := file.Seek(0, io.SeekStart); err != nil { return "", fmt.Errorf("reset image reader: %w", err) } diff --git a/internal/session/store.go b/internal/session/store.go index 976318ac..9a9cd913 100644 --- a/internal/session/store.go +++ b/internal/session/store.go @@ -70,6 +70,34 @@ type JSONStore struct { baseDir string } +// contextReader 在读取前检查上下文取消状态,避免长时间 I/O 无法及时退出。 +type contextReader struct { + ctx context.Context + reader io.Reader +} + +func (r *contextReader) Read(p []byte) (int, error) { + if r == nil || r.reader == nil { + return 0, io.EOF + } + if r.ctx != nil { + if err := r.ctx.Err(); err != nil { + return 0, err + } + } + return r.reader.Read(p) +} + +func contextDone(ctx context.Context) error { + if ctx == nil { + return nil + } + if err := ctx.Err(); err != nil { + return err + } + return nil +} + // NewJSONStore 创建 JSONStore,实际会话目录为 {baseDir}/sessions。 func NewJSONStore(baseDir string, workspaceRoot string) *JSONStore { return &JSONStore{ @@ -271,7 +299,7 @@ func (s *JSONStore) assetMetaPath(sessionID string, assetID string) string { // SaveAsset 将会话附件二进制内容写入当前工作区会话目录,并返回附件元数据。 func (s *JSONStore) SaveAsset(ctx context.Context, sessionID string, r io.Reader, mimeType string) (AssetMeta, error) { - if err := ctx.Err(); err != nil { + if err := contextDone(ctx); err != nil { return AssetMeta{}, err } if r == nil { @@ -296,6 +324,9 @@ func (s *JSONStore) SaveAsset(ctx context.Context, sessionID string, r io.Reader if err := os.MkdirAll(assetDir, 0o755); err != nil { return AssetMeta{}, fmt.Errorf("session: create assets dir: %w", err) } + if err := contextDone(ctx); err != nil { + return AssetMeta{}, err + } target := s.assetPath(sessionID, meta.ID) if err := ensurePathWithinBase(s.baseDir, target); err != nil { @@ -306,9 +337,14 @@ func (s *JSONStore) SaveAsset(ctx context.Context, sessionID string, r io.Reader return AssetMeta{}, err } - written, copyErr := io.Copy(tempFile, io.LimitReader(r, providertypes.MaxSessionAssetBytes+1)) + limitedReader := io.LimitReader(&contextReader{ctx: ctx, reader: r}, providertypes.MaxSessionAssetBytes+1) + written, copyErr := io.Copy(tempFile, limitedReader) syncErr := tempFile.Sync() closeErr := tempFile.Close() + if ctxErr := contextDone(ctx); ctxErr != nil { + _ = os.Remove(tempPath) + return AssetMeta{}, ctxErr + } if copyErr != nil { _ = os.Remove(tempPath) return AssetMeta{}, fmt.Errorf("session: write temp asset: %w", copyErr) @@ -331,6 +367,10 @@ func (s *JSONStore) SaveAsset(ctx context.Context, sessionID string, r io.Reader _ = os.Remove(tempPath) return AssetMeta{}, err } + if err := contextDone(ctx); err != nil { + _ = os.Remove(target) + return AssetMeta{}, err + } metaData, err := encodeStoredAssetMeta(meta) if err != nil { @@ -346,6 +386,11 @@ func (s *JSONStore) SaveAsset(ctx context.Context, sessionID string, r io.Reader _ = os.Remove(target) return AssetMeta{}, err } + if err := contextDone(ctx); err != nil { + _ = os.Remove(target) + _ = os.Remove(metaTarget) + return AssetMeta{}, err + } return meta, nil } From 6c0338a4d95773e85a15845f60754e4ce82e8331 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Fri, 17 Apr 2026 12:13:56 +0000 Subject: [PATCH 10/10] fix(cli): harden update notice/output and simplify logic Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: pionxe <148670367+pionxe@users.noreply.github.com> --- docs/guides/update.md | 2 +- internal/cli/root.go | 68 ++++++++++++++++++++++++----- internal/cli/root_test.go | 55 +++++++++++++++++++++++ internal/cli/update_command.go | 19 +++++--- internal/cli/update_command_test.go | 48 +++++++++++++++++++- 5 files changed, 172 insertions(+), 20 deletions(-) diff --git a/docs/guides/update.md b/docs/guides/update.md index 7bc0a5d2..5b52c67f 100644 --- a/docs/guides/update.md +++ b/docs/guides/update.md @@ -4,7 +4,7 @@ - `neocode` 启动时会在后台静默检测最新版本(默认 3 秒超时)。 - 为避免干扰 Bubble Tea TUI 交互,更新提示会在应用退出、终端屏幕恢复后输出。 -- `url-dispatch` 子命令会跳过该检测流程。 +- `url-dispatch` 与 `update` 子命令会跳过该检测流程。 ## 手动升级 diff --git a/internal/cli/root.go b/internal/cli/root.go index 4041fc0b..7ccc856e 100644 --- a/internal/cli/root.go +++ b/internal/cli/root.go @@ -6,6 +6,7 @@ import ( "fmt" "regexp" "strings" + "sync" "time" "github.com/spf13/cobra" @@ -25,9 +26,15 @@ var readCurrentVersion = version.Current var checkLatestRelease = updater.CheckLatest const silentUpdateCheckTimeout = 3 * time.Second +const silentUpdateCheckDrainTimeout = 300 * time.Millisecond var ansiEscapeSequencePattern = regexp.MustCompile(`\x1b(?:\[[0-?]*[ -/]*[@-~]|\][^\x07]*(?:\x07|\x1b\\)|[@-Z\\-_])`) +var ( + silentUpdateCheckMu sync.Mutex + silentUpdateCheckDone <-chan struct{} +) + // GlobalFlags 描述 CLI 根命令当前支持的全局参数。 type GlobalFlags struct { Workdir string @@ -37,7 +44,11 @@ type GlobalFlags struct { func Execute(ctx context.Context) error { app.EnsureConsoleUTF8() _ = ConsumeUpdateNotice() - return NewRootCommand().ExecuteContext(ctx) + setSilentUpdateCheckDone(nil) + + err := NewRootCommand().ExecuteContext(ctx) + waitSilentUpdateCheckDone(silentUpdateCheckDrainTimeout) + return err } // NewRootCommand 创建 NeoCode 的 CLI 根命令。 @@ -116,11 +127,16 @@ func defaultGlobalPreload(ctx context.Context) error { func defaultSilentUpdateCheck(ctx context.Context) { currentVersion := readCurrentVersion() if !version.IsSemverRelease(currentVersion) { + setSilentUpdateCheckDone(nil) return } parentCtx := context.WithoutCancel(ctx) + done := make(chan struct{}) + setSilentUpdateCheckDone(done) + + go func(parent context.Context, currentVersion string, done chan struct{}) { + defer close(done) - go func(parent context.Context, currentVersion string) { checkCtx, cancel := context.WithTimeout(parent, silentUpdateCheckTimeout) defer cancel() @@ -137,23 +153,17 @@ func defaultSilentUpdateCheck(ctx context.Context) { return } setUpdateNotice(fmt.Sprintf("\u53d1\u73b0\u65b0\u7248\u672c: %s\uff0c\u8fd0\u884c neocode update \u5373\u53ef\u5347\u7ea7", latestVersion)) - }(parentCtx, currentVersion) + }(parentCtx, currentVersion, done) } // shouldSkipGlobalPreload 判断当前命令是否应跳过全局预加载逻辑。 func shouldSkipGlobalPreload(cmd *cobra.Command) bool { - if cmd == nil { - return false - } - return strings.EqualFold(strings.TrimSpace(cmd.Name()), "url-dispatch") + return normalizedCommandName(cmd) == "url-dispatch" } // shouldSkipSilentUpdateCheck 判断当前命令是否应跳过静默更新检测。 func shouldSkipSilentUpdateCheck(cmd *cobra.Command) bool { - if cmd == nil { - return false - } - switch strings.ToLower(strings.TrimSpace(cmd.Name())) { + switch normalizedCommandName(cmd) { case "url-dispatch", "update": return true default: @@ -173,3 +183,39 @@ func sanitizeVersionForTerminal(version string) string { } return strings.TrimSpace(builder.String()) } + +// normalizedCommandName 返回标准化后的命令名,统一处理空命令与大小写。 +func normalizedCommandName(cmd *cobra.Command) string { + if cmd == nil { + return "" + } + return strings.ToLower(strings.TrimSpace(cmd.Name())) +} + +// setSilentUpdateCheckDone 保存当前静默检测任务的完成信号通道。 +func setSilentUpdateCheckDone(done <-chan struct{}) { + silentUpdateCheckMu.Lock() + silentUpdateCheckDone = done + silentUpdateCheckMu.Unlock() +} + +// waitSilentUpdateCheckDone 在命令退出阶段等待静默检测短暂收口,降低提示丢失概率。 +func waitSilentUpdateCheckDone(timeout time.Duration) { + if timeout <= 0 { + return + } + + silentUpdateCheckMu.Lock() + done := silentUpdateCheckDone + silentUpdateCheckMu.Unlock() + if done == nil { + return + } + + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case <-done: + case <-timer.C: + } +} diff --git a/internal/cli/root_test.go b/internal/cli/root_test.go index 432e7b28..ceac8d22 100644 --- a/internal/cli/root_test.go +++ b/internal/cli/root_test.go @@ -110,6 +110,52 @@ func TestExecuteUsesOSArgs(t *testing.T) { } } +func TestExecuteWaitsForSilentUpdateCheckCompletion(t *testing.T) { + originalLauncher := launchRootProgram + originalPreload := runGlobalPreload + originalSilentCheck := runSilentUpdateCheck + originalArgs := os.Args + t.Cleanup(func() { + launchRootProgram = originalLauncher + runGlobalPreload = originalPreload + runSilentUpdateCheck = originalSilentCheck + os.Args = originalArgs + }) + + _ = ConsumeUpdateNotice() + runGlobalPreload = func(context.Context) error { return nil } + launchRootProgram = func(context.Context, app.BootstrapOptions) error { return nil } + runSilentUpdateCheck = func(context.Context) { + done := make(chan struct{}) + setSilentUpdateCheckDone(done) + go func() { + time.Sleep(50 * time.Millisecond) + setUpdateNotice("发现新版本: v0.2.1") + close(done) + }() + } + os.Args = []string{"neocode"} + + if err := Execute(context.Background()); err != nil { + t.Fatalf("Execute() error = %v", err) + } + if got := ConsumeUpdateNotice(); got == "" { + t.Fatal("expected update notice after Execute waits for silent check") + } +} + +func TestWaitSilentUpdateCheckDoneReturnsOnTimeout(t *testing.T) { + blocked := make(chan struct{}) + setSilentUpdateCheckDone(blocked) + t.Cleanup(func() { setSilentUpdateCheckDone(nil) }) + + start := time.Now() + waitSilentUpdateCheckDone(30 * time.Millisecond) + if elapsed := time.Since(start); elapsed < 20*time.Millisecond || elapsed > 150*time.Millisecond { + t.Fatalf("wait duration out of expected range, got %s", elapsed) + } +} + func TestDefaultRootProgramLauncherRunsProgram(t *testing.T) { originalNewProgram := newRootProgram t.Cleanup(func() { newRootProgram = originalNewProgram }) @@ -1198,6 +1244,15 @@ func TestShouldSkipGlobalPreload(t *testing.T) { } } +func TestNormalizedCommandName(t *testing.T) { + if got := normalizedCommandName(nil); got != "" { + t.Fatalf("normalizedCommandName(nil) = %q, want empty", got) + } + if got := normalizedCommandName(&cobra.Command{Use: "URL-Dispatch"}); got != "url-dispatch" { + t.Fatalf("normalizedCommandName() = %q, want %q", got, "url-dispatch") + } +} + func TestShouldSkipSilentUpdateCheck(t *testing.T) { if !shouldSkipSilentUpdateCheck(&cobra.Command{Use: "url-dispatch"}) { t.Fatal("url-dispatch should skip silent update check") diff --git a/internal/cli/update_command.go b/internal/cli/update_command.go index bc44cd98..d1992690 100644 --- a/internal/cli/update_command.go +++ b/internal/cli/update_command.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "strings" "time" "github.com/spf13/cobra" @@ -41,15 +40,14 @@ func newUpdateCommand() *cobra.Command { out := cmd.OutOrStdout() if !result.Updated { - latest := strings.TrimSpace(result.LatestVersion) - if latest == "" { - latest = "unknown" - } + latest := displayVersionForTerminal(result.LatestVersion) _, _ = fmt.Fprintf(out, "Already up-to-date (latest: %s).\n", latest) return nil } - _, _ = fmt.Fprintf(out, "Updated successfully: %s -> %s\n", result.CurrentVersion, result.LatestVersion) + current := displayVersionForTerminal(result.CurrentVersion) + latest := displayVersionForTerminal(result.LatestVersion) + _, _ = fmt.Fprintf(out, "Updated successfully: %s -> %s\n", current, latest) return nil }, } @@ -75,3 +73,12 @@ func defaultUpdateCommandRunner(ctx context.Context, options updateCommandOption } return result, nil } + +// displayVersionForTerminal 清洗版本字符串并为不可用值提供统一回退文案。 +func displayVersionForTerminal(raw string) string { + version := sanitizeVersionForTerminal(raw) + if version == "" { + return "unknown" + } + return version +} diff --git a/internal/cli/update_command_test.go b/internal/cli/update_command_test.go index c8a69d21..06a4d6ea 100644 --- a/internal/cli/update_command_test.go +++ b/internal/cli/update_command_test.go @@ -57,8 +57,8 @@ func TestUpdateCommandShowsSuccessMessage(t *testing.T) { runSilentUpdateCheck = func(context.Context) {} runUpdateCommand = func(context.Context, updateCommandOptions) (updater.UpdateResult, error) { return updater.UpdateResult{ - CurrentVersion: "v0.1.0", - LatestVersion: "v0.2.1", + CurrentVersion: "\x1b[31mv0.1.0\x1b[0m", + LatestVersion: "\x1b[32mv0.2.1\x1b[0m\t", Updated: true, }, nil } @@ -74,6 +74,12 @@ func TestUpdateCommandShowsSuccessMessage(t *testing.T) { if got := stdout.String(); got == "" || !bytes.Contains(stdout.Bytes(), []byte("Updated successfully")) { t.Fatalf("unexpected output: %q", got) } + if strings.Contains(stdout.String(), "\x1b") { + t.Fatalf("expected sanitized output without ANSI sequence, got %q", stdout.String()) + } + if !strings.Contains(stdout.String(), "v0.1.0 -> v0.2.1") { + t.Fatalf("unexpected output: %q", stdout.String()) + } } func TestUpdateCommandShowsUnknownLatestWhenLatestVersionEmpty(t *testing.T) { @@ -102,6 +108,35 @@ func TestUpdateCommandShowsUnknownLatestWhenLatestVersionEmpty(t *testing.T) { } } +func TestUpdateCommandSanitizesLatestVersionInUpToDateMessage(t *testing.T) { + originalRunner := runUpdateCommand + originalPreload := runGlobalPreload + originalSilentCheck := runSilentUpdateCheck + t.Cleanup(func() { runUpdateCommand = originalRunner }) + t.Cleanup(func() { runGlobalPreload = originalPreload }) + t.Cleanup(func() { runSilentUpdateCheck = originalSilentCheck }) + + runGlobalPreload = func(context.Context) error { return nil } + runSilentUpdateCheck = func(context.Context) {} + runUpdateCommand = func(context.Context, updateCommandOptions) (updater.UpdateResult, error) { + return updater.UpdateResult{Updated: false, LatestVersion: "\x1b[31mv0.2.1\x1b[0m\t\n"}, nil + } + + command := NewRootCommand() + var stdout bytes.Buffer + command.SetOut(&stdout) + command.SetArgs([]string{"update"}) + if err := command.ExecuteContext(context.Background()); err != nil { + t.Fatalf("ExecuteContext() error = %v", err) + } + if strings.Contains(stdout.String(), "\x1b") { + t.Fatalf("expected sanitized output without ANSI sequence, got %q", stdout.String()) + } + if !strings.Contains(stdout.String(), "latest: v0.2.1") { + t.Fatalf("unexpected output: %q", stdout.String()) + } +} + func TestUpdateCommandReturnsRunnerError(t *testing.T) { originalRunner := runUpdateCommand originalPreload := runGlobalPreload @@ -216,3 +251,12 @@ func TestDefaultUpdateCommandRunnerReturnsUnderlyingError(t *testing.T) { t.Fatalf("expected underlying error %v, got %v", expected, err) } } + +func TestDisplayVersionForTerminal(t *testing.T) { + if got := displayVersionForTerminal("\x1b[31mv0.2.1\x1b[0m\t"); got != "v0.2.1" { + t.Fatalf("displayVersionForTerminal() = %q, want %q", got, "v0.2.1") + } + if got := displayVersionForTerminal(" \n\t"); got != "unknown" { + t.Fatalf("displayVersionForTerminal() empty = %q, want %q", got, "unknown") + } +}