diff --git a/.gitignore b/.gitignore index 5ab96303..68339b20 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,8 @@ config.local.yaml data/ .cache/ .neocode/projects/**/.transcripts/ +.gocache/ +.gomodcache/ # Editor/IDE .idea/ @@ -43,6 +45,7 @@ workspace.xml .claude/ .windsurf/ .codebuddy/ +.agents/ # VitePress / frontend build artifacts www/.vitepress/cache/ www/.vitepress/dist/ diff --git a/docs/gateway-rpc-api.md b/docs/gateway-rpc-api.md index c3013e96..08a481bd 100644 --- a/docs/gateway-rpc-api.md +++ b/docs/gateway-rpc-api.md @@ -176,11 +176,12 @@ type RunParams struct { } ``` -- 多模态图片约束: +- 多模态附件约束: - `type=image` 时 `media.mime_type` 必填。 - `media.uri` 与 `media.asset_id` 必须二选一,不能同时为空或同时提供。 - - `media.uri` 仅用于后端可读取的本地路径;Web 浏览器上传图片应先通过 `POST /api/session-assets` 保存,再在 `gateway.run` 中使用 `media.asset_id` 引用。 + - `media.uri` 仅用于后端可读取的本地路径;Web 浏览器上传图片或文本应先通过 `POST /api/session-assets` 保存,再在 `gateway.run` 中使用 `media.asset_id` 引用。 - `asset_id` 必须属于当前 `session_id`,不存在或跨 session 引用会在 runtime 输入准备阶段失败。 + - 文本附件(如 `.md`、`.json`、`.csv`)使用 `type=image` + 真实文本 mime 表达,runtime 端按 `session.TextAssetWhitelist` 自动判定并内联为 user message 的 text part;无需新增 `type` 字段。详见 issue #701。 - Response Schema: - Success(受理即返回): @@ -242,13 +243,15 @@ type RunParams struct { - Content-Type: `multipart/form-data` - Fields: - `session_id`: 目标会话 ID,必填。 - - `file`: 图片文件,必填。 + - `file`: 图片或文本文件,必填。 - Server-side validation: - - 仅接受 `image/png`、`image/jpeg`、`image/webp`。 + - 接受 `image/png`、`image/jpeg`、`image/webp`(按文件头嗅探)。 + - 同时接受会话侧文本资产白名单内的扩展名(`.txt`、`.md`、`.json`、`.yaml`、`.yml`、`.csv`)与对应 MIME(`text/plain`、`text/markdown`、`application/json`、`text/yaml`、`application/x-yaml`、`text/csv`)。 + - 文本资产额外校验 UTF-8,非 UTF-8 内容返回 `415`。 - MIME 以服务端文件头检测结果为准,不信任浏览器声明。 - 空文件返回 `400`。 - - 超过 `MaxSessionAssetBytes` 返回 `413`。 - - 非图片或不支持类型返回 `415`。 + - 超过 `MaxSessionAssetBytes`(图片)或 `MaxTextAssetBytes`(文本)返回 `413`。 + - 不在任一白名单内的类型返回 `415`。 - 未认证返回 `401`,Origin/CORS 或 ACL 拒绝返回 `403`。 - 工作区不存在返回 `404 workspace not found`;目标 session 不在该工作区返回 `404 session not found`。 - Response: @@ -262,6 +265,8 @@ type RunParams struct { } ``` +文本附件上传成功后,runtime 会在 `PrepareUserInput` 阶段按 `session.TextAssetWhitelist` 命中后自动读取并内联为 user message 的 `text` content part(带文件名边界 + 可选截断提示),Provider 层不感知"文件"概念。配置项 `runtime.assets.text_asset_enabled`(默认 `true`)可关闭该行为,关闭后文本附件会作为普通附件原样提交。详见 `docs/runtime-provider-event-flow.md` 与 issue #701。 + ### GET /api/session-assets/{session_id}/{asset_id} - Auth Required: Yes(`Authorization: Bearer `) diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index c1b08047..9186b76f 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -48,6 +48,9 @@ runtime: assets: max_session_asset_bytes: 20971520 max_session_assets_total_bytes: 20971520 + text_asset_enabled: true + max_text_asset_bytes: 262144 # 256 KiB + max_text_asset_chars: 250000 # ~25 万 UTF-8 字符 tools: webfetch: @@ -112,6 +115,9 @@ context: | `runtime.hooks.items` | user hooks 列表;支持 `builtin/sync` 与 `http/observe` 两种子类型 | | `runtime.assets.max_session_asset_bytes` | 单个 `session_asset` 最大原始字节数,默认 `20971520`(20 MiB);`0` 或未配置时回退默认值 | | `runtime.assets.max_session_assets_total_bytes` | 单次请求可携带的 `session_asset` 原始总字节上限,默认 `20971520`(20 MiB);`0` 或未配置时回退默认值 | +| `runtime.assets.text_asset_enabled` | 是否把文本类 asset 在提交会话前内联为 text part,默认 `true`;关闭后文本 asset 走原图片路径(回滚开关) | +| `runtime.assets.max_text_asset_bytes` | 单个文本 asset 最大字节数,默认 `262144`(256 KiB),硬上限 `4 MiB`;超过保存时返回 413 | +| `runtime.assets.max_text_asset_chars` | 单个文本 asset 在 UTF-8 解码后的最大字符数,默认 `250000`,硬上限 `4_000_000`;超过时截断并保留截断提示 | ### `runtime.hooks.items` 字段约束 diff --git a/internal/config/runtime.go b/internal/config/runtime.go index 36e035a2..f8437356 100644 --- a/internal/config/runtime.go +++ b/internal/config/runtime.go @@ -25,6 +25,13 @@ type RuntimeConfig struct { type RuntimeAssetsConfig struct { MaxSessionAssetBytes int64 `yaml:"max_session_asset_bytes,omitempty"` MaxSessionAssetsTotalBytes int64 `yaml:"max_session_assets_total_bytes,omitempty"` + // TextAssetEnabled 控制是否把文本类 asset 在提交会话前内联为 text part;关闭时文本 asset + // 仅作为图像风格的会话附件存在(保持向后兼容,便于回滚)。 + TextAssetEnabled *bool `yaml:"text_asset_enabled,omitempty"` + // MaxTextAssetBytes 限制单个文本 asset 的字节上限(保存与读取都受此约束)。 + MaxTextAssetBytes int64 `yaml:"max_text_asset_bytes,omitempty"` + // MaxTextAssetChars 限制单个文本 asset 在 UTF-8 解码后允许保留的最大字符数。 + MaxTextAssetChars int `yaml:"max_text_asset_chars,omitempty"` } // defaultRuntimeConfig 返回 runtime 配置的静态默认值。 @@ -40,9 +47,13 @@ func defaultRuntimeConfig() RuntimeConfig { // defaultRuntimeAssetsConfig 返回 runtime 附件限制配置默认值。 func defaultRuntimeAssetsConfig() RuntimeAssetsConfig { + enabled := true return RuntimeAssetsConfig{ MaxSessionAssetBytes: session.MaxSessionAssetBytes, MaxSessionAssetsTotalBytes: provider.MaxSessionAssetsTotalBytes, + TextAssetEnabled: &enabled, + MaxTextAssetBytes: session.DefaultMaxTextAssetBytes, + MaxTextAssetChars: session.DefaultMaxTextAssetChars, } } @@ -107,9 +118,24 @@ func (c RuntimeConfig) ResolveRequestAssetBudget() provider.RequestAssetBudget { return c.Assets.ResolveRequestAssetBudget() } +// ResolveTextAssetPolicy 归一化 runtime 文本附件策略并施加代码硬上限兜底。 +func (c RuntimeConfig) ResolveTextAssetPolicy() session.TextAssetPolicy { + return c.Assets.ResolveTextAssetPolicy() +} + +// IsTextAssetEnabled 返回当前 runtime 文本附件内联开关。 +func (c RuntimeConfig) IsTextAssetEnabled() bool { + return c.Assets.IsTextAssetEnabled() +} + // Clone 复制附件限制配置,避免调用方共享可变状态。 func (c RuntimeAssetsConfig) Clone() RuntimeAssetsConfig { - return c + out := c + if c.TextAssetEnabled != nil { + enabled := *c.TextAssetEnabled + out.TextAssetEnabled = &enabled + } + return out } // ApplyDefaults 在配置缺失、为零或非法时回填附件限制默认值。 @@ -123,6 +149,20 @@ func (c *RuntimeAssetsConfig) ApplyDefaults(defaults RuntimeAssetsConfig) { if c.MaxSessionAssetsTotalBytes <= 0 { c.MaxSessionAssetsTotalBytes = defaults.MaxSessionAssetsTotalBytes } + if c.TextAssetEnabled == nil { + // nil 视为 true,与 IsTextAssetEnabled() 的 nil-as-true 语义对齐。 + enabled := true + if defaults.TextAssetEnabled != nil { + enabled = *defaults.TextAssetEnabled + } + c.TextAssetEnabled = &enabled + } + if c.MaxTextAssetBytes <= 0 { + c.MaxTextAssetBytes = defaults.MaxTextAssetBytes + } + if c.MaxTextAssetChars <= 0 { + c.MaxTextAssetChars = defaults.MaxTextAssetChars + } } // Validate 校验附件限制配置是否满足最小约束;0 表示使用默认值,仅禁止负数。 @@ -133,9 +173,23 @@ func (c RuntimeAssetsConfig) Validate() error { if c.MaxSessionAssetsTotalBytes < 0 { return errors.New("runtime.assets.max_session_assets_total_bytes must be greater than or equal to 0") } + if c.MaxTextAssetBytes < 0 { + return errors.New("runtime.assets.max_text_asset_bytes must be greater than or equal to 0") + } + if c.MaxTextAssetChars < 0 { + return errors.New("runtime.assets.max_text_asset_chars must be greater than or equal to 0") + } return nil } +// IsTextAssetEnabled 返回文本类 asset 内联开关;nil 视为 true。 +func (c RuntimeAssetsConfig) IsTextAssetEnabled() bool { + if c.TextAssetEnabled == nil { + return true + } + return *c.TextAssetEnabled +} + // ResolveSessionAssetPolicy 归一化附件存储策略并应用代码硬上限。 func (c RuntimeAssetsConfig) ResolveSessionAssetPolicy() session.AssetPolicy { return session.NormalizeAssetPolicy(session.AssetPolicy{ @@ -150,3 +204,12 @@ func (c RuntimeAssetsConfig) ResolveRequestAssetBudget() provider.RequestAssetBu MaxSessionAssetsTotalBytes: c.MaxSessionAssetsTotalBytes, }, assetPolicy.MaxSessionAssetBytes) } + +// ResolveTextAssetPolicy 归一化文本附件策略并应用代码硬上限。 +func (c RuntimeAssetsConfig) ResolveTextAssetPolicy() session.TextAssetPolicy { + return session.NormalizeTextAssetPolicy(session.TextAssetPolicy{ + Whitelist: session.DefaultTextAssetWhitelist(), + MaxTextAssetBytes: c.MaxTextAssetBytes, + MaxTextAssetChars: c.MaxTextAssetChars, + }) +} diff --git a/internal/config/runtime_test.go b/internal/config/runtime_test.go index e88fdfd3..640341ef 100644 --- a/internal/config/runtime_test.go +++ b/internal/config/runtime_test.go @@ -19,6 +19,24 @@ func TestRuntimeConfigCloneAndDefaults(t *testing.T) { } } +// TestRuntimeAssetsConfigApplyDefaultsNilAsTrue 验证当 defaults.TextAssetEnabled 为 nil 时, +// ApplyDefaults 将 TextAssetEnabled 设为 true,与 IsTextAssetEnabled() 的 nil-as-true 语义对齐。 +func TestRuntimeAssetsConfigApplyDefaultsNilAsTrue(t *testing.T) { + t.Parallel() + + var zero RuntimeAssetsConfig + zero.ApplyDefaults(RuntimeAssetsConfig{}) + if zero.TextAssetEnabled == nil { + t.Fatal("expected TextAssetEnabled to be non-nil after ApplyDefaults") + } + if !*zero.TextAssetEnabled { + t.Fatalf("expected TextAssetEnabled default true, got false") + } + if !zero.IsTextAssetEnabled() { + t.Fatalf("expected IsTextAssetEnabled() true, got false") + } +} + func TestRuntimeConfigValidate(t *testing.T) { t.Parallel() diff --git a/internal/gateway/network_server.go b/internal/gateway/network_server.go index cefc5a48..e3e1c9d8 100644 --- a/internal/gateway/network_server.go +++ b/internal/gateway/network_server.go @@ -17,6 +17,7 @@ import ( "strings" "sync" "time" + "unicode/utf8" "github.com/prometheus/client_golang/prometheus/promhttp" "golang.org/x/net/websocket" @@ -435,7 +436,7 @@ func (s *NetworkServer) handleSessionAssetUpload(writer http.ResponseWriter, req return } - file, _, err := request.FormFile("file") + file, fileHeader, err := request.FormFile("file") if err != nil { writeJSONResponse(writer, http.StatusBadRequest, map[string]string{"error": "file is required"}) return @@ -443,6 +444,10 @@ func (s *NetworkServer) handleSessionAssetUpload(writer http.ResponseWriter, req defer func() { _ = file.Close() }() + fileName := "" + if fileHeader != nil { + fileName = strings.TrimSpace(fileHeader.Filename) + } payload, err := io.ReadAll(io.LimitReader(file, limit+1)) if err != nil { @@ -460,7 +465,11 @@ func (s *NetworkServer) handleSessionAssetUpload(writer http.ResponseWriter, req mimeType := detectAllowedUploadImageMime(payload) if mimeType == "" { - writeJSONResponse(writer, http.StatusUnsupportedMediaType, map[string]string{"error": "unsupported image type"}) + // 文本附件走白名单嗅探:先按声明/扩展名匹配,再做 UTF-8 校验。 + mimeType = detectAllowedUploadTextMime(payload, fileName) + } + if mimeType == "" { + writeJSONResponse(writer, http.StatusUnsupportedMediaType, map[string]string{"error": "unsupported asset type"}) return } @@ -626,6 +635,46 @@ func detectAllowedUploadImageMime(payload []byte) string { } } +// detectAllowedUploadTextMime 按会话侧文本资产白名单探测上传文件 MIME。 +// 流程:先按文件扩展名查 mime;若未命中则尝试从 payload 内容头推断(仅对纯文本类文件)。 +// 任一环节命中后必须再次校验 payload 是否为合法 UTF-8,非 UTF-8 返回空。 +// 与 detectAllowedUploadImageMime 并列,互不冲突。 +func detectAllowedUploadTextMime(payload []byte, fileName string) string { + if len(payload) == 0 { + return "" + } + whitelist := agentsession.DefaultTextAssetWhitelist() + if whitelist.IsEmpty() { + return "" + } + mimeType := whitelist.LookupByExtension(fileName) + if mimeType == "" { + // 用 http.DetectContentType 做粗略内容嗅探;命中 text/* 即认为是文本。 + probe := payload + if len(probe) > 512 { + probe = probe[:512] + } + detected := strings.ToLower(strings.TrimSpace(http.DetectContentType(probe))) + if !strings.HasPrefix(detected, "text/") { + return "" + } + // http.DetectContentType 对纯文本会返回 "text/plain; charset=utf-8", + // 需剥离 "; charset=..." 参数后再与白名单对比,否则永远匹配失败。 + mediaType := strings.TrimSpace(strings.SplitN(detected, ";", 2)[0]) + // 仅接受白名单内的 mime,避免被任意 text/* 通过。 + if !whitelist.LookupByMime(mediaType) { + return "" + } + mimeType = mediaType + } + // UTF-8 校验:文本资产进入 runtime 后会被读取并按 UTF-8 解码,非 UTF-8 会立刻失败。 + // 提前在网关层拒绝,避免无效上传占用存储。 + if !utf8.Valid(payload) { + return "" + } + return mimeType +} + // parseSessionAssetPath 从 /api/session-assets/{session_id}/{asset_id} 提取路径参数。 func parseSessionAssetPath(rawPath string) (string, string, bool) { cleanPath := path.Clean("/" + strings.TrimSpace(rawPath)) diff --git a/internal/gateway/network_server_test.go b/internal/gateway/network_server_test.go index c1024eb5..c65abba2 100644 --- a/internal/gateway/network_server_test.go +++ b/internal/gateway/network_server_test.go @@ -647,8 +647,20 @@ func TestNetworkServerSessionAssetUploadErrors(t *testing.T) { } }) - t.Run("non image", func(t *testing.T) { - request := newSessionAssetUploadRequest(t, "session-1", "bad.txt", []byte("not an image")) + t.Run("non whitelisted binary content rejected", func(t *testing.T) { + // 使用二进制内容:既不是图片,content-sniffing 也不会命中 text/* 白名单。 + request := newSessionAssetUploadRequest(t, "session-1", "blob.bin", []byte{0xFF, 0xFE, 0x00, 0x01}) + request.Header.Set("Authorization", "Bearer gateway-token") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusUnsupportedMediaType { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusUnsupportedMediaType) + } + }) + + t.Run("non utf8 text asset is rejected", func(t *testing.T) { + // 0xC3 0x28 是非法 UTF-8 序列;扩展名 .txt 在白名单内,但 UTF-8 校验应拒绝。 + request := newSessionAssetUploadRequest(t, "session-1", "broken.txt", []byte{0xC3, 0x28, 0xA0, 0xA1}) request.Header.Set("Authorization", "Bearer gateway-token") recorder := httptest.NewRecorder() handler.ServeHTTP(recorder, request) @@ -1827,6 +1839,80 @@ func gatewayMinimalPNGBytes() []byte { } } +// TestDetectAllowedUploadTextMime 直接覆盖文本附件 MIME 嗅探的各分支, +// 重点验证 content-sniffing 回退路径剥离 "; charset=utf-8" 参数后能命中白名单。 +func TestDetectAllowedUploadTextMime(t *testing.T) { + tests := []struct { + name string + payload []byte + fileName string + want string + }{ + { + name: "txt extension success", + payload: []byte("hello world"), + fileName: "note.txt", + want: "text/plain", + }, + { + name: "md extension success", + payload: []byte("# title"), + fileName: "readme.md", + want: "text/markdown", + }, + { + name: "csv extension success", + payload: []byte("a,b\n1,2"), + fileName: "data.csv", + want: "text/csv", + }, + { + name: "json extension success", + payload: []byte(`{"k":"v"}`), + fileName: "cfg.json", + want: "application/json", + }, + { + name: "yaml extension success", + payload: []byte("key: value"), + fileName: "conf.yaml", + want: "text/yaml", + }, + { + name: "content sniffing strips charset param for extensionless text", + payload: []byte("plain text without extension"), + fileName: "notes", + want: "text/plain", + }, + { + name: "non utf8 payload rejected even with whitelisted extension", + payload: []byte{0xC3, 0x28, 0xA0, 0xA1}, + fileName: "broken.txt", + want: "", + }, + { + name: "non text content rejected by sniffing", + payload: []byte(""), + fileName: "page", + want: "", + }, + { + name: "empty payload rejected", + payload: nil, + fileName: "empty.txt", + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := detectAllowedUploadTextMime(tt.payload, tt.fileName) + if got != tt.want { + t.Fatalf("detectAllowedUploadTextMime(%q, %q) = %q, want %q", tt.fileName, tt.payload, got, tt.want) + } + }) + } +} + type staticTokenAuthenticator struct { token string } diff --git a/internal/runtime/events.go b/internal/runtime/events.go index 24c17fb4..f99a8595 100644 --- a/internal/runtime/events.go +++ b/internal/runtime/events.go @@ -237,8 +237,9 @@ type TodoEventPayload struct { // InputNormalizedPayload 描述输入归一化完成后的摘要信息。 type InputNormalizedPayload struct { - TextLength int `json:"text_length"` - ImageCount int `json:"image_count"` + TextLength int `json:"text_length"` + ImageCount int `json:"image_count"` + TextAssetCount int `json:"text_asset_count,omitempty"` } // AssetSavedPayload 描述单个附件成功保存后的结果。 diff --git a/internal/runtime/input_prepare.go b/internal/runtime/input_prepare.go index 8b9e7082..305e9cbb 100644 --- a/internal/runtime/input_prepare.go +++ b/internal/runtime/input_prepare.go @@ -13,9 +13,13 @@ import ( const prepareEventEmitTimeout = 200 * time.Millisecond // NewSessionInputPreparer 创建基于 session 子层实现的输入归一化适配器。 +// 文本附件策略(textPolicy)由 PrepareUserInput 在每次调用时通过 sessionTextPolicyInjectable +// 重新注入;构造阶段使用 session 默认值兜底,避免未配置时走到 nil 路径。 func NewSessionInputPreparer(store agentsession.Store, assetStore agentsession.AssetStore) UserInputPreparer { + preparer := agentsession.NewInputPreparer(store, assetStore) + preparer.SetTextAssetPolicy(agentsession.DefaultTextAssetPolicy()) return sessionInputPreparer{ - preparer: agentsession.NewInputPreparer(store, assetStore), + preparer: preparer, } } @@ -44,13 +48,26 @@ func (s *Service) PrepareUserInput(ctx context.Context, input PrepareInput) (Use defaultWorkdir := "" sessionAssetPolicy := agentsession.DefaultAssetPolicy() + textAssetPolicy := agentsession.DefaultTextAssetPolicy() + textAssetEnabled := true if s.configManager != nil { cfg := s.configManager.Get() defaultWorkdir = strings.TrimSpace(cfg.Workdir) sessionAssetPolicy = cfg.Runtime.ResolveSessionAssetPolicy() + textAssetPolicy = cfg.Runtime.ResolveTextAssetPolicy() + textAssetEnabled = cfg.Runtime.IsTextAssetEnabled() } if limitAwareStore, ok := s.sessionAssetStore.(sessionAssetLimitStore); ok { limitAwareStore.SetAssetPolicy(sessionAssetPolicy) + // 同步设置文本附件策略(与图片策略互不影响,按 mime 路由)。 + if textAwareStore, okText := s.sessionAssetStore.(sessionTextAssetLimitStore); okText { + textAwareStore.SetTextAssetPolicy(textAssetPolicy) + } + } + + // 同步把 text policy 注入到 user input preparer,让文本 asset 能被解析。 + if injectable, ok := s.userInputPreparer.(sessionTextPolicyInjectable); ok { + injectable.SetTextAssetPolicy(textAssetPolicy) } prepared, err := s.userInputPreparer.Prepare(ctx, input, defaultWorkdir) @@ -59,10 +76,26 @@ func (s *Service) PrepareUserInput(ctx context.Context, input PrepareInput) (Use return UserInput{}, err } + // 文本附件内联:在提交会话前把 prepared.Parts 里的文本 asset 读取并替换为 text part。 + // 关闭开关时跳过内联,并把文本 asset 的 image part 丢弃,避免非 image/* mime 进入 provider 失败。 runID := strings.TrimSpace(input.RunID) + textAssetCount := 0 + if textAssetEnabled { + inlineResult := s.inlineUserInputTextAssets(ctx, prepared.UserInput.SessionID, input, prepared.UserInput.Parts, textAssetPolicy) + prepared.UserInput.Parts = inlineResult.Parts + textAssetCount = inlineResult.Inlined + } else { + // text_asset_enabled=false:丢弃文本 asset image part,emit EventError 告知用户。 + prepared.UserInput.Parts = dropTextAssetImageParts(prepared.UserInput.Parts, textAssetPolicy, func(assetID string, mime string) { + _ = s.emitPrepareEvent(ctx, EventError, runID, prepared.UserInput.SessionID, + "text asset dropped (text_asset_enabled=false): asset_id="+assetID+" mime="+mime) + }) + } + _ = s.emitPrepareEvent(ctx, EventInputNormalized, runID, prepared.UserInput.SessionID, InputNormalizedPayload{ - TextLength: len([]rune(strings.TrimSpace(input.Text))), - ImageCount: len(input.Images), + TextLength: len([]rune(strings.TrimSpace(input.Text))), + ImageCount: len(input.Images), + TextAssetCount: textAssetCount, }) for index, asset := range prepared.SavedAssets { path := "" @@ -130,10 +163,29 @@ type sessionInputPreparer struct { preparer *agentsession.InputPreparer } +// SetTextAssetPolicy 注入文本附件策略到内部 session.InputPreparer。 +// 该方法用于实现 sessionTextPolicyInjectable 接口,注入过程对调用方透明。 +func (p sessionInputPreparer) SetTextAssetPolicy(policy agentsession.TextAssetPolicy) { + if p.preparer == nil { + return + } + p.preparer.SetTextAssetPolicy(policy) +} + type sessionAssetLimitStore interface { SetAssetPolicy(policy agentsession.AssetPolicy) } +type sessionTextAssetLimitStore interface { + SetTextAssetPolicy(policy agentsession.TextAssetPolicy) +} + +// sessionTextPolicyInjectable 表示 user input preparer 支持运行时注入文本附件策略。 +// 引入这个内部接口是为了在不改动 UserInputPreparer 公开契约的前提下注入配置。 +type sessionTextPolicyInjectable interface { + SetTextAssetPolicy(policy agentsession.TextAssetPolicy) +} + // Prepare 将 runtime 输入 DTO 映射到 session 子层并返回标准 UserInput 结果。 func (p sessionInputPreparer) Prepare( ctx context.Context, diff --git a/internal/runtime/input_prepare_test.go b/internal/runtime/input_prepare_test.go index 46478a9e..731ac790 100644 --- a/internal/runtime/input_prepare_test.go +++ b/internal/runtime/input_prepare_test.go @@ -4,6 +4,7 @@ import ( "context" "os" "path/filepath" + "strings" "testing" "time" @@ -229,3 +230,153 @@ func minimalPNGBytesForRuntimeTest() []byte { 0x44, 0xae, 0x42, 0x60, 0x82, } } + +func TestServicePrepareUserInputInlinesTextAssetAndReportsCount(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + svc, _ := newPrepareTestService(t, workdir, true) + + textPath := filepath.Join(workdir, "notes.md") + if err := os.WriteFile(textPath, []byte("# Title\nbody content"), 0o644); err != nil { + t.Fatalf("write text: %v", err) + } + + input, err := svc.PrepareUserInput(context.Background(), PrepareInput{ + RunID: "run-prepare-text-1", + Text: "user query", + Images: []UserImageInput{{Path: textPath, MimeType: "text/markdown"}}, + }) + if err != nil { + t.Fatalf("PrepareUserInput() error = %v", err) + } + if len(input.Parts) != 2 { + t.Fatalf("expected 2 parts (user text + inlined markdown), got %d", len(input.Parts)) + } + // 第一个 part 是用户文本;第二个 part 应是 markdown 内容(被内联为 text part)。 + if input.Parts[0].Kind != providertypes.ContentPartText { + t.Errorf("Parts[0].Kind = %q, want text", input.Parts[0].Kind) + } + if input.Parts[1].Kind != providertypes.ContentPartText { + t.Errorf("Parts[1].Kind = %q, want text (inlined from text/markdown asset)", input.Parts[1].Kind) + } + if !strings.Contains(input.Parts[1].Text, "# Title") { + t.Errorf("Parts[1].Text = %q, want to contain markdown content", input.Parts[1].Text) + } + + // 第一个事件必须是 input_normalized,TextAssetCount = 1。 + event := mustReadRuntimeEvent(t, svc.Events()) + if event.Type != EventInputNormalized { + t.Fatalf("expected event %q, got %q", EventInputNormalized, event.Type) + } + payload, ok := event.Payload.(InputNormalizedPayload) + if !ok { + t.Fatalf("unexpected payload type: %T", event.Payload) + } + if payload.TextAssetCount != 1 { + t.Errorf("TextAssetCount = %d, want 1", payload.TextAssetCount) + } + if payload.TextLength == 0 { + t.Errorf("TextLength = 0, want > 0") + } +} + +func TestServicePrepareUserInputMixedTextAndImageAssets(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + svc, _ := newPrepareTestService(t, workdir, true) + + imagePath := filepath.Join(workdir, "img.png") + if err := os.WriteFile(imagePath, minimalPNGBytesForRuntimeTest(), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + textPath := filepath.Join(workdir, "data.json") + if err := os.WriteFile(textPath, []byte(`{"k":"v"}`), 0o644); err != nil { + t.Fatalf("write text: %v", err) + } + + input, err := svc.PrepareUserInput(context.Background(), PrepareInput{ + RunID: "run-prepare-mixed-1", + Text: "user query", + Images: []UserImageInput{ + {Path: imagePath, MimeType: "image/png"}, + {Path: textPath, MimeType: "application/json"}, + }, + }) + if err != nil { + t.Fatalf("PrepareUserInput() error = %v", err) + } + // 期望:user text + image asset + inlined json text = 3 parts。 + if len(input.Parts) != 3 { + t.Fatalf("expected 3 parts, got %d: %+v", len(input.Parts), input.Parts) + } + images, texts := 0, 0 + for _, p := range input.Parts { + switch p.Kind { + case providertypes.ContentPartImage: + images++ + case providertypes.ContentPartText: + texts++ + } + } + if images != 1 { + t.Errorf("images = %d, want 1 (image kept)", images) + } + if texts != 2 { + t.Errorf("texts = %d, want 2 (user text + inlined json)", texts) + } +} + +func TestServicePrepareUserInputRespectsTextAssetDisabledConfig(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + runtimeCfg := config.StaticDefaults().Runtime + disabled := false + runtimeCfg.Assets.TextAssetEnabled = &disabled + svc, _ := newPrepareTestServiceWithRuntimeConfig(t, workdir, true, runtimeCfg) + + textPath := filepath.Join(workdir, "notes.md") + if err := os.WriteFile(textPath, []byte("# Title"), 0o644); err != nil { + t.Fatalf("write text: %v", err) + } + + input, err := svc.PrepareUserInput(context.Background(), PrepareInput{ + RunID: "run-prepare-text-disabled", + Text: "user query", + Images: []UserImageInput{{Path: textPath, MimeType: "text/markdown"}}, + }) + if err != nil { + t.Fatalf("PrepareUserInput() error = %v", err) + } + // 关闭开关时文本 asset 被丢弃,只剩用户文本 part,不保留为会失败的 image part。 + if len(input.Parts) != 1 { + t.Fatalf("expected 1 part (text only), got %d: %+v", len(input.Parts), input.Parts) + } + if input.Parts[0].Kind != providertypes.ContentPartText { + t.Errorf("Parts[0].Kind = %q, want text (text asset dropped)", input.Parts[0].Kind) + } + + // 先读到 EventError(文本 asset 被丢弃的通知),再读到 InputNormalized。 + dropEvent := mustReadRuntimeEvent(t, svc.Events()) + if dropEvent.Type != EventError { + t.Fatalf("expected EventError for dropped text asset, got %v", dropEvent.Type) + } + dropMsg, ok := dropEvent.Payload.(string) + if !ok { + t.Fatalf("unexpected drop event payload type: %T", dropEvent.Payload) + } + if !strings.Contains(dropMsg, "text asset dropped") || !strings.Contains(dropMsg, "text_asset_enabled=false") { + t.Errorf("drop event message = %q, want contains 'text asset dropped' and 'text_asset_enabled=false'", dropMsg) + } + + event := mustReadRuntimeEvent(t, svc.Events()) + payload, ok := event.Payload.(InputNormalizedPayload) + if !ok { + t.Fatalf("unexpected payload type: %T", event.Payload) + } + if payload.TextAssetCount != 0 { + t.Errorf("TextAssetCount = %d, want 0 (text inline disabled)", payload.TextAssetCount) + } +} diff --git a/internal/runtime/text_assets.go b/internal/runtime/text_assets.go new file mode 100644 index 00000000..47b61cf4 --- /dev/null +++ b/internal/runtime/text_assets.go @@ -0,0 +1,180 @@ +package runtime + +import ( + "context" + "strings" + + providertypes "neo-code/internal/provider/types" + agentsession "neo-code/internal/session" +) + +// textAssetInlineResult 描述一次文本附件内联的统计结果,供上层事件与日志使用。 +type textAssetInlineResult struct { + Parts []providertypes.ContentPart + Inlined int + Truncated int + Failed int +} + +// inlineUserInputTextAssets 包装 inlineTextSessionAssets,并把失败回写到 prepare 事件。 +// 返回新的 parts 切片与内联统计;PrepareUserInput 用 Inlined 作为 TextAssetCount 上报。 +func (s *Service) inlineUserInputTextAssets( + ctx context.Context, + sessionID string, + input PrepareInput, + parts []providertypes.ContentPart, + policy agentsession.TextAssetPolicy, +) textAssetInlineResult { + if s == nil { + return textAssetInlineResult{Parts: parts} + } + runID := strings.TrimSpace(input.RunID) + fileNames := textAssetFileNameMap(input.Images) + normalized := agentsession.NormalizeTextAssetPolicy(policy) + newParts, detail := inlineTextSessionAssets( + ctx, + s.sessionAssetStore, + sessionID, + parts, + normalized, + fileNames, + func(assetID string, err error) { + // 文本附件内联失败通过 EventError 上报;不引入新事件类型,避免协议扩散。 + _ = s.emitPrepareEvent(ctx, EventError, runID, sessionID, "text asset inline failed: "+err.Error()) + }, + ) + return textAssetInlineResult{ + Parts: newParts, + Inlined: detail.Inlined, + Truncated: detail.Truncated, + Failed: detail.Failed, + } +} + +// inlineTextSessionAssets 把 prepared.Parts 里 mime 属于会话文本白名单的 session_asset image +// part 读取后替换为 text part,让 Provider 完全不感知"文件"概念。 +// 不在白名单内的 image part 保持原样,函数对非 asset 类型 part(text、remote image)透明。 +// +// 行为细节: +// - 每遇到一个目标 part,调用 session.LoadTextAsset 做 UTF-8 校验 + 字节/字符双阈值截断。 +// - 成功 → 用 NewTextPart(content) 替换原 part。 +// - 失败 → 调用者提供的 onError 钩子被触发(用于 emit 事件),该 part 被丢弃(避免把坏数据送给 provider)。 +// - onError 为 nil 时失败静默丢弃。 +func inlineTextSessionAssets( + ctx context.Context, + store agentsession.AssetStore, + sessionID string, + parts []providertypes.ContentPart, + policy agentsession.TextAssetPolicy, + originalFileNames map[string]string, + onError func(assetID string, err error), +) ([]providertypes.ContentPart, textAssetInlineResult) { + result := textAssetInlineResult{} + if store == nil || len(parts) == 0 { + return parts, result + } + normalized := agentsession.NormalizeTextAssetPolicy(policy) + if normalized.Whitelist.IsEmpty() { + return parts, result + } + out := make([]providertypes.ContentPart, 0, len(parts)) + for _, part := range parts { + if part.Kind != providertypes.ContentPartImage || part.Image == nil { + out = append(out, part) + continue + } + // 只对 session asset 类的 image part 做内联,remote URL 直接保留。 + if part.Image.SourceType != providertypes.ImageSourceSessionAsset { + out = append(out, part) + continue + } + if part.Image.Asset == nil || strings.TrimSpace(part.Image.Asset.ID) == "" { + out = append(out, part) + continue + } + assetMime := strings.TrimSpace(part.Image.Asset.MimeType) + if !normalized.Whitelist.LookupByMime(assetMime) { + out = append(out, part) + continue + } + opts := agentsession.TextAssetLoadOptions{ + FallbackName: assetMime, + } + if originalFileNames != nil { + if name, ok := originalFileNames[part.Image.Asset.ID]; ok { + opts.FileName = name + } + } + loadResult, err := agentsession.LoadTextAsset(ctx, store, sessionID, part.Image.Asset.ID, normalized, opts) + if err != nil { + result.Failed++ + if onError != nil { + onError(part.Image.Asset.ID, err) + } + // 失败 part 直接丢弃,不影响其它 part。 + continue + } + result.Inlined++ + if loadResult.Truncated { + result.Truncated++ + } + out = append(out, providertypes.NewTextPart(loadResult.Content)) + } + result.Parts = out + return out, result +} + +// textAssetFileNameMap 把 input.Images 里携带的路径映射为 asset_id → 文件名, +// 供 inlineTextSessionAssets 在截断提示里展示原始文件名。 +// 当 runtime 输入未携带文件名时返回 nil(fallback 到 mime 名称)。 +func textAssetFileNameMap(images []UserImageInput) map[string]string { + if len(images) == 0 { + return nil + } + out := make(map[string]string, len(images)) + for _, img := range images { + if strings.TrimSpace(img.AssetID) == "" { + continue + } + if name := strings.TrimSpace(img.Path); name != "" { + out[img.AssetID] = name + } + } + if len(out) == 0 { + return nil + } + return out +} + +// dropTextAssetImageParts 在 text_asset_enabled=false 时把 prepared.Parts 里属于文本白名单的 +// session_asset image part 丢弃,避免非 image/* mime 进入 provider 的 image-source resolver 失败。 +// 复用 inlineTextSessionAssets 的识别条件(ContentPartImage + ImageSourceSessionAsset + whitelist), +// 但丢弃而非读取替换。onDropped 钩子用于 emit 事件告知调用方,可为 nil。 +func dropTextAssetImageParts( + parts []providertypes.ContentPart, + policy agentsession.TextAssetPolicy, + onDropped func(assetID string, mime string), +) []providertypes.ContentPart { + if len(parts) == 0 { + return parts + } + normalized := agentsession.NormalizeTextAssetPolicy(policy) + if normalized.Whitelist.IsEmpty() { + return parts + } + out := make([]providertypes.ContentPart, 0, len(parts)) + for _, part := range parts { + // 只丢弃 session asset 类的文本 image part;图片 image part 和非 asset part 原样保留。 + if part.Kind == providertypes.ContentPartImage && part.Image != nil && + part.Image.SourceType == providertypes.ImageSourceSessionAsset && + part.Image.Asset != nil && strings.TrimSpace(part.Image.Asset.ID) != "" && + normalized.Whitelist.LookupByMime(part.Image.Asset.MimeType) { + if onDropped != nil { + onDropped(part.Image.Asset.ID, part.Image.Asset.MimeType) + } + continue + } + out = append(out, part) + } + return out +} diff --git a/internal/runtime/text_assets_test.go b/internal/runtime/text_assets_test.go new file mode 100644 index 00000000..5acbe5e9 --- /dev/null +++ b/internal/runtime/text_assets_test.go @@ -0,0 +1,376 @@ +package runtime + +import ( + "bytes" + "context" + "errors" + "io" + "strings" + "sync" + "testing" + + providertypes "neo-code/internal/provider/types" + agentsession "neo-code/internal/session" +) + +// textAssetStubStore 是为 inlineTextSessionAssets 设计的最小 AssetStore 桩。 +// 预置的 mimeToBytes / mimeToFileName 决定 Open 返回的内容;openErr 可注入 IO 失败。 +type textAssetStubStore struct { + mu sync.Mutex + payloads map[string]map[string][]byte // sessionID -> assetID -> bytes + mimes map[string]map[string]string // sessionID -> assetID -> mime + fileNames map[string]map[string]string // sessionID -> assetID -> fileName + openErr error + invalidUTF8 bool +} + +func newTextAssetStubStore() *textAssetStubStore { + return &textAssetStubStore{ + payloads: map[string]map[string][]byte{}, + mimes: map[string]map[string]string{}, + fileNames: map[string]map[string]string{}, + } +} + +func (s *textAssetStubStore) SaveAsset(_ context.Context, sessionID string, r io.Reader, mimeType string) (agentsession.AssetMeta, error) { + data, err := io.ReadAll(r) + if err != nil { + return agentsession.AssetMeta{}, err + } + s.mu.Lock() + defer s.mu.Unlock() + if s.payloads[sessionID] == nil { + s.payloads[sessionID] = map[string][]byte{} + } + if s.mimes[sessionID] == nil { + s.mimes[sessionID] = map[string]string{} + } + if s.fileNames[sessionID] == nil { + s.fileNames[sessionID] = map[string]string{} + } + id := "asset-" + mimeType + s.payloads[sessionID][id] = data + s.mimes[sessionID][id] = mimeType + return agentsession.AssetMeta{ID: id, MimeType: mimeType, Size: int64(len(data))}, nil +} + +func (s *textAssetStubStore) Open(_ context.Context, sessionID, assetID string) (io.ReadCloser, agentsession.AssetMeta, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.openErr != nil { + return nil, agentsession.AssetMeta{}, s.openErr + } + payloads, ok := s.payloads[sessionID] + if !ok { + return nil, agentsession.AssetMeta{}, errors.New("session not found") + } + data, ok := payloads[assetID] + if !ok { + return nil, agentsession.AssetMeta{}, errors.New("asset not found") + } + mime := s.mimes[sessionID][assetID] + if s.invalidUTF8 { + // 注入一个非 UTF-8 字节序列用于错误路径测试。 + data = []byte{0xC3, 0x28, 0xFF, 0xFE} + } + return io.NopCloser(bytes.NewReader(data)), agentsession.AssetMeta{ID: assetID, MimeType: mime, Size: int64(len(data))}, nil +} + +func (s *textAssetStubStore) Stat(_ context.Context, sessionID, assetID string) (agentsession.AssetMeta, error) { + s.mu.Lock() + defer s.mu.Unlock() + mime := s.mimes[sessionID][assetID] + return agentsession.AssetMeta{ID: assetID, MimeType: mime, Size: int64(len(s.payloads[sessionID][assetID]))}, nil +} + +// 工具方法:把 inlineTextSessionAssets 返回的结果"扁平化"为 (image parts, text parts) 数量。 +func countPartsByKind(parts []providertypes.ContentPart) (images, texts int) { + for _, p := range parts { + switch p.Kind { + case providertypes.ContentPartImage: + images++ + case providertypes.ContentPartText: + texts++ + } + } + return +} + +func TestInlineTextSessionAssetsReplacesTextAsset(t *testing.T) { + t.Parallel() + + store := newTextAssetStubStore() + ctx := context.Background() + meta, err := store.SaveAsset(ctx, "s1", strings.NewReader("# README\nhello world"), "text/markdown") + if err != nil { + t.Fatalf("SaveAsset() error = %v", err) + } + + parts := []providertypes.ContentPart{ + providertypes.NewTextPart("user query"), + providertypes.NewSessionAssetImagePart(meta.ID, meta.MimeType), + } + + out, result := inlineTextSessionAssets(ctx, store, "s1", parts, agentsession.DefaultTextAssetPolicy(), nil, nil) + if result.Failed != 0 { + t.Fatalf("Failed = %d, want 0", result.Failed) + } + if result.Inlined != 1 { + t.Fatalf("Inlined = %d, want 1", result.Inlined) + } + if result.Truncated != 0 { + t.Fatalf("Truncated = %d, want 0", result.Truncated) + } + images, texts := countPartsByKind(out) + if images != 0 { + t.Errorf("images = %d, want 0 (text asset should be replaced)", images) + } + if texts != 2 { + t.Errorf("texts = %d, want 2 (user query + inlined text asset)", texts) + } + // 内联后的内容应包含原始 markdown 文本。 + found := false + for _, p := range out { + if p.Kind == providertypes.ContentPartText && strings.Contains(p.Text, "# README") { + found = true + } + } + if !found { + t.Errorf("expected inlined text part to contain # README, got %+v", out) + } +} + +func TestInlineTextSessionAssetsKeepsImageAsset(t *testing.T) { + t.Parallel() + + store := newTextAssetStubStore() + ctx := context.Background() + meta, err := store.SaveAsset(ctx, "s1", strings.NewReader("image-bytes"), "image/png") + if err != nil { + t.Fatalf("SaveAsset() error = %v", err) + } + + parts := []providertypes.ContentPart{ + providertypes.NewSessionAssetImagePart(meta.ID, meta.MimeType), + } + + out, result := inlineTextSessionAssets(ctx, store, "s1", parts, agentsession.DefaultTextAssetPolicy(), nil, nil) + if result.Inlined != 0 { + t.Errorf("Inlined = %d, want 0 (image asset should not be inlined)", result.Inlined) + } + if result.Failed != 0 { + t.Errorf("Failed = %d, want 0", result.Failed) + } + images, _ := countPartsByKind(out) + if images != 1 { + t.Errorf("images = %d, want 1 (image asset kept)", images) + } +} + +func TestInlineTextSessionAssetsKeepsRemoteImage(t *testing.T) { + t.Parallel() + + store := newTextAssetStubStore() + parts := []providertypes.ContentPart{ + providertypes.NewRemoteImagePart("https://example.com/cat.png"), + } + out, result := inlineTextSessionAssets(context.Background(), store, "s1", parts, agentsession.DefaultTextAssetPolicy(), nil, nil) + if result.Inlined != 0 { + t.Errorf("Inlined = %d, want 0", result.Inlined) + } + images, _ := countPartsByKind(out) + if images != 1 { + t.Errorf("images = %d, want 1 (remote image kept)", images) + } +} + +func TestInlineTextSessionAssetsDropsOnUTF8Error(t *testing.T) { + t.Parallel() + + store := newTextAssetStubStore() + store.invalidUTF8 = true + ctx := context.Background() + meta, err := store.SaveAsset(ctx, "s1", strings.NewReader("placeholder"), "text/plain") + if err != nil { + t.Fatalf("SaveAsset() error = %v", err) + } + + parts := []providertypes.ContentPart{ + providertypes.NewTextPart("user text"), + providertypes.NewSessionAssetImagePart(meta.ID, meta.MimeType), + } + + var onErrorCalls int + var onErrorID string + var onErrorErr error + out, result := inlineTextSessionAssets(ctx, store, "s1", parts, agentsession.DefaultTextAssetPolicy(), nil, + func(assetID string, err error) { + onErrorCalls++ + onErrorID = assetID + onErrorErr = err + }, + ) + if result.Failed != 1 { + t.Errorf("Failed = %d, want 1", result.Failed) + } + if result.Inlined != 0 { + t.Errorf("Inlined = %d, want 0", result.Inlined) + } + if onErrorCalls != 1 { + t.Errorf("onError called %d times, want 1", onErrorCalls) + } + if onErrorID != meta.ID { + t.Errorf("onError assetID = %q, want %q", onErrorID, meta.ID) + } + var loadErr *agentsession.AssetTextLoadError + if !errors.As(onErrorErr, &loadErr) { + t.Errorf("onError err type = %T, want *AssetTextLoadError", onErrorErr) + } + images, texts := countPartsByKind(out) + if images != 0 { + t.Errorf("images = %d, want 0 (failed part dropped)", images) + } + if texts != 1 { + t.Errorf("texts = %d, want 1 (user text kept)", texts) + } +} + +func TestInlineTextSessionAssetsTruncatesAndAddsMarker(t *testing.T) { + t.Parallel() + + store := newTextAssetStubStore() + ctx := context.Background() + policy := agentsession.TextAssetPolicy{ + Whitelist: agentsession.DefaultTextAssetWhitelist(), + MaxTextAssetBytes: 4, + MaxTextAssetChars: 1000, + } + original := strings.Repeat("a", 16) + meta, err := store.SaveAsset(ctx, "s1", strings.NewReader(original), "text/plain") + if err != nil { + t.Fatalf("SaveAsset() error = %v", err) + } + + parts := []providertypes.ContentPart{ + providertypes.NewSessionAssetImagePart(meta.ID, meta.MimeType), + } + + out, result := inlineTextSessionAssets(ctx, store, "s1", parts, policy, map[string]string{ + meta.ID: "big.txt", + }, nil) + if result.Inlined != 1 { + t.Errorf("Inlined = %d, want 1", result.Inlined) + } + if result.Truncated != 1 { + t.Errorf("Truncated = %d, want 1", result.Truncated) + } + images, _ := countPartsByKind(out) + if images != 0 { + t.Errorf("images = %d, want 0 (truncated text asset still replaced)", images) + } + foundMarker := false + for _, p := range out { + if p.Kind == providertypes.ContentPartText && strings.Contains(p.Text, "[truncated:") && strings.Contains(p.Text, "filename=big.txt") { + foundMarker = true + } + } + if !foundMarker { + t.Errorf("expected inlined part to contain truncation marker with sanitized filename, got %+v", out) + } +} + +func TestInlineTextSessionAssetsMixedTextAndImage(t *testing.T) { + t.Parallel() + + store := newTextAssetStubStore() + ctx := context.Background() + imageMeta, err := store.SaveAsset(ctx, "s1", strings.NewReader("image"), "image/png") + if err != nil { + t.Fatalf("SaveAsset(image) error = %v", err) + } + textMeta, err := store.SaveAsset(ctx, "s1", strings.NewReader("json-body"), "application/json") + if err != nil { + t.Fatalf("SaveAsset(text) error = %v", err) + } + + parts := []providertypes.ContentPart{ + providertypes.NewTextPart("user text"), + providertypes.NewSessionAssetImagePart(imageMeta.ID, imageMeta.MimeType), + providertypes.NewSessionAssetImagePart(textMeta.ID, textMeta.MimeType), + } + out, result := inlineTextSessionAssets(ctx, store, "s1", parts, agentsession.DefaultTextAssetPolicy(), nil, nil) + if result.Inlined != 1 { + t.Errorf("Inlined = %d, want 1 (only text asset should be inlined)", result.Inlined) + } + images, texts := countPartsByKind(out) + if images != 1 { + t.Errorf("images = %d, want 1 (image kept)", images) + } + if texts != 2 { + t.Errorf("texts = %d, want 2 (user text + inlined text asset)", texts) + } +} + +func TestTextAssetFileNameMap(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + images []UserImageInput + wantNil bool + want map[string]string + }{ + { + name: "empty images", + images: nil, + wantNil: true, + }, + { + name: "skip empty asset id", + images: []UserImageInput{ + {AssetID: "", Path: "/x.png"}, + }, + wantNil: true, + }, + { + name: "skip empty path", + images: []UserImageInput{ + {AssetID: "a1", Path: " "}, + }, + wantNil: true, + }, + { + name: "map populated", + images: []UserImageInput{ + {AssetID: "a1", Path: "/tmp/notes.md"}, + {AssetID: "a2", Path: "data.csv"}, + }, + want: map[string]string{ + "a1": "/tmp/notes.md", + "a2": "data.csv", + }, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := textAssetFileNameMap(tc.images) + if tc.wantNil { + if got != nil { + t.Errorf("got = %+v, want nil", got) + } + return + } + if got == nil { + t.Fatalf("got nil, want %+v", tc.want) + } + if len(got) != len(tc.want) { + t.Errorf("len(got)=%d, want %d", len(got), len(tc.want)) + } + for k, v := range tc.want { + if got[k] != v { + t.Errorf("got[%q] = %q, want %q", k, got[k], v) + } + } + }) + } +} diff --git a/internal/session/asset_store.go b/internal/session/asset_store.go index 6fee5d3d..65afe964 100644 --- a/internal/session/asset_store.go +++ b/internal/session/asset_store.go @@ -22,16 +22,23 @@ type AssetStore interface { } // newAssetMeta 生成新的会话附件元数据,并校验 MIME 约束。 +// 接受 image/* 与会话侧 TextAssetWhitelist 命中的文本 MIME;其他输入返回明确错误。 func newAssetMeta(mimeType string) (AssetMeta, error) { normalized := strings.ToLower(strings.TrimSpace(mimeType)) if normalized == "" { return AssetMeta{}, fmt.Errorf("session: asset mime type is empty") } - if !strings.HasPrefix(normalized, "image/") { - return AssetMeta{}, fmt.Errorf("session: unsupported asset mime type %q", mimeType) + if strings.HasPrefix(normalized, "image/") { + return AssetMeta{ + ID: NewID("asset"), + MimeType: normalized, + }, nil } - return AssetMeta{ - ID: NewID("asset"), - MimeType: normalized, - }, nil + if DefaultTextAssetWhitelist().LookupByMime(normalized) { + return AssetMeta{ + ID: NewID("asset"), + MimeType: normalized, + }, nil + } + return AssetMeta{}, fmt.Errorf("session: unsupported asset mime type %q", mimeType) } diff --git a/internal/session/asset_store_test.go b/internal/session/asset_store_test.go index e9d0bad0..38c3ab1c 100644 --- a/internal/session/asset_store_test.go +++ b/internal/session/asset_store_test.go @@ -94,7 +94,8 @@ func TestSQLiteStoreSaveAssetRejectsInvalidInput(t *testing.T) { if _, err := store.SaveAsset(ctx, "session_assets_invalid", strings.NewReader("x"), ""); err == nil { t.Fatalf("expected empty mime type error") } - if _, err := store.SaveAsset(ctx, "session_assets_invalid", strings.NewReader("x"), "text/plain"); err == nil { + // text/html 仍不属于会话文本 asset 白名单(避免任意 HTML/可执行脚本被内联为模型上下文)。 + if _, err := store.SaveAsset(ctx, "session_assets_invalid", strings.NewReader("x"), "text/html"); err == nil { t.Fatalf("expected unsupported mime type error") } if _, err := store.SaveAsset(ctx, "missing", strings.NewReader("x"), "image/png"); err == nil { @@ -309,6 +310,116 @@ func MaxSessionAssetBytesForTest() int64 { return MaxSessionAssetBytes } +func TestSQLiteStoreSaveAssetAcceptsWhitelistedTextMime(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := newTestStore(t) + session, err := store.CreateSession(ctx, CreateSessionInput{ID: "session_assets_text_ok", Title: "assets"}) + if err != nil { + t.Fatalf("CreateSession() error = %v", err) + } + + cases := []struct { + mime string + payload string + wantMime string + }{ + {mime: "text/plain", payload: "hello", wantMime: "text/plain"}, + {mime: "text/markdown", payload: "# title", wantMime: "text/markdown"}, + {mime: "application/json", payload: "{\"k\":1}", wantMime: "application/json"}, + {mime: "text/yaml", payload: "k: v", wantMime: "text/yaml"}, + {mime: "application/x-yaml", payload: "k: v", wantMime: "application/x-yaml"}, + {mime: "text/csv", payload: "a,b\n1,2", wantMime: "text/csv"}, + } + for _, tc := range cases { + meta, err := store.SaveAsset(ctx, session.ID, strings.NewReader(tc.payload), tc.mime) + if err != nil { + t.Fatalf("SaveAsset(mime=%q) error = %v", tc.mime, err) + } + if meta.MimeType != tc.wantMime { + t.Errorf("SaveAsset(mime=%q) returned mime=%q", tc.mime, meta.MimeType) + } + if meta.Size != int64(len(tc.payload)) { + t.Errorf("SaveAsset(mime=%q) size=%d, want %d", tc.mime, meta.Size, len(tc.payload)) + } + // Open 必须能取回相同字节流。 + rc, openMeta, err := store.Open(ctx, session.ID, meta.ID) + if err != nil { + t.Fatalf("Open(asset=%s) error = %v", meta.ID, err) + } + data, readErr := io.ReadAll(rc) + _ = rc.Close() + if readErr != nil { + t.Fatalf("ReadAll(asset=%s) error = %v", meta.ID, readErr) + } + if string(data) != tc.payload { + t.Errorf("Open(asset=%s) payload mismatch", meta.ID) + } + if openMeta.MimeType != tc.wantMime { + t.Errorf("Open(asset=%s) mime=%q, want %q", meta.ID, openMeta.MimeType, tc.wantMime) + } + } +} + +func TestSQLiteStoreSaveAssetRejectsNonWhitelistedTextMime(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := newTestStore(t) + session, err := store.CreateSession(ctx, CreateSessionInput{ID: "session_assets_text_reject", Title: "assets"}) + if err != nil { + t.Fatalf("CreateSession() error = %v", err) + } + + cases := []string{ + "text/html", + "application/javascript", + "application/octet-stream", + "application/xml", + "text/css", + "application/pdf", + } + for _, mime := range cases { + if _, err := store.SaveAsset(ctx, session.ID, strings.NewReader("x"), mime); err == nil { + t.Errorf("SaveAsset(mime=%q) should be rejected", mime) + } else if !strings.Contains(err.Error(), "unsupported asset mime type") { + t.Errorf("SaveAsset(mime=%q) unexpected error = %v", mime, err) + } + } +} + +func TestSQLiteStoreSaveAssetAppliesTextPolicySizeLimit(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := newTestStore(t) + store.SetTextAssetPolicy(TextAssetPolicy{ + Whitelist: DefaultTextAssetWhitelist(), + MaxTextAssetBytes: 4, + MaxTextAssetChars: 16, + }) + session, err := store.CreateSession(ctx, CreateSessionInput{ID: "session_assets_text_limit", Title: "assets"}) + if err != nil { + t.Fatalf("CreateSession() error = %v", err) + } + + // 1 字节 OK。 + if _, err := store.SaveAsset(ctx, session.ID, strings.NewReader("abcd"), "text/plain"); err != nil { + t.Fatalf("SaveAsset(within text limit) error = %v", err) + } + // 超 4 字节应被拒(错误信息中应出现 4 bytes 截断值)。 + oversized := bytes.NewReader([]byte("abcde")) + if _, err := store.SaveAsset(ctx, session.ID, oversized, "text/markdown"); err == nil || + !strings.Contains(err.Error(), "asset size exceeds 4 bytes") { + t.Fatalf("SaveAsset(exceed text limit) error = %v, want size limit error", err) + } + // 图片路径仍走 image 上限(默认 20 MiB),4 字节图片 OK。 + if _, err := store.SaveAsset(ctx, session.ID, strings.NewReader("abcd"), "image/png"); err != nil { + t.Fatalf("SaveAsset(image within image limit) error = %v", err) + } +} + func TestSQLiteStoreOpenMissingAssetReturnsNotExist(t *testing.T) { ctx := context.Background() store := newTestStore(t) diff --git a/internal/session/asset_text.go b/internal/session/asset_text.go new file mode 100644 index 00000000..429128e1 --- /dev/null +++ b/internal/session/asset_text.go @@ -0,0 +1,178 @@ +package session + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "unicode/utf8" +) + +// AssetTextLoadError 描述加载 session 文本 asset 时出现的可恢复错误。 +// Reason 区分错误类别(utf8/empty/exceeded/io);AssetID 便于上层回传给前端定位。 +type AssetTextLoadError struct { + AssetID string + Reason string + Err error +} + +// Error 实现 error 接口。 +func (e *AssetTextLoadError) Error() string { + if e == nil { + return "" + } + base := "session: load text asset" + if e.AssetID != "" { + base = fmt.Sprintf("%s %q", base, e.AssetID) + } + if e.Reason != "" { + base = fmt.Sprintf("%s: %s", base, e.Reason) + } + if e.Err != nil { + return fmt.Sprintf("%s: %v", base, e.Err) + } + return base +} + +// Unwrap 支持 errors.Is / errors.As。 +func (e *AssetTextLoadError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} + +// TextAssetLoadResult 是 LoadTextAsset 成功加载并截断后的输出。 +// Content 是按 UTF-8 解码、字节与字符双阈值兜底后的内容(若被截断,会在末尾追加截断提示)。 +// Truncated 为 true 时 Content 已包含截断提示;OriginalBytes 是从 AssetStore 读到的原始字节数 +// (受 LimitReader 控制);KeptChars 是截断后、附加提示前的内容字符数;TotalChars 是最终 Content 的字符数。 +type TextAssetLoadResult struct { + Content string + Truncated bool + OriginalBytes int64 + KeptChars int + TotalChars int +} + +// TextAssetLoadOptions 描述加载文本 asset 时的可调参数。 +// FileName 仅用于截断提示的展示,会通过 SanitizeTextAssetFileName 做安全清洗,不作为可信路径。 +type TextAssetLoadOptions struct { + FileName string + FallbackName string + IncludeBoundary bool // 是否在内容首尾追加 "" 边界;默认 false,runtime 层决定。 +} + +// LoadTextAsset 从 store 中读取指定文本 asset,按 UTF-8 校验并按 policy 截断。 +// 返回结构化结果与错误;非 UTF-8 / 空 / IO 错误都会以 AssetTextLoadError 形式返回。 +func LoadTextAsset( + ctx context.Context, + store AssetStore, + sessionID string, + assetID string, + policy TextAssetPolicy, + opts TextAssetLoadOptions, +) (TextAssetLoadResult, error) { + if ctx == nil { + ctx = context.Background() + } + if err := ctx.Err(); err != nil { + return TextAssetLoadResult{}, &AssetTextLoadError{AssetID: assetID, Reason: "ctx-canceled", Err: err} + } + if store == nil { + return TextAssetLoadResult{}, &AssetTextLoadError{AssetID: assetID, Reason: "store-nil", Err: errors.New("asset store is nil")} + } + if assetID == "" { + return TextAssetLoadResult{}, &AssetTextLoadError{AssetID: assetID, Reason: "missing-asset-id", Err: errors.New("asset id is empty")} + } + normalized := NormalizeTextAssetPolicy(policy) + // 调用方提供的 raw policy 本身就空白名单时直接拒,不应用默认值(与 issue 风险节"白名单关闭=拒"对齐)。 + // 注意:NormalizeTextAssetPolicy 会把空白名单填为默认白名单,因此必须在 normalize 之前检查 raw policy。 + if policy.Whitelist.IsEmpty() { + return TextAssetLoadResult{}, &AssetTextLoadError{AssetID: assetID, Reason: "whitelist-empty", Err: errors.New("text asset whitelist is empty")} + } + + rc, meta, err := store.Open(ctx, sessionID, assetID) + if err != nil { + return TextAssetLoadResult{}, &AssetTextLoadError{AssetID: assetID, Reason: "open", Err: err} + } + defer func() { _ = rc.Close() }() + + // 字节上限保护:避免一次读入超大数据导致 OOM。 + // +1 是为了让"正好等于上限"的场景也能区分"刚好"和"超限"。 + raw, err := io.ReadAll(io.LimitReader(bufio.NewReader(rc), normalized.MaxTextAssetBytes+1)) + if err != nil { + return TextAssetLoadResult{}, &AssetTextLoadError{AssetID: assetID, Reason: "io", Err: err} + } + if err := ctx.Err(); err != nil { + return TextAssetLoadResult{}, &AssetTextLoadError{AssetID: assetID, Reason: "ctx-canceled", Err: err} + } + if int64(len(raw)) == 0 { + return TextAssetLoadResult{}, &AssetTextLoadError{AssetID: assetID, Reason: "empty", Err: errors.New("asset payload is empty")} + } + truncated := false + if int64(len(raw)) > normalized.MaxTextAssetBytes { + raw = raw[:normalized.MaxTextAssetBytes] + truncated = true + } + + // 字符级截断(按 rune 计数)。先做 UTF-8 校验以便给出明确错误。 + if !utf8.Valid(raw) { + return TextAssetLoadResult{}, &AssetTextLoadError{ + AssetID: assetID, + Reason: "utf8", + Err: fmt.Errorf("asset payload is not valid UTF-8 (mime=%q, bytes=%d)", meta.MimeType, len(raw)), + } + } + content, charTruncated := truncateByRuneCount(string(raw), normalized.MaxTextAssetChars) + if charTruncated { + truncated = true + } + // KeptChars 反映"截断后、附加提示前"的原内容字符数;TotalChars 反映最终返回 Content 的字符数。 + keptChars := utf8.RuneCountInString(content) + originalBytes := int64(len(raw)) + + displayName := SanitizeTextAssetFileName(opts.FileName, opts.FallbackName) + if displayName == "" { + displayName = SanitizeTextAssetFileName(meta.MimeType, "text-asset") + } + + if truncated { + content = content + fmt.Sprintf( + "\n\n[truncated: original=%d bytes, kept=%d chars; filename=%s]", + originalBytes, keptChars, displayName, + ) + } + + return TextAssetLoadResult{ + Content: content, + Truncated: truncated, + OriginalBytes: originalBytes, + KeptChars: keptChars, + TotalChars: utf8.RuneCountInString(content), + }, nil +} + +// truncateByRuneCount 按 UTF-8 rune 数截断字符串;返回 (截断后, 是否截断)。 +// 仅在字符数超过上限时触发;不会破坏多字节字符边界。 +func truncateByRuneCount(s string, maxChars int) (string, bool) { + if maxChars <= 0 { + return "", true + } + count := utf8.RuneCountInString(s) + if count <= maxChars { + return s, false + } + var b []byte + runeIdx := 0 + for i := 0; i < len(s); { + _, size := utf8.DecodeRuneInString(s[i:]) + if runeIdx >= maxChars { + break + } + b = append(b, s[i:i+size]...) + i += size + runeIdx++ + } + return string(b), true +} diff --git a/internal/session/asset_text_test.go b/internal/session/asset_text_test.go new file mode 100644 index 00000000..864c3412 --- /dev/null +++ b/internal/session/asset_text_test.go @@ -0,0 +1,363 @@ +package session + +import ( + "bytes" + "context" + "errors" + "io" + "strings" + "testing" +) + +// stubAssetStore 是用于单元测试 LoadTextAsset 的最小 AssetStore 桩。 +// SaveAsset 直接写内存字节流;Open/Stat 返回同样内容。 +type stubAssetStore struct { + payloads map[string]map[string][]byte // sessionID -> assetID -> bytes + mimes map[string]map[string]string +} + +func newStubAssetStore() *stubAssetStore { + return &stubAssetStore{ + payloads: map[string]map[string][]byte{}, + mimes: map[string]map[string]string{}, + } +} + +func (s *stubAssetStore) SaveAsset(_ context.Context, sessionID string, r io.Reader, mimeType string) (AssetMeta, error) { + data, err := io.ReadAll(r) + if err != nil { + return AssetMeta{}, err + } + if s.payloads[sessionID] == nil { + s.payloads[sessionID] = map[string][]byte{} + } + if s.mimes[sessionID] == nil { + s.mimes[sessionID] = map[string]string{} + } + id := "asset-" + mimeType + "-" + itoa(len(s.payloads[sessionID])) + s.payloads[sessionID][id] = data + s.mimes[sessionID][id] = mimeType + return AssetMeta{ID: id, MimeType: mimeType, Size: int64(len(data))}, nil +} + +func (s *stubAssetStore) Open(_ context.Context, sessionID, assetID string) (io.ReadCloser, AssetMeta, error) { + payloads, ok := s.payloads[sessionID] + if !ok { + return nil, AssetMeta{}, errors.New("session not found") + } + data, ok := payloads[assetID] + if !ok { + return nil, AssetMeta{}, errors.New("asset not found") + } + mime := s.mimes[sessionID][assetID] + return io.NopCloser(bytes.NewReader(data)), AssetMeta{ID: assetID, MimeType: mime, Size: int64(len(data))}, nil +} + +func (s *stubAssetStore) Stat(_ context.Context, sessionID, assetID string) (AssetMeta, error) { + payloads, ok := s.payloads[sessionID] + if !ok { + return AssetMeta{}, errors.New("session not found") + } + data, ok := payloads[assetID] + if !ok { + return AssetMeta{}, errors.New("asset not found") + } + mime := s.mimes[sessionID][assetID] + return AssetMeta{ID: assetID, MimeType: mime, Size: int64(len(data))}, nil +} + +func itoa(n int) string { + if n == 0 { + return "0" + } + const digits = "0123456789" + out := "" + for n > 0 { + out = string(digits[n%10]) + out + n /= 10 + } + return out +} + +func TestLoadTextAsset_Success(t *testing.T) { + t.Parallel() + + store := newStubAssetStore() + ctx := context.Background() + meta, err := store.SaveAsset(ctx, "s1", strings.NewReader("hello world"), "text/plain") + if err != nil { + t.Fatalf("SaveAsset() error = %v", err) + } + + result, err := LoadTextAsset(ctx, store, "s1", meta.ID, DefaultTextAssetPolicy(), TextAssetLoadOptions{ + FileName: "notes.md", + }) + if err != nil { + t.Fatalf("LoadTextAsset() error = %v", err) + } + if result.Content != "hello world" { + t.Errorf("content = %q, want %q", result.Content, "hello world") + } + if result.Truncated { + t.Errorf("Truncated = true, want false") + } + if result.KeptChars != 11 { + t.Errorf("KeptChars = %d, want 11", result.KeptChars) + } +} + +func TestLoadTextAsset_EmptyPayload(t *testing.T) { + t.Parallel() + + store := newStubAssetStore() + ctx := context.Background() + meta, err := store.SaveAsset(ctx, "s1", strings.NewReader(""), "text/plain") + if err != nil { + t.Fatalf("SaveAsset() error = %v", err) + } + + _, err = LoadTextAsset(ctx, store, "s1", meta.ID, DefaultTextAssetPolicy(), TextAssetLoadOptions{}) + var loadErr *AssetTextLoadError + if !errors.As(err, &loadErr) { + t.Fatalf("expected AssetTextLoadError, got %v", err) + } + if loadErr.Reason != "empty" { + t.Errorf("Reason = %q, want %q", loadErr.Reason, "empty") + } +} + +func TestLoadTextAsset_NonUTF8Payload(t *testing.T) { + t.Parallel() + + store := newStubAssetStore() + ctx := context.Background() + // 0xC3 0x28 是非法 UTF-8 序列(缺第二字节)。 + bad := []byte{0xC3, 0x28, 0xA0, 0xA1} + meta, err := store.SaveAsset(ctx, "s1", bytes.NewReader(bad), "text/plain") + if err != nil { + t.Fatalf("SaveAsset() error = %v", err) + } + + _, err = LoadTextAsset(ctx, store, "s1", meta.ID, DefaultTextAssetPolicy(), TextAssetLoadOptions{}) + var loadErr *AssetTextLoadError + if !errors.As(err, &loadErr) { + t.Fatalf("expected AssetTextLoadError, got %v", err) + } + if loadErr.Reason != "utf8" { + t.Errorf("Reason = %q, want %q", loadErr.Reason, "utf8") + } +} + +func TestLoadTextAsset_TruncatesByBytes(t *testing.T) { + t.Parallel() + + store := newStubAssetStore() + ctx := context.Background() + policy := TextAssetPolicy{ + Whitelist: DefaultTextAssetWhitelist(), + MaxTextAssetBytes: 8, + MaxTextAssetChars: 1000, + } + original := strings.Repeat("a", 16) // 16 字节,> 8 字节上限 + meta, err := store.SaveAsset(ctx, "s1", strings.NewReader(original), "text/plain") + if err != nil { + t.Fatalf("SaveAsset() error = %v", err) + } + + result, err := LoadTextAsset(ctx, store, "s1", meta.ID, policy, TextAssetLoadOptions{FileName: "long.txt"}) + if err != nil { + t.Fatalf("LoadTextAsset() error = %v", err) + } + if !result.Truncated { + t.Errorf("Truncated = false, want true") + } + if !strings.Contains(result.Content, "[truncated:") { + t.Errorf("content missing truncation marker, got %q", result.Content) + } + if !strings.Contains(result.Content, "filename=long.txt") { + t.Errorf("content missing sanitized filename, got %q", result.Content) + } + if result.OriginalBytes != 8 { + t.Errorf("OriginalBytes = %d, want 8 (post-byte-cap)", result.OriginalBytes) + } +} + +func TestLoadTextAsset_TruncatesByChars(t *testing.T) { + t.Parallel() + + store := newStubAssetStore() + ctx := context.Background() + policy := TextAssetPolicy{ + Whitelist: DefaultTextAssetWhitelist(), + MaxTextAssetBytes: 1024, + MaxTextAssetChars: 4, + } + meta, err := store.SaveAsset(ctx, "s1", strings.NewReader("abcdefghij"), "text/plain") + if err != nil { + t.Fatalf("SaveAsset() error = %v", err) + } + result, err := LoadTextAsset(ctx, store, "s1", meta.ID, policy, TextAssetLoadOptions{}) + if err != nil { + t.Fatalf("LoadTextAsset() error = %v", err) + } + if !result.Truncated { + t.Errorf("Truncated = false, want true") + } + // 截断提示里应包含 kept=4 chars。 + if !strings.Contains(result.Content, "kept=4 chars") { + t.Errorf("content missing char count marker, got %q", result.Content) + } +} + +func TestLoadTextAsset_PreservesMultibyteBoundary(t *testing.T) { + t.Parallel() + + store := newStubAssetStore() + ctx := context.Background() + // 5 个中文字符 = 15 字节。字符上限设为 3 → 截到 9 字节且不切碎字符。 + policy := TextAssetPolicy{ + Whitelist: DefaultTextAssetWhitelist(), + MaxTextAssetBytes: 1024, + MaxTextAssetChars: 3, + } + meta, err := store.SaveAsset(ctx, "s1", strings.NewReader("一二三四五"), "text/plain") + if err != nil { + t.Fatalf("SaveAsset() error = %v", err) + } + result, err := LoadTextAsset(ctx, store, "s1", meta.ID, policy, TextAssetLoadOptions{}) + if err != nil { + t.Fatalf("LoadTextAsset() error = %v", err) + } + if result.KeptChars != 3 { + t.Errorf("KeptChars = %d, want 3", result.KeptChars) + } + // 截断后剩余内容(不含截断提示)必须是 3 个完整中文字符。 + // 截断提示一定包含 newline + "[truncated:",所以按最后一段截取前段。 + body := result.Content + if idx := strings.Index(body, "\n\n[truncated:"); idx >= 0 { + body = body[:idx] + } + if body != "一二三" { + t.Errorf("truncated body = %q, want %q", body, "一二三") + } +} + +func TestLoadTextAsset_SanitizesFileNameInMarker(t *testing.T) { + t.Parallel() + + store := newStubAssetStore() + ctx := context.Background() + policy := TextAssetPolicy{ + Whitelist: DefaultTextAssetWhitelist(), + MaxTextAssetBytes: 4, + MaxTextAssetChars: 1000, + } + original := strings.Repeat("x", 16) + meta, err := store.SaveAsset(ctx, "s1", strings.NewReader(original), "text/plain") + if err != nil { + t.Fatalf("SaveAsset() error = %v", err) + } + result, err := LoadTextAsset(ctx, store, "s1", meta.ID, policy, TextAssetLoadOptions{ + FileName: "../etc/passwd", + }) + if err != nil { + t.Fatalf("LoadTextAsset() error = %v", err) + } + if strings.Contains(result.Content, "../") { + t.Errorf("content still contains path separator, got %q", result.Content) + } + if !strings.Contains(result.Content, "filename=passwd") { + t.Errorf("content missing sanitized basename, got %q", result.Content) + } +} + +func TestLoadTextAsset_RejectsOpenError(t *testing.T) { + t.Parallel() + + store := newStubAssetStore() + _, err := LoadTextAsset(context.Background(), store, "s1", "missing", DefaultTextAssetPolicy(), TextAssetLoadOptions{}) + var loadErr *AssetTextLoadError + if !errors.As(err, &loadErr) { + t.Fatalf("expected AssetTextLoadError, got %v", err) + } + if loadErr.Reason != "open" { + t.Errorf("Reason = %q, want %q", loadErr.Reason, "open") + } +} + +func TestLoadTextAsset_RejectsEmptyWhitelist(t *testing.T) { + t.Parallel() + + store := newStubAssetStore() + _, err := LoadTextAsset( + context.Background(), + store, + "s1", + "missing", + TextAssetPolicy{Whitelist: NewTextAssetWhitelist(nil)}, + TextAssetLoadOptions{}, + ) + var loadErr *AssetTextLoadError + if !errors.As(err, &loadErr) { + t.Fatalf("expected AssetTextLoadError, got %v", err) + } + if loadErr.Reason != "whitelist-empty" { + t.Errorf("Reason = %q, want %q", loadErr.Reason, "whitelist-empty") + } +} + +func TestLoadTextAsset_RejectsEmptyAssetID(t *testing.T) { + t.Parallel() + + store := newStubAssetStore() + _, err := LoadTextAsset(context.Background(), store, "s1", "", DefaultTextAssetPolicy(), TextAssetLoadOptions{}) + var loadErr *AssetTextLoadError + if !errors.As(err, &loadErr) { + t.Fatalf("expected AssetTextLoadError, got %v", err) + } + if loadErr.Reason != "missing-asset-id" { + t.Errorf("Reason = %q, want %q", loadErr.Reason, "missing-asset-id") + } +} + +func TestLoadTextAsset_RejectsNilStore(t *testing.T) { + t.Parallel() + + _, err := LoadTextAsset(context.Background(), nil, "s1", "a1", DefaultTextAssetPolicy(), TextAssetLoadOptions{}) + var loadErr *AssetTextLoadError + if !errors.As(err, &loadErr) { + t.Fatalf("expected AssetTextLoadError, got %v", err) + } + if loadErr.Reason != "store-nil" { + t.Errorf("Reason = %q, want %q", loadErr.Reason, "store-nil") + } +} + +func TestTruncateByRuneCount(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + in string + max int + want string + wantTrunc bool + }{ + {name: "no truncate when within limit", in: "abc", max: 10, want: "abc", wantTrunc: false}, + {name: "exact limit no truncate", in: "abcde", max: 5, want: "abcde", wantTrunc: false}, + {name: "truncate at boundary", in: "abcdef", max: 4, want: "abcd", wantTrunc: true}, + {name: "multibyte preserved", in: "一二三", max: 2, want: "一二", wantTrunc: true}, + {name: "zero max forces truncate", in: "abc", max: 0, want: "", wantTrunc: true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got, trunc := truncateByRuneCount(tc.in, tc.max) + if got != tc.want { + t.Errorf("got = %q, want %q", got, tc.want) + } + if trunc != tc.wantTrunc { + t.Errorf("trunc = %v, want %v", trunc, tc.wantTrunc) + } + }) + } +} diff --git a/internal/session/assets_policy.go b/internal/session/assets_policy.go index 2ca3433f..4ffa8cb5 100644 --- a/internal/session/assets_policy.go +++ b/internal/session/assets_policy.go @@ -1,8 +1,25 @@ package session +import ( + "fmt" + "path/filepath" + "strings" +) + const ( // MaxSessionAssetBytes 定义 session_asset 在读写链路中的统一大小上限(20 MiB)。 MaxSessionAssetBytes int64 = 20 * 1024 * 1024 + // DefaultMaxTextAssetBytes 定义 session 文本 asset 的默认字节上限(256 KiB)。 + // 该默认值用于避免单轮会话上下文被大体量文本文件击穿;硬上限由 MaxTextAssetBytesHardLimit 兜底。 + DefaultMaxTextAssetBytes int64 = 256 * 1024 + // MaxTextAssetBytesHardLimit 是文本 asset 字节上限的硬编码兜底,防止配置或调用方把上限放得过大。 + MaxTextAssetBytesHardLimit int64 = 4 * 1024 * 1024 + // DefaultMaxTextAssetChars 定义 session 文本 asset 的默认字符上限(约 25 万字符,按 UTF-8 解码后计数)。 + DefaultMaxTextAssetChars int = 250_000 + // MaxTextAssetCharsHardLimit 是文本 asset 字符上限的硬编码兜底。 + MaxTextAssetCharsHardLimit int = 4_000_000 + // MaxTextAssetFileNameBytes 定义文本 asset 文件名(用于上下文标签)的最大字节数。 + MaxTextAssetFileNameBytes int = 128 ) // AssetPolicy 描述 session_asset 在单文件维度的存储与读写策略。 @@ -28,3 +45,203 @@ func NormalizeAssetPolicy(policy AssetPolicy) AssetPolicy { } return normalized } + +// TextAssetWhitelist 描述文本类附件的扩展名与 MIME 双向白名单。 +// 同时提供按扩展名(无 "." 前缀、小写)和按 MIME(小写)查询的能力, +// 允许调用方在不修改核心代码的前提下扩展支持类型(与默认项做 union 合并)。 +type TextAssetWhitelist struct { + // extensionToMime 记录扩展名到 MIME 的映射。 + extensionToMime map[string]string + // mimeSet 记录支持的全部 MIME 集合,便于按 MIME 反向查扩展名。 + mimeSet map[string]struct{} +} + +// DefaultTextAssetWhitelist 返回内置的文本附件白名单(与 issue #701 验收清单对齐)。 +func DefaultTextAssetWhitelist() TextAssetWhitelist { + return NewTextAssetWhitelist(map[string]string{ + "txt": "text/plain", + "md": "text/markdown", + "json": "application/json", + "yaml": "text/yaml", + "yml": "application/x-yaml", + "csv": "text/csv", + }) +} + +// NewTextAssetWhitelist 通过扩展名→MIME 映射构造白名单;空映射会得到一个空的、拒绝一切的实例。 +func NewTextAssetWhitelist(extensionToMime map[string]string) TextAssetWhitelist { + normalized := make(map[string]string, len(extensionToMime)) + mimes := make(map[string]struct{}, len(extensionToMime)) + for ext, mime := range extensionToMime { + cleanExt := strings.ToLower(strings.TrimSpace(strings.TrimPrefix(ext, "."))) + cleanMime := strings.ToLower(strings.TrimSpace(mime)) + if cleanExt == "" || cleanMime == "" { + continue + } + normalized[cleanExt] = cleanMime + mimes[cleanMime] = struct{}{} + } + return TextAssetWhitelist{ + extensionToMime: normalized, + mimeSet: mimes, + } +} + +// WithExtensions 在现有白名单上追加扩展名→MIME 项,返回新的白名单实例(不可变)。 +// 重复键以 base 为准;返回的实例不与 base 共享内部 map。 +func (w TextAssetWhitelist) WithExtensions(extensionToMime map[string]string) TextAssetWhitelist { + merged := make(map[string]string, len(w.extensionToMime)+len(extensionToMime)) + for k, v := range w.extensionToMime { + merged[k] = v + } + for k, v := range extensionToMime { + cleanExt := strings.ToLower(strings.TrimSpace(strings.TrimPrefix(k, "."))) + cleanMime := strings.ToLower(strings.TrimSpace(v)) + if cleanExt == "" || cleanMime == "" { + continue + } + merged[cleanExt] = cleanMime + } + mimes := make(map[string]struct{}, len(merged)) + for _, mime := range merged { + mimes[mime] = struct{}{} + } + return TextAssetWhitelist{extensionToMime: merged, mimeSet: mimes} +} + +// IsEmpty 报告白名单是否为空。空白名单会拒绝所有文本 asset。 +func (w TextAssetWhitelist) IsEmpty() bool { + return len(w.extensionToMime) == 0 +} + +// LookupByExtension 按文件名(不区分大小写,自动去除路径分隔符)解析扩展名对应的 MIME;未命中返回空。 +func (w TextAssetWhitelist) LookupByExtension(fileName string) string { + ext := strings.ToLower(strings.TrimSpace(strings.TrimPrefix(filepath.Ext(strings.TrimSpace(fileName)), "."))) + if ext == "" { + return "" + } + return w.extensionToMime[ext] +} + +// LookupByMime 按 MIME(不区分大小写)判断是否在白名单中。 +func (w TextAssetWhitelist) LookupByMime(mimeType string) bool { + mime := strings.ToLower(strings.TrimSpace(mimeType)) + if mime == "" { + return false + } + _, ok := w.mimeSet[mime] + return ok +} + +// Extensions 返回当前白名单包含的全部扩展名(只读快照)。 +func (w TextAssetWhitelist) Extensions() []string { + if len(w.extensionToMime) == 0 { + return nil + } + out := make([]string, 0, len(w.extensionToMime)) + for ext := range w.extensionToMime { + out = append(out, ext) + } + return out +} + +// TextAssetPolicy 描述 session 文本类 asset 的存储与读取策略。 +type TextAssetPolicy struct { + // Whitelist 是文本 asset 的扩展名/MIME 白名单。 + Whitelist TextAssetWhitelist + // MaxTextAssetBytes 是单个文本 asset 的字节上限。 + MaxTextAssetBytes int64 + // MaxTextAssetChars 是单个文本 asset 在 UTF-8 解码后允许保留的最大字符数。 + MaxTextAssetChars int +} + +// DefaultTextAssetPolicy 返回 session 文本 asset 策略的默认值。 +func DefaultTextAssetPolicy() TextAssetPolicy { + return TextAssetPolicy{ + Whitelist: DefaultTextAssetWhitelist(), + MaxTextAssetBytes: DefaultMaxTextAssetBytes, + MaxTextAssetChars: DefaultMaxTextAssetChars, + } +} + +// NormalizeTextAssetPolicy 归一化文本 asset 策略并施加硬上限兜底。 +// 零值或负值回填默认值;超过硬上限的值会被截到硬上限。 +func NormalizeTextAssetPolicy(policy TextAssetPolicy) TextAssetPolicy { + normalized := policy + if normalized.Whitelist.IsEmpty() { + normalized.Whitelist = DefaultTextAssetWhitelist() + } + if normalized.MaxTextAssetBytes <= 0 { + normalized.MaxTextAssetBytes = DefaultMaxTextAssetBytes + } + if normalized.MaxTextAssetBytes > MaxTextAssetBytesHardLimit { + normalized.MaxTextAssetBytes = MaxTextAssetBytesHardLimit + } + if normalized.MaxTextAssetChars <= 0 { + normalized.MaxTextAssetChars = DefaultMaxTextAssetChars + } + if normalized.MaxTextAssetChars > MaxTextAssetCharsHardLimit { + normalized.MaxTextAssetChars = MaxTextAssetCharsHardLimit + } + return normalized +} + +// SanitizeTextAssetFileName 把原始文件名清洗为"只用于上下文标签"的安全字符串。 +// 处理项:去路径分隔符、控制字符、引号、反引号;裁剪到 MaxTextAssetFileNameBytes;空值返回 fallback。 +// 注意:返回值**只作为文本内容里的展示标签**,禁止被当作文件系统路径或受信字符串使用。 +func SanitizeTextAssetFileName(raw string, fallback string) string { + cleaned := strings.TrimSpace(raw) + if cleaned == "" { + cleaned = strings.TrimSpace(fallback) + } + if cleaned == "" { + return "" + } + // 仅取 basename(按 '/' 或 '\\' 切分最后一段),避免目录穿越/绝对路径泄漏到标签。 + // 使用纯字符串处理,不依赖 filepath.Base —— 后者在不同 OS 上对控制字符和末尾分隔符行为不一致。 + if idx := strings.LastIndexAny(cleaned, "/\\"); idx >= 0 { + cleaned = cleaned[idx+1:] + } + // 替换路径分隔符与控制字符为下划线(控制字符单独走一遍确保万无一失)。 + var b strings.Builder + b.Grow(len(cleaned)) + for _, r := range cleaned { + switch { + case r == '/' || r == '\\': + b.WriteByte('_') + case r < 0x20 || r == 0x7f: + b.WriteByte('_') + case r == '"' || r == '\'' || r == '`': + b.WriteByte('_') + default: + b.WriteRune(r) + } + } + out := b.String() + if out == "" || out == "." || out == ".." { + return "" + } + if len(out) > MaxTextAssetFileNameBytes { + // 按 rune 逐步缩减到字节上限内,保留完整 UTF-8 字符边界, + // 避免在多字节字符(如中文)中间切断产生非法 UTF-8。 + runes := []rune(out) + cut := len(runes) + for cut > 0 && len(string(runes[:cut])) > MaxTextAssetFileNameBytes { + cut-- + } + out = string(runes[:cut]) + } + return out +} + +// String 返回当前白名单包含的全部 (扩展名 → MIME) 对,用于错误信息或调试输出。 +func (w TextAssetWhitelist) String() string { + if w.IsEmpty() { + return "empty" + } + parts := make([]string, 0, len(w.extensionToMime)) + for ext, mime := range w.extensionToMime { + parts = append(parts, fmt.Sprintf("%s→%s", ext, mime)) + } + return strings.Join(parts, ",") +} diff --git a/internal/session/assets_policy_test.go b/internal/session/assets_policy_test.go index adfebde8..213cee10 100644 --- a/internal/session/assets_policy_test.go +++ b/internal/session/assets_policy_test.go @@ -1,6 +1,10 @@ package session -import "testing" +import ( + "strings" + "testing" + "unicode/utf8" +) func TestDefaultAssetPolicy(t *testing.T) { t.Parallel() @@ -45,3 +49,179 @@ func TestNormalizeAssetPolicy(t *testing.T) { }) } } + +func TestDefaultTextAssetWhitelist(t *testing.T) { + t.Parallel() + + whitelist := DefaultTextAssetWhitelist() + cases := []struct { + ext string + fileName string + mime string + want string + }{ + {ext: "txt", fileName: "notes.txt", want: "text/plain"}, + {ext: "md", fileName: "README.md", want: "text/markdown"}, + {ext: "json", fileName: "config.json", want: "application/json"}, + {ext: "yaml", fileName: "values.yaml", want: "text/yaml"}, + {ext: "yml", fileName: "values.yml", want: "application/x-yaml"}, + {ext: "csv", fileName: "data.csv", want: "text/csv"}, + } + for _, tc := range cases { + if got := whitelist.LookupByExtension(tc.fileName); got != tc.want { + t.Errorf("LookupByExtension(%q) = %q, want %q", tc.fileName, got, tc.want) + } + if !whitelist.LookupByMime(tc.want) { + t.Errorf("LookupByMime(%q) = false, want true", tc.want) + } + } + // 路径无关性:basename 提取正确。 + if got := whitelist.LookupByExtension("/tmp/sub/notes.txt"); got != "text/plain" { + t.Errorf("LookupByExtension with dir prefix = %q, want text/plain", got) + } + // 命中后白名单应非空。 + if whitelist.IsEmpty() { + t.Fatalf("default text whitelist should not be empty") + } + // 不在白名单的扩展名/MIME 应返回空/否。 + if got := whitelist.LookupByExtension("page.html"); got != "" { + t.Errorf("LookupByExtension(.html) = %q, want empty", got) + } + if whitelist.LookupByMime("text/html") { + t.Errorf("LookupByMime(text/html) = true, want false") + } +} + +func TestTextAssetWhitelistWithExtensions(t *testing.T) { + t.Parallel() + + whitelist := DefaultTextAssetWhitelist().WithExtensions(map[string]string{ + "log": "text/plain", + "tsv": "text/tab-separated-values", + }) + if got := whitelist.LookupByExtension("debug.log"); got != "text/plain" { + t.Errorf("LookupByExtension(.log) = %q, want text/plain", got) + } + if !whitelist.LookupByMime("text/tab-separated-values") { + t.Errorf("LookupByMime(text/tab-separated-values) = false, want true") + } + // 默认项保留。 + if got := whitelist.LookupByExtension("data.csv"); got != "text/csv" { + t.Errorf("LookupByExtension(.csv) = %q, want text/csv (default preserved)", got) + } +} + +func TestNormalizeTextAssetPolicy(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + in TextAssetPolicy + wantBytes int64 + wantChars int + wantWhitelistN int + }{ + { + name: "zero values use defaults", + in: TextAssetPolicy{}, + wantBytes: DefaultMaxTextAssetBytes, + wantChars: DefaultMaxTextAssetChars, + wantWhitelistN: len(DefaultTextAssetWhitelist().Extensions()), + }, + { + name: "byte cap hard limit", + in: TextAssetPolicy{ + Whitelist: DefaultTextAssetWhitelist(), + MaxTextAssetBytes: MaxTextAssetBytesHardLimit + 1024, + MaxTextAssetChars: DefaultMaxTextAssetChars, + }, + wantBytes: MaxTextAssetBytesHardLimit, + wantChars: DefaultMaxTextAssetChars, + wantWhitelistN: len(DefaultTextAssetWhitelist().Extensions()), + }, + { + name: "char cap hard limit", + in: TextAssetPolicy{ + Whitelist: DefaultTextAssetWhitelist(), + MaxTextAssetBytes: 1024, + MaxTextAssetChars: MaxTextAssetCharsHardLimit + 100, + }, + wantBytes: 1024, + wantChars: MaxTextAssetCharsHardLimit, + wantWhitelistN: len(DefaultTextAssetWhitelist().Extensions()), + }, + { + name: "empty whitelist falls back to default", + in: TextAssetPolicy{ + Whitelist: NewTextAssetWhitelist(nil), + MaxTextAssetBytes: 512, + MaxTextAssetChars: 100, + }, + wantBytes: 512, + wantChars: 100, + wantWhitelistN: len(DefaultTextAssetWhitelist().Extensions()), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NormalizeTextAssetPolicy(tt.in) + if got.MaxTextAssetBytes != tt.wantBytes { + t.Errorf("MaxTextAssetBytes = %d, want %d", got.MaxTextAssetBytes, tt.wantBytes) + } + if got.MaxTextAssetChars != tt.wantChars { + t.Errorf("MaxTextAssetChars = %d, want %d", got.MaxTextAssetChars, tt.wantChars) + } + if n := len(got.Whitelist.Extensions()); n != tt.wantWhitelistN { + t.Errorf("whitelist size = %d, want %d", n, tt.wantWhitelistN) + } + }) + } +} + +func TestSanitizeTextAssetFileName(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + raw string + fallback string + want string + }{ + {name: "plain name kept", raw: "notes.md", want: "notes.md"}, + {name: "path separator collapsed", raw: "../etc/passwd", want: "passwd"}, + {name: "backslash collapsed", raw: `..\..\evil.md`, want: "evil.md"}, + {name: "absolute path", raw: "/var/log/app.log", want: "app.log"}, + {name: "control chars replaced", raw: "name\x00\x01.md", want: "name__.md"}, + {name: "quotes replaced", raw: "evil\"name`.md", want: "evil_name_.md"}, + {name: "empty falls back", raw: "", fallback: "fallback.txt", want: "fallback.txt"}, + {name: "all invalid empty", raw: "////", want: ""}, + {name: "dot dot returns empty", raw: "..", want: ""}, + {name: "trim spaces", raw: " hello.txt ", want: "hello.txt"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := SanitizeTextAssetFileName(tc.raw, tc.fallback) + if got != tc.want { + t.Errorf("SanitizeTextAssetFileName(%q, %q) = %q, want %q", tc.raw, tc.fallback, got, tc.want) + } + }) + } + + // 超长输入按字节裁剪到上限。 + longName := strings.Repeat("a", MaxTextAssetFileNameBytes+50) + ".md" + got := SanitizeTextAssetFileName(longName, "") + if len(got) > MaxTextAssetFileNameBytes { + t.Errorf("sanitized name length %d exceeds limit %d", len(got), MaxTextAssetFileNameBytes) + } + + // 多字节 UTF-8(中文)长文件名截断后必须保持合法 UTF-8,不在字符中间切断。 + chineseName := strings.Repeat("设计方案报告", 20) + ".md" // 每个中文字符 3 字节,远超 128 字节上限 + chineseGot := SanitizeTextAssetFileName(chineseName, "") + if len(chineseGot) > MaxTextAssetFileNameBytes { + t.Errorf("chinese name length %d exceeds limit %d", len(chineseGot), MaxTextAssetFileNameBytes) + } + if !utf8.ValidString(chineseGot) { + t.Errorf("sanitized chinese name is not valid UTF-8: %q", chineseGot) + } +} diff --git a/internal/session/input_preparer.go b/internal/session/input_preparer.go index f5b9d9ba..39d7846e 100644 --- a/internal/session/input_preparer.go +++ b/internal/session/input_preparer.go @@ -70,6 +70,8 @@ func (e *AssetSaveError) Unwrap() error { type InputPreparer struct { store Store assetStore AssetStore + // textPolicy 控制文本类附件的解析与落盘策略;nil 时使用 DefaultTextAssetPolicy。 + textPolicy *TextAssetPolicy } type assetCleanupStore interface { @@ -92,6 +94,23 @@ func NewInputPreparer(store Store, assetStore AssetStore) *InputPreparer { } } +// SetTextAssetPolicy 注入文本类附件策略;nil 或未调用时使用 session.DefaultTextAssetPolicy。 +func (p *InputPreparer) SetTextAssetPolicy(policy TextAssetPolicy) { + if p == nil { + return + } + normalized := NormalizeTextAssetPolicy(policy) + p.textPolicy = &normalized +} + +// textAssetPolicySnapshot 返回当前生效的文本附件策略。 +func (p *InputPreparer) textAssetPolicySnapshot() TextAssetPolicy { + if p == nil || p.textPolicy == nil { + return DefaultTextAssetPolicy() + } + return NormalizeTextAssetPolicy(*p.textPolicy) +} + // Prepare 负责会话解析/创建、附件落盘与 parts 组装。 func (p *InputPreparer) Prepare(ctx context.Context, input PrepareInput) (PreparedInput, error) { if err := ctx.Err(); err != nil { @@ -130,6 +149,8 @@ func (p *InputPreparer) Prepare(ctx context.Context, input PrepareInput) (Prepar for index, image := range input.Images { path := strings.TrimSpace(image.Path) assetID := strings.TrimSpace(image.AssetID) + mimeType := strings.TrimSpace(image.MimeType) + isText := p.textAssetPolicySnapshot().Whitelist.LookupByMime(mimeType) if assetID != "" { if path != "" { p.rollbackCreatedSession(ctx, session.ID, sessionCreated) @@ -138,10 +159,16 @@ func (p *InputPreparer) Prepare(ctx context.Context, input PrepareInput) (Prepar SessionID: session.ID, Index: index, Path: path, - Err: fmt.Errorf("image input cannot contain both path and asset id"), + Err: fmt.Errorf("input cannot contain both path and asset id"), } } - meta, err := p.referenceImageAsset(ctx, session.ID, assetID, image.MimeType) + var meta AssetMeta + var err error + if isText { + meta, err = p.referenceTextAsset(ctx, session.ID, assetID, mimeType) + } else { + meta, err = p.referenceImageAsset(ctx, session.ID, assetID, mimeType) + } if err != nil { p.rollbackCreatedSession(ctx, session.ID, sessionCreated) p.cleanupSavedAssets(ctx, session.ID, savedAssets) @@ -165,9 +192,14 @@ func (p *InputPreparer) Prepare(ctx context.Context, input PrepareInput) (Prepar Err: fmt.Errorf("image path is empty"), } } - mimeType := strings.TrimSpace(image.MimeType) - meta, err := p.saveImageAsset(ctx, session.ID, session.Workdir, path, mimeType) + var meta AssetMeta + var err error + if isText { + meta, err = p.saveTextAsset(ctx, session.ID, session.Workdir, path, mimeType) + } else { + meta, err = p.saveImageAsset(ctx, session.ID, session.Workdir, path, mimeType) + } if err != nil { p.rollbackCreatedSession(ctx, session.ID, sessionCreated) p.cleanupSavedAssets(ctx, session.ID, savedAssets) @@ -247,6 +279,89 @@ func (p *InputPreparer) saveImageAsset( return meta, nil } +// saveTextAsset 按会话工作目录解析文本附件路径后落盘到 session asset 存储。 +// 与 saveImageAsset 的差异:不做图像头嗅探,依赖调用方传入的 mime 是否在文本白名单内做兜底。 +func (p *InputPreparer) saveTextAsset( + ctx context.Context, + sessionID string, + workdir string, + path string, + mimeType string, +) (AssetMeta, error) { + if err := ctx.Err(); err != nil { + return AssetMeta{}, err + } + absolutePath, err := resolveImagePath(workdir, path) + if err != nil { + return AssetMeta{}, err + } + if err := ctx.Err(); err != nil { + return AssetMeta{}, err + } + + file, err := os.Open(absolutePath) + if err != nil { + return AssetMeta{}, fmt.Errorf("open text file: %w", err) + } + defer func() { + _ = file.Close() + }() + if err := ctx.Err(); err != nil { + return AssetMeta{}, err + } + + // 文本类附件使用声明的 mime 落盘,不做嗅探(避免误判非图片的二进制为 image/*)。 + // 声明 mime 为空时按扩展名兜底,但仍要求命中文本白名单。 + resolvedMime := normalizeMimeType(mimeType) + if resolvedMime == "" { + resolvedMime = normalizeMimeType(mime.TypeByExtension(strings.ToLower(filepath.Ext(path)))) + } + policy := p.textAssetPolicySnapshot() + if !policy.Whitelist.LookupByMime(resolvedMime) { + return AssetMeta{}, fmt.Errorf("declared mime type %q is not in text asset whitelist", mimeType) + } + + meta, err := p.assetStore.SaveAsset(ctx, sessionID, file, resolvedMime) + if err != nil { + return AssetMeta{}, err + } + return meta, nil +} + +// referenceTextAsset 校验已保存附件属于当前会话,并返回可进入 runtime 的文本元数据。 +// 校验 mime 是否在文本白名单内,绕过 referenceImageAsset 的 image/* 限制。 +func (p *InputPreparer) referenceTextAsset( + ctx context.Context, + sessionID string, + assetID string, + mimeType string, +) (AssetMeta, error) { + if err := ctx.Err(); err != nil { + return AssetMeta{}, err + } + if p.assetStore == nil { + return AssetMeta{}, fmt.Errorf("session: asset store is not configured") + } + normalizedAssetID := strings.TrimSpace(assetID) + if normalizedAssetID == "" { + return AssetMeta{}, fmt.Errorf("text asset id is empty") + } + + meta, err := p.assetStore.Stat(ctx, sessionID, normalizedAssetID) + if err != nil { + return AssetMeta{}, fmt.Errorf("stat text asset: %w", err) + } + policy := p.textAssetPolicySnapshot() + if !policy.Whitelist.LookupByMime(meta.MimeType) { + return AssetMeta{}, fmt.Errorf("asset %q has mime %q, not in text asset whitelist", normalizedAssetID, meta.MimeType) + } + declaredMime := normalizeMimeType(mimeType) + if declaredMime != "" && declaredMime != meta.MimeType { + return AssetMeta{}, fmt.Errorf("declared mime type %q mismatches saved asset %q", declaredMime, meta.MimeType) + } + return meta, nil +} + // referenceImageAsset 校验已保存附件属于当前会话,并返回可进入 provider 的图片元数据。 func (p *InputPreparer) referenceImageAsset( ctx context.Context, diff --git a/internal/session/input_preparer_test.go b/internal/session/input_preparer_test.go index 356449cc..1a677ad3 100644 --- a/internal/session/input_preparer_test.go +++ b/internal/session/input_preparer_test.go @@ -490,22 +490,40 @@ func TestInputPreparerPrepareImagePathAndMimeValidation(t *testing.T) { } }) - t.Run("declared non image mime is rejected", func(t *testing.T) { + t.Run("declared non whitelisted mime is rejected", func(t *testing.T) { + // 非图片且非文本白名单的 mime(如 application/octet-stream)应被拒。 imagePath := filepath.Join(workdir, "declared-text.png") if err := os.WriteFile(imagePath, minimalPNGBytes(), 0o644); err != nil { t.Fatalf("write image: %v", err) } _, err := preparer.Prepare(context.Background(), PrepareInput{ - Text: "declared text", - Images: []PrepareImageInput{{Path: imagePath, MimeType: "text/plain"}}, + Text: "declared unknown", + Images: []PrepareImageInput{{Path: imagePath, MimeType: "application/octet-stream"}}, DefaultWorkdir: workdir, }) if err == nil { - t.Fatalf("expected non-image mime error") + t.Fatalf("expected unsupported mime error") } if !strings.Contains(err.Error(), "is not an image") { - t.Fatalf("expected non-image mime error, got %v", err) + t.Fatalf("expected unsupported mime error, got %v", err) + } + }) + + t.Run("declared text mime is accepted via text path", func(t *testing.T) { + // 文本白名单内的 mime 应走文本附件路径,成功保存。 + textPath := filepath.Join(workdir, "notes.md") + if err := os.WriteFile(textPath, []byte("# title"), 0o644); err != nil { + t.Fatalf("write text: %v", err) + } + + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "with text asset", + Images: []PrepareImageInput{{Path: textPath, MimeType: "text/markdown"}}, + DefaultWorkdir: workdir, + }) + if err != nil { + t.Fatalf("expected text asset to be accepted, got %v", err) } }) diff --git a/internal/session/sqlite_store.go b/internal/session/sqlite_store.go index d214b39c..bc4bfa13 100644 --- a/internal/session/sqlite_store.go +++ b/internal/session/sqlite_store.go @@ -65,6 +65,7 @@ type SQLiteStore struct { db *sql.DB limitsMu sync.RWMutex assetPolicy AssetPolicy + textPolicy TextAssetPolicy } // SetAssetPolicy 设置会话附件大小限制;非法值会回退到默认并应用硬上限兜底。 @@ -77,6 +78,17 @@ func (s *SQLiteStore) SetAssetPolicy(policy AssetPolicy) { s.limitsMu.Unlock() } +// SetTextAssetPolicy 设置会话文本类附件的策略(白名单与大小/字符上限)。 +// 空值或非法值会回退到默认并应用硬上限兜底。 +func (s *SQLiteStore) SetTextAssetPolicy(policy TextAssetPolicy) { + if s == nil { + return + } + s.limitsMu.Lock() + s.textPolicy = NormalizeTextAssetPolicy(policy) + s.limitsMu.Unlock() +} + // assetPolicySnapshot 返回当前生效的会话附件限制。 func (s *SQLiteStore) assetPolicySnapshot() AssetPolicy { if s == nil { @@ -88,6 +100,17 @@ func (s *SQLiteStore) assetPolicySnapshot() AssetPolicy { return NormalizeAssetPolicy(policy) } +// textAssetPolicySnapshot 返回当前生效的会话文本附件策略。 +func (s *SQLiteStore) textAssetPolicySnapshot() TextAssetPolicy { + if s == nil { + return DefaultTextAssetPolicy() + } + s.limitsMu.RLock() + policy := s.textPolicy + s.limitsMu.RUnlock() + return NormalizeTextAssetPolicy(policy) +} + // Close 释放数据库连接,供测试和上层生命周期管理复用。 func (s *SQLiteStore) Close() error { if s == nil || s.db == nil { @@ -648,16 +671,22 @@ func (s *SQLiteStore) SaveAsset(ctx context.Context, sessionID string, r io.Read } policy := s.assetPolicySnapshot() - written, copyErr := io.Copy(tempFile, io.LimitReader(r, policy.MaxSessionAssetBytes+1)) + textPolicy := s.textAssetPolicySnapshot() + // 文本 asset 走更严格的字节上限(避免 20 MiB 文本在单轮会话中击穿上下文预算)。 + limit := policy.MaxSessionAssetBytes + if textPolicy.Whitelist.LookupByMime(meta.MimeType) { + limit = textPolicy.MaxTextAssetBytes + } + written, copyErr := io.Copy(tempFile, io.LimitReader(r, limit+1)) syncErr := tempFile.Sync() closeErr := tempFile.Close() if copyErr != nil { _ = os.Remove(tempPath) return AssetMeta{}, fmt.Errorf("session: write temp asset: %w", copyErr) } - if written > policy.MaxSessionAssetBytes { + if written > limit { _ = os.Remove(tempPath) - return AssetMeta{}, fmt.Errorf("session: asset size exceeds %d bytes", policy.MaxSessionAssetBytes) + return AssetMeta{}, fmt.Errorf("session: asset size exceeds %d bytes", limit) } if syncErr != nil { _ = os.Remove(tempPath) diff --git a/web/src/components/chat/ChatInput.test.tsx b/web/src/components/chat/ChatInput.test.tsx index fcbef263..0ba11f78 100644 --- a/web/src/components/chat/ChatInput.test.tsx +++ b/web/src/components/chat/ChatInput.test.tsx @@ -179,7 +179,7 @@ describe('ChatInput', () => { it('renders the image attachment picker but keeps mention button absent', () => { render() - expect(screen.getByRole('button', { name: /添加图片/ })).toBeInTheDocument() + expect(screen.getByRole('button', { name: /添加附件/ })).toBeInTheDocument() expect(screen.queryByTitle('引用上下文')).not.toBeInTheDocument() }) @@ -223,7 +223,7 @@ describe('ChatInput', () => { }) }) - it('blocks image selection when the selected model explicitly rejects images', async () => { + it('blocks image selection when the selected model rejects images but still allows text', async () => { mockGatewayAPI.listModels.mockResolvedValueOnce({ payload: { models: [{ @@ -238,16 +238,24 @@ describe('ChatInput', () => { }) render() + // 等待模型能力加载完成:按钮 title 切换为不支持图片的提示,作为同步信号。 await waitFor(() => { - expect(screen.getByRole('button', { name: /添加图片/ })).toBeDisabled() + expect(screen.getByRole('button', { name: /添加附件/ })).toHaveAttribute('title', '当前模型不支持图片,可添加文本文件') }) - const file = new File(['img'], 'a.png', { type: 'image/png' }) - const input = document.querySelector('input[type="file"]') as HTMLInputElement - fireEvent.change(input, { target: { files: [file] } }) + // 图片附件仍受模型视觉能力限制,被拒绝。 + const input = document.querySelector('input[type="file"]') as HTMLInputElement + fireEvent.change(input, { target: { files: [new File(['img'], 'a.png', { type: 'image/png' })] } }) await waitFor(() => { expect(useComposerStore.getState().attachments).toHaveLength(0) }) + + // 文本附件不依赖模型视觉能力,仍可添加。 + fireEvent.change(input, { target: { files: [new File(['# hi'], 'note.md', { type: 'text/markdown' })] } }) + await waitFor(() => { + expect(useComposerStore.getState().attachments).toHaveLength(1) + expect(useComposerStore.getState().attachments[0].kind).toBe('text') + }) }) it('blocks sending existing image attachments when the selected model rejects images', async () => { @@ -266,8 +274,9 @@ describe('ChatInput', () => { render() await waitFor(() => { - expect(screen.getByRole('button', { name: /添加图片/ })).toBeDisabled() + expect(screen.getByRole('button', { name: /添加附件/ })).toHaveAttribute('title', '当前模型不支持图片,可添加文本文件') }) + // 直接塞入图片附件(绕过选择校验),验证发送阶段对图片附件的阻断。 useComposerStore.getState().addAttachmentFiles([new File(['img'], 'a.png', { type: 'image/png' })]) fireEvent.keyDown(screen.getByRole('textbox'), { key: 'Enter' }) @@ -278,6 +287,58 @@ describe('ChatInput', () => { }) }) + it('uploads selected text file and sends it as an image part with the real text mime', async () => { + useSessionStore.setState({ currentSessionId: 'session-1' } as never) + mockGatewayAPI.uploadSessionAsset.mockResolvedValueOnce({ + session_id: 'session-1', + asset_id: 'asset-text-1', + mime_type: 'text/markdown', + size: 5, + }) + render() + + const file = new File(['# hi'], 'note.md', { type: 'text/markdown' }) + const input = document.querySelector('input[type="file"]') as HTMLInputElement + fireEvent.change(input, { target: { files: [file] } }) + + // 文本附件以 chip 形式预览,展示文件名。 + await waitFor(() => { + expect(screen.getByText('note.md')).toBeInTheDocument() + }) + + fireEvent.keyDown(screen.getByRole('textbox'), { key: 'Enter' }) + + await waitFor(() => { + expect(mockGatewayAPI.uploadSessionAsset).toHaveBeenCalledWith('session-1', file, 'workspace-b') + // 文本附件复用 type:'image' + 真实文本 mime,由 runtime 按白名单内联为 text part。 + expect(mockGatewayAPI.run).toHaveBeenCalledWith({ + session_id: 'session-1', + input_parts: [ + { type: 'image', media: { asset_id: 'asset-text-1', mime_type: 'text/markdown', file_name: 'note.md' } }, + ], + mode: 'build', + }) + }) + + expect(useChatStore.getState().messages[0]).toMatchObject({ + role: 'user', + attachments: [{ assetId: 'asset-text-1', kind: 'text', name: 'note.md' }], + }) + }) + + it('rejects unsupported file types with a clear error', async () => { + render() + + const file = new File(['%PDF-1.4'], 'doc.pdf', { type: 'application/pdf' }) + const input = document.querySelector('input[type="file"]') as HTMLInputElement + fireEvent.change(input, { target: { files: [file] } }) + + await waitFor(() => { + // 未知类型应被拒绝,不进入附件列表。 + expect(useComposerStore.getState().attachments).toHaveLength(0) + }) + }) + it('deletes uploaded session assets when run fails', async () => { useSessionStore.setState({ currentSessionId: 'session-1' } as never) mockGatewayAPI.uploadSessionAsset.mockResolvedValueOnce({ diff --git a/web/src/components/chat/ChatInput.tsx b/web/src/components/chat/ChatInput.tsx index 568de781..d98065a5 100644 --- a/web/src/components/chat/ChatInput.tsx +++ b/web/src/components/chat/ChatInput.tsx @@ -5,13 +5,17 @@ import { useSessionStore, isValidSessionId } from '@/stores/useSessionStore' import { useUIStore } from '@/stores/useUIStore' import { acceptedImageMimeTypes, + acceptedTextMimeTypes, + acceptedTextExtensions, maxComposerAttachmentBytes, + maxTextAttachmentBytes, + resolveAttachmentKind, useComposerStore, type ComposerAttachment, } from '@/stores/useComposerStore' import { useRuntimeInsightStore } from '@/stores/useRuntimeInsightStore' import { useWorkspaceStore } from '@/stores/useWorkspaceStore' -import { formatTokenCount } from '@/utils/format' +import { formatTokenCount, formatBytes } from '@/utils/format' import { useGatewayAPI } from '@/context/RuntimeProvider' import { type ModelEntry } from '@/api/protocol' import { @@ -26,7 +30,7 @@ import { import SlashCommandMenu from './SlashCommandMenu' import SkillPicker from './SkillPicker' import ModelSelector from './ModelSelector' -import { ImagePlus, Loader2, Send, Square, X } from 'lucide-react' +import { FileText, ImagePlus, Loader2, Send, Square, X } from 'lucide-react' const slashMenuAnchorStyle: React.CSSProperties = { position: 'absolute', @@ -370,7 +374,8 @@ export default function ChatInput() { if (handled) return } - if (pendingAttachments.length > 0 && currentImageInput === 'unsupported') { + // 仅当存在图片附件且模型不支持图片时才阻断;文本附件不依赖视觉能力。 + if (pendingAttachments.some((a) => a.kind === 'image') && currentImageInput === 'unsupported') { useUIStore.getState().showToast(unsupportedImageInputMessage, 'error') return } @@ -409,6 +414,7 @@ export default function ChatInput() { name: attachment.file.name, size: meta.size, previewUrl: attachment.previewUrl, + kind: attachment.kind, }))) setText('') @@ -514,24 +520,45 @@ export default function ChatInput() { } function handleFilesSelected(files: FileList | File[]) { - if (currentImageInput === 'unsupported') { - useUIStore.getState().showToast(unsupportedImageInputMessage, 'error') - return - } + // 文本附件不依赖模型视觉能力,任何模型都可接收;仅图片附件受 currentImageInput 约束。 + const imageUnsupported = currentImageInput === 'unsupported' const accepted: File[] = [] for (const file of Array.from(files)) { - if (!acceptedImageMimeTypes.includes(file.type as any)) { - useUIStore.getState().showToast('Only PNG, JPEG, and WebP images are supported', 'error') - continue - } if (file.size <= 0) { useUIStore.getState().showToast('Cannot upload an empty file', 'error') continue } - if (file.size > maxComposerAttachmentBytes) { - useUIStore.getState().showToast('Image exceeds the 20 MiB limit', 'error') + const kind = resolveAttachmentKind(file) + if (kind === 'unknown') { + useUIStore.getState().showToast('Unsupported file type', 'error') continue } + if (kind === 'image') { + if (imageUnsupported) { + useUIStore.getState().showToast(unsupportedImageInputMessage, 'error') + continue + } + if (!acceptedImageMimeTypes.includes(file.type as any)) { + useUIStore.getState().showToast('Only PNG, JPEG, and WebP images are supported', 'error') + continue + } + if (file.size > maxComposerAttachmentBytes) { + useUIStore.getState().showToast('Image exceeds the 20 MiB limit', 'error') + continue + } + } else { + // 文本附件:按 MIME 或扩展名校验白名单,并施加 256 KiB 字节上限。 + const isTextMime = (acceptedTextMimeTypes as readonly string[]).includes((file.type || '').toLowerCase()) + const isTextExt = acceptedTextExtensions.some((ext) => file.name.toLowerCase().endsWith(ext)) + if (!isTextMime && !isTextExt) { + useUIStore.getState().showToast('Unsupported file type', 'error') + continue + } + if (file.size > maxTextAttachmentBytes) { + useUIStore.getState().showToast('Text file exceeds the 256 KiB limit', 'error') + continue + } + } accepted.push(file) } if (accepted.length > 0) addAttachmentFiles(accepted) @@ -608,7 +635,7 @@ export default function ChatInput() { { @@ -618,15 +645,11 @@ export default function ChatInput() { /> - {attachment.error &&
{attachment.error}
} - - ))} + {isText ? ( +
+ + + {attachment.file.name} + + {formatBytes(attachment.file.size)} +
+ ) : attachment.previewUrl ? ( + {attachment.file.name} + ) : ( +
image
+ )} + {attachment.status === 'uploading' && ( +
+ )} + + {attachment.error &&
{attachment.error}
} + + ) + })} ) } @@ -792,6 +831,38 @@ const attachmentPreviewStyles: Record = { background: 'var(--bg-secondary)', overflow: 'hidden', }, + // 文本附件 chip:自适应宽度,展示文件名+大小,与图片缩略图区分。 + textItem: { + position: 'relative', + maxWidth: 260, + height: 32, + borderRadius: 'var(--radius-md)', + border: '1px solid var(--border-primary)', + background: 'var(--bg-secondary)', + overflow: 'hidden', + display: 'flex', + alignItems: 'center', + }, + textChip: { + display: 'flex', + alignItems: 'center', + gap: 6, + padding: '0 28px 0 8px', + width: '100%', + height: '100%', + color: 'var(--text-primary)', + fontSize: 12, + }, + textName: { + overflow: 'hidden', + textOverflow: 'ellipsis', + whiteSpace: 'nowrap', + }, + textSize: { + flexShrink: 0, + color: 'var(--text-tertiary)', + fontSize: 11, + }, image: { width: '100%', height: '100%', diff --git a/web/src/components/chat/MessageItem.test.tsx b/web/src/components/chat/MessageItem.test.tsx index d9d5ef90..cdf11743 100644 --- a/web/src/components/chat/MessageItem.test.tsx +++ b/web/src/components/chat/MessageItem.test.tsx @@ -242,4 +242,69 @@ describe("MessageItem", () => { }); createObjectURL.mockRestore(); }); + + it("renders user text attachments as chips with name and size", () => { + render( + , + ); + + expect(screen.getByText("read this")).toBeInTheDocument(); + // 文本附件以 chip 展示文件名和大小。 + expect(screen.getByText("notes.md")).toBeInTheDocument(); + expect(screen.getByText("128 B")).toBeInTheDocument(); + // 文本附件不应渲染为 。 + expect(screen.queryByAltText("notes.md")).not.toBeInTheDocument(); + // 文本附件不应触发 session asset 回载。 + expect(mockFetchSessionAsset).not.toHaveBeenCalled(); + }); + + it("infers text kind from mimeType when kind is missing for backward compat", () => { + render( + , + ); + + expect(screen.getByText("data.json")).toBeInTheDocument(); + expect(screen.getByText("2.0 KB")).toBeInTheDocument(); + expect(screen.queryByAltText("data.json")).not.toBeInTheDocument(); + }); }); diff --git a/web/src/components/chat/MessageItem.tsx b/web/src/components/chat/MessageItem.tsx index 8b3564fe..619ba759 100644 --- a/web/src/components/chat/MessageItem.tsx +++ b/web/src/components/chat/MessageItem.tsx @@ -1,12 +1,13 @@ import { memo, useEffect, useState } from "react"; -import { type ChatMessage } from "@/stores/useChatStore"; +import { type ChatMessage, type ChatAttachment } from "@/stores/useChatStore"; import { type PlanArtifact } from "@/api/protocol"; import { useGatewayAPI } from "@/context/RuntimeProvider"; +import { formatBytes } from "@/utils/format"; import ToolCallCard from "./ToolCallCard"; import AcceptanceMessage from "./AcceptanceMessage"; import CodeBlock from "./CodeBlock"; import MarkdownContent from "./MarkdownContent"; -import { Bot, ChevronRight, ClipboardList, Info, Loader2 } from "lucide-react"; +import { Bot, ChevronRight, ClipboardList, FileText, Info, Loader2 } from "lucide-react"; interface MessageItemProps { message: ChatMessage; @@ -91,6 +92,14 @@ function UserMessage({ message }: { message: ChatMessage }) { ); } +// formatBytes 已提取到 @/utils/format,composer 与消息历史共用同一实现。 + +// isImageAttachment 判定附件是否为图片类:优先使用显式 kind,缺省时按 mimeType 前缀推断(向后兼容)。 +function isImageAttachment(attachment: ChatAttachment): boolean { + if (attachment.kind) return attachment.kind === "image"; + return (attachment.mimeType || "").toLowerCase().startsWith("image/"); +} + function UserAttachments({ message }: { message: ChatMessage }) { const gatewayAPI = useGatewayAPI(); const [loadedURLs, setLoadedURLs] = useState>({}); @@ -101,6 +110,8 @@ function UserAttachments({ message }: { message: ChatMessage }) { let cancelled = false; const created: string[] = []; attachments.forEach((attachment) => { + // 文本附件 chip 只展示文件名+大小,无需回载 blob;仅图片附件需要 fetch。 + if (!isImageAttachment(attachment)) return; if (attachment.previewUrl || !attachment.sessionId || !attachment.assetId) return; gatewayAPI.fetchSessionAsset(attachment.sessionId, attachment.assetId, attachment.workspaceHash) .then((blob) => { @@ -121,6 +132,18 @@ function UserAttachments({ message }: { message: ChatMessage }) { return (
{attachments.map((attachment) => { + // 文本附件以 chip 形式展示文件名和大小,与图片缩略图区分。 + if (!isImageAttachment(attachment)) { + return ( +
+ + + {attachment.name || "text file"} + + {formatBytes(attachment.size)} +
+ ); + } const src = attachment.previewUrl || loadedURLs[attachment.id] || ""; return (
@@ -426,6 +449,29 @@ const styles: Record = { fontSize: 11, fontFamily: "var(--font-mono)", }, + // 文本附件 chip:展示文件名+大小,与图片缩略图视觉区分。 + userTextChip: { + display: "flex", + alignItems: "center", + gap: 6, + padding: "6px 10px", + borderRadius: "var(--radius-md)", + border: "1px solid var(--border-primary)", + background: "var(--bg-secondary)", + color: "var(--text-primary)", + fontSize: 12, + overflow: "hidden", + }, + userTextChipName: { + overflow: "hidden", + textOverflow: "ellipsis", + whiteSpace: "nowrap", + }, + userTextChipSize: { + flexShrink: 0, + color: "var(--text-tertiary)", + fontSize: 11, + }, aiRow: { display: "flex", gap: 10, diff --git a/web/src/stores/useChatStore.test.ts b/web/src/stores/useChatStore.test.ts index 1e6f5cae..92b44335 100644 --- a/web/src/stores/useChatStore.test.ts +++ b/web/src/stores/useChatStore.test.ts @@ -1,5 +1,5 @@ import { describe, it, expect, beforeEach, vi } from 'vitest' -import { useChatStore } from './useChatStore' +import { useChatStore, createUserMessage } from './useChatStore' beforeEach(() => { if (typeof URL.revokeObjectURL !== 'function') { @@ -276,4 +276,20 @@ describe('useChatStore', () => { expect(useChatStore.getState().permissionMode).toBe('default') expect(useChatStore.getState().isCompacting).toBe(false) }) + + it('createUserMessage preserves attachment kind for text and image files', () => { + const msg = createUserMessage('hi', [ + { id: 'a1', mimeType: 'image/png', kind: 'image', name: 'a.png' }, + { id: 'a2', mimeType: 'text/markdown', kind: 'text', name: 'b.md' }, + ]) + expect(msg.attachments).toEqual([ + { id: 'a1', mimeType: 'image/png', kind: 'image', name: 'a.png' }, + { id: 'a2', mimeType: 'text/markdown', kind: 'text', name: 'b.md' }, + ]) + }) + + it('createUserMessage omits attachments when none provided', () => { + const msg = createUserMessage('hi') + expect(msg.attachments).toBeUndefined() + }) }) diff --git a/web/src/stores/useChatStore.ts b/web/src/stores/useChatStore.ts index c2aa652a..485d436c 100644 --- a/web/src/stores/useChatStore.ts +++ b/web/src/stores/useChatStore.ts @@ -7,6 +7,7 @@ import { type PlanArtifact, } from "@/api/protocol"; import { resetEventBridgeCursors } from "@/utils/eventBridge"; +import type { AttachmentKind } from "@/stores/useComposerStore"; export interface ChatAttachment { id: string; @@ -17,6 +18,8 @@ export interface ChatAttachment { name?: string; size?: number; previewUrl?: string; + // 附件类别,用于消息展示区分图片缩略图与文本 chip;缺省时按 mimeType 前缀推断。 + kind?: AttachmentKind; } /** 聊天消息 */ diff --git a/web/src/stores/useComposerStore.test.ts b/web/src/stores/useComposerStore.test.ts index cf7dfcda..d67f4fdc 100644 --- a/web/src/stores/useComposerStore.test.ts +++ b/web/src/stores/useComposerStore.test.ts @@ -1,5 +1,5 @@ import { describe, it, expect, beforeEach, vi } from 'vitest' -import { useComposerStore } from './useComposerStore' +import { useComposerStore, resolveAttachmentKind } from './useComposerStore' beforeEach(() => { vi.restoreAllMocks() @@ -78,4 +78,75 @@ describe('useComposerStore', () => { error: 'too large', }) }) + + it('adds text attachments with kind text and no preview URL', () => { + const createObjectURL = vi.spyOn(URL, 'createObjectURL').mockReturnValue('blob:should-not-be-called') + const file = new File(['# title'], 'notes.md', { type: 'text/markdown' }) + + useComposerStore.getState().addAttachmentFiles([file]) + + const [attachment] = useComposerStore.getState().attachments + // 文本附件不应创建 blob URL,避免无意义内存占用。 + expect(createObjectURL).not.toHaveBeenCalled() + expect(attachment).toMatchObject({ + file, + previewUrl: '', + status: 'pending', + kind: 'text', + }) + }) + + it('classifies text files by extension when browser omits MIME', () => { + // 浏览器对 .csv/.md 等扩展名常返回空 MIME,应按扩展名兜底判定为 text。 + const file = new File(['a,b\n1,2'], 'data.csv', { type: '' }) + + useComposerStore.getState().addAttachmentFiles([file]) + + const [attachment] = useComposerStore.getState().attachments + expect(attachment.kind).toBe('text') + expect(attachment.previewUrl).toBe('') + }) + + it('keeps image attachments as kind image with preview URL', () => { + vi.spyOn(URL, 'createObjectURL').mockReturnValue('blob:preview-1') + useComposerStore.getState().addAttachmentFiles([new File(['img'], 'a.png', { type: 'image/png' })]) + + const [attachment] = useComposerStore.getState().attachments + expect(attachment.kind).toBe('image') + expect(attachment.previewUrl).toBe('blob:preview-1') + }) + + it('does not call revokeObjectURL when removing a text attachment', () => { + const revokeObjectURL = vi.spyOn(URL, 'revokeObjectURL').mockImplementation(() => {}) + useComposerStore.getState().addAttachmentFiles([new File(['x'], 'a.md', { type: 'text/markdown' })]) + const attachmentId = useComposerStore.getState().attachments[0].id + + useComposerStore.getState().removeAttachment(attachmentId) + + expect(useComposerStore.getState().attachments).toEqual([]) + // 文本附件 previewUrl 为空,revokePreviewURL 会直接跳过。 + expect(revokeObjectURL).not.toHaveBeenCalled() + }) +}) + +describe('resolveAttachmentKind', () => { + it('classifies image MIME as image', () => { + expect(resolveAttachmentKind(new File([], 'a.png', { type: 'image/png' }))).toBe('image') + }) + + it('classifies text MIME as text', () => { + expect(resolveAttachmentKind(new File([], 'a.md', { type: 'text/markdown' }))).toBe('text') + }) + + it('classifies by extension when MIME is empty', () => { + expect(resolveAttachmentKind(new File([], 'data.csv', { type: '' }))).toBe('text') + }) + + it('returns unknown for unsupported types like PDF', () => { + expect(resolveAttachmentKind(new File([], 'doc.pdf', { type: 'application/pdf' }))).toBe('unknown') + }) + + it('returns unknown for binary with no recognized extension', () => { + expect(resolveAttachmentKind(new File([], 'archive.zip', { type: 'application/zip' }))).toBe('unknown') + }) }) diff --git a/web/src/stores/useComposerStore.ts b/web/src/stores/useComposerStore.ts index 6109f0a5..1949913d 100644 --- a/web/src/stores/useComposerStore.ts +++ b/web/src/stores/useComposerStore.ts @@ -1,14 +1,34 @@ import { create } from 'zustand' export const acceptedImageMimeTypes = ['image/png', 'image/jpeg', 'image/webp'] as const +// 文本附件白名单,与后端 session.DefaultTextAssetWhitelist 保持一致(txt/md/json/yaml/yml/csv)。 +export const acceptedTextMimeTypes = [ + 'text/plain', + 'text/markdown', + 'application/json', + 'text/yaml', + 'application/x-yaml', + 'text/csv', +] as const +// 浏览器对部分文本扩展名(如 .md/.csv)可能不返回 MIME,需要按扩展名兜底校验。 +export const acceptedTextExtensions = ['.txt', '.md', '.json', '.yaml', '.yml', '.csv'] as const export const maxComposerAttachmentBytes = 20 * 1024 * 1024 +// 文本附件字节上限,与后端 session.DefaultMaxTextAssetBytes(256 KiB)对齐,避免无效上传。 +export const maxTextAttachmentBytes = 256 * 1024 + +export type AttachmentKind = 'image' | 'text' +// ResolvedAttachmentKind 用于文件校验阶段的三路判定,unknown 表示既非图片也非文本。 +export type ResolvedAttachmentKind = AttachmentKind | 'unknown' export interface ComposerAttachment { id: string file: File + // 图片附件用 blob URL 做缩略图预览;文本附件不创建 blob URL,留空字符串。 previewUrl: string status: 'pending' | 'uploading' | 'uploaded' | 'error' error?: string + // 附件类别,由 createComposerAttachment 按 MIME/扩展名判定,供预览与展示分流使用。 + kind: AttachmentKind } interface ComposerState { @@ -48,14 +68,31 @@ export const useComposerStore = create((set) => ({ })) function createComposerAttachment(file: File): ComposerAttachment { + const resolved = resolveAttachmentKind(file) + // unknown 类型在 handleFilesSelected 校验阶段已被拒绝,此处仅做类型收窄(不可达分支)。 + const kind: AttachmentKind = resolved === 'text' ? 'text' : 'image' return { id: `att_${Date.now()}_${Math.random().toString(36).slice(2)}`, file, - previewUrl: createPreviewURL(file), + // 文本附件不创建 blob URL:避免无意义内存占用,文本预览由 chip(文件名+大小)承载。 + previewUrl: kind === 'image' ? createPreviewURL(file) : '', status: 'pending', + kind, } } +// resolveAttachmentKind 按浏览器声明的 MIME 判定附件类别;MIME 缺失时按扩展名兜底。 +// 图片 MIME 优先判定为 image,文本白名单内的扩展名/MIME 判定为 text,其余返回 unknown。 +export function resolveAttachmentKind(file: File): ResolvedAttachmentKind { + const mime = (file.type || '').toLowerCase() + if (mime.startsWith('image/')) return 'image' + if ((acceptedTextMimeTypes as readonly string[]).includes(mime)) return 'text' + // 浏览器对 .md/.csv 等扩展名常返回空 MIME 或 application/octet-stream,按扩展名兜底。 + const name = (file.name || '').toLowerCase() + if (acceptedTextExtensions.some((ext) => name.endsWith(ext))) return 'text' + return 'unknown' +} + function createPreviewURL(file: File) { if (typeof URL !== 'undefined' && typeof URL.createObjectURL === 'function') { return URL.createObjectURL(file) diff --git a/web/src/utils/format.test.ts b/web/src/utils/format.test.ts index eac2a438..aaa690ee 100644 --- a/web/src/utils/format.test.ts +++ b/web/src/utils/format.test.ts @@ -1,5 +1,5 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' -import { formatSessionTime, parseDateTime, relativeTime } from './format' +import { formatBytes, formatSessionTime, parseDateTime, relativeTime } from './format' describe('relativeTime', () => { beforeEach(() => { @@ -57,3 +57,27 @@ describe('formatSessionTime', () => { expect(formatSessionTime('not-a-time')).toBe('not-a-time') }) }) + +describe('formatBytes', () => { + it('returns 0 B for zero or undefined', () => { + expect(formatBytes(0)).toBe('0 B') + expect(formatBytes(undefined)).toBe('0 B') + expect(formatBytes(-1)).toBe('0 B') + }) + + it('formats bytes below 1024 as B', () => { + expect(formatBytes(1)).toBe('1 B') + expect(formatBytes(512)).toBe('512 B') + expect(formatBytes(1023)).toBe('1023 B') + }) + + it('formats bytes between 1 KiB and 1 MiB as KB', () => { + expect(formatBytes(1024)).toBe('1.0 KB') + expect(formatBytes(2048)).toBe('2.0 KB') + }) + + it('formats bytes above 1 MiB as MB', () => { + expect(formatBytes(1024 * 1024)).toBe('1.0 MB') + expect(formatBytes(256 * 1024)).toBe('256.0 KB') + }) +}) diff --git a/web/src/utils/format.ts b/web/src/utils/format.ts index e5eb6204..3c46e6b0 100644 --- a/web/src/utils/format.ts +++ b/web/src/utils/format.ts @@ -19,6 +19,14 @@ export function formatTokenCount(n: number): string { return String(n) } +/** 格式化字节数为人类可读字符串,用于附件大小展示 */ +export function formatBytes(bytes?: number): string { + if (!bytes || bytes <= 0) return '0 B' + if (bytes < 1024) return `${bytes} B` + if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(1)} KB` + return `${(bytes / (1024 * 1024)).toFixed(1)} MB` +} + /** 截断文本 */ export function truncate(text: string, maxLen: number): string { if (text.length <= maxLen) return text