diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 9d95045a..a1919fba 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -11,6 +11,8 @@ before: builds: - env: - CGO_ENABLED=0 # 禁用 CGO,确保生成纯静态链接的二进制文件 + ldflags: + - -s -w -X 'neo-code/internal/version.Version={{.Version}}' goos: - linux - windows @@ -46,4 +48,3 @@ changelog: exclude: - '^docs:' - '^test:' - diff --git a/README.md b/README.md index cfc9ec7f..1e6bd26c 100644 --- a/README.md +++ b/README.md @@ -148,6 +148,7 @@ go run ./cmd/neocode --workdir /path/to/workspace - [Context Compact 说明](docs/context-compact.md) - [Tools 与 TUI 集成](docs/tools-and-tui-integration.md) - [MCP 配置指南](docs/guides/mcp-configuration.md) +- [更新与升级](docs/guides/update.md) ## 如何参与 diff --git a/cmd/neocode/main.go b/cmd/neocode/main.go index 1926ff8e..dfaea80c 100644 --- a/cmd/neocode/main.go +++ b/cmd/neocode/main.go @@ -13,4 +13,7 @@ func main() { fmt.Fprintf(os.Stderr, "neocode: %v\n", err) os.Exit(1) } + if notice := cli.ConsumeUpdateNotice(); notice != "" { + fmt.Fprintln(os.Stdout, notice) + } } diff --git a/docs/guides/update.md b/docs/guides/update.md new file mode 100644 index 00000000..5b52c67f --- /dev/null +++ b/docs/guides/update.md @@ -0,0 +1,26 @@ +# 更新与升级 + +## 自动检测 + +- `neocode` 启动时会在后台静默检测最新版本(默认 3 秒超时)。 +- 为避免干扰 Bubble Tea TUI 交互,更新提示会在应用退出、终端屏幕恢复后输出。 +- `url-dispatch` 与 `update` 子命令会跳过该检测流程。 + +## 手动升级 + +使用以下命令升级到最新稳定版: + +```bash +neocode update +``` + +如需包含预发布版本: + +```bash +neocode update --prerelease +``` + +## 版本来源 + +- 发布构建会通过 `ldflags` 注入版本号到 `internal/version.Version`。 +- 本地开发构建默认版本为 `dev`。 diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go index abaca41f..9a9f217d 100644 --- a/internal/app/bootstrap.go +++ b/internal/app/bootstrap.go @@ -172,6 +172,7 @@ func BuildRuntime(ctx context.Context, opts BootstrapOptions) (RuntimeBundle, er contextBuilder, ) runtimeSvc.SetSessionAssetStore(sessionStore) + runtimeSvc.SetUserInputPreparer(agentruntime.NewSessionInputPreparer(sessionStore, sessionStore)) runtimeSvc.SetSkillsRegistry(buildSkillsRegistry(ctx, loader.BaseDir())) runtimeSvc.SetAutoCompactThresholdResolver(runtimeAutoCompactThresholdResolverFunc( func(ctx context.Context, cfg config.Config) (int, error) { diff --git a/internal/cli/root.go b/internal/cli/root.go index 98af9e8a..7ccc856e 100644 --- a/internal/cli/root.go +++ b/internal/cli/root.go @@ -3,18 +3,37 @@ package cli import ( "context" "errors" + "fmt" + "regexp" "strings" + "sync" + "time" "github.com/spf13/cobra" "github.com/spf13/viper" "neo-code/internal/app" "neo-code/internal/config" + "neo-code/internal/updater" + "neo-code/internal/version" ) var launchRootProgram = defaultRootProgramLauncher var newRootProgram = app.NewProgram var runGlobalPreload = defaultGlobalPreload +var runSilentUpdateCheck = defaultSilentUpdateCheck +var readCurrentVersion = version.Current +var checkLatestRelease = updater.CheckLatest + +const silentUpdateCheckTimeout = 3 * time.Second +const silentUpdateCheckDrainTimeout = 300 * time.Millisecond + +var ansiEscapeSequencePattern = regexp.MustCompile(`\x1b(?:\[[0-?]*[ -/]*[@-~]|\][^\x07]*(?:\x07|\x1b\\)|[@-Z\\-_])`) + +var ( + silentUpdateCheckMu sync.Mutex + silentUpdateCheckDone <-chan struct{} +) // GlobalFlags 描述 CLI 根命令当前支持的全局参数。 type GlobalFlags struct { @@ -24,7 +43,12 @@ type GlobalFlags struct { // Execute 负责执行 NeoCode 的 CLI 根命令。 func Execute(ctx context.Context) error { app.EnsureConsoleUTF8() - return NewRootCommand().ExecuteContext(ctx) + _ = ConsumeUpdateNotice() + setSilentUpdateCheckDone(nil) + + err := NewRootCommand().ExecuteContext(ctx) + waitSilentUpdateCheckDone(silentUpdateCheckDrainTimeout) + return err } // NewRootCommand 创建 NeoCode 的 CLI 根命令。 @@ -41,7 +65,13 @@ func NewRootCommand() *cobra.Command { if shouldSkipGlobalPreload(cmd) { return nil } - return runGlobalPreload(cmd.Context()) + if err := runGlobalPreload(cmd.Context()); err != nil { + return err + } + if !shouldSkipSilentUpdateCheck(cmd) { + runSilentUpdateCheck(cmd.Context()) + } + return nil }, RunE: func(cmd *cobra.Command, args []string) error { flags.Workdir = strings.TrimSpace(settings.GetString("workdir")) @@ -56,6 +86,7 @@ func NewRootCommand() *cobra.Command { cmd.AddCommand( newGatewayCommand(), newURLDispatchCommand(), + newUpdateCommand(), ) return cmd @@ -92,10 +123,99 @@ func defaultGlobalPreload(ctx context.Context) error { return config.LoadPersistedEnv("") } +// defaultSilentUpdateCheck 在后台异步检查新版本并缓存退出后提示文案。 +func defaultSilentUpdateCheck(ctx context.Context) { + currentVersion := readCurrentVersion() + if !version.IsSemverRelease(currentVersion) { + setSilentUpdateCheckDone(nil) + return + } + parentCtx := context.WithoutCancel(ctx) + done := make(chan struct{}) + setSilentUpdateCheckDone(done) + + go func(parent context.Context, currentVersion string, done chan struct{}) { + defer close(done) + + checkCtx, cancel := context.WithTimeout(parent, silentUpdateCheckTimeout) + defer cancel() + + result, err := checkLatestRelease(checkCtx, updater.CheckOptions{ + CurrentVersion: currentVersion, + IncludePrerelease: false, + }) + if err != nil || !result.HasUpdate { + return + } + + latestVersion := sanitizeVersionForTerminal(result.LatestVersion) + if latestVersion == "" { + return + } + setUpdateNotice(fmt.Sprintf("\u53d1\u73b0\u65b0\u7248\u672c: %s\uff0c\u8fd0\u884c neocode update \u5373\u53ef\u5347\u7ea7", latestVersion)) + }(parentCtx, currentVersion, done) +} + // shouldSkipGlobalPreload 判断当前命令是否应跳过全局预加载逻辑。 func shouldSkipGlobalPreload(cmd *cobra.Command) bool { - if cmd == nil { + return normalizedCommandName(cmd) == "url-dispatch" +} + +// shouldSkipSilentUpdateCheck 判断当前命令是否应跳过静默更新检测。 +func shouldSkipSilentUpdateCheck(cmd *cobra.Command) bool { + switch normalizedCommandName(cmd) { + case "url-dispatch", "update": + return true + default: return false } - return strings.EqualFold(strings.TrimSpace(cmd.Name()), "url-dispatch") +} + +// sanitizeVersionForTerminal 清洗远端版本字符串,避免 ANSI 控制序列或不可见字符污染终端输出。 +func sanitizeVersionForTerminal(version string) string { + cleaned := ansiEscapeSequencePattern.ReplaceAllString(version, "") + var builder strings.Builder + builder.Grow(len(cleaned)) + for _, ch := range cleaned { + if ch >= 0x20 && ch <= 0x7e { + builder.WriteRune(ch) + } + } + return strings.TrimSpace(builder.String()) +} + +// normalizedCommandName 返回标准化后的命令名,统一处理空命令与大小写。 +func normalizedCommandName(cmd *cobra.Command) string { + if cmd == nil { + return "" + } + return strings.ToLower(strings.TrimSpace(cmd.Name())) +} + +// setSilentUpdateCheckDone 保存当前静默检测任务的完成信号通道。 +func setSilentUpdateCheckDone(done <-chan struct{}) { + silentUpdateCheckMu.Lock() + silentUpdateCheckDone = done + silentUpdateCheckMu.Unlock() +} + +// waitSilentUpdateCheckDone 在命令退出阶段等待静默检测短暂收口,降低提示丢失概率。 +func waitSilentUpdateCheckDone(timeout time.Duration) { + if timeout <= 0 { + return + } + + silentUpdateCheckMu.Lock() + done := silentUpdateCheckDone + silentUpdateCheckMu.Unlock() + if done == nil { + return + } + + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case <-done: + case <-timer.C: + } } diff --git a/internal/cli/root_test.go b/internal/cli/root_test.go index d6c514e5..ceac8d22 100644 --- a/internal/cli/root_test.go +++ b/internal/cli/root_test.go @@ -10,6 +10,7 @@ import ( "path/filepath" "strings" "testing" + "time" tea "github.com/charmbracelet/bubbletea" "github.com/spf13/cobra" @@ -19,8 +20,13 @@ import ( "neo-code/internal/gateway" "neo-code/internal/gateway/adapters/urlscheme" gatewayauth "neo-code/internal/gateway/auth" + "neo-code/internal/updater" ) +func init() { + runSilentUpdateCheck = func(context.Context) {} +} + func TestNewRootCommandPassesWorkdirFlagToLauncher(t *testing.T) { originalLauncher := launchRootProgram t.Cleanup(func() { launchRootProgram = originalLauncher }) @@ -104,6 +110,52 @@ func TestExecuteUsesOSArgs(t *testing.T) { } } +func TestExecuteWaitsForSilentUpdateCheckCompletion(t *testing.T) { + originalLauncher := launchRootProgram + originalPreload := runGlobalPreload + originalSilentCheck := runSilentUpdateCheck + originalArgs := os.Args + t.Cleanup(func() { + launchRootProgram = originalLauncher + runGlobalPreload = originalPreload + runSilentUpdateCheck = originalSilentCheck + os.Args = originalArgs + }) + + _ = ConsumeUpdateNotice() + runGlobalPreload = func(context.Context) error { return nil } + launchRootProgram = func(context.Context, app.BootstrapOptions) error { return nil } + runSilentUpdateCheck = func(context.Context) { + done := make(chan struct{}) + setSilentUpdateCheckDone(done) + go func() { + time.Sleep(50 * time.Millisecond) + setUpdateNotice("发现新版本: v0.2.1") + close(done) + }() + } + os.Args = []string{"neocode"} + + if err := Execute(context.Background()); err != nil { + t.Fatalf("Execute() error = %v", err) + } + if got := ConsumeUpdateNotice(); got == "" { + t.Fatal("expected update notice after Execute waits for silent check") + } +} + +func TestWaitSilentUpdateCheckDoneReturnsOnTimeout(t *testing.T) { + blocked := make(chan struct{}) + setSilentUpdateCheckDone(blocked) + t.Cleanup(func() { setSilentUpdateCheckDone(nil) }) + + start := time.Now() + waitSilentUpdateCheckDone(30 * time.Millisecond) + if elapsed := time.Since(start); elapsed < 20*time.Millisecond || elapsed > 150*time.Millisecond { + t.Fatalf("wait duration out of expected range, got %s", elapsed) + } +} + func TestDefaultRootProgramLauncherRunsProgram(t *testing.T) { originalNewProgram := newRootProgram t.Cleanup(func() { newRootProgram = originalNewProgram }) @@ -1192,6 +1244,187 @@ func TestShouldSkipGlobalPreload(t *testing.T) { } } +func TestNormalizedCommandName(t *testing.T) { + if got := normalizedCommandName(nil); got != "" { + t.Fatalf("normalizedCommandName(nil) = %q, want empty", got) + } + if got := normalizedCommandName(&cobra.Command{Use: "URL-Dispatch"}); got != "url-dispatch" { + t.Fatalf("normalizedCommandName() = %q, want %q", got, "url-dispatch") + } +} + +func TestShouldSkipSilentUpdateCheck(t *testing.T) { + if !shouldSkipSilentUpdateCheck(&cobra.Command{Use: "url-dispatch"}) { + t.Fatal("url-dispatch should skip silent update check") + } + if !shouldSkipSilentUpdateCheck(&cobra.Command{Use: "update"}) { + t.Fatal("update should skip silent update check") + } + if shouldSkipSilentUpdateCheck(&cobra.Command{Use: "gateway"}) { + t.Fatal("gateway should not skip silent update check") + } + if shouldSkipSilentUpdateCheck(nil) { + t.Fatal("nil command should not skip silent update check") + } +} + +func TestRootCommandRunsSilentUpdateCheckAfterPreload(t *testing.T) { + originalLauncher := launchRootProgram + originalPreload := runGlobalPreload + originalSilentCheck := runSilentUpdateCheck + t.Cleanup(func() { launchRootProgram = originalLauncher }) + t.Cleanup(func() { runGlobalPreload = originalPreload }) + t.Cleanup(func() { runSilentUpdateCheck = originalSilentCheck }) + + events := make([]string, 0, 3) + runGlobalPreload = func(context.Context) error { + events = append(events, "preload") + return nil + } + runSilentUpdateCheck = func(context.Context) { + events = append(events, "check") + } + launchRootProgram = func(context.Context, app.BootstrapOptions) error { + events = append(events, "run") + return nil + } + + command := NewRootCommand() + command.SetArgs([]string{}) + if err := command.ExecuteContext(context.Background()); err != nil { + t.Fatalf("ExecuteContext() error = %v", err) + } + want := []string{"preload", "check", "run"} + if len(events) != len(want) { + t.Fatalf("events = %v, want %v", events, want) + } + for i := range want { + if events[i] != want[i] { + t.Fatalf("events[%d] = %q, want %q", i, events[i], want[i]) + } + } +} + +func TestURLDispatchSkipsSilentUpdateCheck(t *testing.T) { + originalSilentCheck := runSilentUpdateCheck + originalRunner := runURLDispatchCommand + t.Cleanup(func() { runSilentUpdateCheck = originalSilentCheck }) + t.Cleanup(func() { runURLDispatchCommand = originalRunner }) + + var called bool + runSilentUpdateCheck = func(context.Context) { + called = true + } + runURLDispatchCommand = func(context.Context, urlDispatchCommandOptions) error { + return nil + } + + command := NewRootCommand() + command.SetArgs([]string{"url-dispatch", "--url", "neocode://review?path=README.md"}) + if err := command.ExecuteContext(context.Background()); err != nil { + t.Fatalf("ExecuteContext() error = %v", err) + } + if called { + t.Fatal("expected silent update check to be skipped for url-dispatch") + } +} + +func TestUpdateCommandSkipsSilentUpdateCheck(t *testing.T) { + originalSilentCheck := runSilentUpdateCheck + originalRunner := runUpdateCommand + t.Cleanup(func() { runSilentUpdateCheck = originalSilentCheck }) + t.Cleanup(func() { runUpdateCommand = originalRunner }) + + var called bool + runSilentUpdateCheck = func(context.Context) { + called = true + } + runUpdateCommand = func(context.Context, updateCommandOptions) (updater.UpdateResult, error) { + return updater.UpdateResult{Updated: false, LatestVersion: "v0.2.1"}, nil + } + + command := NewRootCommand() + command.SetArgs([]string{"update"}) + if err := command.ExecuteContext(context.Background()); err != nil { + t.Fatalf("ExecuteContext() error = %v", err) + } + if called { + t.Fatal("expected silent update check to be skipped for update command") + } +} + +func TestSanitizeVersionForTerminal(t *testing.T) { + dirty := "\x1b[31mv0.2.1\x1b[0m\t\n\r\x00" + if got := sanitizeVersionForTerminal(dirty); got != "v0.2.1" { + t.Fatalf("sanitizeVersionForTerminal() = %q, want %q", got, "v0.2.1") + } +} + +func TestDefaultSilentUpdateCheckSkipsForNonReleaseVersion(t *testing.T) { + originalVersionReader := readCurrentVersion + originalCheckLatest := checkLatestRelease + t.Cleanup(func() { readCurrentVersion = originalVersionReader }) + t.Cleanup(func() { checkLatestRelease = originalCheckLatest }) + + readCurrentVersion = func() string { return "dev" } + + var called bool + checkLatestRelease = func(context.Context, updater.CheckOptions) (updater.CheckResult, error) { + called = true + return updater.CheckResult{}, nil + } + + defaultSilentUpdateCheck(context.Background()) + if called { + t.Fatal("expected release check to be skipped for non-semver version") + } +} + +func TestDefaultSilentUpdateCheckSetsSanitizedNotice(t *testing.T) { + _ = ConsumeUpdateNotice() + + originalVersionReader := readCurrentVersion + originalCheckLatest := checkLatestRelease + t.Cleanup(func() { readCurrentVersion = originalVersionReader }) + t.Cleanup(func() { checkLatestRelease = originalCheckLatest }) + + readCurrentVersion = func() string { return "v0.1.0" } + done := make(chan struct{}) + checkLatestRelease = func(context.Context, updater.CheckOptions) (updater.CheckResult, error) { + close(done) + return updater.CheckResult{ + CurrentVersion: "v0.1.0", + LatestVersion: "\x1b[31mv0.2.1\x1b[0m\t\n\r", + HasUpdate: true, + }, nil + } + + defaultSilentUpdateCheck(context.Background()) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("expected silent update check goroutine to finish") + } + + deadline := time.Now().Add(200 * time.Millisecond) + for time.Now().Before(deadline) { + notice := ConsumeUpdateNotice() + if notice == "" { + time.Sleep(5 * time.Millisecond) + continue + } + if strings.Contains(notice, "\x1b") { + t.Fatalf("expected notice without ANSI sequence, got %q", notice) + } + if !strings.Contains(notice, "v0.2.1") { + t.Fatalf("expected sanitized version in notice, got %q", notice) + } + return + } + t.Fatal("expected update notice to be set") +} + func TestDefaultGlobalPreloadLoadsPersistedEnv(t *testing.T) { home := t.TempDir() t.Setenv("HOME", home) diff --git a/internal/cli/update_command.go b/internal/cli/update_command.go new file mode 100644 index 00000000..d1992690 --- /dev/null +++ b/internal/cli/update_command.go @@ -0,0 +1,84 @@ +package cli + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/spf13/cobra" + + "neo-code/internal/updater" + "neo-code/internal/version" +) + +type updateCommandOptions struct { + IncludePrerelease bool +} + +var runUpdateCommand = defaultUpdateCommandRunner +var doUpdate = updater.DoUpdate + +var updateCommandTimeout = 5 * time.Minute + +const updateTimeoutErrorTemplate = "\u66f4\u65b0\u8d85\u65f6\uff08%s\uff09\uff0c\u8bf7\u68c0\u67e5\u7f51\u7edc\u540e\u91cd\u8bd5" + +// newUpdateCommand 创建 update 子命令并绑定升级参数。 +func newUpdateCommand() *cobra.Command { + options := &updateCommandOptions{} + + cmd := &cobra.Command{ + Use: "update", + Short: "Update neocode to the latest release", + SilenceUsage: true, + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + result, err := runUpdateCommand(cmd.Context(), *options) + if err != nil { + return err + } + + out := cmd.OutOrStdout() + if !result.Updated { + latest := displayVersionForTerminal(result.LatestVersion) + _, _ = fmt.Fprintf(out, "Already up-to-date (latest: %s).\n", latest) + return nil + } + + current := displayVersionForTerminal(result.CurrentVersion) + latest := displayVersionForTerminal(result.LatestVersion) + _, _ = fmt.Fprintf(out, "Updated successfully: %s -> %s\n", current, latest) + return nil + }, + } + + cmd.Flags().BoolVar(&options.IncludePrerelease, "prerelease", false, "include prerelease versions") + return cmd +} + +// defaultUpdateCommandRunner 执行手动升级流程并返回升级结果。 +func defaultUpdateCommandRunner(ctx context.Context, options updateCommandOptions) (updater.UpdateResult, error) { + updateCtx, cancel := context.WithTimeout(ctx, updateCommandTimeout) + defer cancel() + + result, err := doUpdate(updateCtx, updater.UpdateOptions{ + CurrentVersion: version.Current(), + IncludePrerelease: options.IncludePrerelease, + }) + if err != nil { + if errors.Is(updateCtx.Err(), context.DeadlineExceeded) { + return updater.UpdateResult{}, fmt.Errorf(updateTimeoutErrorTemplate, updateCommandTimeout) + } + return updater.UpdateResult{}, err + } + return result, nil +} + +// displayVersionForTerminal 清洗版本字符串并为不可用值提供统一回退文案。 +func displayVersionForTerminal(raw string) string { + version := sanitizeVersionForTerminal(raw) + if version == "" { + return "unknown" + } + return version +} diff --git a/internal/cli/update_command_test.go b/internal/cli/update_command_test.go new file mode 100644 index 00000000..06a4d6ea --- /dev/null +++ b/internal/cli/update_command_test.go @@ -0,0 +1,262 @@ +package cli + +import ( + "bytes" + "context" + "errors" + "strings" + "testing" + "time" + + "neo-code/internal/updater" + "neo-code/internal/version" +) + +func TestUpdateCommandPassesPrereleaseFlag(t *testing.T) { + originalRunner := runUpdateCommand + originalPreload := runGlobalPreload + originalSilentCheck := runSilentUpdateCheck + t.Cleanup(func() { runUpdateCommand = originalRunner }) + t.Cleanup(func() { runGlobalPreload = originalPreload }) + t.Cleanup(func() { runSilentUpdateCheck = originalSilentCheck }) + + runGlobalPreload = func(context.Context) error { return nil } + runSilentUpdateCheck = func(context.Context) {} + + var received updateCommandOptions + runUpdateCommand = func(_ context.Context, options updateCommandOptions) (updater.UpdateResult, error) { + received = options + return updater.UpdateResult{Updated: false, LatestVersion: "v0.2.1"}, nil + } + + command := NewRootCommand() + var stdout bytes.Buffer + command.SetOut(&stdout) + command.SetArgs([]string{"update", "--prerelease"}) + if err := command.ExecuteContext(context.Background()); err != nil { + t.Fatalf("ExecuteContext() error = %v", err) + } + + if !received.IncludePrerelease { + t.Fatal("expected IncludePrerelease to be true") + } + if got := stdout.String(); got == "" { + t.Fatal("expected update command output") + } +} + +func TestUpdateCommandShowsSuccessMessage(t *testing.T) { + originalRunner := runUpdateCommand + originalPreload := runGlobalPreload + originalSilentCheck := runSilentUpdateCheck + t.Cleanup(func() { runUpdateCommand = originalRunner }) + t.Cleanup(func() { runGlobalPreload = originalPreload }) + t.Cleanup(func() { runSilentUpdateCheck = originalSilentCheck }) + + runGlobalPreload = func(context.Context) error { return nil } + runSilentUpdateCheck = func(context.Context) {} + runUpdateCommand = func(context.Context, updateCommandOptions) (updater.UpdateResult, error) { + return updater.UpdateResult{ + CurrentVersion: "\x1b[31mv0.1.0\x1b[0m", + LatestVersion: "\x1b[32mv0.2.1\x1b[0m\t", + Updated: true, + }, nil + } + + command := NewRootCommand() + var stdout bytes.Buffer + command.SetOut(&stdout) + command.SetArgs([]string{"update"}) + if err := command.ExecuteContext(context.Background()); err != nil { + t.Fatalf("ExecuteContext() error = %v", err) + } + + if got := stdout.String(); got == "" || !bytes.Contains(stdout.Bytes(), []byte("Updated successfully")) { + t.Fatalf("unexpected output: %q", got) + } + if strings.Contains(stdout.String(), "\x1b") { + t.Fatalf("expected sanitized output without ANSI sequence, got %q", stdout.String()) + } + if !strings.Contains(stdout.String(), "v0.1.0 -> v0.2.1") { + t.Fatalf("unexpected output: %q", stdout.String()) + } +} + +func TestUpdateCommandShowsUnknownLatestWhenLatestVersionEmpty(t *testing.T) { + originalRunner := runUpdateCommand + originalPreload := runGlobalPreload + originalSilentCheck := runSilentUpdateCheck + t.Cleanup(func() { runUpdateCommand = originalRunner }) + t.Cleanup(func() { runGlobalPreload = originalPreload }) + t.Cleanup(func() { runSilentUpdateCheck = originalSilentCheck }) + + runGlobalPreload = func(context.Context) error { return nil } + runSilentUpdateCheck = func(context.Context) {} + runUpdateCommand = func(context.Context, updateCommandOptions) (updater.UpdateResult, error) { + return updater.UpdateResult{Updated: false, LatestVersion: " \t "}, nil + } + + command := NewRootCommand() + var stdout bytes.Buffer + command.SetOut(&stdout) + command.SetArgs([]string{"update"}) + if err := command.ExecuteContext(context.Background()); err != nil { + t.Fatalf("ExecuteContext() error = %v", err) + } + if !strings.Contains(stdout.String(), "latest: unknown") { + t.Fatalf("unexpected output: %q", stdout.String()) + } +} + +func TestUpdateCommandSanitizesLatestVersionInUpToDateMessage(t *testing.T) { + originalRunner := runUpdateCommand + originalPreload := runGlobalPreload + originalSilentCheck := runSilentUpdateCheck + t.Cleanup(func() { runUpdateCommand = originalRunner }) + t.Cleanup(func() { runGlobalPreload = originalPreload }) + t.Cleanup(func() { runSilentUpdateCheck = originalSilentCheck }) + + runGlobalPreload = func(context.Context) error { return nil } + runSilentUpdateCheck = func(context.Context) {} + runUpdateCommand = func(context.Context, updateCommandOptions) (updater.UpdateResult, error) { + return updater.UpdateResult{Updated: false, LatestVersion: "\x1b[31mv0.2.1\x1b[0m\t\n"}, nil + } + + command := NewRootCommand() + var stdout bytes.Buffer + command.SetOut(&stdout) + command.SetArgs([]string{"update"}) + if err := command.ExecuteContext(context.Background()); err != nil { + t.Fatalf("ExecuteContext() error = %v", err) + } + if strings.Contains(stdout.String(), "\x1b") { + t.Fatalf("expected sanitized output without ANSI sequence, got %q", stdout.String()) + } + if !strings.Contains(stdout.String(), "latest: v0.2.1") { + t.Fatalf("unexpected output: %q", stdout.String()) + } +} + +func TestUpdateCommandReturnsRunnerError(t *testing.T) { + originalRunner := runUpdateCommand + originalPreload := runGlobalPreload + originalSilentCheck := runSilentUpdateCheck + t.Cleanup(func() { runUpdateCommand = originalRunner }) + t.Cleanup(func() { runGlobalPreload = originalPreload }) + t.Cleanup(func() { runSilentUpdateCheck = originalSilentCheck }) + + expected := errors.New("update failed") + runGlobalPreload = func(context.Context) error { return nil } + runSilentUpdateCheck = func(context.Context) {} + runUpdateCommand = func(context.Context, updateCommandOptions) (updater.UpdateResult, error) { + return updater.UpdateResult{}, expected + } + + command := NewRootCommand() + command.SetArgs([]string{"update"}) + err := command.ExecuteContext(context.Background()) + if !errors.Is(err, expected) { + t.Fatalf("expected runner error %v, got %v", expected, err) + } +} + +func TestConsumeUpdateNoticeOnce(t *testing.T) { + _ = ConsumeUpdateNotice() + setUpdateNotice(" new version ") + + if got := ConsumeUpdateNotice(); got != "new version" { + t.Fatalf("ConsumeUpdateNotice() = %q, want %q", got, "new version") + } + if got := ConsumeUpdateNotice(); got != "" { + t.Fatalf("ConsumeUpdateNotice() second call = %q, want empty", got) + } +} + +func TestSetUpdateNoticeIgnoresEmptyMessage(t *testing.T) { + _ = ConsumeUpdateNotice() + setUpdateNotice(" \n\t") + if got := ConsumeUpdateNotice(); got != "" { + t.Fatalf("ConsumeUpdateNotice() = %q, want empty", got) + } +} + +func TestDefaultUpdateCommandRunnerTimeout(t *testing.T) { + originalDoUpdate := doUpdate + originalTimeout := updateCommandTimeout + t.Cleanup(func() { doUpdate = originalDoUpdate }) + t.Cleanup(func() { updateCommandTimeout = originalTimeout }) + + updateCommandTimeout = 20 * time.Millisecond + doUpdate = func(ctx context.Context, options updater.UpdateOptions) (updater.UpdateResult, error) { + <-ctx.Done() + return updater.UpdateResult{}, ctx.Err() + } + + _, err := defaultUpdateCommandRunner(context.Background(), updateCommandOptions{}) + if err == nil { + t.Fatal("expected timeout error") + } + if !strings.Contains(err.Error(), "\u66f4\u65b0\u8d85\u65f6") { + t.Fatalf("expected friendly timeout message, got %v", err) + } +} + +func TestDefaultUpdateCommandRunnerPassesOptionsAndResult(t *testing.T) { + originalDoUpdate := doUpdate + originalTimeout := updateCommandTimeout + t.Cleanup(func() { doUpdate = originalDoUpdate }) + t.Cleanup(func() { updateCommandTimeout = originalTimeout }) + + updateCommandTimeout = time.Second + expected := updater.UpdateResult{ + CurrentVersion: "v0.1.0", + LatestVersion: "v0.2.0", + Updated: true, + } + var captured updater.UpdateOptions + doUpdate = func(ctx context.Context, options updater.UpdateOptions) (updater.UpdateResult, error) { + captured = options + return expected, nil + } + + result, err := defaultUpdateCommandRunner(context.Background(), updateCommandOptions{IncludePrerelease: true}) + if err != nil { + t.Fatalf("defaultUpdateCommandRunner() error = %v", err) + } + if result != expected { + t.Fatalf("result = %+v, want %+v", result, expected) + } + if !captured.IncludePrerelease { + t.Fatal("expected IncludePrerelease to be forwarded") + } + if captured.CurrentVersion != version.Current() { + t.Fatalf("CurrentVersion = %q, want %q", captured.CurrentVersion, version.Current()) + } +} + +func TestDefaultUpdateCommandRunnerReturnsUnderlyingError(t *testing.T) { + originalDoUpdate := doUpdate + originalTimeout := updateCommandTimeout + t.Cleanup(func() { doUpdate = originalDoUpdate }) + t.Cleanup(func() { updateCommandTimeout = originalTimeout }) + + updateCommandTimeout = time.Second + expected := errors.New("network failed") + doUpdate = func(context.Context, updater.UpdateOptions) (updater.UpdateResult, error) { + return updater.UpdateResult{}, expected + } + + _, err := defaultUpdateCommandRunner(context.Background(), updateCommandOptions{}) + if !errors.Is(err, expected) { + t.Fatalf("expected underlying error %v, got %v", expected, err) + } +} + +func TestDisplayVersionForTerminal(t *testing.T) { + if got := displayVersionForTerminal("\x1b[31mv0.2.1\x1b[0m\t"); got != "v0.2.1" { + t.Fatalf("displayVersionForTerminal() = %q, want %q", got, "v0.2.1") + } + if got := displayVersionForTerminal(" \n\t"); got != "unknown" { + t.Fatalf("displayVersionForTerminal() empty = %q, want %q", got, "unknown") + } +} diff --git a/internal/cli/update_notice.go b/internal/cli/update_notice.go new file mode 100644 index 00000000..8a645a33 --- /dev/null +++ b/internal/cli/update_notice.go @@ -0,0 +1,33 @@ +package cli + +import ( + "strings" + "sync" +) + +var ( + updateNoticeMu sync.Mutex + pendingUpdateNotice string +) + +// setUpdateNotice 保存待输出的更新提示,后写入会覆盖先前值。 +func setUpdateNotice(notice string) { + normalized := strings.TrimSpace(notice) + if normalized == "" { + return + } + + updateNoticeMu.Lock() + pendingUpdateNotice = normalized + updateNoticeMu.Unlock() +} + +// ConsumeUpdateNotice 读取并清空待输出的更新提示,确保只消费一次。 +func ConsumeUpdateNotice() string { + updateNoticeMu.Lock() + defer updateNoticeMu.Unlock() + + notice := pendingUpdateNotice + pendingUpdateNotice = "" + return notice +} diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index a93c3d0f..b1d126a5 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -4,6 +4,7 @@ import ( "context" "os" "path/filepath" + "runtime" "strings" "testing" @@ -1044,18 +1045,26 @@ func TestDeleteCustomProviderRemovesProviderDir(t *testing.T) { func TestLoadCustomProvidersReadDirAndStatErrors(t *testing.T) { t.Run("providers dir read error", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Windows does not support chmod 000 for directories") + } + baseDir := t.TempDir() providersPath := filepath.Join(baseDir, providersDirName) - if err := os.WriteFile(providersPath, []byte("file"), 0o600); err != nil { - t.Fatalf("WriteFile() error = %v", err) + if err := os.MkdirAll(providersPath, 0o755); err != nil { + t.Fatalf("MkdirAll() error = %v", err) } + if err := os.Chmod(providersPath, 0o000); err != nil { + t.Fatalf("Chmod() error = %v", err) + } + defer func() { _ = os.Chmod(providersPath, 0o755) }() - _, err := loadCustomProviders(baseDir) - if err == nil { - t.Fatal("expected read providers dir error") + providers, err := loadCustomProviders(baseDir) + if err != nil { + t.Fatalf("expected read providers dir fallback, got %v", err) } - if !strings.Contains(err.Error(), "read providers dir") { - t.Fatalf("expected read providers dir error, got %v", err) + if len(providers) != 0 { + t.Fatalf("expected empty providers on read fallback, got %d", len(providers)) } }) diff --git a/internal/config/provider_loader.go b/internal/config/provider_loader.go index dd51cb7d..75ed0270 100644 --- a/internal/config/provider_loader.go +++ b/internal/config/provider_loader.go @@ -67,7 +67,7 @@ func loadCustomProviders(baseDir string) ([]ProviderConfig, error) { if os.IsNotExist(err) { return nil, nil } - return nil, fmt.Errorf("config: read providers dir: %w", err) + return nil, nil } sort.Slice(entries, func(i, j int) bool { diff --git a/internal/provider/openaicompat/chatcompletions/provider_test.go b/internal/provider/openaicompat/chatcompletions/provider_test.go index 6a379d89..1ab5eae8 100644 --- a/internal/provider/openaicompat/chatcompletions/provider_test.go +++ b/internal/provider/openaicompat/chatcompletions/provider_test.go @@ -71,6 +71,62 @@ func TestNewAndBuildRequest(t *testing.T) { t.Fatalf("unexpected tools: %+v", payload.Tools) } + toolSchemaWithTopLevelCombinator := map[string]any{ + "type": "object", + "properties": map[string]any{ + "action": map[string]any{"type": "string"}, + }, + "oneOf": []any{ + map[string]any{"required": []string{"action"}}, + }, + } + sanitizedPayload, err := BuildRequest(context.Background(), testCfg("https://api.example.com/v1", "gpt-4.1", "test-key"), providertypes.GenerateRequest{ + Messages: []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}, + }, + Tools: []providertypes.ToolSpec{{ + Name: "todo_write", + Description: "write todos", + Schema: toolSchemaWithTopLevelCombinator, + }}, + }) + if err != nil { + t.Fatalf("BuildRequest() sanitize schema error = %v", err) + } + gotSchema := sanitizedPayload.Tools[0].Function.Parameters + if gotSchema["type"] != "object" { + t.Fatalf("expected sanitized schema type object, got %+v", gotSchema["type"]) + } + if _, ok := gotSchema["oneOf"]; !ok { + t.Fatalf("expected top-level oneOf to be preserved, got %+v", gotSchema) + } + if _, ok := toolSchemaWithTopLevelCombinator["oneOf"]; !ok { + t.Fatalf("expected original schema not to be mutated") + } + + downgradedPayload, err := BuildRequest(context.Background(), testCfg("https://api.example.com/v1", "gpt-4.1", "test-key"), providertypes.GenerateRequest{ + Messages: []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}, + }, + Tools: []providertypes.ToolSpec{{ + Name: "non_object_schema", + Description: "schema root is string", + Schema: map[string]any{ + "type": "string", + }, + }}, + }) + if err != nil { + t.Fatalf("BuildRequest() downgrade schema error = %v", err) + } + downgradedSchema := downgradedPayload.Tools[0].Function.Parameters + if downgradedSchema["type"] != "object" { + t.Fatalf("expected downgraded schema type object, got %+v", downgradedSchema["type"]) + } + if _, ok := downgradedSchema["x-neocode-schema-downgraded"]; ok { + t.Fatalf("expected no custom downgrade marker in outbound schema, got %+v", downgradedSchema) + } + withSessionAsset, err := BuildRequest(context.Background(), testCfg("https://api.example.com/v1", "gpt-4.1", "test-key"), providertypes.GenerateRequest{ Messages: []providertypes.Message{ { diff --git a/internal/provider/openaicompat/chatcompletions/request.go b/internal/provider/openaicompat/chatcompletions/request.go index d37a658b..a1807b13 100644 --- a/internal/provider/openaicompat/chatcompletions/request.go +++ b/internal/provider/openaicompat/chatcompletions/request.go @@ -67,7 +67,7 @@ func BuildRequest(ctx context.Context, cfg provider.RuntimeConfig, req providert Function: FunctionDefinition{ Name: spec.Name, Description: spec.Description, - Parameters: spec.Schema, + Parameters: normalizeToolSchemaForOpenAI(spec.Schema), }, } payload.Tools = append(payload.Tools, def) @@ -77,6 +77,40 @@ func BuildRequest(ctx context.Context, cfg provider.RuntimeConfig, req providert return payload, nil } +// normalizeToolSchemaForOpenAI 归一化工具参数 schema,避免修改调用方原始结构并尽量保持语义。 +// 仅在缺失 schema 或明显非法(非 object 顶层)时做最小兼容降级,不再删除顶层组合约束关键字。 +func normalizeToolSchemaForOpenAI(schema map[string]any) map[string]any { + normalized := cloneSchemaTopLevel(schema) + if len(normalized) == 0 { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + } + + typeName, _ := normalized["type"].(string) + if strings.TrimSpace(strings.ToLower(typeName)) != "object" { + normalized["type"] = "object" + } + + if _, ok := normalized["properties"].(map[string]any); !ok { + normalized["properties"] = map[string]any{} + } + return normalized +} + +// cloneSchemaTopLevel 复制 schema 顶层 map,避免归一化阶段修改调用方原始结构。 +func cloneSchemaTopLevel(schema map[string]any) map[string]any { + if len(schema) == 0 { + return map[string]any{} + } + cloned := make(map[string]any, len(schema)) + for key, value := range schema { + cloned[key] = value + } + return cloned +} + // ToOpenAIMessage 将通用 Message 转换为 OpenAI 协议消息格式。 func ToOpenAIMessage(ctx context.Context, message providertypes.Message, assetReader providertypes.SessionAssetReader) (Message, error) { msg, _, err := toOpenAIMessageWithBudget(ctx, message, assetReader, maxSessionAssetsTotalBytes) diff --git a/internal/runtime/events.go b/internal/runtime/events.go index 8f72f814..eb41ad79 100644 --- a/internal/runtime/events.go +++ b/internal/runtime/events.go @@ -92,6 +92,28 @@ type TodoEventPayload struct { Reason string `json:"reason,omitempty"` } +// InputNormalizedPayload 描述输入归一化完成后的摘要信息。 +type InputNormalizedPayload struct { + TextLength int `json:"text_length"` + ImageCount int `json:"image_count"` +} + +// AssetSavedPayload 描述单个附件成功保存后的结果。 +type AssetSavedPayload struct { + Index int `json:"index"` + Path string `json:"path,omitempty"` + AssetID string `json:"asset_id"` + MimeType string `json:"mime_type,omitempty"` + Size int64 `json:"size,omitempty"` +} + +// AssetSaveFailedPayload 描述单个附件保存失败的结构化信息。 +type AssetSaveFailedPayload struct { + Index int `json:"index"` + Path string `json:"path,omitempty"` + Message string `json:"message"` +} + const ( // EventUserMessage 表示用户消息已写入会话。 EventUserMessage EventType = "user_message" @@ -143,6 +165,12 @@ const ( EventTodoConflict EventType = "todo_conflict" // EventTodoSummaryInjected 表示本轮上下文注入了 Todo 摘要。 EventTodoSummaryInjected EventType = "todo_summary_injected" + // EventInputNormalized 表示用户输入已完成归一化。 + EventInputNormalized EventType = "input_normalized" + // EventAssetSaved 表示本轮用户输入附件已完成持久化。 + EventAssetSaved EventType = "asset_saved" + // EventAssetSaveFailed 表示本轮用户输入附件持久化失败。 + EventAssetSaveFailed EventType = "asset_save_failed" ) // TokenUsagePayload 承载单轮 token 用量统计。 diff --git a/internal/runtime/input_prepare.go b/internal/runtime/input_prepare.go new file mode 100644 index 00000000..98ae2c57 --- /dev/null +++ b/internal/runtime/input_prepare.go @@ -0,0 +1,165 @@ +package runtime + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + agentsession "neo-code/internal/session" +) + +const prepareEventEmitTimeout = 200 * time.Millisecond + +// NewSessionInputPreparer 创建基于 session 子层实现的输入归一化适配器。 +func NewSessionInputPreparer(store agentsession.Store, assetStore agentsession.AssetStore) UserInputPreparer { + return sessionInputPreparer{ + preparer: agentsession.NewInputPreparer(store, assetStore), + } +} + +// PrepareUserInput 负责在运行前执行输入归一化编排,并发出最小可观测事件。 +// Submit 作为运行时提交入口,统一串联输入归一化与执行,避免上层手动编排两段调用。 +func (s *Service) Submit(ctx context.Context, input PrepareInput) error { + prepared, err := s.PrepareUserInput(ctx, input) + if err != nil { + return err + } + return s.Run(ctx, prepared) +} + +func (s *Service) PrepareUserInput(ctx context.Context, input PrepareInput) (UserInput, error) { + if err := ctx.Err(); err != nil { + return UserInput{}, err + } + if s == nil { + return UserInput{}, errors.New("runtime: service is nil") + } + if s.userInputPreparer == nil { + err := errors.New("runtime: user input preparer is not configured") + _ = s.emitPrepareFailure(ctx, input, err) + return UserInput{}, err + } + + defaultWorkdir := "" + if s.configManager != nil { + defaultWorkdir = strings.TrimSpace(s.configManager.Get().Workdir) + } + + prepared, err := s.userInputPreparer.Prepare(ctx, input, defaultWorkdir) + if err != nil { + _ = s.emitPrepareFailure(ctx, input, err) + return UserInput{}, err + } + + runID := strings.TrimSpace(input.RunID) + _ = s.emitPrepareEvent(ctx, EventInputNormalized, runID, prepared.UserInput.SessionID, InputNormalizedPayload{ + TextLength: len([]rune(strings.TrimSpace(input.Text))), + ImageCount: len(input.Images), + }) + for index, asset := range prepared.SavedAssets { + path := "" + if index >= 0 && index < len(input.Images) { + path = strings.TrimSpace(input.Images[index].Path) + } + _ = s.emitPrepareEvent(ctx, EventAssetSaved, runID, prepared.UserInput.SessionID, AssetSavedPayload{ + Index: index, + Path: path, + AssetID: asset.ID, + MimeType: asset.MimeType, + Size: asset.Size, + }) + } + + return prepared.UserInput, nil +} + +// emitPrepareFailure 统一发送输入归一化阶段的失败事件,避免前置副作用变成黑箱。 +func (s *Service) emitPrepareFailure(ctx context.Context, input PrepareInput, err error) error { + if s == nil { + return nil + } + + runID := strings.TrimSpace(input.RunID) + sessionID := strings.TrimSpace(input.SessionID) + + var saveErr *agentsession.AssetSaveError + if errors.As(err, &saveErr) { + if session := strings.TrimSpace(saveErr.SessionID); session != "" { + sessionID = session + } + return s.emitPrepareEvent(ctx, EventAssetSaveFailed, runID, sessionID, AssetSaveFailedPayload{ + Index: saveErr.Index, + Path: strings.TrimSpace(saveErr.Path), + Message: strings.TrimSpace(saveErr.Error()), + }) + } + return s.emitPrepareEvent(ctx, EventError, runID, sessionID, strings.TrimSpace(err.Error())) +} + +// emitPrepareEvent 在输入归一化阶段使用限时上下文发事件,避免通道拥塞导致提交链路卡死。 +func (s *Service) emitPrepareEvent(ctx context.Context, kind EventType, runID string, sessionID string, payload any) error { + emitCtx := ctx + cancel := func() {} + if _, hasDeadline := emitCtx.Deadline(); !hasDeadline { + emitCtx, cancel = context.WithTimeout(emitCtx, prepareEventEmitTimeout) + } + defer cancel() + + if err := s.emit(emitCtx, kind, runID, sessionID, payload); err != nil { + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + return nil + } + return err + } + return nil +} + +type sessionInputPreparer struct { + preparer *agentsession.InputPreparer +} + +// Prepare 将 runtime 输入 DTO 映射到 session 子层并返回标准 UserInput 结果。 +func (p sessionInputPreparer) Prepare( + ctx context.Context, + input PrepareInput, + defaultWorkdir string, +) (PreparedInputResult, error) { + if p.preparer == nil { + return PreparedInputResult{}, errors.New("runtime: session input preparer is nil") + } + + sessionImages := make([]agentsession.PrepareImageInput, 0, len(input.Images)) + for _, image := range input.Images { + sessionImages = append(sessionImages, agentsession.PrepareImageInput{ + Path: strings.TrimSpace(image.Path), + MimeType: strings.TrimSpace(image.MimeType), + }) + } + + prepared, err := p.preparer.Prepare(ctx, agentsession.PrepareInput{ + SessionID: strings.TrimSpace(input.SessionID), + Text: input.Text, + Images: sessionImages, + DefaultWorkdir: strings.TrimSpace(defaultWorkdir), + RequestedWorkdir: strings.TrimSpace(input.Workdir), + }) + if err != nil { + return PreparedInputResult{}, err + } + + if len(prepared.Parts) == 0 { + return PreparedInputResult{}, fmt.Errorf("runtime: prepared parts is empty") + } + + return PreparedInputResult{ + UserInput: UserInput{ + SessionID: strings.TrimSpace(prepared.SessionID), + RunID: strings.TrimSpace(input.RunID), + Parts: prepared.Parts, + Workdir: strings.TrimSpace(prepared.Workdir), + }, + SavedAssets: append([]agentsession.AssetMeta(nil), prepared.SavedAssets...), + }, nil +} diff --git a/internal/runtime/input_prepare_test.go b/internal/runtime/input_prepare_test.go new file mode 100644 index 00000000..4c7f32b9 --- /dev/null +++ b/internal/runtime/input_prepare_test.go @@ -0,0 +1,194 @@ +package runtime + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "neo-code/internal/config" + providertypes "neo-code/internal/provider/types" + agentsession "neo-code/internal/session" +) + +func TestServicePrepareUserInputEmitsNormalizeAndAssetSaved(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + svc, _ := newPrepareTestService(t, workdir, true) + + imagePath := filepath.Join(workdir, "img.png") + if err := os.WriteFile(imagePath, minimalPNGBytesForRuntimeTest(), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + input, err := svc.PrepareUserInput(context.Background(), PrepareInput{ + RunID: "run-prepare-1", + Text: "hello", + Images: []UserImageInput{{Path: imagePath, MimeType: "image/png"}}, + }) + if err != nil { + t.Fatalf("PrepareUserInput() error = %v", err) + } + if input.SessionID == "" || input.RunID != "run-prepare-1" { + t.Fatalf("unexpected prepared user input: %+v", input) + } + if len(input.Parts) != 2 || input.Parts[0].Kind != providertypes.ContentPartText || input.Parts[1].Kind != providertypes.ContentPartImage { + t.Fatalf("unexpected prepared parts: %+v", input.Parts) + } + + normalizedEvent := mustReadRuntimeEvent(t, svc.Events()) + if normalizedEvent.Type != EventInputNormalized { + t.Fatalf("expected first event %q, got %q", EventInputNormalized, normalizedEvent.Type) + } + normalizedPayload, ok := normalizedEvent.Payload.(InputNormalizedPayload) + if !ok || normalizedPayload.ImageCount != 1 { + t.Fatalf("unexpected normalized payload: %#v", normalizedEvent.Payload) + } + + assetSavedEvent := mustReadRuntimeEvent(t, svc.Events()) + if assetSavedEvent.Type != EventAssetSaved { + t.Fatalf("expected second event %q, got %q", EventAssetSaved, assetSavedEvent.Type) + } + assetSavedPayload, ok := assetSavedEvent.Payload.(AssetSavedPayload) + if !ok || assetSavedPayload.AssetID == "" || assetSavedPayload.MimeType != "image/png" { + t.Fatalf("unexpected asset_saved payload: %#v", assetSavedEvent.Payload) + } +} + +func TestServicePrepareUserInputEmitsAssetSaveFailed(t *testing.T) { + t.Parallel() + + svc, _ := newPrepareTestService(t, t.TempDir(), true) + _, err := svc.PrepareUserInput(context.Background(), PrepareInput{ + RunID: "run-prepare-2", + Text: "hello", + Images: []UserImageInput{{Path: filepath.Join(t.TempDir(), "missing.png"), MimeType: "image/png"}}, + }) + if err == nil { + t.Fatalf("expected PrepareUserInput() to fail") + } + + failedEvent := mustReadRuntimeEvent(t, svc.Events()) + if failedEvent.Type != EventAssetSaveFailed { + t.Fatalf("expected event %q, got %q", EventAssetSaveFailed, failedEvent.Type) + } + if failedEvent.SessionID == "" { + t.Fatalf("expected asset_save_failed event to include session id") + } + payload, ok := failedEvent.Payload.(AssetSaveFailedPayload) + if !ok || payload.Index != 0 { + t.Fatalf("unexpected asset_save_failed payload: %#v", failedEvent.Payload) + } +} + +func TestServicePrepareUserInputWithoutPreparerEmitsErrorEvent(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + svc, _ := newPrepareTestService(t, workdir, false) + + _, err := svc.PrepareUserInput(context.Background(), PrepareInput{ + RunID: "run-prepare-3", + Text: "hello", + }) + if err == nil { + t.Fatalf("expected PrepareUserInput() to fail without preparer") + } + + errorEvent := mustReadRuntimeEvent(t, svc.Events()) + if errorEvent.Type != EventError { + t.Fatalf("expected event %q, got %q", EventError, errorEvent.Type) + } +} + +func TestServiceSubmitWithoutPreparerReturnsError(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + svc, _ := newPrepareTestService(t, workdir, false) + + err := svc.Submit(context.Background(), PrepareInput{ + RunID: "run-submit-1", + Text: "hello", + }) + if err == nil { + t.Fatalf("expected Submit() to fail without preparer") + } + + errorEvent := mustReadRuntimeEvent(t, svc.Events()) + if errorEvent.Type != EventError { + t.Fatalf("expected event %q, got %q", EventError, errorEvent.Type) + } +} + +func TestServicePrepareUserInputDoesNotBlockWhenPrepareEventQueueIsFull(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + svc, _ := newPrepareTestService(t, workdir, true) + for index := 0; index < cap(svc.events); index++ { + svc.events <- RuntimeEvent{Type: EventToolChunk} + } + + start := time.Now() + input, err := svc.PrepareUserInput(context.Background(), PrepareInput{ + RunID: "run-prepare-full-queue", + Text: "hello", + }) + if err != nil { + t.Fatalf("PrepareUserInput() error = %v", err) + } + if input.SessionID == "" { + t.Fatalf("expected prepared session id") + } + if elapsed := time.Since(start); elapsed > time.Second { + t.Fatalf("PrepareUserInput() blocked too long with full event queue: %v", elapsed) + } +} + +func newPrepareTestService(t *testing.T, workdir string, withPreparer bool) (*Service, *agentsession.SQLiteStore) { + t.Helper() + + cfg := config.StaticDefaults() + cfg.Workdir = workdir + loader := config.NewLoader(t.TempDir(), cfg) + manager := config.NewManager(loader) + if _, err := manager.Load(context.Background()); err != nil { + t.Fatalf("load config: %v", err) + } + + store := agentsession.NewStore(t.TempDir(), workdir) + svc := NewWithFactory(manager, nil, store, nil, nil) + svc.SetSessionAssetStore(store) + if withPreparer { + svc.SetUserInputPreparer(NewSessionInputPreparer(store, store)) + } + return svc, store +} + +func mustReadRuntimeEvent(t *testing.T, events <-chan RuntimeEvent) RuntimeEvent { + t.Helper() + select { + case event := <-events: + return event + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for runtime event") + return RuntimeEvent{} + } +} + +func minimalPNGBytesForRuntimeTest() []byte { + return []byte{ + 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, + 0x00, 0x00, 0x00, 0x0d, 0x49, 0x48, 0x44, 0x52, + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, + 0x08, 0x06, 0x00, 0x00, 0x00, 0x1f, 0x15, 0xc4, + 0x89, 0x00, 0x00, 0x00, 0x0d, 0x49, 0x44, 0x41, + 0x54, 0x78, 0x9c, 0x63, 0xf8, 0xcf, 0xc0, 0x00, + 0x00, 0x03, 0x01, 0x01, 0x00, 0xc9, 0xfe, 0x92, + 0xef, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4e, + 0x44, 0xae, 0x42, 0x60, 0x82, + } +} diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index edf86035..07c6cbc1 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -28,6 +28,8 @@ const ( // Runtime 定义 runtime 对外暴露的运行、压缩与审批接口。 type Runtime interface { + Submit(ctx context.Context, input PrepareInput) error + PrepareUserInput(ctx context.Context, input PrepareInput) (UserInput, error) Run(ctx context.Context, input UserInput) error Compact(ctx context.Context, input CompactInput) (CompactResult, error) ResolvePermission(ctx context.Context, input PermissionResolutionInput) error @@ -51,6 +53,32 @@ type UserInput struct { CapabilityToken *security.CapabilityToken } +// UserImageInput 表示用户输入中附带的单个图片引用(路径 + MIME)。 +type UserImageInput struct { + Path string + MimeType string +} + +// PrepareInput 表示进入 runtime 归一化前的领域输入(仅包含文本/图片/会话上下文)。 +type PrepareInput struct { + SessionID string + RunID string + Workdir string + Text string + Images []UserImageInput +} + +// PreparedInputResult 描述输入归一化完成后的结果快照(标准 UserInput + 本轮保存附件元数据)。 +type PreparedInputResult struct { + UserInput UserInput + SavedAssets []agentsession.AssetMeta +} + +// UserInputPreparer 定义 runtime 输入归一化能力:会话绑定、附件持久化与 parts 组装。 +type UserInputPreparer interface { + Prepare(ctx context.Context, input PrepareInput, defaultWorkdir string) (PreparedInputResult, error) +} + // ProviderFactory 负责基于运行期配置创建 provider 实例。 type ProviderFactory interface { Build(ctx context.Context, cfg provider.RuntimeConfig) (provider.Provider, error) @@ -72,6 +100,7 @@ type Service struct { configManager *config.Manager sessionStore agentsession.Store sessionAssetStore agentsession.AssetStore + userInputPreparer UserInputPreparer toolManager tools.Manager providerFactory ProviderFactory contextBuilder agentcontext.Builder @@ -146,6 +175,11 @@ func (s *Service) SetSessionAssetStore(store agentsession.AssetStore) { s.sessionAssetStore = store } +// SetUserInputPreparer 设置输入归一化能力实现;runtime 仅做编排调用,不承载具体存储细节。 +func (s *Service) SetUserInputPreparer(preparer UserInputPreparer) { + s.userInputPreparer = preparer +} + // SetSkillsRegistry 设置运行时可选的 skills registry,用于激活校验与上下文注入。 func (s *Service) SetSkillsRegistry(registry skills.Registry) { s.skillsRegistry = registry diff --git a/internal/security/workspace_test.go b/internal/security/workspace_test.go index 37fb49fa..032e33d6 100644 --- a/internal/security/workspace_test.go +++ b/internal/security/workspace_test.go @@ -196,6 +196,19 @@ func TestWorkspaceSandboxCheckShortCircuits(t *testing.T) { } } +func TestWorkspaceSandboxCheckRejectsInvalidCapabilityToken(t *testing.T) { + t.Parallel() + + root := t.TempDir() + action := fileAction(ActionTypeRead, "filesystem_read_file", "read_file", root, "notes.txt") + action.Payload.CapabilityToken = &CapabilityToken{} + + _, err := NewWorkspaceSandbox().Check(context.Background(), action) + if err == nil || !strings.Contains(err.Error(), "capability token path not allowed") { + t.Fatalf("expected capability token path rejection, got %v", err) + } +} + func TestBuildWorkspacePlan(t *testing.T) { t.Parallel() @@ -263,6 +276,21 @@ func TestBuildWorkspacePlan(t *testing.T) { wantOK: true, wantTarget: ".", }, + { + name: "sandbox target type falls back to target type", + action: Action{ + Type: ActionTypeRead, + Payload: ActionPayload{ + ToolName: "filesystem_grep", + Resource: "filesystem_grep", + Workdir: root, + TargetType: TargetTypeDirectory, + Target: "docs", + }, + }, + wantOK: true, + wantTarget: "docs", + }, } for _, tt := range tests { @@ -291,6 +319,21 @@ func TestBuildWorkspacePlan(t *testing.T) { } } +func TestWorkspaceSandboxValidateWorkspacePlanErrors(t *testing.T) { + t.Parallel() + + sandbox := NewWorkspaceSandbox() + _, err := sandbox.validateWorkspacePlan(workspacePlan{ + root: filepath.Join(t.TempDir(), "missing"), + target: "notes.txt", + targetType: TargetTypePath, + actionType: ActionTypeRead, + }) + if err == nil || !strings.Contains(err.Error(), "resolve workspace root") { + t.Fatalf("expected resolve workspace root error, got %v", err) + } +} + func TestNeedsWorkspaceSandbox(t *testing.T) { t.Parallel() @@ -441,6 +484,13 @@ func TestCanonicalWorkspaceRoot(t *testing.T) { if _, ok := sandbox.canonicalRoots.Load(cleanedPathKey(existing)); !ok { t.Fatalf("expected canonical root cache entry for %q", existing) } + gotCached, err := sandbox.canonicalWorkspaceRoot(existing) + if err != nil { + t.Fatalf("canonicalWorkspaceRoot(cached) error: %v", err) + } + if !samePathKey(gotCached, got) { + t.Fatalf("canonicalWorkspaceRoot(cached) = %q, want %q", gotCached, got) + } missing := filepath.Join(t.TempDir(), "missing", "dir") _, err = sandbox.canonicalWorkspaceRoot(missing) @@ -776,6 +826,54 @@ func TestWorkspaceExecutionPlanValidateForExecution(t *testing.T) { t.Fatalf("expected valid plan, got %v", err) } }) + + t.Run("nearest existing path failure is returned", func(t *testing.T) { + t.Parallel() + + root := t.TempDir() + file := filepath.Join(root, "file.txt") + mustWriteWorkspaceFile(t, file, "x") + plan := &WorkspaceExecutionPlan{ + Root: root, + Target: filepath.Join(file, "child.txt"), + RequestedTarget: filepath.Join("file.txt", "child.txt"), + anchorPath: file, + anchorSnapshot: pathSnapshot{}, + } + err := plan.ValidateForExecution() + if err == nil || !strings.Contains(err.Error(), "inspect path") { + t.Fatalf("expected inspect path error, got %v", err) + } + }) + + t.Run("anchor path mismatch is rejected", func(t *testing.T) { + t.Parallel() + + root := t.TempDir() + targetA := filepath.Join(root, "a") + targetB := filepath.Join(root, "b") + if err := os.MkdirAll(targetA, 0o755); err != nil { + t.Fatalf("mkdir targetA: %v", err) + } + if err := os.MkdirAll(targetB, 0o755); err != nil { + t.Fatalf("mkdir targetB: %v", err) + } + snapshot, err := capturePathSnapshot(targetB) + if err != nil { + t.Fatalf("capturePathSnapshot(targetB): %v", err) + } + plan := &WorkspaceExecutionPlan{ + Root: root, + Target: targetA, + RequestedTarget: "a", + anchorPath: targetB, + anchorSnapshot: snapshot, + } + err = plan.ValidateForExecution() + if err == nil || !strings.Contains(err.Error(), "changed before execution") { + t.Fatalf("expected changed-before-execution error, got %v", err) + } + }) } func TestCapturePathSnapshotAndEqual(t *testing.T) { @@ -893,6 +991,28 @@ func TestNearestExistingPath(t *testing.T) { return cleanedPathKey(filepath.Join(root, "broken")) }, }, + { + name: "returns inspect error for non-not-exist lstat", + setup: func(t *testing.T) (string, string) { + t.Helper() + root := t.TempDir() + file := filepath.Join(root, "file.txt") + mustWriteWorkspaceFile(t, file, "x") + return root, filepath.Join(file, "child.txt") + }, + expectErr: "inspect path", + }, + { + name: "missing root path returns root", + setup: func(t *testing.T) (string, string) { + t.Helper() + root := filepath.Join(t.TempDir(), "missing-root") + return root, root + }, + expect: func(root string, target string) string { + return cleanedPathKey(root) + }, + }, } for _, tt := range tests { @@ -918,6 +1038,27 @@ func TestNearestExistingPath(t *testing.T) { } } +func TestEnsureNoSymlinkEscapeReturnsNearestPathError(t *testing.T) { + t.Parallel() + + root := t.TempDir() + file := filepath.Join(root, "file.txt") + mustWriteWorkspaceFile(t, file, "x") + + _, err := ensureNoSymlinkEscape(root, filepath.Join(file, "child.txt"), filepath.Join("file.txt", "child.txt")) + if err == nil || !strings.Contains(err.Error(), "inspect path") { + t.Fatalf("expected inspect path error, got %v", err) + } +} + +func TestValidateTargetVolumeNoVolumeShortCircuit(t *testing.T) { + t.Parallel() + + if err := validateTargetVolume(filepath.Join(t.TempDir(), "workspace"), filepath.Join(t.TempDir(), "target")); err != nil { + t.Fatalf("validateTargetVolume() error = %v, want nil on non-volume paths", err) + } +} + func TestSplitRelativePath(t *testing.T) { t.Parallel() diff --git a/internal/session/input_preparer.go b/internal/session/input_preparer.go new file mode 100644 index 00000000..3168414b --- /dev/null +++ b/internal/session/input_preparer.go @@ -0,0 +1,460 @@ +package session + +import ( + "context" + "fmt" + "io" + "mime" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + providertypes "neo-code/internal/provider/types" +) + +const imageOnlySessionTitle = "Image Message" + +// PrepareImageInput 表示一次用户输入中附带的本地图片引用。 +type PrepareImageInput struct { + Path string + MimeType string +} + +// PrepareInput 定义会话输入归一化的领域输入参数。 +type PrepareInput struct { + SessionID string + Text string + Images []PrepareImageInput + DefaultWorkdir string + RequestedWorkdir string +} + +// PreparedInput 表示归一化完成后可直接进入 runtime 的标准输入结果。 +type PreparedInput struct { + SessionID string + Workdir string + Parts []providertypes.ContentPart + SavedAssets []AssetMeta +} + +// AssetSaveError 描述图片落盘阶段的结构化失败信息,便于上层统一事件化处理。 +type AssetSaveError struct { + SessionID string + Index int + Path string + Err error +} + +func (e *AssetSaveError) Error() string { + if e == nil { + return "session: asset save failed" + } + if strings.TrimSpace(e.Path) == "" { + return fmt.Sprintf("session: save asset at index %d: %v", e.Index, e.Err) + } + return fmt.Sprintf("session: save asset %q at index %d: %v", e.Path, e.Index, e.Err) +} + +func (e *AssetSaveError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} + +// InputPreparer 负责把用户文本/图片输入归一化为会话级标准 parts。 +type InputPreparer struct { + store Store + assetStore AssetStore +} + +type assetCleanupStore interface { + DeleteAsset(ctx context.Context, sessionID string, assetID string) error +} + +type sessionCleanupStore interface { + DeleteSession(ctx context.Context, sessionID string) error +} + +// NewInputPreparer 创建会话输入归一化组件。 +func NewInputPreparer(store Store, assetStore AssetStore) *InputPreparer { + return &InputPreparer{ + store: store, + assetStore: assetStore, + } +} + +// Prepare 负责会话解析/创建、附件落盘与 parts 组装。 +func (p *InputPreparer) Prepare(ctx context.Context, input PrepareInput) (PreparedInput, error) { + if err := ctx.Err(); err != nil { + return PreparedInput{}, err + } + if p == nil || p.store == nil { + return PreparedInput{}, fmt.Errorf("session: input preparer store is not configured") + } + if len(input.Images) > 0 && p.assetStore == nil { + return PreparedInput{}, fmt.Errorf("session: asset store is not configured") + } + + trimmedText := strings.TrimSpace(input.Text) + if trimmedText == "" && len(input.Images) == 0 { + return PreparedInput{}, fmt.Errorf("session: input content is empty") + } + + sessionTitle := buildSessionTitle(trimmedText, len(input.Images) > 0) + session, sessionCreated, pendingUpdate, err := p.loadOrCreateSession( + ctx, + input.SessionID, + sessionTitle, + input.DefaultWorkdir, + input.RequestedWorkdir, + ) + if err != nil { + return PreparedInput{}, err + } + + parts := make([]providertypes.ContentPart, 0, 1+len(input.Images)) + if trimmedText != "" { + parts = append(parts, providertypes.NewTextPart(trimmedText)) + } + + savedAssets := make([]AssetMeta, 0, len(input.Images)) + for index, image := range input.Images { + path := strings.TrimSpace(image.Path) + if path == "" { + p.rollbackCreatedSession(ctx, session.ID, sessionCreated) + p.cleanupSavedAssets(ctx, session.ID, savedAssets) + return PreparedInput{}, &AssetSaveError{ + SessionID: session.ID, + Index: index, + Path: path, + Err: fmt.Errorf("image path is empty"), + } + } + mimeType := strings.TrimSpace(image.MimeType) + + meta, err := p.saveImageAsset(ctx, session.ID, session.Workdir, path, mimeType) + if err != nil { + p.rollbackCreatedSession(ctx, session.ID, sessionCreated) + p.cleanupSavedAssets(ctx, session.ID, savedAssets) + return PreparedInput{}, &AssetSaveError{ + SessionID: session.ID, + Index: index, + Path: path, + Err: err, + } + } + savedAssets = append(savedAssets, meta) + parts = append(parts, providertypes.NewSessionAssetImagePart(meta.ID, meta.MimeType)) + } + + if err := providertypes.ValidateParts(parts); err != nil { + p.rollbackCreatedSession(ctx, session.ID, sessionCreated) + p.cleanupSavedAssets(ctx, session.ID, savedAssets) + return PreparedInput{}, fmt.Errorf("session: normalize parts: %w", err) + } + if err := p.persistSessionWorkdirUpdate(ctx, pendingUpdate); err != nil { + p.rollbackCreatedSession(ctx, session.ID, sessionCreated) + p.cleanupSavedAssets(ctx, session.ID, savedAssets) + return PreparedInput{}, err + } + + return PreparedInput{ + SessionID: session.ID, + Workdir: session.Workdir, + Parts: parts, + SavedAssets: savedAssets, + }, nil +} + +// saveImageAsset 按会话工作目录解析并校验图片路径后落盘,禁止越界访问工作目录外文件。 +func (p *InputPreparer) saveImageAsset( + ctx context.Context, + sessionID string, + workdir string, + path string, + mimeType string, +) (AssetMeta, error) { + if err := ctx.Err(); err != nil { + return AssetMeta{}, err + } + + absolutePath, err := resolveImagePath(workdir, path) + if err != nil { + return AssetMeta{}, err + } + if err := ctx.Err(); err != nil { + return AssetMeta{}, err + } + + file, err := os.Open(absolutePath) + if err != nil { + return AssetMeta{}, fmt.Errorf("open image file: %w", err) + } + defer func() { + _ = file.Close() + }() + if err := ctx.Err(); err != nil { + return AssetMeta{}, err + } + + resolvedMimeType, err := resolveImageMimeType(ctx, path, mimeType, file) + if err != nil { + return AssetMeta{}, err + } + if err := ctx.Err(); err != nil { + return AssetMeta{}, err + } + + meta, err := p.assetStore.SaveAsset(ctx, sessionID, file, resolvedMimeType) + if err != nil { + return AssetMeta{}, err + } + return meta, nil +} + +// resolveImageMimeType 解析图片 MIME 类型,仅允许 image/*,并要求声明值与文件头探测一致。 +func resolveImageMimeType(ctx context.Context, path string, declared string, file *os.File) (string, error) { + if err := ctx.Err(); err != nil { + return "", err + } + + detected, err := detectImageMimeTypeFromFile(ctx, file) + if err != nil { + return "", err + } + + declaredMime := normalizeMimeType(declared) + if declaredMime != "" { + if !strings.HasPrefix(declaredMime, "image/") { + return "", fmt.Errorf("declared mime type %q is not an image", declared) + } + if declaredMime != detected { + return "", fmt.Errorf("declared mime type %q mismatches detected %q", declaredMime, detected) + } + return detected, nil + } + + extMime := normalizeMimeType(mime.TypeByExtension(strings.ToLower(filepath.Ext(path)))) + if extMime != "" && strings.HasPrefix(extMime, "image/") && extMime != detected { + return "", fmt.Errorf("file extension mime %q mismatches detected %q", extMime, detected) + } + return detected, nil +} + +// detectImageMimeTypeFromFile 根据文件头探测 MIME,且要求结果为 image/*。 +func detectImageMimeTypeFromFile(ctx context.Context, file *os.File) (string, error) { + if err := ctx.Err(); err != nil { + return "", err + } + + buffer := make([]byte, 512) + n, readErr := file.Read(buffer) + if readErr != nil && readErr != io.EOF { + return "", fmt.Errorf("detect image mime type: %w", readErr) + } + if err := ctx.Err(); err != nil { + return "", err + } + if _, err := file.Seek(0, io.SeekStart); err != nil { + return "", fmt.Errorf("reset image reader: %w", err) + } + + detected := strings.ToLower(strings.TrimSpace(http.DetectContentType(buffer[:n]))) + if strings.HasPrefix(detected, "image/") { + return detected, nil + } + return "", fmt.Errorf("unsupported image format") +} + +// normalizeMimeType 清洗 MIME 字符串并移除参数段,返回小写标准形式。 +func normalizeMimeType(value string) string { + normalized := strings.ToLower(strings.TrimSpace(value)) + if normalized == "" { + return "" + } + if idx := strings.Index(normalized, ";"); idx >= 0 { + normalized = strings.TrimSpace(normalized[:idx]) + } + return normalized +} + +// resolveImagePath 以会话工作目录为基准解析图片路径并强制限制在工作目录内。 +func resolveImagePath(workdir string, path string) (string, error) { + base := strings.TrimSpace(workdir) + if base == "" { + return "", fmt.Errorf("resolve image path: workdir is empty") + } + baseAbs, err := filepath.Abs(base) + if err != nil { + return "", fmt.Errorf("resolve image path base: %w", err) + } + + target := strings.TrimSpace(path) + if target == "" { + return "", fmt.Errorf("resolve image path: path is empty") + } + if !filepath.IsAbs(target) { + target = filepath.Join(baseAbs, target) + } + + targetAbs, err := filepath.Abs(target) + if err != nil { + return "", fmt.Errorf("resolve image path: %w", err) + } + if err := ensurePathWithinBase(baseAbs, targetAbs); err != nil { + return "", fmt.Errorf("resolve image path: %w", err) + } + + resolved := targetAbs + if linkTarget, linkErr := filepath.EvalSymlinks(targetAbs); linkErr == nil { + if err := ensurePathWithinBase(baseAbs, linkTarget); err != nil { + return "", fmt.Errorf("resolve image path: %w", err) + } + resolved = linkTarget + } + return resolved, nil +} + +// sessionWorkdirUpdate 描述已有会话 workdir 的待提交变更,确保 Prepare 成功后再落盘。 +type sessionWorkdirUpdate struct { + session Session + dirty bool +} + +func (p *InputPreparer) loadOrCreateSession( + ctx context.Context, + sessionID string, + title string, + defaultWorkdir string, + requestedWorkdir string, +) (Session, bool, sessionWorkdirUpdate, error) { + if strings.TrimSpace(sessionID) == "" { + sessionWorkdir, err := resolveWorkdirForInput(defaultWorkdir, "", requestedWorkdir) + if err != nil { + return Session{}, false, sessionWorkdirUpdate{}, err + } + session := NewWithWorkdir(title, sessionWorkdir) + created, err := p.store.CreateSession(ctx, CreateSessionInput{ + ID: session.ID, + Title: session.Title, + CreatedAt: session.CreatedAt, + UpdatedAt: session.UpdatedAt, + Provider: session.Provider, + Model: session.Model, + Workdir: session.Workdir, + TaskState: session.TaskState, + ActivatedSkills: session.ActivatedSkills, + Todos: session.Todos, + TokenInputTotal: session.TokenInputTotal, + TokenOutputTotal: session.TokenOutputTotal, + }) + if err != nil { + return Session{}, false, sessionWorkdirUpdate{}, err + } + return created, true, sessionWorkdirUpdate{}, nil + } + + session, err := p.store.LoadSession(ctx, sessionID) + if err != nil { + return Session{}, false, sessionWorkdirUpdate{}, err + } + if strings.TrimSpace(requestedWorkdir) == "" && strings.TrimSpace(session.Workdir) != "" { + return session, false, sessionWorkdirUpdate{}, nil + } + + resolved, err := resolveWorkdirForInput(defaultWorkdir, session.Workdir, requestedWorkdir) + if err != nil { + return Session{}, false, sessionWorkdirUpdate{}, err + } + if session.Workdir == resolved { + return session, false, sessionWorkdirUpdate{}, nil + } + + session.Workdir = resolved + session.UpdatedAt = time.Now() + return session, false, sessionWorkdirUpdate{ + session: session, + dirty: true, + }, nil +} + +// rollbackCreatedSession 在本次 Prepare 新建会话后发生错误时回滚会话目录,避免残留孤儿会话。 +func (p *InputPreparer) rollbackCreatedSession(ctx context.Context, sessionID string, created bool) { + if !created { + return + } + if err := ctx.Err(); err != nil { + return + } + cleanupStore, ok := p.store.(sessionCleanupStore) + if !ok { + return + } + _ = cleanupStore.DeleteSession(ctx, sessionID) +} + +// persistSessionWorkdirUpdate 在 Prepare 其余步骤完成后统一提交会话 workdir 更新,避免失败时出现部分提交。 +func (p *InputPreparer) persistSessionWorkdirUpdate(ctx context.Context, pending sessionWorkdirUpdate) error { + if !pending.dirty { + return nil + } + if err := p.store.UpdateSessionState(ctx, UpdateSessionStateInput{ + SessionID: pending.session.ID, + Title: pending.session.Title, + UpdatedAt: pending.session.UpdatedAt, + Provider: pending.session.Provider, + Model: pending.session.Model, + Workdir: pending.session.Workdir, + TaskState: pending.session.TaskState, + ActivatedSkills: pending.session.ActivatedSkills, + Todos: pending.session.Todos, + TokenInputTotal: pending.session.TokenInputTotal, + TokenOutputTotal: pending.session.TokenOutputTotal, + }); err != nil { + return err + } + return nil +} + +// cleanupSavedAssets 在 Prepare 失败时尽力回收已落盘的附件,减少 existing session 残留垃圾文件。 +func (p *InputPreparer) cleanupSavedAssets(ctx context.Context, sessionID string, assets []AssetMeta) { + if len(assets) == 0 || ctx.Err() != nil { + return + } + cleanupStore, ok := p.assetStore.(assetCleanupStore) + if !ok { + return + } + for _, asset := range assets { + if strings.TrimSpace(asset.ID) == "" { + continue + } + _ = cleanupStore.DeleteAsset(ctx, sessionID, asset.ID) + } +} + +func resolveWorkdirForInput(defaultWorkdir string, currentWorkdir string, requestedWorkdir string) (string, error) { + base := EffectiveWorkdir(currentWorkdir, defaultWorkdir) + if strings.TrimSpace(requestedWorkdir) == "" { + return ResolveExistingDir(base) + } + + target := strings.TrimSpace(requestedWorkdir) + if !filepath.IsAbs(target) { + target = filepath.Join(base, target) + } + return ResolveExistingDir(target) +} + +func buildSessionTitle(text string, hasImages bool) string { + if strings.TrimSpace(text) != "" { + return strings.TrimSpace(text) + } + if hasImages { + return imageOnlySessionTitle + } + return "New Session" +} diff --git a/internal/session/input_preparer_test.go b/internal/session/input_preparer_test.go new file mode 100644 index 00000000..3b177748 --- /dev/null +++ b/internal/session/input_preparer_test.go @@ -0,0 +1,482 @@ +package session + +import ( + "context" + "errors" + "io" + "os" + "path/filepath" + "strings" + "testing" + + providertypes "neo-code/internal/provider/types" +) + +func TestInputPreparerPrepareTextOnly(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + store := NewStore(t.TempDir(), workdir) + preparer := NewInputPreparer(store, store) + + result, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "hello world", + DefaultWorkdir: workdir, + }) + if err != nil { + t.Fatalf("Prepare() error = %v", err) + } + if result.SessionID == "" { + t.Fatalf("expected non-empty session id") + } + if len(result.Parts) != 1 || result.Parts[0].Kind != providertypes.ContentPartText || result.Parts[0].Text != "hello world" { + t.Fatalf("unexpected prepared parts: %+v", result.Parts) + } + if len(result.SavedAssets) != 0 { + t.Fatalf("expected no saved assets, got %+v", result.SavedAssets) + } +} + +func TestInputPreparerPrepareTextAndImage(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + store := NewStore(t.TempDir(), workdir) + preparer := NewInputPreparer(store, store) + + imagePath := filepath.Join(workdir, "img.png") + payload := minimalPNGBytes() + if err := os.WriteFile(imagePath, payload, 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + result, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "with image", + Images: []PrepareImageInput{{Path: imagePath, MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err != nil { + t.Fatalf("Prepare() error = %v", err) + } + if len(result.SavedAssets) != 1 { + t.Fatalf("expected one saved asset, got %+v", result.SavedAssets) + } + if len(result.Parts) != 2 { + t.Fatalf("expected 2 parts, got %+v", result.Parts) + } + imagePart := result.Parts[1] + if imagePart.Kind != providertypes.ContentPartImage || imagePart.Image == nil || imagePart.Image.Asset == nil { + t.Fatalf("expected session asset image part, got %+v", imagePart) + } + if imagePart.Image.Asset.ID != result.SavedAssets[0].ID { + t.Fatalf("expected image part asset id %q, got %+v", result.SavedAssets[0].ID, imagePart.Image.Asset) + } + + rc, meta, err := store.Open(context.Background(), result.SessionID, result.SavedAssets[0].ID) + if err != nil { + t.Fatalf("Open() error = %v", err) + } + defer func() { _ = rc.Close() }() + got, err := io.ReadAll(rc) + if err != nil { + t.Fatalf("ReadAll() error = %v", err) + } + if meta.MimeType != "image/png" || string(got) != string(payload) { + t.Fatalf("unexpected stored asset mime=%q payload=%q", meta.MimeType, string(got)) + } +} + +func TestInputPreparerPrepareImageInfersMimeWhenMissing(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + store := NewStore(t.TempDir(), workdir) + preparer := NewInputPreparer(store, store) + + imagePath := filepath.Join(workdir, "auto.png") + if err := os.WriteFile(imagePath, minimalPNGBytes(), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + result, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "infer mime", + Images: []PrepareImageInput{{Path: imagePath}}, + DefaultWorkdir: workdir, + }) + if err != nil { + t.Fatalf("Prepare() error = %v", err) + } + if len(result.SavedAssets) != 1 { + t.Fatalf("expected one saved asset, got %+v", result.SavedAssets) + } + if result.SavedAssets[0].MimeType != "image/png" { + t.Fatalf("expected inferred mime image/png, got %q", result.SavedAssets[0].MimeType) + } +} + +func TestInputPreparerPrepareImageOnlyUsesImageTitle(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + store := NewStore(t.TempDir(), workdir) + preparer := NewInputPreparer(store, store) + + imagePath := filepath.Join(workdir, "only.png") + if err := os.WriteFile(imagePath, minimalPNGBytes(), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + result, err := preparer.Prepare(context.Background(), PrepareInput{ + Images: []PrepareImageInput{{Path: imagePath, MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err != nil { + t.Fatalf("Prepare() error = %v", err) + } + if len(result.Parts) != 1 || result.Parts[0].Kind != providertypes.ContentPartImage { + t.Fatalf("expected one image part, got %+v", result.Parts) + } + + session, err := store.LoadSession(context.Background(), result.SessionID) + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if session.Title != imageOnlySessionTitle { + t.Fatalf("expected image-only title %q, got %q", imageOnlySessionTitle, session.Title) + } +} + +func TestInputPreparerPrepareErrors(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + store := NewStore(t.TempDir(), workdir) + + t.Run("missing store", func(t *testing.T) { + preparer := NewInputPreparer(nil, nil) + if _, err := preparer.Prepare(context.Background(), PrepareInput{Text: "x", DefaultWorkdir: workdir}); err == nil { + t.Fatalf("expected missing store error") + } + }) + + t.Run("missing asset store", func(t *testing.T) { + preparer := NewInputPreparer(store, nil) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Images: []PrepareImageInput{{Path: "x", MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected missing asset store error") + } + }) + + t.Run("empty content", func(t *testing.T) { + preparer := NewInputPreparer(store, store) + if _, err := preparer.Prepare(context.Background(), PrepareInput{DefaultWorkdir: workdir}); err == nil { + t.Fatalf("expected empty content error") + } + }) + + t.Run("asset save error is structured", func(t *testing.T) { + preparer := NewInputPreparer(store, store) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Images: []PrepareImageInput{{Path: "not-found.png", MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected asset save error") + } + var saveErr *AssetSaveError + if !errors.As(err, &saveErr) { + t.Fatalf("expected AssetSaveError, got %T %v", err, err) + } + if saveErr.Index != 0 { + t.Fatalf("expected save error index 0, got %d", saveErr.Index) + } + if saveErr.SessionID == "" { + t.Fatalf("expected save error session id") + } + }) + + t.Run("new session is rolled back when asset save fails", func(t *testing.T) { + preparer := NewInputPreparer(store, store) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Images: []PrepareImageInput{{Path: "not-found.png", MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected asset save error") + } + + summaries, listErr := store.ListSummaries(context.Background()) + if listErr != nil { + t.Fatalf("ListSummaries() error = %v", listErr) + } + if len(summaries) != 0 { + t.Fatalf("expected no persisted session after rollback, got %+v", summaries) + } + }) + + t.Run("existing session is kept when asset save fails", func(t *testing.T) { + existing := NewWithWorkdir("existing", workdir) + if err := createSessionForPreparerTest(context.Background(), store, existing); err != nil { + t.Fatalf("createSessionForPreparerTest() error = %v", err) + } + + preparer := NewInputPreparer(store, store) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + SessionID: existing.ID, + Images: []PrepareImageInput{{Path: "not-found.png", MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected asset save error") + } + + if _, loadErr := store.LoadSession(context.Background(), existing.ID); loadErr != nil { + t.Fatalf("expected existing session to remain, load error = %v", loadErr) + } + }) + + t.Run("existing session cleanup removes previously saved assets on later failure", func(t *testing.T) { + existing := NewWithWorkdir("existing-cleanup", workdir) + if err := createSessionForPreparerTest(context.Background(), store, existing); err != nil { + t.Fatalf("createSessionForPreparerTest() error = %v", err) + } + + okImage := filepath.Join(workdir, "ok.png") + if err := os.WriteFile(okImage, minimalPNGBytes(), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + preparer := NewInputPreparer(store, store) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + SessionID: existing.ID, + Text: "cleanup", + Images: []PrepareImageInput{ + {Path: okImage}, + {Path: "not-found.png", MimeType: "image/png"}, + }, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected prepare error") + } + + entries, readErr := os.ReadDir(filepath.Join(store.assetsDir, existing.ID)) + if readErr != nil { + t.Fatalf("ReadDir() error = %v", readErr) + } + if len(entries) != 0 { + t.Fatalf("expected no leftover assets, got %d files", len(entries)) + } + }) + + t.Run("existing session workdir change is not persisted when prepare fails", func(t *testing.T) { + currentWorkdir := filepath.Join(workdir, "current") + if err := os.MkdirAll(currentWorkdir, 0o755); err != nil { + t.Fatalf("mkdir current workdir: %v", err) + } + targetWorkdir := filepath.Join(currentWorkdir, "nested") + if err := os.MkdirAll(targetWorkdir, 0o755); err != nil { + t.Fatalf("mkdir nested workdir: %v", err) + } + + existing := NewWithWorkdir("existing-workdir", currentWorkdir) + if err := createSessionForPreparerTest(context.Background(), store, existing); err != nil { + t.Fatalf("createSessionForPreparerTest() error = %v", err) + } + + preparer := NewInputPreparer(store, store) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + SessionID: existing.ID, + Text: "will fail", + RequestedWorkdir: "nested", + Images: []PrepareImageInput{{Path: "not-found.png", MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected prepare error") + } + + loaded, loadErr := store.LoadSession(context.Background(), existing.ID) + if loadErr != nil { + t.Fatalf("Load() error = %v", loadErr) + } + if loaded.Workdir != currentWorkdir { + t.Fatalf("expected workdir to stay %q, got %q", currentWorkdir, loaded.Workdir) + } + }) +} + +func TestInputPreparerPrepareImagePathAndMimeValidation(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + store := NewStore(t.TempDir(), workdir) + preparer := NewInputPreparer(store, store) + + t.Run("relative path is resolved by workdir", func(t *testing.T) { + relativeDir := filepath.Join(workdir, "images") + if err := os.MkdirAll(relativeDir, 0o755); err != nil { + t.Fatalf("mkdir images: %v", err) + } + imagePath := filepath.Join(relativeDir, "a.png") + if err := os.WriteFile(imagePath, minimalPNGBytes(), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + result, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "relative path", + Images: []PrepareImageInput{{Path: filepath.Join("images", "a.png")}}, + DefaultWorkdir: workdir, + }) + if err != nil { + t.Fatalf("Prepare() error = %v", err) + } + if len(result.SavedAssets) != 1 || result.SavedAssets[0].MimeType != "image/png" { + t.Fatalf("unexpected saved assets: %+v", result.SavedAssets) + } + }) + + t.Run("path outside workdir is rejected", func(t *testing.T) { + outside := filepath.Join(t.TempDir(), "outside.png") + if err := os.WriteFile(outside, minimalPNGBytes(), 0o644); err != nil { + t.Fatalf("write outside image: %v", err) + } + + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "outside", + Images: []PrepareImageInput{{Path: outside, MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected outside workdir error") + } + if !strings.Contains(err.Error(), "escapes base dir") { + t.Fatalf("expected escapes base dir error, got %v", err) + } + }) + + t.Run("declared mime mismatch with file header is rejected", func(t *testing.T) { + imagePath := filepath.Join(workdir, "declared-mismatch.png") + if err := os.WriteFile(imagePath, minimalPNGBytes(), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "declared mismatch", + Images: []PrepareImageInput{{Path: imagePath, MimeType: "image/jpeg"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected mime mismatch error") + } + if !strings.Contains(err.Error(), "mismatches detected") { + t.Fatalf("expected mismatch error, got %v", err) + } + }) +} + +func TestAssetSaveErrorMethods(t *testing.T) { + t.Parallel() + + if err := (*AssetSaveError)(nil).Unwrap(); err != nil { + t.Fatalf("expected nil asset save error unwrap to return nil, got %v", err) + } + if msg := (*AssetSaveError)(nil).Error(); msg != "session: asset save failed" { + t.Fatalf("unexpected nil asset save error message: %q", msg) + } + + inner := errors.New("boom") + assetErr := &AssetSaveError{ + SessionID: "session-1", + Index: 2, + Path: "/tmp/image.png", + Err: inner, + } + if !errors.Is(assetErr, inner) { + t.Fatalf("expected asset save error to unwrap inner error") + } + if !strings.Contains(assetErr.Error(), "image.png") || !strings.Contains(assetErr.Error(), "index 2") { + t.Fatalf("unexpected asset save error message: %q", assetErr.Error()) + } +} + +func minimalPNGBytes() []byte { + return []byte{ + 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, + 0x00, 0x00, 0x00, 0x0d, 0x49, 0x48, 0x44, 0x52, + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, + 0x08, 0x06, 0x00, 0x00, 0x00, 0x1f, 0x15, 0xc4, + 0x89, 0x00, 0x00, 0x00, 0x0d, 0x49, 0x44, 0x41, + 0x54, 0x78, 0x9c, 0x63, 0xf8, 0xcf, 0xc0, 0x00, + 0x00, 0x03, 0x01, 0x01, 0x00, 0xc9, 0xfe, 0x92, + 0xef, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4e, + 0x44, 0xae, 0x42, 0x60, 0x82, + } +} + +func TestInputPreparerPrepareUpdatesExistingSessionWorkdir(t *testing.T) { + t.Parallel() + + base := t.TempDir() + defaultWorkdir := filepath.Join(base, "workspace") + if err := os.MkdirAll(defaultWorkdir, 0o755); err != nil { + t.Fatalf("mkdir default workdir: %v", err) + } + currentWorkdir := filepath.Join(defaultWorkdir, "current") + if err := os.MkdirAll(currentWorkdir, 0o755); err != nil { + t.Fatalf("mkdir current workdir: %v", err) + } + targetWorkdir := filepath.Join(currentWorkdir, "nested") + if err := os.MkdirAll(targetWorkdir, 0o755); err != nil { + t.Fatalf("mkdir nested workdir: %v", err) + } + + store := NewStore(t.TempDir(), defaultWorkdir) + session := NewWithWorkdir("existing", currentWorkdir) + if err := createSessionForPreparerTest(context.Background(), store, session); err != nil { + t.Fatalf("createSessionForPreparerTest() error = %v", err) + } + + preparer := NewInputPreparer(store, store) + result, err := preparer.Prepare(context.Background(), PrepareInput{ + SessionID: session.ID, + Text: "update workdir", + DefaultWorkdir: defaultWorkdir, + RequestedWorkdir: "nested", + }) + if err != nil { + t.Fatalf("Prepare() error = %v", err) + } + if result.Workdir != targetWorkdir { + t.Fatalf("expected target workdir %q, got %q", targetWorkdir, result.Workdir) + } + + loaded, err := store.LoadSession(context.Background(), session.ID) + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if loaded.Workdir != targetWorkdir { + t.Fatalf("expected persisted workdir %q, got %q", targetWorkdir, loaded.Workdir) + } +} + +func createSessionForPreparerTest(ctx context.Context, store *SQLiteStore, session Session) error { + _, err := store.CreateSession(ctx, CreateSessionInput{ + ID: session.ID, + Title: session.Title, + CreatedAt: session.CreatedAt, + UpdatedAt: session.UpdatedAt, + Provider: session.Provider, + Model: session.Model, + Workdir: session.Workdir, + TaskState: session.TaskState, + ActivatedSkills: session.ActivatedSkills, + Todos: session.Todos, + TokenInputTotal: session.TokenInputTotal, + TokenOutputTotal: session.TokenOutputTotal, + }) + return err +} diff --git a/internal/session/sqlite_store.go b/internal/session/sqlite_store.go index 826ef923..ffc0d6df 100644 --- a/internal/session/sqlite_store.go +++ b/internal/session/sqlite_store.go @@ -501,6 +501,71 @@ func (s *SQLiteStore) Stat(ctx context.Context, sessionID string, assetID string return meta, nil } +// DeleteAsset 删除指定会话附件的元数据与二进制文件,缺失目标按幂等处理。 +func (s *SQLiteStore) DeleteAsset(ctx context.Context, sessionID string, assetID string) error { + if err := ctx.Err(); err != nil { + return err + } + if err := validateStorageID("session id", sessionID); err != nil { + return fmt.Errorf("session: %w", err) + } + if err := validateStorageID("asset id", assetID); err != nil { + return fmt.Errorf("session: %w", err) + } + db, err := s.ensureDB(ctx) + if err != nil { + return err + } + + meta, path, err := s.loadAssetMeta(ctx, sessionID, assetID) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return err + } + + result, execErr := db.ExecContext(ctx, `DELETE FROM session_assets WHERE session_id = ? AND id = ?`, sessionID, assetID) + if execErr != nil { + return fmt.Errorf("session: delete asset meta %s: %w", assetID, execErr) + } + if affected, affErr := result.RowsAffected(); affErr == nil && affected == 0 && errors.Is(err, os.ErrNotExist) { + return nil + } + + if strings.TrimSpace(meta.ID) == "" { + return nil + } + if removeErr := os.Remove(path); removeErr != nil && !errors.Is(removeErr, os.ErrNotExist) { + return fmt.Errorf("session: delete asset file %s: %w", assetID, removeErr) + } + return nil +} + +// DeleteSession 删除会话头、消息、附件元数据,并清理对应附件目录。 +func (s *SQLiteStore) DeleteSession(ctx context.Context, sessionID string) error { + if err := ctx.Err(); err != nil { + return err + } + if err := validateStorageID("session id", sessionID); err != nil { + return fmt.Errorf("session: %w", err) + } + db, err := s.ensureDB(ctx) + if err != nil { + return err + } + + if _, err := db.ExecContext(ctx, `DELETE FROM sessions WHERE id = ?`, sessionID); err != nil { + return fmt.Errorf("session: delete session %s: %w", sessionID, err) + } + + assetDir := filepath.Join(s.assetsDir, sessionID) + if err := ensurePathWithinBase(s.projectDir, assetDir); err != nil { + return fmt.Errorf("session: resolve assets dir path: %w", err) + } + if err := os.RemoveAll(assetDir); err != nil { + return fmt.Errorf("session: delete assets dir %s: %w", sessionID, err) + } + return nil +} + // ensureDB 懒加载数据库并执行 schema 初始化。 func (s *SQLiteStore) ensureDB(ctx context.Context) (*sql.DB, error) { s.initMu.Lock() diff --git a/internal/tools/todo/write.go b/internal/tools/todo/write.go index c073d72c..8a94812e 100644 --- a/internal/tools/todo/write.go +++ b/internal/tools/todo/write.go @@ -138,6 +138,9 @@ func (t *Tool) Schema() map[string]any { }, "artifacts": map[string]any{ "type": "array", + "items": map[string]any{ + "type": "string", + }, }, "reason": map[string]any{ "type": "string", diff --git a/internal/tools/todo/write_test.go b/internal/tools/todo/write_test.go index 85d38a95..634335b7 100644 --- a/internal/tools/todo/write_test.go +++ b/internal/tools/todo/write_test.go @@ -202,6 +202,20 @@ func TestToolMetadataMethods(t *testing.T) { if _, ok := properties["items"]; !ok { t.Fatalf("Schema() should include items property") } + artifacts, ok := properties["artifacts"].(map[string]any) + if !ok { + t.Fatalf("Schema() artifacts should be object, got %T", properties["artifacts"]) + } + if artifacts["type"] != "array" { + t.Fatalf("Schema() artifacts.type = %+v, want array", artifacts["type"]) + } + items, ok := artifacts["items"].(map[string]any) + if !ok { + t.Fatalf("Schema() artifacts.items should be object, got %T", artifacts["items"]) + } + if items["type"] != "string" { + t.Fatalf("Schema() artifacts.items.type = %+v, want string", items["type"]) + } } func TestToolExecuteActionSequence(t *testing.T) { diff --git a/internal/tui/bootstrap/builder_test.go b/internal/tui/bootstrap/builder_test.go index 7a688b78..b6bf3677 100644 --- a/internal/tui/bootstrap/builder_test.go +++ b/internal/tui/bootstrap/builder_test.go @@ -15,6 +15,26 @@ import ( type testRuntime struct{} +func (r *testRuntime) PrepareUserInput(ctx context.Context, input agentruntime.PrepareInput) (agentruntime.UserInput, error) { + return agentruntime.UserInput{ + SessionID: input.SessionID, + RunID: input.RunID, + Workdir: input.Workdir, + }, nil +} + +func (r *testRuntime) Submit(ctx context.Context, input agentruntime.PrepareInput) error { + _, err := r.PrepareUserInput(ctx, input) + if err != nil { + return err + } + return r.Run(ctx, agentruntime.UserInput{ + SessionID: input.SessionID, + RunID: input.RunID, + Workdir: input.Workdir, + }) +} + func (r *testRuntime) Run(ctx context.Context, input agentruntime.UserInput) error { return nil } @@ -206,6 +226,26 @@ func (f errorFactory) BuildProvider(mode Mode, current ProviderService) (Provide type noopRuntime struct{} +func (r noopRuntime) PrepareUserInput(ctx context.Context, input agentruntime.PrepareInput) (agentruntime.UserInput, error) { + return agentruntime.UserInput{ + SessionID: input.SessionID, + RunID: input.RunID, + Workdir: input.Workdir, + }, nil +} + +func (r noopRuntime) Submit(ctx context.Context, input agentruntime.PrepareInput) error { + _, err := r.PrepareUserInput(ctx, input) + if err != nil { + return err + } + return r.Run(ctx, agentruntime.UserInput{ + SessionID: input.SessionID, + RunID: input.RunID, + Workdir: input.Workdir, + }) +} + func (r noopRuntime) Run(ctx context.Context, input agentruntime.UserInput) error { return nil } diff --git a/internal/tui/core/app/app.go b/internal/tui/core/app/app.go index 394eaa82..37769e16 100644 --- a/internal/tui/core/app/app.go +++ b/internal/tui/core/app/app.go @@ -89,27 +89,27 @@ type appComponents struct { // appRuntimeState 聚合运行期易变字段,降低 App 顶层字段密度。 type appRuntimeState struct { - codeCopyBlocks map[int]string - pendingCopyID int - deferredEventCmd tea.Cmd - nowFn func() time.Time - lastInputEditAt time.Time - lastPasteLikeAt time.Time - inputBurstStart time.Time - inputBurstCount int - pasteMode bool - activeMessages []providertypes.Message - activities []tuistate.ActivityEntry - fileCandidates []string - modelRefreshID string - focus panel - runProgressValue float64 - runProgressKnown bool - runProgressLabel string - pendingPermission *permissionPromptState - pendingImageAttachments []pendingImageAttachment - currentModelCapabilities modelCapabilityState - providerAddForm *providerAddFormState + codeCopyBlocks map[int]string + pendingCopyID int + deferredEventCmd tea.Cmd + nowFn func() time.Time + lastInputEditAt time.Time + lastPasteLikeAt time.Time + inputBurstStart time.Time + inputBurstCount int + pasteMode bool + activeMessages []providertypes.Message + activities []tuistate.ActivityEntry + fileCandidates []string + modelRefreshID string + focus panel + runProgressValue float64 + runProgressKnown bool + runProgressLabel string + lastUserMessageRunID string + pendingPermission *permissionPromptState + pendingImageAttachments []pendingImageAttachment + providerAddForm *providerAddFormState } type pendingImageAttachment struct { @@ -119,11 +119,6 @@ type pendingImageAttachment struct { Name string } -type modelCapabilityState struct { - supportsImageInput bool - checked bool -} - // providerAddFormState 保存添加新 provider 表单的状态。 type providerAddFormState struct { Step int // 当前聚焦字段在“当前 driver 可见字段列表”中的索引 diff --git a/internal/tui/core/app/input_features.go b/internal/tui/core/app/input_features.go index c52d4a3b..99e164b2 100644 --- a/internal/tui/core/app/input_features.go +++ b/internal/tui/core/app/input_features.go @@ -24,7 +24,6 @@ const ( maxWorkspaceFiles = 4000 maxFileSuggestions = 6 maxImageAttachments = 3 - imageMaxSizeBytes = 5 * 1024 * 1024 // 5 MiB ) type tokenSelector int @@ -37,7 +36,6 @@ const ( var workspaceCommandExecutor = defaultWorkspaceCommandExecutor var readClipboardImage = tuiinfra.ReadClipboardImage var saveClipboardImageToTempFile = tuiinfra.SaveImageToTempFile -var detectImageMimeType = tuiinfra.DetectImageMimeType func isWorkspaceCommandInput(input string) bool { return strings.HasPrefix(strings.TrimSpace(input), workspaceCommandPrefix) @@ -219,6 +217,191 @@ func (a *App) applyImageReference(input string) error { return a.addImageAttachment(path) } +// absorbInlineImageReferences 会把输入文本中的 @image: 令牌吸收到附件队列,并返回移除令牌后的文本。 +// 该实现保留原始空白布局,仅移除命中的图片令牌,避免改变用户提示词语义。 +func (a *App) absorbInlineImageReferences(input string) (string, int, error) { + if strings.TrimSpace(input) == "" { + return strings.TrimSpace(input), 0, nil + } + + var builder strings.Builder + absorbed := 0 + for i := 0; i < len(input); { + imagePath, end, ok := parseInlineImageReferenceAt(input, i) + if ok && looksLikeImagePath(imagePath) { + if err := a.queueImageAttachmentForPrepare(imagePath); err != nil { + return "", absorbed, err + } + absorbed++ + i = end + continue + } + builder.WriteByte(input[i]) + i++ + } + + return strings.TrimSpace(builder.String()), absorbed, nil +} + +// isInlineTokenSpace 判断字符是否属于输入令牌分隔空白字符。 +func isInlineTokenSpace(ch byte) bool { + switch ch { + case ' ', '\t', '\r', '\n': + return true + default: + return false + } +} + +// parseInlineImagePathToken 识别 @image: 形式的图片路径令牌,并映射为待发送路径。 +func (a *App) parseInlineImagePathToken(token string) (string, bool) { + path, _, ok := parseInlineImageReferenceAt(strings.TrimSpace(token), 0) + if !ok { + return "", false + } + path = strings.TrimSpace(path) + if path == "" || !looksLikeImagePath(path) { + return "", false + } + + resolved := path + if !filepath.IsAbs(resolved) { + base := strings.TrimSpace(a.state.CurrentWorkdir) + if base == "" { + return "", false + } + resolved = filepath.Join(base, resolved) + } + return resolved, true +} + +// parseInlineImageReferenceAt 从输入指定位置解析 @image:,支持引号与空格路径。 +func parseInlineImageReferenceAt(input string, start int) (path string, end int, ok bool) { + if start < 0 || start >= len(input) { + return "", 0, false + } + if start > 0 && !isInlineTokenSpace(input[start-1]) { + return "", 0, false + } + if !strings.HasPrefix(input[start:], imageReferencePrefix) { + return "", 0, false + } + + cursor := start + len(imageReferencePrefix) + if cursor >= len(input) { + return "", 0, false + } + + quotedPath, quotedEnd, quoted := readQuotedInlinePath(input, cursor) + if quoted { + if strings.TrimSpace(quotedPath) == "" { + return "", 0, false + } + return strings.TrimSpace(quotedPath), quotedEnd, true + } + + unquotedPath, unquotedEnd := readUnquotedInlinePath(input, cursor) + unquotedPath = strings.TrimSpace(unquotedPath) + if unquotedPath == "" { + return "", 0, false + } + return unquotedPath, unquotedEnd, true +} + +// readQuotedInlinePath 读取带引号路径,支持 \" 和 \' 转义。 +func readQuotedInlinePath(input string, start int) (string, int, bool) { + if start >= len(input) { + return "", 0, false + } + quote := input[start] + if quote != '"' && quote != '\'' { + return "", 0, false + } + var builder strings.Builder + for i := start + 1; i < len(input); i++ { + ch := input[i] + if ch == '\\' && i+1 < len(input) { + next := input[i+1] + if next == quote || next == '\\' { + builder.WriteByte(next) + i++ + continue + } + } + if ch == quote { + return builder.String(), i + 1, true + } + builder.WriteByte(ch) + } + return "", 0, false +} + +// readUnquotedInlinePath 读取非引号路径,遇到空白或换行结束,支持反斜杠转义空白字符。 +func readUnquotedInlinePath(input string, start int) (string, int) { + var builder strings.Builder + end := start + for end < len(input) { + ch := input[end] + if isInlineTokenSpace(ch) { + break + } + if ch == '\\' && end+1 < len(input) { + next := input[end+1] + if isInlineTokenSpace(next) { + builder.WriteByte(next) + end += 2 + continue + } + } + builder.WriteByte(ch) + end++ + } + return builder.String(), end +} + +// queueImageAttachmentForPrepare 将图片路径排队为待发送附件,不在 TUI 层做文件系统和 MIME 硬校验。 +// 真正的可用性校验与错误语义统一在 runtime/session 归一化阶段完成。 +func (a *App) queueImageAttachmentForPrepare(path string) error { + path = strings.TrimSpace(path) + if path == "" { + return fmt.Errorf("image path is empty") + } + if len(a.pendingImageAttachments) >= maxImageAttachments { + return fmt.Errorf("maximum %d image attachments allowed", maxImageAttachments) + } + + resolved := path + if !filepath.IsAbs(resolved) { + base := strings.TrimSpace(a.state.CurrentWorkdir) + if base != "" { + resolved = filepath.Join(base, resolved) + } + } + absPath, err := filepath.Abs(resolved) + if err != nil { + return fmt.Errorf("invalid image path: %w", err) + } + + a.pendingImageAttachments = append(a.pendingImageAttachments, pendingImageAttachment{ + Path: absPath, + MimeType: "", + Size: 0, + Name: filepath.Base(absPath), + }) + a.refreshImageAttachmentDisplay() + return nil +} + +// looksLikeImagePath 使用扩展名快速判断路径是否是常见图片文件。 +func looksLikeImagePath(path string) bool { + switch strings.ToLower(filepath.Ext(strings.TrimSpace(path))) { + case ".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp": + return true + default: + return false + } +} + func (a *App) applyFileReference(path string) error { path = strings.TrimSpace(path) if path == "" { @@ -275,43 +458,12 @@ func extractImageReference(input string) string { } func (a *App) addImageAttachment(path string) error { - path = strings.TrimSpace(path) - if path == "" { - return fmt.Errorf("image path is empty") - } - - if len(a.pendingImageAttachments) >= maxImageAttachments { - return fmt.Errorf("maximum %d image attachments allowed", maxImageAttachments) - } - - absPath, err := filepath.Abs(path) - if err != nil { - return fmt.Errorf("invalid image path: %w", err) - } - - info, err := tuiinfra.GetFileInfo(absPath) - if err != nil { - return fmt.Errorf("cannot read image file: %w", err) - } - - if info.Size() > imageMaxSizeBytes { - return fmt.Errorf("image size exceeds %d MB limit", imageMaxSizeBytes/(1024*1024)) + if err := a.queueImageAttachmentForPrepare(path); err != nil { + return err } - - mimeType := detectImageMimeType(absPath) - if mimeType == "" { - return fmt.Errorf("unsupported image format") + if count := len(a.pendingImageAttachments); count > 0 { + a.state.StatusText = fmt.Sprintf("[System] Added image: %s", a.pendingImageAttachments[count-1].Name) } - - a.pendingImageAttachments = append(a.pendingImageAttachments, pendingImageAttachment{ - Path: absPath, - MimeType: mimeType, - Size: info.Size(), - Name: filepath.Base(absPath), - }) - - a.refreshImageAttachmentDisplay() - a.state.StatusText = fmt.Sprintf("[System] Added image: %s", filepath.Base(absPath)) return nil } @@ -370,62 +522,13 @@ func (a *App) addImageFromClipboard() error { return fmt.Errorf("no image in clipboard") } - if int64(len(data)) > imageMaxSizeBytes { - return fmt.Errorf("image size exceeds %d MB limit", imageMaxSizeBytes/(1024*1024)) - } - tmpPath, err := saveClipboardImageToTempFile(data, "paste") if err != nil { return fmt.Errorf("failed to save clipboard image: %w", err) } - - mimeType := detectImageMimeType(tmpPath) - if mimeType == "" { - return fmt.Errorf("unsupported image format from clipboard") + if err := a.queueImageAttachmentForPrepare(tmpPath); err != nil { + return err } - - a.pendingImageAttachments = append(a.pendingImageAttachments, pendingImageAttachment{ - Path: tmpPath, - MimeType: mimeType, - Size: int64(len(data)), - Name: "clipboard_image.png", - }) - - a.refreshImageAttachmentDisplay() a.state.StatusText = "[System] Added image from clipboard" return nil } - -func (a *App) checkModelImageSupport() bool { - if a.currentModelCapabilities.checked { - return a.currentModelCapabilities.supportsImageInput - } - - models, err := a.providerSvc.ListModelsSnapshot(context.Background()) - if err != nil { - a.currentModelCapabilities.checked = true - a.currentModelCapabilities.supportsImageInput = false - return false - } - - for _, m := range models { - if m.ID == a.state.CurrentModel { - a.currentModelCapabilities.checked = true - a.currentModelCapabilities.supportsImageInput = m.CapabilityHints.ImageInput == "supported" - return a.currentModelCapabilities.supportsImageInput - } - } - - a.currentModelCapabilities.checked = true - a.currentModelCapabilities.supportsImageInput = false - return false -} - -func (a *App) canSendImageInput() bool { - return a.checkModelImageSupport() -} - -// invalidateModelCapabilityCache 在 provider 或 model 变化时清理图片能力缓存,避免复用旧结果。 -func (a *App) invalidateModelCapabilityCache() { - a.currentModelCapabilities = modelCapabilityState{} -} diff --git a/internal/tui/core/app/input_features_test.go b/internal/tui/core/app/input_features_test.go index b38e0edc..5092c898 100644 --- a/internal/tui/core/app/input_features_test.go +++ b/internal/tui/core/app/input_features_test.go @@ -12,19 +12,8 @@ import ( tea "github.com/charmbracelet/bubbletea" "neo-code/internal/config" - configstate "neo-code/internal/config/state" - providertypes "neo-code/internal/provider/types" ) -type snapshotErrProviderService struct { - stubProviderService - err error -} - -func (s snapshotErrProviderService) ListModelsSnapshot(ctx context.Context) ([]providertypes.ModelDescriptor, error) { - return nil, s.err -} - func TestTokenAndReferenceParsing(t *testing.T) { start, end, token, ok := tokenRange(" @file/path", tokenSelectorFirst) if !ok || start != 2 || end != len(" @file/path") || token != "@file/path" { @@ -165,41 +154,6 @@ func TestAddImageAttachmentLimit(t *testing.T) { } } -func TestCanSendImageInputCacheInvalidationOnModelChange(t *testing.T) { - app, _ := newTestApp(t) - providerID := app.state.CurrentProvider - - app.providerSvc = stubProviderService{ - providers: []configstate.ProviderOption{{ID: providerID, Name: providerID}}, - models: []providertypes.ModelDescriptor{{ - ID: "model-a", - Name: "model-a", - CapabilityHints: providertypes.ModelCapabilityHints{ - ImageInput: providertypes.ModelCapabilityStateSupported, - }, - }}, - } - app.state.CurrentModel = "model-a" - if !app.canSendImageInput() { - t.Fatalf("expected model-a to support images") - } - - app.providerSvc = stubProviderService{ - providers: []configstate.ProviderOption{{ID: providerID, Name: providerID}}, - models: []providertypes.ModelDescriptor{{ - ID: "model-b", - Name: "model-b", - CapabilityHints: providertypes.ModelCapabilityHints{ - ImageInput: providertypes.ModelCapabilityStateUnsupported, - }, - }}, - } - app.syncConfigState(config.Config{SelectedProvider: providerID, CurrentModel: "model-b", Workdir: app.state.CurrentWorkdir}) - if app.canSendImageInput() { - t.Fatalf("expected model-b to be unsupported after cache invalidation") - } -} - func TestApplyImageReference(t *testing.T) { app, _ := newTestApp(t) root := t.TempDir() @@ -218,6 +172,182 @@ func TestApplyImageReference(t *testing.T) { } } +func TestAbsorbInlineImageReferences(t *testing.T) { + app, _ := newTestApp(t) + root := t.TempDir() + app.state.CurrentWorkdir = root + + imagePath := filepath.Join(root, "chart.png") + if err := os.WriteFile(imagePath, []byte("png"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + normalized, absorbed, err := app.absorbInlineImageReferences("请分析 @image:chart.png 趋势") + if err != nil { + t.Fatalf("absorbInlineImageReferences() error = %v", err) + } + if absorbed != 1 { + t.Fatalf("expected one absorbed image, got %d", absorbed) + } + if normalized != "请分析 趋势" { + t.Fatalf("unexpected normalized text: %q", normalized) + } + if app.getImageAttachmentCount() != 1 { + t.Fatalf("expected one pending image attachment, got %d", app.getImageAttachmentCount()) + } +} + +func TestAbsorbInlineImageReferencesRequiresExplicitPrefix(t *testing.T) { + app, _ := newTestApp(t) + root := t.TempDir() + app.state.CurrentWorkdir = root + + normalized, absorbed, err := app.absorbInlineImageReferences("请分析 @chart.png 趋势") + if err != nil { + t.Fatalf("absorbInlineImageReferences() error = %v", err) + } + if absorbed != 0 { + t.Fatalf("expected absorbed image count to be 0, got %d", absorbed) + } + if normalized != "请分析 @chart.png 趋势" { + t.Fatalf("unexpected normalized text: %q", normalized) + } + if app.getImageAttachmentCount() != 0 { + t.Fatalf("expected no pending image attachments") + } +} + +func TestAbsorbInlineImageReferencesKeepsNonImageToken(t *testing.T) { + app, _ := newTestApp(t) + root := t.TempDir() + app.state.CurrentWorkdir = root + + normalized, absorbed, err := app.absorbInlineImageReferences("查看 @README.md 内容") + if err != nil { + t.Fatalf("absorbInlineImageReferences() error = %v", err) + } + if absorbed != 0 { + t.Fatalf("expected absorbed image count to be 0, got %d", absorbed) + } + if normalized != "查看 @README.md 内容" { + t.Fatalf("unexpected normalized text: %q", normalized) + } + if app.getImageAttachmentCount() != 0 { + t.Fatalf("expected no pending image attachments") + } +} + +func TestAbsorbInlineImageReferencesDoesNotRequireFileExistenceInTUI(t *testing.T) { + app, _ := newTestApp(t) + app.state.CurrentWorkdir = t.TempDir() + + normalized, absorbed, err := app.absorbInlineImageReferences("处理 @image:not-exist.png") + if err != nil { + t.Fatalf("absorbInlineImageReferences() error = %v", err) + } + if absorbed != 1 { + t.Fatalf("expected one absorbed image, got %d", absorbed) + } + if normalized != "处理" { + t.Fatalf("unexpected normalized text: %q", normalized) + } + if app.getImageAttachmentCount() != 1 { + t.Fatalf("expected one pending attachment") + } + if app.getImageAttachments()[0].MimeType != "" { + t.Fatalf("expected mime type to stay empty before runtime/session validation") + } +} + +func TestAbsorbInlineImageReferencesPreservesWhitespaceLayout(t *testing.T) { + app, _ := newTestApp(t) + app.state.CurrentWorkdir = t.TempDir() + + normalized, absorbed, err := app.absorbInlineImageReferences("A @image:x.png\nB\t @image:y.jpg C") + if err != nil { + t.Fatalf("absorbInlineImageReferences() error = %v", err) + } + if absorbed != 2 { + t.Fatalf("expected absorbed image count to be 2, got %d", absorbed) + } + if normalized != "A \nB\t C" { + t.Fatalf("unexpected normalized text: %q", normalized) + } + if app.getImageAttachmentCount() != 2 { + t.Fatalf("expected two pending image attachments") + } +} + +func TestAbsorbInlineImageReferencesSupportsQuotedPathWithSpaces(t *testing.T) { + app, _ := newTestApp(t) + root := t.TempDir() + app.state.CurrentWorkdir = root + + normalized, absorbed, err := app.absorbInlineImageReferences(`请分析 @image:"charts/sales q1.png" 趋势`) + if err != nil { + t.Fatalf("absorbInlineImageReferences() error = %v", err) + } + if absorbed != 1 { + t.Fatalf("expected absorbed image count to be 1, got %d", absorbed) + } + if normalized != "请分析 趋势" { + t.Fatalf("unexpected normalized text: %q", normalized) + } + if app.getImageAttachmentCount() != 1 { + t.Fatalf("expected one pending image attachment") + } + if !strings.HasSuffix(app.getImageAttachments()[0].Path, filepath.FromSlash("charts/sales q1.png")) { + t.Fatalf("unexpected queued path: %q", app.getImageAttachments()[0].Path) + } +} + +func TestParseInlineImagePathToken(t *testing.T) { + app, _ := newTestApp(t) + root := t.TempDir() + app.state.CurrentWorkdir = root + + relative, ok := app.parseInlineImagePathToken(`@image:"charts/sales q1.png"`) + if !ok { + t.Fatalf("expected quoted relative token to parse") + } + if relative != filepath.Join(root, filepath.FromSlash("charts/sales q1.png")) { + t.Fatalf("unexpected resolved path: %q", relative) + } + + absolutePath := filepath.Join(root, "abs.png") + absolute, ok := app.parseInlineImagePathToken("@image:" + absolutePath) + if !ok || absolute != absolutePath { + t.Fatalf("expected absolute token to pass through, got %q ok=%v", absolute, ok) + } + + if _, ok := app.parseInlineImagePathToken("@image:notes.txt"); ok { + t.Fatalf("expected non-image token to be rejected") + } + app.state.CurrentWorkdir = "" + if _, ok := app.parseInlineImagePathToken("@image:relative.png"); ok { + t.Fatalf("expected missing workdir to reject relative token") + } + if _, ok := app.parseInlineImagePathToken("not-image-token"); ok { + t.Fatalf("expected invalid token to be rejected") + } +} + +func TestParseInlineImageReferenceAtBranches(t *testing.T) { + if _, _, ok := parseInlineImageReferenceAt("x@image:a.png", 1); ok { + t.Fatalf("expected token without boundary whitespace to be rejected") + } + path, end, ok := parseInlineImageReferenceAt(`@image:folder\ with\ space.png next`, 0) + if !ok { + t.Fatalf("expected escaped-space token to parse") + } + if path != "folder with space.png" || end <= 0 { + t.Fatalf("unexpected escaped path parse result path=%q end=%d", path, end) + } + if _, _, ok := parseInlineImageReferenceAt(`@image:""`, 0); ok { + t.Fatalf("expected empty quoted token to fail") + } +} + func TestGetAndClearImageAttachments(t *testing.T) { app, _ := newTestApp(t) app.pendingImageAttachments = []pendingImageAttachment{ @@ -250,7 +380,6 @@ func TestAddImageFromClipboardSuccess(t *testing.T) { app, _ := newTestApp(t) originalRead := readClipboardImage originalSave := saveClipboardImageToTempFile - originalDetect := detectImageMimeType readClipboardImage = func() ([]byte, error) { return []byte("image-bytes"), nil } @@ -261,11 +390,9 @@ func TestAddImageFromClipboardSuccess(t *testing.T) { } return path, nil } - detectImageMimeType = func(path string) string { return "image/png" } defer func() { readClipboardImage = originalRead saveClipboardImageToTempFile = originalSave - detectImageMimeType = originalDetect }() if err := app.addImageFromClipboard(); err != nil { @@ -280,11 +407,9 @@ func TestAddImageFromClipboardBranches(t *testing.T) { app, _ := newTestApp(t) originalRead := readClipboardImage originalSave := saveClipboardImageToTempFile - originalDetect := detectImageMimeType defer func() { readClipboardImage = originalRead saveClipboardImageToTempFile = originalSave - detectImageMimeType = originalDetect }() readClipboardImage = func() ([]byte, error) { return nil, nil } @@ -292,20 +417,6 @@ func TestAddImageFromClipboardBranches(t *testing.T) { t.Fatalf("expected no image in clipboard error") } - readClipboardImage = func() ([]byte, error) { return make([]byte, imageMaxSizeBytes+1), nil } - if err := app.addImageFromClipboard(); err == nil { - t.Fatalf("expected image size limit error") - } - - readClipboardImage = func() ([]byte, error) { return []byte("x"), nil } - saveClipboardImageToTempFile = func(data []byte, prefix string) (string, error) { - return filepath.Join(t.TempDir(), "clipboard.bin"), nil - } - detectImageMimeType = func(path string) string { return "" } - if err := app.addImageFromClipboard(); err == nil { - t.Fatalf("expected unsupported image format error") - } - readClipboardImage = func() ([]byte, error) { return []byte("x"), nil } saveClipboardImageToTempFile = func(data []byte, prefix string) (string, error) { return "", errors.New("save failed") @@ -315,31 +426,6 @@ func TestAddImageFromClipboardBranches(t *testing.T) { } } -func TestCheckModelImageSupportErrorAndModelNotFound(t *testing.T) { - app, _ := newTestApp(t) - app.providerSvc = snapshotErrProviderService{ - stubProviderService: stubProviderService{}, - err: errors.New("boom"), - } - if app.checkModelImageSupport() { - t.Fatalf("expected false when provider snapshot fails") - } - if !app.currentModelCapabilities.checked { - t.Fatalf("expected capability cache to be marked checked after failure") - } - - app.currentModelCapabilities = modelCapabilityState{} - app.providerSvc = stubProviderService{ - providers: []configstate.ProviderOption{{ID: app.state.CurrentProvider, Name: app.state.CurrentProvider}}, - models: []providertypes.ModelDescriptor{{ - ID: "other-model", - }}, - } - if app.checkModelImageSupport() { - t.Fatalf("expected false when current model is missing from snapshot") - } -} - func TestExecuteWorkspaceCommand(t *testing.T) { app, _ := newTestApp(t) original := workspaceCommandExecutor @@ -393,7 +479,7 @@ func TestRunWorkspaceCommandCmd(t *testing.T) { } } -func TestUpdateSendWithImageAttachmentsBlocksUntilSessionAssets(t *testing.T) { +func TestUpdateSendWithImageAttachmentsRunsThroughPreparePipeline(t *testing.T) { app, runtime := newTestApp(t) root := t.TempDir() imagePath := filepath.Join(root, "queued.png") @@ -403,17 +489,6 @@ func TestUpdateSendWithImageAttachmentsBlocksUntilSessionAssets(t *testing.T) { if err := app.addImageAttachment(imagePath); err != nil { t.Fatalf("addImageAttachment() error = %v", err) } - app.providerSvc = stubProviderService{ - providers: []configstate.ProviderOption{{ID: app.state.CurrentProvider, Name: app.state.CurrentProvider}}, - models: []providertypes.ModelDescriptor{{ - ID: app.state.CurrentModel, - Name: app.state.CurrentModel, - CapabilityHints: providertypes.ModelCapabilityHints{ - ImageInput: providertypes.ModelCapabilityStateSupported, - }, - }}, - } - app.input.SetValue("hello") app.state.InputText = "hello" @@ -423,18 +498,21 @@ func TestUpdateSendWithImageAttachmentsBlocksUntilSessionAssets(t *testing.T) { } app = model.(App) if app.hasImageAttachments() { - t.Fatalf("expected attachments cleared after unsupported send path") + t.Fatalf("expected attachments cleared after send") } - if app.state.IsAgentRunning { - t.Fatalf("expected image send to be blocked until session assets are available") + if !app.state.IsAgentRunning { + t.Fatalf("expected image send to enter running state") } - if len(app.activeMessages) != 0 { - t.Fatalf("expected no text fallback message in transcript, got %+v", app.activeMessages) + if app.state.StatusText != statusThinking { + t.Fatalf("unexpected status text: %q", app.state.StatusText) } - if len(runtime.runInputs) != 0 { - t.Fatalf("expected no runtime input with image metadata fallback, got %+v", runtime.runInputs) + if len(runtime.prepareInputs) != 1 { + t.Fatalf("expected one prepare input, got %+v", runtime.prepareInputs) } - if app.state.StatusText != "Image attachments need session asset support" { - t.Fatalf("unexpected status text: %q", app.state.StatusText) + if len(runtime.prepareInputs[0].Images) != 1 || runtime.prepareInputs[0].Images[0].MimeType != "" { + t.Fatalf("expected one queued image in prepare input, got %+v", runtime.prepareInputs[0].Images) + } + if len(runtime.runInputs) != 1 { + t.Fatalf("expected one runtime input after prepare, got %+v", runtime.runInputs) } } diff --git a/internal/tui/core/app/update.go b/internal/tui/core/app/update.go index 9fea342e..836ad0f3 100644 --- a/internal/tui/core/app/update.go +++ b/internal/tui/core/app/update.go @@ -319,7 +319,8 @@ func (a App) updateInputPanel(msg tea.Msg, typed tea.KeyMsg, cmds []tea.Cmd) (te effectiveTyped = tea.KeyMsg{Type: tea.KeyEnter, Paste: true} } else { input := strings.TrimSpace(a.input.Value()) - if input == "" || a.isBusy() { + hasImages := a.hasImageAttachments() + if (input == "" && !hasImages) || a.isBusy() { return a, tea.Batch(cmds...) } @@ -411,21 +412,18 @@ func (a App) updateInputPanel(msg tea.Msg, typed tea.KeyMsg, cmds []tea.Cmd) (te return a, tea.Batch(cmds...) } - if a.hasImageAttachments() && !a.canSendImageInput() { - a.state.ExecutionError = "current model does not support image input" - a.state.StatusText = "Model does not support images" - a.appendActivity("multimodal", "Image input not supported", fmt.Sprintf("Model %s does not support image input", a.state.CurrentModel), true) - a.clearImageAttachments() + normalizedInput, absorbedImages, err := a.absorbInlineImageReferences(input) + if err != nil { + a.state.ExecutionError = err.Error() + a.state.StatusText = err.Error() + a.appendActivity("multimodal", "Failed to absorb inline image reference", err.Error(), true) return a, tea.Batch(cmds...) } - if a.hasImageAttachments() { - a.state.ExecutionError = "image attachments require session asset storage before sending" - a.state.StatusText = "Image attachments need session asset support" - a.appendActivity("multimodal", "Image attachments not sent", "Session asset storage is not available yet; images were not converted to text.", true) - a.clearImageAttachments() - return a, tea.Batch(cmds...) + if absorbedImages > 0 { + input = normalizedInput } + // image capability precheck is intentionally disabled. // 如果不是立即执行的命令,再执行常规的输入重置 a.input.Reset() a.state.InputText = "" @@ -442,12 +440,23 @@ func (a App) updateInputPanel(msg tea.Msg, typed tea.KeyMsg, cmds []tea.Cmd) (te a.state.StatusText = statusThinking a.state.CurrentTool = "" - a.activeMessages = append(a.activeMessages, providertypes.Message{Role: roleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart(input)}}) - a.rebuildTranscript() runID := fmt.Sprintf("run-%d", a.now().UnixNano()) a.state.ActiveRunID = runID requestedWorkdir := tuiutils.RequestedWorkdirForRun(a.state.CurrentWorkdir) - cmds = append(cmds, runAgent(a.runtime, runID, a.state.ActiveSessionID, requestedWorkdir, input)) + images := make([]agentruntime.UserImageInput, 0, len(a.pendingImageAttachments)) + for _, attachment := range a.pendingImageAttachments { + images = append(images, agentruntime.UserImageInput{ + Path: attachment.Path, + MimeType: attachment.MimeType, + }) + } + cmds = append(cmds, runAgent(a.runtime, agentruntime.PrepareInput{ + SessionID: a.state.ActiveSessionID, + RunID: runID, + Workdir: requestedWorkdir, + Text: input, + Images: images, + })) a.clearImageAttachments() return a, tea.Batch(cmds...) } @@ -726,6 +735,7 @@ func (a *App) refreshSessionPicker() error { } func (a *App) refreshMessages() error { + a.resetSessionRuntimeState() if strings.TrimSpace(a.state.ActiveSessionID) == "" { a.activeMessages = nil a.clearActivities() @@ -745,6 +755,20 @@ func (a *App) refreshMessages() error { return nil } +// resetSessionRuntimeState 在切换/刷新会话前清理运行态缓存,避免跨会话残留工具与用量展示。 +func (a *App) resetSessionRuntimeState() { + a.state.IsAgentRunning = false + a.state.StreamingReply = false + a.state.CurrentTool = "" + a.state.ActiveRunID = "" + a.lastUserMessageRunID = "" + a.state.ToolStates = nil + a.state.RunContext = tuistate.ContextWindowState{} + a.state.TokenUsage = tuistate.TokenUsageState{} + a.pendingPermission = nil + a.clearRunProgress() +} + func (a *App) activateSelectedSession() error { item, ok := a.sessionPicker.SelectedItem().(sessionItem) if !ok { @@ -789,10 +813,6 @@ func (a *App) syncActiveSessionTitle() { } func (a *App) syncConfigState(cfg config.Config) { - if !strings.EqualFold(strings.TrimSpace(a.state.CurrentProvider), strings.TrimSpace(cfg.SelectedProvider)) || - !strings.EqualFold(strings.TrimSpace(a.state.CurrentModel), strings.TrimSpace(cfg.CurrentModel)) { - a.invalidateModelCapabilityCache() - } a.state.CurrentProvider = cfg.SelectedProvider a.state.CurrentModel = cfg.CurrentModel if strings.TrimSpace(a.state.CurrentWorkdir) == "" { @@ -868,6 +888,9 @@ type runtimeRunSnapshotSource interface { var runtimeEventHandlerRegistry = map[agentruntime.EventType]func(*App, agentruntime.RuntimeEvent) bool{ agentruntime.EventUserMessage: runtimeEventUserMessageHandler, + agentruntime.EventInputNormalized: runtimeEventInputNormalizedHandler, + agentruntime.EventAssetSaved: runtimeEventAssetSavedHandler, + agentruntime.EventAssetSaveFailed: runtimeEventAssetSaveFailedHandler, agentruntime.EventType(tuiservices.RuntimeEventRunContext): runtimeEventRunContextHandler, agentruntime.EventType(tuiservices.RuntimeEventToolStatus): runtimeEventToolStatusHandler, agentruntime.EventType(tuiservices.RuntimeEventUsage): runtimeEventUsageHandler, @@ -940,8 +963,8 @@ func runtimeEventStopReasonDecidedHandler(a *App, event agentruntime.RuntimeEven // handleRuntimeEvent 通过注册表分发 runtime 事件,避免巨型 switch 膨胀。 func (a *App) handleRuntimeEvent(event agentruntime.RuntimeEvent) bool { - if a.state.ActiveSessionID == "" { - a.state.ActiveSessionID = event.SessionID + if !a.shouldHandleRuntimeEvent(event) { + return false } handler, ok := runtimeEventHandlerRegistry[event.Type] if !ok { @@ -950,17 +973,108 @@ func (a *App) handleRuntimeEvent(event agentruntime.RuntimeEvent) bool { return handler(a, event) } +// shouldHandleRuntimeEvent 校验事件与当前活跃会话/运行上下文的关联,避免跨会话污染 UI 状态。 +func (a *App) shouldHandleRuntimeEvent(event agentruntime.RuntimeEvent) bool { + activeSessionID := strings.TrimSpace(a.state.ActiveSessionID) + eventSessionID := strings.TrimSpace(event.SessionID) + if activeSessionID != "" && eventSessionID != "" && !strings.EqualFold(activeSessionID, eventSessionID) { + return false + } + + activeRunID := strings.TrimSpace(a.state.ActiveRunID) + eventRunID := strings.TrimSpace(event.RunID) + if activeRunID != "" && eventRunID != "" && !strings.EqualFold(activeRunID, eventRunID) { + return false + } + return true +} + // runtimeEventUserMessageHandler 处理用户消息进入运行队列后的状态同步。 -func runtimeEventUserMessageHandler(a *App, event agentruntime.RuntimeEvent) bool { +// runtimeEventInputNormalizedHandler 处理输入归一化完成事件并更新运行态提示。 +func runtimeEventInputNormalizedHandler(a *App, event agentruntime.RuntimeEvent) bool { if strings.TrimSpace(event.RunID) != "" { a.state.ActiveRunID = strings.TrimSpace(event.RunID) } + payload, ok := event.Payload.(agentruntime.InputNormalizedPayload) + if !ok { + return false + } + if payload.ImageCount > 0 { + a.appendActivity( + "multimodal", + "Input normalized", + fmt.Sprintf("text=%d chars, images=%d", payload.TextLength, payload.ImageCount), + false, + ) + } + return false +} + +// runtimeEventAssetSavedHandler 处理附件保存成功事件并写入活动面板。 +func runtimeEventAssetSavedHandler(a *App, event agentruntime.RuntimeEvent) bool { + payload, ok := event.Payload.(agentruntime.AssetSavedPayload) + if !ok { + return false + } + detail := strings.TrimSpace(payload.AssetID) + if detail == "" { + detail = "asset saved" + } + if strings.TrimSpace(payload.Path) != "" { + detail = fmt.Sprintf("%s (%s)", detail, filepath.Base(payload.Path)) + } + a.appendActivity("multimodal", "Saved attachment", detail, false) + return false +} + +// runtimeEventAssetSaveFailedHandler 处理附件保存失败事件并同步错误状态。 +func runtimeEventAssetSaveFailedHandler(a *App, event agentruntime.RuntimeEvent) bool { + payload, ok := event.Payload.(agentruntime.AssetSaveFailedPayload) + if !ok { + return false + } + message := strings.TrimSpace(payload.Message) + if message == "" { + message = "failed to save attachment" + } + a.state.ExecutionError = message + a.state.StatusText = message + a.appendActivity("multimodal", "Failed to save attachment", message, true) + return false +} + +func runtimeEventUserMessageHandler(a *App, event agentruntime.RuntimeEvent) bool { + runID := strings.TrimSpace(event.RunID) + if runID != "" { + a.state.ActiveRunID = runID + } + if sessionID := strings.TrimSpace(event.SessionID); sessionID != "" { + a.state.ActiveSessionID = sessionID + } a.state.StatusText = statusThinking a.state.StreamingReply = false a.state.CurrentTool = "" a.state.ExecutionError = "" a.setRunProgress(0.15, "Queued") - return false + payload, ok := event.Payload.(providertypes.Message) + if !ok { + return false + } + content := renderMessagePartsForDisplay(payload.Parts) + if strings.TrimSpace(content) == "" { + return false + } + if runID != "" && strings.EqualFold(a.lastUserMessageRunID, runID) { + return false + } + a.activeMessages = append(a.activeMessages, providertypes.Message{ + Role: roleUser, + Parts: providertypes.CloneParts(payload.Parts), + }) + if runID != "" { + a.lastUserMessageRunID = runID + } + return true } // runtimeEventRunContextHandler 处理 runtime 上下文事件并回填界面状态。 @@ -971,19 +1085,16 @@ func runtimeEventRunContextHandler(a *App, event agentruntime.RuntimeEvent) bool } mapped := tuiservices.MapRunContextPayload(event.RunID, event.SessionID, payload) a.state.RunContext = mapped + if strings.TrimSpace(mapped.SessionID) != "" { + a.state.ActiveSessionID = strings.TrimSpace(mapped.SessionID) + } if strings.TrimSpace(mapped.RunID) != "" { a.state.ActiveRunID = mapped.RunID } if strings.TrimSpace(mapped.Provider) != "" { - if !strings.EqualFold(strings.TrimSpace(a.state.CurrentProvider), strings.TrimSpace(mapped.Provider)) { - a.invalidateModelCapabilityCache() - } a.state.CurrentProvider = mapped.Provider } if strings.TrimSpace(mapped.Model) != "" { - if !strings.EqualFold(strings.TrimSpace(a.state.CurrentModel), strings.TrimSpace(mapped.Model)) { - a.invalidateModelCapabilityCache() - } a.state.CurrentModel = mapped.Model } if strings.TrimSpace(mapped.Workdir) != "" { @@ -1112,7 +1223,7 @@ func runtimeEventAgentDoneHandler(a *App, event agentruntime.RuntimeEvent) bool } // runtimeEventRunCanceledHandler 处理运行取消事件。 -func runtimeEventRunCanceledHandler(a *App, event agentruntime.RuntimeEvent) bool { +func runtimeEventRunCanceledHandler(a *App) bool { a.state.IsAgentRunning = false a.state.StreamingReply = false a.state.CurrentTool = "" @@ -1859,6 +1970,7 @@ func (a *App) startDraftSession() { a.state.ExecutionError = "" a.state.CurrentTool = "" a.state.ActiveRunID = "" + a.lastUserMessageRunID = "" a.state.ToolStates = nil a.state.RunContext = tuistate.ContextWindowState{} a.state.TokenUsage = tuistate.TokenUsageState{} @@ -1895,15 +2007,10 @@ func ListenForRuntimeEvent(sub <-chan agentruntime.RuntimeEvent) tea.Cmd { ) } -func runAgent(runtime agentruntime.Runtime, runID string, sessionID string, workdir string, content string) tea.Cmd { - return tuiservices.RunAgentCmd( +func runAgent(runtime agentruntime.Runtime, input agentruntime.PrepareInput) tea.Cmd { + return tuiservices.RunSubmitCmd( runtime, - agentruntime.UserInput{ - SessionID: sessionID, - RunID: strings.TrimSpace(runID), - Parts: []providertypes.ContentPart{providertypes.NewTextPart(content)}, - Workdir: workdir, - }, + input, func(err error) tea.Msg { return runFinishedMsg{Err: err} }, ) } diff --git a/internal/tui/core/app/update_permission_test.go b/internal/tui/core/app/update_permission_test.go index 47c18475..5b24ef62 100644 --- a/internal/tui/core/app/update_permission_test.go +++ b/internal/tui/core/app/update_permission_test.go @@ -23,6 +23,20 @@ type permissionTestRuntime struct { lastResolved agentruntime.PermissionResolutionInput } +func (r *permissionTestRuntime) PrepareUserInput(ctx context.Context, input agentruntime.PrepareInput) (agentruntime.UserInput, error) { + return agentruntime.UserInput{ + SessionID: input.SessionID, + RunID: input.RunID, + Parts: nil, + Workdir: input.Workdir, + }, nil +} + +func (r *permissionTestRuntime) Submit(ctx context.Context, input agentruntime.PrepareInput) error { + _, err := r.PrepareUserInput(ctx, input) + return err +} + func (r *permissionTestRuntime) Run(ctx context.Context, input agentruntime.UserInput) error { return nil } diff --git a/internal/tui/core/app/update_runtime_events_test.go b/internal/tui/core/app/update_runtime_events_test.go index 8e571af0..d6f7725d 100644 --- a/internal/tui/core/app/update_runtime_events_test.go +++ b/internal/tui/core/app/update_runtime_events_test.go @@ -1,10 +1,13 @@ package tui import ( + "strings" "testing" + providertypes "neo-code/internal/provider/types" agentruntime "neo-code/internal/runtime" "neo-code/internal/runtime/controlplane" + tuiservices "neo-code/internal/tui/services" ) func TestRuntimeEventPhaseChangedHandlerBranches(t *testing.T) { @@ -131,3 +134,154 @@ func TestRuntimeEventHandlerRegistryContainsRenamedEvents(t *testing.T) { t.Fatalf("expected compact_applied handler to be registered") } } + +func TestShouldHandleRuntimeEventFiltersBySessionAndRun(t *testing.T) { + t.Parallel() + + app, _ := newTestApp(t) + app.state.ActiveSessionID = "session-active" + app.state.ActiveRunID = "run-active" + + if app.shouldHandleRuntimeEvent(agentruntime.RuntimeEvent{ + Type: agentruntime.EventAgentChunk, + SessionID: "session-other", + RunID: "run-active", + }) { + t.Fatalf("expected mismatched session event to be ignored") + } + if app.shouldHandleRuntimeEvent(agentruntime.RuntimeEvent{ + Type: agentruntime.EventAgentChunk, + SessionID: "session-active", + RunID: "run-other", + }) { + t.Fatalf("expected mismatched run event to be ignored") + } + if !app.shouldHandleRuntimeEvent(agentruntime.RuntimeEvent{ + Type: agentruntime.EventAgentChunk, + SessionID: "session-active", + RunID: "run-active", + }) { + t.Fatalf("expected matched event to be handled") + } +} + +func TestRuntimeEventMultimodalHandlers(t *testing.T) { + t.Parallel() + + app, _ := newTestApp(t) + + if handled := runtimeEventInputNormalizedHandler(&app, agentruntime.RuntimeEvent{Payload: "bad"}); handled { + t.Fatalf("expected invalid normalized payload to return false") + } + runtimeEventInputNormalizedHandler(&app, agentruntime.RuntimeEvent{ + RunID: "run-1", + Payload: agentruntime.InputNormalizedPayload{ + TextLength: 12, + ImageCount: 2, + }, + }) + if app.state.ActiveRunID != "run-1" { + t.Fatalf("expected active run id to be updated, got %q", app.state.ActiveRunID) + } + if len(app.activities) == 0 { + t.Fatalf("expected input normalized activity to be appended") + } + last := app.activities[len(app.activities)-1] + if last.Title != "Input normalized" || !strings.Contains(last.Detail, "images=2") { + t.Fatalf("unexpected normalized activity: %+v", last) + } + + before := len(app.activities) + runtimeEventAssetSavedHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.AssetSavedPayload{ + AssetID: "asset-1", + Path: "/tmp/chart.png", + }, + }) + if len(app.activities) != before+1 { + t.Fatalf("expected saved attachment activity appended") + } + last = app.activities[len(app.activities)-1] + if last.Title != "Saved attachment" || !strings.Contains(last.Detail, "chart.png") { + t.Fatalf("unexpected asset saved activity: %+v", last) + } + if handled := runtimeEventAssetSavedHandler(&app, agentruntime.RuntimeEvent{Payload: 123}); handled { + t.Fatalf("expected invalid asset_saved payload to return false") + } + + runtimeEventAssetSaveFailedHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.AssetSaveFailedPayload{Message: " failed "}, + }) + if app.state.ExecutionError != "failed" || app.state.StatusText != "failed" { + t.Fatalf("expected failed status to be surfaced, got status=%q err=%q", app.state.StatusText, app.state.ExecutionError) + } + last = app.activities[len(app.activities)-1] + if !last.IsError || last.Title != "Failed to save attachment" { + t.Fatalf("unexpected asset save failed activity: %+v", last) + } + runtimeEventAssetSaveFailedHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.AssetSaveFailedPayload{}, + }) + if app.state.ExecutionError != "failed to save attachment" || app.state.StatusText != "failed to save attachment" { + t.Fatalf("expected default failed message, got status=%q err=%q", app.state.StatusText, app.state.ExecutionError) + } + if handled := runtimeEventAssetSaveFailedHandler(&app, agentruntime.RuntimeEvent{Payload: true}); handled { + t.Fatalf("expected invalid asset_save_failed payload to return false") + } +} + +func TestHandleRuntimeEventRoutesByRegistryWithoutBindingTransientSession(t *testing.T) { + t.Parallel() + + app, _ := newTestApp(t) + handled := app.handleRuntimeEvent(agentruntime.RuntimeEvent{ + Type: agentruntime.EventAssetSaved, + SessionID: "session-1", + Payload: agentruntime.AssetSavedPayload{AssetID: "asset-1"}, + }) + if handled { + t.Fatalf("expected asset_saved handler to return false") + } + if app.state.ActiveSessionID != "" { + t.Fatalf("expected active session to stay empty for non-stable event, got %q", app.state.ActiveSessionID) + } + if len(app.activities) == 0 || app.activities[len(app.activities)-1].Title != "Saved attachment" { + t.Fatalf("expected saved attachment activity") + } + + if app.handleRuntimeEvent(agentruntime.RuntimeEvent{Type: "unknown_event", SessionID: "session-1"}) { + t.Fatalf("expected unknown event handler result to be false") + } +} + +func TestHandleRuntimeEventBindsSessionFromStableEvents(t *testing.T) { + t.Parallel() + + app, _ := newTestApp(t) + + app.handleRuntimeEvent(agentruntime.RuntimeEvent{ + Type: agentruntime.EventUserMessage, + SessionID: "session-user", + RunID: "run-1", + Payload: providertypes.Message{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")}, + }, + }) + if app.state.ActiveSessionID != "session-user" { + t.Fatalf("expected active session from user_message, got %q", app.state.ActiveSessionID) + } + + app.state.ActiveSessionID = "" + app.handleRuntimeEvent(agentruntime.RuntimeEvent{ + Type: agentruntime.EventType(tuiservices.RuntimeEventRunContext), + SessionID: "session-context", + Payload: tuiservices.RuntimeRunContextPayload{ + Provider: "openai", + Model: "gpt-5.4", + }, + }) + if app.state.ActiveSessionID != "session-context" { + t.Fatalf("expected active session from run_context, got %q", app.state.ActiveSessionID) + } +} diff --git a/internal/tui/core/app/update_test.go b/internal/tui/core/app/update_test.go index 8ff0d3a5..269f33f4 100644 --- a/internal/tui/core/app/update_test.go +++ b/internal/tui/core/app/update_test.go @@ -80,6 +80,9 @@ func (s stubProviderService) SetCurrentModel(ctx context.Context, modelID string type stubRuntime struct { events chan agentruntime.RuntimeEvent + prepareInputs []agentruntime.PrepareInput + prepareErr error + preparedOutput agentruntime.UserInput runInputs []agentruntime.UserInput resolveCalls []agentruntime.PermissionResolutionInput resolveErr error @@ -101,6 +104,38 @@ func newStubRuntime() *stubRuntime { return &stubRuntime{events: make(chan agentruntime.RuntimeEvent)} } +func (s *stubRuntime) PrepareUserInput(ctx context.Context, input agentruntime.PrepareInput) (agentruntime.UserInput, error) { + s.prepareInputs = append(s.prepareInputs, input) + if s.prepareErr != nil { + return agentruntime.UserInput{}, s.prepareErr + } + if len(s.preparedOutput.Parts) > 0 { + return s.preparedOutput, nil + } + sessionID := strings.TrimSpace(input.SessionID) + if sessionID == "" { + sessionID = "session-prepared" + } + content := strings.TrimSpace(input.Text) + if content == "" { + content = "image input" + } + return agentruntime.UserInput{ + SessionID: sessionID, + RunID: strings.TrimSpace(input.RunID), + Parts: []providertypes.ContentPart{providertypes.NewTextPart(content)}, + Workdir: strings.TrimSpace(input.Workdir), + }, nil +} + +func (s *stubRuntime) Submit(ctx context.Context, input agentruntime.PrepareInput) error { + prepared, err := s.PrepareUserInput(ctx, input) + if err != nil { + return err + } + return s.Run(ctx, prepared) +} + func (s *stubRuntime) Run(ctx context.Context, input agentruntime.UserInput) error { s.runInputs = append(s.runInputs, input) return nil @@ -1457,6 +1492,39 @@ func TestRuntimeEventUserMessageHandler(t *testing.T) { } } +func TestRuntimeEventUserMessageHandlerDeduplicatesByRunID(t *testing.T) { + app, _ := newTestApp(t) + payload := providertypes.Message{ + Role: roleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("same content")}, + } + event := agentruntime.RuntimeEvent{RunID: "run-1", Payload: payload} + + handled := runtimeEventUserMessageHandler(&app, event) + if !handled { + t.Fatalf("expected first user message to be rendered") + } + if len(app.activeMessages) != 1 { + t.Fatalf("expected one user message, got %d", len(app.activeMessages)) + } + + handled = runtimeEventUserMessageHandler(&app, event) + if handled { + t.Fatalf("expected duplicate run id to be ignored") + } + if len(app.activeMessages) != 1 { + t.Fatalf("expected one user message after duplicate event, got %d", len(app.activeMessages)) + } + + handled = runtimeEventUserMessageHandler(&app, agentruntime.RuntimeEvent{RunID: "run-2", Payload: payload}) + if !handled { + t.Fatalf("expected same content with new run id to be rendered") + } + if len(app.activeMessages) != 2 { + t.Fatalf("expected two user messages after new run id, got %d", len(app.activeMessages)) + } +} + func TestRuntimeEventRunContextHandler(t *testing.T) { app, _ := newTestApp(t) payload := tuiservices.RuntimeRunContextPayload{ @@ -1474,44 +1542,6 @@ func TestRuntimeEventRunContextHandler(t *testing.T) { } } -func TestRuntimeEventRunContextHandlerInvalidatesModelCapabilityCache(t *testing.T) { - app, _ := newTestApp(t) - app.state.CurrentProvider = "provider-a" - app.state.CurrentModel = "model-a" - app.currentModelCapabilities = modelCapabilityState{ - checked: true, - supportsImageInput: true, - } - - payload := tuiservices.RuntimeRunContextPayload{ - Provider: "provider-b", - Model: "model-b", - } - _ = runtimeEventRunContextHandler(&app, agentruntime.RuntimeEvent{Payload: payload}) - if app.currentModelCapabilities.checked { - t.Fatalf("expected capability cache to be invalidated when provider/model changes") - } -} - -func TestSyncConfigStateInvalidatesModelCapabilityCache(t *testing.T) { - app, _ := newTestApp(t) - app.state.CurrentProvider = "provider-a" - app.state.CurrentModel = "model-a" - app.currentModelCapabilities = modelCapabilityState{ - checked: true, - supportsImageInput: true, - } - - app.syncConfigState(config.Config{ - SelectedProvider: "provider-b", - CurrentModel: "model-b", - Workdir: app.state.CurrentWorkdir, - }) - if app.currentModelCapabilities.checked { - t.Fatalf("expected capability cache to be invalidated") - } -} - func TestUpdatePasteImageShortcutFailure(t *testing.T) { app, _ := newTestApp(t) model, cmd := app.Update(tea.KeyMsg{Type: tea.KeyCtrlV}) @@ -1563,8 +1593,8 @@ func TestUpdateEnterImageReferencePath(t *testing.T) { } } -func TestUpdateSendWithUnsupportedImageInput(t *testing.T) { - app, _ := newTestApp(t) +func TestUpdateSendWithUnsupportedImageInputDoesNotPreBlock(t *testing.T) { + app, runtime := newTestApp(t) app.pendingImageAttachments = []pendingImageAttachment{ {Name: "a.png", MimeType: "image/png", Path: "/tmp/a.png", Size: 1}, } @@ -1586,25 +1616,25 @@ func TestUpdateSendWithUnsupportedImageInput(t *testing.T) { _ = cmd() } app = model.(App) - if app.state.IsAgentRunning { - t.Fatalf("expected send to be blocked for unsupported model image input") + if !app.state.IsAgentRunning { + t.Fatalf("expected send not to be pre-blocked by model capability hints") } if app.hasImageAttachments() { - t.Fatalf("expected pending image attachments to be cleared on unsupported model") + t.Fatalf("expected pending image attachments to be cleared after send") } - if app.state.StatusText != "Model does not support images" { + if app.state.StatusText != statusThinking { t.Fatalf("unexpected status text: %q", app.state.StatusText) } - if app.input.Value() != "hello" { - t.Fatalf("expected input to be preserved when send is blocked, got %q", app.input.Value()) + if app.input.Value() != "" || app.state.InputText != "" { + t.Fatalf("expected input to reset after send, got input=%q state=%q", app.input.Value(), app.state.InputText) } - if app.state.InputText != "hello" { - t.Fatalf("expected state input text to be preserved, got %q", app.state.InputText) + if len(runtime.prepareInputs) != 1 || len(runtime.prepareInputs[0].Images) != 1 { + t.Fatalf("expected image to flow into prepare pipeline, got %+v", runtime.prepareInputs) } } -func TestUpdateSendWithImageAttachmentsWithoutSessionAssets(t *testing.T) { - app, _ := newTestApp(t) +func TestUpdateSendWithImageAttachmentsUsesPreparePipeline(t *testing.T) { + app, runtime := newTestApp(t) app.pendingImageAttachments = []pendingImageAttachment{ {Name: "a.png", MimeType: "image/png", Path: "/tmp/a.png", Size: 1}, } @@ -1626,20 +1656,72 @@ func TestUpdateSendWithImageAttachmentsWithoutSessionAssets(t *testing.T) { _ = cmd() } app = model.(App) - if app.state.IsAgentRunning { - t.Fatalf("expected send to be blocked when session assets are unavailable") + if !app.state.IsAgentRunning { + t.Fatalf("expected send to enter running state") } if app.hasImageAttachments() { - t.Fatalf("expected pending image attachments to be cleared when storage is unavailable") + t.Fatalf("expected pending image attachments to be cleared after send") } - if app.state.StatusText != "Image attachments need session asset support" { + if app.state.StatusText != statusThinking { t.Fatalf("unexpected status text: %q", app.state.StatusText) } - if app.input.Value() != "hello" { - t.Fatalf("expected input to be preserved when send is blocked, got %q", app.input.Value()) + if app.input.Value() != "" { + t.Fatalf("expected input to be reset after send, got %q", app.input.Value()) + } + if app.state.InputText != "" { + t.Fatalf("expected state input text to reset after send, got %q", app.state.InputText) + } + if len(runtime.prepareInputs) != 1 { + t.Fatalf("expected one prepare input, got %+v", runtime.prepareInputs) + } + if len(runtime.prepareInputs[0].Images) != 1 || runtime.prepareInputs[0].Images[0].MimeType != "image/png" { + t.Fatalf("expected image metadata to flow through prepare input, got %+v", runtime.prepareInputs[0].Images) + } + if len(runtime.runInputs) != 1 { + t.Fatalf("expected one runtime run input, got %+v", runtime.runInputs) + } +} + +func TestUpdateSendWithInlineImageReferenceUsesPreparePipeline(t *testing.T) { + app, runtime := newTestApp(t) + root := t.TempDir() + app.state.CurrentWorkdir = root + + imagePath := filepath.Join(root, "burn.png") + if err := os.WriteFile(imagePath, []byte("png"), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + app.providerSvc = stubProviderService{ + providers: []configstate.ProviderOption{{ID: app.state.CurrentProvider, Name: app.state.CurrentProvider}}, + models: []providertypes.ModelDescriptor{{ + ID: app.state.CurrentModel, + Name: app.state.CurrentModel, + CapabilityHints: providertypes.ModelCapabilityHints{ + ImageInput: providertypes.ModelCapabilityStateSupported, + }, + }}, + } + + app.input.SetValue("请分析 @image:burn.png") + app.state.InputText = "请分析 @image:burn.png" + + model, cmd := app.Update(tea.KeyMsg{Type: tea.KeyEnter}) + if cmd != nil { + _ = cmd() + } + app = model.(App) + + if len(runtime.prepareInputs) != 1 { + t.Fatalf("expected one prepare input, got %+v", runtime.prepareInputs) + } + if runtime.prepareInputs[0].Text != "请分析" { + t.Fatalf("expected inline image token removed from text, got %q", runtime.prepareInputs[0].Text) + } + if len(runtime.prepareInputs[0].Images) != 1 || runtime.prepareInputs[0].Images[0].MimeType != "" { + t.Fatalf("expected one promoted image in prepare input, got %+v", runtime.prepareInputs[0].Images) } - if app.state.InputText != "hello" { - t.Fatalf("expected state input text to be preserved, got %q", app.state.InputText) + if len(runtime.runInputs) != 1 { + t.Fatalf("expected one runtime run input, got %+v", runtime.runInputs) } } @@ -1778,7 +1860,7 @@ func TestRuntimeEventAgentChunkHandler(t *testing.T) { func TestRuntimeEventRunCanceledHandler(t *testing.T) { app, _ := newTestApp(t) app.state.ActiveRunID = "run-3" - runtimeEventRunCanceledHandler(&app, agentruntime.RuntimeEvent{}) + runtimeEventRunCanceledHandler(&app) if app.state.StatusText != statusCanceled { t.Fatalf("expected canceled status") } diff --git a/internal/tui/services/runtime_service.go b/internal/tui/services/runtime_service.go index d4982d4d..16d5caa1 100644 --- a/internal/tui/services/runtime_service.go +++ b/internal/tui/services/runtime_service.go @@ -16,6 +16,12 @@ type Runner interface { Run(ctx context.Context, input agentruntime.UserInput) error } +// PreparedRunner 定义“输入归一化 + run”链路所需最小能力。 +// Submitter 定义 runtime 单入口提交所需的最小能力。 +type Submitter interface { + Submit(ctx context.Context, input agentruntime.PrepareInput) error +} + // Compactor 定义执行 runtime compact 所需最小能力。 type Compactor interface { Compact(ctx context.Context, input agentruntime.CompactInput) (agentruntime.CompactResult, error) @@ -53,6 +59,15 @@ func RunAgentCmd( } } +// RunPreparedAgentCmd 先执行输入归一化,再执行 runtime run,并将结果映射为 UI 消息。 +// RunSubmitCmd 执行 runtime 单入口提交,并将结果映射为 UI 消息。 +func RunSubmitCmd(runtime Submitter, input agentruntime.PrepareInput, doneMsg func(error) tea.Msg) tea.Cmd { + return func() tea.Msg { + err := runtime.Submit(context.Background(), input) + return doneMsg(err) + } +} + // RunCompactCmd 执行 runtime compact,并将结果映射为 UI 消息。 func RunCompactCmd( runtime Compactor, diff --git a/internal/tui/services/services_test.go b/internal/tui/services/services_test.go index 42b5eec4..bb5eefcb 100644 --- a/internal/tui/services/services_test.go +++ b/internal/tui/services/services_test.go @@ -26,6 +26,16 @@ func (s *stubRunner) Run(ctx context.Context, input agentruntime.UserInput) erro return s.err } +type stubSubmitter struct { + lastInput agentruntime.PrepareInput + err error +} + +func (s *stubSubmitter) Submit(ctx context.Context, input agentruntime.PrepareInput) error { + s.lastInput = input + return s.err +} + type stubCompactor struct { lastInput agentruntime.CompactInput err error @@ -105,6 +115,24 @@ func TestRunAgentCmd(t *testing.T) { } } +func TestRunSubmitCmd(t *testing.T) { + runner := &stubSubmitter{err: errors.New("run failed")} + prepareInput := agentruntime.PrepareInput{ + SessionID: "s1", + RunID: "run-1", + Workdir: "D:/", + Text: "hello", + Images: []agentruntime.UserImageInput{{Path: "C:/a.png", MimeType: "image/png"}}, + } + msg := RunSubmitCmd(runner, prepareInput, func(err error) tea.Msg { return err })() + if runner.lastInput.RunID != "run-1" || len(runner.lastInput.Images) != 1 { + t.Fatalf("unexpected submit input: %+v", runner.lastInput) + } + if err, ok := msg.(error); !ok || err == nil || err.Error() != "run failed" { + t.Fatalf("expected forwarded run error, got %T %#v", msg, msg) + } +} + func TestRunCompactCmd(t *testing.T) { compactor := &stubCompactor{err: errors.New("compact failed")} input := agentruntime.CompactInput{SessionID: "s2"} diff --git a/internal/updater/updater.go b/internal/updater/updater.go new file mode 100644 index 00000000..4166b643 --- /dev/null +++ b/internal/updater/updater.go @@ -0,0 +1,252 @@ +package updater + +import ( + "context" + "errors" + "fmt" + "regexp" + "runtime" + "strings" + + selfupdate "github.com/creativeprojects/go-selfupdate" + + "neo-code/internal/version" +) + +const ( + repositoryOwner = "1024XEngineer" + repositoryName = "neo-code" + checksumFilename = "checksums.txt" +) + +var ( + runtimeGOOS = runtime.GOOS + runtimeGOARCH = runtime.GOARCH +) + +var ( + newClient = func(config selfupdate.Config) (updateClient, error) { + updater, err := selfupdate.NewUpdater(config) + if err != nil { + return nil, err + } + return selfupdateClient{updater: updater}, nil + } + resolveExecutablePath = selfupdate.ExecutablePath +) + +type assetTarget struct { + OSToken string + ArchToken string + Ext string + AssetName string +} + +type releaseView interface { + Version() string + GreaterThan(other string) bool +} + +type updateClient interface { + DetectLatest(ctx context.Context, repository selfupdate.Repository) (releaseView, bool, error) + UpdateTo(ctx context.Context, rel releaseView, cmdPath string) error +} + +type selfupdateClient struct { + updater *selfupdate.Updater +} + +type selfupdateRelease struct { + release *selfupdate.Release +} + +// CheckOptions 描述静默检测新版本时的输入参数。 +type CheckOptions struct { + CurrentVersion string + IncludePrerelease bool +} + +// CheckResult 表示静默检测流程返回的版本信息。 +type CheckResult struct { + CurrentVersion string + LatestVersion string + HasUpdate bool +} + +// UpdateOptions 描述手动更新命令的输入参数。 +type UpdateOptions struct { + CurrentVersion string + IncludePrerelease bool +} + +// UpdateResult 表示手动更新流程的最终结果。 +type UpdateResult struct { + CurrentVersion string + LatestVersion string + Updated bool +} + +// CheckLatest 按当前平台资产规则检测最新版本,不做本地文件替换。 +func CheckLatest(ctx context.Context, opts CheckOptions) (CheckResult, error) { + currentVersion := normalizeCurrentVersion(opts.CurrentVersion) + target, err := resolveAssetTarget(runtimeGOOS, runtimeGOARCH) + if err != nil { + return CheckResult{CurrentVersion: currentVersion}, err + } + + client, err := newClient(buildSelfupdateConfig(target, opts.IncludePrerelease)) + if err != nil { + return CheckResult{CurrentVersion: currentVersion}, err + } + + repository := selfupdate.NewRepositorySlug(repositoryOwner, repositoryName) + release, found, err := client.DetectLatest(ctx, repository) + if err != nil { + return CheckResult{CurrentVersion: currentVersion}, err + } + + result := CheckResult{CurrentVersion: currentVersion} + if !found || release == nil { + return result, nil + } + + result.LatestVersion = strings.TrimSpace(release.Version()) + if result.LatestVersion == "" { + return result, nil + } + + if version.IsSemverRelease(currentVersion) { + result.HasUpdate = release.GreaterThan(currentVersion) + } + return result, nil +} + +// DoUpdate 下载并校验最新版本后原地替换当前可执行文件。 +func DoUpdate(ctx context.Context, opts UpdateOptions) (UpdateResult, error) { + currentVersion := normalizeCurrentVersion(opts.CurrentVersion) + target, err := resolveAssetTarget(runtimeGOOS, runtimeGOARCH) + if err != nil { + return UpdateResult{CurrentVersion: currentVersion}, err + } + + client, err := newClient(buildSelfupdateConfig(target, opts.IncludePrerelease)) + if err != nil { + return UpdateResult{CurrentVersion: currentVersion}, err + } + + repository := selfupdate.NewRepositorySlug(repositoryOwner, repositoryName) + release, found, err := client.DetectLatest(ctx, repository) + if err != nil { + return UpdateResult{CurrentVersion: currentVersion}, err + } + if !found || release == nil { + return UpdateResult{CurrentVersion: currentVersion}, errors.New("updater: no release asset found for current platform") + } + + latestVersion := strings.TrimSpace(release.Version()) + result := UpdateResult{ + CurrentVersion: currentVersion, + LatestVersion: latestVersion, + } + + if version.IsSemverRelease(currentVersion) && !release.GreaterThan(currentVersion) { + return result, nil + } + + executablePath, err := resolveExecutablePath() + if err != nil { + return result, err + } + + if err := client.UpdateTo(ctx, release, executablePath); err != nil { + return result, err + } + + result.Updated = true + return result, nil +} + +// DetectLatest 调用底层 go-selfupdate 客户端获取最新版本信息。 +func (c selfupdateClient) DetectLatest(ctx context.Context, repository selfupdate.Repository) (releaseView, bool, error) { + release, found, err := c.updater.DetectLatest(ctx, repository) + if err != nil || !found || release == nil { + return nil, found, err + } + return selfupdateRelease{release: release}, true, nil +} + +// UpdateTo 委托 go-selfupdate 完成原地替换流程,不追加平台分支逻辑。 +func (c selfupdateClient) UpdateTo(ctx context.Context, rel releaseView, cmdPath string) error { + typed, ok := rel.(selfupdateRelease) + if !ok || typed.release == nil { + return errors.New("updater: unsupported release type") + } + return c.updater.UpdateTo(ctx, typed.release, cmdPath) +} + +// Version 返回底层 release 的语义化版本字符串。 +func (r selfupdateRelease) Version() string { + return strings.TrimSpace(r.release.Version()) +} + +// GreaterThan 判断底层 release 是否高于指定版本。 +func (r selfupdateRelease) GreaterThan(other string) bool { + return r.release.GreaterThan(other) +} + +// normalizeCurrentVersion 归一化当前版本输入并处理空值回退。 +func normalizeCurrentVersion(value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "dev" + } + return trimmed +} + +// buildSelfupdateConfig 构建严格资产匹配与 checksum 校验配置。 +func buildSelfupdateConfig(target assetTarget, includePrerelease bool) selfupdate.Config { + return selfupdate.Config{ + OS: target.OSToken, + Arch: target.ArchToken, + Filters: []string{"^" + regexp.QuoteMeta(target.AssetName) + "$"}, + Validator: &selfupdate.ChecksumValidator{UniqueFilename: checksumFilename}, + Prerelease: includePrerelease, + } +} + +// resolveAssetTarget 按 GoReleaser 产物命名约束生成当前平台目标资产名。 +func resolveAssetTarget(goos string, goarch string) (assetTarget, error) { + var osToken string + switch strings.ToLower(strings.TrimSpace(goos)) { + case "linux": + osToken = "Linux" + case "darwin": + osToken = "Darwin" + case "windows": + osToken = "Windows" + default: + return assetTarget{}, fmt.Errorf("updater: unsupported os %q", goos) + } + + var archToken string + switch strings.ToLower(strings.TrimSpace(goarch)) { + case "amd64": + archToken = "x86_64" + case "arm64": + archToken = "arm64" + default: + return assetTarget{}, fmt.Errorf("updater: unsupported arch %q", goarch) + } + + ext := "tar.gz" + if osToken == "Windows" { + ext = "zip" + } + + return assetTarget{ + OSToken: osToken, + ArchToken: archToken, + Ext: ext, + AssetName: fmt.Sprintf("neocode_%s_%s.%s", osToken, archToken, ext), + }, nil +} diff --git a/internal/updater/updater_test.go b/internal/updater/updater_test.go new file mode 100644 index 00000000..b01f986c --- /dev/null +++ b/internal/updater/updater_test.go @@ -0,0 +1,703 @@ +package updater + +import ( + "bytes" + "context" + "errors" + "io" + "regexp" + "testing" + "time" + + selfupdate "github.com/creativeprojects/go-selfupdate" +) + +type fakeRelease struct { + version string + greaterFn func(string) bool +} + +func (r fakeRelease) Version() string { + return r.version +} + +func (r fakeRelease) GreaterThan(other string) bool { + if r.greaterFn != nil { + return r.greaterFn(other) + } + return false +} + +type fakeClient struct { + release releaseView + found bool + detectErr error + updateErr error + updateCalls int + lastUpdatePath string +} + +func (c *fakeClient) DetectLatest(context.Context, selfupdate.Repository) (releaseView, bool, error) { + return c.release, c.found, c.detectErr +} + +func (c *fakeClient) UpdateTo(_ context.Context, rel releaseView, cmdPath string) error { + _ = rel + c.updateCalls++ + c.lastUpdatePath = cmdPath + return c.updateErr +} + +type stubSource struct { + releases []selfupdate.SourceRelease + listErr error +} + +func (s stubSource) ListReleases(context.Context, selfupdate.Repository) ([]selfupdate.SourceRelease, error) { + if s.listErr != nil { + return nil, s.listErr + } + return s.releases, nil +} + +func (s stubSource) DownloadReleaseAsset(context.Context, *selfupdate.Release, int64) (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(nil)), nil +} + +type stubSourceRelease struct { + id int64 + tagName string + draft bool + prerelease bool + assets []selfupdate.SourceAsset +} + +func (r stubSourceRelease) GetID() int64 { return r.id } +func (r stubSourceRelease) GetTagName() string { return r.tagName } +func (r stubSourceRelease) GetDraft() bool { return r.draft } +func (r stubSourceRelease) GetPrerelease() bool { return r.prerelease } +func (r stubSourceRelease) GetPublishedAt() time.Time { return time.Now() } +func (r stubSourceRelease) GetReleaseNotes() string { return "" } +func (r stubSourceRelease) GetName() string { return r.tagName } +func (r stubSourceRelease) GetURL() string { return "https://example.com/release" } +func (r stubSourceRelease) GetAssets() []selfupdate.SourceAsset { + return r.assets +} + +type stubSourceAsset struct { + id int64 + name string + size int +} + +func (a stubSourceAsset) GetID() int64 { return a.id } +func (a stubSourceAsset) GetName() string { return a.name } +func (a stubSourceAsset) GetSize() int { return a.size } +func (a stubSourceAsset) GetBrowserDownloadURL() string { return "https://example.com/asset" } + +func TestResolveAssetTarget(t *testing.T) { + tests := []struct { + name string + goos string + goarch string + wantOS string + wantArch string + wantExt string + wantAsset string + expectErrMsg string + }{ + { + name: "linux amd64", + goos: "linux", + goarch: "amd64", + wantOS: "Linux", + wantArch: "x86_64", + wantExt: "tar.gz", + wantAsset: "neocode_Linux_x86_64.tar.gz", + }, + { + name: "darwin arm64", + goos: "darwin", + goarch: "arm64", + wantOS: "Darwin", + wantArch: "arm64", + wantExt: "tar.gz", + wantAsset: "neocode_Darwin_arm64.tar.gz", + }, + { + name: "windows amd64", + goos: "windows", + goarch: "amd64", + wantOS: "Windows", + wantArch: "x86_64", + wantExt: "zip", + wantAsset: "neocode_Windows_x86_64.zip", + }, + { + name: "unsupported os", + goos: "freebsd", + goarch: "amd64", + expectErrMsg: "unsupported os", + }, + { + name: "unsupported arch", + goos: "linux", + goarch: "386", + expectErrMsg: "unsupported arch", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + target, err := resolveAssetTarget(tt.goos, tt.goarch) + if tt.expectErrMsg != "" { + if err == nil || !regexp.MustCompile(tt.expectErrMsg).MatchString(err.Error()) { + t.Fatalf("resolveAssetTarget() error = %v, want contains %q", err, tt.expectErrMsg) + } + return + } + if err != nil { + t.Fatalf("resolveAssetTarget() error = %v", err) + } + if target.OSToken != tt.wantOS { + t.Fatalf("OSToken = %q, want %q", target.OSToken, tt.wantOS) + } + if target.ArchToken != tt.wantArch { + t.Fatalf("ArchToken = %q, want %q", target.ArchToken, tt.wantArch) + } + if target.Ext != tt.wantExt { + t.Fatalf("Ext = %q, want %q", target.Ext, tt.wantExt) + } + if target.AssetName != tt.wantAsset { + t.Fatalf("AssetName = %q, want %q", target.AssetName, tt.wantAsset) + } + }) + } +} + +func TestBuildSelfupdateConfigUsesExactFilterAndChecksum(t *testing.T) { + target := assetTarget{ + OSToken: "Darwin", + ArchToken: "x86_64", + Ext: "tar.gz", + AssetName: "neocode_Darwin_x86_64.tar.gz", + } + config := buildSelfupdateConfig(target, true) + if config.OS != "Darwin" || config.Arch != "x86_64" { + t.Fatalf("OS/Arch = %q/%q, want %q/%q", config.OS, config.Arch, "Darwin", "x86_64") + } + if !config.Prerelease { + t.Fatal("expected prerelease to be enabled") + } + if len(config.Filters) != 1 { + t.Fatalf("len(Filters) = %d, want 1", len(config.Filters)) + } + exactFilter := config.Filters[0] + re := regexp.MustCompile(exactFilter) + if !re.MatchString("neocode_Darwin_x86_64.tar.gz") { + t.Fatal("exact filter should match target asset") + } + if re.MatchString("neocode_Darwin_x86_64.tar.gz.sig") { + t.Fatal("exact filter should not match similar asset names") + } + validator, ok := config.Validator.(*selfupdate.ChecksumValidator) + if !ok { + t.Fatalf("validator type = %T, want *selfupdate.ChecksumValidator", config.Validator) + } + if validator.UniqueFilename != checksumFilename { + t.Fatalf("UniqueFilename = %q, want %q", validator.UniqueFilename, checksumFilename) + } +} + +func TestCheckLatest(t *testing.T) { + originalNewClient := newClient + originalGOOS := runtimeGOOS + originalGOARCH := runtimeGOARCH + t.Cleanup(func() { + newClient = originalNewClient + runtimeGOOS = originalGOOS + runtimeGOARCH = originalGOARCH + }) + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + + client := &fakeClient{ + release: fakeRelease{ + version: "v1.2.0", + greaterFn: func(other string) bool { + return other == "v1.1.0" + }, + }, + found: true, + } + newClient = func(config selfupdate.Config) (updateClient, error) { + return client, nil + } + + result, err := CheckLatest(context.Background(), CheckOptions{ + CurrentVersion: "v1.1.0", + IncludePrerelease: false, + }) + if err != nil { + t.Fatalf("CheckLatest() error = %v", err) + } + if !result.HasUpdate { + t.Fatal("expected HasUpdate to be true") + } + if result.LatestVersion != "v1.2.0" { + t.Fatalf("LatestVersion = %q, want %q", result.LatestVersion, "v1.2.0") + } +} + +func TestCheckLatestErrorBranches(t *testing.T) { + originalNewClient := newClient + originalGOOS := runtimeGOOS + originalGOARCH := runtimeGOARCH + t.Cleanup(func() { + newClient = originalNewClient + runtimeGOOS = originalGOOS + runtimeGOARCH = originalGOARCH + }) + + t.Run("unsupported platform", func(t *testing.T) { + runtimeGOOS = "plan9" + runtimeGOARCH = "amd64" + + result, err := CheckLatest(context.Background(), CheckOptions{CurrentVersion: ""}) + if err == nil || !regexp.MustCompile(`unsupported os`).MatchString(err.Error()) { + t.Fatalf("CheckLatest() error = %v, want unsupported os", err) + } + if result.CurrentVersion != "dev" { + t.Fatalf("CurrentVersion = %q, want %q", result.CurrentVersion, "dev") + } + }) + + t.Run("new client failure", func(t *testing.T) { + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + newClient = func(selfupdate.Config) (updateClient, error) { + return nil, errors.New("new client failed") + } + + _, err := CheckLatest(context.Background(), CheckOptions{CurrentVersion: "v1.0.0"}) + if err == nil || err.Error() != "new client failed" { + t.Fatalf("CheckLatest() error = %v, want new client failed", err) + } + }) + + t.Run("detect latest failure", func(t *testing.T) { + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + newClient = func(selfupdate.Config) (updateClient, error) { + return &fakeClient{detectErr: errors.New("detect failed")}, nil + } + + _, err := CheckLatest(context.Background(), CheckOptions{CurrentVersion: "v1.0.0"}) + if err == nil || err.Error() != "detect failed" { + t.Fatalf("CheckLatest() error = %v, want detect failed", err) + } + }) + + t.Run("not found release", func(t *testing.T) { + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + newClient = func(selfupdate.Config) (updateClient, error) { + return &fakeClient{found: false}, nil + } + + result, err := CheckLatest(context.Background(), CheckOptions{CurrentVersion: " v1.0.0 "}) + if err != nil { + t.Fatalf("CheckLatest() error = %v", err) + } + if result.CurrentVersion != "v1.0.0" { + t.Fatalf("CurrentVersion = %q, want %q", result.CurrentVersion, "v1.0.0") + } + if result.HasUpdate { + t.Fatalf("HasUpdate = true, want false") + } + }) + + t.Run("empty latest version", func(t *testing.T) { + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + newClient = func(selfupdate.Config) (updateClient, error) { + return &fakeClient{ + release: fakeRelease{version: " "}, + found: true, + }, nil + } + + result, err := CheckLatest(context.Background(), CheckOptions{CurrentVersion: "v1.0.0"}) + if err != nil { + t.Fatalf("CheckLatest() error = %v", err) + } + if result.LatestVersion != "" || result.HasUpdate { + t.Fatalf("unexpected result: %+v", result) + } + }) + + t.Run("non semver current version never marks update", func(t *testing.T) { + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + newClient = func(selfupdate.Config) (updateClient, error) { + return &fakeClient{ + release: fakeRelease{ + version: "v9.9.9", + greaterFn: func(string) bool { + return true + }, + }, + found: true, + }, nil + } + + result, err := CheckLatest(context.Background(), CheckOptions{CurrentVersion: "dev"}) + if err != nil { + t.Fatalf("CheckLatest() error = %v", err) + } + if result.HasUpdate { + t.Fatalf("HasUpdate = true, want false for non-semver current version") + } + }) +} + +func TestDoUpdateSkipsWhenAlreadyLatestForSemver(t *testing.T) { + originalNewClient := newClient + originalGOOS := runtimeGOOS + originalGOARCH := runtimeGOARCH + t.Cleanup(func() { + newClient = originalNewClient + runtimeGOOS = originalGOOS + runtimeGOARCH = originalGOARCH + }) + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + + client := &fakeClient{ + release: fakeRelease{ + version: "v1.2.0", + greaterFn: func(other string) bool { + return false + }, + }, + found: true, + } + newClient = func(config selfupdate.Config) (updateClient, error) { + return client, nil + } + + result, err := DoUpdate(context.Background(), UpdateOptions{CurrentVersion: "v1.2.0"}) + if err != nil { + t.Fatalf("DoUpdate() error = %v", err) + } + if result.Updated { + t.Fatal("expected Updated to be false") + } + if client.updateCalls != 0 { + t.Fatalf("update calls = %d, want 0", client.updateCalls) + } +} + +func TestDoUpdateUsesUpdaterLibraryPathForWindows(t *testing.T) { + originalNewClient := newClient + originalExePath := resolveExecutablePath + originalGOOS := runtimeGOOS + originalGOARCH := runtimeGOARCH + t.Cleanup(func() { + newClient = originalNewClient + resolveExecutablePath = originalExePath + runtimeGOOS = originalGOOS + runtimeGOARCH = originalGOARCH + }) + runtimeGOOS = "windows" + runtimeGOARCH = "amd64" + + client := &fakeClient{ + release: fakeRelease{ + version: "v1.3.0", + greaterFn: func(other string) bool { + return false + }, + }, + found: true, + } + + var capturedConfig selfupdate.Config + newClient = func(config selfupdate.Config) (updateClient, error) { + capturedConfig = config + return client, nil + } + resolveExecutablePath = func() (string, error) { + return `C:\Tools\neocode.exe`, nil + } + + result, err := DoUpdate(context.Background(), UpdateOptions{CurrentVersion: "dev"}) + if err != nil { + t.Fatalf("DoUpdate() error = %v", err) + } + if !result.Updated { + t.Fatal("expected Updated to be true") + } + if client.updateCalls != 1 { + t.Fatalf("update calls = %d, want 1", client.updateCalls) + } + if client.lastUpdatePath != `C:\Tools\neocode.exe` { + t.Fatalf("last update path = %q, want %q", client.lastUpdatePath, `C:\Tools\neocode.exe`) + } + if capturedConfig.OS != "Windows" || capturedConfig.Arch != "x86_64" { + t.Fatalf("config OS/Arch = %q/%q, want %q/%q", capturedConfig.OS, capturedConfig.Arch, "Windows", "x86_64") + } +} + +func TestDoUpdatePropagatesUpdateError(t *testing.T) { + originalNewClient := newClient + originalExePath := resolveExecutablePath + originalGOOS := runtimeGOOS + originalGOARCH := runtimeGOARCH + t.Cleanup(func() { + newClient = originalNewClient + resolveExecutablePath = originalExePath + runtimeGOOS = originalGOOS + runtimeGOARCH = originalGOARCH + }) + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + + expected := errors.New("apply update failed") + client := &fakeClient{ + release: fakeRelease{ + version: "v1.3.0", + greaterFn: func(other string) bool { + return true + }, + }, + found: true, + updateErr: expected, + } + + newClient = func(config selfupdate.Config) (updateClient, error) { + return client, nil + } + resolveExecutablePath = func() (string, error) { + return "/usr/local/bin/neocode", nil + } + + _, err := DoUpdate(context.Background(), UpdateOptions{CurrentVersion: "v1.2.0"}) + if !errors.Is(err, expected) { + t.Fatalf("DoUpdate() error = %v, want %v", err, expected) + } +} + +func TestDoUpdateErrorAndEdgeBranches(t *testing.T) { + originalNewClient := newClient + originalExePath := resolveExecutablePath + originalGOOS := runtimeGOOS + originalGOARCH := runtimeGOARCH + t.Cleanup(func() { + newClient = originalNewClient + resolveExecutablePath = originalExePath + runtimeGOOS = originalGOOS + runtimeGOARCH = originalGOARCH + }) + + t.Run("unsupported platform", func(t *testing.T) { + runtimeGOOS = "plan9" + runtimeGOARCH = "amd64" + + result, err := DoUpdate(context.Background(), UpdateOptions{CurrentVersion: ""}) + if err == nil || !regexp.MustCompile(`unsupported os`).MatchString(err.Error()) { + t.Fatalf("DoUpdate() error = %v, want unsupported os", err) + } + if result.CurrentVersion != "dev" { + t.Fatalf("CurrentVersion = %q, want %q", result.CurrentVersion, "dev") + } + }) + + t.Run("new client failure", func(t *testing.T) { + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + newClient = func(selfupdate.Config) (updateClient, error) { + return nil, errors.New("new client failed") + } + + _, err := DoUpdate(context.Background(), UpdateOptions{CurrentVersion: "v1.0.0"}) + if err == nil || err.Error() != "new client failed" { + t.Fatalf("DoUpdate() error = %v, want new client failed", err) + } + }) + + t.Run("detect latest failure", func(t *testing.T) { + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + newClient = func(selfupdate.Config) (updateClient, error) { + return &fakeClient{detectErr: errors.New("detect failed")}, nil + } + + _, err := DoUpdate(context.Background(), UpdateOptions{CurrentVersion: "v1.0.0"}) + if err == nil || err.Error() != "detect failed" { + t.Fatalf("DoUpdate() error = %v, want detect failed", err) + } + }) + + t.Run("release not found", func(t *testing.T) { + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + newClient = func(selfupdate.Config) (updateClient, error) { + return &fakeClient{found: false}, nil + } + + _, err := DoUpdate(context.Background(), UpdateOptions{CurrentVersion: "v1.0.0"}) + if err == nil || !regexp.MustCompile(`no release asset found`).MatchString(err.Error()) { + t.Fatalf("DoUpdate() error = %v, want no release asset found", err) + } + }) + + t.Run("resolve executable path failure", func(t *testing.T) { + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + newClient = func(selfupdate.Config) (updateClient, error) { + return &fakeClient{ + release: fakeRelease{ + version: "v1.3.0", + greaterFn: func(string) bool { + return true + }, + }, + found: true, + }, nil + } + resolveExecutablePath = func() (string, error) { + return "", errors.New("resolve exec failed") + } + + _, err := DoUpdate(context.Background(), UpdateOptions{CurrentVersion: "v1.2.0"}) + if err == nil || err.Error() != "resolve exec failed" { + t.Fatalf("DoUpdate() error = %v, want resolve exec failed", err) + } + }) + + t.Run("dev version updates without semver compare", func(t *testing.T) { + runtimeGOOS = "linux" + runtimeGOARCH = "amd64" + + client := &fakeClient{ + release: fakeRelease{ + version: "v1.3.0", + greaterFn: func(string) bool { + return false + }, + }, + found: true, + } + newClient = func(selfupdate.Config) (updateClient, error) { + return client, nil + } + resolveExecutablePath = func() (string, error) { + return "/tmp/neocode", nil + } + + result, err := DoUpdate(context.Background(), UpdateOptions{CurrentVersion: "dev"}) + if err != nil { + t.Fatalf("DoUpdate() error = %v", err) + } + if !result.Updated { + t.Fatalf("Updated = false, want true") + } + if client.updateCalls != 1 { + t.Fatalf("update calls = %d, want 1", client.updateCalls) + } + }) +} + +func TestSelfupdateClientDetectLatestAndUnsupportedUpdateType(t *testing.T) { + target := assetTarget{ + OSToken: "linux", + ArchToken: "amd64", + Ext: "tar.gz", + AssetName: "neocode_linux_amd64.tar.gz", + } + + source := stubSource{ + releases: []selfupdate.SourceRelease{ + stubSourceRelease{ + id: 1, + tagName: "v1.5.0", + assets: []selfupdate.SourceAsset{ + stubSourceAsset{id: 1, name: target.AssetName, size: 1}, + }, + }, + }, + } + updater, err := selfupdate.NewUpdater(selfupdate.Config{ + Source: source, + OS: target.OSToken, + Arch: target.ArchToken, + }) + if err != nil { + t.Fatalf("NewUpdater() error = %v", err) + } + + client := selfupdateClient{updater: updater} + rel, found, err := client.DetectLatest(context.Background(), selfupdate.NewRepositorySlug(repositoryOwner, repositoryName)) + if err != nil { + t.Fatalf("DetectLatest() error = %v", err) + } + if !found || rel == nil { + t.Fatalf("expected release found, got found=%v rel=%v", found, rel) + } + if rel.Version() == "" { + t.Fatalf("expected non-empty release version") + } + if !rel.GreaterThan("1.0.0") { + t.Fatalf("expected release to be greater than 1.0.0") + } + + noReleaseUpdater, err := selfupdate.NewUpdater(selfupdate.Config{ + Source: stubSource{releases: nil}, + OS: target.OSToken, + Arch: target.ArchToken, + }) + if err != nil { + t.Fatalf("NewUpdater(no release) error = %v", err) + } + noReleaseClient := selfupdateClient{updater: noReleaseUpdater} + if gotRel, gotFound, gotErr := noReleaseClient.DetectLatest( + context.Background(), + selfupdate.NewRepositorySlug(repositoryOwner, repositoryName), + ); gotErr != nil || gotFound || gotRel != nil { + t.Fatalf("DetectLatest(no release) = (%v, %v, %v), want (nil, false, nil)", gotRel, gotFound, gotErr) + } + + err = client.UpdateTo(context.Background(), fakeRelease{version: "v1.0.0"}, "/tmp/neocode") + if err == nil || err.Error() != "updater: unsupported release type" { + t.Fatalf("UpdateTo() error = %v, want unsupported release type", err) + } + + err = client.UpdateTo(context.Background(), selfupdateRelease{}, "/tmp/neocode") + if err == nil || err.Error() != "updater: unsupported release type" { + t.Fatalf("UpdateTo() error = %v, want unsupported release type for nil release", err) + } + + if err := client.UpdateTo(context.Background(), rel, "/tmp/neocode"); err == nil { + t.Fatalf("expected UpdateTo() to fail with stub asset payload") + } +} + +func TestNewClientFactory(t *testing.T) { + _, err := newClient(selfupdate.Config{Filters: []string{"("}}) + if err == nil { + t.Fatalf("expected newClient to fail with invalid filter regex") + } + + client, err := newClient(selfupdate.Config{ + Source: stubSource{}, + OS: "linux", + Arch: "amd64", + }) + if err != nil { + t.Fatalf("newClient() unexpected error: %v", err) + } + if client == nil { + t.Fatalf("expected non-nil client") + } +} diff --git a/internal/version/version.go b/internal/version/version.go new file mode 100644 index 00000000..9f8ea7af --- /dev/null +++ b/internal/version/version.go @@ -0,0 +1,25 @@ +package version + +import ( + "regexp" + "strings" +) + +var semverPattern = regexp.MustCompile(`^v?\d+\.\d+\.\d+(?:-[0-9A-Za-z.-]+)?(?:\+[0-9A-Za-z.-]+)?$`) + +// Version 表示当前构建注入的版本号;默认值用于本地开发构建。 +var Version = "dev" + +// Current 返回归一化后的当前版本;空值会回退为 dev。 +func Current() string { + value := strings.TrimSpace(Version) + if value == "" { + return "dev" + } + return value +} + +// IsSemverRelease 判断给定版本字符串是否为可比较的语义化版本。 +func IsSemverRelease(value string) bool { + return semverPattern.MatchString(strings.TrimSpace(value)) +} diff --git a/internal/version/version_test.go b/internal/version/version_test.go new file mode 100644 index 00000000..befdd017 --- /dev/null +++ b/internal/version/version_test.go @@ -0,0 +1,41 @@ +package version + +import "testing" + +func TestCurrentFallsBackToDev(t *testing.T) { + original := Version + t.Cleanup(func() { Version = original }) + + Version = " " + if got := Current(); got != "dev" { + t.Fatalf("Current() = %q, want %q", got, "dev") + } + + Version = " v1.2.3 " + if got := Current(); got != "v1.2.3" { + t.Fatalf("Current() = %q, want %q", got, "v1.2.3") + } +} + +func TestIsSemverRelease(t *testing.T) { + tests := []struct { + name string + value string + matched bool + }{ + {name: "with v prefix", value: "v1.2.3", matched: true}, + {name: "without v prefix", value: "1.2.3", matched: true}, + {name: "prerelease", value: "v1.2.3-rc.1", matched: true}, + {name: "build metadata", value: "v1.2.3+meta", matched: true}, + {name: "dev", value: "dev", matched: false}, + {name: "missing patch", value: "v1.2", matched: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsSemverRelease(tt.value); got != tt.matched { + t.Fatalf("IsSemverRelease(%q) = %v, want %v", tt.value, got, tt.matched) + } + }) + } +}