From a2bd3218aa648ab79a3b3a3a6da28f2e68be8e3a Mon Sep 17 00:00:00 2001 From: Mohammad Julfikar Date: Tue, 20 Jan 2026 16:54:34 +0800 Subject: [PATCH 1/9] fix(dml): support subqueries in UPDATE SET clause - Add execute_scalar_subquery_for_update function to evaluate scalar subqueries during UPDATE operations - Add eval_expr_with_params_and_subqueries to convert.rs to handle subquery expressions when evaluating assignment values - Add helper functions for expression-based aggregates in subqueries: - find_expression_aggregate: detects aggregates with expression arguments - eval_expr_for_row: evaluates expressions against materialized rows - compute_aggregate_manually: computes aggregates when executor can't handle expression arguments (e.g., SUM(quantity * unit_price)) - Handle both TableScan and SecondaryIndexScan plan sources for subqueries - Pre-compute scalar subqueries before the main UPDATE loop to avoid repeated execution Fixes #23 Co-Authored-By: Claude Opus 4.5 --- src/database/convert.rs | 120 ++++++++ src/database/dml/update.rs | 594 ++++++++++++++++++++++++++++++++++++- 2 files changed, 709 insertions(+), 5 deletions(-) diff --git a/src/database/convert.rs b/src/database/convert.rs index 6307908..b5011c7 100644 --- a/src/database/convert.rs +++ b/src/database/convert.rs @@ -404,6 +404,126 @@ impl Database { Self::eval_literal_with_type(expr, target_type) } + pub(crate) fn eval_expr_with_params_and_subqueries( + expr: &crate::sql::ast::Expr<'_>, + target_type: Option<&crate::records::types::DataType>, + params: Option<&[OwnedValue]>, + param_idx: &mut usize, + scalar_subquery_results: &crate::sql::context::ScalarSubqueryResults, + ) -> Result { + use crate::sql::ast::{Expr, ParameterRef}; + + match expr { + Expr::Parameter(param_ref) => { + if let Some(params) = params { + let idx = match param_ref { + ParameterRef::Anonymous => { + let i = *param_idx; + *param_idx += 1; + i + } + ParameterRef::Positional(n) => (*n as usize).saturating_sub(1), + ParameterRef::Named(_) => { + let i = *param_idx; + *param_idx += 1; + i + } + }; + + if idx >= params.len() { + bail!( + "parameter index {} out of range (only {} parameters bound)", + idx + 1, + params.len() + ); + } + + Ok(params[idx].clone()) + } else { + bail!("parameter placeholder found but no parameters were bound") + } + } + Expr::Subquery(subq) => { + let key = std::ptr::from_ref(*subq) as usize; + scalar_subquery_results + .iter() + .find(|(k, _)| *k == key) + .map(|(_, v)| v.clone()) + .ok_or_else(|| eyre::eyre!("scalar subquery result not found in pre-computed results")) + } + Expr::BinaryOp { left, op, right } => { + let left_val = Self::eval_expr_with_params_and_subqueries( + left, + target_type, + params, + param_idx, + scalar_subquery_results, + )?; + let right_val = Self::eval_expr_with_params_and_subqueries( + right, + target_type, + params, + param_idx, + scalar_subquery_results, + )?; + + use crate::sql::ast::BinaryOperator; + match op { + BinaryOperator::Plus => match (&left_val, &right_val) { + (OwnedValue::Int(a), OwnedValue::Int(b)) => Ok(OwnedValue::Int(a + b)), + (OwnedValue::Float(a), OwnedValue::Float(b)) => Ok(OwnedValue::Float(a + b)), + (OwnedValue::Int(a), OwnedValue::Float(b)) => { + Ok(OwnedValue::Float(*a as f64 + b)) + } + (OwnedValue::Float(a), OwnedValue::Int(b)) => { + Ok(OwnedValue::Float(a + *b as f64)) + } + _ => bail!("unsupported types for addition in UPDATE SET"), + }, + BinaryOperator::Minus => match (&left_val, &right_val) { + (OwnedValue::Int(a), OwnedValue::Int(b)) => Ok(OwnedValue::Int(a - b)), + (OwnedValue::Float(a), OwnedValue::Float(b)) => Ok(OwnedValue::Float(a - b)), + (OwnedValue::Int(a), OwnedValue::Float(b)) => { + Ok(OwnedValue::Float(*a as f64 - b)) + } + (OwnedValue::Float(a), OwnedValue::Int(b)) => { + Ok(OwnedValue::Float(a - *b as f64)) + } + _ => bail!("unsupported types for subtraction in UPDATE SET"), + }, + BinaryOperator::Multiply => match (&left_val, &right_val) { + (OwnedValue::Int(a), OwnedValue::Int(b)) => Ok(OwnedValue::Int(a * b)), + (OwnedValue::Float(a), OwnedValue::Float(b)) => Ok(OwnedValue::Float(a * b)), + (OwnedValue::Int(a), OwnedValue::Float(b)) => { + Ok(OwnedValue::Float(*a as f64 * b)) + } + (OwnedValue::Float(a), OwnedValue::Int(b)) => { + Ok(OwnedValue::Float(a * *b as f64)) + } + _ => bail!("unsupported types for multiplication in UPDATE SET"), + }, + BinaryOperator::Divide => match (&left_val, &right_val) { + (OwnedValue::Int(a), OwnedValue::Int(b)) if *b != 0 => { + Ok(OwnedValue::Int(a / b)) + } + (OwnedValue::Float(a), OwnedValue::Float(b)) if *b != 0.0 => { + Ok(OwnedValue::Float(a / b)) + } + (OwnedValue::Int(a), OwnedValue::Float(b)) if *b != 0.0 => { + Ok(OwnedValue::Float(*a as f64 / b)) + } + (OwnedValue::Float(a), OwnedValue::Int(b)) if *b != 0 => { + Ok(OwnedValue::Float(a / *b as f64)) + } + _ => bail!("division by zero or unsupported types"), + }, + _ => Self::eval_literal_with_type(expr, target_type), + } + } + _ => Self::eval_literal_with_type(expr, target_type), + } + } + pub(crate) fn parse_json_string(s: &str) -> Result { let value = Self::parse_json_to_value(s.trim())?; let bytes = Self::jsonb_value_to_bytes(&value); diff --git a/src/database/dml/update.rs b/src/database/dml/update.rs index 7907428..2d21396 100644 --- a/src/database/dml/update.rs +++ b/src/database/dml/update.rs @@ -91,6 +91,7 @@ use crate::database::{Database, ExecuteResult}; use crate::mvcc::WriteEntry; use crate::records::RecordView; use crate::schema::table::{Constraint, IndexType}; +use crate::sql::context::ScalarSubqueryResults; use crate::sql::decoder::RecordDecoder; use crate::sql::executor::ExecutorRow; use crate::sql::predicate::CompiledPredicate; @@ -147,7 +148,8 @@ fn expr_contains_column_ref(expr: &crate::sql::ast::Expr) -> bool { Expr::InList { expr, list, .. } => { expr_contains_column_ref(expr) || list.iter().any(|e| expr_contains_column_ref(e)) } - Expr::InSubquery { .. } | Expr::Subquery(_) | Expr::Exists { .. } => true, + Expr::InSubquery { .. } | Expr::Exists { .. } => true, + Expr::Subquery(_) => false, Expr::IsDistinctFrom { left, right, .. } => { expr_contains_column_ref(left) || expr_contains_column_ref(right) } @@ -173,7 +175,562 @@ fn expr_contains_column_ref(expr: &crate::sql::ast::Expr) -> bool { } } +fn collect_scalar_subqueries_from_expr<'a>( + expr: &'a crate::sql::ast::Expr<'a>, + subqueries: &mut SmallVec<[&'a crate::sql::ast::SelectStmt<'a>; 4]>, +) { + use crate::sql::ast::Expr; + match expr { + Expr::Subquery(subq) => { + subqueries.push(subq); + } + Expr::BinaryOp { left, right, .. } => { + collect_scalar_subqueries_from_expr(left, subqueries); + collect_scalar_subqueries_from_expr(right, subqueries); + } + Expr::UnaryOp { expr, .. } => { + collect_scalar_subqueries_from_expr(expr, subqueries); + } + Expr::IsNull { expr, .. } => { + collect_scalar_subqueries_from_expr(expr, subqueries); + } + Expr::InList { expr, list, .. } => { + collect_scalar_subqueries_from_expr(expr, subqueries); + for item in list.iter() { + collect_scalar_subqueries_from_expr(item, subqueries); + } + } + Expr::Between { + expr, low, high, .. + } => { + collect_scalar_subqueries_from_expr(expr, subqueries); + collect_scalar_subqueries_from_expr(low, subqueries); + collect_scalar_subqueries_from_expr(high, subqueries); + } + Expr::Case { + operand, + conditions, + else_result, + } => { + if let Some(op) = operand { + collect_scalar_subqueries_from_expr(op, subqueries); + } + for cond in conditions.iter() { + collect_scalar_subqueries_from_expr(cond.condition, subqueries); + collect_scalar_subqueries_from_expr(cond.result, subqueries); + } + if let Some(else_e) = else_result { + collect_scalar_subqueries_from_expr(else_e, subqueries); + } + } + Expr::Cast { expr, .. } => { + collect_scalar_subqueries_from_expr(expr, subqueries); + } + Expr::Function(func) => { + if let crate::sql::ast::FunctionArgs::Args(args) = &func.args { + for arg in args.iter() { + collect_scalar_subqueries_from_expr(arg.value, subqueries); + } + } + } + _ => {} + } +} + +fn find_expression_aggregate<'a>( + op: &'a crate::sql::planner::PhysicalOperator<'a>, +) -> Option<( + &'a crate::sql::planner::AggregateFunction, + &'a crate::sql::ast::Expr<'a>, +)> { + use crate::sql::planner::PhysicalOperator; + + match op { + PhysicalOperator::HashAggregate(agg) => { + for agg_expr in agg.aggregates.iter() { + if let Some(arg) = agg_expr.argument { + if !matches!(arg, crate::sql::ast::Expr::Column(_)) { + return Some((&agg_expr.function, arg)); + } + } + } + find_expression_aggregate(agg.input) + } + PhysicalOperator::SortedAggregate(agg) => { + for agg_expr in agg.aggregates.iter() { + if let Some(arg) = agg_expr.argument { + if !matches!(arg, crate::sql::ast::Expr::Column(_)) { + return Some((&agg_expr.function, arg)); + } + } + } + find_expression_aggregate(agg.input) + } + PhysicalOperator::FilterExec(f) => find_expression_aggregate(f.input), + PhysicalOperator::ProjectExec(p) => find_expression_aggregate(p.input), + PhysicalOperator::SortExec(s) => find_expression_aggregate(s.input), + PhysicalOperator::LimitExec(l) => find_expression_aggregate(l.input), + _ => None, + } +} + +fn eval_expr_for_row( + expr: &crate::sql::ast::Expr<'_>, + row: &[OwnedValue], + column_map: &[(String, usize)], +) -> OwnedValue { + use crate::sql::ast::{BinaryOperator, Expr, Literal}; + + match expr { + Expr::Column(col_ref) => { + let col_name = col_ref.column.to_lowercase(); + let idx = column_map + .iter() + .find(|(name, _)| name == &col_name) + .map(|(_, idx)| *idx); + if let Some(i) = idx { + row.get(i).cloned().unwrap_or(OwnedValue::Null) + } else { + OwnedValue::Null + } + } + Expr::Literal(lit) => match lit { + Literal::Integer(s) => s.parse::().map(OwnedValue::Int).unwrap_or(OwnedValue::Null), + Literal::Float(s) => s.parse::().map(OwnedValue::Float).unwrap_or(OwnedValue::Null), + Literal::String(s) => OwnedValue::Text((*s).to_string()), + Literal::Boolean(b) => OwnedValue::Int(if *b { 1 } else { 0 }), + Literal::Null => OwnedValue::Null, + _ => OwnedValue::Null, + }, + Expr::BinaryOp { left, op, right } => { + let left_val = eval_expr_for_row(left, row, column_map); + let right_val = eval_expr_for_row(right, row, column_map); + match op { + BinaryOperator::Plus => match (&left_val, &right_val) { + (OwnedValue::Int(a), OwnedValue::Int(b)) => OwnedValue::Int(a + b), + (OwnedValue::Float(a), OwnedValue::Float(b)) => OwnedValue::Float(a + b), + (OwnedValue::Int(a), OwnedValue::Float(b)) => OwnedValue::Float(*a as f64 + b), + (OwnedValue::Float(a), OwnedValue::Int(b)) => OwnedValue::Float(a + *b as f64), + _ => OwnedValue::Null, + }, + BinaryOperator::Minus => match (&left_val, &right_val) { + (OwnedValue::Int(a), OwnedValue::Int(b)) => OwnedValue::Int(a - b), + (OwnedValue::Float(a), OwnedValue::Float(b)) => OwnedValue::Float(a - b), + (OwnedValue::Int(a), OwnedValue::Float(b)) => OwnedValue::Float(*a as f64 - b), + (OwnedValue::Float(a), OwnedValue::Int(b)) => OwnedValue::Float(a - *b as f64), + _ => OwnedValue::Null, + }, + BinaryOperator::Multiply => match (&left_val, &right_val) { + (OwnedValue::Int(a), OwnedValue::Int(b)) => OwnedValue::Int(a * b), + (OwnedValue::Float(a), OwnedValue::Float(b)) => OwnedValue::Float(a * b), + (OwnedValue::Int(a), OwnedValue::Float(b)) => OwnedValue::Float(*a as f64 * b), + (OwnedValue::Float(a), OwnedValue::Int(b)) => OwnedValue::Float(a * *b as f64), + _ => OwnedValue::Null, + }, + BinaryOperator::Divide => match (&left_val, &right_val) { + (OwnedValue::Int(a), OwnedValue::Int(b)) if *b != 0 => OwnedValue::Int(a / b), + (OwnedValue::Float(a), OwnedValue::Float(b)) if *b != 0.0 => { + OwnedValue::Float(a / b) + } + (OwnedValue::Int(a), OwnedValue::Float(b)) if *b != 0.0 => { + OwnedValue::Float(*a as f64 / b) + } + (OwnedValue::Float(a), OwnedValue::Int(b)) if *b != 0 => { + OwnedValue::Float(a / *b as f64) + } + _ => OwnedValue::Null, + }, + _ => OwnedValue::Null, + } + } + Expr::UnaryOp { op, expr } => { + let val = eval_expr_for_row(expr, row, column_map); + match op { + crate::sql::ast::UnaryOperator::Minus => match val { + OwnedValue::Int(i) => OwnedValue::Int(-i), + OwnedValue::Float(f) => OwnedValue::Float(-f), + _ => OwnedValue::Null, + }, + crate::sql::ast::UnaryOperator::Plus => val, + _ => OwnedValue::Null, + } + } + _ => OwnedValue::Null, + } +} + +fn compute_aggregate_manually( + agg_func: &crate::sql::planner::AggregateFunction, + expr: &crate::sql::ast::Expr<'_>, + rows: &[Vec], + column_map: &[(String, usize)], +) -> OwnedValue { + use crate::sql::planner::AggregateFunction; + + match agg_func { + AggregateFunction::Sum => { + let mut sum_int: i64 = 0; + let mut sum_float: f64 = 0.0; + let mut has_float = false; + + for row in rows { + let val = eval_expr_for_row(expr, row, column_map); + match val { + OwnedValue::Int(i) => sum_int += i, + OwnedValue::Float(f) => { + sum_float += f; + has_float = true; + } + _ => {} + } + } + + if has_float { + OwnedValue::Float(sum_float + sum_int as f64) + } else { + OwnedValue::Int(sum_int) + } + } + AggregateFunction::Avg => { + let mut sum: f64 = 0.0; + let mut count: usize = 0; + + for row in rows { + let val = eval_expr_for_row(expr, row, column_map); + match val { + OwnedValue::Int(i) => { + sum += i as f64; + count += 1; + } + OwnedValue::Float(f) => { + sum += f; + count += 1; + } + _ => {} + } + } + + if count > 0 { + OwnedValue::Float(sum / count as f64) + } else { + OwnedValue::Null + } + } + AggregateFunction::Count => OwnedValue::Int(rows.len() as i64), + AggregateFunction::Min => { + let mut min_val: Option = None; + + for row in rows { + let val = eval_expr_for_row(expr, row, column_map); + match (&min_val, &val) { + (None, v) if !matches!(v, OwnedValue::Null) => min_val = Some(val), + (Some(OwnedValue::Int(m)), OwnedValue::Int(i)) if *i < *m => { + min_val = Some(val) + } + (Some(OwnedValue::Float(m)), OwnedValue::Float(f)) if *f < *m => { + min_val = Some(val) + } + _ => {} + } + } + + min_val.unwrap_or(OwnedValue::Null) + } + AggregateFunction::Max => { + let mut max_val: Option = None; + + for row in rows { + let val = eval_expr_for_row(expr, row, column_map); + match (&max_val, &val) { + (None, v) if !matches!(v, OwnedValue::Null) => max_val = Some(val), + (Some(OwnedValue::Int(m)), OwnedValue::Int(i)) if *i > *m => { + max_val = Some(val) + } + (Some(OwnedValue::Float(m)), OwnedValue::Float(f)) if *f > *m => { + max_val = Some(val) + } + _ => {} + } + } + + max_val.unwrap_or(OwnedValue::Null) + } + } +} + impl Database { + fn execute_scalar_subquery_for_update<'a>( + subq: &'a crate::sql::ast::SelectStmt<'a>, + catalog: &crate::schema::catalog::Catalog, + file_manager: &mut crate::storage::FileManager, + arena: &'a Bump, + ) -> Result { + use crate::btree::BTreeReader; + use crate::database::query::{find_plan_source, PlanSource}; + use crate::records::RecordView; + use crate::sql::builder::ExecutorBuilder; + use crate::sql::context::ExecutionContext; + use crate::sql::executor::{Executor, MaterializedRowSource, StreamingBTreeSource}; + use crate::sql::planner::{Planner, ScanRange}; + use crate::types::create_record_schema; + + let planner = Planner::new(catalog, arena); + let stmt = crate::sql::ast::Statement::Select(subq); + let subq_plan = planner.create_physical_plan(&stmt)?; + + let plan_source = find_plan_source(subq_plan.root); + + fn read_root_page(storage: &crate::storage::MmapStorage) -> Result { + use crate::storage::TableFileHeader; + let page = storage.page(0)?; + Ok(TableFileHeader::from_bytes(page)?.root_page()) + } + + fn read_index_root_page(storage: &crate::storage::MmapStorage) -> Result { + use crate::storage::IndexFileHeader; + let page = storage.page(0)?; + Ok(IndexFileHeader::from_bytes(page)?.root_page()) + } + + match plan_source { + Some(PlanSource::TableScan(scan)) => { + let schema_name = scan.schema.unwrap_or(DEFAULT_SCHEMA); + let table_name = scan.table; + + let table_def = catalog.resolve_table_in_schema(scan.schema, table_name)?; + let column_types: Vec<_> = + table_def.columns().iter().map(|c| c.data_type()).collect(); + let columns = table_def.columns(); + + let storage_arc = file_manager.table_data(schema_name, table_name)?; + let storage = storage_arc.read(); + + let root_page = read_root_page(&storage)?; + + let column_map: Vec<(String, usize)> = table_def + .columns() + .iter() + .enumerate() + .map(|(i, c)| (c.name().to_lowercase(), i)) + .collect(); + + if let Some((agg_func, expr)) = find_expression_aggregate(subq_plan.root) { + let schema = create_record_schema(columns); + let table_reader = crate::btree::BTreeReader::new(&storage, root_page)?; + let mut cursor = table_reader.cursor_first()?; + let mut materialized_rows: Vec> = Vec::new(); + + while cursor.valid() { + let row_data = cursor.value()?; + let user_data = crate::database::dml::mvcc_helpers::get_user_data(row_data); + let record = RecordView::new(user_data, &schema)?; + let row_values = OwnedValue::extract_row_from_record(&record, columns)?; + materialized_rows.push(row_values); + if !cursor.advance()? { + break; + } + } + + let result = + compute_aggregate_manually(agg_func, expr, &materialized_rows, &column_map); + return Ok(result); + } + + let source = StreamingBTreeSource::from_btree_scan_with_projections( + &storage, + root_page, + column_types, + None, + )?; + + let ctx = ExecutionContext::new(arena); + let builder = ExecutorBuilder::new(&ctx); + let mut executor = + builder.build_with_source_and_column_map(&subq_plan, source, &column_map)?; + + executor.open()?; + let result = if let Some(row) = executor.next()? { + row.values + .first() + .map(OwnedValue::from) + .unwrap_or(OwnedValue::Null) + } else { + OwnedValue::Null + }; + executor.close()?; + Ok(result) + } + Some(PlanSource::SecondaryIndexScan(scan)) => { + let schema_name = scan.schema.unwrap_or(DEFAULT_SCHEMA); + let table_name = scan.table; + let index_name = scan.index_name; + + let table_def = scan.table_def.ok_or_else(|| { + eyre::eyre!("SecondaryIndexScan missing table_def for {}", table_name) + })?; + + let columns = table_def.columns(); + let schema = create_record_schema(columns); + + let row_id_suffix_len = if scan.is_unique_index { 0 } else { 8 }; + + let row_keys: Vec<[u8; 8]> = { + let index_storage_arc = + file_manager.index_data(schema_name, table_name, index_name)?; + let index_storage = index_storage_arc.read(); + + let root_page = read_index_root_page(&index_storage)?; + let index_reader = BTreeReader::new(&index_storage, root_page)?; + + let mut keys = Vec::new(); + + match &scan.key_range { + Some(ScanRange::PrefixScan { prefix }) => { + let mut cursor = index_reader.cursor_seek(prefix)?; + if cursor.valid() { + loop { + let index_key = cursor.key()?; + if !index_key.starts_with(prefix) { + break; + } + + let row_id_bytes = if scan.is_unique_index { + cursor.value()? + } else { + &index_key[index_key.len().saturating_sub(row_id_suffix_len)..] + }; + + if row_id_bytes.len() == 8 { + let row_key: [u8; 8] = row_id_bytes.try_into().unwrap(); + keys.push(row_key); + } + + if !cursor.advance()? { + break; + } + } + } + } + Some(ScanRange::RangeScan { start, end }) => { + let mut cursor = if let Some(start_key) = start { + index_reader.cursor_seek(start_key)? + } else { + index_reader.cursor_first()? + }; + + if cursor.valid() { + loop { + let index_key = cursor.key()?; + if let Some(end_key) = end { + if index_key >= *end_key { + break; + } + } + + let row_id_bytes = if scan.is_unique_index { + cursor.value()? + } else { + &index_key[index_key.len().saturating_sub(row_id_suffix_len)..] + }; + + if row_id_bytes.len() == 8 { + let row_key: [u8; 8] = row_id_bytes.try_into().unwrap(); + keys.push(row_key); + } + + if !cursor.advance()? { + break; + } + } + } + } + Some(ScanRange::FullScan) | None => { + let mut cursor = index_reader.cursor_first()?; + if cursor.valid() { + loop { + let index_key = cursor.key()?; + let row_id_bytes = if scan.is_unique_index { + cursor.value()? + } else { + &index_key[index_key.len().saturating_sub(row_id_suffix_len)..] + }; + + if row_id_bytes.len() == 8 { + let row_key: [u8; 8] = row_id_bytes.try_into().unwrap(); + keys.push(row_key); + } + + if !cursor.advance()? { + break; + } + } + } + } + } + keys + }; + + let column_map: Vec<(String, usize)> = table_def + .columns() + .iter() + .enumerate() + .map(|(idx, col)| (col.name().to_lowercase(), idx)) + .collect(); + + let mut materialized_rows: Vec> = + Vec::with_capacity(row_keys.len()); + + { + let table_storage_arc = file_manager.table_data(schema_name, table_name)?; + let table_storage = table_storage_arc.read(); + + let root_page = read_root_page(&table_storage)?; + let table_reader = BTreeReader::new(&table_storage, root_page)?; + + for row_key in &row_keys { + if let Some(row_data) = table_reader.get(row_key)? { + let user_data = + crate::database::dml::mvcc_helpers::get_user_data(row_data); + let record = RecordView::new(user_data, &schema)?; + let row_values = OwnedValue::extract_row_from_record(&record, columns)?; + materialized_rows.push(row_values); + } + } + } + + if let Some((agg_func, expr)) = find_expression_aggregate(subq_plan.root) { + let result = + compute_aggregate_manually(agg_func, expr, &materialized_rows, &column_map); + return Ok(result); + } + + let materialized_source = MaterializedRowSource::new(materialized_rows); + + let ctx = ExecutionContext::new(arena); + let builder = ExecutorBuilder::new(&ctx); + + let mut executor = builder.build_with_source_and_column_map( + &subq_plan, + materialized_source, + &column_map, + )?; + + executor.open()?; + let result = if let Some(row) = executor.next()? { + row.values + .first() + .map(OwnedValue::from) + .unwrap_or(OwnedValue::Null) + } else { + OwnedValue::Null + }; + executor.close()?; + Ok(result) + } + _ => Ok(OwnedValue::Null), + } + } + pub(crate) fn execute_update( &self, update: &crate::sql::ast::UpdateStmt<'_>, @@ -325,11 +882,10 @@ impl Database { }) .collect(); - drop(catalog_guard); - let schema = create_record_schema(&columns); if let Some(from_tables) = from_tables_data { + drop(catalog_guard); return self.execute_update_with_from( update, arena, @@ -355,6 +911,29 @@ impl Database { }) .collect(); + let mut scalar_subquery_results: ScalarSubqueryResults = ScalarSubqueryResults::new(); + { + let mut subqueries: SmallVec<[&crate::sql::ast::SelectStmt<'_>; 4]> = SmallVec::new(); + for (_, value_expr) in &assignment_indices { + collect_scalar_subqueries_from_expr(value_expr, &mut subqueries); + } + + if !subqueries.is_empty() { + let mut fm_guard = self.shared.file_manager.write(); + let fm = fm_guard.as_mut().unwrap(); + for subq in subqueries { + let key = std::ptr::from_ref(subq) as usize; + if !scalar_subquery_results.iter().any(|(k, _)| *k == key) { + let result = + Self::execute_scalar_subquery_for_update(subq, catalog, fm, arena)?; + scalar_subquery_results.push((key, result)); + } + } + } + } + + drop(catalog_guard); + let mut file_manager_guard = self.shared.file_manager.write(); let file_manager = file_manager_guard.as_mut().unwrap(); let storage_arc = file_manager.table_data_mut(schema_name, table_name)?; @@ -478,8 +1057,13 @@ impl Database { deferred_assignments.push((*col_idx, assign_idx)); param_idx += count_params_in_expr(value_expr); } else { - let val = - Self::eval_expr_with_params(value_expr, column_types.get(*col_idx), Some(params), &mut param_idx)?; + let val = Self::eval_expr_with_params_and_subqueries( + value_expr, + column_types.get(*col_idx), + Some(params), + &mut param_idx, + &scalar_subquery_results, + )?; precomputed_assignments.push((*col_idx, val)); } } From c1fdab9f5683eb306f771f6319079c07ab22bda4 Mon Sep 17 00:00:00 2001 From: Mohammad Julfikar Date: Tue, 20 Jan 2026 20:52:39 +0800 Subject: [PATCH 2/9] fix(dml): address code review issues for UPDATE subquery support - Replace row materialization with streaming aggregation using StreamingAggregateState to compute aggregates without storing all rows - Add eval_expr_for_record_streaming with HashMap for O(1) column lookups with proper type information - Remove duplicate inner read_root_page/read_index_root_page functions that shadowed outer helpers - Add proper error context per ERRORS.md with wrap_err_with for all operations including table resolution, storage access, and root page reads - Replace .unwrap() calls with proper error handling via try_into().map_err() - Extract eval_binary_op shared helper function for binary operations - Consolidate read_root_page and read_index_root_page helpers at module level Addresses code review feedback on PR #40. Co-Authored-By: Claude Opus 4.5 --- src/database/dml/update.rs | 578 ++++++++++++++++++++++--------------- 1 file changed, 340 insertions(+), 238 deletions(-) diff --git a/src/database/dml/update.rs b/src/database/dml/update.rs index 2d21396..23714a5 100644 --- a/src/database/dml/update.rs +++ b/src/database/dml/update.rs @@ -95,11 +95,11 @@ use crate::sql::context::ScalarSubqueryResults; use crate::sql::decoder::RecordDecoder; use crate::sql::executor::ExecutorRow; use crate::sql::predicate::CompiledPredicate; -use crate::storage::{IndexFileHeader, DEFAULT_SCHEMA}; +use crate::storage::{IndexFileHeader, TableFileHeader, DEFAULT_SCHEMA}; use crate::types::{create_record_schema, OwnedValue, Value}; use bumpalo::Bump; use eyre::{bail, Result, WrapErr}; -use hashbrown::HashSet; +use hashbrown::{HashMap, HashSet}; use smallvec::SmallVec; use std::borrow::Cow; use std::sync::atomic::Ordering; @@ -274,22 +274,53 @@ fn find_expression_aggregate<'a>( } } -fn eval_expr_for_row( +fn eval_binary_op(left: &OwnedValue, op: &crate::sql::ast::BinaryOperator, right: &OwnedValue) -> OwnedValue { + use crate::sql::ast::BinaryOperator; + match op { + BinaryOperator::Plus => match (left, right) { + (OwnedValue::Int(a), OwnedValue::Int(b)) => OwnedValue::Int(a + b), + (OwnedValue::Float(a), OwnedValue::Float(b)) => OwnedValue::Float(a + b), + (OwnedValue::Int(a), OwnedValue::Float(b)) => OwnedValue::Float(*a as f64 + b), + (OwnedValue::Float(a), OwnedValue::Int(b)) => OwnedValue::Float(a + *b as f64), + _ => OwnedValue::Null, + }, + BinaryOperator::Minus => match (left, right) { + (OwnedValue::Int(a), OwnedValue::Int(b)) => OwnedValue::Int(a - b), + (OwnedValue::Float(a), OwnedValue::Float(b)) => OwnedValue::Float(a - b), + (OwnedValue::Int(a), OwnedValue::Float(b)) => OwnedValue::Float(*a as f64 - b), + (OwnedValue::Float(a), OwnedValue::Int(b)) => OwnedValue::Float(a - *b as f64), + _ => OwnedValue::Null, + }, + BinaryOperator::Multiply => match (left, right) { + (OwnedValue::Int(a), OwnedValue::Int(b)) => OwnedValue::Int(a * b), + (OwnedValue::Float(a), OwnedValue::Float(b)) => OwnedValue::Float(a * b), + (OwnedValue::Int(a), OwnedValue::Float(b)) => OwnedValue::Float(*a as f64 * b), + (OwnedValue::Float(a), OwnedValue::Int(b)) => OwnedValue::Float(a * *b as f64), + _ => OwnedValue::Null, + }, + BinaryOperator::Divide => match (left, right) { + (OwnedValue::Int(a), OwnedValue::Int(b)) if *b != 0 => OwnedValue::Int(a / b), + (OwnedValue::Float(a), OwnedValue::Float(b)) if *b != 0.0 => OwnedValue::Float(a / b), + (OwnedValue::Int(a), OwnedValue::Float(b)) if *b != 0.0 => OwnedValue::Float(*a as f64 / b), + (OwnedValue::Float(a), OwnedValue::Int(b)) if *b != 0 => OwnedValue::Float(a / *b as f64), + _ => OwnedValue::Null, + }, + _ => OwnedValue::Null, + } +} + +fn eval_expr_for_record_streaming( expr: &crate::sql::ast::Expr<'_>, - row: &[OwnedValue], - column_map: &[(String, usize)], + record: &RecordView<'_>, + column_info: &HashMap, ) -> OwnedValue { - use crate::sql::ast::{BinaryOperator, Expr, Literal}; + use crate::sql::ast::{Expr, Literal}; match expr { Expr::Column(col_ref) => { - let col_name = col_ref.column.to_lowercase(); - let idx = column_map - .iter() - .find(|(name, _)| name == &col_name) - .map(|(_, idx)| *idx); - if let Some(i) = idx { - row.get(i).cloned().unwrap_or(OwnedValue::Null) + let col_lower = col_ref.column.to_lowercase(); + if let Some(&(idx, data_type)) = column_info.get(&col_lower) { + OwnedValue::from_record_column(record, idx, data_type).unwrap_or(OwnedValue::Null) } else { OwnedValue::Null } @@ -303,48 +334,12 @@ fn eval_expr_for_row( _ => OwnedValue::Null, }, Expr::BinaryOp { left, op, right } => { - let left_val = eval_expr_for_row(left, row, column_map); - let right_val = eval_expr_for_row(right, row, column_map); - match op { - BinaryOperator::Plus => match (&left_val, &right_val) { - (OwnedValue::Int(a), OwnedValue::Int(b)) => OwnedValue::Int(a + b), - (OwnedValue::Float(a), OwnedValue::Float(b)) => OwnedValue::Float(a + b), - (OwnedValue::Int(a), OwnedValue::Float(b)) => OwnedValue::Float(*a as f64 + b), - (OwnedValue::Float(a), OwnedValue::Int(b)) => OwnedValue::Float(a + *b as f64), - _ => OwnedValue::Null, - }, - BinaryOperator::Minus => match (&left_val, &right_val) { - (OwnedValue::Int(a), OwnedValue::Int(b)) => OwnedValue::Int(a - b), - (OwnedValue::Float(a), OwnedValue::Float(b)) => OwnedValue::Float(a - b), - (OwnedValue::Int(a), OwnedValue::Float(b)) => OwnedValue::Float(*a as f64 - b), - (OwnedValue::Float(a), OwnedValue::Int(b)) => OwnedValue::Float(a - *b as f64), - _ => OwnedValue::Null, - }, - BinaryOperator::Multiply => match (&left_val, &right_val) { - (OwnedValue::Int(a), OwnedValue::Int(b)) => OwnedValue::Int(a * b), - (OwnedValue::Float(a), OwnedValue::Float(b)) => OwnedValue::Float(a * b), - (OwnedValue::Int(a), OwnedValue::Float(b)) => OwnedValue::Float(*a as f64 * b), - (OwnedValue::Float(a), OwnedValue::Int(b)) => OwnedValue::Float(a * *b as f64), - _ => OwnedValue::Null, - }, - BinaryOperator::Divide => match (&left_val, &right_val) { - (OwnedValue::Int(a), OwnedValue::Int(b)) if *b != 0 => OwnedValue::Int(a / b), - (OwnedValue::Float(a), OwnedValue::Float(b)) if *b != 0.0 => { - OwnedValue::Float(a / b) - } - (OwnedValue::Int(a), OwnedValue::Float(b)) if *b != 0.0 => { - OwnedValue::Float(*a as f64 / b) - } - (OwnedValue::Float(a), OwnedValue::Int(b)) if *b != 0 => { - OwnedValue::Float(a / *b as f64) - } - _ => OwnedValue::Null, - }, - _ => OwnedValue::Null, - } + let left_val = eval_expr_for_record_streaming(left, record, column_info); + let right_val = eval_expr_for_record_streaming(right, record, column_info); + eval_binary_op(&left_val, op, &right_val) } - Expr::UnaryOp { op, expr } => { - let val = eval_expr_for_row(expr, row, column_map); + Expr::UnaryOp { op, expr: inner } => { + let val = eval_expr_for_record_streaming(inner, record, column_info); match op { crate::sql::ast::UnaryOperator::Minus => match val { OwnedValue::Int(i) => OwnedValue::Int(-i), @@ -359,105 +354,126 @@ fn eval_expr_for_row( } } -fn compute_aggregate_manually( - agg_func: &crate::sql::planner::AggregateFunction, - expr: &crate::sql::ast::Expr<'_>, - rows: &[Vec], - column_map: &[(String, usize)], -) -> OwnedValue { - use crate::sql::planner::AggregateFunction; - - match agg_func { - AggregateFunction::Sum => { - let mut sum_int: i64 = 0; - let mut sum_float: f64 = 0.0; - let mut has_float = false; - - for row in rows { - let val = eval_expr_for_row(expr, row, column_map); - match val { - OwnedValue::Int(i) => sum_int += i, - OwnedValue::Float(f) => { - sum_float += f; - has_float = true; - } - _ => {} - } - } +struct StreamingAggregateState { + sum_int: i64, + sum_float: f64, + has_float: bool, + count: usize, + min_int: Option, + min_float: Option, + max_int: Option, + max_float: Option, +} - if has_float { - OwnedValue::Float(sum_float + sum_int as f64) - } else { - OwnedValue::Int(sum_int) +impl StreamingAggregateState { + fn new() -> Self { + Self { + sum_int: 0, + sum_float: 0.0, + has_float: false, + count: 0, + min_int: None, + min_float: None, + max_int: None, + max_float: None, + } + } + + fn update(&mut self, value: &OwnedValue) { + match value { + OwnedValue::Int(i) => { + self.sum_int += i; + self.count += 1; + self.min_int = Some(self.min_int.map(|m| m.min(*i)).unwrap_or(*i)); + self.max_int = Some(self.max_int.map(|m| m.max(*i)).unwrap_or(*i)); + } + OwnedValue::Float(f) => { + self.sum_float += f; + self.has_float = true; + self.count += 1; + self.min_float = Some(self.min_float.map(|m| m.min(*f)).unwrap_or(*f)); + self.max_float = Some(self.max_float.map(|m| m.max(*f)).unwrap_or(*f)); } + _ => {} } - AggregateFunction::Avg => { - let mut sum: f64 = 0.0; - let mut count: usize = 0; - - for row in rows { - let val = eval_expr_for_row(expr, row, column_map); - match val { - OwnedValue::Int(i) => { - sum += i as f64; - count += 1; - } - OwnedValue::Float(f) => { - sum += f; - count += 1; - } - _ => {} + } + + fn finalize(&self, agg_func: &crate::sql::planner::AggregateFunction) -> OwnedValue { + use crate::sql::planner::AggregateFunction; + + match agg_func { + AggregateFunction::Sum => { + if self.has_float { + OwnedValue::Float(self.sum_float + self.sum_int as f64) + } else if self.count > 0 { + OwnedValue::Int(self.sum_int) + } else { + OwnedValue::Null } } - - if count > 0 { - OwnedValue::Float(sum / count as f64) - } else { - OwnedValue::Null + AggregateFunction::Avg => { + if self.count > 0 { + let total = self.sum_float + self.sum_int as f64; + OwnedValue::Float(total / self.count as f64) + } else { + OwnedValue::Null + } } - } - AggregateFunction::Count => OwnedValue::Int(rows.len() as i64), - AggregateFunction::Min => { - let mut min_val: Option = None; - - for row in rows { - let val = eval_expr_for_row(expr, row, column_map); - match (&min_val, &val) { - (None, v) if !matches!(v, OwnedValue::Null) => min_val = Some(val), - (Some(OwnedValue::Int(m)), OwnedValue::Int(i)) if *i < *m => { - min_val = Some(val) - } - (Some(OwnedValue::Float(m)), OwnedValue::Float(f)) if *f < *m => { - min_val = Some(val) + AggregateFunction::Count => OwnedValue::Int(self.count as i64), + AggregateFunction::Min => { + if self.has_float { + match (self.min_int, self.min_float) { + (Some(i), Some(f)) => { + if (i as f64) < f { + OwnedValue::Int(i) + } else { + OwnedValue::Float(f) + } + } + (None, Some(f)) => OwnedValue::Float(f), + (Some(i), None) => OwnedValue::Int(i), + (None, None) => OwnedValue::Null, } - _ => {} + } else { + self.min_int.map(OwnedValue::Int).unwrap_or(OwnedValue::Null) } } - - min_val.unwrap_or(OwnedValue::Null) - } - AggregateFunction::Max => { - let mut max_val: Option = None; - - for row in rows { - let val = eval_expr_for_row(expr, row, column_map); - match (&max_val, &val) { - (None, v) if !matches!(v, OwnedValue::Null) => max_val = Some(val), - (Some(OwnedValue::Int(m)), OwnedValue::Int(i)) if *i > *m => { - max_val = Some(val) - } - (Some(OwnedValue::Float(m)), OwnedValue::Float(f)) if *f > *m => { - max_val = Some(val) + AggregateFunction::Max => { + if self.has_float { + match (self.max_int, self.max_float) { + (Some(i), Some(f)) => { + if (i as f64) > f { + OwnedValue::Int(i) + } else { + OwnedValue::Float(f) + } + } + (None, Some(f)) => OwnedValue::Float(f), + (Some(i), None) => OwnedValue::Int(i), + (None, None) => OwnedValue::Null, } - _ => {} + } else { + self.max_int.map(OwnedValue::Int).unwrap_or(OwnedValue::Null) } } - - max_val.unwrap_or(OwnedValue::Null) } } } +fn read_root_page(storage: &crate::storage::MmapStorage) -> Result { + let page = storage.page(0)?; + TableFileHeader::from_bytes(page) + .map(|h| h.root_page()) + .wrap_err("failed to read table file header for root page") +} + +fn read_index_root_page(storage: &crate::storage::MmapStorage) -> Result { + let page = storage.page(0)?; + IndexFileHeader::from_bytes(page) + .map(|h| h.root_page()) + .wrap_err("failed to read index file header for root page") +} + impl Database { fn execute_scalar_subquery_for_update<'a>( subq: &'a crate::sql::ast::SelectStmt<'a>, @@ -476,36 +492,35 @@ impl Database { let planner = Planner::new(catalog, arena); let stmt = crate::sql::ast::Statement::Select(subq); - let subq_plan = planner.create_physical_plan(&stmt)?; + let subq_plan = planner.create_physical_plan(&stmt) + .wrap_err("failed to create physical plan for scalar subquery")?; let plan_source = find_plan_source(subq_plan.root); - fn read_root_page(storage: &crate::storage::MmapStorage) -> Result { - use crate::storage::TableFileHeader; - let page = storage.page(0)?; - Ok(TableFileHeader::from_bytes(page)?.root_page()) - } - - fn read_index_root_page(storage: &crate::storage::MmapStorage) -> Result { - use crate::storage::IndexFileHeader; - let page = storage.page(0)?; - Ok(IndexFileHeader::from_bytes(page)?.root_page()) - } - match plan_source { Some(PlanSource::TableScan(scan)) => { let schema_name = scan.schema.unwrap_or(DEFAULT_SCHEMA); let table_name = scan.table; - let table_def = catalog.resolve_table_in_schema(scan.schema, table_name)?; + let table_def = catalog.resolve_table_in_schema(scan.schema, table_name) + .wrap_err_with(|| format!("failed to resolve table '{}' in scalar subquery", table_name))?; let column_types: Vec<_> = table_def.columns().iter().map(|c| c.data_type()).collect(); let columns = table_def.columns(); - let storage_arc = file_manager.table_data(schema_name, table_name)?; + let storage_arc = file_manager.table_data(schema_name, table_name) + .wrap_err_with(|| format!("failed to open table data for '{}' in scalar subquery", table_name))?; let storage = storage_arc.read(); - let root_page = read_root_page(&storage)?; + let root_page = read_root_page(&storage) + .wrap_err_with(|| format!("failed to read root page for table '{}'", table_name))?; + + let column_info: HashMap = table_def + .columns() + .iter() + .enumerate() + .map(|(i, c)| (c.name().to_lowercase(), (i, c.data_type()))) + .collect(); let column_map: Vec<(String, usize)> = table_def .columns() @@ -516,24 +531,23 @@ impl Database { if let Some((agg_func, expr)) = find_expression_aggregate(subq_plan.root) { let schema = create_record_schema(columns); - let table_reader = crate::btree::BTreeReader::new(&storage, root_page)?; + let table_reader = crate::btree::BTreeReader::new(&storage, root_page) + .wrap_err_with(|| format!("failed to create BTreeReader for table '{}'", table_name))?; let mut cursor = table_reader.cursor_first()?; - let mut materialized_rows: Vec> = Vec::new(); + let mut agg_state = StreamingAggregateState::new(); while cursor.valid() { let row_data = cursor.value()?; let user_data = crate::database::dml::mvcc_helpers::get_user_data(row_data); let record = RecordView::new(user_data, &schema)?; - let row_values = OwnedValue::extract_row_from_record(&record, columns)?; - materialized_rows.push(row_values); + let expr_value = eval_expr_for_record_streaming(expr, &record, &column_info); + agg_state.update(&expr_value); if !cursor.advance()? { break; } } - let result = - compute_aggregate_manually(agg_func, expr, &materialized_rows, &column_map); - return Ok(result); + return Ok(agg_state.finalize(agg_func)); } let source = StreamingBTreeSource::from_btree_scan_with_projections( @@ -566,49 +580,80 @@ impl Database { let index_name = scan.index_name; let table_def = scan.table_def.ok_or_else(|| { - eyre::eyre!("SecondaryIndexScan missing table_def for {}", table_name) + eyre::eyre!("SecondaryIndexScan missing table_def for table '{}' in scalar subquery", table_name) })?; let columns = table_def.columns(); let schema = create_record_schema(columns); + let column_info: HashMap = table_def + .columns() + .iter() + .enumerate() + .map(|(i, c)| (c.name().to_lowercase(), (i, c.data_type()))) + .collect(); + + let column_map: Vec<(String, usize)> = table_def + .columns() + .iter() + .enumerate() + .map(|(idx, col)| (col.name().to_lowercase(), idx)) + .collect(); + let row_id_suffix_len = if scan.is_unique_index { 0 } else { 8 }; - let row_keys: Vec<[u8; 8]> = { - let index_storage_arc = - file_manager.index_data(schema_name, table_name, index_name)?; - let index_storage = index_storage_arc.read(); + if let Some((agg_func, expr)) = find_expression_aggregate(subq_plan.root) { + let mut agg_state = StreamingAggregateState::new(); - let root_page = read_index_root_page(&index_storage)?; - let index_reader = BTreeReader::new(&index_storage, root_page)?; + let index_storage_arc = file_manager.index_data(schema_name, table_name, index_name) + .wrap_err_with(|| format!("failed to open index '{}' for table '{}'", index_name, table_name))?; + let index_storage = index_storage_arc.read(); + let index_root = read_index_root_page(&index_storage) + .wrap_err_with(|| format!("failed to read root page for index '{}'", index_name))?; + let index_reader = BTreeReader::new(&index_storage, index_root)?; - let mut keys = Vec::new(); + let table_storage_arc = file_manager.table_data(schema_name, table_name) + .wrap_err_with(|| format!("failed to open table data for '{}' in scalar subquery", table_name))?; + let table_storage = table_storage_arc.read(); + let table_root = read_root_page(&table_storage) + .wrap_err_with(|| format!("failed to read root page for table '{}'", table_name))?; + let table_reader = BTreeReader::new(&table_storage, table_root)?; + + let process_index_cursor = |cursor: &mut crate::btree::Cursor<'_, crate::storage::MmapStorage>| -> Result> { + let index_key = cursor.key()?; + let row_id_bytes = if scan.is_unique_index { + cursor.value()? + } else { + &index_key[index_key.len().saturating_sub(row_id_suffix_len)..] + }; + if row_id_bytes.len() == 8 { + let row_key: [u8; 8] = row_id_bytes.try_into() + .map_err(|_| eyre::eyre!("invalid row key in index '{}': expected 8 bytes", index_name))?; + Ok(Some(row_key)) + } else { + Ok(None) + } + }; match &scan.key_range { Some(ScanRange::PrefixScan { prefix }) => { let mut cursor = index_reader.cursor_seek(prefix)?; - if cursor.valid() { - loop { - let index_key = cursor.key()?; - if !index_key.starts_with(prefix) { - break; - } - - let row_id_bytes = if scan.is_unique_index { - cursor.value()? - } else { - &index_key[index_key.len().saturating_sub(row_id_suffix_len)..] - }; - - if row_id_bytes.len() == 8 { - let row_key: [u8; 8] = row_id_bytes.try_into().unwrap(); - keys.push(row_key); - } - - if !cursor.advance()? { - break; + while cursor.valid() { + let index_key = cursor.key()?; + if !index_key.starts_with(prefix) { + break; + } + if let Some(row_key) = process_index_cursor(&mut cursor)? { + if let Some(row_data) = table_reader.get(&row_key)? { + let user_data = crate::database::dml::mvcc_helpers::get_user_data(row_data); + let record = RecordView::new(user_data, &schema)?; + let expr_value = eval_expr_for_record_streaming(expr, &record, &column_info); + agg_state.update(&expr_value); } } + if !cursor.advance()? { + break; + } } } Some(ScanRange::RangeScan { start, end }) => { @@ -617,74 +662,137 @@ impl Database { } else { index_reader.cursor_first()? }; - - if cursor.valid() { - loop { - let index_key = cursor.key()?; - if let Some(end_key) = end { - if index_key >= *end_key { - break; - } - } - - let row_id_bytes = if scan.is_unique_index { - cursor.value()? - } else { - &index_key[index_key.len().saturating_sub(row_id_suffix_len)..] - }; - - if row_id_bytes.len() == 8 { - let row_key: [u8; 8] = row_id_bytes.try_into().unwrap(); - keys.push(row_key); - } - - if !cursor.advance()? { + while cursor.valid() { + let index_key = cursor.key()?; + if let Some(end_key) = end { + if index_key >= *end_key { break; } } + if let Some(row_key) = process_index_cursor(&mut cursor)? { + if let Some(row_data) = table_reader.get(&row_key)? { + let user_data = crate::database::dml::mvcc_helpers::get_user_data(row_data); + let record = RecordView::new(user_data, &schema)?; + let expr_value = eval_expr_for_record_streaming(expr, &record, &column_info); + agg_state.update(&expr_value); + } + } + if !cursor.advance()? { + break; + } } } Some(ScanRange::FullScan) | None => { let mut cursor = index_reader.cursor_first()?; - if cursor.valid() { - loop { - let index_key = cursor.key()?; - let row_id_bytes = if scan.is_unique_index { - cursor.value()? - } else { - &index_key[index_key.len().saturating_sub(row_id_suffix_len)..] - }; - - if row_id_bytes.len() == 8 { - let row_key: [u8; 8] = row_id_bytes.try_into().unwrap(); - keys.push(row_key); + while cursor.valid() { + if let Some(row_key) = process_index_cursor(&mut cursor)? { + if let Some(row_data) = table_reader.get(&row_key)? { + let user_data = crate::database::dml::mvcc_helpers::get_user_data(row_data); + let record = RecordView::new(user_data, &schema)?; + let expr_value = eval_expr_for_record_streaming(expr, &record, &column_info); + agg_state.update(&expr_value); } + } + if !cursor.advance()? { + break; + } + } + } + } + + return Ok(agg_state.finalize(agg_func)); + } - if !cursor.advance()? { + let row_keys: Vec<[u8; 8]> = { + let index_storage_arc = + file_manager.index_data(schema_name, table_name, index_name) + .wrap_err_with(|| format!("failed to open index '{}' for table '{}'", index_name, table_name))?; + let index_storage = index_storage_arc.read(); + + let root_page = read_index_root_page(&index_storage) + .wrap_err_with(|| format!("failed to read root page for index '{}'", index_name))?; + let index_reader = BTreeReader::new(&index_storage, root_page)?; + + let mut keys = Vec::new(); + + let extract_row_key = |cursor: &mut crate::btree::Cursor<'_, crate::storage::MmapStorage>| -> Result> { + let index_key = cursor.key()?; + let row_id_bytes = if scan.is_unique_index { + cursor.value()? + } else { + &index_key[index_key.len().saturating_sub(row_id_suffix_len)..] + }; + if row_id_bytes.len() == 8 { + let row_key: [u8; 8] = row_id_bytes.try_into() + .map_err(|_| eyre::eyre!("invalid row key in index '{}': expected 8 bytes", index_name))?; + Ok(Some(row_key)) + } else { + Ok(None) + } + }; + + match &scan.key_range { + Some(ScanRange::PrefixScan { prefix }) => { + let mut cursor = index_reader.cursor_seek(prefix)?; + while cursor.valid() { + let index_key = cursor.key()?; + if !index_key.starts_with(prefix) { + break; + } + if let Some(row_key) = extract_row_key(&mut cursor)? { + keys.push(row_key); + } + if !cursor.advance()? { + break; + } + } + } + Some(ScanRange::RangeScan { start, end }) => { + let mut cursor = if let Some(start_key) = start { + index_reader.cursor_seek(start_key)? + } else { + index_reader.cursor_first()? + }; + while cursor.valid() { + let index_key = cursor.key()?; + if let Some(end_key) = end { + if index_key >= *end_key { break; } } + if let Some(row_key) = extract_row_key(&mut cursor)? { + keys.push(row_key); + } + if !cursor.advance()? { + break; + } + } + } + Some(ScanRange::FullScan) | None => { + let mut cursor = index_reader.cursor_first()?; + while cursor.valid() { + if let Some(row_key) = extract_row_key(&mut cursor)? { + keys.push(row_key); + } + if !cursor.advance()? { + break; + } } } } keys }; - let column_map: Vec<(String, usize)> = table_def - .columns() - .iter() - .enumerate() - .map(|(idx, col)| (col.name().to_lowercase(), idx)) - .collect(); - let mut materialized_rows: Vec> = Vec::with_capacity(row_keys.len()); { - let table_storage_arc = file_manager.table_data(schema_name, table_name)?; + let table_storage_arc = file_manager.table_data(schema_name, table_name) + .wrap_err_with(|| format!("failed to open table data for '{}' in scalar subquery", table_name))?; let table_storage = table_storage_arc.read(); - let root_page = read_root_page(&table_storage)?; + let root_page = read_root_page(&table_storage) + .wrap_err_with(|| format!("failed to read root page for table '{}'", table_name))?; let table_reader = BTreeReader::new(&table_storage, root_page)?; for row_key in &row_keys { @@ -698,12 +806,6 @@ impl Database { } } - if let Some((agg_func, expr)) = find_expression_aggregate(subq_plan.root) { - let result = - compute_aggregate_manually(agg_func, expr, &materialized_rows, &column_map); - return Ok(result); - } - let materialized_source = MaterializedRowSource::new(materialized_rows); let ctx = ExecutionContext::new(arena); From 7eafd8ce34937efcf33e4022b1b8d6eae44073f6 Mon Sep 17 00:00:00 2001 From: Mohammad Julfikar Date: Tue, 20 Jan 2026 21:23:29 +0800 Subject: [PATCH 3/9] fix(update): address second code review round MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix integer overflow risk using saturating_add in StreamingAggregateState - Fix mixed int/float comparison with proper compare_int_float function that handles values outside f64's exact range (±2^53) - Fix allocation violations in SecondaryIndexScan non-aggregate path by returning first matching row directly instead of materializing all rows - Change ScalarSubqueryResults from SmallVec to HashMap for O(1) lookup - Add documentation for new functions per STYLE.md requirements - Remove unused imports and variables Co-Authored-By: Claude Opus 4.5 --- src/database/convert.rs | 7 +- src/database/database.rs | 4 +- src/database/dml/update.rs | 253 ++++++++++++++++++------------------- src/sql/context.rs | 4 +- src/sql/predicate.rs | 5 +- 5 files changed, 135 insertions(+), 138 deletions(-) diff --git a/src/database/convert.rs b/src/database/convert.rs index b5011c7..ecb52af 100644 --- a/src/database/convert.rs +++ b/src/database/convert.rs @@ -446,10 +446,9 @@ impl Database { Expr::Subquery(subq) => { let key = std::ptr::from_ref(*subq) as usize; scalar_subquery_results - .iter() - .find(|(k, _)| *k == key) - .map(|(_, v)| v.clone()) - .ok_or_else(|| eyre::eyre!("scalar subquery result not found in pre-computed results")) + .get(&key) + .cloned() + .ok_or_else(|| eyre::eyre!("scalar subquery result not found for key 0x{:x}", key)) } Expr::BinaryOp { left, op, right } => { let left_val = Self::eval_expr_with_params_and_subqueries( diff --git a/src/database/database.rs b/src/database/database.rs index ba4598f..8d70cd9 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -966,9 +966,9 @@ impl Database { for subq in subqueries { let key = std::ptr::from_ref(subq) as usize; - if !scalar_subquery_results.iter().any(|(k, _)| *k == key) { + if !scalar_subquery_results.contains_key(&key) { let result = execute_scalar_subquery(subq, catalog, file_manager, &arena)?; - scalar_subquery_results.push((key, result)); + scalar_subquery_results.insert(key, result); } } } diff --git a/src/database/dml/update.rs b/src/database/dml/update.rs index 23714a5..bc67730 100644 --- a/src/database/dml/update.rs +++ b/src/database/dml/update.rs @@ -175,6 +175,7 @@ fn expr_contains_column_ref(expr: &crate::sql::ast::Expr) -> bool { } } +/// Recursively collects scalar subqueries from an expression tree. fn collect_scalar_subqueries_from_expr<'a>( expr: &'a crate::sql::ast::Expr<'a>, subqueries: &mut SmallVec<[&'a crate::sql::ast::SelectStmt<'a>; 4]>, @@ -237,6 +238,7 @@ fn collect_scalar_subqueries_from_expr<'a>( } } +/// Finds a single aggregate function with expression in a query plan. fn find_expression_aggregate<'a>( op: &'a crate::sql::planner::PhysicalOperator<'a>, ) -> Option<( @@ -274,6 +276,7 @@ fn find_expression_aggregate<'a>( } } +/// Evaluates a binary operation on two OwnedValues. fn eval_binary_op(left: &OwnedValue, op: &crate::sql::ast::BinaryOperator, right: &OwnedValue) -> OwnedValue { use crate::sql::ast::BinaryOperator; match op { @@ -309,6 +312,7 @@ fn eval_binary_op(left: &OwnedValue, op: &crate::sql::ast::BinaryOperator, right } } +/// Evaluates an expression against a record for streaming aggregation. fn eval_expr_for_record_streaming( expr: &crate::sql::ast::Expr<'_>, record: &RecordView<'_>, @@ -354,6 +358,48 @@ fn eval_expr_for_record_streaming( } } +/// Compares i64 and f64 without precision loss for large integers. +fn compare_int_float(i: i64, f: f64) -> std::cmp::Ordering { + use std::cmp::Ordering; + + if f.is_nan() { + return Ordering::Less; + } + if f == f64::INFINITY { + return Ordering::Less; + } + if f == f64::NEG_INFINITY { + return Ordering::Greater; + } + + const MAX_EXACT: i64 = 1i64 << 53; + const MIN_EXACT: i64 = -(1i64 << 53); + + if (MIN_EXACT..=MAX_EXACT).contains(&i) { + let i_as_f64 = i as f64; + i_as_f64.partial_cmp(&f).unwrap_or(Ordering::Equal) + } else if f >= (i64::MAX as f64) { + Ordering::Less + } else if f <= (i64::MIN as f64) { + Ordering::Greater + } else { + let f_truncated = f as i64; + match i.cmp(&f_truncated) { + Ordering::Equal => { + let f_frac = f - (f_truncated as f64); + if f_frac > 0.0 { + Ordering::Less + } else if f_frac < 0.0 { + Ordering::Greater + } else { + Ordering::Equal + } + } + other => other, + } + } +} + struct StreamingAggregateState { sum_int: i64, sum_float: f64, @@ -382,15 +428,15 @@ impl StreamingAggregateState { fn update(&mut self, value: &OwnedValue) { match value { OwnedValue::Int(i) => { - self.sum_int += i; - self.count += 1; + self.sum_int = self.sum_int.saturating_add(*i); + self.count = self.count.saturating_add(1); self.min_int = Some(self.min_int.map(|m| m.min(*i)).unwrap_or(*i)); self.max_int = Some(self.max_int.map(|m| m.max(*i)).unwrap_or(*i)); } OwnedValue::Float(f) => { self.sum_float += f; self.has_float = true; - self.count += 1; + self.count = self.count.saturating_add(1); self.min_float = Some(self.min_float.map(|m| m.min(*f)).unwrap_or(*f)); self.max_float = Some(self.max_float.map(|m| m.max(*f)).unwrap_or(*f)); } @@ -424,7 +470,7 @@ impl StreamingAggregateState { if self.has_float { match (self.min_int, self.min_float) { (Some(i), Some(f)) => { - if (i as f64) < f { + if compare_int_float(i, f).is_lt() { OwnedValue::Int(i) } else { OwnedValue::Float(f) @@ -442,7 +488,7 @@ impl StreamingAggregateState { if self.has_float { match (self.max_int, self.max_float) { (Some(i), Some(f)) => { - if (i as f64) > f { + if compare_int_float(i, f).is_gt() { OwnedValue::Int(i) } else { OwnedValue::Float(f) @@ -475,6 +521,7 @@ fn read_index_root_page(storage: &crate::storage::MmapStorage) -> Result { } impl Database { + /// Executes a scalar subquery and returns its single result value. fn execute_scalar_subquery_for_update<'a>( subq: &'a crate::sql::ast::SelectStmt<'a>, catalog: &crate::schema::catalog::Catalog, @@ -486,7 +533,7 @@ impl Database { use crate::records::RecordView; use crate::sql::builder::ExecutorBuilder; use crate::sql::context::ExecutionContext; - use crate::sql::executor::{Executor, MaterializedRowSource, StreamingBTreeSource}; + use crate::sql::executor::{Executor, StreamingBTreeSource}; use crate::sql::planner::{Planner, ScanRange}; use crate::types::create_record_schema; @@ -593,13 +640,6 @@ impl Database { .map(|(i, c)| (c.name().to_lowercase(), (i, c.data_type()))) .collect(); - let column_map: Vec<(String, usize)> = table_def - .columns() - .iter() - .enumerate() - .map(|(idx, col)| (col.name().to_lowercase(), idx)) - .collect(); - let row_id_suffix_len = if scan.is_unique_index { 0 } else { 8 }; if let Some((agg_func, expr)) = find_expression_aggregate(subq_plan.root) { @@ -703,131 +743,90 @@ impl Database { return Ok(agg_state.finalize(agg_func)); } - let row_keys: Vec<[u8; 8]> = { - let index_storage_arc = - file_manager.index_data(schema_name, table_name, index_name) - .wrap_err_with(|| format!("failed to open index '{}' for table '{}'", index_name, table_name))?; - let index_storage = index_storage_arc.read(); - - let root_page = read_index_root_page(&index_storage) - .wrap_err_with(|| format!("failed to read root page for index '{}'", index_name))?; - let index_reader = BTreeReader::new(&index_storage, root_page)?; - - let mut keys = Vec::new(); + let index_storage_arc = file_manager.index_data(schema_name, table_name, index_name) + .wrap_err_with(|| format!("failed to open index '{}' for table '{}'", index_name, table_name))?; + let index_storage = index_storage_arc.read(); + let index_root = read_index_root_page(&index_storage) + .wrap_err_with(|| format!("failed to read root page for index '{}'", index_name))?; + let index_reader = BTreeReader::new(&index_storage, index_root)?; - let extract_row_key = |cursor: &mut crate::btree::Cursor<'_, crate::storage::MmapStorage>| -> Result> { - let index_key = cursor.key()?; - let row_id_bytes = if scan.is_unique_index { - cursor.value()? - } else { - &index_key[index_key.len().saturating_sub(row_id_suffix_len)..] - }; - if row_id_bytes.len() == 8 { - let row_key: [u8; 8] = row_id_bytes.try_into() - .map_err(|_| eyre::eyre!("invalid row key in index '{}': expected 8 bytes", index_name))?; - Ok(Some(row_key)) - } else { - Ok(None) - } - }; - - match &scan.key_range { - Some(ScanRange::PrefixScan { prefix }) => { - let mut cursor = index_reader.cursor_seek(prefix)?; - while cursor.valid() { - let index_key = cursor.key()?; - if !index_key.starts_with(prefix) { - break; - } - if let Some(row_key) = extract_row_key(&mut cursor)? { - keys.push(row_key); - } - if !cursor.advance()? { - break; + let table_storage_arc = file_manager.table_data(schema_name, table_name) + .wrap_err_with(|| format!("failed to open table data for '{}' in scalar subquery", table_name))?; + let table_storage = table_storage_arc.read(); + let table_root = read_root_page(&table_storage) + .wrap_err_with(|| format!("failed to read root page for table '{}'", table_name))?; + let table_reader = BTreeReader::new(&table_storage, table_root)?; + + let extract_and_lookup_first_row = + |cursor: &mut crate::btree::Cursor<'_, crate::storage::MmapStorage>| -> Result> { + while cursor.valid() { + let index_key = cursor.key()?; + let row_id_bytes = if scan.is_unique_index { + cursor.value()? + } else { + &index_key[index_key.len().saturating_sub(row_id_suffix_len)..] + }; + if row_id_bytes.len() == 8 { + let row_key: [u8; 8] = row_id_bytes.try_into() + .map_err(|_| eyre::eyre!("invalid row key in index '{}': expected 8 bytes", index_name))?; + if let Some(row_data) = table_reader.get(&row_key)? { + let user_data = crate::database::dml::mvcc_helpers::get_user_data(row_data); + let record = RecordView::new(user_data, &schema)?; + let first_col_type = columns.first().map(|c| c.data_type()) + .ok_or_else(|| eyre::eyre!("table '{}' has no columns", table_name))?; + return Ok(Some(OwnedValue::from_record_column(&record, 0, first_col_type) + .unwrap_or(OwnedValue::Null))); } } + if !cursor.advance()? { + break; + } } - Some(ScanRange::RangeScan { start, end }) => { - let mut cursor = if let Some(start_key) = start { - index_reader.cursor_seek(start_key)? + Ok(None) + }; + + let result = match &scan.key_range { + Some(ScanRange::PrefixScan { prefix }) => { + let mut cursor = index_reader.cursor_seek(prefix)?; + if cursor.valid() { + let index_key = cursor.key()?; + if index_key.starts_with(prefix) { + extract_and_lookup_first_row(&mut cursor)? } else { - index_reader.cursor_first()? - }; - while cursor.valid() { - let index_key = cursor.key()?; - if let Some(end_key) = end { - if index_key >= *end_key { - break; - } - } - if let Some(row_key) = extract_row_key(&mut cursor)? { - keys.push(row_key); - } - if !cursor.advance()? { - break; - } + None } + } else { + None } - Some(ScanRange::FullScan) | None => { - let mut cursor = index_reader.cursor_first()?; - while cursor.valid() { - if let Some(row_key) = extract_row_key(&mut cursor)? { - keys.push(row_key); - } - if !cursor.advance()? { - break; + } + Some(ScanRange::RangeScan { start, end }) => { + let mut cursor = if let Some(start_key) = start { + index_reader.cursor_seek(start_key)? + } else { + index_reader.cursor_first()? + }; + if cursor.valid() { + if let Some(end_key) = end { + let index_key = cursor.key()?; + if index_key < *end_key { + extract_and_lookup_first_row(&mut cursor)? + } else { + None } + } else { + extract_and_lookup_first_row(&mut cursor)? } + } else { + None } } - keys - }; - - let mut materialized_rows: Vec> = - Vec::with_capacity(row_keys.len()); - - { - let table_storage_arc = file_manager.table_data(schema_name, table_name) - .wrap_err_with(|| format!("failed to open table data for '{}' in scalar subquery", table_name))?; - let table_storage = table_storage_arc.read(); - - let root_page = read_root_page(&table_storage) - .wrap_err_with(|| format!("failed to read root page for table '{}'", table_name))?; - let table_reader = BTreeReader::new(&table_storage, root_page)?; - - for row_key in &row_keys { - if let Some(row_data) = table_reader.get(row_key)? { - let user_data = - crate::database::dml::mvcc_helpers::get_user_data(row_data); - let record = RecordView::new(user_data, &schema)?; - let row_values = OwnedValue::extract_row_from_record(&record, columns)?; - materialized_rows.push(row_values); - } + Some(ScanRange::FullScan) | None => { + let mut cursor = index_reader.cursor_first()?; + extract_and_lookup_first_row(&mut cursor)? } - } - - let materialized_source = MaterializedRowSource::new(materialized_rows); - - let ctx = ExecutionContext::new(arena); - let builder = ExecutorBuilder::new(&ctx); - - let mut executor = builder.build_with_source_and_column_map( - &subq_plan, - materialized_source, - &column_map, - )?; - - executor.open()?; - let result = if let Some(row) = executor.next()? { - row.values - .first() - .map(OwnedValue::from) - .unwrap_or(OwnedValue::Null) - } else { - OwnedValue::Null }; - executor.close()?; - Ok(result) + + Ok(result.unwrap_or(OwnedValue::Null)) } _ => Ok(OwnedValue::Null), } @@ -1025,10 +1024,10 @@ impl Database { let fm = fm_guard.as_mut().unwrap(); for subq in subqueries { let key = std::ptr::from_ref(subq) as usize; - if !scalar_subquery_results.iter().any(|(k, _)| *k == key) { + if !scalar_subquery_results.contains_key(&key) { let result = Self::execute_scalar_subquery_for_update(subq, catalog, fm, arena)?; - scalar_subquery_results.push((key, result)); + scalar_subquery_results.insert(key, result); } } } diff --git a/src/sql/context.rs b/src/sql/context.rs index d359fb0..d544fa2 100644 --- a/src/sql/context.rs +++ b/src/sql/context.rs @@ -1,10 +1,10 @@ use bumpalo::Bump; use crate::memory::MemoryBudget; use crate::types::OwnedValue; -use smallvec::SmallVec; +use hashbrown::HashMap; use std::sync::Arc; -pub type ScalarSubqueryResults = SmallVec<[(usize, OwnedValue); 4]>; +pub type ScalarSubqueryResults = HashMap; pub struct ExecutionContext<'a> { pub arena: &'a Bump, diff --git a/src/sql/predicate.rs b/src/sql/predicate.rs index f3d654e..fcee07e 100644 --- a/src/sql/predicate.rs +++ b/src/sql/predicate.rs @@ -286,9 +286,8 @@ impl<'a> CompiledPredicate<'a> { Expr::Subquery(subq) => { let key = std::ptr::from_ref(*subq) as usize; self.scalar_subquery_results - .iter() - .find(|(k, _)| *k == key) - .map(|(_, v)| self.owned_value_to_value(v)) + .get(&key) + .map(|v| self.owned_value_to_value(v)) } _ => None, } From 79b5282ca6f1ea946a0b45299d56c1aabd124a38 Mon Sep 17 00:00:00 2001 From: Mohammad Julfikar Date: Tue, 20 Jan 2026 22:06:05 +0800 Subject: [PATCH 4/9] docs(update): add module documentation and safety comments - Add comprehensive module documentation for scalar subquery support (~35 lines) - Document streaming aggregate optimization strategy and limitations - Add SAFETY comment explaining AST node address stability for HashMap keys - Add edge case tests: SUM, AVG, and multiple subqueries in SET clause Co-Authored-By: Claude Opus 4.5 --- src/database/dml/update.rs | 39 ++++++++++++ tests/regression_smoke_test.rs | 105 +++++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+) diff --git a/src/database/dml/update.rs b/src/database/dml/update.rs index bc67730..e6e8126 100644 --- a/src/database/dml/update.rs +++ b/src/database/dml/update.rs @@ -82,6 +82,41 @@ //! //! UPDATE acquires write lock on file_manager. Transaction write entries //! include undo data (old row value) for rollback support. +//! +//! ## Scalar Subquery Support +//! +//! UPDATE SET clauses support scalar subqueries in value expressions: +//! ```sql +//! UPDATE orders SET total = (SELECT SUM(amount) FROM order_items WHERE order_id = 1) +//! ``` +//! +//! ### Strategy +//! +//! 1. **Collection**: All scalar subqueries are collected from SET expressions before +//! the UPDATE loop begins +//! 2. **Pre-computation**: Subquery results are computed once and stored in a HashMap +//! keyed by AST node address (stable within the function scope) +//! 3. **Substitution**: During expression evaluation, subquery nodes are replaced with +//! their pre-computed results +//! +//! ### Streaming Aggregate Optimization +//! +//! For simple aggregate subqueries (SUM, AVG, MIN, MAX, COUNT over a single expression), +//! we bypass the full executor and stream directly through BTree cursors: +//! - Creates `StreamingAggregateState` to accumulate results +//! - Iterates through table/index cursors evaluating expressions per-row +//! - Avoids materializing all rows into memory +//! +//! This optimization applies when the subquery plan has: +//! - A single aggregate function +//! - An expression to aggregate (e.g., `SUM(quantity * price)`) +//! - No complex operators like joins or sorts +//! +//! ### Limitations +//! +//! - Subqueries are not correlated (cannot reference outer UPDATE row) +//! - Streaming optimization bypasses WHERE filters (full table scan for aggregates) +//! - Multiple rows from non-aggregate subqueries return only the first row use crate::btree::BTree; use crate::database::dml::mvcc_helpers::{get_user_data, wrap_record_for_update}; @@ -1023,6 +1058,10 @@ impl Database { let mut fm_guard = self.shared.file_manager.write(); let fm = fm_guard.as_mut().unwrap(); for subq in subqueries { + // SAFETY: Using AST node address as HashMap key is safe because: + // 1. The AST is arena-allocated and lives for the duration of this function + // 2. We only read from scalar_subquery_results within this same scope + // 3. The same subquery AST node will have the same address when looked up let key = std::ptr::from_ref(subq) as usize; if !scalar_subquery_results.contains_key(&key) { let result = diff --git a/tests/regression_smoke_test.rs b/tests/regression_smoke_test.rs index 86303c9..885d9fc 100644 --- a/tests/regression_smoke_test.rs +++ b/tests/regression_smoke_test.rs @@ -1377,6 +1377,111 @@ mod real_world_scenario { } } + #[test] + fn update_with_subquery_sum_computes_total_correctly() { + let (db, _dir) = create_test_db(); + + db.execute("CREATE TABLE orders (id INTEGER PRIMARY KEY, total REAL)") + .unwrap(); + db.execute("CREATE TABLE order_items (order_id INTEGER, amount REAL)") + .unwrap(); + + db.execute("INSERT INTO orders (id, total) VALUES (1, 0.0)") + .unwrap(); + db.execute("INSERT INTO order_items (order_id, amount) VALUES (1, 25.0)") + .unwrap(); + db.execute("INSERT INTO order_items (order_id, amount) VALUES (1, 75.0)") + .unwrap(); + + db.execute( + "UPDATE orders SET total = (SELECT SUM(amount) FROM order_items) WHERE id = 1", + ) + .unwrap(); + + let rows = db.query("SELECT total FROM orders WHERE id = 1").unwrap(); + match &rows[0].values[0] { + OwnedValue::Float(v) => assert!((*v - 100.0).abs() < 0.01, "Expected 100.0, got {}", v), + other => panic!("Expected Float, got {:?}", other), + } + } + + #[test] + fn update_with_subquery_avg_computes_average_correctly() { + let (db, _dir) = create_test_db(); + + db.execute("CREATE TABLE stats (id INTEGER PRIMARY KEY, avg_value REAL)") + .unwrap(); + db.execute("CREATE TABLE values_table (value REAL)") + .unwrap(); + + db.execute("INSERT INTO stats (id, avg_value) VALUES (1, 0.0)") + .unwrap(); + db.execute("INSERT INTO values_table (value) VALUES (10.0)") + .unwrap(); + db.execute("INSERT INTO values_table (value) VALUES (20.0)") + .unwrap(); + db.execute("INSERT INTO values_table (value) VALUES (30.0)") + .unwrap(); + + db.execute( + "UPDATE stats SET avg_value = (SELECT AVG(value) FROM values_table) WHERE id = 1", + ) + .unwrap(); + + let rows = db + .query("SELECT avg_value FROM stats WHERE id = 1") + .unwrap(); + match &rows[0].values[0] { + OwnedValue::Float(v) => assert!((*v - 20.0).abs() < 0.01, "Expected 20.0, got {}", v), + other => panic!("Expected Float, got {:?}", other), + } + } + + #[test] + fn update_with_multiple_subqueries_in_set_clause() { + let (db, _dir) = create_test_db(); + + db.execute("CREATE TABLE summary (id INTEGER PRIMARY KEY, min_val REAL, max_val REAL, sum_val REAL)") + .unwrap(); + db.execute("CREATE TABLE data_points (value REAL)") + .unwrap(); + + db.execute("INSERT INTO summary (id, min_val, max_val, sum_val) VALUES (1, 0, 0, 0)") + .unwrap(); + db.execute("INSERT INTO data_points (value) VALUES (10.0)") + .unwrap(); + db.execute("INSERT INTO data_points (value) VALUES (20.0)") + .unwrap(); + db.execute("INSERT INTO data_points (value) VALUES (30.0)") + .unwrap(); + + db.execute( + "UPDATE summary SET + min_val = (SELECT MIN(value) FROM data_points), + max_val = (SELECT MAX(value) FROM data_points), + sum_val = (SELECT SUM(value) FROM data_points) + WHERE id = 1", + ) + .unwrap(); + + let rows = db + .query("SELECT min_val, max_val, sum_val FROM summary WHERE id = 1") + .unwrap(); + + match &rows[0].values[0] { + OwnedValue::Float(v) => assert!((*v - 10.0).abs() < 0.01, "Expected min 10.0, got {}", v), + other => panic!("Expected Float for min_val, got {:?}", other), + } + match &rows[0].values[1] { + OwnedValue::Float(v) => assert!((*v - 30.0).abs() < 0.01, "Expected max 30.0, got {}", v), + other => panic!("Expected Float for max_val, got {:?}", other), + } + match &rows[0].values[2] { + OwnedValue::Float(v) => assert!((*v - 60.0).abs() < 0.01, "Expected sum 60.0, got {}", v), + other => panic!("Expected Float for sum_val, got {:?}", other), + } + } + #[test] fn blog_platform_scenario() { let (db, _dir) = create_test_db(); From 604220b4175b1bda47fefefe81827b17fc8493c8 Mon Sep 17 00:00:00 2001 From: Mohammad Julfikar Date: Tue, 20 Jan 2026 22:33:52 +0800 Subject: [PATCH 5/9] fix(update): disable streaming optimization when filters present CRITICAL BUG FIX: Streaming aggregate optimization was bypassing WHERE filters in subqueries, causing incorrect results. For example: UPDATE orders SET total = (SELECT SUM(amount) FROM items WHERE order_id=1) would sum ALL items instead of just order_id=1. Fix: - Add `plan_has_filter()` helper to detect FilterExec in query plan - Disable streaming optimization when filter is present (TableScan or SecondaryIndexScan paths), falling back to full executor which applies filters correctly - Check both `plan_has_filter(subq_plan.root)` and `scan.post_scan_filter` Other changes: - Simplify min/max tracking pattern using `map_or` instead of `map().unwrap_or()` - Add test `update_with_subquery_where_filter_applied_correctly` - Update documentation to reflect correct behavior Co-Authored-By: Claude Opus 4.5 --- src/database/dml/update.rs | 71 +++++++++++++++++++++++----------- tests/regression_smoke_test.rs | 40 +++++++++++++++++++ 2 files changed, 88 insertions(+), 23 deletions(-) diff --git a/src/database/dml/update.rs b/src/database/dml/update.rs index e6e8126..bd6f953 100644 --- a/src/database/dml/update.rs +++ b/src/database/dml/update.rs @@ -115,7 +115,8 @@ //! ### Limitations //! //! - Subqueries are not correlated (cannot reference outer UPDATE row) -//! - Streaming optimization bypasses WHERE filters (full table scan for aggregates) +//! - Streaming optimization only applies when no WHERE filters are present +//! (filtered queries fall back to the full executor) //! - Multiple rows from non-aggregate subqueries return only the first row use crate::btree::BTree; @@ -465,15 +466,15 @@ impl StreamingAggregateState { OwnedValue::Int(i) => { self.sum_int = self.sum_int.saturating_add(*i); self.count = self.count.saturating_add(1); - self.min_int = Some(self.min_int.map(|m| m.min(*i)).unwrap_or(*i)); - self.max_int = Some(self.max_int.map(|m| m.max(*i)).unwrap_or(*i)); + self.min_int = Some(self.min_int.map_or(*i, |m| m.min(*i))); + self.max_int = Some(self.max_int.map_or(*i, |m| m.max(*i))); } OwnedValue::Float(f) => { self.sum_float += f; self.has_float = true; self.count = self.count.saturating_add(1); - self.min_float = Some(self.min_float.map(|m| m.min(*f)).unwrap_or(*f)); - self.max_float = Some(self.max_float.map(|m| m.max(*f)).unwrap_or(*f)); + self.min_float = Some(self.min_float.map_or(*f, |m| m.min(*f))); + self.max_float = Some(self.max_float.map_or(*f, |m| m.max(*f))); } _ => {} } @@ -541,6 +542,24 @@ impl StreamingAggregateState { } } +fn plan_has_filter(op: &crate::sql::planner::physical::PhysicalOperator<'_>) -> bool { + use crate::sql::planner::physical::PhysicalOperator; + match op { + PhysicalOperator::FilterExec(_) => true, + PhysicalOperator::ProjectExec(p) => plan_has_filter(p.input), + PhysicalOperator::LimitExec(l) => plan_has_filter(l.input), + PhysicalOperator::SortExec(s) => plan_has_filter(s.input), + PhysicalOperator::TopKExec(t) => plan_has_filter(t.input), + PhysicalOperator::HashAggregate(a) => plan_has_filter(a.input), + PhysicalOperator::SortedAggregate(a) => plan_has_filter(a.input), + PhysicalOperator::WindowExec(w) => plan_has_filter(w.input), + PhysicalOperator::ScalarSubqueryExec(s) => plan_has_filter(s.subquery), + PhysicalOperator::ExistsSubqueryExec(s) => plan_has_filter(s.subquery), + PhysicalOperator::InListSubqueryExec(s) => plan_has_filter(s.subquery), + _ => false, + } +} + fn read_root_page(storage: &crate::storage::MmapStorage) -> Result { let page = storage.page(0)?; TableFileHeader::from_bytes(page) @@ -611,25 +630,28 @@ impl Database { .map(|(i, c)| (c.name().to_lowercase(), i)) .collect(); - if let Some((agg_func, expr)) = find_expression_aggregate(subq_plan.root) { - let schema = create_record_schema(columns); - let table_reader = crate::btree::BTreeReader::new(&storage, root_page) - .wrap_err_with(|| format!("failed to create BTreeReader for table '{}'", table_name))?; - let mut cursor = table_reader.cursor_first()?; - let mut agg_state = StreamingAggregateState::new(); + let has_filter = plan_has_filter(subq_plan.root) || scan.post_scan_filter.is_some(); + if !has_filter { + if let Some((agg_func, expr)) = find_expression_aggregate(subq_plan.root) { + let schema = create_record_schema(columns); + let table_reader = crate::btree::BTreeReader::new(&storage, root_page) + .wrap_err_with(|| format!("failed to create BTreeReader for table '{}'", table_name))?; + let mut cursor = table_reader.cursor_first()?; + let mut agg_state = StreamingAggregateState::new(); - while cursor.valid() { - let row_data = cursor.value()?; - let user_data = crate::database::dml::mvcc_helpers::get_user_data(row_data); - let record = RecordView::new(user_data, &schema)?; - let expr_value = eval_expr_for_record_streaming(expr, &record, &column_info); - agg_state.update(&expr_value); - if !cursor.advance()? { - break; + while cursor.valid() { + let row_data = cursor.value()?; + let user_data = crate::database::dml::mvcc_helpers::get_user_data(row_data); + let record = RecordView::new(user_data, &schema)?; + let expr_value = eval_expr_for_record_streaming(expr, &record, &column_info); + agg_state.update(&expr_value); + if !cursor.advance()? { + break; + } } - } - return Ok(agg_state.finalize(agg_func)); + return Ok(agg_state.finalize(agg_func)); + } } let source = StreamingBTreeSource::from_btree_scan_with_projections( @@ -677,7 +699,9 @@ impl Database { let row_id_suffix_len = if scan.is_unique_index { 0 } else { 8 }; - if let Some((agg_func, expr)) = find_expression_aggregate(subq_plan.root) { + let has_filter = plan_has_filter(subq_plan.root); + if !has_filter { + if let Some((agg_func, expr)) = find_expression_aggregate(subq_plan.root) { let mut agg_state = StreamingAggregateState::new(); let index_storage_arc = file_manager.index_data(schema_name, table_name, index_name) @@ -775,7 +799,8 @@ impl Database { } } - return Ok(agg_state.finalize(agg_func)); + return Ok(agg_state.finalize(agg_func)); + } } let index_storage_arc = file_manager.index_data(schema_name, table_name, index_name) diff --git a/tests/regression_smoke_test.rs b/tests/regression_smoke_test.rs index 885d9fc..4120f64 100644 --- a/tests/regression_smoke_test.rs +++ b/tests/regression_smoke_test.rs @@ -1482,6 +1482,46 @@ mod real_world_scenario { } } + #[test] + fn update_with_subquery_where_filter_applied_correctly() { + let (db, _dir) = create_test_db(); + + db.execute("CREATE TABLE orders (id INTEGER PRIMARY KEY, total REAL)") + .unwrap(); + db.execute("CREATE TABLE order_items (order_id INTEGER, amount REAL)") + .unwrap(); + + db.execute("INSERT INTO orders (id, total) VALUES (1, 0.0)") + .unwrap(); + db.execute("INSERT INTO orders (id, total) VALUES (2, 0.0)") + .unwrap(); + db.execute("INSERT INTO order_items (order_id, amount) VALUES (1, 10.0)") + .unwrap(); + db.execute("INSERT INTO order_items (order_id, amount) VALUES (1, 15.0)") + .unwrap(); + db.execute("INSERT INTO order_items (order_id, amount) VALUES (2, 100.0)") + .unwrap(); + db.execute("INSERT INTO order_items (order_id, amount) VALUES (2, 200.0)") + .unwrap(); + + db.execute( + "UPDATE orders SET total = (SELECT SUM(amount) FROM order_items WHERE order_id = 1) WHERE id = 1", + ) + .unwrap(); + + let rows = db.query("SELECT total FROM orders WHERE id = 1").unwrap(); + match &rows[0].values[0] { + OwnedValue::Float(v) => { + assert!( + (*v - 25.0).abs() < 0.01, + "Expected 25.0 (sum of order_id=1 items only), got {}. WHERE filter may not be applied.", + v + ); + } + other => panic!("Expected Float, got {:?}", other), + } + } + #[test] fn blog_platform_scenario() { let (db, _dir) = create_test_db(); From 2d4ce4b34349faddb15c0a5bf0a7cb100eb72ccb Mon Sep 17 00:00:00 2001 From: Mohammad Julfikar Date: Tue, 20 Jan 2026 22:50:03 +0800 Subject: [PATCH 6/9] perf(convert): return Cow from eval_expr to avoid allocations Address zero-allocation requirement per MEMORY.md: - Change eval_expr_with_params_and_subqueries to return Cow<'a, OwnedValue> - For params lookup: return Cow::Borrowed(¶ms[idx]) - no allocation - For subquery lookup: return Cow::Borrowed(result) - no allocation - For computed values (binary ops, literals): return Cow::Owned(value) - Caller uses .into_owned() only when ownership is actually needed Also fix error handling: - Replace unwrap_or(OwnedValue::Null) with proper error propagation - Use wrap_err_with for better error context on from_record_column Co-Authored-By: Claude Opus 4.5 --- src/database/convert.rs | 57 +++++++++++++++++++------------------- src/database/dml/update.rs | 19 ++++++++----- 2 files changed, 41 insertions(+), 35 deletions(-) diff --git a/src/database/convert.rs b/src/database/convert.rs index ecb52af..b2f60be 100644 --- a/src/database/convert.rs +++ b/src/database/convert.rs @@ -104,6 +104,7 @@ use crate::parsing::{ }; use crate::types::{DataType, OwnedValue, Value}; use eyre::{bail, Result, WrapErr}; +use std::borrow::Cow; use super::Database; @@ -404,13 +405,13 @@ impl Database { Self::eval_literal_with_type(expr, target_type) } - pub(crate) fn eval_expr_with_params_and_subqueries( + pub(crate) fn eval_expr_with_params_and_subqueries<'a>( expr: &crate::sql::ast::Expr<'_>, target_type: Option<&crate::records::types::DataType>, - params: Option<&[OwnedValue]>, + params: Option<&'a [OwnedValue]>, param_idx: &mut usize, - scalar_subquery_results: &crate::sql::context::ScalarSubqueryResults, - ) -> Result { + scalar_subquery_results: &'a crate::sql::context::ScalarSubqueryResults, + ) -> Result> { use crate::sql::ast::{Expr, ParameterRef}; match expr { @@ -438,7 +439,7 @@ impl Database { ); } - Ok(params[idx].clone()) + Ok(Cow::Borrowed(¶ms[idx])) } else { bail!("parameter placeholder found but no parameters were bound") } @@ -447,7 +448,7 @@ impl Database { let key = std::ptr::from_ref(*subq) as usize; scalar_subquery_results .get(&key) - .cloned() + .map(Cow::Borrowed) .ok_or_else(|| eyre::eyre!("scalar subquery result not found for key 0x{:x}", key)) } Expr::BinaryOp { left, op, right } => { @@ -468,58 +469,58 @@ impl Database { use crate::sql::ast::BinaryOperator; match op { - BinaryOperator::Plus => match (&left_val, &right_val) { - (OwnedValue::Int(a), OwnedValue::Int(b)) => Ok(OwnedValue::Int(a + b)), - (OwnedValue::Float(a), OwnedValue::Float(b)) => Ok(OwnedValue::Float(a + b)), + BinaryOperator::Plus => match (left_val.as_ref(), right_val.as_ref()) { + (OwnedValue::Int(a), OwnedValue::Int(b)) => Ok(Cow::Owned(OwnedValue::Int(a + b))), + (OwnedValue::Float(a), OwnedValue::Float(b)) => Ok(Cow::Owned(OwnedValue::Float(a + b))), (OwnedValue::Int(a), OwnedValue::Float(b)) => { - Ok(OwnedValue::Float(*a as f64 + b)) + Ok(Cow::Owned(OwnedValue::Float(*a as f64 + b))) } (OwnedValue::Float(a), OwnedValue::Int(b)) => { - Ok(OwnedValue::Float(a + *b as f64)) + Ok(Cow::Owned(OwnedValue::Float(a + *b as f64))) } _ => bail!("unsupported types for addition in UPDATE SET"), }, - BinaryOperator::Minus => match (&left_val, &right_val) { - (OwnedValue::Int(a), OwnedValue::Int(b)) => Ok(OwnedValue::Int(a - b)), - (OwnedValue::Float(a), OwnedValue::Float(b)) => Ok(OwnedValue::Float(a - b)), + BinaryOperator::Minus => match (left_val.as_ref(), right_val.as_ref()) { + (OwnedValue::Int(a), OwnedValue::Int(b)) => Ok(Cow::Owned(OwnedValue::Int(a - b))), + (OwnedValue::Float(a), OwnedValue::Float(b)) => Ok(Cow::Owned(OwnedValue::Float(a - b))), (OwnedValue::Int(a), OwnedValue::Float(b)) => { - Ok(OwnedValue::Float(*a as f64 - b)) + Ok(Cow::Owned(OwnedValue::Float(*a as f64 - b))) } (OwnedValue::Float(a), OwnedValue::Int(b)) => { - Ok(OwnedValue::Float(a - *b as f64)) + Ok(Cow::Owned(OwnedValue::Float(a - *b as f64))) } _ => bail!("unsupported types for subtraction in UPDATE SET"), }, - BinaryOperator::Multiply => match (&left_val, &right_val) { - (OwnedValue::Int(a), OwnedValue::Int(b)) => Ok(OwnedValue::Int(a * b)), - (OwnedValue::Float(a), OwnedValue::Float(b)) => Ok(OwnedValue::Float(a * b)), + BinaryOperator::Multiply => match (left_val.as_ref(), right_val.as_ref()) { + (OwnedValue::Int(a), OwnedValue::Int(b)) => Ok(Cow::Owned(OwnedValue::Int(a * b))), + (OwnedValue::Float(a), OwnedValue::Float(b)) => Ok(Cow::Owned(OwnedValue::Float(a * b))), (OwnedValue::Int(a), OwnedValue::Float(b)) => { - Ok(OwnedValue::Float(*a as f64 * b)) + Ok(Cow::Owned(OwnedValue::Float(*a as f64 * b))) } (OwnedValue::Float(a), OwnedValue::Int(b)) => { - Ok(OwnedValue::Float(a * *b as f64)) + Ok(Cow::Owned(OwnedValue::Float(a * *b as f64))) } _ => bail!("unsupported types for multiplication in UPDATE SET"), }, - BinaryOperator::Divide => match (&left_val, &right_val) { + BinaryOperator::Divide => match (left_val.as_ref(), right_val.as_ref()) { (OwnedValue::Int(a), OwnedValue::Int(b)) if *b != 0 => { - Ok(OwnedValue::Int(a / b)) + Ok(Cow::Owned(OwnedValue::Int(a / b))) } (OwnedValue::Float(a), OwnedValue::Float(b)) if *b != 0.0 => { - Ok(OwnedValue::Float(a / b)) + Ok(Cow::Owned(OwnedValue::Float(a / b))) } (OwnedValue::Int(a), OwnedValue::Float(b)) if *b != 0.0 => { - Ok(OwnedValue::Float(*a as f64 / b)) + Ok(Cow::Owned(OwnedValue::Float(*a as f64 / b))) } (OwnedValue::Float(a), OwnedValue::Int(b)) if *b != 0 => { - Ok(OwnedValue::Float(a / *b as f64)) + Ok(Cow::Owned(OwnedValue::Float(a / *b as f64))) } _ => bail!("division by zero or unsupported types"), }, - _ => Self::eval_literal_with_type(expr, target_type), + _ => Self::eval_literal_with_type(expr, target_type).map(Cow::Owned), } } - _ => Self::eval_literal_with_type(expr, target_type), + _ => Self::eval_literal_with_type(expr, target_type).map(Cow::Owned), } } diff --git a/src/database/dml/update.rs b/src/database/dml/update.rs index bd6f953..6417a7e 100644 --- a/src/database/dml/update.rs +++ b/src/database/dml/update.rs @@ -668,10 +668,10 @@ impl Database { executor.open()?; let result = if let Some(row) = executor.next()? { - row.values - .first() - .map(OwnedValue::from) - .unwrap_or(OwnedValue::Null) + if row.values.is_empty() { + bail!("scalar subquery returned row with no columns for table '{}'", table_name); + } + OwnedValue::from(row.values.first().unwrap()) } else { OwnedValue::Null }; @@ -834,8 +834,12 @@ impl Database { let record = RecordView::new(user_data, &schema)?; let first_col_type = columns.first().map(|c| c.data_type()) .ok_or_else(|| eyre::eyre!("table '{}' has no columns", table_name))?; - return Ok(Some(OwnedValue::from_record_column(&record, 0, first_col_type) - .unwrap_or(OwnedValue::Null))); + let col_value = OwnedValue::from_record_column(&record, 0, first_col_type) + .wrap_err_with(|| eyre::eyre!( + "failed to read column 0 from record in table '{}' via index '{}'", + table_name, index_name + ))?; + return Ok(Some(col_value)); } } if !cursor.advance()? { @@ -1228,7 +1232,8 @@ impl Database { Some(params), &mut param_idx, &scalar_subquery_results, - )?; + )? + .into_owned(); precomputed_assignments.push((*col_idx, val)); } } From edd064e381ecf723c97f3fe6c11db36a548e7ff2 Mon Sep 17 00:00:00 2001 From: Mohammad Julfikar Date: Tue, 20 Jan 2026 23:28:32 +0800 Subject: [PATCH 7/9] fix(update): address code review - extract arithmetic ops and fix SecondaryIndexScan - Add ArithmeticOp enum and eval_arithmetic to OwnedValue for shared arithmetic logic - Refactor eval_binary_op in update.rs to use shared eval_arithmetic - Refactor eval_expr_with_row in update.rs to use shared eval_arithmetic - Refactor convert.rs BinaryOp handling to use shared eval_arithmetic - Add find_any_aggregate to handle simple column aggregates (e.g., SUM(price)) - Change SecondaryIndexScan streaming to use find_any_aggregate instead of find_expression_aggregate - Add test for SecondaryIndexScan WHERE filter verification Co-Authored-By: Claude Opus 4.5 --- src/database/convert.rs | 70 ++++---------- src/database/dml/update.rs | 162 ++++++++++++++------------------- src/types/mod.rs | 2 +- src/types/owned_value.rs | 45 +++++++++ tests/regression_smoke_test.rs | 45 +++++++++ 5 files changed, 177 insertions(+), 147 deletions(-) diff --git a/src/database/convert.rs b/src/database/convert.rs index b2f60be..a0c9e1b 100644 --- a/src/database/convert.rs +++ b/src/database/convert.rs @@ -102,7 +102,7 @@ use crate::parsing::{ parse_binary_blob, parse_date, parse_hex_blob, parse_interval, parse_time, parse_timestamp, parse_uuid, parse_vector, }; -use crate::types::{DataType, OwnedValue, Value}; +use crate::types::{ArithmeticOp, DataType, OwnedValue, Value}; use eyre::{bail, Result, WrapErr}; use std::borrow::Cow; @@ -468,56 +468,24 @@ impl Database { )?; use crate::sql::ast::BinaryOperator; - match op { - BinaryOperator::Plus => match (left_val.as_ref(), right_val.as_ref()) { - (OwnedValue::Int(a), OwnedValue::Int(b)) => Ok(Cow::Owned(OwnedValue::Int(a + b))), - (OwnedValue::Float(a), OwnedValue::Float(b)) => Ok(Cow::Owned(OwnedValue::Float(a + b))), - (OwnedValue::Int(a), OwnedValue::Float(b)) => { - Ok(Cow::Owned(OwnedValue::Float(*a as f64 + b))) - } - (OwnedValue::Float(a), OwnedValue::Int(b)) => { - Ok(Cow::Owned(OwnedValue::Float(a + *b as f64))) - } - _ => bail!("unsupported types for addition in UPDATE SET"), - }, - BinaryOperator::Minus => match (left_val.as_ref(), right_val.as_ref()) { - (OwnedValue::Int(a), OwnedValue::Int(b)) => Ok(Cow::Owned(OwnedValue::Int(a - b))), - (OwnedValue::Float(a), OwnedValue::Float(b)) => Ok(Cow::Owned(OwnedValue::Float(a - b))), - (OwnedValue::Int(a), OwnedValue::Float(b)) => { - Ok(Cow::Owned(OwnedValue::Float(*a as f64 - b))) - } - (OwnedValue::Float(a), OwnedValue::Int(b)) => { - Ok(Cow::Owned(OwnedValue::Float(a - *b as f64))) - } - _ => bail!("unsupported types for subtraction in UPDATE SET"), - }, - BinaryOperator::Multiply => match (left_val.as_ref(), right_val.as_ref()) { - (OwnedValue::Int(a), OwnedValue::Int(b)) => Ok(Cow::Owned(OwnedValue::Int(a * b))), - (OwnedValue::Float(a), OwnedValue::Float(b)) => Ok(Cow::Owned(OwnedValue::Float(a * b))), - (OwnedValue::Int(a), OwnedValue::Float(b)) => { - Ok(Cow::Owned(OwnedValue::Float(*a as f64 * b))) - } - (OwnedValue::Float(a), OwnedValue::Int(b)) => { - Ok(Cow::Owned(OwnedValue::Float(a * *b as f64))) - } - _ => bail!("unsupported types for multiplication in UPDATE SET"), - }, - BinaryOperator::Divide => match (left_val.as_ref(), right_val.as_ref()) { - (OwnedValue::Int(a), OwnedValue::Int(b)) if *b != 0 => { - Ok(Cow::Owned(OwnedValue::Int(a / b))) - } - (OwnedValue::Float(a), OwnedValue::Float(b)) if *b != 0.0 => { - Ok(Cow::Owned(OwnedValue::Float(a / b))) - } - (OwnedValue::Int(a), OwnedValue::Float(b)) if *b != 0.0 => { - Ok(Cow::Owned(OwnedValue::Float(*a as f64 / b))) - } - (OwnedValue::Float(a), OwnedValue::Int(b)) if *b != 0 => { - Ok(Cow::Owned(OwnedValue::Float(a / *b as f64))) - } - _ => bail!("division by zero or unsupported types"), - }, - _ => Self::eval_literal_with_type(expr, target_type).map(Cow::Owned), + let arith_op = match op { + BinaryOperator::Plus => Some(ArithmeticOp::Plus), + BinaryOperator::Minus => Some(ArithmeticOp::Minus), + BinaryOperator::Multiply => Some(ArithmeticOp::Multiply), + BinaryOperator::Divide => Some(ArithmeticOp::Divide), + _ => None, + }; + if let Some(aop) = arith_op { + OwnedValue::eval_arithmetic(left_val.as_ref(), aop, right_val.as_ref()) + .map(Cow::Owned) + .ok_or_else(|| { + eyre::eyre!( + "unsupported types or division by zero for {:?} in UPDATE SET", + aop + ) + }) + } else { + Self::eval_literal_with_type(expr, target_type).map(Cow::Owned) } } _ => Self::eval_literal_with_type(expr, target_type).map(Cow::Owned), diff --git a/src/database/dml/update.rs b/src/database/dml/update.rs index 6417a7e..62f05cc 100644 --- a/src/database/dml/update.rs +++ b/src/database/dml/update.rs @@ -132,7 +132,7 @@ use crate::sql::decoder::RecordDecoder; use crate::sql::executor::ExecutorRow; use crate::sql::predicate::CompiledPredicate; use crate::storage::{IndexFileHeader, TableFileHeader, DEFAULT_SCHEMA}; -use crate::types::{create_record_schema, OwnedValue, Value}; +use crate::types::{create_record_schema, ArithmeticOp, OwnedValue, Value}; use bumpalo::Bump; use eyre::{bail, Result, WrapErr}; use hashbrown::{HashMap, HashSet}; @@ -312,40 +312,53 @@ fn find_expression_aggregate<'a>( } } +/// Finds ANY single aggregate function in a query plan (including simple column aggregates). +fn find_any_aggregate<'a>( + op: &'a crate::sql::planner::PhysicalOperator<'a>, +) -> Option<( + &'a crate::sql::planner::AggregateFunction, + &'a crate::sql::ast::Expr<'a>, +)> { + use crate::sql::planner::PhysicalOperator; + + match op { + PhysicalOperator::HashAggregate(agg) => { + for agg_expr in agg.aggregates.iter() { + if let Some(arg) = agg_expr.argument { + return Some((&agg_expr.function, arg)); + } + } + find_any_aggregate(agg.input) + } + PhysicalOperator::SortedAggregate(agg) => { + for agg_expr in agg.aggregates.iter() { + if let Some(arg) = agg_expr.argument { + return Some((&agg_expr.function, arg)); + } + } + find_any_aggregate(agg.input) + } + PhysicalOperator::FilterExec(f) => find_any_aggregate(f.input), + PhysicalOperator::ProjectExec(p) => find_any_aggregate(p.input), + PhysicalOperator::SortExec(s) => find_any_aggregate(s.input), + PhysicalOperator::LimitExec(l) => find_any_aggregate(l.input), + _ => None, + } +} + /// Evaluates a binary operation on two OwnedValues. fn eval_binary_op(left: &OwnedValue, op: &crate::sql::ast::BinaryOperator, right: &OwnedValue) -> OwnedValue { use crate::sql::ast::BinaryOperator; - match op { - BinaryOperator::Plus => match (left, right) { - (OwnedValue::Int(a), OwnedValue::Int(b)) => OwnedValue::Int(a + b), - (OwnedValue::Float(a), OwnedValue::Float(b)) => OwnedValue::Float(a + b), - (OwnedValue::Int(a), OwnedValue::Float(b)) => OwnedValue::Float(*a as f64 + b), - (OwnedValue::Float(a), OwnedValue::Int(b)) => OwnedValue::Float(a + *b as f64), - _ => OwnedValue::Null, - }, - BinaryOperator::Minus => match (left, right) { - (OwnedValue::Int(a), OwnedValue::Int(b)) => OwnedValue::Int(a - b), - (OwnedValue::Float(a), OwnedValue::Float(b)) => OwnedValue::Float(a - b), - (OwnedValue::Int(a), OwnedValue::Float(b)) => OwnedValue::Float(*a as f64 - b), - (OwnedValue::Float(a), OwnedValue::Int(b)) => OwnedValue::Float(a - *b as f64), - _ => OwnedValue::Null, - }, - BinaryOperator::Multiply => match (left, right) { - (OwnedValue::Int(a), OwnedValue::Int(b)) => OwnedValue::Int(a * b), - (OwnedValue::Float(a), OwnedValue::Float(b)) => OwnedValue::Float(a * b), - (OwnedValue::Int(a), OwnedValue::Float(b)) => OwnedValue::Float(*a as f64 * b), - (OwnedValue::Float(a), OwnedValue::Int(b)) => OwnedValue::Float(a * *b as f64), - _ => OwnedValue::Null, - }, - BinaryOperator::Divide => match (left, right) { - (OwnedValue::Int(a), OwnedValue::Int(b)) if *b != 0 => OwnedValue::Int(a / b), - (OwnedValue::Float(a), OwnedValue::Float(b)) if *b != 0.0 => OwnedValue::Float(a / b), - (OwnedValue::Int(a), OwnedValue::Float(b)) if *b != 0.0 => OwnedValue::Float(*a as f64 / b), - (OwnedValue::Float(a), OwnedValue::Int(b)) if *b != 0 => OwnedValue::Float(a / *b as f64), - _ => OwnedValue::Null, - }, - _ => OwnedValue::Null, - } + let arith_op = match op { + BinaryOperator::Plus => Some(ArithmeticOp::Plus), + BinaryOperator::Minus => Some(ArithmeticOp::Minus), + BinaryOperator::Multiply => Some(ArithmeticOp::Multiply), + BinaryOperator::Divide => Some(ArithmeticOp::Divide), + _ => None, + }; + arith_op + .and_then(|aop| OwnedValue::eval_arithmetic(left, aop, right)) + .unwrap_or(OwnedValue::Null) } /// Evaluates an expression against a record for streaming aggregation. @@ -701,7 +714,7 @@ impl Database { let has_filter = plan_has_filter(subq_plan.root); if !has_filter { - if let Some((agg_func, expr)) = find_expression_aggregate(subq_plan.root) { + if let Some((agg_func, expr)) = find_any_aggregate(subq_plan.root) { let mut agg_state = StreamingAggregateState::new(); let index_storage_arc = file_manager.index_data(schema_name, table_name, index_name) @@ -2474,68 +2487,27 @@ impl Database { let left_val = self.eval_expr_with_row(left, row, column_map)?; let right_val = self.eval_expr_with_row(right, row, column_map)?; - match op { - BinaryOperator::Plus => match (&left_val, &right_val) { - (OwnedValue::Int(a), OwnedValue::Int(b)) => Ok(OwnedValue::Int(a + b)), - (OwnedValue::Float(a), OwnedValue::Float(b)) => { - Ok(OwnedValue::Float(a + b)) - } - (OwnedValue::Int(a), OwnedValue::Float(b)) => { - Ok(OwnedValue::Float(*a as f64 + b)) - } - (OwnedValue::Float(a), OwnedValue::Int(b)) => { - Ok(OwnedValue::Float(a + *b as f64)) - } - _ => bail!("unsupported types for addition"), - }, - BinaryOperator::Minus => match (&left_val, &right_val) { - (OwnedValue::Int(a), OwnedValue::Int(b)) => Ok(OwnedValue::Int(a - b)), - (OwnedValue::Float(a), OwnedValue::Float(b)) => { - Ok(OwnedValue::Float(a - b)) - } - (OwnedValue::Int(a), OwnedValue::Float(b)) => { - Ok(OwnedValue::Float(*a as f64 - b)) - } - (OwnedValue::Float(a), OwnedValue::Int(b)) => { - Ok(OwnedValue::Float(a - *b as f64)) - } - _ => bail!("unsupported types for subtraction"), - }, - BinaryOperator::Multiply => match (&left_val, &right_val) { - (OwnedValue::Int(a), OwnedValue::Int(b)) => Ok(OwnedValue::Int(a * b)), - (OwnedValue::Float(a), OwnedValue::Float(b)) => { - Ok(OwnedValue::Float(a * b)) - } - (OwnedValue::Int(a), OwnedValue::Float(b)) => { - Ok(OwnedValue::Float(*a as f64 * b)) - } - (OwnedValue::Float(a), OwnedValue::Int(b)) => { - Ok(OwnedValue::Float(a * *b as f64)) - } - _ => bail!("unsupported types for multiplication"), - }, - BinaryOperator::Divide => match (&left_val, &right_val) { - (OwnedValue::Int(a), OwnedValue::Int(b)) if *b != 0 => { - Ok(OwnedValue::Int(a / b)) - } - (OwnedValue::Float(a), OwnedValue::Float(b)) if *b != 0.0 => { - Ok(OwnedValue::Float(a / b)) - } - (OwnedValue::Int(a), OwnedValue::Float(b)) if *b != 0.0 => { - Ok(OwnedValue::Float(*a as f64 / b)) - } - (OwnedValue::Float(a), OwnedValue::Int(b)) if *b != 0 => { - Ok(OwnedValue::Float(a / *b as f64)) - } - _ => bail!("division by zero or unsupported types"), - }, - BinaryOperator::Concat => match (&left_val, &right_val) { - (OwnedValue::Text(a), OwnedValue::Text(b)) => { - Ok(OwnedValue::Text(format!("{}{}", a, b))) - } - _ => bail!("unsupported types for concatenation"), - }, - _ => bail!("unsupported binary operator in UPDATE...FROM SET expression"), + let arith_op = match op { + BinaryOperator::Plus => Some(ArithmeticOp::Plus), + BinaryOperator::Minus => Some(ArithmeticOp::Minus), + BinaryOperator::Multiply => Some(ArithmeticOp::Multiply), + BinaryOperator::Divide => Some(ArithmeticOp::Divide), + _ => None, + }; + if let Some(aop) = arith_op { + OwnedValue::eval_arithmetic(&left_val, aop, &right_val).ok_or_else(|| { + eyre::eyre!("unsupported types or division by zero for {:?}", aop) + }) + } else { + match op { + BinaryOperator::Concat => match (&left_val, &right_val) { + (OwnedValue::Text(a), OwnedValue::Text(b)) => { + Ok(OwnedValue::Text(format!("{}{}", a, b))) + } + _ => bail!("unsupported types for concatenation"), + }, + _ => bail!("unsupported binary operator in UPDATE...FROM SET expression"), + } } } Expr::UnaryOp { op, expr: inner } => { diff --git a/src/types/mod.rs b/src/types/mod.rs index 38b3aca..83bd102 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -41,6 +41,6 @@ mod value; pub use column::{range_flags, ColumnDef, DecimalView, Range}; pub use data_type::{DataType, TypeAffinity}; pub use owned_value::{ - create_column_map, create_record_schema, owned_values_to_values, OwnedValue, + create_column_map, create_record_schema, owned_values_to_values, ArithmeticOp, OwnedValue, }; pub use value::Value; diff --git a/src/types/owned_value.rs b/src/types/owned_value.rs index 2ffb687..ae67cca 100644 --- a/src/types/owned_value.rs +++ b/src/types/owned_value.rs @@ -591,6 +591,51 @@ impl OwnedValue { JsonbValue::Object(view) => OwnedValue::Jsonb(view.data().to_vec()), }) } + + pub fn eval_arithmetic( + left: &OwnedValue, + op: ArithmeticOp, + right: &OwnedValue, + ) -> Option { + match op { + ArithmeticOp::Plus => match (left, right) { + (OwnedValue::Int(a), OwnedValue::Int(b)) => Some(OwnedValue::Int(a + b)), + (OwnedValue::Float(a), OwnedValue::Float(b)) => Some(OwnedValue::Float(a + b)), + (OwnedValue::Int(a), OwnedValue::Float(b)) => Some(OwnedValue::Float(*a as f64 + b)), + (OwnedValue::Float(a), OwnedValue::Int(b)) => Some(OwnedValue::Float(a + *b as f64)), + _ => None, + }, + ArithmeticOp::Minus => match (left, right) { + (OwnedValue::Int(a), OwnedValue::Int(b)) => Some(OwnedValue::Int(a - b)), + (OwnedValue::Float(a), OwnedValue::Float(b)) => Some(OwnedValue::Float(a - b)), + (OwnedValue::Int(a), OwnedValue::Float(b)) => Some(OwnedValue::Float(*a as f64 - b)), + (OwnedValue::Float(a), OwnedValue::Int(b)) => Some(OwnedValue::Float(a - *b as f64)), + _ => None, + }, + ArithmeticOp::Multiply => match (left, right) { + (OwnedValue::Int(a), OwnedValue::Int(b)) => Some(OwnedValue::Int(a * b)), + (OwnedValue::Float(a), OwnedValue::Float(b)) => Some(OwnedValue::Float(a * b)), + (OwnedValue::Int(a), OwnedValue::Float(b)) => Some(OwnedValue::Float(*a as f64 * b)), + (OwnedValue::Float(a), OwnedValue::Int(b)) => Some(OwnedValue::Float(a * *b as f64)), + _ => None, + }, + ArithmeticOp::Divide => match (left, right) { + (OwnedValue::Int(a), OwnedValue::Int(b)) if *b != 0 => Some(OwnedValue::Int(a / b)), + (OwnedValue::Float(a), OwnedValue::Float(b)) if *b != 0.0 => Some(OwnedValue::Float(a / b)), + (OwnedValue::Int(a), OwnedValue::Float(b)) if *b != 0.0 => Some(OwnedValue::Float(*a as f64 / b)), + (OwnedValue::Float(a), OwnedValue::Int(b)) if *b != 0 => Some(OwnedValue::Float(a / *b as f64)), + _ => None, + }, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ArithmeticOp { + Plus, + Minus, + Multiply, + Divide, } mod hex { diff --git a/tests/regression_smoke_test.rs b/tests/regression_smoke_test.rs index 4120f64..b354cc5 100644 --- a/tests/regression_smoke_test.rs +++ b/tests/regression_smoke_test.rs @@ -1522,6 +1522,51 @@ mod real_world_scenario { } } + #[test] + fn update_with_subquery_secondary_index_where_filter_applied() { + let (db, _dir) = create_test_db(); + + db.execute("CREATE TABLE products (id INTEGER PRIMARY KEY, category INTEGER, price REAL)") + .unwrap(); + db.execute("CREATE TABLE summary (category INTEGER PRIMARY KEY, total_price REAL)") + .unwrap(); + db.execute("CREATE INDEX idx_products_category ON products(category)") + .unwrap(); + + db.execute("INSERT INTO products (id, category, price) VALUES (1, 1, 10.0)") + .unwrap(); + db.execute("INSERT INTO products (id, category, price) VALUES (2, 1, 20.0)") + .unwrap(); + db.execute("INSERT INTO products (id, category, price) VALUES (3, 2, 100.0)") + .unwrap(); + db.execute("INSERT INTO products (id, category, price) VALUES (4, 2, 200.0)") + .unwrap(); + + db.execute("INSERT INTO summary (category, total_price) VALUES (1, 0.0)") + .unwrap(); + db.execute("INSERT INTO summary (category, total_price) VALUES (2, 0.0)") + .unwrap(); + + db.execute( + "UPDATE summary SET total_price = (SELECT SUM(price) FROM products WHERE category = 1) WHERE category = 1", + ) + .unwrap(); + + let rows = db + .query("SELECT total_price FROM summary WHERE category = 1") + .unwrap(); + match &rows[0].values[0] { + OwnedValue::Float(v) => { + assert!( + (*v - 30.0).abs() < 0.01, + "Expected 30.0 (sum of category 1: 10+20), got {}. SecondaryIndexScan WHERE filter may not be applied.", + v + ); + } + other => panic!("Expected Float, got {:?}", other), + } + } + #[test] fn blog_platform_scenario() { let (db, _dir) = create_test_db(); From d472490158be3c129e6097486fb9a0a4682ea543 Mon Sep 17 00:00:00 2001 From: Mohammad Julfikar Date: Wed, 21 Jan 2026 00:27:14 +0800 Subject: [PATCH 8/9] perf(update): eliminate heap allocations in PK lookup path - Use SmallVec<[u8; 16]> for pk_lookup_info key instead of Vec - Use SmallVec<[u8; 16]> for index_key instead of Vec::new() - Use SmallVec for precomputed_assignments (inline 8) and deferred_assignments (inline 4) - Convert to Vec only when needed for WriteEntry (transaction logging) These changes eliminate heap allocations in the hot path for single-row UPDATE operations that use PK index lookup. The SmallVec inline storage of 16 bytes covers most primary key encodings (9 bytes for i64 with type prefix). Co-Authored-By: Claude Opus 4.5 --- src/database/dml/update.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/database/dml/update.rs b/src/database/dml/update.rs index 62f05cc..951fe43 100644 --- a/src/database/dml/update.rs +++ b/src/database/dml/update.rs @@ -1132,7 +1132,7 @@ impl Database { }; let btree = BTree::new(&mut *storage, root_page)?; - let mut pk_lookup_info: Option<(Vec, OwnedValue)> = 'pk_analysis: { + let mut pk_lookup_info: Option<(SmallVec<[u8; 16]>, OwnedValue)> = 'pk_analysis: { if let Some(crate::sql::ast::Expr::BinaryOp { left, op: crate::sql::ast::BinaryOperator::Eq, @@ -1203,12 +1203,12 @@ impl Database { let index_btree = BTree::new(&mut *index_storage, index_root_page)?; - let mut index_key = Vec::new(); + let mut index_key: SmallVec<[u8; 16]> = SmallVec::new(); Self::encode_value_as_key(&val, &mut index_key); if let Some(handle) = index_btree.search(&index_key)? { - let row_key = index_btree.get_value(&handle)?.to_vec(); - break 'pk_analysis Some((row_key, val)); + let row_key_slice = index_btree.get_value(&handle)?; + break 'pk_analysis Some((SmallVec::from_slice(row_key_slice), val)); } } } @@ -1229,8 +1229,8 @@ impl Database { Vec<(usize, OwnedValue)>, )> = Vec::new(); - let mut precomputed_assignments: Vec<(usize, OwnedValue)> = Vec::new(); - let mut deferred_assignments: Vec<(usize, usize)> = Vec::new(); + let mut precomputed_assignments: SmallVec<[(usize, OwnedValue); 8]> = SmallVec::new(); + let mut deferred_assignments: SmallVec<[(usize, usize); 4]> = SmallVec::new(); let set_param_count: usize; { let mut param_idx = 0; @@ -1419,7 +1419,7 @@ impl Database { txn.add_write_entry_with_undo( WriteEntry { table_id: table_id as u32, - key: target_key.clone(), + key: target_key.to_vec(), page_id: 0, offset: 0, undo_page_id: None, From dcba2e8dea12ba9eb3e74c7d65a4e9e302f320ea Mon Sep 17 00:00:00 2001 From: Mohammad Julfikar Date: Wed, 21 Jan 2026 10:17:32 +0800 Subject: [PATCH 9/9] perf(update): reduce allocations in multipass update loop - Only clone old_row_values when secondary indexes need updating (adds needs_old_row_for_secondary_index flag) - Use mem::replace in toast/assignment loops to combine check + assign (eliminates separate clone for toast pointer tracking) These changes reduce per-row allocations in the multipass update path: - Tables without secondary indexes: eliminates row clone entirely - Tables with toast columns: eliminates one clone per toast assignment Co-Authored-By: Claude Opus 4.5 --- src/database/dml/update.rs | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/database/dml/update.rs b/src/database/dml/update.rs index 951fe43..f76d407 100644 --- a/src/database/dml/update.rs +++ b/src/database/dml/update.rs @@ -1298,6 +1298,10 @@ impl Database { let modified_col_indices: HashSet = assignment_indices.iter().map(|(idx, _)| *idx).collect(); + let needs_old_row_for_secondary_index = secondary_indexes + .iter() + .any(|(_, col_indices)| col_indices.iter().any(|idx| modified_col_indices.contains(idx))); + let unique_col_indices: Vec = columns .iter() .enumerate() @@ -1492,15 +1496,19 @@ impl Database { if should_update { let old_value = value.to_vec(); - let old_row_values = row_values.clone(); + let old_row_values = if needs_old_row_for_secondary_index { + row_values.clone() + } else { + Vec::new() + }; let mut old_toast_values: Vec<(usize, OwnedValue)> = Vec::new(); for (col_idx, val) in &precomputed_assignments { - if let OwnedValue::ToastPointer(_) = &row_values[*col_idx] { - old_toast_values.push((*col_idx, row_values[*col_idx].clone())); + let old = std::mem::replace(&mut row_values[*col_idx], val.clone()); + if let OwnedValue::ToastPointer(_) = old { + old_toast_values.push((*col_idx, old)); } - row_values[*col_idx] = val.clone(); } if !deferred_assignments.is_empty() { @@ -1520,10 +1528,10 @@ impl Database { } for (col_idx, new_val) in deferred_values_buf.drain(..) { - if let OwnedValue::ToastPointer(_) = &row_values[col_idx] { - old_toast_values.push((col_idx, row_values[col_idx].clone())); + let old = std::mem::replace(&mut row_values[col_idx], new_val); + if let OwnedValue::ToastPointer(_) = old { + old_toast_values.push((col_idx, old)); } - row_values[col_idx] = new_val; } }