Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 38 additions & 153 deletions crates/aionui-conversation/src/stream_persistence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;

use aionui_ai_agent::protocol::events::{
ErrorEventData, TipType, TipsEventData,
tool_call::{AcpToolCallSessionUpdateKind, AcpToolCallStatus, ToolCallStatus},
tool_call::{AcpToolCallStatus, ToolCallStatus},
};
use aionui_api_types::{ConversationRuntimeSummary, WebSocketMessage};
use aionui_common::{ErrorChain, normalize_keys_to_snake_case, now_ms};
Expand Down Expand Up @@ -400,63 +400,32 @@ impl StreamPersistenceAdapter {
};
let content = serde_json::to_string(data).unwrap_or_default();

let existing = self
.repo
.get_message_by_msg_id(&self.conversation_id, &data.call_id, "tool_call")
.await
.unwrap_or(None);

if let Some(existing_row) = existing {
let merged_content = Self::merge_json_content(&existing_row.content, &content);
let update = MessageRowUpdate {
content: Some(merged_content),
status: Some(Some(status.to_owned())),
hidden: None,
};
if let Err(e) = self.repo.update_message(&data.call_id, &update).await {
error!(
call_id = %data.call_id,
tool = %data.name,
status,
error = %ErrorChain(&e),
"Failed to update tool_call message"
);
} else {
debug!(
call_id = %data.call_id,
tool = %data.name,
status,
"Updated tool_call message"
);
}
let row = MessageRow {
id: data.call_id.clone(),
conversation_id: self.conversation_id.clone(),
msg_id: Some(data.call_id.clone()),
r#type: "tool_call".into(),
content,
position: Some("left".into()),
status: Some(status.to_owned()),
hidden: false,
created_at: now_ms(),
};
if let Err(e) = self.repo.upsert_message(&row).await {
error!(
call_id = %data.call_id,
tool = %data.name,
status,
error = %ErrorChain(&e),
"Failed to upsert tool_call message"
);
} else {
let row = MessageRow {
id: data.call_id.clone(),
conversation_id: self.conversation_id.clone(),
msg_id: Some(data.call_id.clone()),
r#type: "tool_call".into(),
content,
position: Some("left".into()),
status: Some(status.to_owned()),
hidden: false,
created_at: now_ms(),
};
if let Err(e) = self.repo.insert_message(&row).await {
error!(
call_id = %data.call_id,
tool = %data.name,
status,
error = %ErrorChain(&e),
"Failed to persist tool_call message"
);
} else {
debug!(
call_id = %data.call_id,
tool = %data.name,
status,
"Persisted tool_call message"
);
}
debug!(
call_id = %data.call_id,
tool = %data.name,
status,
"Upserted tool_call message"
);
}
}

Expand All @@ -481,104 +450,20 @@ impl StreamPersistenceAdapter {
normalize_keys_to_snake_case(&mut value);
let content = value.to_string();

match data.update.session_update {
AcpToolCallSessionUpdateKind::ToolCall => {
let row = MessageRow {
id: tool_call_id.clone(),
conversation_id: self.conversation_id.clone(),
msg_id: Some(tool_call_id.clone()),
r#type: "acp_tool_call".into(),
content,
position: Some("left".into()),
status: Some(status.to_owned()),
hidden: false,
created_at: now_ms(),
};
if let Err(e) = self.repo.insert_message(&row).await {
error!(error = %ErrorChain(&e), "Failed to persist acp_tool_call message");
}
}
AcpToolCallSessionUpdateKind::ToolCallUpdate => {
let merged_content = self.merge_acp_tool_call_content(tool_call_id, &value).await;
let update = MessageRowUpdate {
content: Some(merged_content.clone()),
status: Some(Some(status.to_owned())),
hidden: None,
};
if let Err(e) = self.repo.update_message(tool_call_id, &update).await {
match e {
DbError::NotFound(_) => {
warn!(
conversation_id = %self.conversation_id,
tool_call_id = %tool_call_id,
"ACP tool call update arrived before initial tool call; inserting placeholder"
);
let row = MessageRow {
id: tool_call_id.clone(),
conversation_id: self.conversation_id.clone(),
msg_id: Some(tool_call_id.clone()),
r#type: "acp_tool_call".into(),
content: merged_content,
position: Some("left".into()),
status: Some(status.to_owned()),
hidden: false,
created_at: now_ms(),
};
if let Err(insert_err) = self.repo.insert_message(&row).await {
error!(
error = %ErrorChain(&insert_err),
"Failed to insert late acp_tool_call placeholder"
);
}
}
other => {
error!(error = %ErrorChain(&other), "Failed to update acp_tool_call message");
}
}
}
}
}
}

/// Merge two JSON content strings: overlays non-null fields from `new_json`
/// onto `existing_json`, preserving fields only present in the original.
fn merge_json_content(existing_json: &str, new_json: &str) -> String {
let mut base: serde_json::Value = serde_json::from_str(existing_json).unwrap_or_default();
let new_value: serde_json::Value = serde_json::from_str(new_json).unwrap_or_default();
if let (Some(base_obj), Some(new_obj)) = (base.as_object_mut(), new_value.as_object()) {
for (key, val) in new_obj {
if !val.is_null() {
base_obj.insert(key.clone(), val.clone());
}
}
}
base.to_string()
}

async fn merge_acp_tool_call_content(&self, tool_call_id: &str, update_value: &serde_json::Value) -> String {
let existing = self
.repo
.get_message_by_msg_id(&self.conversation_id, tool_call_id, "acp_tool_call")
.await
.ok()
.flatten();

let Some(existing_row) = existing else {
return update_value.to_string();
let row = MessageRow {
id: tool_call_id.clone(),
conversation_id: self.conversation_id.clone(),
msg_id: Some(tool_call_id.clone()),
r#type: "acp_tool_call".into(),
content,
position: Some("left".into()),
status: Some(status.to_owned()),
hidden: false,
created_at: now_ms(),
};

let mut base: serde_json::Value = serde_json::from_str(&existing_row.content).unwrap_or_default();
if let (Some(base_update), Some(new_update)) = (
base.get_mut("update").and_then(|v| v.as_object_mut()),
update_value.get("update").and_then(|v| v.as_object()),
) {
for (key, val) in new_update {
if !val.is_null() {
base_update.insert(key.clone(), val.clone());
}
}
if let Err(e) = self.repo.upsert_message(&row).await {
error!(error = %ErrorChain(&e), "Failed to upsert acp_tool_call message");
}
base.to_string()
}

/// Persist a tool_group event (array of tool summaries).
Expand Down
93 changes: 93 additions & 0 deletions crates/aionui-conversation/src/stream_relay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1805,6 +1805,75 @@ mod tests {
fn take_updates(&self) -> Vec<(String, aionui_db::MessageRowUpdate)> {
std::mem::take(&mut self.updates.lock().unwrap())
}

fn merged_row(existing: &MessageRow, incoming: &MessageRow) -> MessageRow {
let preserve_terminal_status = matches!(existing.status.as_deref(), Some("finish" | "error"))
&& incoming.status.as_deref() == Some("work");
let mut content = Self::merge_json_content(&existing.content, &incoming.content);
if preserve_terminal_status {
content = Self::preserve_json_status(&content, &existing.content, &existing.r#type);
}

let mut merged = existing.clone();
merged.content = content;
merged.status = if preserve_terminal_status {
existing.status.clone()
} else {
incoming.status.clone()
};
merged.hidden = incoming.hidden;
merged
}

fn merge_json_content(existing_json: &str, incoming_json: &str) -> String {
let mut existing: serde_json::Value = serde_json::from_str(existing_json).unwrap_or_default();
let incoming: serde_json::Value = serde_json::from_str(incoming_json).unwrap_or_default();
Self::merge_json_value(&mut existing, incoming);
existing.to_string()
}

fn merge_json_value(existing: &mut serde_json::Value, incoming: serde_json::Value) {
match (existing, incoming) {
(serde_json::Value::Object(existing_obj), serde_json::Value::Object(incoming_obj)) => {
for (key, value) in incoming_obj {
if !value.is_null() {
if let Some(existing_value) = existing_obj.get_mut(&key) {
Self::merge_json_value(existing_value, value);
} else {
existing_obj.insert(key, value);
}
}
}
}
(existing_value, incoming_value) => {
if !incoming_value.is_null() {
*existing_value = incoming_value;
}
}
}
}

fn preserve_json_status(merged_json: &str, existing_json: &str, msg_type: &str) -> String {
let mut merged: serde_json::Value = serde_json::from_str(merged_json).unwrap_or_default();
let existing: serde_json::Value = serde_json::from_str(existing_json).unwrap_or_default();
let status = if msg_type == "acp_tool_call" {
existing.pointer("/update/status").cloned()
} else {
existing.get("status").cloned()
};

if let Some(status) = status {
if msg_type == "acp_tool_call" {
if let Some(update) = merged.get_mut("update").and_then(|value| value.as_object_mut()) {
update.insert("status".into(), status);
}
} else if let Some(object) = merged.as_object_mut() {
object.insert("status".into(), status);
}
}

merged.to_string()
}
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -1881,6 +1950,30 @@ mod tests {
self.inserts.lock().unwrap().push(row.clone());
Ok(())
}
async fn upsert_message(&self, row: &MessageRow) -> Result<(), DbError> {
if self.not_found.load(Ordering::Acquire) {
return Err(DbError::NotFound(format!("Message '{}'", row.id)));
}
if self.foreign_key_failure.load(Ordering::Acquire) {
return Err(DbError::Init("FOREIGN KEY constraint failed".into()));
}

let mut inserts = self.inserts.lock().unwrap();
if let Some(existing) = inserts.iter().find(|message| message.id == row.id) {
let merged = Self::merged_row(existing, row);
self.updates.lock().unwrap().push((
row.id.clone(),
aionui_db::MessageRowUpdate {
content: Some(merged.content),
status: Some(merged.status),
hidden: Some(merged.hidden),
},
));
} else {
inserts.push(row.clone());
}
Ok(())
}
async fn update_message(&self, id: &str, updates: &aionui_db::MessageRowUpdate) -> Result<(), DbError> {
if self.not_found.load(Ordering::Acquire) {
return Err(DbError::NotFound(format!("Message '{id}' not found")));
Expand Down
Loading
Loading