From 6feef44c10ccd9d469577d5c2fb546697fc64c97 Mon Sep 17 00:00:00 2001 From: Boii Date: Wed, 17 Jun 2026 21:16:13 +0800 Subject: [PATCH] fix(conversation): upsert streaming tool calls - Persist ACP and aionrs tool-call events through an atomic message upsert - Merge JSON content on duplicate message IDs while preserving terminal status - Cover out-of-order ACP and aionrs tool-call event persistence --- .../src/stream_persistence.rs | 191 ++++-------------- .../aionui-conversation/src/stream_relay.rs | 93 +++++++++ .../tests/acp_tool_call_persistence.rs | 88 ++++++++ .../tests/stream_relay_tool_call.rs | 56 +++++ .../aionui-db/src/repository/conversation.rs | 19 ++ .../src/repository/sqlite_conversation.rs | 46 +++++ 6 files changed, 340 insertions(+), 153 deletions(-) diff --git a/crates/aionui-conversation/src/stream_persistence.rs b/crates/aionui-conversation/src/stream_persistence.rs index 738f706f6..7bcad7a9b 100644 --- a/crates/aionui-conversation/src/stream_persistence.rs +++ b/crates/aionui-conversation/src/stream_persistence.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use aionui_ai_agent::protocol::events::{ ErrorEventData, TipType, TipsEventData, - tool_call::{AcpToolCallSessionUpdateKind, AcpToolCallStatus, ToolCallStatus}, + tool_call::{AcpToolCallStatus, ToolCallStatus}, }; use aionui_api_types::{ConversationRuntimeSummary, WebSocketMessage}; use aionui_common::{ErrorChain, normalize_keys_to_snake_case, now_ms}; @@ -400,63 +400,32 @@ impl StreamPersistenceAdapter { }; let content = serde_json::to_string(data).unwrap_or_default(); - let existing = self - .repo - .get_message_by_msg_id(&self.conversation_id, &data.call_id, "tool_call") - .await - .unwrap_or(None); - - if let Some(existing_row) = existing { - let merged_content = Self::merge_json_content(&existing_row.content, &content); - let update = MessageRowUpdate { - content: Some(merged_content), - status: Some(Some(status.to_owned())), - hidden: None, - }; - if let Err(e) = self.repo.update_message(&data.call_id, &update).await { - error!( - call_id = %data.call_id, - tool = %data.name, - status, - error = %ErrorChain(&e), - "Failed to update tool_call message" - ); - } else { - debug!( - call_id = %data.call_id, - tool = %data.name, - status, - "Updated tool_call message" - ); - } + let row = MessageRow { + id: data.call_id.clone(), + conversation_id: self.conversation_id.clone(), + msg_id: Some(data.call_id.clone()), + r#type: "tool_call".into(), + content, + position: Some("left".into()), + status: Some(status.to_owned()), + hidden: false, + created_at: now_ms(), + }; + if let Err(e) = self.repo.upsert_message(&row).await { + error!( + call_id = %data.call_id, + tool = %data.name, + status, + error = %ErrorChain(&e), + "Failed to upsert tool_call message" + ); } else { - let row = MessageRow { - id: data.call_id.clone(), - conversation_id: self.conversation_id.clone(), - msg_id: Some(data.call_id.clone()), - r#type: "tool_call".into(), - content, - position: Some("left".into()), - status: Some(status.to_owned()), - hidden: false, - created_at: now_ms(), - }; - if let Err(e) = self.repo.insert_message(&row).await { - error!( - call_id = %data.call_id, - tool = %data.name, - status, - error = %ErrorChain(&e), - "Failed to persist tool_call message" - ); - } else { - debug!( - call_id = %data.call_id, - tool = %data.name, - status, - "Persisted tool_call message" - ); - } + debug!( + call_id = %data.call_id, + tool = %data.name, + status, + "Upserted tool_call message" + ); } } @@ -481,104 +450,20 @@ impl StreamPersistenceAdapter { normalize_keys_to_snake_case(&mut value); let content = value.to_string(); - match data.update.session_update { - AcpToolCallSessionUpdateKind::ToolCall => { - let row = MessageRow { - id: tool_call_id.clone(), - conversation_id: self.conversation_id.clone(), - msg_id: Some(tool_call_id.clone()), - r#type: "acp_tool_call".into(), - content, - position: Some("left".into()), - status: Some(status.to_owned()), - hidden: false, - created_at: now_ms(), - }; - if let Err(e) = self.repo.insert_message(&row).await { - error!(error = %ErrorChain(&e), "Failed to persist acp_tool_call message"); - } - } - AcpToolCallSessionUpdateKind::ToolCallUpdate => { - let merged_content = self.merge_acp_tool_call_content(tool_call_id, &value).await; - let update = MessageRowUpdate { - content: Some(merged_content.clone()), - status: Some(Some(status.to_owned())), - hidden: None, - }; - if let Err(e) = self.repo.update_message(tool_call_id, &update).await { - match e { - DbError::NotFound(_) => { - warn!( - conversation_id = %self.conversation_id, - tool_call_id = %tool_call_id, - "ACP tool call update arrived before initial tool call; inserting placeholder" - ); - let row = MessageRow { - id: tool_call_id.clone(), - conversation_id: self.conversation_id.clone(), - msg_id: Some(tool_call_id.clone()), - r#type: "acp_tool_call".into(), - content: merged_content, - position: Some("left".into()), - status: Some(status.to_owned()), - hidden: false, - created_at: now_ms(), - }; - if let Err(insert_err) = self.repo.insert_message(&row).await { - error!( - error = %ErrorChain(&insert_err), - "Failed to insert late acp_tool_call placeholder" - ); - } - } - other => { - error!(error = %ErrorChain(&other), "Failed to update acp_tool_call message"); - } - } - } - } - } - } - - /// Merge two JSON content strings: overlays non-null fields from `new_json` - /// onto `existing_json`, preserving fields only present in the original. - fn merge_json_content(existing_json: &str, new_json: &str) -> String { - let mut base: serde_json::Value = serde_json::from_str(existing_json).unwrap_or_default(); - let new_value: serde_json::Value = serde_json::from_str(new_json).unwrap_or_default(); - if let (Some(base_obj), Some(new_obj)) = (base.as_object_mut(), new_value.as_object()) { - for (key, val) in new_obj { - if !val.is_null() { - base_obj.insert(key.clone(), val.clone()); - } - } - } - base.to_string() - } - - async fn merge_acp_tool_call_content(&self, tool_call_id: &str, update_value: &serde_json::Value) -> String { - let existing = self - .repo - .get_message_by_msg_id(&self.conversation_id, tool_call_id, "acp_tool_call") - .await - .ok() - .flatten(); - - let Some(existing_row) = existing else { - return update_value.to_string(); + let row = MessageRow { + id: tool_call_id.clone(), + conversation_id: self.conversation_id.clone(), + msg_id: Some(tool_call_id.clone()), + r#type: "acp_tool_call".into(), + content, + position: Some("left".into()), + status: Some(status.to_owned()), + hidden: false, + created_at: now_ms(), }; - - let mut base: serde_json::Value = serde_json::from_str(&existing_row.content).unwrap_or_default(); - if let (Some(base_update), Some(new_update)) = ( - base.get_mut("update").and_then(|v| v.as_object_mut()), - update_value.get("update").and_then(|v| v.as_object()), - ) { - for (key, val) in new_update { - if !val.is_null() { - base_update.insert(key.clone(), val.clone()); - } - } + if let Err(e) = self.repo.upsert_message(&row).await { + error!(error = %ErrorChain(&e), "Failed to upsert acp_tool_call message"); } - base.to_string() } /// Persist a tool_group event (array of tool summaries). diff --git a/crates/aionui-conversation/src/stream_relay.rs b/crates/aionui-conversation/src/stream_relay.rs index 542933e28..0baf97cde 100644 --- a/crates/aionui-conversation/src/stream_relay.rs +++ b/crates/aionui-conversation/src/stream_relay.rs @@ -1805,6 +1805,75 @@ mod tests { fn take_updates(&self) -> Vec<(String, aionui_db::MessageRowUpdate)> { std::mem::take(&mut self.updates.lock().unwrap()) } + + fn merged_row(existing: &MessageRow, incoming: &MessageRow) -> MessageRow { + let preserve_terminal_status = matches!(existing.status.as_deref(), Some("finish" | "error")) + && incoming.status.as_deref() == Some("work"); + let mut content = Self::merge_json_content(&existing.content, &incoming.content); + if preserve_terminal_status { + content = Self::preserve_json_status(&content, &existing.content, &existing.r#type); + } + + let mut merged = existing.clone(); + merged.content = content; + merged.status = if preserve_terminal_status { + existing.status.clone() + } else { + incoming.status.clone() + }; + merged.hidden = incoming.hidden; + merged + } + + fn merge_json_content(existing_json: &str, incoming_json: &str) -> String { + let mut existing: serde_json::Value = serde_json::from_str(existing_json).unwrap_or_default(); + let incoming: serde_json::Value = serde_json::from_str(incoming_json).unwrap_or_default(); + Self::merge_json_value(&mut existing, incoming); + existing.to_string() + } + + fn merge_json_value(existing: &mut serde_json::Value, incoming: serde_json::Value) { + match (existing, incoming) { + (serde_json::Value::Object(existing_obj), serde_json::Value::Object(incoming_obj)) => { + for (key, value) in incoming_obj { + if !value.is_null() { + if let Some(existing_value) = existing_obj.get_mut(&key) { + Self::merge_json_value(existing_value, value); + } else { + existing_obj.insert(key, value); + } + } + } + } + (existing_value, incoming_value) => { + if !incoming_value.is_null() { + *existing_value = incoming_value; + } + } + } + } + + fn preserve_json_status(merged_json: &str, existing_json: &str, msg_type: &str) -> String { + let mut merged: serde_json::Value = serde_json::from_str(merged_json).unwrap_or_default(); + let existing: serde_json::Value = serde_json::from_str(existing_json).unwrap_or_default(); + let status = if msg_type == "acp_tool_call" { + existing.pointer("/update/status").cloned() + } else { + existing.get("status").cloned() + }; + + if let Some(status) = status { + if msg_type == "acp_tool_call" { + if let Some(update) = merged.get_mut("update").and_then(|value| value.as_object_mut()) { + update.insert("status".into(), status); + } + } else if let Some(object) = merged.as_object_mut() { + object.insert("status".into(), status); + } + } + + merged.to_string() + } } #[async_trait::async_trait] @@ -1881,6 +1950,30 @@ mod tests { self.inserts.lock().unwrap().push(row.clone()); Ok(()) } + async fn upsert_message(&self, row: &MessageRow) -> Result<(), DbError> { + if self.not_found.load(Ordering::Acquire) { + return Err(DbError::NotFound(format!("Message '{}'", row.id))); + } + if self.foreign_key_failure.load(Ordering::Acquire) { + return Err(DbError::Init("FOREIGN KEY constraint failed".into())); + } + + let mut inserts = self.inserts.lock().unwrap(); + if let Some(existing) = inserts.iter().find(|message| message.id == row.id) { + let merged = Self::merged_row(existing, row); + self.updates.lock().unwrap().push(( + row.id.clone(), + aionui_db::MessageRowUpdate { + content: Some(merged.content), + status: Some(merged.status), + hidden: Some(merged.hidden), + }, + )); + } else { + inserts.push(row.clone()); + } + Ok(()) + } async fn update_message(&self, id: &str, updates: &aionui_db::MessageRowUpdate) -> Result<(), DbError> { if self.not_found.load(Ordering::Acquire) { return Err(DbError::NotFound(format!("Message '{id}' not found"))); diff --git a/crates/aionui-conversation/tests/acp_tool_call_persistence.rs b/crates/aionui-conversation/tests/acp_tool_call_persistence.rs index e09347ccc..f07dc8095 100644 --- a/crates/aionui-conversation/tests/acp_tool_call_persistence.rs +++ b/crates/aionui-conversation/tests/acp_tool_call_persistence.rs @@ -78,3 +78,91 @@ async fn run_acp_tool_call_update_without_insert_creates_placeholder() { .any(|m| m.id == "atc-late" && m.r#type == "acp_tool_call") ); } + +#[tokio::test] +async fn run_acp_tool_call_late_initial_event_merges_with_update_placeholder() { + let db = init_database_memory().await.unwrap(); + let user_repo = SqliteUserRepository::new(db.pool().clone()); + let user = user_repo.create_user("user-1", "hash").await.unwrap(); + let repo = Arc::new(SqliteConversationRepository::new(db.pool().clone())); + repo.create(&ConversationRow { + id: "conv-1".into(), + user_id: user.id, + name: "test".into(), + r#type: "acp".into(), + extra: "{}".into(), + model: None, + status: Some("running".into()), + source: Some("aionui".into()), + channel_chat_id: None, + pinned: false, + pinned_at: None, + created_at: now_ms(), + updated_at: now_ms(), + }) + .await + .unwrap(); + + let bus = Arc::new(aionui_realtime::BroadcastEventBus::new(64)); + let (tx, _) = broadcast::channel(64); + let relay = StreamRelay::new( + "conv-1".into(), + "asst-1".into(), + "turn-1".into(), + "user-1".into(), + repo.clone(), + bus, + None, + ); + let rx = tx.subscribe(); + + tx.send(AgentStreamEvent::AcpToolCall(AcpToolCallEventData { + session_id: "sess-1".into(), + update: AcpToolCallUpdateData { + session_update: AcpToolCallSessionUpdateKind::ToolCallUpdate, + tool_call_id: "atc-out-of-order".into(), + status: Some(AcpToolCallStatus::Completed), + title: None, + kind: None, + raw_input: None, + raw_output: Some(json!("exit 0")), + content: None, + locations: None, + }, + meta: None, + })) + .unwrap(); + tx.send(AgentStreamEvent::AcpToolCall(AcpToolCallEventData { + session_id: "sess-1".into(), + update: AcpToolCallUpdateData { + session_update: AcpToolCallSessionUpdateKind::ToolCall, + tool_call_id: "atc-out-of-order".into(), + status: Some(AcpToolCallStatus::InProgress), + title: Some("Bash".into()), + kind: None, + raw_input: Some(json!({"command": "echo hi"})), + raw_output: None, + content: None, + locations: None, + }, + meta: None, + })) + .unwrap(); + tx.send(AgentStreamEvent::Finish(FinishEventData::default())).unwrap(); + + relay.consume(rx).await; + + let messages = repo.get_messages("conv-1", 1, 20, SortOrder::Asc).await.unwrap().items; + let msg = messages + .iter() + .find(|m| m.id == "atc-out-of-order" && m.r#type == "acp_tool_call") + .expect("acp tool call row should be persisted"); + assert_eq!(msg.status.as_deref(), Some("finish")); + + let content: serde_json::Value = serde_json::from_str(&msg.content).unwrap(); + let update = content.get("update").expect("content should include update object"); + assert_eq!(update["status"], "completed"); + assert_eq!(update["title"], "Bash"); + assert_eq!(update["raw_input"]["command"], "echo hi"); + assert_eq!(update["raw_output"], "exit 0"); +} diff --git a/crates/aionui-conversation/tests/stream_relay_tool_call.rs b/crates/aionui-conversation/tests/stream_relay_tool_call.rs index 2b9fe5901..f6b4954d0 100644 --- a/crates/aionui-conversation/tests/stream_relay_tool_call.rs +++ b/crates/aionui-conversation/tests/stream_relay_tool_call.rs @@ -107,6 +107,62 @@ async fn run_tool_call_with_empty_call_id_is_not_persisted() { ); } +#[tokio::test] +async fn run_tool_call_late_running_event_does_not_regress_completed_message() { + let (repo, _db) = setup_repo().await; + let bus = Arc::new(BroadcastEventBus::new(64)); + let (tx, _) = broadcast::channel(64); + + let relay = StreamRelay::new( + "conv-1".into(), + "asst-1".into(), + "turn-1".into(), + "system_default_user".into(), + repo.clone(), + bus, + None, + ); + + let rx = tx.subscribe(); + tx.send(AgentStreamEvent::ToolCall(ToolCallEventData { + call_id: "glob-1".into(), + name: "Glob".into(), + args: json!({"pattern": "*.rs"}), + status: ToolCallStatus::Completed, + input: None, + output: Some("src/main.rs".into()), + description: None, + })) + .unwrap(); + tx.send(AgentStreamEvent::ToolCall(ToolCallEventData { + call_id: "glob-1".into(), + name: "Glob".into(), + args: json!({"pattern": "*.rs"}), + status: ToolCallStatus::Running, + input: Some(json!({"pattern": "*.rs"})), + output: None, + description: Some("search files".into()), + })) + .unwrap(); + tx.send(AgentStreamEvent::Finish(FinishEventData::default())).unwrap(); + + relay.consume(rx).await; + + let messages = repo.get_messages("conv-1", 1, 100, SortOrder::Asc).await.unwrap(); + let msg = messages + .items + .iter() + .find(|row| row.id == "glob-1" && row.r#type == "tool_call") + .expect("tool call row should be persisted"); + assert_eq!(msg.status.as_deref(), Some("finish")); + + let content: serde_json::Value = serde_json::from_str(&msg.content).unwrap(); + assert_eq!(content["status"], "completed"); + assert_eq!(content["output"], "src/main.rs"); + assert_eq!(content["input"]["pattern"], "*.rs"); + assert_eq!(content["description"], "search files"); +} + struct ToolCallAgent { conversation_id: String, event_tx: broadcast::Sender, diff --git a/crates/aionui-db/src/repository/conversation.rs b/crates/aionui-db/src/repository/conversation.rs index 4f00f061b..6c7f242b1 100644 --- a/crates/aionui-db/src/repository/conversation.rs +++ b/crates/aionui-db/src/repository/conversation.rs @@ -96,6 +96,25 @@ pub trait IConversationRepository: Send + Sync { /// Inserts a new message row. async fn insert_message(&self, message: &MessageRow) -> Result<(), DbError>; + /// Inserts a message row, or merges mutable fields into the existing row with the same ID. + async fn upsert_message(&self, message: &MessageRow) -> Result<(), DbError> { + match self.insert_message(message).await { + Ok(()) => Ok(()), + Err(DbError::Conflict(_)) => { + self.update_message( + &message.id, + &MessageRowUpdate { + content: Some(message.content.clone()), + status: Some(message.status.clone()), + hidden: Some(message.hidden), + }, + ) + .await + } + Err(err) => Err(err), + } + } + /// Partially updates a message. Returns `DbError::NotFound` if ID is missing. async fn update_message(&self, id: &str, updates: &MessageRowUpdate) -> Result<(), DbError>; diff --git a/crates/aionui-db/src/repository/sqlite_conversation.rs b/crates/aionui-db/src/repository/sqlite_conversation.rs index f1ea2b7ea..2d06c6d1c 100644 --- a/crates/aionui-db/src/repository/sqlite_conversation.rs +++ b/crates/aionui-db/src/repository/sqlite_conversation.rs @@ -458,6 +458,52 @@ impl IConversationRepository for SqliteConversationRepository { Ok(()) } + async fn upsert_message(&self, message: &MessageRow) -> Result<(), DbError> { + sqlx::query( + "INSERT INTO messages \ + (id, conversation_id, msg_id, type, content, position, \ + status, hidden, created_at) \ + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) \ + ON CONFLICT(id) DO UPDATE SET \ + content = CASE \ + WHEN messages.status IN ('finish', 'error') AND excluded.status = 'work' THEN \ + CASE messages.type \ + WHEN 'acp_tool_call' THEN json_set( \ + json_patch(messages.content, excluded.content), \ + '$.update.status', \ + json_extract(messages.content, '$.update.status') \ + ) \ + ELSE json_set( \ + json_patch(messages.content, excluded.content), \ + '$.status', \ + json_extract(messages.content, '$.status') \ + ) \ + END \ + ELSE json_patch(messages.content, excluded.content) \ + END, \ + status = CASE \ + WHEN messages.status IN ('finish', 'error') AND excluded.status = 'work' THEN messages.status \ + ELSE excluded.status \ + END, \ + position = COALESCE(messages.position, excluded.position), \ + hidden = excluded.hidden, \ + created_at = MIN(messages.created_at, excluded.created_at)", + ) + .bind(&message.id) + .bind(&message.conversation_id) + .bind(&message.msg_id) + .bind(&message.r#type) + .bind(&message.content) + .bind(&message.position) + .bind(&message.status) + .bind(message.hidden) + .bind(message.created_at) + .execute(&self.pool) + .await?; + + Ok(()) + } + async fn update_message(&self, id: &str, updates: &MessageRowUpdate) -> Result<(), DbError> { let mut set_parts: Vec = Vec::new(); let mut binds: Vec = Vec::new();