diff --git a/docs/session-persistence-design.md b/docs/session-persistence-design.md index 03ce5aa9..8a970dc9 100644 --- a/docs/session-persistence-design.md +++ b/docs/session-persistence-design.md @@ -2,106 +2,144 @@ ## 模块职责与边界 -- `internal/session` 是会话领域模型、存储抽象与 JSON 持久化实现的唯一归属层。 -- `internal/runtime` 负责决定保存时机、恢复会话状态和编排主循环,不承载文件存储细节。 -- `internal/tui` 只消费 runtime 暴露的会话数据,不直接读写会话文件。 +- `internal/session` 是会话领域模型、SQLite 存储实现和资产持久化的唯一归属层。 +- `internal/runtime` 只决定何时创建会话、追加消息、更新会话头和替换 transcript,不关心底层表结构。 +- `internal/tui` 只消费 runtime 暴露的会话数据,不直接读取数据库或资产文件。 ## 存储策略 -NeoCode 当前使用本地 JSON 文件持久化会话,以保持实现简单、可调试且跨平台可移植。 +NeoCode 当前使用工作区级 SQLite 数据库持久化会话,不再使用 `session.json` 文件。 -- 默认目录按工作区隔离:`~/.neocode/projects//sessions/` -- 工作区哈希基于启动阶段确定的工作区根目录生成 -- `session.Workdir` 表示会话最近一次运行实际使用的目录,由启动 `workdir` 或请求级覆盖值写回,但不参与分桶 -- 旧的全局 `~/.neocode/sessions/` 开发期数据不迁移、不回读 +- 数据库路径:`~/.neocode/projects//session.db` +- 资产目录:`~/.neocode/projects//assets//.bin` +- 工作区哈希基于启动时确定的工作区根目录生成 +- `session.Workdir` 记录该会话最近一次运行实际使用的目录,但不参与分桶 +- 开发阶段遗留的旧 `sessions/` JSON 目录不迁移、不回读、不兼容 + +SQLite 初始化固定使用以下 PRAGMA: + +- `journal_mode = WAL` +- `synchronous = NORMAL` +- `foreign_keys = ON` +- `busy_timeout = 5000` +- `user_version = 1` ## 数据模型 -`internal/session.Session` 持久化以下核心字段: +### sessions + +会话头保存摘要和 durable 状态: -- `schema_version` -- `id`、`title` -- `provider`、`model` -- `created_at`、`updated_at` +- `id` +- `title` +- `created_at_ms` +- `updated_at_ms` +- `provider` +- `model` - `workdir` -- `task_state` -- `todos` -- `messages` +- `task_state_json` +- `todos_json` +- `activated_skills_json` - `token_input_total` - `token_output_total` +- `last_seq` +- `message_count` -其中: +### messages -- `schema_version` 为开发期强校验字段;当前实现只接受当前版本,不兼容旧 session 文件 -- `provider` / `model` 记录最近一次成功运行会话时使用的配置,供 compact 等流程优先复用 -- `task_state` 是会话级 durable task state,由 runtime 维护、session 持久化、context 只读投影 -- `token_input_total` / `token_output_total` 分别表示会话累计输入与输出 token -- token 字段仍使用 `omitempty`,但不再承担旧版 session JSON 兼容职责 +消息正文按行存储,一条消息对应一行: -`internal/session.Summary` 只保留会话列表渲染所需的轻量字段,不加载完整消息历史。 +- `session_id` +- `seq` +- `role` +- `parts_json` +- `tool_calls_json` +- `tool_call_id` +- `is_error` +- `tool_metadata_json` +- `created_at_ms` -`task_state` 固定包含以下字段: +### session_assets -- `goal` -- `progress` -- `open_items` -- `next_step` -- `blockers` -- `key_artifacts` -- `decisions` -- `user_constraints` -- `last_updated_at` +资产元数据入库,二进制内容落盘: -`todos` 固定包含以下要点: - `id` -- `content` -- `status` -- `dependencies` -- `created_at` -- `updated_at` -- 可选 `priority` +- `session_id` +- `mime_type` +- `size_bytes` +- `relative_path` +- `created_at_ms` + +## 运行时读写语义 + +### 创建会话 + +runtime 在新会话开始时调用 `CreateSession`,只写入一条空会话头,不写消息正文。 + +### 追加消息 + +runtime 在以下时机调用 `AppendMessages`: + +- 用户消息提交后 +- assistant 完整回复后 +- 每个 tool result 完成后 + +一次调用会在同一事务内完成两件事: + +- 追加 1..N 条消息 +- 更新会话头上的 `updated_at`、`provider`、`model`、`workdir`、token 增量和消息计数 + +因此常规写入不再与历史消息总量线性耦合。 + +### 更新会话头 + +runtime 在以下场景调用 `UpdateSessionState`: + +- workdir 变更 +- task_state 变更 +- todo 列表变更 +- skill 激活状态变更 +- assistant 本轮没有正文,但 provider/model 或 token 统计发生变化 + +该操作不写消息,只覆盖会话头字段。 -其中 `status` 当前固定为: -- `pending` -- `in_progress` -- `completed` +### 替换 transcript -同时,当 session JSON 缺失 `todos` 字段时,`Load` 会按空 Todo 列表兼容加载。 +compact 成功后,runtime 调用 `ReplaceTranscript`,在单事务内: -## 读写行为 +- 删除该会话原有全部消息 +- 按新顺序写回 compact 后的消息 +- 同步更新 `task_state`、token 统计、provider/model/workdir 和消息计数 -- `Save` 使用“临时文件 + 原子替换”写入完整会话 JSON -- `Load` 在用户真正进入某个会话时读取完整历史,并严格要求 `schema_version` 与 `task_state` 字段存在 -- `ListSummaries` 只解析摘要字段,并按 `updated_at` 倒序返回;不合法的旧 session 文件会被直接跳过 +这是低频路径,允许重写整段 transcript。 -## Token 计数持久化 +### 加载会话 -- runtime 在 provider 调用完成后更新 session 的累计 token 字段 -- 会话保存时,token 计数随 session 一起持久化 -- 会话重新加载时,runtime 从 session 恢复累计 token -- 自动 compact 成功后,runtime 会重置累计 token,并将重置后的值持久化 +- `ListSummaries` 只查询 `sessions` 表,并按 `updated_at` 倒序返回摘要 +- `LoadSession` 先读取会话头,再按 `seq` 顺序加载消息并组装完整 `Session` -## TaskState 与 compact +## Token 持久化 -- `TaskState` 是继续执行多轮任务时的唯一 durable truth,不依赖聊天消息本身长期保存 -- compact 成功后,runtime 会同时回写 `session.TaskState` 和压缩后的 `session.Messages` -- `messages` 中的 `[compact_summary]` 只是展示层,不再是唯一续航载体 -- context 构建时会优先注入 `TaskState`,再注入 memo、最近消息和必要工具结果 -- 只有当 `TaskState` 已建立后,读时 micro compact 才允许清理旧的可重建 tool payload +- runtime 在 assistant 调用完成后累计输入和输出 token +- `AppendMessages` 可以原子地追加消息并累加 token +- `UpdateSessionState` 和 `ReplaceTranscript` 可以直接覆盖 token 总量 +- compact 成功后,runtime 会将 token 总量重置为 0 并持久化 -## 保存时机 +## TaskState 与 Todo -- 用户消息提交后保存 -- assistant 完整回复后保存 -- 每个工具结果完成后保存 -- 避免在高频 UI 刷新路径中直接做磁盘 I/O +- `TaskState` 是 compact 与多轮续航依赖的 durable summary +- `Todo` 是结构化任务状态,独立持久化在 `sessions.todos_json` +- 二者都属于会话头,不写入 `messages` 表 +- context 构建时优先读取 `TaskState`、`Todo`、最近消息和必要工具结果 ## 并发约束 -- `internal/session` 的存储实现自行保护共享访问 -- 保存时机统一由 runtime 决定,TUI 不直接触发磁盘写入 +- SQLite 负责单工作区数据库的一致性和事务边界 +- runtime 继续通过会话锁串行化同一 session 的关键写入路径 +- 不同 session 可以并行运行 ## 演进约束 -- 新增存储实现时,应优先在 `internal/session` 内扩展并通过接口注入 -- 不应把持久化逻辑重新分散到 `runtime`、`tui` 或其他上层模块 +- 新增持久化行为时,优先扩展 `internal/session.Store` 的意图型接口 +- 不要把 SQL、事务或表结构细节泄漏到 `runtime`、`tui` 或其他上层模块 +- 如需进一步优化读路径,应继续在 `internal/session` 内演进,而不是重新引入文件级快照保存 diff --git a/docs/session-todo-design.md b/docs/session-todo-design.md index 50c6d9f6..055e059c 100644 --- a/docs/session-todo-design.md +++ b/docs/session-todo-design.md @@ -1,47 +1,60 @@ # Session Todo 设计说明 -本文档补充说明 `internal/session` 中 Todo 的数据模型、持久化语义与边界约束。 +本文档补充说明 `internal/session` 中 Todo 的数据模型、持久化语义和边界约束。 ## 设计目标 -- Todo 归属于 `Session`,不单独引入新的持久化子系统。 -- Todo 只表示结构化待办状态,不替代现有 `TaskState`。 -- Todo 的校验、规范化和基础增删改查统一收敛在 `internal/session`。 +- Todo 归属于 `Session`,不单独引入新的持久化子系统 +- Todo 只表示结构化待办状态,不替代 `TaskState` +- Todo 的校验、规范化和基础增删改查统一收敛在 `internal/session` ## 数据模型 -`Session` 新增 `todos` 字段,对应 `[]TodoItem`。 +`Session` 包含 `Todos []TodoItem` 字段。 -单个 `TodoItem` 目前包含: +单个 `TodoItem` 当前包含: - `id` - `content` - `status` - `dependencies` +- `priority` +- `owner_type` +- `owner_id` +- `artifacts` +- `failure_reason` +- `revision` - `created_at` - `updated_at` -- 可选 `priority` -其中 `status` 固定为以下三个值: +其中 `status` 当前固定为: - `pending` - `in_progress` - `completed` +- `failed` ## 持久化语义 -- Todo 跟随 `Session` 一起通过现有 JSONStore 保存和加载。 -- `Save` 前会对 Todo 执行统一规范化与校验: - - `id`、`content` 去空白 - - 空状态默认收敛为 `pending` - - `dependencies` 去空白、去重、保持顺序 - - 拒绝重复 ID - - 拒绝自依赖 - - 拒绝引用不存在的依赖项 -- `Load` 允许 session JSON 缺失 `todos` 字段,并按空 Todo 列表处理。 +- Todo 跟随会话头一起保存在 SQLite `sessions.todos_json` +- runtime 修改 Todo 时只调用 `UpdateSessionState`,不会写入 `messages` 表 +- `LoadSession` 时会把 `todos_json` 还原为完整 `[]TodoItem` + +## 规范化与校验 + +写入前会统一执行 Todo 校验和规范化,包括: + +- `id`、`content` 去空白 +- 空状态收敛为 `pending` +- `dependencies` 去空白、去重并保持顺序 +- 拒绝重复 ID +- 拒绝自依赖 +- 拒绝引用不存在的依赖项 +- 使用 `revision` 保障更新时的乐观并发校验 ## 与 TaskState 的关系 -- `TaskState` 仍是 runtime/context 用于 compact 与续航的 durable summary。 -- `Todo` 是更细粒度的结构化状态,不直接注入 context,不写入消息历史。 -- 如果未来需要收敛两者关系,应通过单独演进,让 `TaskState` 从 `Todo` 派生摘要,而不是直接复用同一字段。 +- `TaskState` 仍是 runtime/context 用于 compact 和续航的 durable summary +- `Todo` 是更细粒度的结构化执行状态 +- `Todo` 不直接拼入模型消息历史 +- 如需让 `TaskState` 汇总 Todo,应在 runtime/context 层显式投影,而不是复用同一个字段 diff --git a/go.mod b/go.mod index d5cfe3dd..884cdc24 100644 --- a/go.mod +++ b/go.mod @@ -3,26 +3,28 @@ module neo-code go 1.25.0 require ( + github.com/Microsoft/go-winio v0.6.2 + github.com/atotto/clipboard v0.1.4 github.com/charmbracelet/bubbles v1.0.0 github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/glamour v1.0.0 github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 github.com/creativeprojects/go-selfupdate v1.5.2 + github.com/prometheus/client_golang v1.23.2 github.com/spf13/cobra v1.10.2 github.com/spf13/viper v1.21.0 golang.design/x/clipboard v0.7.1 golang.org/x/net v0.52.0 golang.org/x/sys v0.42.0 gopkg.in/yaml.v3 v3.0.1 + modernc.org/sqlite v1.48.2 ) require ( code.gitea.io/sdk/gitea v0.22.1 // indirect github.com/42wim/httpsig v1.2.3 // indirect github.com/Masterminds/semver/v3 v3.4.0 // indirect - github.com/Microsoft/go-winio v0.6.2 // indirect github.com/alecthomas/chroma/v2 v2.20.0 // indirect - github.com/atotto/clipboard v0.1.4 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymerick/douceur v0.2.0 // indirect github.com/beorn7/perks v1.0.1 // indirect @@ -44,6 +46,7 @@ require ( github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/google/go-github/v74 v74.0.0 // indirect github.com/google/go-querystring v1.1.0 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/gorilla/css v1.0.1 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-retryablehttp v0.7.8 // indirect @@ -59,11 +62,12 @@ require ( github.com/muesli/reflow v0.3.0 // indirect github.com/muesli/termenv v0.16.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect - github.com/prometheus/client_golang v1.23.2 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect github.com/sahilm/fuzzy v0.1.1 // indirect @@ -88,4 +92,7 @@ require ( golang.org/x/text v0.35.0 // indirect golang.org/x/time v0.14.0 // indirect google.golang.org/protobuf v1.36.11 // indirect + modernc.org/libc v1.70.0 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect ) diff --git a/go.sum b/go.sum index 1e41deec..6cdac7df 100644 --- a/go.sum +++ b/go.sum @@ -77,10 +77,15 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-github/v74 v74.0.0 h1:yZcddTUn8DPbj11GxnMrNiAnXH14gNs559AsUpNpPgM= github.com/google/go-github/v74 v74.0.0/go.mod h1:ubn/YdyftV80VPSI26nSJvaEsTOnsjrxG3o9kJhcyak= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8= github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= @@ -89,10 +94,14 @@ github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVU github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw= github.com/hashicorp/go-version v1.8.0 h1:KAkNb1HAiZd1ukkxDFGmokVZe1Xy9HG6NUp+bPle2i4= github.com/hashicorp/go-version v1.8.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -120,6 +129,8 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -132,10 +143,14 @@ github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9Z github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -151,9 +166,9 @@ github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= -github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= @@ -170,6 +185,8 @@ github.com/yuin/goldmark-emoji v1.0.6 h1:QWfF2FYaXwL74tfGOW5izeiZepUDroDJfWubQI9 github.com/yuin/goldmark-emoji v1.0.6/go.mod h1:ukxJDKFpdFb5x0a5HqbdlcKtebh086iJpI31LTKmWuA= gitlab.com/gitlab-org/api/client-go v1.9.1 h1:tZm+URa36sVy8UCEHQyGGJ8COngV4YqMHpM6k9O5tK8= gitlab.com/gitlab-org/api/client-go v1.9.1/go.mod h1:71yTJk1lnHCWcZLvM5kPAXzeJ2fn5GjaoV8gTOPd4ME= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= @@ -183,25 +200,31 @@ golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/exp/shiny v0.0.0-20250606033433-dcc06ee1d476 h1:Wdx0vgH5Wgsw+lF//LJKmWOJBLWX6nprsMqnf99rYDE= golang.org/x/exp/shiny v0.0.0-20250606033433-dcc06ee1d476/go.mod h1:ygj7T6vSGhhm/9yTpOQQNvuAUFziTH7RUiH74EoE2C8= golang.org/x/image v0.28.0 h1:gdem5JW1OLS4FbkWgLO+7ZeFzYtL3xClb97GaUzYMFE= golang.org/x/image v0.28.0/go.mod h1:GUJYXtnGKEUgggyzh+Vxt+AviiCcyiwpsl8iQ8MvwGY= golang.org/x/mobile v0.0.0-20250606033058-a2a15c67f36f h1:/n+PL2HlfqeSiDCuhdBbRNlGS/g2fM4OHufalHaTVG8= golang.org/x/mobile v0.0.0-20250606033058-a2a15c67f36f/go.mod h1:ESkJ836Z6LpG6mTVAhA48LpfW/8fNR0ifStlH2axyfg= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= @@ -212,13 +235,45 @@ golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= -google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= +google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= +modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.32.0 h1:hjG66bI/kqIPX1b2yT6fr/jt+QedtP2fqojG2VrFuVw= +modernc.org/ccgo/v4 v4.32.0/go.mod h1:6F08EBCx5uQc38kMGl+0Nm0oWczoo1c7cgpzEry7Uc0= +modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM= +modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo= +modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw= +modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.48.2 h1:5CnW4uP8joZtA0LedVqLbZV5GD7F/0x91AXeSyjoh5c= +modernc.org/sqlite v1.48.2/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go index 6600db2c..9a9f217d 100644 --- a/internal/app/bootstrap.go +++ b/internal/app/bootstrap.go @@ -194,13 +194,15 @@ func BuildRuntime(ctx context.Context, opts BootstrapOptions) (RuntimeBundle, er } needCleanup = false + closeBundle := combineRuntimeClosers(toolsCleanup, sessionStore.Close) + return RuntimeBundle{ Config: cfg, ConfigManager: manager, Runtime: runtimeSvc, ProviderSelection: providerSelection, MemoService: memoSvc, - Close: toolsCleanup, + Close: closeBundle, }, nil } @@ -328,3 +330,19 @@ type runtimeAutoCompactThresholdResolverFunc func(ctx context.Context, cfg confi func (f runtimeAutoCompactThresholdResolverFunc) ResolveAutoCompactThreshold(ctx context.Context, cfg config.Config) (int, error) { return f(ctx, cfg) } + +// combineRuntimeClosers 按顺序执行 runtime 初始化阶段注册的清理函数。 +func combineRuntimeClosers(closers ...func() error) func() error { + return func() error { + var firstErr error + for _, closer := range closers { + if closer == nil { + continue + } + if err := closer(); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr + } +} diff --git a/internal/app/bootstrap_test.go b/internal/app/bootstrap_test.go index ad84cb48..29265de6 100644 --- a/internal/app/bootstrap_test.go +++ b/internal/app/bootstrap_test.go @@ -736,6 +736,13 @@ func TestBuildRuntimeSucceedsWhenSkillsRootMissing(t *testing.T) { if err != nil { t.Fatalf("BuildRuntime() error = %v", err) } + if bundle.Close != nil { + t.Cleanup(func() { + if err := bundle.Close(); err != nil { + t.Fatalf("bundle.Close() error = %v", err) + } + }) + } if bundle.Runtime == nil { t.Fatalf("expected runtime bundle to be created") } @@ -748,8 +755,20 @@ func TestBuildRuntimeSucceedsWhenSkillsRootMissing(t *testing.T) { } store := agentsession.NewStore(bundle.ConfigManager.BaseDir(), bundle.Config.Workdir) + t.Cleanup(func() { + if err := store.Close(); err != nil { + t.Fatalf("store.Close() error = %v", err) + } + }) session := agentsession.New("missing root session") - if err := store.Save(context.Background(), &session); err != nil { + session, err = store.CreateSession(context.Background(), agentsession.CreateSessionInput{ + ID: session.ID, + Title: session.Title, + CreatedAt: session.CreatedAt, + UpdatedAt: session.UpdatedAt, + Workdir: session.Workdir, + }) + if err != nil { t.Fatalf("save session: %v", err) } @@ -787,10 +806,29 @@ func TestBuildRuntimeInjectsSkillsRegistryWhenRootExists(t *testing.T) { if err != nil { t.Fatalf("BuildRuntime() error = %v", err) } + if bundle.Close != nil { + t.Cleanup(func() { + if err := bundle.Close(); err != nil { + t.Fatalf("bundle.Close() error = %v", err) + } + }) + } store := agentsession.NewStore(bundle.ConfigManager.BaseDir(), bundle.Config.Workdir) + t.Cleanup(func() { + if err := store.Close(); err != nil { + t.Fatalf("store.Close() error = %v", err) + } + }) session := agentsession.New("skill session") - if err := store.Save(context.Background(), &session); err != nil { + session, err = store.CreateSession(context.Background(), agentsession.CreateSessionInput{ + ID: session.ID, + Title: session.Title, + CreatedAt: session.CreatedAt, + UpdatedAt: session.UpdatedAt, + Workdir: session.Workdir, + }) + if err != nil { t.Fatalf("save session: %v", err) } @@ -804,7 +842,7 @@ func TestBuildRuntimeInjectsSkillsRegistryWhenRootExists(t *testing.T) { t.Fatalf("ActivateSessionSkill() error = %v", err) } - loaded, err := store.Load(context.Background(), session.ID) + loaded, err := store.LoadSession(context.Background(), session.ID) if err != nil { t.Fatalf("load session: %v", err) } diff --git a/internal/runtime/compact.go b/internal/runtime/compact.go index 8af6a7c2..1d16c9ee 100644 --- a/internal/runtime/compact.go +++ b/internal/runtime/compact.go @@ -76,7 +76,7 @@ func (s *Service) Compact(ctx context.Context, input CompactInput) (CompactResul }() cfg := s.configManager.Get() - session, err := s.sessionStore.Load(ctx, input.SessionID) + session, err := s.sessionStore.LoadSession(ctx, input.SessionID) if err != nil { return CompactResult{}, err } @@ -144,7 +144,7 @@ func (s *Service) runCompactForSession( session.TokenInputTotal = 0 session.TokenOutputTotal = 0 session.UpdatedAt = time.Now() - if err := s.sessionStore.Save(ctx, &session); err != nil { + if err := s.sessionStore.ReplaceTranscript(ctx, replaceTranscriptInputFromSession(session)); err != nil { session.Messages = originalMessages session.TaskState = originalTaskState session.TokenInputTotal = originalTokenInputTotal diff --git a/internal/runtime/input_prepare_test.go b/internal/runtime/input_prepare_test.go index 5b9e4f66..d3f53b4f 100644 --- a/internal/runtime/input_prepare_test.go +++ b/internal/runtime/input_prepare_test.go @@ -148,7 +148,7 @@ func TestServicePrepareUserInputDoesNotBlockWhenPrepareEventQueueIsFull(t *testi } } -func newPrepareTestService(t *testing.T, workdir string, withPreparer bool) (*Service, *agentsession.JSONStore) { +func newPrepareTestService(t *testing.T, workdir string, withPreparer bool) (*Service, *agentsession.SQLiteStore) { t.Helper() cfg := config.StaticDefaults() @@ -160,6 +160,9 @@ func newPrepareTestService(t *testing.T, workdir string, withPreparer bool) (*Se } store := agentsession.NewStore(t.TempDir(), workdir) + t.Cleanup(func() { + _ = store.Close() + }) svc := NewWithFactory(manager, nil, store, nil, nil) svc.SetSessionAssetStore(store) if withPreparer { diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 580664af..6d6eb421 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -139,8 +139,14 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { if strings.TrimSpace(turnResult.assistant.Role) == "" { turnResult.assistant.Role = providertypes.RoleAssistant } - state.recordUsage(turnResult.inputTokens, turnResult.outputTokens) - if err := s.appendAssistantMessageAndSave(ctx, &state, snapshot, turnResult.assistant); err != nil { + if err := s.appendAssistantMessageAndSave( + ctx, + &state, + snapshot, + turnResult.assistant, + turnResult.inputTokens, + turnResult.outputTokens, + ); err != nil { return s.handleRunError(ctx, state.runID, state.session.ID, err) } s.emitTokenUsage(ctx, &state, turnResult) diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index ebe74ef6..07c6cbc1 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -214,7 +214,7 @@ func (s *Service) ListSessions(ctx context.Context) ([]agentsession.Summary, err // LoadSession 按 id 加载完整会话内容。 func (s *Service) LoadSession(ctx context.Context, id string) (agentsession.Session, error) { - session, err := s.sessionStore.Load(ctx, id) + session, err := s.sessionStore.LoadSession(ctx, id) if err != nil { return agentsession.Session{}, err } diff --git a/internal/runtime/runtime_internal_helpers_test.go b/internal/runtime/runtime_internal_helpers_test.go index 235665a9..8e21a062 100644 --- a/internal/runtime/runtime_internal_helpers_test.go +++ b/internal/runtime/runtime_internal_helpers_test.go @@ -22,17 +22,21 @@ type stubMemoExtractor struct { } type lockProbeStore struct { - saveFn func(ctx context.Context, session *agentsession.Session) error + appendFn func(ctx context.Context, input agentsession.AppendMessagesInput) error } -func (s *lockProbeStore) Save(ctx context.Context, session *agentsession.Session) error { - if s.saveFn == nil { +func (s *lockProbeStore) CreateSession(ctx context.Context, input agentsession.CreateSessionInput) (agentsession.Session, error) { + return agentsession.Session{}, errors.New("not implemented") +} + +func (s *lockProbeStore) AppendMessages(ctx context.Context, input agentsession.AppendMessagesInput) error { + if s.appendFn == nil { return nil } - return s.saveFn(ctx, session) + return s.appendFn(ctx, input) } -func (s *lockProbeStore) Load(ctx context.Context, id string) (agentsession.Session, error) { +func (s *lockProbeStore) LoadSession(ctx context.Context, id string) (agentsession.Session, error) { return agentsession.Session{}, errors.New("not implemented") } @@ -40,7 +44,16 @@ func (s *lockProbeStore) ListSummaries(ctx context.Context) ([]agentsession.Summ return nil, errors.New("not implemented") } -func (s *lockProbeStore) DeleteSession(ctx context.Context, id string) error { +// UpdateSessionWorkdir 仅为接口占位,当前测试不会走到该分支。 +func (s *lockProbeStore) UpdateSessionWorkdir(ctx context.Context, input agentsession.UpdateSessionWorkdirInput) error { + return errors.New("not implemented") +} + +func (s *lockProbeStore) UpdateSessionState(ctx context.Context, input agentsession.UpdateSessionStateInput) error { + return errors.New("not implemented") +} + +func (s *lockProbeStore) ReplaceTranscript(ctx context.Context, input agentsession.ReplaceTranscriptInput) error { return errors.New("not implemented") } @@ -141,7 +154,14 @@ func TestAppendAssistantMessageAndSaveMetadataBranches(t *testing.T) { model: "gpt-4.1", } - if err := service.appendAssistantMessageAndSave(context.Background(), &state, snapshot, providertypes.Message{Role: providertypes.RoleAssistant}); err != nil { + if err := service.appendAssistantMessageAndSave( + context.Background(), + &state, + snapshot, + providertypes.Message{Role: providertypes.RoleAssistant}, + 0, + 0, + ); err != nil { t.Fatalf("appendAssistantMessageAndSave() error = %v", err) } if store.saves != 1 { @@ -151,7 +171,14 @@ func TestAppendAssistantMessageAndSaveMetadataBranches(t *testing.T) { store.saves = 0 state.session.Provider = snapshot.providerConfig.Name state.session.Model = snapshot.model - if err := service.appendAssistantMessageAndSave(context.Background(), &state, snapshot, providertypes.Message{Role: providertypes.RoleAssistant}); err != nil { + if err := service.appendAssistantMessageAndSave( + context.Background(), + &state, + snapshot, + providertypes.Message{Role: providertypes.RoleAssistant}, + 0, + 0, + ); err != nil { t.Fatalf("appendAssistantMessageAndSave() error = %v", err) } if store.saves != 0 { @@ -313,7 +340,7 @@ func TestAppendToolMessageAndSaveUnlocksStateBeforePersist(t *testing.T) { state := newRunState("run-append-tool-lock", session) store := &lockProbeStore{ - saveFn: func(_ context.Context, _ *agentsession.Session) error { + appendFn: func(_ context.Context, _ agentsession.AppendMessagesInput) error { locked := make(chan struct{}) go func() { state.mu.Lock() diff --git a/internal/runtime/runtime_remaining_branches_test.go b/internal/runtime/runtime_remaining_branches_test.go index 4fade2ea..e5fd395c 100644 --- a/internal/runtime/runtime_remaining_branches_test.go +++ b/internal/runtime/runtime_remaining_branches_test.go @@ -55,11 +55,14 @@ type saveHookStore struct { saveHook func() } -func (s *saveHookStore) Save(ctx context.Context, session *agentsession.Session) error { +func (s *saveHookStore) beforeSave(ctx context.Context) error { + if err := ctx.Err(); err != nil { + return err + } if s.saveHook != nil { s.saveHook() } - return s.base.Save(ctx, session) + return nil } type postSaveHookStore struct { @@ -68,36 +71,85 @@ type postSaveHookStore struct { once sync.Once } -func (s *postSaveHookStore) Save(ctx context.Context, session *agentsession.Session) error { - err := s.base.Save(ctx, session) +func (s *postSaveHookStore) afterSave(err error) error { if err == nil && s.saveHook != nil { s.once.Do(s.saveHook) } return err } -func (s *saveHookStore) Load(ctx context.Context, id string) (agentsession.Session, error) { - return s.base.Load(ctx, id) +func (s *saveHookStore) CreateSession(ctx context.Context, input agentsession.CreateSessionInput) (agentsession.Session, error) { + if err := s.beforeSave(ctx); err != nil { + return agentsession.Session{}, err + } + return s.base.CreateSession(ctx, input) +} + +func (s *saveHookStore) LoadSession(ctx context.Context, id string) (agentsession.Session, error) { + return s.base.LoadSession(ctx, id) } func (s *saveHookStore) ListSummaries(ctx context.Context) ([]agentsession.Summary, error) { return s.base.ListSummaries(ctx) } -func (s *saveHookStore) DeleteSession(ctx context.Context, id string) error { - return s.base.DeleteSession(ctx, id) +func (s *saveHookStore) AppendMessages(ctx context.Context, input agentsession.AppendMessagesInput) error { + if err := s.beforeSave(ctx); err != nil { + return err + } + return s.base.AppendMessages(ctx, input) +} + +// UpdateSessionWorkdir 在写前注入回调,再转发给底层内存 store。 +func (s *saveHookStore) UpdateSessionWorkdir(ctx context.Context, input agentsession.UpdateSessionWorkdirInput) error { + if err := s.beforeSave(ctx); err != nil { + return err + } + return s.base.UpdateSessionWorkdir(ctx, input) } -func (s *postSaveHookStore) Load(ctx context.Context, id string) (agentsession.Session, error) { - return s.base.Load(ctx, id) +func (s *saveHookStore) UpdateSessionState(ctx context.Context, input agentsession.UpdateSessionStateInput) error { + if err := s.beforeSave(ctx); err != nil { + return err + } + return s.base.UpdateSessionState(ctx, input) +} + +func (s *saveHookStore) ReplaceTranscript(ctx context.Context, input agentsession.ReplaceTranscriptInput) error { + if err := s.beforeSave(ctx); err != nil { + return err + } + return s.base.ReplaceTranscript(ctx, input) +} + +func (s *postSaveHookStore) CreateSession(ctx context.Context, input agentsession.CreateSessionInput) (agentsession.Session, error) { + session, err := s.base.CreateSession(ctx, input) + return session, s.afterSave(err) +} + +func (s *postSaveHookStore) LoadSession(ctx context.Context, id string) (agentsession.Session, error) { + return s.base.LoadSession(ctx, id) } func (s *postSaveHookStore) ListSummaries(ctx context.Context) ([]agentsession.Summary, error) { return s.base.ListSummaries(ctx) } -func (s *postSaveHookStore) DeleteSession(ctx context.Context, id string) error { - return s.base.DeleteSession(ctx, id) +func (s *postSaveHookStore) AppendMessages(ctx context.Context, input agentsession.AppendMessagesInput) error { + return s.afterSave(s.base.AppendMessages(ctx, input)) +} + +// UpdateSessionWorkdir 在底层写入完成后执行一次 post-save 钩子。 +func (s *postSaveHookStore) UpdateSessionWorkdir(ctx context.Context, input agentsession.UpdateSessionWorkdirInput) error { + return s.afterSave(s.base.UpdateSessionWorkdir(ctx, input)) +} + +func (s *postSaveHookStore) UpdateSessionState(ctx context.Context, input agentsession.UpdateSessionStateInput) error { + return s.afterSave(s.base.UpdateSessionState(ctx, input)) +} + +func (s *postSaveHookStore) ReplaceTranscript(ctx context.Context, input agentsession.ReplaceTranscriptInput) error { + return s.afterSave(s.base.ReplaceTranscript(ctx, input)) } func TestResolveCompactProviderSelectionResolveErrorBranch(t *testing.T) { diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 11d73f25..e771ff01 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -27,6 +27,7 @@ import ( ) type memoryStore struct { + mu sync.Mutex sessions map[string]agentsession.Session saves int } @@ -49,7 +50,8 @@ func newMemoryStore() *memoryStore { return &memoryStore{sessions: map[string]agentsession.Session{}} } -func (s *failingStore) Save(ctx context.Context, session *agentsession.Session) error { +// nextSaveError 模拟旧 save hook 语义,对所有持久化写操作统一计数注入失败。 +func (s *failingStore) nextSaveError(ctx context.Context) error { s.saveCalls++ if s.failOnSave > 0 && s.saveCalls == s.failOnSave { return s.saveErr @@ -57,28 +59,50 @@ func (s *failingStore) Save(ctx context.Context, session *agentsession.Session) if s.ignoreContextErr && s.saveErr != nil { return s.saveErr } - if s.Store == nil { - return nil + if err := ctx.Err(); err != nil { + return err } - return s.Store.Save(ctx, session) + return nil } -func (s *memoryStore) Save(ctx context.Context, session *agentsession.Session) error { +// CreateSession 在内存中创建一条完整会话记录,供 runtime 测试使用。 +func (s *memoryStore) CreateSession(ctx context.Context, input agentsession.CreateSessionInput) (agentsession.Session, error) { if err := ctx.Err(); err != nil { - return err + return agentsession.Session{}, err + } + session := agentsession.NewWithWorkdir(input.Title, input.Workdir) + if strings.TrimSpace(input.ID) != "" { + session.ID = input.ID } - if session == nil { - return errors.New("nil session") + if !input.CreatedAt.IsZero() { + session.CreatedAt = input.CreatedAt } + if !input.UpdatedAt.IsZero() { + session.UpdatedAt = input.UpdatedAt + } + session.Provider = input.Provider + session.Model = input.Model + session.TaskState = input.TaskState.Clone() + session.ActivatedSkills = agentsessionCloneSkillActivations(input.ActivatedSkills) + session.Todos = cloneTodosForPersistence(input.Todos) + session.TokenInputTotal = input.TokenInputTotal + session.TokenOutputTotal = input.TokenOutputTotal + session.Messages = []providertypes.Message{} + + s.mu.Lock() + defer s.mu.Unlock() s.saves++ - s.sessions[session.ID] = cloneSession(*session) - return nil + s.sessions[session.ID] = cloneSession(session) + return cloneSession(session), nil } -func (s *memoryStore) Load(ctx context.Context, id string) (agentsession.Session, error) { +// LoadSession 从内存快照返回完整会话副本。 +func (s *memoryStore) LoadSession(ctx context.Context, id string) (agentsession.Session, error) { if err := ctx.Err(); err != nil { return agentsession.Session{}, err } + s.mu.Lock() + defer s.mu.Unlock() session, ok := s.sessions[id] if !ok { return agentsession.Session{}, errors.New("not found") @@ -86,10 +110,18 @@ func (s *memoryStore) Load(ctx context.Context, id string) (agentsession.Session return cloneSession(session), nil } +// Load 作为测试辅助别名保留,便于沿用现有断言代码。 +func (s *memoryStore) Load(ctx context.Context, id string) (agentsession.Session, error) { + return s.LoadSession(ctx, id) +} + +// ListSummaries 返回所有会话摘要。 func (s *memoryStore) ListSummaries(ctx context.Context) ([]agentsession.Summary, error) { if err := ctx.Err(); err != nil { return nil, err } + s.mu.Lock() + defer s.mu.Unlock() summaries := make([]agentsession.Summary, 0, len(s.sessions)) for _, session := range s.sessions { summaries = append(summaries, agentsession.Summary{ @@ -102,14 +134,182 @@ func (s *memoryStore) ListSummaries(ctx context.Context) ([]agentsession.Summary return summaries, nil } -func (s *memoryStore) DeleteSession(ctx context.Context, id string) error { +// AppendMessages 追加消息并同步更新会话头的增量字段。 +func (s *memoryStore) AppendMessages(ctx context.Context, input agentsession.AppendMessagesInput) error { + if err := ctx.Err(); err != nil { + return err + } + s.mu.Lock() + defer s.mu.Unlock() + + session, ok := s.sessions[input.SessionID] + if !ok { + return errors.New("not found") + } + session.Messages = append(session.Messages, cloneMessagesForPersistence(input.Messages)...) + if !input.UpdatedAt.IsZero() { + session.UpdatedAt = input.UpdatedAt + } + session.Provider = input.Provider + session.Model = input.Model + session.Workdir = input.Workdir + session.TokenInputTotal += input.TokenInputDelta + session.TokenOutputTotal += input.TokenOutputDelta + s.saves++ + s.sessions[input.SessionID] = cloneSession(session) + return nil +} + +// UpdateSessionState 只覆写会话头字段,不改消息正文。 +// UpdateSessionWorkdir 仅更新会话 workdir 与时间,避免输入归一化覆盖其他会话头字段。 +func (s *memoryStore) UpdateSessionWorkdir(ctx context.Context, input agentsession.UpdateSessionWorkdirInput) error { + if err := ctx.Err(); err != nil { + return err + } + s.mu.Lock() + defer s.mu.Unlock() + + session, ok := s.sessions[input.SessionID] + if !ok { + return errors.New("not found") + } + if !input.UpdatedAt.IsZero() { + session.UpdatedAt = input.UpdatedAt + } + session.Workdir = input.Workdir + s.saves++ + s.sessions[input.SessionID] = cloneSession(session) + return nil +} + +func (s *memoryStore) UpdateSessionState(ctx context.Context, input agentsession.UpdateSessionStateInput) error { + if err := ctx.Err(); err != nil { + return err + } + s.mu.Lock() + defer s.mu.Unlock() + + session, ok := s.sessions[input.SessionID] + if !ok { + return errors.New("not found") + } + session.Title = input.Title + if !input.UpdatedAt.IsZero() { + session.UpdatedAt = input.UpdatedAt + } + session.Provider = input.Provider + session.Model = input.Model + session.Workdir = input.Workdir + session.TaskState = input.TaskState.Clone() + session.ActivatedSkills = agentsessionCloneSkillActivations(input.ActivatedSkills) + session.Todos = cloneTodosForPersistence(input.Todos) + session.TokenInputTotal = input.TokenInputTotal + session.TokenOutputTotal = input.TokenOutputTotal + s.saves++ + s.sessions[input.SessionID] = cloneSession(session) + return nil +} + +// ReplaceTranscript 用新的消息切片替换原会话 transcript,并同步会话头状态。 +func (s *memoryStore) ReplaceTranscript(ctx context.Context, input agentsession.ReplaceTranscriptInput) error { if err := ctx.Err(); err != nil { return err } - delete(s.sessions, id) + s.mu.Lock() + defer s.mu.Unlock() + + session, ok := s.sessions[input.SessionID] + if !ok { + return errors.New("not found") + } + session.Messages = cloneMessagesForPersistence(input.Messages) + if !input.UpdatedAt.IsZero() { + session.UpdatedAt = input.UpdatedAt + } + session.Provider = input.Provider + session.Model = input.Model + session.Workdir = input.Workdir + session.TaskState = input.TaskState.Clone() + session.ActivatedSkills = agentsessionCloneSkillActivations(input.ActivatedSkills) + session.Todos = cloneTodosForPersistence(input.Todos) + session.TokenInputTotal = input.TokenInputTotal + session.TokenOutputTotal = input.TokenOutputTotal + s.saves++ + s.sessions[input.SessionID] = cloneSession(session) return nil } +// CreateSession 转发到底层 Store,并按旧 save 计数规则注入失败。 +func (s *failingStore) CreateSession(ctx context.Context, input agentsession.CreateSessionInput) (agentsession.Session, error) { + if err := s.nextSaveError(ctx); err != nil { + return agentsession.Session{}, err + } + if s.Store == nil { + return agentsession.Session{}, nil + } + return s.Store.CreateSession(ctx, input) +} + +// LoadSession 直接透传到底层 Store。 +func (s *failingStore) LoadSession(ctx context.Context, id string) (agentsession.Session, error) { + if s.Store == nil { + return agentsession.Session{}, errors.New("not found") + } + return s.Store.LoadSession(ctx, id) +} + +// ListSummaries 直接透传到底层 Store。 +func (s *failingStore) ListSummaries(ctx context.Context) ([]agentsession.Summary, error) { + if s.Store == nil { + return nil, nil + } + return s.Store.ListSummaries(ctx) +} + +// AppendMessages 转发到底层 Store,并按写入次数注入失败。 +func (s *failingStore) AppendMessages(ctx context.Context, input agentsession.AppendMessagesInput) error { + if err := s.nextSaveError(ctx); err != nil { + return err + } + if s.Store == nil { + return nil + } + return s.Store.AppendMessages(ctx, input) +} + +// UpdateSessionState 转发到底层 Store,并按写入次数注入失败。 +// UpdateSessionWorkdir 杞彂鍒板簳灞?Store锛屽苟鎸夊啓鍏ユ鏁版敞鍏ュけ璐ャ€? +func (s *failingStore) UpdateSessionWorkdir(ctx context.Context, input agentsession.UpdateSessionWorkdirInput) error { + if err := s.nextSaveError(ctx); err != nil { + return err + } + if s.Store == nil { + return nil + } + return s.Store.UpdateSessionWorkdir(ctx, input) +} + +func (s *failingStore) UpdateSessionState(ctx context.Context, input agentsession.UpdateSessionStateInput) error { + if err := s.nextSaveError(ctx); err != nil { + return err + } + if s.Store == nil { + return nil + } + return s.Store.UpdateSessionState(ctx, input) +} + +// ReplaceTranscript 转发到底层 Store,并按写入次数注入失败。 +func (s *failingStore) ReplaceTranscript(ctx context.Context, input agentsession.ReplaceTranscriptInput) error { + if err := s.nextSaveError(ctx); err != nil { + return err + } + if s.Store == nil { + return nil + } + return s.Store.ReplaceTranscript(ctx, input) +} + // blockingLoadStore 用于并发测试:首次 Load 阻塞,以验证同 session 的锁时序。 type blockingLoadStore struct { mu sync.Mutex @@ -128,20 +328,36 @@ func newBlockingLoadStore() *blockingLoadStore { } } -func (s *blockingLoadStore) Save(ctx context.Context, session *agentsession.Session) error { +// CreateSession 在阻塞加载测试桩中创建会话记录。 +func (s *blockingLoadStore) CreateSession(ctx context.Context, input agentsession.CreateSessionInput) (agentsession.Session, error) { if err := ctx.Err(); err != nil { - return err + return agentsession.Session{}, err + } + session := agentsession.NewWithWorkdir(input.Title, input.Workdir) + if strings.TrimSpace(input.ID) != "" { + session.ID = input.ID } - if session == nil { - return errors.New("nil session") + if !input.CreatedAt.IsZero() { + session.CreatedAt = input.CreatedAt } + if !input.UpdatedAt.IsZero() { + session.UpdatedAt = input.UpdatedAt + } + session.Provider = input.Provider + session.Model = input.Model + session.TaskState = input.TaskState.Clone() + session.ActivatedSkills = agentsessionCloneSkillActivations(input.ActivatedSkills) + session.Todos = cloneTodosForPersistence(input.Todos) + session.TokenInputTotal = input.TokenInputTotal + session.TokenOutputTotal = input.TokenOutputTotal s.mu.Lock() - s.sessions[session.ID] = cloneSession(*session) + s.sessions[session.ID] = cloneSession(session) s.mu.Unlock() - return nil + return cloneSession(session), nil } -func (s *blockingLoadStore) Load(ctx context.Context, id string) (agentsession.Session, error) { +// LoadSession 首次调用时阻塞,用于验证同 session 锁时序。 +func (s *blockingLoadStore) LoadSession(ctx context.Context, id string) (agentsession.Session, error) { if err := ctx.Err(); err != nil { return agentsession.Session{}, err } @@ -167,6 +383,103 @@ func (s *blockingLoadStore) Load(ctx context.Context, id string) (agentsession.S return cloneSession(session), nil } +// AppendMessages 在阻塞加载测试桩中追加消息。 +func (s *blockingLoadStore) AppendMessages(ctx context.Context, input agentsession.AppendMessagesInput) error { + if err := ctx.Err(); err != nil { + return err + } + s.mu.Lock() + defer s.mu.Unlock() + session, ok := s.sessions[input.SessionID] + if !ok { + return errors.New("not found") + } + session.Messages = append(session.Messages, cloneMessagesForPersistence(input.Messages)...) + if !input.UpdatedAt.IsZero() { + session.UpdatedAt = input.UpdatedAt + } + session.Provider = input.Provider + session.Model = input.Model + session.Workdir = input.Workdir + session.TokenInputTotal += input.TokenInputDelta + session.TokenOutputTotal += input.TokenOutputDelta + s.sessions[input.SessionID] = cloneSession(session) + return nil +} + +// UpdateSessionState 在阻塞加载测试桩中更新会话头。 +// UpdateSessionWorkdir 鍦ㄩ樆濉炲姞杞芥祴璇曟々涓粎鏇存柊 workdir 涓庢椂闂淬€? +func (s *blockingLoadStore) UpdateSessionWorkdir(ctx context.Context, input agentsession.UpdateSessionWorkdirInput) error { + if err := ctx.Err(); err != nil { + return err + } + s.mu.Lock() + defer s.mu.Unlock() + session, ok := s.sessions[input.SessionID] + if !ok { + return errors.New("not found") + } + if !input.UpdatedAt.IsZero() { + session.UpdatedAt = input.UpdatedAt + } + session.Workdir = input.Workdir + s.sessions[input.SessionID] = cloneSession(session) + return nil +} + +func (s *blockingLoadStore) UpdateSessionState(ctx context.Context, input agentsession.UpdateSessionStateInput) error { + if err := ctx.Err(); err != nil { + return err + } + s.mu.Lock() + defer s.mu.Unlock() + session, ok := s.sessions[input.SessionID] + if !ok { + return errors.New("not found") + } + session.Title = input.Title + if !input.UpdatedAt.IsZero() { + session.UpdatedAt = input.UpdatedAt + } + session.Provider = input.Provider + session.Model = input.Model + session.Workdir = input.Workdir + session.TaskState = input.TaskState.Clone() + session.ActivatedSkills = agentsessionCloneSkillActivations(input.ActivatedSkills) + session.Todos = cloneTodosForPersistence(input.Todos) + session.TokenInputTotal = input.TokenInputTotal + session.TokenOutputTotal = input.TokenOutputTotal + s.sessions[input.SessionID] = cloneSession(session) + return nil +} + +// ReplaceTranscript 在阻塞加载测试桩中重写会话消息。 +func (s *blockingLoadStore) ReplaceTranscript(ctx context.Context, input agentsession.ReplaceTranscriptInput) error { + if err := ctx.Err(); err != nil { + return err + } + s.mu.Lock() + defer s.mu.Unlock() + session, ok := s.sessions[input.SessionID] + if !ok { + return errors.New("not found") + } + session.Messages = cloneMessagesForPersistence(input.Messages) + if !input.UpdatedAt.IsZero() { + session.UpdatedAt = input.UpdatedAt + } + session.Provider = input.Provider + session.Model = input.Model + session.Workdir = input.Workdir + session.TaskState = input.TaskState.Clone() + session.ActivatedSkills = agentsessionCloneSkillActivations(input.ActivatedSkills) + session.Todos = cloneTodosForPersistence(input.Todos) + session.TokenInputTotal = input.TokenInputTotal + session.TokenOutputTotal = input.TokenOutputTotal + s.sessions[input.SessionID] = cloneSession(session) + return nil +} + func (s *blockingLoadStore) ListSummaries(ctx context.Context) ([]agentsession.Summary, error) { if err := ctx.Err(); err != nil { return nil, err @@ -185,16 +498,6 @@ func (s *blockingLoadStore) ListSummaries(ctx context.Context) ([]agentsession.S return summaries, nil } -func (s *blockingLoadStore) DeleteSession(ctx context.Context, id string) error { - if err := ctx.Err(); err != nil { - return err - } - s.mu.Lock() - delete(s.sessions, id) - s.mu.Unlock() - return nil -} - type scriptedProvider struct { name string streams [][]providertypes.StreamEvent diff --git a/internal/runtime/session_mutation.go b/internal/runtime/session_mutation.go index 1738fa13..3a6c65c8 100644 --- a/internal/runtime/session_mutation.go +++ b/internal/runtime/session_mutation.go @@ -11,7 +11,7 @@ import ( const toolNameMetadataKey = "tool_name" -// appendUserMessageAndSave 将用户消息追加到会话并立即持久化。 +// appendUserMessageAndSave 将用户消息追加到会话,并立即落盘为一条增量消息。 func (s *Service) appendUserMessageAndSave(ctx context.Context, state *runState, parts []providertypes.ContentPart) error { message := providertypes.Message{ Role: providertypes.RoleUser, @@ -19,37 +19,57 @@ func (s *Service) appendUserMessageAndSave(ctx context.Context, state *runState, } state.session.Messages = append(state.session.Messages, message) state.touchSession() - if err := s.sessionStore.Save(ctx, &state.session); err != nil { + if err := s.sessionStore.AppendMessages(ctx, agentsession.AppendMessagesInput{ + SessionID: state.session.ID, + Messages: []providertypes.Message{message}, + UpdatedAt: state.session.UpdatedAt, + Provider: state.session.Provider, + Model: state.session.Model, + Workdir: state.session.Workdir, + }); err != nil { return err } s.emitRunScoped(ctx, EventUserMessage, state, message) return nil } -// appendAssistantMessageAndSave 将 assistant 消息和本轮模型元数据写回会话。 +// appendAssistantMessageAndSave 将 assistant 消息和本轮 token/provider/model 增量写回会话。 func (s *Service) appendAssistantMessageAndSave( ctx context.Context, state *runState, snapshot turnSnapshot, assistant providertypes.Message, + inputTokens int, + outputTokens int, ) error { metadataChanged := state.session.Provider != snapshot.providerConfig.Name || state.session.Model != snapshot.model state.session.Provider = snapshot.providerConfig.Name state.session.Model = snapshot.model + state.recordUsage(inputTokens, outputTokens) if !assistant.IsEmpty() { state.session.Messages = append(state.session.Messages, assistant) state.touchSession() - return s.sessionStore.Save(ctx, &state.session) + return s.sessionStore.AppendMessages(ctx, agentsession.AppendMessagesInput{ + SessionID: state.session.ID, + Messages: []providertypes.Message{assistant}, + UpdatedAt: state.session.UpdatedAt, + Provider: state.session.Provider, + Model: state.session.Model, + Workdir: state.session.Workdir, + TokenInputDelta: inputTokens, + TokenOutputDelta: outputTokens, + }) } - if metadataChanged { + + if metadataChanged || inputTokens != 0 || outputTokens != 0 { state.touchSession() - return s.sessionStore.Save(ctx, &state.session) + return s.sessionStore.UpdateSessionState(ctx, sessionStateInputFromSession(state.session)) } return nil } -// appendToolMessageAndSave 将工具原始结果写回会话,避免污染持久化对话内容。 +// appendToolMessageAndSave 将工具原始结果写回会话,持久化时仅追加一条 tool message。 func (s *Service) appendToolMessageAndSave( ctx context.Context, state *runState, @@ -60,9 +80,16 @@ func (s *Service) appendToolMessageAndSave( toolMessage := normalizeToolMessageForPersistence(call, result) state.session.Messages = append(state.session.Messages, toolMessage) state.touchSession() - sessionSnapshot := cloneSessionForPersistence(state.session) + input := agentsession.AppendMessagesInput{ + SessionID: state.session.ID, + Messages: []providertypes.Message{toolMessage}, + UpdatedAt: state.session.UpdatedAt, + Provider: state.session.Provider, + Model: state.session.Model, + Workdir: state.session.Workdir, + } state.mu.Unlock() - return s.sessionStore.Save(ctx, &sessionSnapshot) + return s.sessionStore.AppendMessages(ctx, input) } // normalizeToolMessageForPersistence 负责在写入会话前收敛工具结果,避免成功结果落成完全空语义消息。 @@ -97,7 +124,59 @@ func hasNonToolNameToolMetadata(metadata map[string]string) bool { return false } -// cloneSessionForPersistence 复制会话快照,避免持久化阶段与并发写入共享可变切片/映射。 +// createSessionInputFromSession 将运行态 session 转为建库时使用的会话头输入。 +func createSessionInputFromSession(session agentsession.Session) agentsession.CreateSessionInput { + return agentsession.CreateSessionInput{ + ID: session.ID, + Title: session.Title, + CreatedAt: session.CreatedAt, + UpdatedAt: session.UpdatedAt, + Provider: session.Provider, + Model: session.Model, + Workdir: session.Workdir, + TaskState: session.TaskState.Clone(), + ActivatedSkills: agentsessionCloneSkillActivations(session.ActivatedSkills), + Todos: cloneTodosForPersistence(session.Todos), + TokenInputTotal: session.TokenInputTotal, + TokenOutputTotal: session.TokenOutputTotal, + } +} + +// sessionStateInputFromSession 将运行态 session 映射为只更新会话头的持久化输入。 +func sessionStateInputFromSession(session agentsession.Session) agentsession.UpdateSessionStateInput { + return agentsession.UpdateSessionStateInput{ + SessionID: session.ID, + Title: session.Title, + UpdatedAt: session.UpdatedAt, + Provider: session.Provider, + Model: session.Model, + Workdir: session.Workdir, + TaskState: session.TaskState.Clone(), + ActivatedSkills: agentsessionCloneSkillActivations(session.ActivatedSkills), + Todos: cloneTodosForPersistence(session.Todos), + TokenInputTotal: session.TokenInputTotal, + TokenOutputTotal: session.TokenOutputTotal, + } +} + +// replaceTranscriptInputFromSession 将完整 session 映射为 transcript 原子替换输入。 +func replaceTranscriptInputFromSession(session agentsession.Session) agentsession.ReplaceTranscriptInput { + return agentsession.ReplaceTranscriptInput{ + SessionID: session.ID, + Messages: cloneMessagesForPersistence(session.Messages), + UpdatedAt: session.UpdatedAt, + Provider: session.Provider, + Model: session.Model, + Workdir: session.Workdir, + TaskState: session.TaskState.Clone(), + ActivatedSkills: agentsessionCloneSkillActivations(session.ActivatedSkills), + Todos: cloneTodosForPersistence(session.Todos), + TokenInputTotal: session.TokenInputTotal, + TokenOutputTotal: session.TokenOutputTotal, + } +} + +// cloneSessionForPersistence 复制会话快照,避免并发读写共享底层切片和映射。 func cloneSessionForPersistence(session agentsession.Session) agentsession.Session { cloned := session cloned.Messages = cloneMessagesForPersistence(session.Messages) @@ -107,7 +186,7 @@ func cloneSessionForPersistence(session agentsession.Session) agentsession.Sessi return cloned } -// agentsessionCloneSkillActivations 深拷贝会话中的 skill 激活列表,避免持久化阶段共享底层切片。 +// agentsessionCloneSkillActivations 深拷贝会话中的 skill 激活列表,避免共享底层切片。 func agentsessionCloneSkillActivations(items []agentsession.SkillActivation) []agentsession.SkillActivation { if len(items) == 0 { return nil @@ -119,7 +198,7 @@ func agentsessionCloneSkillActivations(items []agentsession.SkillActivation) []a return cloned } -// cloneMessagesForPersistence 深拷贝消息切片及其嵌套字段,确保 Save 期间读取稳定。 +// cloneMessagesForPersistence 深拷贝消息切片及其嵌套字段,确保持久化和测试读取稳定。 func cloneMessagesForPersistence(messages []providertypes.Message) []providertypes.Message { if len(messages) == 0 { return nil @@ -146,7 +225,7 @@ func cloneMessagesForPersistence(messages []providertypes.Message) []providertyp return cloned } -// cloneTodosForPersistence 深拷贝 Todo 列表,确保持久化快照不与运行态共享底层切片。 +// cloneTodosForPersistence 深拷贝 Todo 列表,确保会话快照与运行态隔离。 func cloneTodosForPersistence(todos []agentsession.TodoItem) []agentsession.TodoItem { if len(todos) == 0 { return nil diff --git a/internal/runtime/session_scheduler.go b/internal/runtime/session_scheduler.go index ed783622..d790a9e3 100644 --- a/internal/runtime/session_scheduler.go +++ b/internal/runtime/session_scheduler.go @@ -24,13 +24,14 @@ func (s *Service) loadOrCreateSession( return agentsession.Session{}, err } session := agentsession.NewWithWorkdir(title, sessionWorkdir) - if err := s.sessionStore.Save(ctx, &session); err != nil { + session, err = s.sessionStore.CreateSession(ctx, createSessionInputFromSession(session)) + if err != nil { return agentsession.Session{}, err } return session, nil } - session, err := s.sessionStore.Load(ctx, sessionID) + session, err := s.sessionStore.LoadSession(ctx, sessionID) if err != nil { return agentsession.Session{}, err } @@ -48,7 +49,7 @@ func (s *Service) loadOrCreateSession( session.Workdir = resolved session.UpdatedAt = time.Now() - if err := s.sessionStore.Save(ctx, &session); err != nil { + if err := s.sessionStore.UpdateSessionState(ctx, sessionStateInputFromSession(session)); err != nil { return agentsession.Session{}, err } return session, nil diff --git a/internal/runtime/skills.go b/internal/runtime/skills.go index 6f5bcd4a..e5eda75b 100644 --- a/internal/runtime/skills.go +++ b/internal/runtime/skills.go @@ -78,7 +78,7 @@ func (s *Service) ListSessionSkills(ctx context.Context, sessionID string) ([]Se return nil, errors.New("runtime: session id is empty") } - session, err := s.sessionStore.Load(ctx, sessionID) + session, err := s.sessionStore.LoadSession(ctx, sessionID) if err != nil { return nil, err } @@ -180,7 +180,7 @@ func (s *Service) mutateSessionSkills( releaseLockRef() }() - session, err := s.sessionStore.Load(ctx, sessionID) + session, err := s.sessionStore.LoadSession(ctx, sessionID) if err != nil { return agentsession.Session{}, false, err } @@ -189,7 +189,7 @@ func (s *Service) mutateSessionSkills( } session.UpdatedAt = time.Now() - if err := s.sessionStore.Save(ctx, &session); err != nil { + if err := s.sessionStore.UpdateSessionState(ctx, sessionStateInputFromSession(session)); err != nil { return agentsession.Session{}, false, err } return session, true, nil diff --git a/internal/runtime/todo_mutator.go b/internal/runtime/todo_mutator.go index 5fe7439e..7688909b 100644 --- a/internal/runtime/todo_mutator.go +++ b/internal/runtime/todo_mutator.go @@ -120,7 +120,7 @@ func (m *runtimeSessionMutator) mutateAndSave(mutate func(session *agentsession. return err } sessionSnapshot.UpdatedAt = time.Now() - if err := m.service.sessionStore.Save(m.ctx, &sessionSnapshot); err != nil { + if err := m.service.sessionStore.UpdateSessionState(m.ctx, sessionStateInputFromSession(sessionSnapshot)); err != nil { m.state.mu.Unlock() return err } diff --git a/internal/runtime/todo_mutator_test.go b/internal/runtime/todo_mutator_test.go index 8c5e4cce..7c9b9d8c 100644 --- a/internal/runtime/todo_mutator_test.go +++ b/internal/runtime/todo_mutator_test.go @@ -14,21 +14,22 @@ type mutatorStore struct { err error } -func (s *mutatorStore) Save(ctx context.Context, session *agentsession.Session) error { +func (s *mutatorStore) CreateSession(ctx context.Context, input agentsession.CreateSessionInput) (agentsession.Session, error) { if err := ctx.Err(); err != nil { - return err + return agentsession.Session{}, err } if s.err != nil { - return s.err + return agentsession.Session{}, s.err } - if session == nil { - return errors.New("nil session") + session := agentsession.NewWithWorkdir(input.Title, input.Workdir) + if input.ID != "" { + session.ID = input.ID } - s.last = cloneSessionForPersistence(*session) - return nil + s.last = cloneSessionForPersistence(session) + return cloneSessionForPersistence(session), nil } -func (s *mutatorStore) Load(ctx context.Context, id string) (agentsession.Session, error) { +func (s *mutatorStore) LoadSession(ctx context.Context, id string) (agentsession.Session, error) { if err := ctx.Err(); err != nil { return agentsession.Session{}, err } @@ -45,13 +46,88 @@ func (s *mutatorStore) ListSummaries(ctx context.Context) ([]agentsession.Summar return nil, nil } -func (s *mutatorStore) DeleteSession(ctx context.Context, id string) error { +func (s *mutatorStore) AppendMessages(ctx context.Context, input agentsession.AppendMessagesInput) error { if err := ctx.Err(); err != nil { return err } - if s.last.ID == id { - s.last = agentsession.Session{} + if s.err != nil { + return s.err + } + if s.last.ID != input.SessionID { + return errors.New("not found") + } + s.last.Messages = append(s.last.Messages, cloneMessagesForPersistence(input.Messages)...) + s.last.UpdatedAt = input.UpdatedAt + s.last.Provider = input.Provider + s.last.Model = input.Model + s.last.Workdir = input.Workdir + s.last.TokenInputTotal += input.TokenInputDelta + s.last.TokenOutputTotal += input.TokenOutputDelta + return nil +} + +// UpdateSessionWorkdir 仅更新 workdir 与更新时间,模拟最小粒度持久化。 +func (s *mutatorStore) UpdateSessionWorkdir(ctx context.Context, input agentsession.UpdateSessionWorkdirInput) error { + if err := ctx.Err(); err != nil { + return err + } + if s.err != nil { + return s.err + } + if s.last.ID != "" && s.last.ID != input.SessionID { + return errors.New("not found") + } + s.last.ID = input.SessionID + s.last.UpdatedAt = input.UpdatedAt + s.last.Workdir = input.Workdir + return nil +} + +func (s *mutatorStore) UpdateSessionState(ctx context.Context, input agentsession.UpdateSessionStateInput) error { + if err := ctx.Err(); err != nil { + return err + } + if s.err != nil { + return s.err + } + if s.last.ID != "" && s.last.ID != input.SessionID { + return errors.New("not found") + } + s.last.ID = input.SessionID + s.last.Title = input.Title + s.last.UpdatedAt = input.UpdatedAt + s.last.Provider = input.Provider + s.last.Model = input.Model + s.last.Workdir = input.Workdir + s.last.TaskState = input.TaskState.Clone() + s.last.ActivatedSkills = agentsessionCloneSkillActivations(input.ActivatedSkills) + s.last.Todos = cloneTodosForPersistence(input.Todos) + s.last.TokenInputTotal = input.TokenInputTotal + s.last.TokenOutputTotal = input.TokenOutputTotal + return nil +} + +func (s *mutatorStore) ReplaceTranscript(ctx context.Context, input agentsession.ReplaceTranscriptInput) error { + if err := ctx.Err(); err != nil { + return err + } + if s.err != nil { + return s.err } + if s.last.ID != "" && s.last.ID != input.SessionID { + return errors.New("not found") + } + s.last.ID = input.SessionID + s.last.Messages = cloneMessagesForPersistence(input.Messages) + s.last.UpdatedAt = input.UpdatedAt + s.last.Provider = input.Provider + s.last.Model = input.Model + s.last.Workdir = input.Workdir + s.last.TaskState = input.TaskState.Clone() + s.last.ActivatedSkills = agentsessionCloneSkillActivations(input.ActivatedSkills) + s.last.Todos = cloneTodosForPersistence(input.Todos) + s.last.TokenInputTotal = input.TokenInputTotal + s.last.TokenOutputTotal = input.TokenOutputTotal return nil } diff --git a/internal/session/asset_store.go b/internal/session/asset_store.go index d63b4c16..6fee5d3d 100644 --- a/internal/session/asset_store.go +++ b/internal/session/asset_store.go @@ -2,13 +2,12 @@ package session import ( "context" - "encoding/json" "fmt" "io" "strings" ) -// AssetMeta 描述会话附件最小元数据,用于 provider 请求阶段定位与发送。 +// AssetMeta 描述会话附件的最小元数据。 type AssetMeta struct { ID string `json:"id"` MimeType string `json:"mime_type"` @@ -22,12 +21,6 @@ type AssetStore interface { Stat(ctx context.Context, sessionID string, assetID string) (AssetMeta, error) } -type storedAssetMeta struct { - ID string `json:"id"` - MimeType string `json:"mime_type"` - Size int64 `json:"size"` -} - // newAssetMeta 生成新的会话附件元数据,并校验 MIME 约束。 func newAssetMeta(mimeType string) (AssetMeta, error) { normalized := strings.ToLower(strings.TrimSpace(mimeType)) @@ -42,46 +35,3 @@ func newAssetMeta(mimeType string) (AssetMeta, error) { MimeType: normalized, }, nil } - -// decodeStoredAssetMeta 将磁盘上的附件元数据反序列化为运行时结构。 -func decodeStoredAssetMeta(data []byte) (AssetMeta, error) { - var stored storedAssetMeta - if err := json.Unmarshal(data, &stored); err != nil { - return AssetMeta{}, fmt.Errorf("session: decode asset meta: %w", err) - } - if err := validateStorageID("asset id", stored.ID); err != nil { - return AssetMeta{}, fmt.Errorf("session: %w", err) - } - normalizedMime := strings.ToLower(strings.TrimSpace(stored.MimeType)) - if normalizedMime == "" { - return AssetMeta{}, fmt.Errorf("session: asset meta mime_type is empty") - } - if !strings.HasPrefix(normalizedMime, "image/") { - return AssetMeta{}, fmt.Errorf("session: unsupported asset mime type %q", stored.MimeType) - } - return AssetMeta{ - ID: stored.ID, - MimeType: normalizedMime, - Size: stored.Size, - }, nil -} - -// encodeStoredAssetMeta 将运行时附件元数据编码为可持久化 JSON。 -func encodeStoredAssetMeta(meta AssetMeta) ([]byte, error) { - if err := validateStorageID("asset id", meta.ID); err != nil { - return nil, fmt.Errorf("session: %w", err) - } - normalizedMime := strings.ToLower(strings.TrimSpace(meta.MimeType)) - if normalizedMime == "" { - return nil, fmt.Errorf("session: asset mime type is empty") - } - payload, err := json.MarshalIndent(storedAssetMeta{ - ID: meta.ID, - MimeType: normalizedMime, - Size: meta.Size, - }, "", " ") - if err != nil { - return nil, fmt.Errorf("session: marshal asset meta: %w", err) - } - return append(payload, '\n'), nil -} diff --git a/internal/session/asset_store_test.go b/internal/session/asset_store_test.go index edf7d8b4..999ef6c5 100644 --- a/internal/session/asset_store_test.go +++ b/internal/session/asset_store_test.go @@ -9,282 +9,180 @@ import ( "path/filepath" "strings" "testing" - - providertypes "neo-code/internal/provider/types" ) -func TestJSONStoreSaveAssetOpenAndStat(t *testing.T) { - t.Parallel() - - store := NewJSONStore(t.TempDir(), t.TempDir()) - sessionID := "session_asset_round_trip" - payload := []byte("fake-image-bytes") +func TestSQLiteStoreSaveAssetOpenAndStat(t *testing.T) { + ctx := context.Background() + store := newTestStore(t) + session, err := store.CreateSession(ctx, CreateSessionInput{ID: "session_assets", Title: "assets"}) + if err != nil { + t.Fatalf("CreateSession() error = %v", err) + } - meta, err := store.SaveAsset(context.Background(), sessionID, bytes.NewReader(payload), "image/png") + payload := []byte("image-bytes") + meta, err := store.SaveAsset(ctx, session.ID, bytes.NewReader(payload), "image/png") if err != nil { t.Fatalf("SaveAsset() error = %v", err) } - if meta.ID == "" || meta.MimeType != "image/png" || meta.Size != int64(len(payload)) { - t.Fatalf("unexpected saved meta: %+v", meta) + if meta.ID == "" || meta.Size != int64(len(payload)) { + t.Fatalf("unexpected asset meta: %+v", meta) } - statMeta, err := store.Stat(context.Background(), sessionID, meta.ID) + statMeta, err := store.Stat(ctx, session.ID, meta.ID) if err != nil { t.Fatalf("Stat() error = %v", err) } if statMeta != meta { - t.Fatalf("expected stat meta %+v, got %+v", meta, statMeta) + t.Fatalf("Stat() = %+v, want %+v", statMeta, meta) } - rc, openMeta, err := store.Open(context.Background(), sessionID, meta.ID) + rc, openMeta, err := store.Open(ctx, session.ID, meta.ID) if err != nil { t.Fatalf("Open() error = %v", err) } - defer func() { _ = rc.Close() }() + defer rc.Close() data, err := io.ReadAll(rc) if err != nil { - t.Fatalf("ReadAll(open) error = %v", err) + t.Fatalf("ReadAll() error = %v", err) } if string(data) != string(payload) { - t.Fatalf("unexpected open payload: got %q want %q", string(data), string(payload)) + t.Fatalf("unexpected open payload %q, want %q", string(data), string(payload)) } if openMeta != meta { - t.Fatalf("expected open meta %+v, got %+v", meta, openMeta) - } - - assetPath := store.assetPath(sessionID, meta.ID) - if _, err := os.Stat(assetPath); err != nil { - t.Fatalf("expected asset file %q exists, err=%v", assetPath, err) - } - if _, err := os.Stat(store.assetMetaPath(sessionID, meta.ID)); err != nil { - t.Fatalf("expected asset meta file exists, err=%v", err) + t.Fatalf("Open() meta = %+v, want %+v", openMeta, meta) } } -func TestJSONStoreSaveAssetRejectsInvalidInput(t *testing.T) { - t.Parallel() - - store := NewJSONStore(t.TempDir(), t.TempDir()) - sessionID := "session_asset_invalid" +func TestSQLiteStoreSaveAssetRejectsInvalidInput(t *testing.T) { + ctx := context.Background() + store := newTestStore(t) + if _, err := store.CreateSession(ctx, CreateSessionInput{ID: "session_assets_invalid", Title: "assets"}); err != nil { + t.Fatalf("CreateSession() error = %v", err) + } - if _, err := store.SaveAsset(context.Background(), sessionID, nil, "image/png"); err == nil { + if _, err := store.SaveAsset(ctx, "session_assets_invalid", nil, "image/png"); err == nil { t.Fatalf("expected nil reader error") } - if _, err := store.SaveAsset(context.Background(), sessionID, strings.NewReader("x"), ""); err == nil { - t.Fatalf("expected empty mime error") + if _, err := store.SaveAsset(ctx, "session_assets_invalid", strings.NewReader("x"), ""); err == nil { + t.Fatalf("expected empty mime type error") + } + if _, err := store.SaveAsset(ctx, "session_assets_invalid", strings.NewReader("x"), "text/plain"); err == nil { + t.Fatalf("expected unsupported mime type error") } - if _, err := store.SaveAsset(context.Background(), sessionID, strings.NewReader("x"), "text/plain"); err == nil { - t.Fatalf("expected unsupported mime error") + if _, err := store.SaveAsset(ctx, "missing", strings.NewReader("x"), "image/png"); err == nil { + t.Fatalf("expected missing session error") + } else if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected os.ErrNotExist, got %v", err) } - if _, err := store.SaveAsset(context.Background(), "../bad", strings.NewReader("x"), "image/png"); err == nil { + if _, _, err := store.Open(ctx, "bad/session", "asset_ok"); err == nil { t.Fatalf("expected invalid session id error") } + if _, err := store.Stat(ctx, "session_assets_invalid", "../bad"); err == nil { + t.Fatalf("expected invalid asset id error") + } } -func TestJSONStoreAssetOpenAndStatRejectInvalidID(t *testing.T) { - t.Parallel() - - store := NewJSONStore(t.TempDir(), t.TempDir()) - - if _, _, err := store.Open(context.Background(), "bad/session", "asset-ok"); err == nil { - t.Fatalf("expected invalid session id on open") - } - if _, _, err := store.Open(context.Background(), "session_ok", "../bad"); err == nil { - t.Fatalf("expected invalid asset id on open") - } - if _, err := store.Stat(context.Background(), "bad/session", "asset-ok"); err == nil { - t.Fatalf("expected invalid session id on stat") +func TestSQLiteStoreSaveAssetRejectsOversizedPayload(t *testing.T) { + ctx := context.Background() + store := newTestStore(t) + session, err := store.CreateSession(ctx, CreateSessionInput{ID: "session_assets_big", Title: "assets"}) + if err != nil { + t.Fatalf("CreateSession() error = %v", err) } - if _, err := store.Stat(context.Background(), "session_ok", "../bad"); err == nil { - t.Fatalf("expected invalid asset id on stat") + + oversized := bytes.NewReader(bytes.Repeat([]byte("x"), int(1+MaxSessionAssetBytesForTest()))) + if _, err := store.SaveAsset(ctx, session.ID, oversized, "image/png"); err == nil { + t.Fatalf("expected oversize error") } } -func TestJSONStoreAssetStoreRespectsCanceledContext(t *testing.T) { - t.Parallel() - - store := NewJSONStore(t.TempDir(), t.TempDir()) - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - if _, err := store.SaveAsset(ctx, "session_ctx_cancel", strings.NewReader("x"), "image/png"); err == nil { - t.Fatalf("expected canceled SaveAsset error") - } - if _, _, err := store.Open(ctx, "session_ctx_cancel", "asset_x"); err == nil { - t.Fatalf("expected canceled Open error") +func TestSQLiteStoreOpenReturnsFileErrorWhenPayloadMissing(t *testing.T) { + ctx := context.Background() + baseDir, err := os.MkdirTemp("", "session-base-") + if err != nil { + t.Fatalf("MkdirTemp() baseDir error = %v", err) } - if _, err := store.Stat(ctx, "session_ctx_cancel", "asset_x"); err == nil { - t.Fatalf("expected canceled Stat error") + workspaceRoot, err := os.MkdirTemp("", "session-workspace-") + if err != nil { + t.Fatalf("MkdirTemp() workspaceRoot error = %v", err) } -} - -func TestJSONStoreSaveAssetStopsWhenContextCanceledDuringCopy(t *testing.T) { - t.Parallel() - - store := NewJSONStore(t.TempDir(), t.TempDir()) - ctx, cancel := context.WithCancel(context.Background()) - reader := &cancelAfterFirstReadReader{ - cancel: cancel, - chunks: [][]byte{[]byte("chunk-1"), []byte("chunk-2")}, + store := NewStore(baseDir, workspaceRoot) + t.Cleanup(func() { + _ = store.Close() + _ = os.RemoveAll(baseDir) + _ = os.RemoveAll(workspaceRoot) + }) + session, err := store.CreateSession(ctx, CreateSessionInput{ID: "session_assets_missing_file", Title: "assets"}) + if err != nil { + t.Fatalf("CreateSession() error = %v", err) } - if _, err := store.SaveAsset(ctx, "session_ctx_cancel_during_copy", reader, "image/png"); !errors.Is(err, context.Canceled) { - t.Fatalf("expected context canceled during copy, got %v", err) + meta, err := store.SaveAsset(ctx, session.ID, strings.NewReader("img"), "image/png") + if err != nil { + t.Fatalf("SaveAsset() error = %v", err) } -} - -func TestJSONStoreSaveAssetRejectsOversizedPayload(t *testing.T) { - t.Parallel() - - store := NewJSONStore(t.TempDir(), t.TempDir()) - oversized := bytes.NewReader(make([]byte, providertypes.MaxSessionAssetBytes+1)) - - if _, err := store.SaveAsset(context.Background(), "session_oversize", oversized, "image/png"); err == nil || - !strings.Contains(err.Error(), "asset size exceeds") { - t.Fatalf("expected oversized payload rejection, got %v", err) + target := filepath.Join(assetsDirectory(baseDir, workspaceRoot), session.ID, meta.ID+".bin") + if err := os.Remove(target); err != nil { + t.Fatalf("remove target asset: %v", err) } -} - -func TestDecodeStoredAssetMetaRejectsNonImageMIME(t *testing.T) { - t.Parallel() - _, err := decodeStoredAssetMeta([]byte(`{"id":"asset_ok","mime_type":"text/plain","size":1}`)) - if err == nil || !strings.Contains(err.Error(), "unsupported asset mime type") { - t.Fatalf("expected non-image mime rejection, got %v", err) + if _, _, err := store.Open(ctx, session.ID, meta.ID); err == nil { + t.Fatalf("expected missing payload file error") } } -func TestDecodeAndEncodeStoredAssetMetaValidation(t *testing.T) { - t.Parallel() +func TestSQLiteStoreAssetMethodsRespectContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() - if _, err := decodeStoredAssetMeta([]byte(`{`)); err == nil || !strings.Contains(err.Error(), "decode asset meta") { - t.Fatalf("expected decode error, got %v", err) - } - if _, err := decodeStoredAssetMeta([]byte(`{"id":"bad/asset","mime_type":"image/png","size":1}`)); err == nil || - !strings.Contains(err.Error(), "unsupported characters") { - t.Fatalf("expected invalid asset id error, got %v", err) - } - if _, err := decodeStoredAssetMeta([]byte(`{"id":"asset_ok","mime_type":" ","size":1}`)); err == nil || - !strings.Contains(err.Error(), "mime_type is empty") { - t.Fatalf("expected empty mime_type error, got %v", err) + store := newTestStore(t) + if _, err := store.SaveAsset(ctx, "session_assets_ctx", strings.NewReader("x"), "image/png"); err == nil { + t.Fatalf("expected canceled SaveAsset error") } - - if _, err := encodeStoredAssetMeta(AssetMeta{ID: "bad/asset", MimeType: "image/png", Size: 1}); err == nil || - !strings.Contains(err.Error(), "unsupported characters") { - t.Fatalf("expected invalid asset id encode error, got %v", err) + if _, _, err := store.Open(ctx, "session_assets_ctx", "asset_x"); err == nil { + t.Fatalf("expected canceled Open error") } - if _, err := encodeStoredAssetMeta(AssetMeta{ID: "asset_ok", MimeType: " ", Size: 1}); err == nil || - !strings.Contains(err.Error(), "asset mime type is empty") { - t.Fatalf("expected empty mime encode error, got %v", err) + if _, err := store.Stat(ctx, "session_assets_ctx", "asset_x"); err == nil { + t.Fatalf("expected canceled Stat error") } } -func TestJSONStoreSaveAssetFailurePaths(t *testing.T) { - t.Parallel() - - t.Run("create assets dir failed", func(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) - - assetsDir := store.assetsDir("session_assets_dir_fail") - if err := os.MkdirAll(filepath.Dir(assetsDir), 0o755); err != nil { - t.Fatalf("mkdir parent: %v", err) - } - if err := os.WriteFile(assetsDir, []byte("blocked"), 0o644); err != nil { - t.Fatalf("write blocker file: %v", err) - } - - if _, err := store.SaveAsset(context.Background(), "session_assets_dir_fail", strings.NewReader("x"), "image/png"); err == nil || - !strings.Contains(err.Error(), "create assets dir") { - t.Fatalf("expected create assets dir error, got %v", err) - } - }) - - t.Run("copy temp asset failed", func(t *testing.T) { - t.Parallel() - - store := NewJSONStore(t.TempDir(), t.TempDir()) - if _, err := store.SaveAsset(context.Background(), "session_copy_fail", failingReader{}, "image/png"); err == nil || - !strings.Contains(err.Error(), "write temp asset") { - t.Fatalf("expected write temp asset error, got %v", err) - } - }) +func MaxSessionAssetBytesForTest() int64 { + return 20 * 1024 * 1024 } -func TestJSONStoreOpenAndStatMissingStoredFiles(t *testing.T) { - t.Parallel() - - store := NewJSONStore(t.TempDir(), t.TempDir()) - sessionID := "session_missing_files" - meta, err := store.SaveAsset(context.Background(), sessionID, strings.NewReader("img"), "image/png") - if err != nil { - t.Fatalf("save seed asset: %v", err) - } - - if err := os.Remove(store.assetPath(sessionID, meta.ID)); err != nil { - t.Fatalf("remove asset file: %v", err) - } - if _, _, err := store.Open(context.Background(), sessionID, meta.ID); err == nil { - t.Fatalf("expected open failure when asset binary is missing") +func TestSQLiteStoreOpenMissingAssetReturnsNotExist(t *testing.T) { + ctx := context.Background() + store := newTestStore(t) + if _, err := store.CreateSession(ctx, CreateSessionInput{ID: "session_assets_missing_meta", Title: "assets"}); err != nil { + t.Fatalf("CreateSession() error = %v", err) } - if err := os.Remove(store.assetMetaPath(sessionID, meta.ID)); err != nil { - t.Fatalf("remove asset meta file: %v", err) - } - if _, err := store.Stat(context.Background(), sessionID, meta.ID); err == nil { - t.Fatalf("expected stat failure when asset meta is missing") + _, _, err := store.Open(ctx, "session_assets_missing_meta", "asset_missing") + if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected os.ErrNotExist, got %v", err) } } -func TestJSONStoreDeleteAsset(t *testing.T) { - t.Parallel() - - store := NewJSONStore(t.TempDir(), t.TempDir()) - sessionID := "session-delete-asset" - meta, err := store.SaveAsset(context.Background(), sessionID, strings.NewReader("img"), "image/png") +func TestSQLiteStoreAssetMetaRejectsEscapedRelativePath(t *testing.T) { + ctx := context.Background() + store := newTestStore(t) + session, err := store.CreateSession(ctx, CreateSessionInput{ID: "session_assets_escape", Title: "assets"}) if err != nil { - t.Fatalf("save seed asset: %v", err) - } - - if err := store.DeleteAsset(context.Background(), sessionID, meta.ID); err != nil { - t.Fatalf("DeleteAsset() error = %v", err) - } - if _, statErr := os.Stat(store.assetPath(sessionID, meta.ID)); !errors.Is(statErr, os.ErrNotExist) { - t.Fatalf("expected removed asset file, got %v", statErr) + t.Fatalf("CreateSession() error = %v", err) } - if _, statErr := os.Stat(store.assetMetaPath(sessionID, meta.ID)); !errors.Is(statErr, os.ErrNotExist) { - t.Fatalf("expected removed asset meta file, got %v", statErr) + db, err := store.ensureDB(ctx) + if err != nil { + t.Fatalf("ensureDB() error = %v", err) } - - if err := store.DeleteAsset(context.Background(), sessionID, meta.ID); err != nil { - t.Fatalf("DeleteAsset() should ignore already deleted files, got %v", err) + if _, err := db.ExecContext(ctx, ` +INSERT INTO session_assets (id, session_id, mime_type, size_bytes, relative_path, created_at_ms) +VALUES ('asset_escape', ?, 'image/png', 4, '../escape.bin', 0) +`, session.ID); err != nil { + t.Fatalf("insert escaped asset meta: %v", err) } -} -type failingReader struct{} - -func (failingReader) Read(_ []byte) (int, error) { - return 0, errors.New("read failure") -} - -type cancelAfterFirstReadReader struct { - cancel context.CancelFunc - chunks [][]byte - index int -} - -func (r *cancelAfterFirstReadReader) Read(p []byte) (int, error) { - if r.index >= len(r.chunks) { - return 0, io.EOF - } - chunk := r.chunks[r.index] - r.index++ - n := copy(p, chunk) - if r.index == 1 && r.cancel != nil { - r.cancel() + if _, err := store.Stat(ctx, session.ID, "asset_escape"); err == nil || !strings.Contains(err.Error(), "escapes base dir") { + t.Fatalf("expected escaped relative path error, got %v", err) } - return n, nil } diff --git a/internal/session/helpers_test.go b/internal/session/helpers_test.go new file mode 100644 index 00000000..6bbc4ef3 --- /dev/null +++ b/internal/session/helpers_test.go @@ -0,0 +1,98 @@ +package session + +import ( + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +func TestTaskStateEstablishedAndTruncateHelpers(t *testing.T) { + t.Parallel() + + if (TaskState{}).Established() { + t.Fatalf("empty task state should not be established") + } + if !(TaskState{Goal: "ship"}).Established() { + t.Fatalf("goal should mark task state as established") + } + if !(TaskState{Progress: []string{"done"}}).Established() { + t.Fatalf("progress should mark task state as established") + } + if truncateRunes("abc", 0) != "" { + t.Fatalf("truncateRunes with zero limit should return empty string") + } + if truncateRunes("abc", 3) != "abc" { + t.Fatalf("truncateRunes should keep exact-length string") + } +} + +func TestWorkspaceHelpersAndPathKey(t *testing.T) { + t.Parallel() + + if NormalizeWorkspaceRoot(" ") != "" { + t.Fatalf("blank workspace root should normalize to empty") + } + dir := t.TempDir() + key := WorkspacePathKey(dir) + if key == "" { + t.Fatalf("workspace path key should not be empty") + } + if runtime.GOOS == "windows" && key != strings.ToLower(key) { + t.Fatalf("windows workspace path key should be lower-cased, got %q", key) + } + if got := EffectiveWorkdir(" session ", "default"); got != "session" { + t.Fatalf("EffectiveWorkdir should prefer session workdir, got %q", got) + } +} + +func TestStorageHelpers(t *testing.T) { + t.Parallel() + + baseDir := t.TempDir() + inside := filepath.Join(baseDir, "sub", "file.txt") + if err := os.MkdirAll(filepath.Dir(inside), 0o755); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + if err := ensurePathWithinBase(baseDir, inside); err != nil { + t.Fatalf("ensurePathWithinBase() inside error = %v", err) + } + if err := ensurePathWithinBase(baseDir, filepath.Dir(baseDir)); err == nil { + t.Fatalf("expected path escape error") + } + + if _, _, err := createTempFile(filepath.Join(baseDir, "missing"), "tmp-*", "temp"); err == nil { + t.Fatalf("expected createTempFile() error for missing dir") + } + + tempFile, tempPath, err := createTempFile(baseDir, "tmp-*", "temp") + if err != nil { + t.Fatalf("createTempFile() error = %v", err) + } + if err := tempFile.Close(); err != nil { + t.Fatalf("tempFile.Close() error = %v", err) + } + target := filepath.Join(baseDir, "target.txt") + if err := os.WriteFile(target, []byte("old"), 0o644); err != nil { + t.Fatalf("WriteFile(target) error = %v", err) + } + if err := replaceFileWithTemp(tempPath, target, "target"); err != nil { + t.Fatalf("replaceFileWithTemp() error = %v", err) + } + if _, err := os.Stat(target); err != nil { + t.Fatalf("expected replaced target to exist, got %v", err) + } + if err := replaceFileWithTemp(filepath.Join(baseDir, "missing.tmp"), filepath.Join(baseDir, "missing.txt"), "missing"); err == nil { + t.Fatalf("expected replaceFileWithTemp() error for missing temp file") + } + + outsideDir := t.TempDir() + linkPath := filepath.Join(baseDir, "link") + if err := os.Symlink(outsideDir, linkPath); err != nil { + t.Skipf("symlink not supported in current environment: %v", err) + } + if err := ensurePathWithinBase(baseDir, filepath.Join(linkPath, "escape.txt")); err == nil { + t.Fatalf("expected symlink path escape error") + } +} diff --git a/internal/session/input_preparer.go b/internal/session/input_preparer.go index e834cc60..68451347 100644 --- a/internal/session/input_preparer.go +++ b/internal/session/input_preparer.go @@ -74,6 +74,10 @@ type assetCleanupStore interface { DeleteAsset(ctx context.Context, sessionID string, assetID string) error } +type sessionCleanupStore interface { + DeleteSession(ctx context.Context, sessionID string) error +} + // NewInputPreparer 创建会话输入归一化组件。 func NewInputPreparer(store Store, assetStore AssetStore) *InputPreparer { return &InputPreparer{ @@ -333,13 +337,27 @@ func (p *InputPreparer) loadOrCreateSession( return Session{}, false, sessionWorkdirUpdate{}, err } session := NewWithWorkdir(title, sessionWorkdir) - if err := p.store.Save(ctx, &session); err != nil { + created, err := p.store.CreateSession(ctx, CreateSessionInput{ + ID: session.ID, + Title: session.Title, + CreatedAt: session.CreatedAt, + UpdatedAt: session.UpdatedAt, + Provider: session.Provider, + Model: session.Model, + Workdir: session.Workdir, + TaskState: session.TaskState, + ActivatedSkills: session.ActivatedSkills, + Todos: session.Todos, + TokenInputTotal: session.TokenInputTotal, + TokenOutputTotal: session.TokenOutputTotal, + }) + if err != nil { return Session{}, false, sessionWorkdirUpdate{}, err } - return session, true, sessionWorkdirUpdate{}, nil + return created, true, sessionWorkdirUpdate{}, nil } - session, err := p.store.Load(ctx, sessionID) + session, err := p.store.LoadSession(ctx, sessionID) if err != nil { return Session{}, false, sessionWorkdirUpdate{}, err } @@ -371,7 +389,11 @@ func (p *InputPreparer) rollbackCreatedSession(ctx context.Context, sessionID st if err := ctx.Err(); err != nil { return } - _ = p.store.DeleteSession(ctx, sessionID) + cleanupStore, ok := p.store.(sessionCleanupStore) + if !ok { + return + } + _ = cleanupStore.DeleteSession(ctx, sessionID) } // persistSessionWorkdirUpdate 在 Prepare 其余步骤完成后统一提交会话 workdir 更新,避免失败时出现部分提交。 @@ -379,7 +401,11 @@ func (p *InputPreparer) persistSessionWorkdirUpdate(ctx context.Context, pending if !pending.dirty { return nil } - if err := p.store.Save(ctx, &pending.session); err != nil { + if err := p.store.UpdateSessionWorkdir(ctx, UpdateSessionWorkdirInput{ + SessionID: pending.session.ID, + UpdatedAt: pending.session.UpdatedAt, + Workdir: pending.session.Workdir, + }); err != nil { return err } return nil diff --git a/internal/session/input_preparer_test.go b/internal/session/input_preparer_test.go index d0b6c21f..e7b96c95 100644 --- a/internal/session/input_preparer_test.go +++ b/internal/session/input_preparer_test.go @@ -8,6 +8,7 @@ import ( "path/filepath" "strings" "testing" + "time" providertypes "neo-code/internal/provider/types" ) @@ -16,7 +17,7 @@ func TestInputPreparerPrepareTextOnly(t *testing.T) { t.Parallel() workdir := t.TempDir() - store := NewStore(t.TempDir(), workdir) + store := newInputPreparerTestStore(t, workdir) preparer := NewInputPreparer(store, store) result, err := preparer.Prepare(context.Background(), PrepareInput{ @@ -41,7 +42,7 @@ func TestInputPreparerPrepareTextAndImage(t *testing.T) { t.Parallel() workdir := t.TempDir() - store := NewStore(t.TempDir(), workdir) + store := newInputPreparerTestStore(t, workdir) preparer := NewInputPreparer(store, store) imagePath := filepath.Join(workdir, "img.png") @@ -90,7 +91,7 @@ func TestInputPreparerPrepareImageInfersMimeWhenMissing(t *testing.T) { t.Parallel() workdir := t.TempDir() - store := NewStore(t.TempDir(), workdir) + store := newInputPreparerTestStore(t, workdir) preparer := NewInputPreparer(store, store) imagePath := filepath.Join(workdir, "auto.png") @@ -118,7 +119,7 @@ func TestInputPreparerPrepareImageOnlyUsesImageTitle(t *testing.T) { t.Parallel() workdir := t.TempDir() - store := NewStore(t.TempDir(), workdir) + store := newInputPreparerTestStore(t, workdir) preparer := NewInputPreparer(store, store) imagePath := filepath.Join(workdir, "only.png") @@ -137,7 +138,7 @@ func TestInputPreparerPrepareImageOnlyUsesImageTitle(t *testing.T) { t.Fatalf("expected one image part, got %+v", result.Parts) } - session, err := store.Load(context.Background(), result.SessionID) + session, err := store.LoadSession(context.Background(), result.SessionID) if err != nil { t.Fatalf("Load() error = %v", err) } @@ -150,7 +151,7 @@ func TestInputPreparerPrepareErrors(t *testing.T) { t.Parallel() workdir := t.TempDir() - store := NewStore(t.TempDir(), workdir) + store := newInputPreparerTestStore(t, workdir) t.Run("missing store", func(t *testing.T) { preparer := NewInputPreparer(nil, nil) @@ -219,8 +220,8 @@ func TestInputPreparerPrepareErrors(t *testing.T) { t.Run("existing session is kept when asset save fails", func(t *testing.T) { existing := NewWithWorkdir("existing", workdir) - if err := store.Save(context.Background(), &existing); err != nil { - t.Fatalf("Save() error = %v", err) + if err := createSessionForPreparerTest(context.Background(), store, existing); err != nil { + t.Fatalf("createSessionForPreparerTest() error = %v", err) } preparer := NewInputPreparer(store, store) @@ -233,15 +234,15 @@ func TestInputPreparerPrepareErrors(t *testing.T) { t.Fatalf("expected asset save error") } - if _, loadErr := store.Load(context.Background(), existing.ID); loadErr != nil { + if _, loadErr := store.LoadSession(context.Background(), existing.ID); loadErr != nil { t.Fatalf("expected existing session to remain, load error = %v", loadErr) } }) t.Run("existing session cleanup removes previously saved assets on later failure", func(t *testing.T) { existing := NewWithWorkdir("existing-cleanup", workdir) - if err := store.Save(context.Background(), &existing); err != nil { - t.Fatalf("Save() error = %v", err) + if err := createSessionForPreparerTest(context.Background(), store, existing); err != nil { + t.Fatalf("createSessionForPreparerTest() error = %v", err) } okImage := filepath.Join(workdir, "ok.png") @@ -263,7 +264,7 @@ func TestInputPreparerPrepareErrors(t *testing.T) { t.Fatalf("expected prepare error") } - entries, readErr := os.ReadDir(store.assetsDir(existing.ID)) + entries, readErr := os.ReadDir(filepath.Join(store.assetsDir, existing.ID)) if readErr != nil { t.Fatalf("ReadDir() error = %v", readErr) } @@ -283,8 +284,8 @@ func TestInputPreparerPrepareErrors(t *testing.T) { } existing := NewWithWorkdir("existing-workdir", currentWorkdir) - if err := store.Save(context.Background(), &existing); err != nil { - t.Fatalf("Save() error = %v", err) + if err := createSessionForPreparerTest(context.Background(), store, existing); err != nil { + t.Fatalf("createSessionForPreparerTest() error = %v", err) } preparer := NewInputPreparer(store, store) @@ -299,7 +300,7 @@ func TestInputPreparerPrepareErrors(t *testing.T) { t.Fatalf("expected prepare error") } - loaded, loadErr := store.Load(context.Background(), existing.ID) + loaded, loadErr := store.LoadSession(context.Background(), existing.ID) if loadErr != nil { t.Fatalf("Load() error = %v", loadErr) } @@ -313,7 +314,7 @@ func TestInputPreparerPrepareImagePathAndMimeValidation(t *testing.T) { t.Parallel() workdir := t.TempDir() - store := NewStore(t.TempDir(), workdir) + store := newInputPreparerTestStore(t, workdir) preparer := NewInputPreparer(store, store) t.Run("relative path is resolved by workdir", func(t *testing.T) { @@ -434,10 +435,28 @@ func TestInputPreparerPrepareUpdatesExistingSessionWorkdir(t *testing.T) { t.Fatalf("mkdir nested workdir: %v", err) } - store := NewStore(t.TempDir(), defaultWorkdir) + store := newInputPreparerTestStore(t, defaultWorkdir) session := NewWithWorkdir("existing", currentWorkdir) - if err := store.Save(context.Background(), &session); err != nil { - t.Fatalf("Save() error = %v", err) + session.Provider = "provider-a" + session.Model = "model-a" + session.TokenInputTotal = 13 + session.TokenOutputTotal = 21 + session.TaskState = TaskState{ + Goal: "keep original state", + Progress: []string{"captured"}, + NextStep: "verify workdir-only update", + Blockers: []string{"none"}, + Decisions: []string{"preserve head fields"}, + } + session.Todos = []TodoItem{{ + ID: "todo-preserve", + Content: "must survive prepare workdir update", + Status: TodoStatusInProgress, + CreatedAt: session.CreatedAt, + UpdatedAt: session.UpdatedAt, + }} + if err := createSessionForPreparerTest(context.Background(), store, session); err != nil { + t.Fatalf("createSessionForPreparerTest() error = %v", err) } preparer := NewInputPreparer(store, store) @@ -454,11 +473,171 @@ func TestInputPreparerPrepareUpdatesExistingSessionWorkdir(t *testing.T) { t.Fatalf("expected target workdir %q, got %q", targetWorkdir, result.Workdir) } - loaded, err := store.Load(context.Background(), session.ID) + loaded, err := store.LoadSession(context.Background(), session.ID) + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if loaded.Workdir != targetWorkdir { + t.Fatalf("expected persisted workdir %q, got %q", targetWorkdir, loaded.Workdir) + } + if loaded.Provider != session.Provider { + t.Fatalf("expected provider %q, got %q", session.Provider, loaded.Provider) + } + if loaded.Model != session.Model { + t.Fatalf("expected model %q, got %q", session.Model, loaded.Model) + } + if loaded.TokenInputTotal != session.TokenInputTotal || loaded.TokenOutputTotal != session.TokenOutputTotal { + t.Fatalf("expected token totals %d/%d, got %d/%d", + session.TokenInputTotal, + session.TokenOutputTotal, + loaded.TokenInputTotal, + loaded.TokenOutputTotal, + ) + } + if loaded.TaskState.Goal != session.TaskState.Goal || loaded.TaskState.NextStep != session.TaskState.NextStep { + t.Fatalf("expected task state to remain unchanged, got %+v", loaded.TaskState) + } + if len(loaded.Todos) != 1 || loaded.Todos[0].ID != session.Todos[0].ID || loaded.Todos[0].Status != session.Todos[0].Status { + t.Fatalf("expected todos to remain unchanged, got %+v", loaded.Todos) + } +} + +func TestInputPreparerPrepareWorkdirUpdatePreservesConcurrentSessionHeadChanges(t *testing.T) { + t.Parallel() + + base := t.TempDir() + defaultWorkdir := filepath.Join(base, "workspace") + if err := os.MkdirAll(defaultWorkdir, 0o755); err != nil { + t.Fatalf("mkdir default workdir: %v", err) + } + currentWorkdir := filepath.Join(defaultWorkdir, "current") + if err := os.MkdirAll(currentWorkdir, 0o755); err != nil { + t.Fatalf("mkdir current workdir: %v", err) + } + targetWorkdir := filepath.Join(currentWorkdir, "nested") + if err := os.MkdirAll(targetWorkdir, 0o755); err != nil { + t.Fatalf("mkdir nested workdir: %v", err) + } + + store := newInputPreparerTestStore(t, defaultWorkdir) + session := NewWithWorkdir("existing", currentWorkdir) + session.Provider = "provider-before" + session.Model = "model-before" + session.TokenInputTotal = 3 + session.TokenOutputTotal = 5 + if err := createSessionForPreparerTest(context.Background(), store, session); err != nil { + t.Fatalf("createSessionForPreparerTest() error = %v", err) + } + + concurrentState := UpdateSessionStateInput{ + SessionID: session.ID, + Title: session.Title, + UpdatedAt: session.UpdatedAt.Add(time.Minute), + Provider: "provider-after", + Model: "model-after", + Workdir: currentWorkdir, + TaskState: TaskState{ + Goal: "newer task state", + NextStep: "must survive workdir update", + }, + Todos: []TodoItem{{ + ID: "todo-newer", + Content: "written by concurrent run", + Status: TodoStatusCompleted, + CreatedAt: session.CreatedAt, + UpdatedAt: session.UpdatedAt.Add(time.Minute), + }}, + TokenInputTotal: 55, + TokenOutputTotal: 89, + } + + preparerStore := &workdirRaceStore{ + SQLiteStore: store, + beforeWorkdirUpdate: func(ctx context.Context) error { + return store.UpdateSessionState(ctx, concurrentState) + }, + } + preparer := NewInputPreparer(preparerStore, store) + + result, err := preparer.Prepare(context.Background(), PrepareInput{ + SessionID: session.ID, + Text: "update workdir", + DefaultWorkdir: defaultWorkdir, + RequestedWorkdir: "nested", + }) + if err != nil { + t.Fatalf("Prepare() error = %v", err) + } + if result.Workdir != targetWorkdir { + t.Fatalf("expected target workdir %q, got %q", targetWorkdir, result.Workdir) + } + + loaded, err := store.LoadSession(context.Background(), session.ID) if err != nil { t.Fatalf("Load() error = %v", err) } if loaded.Workdir != targetWorkdir { t.Fatalf("expected persisted workdir %q, got %q", targetWorkdir, loaded.Workdir) } + if loaded.Provider != concurrentState.Provider || loaded.Model != concurrentState.Model { + t.Fatalf("expected provider/model %q/%q, got %q/%q", + concurrentState.Provider, concurrentState.Model, loaded.Provider, loaded.Model) + } + if loaded.TokenInputTotal != concurrentState.TokenInputTotal || loaded.TokenOutputTotal != concurrentState.TokenOutputTotal { + t.Fatalf("expected token totals %d/%d, got %d/%d", + concurrentState.TokenInputTotal, + concurrentState.TokenOutputTotal, + loaded.TokenInputTotal, + loaded.TokenOutputTotal, + ) + } + if loaded.TaskState.Goal != concurrentState.TaskState.Goal || loaded.TaskState.NextStep != concurrentState.TaskState.NextStep { + t.Fatalf("expected newer task state to survive, got %+v", loaded.TaskState) + } + if len(loaded.Todos) != 1 || loaded.Todos[0].ID != concurrentState.Todos[0].ID || loaded.Todos[0].Status != concurrentState.Todos[0].Status { + t.Fatalf("expected newer todos to survive, got %+v", loaded.Todos) + } +} + +func createSessionForPreparerTest(ctx context.Context, store *SQLiteStore, session Session) error { + _, err := store.CreateSession(ctx, CreateSessionInput{ + ID: session.ID, + Title: session.Title, + CreatedAt: session.CreatedAt, + UpdatedAt: session.UpdatedAt, + Provider: session.Provider, + Model: session.Model, + Workdir: session.Workdir, + TaskState: session.TaskState, + ActivatedSkills: session.ActivatedSkills, + Todos: session.Todos, + TokenInputTotal: session.TokenInputTotal, + TokenOutputTotal: session.TokenOutputTotal, + }) + return err +} + +func newInputPreparerTestStore(t *testing.T, workdir string) *SQLiteStore { + t.Helper() + + store := NewStore(t.TempDir(), workdir) + t.Cleanup(func() { + _ = store.Close() + }) + return store +} + +type workdirRaceStore struct { + *SQLiteStore + beforeWorkdirUpdate func(ctx context.Context) error +} + +// UpdateSessionWorkdir 在真正更新 workdir 前注入一次更晚的会话头写入,用于回归 stale snapshot 覆盖问题。 +func (s *workdirRaceStore) UpdateSessionWorkdir(ctx context.Context, input UpdateSessionWorkdirInput) error { + if s.beforeWorkdirUpdate != nil { + if err := s.beforeWorkdirUpdate(ctx); err != nil { + return err + } + } + return s.SQLiteStore.UpdateSessionWorkdir(ctx, input) } diff --git a/internal/session/skill_activation_test.go b/internal/session/skill_activation_test.go index 19704a6b..f34d106c 100644 --- a/internal/session/skill_activation_test.go +++ b/internal/session/skill_activation_test.go @@ -2,8 +2,6 @@ package session import ( "context" - "os" - "strings" "testing" "time" @@ -11,8 +9,6 @@ import ( ) func TestSessionSkillActivationHelpers(t *testing.T) { - t.Parallel() - session := New("skills") if !session.ActivateSkill(" Go_Review ") { t.Fatalf("expected first activation to report change") @@ -37,94 +33,46 @@ func TestSessionSkillActivationHelpers(t *testing.T) { } } -func TestJSONStoreSaveLoadRoundTripActivatedSkills(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) - - session := &Session{ - SchemaVersion: CurrentSchemaVersion, - ID: "skills-round-trip", - Title: "Skills Round Trip", - CreatedAt: time.Now().Add(-time.Minute), - UpdatedAt: time.Now(), - TaskState: TaskState{}, +func TestSQLiteStoreRoundTripActivatedSkills(t *testing.T) { + ctx := context.Background() + store := newTestStore(t) + session, err := store.CreateSession(ctx, CreateSessionInput{ + ID: "skills_round_trip", + Title: "Skills Round Trip", + CreatedAt: time.Now().Add(-time.Minute), + UpdatedAt: time.Now(), ActivatedSkills: []SkillActivation{ {SkillID: " zeta "}, {SkillID: "go_review"}, {SkillID: "go-review"}, }, - Messages: []providertypes.Message{{Role: "user", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}}, - } - - if err := store.Save(context.Background(), session); err != nil { - t.Fatalf("save session with activated skills: %v", err) + }) + if err != nil { + t.Fatalf("CreateSession() error = %v", err) } if got := session.ActiveSkillIDs(); len(got) != 2 || got[0] != "go-review" || got[1] != "zeta" { t.Fatalf("expected normalized in-memory activations, got %+v", got) } - loaded, err := store.Load(context.Background(), session.ID) - if err != nil { - t.Fatalf("load session with activated skills: %v", err) - } - if got := loaded.ActiveSkillIDs(); len(got) != 2 || got[0] != "go-review" || got[1] != "zeta" { - t.Fatalf("expected normalized loaded activations, got %+v", got) - } - - rawPath := sessionFilePathForTest(baseDir, workspaceRoot, session.ID) - raw, err := os.ReadFile(rawPath) - if err != nil { - t.Fatalf("read saved session: %v", err) - } - if !strings.Contains(string(raw), "\"activated_skills\"") { - t.Fatalf("expected persisted activated_skills field, got:\n%s", string(raw)) + if err := store.AppendMessages(ctx, AppendMessagesInput{ + SessionID: session.ID, + Messages: []providertypes.Message{ + {Role: "user", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}, + }, + }); err != nil { + t.Fatalf("AppendMessages() error = %v", err) } -} - -func TestJSONStoreLoadAllowsMissingActivatedSkillsField(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) - mustWriteSessionFile(t, sessionFilePathForTest(baseDir, workspaceRoot, "no-activated-skills"), strings.Join([]string{ - `{`, - ` "schema_version": 2,`, - ` "id": "no-activated-skills",`, - ` "title": "No Activated Skills",`, - ` "created_at": "2026-04-15T10:00:00Z",`, - ` "updated_at": "2026-04-15T10:05:00Z",`, - ` "task_state": {`, - ` "goal": "",`, - ` "progress": [],`, - ` "open_items": [],`, - ` "next_step": "",`, - ` "blockers": [],`, - ` "key_artifacts": [],`, - ` "decisions": [],`, - ` "user_constraints": [],`, - ` "last_updated_at": "2026-04-15T10:05:00Z"`, - ` },`, - ` "messages": []`, - `}`, - }, "\n")) - - loaded, err := store.Load(context.Background(), "no-activated-skills") + loaded, err := store.LoadSession(ctx, session.ID) if err != nil { - t.Fatalf("load session without activated_skills field: %v", err) + t.Fatalf("LoadSession() error = %v", err) } - if len(loaded.ActiveSkillIDs()) != 0 { - t.Fatalf("expected no activated skills, got %+v", loaded.ActiveSkillIDs()) + if got := loaded.ActiveSkillIDs(); len(got) != 2 || got[0] != "go-review" || got[1] != "zeta" { + t.Fatalf("expected normalized loaded activations, got %+v", got) } } func TestSkillActivationHelpersHandleNilSessionAndBlankInput(t *testing.T) { - t.Parallel() - var nilSession *Session if nilSession.ActivateSkill("go-review") { t.Fatalf("expected nil session activate to be no-op") @@ -143,8 +91,6 @@ func TestSkillActivationHelpersHandleNilSessionAndBlankInput(t *testing.T) { } func TestSkillActivationCloneHelpers(t *testing.T) { - t.Parallel() - original := []SkillActivation{{SkillID: "go-review"}, {SkillID: "zeta"}} cloned := cloneSkillActivations(original) if len(cloned) != len(original) { diff --git a/internal/session/sqlite_store.go b/internal/session/sqlite_store.go new file mode 100644 index 00000000..e6f3fd54 --- /dev/null +++ b/internal/session/sqlite_store.go @@ -0,0 +1,1179 @@ +package session + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "sync" + "time" + + sqlitedriver "modernc.org/sqlite" + sqlite3 "modernc.org/sqlite/lib" + providertypes "neo-code/internal/provider/types" +) + +type sqliteSessionRow struct { + ID string + Title string + Provider string + Model string + CreatedAtMS int64 + UpdatedAtMS int64 + Workdir string + TaskStateJSON string + ActivatedJSON string + TodosJSON string + TokenInputTotal int + TokenOutputTotal int +} + +type sqliteMessageRow struct { + Role string + PartsJSON string + ToolCallsJSON string + ToolCallID string + IsError bool + ToolMetadataJSON string +} + +// SQLiteStore 使用单个工作区级 SQLite 数据库持久化会话。 +type SQLiteStore struct { + projectDir string + assetsDir string + dbPath string + + initMu sync.Mutex + db *sql.DB +} + +// Close 释放数据库连接,供测试和上层生命周期管理复用。 +func (s *SQLiteStore) Close() error { + if s == nil || s.db == nil { + return nil + } + return s.db.Close() +} + +// CreateSession 创建并持久化一个新的空会话头。 +func (s *SQLiteStore) CreateSession(ctx context.Context, input CreateSessionInput) (Session, error) { + if err := ctx.Err(); err != nil { + return Session{}, err + } + db, err := s.ensureDB(ctx) + if err != nil { + return Session{}, err + } + + session, err := normalizeCreateSessionInput(input) + if err != nil { + return Session{}, err + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return Session{}, fmt.Errorf("session: begin create session tx: %w", err) + } + defer rollbackTx(tx) + + _, err = tx.ExecContext(ctx, ` +INSERT INTO sessions ( + id, title, created_at_ms, updated_at_ms, provider, model, workdir, + task_state_json, todos_json, activated_skills_json, + token_input_total, token_output_total, last_seq, message_count +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 0, 0) +`, + session.ID, + session.Title, + toUnixMillis(session.CreatedAt), + toUnixMillis(session.UpdatedAt), + session.Provider, + session.Model, + session.Workdir, + mustJSONString(session.TaskState), + mustJSONString(session.Todos), + mustJSONString(session.ActivatedSkills), + session.TokenInputTotal, + session.TokenOutputTotal, + ) + if err != nil { + return Session{}, fmt.Errorf("session: insert session %s: %w", session.ID, err) + } + if err := tx.Commit(); err != nil { + return Session{}, fmt.Errorf("session: commit create session %s: %w", session.ID, err) + } + + return cloneSessionValue(session), nil +} + +// LoadSession 加载完整会话头和全部消息。 +func (s *SQLiteStore) LoadSession(ctx context.Context, id string) (Session, error) { + if err := ctx.Err(); err != nil { + return Session{}, err + } + if err := validateStorageID("session id", id); err != nil { + return Session{}, fmt.Errorf("session: %w", err) + } + db, err := s.ensureDB(ctx) + if err != nil { + return Session{}, err + } + + tx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) + if err != nil { + return Session{}, fmt.Errorf("session: begin load session tx: %w", err) + } + defer rollbackTx(tx) + + row, err := loadSessionRow(ctx, tx, stringsTrimSpace(id)) + if err != nil { + return Session{}, err + } + messages, err := loadMessages(ctx, tx, stringsTrimSpace(id)) + if err != nil { + return Session{}, err + } + session, err := buildSessionFromRow(row, messages) + if err != nil { + return Session{}, err + } + if err := tx.Commit(); err != nil { + return Session{}, fmt.Errorf("session: commit load session %s: %w", id, err) + } + return session, nil +} + +// ListSummaries 仅查询会话摘要元数据。 +func (s *SQLiteStore) ListSummaries(ctx context.Context) ([]Summary, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + db, err := s.ensureDB(ctx) + if err != nil { + return nil, err + } + + rows, err := db.QueryContext(ctx, ` +SELECT id, title, created_at_ms, updated_at_ms +FROM sessions +ORDER BY updated_at_ms DESC, id DESC +`) + if err != nil { + return nil, fmt.Errorf("session: list summaries: %w", err) + } + defer rows.Close() + + summaries := make([]Summary, 0) + for rows.Next() { + var summary Summary + var createdAtMS int64 + var updatedAtMS int64 + if err := rows.Scan(&summary.ID, &summary.Title, &createdAtMS, &updatedAtMS); err != nil { + return nil, fmt.Errorf("session: scan summary: %w", err) + } + summary.CreatedAt = fromUnixMillis(createdAtMS) + summary.UpdatedAt = fromUnixMillis(updatedAtMS) + summaries = append(summaries, summary) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("session: iterate summaries: %w", err) + } + return summaries, nil +} + +// AppendMessages 在单事务内追加消息并更新会话头增量字段。 +func (s *SQLiteStore) AppendMessages(ctx context.Context, input AppendMessagesInput) error { + if err := ctx.Err(); err != nil { + return err + } + if err := validateStorageID("session id", input.SessionID); err != nil { + return fmt.Errorf("session: %w", err) + } + if len(input.Messages) == 0 { + return errors.New("session: append messages input is empty") + } + db, err := s.ensureDB(ctx) + if err != nil { + return err + } + + normalizedMessages, err := normalizeMessages(input.Messages) + if err != nil { + return err + } + updatedAt := resolveUpdatedAt(input.UpdatedAt) + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("session: begin append messages tx: %w", err) + } + defer rollbackTx(tx) + + lastSeq, err := currentLastSeq(ctx, tx, input.SessionID) + if err != nil { + return err + } + + for _, message := range normalizedMessages { + lastSeq++ + if err := insertMessage(ctx, tx, input.SessionID, lastSeq, updatedAt, message); err != nil { + return err + } + } + + result, err := tx.ExecContext(ctx, ` +UPDATE sessions +SET updated_at_ms = ?, + provider = ?, + model = ?, + workdir = ?, + token_input_total = token_input_total + ?, + token_output_total = token_output_total + ?, + last_seq = ?, + message_count = message_count + ? +WHERE id = ? +`, + toUnixMillis(updatedAt), + stringsTrimSpace(input.Provider), + stringsTrimSpace(input.Model), + stringsTrimSpace(input.Workdir), + input.TokenInputDelta, + input.TokenOutputDelta, + lastSeq, + len(normalizedMessages), + input.SessionID, + ) + if err != nil { + return fmt.Errorf("session: update session after append %s: %w", input.SessionID, err) + } + if err := expectRowsAffected(result, input.SessionID); err != nil { + return err + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("session: commit append messages %s: %w", input.SessionID, err) + } + return nil +} + +// UpdateSessionWorkdir 仅更新会话 workdir 与更新时间,避免 Prepare 阶段覆盖其他会话头字段。 +func (s *SQLiteStore) UpdateSessionWorkdir(ctx context.Context, input UpdateSessionWorkdirInput) error { + if err := ctx.Err(); err != nil { + return err + } + if err := validateStorageID("session id", input.SessionID); err != nil { + return fmt.Errorf("session: %w", err) + } + db, err := s.ensureDB(ctx) + if err != nil { + return err + } + + result, err := db.ExecContext(ctx, ` +UPDATE sessions +SET updated_at_ms = ?, + workdir = ? +WHERE id = ? +`, + toUnixMillis(resolveUpdatedAt(input.UpdatedAt)), + stringsTrimSpace(input.Workdir), + input.SessionID, + ) + if err != nil { + return fmt.Errorf("session: update session workdir %s: %w", input.SessionID, err) + } + return expectRowsAffected(result, input.SessionID) +} + +// UpdateSessionState 仅更新会话头字段,不写入消息。 +func (s *SQLiteStore) UpdateSessionState(ctx context.Context, input UpdateSessionStateInput) error { + if err := ctx.Err(); err != nil { + return err + } + row, err := normalizeUpdateSessionStateInput(input) + if err != nil { + return err + } + db, err := s.ensureDB(ctx) + if err != nil { + return err + } + + result, err := db.ExecContext(ctx, ` +UPDATE sessions +SET title = ?, + updated_at_ms = ?, + provider = ?, + model = ?, + workdir = ?, + task_state_json = ?, + todos_json = ?, + activated_skills_json = ?, + token_input_total = ?, + token_output_total = ? +WHERE id = ? +`, + row.Title, + row.UpdatedAtMS, + row.Provider, + row.Model, + row.Workdir, + row.TaskStateJSON, + row.TodosJSON, + row.ActivatedJSON, + row.TokenInputTotal, + row.TokenOutputTotal, + row.ID, + ) + if err != nil { + return fmt.Errorf("session: update session state %s: %w", row.ID, err) + } + return expectRowsAffected(result, row.ID) +} + +// ReplaceTranscript 用于 compact 后整段 transcript 的原子替换。 +func (s *SQLiteStore) ReplaceTranscript(ctx context.Context, input ReplaceTranscriptInput) error { + if err := ctx.Err(); err != nil { + return err + } + row, messages, err := normalizeReplaceTranscriptInput(input) + if err != nil { + return err + } + db, err := s.ensureDB(ctx) + if err != nil { + return err + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("session: begin replace transcript tx: %w", err) + } + defer rollbackTx(tx) + + if _, err := tx.ExecContext(ctx, `DELETE FROM messages WHERE session_id = ?`, row.ID); err != nil { + return fmt.Errorf("session: delete transcript %s: %w", row.ID, err) + } + if _, err := currentLastSeq(ctx, tx, row.ID); err != nil { + return err + } + + lastSeq := 0 + for _, message := range messages { + lastSeq++ + if err := insertMessage(ctx, tx, row.ID, lastSeq, fromUnixMillis(row.UpdatedAtMS), message); err != nil { + return err + } + } + + result, err := tx.ExecContext(ctx, ` +UPDATE sessions +SET updated_at_ms = ?, + provider = ?, + model = ?, + workdir = ?, + task_state_json = ?, + todos_json = ?, + activated_skills_json = ?, + token_input_total = ?, + token_output_total = ?, + last_seq = ?, + message_count = ? +WHERE id = ? +`, + row.UpdatedAtMS, + row.Provider, + row.Model, + row.Workdir, + row.TaskStateJSON, + row.TodosJSON, + row.ActivatedJSON, + row.TokenInputTotal, + row.TokenOutputTotal, + lastSeq, + len(messages), + row.ID, + ) + if err != nil { + return fmt.Errorf("session: update session during replace transcript %s: %w", row.ID, err) + } + if err := expectRowsAffected(result, row.ID); err != nil { + return err + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("session: commit replace transcript %s: %w", row.ID, err) + } + return nil +} + +// SaveAsset 将附件二进制内容落盘,并将元数据写入数据库。 +func (s *SQLiteStore) SaveAsset(ctx context.Context, sessionID string, r io.Reader, mimeType string) (AssetMeta, error) { + if err := ctx.Err(); err != nil { + return AssetMeta{}, err + } + if r == nil { + return AssetMeta{}, errors.New("session: asset reader is nil") + } + if err := validateStorageID("session id", sessionID); err != nil { + return AssetMeta{}, fmt.Errorf("session: %w", err) + } + db, err := s.ensureDB(ctx) + if err != nil { + return AssetMeta{}, err + } + + meta, err := newAssetMeta(mimeType) + if err != nil { + return AssetMeta{}, err + } + + assetDir := filepath.Join(s.assetsDir, sessionID) + if err := ensurePathWithinBase(s.projectDir, assetDir); err != nil { + return AssetMeta{}, fmt.Errorf("session: resolve assets dir path: %w", err) + } + if err := os.MkdirAll(assetDir, 0o755); err != nil { + return AssetMeta{}, fmt.Errorf("session: create assets dir: %w", err) + } + + tempFile, tempPath, err := createTempFile(assetDir, "asset-*.tmp", "create temp asset") + if err != nil { + return AssetMeta{}, err + } + + written, copyErr := io.Copy(tempFile, io.LimitReader(r, providertypes.MaxSessionAssetBytes+1)) + syncErr := tempFile.Sync() + closeErr := tempFile.Close() + if copyErr != nil { + _ = os.Remove(tempPath) + return AssetMeta{}, fmt.Errorf("session: write temp asset: %w", copyErr) + } + if written > providertypes.MaxSessionAssetBytes { + _ = os.Remove(tempPath) + return AssetMeta{}, fmt.Errorf("session: asset size exceeds %d bytes", providertypes.MaxSessionAssetBytes) + } + if syncErr != nil { + _ = os.Remove(tempPath) + return AssetMeta{}, fmt.Errorf("session: sync temp asset: %w", syncErr) + } + if closeErr != nil { + _ = os.Remove(tempPath) + return AssetMeta{}, fmt.Errorf("session: close temp asset: %w", closeErr) + } + + target := filepath.Join(assetDir, meta.ID+".bin") + if err := ensurePathWithinBase(s.projectDir, target); err != nil { + _ = os.Remove(tempPath) + return AssetMeta{}, fmt.Errorf("session: resolve asset file path: %w", err) + } + if err := replaceFileWithTemp(tempPath, target, "asset file"); err != nil { + _ = os.Remove(tempPath) + return AssetMeta{}, err + } + + meta.Size = written + relativePath, err := filepath.Rel(s.projectDir, target) + if err != nil { + _ = os.Remove(target) + return AssetMeta{}, fmt.Errorf("session: compute relative asset path: %w", err) + } + + result, err := db.ExecContext(ctx, ` +INSERT INTO session_assets (id, session_id, mime_type, size_bytes, relative_path, created_at_ms) +VALUES (?, ?, ?, ?, ?, ?) +`, + meta.ID, + sessionID, + meta.MimeType, + meta.Size, + filepath.ToSlash(relativePath), + toUnixMillis(time.Now()), + ) + if err != nil { + _ = os.Remove(target) + return AssetMeta{}, mapSessionAssetInsertError(meta.ID, err) + } + if err := expectRowsAffected(result, sessionID); err != nil { + _ = os.Remove(target) + return AssetMeta{}, err + } + return meta, nil +} + +// Open 读取指定会话附件的二进制内容与元数据。 +func (s *SQLiteStore) Open(ctx context.Context, sessionID string, assetID string) (io.ReadCloser, AssetMeta, error) { + if err := ctx.Err(); err != nil { + return nil, AssetMeta{}, err + } + meta, path, err := s.loadAssetMeta(ctx, sessionID, assetID) + if err != nil { + return nil, AssetMeta{}, err + } + file, err := os.Open(path) + if err != nil { + return nil, AssetMeta{}, err + } + return file, meta, nil +} + +// Stat 返回指定会话附件的元数据。 +func (s *SQLiteStore) Stat(ctx context.Context, sessionID string, assetID string) (AssetMeta, error) { + if err := ctx.Err(); err != nil { + return AssetMeta{}, err + } + meta, _, err := s.loadAssetMeta(ctx, sessionID, assetID) + if err != nil { + return AssetMeta{}, err + } + return meta, nil +} + +// DeleteAsset 删除指定会话附件的元数据与二进制文件,缺失目标按幂等处理。 +func (s *SQLiteStore) DeleteAsset(ctx context.Context, sessionID string, assetID string) error { + if err := ctx.Err(); err != nil { + return err + } + if err := validateStorageID("session id", sessionID); err != nil { + return fmt.Errorf("session: %w", err) + } + if err := validateStorageID("asset id", assetID); err != nil { + return fmt.Errorf("session: %w", err) + } + db, err := s.ensureDB(ctx) + if err != nil { + return err + } + + meta, path, err := s.loadAssetMeta(ctx, sessionID, assetID) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return err + } + + result, execErr := db.ExecContext(ctx, `DELETE FROM session_assets WHERE session_id = ? AND id = ?`, sessionID, assetID) + if execErr != nil { + return fmt.Errorf("session: delete asset meta %s: %w", assetID, execErr) + } + if affected, affErr := result.RowsAffected(); affErr == nil && affected == 0 && errors.Is(err, os.ErrNotExist) { + return nil + } + + if strings.TrimSpace(meta.ID) == "" { + return nil + } + if removeErr := os.Remove(path); removeErr != nil && !errors.Is(removeErr, os.ErrNotExist) { + return fmt.Errorf("session: delete asset file %s: %w", assetID, removeErr) + } + return nil +} + +// DeleteSession 删除会话头、消息、附件元数据,并清理对应附件目录。 +func (s *SQLiteStore) DeleteSession(ctx context.Context, sessionID string) error { + if err := ctx.Err(); err != nil { + return err + } + if err := validateStorageID("session id", sessionID); err != nil { + return fmt.Errorf("session: %w", err) + } + db, err := s.ensureDB(ctx) + if err != nil { + return err + } + + if _, err := db.ExecContext(ctx, `DELETE FROM sessions WHERE id = ?`, sessionID); err != nil { + return fmt.Errorf("session: delete session %s: %w", sessionID, err) + } + + assetDir := filepath.Join(s.assetsDir, sessionID) + if err := ensurePathWithinBase(s.projectDir, assetDir); err != nil { + return fmt.Errorf("session: resolve assets dir path: %w", err) + } + if err := os.RemoveAll(assetDir); err != nil { + return fmt.Errorf("session: delete assets dir %s: %w", sessionID, err) + } + return nil +} + +// ensureDB 懒加载数据库并执行 schema 初始化。 +func (s *SQLiteStore) ensureDB(ctx context.Context) (*sql.DB, error) { + s.initMu.Lock() + defer s.initMu.Unlock() + if s.db != nil { + return s.db, nil + } + if err := s.initialize(ctx); err != nil { + return nil, err + } + return s.db, nil +} + +// initialize 打开数据库、设置 PRAGMA 并初始化 schema。 +func (s *SQLiteStore) initialize(ctx context.Context) error { + if err := os.MkdirAll(s.projectDir, 0o755); err != nil { + return fmt.Errorf("session: create project dir: %w", err) + } + if err := os.MkdirAll(s.assetsDir, 0o755); err != nil { + return fmt.Errorf("session: create assets dir: %w", err) + } + + db, err := sql.Open("sqlite", s.dbPath) + if err != nil { + return fmt.Errorf("session: open sqlite db: %w", err) + } + db.SetMaxOpenConns(4) + db.SetMaxIdleConns(4) + + if err := applySQLitePragmas(ctx, db); err != nil { + _ = db.Close() + return err + } + if err := initializeSQLiteSchema(ctx, db); err != nil { + _ = db.Close() + return err + } + + s.db = db + return nil +} + +// loadAssetMeta 查询附件元数据并解析绝对路径。 +func (s *SQLiteStore) loadAssetMeta(ctx context.Context, sessionID string, assetID string) (AssetMeta, string, error) { + if err := validateStorageID("session id", sessionID); err != nil { + return AssetMeta{}, "", fmt.Errorf("session: %w", err) + } + if err := validateStorageID("asset id", assetID); err != nil { + return AssetMeta{}, "", fmt.Errorf("session: %w", err) + } + db, err := s.ensureDB(ctx) + if err != nil { + return AssetMeta{}, "", err + } + + var meta AssetMeta + var relativePath string + err = db.QueryRowContext(ctx, ` +SELECT mime_type, size_bytes, relative_path +FROM session_assets +WHERE session_id = ? AND id = ? +`, + sessionID, + assetID, + ).Scan(&meta.MimeType, &meta.Size, &relativePath) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return AssetMeta{}, "", os.ErrNotExist + } + return AssetMeta{}, "", fmt.Errorf("session: query asset meta %s: %w", assetID, err) + } + meta.ID = assetID + target := filepath.Join(s.projectDir, filepath.FromSlash(relativePath)) + if err := ensurePathWithinBase(s.projectDir, target); err != nil { + return AssetMeta{}, "", fmt.Errorf("session: resolve asset file path: %w", err) + } + return meta, target, nil +} + +// applySQLitePragmas 设置会话数据库的固定运行参数。 +func applySQLitePragmas(ctx context.Context, db *sql.DB) error { + pragmas := []string{ + `PRAGMA journal_mode=WAL`, + `PRAGMA synchronous=NORMAL`, + `PRAGMA foreign_keys=ON`, + `PRAGMA busy_timeout=5000`, + } + for _, pragma := range pragmas { + if _, err := db.ExecContext(ctx, pragma); err != nil { + return fmt.Errorf("session: apply pragma %q: %w", pragma, err) + } + } + return nil +} + +// initializeSQLiteSchema 初始化数据库 schema,并拒绝未知版本。 +func initializeSQLiteSchema(ctx context.Context, db *sql.DB) error { + var userVersion int + if err := db.QueryRowContext(ctx, `PRAGMA user_version`).Scan(&userVersion); err != nil { + return fmt.Errorf("session: read sqlite user_version: %w", err) + } + if userVersion != 0 && userVersion != sqliteSchemaVersion { + return fmt.Errorf("session: unsupported sqlite schema version %d", userVersion) + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("session: begin schema tx: %w", err) + } + defer rollbackTx(tx) + + statements := []string{ + `CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + title TEXT NOT NULL, + created_at_ms INTEGER NOT NULL, + updated_at_ms INTEGER NOT NULL, + provider TEXT NOT NULL DEFAULT '', + model TEXT NOT NULL DEFAULT '', + workdir TEXT NOT NULL DEFAULT '', + task_state_json TEXT NOT NULL, + todos_json TEXT NOT NULL, + activated_skills_json TEXT NOT NULL, + token_input_total INTEGER NOT NULL DEFAULT 0, + token_output_total INTEGER NOT NULL DEFAULT 0, + last_seq INTEGER NOT NULL DEFAULT 0, + message_count INTEGER NOT NULL DEFAULT 0 + )`, + `CREATE TABLE IF NOT EXISTS messages ( + session_id TEXT NOT NULL, + seq INTEGER NOT NULL, + role TEXT NOT NULL, + parts_json TEXT NOT NULL, + tool_calls_json TEXT NOT NULL DEFAULT '', + tool_call_id TEXT NOT NULL DEFAULT '', + is_error INTEGER NOT NULL DEFAULT 0, + tool_metadata_json TEXT NOT NULL DEFAULT '', + created_at_ms INTEGER NOT NULL, + PRIMARY KEY(session_id, seq), + FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE + )`, + `CREATE TABLE IF NOT EXISTS session_assets ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + mime_type TEXT NOT NULL, + size_bytes INTEGER NOT NULL, + relative_path TEXT NOT NULL, + created_at_ms INTEGER NOT NULL, + FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE + )`, + `CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON sessions(updated_at_ms DESC)`, + `CREATE INDEX IF NOT EXISTS idx_messages_session_seq_desc ON messages(session_id, seq DESC)`, + `CREATE INDEX IF NOT EXISTS idx_assets_session_id ON session_assets(session_id)`, + fmt.Sprintf(`PRAGMA user_version=%d`, sqliteSchemaVersion), + } + for _, statement := range statements { + if _, err := tx.ExecContext(ctx, statement); err != nil { + return fmt.Errorf("session: apply schema statement: %w", err) + } + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("session: commit schema tx: %w", err) + } + return nil +} + +// normalizeCreateSessionInput 规范化创建会话输入并生成最终会话头。 +func normalizeCreateSessionInput(input CreateSessionInput) (Session, error) { + session := Session{ + ID: stringsTrimSpace(input.ID), + Title: sanitizeTitle(input.Title), + Provider: stringsTrimSpace(input.Provider), + Model: stringsTrimSpace(input.Model), + CreatedAt: input.CreatedAt, + UpdatedAt: input.UpdatedAt, + Workdir: stringsTrimSpace(input.Workdir), + TaskState: normalizeAndClampTaskState(input.TaskState), + ActivatedSkills: normalizeSkillActivations(input.ActivatedSkills), + TokenInputTotal: input.TokenInputTotal, + TokenOutputTotal: input.TokenOutputTotal, + } + if session.ID == "" { + session.ID = NewID("session") + } + if err := validateStorageID("session id", session.ID); err != nil { + return Session{}, fmt.Errorf("session: %w", err) + } + now := time.Now() + if session.CreatedAt.IsZero() { + session.CreatedAt = now + } + if session.UpdatedAt.IsZero() { + session.UpdatedAt = session.CreatedAt + } + todos, err := normalizeAndValidateTodos(input.Todos) + if err != nil { + return Session{}, err + } + session.Todos = todos + if len(session.Todos) > 0 { + session.TodoVersion = CurrentTodoVersion + } + return session, nil +} + +// normalizeUpdateSessionStateInput 规范化会话头更新输入。 +func normalizeUpdateSessionStateInput(input UpdateSessionStateInput) (sqliteSessionRow, error) { + if err := validateStorageID("session id", input.SessionID); err != nil { + return sqliteSessionRow{}, fmt.Errorf("session: %w", err) + } + todos, err := normalizeAndValidateTodos(input.Todos) + if err != nil { + return sqliteSessionRow{}, err + } + return sqliteSessionRow{ + ID: stringsTrimSpace(input.SessionID), + Title: sanitizeTitle(input.Title), + Provider: stringsTrimSpace(input.Provider), + Model: stringsTrimSpace(input.Model), + UpdatedAtMS: toUnixMillis(resolveUpdatedAt(input.UpdatedAt)), + Workdir: stringsTrimSpace(input.Workdir), + TaskStateJSON: mustJSONString(normalizeAndClampTaskState(input.TaskState)), + TodosJSON: mustJSONString(todos), + ActivatedJSON: mustJSONString(normalizeSkillActivations(input.ActivatedSkills)), + TokenInputTotal: input.TokenInputTotal, + TokenOutputTotal: input.TokenOutputTotal, + }, nil +} + +// normalizeReplaceTranscriptInput 规范化 compact 后的 transcript 替换输入。 +func normalizeReplaceTranscriptInput(input ReplaceTranscriptInput) (sqliteSessionRow, []providertypes.Message, error) { + row, err := normalizeUpdateSessionStateInput(UpdateSessionStateInput{ + SessionID: input.SessionID, + Title: "", + UpdatedAt: input.UpdatedAt, + Provider: input.Provider, + Model: input.Model, + Workdir: input.Workdir, + TaskState: input.TaskState, + ActivatedSkills: input.ActivatedSkills, + Todos: input.Todos, + TokenInputTotal: input.TokenInputTotal, + TokenOutputTotal: input.TokenOutputTotal, + }) + if err != nil { + return sqliteSessionRow{}, nil, err + } + messages, err := normalizeMessages(input.Messages) + if err != nil { + return sqliteSessionRow{}, nil, err + } + return row, messages, nil +} + +// normalizeMessages 校验并深拷贝待持久化消息。 +func normalizeMessages(messages []providertypes.Message) ([]providertypes.Message, error) { + if len(messages) == 0 { + return nil, nil + } + cloned := make([]providertypes.Message, len(messages)) + for idx, message := range messages { + if err := providertypes.ValidateParts(message.Parts); err != nil { + return nil, fmt.Errorf("session: invalid message parts at index %d: %w", idx, err) + } + cloned[idx] = cloneMessage(message) + } + return cloned, nil +} + +// loadSessionRow 查询单条会话头记录。 +func loadSessionRow(ctx context.Context, tx *sql.Tx, sessionID string) (sqliteSessionRow, error) { + var row sqliteSessionRow + err := tx.QueryRowContext(ctx, ` +SELECT id, title, provider, model, created_at_ms, updated_at_ms, workdir, + task_state_json, activated_skills_json, todos_json, token_input_total, token_output_total +FROM sessions +WHERE id = ? +`, + sessionID, + ).Scan( + &row.ID, + &row.Title, + &row.Provider, + &row.Model, + &row.CreatedAtMS, + &row.UpdatedAtMS, + &row.Workdir, + &row.TaskStateJSON, + &row.ActivatedJSON, + &row.TodosJSON, + &row.TokenInputTotal, + &row.TokenOutputTotal, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return sqliteSessionRow{}, os.ErrNotExist + } + return sqliteSessionRow{}, fmt.Errorf("session: query session %s: %w", sessionID, err) + } + return row, nil +} + +// loadMessages 查询指定会话的全部消息并按顺序返回。 +func loadMessages(ctx context.Context, tx *sql.Tx, sessionID string) ([]sqliteMessageRow, error) { + rows, err := tx.QueryContext(ctx, ` +SELECT role, parts_json, tool_calls_json, tool_call_id, is_error, tool_metadata_json +FROM messages +WHERE session_id = ? +ORDER BY seq ASC +`, + sessionID, + ) + if err != nil { + return nil, fmt.Errorf("session: query messages for %s: %w", sessionID, err) + } + defer rows.Close() + + messages := make([]sqliteMessageRow, 0) + for rows.Next() { + var row sqliteMessageRow + if err := rows.Scan( + &row.Role, + &row.PartsJSON, + &row.ToolCallsJSON, + &row.ToolCallID, + &row.IsError, + &row.ToolMetadataJSON, + ); err != nil { + return nil, fmt.Errorf("session: scan message row: %w", err) + } + messages = append(messages, row) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("session: iterate messages: %w", err) + } + return messages, nil +} + +// buildSessionFromRow 由数据库行构建完整会话对象。 +func buildSessionFromRow(row sqliteSessionRow, messages []sqliteMessageRow) (Session, error) { + var taskState TaskState + if err := json.Unmarshal([]byte(row.TaskStateJSON), &taskState); err != nil { + return Session{}, fmt.Errorf("session: decode task_state for %s: %w", row.ID, err) + } + var activated []SkillActivation + if err := json.Unmarshal([]byte(row.ActivatedJSON), &activated); err != nil { + return Session{}, fmt.Errorf("session: decode activated_skills for %s: %w", row.ID, err) + } + var todos []TodoItem + if err := json.Unmarshal([]byte(row.TodosJSON), &todos); err != nil { + return Session{}, fmt.Errorf("session: decode todos for %s: %w", row.ID, err) + } + normalizedTodos, err := normalizeAndValidateTodos(todos) + if err != nil { + return Session{}, err + } + + result := Session{ + ID: row.ID, + Title: row.Title, + Provider: row.Provider, + Model: row.Model, + CreatedAt: fromUnixMillis(row.CreatedAtMS), + UpdatedAt: fromUnixMillis(row.UpdatedAtMS), + Workdir: row.Workdir, + TaskState: normalizeAndClampTaskState(taskState), + ActivatedSkills: normalizeSkillActivations(activated), + Todos: normalizedTodos, + TokenInputTotal: row.TokenInputTotal, + TokenOutputTotal: row.TokenOutputTotal, + } + if len(result.Todos) > 0 { + result.TodoVersion = CurrentTodoVersion + } + + if len(messages) == 0 { + return result, nil + } + result.Messages = make([]providertypes.Message, 0, len(messages)) + for _, messageRow := range messages { + message, err := buildMessageFromRow(messageRow) + if err != nil { + return Session{}, err + } + result.Messages = append(result.Messages, message) + } + return result, nil +} + +// buildMessageFromRow 由数据库消息行恢复 provider 消息结构。 +func buildMessageFromRow(row sqliteMessageRow) (providertypes.Message, error) { + var parts []providertypes.ContentPart + if err := json.Unmarshal([]byte(row.PartsJSON), &parts); err != nil { + return providertypes.Message{}, fmt.Errorf("session: decode message parts: %w", err) + } + var toolCalls []providertypes.ToolCall + if row.ToolCallsJSON != "" { + if err := json.Unmarshal([]byte(row.ToolCallsJSON), &toolCalls); err != nil { + return providertypes.Message{}, fmt.Errorf("session: decode tool calls: %w", err) + } + } + var metadata map[string]string + if row.ToolMetadataJSON != "" { + if err := json.Unmarshal([]byte(row.ToolMetadataJSON), &metadata); err != nil { + return providertypes.Message{}, fmt.Errorf("session: decode tool metadata: %w", err) + } + } + return providertypes.Message{ + Role: row.Role, + Parts: parts, + ToolCalls: toolCalls, + ToolCallID: row.ToolCallID, + IsError: row.IsError, + ToolMetadata: metadata, + }, nil +} + +// currentLastSeq 读取当前会话的最后消息序号。 +func currentLastSeq(ctx context.Context, tx *sql.Tx, sessionID string) (int, error) { + var lastSeq int + err := tx.QueryRowContext(ctx, `SELECT last_seq FROM sessions WHERE id = ?`, sessionID).Scan(&lastSeq) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, os.ErrNotExist + } + return 0, fmt.Errorf("session: query last_seq for %s: %w", sessionID, err) + } + return lastSeq, nil +} + +// insertMessage 在事务内插入单条消息记录。 +func insertMessage( + ctx context.Context, + tx *sql.Tx, + sessionID string, + seq int, + createdAt time.Time, + message providertypes.Message, +) error { + result, err := tx.ExecContext(ctx, ` +INSERT INTO messages ( + session_id, seq, role, parts_json, tool_calls_json, tool_call_id, is_error, tool_metadata_json, created_at_ms +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) +`, + sessionID, + seq, + message.Role, + mustJSONString(message.Parts), + mustJSONString(message.ToolCalls), + message.ToolCallID, + boolToInt(message.IsError), + mustJSONString(message.ToolMetadata), + toUnixMillis(createdAt), + ) + if err != nil { + return fmt.Errorf("session: insert message %s/%d: %w", sessionID, seq, err) + } + return expectRowsAffected(result, sessionID) +} + +// expectRowsAffected 校验写操作是否命中会话记录。 +func expectRowsAffected(result sql.Result, sessionID string) error { + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("session: inspect rows affected for %s: %w", sessionID, err) + } + if rowsAffected == 0 { + return os.ErrNotExist + } + return nil +} + +// cloneMessage 深拷贝消息,避免共享底层切片和映射。 +// mapSessionAssetInsertError 统一收敛附件元数据插入阶段的缺失会话语义,避免向上泄漏底层 SQLite 错误。 +func mapSessionAssetInsertError(assetID string, err error) error { + if isSQLiteForeignKeyConstraintError(err) { + return fmt.Errorf("session: insert asset meta %s: %w", assetID, os.ErrNotExist) + } + return fmt.Errorf("session: insert asset meta %s: %w", assetID, err) +} + +// isSQLiteForeignKeyConstraintError 判断底层错误是否为 SQLite 外键约束失败。 +func isSQLiteForeignKeyConstraintError(err error) bool { + var sqliteErr *sqlitedriver.Error + if errors.As(err, &sqliteErr) { + return sqliteErr.Code() == sqlite3.SQLITE_CONSTRAINT_FOREIGNKEY + } + return false +} + +func cloneMessage(message providertypes.Message) providertypes.Message { + next := message + next.Parts = providertypes.CloneParts(message.Parts) + next.ToolCalls = append([]providertypes.ToolCall(nil), message.ToolCalls...) + if len(message.ToolMetadata) > 0 { + next.ToolMetadata = make(map[string]string, len(message.ToolMetadata)) + for key, value := range message.ToolMetadata { + next.ToolMetadata[key] = value + } + } else { + next.ToolMetadata = nil + } + return next +} + +// cloneSessionValue 深拷贝会话值,确保调用方拿到独立副本。 +func cloneSessionValue(session Session) Session { + cloned := session + cloned.TaskState = session.TaskState.Clone() + cloned.ActivatedSkills = cloneSkillActivations(session.ActivatedSkills) + cloned.Todos = session.ListTodos() + if len(session.Messages) > 0 { + cloned.Messages = make([]providertypes.Message, len(session.Messages)) + for idx, message := range session.Messages { + cloned.Messages[idx] = cloneMessage(message) + } + } + return cloned +} + +// mustJSONString 将值编码为 JSON 字符串;调用方已保证输入可序列化。 +func mustJSONString(value any) string { + switch typed := value.(type) { + case nil: + return "[]" + case map[string]string: + if typed == nil { + return "{}" + } + case []providertypes.ToolCall: + if typed == nil { + return "[]" + } + } + data, err := json.Marshal(value) + if err != nil { + panic(err) + } + return string(data) +} + +// resolveUpdatedAt 统一为写入选择更新时间,缺省时使用当前时间。 +func resolveUpdatedAt(value time.Time) time.Time { + if value.IsZero() { + return time.Now() + } + return value +} + +// toUnixMillis 将时间转换为 UTC 毫秒时间戳。 +func toUnixMillis(value time.Time) int64 { + return value.UTC().UnixMilli() +} + +// fromUnixMillis 将毫秒时间戳还原为 UTC 时间。 +func fromUnixMillis(value int64) time.Time { + if value == 0 { + return time.Time{} + } + return time.UnixMilli(value).UTC() +} + +// boolToInt 将布尔值映射为 SQLite 整数布尔位。 +func boolToInt(value bool) int { + if value { + return 1 + } + return 0 +} + +// rollbackTx 在返回前吞掉回滚错误,避免覆盖主错误。 +func rollbackTx(tx *sql.Tx) { + if tx != nil { + _ = tx.Rollback() + } +} + +// stringsTrimSpace 集中收敛字符串字段的空白规范化。 +func stringsTrimSpace(value string) string { + return strings.TrimSpace(value) +} diff --git a/internal/session/sqlite_store_additional_test.go b/internal/session/sqlite_store_additional_test.go new file mode 100644 index 00000000..c7d0eb10 --- /dev/null +++ b/internal/session/sqlite_store_additional_test.go @@ -0,0 +1,241 @@ +package session + +import ( + "context" + "database/sql" + "errors" + "os" + "path/filepath" + "strings" + "testing" + "time" + + providertypes "neo-code/internal/provider/types" +) + +func TestSQLiteStoreMethodsRespectCanceledContext(t *testing.T) { + store := newTestStore(t) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + if _, err := store.CreateSession(ctx, CreateSessionInput{ID: "cancel_ctx", Title: "cancel"}); err == nil { + t.Fatalf("expected CreateSession canceled error") + } + if _, err := store.LoadSession(ctx, "cancel_ctx"); err == nil { + t.Fatalf("expected LoadSession canceled error") + } + if _, err := store.ListSummaries(ctx); err == nil { + t.Fatalf("expected ListSummaries canceled error") + } + if err := store.AppendMessages(ctx, AppendMessagesInput{ + SessionID: "cancel_ctx", + Messages: []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")}}, + }, + }); err == nil { + t.Fatalf("expected AppendMessages canceled error") + } + if err := store.UpdateSessionState(ctx, UpdateSessionStateInput{SessionID: "cancel_ctx", Title: "x"}); err == nil { + t.Fatalf("expected UpdateSessionState canceled error") + } + if err := store.ReplaceTranscript(ctx, ReplaceTranscriptInput{SessionID: "cancel_ctx"}); err == nil { + t.Fatalf("expected ReplaceTranscript canceled error") + } +} + +func TestSQLiteStoreMethodsRejectInvalidSessionID(t *testing.T) { + ctx := context.Background() + store := newTestStore(t) + + if _, err := store.LoadSession(ctx, "bad/id"); err == nil { + t.Fatalf("expected LoadSession invalid id error") + } + if err := store.AppendMessages(ctx, AppendMessagesInput{ + SessionID: "bad/id", + Messages: []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")}}, + }, + }); err == nil { + t.Fatalf("expected AppendMessages invalid id error") + } + if err := store.UpdateSessionState(ctx, UpdateSessionStateInput{SessionID: "bad/id", Title: "x"}); err == nil { + t.Fatalf("expected UpdateSessionState invalid id error") + } + if err := store.ReplaceTranscript(ctx, ReplaceTranscriptInput{SessionID: "bad/id"}); err == nil { + t.Fatalf("expected ReplaceTranscript invalid id error") + } +} + +func TestSQLiteHelperBranches(t *testing.T) { + if got, err := normalizeMessages(nil); err != nil || got != nil { + t.Fatalf("normalizeMessages(nil) = (%v, %v), want (nil, nil)", got, err) + } + if !fromUnixMillis(0).IsZero() { + t.Fatalf("fromUnixMillis(0) should return zero time") + } + if boolToInt(true) != 1 || boolToInt(false) != 0 { + t.Fatalf("boolToInt conversion mismatch") + } + + withoutMetadata := cloneMessage(providertypes.Message{Role: providertypes.RoleAssistant}) + if withoutMetadata.ToolMetadata != nil { + t.Fatalf("expected nil metadata clone for empty metadata input") + } + + original := Session{ + ID: "clone_test", + TaskState: TaskState{ + Goal: "goal", + }, + ActivatedSkills: []SkillActivation{{SkillID: "go-review"}}, + Todos: []TodoItem{{ID: "todo-1", Content: "a"}}, + Messages: []providertypes.Message{ + { + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("x")}, + ToolMetadata: map[string]string{"k": "v"}, + }, + }, + } + cloned := cloneSessionValue(original) + cloned.TaskState.Goal = "updated" + cloned.ActivatedSkills[0].SkillID = "other" + cloned.Todos[0].Content = "b" + cloned.Messages[0].ToolMetadata["k"] = "changed" + if original.TaskState.Goal != "goal" { + t.Fatalf("cloneSessionValue should deep-clone task state") + } + if original.ActivatedSkills[0].SkillID != "go-review" { + t.Fatalf("cloneSessionValue should deep-clone activated skills") + } + if original.Todos[0].Content != "a" { + t.Fatalf("cloneSessionValue should deep-clone todos") + } + if original.Messages[0].ToolMetadata["k"] != "v" { + t.Fatalf("cloneSessionValue should deep-clone message metadata") + } + + if got := mustJSONString(nil); got != "[]" { + t.Fatalf("mustJSONString(nil) = %q, want []", got) + } + var nilMap map[string]string + if got := mustJSONString(nilMap); got != "{}" { + t.Fatalf("mustJSONString(nil map) = %q, want {}", got) + } + var nilCalls []providertypes.ToolCall + if got := mustJSONString(nilCalls); got != "[]" { + t.Fatalf("mustJSONString(nil tool calls) = %q, want []", got) + } + + defer func() { + recovered := recover() + if recovered == nil { + t.Fatalf("mustJSONString should panic for unsupported value") + } + }() + _ = mustJSONString(func() {}) +} + +type fakeResult struct { + rows int64 + err error +} + +func (f fakeResult) LastInsertId() (int64, error) { + return 0, nil +} + +func (f fakeResult) RowsAffected() (int64, error) { + if f.err != nil { + return 0, f.err + } + return f.rows, nil +} + +func TestExpectRowsAffectedBranches(t *testing.T) { + if err := expectRowsAffected(fakeResult{err: errors.New("boom")}, "s1"); err == nil || !strings.Contains(err.Error(), "inspect rows affected") { + t.Fatalf("expected rows affected inspect error, got %v", err) + } + if err := expectRowsAffected(fakeResult{rows: 0}, "s1"); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected os.ErrNotExist when rows=0, got %v", err) + } + if err := expectRowsAffected(fakeResult{rows: 1}, "s1"); err != nil { + t.Fatalf("expected rows=1 to pass, got %v", err) + } +} + +func TestStorageHelpersAdditionalErrorBranches(t *testing.T) { + baseDir := t.TempDir() + if err := ensurePathWithinBase(filepath.Join(baseDir, "missing"), filepath.Join(baseDir, "target")); err == nil { + t.Fatalf("expected ensurePathWithinBase to fail with missing base dir") + } + + missingParentTarget := filepath.Join(baseDir, "missing-parent", "child.txt") + if _, err := resolvePathForContainment(missingParentTarget); err == nil || !strings.Contains(err.Error(), "eval parent symlinks") { + t.Fatalf("expected parent symlink resolution error, got %v", err) + } + + tempFile, tempPath, err := createTempFile(baseDir, "replace-*.tmp", "create temp") + if err != nil { + t.Fatalf("createTempFile() error = %v", err) + } + if err := tempFile.Close(); err != nil { + t.Fatalf("tempFile.Close() error = %v", err) + } + targetDir := filepath.Join(baseDir, "target-dir") + if err := os.MkdirAll(targetDir, 0o755); err != nil { + t.Fatalf("MkdirAll(targetDir) error = %v", err) + } + if err := os.WriteFile(filepath.Join(targetDir, "existing.txt"), []byte("x"), 0o644); err != nil { + t.Fatalf("WriteFile(existing) error = %v", err) + } + if err := replaceFileWithTemp(tempPath, targetDir, "target-dir"); err == nil { + t.Fatalf("expected replaceFileWithTemp remove-target error") + } +} + +func TestInitializeSQLiteSchemaOnClosedDB(t *testing.T) { + t.Parallel() + + dbPath := filepath.Join(t.TempDir(), "closed.db") + db, err := sql.Open("sqlite", dbPath) + if err != nil { + t.Fatalf("sql.Open() error = %v", err) + } + if err := db.Close(); err != nil { + t.Fatalf("db.Close() error = %v", err) + } + if err := initializeSQLiteSchema(context.Background(), db); err == nil { + t.Fatalf("expected initializeSQLiteSchema on closed db to fail") + } +} + +func TestNormalizeCreateSessionInputDefaultsGeneratedID(t *testing.T) { + t.Parallel() + + session, err := normalizeCreateSessionInput(CreateSessionInput{ + Title: " test ", + Todos: []TodoItem{{ID: "todo-1", Content: "a"}}, + }) + if err != nil { + t.Fatalf("normalizeCreateSessionInput() error = %v", err) + } + if session.ID == "" || !strings.HasPrefix(session.ID, "session_") { + t.Fatalf("expected generated session id, got %q", session.ID) + } + if session.CreatedAt.IsZero() || session.UpdatedAt.IsZero() { + t.Fatalf("expected default timestamps to be populated") + } + if session.TodoVersion != CurrentTodoVersion { + t.Fatalf("expected todo version %d, got %d", CurrentTodoVersion, session.TodoVersion) + } +} + +func TestResolveUpdatedAtReturnsProvidedValue(t *testing.T) { + t.Parallel() + + provided := time.Now().UTC().Add(-time.Minute).Truncate(time.Millisecond) + if got := resolveUpdatedAt(provided); !got.Equal(provided) { + t.Fatalf("resolveUpdatedAt should keep non-zero value, got %v want %v", got, provided) + } +} diff --git a/internal/session/storage_helpers.go b/internal/session/storage_helpers.go new file mode 100644 index 00000000..637ff6d3 --- /dev/null +++ b/internal/session/storage_helpers.go @@ -0,0 +1,77 @@ +package session + +import ( + "errors" + "fmt" + "os" + "path/filepath" +) + +// ensurePathWithinBase 校验目标路径位于指定基目录内,避免路径越界。 +func ensurePathWithinBase(baseDir string, target string) error { + baseResolved, err := resolvePathForContainment(baseDir) + if err != nil { + return fmt.Errorf("resolve base dir %q: %w", baseDir, err) + } + targetResolved, err := resolvePathForContainment(target) + if err != nil { + return fmt.Errorf("resolve target path %q: %w", target, err) + } + rel, err := filepath.Rel(baseResolved, targetResolved) + if err != nil { + return fmt.Errorf("compute relative path %q -> %q: %w", baseResolved, targetResolved, err) + } + if rel == "." { + return nil + } + if !filepath.IsLocal(rel) { + return fmt.Errorf("target path %q escapes base dir %q", targetResolved, baseResolved) + } + return nil +} + +// resolvePathForContainment 将路径归一化为绝对路径并解析软链接,确保包含性校验基于真实路径。 +func resolvePathForContainment(path string) (string, error) { + absPath, err := filepath.Abs(path) + if err != nil { + return "", fmt.Errorf("resolve absolute path: %w", err) + } + resolved, err := filepath.EvalSymlinks(absPath) + if err == nil { + return resolved, nil + } + if !errors.Is(err, os.ErrNotExist) { + return "", fmt.Errorf("eval symlinks: %w", err) + } + parent := filepath.Dir(absPath) + resolvedParent, parentErr := filepath.EvalSymlinks(parent) + if parentErr != nil { + return "", fmt.Errorf("eval parent symlinks: %w", parentErr) + } + return filepath.Join(resolvedParent, filepath.Base(absPath)), nil +} + +// createTempFile 在目标目录中创建唯一临时文件。 +func createTempFile(dir string, pattern string, op string) (*os.File, string, error) { + file, err := os.CreateTemp(dir, pattern) + if err != nil { + return nil, "", fmt.Errorf("session: %s: %w", op, err) + } + if err := ensurePathWithinBase(dir, file.Name()); err != nil { + _ = file.Close() + _ = os.Remove(file.Name()) + return nil, "", fmt.Errorf("session: %s: %w", op, err) + } + return file, file.Name(), nil +} + +// replaceFileWithTemp 使用原子重命名替换目标文件。 +func replaceFileWithTemp(tempPath string, target string, label string) error { + if err := os.Remove(target); err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("session: replace %s: %w", label, err) + } + if err := os.Rename(tempPath, target); err != nil { + return fmt.Errorf("session: commit %s: %w", label, err) + } + return nil +} diff --git a/internal/session/store.go b/internal/session/store.go index e6a52d8d..75e755f0 100644 --- a/internal/session/store.go +++ b/internal/session/store.go @@ -2,492 +2,136 @@ package session import ( "context" - "encoding/json" - "errors" "fmt" - "io" - "os" - "path/filepath" "regexp" - "sort" "strings" - "sync" "time" providertypes "neo-code/internal/provider/types" ) -const sessionsDirName = "sessions" - const ( - sessionFileName = "session.json" - assetsDirName = "assets" + sessionDatabaseFileName = "session.db" + assetsDirName = "assets" + sqliteSchemaVersion = 1 ) var storageIDPattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]{0,127}$`) -// Session 表示单个会话的持久化模型,包含基础元数据与消息历史。 -// Provider / Model 用于在 compact 等流程中优先复用会话最近一次成功运行的模型配置。 +// Session 表示单个会话的运行态与持久化聚合模型。 type Session struct { - SchemaVersion int `json:"schema_version"` - ID string `json:"id"` - Title string `json:"title"` - // Provider 记录最近一次成功运行会话时使用的 provider,用于 compact 优先复用历史配置。 - Provider string `json:"provider,omitempty"` - // Model 记录最近一次成功运行会话时使用的 model,用于 compact 优先复用历史配置。 - Model string `json:"model,omitempty"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - Workdir string `json:"workdir,omitempty"` - TaskState TaskState `json:"task_state"` - ActivatedSkills []SkillActivation `json:"activated_skills,omitempty"` - TodoVersion int `json:"todo_version,omitempty"` - Todos []TodoItem `json:"todos,omitempty"` - Messages []providertypes.Message `json:"messages"` - TokenInputTotal int `json:"token_input_total,omitempty"` - TokenOutputTotal int `json:"token_output_total,omitempty"` -} - -// Summary 表示会话列表视图所需的轻量摘要信息。 + ID string + Title string + Provider string + Model string + CreatedAt time.Time + UpdatedAt time.Time + Workdir string + TaskState TaskState + ActivatedSkills []SkillActivation + TodoVersion int + Todos []TodoItem + Messages []providertypes.Message + TokenInputTotal int + TokenOutputTotal int +} + +// Summary 表示会话列表视图需要的轻量摘要。 type Summary struct { - ID string `json:"id"` - Title string `json:"title"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` -} - -// Store 定义会话持久化抽象。 + ID string + Title string + CreatedAt time.Time + UpdatedAt time.Time +} + +// CreateSessionInput 描述新建空会话头时需要写入的字段。 +type CreateSessionInput struct { + ID string + Title string + CreatedAt time.Time + UpdatedAt time.Time + Provider string + Model string + Workdir string + TaskState TaskState + ActivatedSkills []SkillActivation + Todos []TodoItem + TokenInputTotal int + TokenOutputTotal int +} + +// AppendMessagesInput 描述一次原子追加消息及会话头增量更新。 +type AppendMessagesInput struct { + SessionID string + Messages []providertypes.Message + UpdatedAt time.Time + Provider string + Model string + Workdir string + TokenInputDelta int + TokenOutputDelta int +} + +// UpdateSessionStateInput 描述一次只更新会话头状态的写入。 +type UpdateSessionStateInput struct { + SessionID string + Title string + UpdatedAt time.Time + Provider string + Model string + Workdir string + TaskState TaskState + ActivatedSkills []SkillActivation + Todos []TodoItem + TokenInputTotal int + TokenOutputTotal int +} + +// UpdateSessionWorkdirInput 描述一次仅更新会话 workdir 的最小粒度写入。 +type UpdateSessionWorkdirInput struct { + SessionID string + UpdatedAt time.Time + Workdir string +} + +// ReplaceTranscriptInput 描述 compact 后整段 transcript 的原子替换。 +type ReplaceTranscriptInput struct { + SessionID string + Messages []providertypes.Message + UpdatedAt time.Time + Provider string + Model string + Workdir string + TaskState TaskState + ActivatedSkills []SkillActivation + Todos []TodoItem + TokenInputTotal int + TokenOutputTotal int +} + +// Store 定义会话持久化的意图型接口。 type Store interface { - Save(ctx context.Context, session *Session) error - Load(ctx context.Context, id string) (Session, error) + CreateSession(ctx context.Context, input CreateSessionInput) (Session, error) + LoadSession(ctx context.Context, id string) (Session, error) ListSummaries(ctx context.Context) ([]Summary, error) - DeleteSession(ctx context.Context, id string) error -} - -// JSONStore 是基于 JSON 文件的会话存储实现。 -type JSONStore struct { - mu sync.RWMutex - baseDir string -} - -// contextReader 在读取前检查上下文取消状态,避免长时间 I/O 无法及时退出。 -type contextReader struct { - ctx context.Context - reader io.Reader -} - -func (r *contextReader) Read(p []byte) (int, error) { - if r == nil || r.reader == nil { - return 0, io.EOF - } - if r.ctx != nil { - if err := r.ctx.Err(); err != nil { - return 0, err - } - } - return r.reader.Read(p) -} - -func contextDone(ctx context.Context) error { - if ctx == nil { - return nil - } - if err := ctx.Err(); err != nil { - return err - } - return nil -} - -// NewJSONStore 创建 JSONStore,实际会话目录为 {baseDir}/sessions。 -func NewJSONStore(baseDir string, workspaceRoot string) *JSONStore { - return &JSONStore{ - baseDir: sessionDirectory(baseDir, workspaceRoot), - } -} - -// NewStore 返回默认会话存储实现(当前为 JSONStore)。 -func NewStore(baseDir string, workspaceRoot string) *JSONStore { - return NewJSONStore(baseDir, workspaceRoot) -} - -// Save 持久化会话到 JSON 文件,采用临时文件 + 原子替换策略。 -func (s *JSONStore) Save(ctx context.Context, session *Session) error { - if err := ctx.Err(); err != nil { - return err - } - if session == nil { - return errors.New("session: session is nil") - } - if err := validateSessionSchema(*session); err != nil { - return err - } - if err := validateStorageID("session id", session.ID); err != nil { - return fmt.Errorf("session: %w", err) - } - - session.TaskState = normalizeAndClampTaskState(session.TaskState) - session.ActivatedSkills = normalizeSkillActivations(session.ActivatedSkills) - normalizedTodos, err := normalizeAndValidateTodos(session.Todos) - if err != nil { - return err - } - session.Todos = normalizedTodos - if len(session.Todos) > 0 && session.TodoVersion <= 0 { - session.TodoVersion = CurrentTodoVersion - } - - s.mu.Lock() - defer s.mu.Unlock() - - if err := os.MkdirAll(s.baseDir, 0o755); err != nil { - return fmt.Errorf("session: create sessions dir: %w", err) - } - - payload, err := json.MarshalIndent(session, "", " ") - if err != nil { - return fmt.Errorf("session: marshal session: %w", err) - } - payload = append(payload, '\n') - - target := s.filePath(session.ID) - if err := ensurePathWithinBase(s.baseDir, target); err != nil { - return fmt.Errorf("session: resolve session file path: %w", err) - } - if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { - return fmt.Errorf("session: create session dir: %w", err) - } - if err := writeFileAtomically(target, "session-*.tmp", payload, 0o644); err != nil { - return err - } - - return nil -} - -// Load 读取并反序列化指定 ID 的会话文件。 -func (s *JSONStore) Load(ctx context.Context, id string) (Session, error) { - if err := ctx.Err(); err != nil { - return Session{}, err - } - - if err := validateStorageID("session id", id); err != nil { - return Session{}, fmt.Errorf("session: %w", err) - } - - s.mu.RLock() - defer s.mu.RUnlock() - - target := s.filePath(id) - if err := ensurePathWithinBase(s.baseDir, target); err != nil { - return Session{}, fmt.Errorf("session: resolve session file path: %w", err) - } - - data, err := os.ReadFile(target) - if err != nil { - return Session{}, err - } - - session, err := decodeStoredSession(data) - if err != nil { - return Session{}, fmt.Errorf("session: decode session %s: %w", id, err) - } - return session, nil -} - -// ListSummaries 列出所有会话摘要,并按 UpdatedAt 倒序返回。 -func (s *JSONStore) ListSummaries(ctx context.Context) ([]Summary, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - - s.mu.RLock() - defer s.mu.RUnlock() - - if err := os.MkdirAll(s.baseDir, 0o755); err != nil { - return nil, fmt.Errorf("session: create sessions dir: %w", err) - } - - entries, err := os.ReadDir(s.baseDir) - if err != nil { - return nil, fmt.Errorf("session: list sessions dir: %w", err) - } - - summaries := make([]Summary, 0, len(entries)) - for _, entry := range entries { - if !entry.IsDir() { - continue - } - - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - target := filepath.Join(s.baseDir, entry.Name(), sessionFileName) - if err := ensurePathWithinBase(s.baseDir, target); err != nil { - continue - } - - data, readErr := os.ReadFile(target) - if readErr != nil { - continue - } - - summary, err := decodeStoredSummary(data) - if err != nil { - continue - } - if strings.TrimSpace(summary.ID) == "" { - continue - } - summaries = append(summaries, summary) - } - - sort.Slice(summaries, func(i, j int) bool { - return summaries[i].UpdatedAt.After(summaries[j].UpdatedAt) - }) - - return summaries, nil -} - -// DeleteSession 删除指定会话目录及其附件,供创建后失败回滚等场景复用。 -func (s *JSONStore) DeleteSession(ctx context.Context, id string) error { - if err := ctx.Err(); err != nil { - return err - } - if err := validateStorageID("session id", id); err != nil { - return fmt.Errorf("session: %w", err) - } - - s.mu.Lock() - defer s.mu.Unlock() - - target := s.sessionDir(id) - if err := ensurePathWithinBase(s.baseDir, target); err != nil { - return fmt.Errorf("session: resolve session dir path: %w", err) - } - if err := os.RemoveAll(target); err != nil { - return fmt.Errorf("session: delete session dir: %w", err) - } - return nil -} - -// filePath 生成会话 ID 对应的 JSON 文件路径。 -func (s *JSONStore) filePath(id string) string { - return filepath.Join(s.sessionDir(id), sessionFileName) -} - -// sessionDir 返回指定会话在当前工作区分桶下的目录路径。 -func (s *JSONStore) sessionDir(id string) string { - return filepath.Join(s.baseDir, id) -} - -// assetsDir 返回指定会话附件目录路径。 -func (s *JSONStore) assetsDir(sessionID string) string { - return filepath.Join(s.sessionDir(sessionID), assetsDirName) -} - -// assetPath 返回指定会话附件二进制文件路径。 -func (s *JSONStore) assetPath(sessionID string, assetID string) string { - return filepath.Join(s.assetsDir(sessionID), assetID+".bin") -} - -// assetMetaPath 返回指定会话附件元数据文件路径。 -func (s *JSONStore) assetMetaPath(sessionID string, assetID string) string { - return filepath.Join(s.assetsDir(sessionID), assetID+".json") -} - -// SaveAsset 将会话附件二进制内容写入当前工作区会话目录,并返回附件元数据。 -func (s *JSONStore) SaveAsset(ctx context.Context, sessionID string, r io.Reader, mimeType string) (AssetMeta, error) { - if err := contextDone(ctx); err != nil { - return AssetMeta{}, err - } - if r == nil { - return AssetMeta{}, errors.New("session: asset reader is nil") - } - if err := validateStorageID("session id", sessionID); err != nil { - return AssetMeta{}, fmt.Errorf("session: %w", err) - } - - meta, err := newAssetMeta(mimeType) - if err != nil { - return AssetMeta{}, err - } - - s.mu.Lock() - defer s.mu.Unlock() - - assetDir := s.assetsDir(sessionID) - if err := ensurePathWithinBase(s.baseDir, assetDir); err != nil { - return AssetMeta{}, fmt.Errorf("session: resolve assets dir path: %w", err) - } - if err := os.MkdirAll(assetDir, 0o755); err != nil { - return AssetMeta{}, fmt.Errorf("session: create assets dir: %w", err) - } - if err := contextDone(ctx); err != nil { - return AssetMeta{}, err - } - - target := s.assetPath(sessionID, meta.ID) - if err := ensurePathWithinBase(s.baseDir, target); err != nil { - return AssetMeta{}, fmt.Errorf("session: resolve asset file path: %w", err) - } - tempFile, tempPath, err := createTempFile(assetDir, "asset-*.tmp", "create temp asset") - if err != nil { - return AssetMeta{}, err - } - - limitedReader := io.LimitReader(&contextReader{ctx: ctx, reader: r}, providertypes.MaxSessionAssetBytes+1) - written, copyErr := io.Copy(tempFile, limitedReader) - syncErr := tempFile.Sync() - closeErr := tempFile.Close() - if ctxErr := contextDone(ctx); ctxErr != nil { - _ = os.Remove(tempPath) - return AssetMeta{}, ctxErr - } - if copyErr != nil { - _ = os.Remove(tempPath) - return AssetMeta{}, fmt.Errorf("session: write temp asset: %w", copyErr) - } - if written > providertypes.MaxSessionAssetBytes { - _ = os.Remove(tempPath) - return AssetMeta{}, fmt.Errorf("session: asset size exceeds %d bytes", providertypes.MaxSessionAssetBytes) - } - if syncErr != nil { - _ = os.Remove(tempPath) - return AssetMeta{}, fmt.Errorf("session: sync temp asset: %w", syncErr) - } - if closeErr != nil { - _ = os.Remove(tempPath) - return AssetMeta{}, fmt.Errorf("session: close temp asset: %w", closeErr) - } - - meta.Size = written - if err := replaceFileWithTemp(tempPath, target, "asset file"); err != nil { - _ = os.Remove(tempPath) - return AssetMeta{}, err - } - if err := contextDone(ctx); err != nil { - _ = os.Remove(target) - return AssetMeta{}, err - } - - metaData, err := encodeStoredAssetMeta(meta) - if err != nil { - _ = os.Remove(target) - return AssetMeta{}, err - } - metaTarget := s.assetMetaPath(sessionID, meta.ID) - if err := ensurePathWithinBase(s.baseDir, metaTarget); err != nil { - _ = os.Remove(target) - return AssetMeta{}, fmt.Errorf("session: resolve asset meta file path: %w", err) - } - if err := writeFileAtomically(metaTarget, "asset-meta-*.tmp", metaData, 0o644); err != nil { - _ = os.Remove(target) - return AssetMeta{}, err - } - if err := contextDone(ctx); err != nil { - _ = os.Remove(target) - _ = os.Remove(metaTarget) - return AssetMeta{}, err - } - - return meta, nil + AppendMessages(ctx context.Context, input AppendMessagesInput) error + UpdateSessionWorkdir(ctx context.Context, input UpdateSessionWorkdirInput) error + UpdateSessionState(ctx context.Context, input UpdateSessionStateInput) error + ReplaceTranscript(ctx context.Context, input ReplaceTranscriptInput) error } -// Open 读取会话附件二进制内容并返回可关闭流与附件元数据。 -func (s *JSONStore) Open(ctx context.Context, sessionID string, assetID string) (io.ReadCloser, AssetMeta, error) { - if err := ctx.Err(); err != nil { - return nil, AssetMeta{}, err - } - if err := validateStorageID("session id", sessionID); err != nil { - return nil, AssetMeta{}, fmt.Errorf("session: %w", err) - } - if err := validateStorageID("asset id", assetID); err != nil { - return nil, AssetMeta{}, fmt.Errorf("session: %w", err) - } - - s.mu.RLock() - defer s.mu.RUnlock() - - meta, err := s.statUnlocked(sessionID, assetID) - if err != nil { - return nil, AssetMeta{}, err - } - - target := s.assetPath(sessionID, assetID) - if err := ensurePathWithinBase(s.baseDir, target); err != nil { - return nil, AssetMeta{}, fmt.Errorf("session: resolve asset file path: %w", err) - } - file, err := os.Open(target) - if err != nil { - return nil, AssetMeta{}, err +// NewSQLiteStore 创建基于 SQLite 的会话存储实现。 +func NewSQLiteStore(baseDir string, workspaceRoot string) *SQLiteStore { + return &SQLiteStore{ + projectDir: projectDirectory(baseDir, workspaceRoot), + assetsDir: assetsDirectory(baseDir, workspaceRoot), + dbPath: databasePath(baseDir, workspaceRoot), } - return file, meta, nil } -// Stat 返回会话附件的元数据而不读取实际内容。 -func (s *JSONStore) Stat(ctx context.Context, sessionID string, assetID string) (AssetMeta, error) { - if err := ctx.Err(); err != nil { - return AssetMeta{}, err - } - if err := validateStorageID("session id", sessionID); err != nil { - return AssetMeta{}, fmt.Errorf("session: %w", err) - } - if err := validateStorageID("asset id", assetID); err != nil { - return AssetMeta{}, fmt.Errorf("session: %w", err) - } - - s.mu.RLock() - defer s.mu.RUnlock() - - return s.statUnlocked(sessionID, assetID) -} - -// DeleteAsset 删除指定会话附件的二进制与元数据文件,用于输入归一化失败后的清理。 -func (s *JSONStore) DeleteAsset(ctx context.Context, sessionID string, assetID string) error { - if err := ctx.Err(); err != nil { - return err - } - if err := validateStorageID("session id", sessionID); err != nil { - return fmt.Errorf("session: %w", err) - } - if err := validateStorageID("asset id", assetID); err != nil { - return fmt.Errorf("session: %w", err) - } - - s.mu.Lock() - defer s.mu.Unlock() - - target := s.assetPath(sessionID, assetID) - if err := ensurePathWithinBase(s.baseDir, target); err != nil { - return fmt.Errorf("session: resolve asset file path: %w", err) - } - if err := os.Remove(target); err != nil && !errors.Is(err, os.ErrNotExist) { - return fmt.Errorf("session: delete asset file: %w", err) - } - - metaTarget := s.assetMetaPath(sessionID, assetID) - if err := ensurePathWithinBase(s.baseDir, metaTarget); err != nil { - return fmt.Errorf("session: resolve asset meta file path: %w", err) - } - if err := os.Remove(metaTarget); err != nil && !errors.Is(err, os.ErrNotExist) { - return fmt.Errorf("session: delete asset meta file: %w", err) - } - return nil -} - -// statUnlocked 在调用方已持有读锁时读取附件元数据,避免重复加锁导致死锁风险。 -func (s *JSONStore) statUnlocked(sessionID string, assetID string) (AssetMeta, error) { - target := s.assetMetaPath(sessionID, assetID) - if err := ensurePathWithinBase(s.baseDir, target); err != nil { - return AssetMeta{}, fmt.Errorf("session: resolve asset meta file path: %w", err) - } - data, err := os.ReadFile(target) - if err != nil { - return AssetMeta{}, err - } - return decodeStoredAssetMeta(data) +// NewStore 返回默认会话存储实现。 +func NewStore(baseDir string, workspaceRoot string) *SQLiteStore { + return NewSQLiteStore(baseDir, workspaceRoot) } // New 创建一个默认标题策略的新会话对象。 @@ -495,11 +139,10 @@ func New(title string) Session { return NewWithWorkdir(title, "") } -// NewWithWorkdir 创建一个包含运行目录的会话对象。 +// NewWithWorkdir 创建一个带运行目录的会话对象。 func NewWithWorkdir(title string, workdir string) Session { now := time.Now() return Session{ - SchemaVersion: CurrentSchemaVersion, ID: NewID("session"), Title: sanitizeTitle(title), CreatedAt: now, @@ -512,7 +155,7 @@ func NewWithWorkdir(title string, workdir string) Session { } } -// sanitizeTitle 规范化会话标题:去空白、空标题回退默认值、超长截断。 +// sanitizeTitle 规范化会话标题,保证空标题和超长标题都有稳定表现。 func sanitizeTitle(title string) string { title = strings.TrimSpace(title) if title == "" { @@ -525,122 +168,7 @@ func sanitizeTitle(title string) string { return title } -// validateSessionSchema 校验会话持久化版本,开发阶段只接受当前结构版本。 -func validateSessionSchema(session Session) error { - if session.SchemaVersion != CurrentSchemaVersion { - return fmt.Errorf( - "session: unsupported schema_version %d, expected %d", - session.SchemaVersion, - CurrentSchemaVersion, - ) - } - return nil -} - -// decodeStoredSession 严格校验持久化会话所需字段,并拒绝缺少 schema_version 或 task_state 的旧数据。 -func decodeStoredSession(data []byte) (Session, error) { - type storedSession struct { - SchemaVersion *int `json:"schema_version"` - ID string `json:"id"` - Title string `json:"title"` - Provider string `json:"provider,omitempty"` - Model string `json:"model,omitempty"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - Workdir string `json:"workdir,omitempty"` - TaskState *TaskState `json:"task_state"` - ActivatedSkills []SkillActivation `json:"activated_skills,omitempty"` - TodoVersion *int `json:"todo_version,omitempty"` - Todos []TodoItem `json:"todos,omitempty"` - Messages []providertypes.Message `json:"messages"` - TokenInput int `json:"token_input_total,omitempty"` - TokenOutput int `json:"token_output_total,omitempty"` - } - - var stored storedSession - if err := json.Unmarshal(data, &stored); err != nil { - return Session{}, err - } - - if stored.SchemaVersion == nil { - return Session{}, errors.New("missing required field schema_version") - } - if stored.TaskState == nil { - return Session{}, errors.New("missing required field task_state") - } - - session := Session{ - SchemaVersion: *stored.SchemaVersion, - ID: stored.ID, - Title: stored.Title, - Provider: stored.Provider, - Model: stored.Model, - CreatedAt: stored.CreatedAt, - UpdatedAt: stored.UpdatedAt, - Workdir: stored.Workdir, - TaskState: *stored.TaskState, - ActivatedSkills: stored.ActivatedSkills, - TodoVersion: 0, - Todos: stored.Todos, - Messages: stored.Messages, - TokenInputTotal: stored.TokenInput, - TokenOutputTotal: stored.TokenOutput, - } - if stored.TodoVersion != nil { - session.TodoVersion = *stored.TodoVersion - } - if err := validateSessionSchema(session); err != nil { - return Session{}, err - } - session.TaskState = normalizeAndClampTaskState(session.TaskState) - session.ActivatedSkills = normalizeSkillActivations(session.ActivatedSkills) - normalizedTodos, err := normalizeAndValidateTodos(session.Todos) - if err != nil { - return Session{}, err - } - session.Todos = normalizedTodos - if len(session.Todos) > 0 && session.TodoVersion <= 0 { - session.TodoVersion = CurrentTodoVersion - } - return session, nil -} - -// normalizeAndClampTaskState 先规范化再限幅,保证持久化前后的 task_state 行为一致。 -func normalizeAndClampTaskState(state TaskState) TaskState { - return ClampTaskStateBoundaries(NormalizeTaskState(state)) -} - -// decodeStoredSummary 只解析会话列表所需的摘要元数据,避免为列表视图反序列化完整消息历史。 -func decodeStoredSummary(data []byte) (Summary, error) { - var stored struct { - SchemaVersion *int `json:"schema_version"` - ID string `json:"id"` - Title string `json:"title"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - TaskState json.RawMessage `json:"task_state"` - } - if err := json.Unmarshal(data, &stored); err != nil { - return Summary{}, err - } - if stored.SchemaVersion == nil { - return Summary{}, errors.New("missing required field schema_version") - } - if len(stored.TaskState) == 0 { - return Summary{}, errors.New("missing required field task_state") - } - if err := validateSessionSchema(Session{SchemaVersion: *stored.SchemaVersion}); err != nil { - return Summary{}, err - } - return Summary{ - ID: stored.ID, - Title: stored.Title, - CreatedAt: stored.CreatedAt, - UpdatedAt: stored.UpdatedAt, - }, nil -} - -// validateStorageID 校验会话/附件 ID,避免路径穿越和非法文件名。 +// validateStorageID 校验会话和附件 ID,避免路径穿越与非法文件名。 func validateStorageID(label string, id string) error { trimmed := strings.TrimSpace(id) if trimmed == "" { @@ -651,85 +179,3 @@ func validateStorageID(label string, id string) error { } return nil } - -// ensurePathWithinBase 校验目标路径在给定基目录内,作为 ID 白名单之外的二次路径约束。 -func ensurePathWithinBase(baseDir string, target string) error { - baseAbs, err := filepath.Abs(baseDir) - if err != nil { - return fmt.Errorf("resolve base dir %q: %w", baseDir, err) - } - targetAbs, err := filepath.Abs(target) - if err != nil { - return fmt.Errorf("resolve target path %q: %w", target, err) - } - rel, err := filepath.Rel(baseAbs, targetAbs) - if err != nil { - return fmt.Errorf("compute relative path %q -> %q: %w", baseAbs, targetAbs, err) - } - if rel == "." { - return nil - } - if !filepath.IsLocal(rel) { - return fmt.Errorf("target path %q escapes base dir %q", targetAbs, baseAbs) - } - return nil -} - -// createTempFile 在目标目录创建唯一临时文件,避免固定 *.tmp 命名在并发场景下冲突。 -func createTempFile(dir string, pattern string, op string) (*os.File, string, error) { - file, err := os.CreateTemp(dir, pattern) - if err != nil { - return nil, "", fmt.Errorf("session: %s: %w", op, err) - } - if err := ensurePathWithinBase(dir, file.Name()); err != nil { - _ = file.Close() - _ = os.Remove(file.Name()) - return nil, "", fmt.Errorf("session: %s: %w", op, err) - } - return file, file.Name(), nil -} - -// replaceFileWithTemp 使用原子重命名替换目标文件,兼容 Windows 需要先删除旧文件的行为。 -func replaceFileWithTemp(tempPath string, target string, label string) error { - if err := os.Remove(target); err != nil && !errors.Is(err, os.ErrNotExist) { - return fmt.Errorf("session: replace %s: %w", label, err) - } - if err := os.Rename(tempPath, target); err != nil { - return fmt.Errorf("session: commit %s: %w", label, err) - } - return nil -} - -// writeFileAtomically 将字节数据写入唯一临时文件并原子替换目标文件,避免中间态文件暴露。 -func writeFileAtomically(target string, tempPattern string, payload []byte, perm os.FileMode) error { - dir := filepath.Dir(target) - tempFile, tempPath, err := createTempFile(dir, tempPattern, "create temp file") - if err != nil { - return err - } - - _, writeErr := tempFile.Write(payload) - syncErr := tempFile.Sync() - closeErr := tempFile.Close() - if writeErr != nil { - _ = os.Remove(tempPath) - return fmt.Errorf("session: write temp file: %w", writeErr) - } - if syncErr != nil { - _ = os.Remove(tempPath) - return fmt.Errorf("session: sync temp file: %w", syncErr) - } - if closeErr != nil { - _ = os.Remove(tempPath) - return fmt.Errorf("session: close temp file: %w", closeErr) - } - if err := os.Chmod(tempPath, perm); err != nil { - _ = os.Remove(tempPath) - return fmt.Errorf("session: chmod temp file: %w", err) - } - if err := replaceFileWithTemp(tempPath, target, "file"); err != nil { - _ = os.Remove(tempPath) - return err - } - return nil -} diff --git a/internal/session/store_test.go b/internal/session/store_test.go index 54d92baf..e38193b5 100644 --- a/internal/session/store_test.go +++ b/internal/session/store_test.go @@ -2,11 +2,10 @@ package session import ( "context" - "encoding/json" + "database/sql" "errors" "os" "path/filepath" - goruntime "runtime" "strings" "testing" "time" @@ -14,1256 +13,468 @@ import ( providertypes "neo-code/internal/provider/types" ) -func TestJSONStoreSaveLoadAndListSummaries(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := filepath.Join(t.TempDir(), "workspace") - if err := os.MkdirAll(workspaceRoot, 0o755); err != nil { - t.Fatalf("mkdir workspace root: %v", err) +func TestSQLiteStoreLifecycleRoundTrip(t *testing.T) { + ctx := context.Background() + store := newTestStore(t) + createdAt := time.Now().Add(-2 * time.Minute).UTC().Truncate(time.Millisecond) + updatedAt := createdAt.Add(time.Minute) + + session, err := store.CreateSession(ctx, CreateSessionInput{ + ID: "session_roundtrip", + Title: " Session Roundtrip ", + CreatedAt: createdAt, + UpdatedAt: updatedAt, + Provider: "openai", + Model: "gpt-5", + Workdir: "/repo", + TaskState: TaskState{ + Goal: "ship sqlite migration", + Progress: []string{"draft plan"}, + }, + ActivatedSkills: []SkillActivation{{SkillID: "go_review"}, {SkillID: "go-review"}}, + Todos: []TodoItem{ + {ID: "todo-1", Content: "implement store"}, + }, + TokenInputTotal: 11, + TokenOutputTotal: 7, + }) + if err != nil { + t.Fatalf("CreateSession() error = %v", err) + } + if session.ID != "session_roundtrip" || session.Title != "Session Roundtrip" { + t.Fatalf("unexpected created session: %+v", session) } - store := NewJSONStore(baseDir, workspaceRoot) - older := &Session{ - SchemaVersion: CurrentSchemaVersion, - ID: "session-old", - Title: "Old Session", - CreatedAt: time.Now().Add(-2 * time.Hour), - UpdatedAt: time.Now().Add(-1 * time.Hour), + if err := store.AppendMessages(ctx, AppendMessagesInput{ + SessionID: session.ID, Messages: []providertypes.Message{ - {Role: "user", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}, - {Role: "assistant", Parts: []providertypes.ContentPart{providertypes.NewTextPart("world")}}, + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}, + }, + { + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("world"), + }, + ToolCalls: []providertypes.ToolCall{{ID: "call-1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`}}, + }, }, - } - newer := &Session{ - SchemaVersion: CurrentSchemaVersion, - ID: "session-new", - Title: "New Session", - CreatedAt: time.Now().Add(-30 * time.Minute), - UpdatedAt: time.Now(), - Workdir: t.TempDir(), - Messages: []providertypes.Message{ - {Role: "user", Parts: []providertypes.ContentPart{providertypes.NewTextPart("new")}}, + UpdatedAt: updatedAt.Add(time.Minute), + Provider: "openai", + Model: "gpt-5.1", + Workdir: "/repo/subdir", + TokenInputDelta: 3, + TokenOutputDelta: 5, + }); err != nil { + t.Fatalf("AppendMessages() error = %v", err) + } + + if err := store.UpdateSessionState(ctx, UpdateSessionStateInput{ + SessionID: session.ID, + Title: "SQLite Ready", + UpdatedAt: updatedAt.Add(2 * time.Minute), + Provider: "openai", + Model: "gpt-5.1", + Workdir: "/repo/final", + TaskState: TaskState{ + Goal: "ship sqlite migration", + Progress: []string{"draft plan", "replace store"}, + UserConstraints: []string{"no legacy compatibility"}, }, + ActivatedSkills: []SkillActivation{{SkillID: "go-review"}}, + Todos: []TodoItem{ + {ID: "todo-1", Content: "implement store", Status: TodoStatusInProgress}, + }, + TokenInputTotal: 99, + TokenOutputTotal: 42, + }); err != nil { + t.Fatalf("UpdateSessionState() error = %v", err) } - if err := store.Save(context.Background(), older); err != nil { - t.Fatalf("Save older session: %v", err) - } - if err := store.Save(context.Background(), newer); err != nil { - t.Fatalf("Save newer session: %v", err) - } - - loaded, err := store.Load(context.Background(), older.ID) + loaded, err := store.LoadSession(ctx, session.ID) if err != nil { - t.Fatalf("Load() error: %v", err) - } - if loaded.Title != older.Title { - t.Fatalf("expected title %q, got %q", older.Title, loaded.Title) + t.Fatalf("LoadSession() error = %v", err) } - if loaded.Workdir != older.Workdir { - t.Fatalf("expected persisted workdir %q, got %q", older.Workdir, loaded.Workdir) + if loaded.Title != "SQLite Ready" || loaded.Workdir != "/repo/final" { + t.Fatalf("unexpected loaded header: %+v", loaded) } - if len(loaded.Messages) != 2 || renderPartsForTest(loaded.Messages[1].Parts) != "world" { - t.Fatalf("unexpected loaded messages: %+v", loaded.Messages) + if loaded.Provider != "openai" || loaded.Model != "gpt-5.1" { + t.Fatalf("unexpected provider/model: %+v", loaded) } - - rawPath := sessionFilePathForTest(baseDir, workspaceRoot, newer.ID) - raw, err := os.ReadFile(rawPath) - if err != nil { - t.Fatalf("read saved session: %v", err) + if loaded.TokenInputTotal != 99 || loaded.TokenOutputTotal != 42 { + t.Fatalf("unexpected token totals: in=%d out=%d", loaded.TokenInputTotal, loaded.TokenOutputTotal) } - if !strings.Contains(string(raw), "\"workdir\"") { - t.Fatalf("expected persisted session file to include workdir, got:\n%s", string(raw)) + if got := loaded.ActiveSkillIDs(); len(got) != 1 || got[0] != "go-review" { + t.Fatalf("unexpected active skills: %+v", got) } - - mustWriteSessionFile(t, sessionFilePathForTest(baseDir, workspaceRoot, "invalid"), "{invalid") - if err := os.MkdirAll(filepath.Join(sessionDirectory(baseDir, workspaceRoot), "directory"), 0o755); err != nil { - t.Fatalf("mkdir stray directory: %v", err) + if len(loaded.Todos) != 1 || loaded.Todos[0].Status != TodoStatusInProgress { + t.Fatalf("unexpected todos: %+v", loaded.Todos) } - - summaries, err := store.ListSummaries(context.Background()) - if err != nil { - t.Fatalf("ListSummaries() error: %v", err) + if len(loaded.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(loaded.Messages)) } - if len(summaries) != 2 { - t.Fatalf("expected 2 summaries, got %d", len(summaries)) + if renderSessionMessageParts(loaded.Messages[0]) != "hello" || renderSessionMessageParts(loaded.Messages[1]) != "world" { + t.Fatalf("unexpected messages: %+v", loaded.Messages) } - if summaries[0].ID != newer.ID || summaries[1].ID != older.ID { - t.Fatalf("expected summaries sorted by UpdatedAt desc, got %+v", summaries) + if len(loaded.Messages[1].ToolCalls) != 1 || loaded.Messages[1].ToolCalls[0].ID != "call-1" { + t.Fatalf("unexpected tool calls: %+v", loaded.Messages[1].ToolCalls) } } -func TestJSONStoreScopesSessionsByWorkspaceRoot(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceA := filepath.Join(t.TempDir(), "中文项目A") - workspaceB := filepath.Join(t.TempDir(), "中文项目B") - if err := os.MkdirAll(workspaceA, 0o755); err != nil { - t.Fatalf("mkdir workspaceA: %v", err) - } - if err := os.MkdirAll(workspaceB, 0o755); err != nil { - t.Fatalf("mkdir workspaceB: %v", err) - } - - storeA := NewJSONStore(baseDir, workspaceA) - storeB := NewJSONStore(baseDir, workspaceB) - - sessionA := &Session{SchemaVersion: CurrentSchemaVersion, ID: "session-a", Title: "A", CreatedAt: time.Now(), UpdatedAt: time.Now()} - sessionB := &Session{SchemaVersion: CurrentSchemaVersion, ID: "session-b", Title: "B", CreatedAt: time.Now(), UpdatedAt: time.Now()} - if err := storeA.Save(context.Background(), sessionA); err != nil { - t.Fatalf("save sessionA: %v", err) - } - if err := storeB.Save(context.Background(), sessionB); err != nil { - t.Fatalf("save sessionB: %v", err) - } - - summariesA, err := storeA.ListSummaries(context.Background()) +func TestSQLiteStoreListSummariesSortedAndLegacyJSONIgnored(t *testing.T) { + ctx := context.Background() + baseDir, err := os.MkdirTemp("", "session-base-") if err != nil { - t.Fatalf("ListSummaries() for storeA error: %v", err) + t.Fatalf("MkdirTemp() baseDir error = %v", err) } - if len(summariesA) != 1 || summariesA[0].ID != sessionA.ID { - t.Fatalf("expected storeA to only list sessionA, got %+v", summariesA) - } - - summariesB, err := storeB.ListSummaries(context.Background()) + workspaceRoot, err := os.MkdirTemp("", "session-workspace-") if err != nil { - t.Fatalf("ListSummaries() for storeB error: %v", err) - } - if len(summariesB) != 1 || summariesB[0].ID != sessionB.ID { - t.Fatalf("expected storeB to only list sessionB, got %+v", summariesB) - } - - if _, err := storeA.Load(context.Background(), sessionB.ID); err == nil { - t.Fatalf("expected storeA to fail loading session from another workspace bucket") - } -} - -func TestHashWorkspaceRootNormalizesChinesePathVariants(t *testing.T) { - t.Parallel() - - base := filepath.Join(t.TempDir(), "中文项目") - if err := os.MkdirAll(base, 0o755); err != nil { - t.Fatalf("mkdir base: %v", err) + t.Fatalf("MkdirTemp() workspaceRoot error = %v", err) } + store := NewStore(baseDir, workspaceRoot) + t.Cleanup(func() { + _ = store.Close() + _ = os.RemoveAll(baseDir) + _ = os.RemoveAll(workspaceRoot) + }) - normalized := NormalizeWorkspaceRoot(base) - slashVariant := strings.ReplaceAll(normalized, `\`, `/`) - if got, want := HashWorkspaceRoot(normalized), HashWorkspaceRoot(slashVariant); got != want { - t.Fatalf("expected slash variants to hash equally, got %q and %q", got, want) + legacyPath := filepath.Join(projectDirectory(baseDir, workspaceRoot), "sessions", "legacy", "session.json") + if err := os.MkdirAll(filepath.Dir(legacyPath), 0o755); err != nil { + t.Fatalf("mkdir legacy path: %v", err) } - - upperVariant := strings.ToUpper(normalized) - lowerVariant := strings.ToLower(normalized) - gotCaseUpper := HashWorkspaceRoot(upperVariant) - gotCaseLower := HashWorkspaceRoot(lowerVariant) - if goruntime.GOOS == "windows" { - if gotCaseUpper != gotCaseLower { - t.Fatalf("expected case variants to hash equally on windows, got %q and %q", gotCaseUpper, gotCaseLower) - } - } else { - if gotCaseUpper == gotCaseLower { - t.Fatalf("expected case variants to hash differently on case-sensitive platforms, got %q", gotCaseUpper) - } + if err := os.WriteFile(legacyPath, []byte(`{"id":"legacy"}`), 0o644); err != nil { + t.Fatalf("write legacy file: %v", err) } -} - -func TestWorkspaceHelpersHandleEmptyAndRelativePath(t *testing.T) { - t.Parallel() - if got := WorkspacePathKey(" "); got != "" { - t.Fatalf("expected empty workspace key, got %q", got) + firstTime := time.Now().Add(-2 * time.Hour).UTC() + secondTime := firstTime.Add(time.Hour) + if _, err := store.CreateSession(ctx, CreateSessionInput{ID: "s1", Title: "Older", CreatedAt: firstTime, UpdatedAt: firstTime}); err != nil { + t.Fatalf("CreateSession(s1) error = %v", err) } - if got := NormalizeWorkspaceRoot(" "); got != "" { - t.Fatalf("expected empty normalized workspace root, got %q", got) + if _, err := store.CreateSession(ctx, CreateSessionInput{ID: "s2", Title: "Newer", CreatedAt: secondTime, UpdatedAt: secondTime}); err != nil { + t.Fatalf("CreateSession(s2) error = %v", err) } - workingDir, err := os.Getwd() + summaries, err := store.ListSummaries(ctx) if err != nil { - t.Fatalf("getwd: %v", err) - } - relative := "." - normalized := NormalizeWorkspaceRoot(relative) - if normalized != filepath.Clean(workingDir) { - t.Fatalf("expected relative path to normalize to %q, got %q", filepath.Clean(workingDir), normalized) - } - - if got, want := HashWorkspaceRoot(""), HashWorkspaceRoot(" "); got != want { - t.Fatalf("expected empty workspace root variants to share fallback hash, got %q want %q", got, want) - } -} - -func TestJSONStoreErrors(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - store := NewJSONStore(baseDir, t.TempDir()) - - cancelledCtx, cancel := context.WithCancel(context.Background()) - cancel() - - if err := store.Save(cancelledCtx, &Session{ID: "x"}); err == nil { - t.Fatalf("expected cancelled save to fail") - } - if err := store.Save(context.Background(), nil); err == nil { - t.Fatalf("expected nil session save to fail") + t.Fatalf("ListSummaries() error = %v", err) } - if _, err := store.Load(cancelledCtx, "missing"); err == nil { - t.Fatalf("expected cancelled load to fail") - } - if _, err := store.ListSummaries(cancelledCtx); err == nil { - t.Fatalf("expected cancelled list to fail") + if len(summaries) != 2 { + t.Fatalf("expected 2 summaries, got %d", len(summaries)) } - if _, err := store.Load(context.Background(), " "); err == nil || !strings.Contains(err.Error(), "session id is empty") { - t.Fatalf("expected empty session id load error, got %v", err) + if summaries[0].ID != "s2" || summaries[1].ID != "s1" { + t.Fatalf("unexpected summary order: %+v", summaries) } } -func TestJSONStoreCorruptedSessionBehaviors(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) - - valid := &Session{ - SchemaVersion: CurrentSchemaVersion, - ID: "valid-session", - Title: "Valid Session", - CreatedAt: time.Now().Add(-time.Minute), - UpdatedAt: time.Now(), - Messages: []providertypes.Message{{Role: "user", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}}, - } - if err := store.Save(context.Background(), valid); err != nil { - t.Fatalf("Save valid session: %v", err) - } - - mustWriteSessionFile(t, sessionFilePathForTest(baseDir, workspaceRoot, "broken"), "{broken") - - _, err := store.Load(context.Background(), "broken") - if err == nil || !strings.Contains(err.Error(), "decode session broken") { - t.Fatalf("expected corrupted session decode error, got %v", err) - } - - summaries, err := store.ListSummaries(context.Background()) +func TestSQLiteStoreReplaceTranscriptAndPragmas(t *testing.T) { + ctx := context.Background() + store := newTestStore(t) + session, err := store.CreateSession(ctx, CreateSessionInput{ID: "replace_me", Title: "replace me"}) if err != nil { - t.Fatalf("ListSummaries() error: %v", err) - } - if len(summaries) != 1 || summaries[0].ID != valid.ID { - t.Fatalf("expected corrupted session file to be skipped, got %+v", summaries) + t.Fatalf("CreateSession() error = %v", err) } -} - -func TestJSONStoreSaveInvalidBaseDir(t *testing.T) { - t.Parallel() - - tempDir := t.TempDir() - baseFile := filepath.Join(tempDir, "not-a-directory") - if err := os.WriteFile(baseFile, []byte("x"), 0o644); err != nil { - t.Fatalf("write base file: %v", err) + if err := store.AppendMessages(ctx, AppendMessagesInput{ + SessionID: session.ID, + Messages: []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("before")}}, + {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("before-response")}}, + }, + }); err != nil { + t.Fatalf("AppendMessages() error = %v", err) } - store := NewJSONStore(baseFile, t.TempDir()) - err := store.Save(context.Background(), &Session{ - SchemaVersion: CurrentSchemaVersion, - ID: "session-x", - Title: "Broken Save", - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - }) - if err == nil || !strings.Contains(err.Error(), "create sessions dir") { - t.Fatalf("expected invalid base dir error, got %v", err) + if err := store.ReplaceTranscript(ctx, ReplaceTranscriptInput{ + SessionID: session.ID, + UpdatedAt: time.Now().UTC(), + Provider: "openai", + Model: "gpt-5.2", + Workdir: "/repo", + TaskState: TaskState{Goal: "after compact"}, + Todos: []TodoItem{ + {ID: "todo-1", Content: "after compact"}, + }, + Messages: []providertypes.Message{ + {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("after")}}, + }, + TokenInputTotal: 0, + TokenOutputTotal: 0, + }); err != nil { + t.Fatalf("ReplaceTranscript() error = %v", err) } -} - -func TestJSONStoreSaveReplaceFailureWhenTargetIsNonEmptyDirectory(t *testing.T) { - t.Parallel() - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) - targetDir := sessionFilePathForTest(baseDir, workspaceRoot, "blocked") - if err := os.MkdirAll(targetDir, 0o755); err != nil { - t.Fatalf("mkdir target dir: %v", err) - } - if err := os.WriteFile(filepath.Join(targetDir, "child.txt"), []byte("x"), 0o644); err != nil { - t.Fatalf("write child file: %v", err) + loaded, err := store.LoadSession(ctx, session.ID) + if err != nil { + t.Fatalf("LoadSession() error = %v", err) } - - err := store.Save(context.Background(), &Session{ - SchemaVersion: CurrentSchemaVersion, - ID: "blocked", - Title: "Blocked", - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - }) - if err == nil || !strings.Contains(err.Error(), "replace file") { - t.Fatalf("expected replace failure, got %v", err) + if loaded.Title != "replace me" { + t.Fatalf("expected title to be preserved after replace, got %q", loaded.Title) } -} - -func TestJSONStoreSaveOverwritesExistingSessionFile(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) - session := &Session{ - SchemaVersion: CurrentSchemaVersion, - ID: "overwrite", - Title: "First", - CreatedAt: time.Now().Add(-time.Minute), - UpdatedAt: time.Now().Add(-time.Minute), - } - if err := store.Save(context.Background(), session); err != nil { - t.Fatalf("save initial session: %v", err) + if len(loaded.Messages) != 1 || renderSessionMessageParts(loaded.Messages[0]) != "after" { + t.Fatalf("unexpected messages after replace: %+v", loaded.Messages) } - - session.Title = "Second" - session.UpdatedAt = time.Now() - if err := store.Save(context.Background(), session); err != nil { - t.Fatalf("save updated session: %v", err) + if loaded.TaskState.Goal != "after compact" { + t.Fatalf("unexpected task state after replace: %+v", loaded.TaskState) } - loaded, err := store.Load(context.Background(), session.ID) + db, err := store.ensureDB(ctx) if err != nil { - t.Fatalf("load updated session: %v", err) - } - if loaded.Title != "Second" { - t.Fatalf("expected overwritten session title %q, got %q", "Second", loaded.Title) + t.Fatalf("ensureDB() error = %v", err) } + assertPragmaString(t, db, "journal_mode", "wal") + assertPragmaInt(t, db, "foreign_keys", 1) + assertPragmaInt(t, db, "busy_timeout", 5000) + assertPragmaInt(t, db, "user_version", sqliteSchemaVersion) } -func TestJSONStoreSaveWriteTempFailure(t *testing.T) { - t.Parallel() - if goruntime.GOOS == "windows" { - t.Skip("chmod write permission behavior is platform-specific on windows") +func TestSQLiteStoreAppendMessagesRollbackOnTriggerFailure(t *testing.T) { + ctx := context.Background() + store := newTestStore(t) + session, err := store.CreateSession(ctx, CreateSessionInput{ID: "rollback_me", Title: "rollback"}) + if err != nil { + t.Fatalf("CreateSession() error = %v", err) } - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) - sessionDir := filepath.Join(sessionDirectory(baseDir, workspaceRoot), "temp-blocked") - if err := os.MkdirAll(sessionDir, 0o755); err != nil { - t.Fatalf("mkdir session dir: %v", err) + db, err := store.ensureDB(ctx) + if err != nil { + t.Fatalf("ensureDB() error = %v", err) } - if err := os.Chmod(sessionDir, 0o555); err != nil { - t.Fatalf("chmod session dir readonly: %v", err) + if _, err := db.ExecContext(ctx, ` +CREATE TRIGGER fail_second_insert +BEFORE INSERT ON messages +WHEN NEW.seq = 2 +BEGIN + SELECT RAISE(ABORT, 'boom'); +END +`); err != nil { + t.Fatalf("create trigger: %v", err) } - t.Cleanup(func() { - _ = os.Chmod(sessionDir, 0o755) - }) - err := store.Save(context.Background(), &Session{ - SchemaVersion: CurrentSchemaVersion, - ID: "temp-blocked", - Title: "Temp Blocked", - CreatedAt: time.Now(), - UpdatedAt: time.Now(), + err = store.AppendMessages(ctx, AppendMessagesInput{ + SessionID: session.ID, + Messages: []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("one")}}, + {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("two")}}, + }, }) - if err == nil || !strings.Contains(err.Error(), "create temp file") { - t.Fatalf("expected temp write failure, got %v", err) - } -} - -func TestEnsurePathWithinBaseRejectsEscapedPath(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - escaped := filepath.Join(baseDir, "..", "escaped", "session.json") - - if err := ensurePathWithinBase(baseDir, escaped); err == nil || !strings.Contains(err.Error(), "escapes base dir") { - t.Fatalf("expected escaped path rejection, got %v", err) - } -} - -func TestJSONStoreLoadMissingFileReturnsError(t *testing.T) { - t.Parallel() - - store := NewJSONStore(t.TempDir(), t.TempDir()) - if _, err := store.Load(context.Background(), "missing"); err == nil { - t.Fatalf("expected missing file load to fail") - } -} - -func TestJSONStoreLoadRejectsInvalidSessionID(t *testing.T) { - t.Parallel() - - store := NewJSONStore(t.TempDir(), t.TempDir()) - if _, err := store.Load(context.Background(), "bad/id"); err == nil || !strings.Contains(err.Error(), "unsupported characters") { - t.Fatalf("expected invalid session id error, got %v", err) + if err == nil || !strings.Contains(err.Error(), "boom") { + t.Fatalf("AppendMessages() err = %v, want trigger failure", err) } -} - -func TestNewUsesDefaultWorkdirAndEmptyMessages(t *testing.T) { - t.Parallel() - session := New("hello title") - - if session.ID == "" { - t.Fatalf("expected non-empty id") - } - if !strings.HasPrefix(session.ID, "session_") { - t.Fatalf("expected id with session_ prefix, got %q", session.ID) - } - if session.SchemaVersion != CurrentSchemaVersion { - t.Fatalf("expected schema version %d, got %d", CurrentSchemaVersion, session.SchemaVersion) - } - if session.Title != "hello title" { - t.Fatalf("expected title %q, got %q", "hello title", session.Title) - } - if session.Workdir != "" { - t.Fatalf("expected empty workdir, got %q", session.Workdir) - } - if len(session.Messages) != 0 { - t.Fatalf("expected empty messages, got %+v", session.Messages) - } - if session.TaskState.Established() { - t.Fatalf("expected empty task state, got %+v", session.TaskState) - } - if session.CreatedAt.IsZero() || session.UpdatedAt.IsZero() { - t.Fatalf("expected non-zero timestamps, got created=%v updated=%v", session.CreatedAt, session.UpdatedAt) + loaded, err := store.LoadSession(ctx, session.ID) + if err != nil { + t.Fatalf("LoadSession() error = %v", err) } - if session.UpdatedAt.Before(session.CreatedAt) { - t.Fatalf("expected UpdatedAt >= CreatedAt, got created=%v updated=%v", session.CreatedAt, session.UpdatedAt) + if len(loaded.Messages) != 0 { + t.Fatalf("expected rollback to leave zero messages, got %+v", loaded.Messages) } } -func TestNewWithWorkdirTrimAndTitleSanitize(t *testing.T) { - t.Parallel() - - tooLong := strings.Repeat("测", 45) - workdir := " /tmp/workdir " +func TestSQLiteStoreErrors(t *testing.T) { + ctx := context.Background() + store := newTestStore(t) - session := NewWithWorkdir(tooLong, workdir) - - if session.Workdir != "/tmp/workdir" { - t.Fatalf("expected trimmed workdir %q, got %q", "/tmp/workdir", session.Workdir) + if _, err := store.CreateSession(ctx, CreateSessionInput{ID: "bad/id", Title: "x"}); err == nil { + t.Fatalf("expected invalid create session id error") } - if got := len([]rune(session.Title)); got != 40 { - t.Fatalf("expected title rune length 40, got %d (title=%q)", got, session.Title) + if err := store.AppendMessages(ctx, AppendMessagesInput{SessionID: "missing"}); err == nil { + t.Fatalf("expected append empty messages error") } -} - -func TestNewWithWorkdirFallsBackDefaultTitle(t *testing.T) { - t.Parallel() - - session := NewWithWorkdir(" \n\t ", "") - - if session.Title != "New Session" { - t.Fatalf("expected default title %q, got %q", "New Session", session.Title) + if err := store.UpdateSessionState(ctx, UpdateSessionStateInput{SessionID: "missing", Title: "x"}); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected update missing session to return os.ErrNotExist, got %v", err) } -} - -func TestNewStoreReturnsJSONStore(t *testing.T) { - t.Parallel() - - store := NewStore(t.TempDir(), t.TempDir()) - if store == nil { - t.Fatalf("expected non-nil store") + if _, err := store.LoadSession(ctx, "missing"); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected load missing session to return os.ErrNotExist, got %v", err) } } -func TestJSONStoreListSummariesReadDirFailure(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) - - sessionsPath := sessionDirectory(baseDir, workspaceRoot) - mustWriteSessionFile(t, sessionsPath, "not-a-dir") - - _, err := store.ListSummaries(context.Background()) - if err == nil || !strings.Contains(err.Error(), "create sessions dir") { - t.Fatalf("expected create sessions dir error, got %v", err) - } -} - -func TestJSONStoreListSummariesContextCanceledDuringIteration(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - store := NewJSONStore(baseDir, t.TempDir()) - - for i := 0; i < 10; i++ { - s := &Session{ - SchemaVersion: CurrentSchemaVersion, - ID: "session-iter-" + strings.Repeat("x", i+1), - Title: "iter", - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - } - if err := store.Save(context.Background(), s); err != nil { - t.Fatalf("save session %d: %v", i, err) - } - } - - ctx, cancel := context.WithCancel(context.Background()) +func TestSQLiteStoreEnsureDBCanRetryAfterInitFailure(t *testing.T) { + store := newTestStore(t) + canceledCtx, cancel := context.WithCancel(context.Background()) cancel() - _, err := store.ListSummaries(ctx) - if !errors.Is(err, context.Canceled) { - t.Fatalf("expected context canceled, got %v", err) + if _, err := store.ensureDB(canceledCtx); err == nil { + t.Fatalf("expected ensureDB() with canceled context to fail") } -} - -func TestJSONStoreLoadDecodeErrorWithNonJSONPayload(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) - - mustWriteSessionFile(t, sessionFilePathForTest(baseDir, workspaceRoot, "decode-bad"), "{not-json") - - _, err := store.Load(context.Background(), "decode-bad") - if err == nil || !strings.Contains(err.Error(), "decode session decode-bad") { - t.Fatalf("expected decode session error, got %v", err) - } -} - -func TestJSONStoreLoadRejectsMissingSchemaVersion(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) - - mustWriteSessionFile( - t, - sessionFilePathForTest(baseDir, workspaceRoot, "missing-schema"), - `{"id":"missing-schema","title":"x","task_state":{"goal":"","progress":[],"open_items":[],"next_step":"","blockers":[],"key_artifacts":[],"decisions":[],"user_constraints":[],"last_updated_at":"0001-01-01T00:00:00Z"},"messages":[]}`, - ) - - _, err := store.Load(context.Background(), "missing-schema") - if err == nil || !strings.Contains(err.Error(), "missing required field schema_version") { - t.Fatalf("expected missing schema_version rejection, got %v", err) - } -} - -func TestJSONStoreLoadRejectsMissingTaskState(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) - - mustWriteSessionFile( - t, - sessionFilePathForTest(baseDir, workspaceRoot, "missing-task-state"), - `{"schema_version":2,"id":"missing-task-state","title":"x","messages":[]}`, - ) - - _, err := store.Load(context.Background(), "missing-task-state") - if err == nil || !strings.Contains(err.Error(), "missing required field task_state") { - t.Fatalf("expected missing task_state rejection, got %v", err) - } -} - -func TestJSONStoreListSummariesSkipsUnreadableAndMalformedEntries(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) - - valid := &Session{ - SchemaVersion: CurrentSchemaVersion, - ID: "valid-summary", - Title: "Valid", - CreatedAt: time.Now().Add(-time.Minute), - UpdatedAt: time.Now(), - } - if err := store.Save(context.Background(), valid); err != nil { - t.Fatalf("save valid session: %v", err) - } - - mustWriteSessionFile(t, sessionFilePathForTest(baseDir, workspaceRoot, "malformed"), "{malformed") - mustWriteSessionFile( - t, - sessionFilePathForTest(baseDir, workspaceRoot, "empty-id"), - `{"schema_version":2,"id":" ","title":"x","created_at":"2026-04-13T00:00:00Z","updated_at":"2026-04-13T00:00:00Z","task_state":{"goal":"","progress":[],"open_items":[],"next_step":"","blockers":[],"key_artifacts":[],"decisions":[],"user_constraints":[],"last_updated_at":"2026-04-13T00:00:00Z"}}`, - ) - mustWriteSessionFile( - t, - sessionFilePathForTest(baseDir, workspaceRoot, "missing-task-state-summary"), - `{"schema_version":2,"id":"missing-task-state-summary","title":"x","created_at":"2026-04-13T00:00:00Z","updated_at":"2026-04-13T00:00:00Z"}`, - ) - - summaries, err := store.ListSummaries(context.Background()) + db, err := store.ensureDB(context.Background()) if err != nil { - t.Fatalf("ListSummaries() error: %v", err) - } - if len(summaries) != 1 || summaries[0].ID != valid.ID { - t.Fatalf("expected only valid summary, got %+v", summaries) - } -} - -func TestJSONStoreSaveRejectsInvalidSchemaAndID(t *testing.T) { - t.Parallel() - - store := NewJSONStore(t.TempDir(), t.TempDir()) - - err := store.Save(context.Background(), &Session{ - SchemaVersion: CurrentSchemaVersion + 1, - ID: "session-invalid-schema", - Title: "Invalid Schema", - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - }) - if err == nil || !strings.Contains(err.Error(), "unsupported schema_version") { - t.Fatalf("expected schema version error, got %v", err) + t.Fatalf("ensureDB() retry with healthy context error = %v", err) } - - err = store.Save(context.Background(), &Session{ - SchemaVersion: CurrentSchemaVersion, - ID: "bad/session", - Title: "Invalid ID", - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - }) - if err == nil || !strings.Contains(err.Error(), "unsupported characters") { - t.Fatalf("expected invalid id error, got %v", err) - } -} - -func TestJSONStoreSaveCreateSessionDirFailure(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) - sessionID := "session-dir-failed" - - sessionDirPath := filepath.Join(sessionDirectory(baseDir, workspaceRoot), sessionID) - if err := os.MkdirAll(filepath.Dir(sessionDirPath), 0o755); err != nil { - t.Fatalf("mkdir session parent: %v", err) - } - if err := os.WriteFile(sessionDirPath, []byte("blocked"), 0o644); err != nil { - t.Fatalf("write session dir blocker: %v", err) - } - - err := store.Save(context.Background(), &Session{ - SchemaVersion: CurrentSchemaVersion, - ID: sessionID, - Title: "Blocked SessionDir", - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - }) - if err == nil || !strings.Contains(err.Error(), "create session dir") { - t.Fatalf("expected create session dir error, got %v", err) - } -} - -func TestDecodeStoredSessionAndSummarySchemaValidation(t *testing.T) { - t.Parallel() - - _, err := decodeStoredSession([]byte(`{ - "schema_version": 999, - "id": "decode-invalid-schema", - "title": "x", - "created_at": "2026-04-13T08:00:00Z", - "updated_at": "2026-04-13T08:00:00Z", - "task_state": { - "goal": "", - "progress": [], - "open_items": [], - "next_step": "", - "blockers": [], - "key_artifacts": [], - "decisions": [], - "user_constraints": [], - "last_updated_at": "2026-04-13T08:00:00Z" - }, - "messages": [] -}`)) - if err == nil || !strings.Contains(err.Error(), "unsupported schema_version") { - t.Fatalf("expected decodeStoredSession schema error, got %v", err) - } - - _, err = decodeStoredSummary([]byte(`{ - "schema_version": 999, - "id": "summary-invalid-schema", - "title": "Summary", - "created_at": "2026-04-13T08:00:00Z", - "updated_at": "2026-04-13T08:00:00Z", - "task_state": { - "goal": "", - "progress": [], - "open_items": [], - "next_step": "", - "blockers": [], - "key_artifacts": [], - "decisions": [], - "user_constraints": [], - "last_updated_at": "2026-04-13T08:00:00Z" - } -}`)) - if err == nil || !strings.Contains(err.Error(), "unsupported schema_version") { - t.Fatalf("expected decodeStoredSummary schema error, got %v", err) + if db == nil { + t.Fatalf("expected ensureDB() retry to return non-nil db") } } -func TestJSONStoreSavePersistsProviderModelAndMessages(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) +func TestSQLiteStoreLoadSessionRejectsCorruptHeaderAndMessageData(t *testing.T) { + ctx := context.Background() + store := newTestStore(t) - session := &Session{ - SchemaVersion: CurrentSchemaVersion, - ID: "persist-full-fields", - Title: "Persist Fields", - Provider: "openai", - Model: "gpt-4.1", - Workdir: "/tmp/persist-workdir", - CreatedAt: time.Now().Add(-time.Hour), - UpdatedAt: time.Now(), - Messages: []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}, - { - Role: providertypes.RoleAssistant, - Parts: []providertypes.ContentPart{providertypes.NewTextPart("calling tool")}, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "webfetch", Arguments: `{"url":"https://example.com"}`}, - }, - }, - { - Role: providertypes.RoleTool, - ToolCallID: "call-1", - Parts: []providertypes.ContentPart{providertypes.NewTextPart("ok")}, - ToolMetadata: map[string]string{ - "tool_name": "webfetch", - "http_status": "200", - }, - }, - }, - } - - if err := store.Save(context.Background(), session); err != nil { - t.Fatalf("save session: %v", err) - } - - rawPath := sessionFilePathForTest(baseDir, workspaceRoot, session.ID) - raw, err := os.ReadFile(rawPath) + session, err := store.CreateSession(ctx, CreateSessionInput{ID: "corrupt_header", Title: "header"}) if err != nil { - t.Fatalf("read raw file: %v", err) + t.Fatalf("CreateSession(corrupt_header) error = %v", err) } - - var decoded map[string]any - if err := json.Unmarshal(raw, &decoded); err != nil { - t.Fatalf("decode raw json: %v", err) - } - - if decoded["provider"] != "openai" { - t.Fatalf("expected provider persisted, got %+v", decoded["provider"]) - } - if decoded["model"] != "gpt-4.1" { - t.Fatalf("expected model persisted, got %+v", decoded["model"]) + db, err := store.ensureDB(ctx) + if err != nil { + t.Fatalf("ensureDB() error = %v", err) } - if _, ok := decoded["messages"]; !ok { - t.Fatalf("expected messages field persisted, got %+v", decoded) + if _, err := db.ExecContext(ctx, `UPDATE sessions SET task_state_json = '{' WHERE id = ?`, session.ID); err != nil { + t.Fatalf("corrupt task_state_json: %v", err) } - if decoded["workdir"] != session.Workdir { - t.Fatalf("expected workdir persisted as %q, got %+v", session.Workdir, decoded["workdir"]) + if _, err := store.LoadSession(ctx, session.ID); err == nil || !strings.Contains(err.Error(), "decode task_state") { + t.Fatalf("expected task_state decode error, got %v", err) } - loaded, err := store.Load(context.Background(), session.ID) + session, err = store.CreateSession(ctx, CreateSessionInput{ID: "corrupt_message", Title: "message"}) if err != nil { - t.Fatalf("load saved session: %v", err) + t.Fatalf("CreateSession(corrupt_message) error = %v", err) } - if loaded.Messages[2].ToolMetadata["tool_name"] != "webfetch" || loaded.Messages[2].ToolMetadata["http_status"] != "200" { - t.Fatalf("expected tool metadata round-trip, got %+v", loaded.Messages[2].ToolMetadata) - } -} - -func TestJSONStoreSaveRoundTripsMetadataOnlyToolMessage(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) - - session := &Session{ - SchemaVersion: CurrentSchemaVersion, - ID: "metadata-only-tool-message", - Title: "Metadata Only Tool Message", - CreatedAt: time.Now().Add(-time.Hour), - UpdatedAt: time.Now(), + if err := store.AppendMessages(ctx, AppendMessagesInput{ + SessionID: session.ID, Messages: []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("inspect")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`}, - }, - }, - { - Role: providertypes.RoleTool, - ToolCallID: "call-1", - Parts: []providertypes.ContentPart{providertypes.NewTextPart("")}, - ToolMetadata: map[string]string{ - "tool_name": "filesystem_read_file", - "path": "README.md", - }, - }, + {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("ok")}}, }, + }); err != nil { + t.Fatalf("AppendMessages() error = %v", err) } - if err := store.Save(context.Background(), session); err != nil { - t.Fatalf("save session: %v", err) - } - - loaded, err := store.Load(context.Background(), session.ID) - if err != nil { - t.Fatalf("load saved session: %v", err) + if _, err := db.ExecContext(ctx, `UPDATE messages SET parts_json = '{' WHERE session_id = ?`, session.ID); err != nil { + t.Fatalf("corrupt parts_json: %v", err) } - if renderPartsForTest(loaded.Messages[2].Parts) != "" { - t.Fatalf("expected empty content to round-trip, got %q", renderPartsForTest(loaded.Messages[2].Parts)) + if _, err := store.LoadSession(ctx, session.ID); err == nil || !strings.Contains(err.Error(), "decode message parts") { + t.Fatalf("expected message parts decode error, got %v", err) } - if loaded.Messages[2].ToolMetadata["tool_name"] != "filesystem_read_file" || - loaded.Messages[2].ToolMetadata["path"] != "README.md" { - t.Fatalf("expected metadata-only tool message round-trip, got %+v", loaded.Messages[2].ToolMetadata) - } -} - -func TestDecodeStoredSummaryUsesLightweightMetadataPath(t *testing.T) { - t.Parallel() - summary, err := decodeStoredSummary([]byte(`{ - "schema_version": 2, - "id": "summary-only", - "title": "Summary Only", - "created_at": "2026-04-13T08:00:00Z", - "updated_at": "2026-04-13T09:00:00Z", - "task_state": { - "goal": "persist task state", - "progress": [], - "open_items": [], - "next_step": "", - "blockers": [], - "key_artifacts": [], - "decisions": [], - "user_constraints": [], - "last_updated_at": "2026-04-13T09:00:00Z" - } -}`)) - if err != nil { - t.Fatalf("decodeStoredSummary() error: %v", err) + if _, err := db.ExecContext(ctx, `UPDATE messages SET parts_json = '[]', tool_calls_json = '{' WHERE session_id = ?`, session.ID); err != nil { + t.Fatalf("corrupt tool_calls_json: %v", err) } - - if summary.ID != "summary-only" { - t.Fatalf("expected summary id %q, got %q", "summary-only", summary.ID) + if _, err := store.LoadSession(ctx, session.ID); err == nil || !strings.Contains(err.Error(), "decode tool calls") { + t.Fatalf("expected tool calls decode error, got %v", err) } - if summary.Title != "Summary Only" { - t.Fatalf("expected summary title %q, got %q", "Summary Only", summary.Title) + + if _, err := db.ExecContext(ctx, `UPDATE messages SET tool_calls_json = '[]', tool_metadata_json = '{' WHERE session_id = ?`, session.ID); err != nil { + t.Fatalf("corrupt tool_metadata_json: %v", err) } - if summary.CreatedAt.IsZero() || summary.UpdatedAt.IsZero() { - t.Fatalf("expected non-zero timestamps, got created=%v updated=%v", summary.CreatedAt, summary.UpdatedAt) + if _, err := store.LoadSession(ctx, session.ID); err == nil || !strings.Contains(err.Error(), "decode tool metadata") { + t.Fatalf("expected tool metadata decode error, got %v", err) } } -func TestJSONStoreSaveClampsOversizedTaskState(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) +func TestSQLiteStoreAppendReplaceAndSchemaErrors(t *testing.T) { + ctx := context.Background() + store := newTestStore(t) - progress := make([]string, 0, taskStateMaxListItems+8) - for i := 0; i < taskStateMaxListItems+8; i++ { - progress = append(progress, strings.Repeat("p", taskStateMaxListItemChars-4)+buildIndexedSuffix(i)) - } - session := &Session{ - SchemaVersion: CurrentSchemaVersion, - ID: "task-state-clamp-save", - Title: "Clamp Save", - CreatedAt: time.Now().Add(-time.Minute), - UpdatedAt: time.Now(), - TaskState: TaskState{ - Goal: strings.Repeat("g", taskStateMaxFieldChars+50), - NextStep: strings.Repeat("n", taskStateMaxFieldChars+50), - Progress: progress, - OpenItems: progress, + if err := store.AppendMessages(ctx, AppendMessagesInput{ + SessionID: "missing_session", + Messages: []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")}}, }, + }); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected append missing session to return os.ErrNotExist, got %v", err) } - if err := store.Save(context.Background(), session); err != nil { - t.Fatalf("save session: %v", err) - } - - if len([]rune(session.TaskState.Goal)) != taskStateMaxFieldChars { - t.Fatalf("expected goal to be clamped to %d runes, got %d", taskStateMaxFieldChars, len([]rune(session.TaskState.Goal))) - } - if len(session.TaskState.Progress) != taskStateMaxListItems { - t.Fatalf("expected progress list clamped to %d, got %d", taskStateMaxListItems, len(session.TaskState.Progress)) - } - if len([]rune(session.TaskState.Progress[0])) != taskStateMaxListItemChars { - t.Fatalf( - "expected progress item clamped to %d runes, got %d", - taskStateMaxListItemChars, - len([]rune(session.TaskState.Progress[0])), - ) - } -} - -func TestJSONStoreLoadClampsOversizedTaskState(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) - - payload := strings.Join([]string{ - `{`, - ` "schema_version": 2,`, - ` "id": "task-state-clamp-load",`, - ` "title": "Clamp Load",`, - ` "created_at": "2026-04-13T08:00:00Z",`, - ` "updated_at": "2026-04-13T09:00:00Z",`, - ` "task_state": {`, - ` "goal": "` + strings.Repeat("g", taskStateMaxFieldChars+30) + `",`, - ` "progress": [` + buildQuotedRepeatedWithIndex("p", taskStateMaxListItemChars+30, taskStateMaxListItems+3) + `],`, - ` "open_items": [],`, - ` "next_step": "",`, - ` "blockers": [],`, - ` "key_artifacts": [],`, - ` "decisions": [],`, - ` "user_constraints": [],`, - ` "last_updated_at": "2026-04-13T09:00:00Z"`, - ` },`, - ` "messages": []`, - `}`, - }, "\n") - mustWriteSessionFile( - t, - sessionFilePathForTest(baseDir, workspaceRoot, "task-state-clamp-load"), - payload, - ) - - loaded, err := store.Load(context.Background(), "task-state-clamp-load") + session, err := store.CreateSession(ctx, CreateSessionInput{ID: "invalid_message", Title: "invalid"}) if err != nil { - t.Fatalf("load session: %v", err) - } - if len([]rune(loaded.TaskState.Goal)) != taskStateMaxFieldChars { - t.Fatalf("expected loaded goal to be clamped to %d runes, got %d", taskStateMaxFieldChars, len([]rune(loaded.TaskState.Goal))) - } - if len(loaded.TaskState.Progress) != taskStateMaxListItems { - t.Fatalf("expected loaded progress list clamped to %d, got %d", taskStateMaxListItems, len(loaded.TaskState.Progress)) - } -} - -func TestJSONStoreSaveLoadRoundTripTodos(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) - - createdAt := time.Date(2026, 4, 14, 10, 0, 0, 0, time.UTC) - updatedAt := createdAt.Add(5 * time.Minute) - session := &Session{ - SchemaVersion: CurrentSchemaVersion, - ID: "todos-round-trip", - Title: "Todos Round Trip", - CreatedAt: createdAt, - UpdatedAt: updatedAt, - TaskState: TaskState{}, + t.Fatalf("CreateSession() error = %v", err) + } + invalidPart := providertypes.ContentPart{Kind: "unknown"} + if err := store.AppendMessages(ctx, AppendMessagesInput{ + SessionID: session.ID, + Messages: []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{invalidPart}}}, + }); err == nil { + t.Fatalf("expected invalid message parts error") + } + if err := store.UpdateSessionState(ctx, UpdateSessionStateInput{ + SessionID: session.ID, + Title: "x", Todos: []TodoItem{ - { - ID: "todo-1", - Content: " design session todo model ", - Status: TodoStatusPending, - Dependencies: []string{"todo-2", "todo-2", " "}, - CreatedAt: createdAt, - UpdatedAt: updatedAt, - }, - { - ID: "todo-2", Content: "persist todos in session", - Status: TodoStatusInProgress, - Priority: 2, - CreatedAt: createdAt, - UpdatedAt: updatedAt, - }, + {ID: "dup", Content: "a"}, + {ID: "dup", Content: "b"}, }, - Messages: []providertypes.Message{{Role: "user", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}}, - } - - if err := store.Save(context.Background(), session); err != nil { - t.Fatalf("save session with todos: %v", err) - } - if got := session.Todos[0].Dependencies; len(got) != 1 || got[0] != "todo-2" { - t.Fatalf("expected dependencies normalized in-memory, got %+v", got) - } - if got := session.Todos[0].Content; got != "design session todo model" { - t.Fatalf("expected content normalized, got %q", got) + }); err == nil { + t.Fatalf("expected invalid todos error") } - - loaded, err := store.Load(context.Background(), session.ID) - if err != nil { - t.Fatalf("load session with todos: %v", err) - } - if len(loaded.Todos) != 2 { - t.Fatalf("expected 2 todos, got %d", len(loaded.Todos)) + if err := store.ReplaceTranscript(ctx, ReplaceTranscriptInput{ + SessionID: session.ID, + Messages: []providertypes.Message{{Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{invalidPart}}}, + }); err == nil { + t.Fatalf("expected replace transcript invalid message error") } - if loaded.Todos[0].ID != "todo-1" || loaded.Todos[1].ID != "todo-2" { - t.Fatalf("unexpected todo ids: %+v", loaded.Todos) - } - if loaded.Todos[1].Priority != 2 { - t.Fatalf("expected priority 2, got %d", loaded.Todos[1].Priority) + if err := store.ReplaceTranscript(ctx, ReplaceTranscriptInput{ + SessionID: "missing_session", + Messages: []providertypes.Message{ + {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("ok")}}, + }, + }); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected replace transcript missing session to return os.ErrNotExist, got %v", err) } } -func TestJSONStoreLoadAllowsMissingTodosField(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) - - mustWriteSessionFile(t, sessionFilePathForTest(baseDir, workspaceRoot, "no-todos"), strings.Join([]string{ - `{`, - ` "schema_version": 2,`, - ` "id": "no-todos",`, - ` "title": "No Todos",`, - ` "created_at": "2026-04-14T10:00:00Z",`, - ` "updated_at": "2026-04-14T10:05:00Z",`, - ` "task_state": {`, - ` "goal": "",`, - ` "progress": [],`, - ` "open_items": [],`, - ` "next_step": "",`, - ` "blockers": [],`, - ` "key_artifacts": [],`, - ` "decisions": [],`, - ` "user_constraints": [],`, - ` "last_updated_at": "2026-04-14T10:05:00Z"`, - ` },`, - ` "messages": []`, - `}`, - }, "\n")) - - loaded, err := store.Load(context.Background(), "no-todos") +func TestSQLiteStoreInitializationRejectsUnsupportedSchemaVersion(t *testing.T) { + ctx := context.Background() + baseDir, err := os.MkdirTemp("", "session-base-") if err != nil { - t.Fatalf("load session without todos field: %v", err) - } - if len(loaded.Todos) != 0 { - t.Fatalf("expected no todos, got %+v", loaded.Todos) + t.Fatalf("MkdirTemp() baseDir error = %v", err) } -} - -func TestJSONStoreLoadBackfillsTodoVersionWhenMissing(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) - - mustWriteSessionFile(t, sessionFilePathForTest(baseDir, workspaceRoot, "todos-no-version"), strings.Join([]string{ - `{`, - ` "schema_version": 2,`, - ` "id": "todos-no-version",`, - ` "title": "Todos No Version",`, - ` "created_at": "2026-04-14T10:00:00Z",`, - ` "updated_at": "2026-04-14T10:05:00Z",`, - ` "task_state": {`, - ` "goal": "",`, - ` "progress": [],`, - ` "open_items": [],`, - ` "next_step": "",`, - ` "blockers": [],`, - ` "key_artifacts": [],`, - ` "decisions": [],`, - ` "user_constraints": [],`, - ` "last_updated_at": "2026-04-14T10:05:00Z"`, - ` },`, - ` "todos": [`, - ` {`, - ` "id": "todo-1",`, - ` "content": "todo item",`, - ` "status": "pending",`, - ` "created_at": "2026-04-14T10:00:00Z",`, - ` "updated_at": "2026-04-14T10:05:00Z"`, - ` }`, - ` ],`, - ` "messages": []`, - `}`, - }, "\n")) - - loaded, err := store.Load(context.Background(), "todos-no-version") + workspaceRoot, err := os.MkdirTemp("", "session-workspace-") if err != nil { - t.Fatalf("load session with todos and missing todo_version: %v", err) + t.Fatalf("MkdirTemp() workspaceRoot error = %v", err) } - if loaded.TodoVersion != CurrentTodoVersion { - t.Fatalf("expected todo version %d, got %d", CurrentTodoVersion, loaded.TodoVersion) - } -} - -func TestJSONStoreSaveRejectsInvalidTodos(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) - - err := store.Save(context.Background(), &Session{ - SchemaVersion: CurrentSchemaVersion, - ID: "invalid-todos", - Title: "Invalid Todos", - CreatedAt: time.Now().Add(-time.Minute), - UpdatedAt: time.Now(), - TaskState: TaskState{}, - Todos: []TodoItem{ - {ID: "todo-1", Content: "first", Dependencies: []string{"missing"}}, - }, + store := NewStore(baseDir, workspaceRoot) + t.Cleanup(func() { + _ = store.Close() + _ = os.RemoveAll(baseDir) + _ = os.RemoveAll(workspaceRoot) }) - if err == nil || !strings.Contains(err.Error(), `unknown dependency "missing"`) { - t.Fatalf("expected invalid dependency error, got %v", err) - } -} -func TestDecodeStoredSummaryRejectsMissingSchemaVersion(t *testing.T) { - t.Parallel() - - _, err := decodeStoredSummary([]byte(`{"id":"summary-no-schema","task_state":{}}`)) - if err == nil || !strings.Contains(err.Error(), "missing required field schema_version") { - t.Fatalf("expected missing schema_version error, got %v", err) - } -} - -func TestCreateTempFileAndAtomicReplaceFailureBranches(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - missingDir := filepath.Join(baseDir, "missing", "dir") - if _, _, err := createTempFile(missingDir, "tmp-*.tmp", "create temp file"); err == nil || - !strings.Contains(err.Error(), "create temp file") { - t.Fatalf("expected create temp file error for missing dir, got %v", err) + projectDir := projectDirectory(baseDir, workspaceRoot) + if err := os.MkdirAll(projectDir, 0o755); err != nil { + t.Fatalf("MkdirAll(projectDir) error = %v", err) } - - source := filepath.Join(baseDir, "source.tmp") - if err := os.WriteFile(source, []byte("payload"), 0o644); err != nil { - t.Fatalf("write source temp: %v", err) + db, err := sql.Open("sqlite", databasePath(baseDir, workspaceRoot)) + if err != nil { + t.Fatalf("sql.Open() error = %v", err) } - blockerDir := filepath.Join(baseDir, "target") - if err := os.MkdirAll(filepath.Join(blockerDir, "child"), 0o755); err != nil { - t.Fatalf("mkdir blocker dir: %v", err) + if _, err := db.ExecContext(ctx, `PRAGMA user_version=999`); err != nil { + t.Fatalf("set user_version: %v", err) } - if err := replaceFileWithTemp(source, blockerDir, "file"); err == nil || !strings.Contains(err.Error(), "replace file") { - t.Fatalf("expected replace failure for non-empty target dir, got %v", err) + if err := db.Close(); err != nil { + t.Fatalf("db.Close() error = %v", err) } -} -func TestReplaceFileWithTempCommitFailure(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - target := filepath.Join(baseDir, "target.txt") - if err := replaceFileWithTemp(filepath.Join(baseDir, "missing.tmp"), target, "file"); err == nil || - !strings.Contains(err.Error(), "commit file") { - t.Fatalf("expected commit failure when temp file missing, got %v", err) + if _, err := store.ListSummaries(ctx); err == nil || !strings.Contains(err.Error(), "unsupported sqlite schema version") { + t.Fatalf("expected unsupported schema version error, got %v", err) } } -func TestWriteFileAtomicallyCreateTempFailure(t *testing.T) { - t.Parallel() - - target := filepath.Join(t.TempDir(), "missing", "nested", "session.json") - err := writeFileAtomically(target, "session-*.tmp", []byte("data"), 0o644) - if err == nil || !strings.Contains(err.Error(), "create temp file") { - t.Fatalf("expected create temp file error, got %v", err) - } -} - -func TestJSONStoreRejectsInvalidBasePathDuringAssetAndSessionResolve(t *testing.T) { - t.Parallel() - - store := &JSONStore{baseDir: string([]byte{'b', 'a', 'd', 0})} - if _, err := store.Load(context.Background(), "session-ok"); err == nil { - t.Fatalf("expected load error for invalid base path") - } - - if _, err := store.SaveAsset(context.Background(), "session-ok", strings.NewReader("img"), "image/png"); err == nil { - t.Fatalf("expected save asset error for invalid base path") +func assertPragmaString(t *testing.T, db *sql.DB, name string, want string) { + t.Helper() + var got string + if err := db.QueryRow(`PRAGMA ` + name).Scan(&got); err != nil { + t.Fatalf("PRAGMA %s scan error = %v", name, err) } - - if _, err := store.Stat(context.Background(), "session-ok", "asset-ok"); err == nil { - t.Fatalf("expected stat error for invalid base path") + if got != want { + t.Fatalf("PRAGMA %s = %q, want %q", name, got, want) } } -func TestJSONStoreLoadRejectsInvalidTodos(t *testing.T) { - t.Parallel() - - baseDir := t.TempDir() - workspaceRoot := t.TempDir() - store := NewJSONStore(baseDir, workspaceRoot) - - mustWriteSessionFile(t, sessionFilePathForTest(baseDir, workspaceRoot, "invalid-todos-load"), strings.Join([]string{ - `{`, - ` "schema_version": 2,`, - ` "id": "invalid-todos-load",`, - ` "title": "Invalid Todos Load",`, - ` "created_at": "2026-04-14T10:00:00Z",`, - ` "updated_at": "2026-04-14T10:05:00Z",`, - ` "task_state": {`, - ` "goal": "",`, - ` "progress": [],`, - ` "open_items": [],`, - ` "next_step": "",`, - ` "blockers": [],`, - ` "key_artifacts": [],`, - ` "decisions": [],`, - ` "user_constraints": [],`, - ` "last_updated_at": "2026-04-14T10:05:00Z"`, - ` },`, - ` "todos": [`, - ` {`, - ` "id": "todo-1",`, - ` "content": "broken todo",`, - ` "status": "paused",`, - ` "created_at": "2026-04-14T10:00:00Z",`, - ` "updated_at": "2026-04-14T10:05:00Z"`, - ` }`, - ` ],`, - ` "messages": []`, - `}`, - }, "\n")) - - _, err := store.Load(context.Background(), "invalid-todos-load") - if err == nil || !strings.Contains(err.Error(), `invalid todo status "paused"`) { - t.Fatalf("expected invalid todo status load error, got %v", err) +func assertPragmaInt(t *testing.T, db *sql.DB, name string, want int) { + t.Helper() + var got int + if err := db.QueryRow(`PRAGMA ` + name).Scan(&got); err != nil { + t.Fatalf("PRAGMA %s scan error = %v", name, err) } -} - -func buildQuotedRepeatedWithIndex(ch string, itemLen int, count int) string { - items := make([]string, 0, count) - for i := 0; i < count; i++ { - items = append(items, `"`+strings.Repeat(ch, itemLen-4)+buildIndexedSuffix(i)+`"`) + if got != want { + t.Fatalf("PRAGMA %s = %d, want %d", name, got, want) } - return strings.Join(items, ",") } -func buildIndexedSuffix(index int) string { - chars := []rune("abcdefghijklmnopqrstuvwxyz0123456789") - hi := chars[(index/len(chars))%len(chars)] - lo := chars[index%len(chars)] - return string([]rune{hi, lo, 'x', 'x'}) -} - -func mustWriteSessionFile(t *testing.T, path string, content string) { - t.Helper() - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - t.Fatalf("mkdir %s: %v", filepath.Dir(path), err) +func renderSessionMessageParts(message providertypes.Message) string { + if len(message.Parts) == 0 { + return "" } - if err := os.WriteFile(path, []byte(content), 0o644); err != nil { - t.Fatalf("write %s: %v", path, err) + var builder strings.Builder + for _, part := range message.Parts { + builder.WriteString(part.Text) } -} - -func sessionFilePathForTest(baseDir string, workspaceRoot string, sessionID string) string { - return filepath.Join(sessionDirectory(baseDir, workspaceRoot), sessionID, sessionFileName) + return builder.String() } diff --git a/internal/session/task_state.go b/internal/session/task_state.go index 7bb33379..b64b2c33 100644 --- a/internal/session/task_state.go +++ b/internal/session/task_state.go @@ -6,9 +6,6 @@ import ( ) const ( - // CurrentSchemaVersion 表示当前会话持久化结构的唯一合法版本。 - CurrentSchemaVersion = 2 - // taskStateMaxFieldChars 限制 TaskState 单值字段的最大字符数,避免异常大文本污染持久化与后续 prompt。 taskStateMaxFieldChars = 2000 // taskStateMaxListItems 限制 TaskState 列表字段的最大条目数,避免模型输出超大数组导致上下文膨胀。 @@ -79,6 +76,11 @@ func ClampTaskStateBoundaries(state TaskState) TaskState { return state } +// normalizeAndClampTaskState 先规范化再限幅,保证持久化前后的 task_state 行为一致。 +func normalizeAndClampTaskState(state TaskState) TaskState { + return ClampTaskStateBoundaries(NormalizeTaskState(state)) +} + // normalizeTaskStateList 对任务状态中的字符串列表做去空、去重并保留顺序。 func normalizeTaskStateList(items []string) []string { if len(items) == 0 { diff --git a/internal/session/test_helpers_test.go b/internal/session/test_helpers_test.go new file mode 100644 index 00000000..1b78de3b --- /dev/null +++ b/internal/session/test_helpers_test.go @@ -0,0 +1,30 @@ +package session + +import ( + "fmt" + "os" + "testing" +) + +func buildIndexedSuffix(index int) string { + return fmt.Sprintf("-%02d", index) +} + +func newTestStore(t *testing.T) *SQLiteStore { + t.Helper() + baseDir, err := os.MkdirTemp("", "session-base-") + if err != nil { + t.Fatalf("MkdirTemp() baseDir error = %v", err) + } + workspaceRoot, err := os.MkdirTemp("", "session-workspace-") + if err != nil { + t.Fatalf("MkdirTemp() error = %v", err) + } + store := NewStore(baseDir, workspaceRoot) + t.Cleanup(func() { + _ = store.Close() + _ = os.RemoveAll(baseDir) + _ = os.RemoveAll(workspaceRoot) + }) + return store +} diff --git a/internal/session/workspace.go b/internal/session/workspace.go index 8dc635c1..d16e0278 100644 --- a/internal/session/workspace.go +++ b/internal/session/workspace.go @@ -12,9 +12,19 @@ import ( const projectsDirName = "projects" -// sessionDirectory 负责根据工作区根目录计算会话分桶目录。 -func sessionDirectory(baseDir string, workspaceRoot string) string { - return filepath.Join(baseDir, projectsDirName, HashWorkspaceRoot(workspaceRoot), sessionsDirName) +// projectDirectory 负责根据工作区根目录计算当前会话数据库所在目录。 +func projectDirectory(baseDir string, workspaceRoot string) string { + return filepath.Join(baseDir, projectsDirName, HashWorkspaceRoot(workspaceRoot)) +} + +// databasePath 返回当前工作区级 SQLite 数据库文件路径。 +func databasePath(baseDir string, workspaceRoot string) string { + return filepath.Join(projectDirectory(baseDir, workspaceRoot), sessionDatabaseFileName) +} + +// assetsDirectory 返回当前工作区附件根目录。 +func assetsDirectory(baseDir string, workspaceRoot string) string { + return filepath.Join(projectDirectory(baseDir, workspaceRoot), assetsDirName) } // HashWorkspaceRoot 为规范化后的工作区根目录生成稳定哈希,供 session 和 memo 等包共享。 diff --git a/internal/tui/core/app/update.go b/internal/tui/core/app/update.go index 836ad0f3..19e3e7be 100644 --- a/internal/tui/core/app/update.go +++ b/internal/tui/core/app/update.go @@ -44,6 +44,8 @@ const ( const providerAddSelectTimeout = 10 * time.Second +const sessionSwitchBusyMessage = "cannot switch sessions while run or compact is active" + var panelOrder = []panel{panelTranscript, panelActivity, panelInput} var persistProviderUserEnvVar = config.PersistUserEnvVar var deleteProviderUserEnvVar = config.DeleteUserEnvVar @@ -378,6 +380,12 @@ func (a App) updateInputPanel(msg tea.Msg, typed tea.KeyMsg, cmds []tea.Cmd) (te } return a, tea.Batch(cmds...) case slashCommandSession: + if err := a.ensureSessionSwitchAllowed(""); err != nil { + a.state.ExecutionError = err.Error() + a.state.StatusText = err.Error() + a.appendActivity("session", "Failed to open session picker", err.Error(), true) + return a, tea.Batch(cmds...) + } if err := a.refreshSessionPicker(); err != nil { a.state.ExecutionError = err.Error() a.state.StatusText = err.Error() @@ -774,6 +782,9 @@ func (a *App) activateSelectedSession() error { if !ok { return nil } + if err := a.ensureSessionSwitchAllowed(item.Summary.ID); err != nil { + return err + } a.state.ActiveSessionID = item.Summary.ID a.state.ActiveSessionTitle = item.Summary.Title @@ -784,6 +795,9 @@ func (a *App) activateSelectedSession() error { } func (a *App) activateSessionByID(sessionID string) error { + if err := a.ensureSessionSwitchAllowed(sessionID); err != nil { + return err + } for _, s := range a.state.Sessions { if s.ID == sessionID { a.state.ActiveSessionID = s.ID @@ -796,6 +810,16 @@ func (a *App) activateSessionByID(sessionID string) error { return fmt.Errorf("session not found: %s", sessionID) } +// ensureSessionSwitchAllowed 统一阻止运行中切换到其他会话,避免 UI 脱离仍在执行的 run 上下文。 +func (a *App) ensureSessionSwitchAllowed(targetSessionID string) error { + targetSessionID = strings.TrimSpace(targetSessionID) + activeSessionID := strings.TrimSpace(a.state.ActiveSessionID) + if !a.isBusy() || (targetSessionID != "" && strings.EqualFold(targetSessionID, activeSessionID)) { + return nil + } + return fmt.Errorf(sessionSwitchBusyMessage) +} + func (a *App) syncActiveSessionTitle() { if strings.TrimSpace(a.state.ActiveSessionID) == "" { if strings.TrimSpace(a.state.ActiveSessionTitle) == "" { @@ -1894,6 +1918,12 @@ func (a *App) handleImmediateSlashCommand(input string) (bool, tea.Cmd) { case slashCommandForget: return true, a.handleForgetCommand(rest) case slashCommandSession: + if err := a.ensureSessionSwitchAllowed(""); err != nil { + a.state.ExecutionError = err.Error() + a.state.StatusText = err.Error() + a.appendActivity("session", "Failed to open session picker", err.Error(), true) + return true, nil + } if err := a.refreshSessionPicker(); err != nil { a.state.ExecutionError = err.Error() a.state.StatusText = err.Error() diff --git a/internal/tui/core/app/update_test.go b/internal/tui/core/app/update_test.go index 269f33f4..7bbdbc2e 100644 --- a/internal/tui/core/app/update_test.go +++ b/internal/tui/core/app/update_test.go @@ -1761,6 +1761,35 @@ func TestUpdatePickerSessionEnterActivatesSelectedSession(t *testing.T) { } } +func TestUpdatePickerSessionEnterWhileBusyRejectsSwitch(t *testing.T) { + app, runtime := newTestApp(t) + now := time.Now() + runtime.listSessions = []agentsession.Summary{ + {ID: "s1", Title: "One", UpdatedAt: now.Add(-time.Minute)}, + {ID: "s2", Title: "Two", UpdatedAt: now}, + } + if err := app.refreshSessionPicker(); err != nil { + t.Fatalf("refreshSessionPicker() error = %v", err) + } + app.state.ActiveSessionID = "s1" + app.state.ActiveSessionTitle = "One" + app.state.IsAgentRunning = true + app.openPicker(pickerSession, statusChooseSession, &app.sessionPicker, "s1") + app.sessionPicker.Select(1) + + model, cmd := app.updatePicker(tea.KeyMsg{Type: tea.KeyEnter}) + if cmd != nil { + t.Fatalf("expected nil cmd for rejected session switch") + } + app = model.(App) + if app.state.ActiveSessionID != "s1" { + t.Fatalf("expected active session to remain unchanged, got %q", app.state.ActiveSessionID) + } + if !strings.Contains(app.state.ExecutionError, sessionSwitchBusyMessage) { + t.Fatalf("expected busy session switch error, got %q", app.state.ExecutionError) + } +} + func TestActivateSessionByIDNotFound(t *testing.T) { app, _ := newTestApp(t) app.state.Sessions = []agentsession.Summary{{ID: "s1", Title: "one"}} @@ -1786,6 +1815,25 @@ func TestHandleImmediateSlashCommandSession(t *testing.T) { } } +func TestHandleImmediateSlashCommandSessionWhileBusy(t *testing.T) { + app, _ := newTestApp(t) + app.state.IsAgentRunning = true + + handled, cmd := app.handleImmediateSlashCommand("/session") + if !handled { + t.Fatalf("expected /session to be handled immediately") + } + if cmd != nil { + t.Fatalf("expected busy /session to avoid returning cmd") + } + if app.state.ActivePicker != pickerNone { + t.Fatalf("expected session picker to stay closed while busy") + } + if !strings.Contains(app.state.ExecutionError, sessionSwitchBusyMessage) { + t.Fatalf("expected busy session switch error, got %q", app.state.ExecutionError) + } +} + func TestRuntimeEventToolStatusHandler(t *testing.T) { app, _ := newTestApp(t) payload := tuiservices.RuntimeToolStatusPayload{ToolCallID: "tool-1", ToolName: "bash", Status: string(tuistate.ToolLifecyclePlanned)}