diff --git a/src/database/convert.rs b/src/database/convert.rs index 6307908..a0c9e1b 100644 --- a/src/database/convert.rs +++ b/src/database/convert.rs @@ -102,8 +102,9 @@ use crate::parsing::{ parse_binary_blob, parse_date, parse_hex_blob, parse_interval, parse_time, parse_timestamp, parse_uuid, parse_vector, }; -use crate::types::{DataType, OwnedValue, Value}; +use crate::types::{ArithmeticOp, DataType, OwnedValue, Value}; use eyre::{bail, Result, WrapErr}; +use std::borrow::Cow; use super::Database; @@ -404,6 +405,93 @@ impl Database { Self::eval_literal_with_type(expr, target_type) } + pub(crate) fn eval_expr_with_params_and_subqueries<'a>( + expr: &crate::sql::ast::Expr<'_>, + target_type: Option<&crate::records::types::DataType>, + params: Option<&'a [OwnedValue]>, + param_idx: &mut usize, + scalar_subquery_results: &'a crate::sql::context::ScalarSubqueryResults, + ) -> Result> { + use crate::sql::ast::{Expr, ParameterRef}; + + match expr { + Expr::Parameter(param_ref) => { + if let Some(params) = params { + let idx = match param_ref { + ParameterRef::Anonymous => { + let i = *param_idx; + *param_idx += 1; + i + } + ParameterRef::Positional(n) => (*n as usize).saturating_sub(1), + ParameterRef::Named(_) => { + let i = *param_idx; + *param_idx += 1; + i + } + }; + + if idx >= params.len() { + bail!( + "parameter index {} out of range (only {} parameters bound)", + idx + 1, + params.len() + ); + } + + Ok(Cow::Borrowed(¶ms[idx])) + } else { + bail!("parameter placeholder found but no parameters were bound") + } + } + Expr::Subquery(subq) => { + let key = std::ptr::from_ref(*subq) as usize; + scalar_subquery_results + .get(&key) + .map(Cow::Borrowed) + .ok_or_else(|| eyre::eyre!("scalar subquery result not found for key 0x{:x}", key)) + } + Expr::BinaryOp { left, op, right } => { + let left_val = Self::eval_expr_with_params_and_subqueries( + left, + target_type, + params, + param_idx, + scalar_subquery_results, + )?; + let right_val = Self::eval_expr_with_params_and_subqueries( + right, + target_type, + params, + param_idx, + scalar_subquery_results, + )?; + + use crate::sql::ast::BinaryOperator; + let arith_op = match op { + BinaryOperator::Plus => Some(ArithmeticOp::Plus), + BinaryOperator::Minus => Some(ArithmeticOp::Minus), + BinaryOperator::Multiply => Some(ArithmeticOp::Multiply), + BinaryOperator::Divide => Some(ArithmeticOp::Divide), + _ => None, + }; + if let Some(aop) = arith_op { + OwnedValue::eval_arithmetic(left_val.as_ref(), aop, right_val.as_ref()) + .map(Cow::Owned) + .ok_or_else(|| { + eyre::eyre!( + "unsupported types or division by zero for {:?} in UPDATE SET", + aop + ) + }) + } else { + Self::eval_literal_with_type(expr, target_type).map(Cow::Owned) + } + } + _ => Self::eval_literal_with_type(expr, target_type).map(Cow::Owned), + } + } + pub(crate) fn parse_json_string(s: &str) -> Result { let value = Self::parse_json_to_value(s.trim())?; let bytes = Self::jsonb_value_to_bytes(&value); diff --git a/src/database/database.rs b/src/database/database.rs index ba4598f..8d70cd9 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -966,9 +966,9 @@ impl Database { for subq in subqueries { let key = std::ptr::from_ref(subq) as usize; - if !scalar_subquery_results.iter().any(|(k, _)| *k == key) { + if !scalar_subquery_results.contains_key(&key) { let result = execute_scalar_subquery(subq, catalog, file_manager, &arena)?; - scalar_subquery_results.push((key, result)); + scalar_subquery_results.insert(key, result); } } } diff --git a/src/database/dml/update.rs b/src/database/dml/update.rs index 7907428..f76d407 100644 --- a/src/database/dml/update.rs +++ b/src/database/dml/update.rs @@ -82,6 +82,42 @@ //! //! UPDATE acquires write lock on file_manager. Transaction write entries //! include undo data (old row value) for rollback support. +//! +//! ## Scalar Subquery Support +//! +//! UPDATE SET clauses support scalar subqueries in value expressions: +//! ```sql +//! UPDATE orders SET total = (SELECT SUM(amount) FROM order_items WHERE order_id = 1) +//! ``` +//! +//! ### Strategy +//! +//! 1. **Collection**: All scalar subqueries are collected from SET expressions before +//! the UPDATE loop begins +//! 2. **Pre-computation**: Subquery results are computed once and stored in a HashMap +//! keyed by AST node address (stable within the function scope) +//! 3. **Substitution**: During expression evaluation, subquery nodes are replaced with +//! their pre-computed results +//! +//! ### Streaming Aggregate Optimization +//! +//! For simple aggregate subqueries (SUM, AVG, MIN, MAX, COUNT over a single expression), +//! we bypass the full executor and stream directly through BTree cursors: +//! - Creates `StreamingAggregateState` to accumulate results +//! - Iterates through table/index cursors evaluating expressions per-row +//! - Avoids materializing all rows into memory +//! +//! This optimization applies when the subquery plan has: +//! - A single aggregate function +//! - An expression to aggregate (e.g., `SUM(quantity * price)`) +//! - No complex operators like joins or sorts +//! +//! ### Limitations +//! +//! - Subqueries are not correlated (cannot reference outer UPDATE row) +//! - Streaming optimization only applies when no WHERE filters are present +//! (filtered queries fall back to the full executor) +//! - Multiple rows from non-aggregate subqueries return only the first row use crate::btree::BTree; use crate::database::dml::mvcc_helpers::{get_user_data, wrap_record_for_update}; @@ -91,14 +127,15 @@ use crate::database::{Database, ExecuteResult}; use crate::mvcc::WriteEntry; use crate::records::RecordView; use crate::schema::table::{Constraint, IndexType}; +use crate::sql::context::ScalarSubqueryResults; use crate::sql::decoder::RecordDecoder; use crate::sql::executor::ExecutorRow; use crate::sql::predicate::CompiledPredicate; -use crate::storage::{IndexFileHeader, DEFAULT_SCHEMA}; -use crate::types::{create_record_schema, OwnedValue, Value}; +use crate::storage::{IndexFileHeader, TableFileHeader, DEFAULT_SCHEMA}; +use crate::types::{create_record_schema, ArithmeticOp, OwnedValue, Value}; use bumpalo::Bump; use eyre::{bail, Result, WrapErr}; -use hashbrown::HashSet; +use hashbrown::{HashMap, HashSet}; use smallvec::SmallVec; use std::borrow::Cow; use std::sync::atomic::Ordering; @@ -147,7 +184,8 @@ fn expr_contains_column_ref(expr: &crate::sql::ast::Expr) -> bool { Expr::InList { expr, list, .. } => { expr_contains_column_ref(expr) || list.iter().any(|e| expr_contains_column_ref(e)) } - Expr::InSubquery { .. } | Expr::Subquery(_) | Expr::Exists { .. } => true, + Expr::InSubquery { .. } | Expr::Exists { .. } => true, + Expr::Subquery(_) => false, Expr::IsDistinctFrom { left, right, .. } => { expr_contains_column_ref(left) || expr_contains_column_ref(right) } @@ -173,7 +211,704 @@ fn expr_contains_column_ref(expr: &crate::sql::ast::Expr) -> bool { } } +/// Recursively collects scalar subqueries from an expression tree. +fn collect_scalar_subqueries_from_expr<'a>( + expr: &'a crate::sql::ast::Expr<'a>, + subqueries: &mut SmallVec<[&'a crate::sql::ast::SelectStmt<'a>; 4]>, +) { + use crate::sql::ast::Expr; + match expr { + Expr::Subquery(subq) => { + subqueries.push(subq); + } + Expr::BinaryOp { left, right, .. } => { + collect_scalar_subqueries_from_expr(left, subqueries); + collect_scalar_subqueries_from_expr(right, subqueries); + } + Expr::UnaryOp { expr, .. } => { + collect_scalar_subqueries_from_expr(expr, subqueries); + } + Expr::IsNull { expr, .. } => { + collect_scalar_subqueries_from_expr(expr, subqueries); + } + Expr::InList { expr, list, .. } => { + collect_scalar_subqueries_from_expr(expr, subqueries); + for item in list.iter() { + collect_scalar_subqueries_from_expr(item, subqueries); + } + } + Expr::Between { + expr, low, high, .. + } => { + collect_scalar_subqueries_from_expr(expr, subqueries); + collect_scalar_subqueries_from_expr(low, subqueries); + collect_scalar_subqueries_from_expr(high, subqueries); + } + Expr::Case { + operand, + conditions, + else_result, + } => { + if let Some(op) = operand { + collect_scalar_subqueries_from_expr(op, subqueries); + } + for cond in conditions.iter() { + collect_scalar_subqueries_from_expr(cond.condition, subqueries); + collect_scalar_subqueries_from_expr(cond.result, subqueries); + } + if let Some(else_e) = else_result { + collect_scalar_subqueries_from_expr(else_e, subqueries); + } + } + Expr::Cast { expr, .. } => { + collect_scalar_subqueries_from_expr(expr, subqueries); + } + Expr::Function(func) => { + if let crate::sql::ast::FunctionArgs::Args(args) = &func.args { + for arg in args.iter() { + collect_scalar_subqueries_from_expr(arg.value, subqueries); + } + } + } + _ => {} + } +} + +/// Finds a single aggregate function with expression in a query plan. +fn find_expression_aggregate<'a>( + op: &'a crate::sql::planner::PhysicalOperator<'a>, +) -> Option<( + &'a crate::sql::planner::AggregateFunction, + &'a crate::sql::ast::Expr<'a>, +)> { + use crate::sql::planner::PhysicalOperator; + + match op { + PhysicalOperator::HashAggregate(agg) => { + for agg_expr in agg.aggregates.iter() { + if let Some(arg) = agg_expr.argument { + if !matches!(arg, crate::sql::ast::Expr::Column(_)) { + return Some((&agg_expr.function, arg)); + } + } + } + find_expression_aggregate(agg.input) + } + PhysicalOperator::SortedAggregate(agg) => { + for agg_expr in agg.aggregates.iter() { + if let Some(arg) = agg_expr.argument { + if !matches!(arg, crate::sql::ast::Expr::Column(_)) { + return Some((&agg_expr.function, arg)); + } + } + } + find_expression_aggregate(agg.input) + } + PhysicalOperator::FilterExec(f) => find_expression_aggregate(f.input), + PhysicalOperator::ProjectExec(p) => find_expression_aggregate(p.input), + PhysicalOperator::SortExec(s) => find_expression_aggregate(s.input), + PhysicalOperator::LimitExec(l) => find_expression_aggregate(l.input), + _ => None, + } +} + +/// Finds ANY single aggregate function in a query plan (including simple column aggregates). +fn find_any_aggregate<'a>( + op: &'a crate::sql::planner::PhysicalOperator<'a>, +) -> Option<( + &'a crate::sql::planner::AggregateFunction, + &'a crate::sql::ast::Expr<'a>, +)> { + use crate::sql::planner::PhysicalOperator; + + match op { + PhysicalOperator::HashAggregate(agg) => { + for agg_expr in agg.aggregates.iter() { + if let Some(arg) = agg_expr.argument { + return Some((&agg_expr.function, arg)); + } + } + find_any_aggregate(agg.input) + } + PhysicalOperator::SortedAggregate(agg) => { + for agg_expr in agg.aggregates.iter() { + if let Some(arg) = agg_expr.argument { + return Some((&agg_expr.function, arg)); + } + } + find_any_aggregate(agg.input) + } + PhysicalOperator::FilterExec(f) => find_any_aggregate(f.input), + PhysicalOperator::ProjectExec(p) => find_any_aggregate(p.input), + PhysicalOperator::SortExec(s) => find_any_aggregate(s.input), + PhysicalOperator::LimitExec(l) => find_any_aggregate(l.input), + _ => None, + } +} + +/// Evaluates a binary operation on two OwnedValues. +fn eval_binary_op(left: &OwnedValue, op: &crate::sql::ast::BinaryOperator, right: &OwnedValue) -> OwnedValue { + use crate::sql::ast::BinaryOperator; + let arith_op = match op { + BinaryOperator::Plus => Some(ArithmeticOp::Plus), + BinaryOperator::Minus => Some(ArithmeticOp::Minus), + BinaryOperator::Multiply => Some(ArithmeticOp::Multiply), + BinaryOperator::Divide => Some(ArithmeticOp::Divide), + _ => None, + }; + arith_op + .and_then(|aop| OwnedValue::eval_arithmetic(left, aop, right)) + .unwrap_or(OwnedValue::Null) +} + +/// Evaluates an expression against a record for streaming aggregation. +fn eval_expr_for_record_streaming( + expr: &crate::sql::ast::Expr<'_>, + record: &RecordView<'_>, + column_info: &HashMap, +) -> OwnedValue { + use crate::sql::ast::{Expr, Literal}; + + match expr { + Expr::Column(col_ref) => { + let col_lower = col_ref.column.to_lowercase(); + if let Some(&(idx, data_type)) = column_info.get(&col_lower) { + OwnedValue::from_record_column(record, idx, data_type).unwrap_or(OwnedValue::Null) + } else { + OwnedValue::Null + } + } + Expr::Literal(lit) => match lit { + Literal::Integer(s) => s.parse::().map(OwnedValue::Int).unwrap_or(OwnedValue::Null), + Literal::Float(s) => s.parse::().map(OwnedValue::Float).unwrap_or(OwnedValue::Null), + Literal::String(s) => OwnedValue::Text((*s).to_string()), + Literal::Boolean(b) => OwnedValue::Int(if *b { 1 } else { 0 }), + Literal::Null => OwnedValue::Null, + _ => OwnedValue::Null, + }, + Expr::BinaryOp { left, op, right } => { + let left_val = eval_expr_for_record_streaming(left, record, column_info); + let right_val = eval_expr_for_record_streaming(right, record, column_info); + eval_binary_op(&left_val, op, &right_val) + } + Expr::UnaryOp { op, expr: inner } => { + let val = eval_expr_for_record_streaming(inner, record, column_info); + match op { + crate::sql::ast::UnaryOperator::Minus => match val { + OwnedValue::Int(i) => OwnedValue::Int(-i), + OwnedValue::Float(f) => OwnedValue::Float(-f), + _ => OwnedValue::Null, + }, + crate::sql::ast::UnaryOperator::Plus => val, + _ => OwnedValue::Null, + } + } + _ => OwnedValue::Null, + } +} + +/// Compares i64 and f64 without precision loss for large integers. +fn compare_int_float(i: i64, f: f64) -> std::cmp::Ordering { + use std::cmp::Ordering; + + if f.is_nan() { + return Ordering::Less; + } + if f == f64::INFINITY { + return Ordering::Less; + } + if f == f64::NEG_INFINITY { + return Ordering::Greater; + } + + const MAX_EXACT: i64 = 1i64 << 53; + const MIN_EXACT: i64 = -(1i64 << 53); + + if (MIN_EXACT..=MAX_EXACT).contains(&i) { + let i_as_f64 = i as f64; + i_as_f64.partial_cmp(&f).unwrap_or(Ordering::Equal) + } else if f >= (i64::MAX as f64) { + Ordering::Less + } else if f <= (i64::MIN as f64) { + Ordering::Greater + } else { + let f_truncated = f as i64; + match i.cmp(&f_truncated) { + Ordering::Equal => { + let f_frac = f - (f_truncated as f64); + if f_frac > 0.0 { + Ordering::Less + } else if f_frac < 0.0 { + Ordering::Greater + } else { + Ordering::Equal + } + } + other => other, + } + } +} + +struct StreamingAggregateState { + sum_int: i64, + sum_float: f64, + has_float: bool, + count: usize, + min_int: Option, + min_float: Option, + max_int: Option, + max_float: Option, +} + +impl StreamingAggregateState { + fn new() -> Self { + Self { + sum_int: 0, + sum_float: 0.0, + has_float: false, + count: 0, + min_int: None, + min_float: None, + max_int: None, + max_float: None, + } + } + + fn update(&mut self, value: &OwnedValue) { + match value { + OwnedValue::Int(i) => { + self.sum_int = self.sum_int.saturating_add(*i); + self.count = self.count.saturating_add(1); + self.min_int = Some(self.min_int.map_or(*i, |m| m.min(*i))); + self.max_int = Some(self.max_int.map_or(*i, |m| m.max(*i))); + } + OwnedValue::Float(f) => { + self.sum_float += f; + self.has_float = true; + self.count = self.count.saturating_add(1); + self.min_float = Some(self.min_float.map_or(*f, |m| m.min(*f))); + self.max_float = Some(self.max_float.map_or(*f, |m| m.max(*f))); + } + _ => {} + } + } + + fn finalize(&self, agg_func: &crate::sql::planner::AggregateFunction) -> OwnedValue { + use crate::sql::planner::AggregateFunction; + + match agg_func { + AggregateFunction::Sum => { + if self.has_float { + OwnedValue::Float(self.sum_float + self.sum_int as f64) + } else if self.count > 0 { + OwnedValue::Int(self.sum_int) + } else { + OwnedValue::Null + } + } + AggregateFunction::Avg => { + if self.count > 0 { + let total = self.sum_float + self.sum_int as f64; + OwnedValue::Float(total / self.count as f64) + } else { + OwnedValue::Null + } + } + AggregateFunction::Count => OwnedValue::Int(self.count as i64), + AggregateFunction::Min => { + if self.has_float { + match (self.min_int, self.min_float) { + (Some(i), Some(f)) => { + if compare_int_float(i, f).is_lt() { + OwnedValue::Int(i) + } else { + OwnedValue::Float(f) + } + } + (None, Some(f)) => OwnedValue::Float(f), + (Some(i), None) => OwnedValue::Int(i), + (None, None) => OwnedValue::Null, + } + } else { + self.min_int.map(OwnedValue::Int).unwrap_or(OwnedValue::Null) + } + } + AggregateFunction::Max => { + if self.has_float { + match (self.max_int, self.max_float) { + (Some(i), Some(f)) => { + if compare_int_float(i, f).is_gt() { + OwnedValue::Int(i) + } else { + OwnedValue::Float(f) + } + } + (None, Some(f)) => OwnedValue::Float(f), + (Some(i), None) => OwnedValue::Int(i), + (None, None) => OwnedValue::Null, + } + } else { + self.max_int.map(OwnedValue::Int).unwrap_or(OwnedValue::Null) + } + } + } + } +} + +fn plan_has_filter(op: &crate::sql::planner::physical::PhysicalOperator<'_>) -> bool { + use crate::sql::planner::physical::PhysicalOperator; + match op { + PhysicalOperator::FilterExec(_) => true, + PhysicalOperator::ProjectExec(p) => plan_has_filter(p.input), + PhysicalOperator::LimitExec(l) => plan_has_filter(l.input), + PhysicalOperator::SortExec(s) => plan_has_filter(s.input), + PhysicalOperator::TopKExec(t) => plan_has_filter(t.input), + PhysicalOperator::HashAggregate(a) => plan_has_filter(a.input), + PhysicalOperator::SortedAggregate(a) => plan_has_filter(a.input), + PhysicalOperator::WindowExec(w) => plan_has_filter(w.input), + PhysicalOperator::ScalarSubqueryExec(s) => plan_has_filter(s.subquery), + PhysicalOperator::ExistsSubqueryExec(s) => plan_has_filter(s.subquery), + PhysicalOperator::InListSubqueryExec(s) => plan_has_filter(s.subquery), + _ => false, + } +} + +fn read_root_page(storage: &crate::storage::MmapStorage) -> Result { + let page = storage.page(0)?; + TableFileHeader::from_bytes(page) + .map(|h| h.root_page()) + .wrap_err("failed to read table file header for root page") +} + +fn read_index_root_page(storage: &crate::storage::MmapStorage) -> Result { + let page = storage.page(0)?; + IndexFileHeader::from_bytes(page) + .map(|h| h.root_page()) + .wrap_err("failed to read index file header for root page") +} + impl Database { + /// Executes a scalar subquery and returns its single result value. + fn execute_scalar_subquery_for_update<'a>( + subq: &'a crate::sql::ast::SelectStmt<'a>, + catalog: &crate::schema::catalog::Catalog, + file_manager: &mut crate::storage::FileManager, + arena: &'a Bump, + ) -> Result { + use crate::btree::BTreeReader; + use crate::database::query::{find_plan_source, PlanSource}; + use crate::records::RecordView; + use crate::sql::builder::ExecutorBuilder; + use crate::sql::context::ExecutionContext; + use crate::sql::executor::{Executor, StreamingBTreeSource}; + use crate::sql::planner::{Planner, ScanRange}; + use crate::types::create_record_schema; + + let planner = Planner::new(catalog, arena); + let stmt = crate::sql::ast::Statement::Select(subq); + let subq_plan = planner.create_physical_plan(&stmt) + .wrap_err("failed to create physical plan for scalar subquery")?; + + let plan_source = find_plan_source(subq_plan.root); + + match plan_source { + Some(PlanSource::TableScan(scan)) => { + let schema_name = scan.schema.unwrap_or(DEFAULT_SCHEMA); + let table_name = scan.table; + + let table_def = catalog.resolve_table_in_schema(scan.schema, table_name) + .wrap_err_with(|| format!("failed to resolve table '{}' in scalar subquery", table_name))?; + let column_types: Vec<_> = + table_def.columns().iter().map(|c| c.data_type()).collect(); + let columns = table_def.columns(); + + let storage_arc = file_manager.table_data(schema_name, table_name) + .wrap_err_with(|| format!("failed to open table data for '{}' in scalar subquery", table_name))?; + let storage = storage_arc.read(); + + let root_page = read_root_page(&storage) + .wrap_err_with(|| format!("failed to read root page for table '{}'", table_name))?; + + let column_info: HashMap = table_def + .columns() + .iter() + .enumerate() + .map(|(i, c)| (c.name().to_lowercase(), (i, c.data_type()))) + .collect(); + + let column_map: Vec<(String, usize)> = table_def + .columns() + .iter() + .enumerate() + .map(|(i, c)| (c.name().to_lowercase(), i)) + .collect(); + + let has_filter = plan_has_filter(subq_plan.root) || scan.post_scan_filter.is_some(); + if !has_filter { + if let Some((agg_func, expr)) = find_expression_aggregate(subq_plan.root) { + let schema = create_record_schema(columns); + let table_reader = crate::btree::BTreeReader::new(&storage, root_page) + .wrap_err_with(|| format!("failed to create BTreeReader for table '{}'", table_name))?; + let mut cursor = table_reader.cursor_first()?; + let mut agg_state = StreamingAggregateState::new(); + + while cursor.valid() { + let row_data = cursor.value()?; + let user_data = crate::database::dml::mvcc_helpers::get_user_data(row_data); + let record = RecordView::new(user_data, &schema)?; + let expr_value = eval_expr_for_record_streaming(expr, &record, &column_info); + agg_state.update(&expr_value); + if !cursor.advance()? { + break; + } + } + + return Ok(agg_state.finalize(agg_func)); + } + } + + let source = StreamingBTreeSource::from_btree_scan_with_projections( + &storage, + root_page, + column_types, + None, + )?; + + let ctx = ExecutionContext::new(arena); + let builder = ExecutorBuilder::new(&ctx); + let mut executor = + builder.build_with_source_and_column_map(&subq_plan, source, &column_map)?; + + executor.open()?; + let result = if let Some(row) = executor.next()? { + if row.values.is_empty() { + bail!("scalar subquery returned row with no columns for table '{}'", table_name); + } + OwnedValue::from(row.values.first().unwrap()) + } else { + OwnedValue::Null + }; + executor.close()?; + Ok(result) + } + Some(PlanSource::SecondaryIndexScan(scan)) => { + let schema_name = scan.schema.unwrap_or(DEFAULT_SCHEMA); + let table_name = scan.table; + let index_name = scan.index_name; + + let table_def = scan.table_def.ok_or_else(|| { + eyre::eyre!("SecondaryIndexScan missing table_def for table '{}' in scalar subquery", table_name) + })?; + + let columns = table_def.columns(); + let schema = create_record_schema(columns); + + let column_info: HashMap = table_def + .columns() + .iter() + .enumerate() + .map(|(i, c)| (c.name().to_lowercase(), (i, c.data_type()))) + .collect(); + + let row_id_suffix_len = if scan.is_unique_index { 0 } else { 8 }; + + let has_filter = plan_has_filter(subq_plan.root); + if !has_filter { + if let Some((agg_func, expr)) = find_any_aggregate(subq_plan.root) { + let mut agg_state = StreamingAggregateState::new(); + + let index_storage_arc = file_manager.index_data(schema_name, table_name, index_name) + .wrap_err_with(|| format!("failed to open index '{}' for table '{}'", index_name, table_name))?; + let index_storage = index_storage_arc.read(); + let index_root = read_index_root_page(&index_storage) + .wrap_err_with(|| format!("failed to read root page for index '{}'", index_name))?; + let index_reader = BTreeReader::new(&index_storage, index_root)?; + + let table_storage_arc = file_manager.table_data(schema_name, table_name) + .wrap_err_with(|| format!("failed to open table data for '{}' in scalar subquery", table_name))?; + let table_storage = table_storage_arc.read(); + let table_root = read_root_page(&table_storage) + .wrap_err_with(|| format!("failed to read root page for table '{}'", table_name))?; + let table_reader = BTreeReader::new(&table_storage, table_root)?; + + let process_index_cursor = |cursor: &mut crate::btree::Cursor<'_, crate::storage::MmapStorage>| -> Result> { + let index_key = cursor.key()?; + let row_id_bytes = if scan.is_unique_index { + cursor.value()? + } else { + &index_key[index_key.len().saturating_sub(row_id_suffix_len)..] + }; + if row_id_bytes.len() == 8 { + let row_key: [u8; 8] = row_id_bytes.try_into() + .map_err(|_| eyre::eyre!("invalid row key in index '{}': expected 8 bytes", index_name))?; + Ok(Some(row_key)) + } else { + Ok(None) + } + }; + + match &scan.key_range { + Some(ScanRange::PrefixScan { prefix }) => { + let mut cursor = index_reader.cursor_seek(prefix)?; + while cursor.valid() { + let index_key = cursor.key()?; + if !index_key.starts_with(prefix) { + break; + } + if let Some(row_key) = process_index_cursor(&mut cursor)? { + if let Some(row_data) = table_reader.get(&row_key)? { + let user_data = crate::database::dml::mvcc_helpers::get_user_data(row_data); + let record = RecordView::new(user_data, &schema)?; + let expr_value = eval_expr_for_record_streaming(expr, &record, &column_info); + agg_state.update(&expr_value); + } + } + if !cursor.advance()? { + break; + } + } + } + Some(ScanRange::RangeScan { start, end }) => { + let mut cursor = if let Some(start_key) = start { + index_reader.cursor_seek(start_key)? + } else { + index_reader.cursor_first()? + }; + while cursor.valid() { + let index_key = cursor.key()?; + if let Some(end_key) = end { + if index_key >= *end_key { + break; + } + } + if let Some(row_key) = process_index_cursor(&mut cursor)? { + if let Some(row_data) = table_reader.get(&row_key)? { + let user_data = crate::database::dml::mvcc_helpers::get_user_data(row_data); + let record = RecordView::new(user_data, &schema)?; + let expr_value = eval_expr_for_record_streaming(expr, &record, &column_info); + agg_state.update(&expr_value); + } + } + if !cursor.advance()? { + break; + } + } + } + Some(ScanRange::FullScan) | None => { + let mut cursor = index_reader.cursor_first()?; + while cursor.valid() { + if let Some(row_key) = process_index_cursor(&mut cursor)? { + if let Some(row_data) = table_reader.get(&row_key)? { + let user_data = crate::database::dml::mvcc_helpers::get_user_data(row_data); + let record = RecordView::new(user_data, &schema)?; + let expr_value = eval_expr_for_record_streaming(expr, &record, &column_info); + agg_state.update(&expr_value); + } + } + if !cursor.advance()? { + break; + } + } + } + } + + return Ok(agg_state.finalize(agg_func)); + } + } + + let index_storage_arc = file_manager.index_data(schema_name, table_name, index_name) + .wrap_err_with(|| format!("failed to open index '{}' for table '{}'", index_name, table_name))?; + let index_storage = index_storage_arc.read(); + let index_root = read_index_root_page(&index_storage) + .wrap_err_with(|| format!("failed to read root page for index '{}'", index_name))?; + let index_reader = BTreeReader::new(&index_storage, index_root)?; + + let table_storage_arc = file_manager.table_data(schema_name, table_name) + .wrap_err_with(|| format!("failed to open table data for '{}' in scalar subquery", table_name))?; + let table_storage = table_storage_arc.read(); + let table_root = read_root_page(&table_storage) + .wrap_err_with(|| format!("failed to read root page for table '{}'", table_name))?; + let table_reader = BTreeReader::new(&table_storage, table_root)?; + + let extract_and_lookup_first_row = + |cursor: &mut crate::btree::Cursor<'_, crate::storage::MmapStorage>| -> Result> { + while cursor.valid() { + let index_key = cursor.key()?; + let row_id_bytes = if scan.is_unique_index { + cursor.value()? + } else { + &index_key[index_key.len().saturating_sub(row_id_suffix_len)..] + }; + if row_id_bytes.len() == 8 { + let row_key: [u8; 8] = row_id_bytes.try_into() + .map_err(|_| eyre::eyre!("invalid row key in index '{}': expected 8 bytes", index_name))?; + if let Some(row_data) = table_reader.get(&row_key)? { + let user_data = crate::database::dml::mvcc_helpers::get_user_data(row_data); + let record = RecordView::new(user_data, &schema)?; + let first_col_type = columns.first().map(|c| c.data_type()) + .ok_or_else(|| eyre::eyre!("table '{}' has no columns", table_name))?; + let col_value = OwnedValue::from_record_column(&record, 0, first_col_type) + .wrap_err_with(|| eyre::eyre!( + "failed to read column 0 from record in table '{}' via index '{}'", + table_name, index_name + ))?; + return Ok(Some(col_value)); + } + } + if !cursor.advance()? { + break; + } + } + Ok(None) + }; + + let result = match &scan.key_range { + Some(ScanRange::PrefixScan { prefix }) => { + let mut cursor = index_reader.cursor_seek(prefix)?; + if cursor.valid() { + let index_key = cursor.key()?; + if index_key.starts_with(prefix) { + extract_and_lookup_first_row(&mut cursor)? + } else { + None + } + } else { + None + } + } + Some(ScanRange::RangeScan { start, end }) => { + let mut cursor = if let Some(start_key) = start { + index_reader.cursor_seek(start_key)? + } else { + index_reader.cursor_first()? + }; + if cursor.valid() { + if let Some(end_key) = end { + let index_key = cursor.key()?; + if index_key < *end_key { + extract_and_lookup_first_row(&mut cursor)? + } else { + None + } + } else { + extract_and_lookup_first_row(&mut cursor)? + } + } else { + None + } + } + Some(ScanRange::FullScan) | None => { + let mut cursor = index_reader.cursor_first()?; + extract_and_lookup_first_row(&mut cursor)? + } + }; + + Ok(result.unwrap_or(OwnedValue::Null)) + } + _ => Ok(OwnedValue::Null), + } + } + pub(crate) fn execute_update( &self, update: &crate::sql::ast::UpdateStmt<'_>, @@ -325,11 +1060,10 @@ impl Database { }) .collect(); - drop(catalog_guard); - let schema = create_record_schema(&columns); if let Some(from_tables) = from_tables_data { + drop(catalog_guard); return self.execute_update_with_from( update, arena, @@ -355,6 +1089,33 @@ impl Database { }) .collect(); + let mut scalar_subquery_results: ScalarSubqueryResults = ScalarSubqueryResults::new(); + { + let mut subqueries: SmallVec<[&crate::sql::ast::SelectStmt<'_>; 4]> = SmallVec::new(); + for (_, value_expr) in &assignment_indices { + collect_scalar_subqueries_from_expr(value_expr, &mut subqueries); + } + + if !subqueries.is_empty() { + let mut fm_guard = self.shared.file_manager.write(); + let fm = fm_guard.as_mut().unwrap(); + for subq in subqueries { + // SAFETY: Using AST node address as HashMap key is safe because: + // 1. The AST is arena-allocated and lives for the duration of this function + // 2. We only read from scalar_subquery_results within this same scope + // 3. The same subquery AST node will have the same address when looked up + let key = std::ptr::from_ref(subq) as usize; + if !scalar_subquery_results.contains_key(&key) { + let result = + Self::execute_scalar_subquery_for_update(subq, catalog, fm, arena)?; + scalar_subquery_results.insert(key, result); + } + } + } + } + + drop(catalog_guard); + let mut file_manager_guard = self.shared.file_manager.write(); let file_manager = file_manager_guard.as_mut().unwrap(); let storage_arc = file_manager.table_data_mut(schema_name, table_name)?; @@ -371,7 +1132,7 @@ impl Database { }; let btree = BTree::new(&mut *storage, root_page)?; - let mut pk_lookup_info: Option<(Vec, OwnedValue)> = 'pk_analysis: { + let mut pk_lookup_info: Option<(SmallVec<[u8; 16]>, OwnedValue)> = 'pk_analysis: { if let Some(crate::sql::ast::Expr::BinaryOp { left, op: crate::sql::ast::BinaryOperator::Eq, @@ -442,12 +1203,12 @@ impl Database { let index_btree = BTree::new(&mut *index_storage, index_root_page)?; - let mut index_key = Vec::new(); + let mut index_key: SmallVec<[u8; 16]> = SmallVec::new(); Self::encode_value_as_key(&val, &mut index_key); if let Some(handle) = index_btree.search(&index_key)? { - let row_key = index_btree.get_value(&handle)?.to_vec(); - break 'pk_analysis Some((row_key, val)); + let row_key_slice = index_btree.get_value(&handle)?; + break 'pk_analysis Some((SmallVec::from_slice(row_key_slice), val)); } } } @@ -468,8 +1229,8 @@ impl Database { Vec<(usize, OwnedValue)>, )> = Vec::new(); - let mut precomputed_assignments: Vec<(usize, OwnedValue)> = Vec::new(); - let mut deferred_assignments: Vec<(usize, usize)> = Vec::new(); + let mut precomputed_assignments: SmallVec<[(usize, OwnedValue); 8]> = SmallVec::new(); + let mut deferred_assignments: SmallVec<[(usize, usize); 4]> = SmallVec::new(); let set_param_count: usize; { let mut param_idx = 0; @@ -478,8 +1239,14 @@ impl Database { deferred_assignments.push((*col_idx, assign_idx)); param_idx += count_params_in_expr(value_expr); } else { - let val = - Self::eval_expr_with_params(value_expr, column_types.get(*col_idx), Some(params), &mut param_idx)?; + let val = Self::eval_expr_with_params_and_subqueries( + value_expr, + column_types.get(*col_idx), + Some(params), + &mut param_idx, + &scalar_subquery_results, + )? + .into_owned(); precomputed_assignments.push((*col_idx, val)); } } @@ -531,6 +1298,10 @@ impl Database { let modified_col_indices: HashSet = assignment_indices.iter().map(|(idx, _)| *idx).collect(); + let needs_old_row_for_secondary_index = secondary_indexes + .iter() + .any(|(_, col_indices)| col_indices.iter().any(|idx| modified_col_indices.contains(idx))); + let unique_col_indices: Vec = columns .iter() .enumerate() @@ -652,7 +1423,7 @@ impl Database { txn.add_write_entry_with_undo( WriteEntry { table_id: table_id as u32, - key: target_key.clone(), + key: target_key.to_vec(), page_id: 0, offset: 0, undo_page_id: None, @@ -725,15 +1496,19 @@ impl Database { if should_update { let old_value = value.to_vec(); - let old_row_values = row_values.clone(); + let old_row_values = if needs_old_row_for_secondary_index { + row_values.clone() + } else { + Vec::new() + }; let mut old_toast_values: Vec<(usize, OwnedValue)> = Vec::new(); for (col_idx, val) in &precomputed_assignments { - if let OwnedValue::ToastPointer(_) = &row_values[*col_idx] { - old_toast_values.push((*col_idx, row_values[*col_idx].clone())); + let old = std::mem::replace(&mut row_values[*col_idx], val.clone()); + if let OwnedValue::ToastPointer(_) = old { + old_toast_values.push((*col_idx, old)); } - row_values[*col_idx] = val.clone(); } if !deferred_assignments.is_empty() { @@ -753,10 +1528,10 @@ impl Database { } for (col_idx, new_val) in deferred_values_buf.drain(..) { - if let OwnedValue::ToastPointer(_) = &row_values[col_idx] { - old_toast_values.push((col_idx, row_values[col_idx].clone())); + let old = std::mem::replace(&mut row_values[col_idx], new_val); + if let OwnedValue::ToastPointer(_) = old { + old_toast_values.push((col_idx, old)); } - row_values[col_idx] = new_val; } } @@ -1720,68 +2495,27 @@ impl Database { let left_val = self.eval_expr_with_row(left, row, column_map)?; let right_val = self.eval_expr_with_row(right, row, column_map)?; - match op { - BinaryOperator::Plus => match (&left_val, &right_val) { - (OwnedValue::Int(a), OwnedValue::Int(b)) => Ok(OwnedValue::Int(a + b)), - (OwnedValue::Float(a), OwnedValue::Float(b)) => { - Ok(OwnedValue::Float(a + b)) - } - (OwnedValue::Int(a), OwnedValue::Float(b)) => { - Ok(OwnedValue::Float(*a as f64 + b)) - } - (OwnedValue::Float(a), OwnedValue::Int(b)) => { - Ok(OwnedValue::Float(a + *b as f64)) - } - _ => bail!("unsupported types for addition"), - }, - BinaryOperator::Minus => match (&left_val, &right_val) { - (OwnedValue::Int(a), OwnedValue::Int(b)) => Ok(OwnedValue::Int(a - b)), - (OwnedValue::Float(a), OwnedValue::Float(b)) => { - Ok(OwnedValue::Float(a - b)) - } - (OwnedValue::Int(a), OwnedValue::Float(b)) => { - Ok(OwnedValue::Float(*a as f64 - b)) - } - (OwnedValue::Float(a), OwnedValue::Int(b)) => { - Ok(OwnedValue::Float(a - *b as f64)) - } - _ => bail!("unsupported types for subtraction"), - }, - BinaryOperator::Multiply => match (&left_val, &right_val) { - (OwnedValue::Int(a), OwnedValue::Int(b)) => Ok(OwnedValue::Int(a * b)), - (OwnedValue::Float(a), OwnedValue::Float(b)) => { - Ok(OwnedValue::Float(a * b)) - } - (OwnedValue::Int(a), OwnedValue::Float(b)) => { - Ok(OwnedValue::Float(*a as f64 * b)) - } - (OwnedValue::Float(a), OwnedValue::Int(b)) => { - Ok(OwnedValue::Float(a * *b as f64)) - } - _ => bail!("unsupported types for multiplication"), - }, - BinaryOperator::Divide => match (&left_val, &right_val) { - (OwnedValue::Int(a), OwnedValue::Int(b)) if *b != 0 => { - Ok(OwnedValue::Int(a / b)) - } - (OwnedValue::Float(a), OwnedValue::Float(b)) if *b != 0.0 => { - Ok(OwnedValue::Float(a / b)) - } - (OwnedValue::Int(a), OwnedValue::Float(b)) if *b != 0.0 => { - Ok(OwnedValue::Float(*a as f64 / b)) - } - (OwnedValue::Float(a), OwnedValue::Int(b)) if *b != 0 => { - Ok(OwnedValue::Float(a / *b as f64)) - } - _ => bail!("division by zero or unsupported types"), - }, - BinaryOperator::Concat => match (&left_val, &right_val) { - (OwnedValue::Text(a), OwnedValue::Text(b)) => { - Ok(OwnedValue::Text(format!("{}{}", a, b))) - } - _ => bail!("unsupported types for concatenation"), - }, - _ => bail!("unsupported binary operator in UPDATE...FROM SET expression"), + let arith_op = match op { + BinaryOperator::Plus => Some(ArithmeticOp::Plus), + BinaryOperator::Minus => Some(ArithmeticOp::Minus), + BinaryOperator::Multiply => Some(ArithmeticOp::Multiply), + BinaryOperator::Divide => Some(ArithmeticOp::Divide), + _ => None, + }; + if let Some(aop) = arith_op { + OwnedValue::eval_arithmetic(&left_val, aop, &right_val).ok_or_else(|| { + eyre::eyre!("unsupported types or division by zero for {:?}", aop) + }) + } else { + match op { + BinaryOperator::Concat => match (&left_val, &right_val) { + (OwnedValue::Text(a), OwnedValue::Text(b)) => { + Ok(OwnedValue::Text(format!("{}{}", a, b))) + } + _ => bail!("unsupported types for concatenation"), + }, + _ => bail!("unsupported binary operator in UPDATE...FROM SET expression"), + } } } Expr::UnaryOp { op, expr: inner } => { diff --git a/src/sql/context.rs b/src/sql/context.rs index d359fb0..d544fa2 100644 --- a/src/sql/context.rs +++ b/src/sql/context.rs @@ -1,10 +1,10 @@ use bumpalo::Bump; use crate::memory::MemoryBudget; use crate::types::OwnedValue; -use smallvec::SmallVec; +use hashbrown::HashMap; use std::sync::Arc; -pub type ScalarSubqueryResults = SmallVec<[(usize, OwnedValue); 4]>; +pub type ScalarSubqueryResults = HashMap; pub struct ExecutionContext<'a> { pub arena: &'a Bump, diff --git a/src/sql/predicate.rs b/src/sql/predicate.rs index f3d654e..fcee07e 100644 --- a/src/sql/predicate.rs +++ b/src/sql/predicate.rs @@ -286,9 +286,8 @@ impl<'a> CompiledPredicate<'a> { Expr::Subquery(subq) => { let key = std::ptr::from_ref(*subq) as usize; self.scalar_subquery_results - .iter() - .find(|(k, _)| *k == key) - .map(|(_, v)| self.owned_value_to_value(v)) + .get(&key) + .map(|v| self.owned_value_to_value(v)) } _ => None, } diff --git a/src/types/mod.rs b/src/types/mod.rs index 38b3aca..83bd102 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -41,6 +41,6 @@ mod value; pub use column::{range_flags, ColumnDef, DecimalView, Range}; pub use data_type::{DataType, TypeAffinity}; pub use owned_value::{ - create_column_map, create_record_schema, owned_values_to_values, OwnedValue, + create_column_map, create_record_schema, owned_values_to_values, ArithmeticOp, OwnedValue, }; pub use value::Value; diff --git a/src/types/owned_value.rs b/src/types/owned_value.rs index 2ffb687..ae67cca 100644 --- a/src/types/owned_value.rs +++ b/src/types/owned_value.rs @@ -591,6 +591,51 @@ impl OwnedValue { JsonbValue::Object(view) => OwnedValue::Jsonb(view.data().to_vec()), }) } + + pub fn eval_arithmetic( + left: &OwnedValue, + op: ArithmeticOp, + right: &OwnedValue, + ) -> Option { + match op { + ArithmeticOp::Plus => match (left, right) { + (OwnedValue::Int(a), OwnedValue::Int(b)) => Some(OwnedValue::Int(a + b)), + (OwnedValue::Float(a), OwnedValue::Float(b)) => Some(OwnedValue::Float(a + b)), + (OwnedValue::Int(a), OwnedValue::Float(b)) => Some(OwnedValue::Float(*a as f64 + b)), + (OwnedValue::Float(a), OwnedValue::Int(b)) => Some(OwnedValue::Float(a + *b as f64)), + _ => None, + }, + ArithmeticOp::Minus => match (left, right) { + (OwnedValue::Int(a), OwnedValue::Int(b)) => Some(OwnedValue::Int(a - b)), + (OwnedValue::Float(a), OwnedValue::Float(b)) => Some(OwnedValue::Float(a - b)), + (OwnedValue::Int(a), OwnedValue::Float(b)) => Some(OwnedValue::Float(*a as f64 - b)), + (OwnedValue::Float(a), OwnedValue::Int(b)) => Some(OwnedValue::Float(a - *b as f64)), + _ => None, + }, + ArithmeticOp::Multiply => match (left, right) { + (OwnedValue::Int(a), OwnedValue::Int(b)) => Some(OwnedValue::Int(a * b)), + (OwnedValue::Float(a), OwnedValue::Float(b)) => Some(OwnedValue::Float(a * b)), + (OwnedValue::Int(a), OwnedValue::Float(b)) => Some(OwnedValue::Float(*a as f64 * b)), + (OwnedValue::Float(a), OwnedValue::Int(b)) => Some(OwnedValue::Float(a * *b as f64)), + _ => None, + }, + ArithmeticOp::Divide => match (left, right) { + (OwnedValue::Int(a), OwnedValue::Int(b)) if *b != 0 => Some(OwnedValue::Int(a / b)), + (OwnedValue::Float(a), OwnedValue::Float(b)) if *b != 0.0 => Some(OwnedValue::Float(a / b)), + (OwnedValue::Int(a), OwnedValue::Float(b)) if *b != 0.0 => Some(OwnedValue::Float(*a as f64 / b)), + (OwnedValue::Float(a), OwnedValue::Int(b)) if *b != 0 => Some(OwnedValue::Float(a / *b as f64)), + _ => None, + }, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ArithmeticOp { + Plus, + Minus, + Multiply, + Divide, } mod hex { diff --git a/tests/regression_smoke_test.rs b/tests/regression_smoke_test.rs index 86303c9..b354cc5 100644 --- a/tests/regression_smoke_test.rs +++ b/tests/regression_smoke_test.rs @@ -1377,6 +1377,196 @@ mod real_world_scenario { } } + #[test] + fn update_with_subquery_sum_computes_total_correctly() { + let (db, _dir) = create_test_db(); + + db.execute("CREATE TABLE orders (id INTEGER PRIMARY KEY, total REAL)") + .unwrap(); + db.execute("CREATE TABLE order_items (order_id INTEGER, amount REAL)") + .unwrap(); + + db.execute("INSERT INTO orders (id, total) VALUES (1, 0.0)") + .unwrap(); + db.execute("INSERT INTO order_items (order_id, amount) VALUES (1, 25.0)") + .unwrap(); + db.execute("INSERT INTO order_items (order_id, amount) VALUES (1, 75.0)") + .unwrap(); + + db.execute( + "UPDATE orders SET total = (SELECT SUM(amount) FROM order_items) WHERE id = 1", + ) + .unwrap(); + + let rows = db.query("SELECT total FROM orders WHERE id = 1").unwrap(); + match &rows[0].values[0] { + OwnedValue::Float(v) => assert!((*v - 100.0).abs() < 0.01, "Expected 100.0, got {}", v), + other => panic!("Expected Float, got {:?}", other), + } + } + + #[test] + fn update_with_subquery_avg_computes_average_correctly() { + let (db, _dir) = create_test_db(); + + db.execute("CREATE TABLE stats (id INTEGER PRIMARY KEY, avg_value REAL)") + .unwrap(); + db.execute("CREATE TABLE values_table (value REAL)") + .unwrap(); + + db.execute("INSERT INTO stats (id, avg_value) VALUES (1, 0.0)") + .unwrap(); + db.execute("INSERT INTO values_table (value) VALUES (10.0)") + .unwrap(); + db.execute("INSERT INTO values_table (value) VALUES (20.0)") + .unwrap(); + db.execute("INSERT INTO values_table (value) VALUES (30.0)") + .unwrap(); + + db.execute( + "UPDATE stats SET avg_value = (SELECT AVG(value) FROM values_table) WHERE id = 1", + ) + .unwrap(); + + let rows = db + .query("SELECT avg_value FROM stats WHERE id = 1") + .unwrap(); + match &rows[0].values[0] { + OwnedValue::Float(v) => assert!((*v - 20.0).abs() < 0.01, "Expected 20.0, got {}", v), + other => panic!("Expected Float, got {:?}", other), + } + } + + #[test] + fn update_with_multiple_subqueries_in_set_clause() { + let (db, _dir) = create_test_db(); + + db.execute("CREATE TABLE summary (id INTEGER PRIMARY KEY, min_val REAL, max_val REAL, sum_val REAL)") + .unwrap(); + db.execute("CREATE TABLE data_points (value REAL)") + .unwrap(); + + db.execute("INSERT INTO summary (id, min_val, max_val, sum_val) VALUES (1, 0, 0, 0)") + .unwrap(); + db.execute("INSERT INTO data_points (value) VALUES (10.0)") + .unwrap(); + db.execute("INSERT INTO data_points (value) VALUES (20.0)") + .unwrap(); + db.execute("INSERT INTO data_points (value) VALUES (30.0)") + .unwrap(); + + db.execute( + "UPDATE summary SET + min_val = (SELECT MIN(value) FROM data_points), + max_val = (SELECT MAX(value) FROM data_points), + sum_val = (SELECT SUM(value) FROM data_points) + WHERE id = 1", + ) + .unwrap(); + + let rows = db + .query("SELECT min_val, max_val, sum_val FROM summary WHERE id = 1") + .unwrap(); + + match &rows[0].values[0] { + OwnedValue::Float(v) => assert!((*v - 10.0).abs() < 0.01, "Expected min 10.0, got {}", v), + other => panic!("Expected Float for min_val, got {:?}", other), + } + match &rows[0].values[1] { + OwnedValue::Float(v) => assert!((*v - 30.0).abs() < 0.01, "Expected max 30.0, got {}", v), + other => panic!("Expected Float for max_val, got {:?}", other), + } + match &rows[0].values[2] { + OwnedValue::Float(v) => assert!((*v - 60.0).abs() < 0.01, "Expected sum 60.0, got {}", v), + other => panic!("Expected Float for sum_val, got {:?}", other), + } + } + + #[test] + fn update_with_subquery_where_filter_applied_correctly() { + let (db, _dir) = create_test_db(); + + db.execute("CREATE TABLE orders (id INTEGER PRIMARY KEY, total REAL)") + .unwrap(); + db.execute("CREATE TABLE order_items (order_id INTEGER, amount REAL)") + .unwrap(); + + db.execute("INSERT INTO orders (id, total) VALUES (1, 0.0)") + .unwrap(); + db.execute("INSERT INTO orders (id, total) VALUES (2, 0.0)") + .unwrap(); + db.execute("INSERT INTO order_items (order_id, amount) VALUES (1, 10.0)") + .unwrap(); + db.execute("INSERT INTO order_items (order_id, amount) VALUES (1, 15.0)") + .unwrap(); + db.execute("INSERT INTO order_items (order_id, amount) VALUES (2, 100.0)") + .unwrap(); + db.execute("INSERT INTO order_items (order_id, amount) VALUES (2, 200.0)") + .unwrap(); + + db.execute( + "UPDATE orders SET total = (SELECT SUM(amount) FROM order_items WHERE order_id = 1) WHERE id = 1", + ) + .unwrap(); + + let rows = db.query("SELECT total FROM orders WHERE id = 1").unwrap(); + match &rows[0].values[0] { + OwnedValue::Float(v) => { + assert!( + (*v - 25.0).abs() < 0.01, + "Expected 25.0 (sum of order_id=1 items only), got {}. WHERE filter may not be applied.", + v + ); + } + other => panic!("Expected Float, got {:?}", other), + } + } + + #[test] + fn update_with_subquery_secondary_index_where_filter_applied() { + let (db, _dir) = create_test_db(); + + db.execute("CREATE TABLE products (id INTEGER PRIMARY KEY, category INTEGER, price REAL)") + .unwrap(); + db.execute("CREATE TABLE summary (category INTEGER PRIMARY KEY, total_price REAL)") + .unwrap(); + db.execute("CREATE INDEX idx_products_category ON products(category)") + .unwrap(); + + db.execute("INSERT INTO products (id, category, price) VALUES (1, 1, 10.0)") + .unwrap(); + db.execute("INSERT INTO products (id, category, price) VALUES (2, 1, 20.0)") + .unwrap(); + db.execute("INSERT INTO products (id, category, price) VALUES (3, 2, 100.0)") + .unwrap(); + db.execute("INSERT INTO products (id, category, price) VALUES (4, 2, 200.0)") + .unwrap(); + + db.execute("INSERT INTO summary (category, total_price) VALUES (1, 0.0)") + .unwrap(); + db.execute("INSERT INTO summary (category, total_price) VALUES (2, 0.0)") + .unwrap(); + + db.execute( + "UPDATE summary SET total_price = (SELECT SUM(price) FROM products WHERE category = 1) WHERE category = 1", + ) + .unwrap(); + + let rows = db + .query("SELECT total_price FROM summary WHERE category = 1") + .unwrap(); + match &rows[0].values[0] { + OwnedValue::Float(v) => { + assert!( + (*v - 30.0).abs() < 0.01, + "Expected 30.0 (sum of category 1: 10+20), got {}. SecondaryIndexScan WHERE filter may not be applied.", + v + ); + } + other => panic!("Expected Float, got {:?}", other), + } + } + #[test] fn blog_platform_scenario() { let (db, _dir) = create_test_db();