diff --git a/dozer-sql/expression/src/execution.rs b/dozer-sql/expression/src/execution.rs index 1c2891b234..1249edaead 100644 --- a/dozer-sql/expression/src/execution.rs +++ b/dozer-sql/expression/src/execution.rs @@ -436,7 +436,7 @@ impl Expression { negated: _, } => Ok(ExpressionType::new( FieldType::Boolean, - false, + true, SourceDefinition::Dynamic, false, )), diff --git a/dozer-sql/expression/src/in_list.rs b/dozer-sql/expression/src/in_list.rs index d538214d65..b1e3cdc215 100644 --- a/dozer-sql/expression/src/in_list.rs +++ b/dozer-sql/expression/src/in_list.rs @@ -12,14 +12,28 @@ pub(crate) fn evaluate_in_list( record: &Record, ) -> Result { let field = expr.evaluate(record, schema)?; + if field == Field::Null { + return Ok(Field::Null); + } + let mut result = false; + let mut saw_null = false; for item in list { let item = item.evaluate(record, schema)?; + if item == Field::Null { + saw_null = true; + continue; + } + if field == item { result = true; break; } } + if !result && saw_null { + return Ok(Field::Null); + } + // Negate the result if the IN list was negated. if negated { result = !result; diff --git a/dozer-sql/src/builder/mod.rs b/dozer-sql/src/builder/mod.rs index 3ccf3ed356..445df420f2 100644 --- a/dozer-sql/src/builder/mod.rs +++ b/dozer-sql/src/builder/mod.rs @@ -6,7 +6,10 @@ use dozer_core::app::AppPipeline; use dozer_core::node::PortHandle; use dozer_core::DEFAULT_PORT_HANDLE; use dozer_sql_expression::builder::{ExpressionBuilder, NameOrAlias}; -use dozer_sql_expression::sqlparser::ast::{SetOperator, SetQuantifier, TableFactor}; +use dozer_sql_expression::sqlparser::ast::{ + BinaryOperator, Expr as SqlExpr, Ident, Join, JoinConstraint, JoinOperator, SelectItem, + SetOperator, SetQuantifier, TableAlias, TableFactor, TableWithJoins, +}; use dozer_types::models::udf_config::UdfConfig; use dozer_sql_expression::sqlparser::{ @@ -21,7 +24,9 @@ use tokio::runtime::Runtime; use super::errors::UnsupportedSqlError; -use super::product::set::set_factory::SetProcessorFactory; +use super::product::set::set_factory::{DedupProcessorFactory, SetProcessorFactory}; + +const IN_SUBQUERY_ALIAS_PREFIX: &str = "__dozer_in_subquery_"; #[derive(Debug, Clone)] pub struct OutputNodeInfo { @@ -236,19 +241,23 @@ fn query_to_pipeline( fn select_to_pipeline( table_info: TableInfo, - select: Select, + mut select: Select, pipeline: &mut AppPipeline, query_ctx: &mut QueryContext, pipeline_idx: usize, is_top_select: bool, ) -> Result { // FROM clause - let Some(from) = select.from.into_iter().next() else { + let Some(mut from) = select.from.into_iter().next() else { return Err(PipelineError::UnsupportedSqlError( UnsupportedSqlError::FromCommaSyntax, )); }; + if let Some(selection) = select.selection.take() { + select.selection = rewrite_in_subquery_selection_as_join(selection, &mut from, query_ctx)?; + } + let connection_info = from::insert_from_to_pipeline(from, pipeline, pipeline_idx, query_ctx)?; let input_nodes = connection_info.input_nodes; @@ -360,6 +369,123 @@ fn select_to_pipeline( Ok(gen_agg_name) } +fn rewrite_in_subquery_selection_as_join( + selection: SqlExpr, + from: &mut TableWithJoins, + query_ctx: &mut QueryContext, +) -> Result, PipelineError> { + match selection { + SqlExpr::BinaryOp { + left, + op: BinaryOperator::And, + right, + } => { + let left = rewrite_in_subquery_selection_as_join(*left, from, query_ctx)?; + let right = rewrite_in_subquery_selection_as_join(*right, from, query_ctx)?; + + match (left, right) { + (Some(left), Some(right)) => Ok(Some(SqlExpr::BinaryOp { + left: Box::new(left), + op: BinaryOperator::And, + right: Box::new(right), + })), + (Some(selection), None) | (None, Some(selection)) => Ok(Some(selection)), + (None, None) => Ok(None), + } + } + SqlExpr::Nested(expr) => rewrite_in_subquery_selection_as_join(*expr, from, query_ctx) + .map(|selection| selection.map(|expr| SqlExpr::Nested(Box::new(expr)))), + SqlExpr::InSubquery { + expr, + subquery, + negated: false, + } => { + let subquery_field = in_subquery_projection_field(&subquery)?; + let expr = qualify_unqualified_outer_identifier(expr, from)?; + let alias = format!( + "{}{}", + IN_SUBQUERY_ALIAS_PREFIX, + query_ctx.get_next_processor_id() + ); + let join_constraint = SqlExpr::BinaryOp { + left: expr, + op: BinaryOperator::Eq, + right: Box::new(SqlExpr::CompoundIdentifier(vec![ + Ident::new(alias.clone()), + subquery_field, + ])), + }; + + from.joins.push(Join { + relation: TableFactor::Derived { + lateral: false, + subquery, + alias: Some(TableAlias { + name: Ident::new(alias), + columns: vec![], + }), + }, + join_operator: JoinOperator::Inner(JoinConstraint::On(join_constraint)), + }); + + Ok(None) + } + SqlExpr::InSubquery { negated: true, .. } => Err(PipelineError::InvalidQuery( + "NOT IN subqueries are not supported".to_string(), + )), + other => Ok(Some(other)), + } +} + +fn qualify_unqualified_outer_identifier( + expr: Box, + from: &TableWithJoins, +) -> Result, PipelineError> { + match *expr { + SqlExpr::Identifier(ident) => { + let source_name_or_alias = common::get_name_or_alias(&from.relation)?; + let qualifier = source_name_or_alias.1.unwrap_or(source_name_or_alias.0); + + Ok(Box::new(SqlExpr::CompoundIdentifier(vec![ + Ident::new(qualifier), + ident, + ]))) + } + other => Ok(Box::new(other)), + } +} + +fn in_subquery_projection_field(query: &Query) -> Result { + let select = match query.body.as_ref() { + SetExpr::Select(select) => select, + SetExpr::Query(query) => return in_subquery_projection_field(query), + _ => { + return Err(PipelineError::InvalidQuery( + "IN subquery must project exactly one column".to_string(), + )) + } + }; + + let [projection] = select.projection.as_slice() else { + return Err(PipelineError::InvalidQuery( + "IN subquery must project exactly one column".to_string(), + )); + }; + + match projection { + SelectItem::UnnamedExpr(SqlExpr::Identifier(ident)) => Ok(ident.clone()), + SelectItem::UnnamedExpr(SqlExpr::CompoundIdentifier(idents)) => { + idents.last().cloned().ok_or_else(|| { + PipelineError::InvalidQuery("IN subquery projection must name a column".to_string()) + }) + } + SelectItem::ExprWithAlias { alias, .. } => Ok(alias.clone()), + _ => Err(PipelineError::InvalidQuery( + "IN subquery projection must name a column".to_string(), + )), + } +} + #[allow(clippy::too_many_arguments)] fn set_to_pipeline( table_info: TableInfo, @@ -530,6 +656,10 @@ fn get_from_source( let alias_name = alias.as_ref().map(|alias_ident| { ExpressionBuilder::fullname_from_ident(&[alias_ident.name.clone()]) }); + let should_dedup_output = alias_name + .as_ref() + .map(|name| name.starts_with(IN_SUBQUERY_ALIAS_PREFIX)) + .unwrap_or(false); let is_top_select = false; //inside FROM clause, so not top select let name_or = NameOrAlias(name, alias_name); query_to_pipeline( @@ -543,6 +673,9 @@ fn get_from_source( pipeline_idx, is_top_select, )?; + if should_dedup_output { + insert_dedup_processor_to_pipeline(&name_or.0, pipeline, query_ctx, pipeline_idx)?; + } Ok(name_or) } @@ -552,6 +685,57 @@ fn get_from_source( } } +fn insert_dedup_processor_to_pipeline( + table_name: &str, + pipeline: &mut AppPipeline, + query_ctx: &mut QueryContext, + pipeline_idx: usize, +) -> Result<(), PipelineError> { + let output_node = query_ctx + .pipeline_map + .get(&(pipeline_idx, table_name.to_string())) + .cloned() + .ok_or_else(|| { + PipelineError::InvalidQuery(format!( + "Unable to deduplicate IN subquery source {table_name}" + )) + })?; + + let dedup_processor_name = format!("dedup--{}", query_ctx.get_next_processor_id()); + if !query_ctx + .processors_list + .insert(dedup_processor_name.clone()) + { + return Err(PipelineError::ProcessorAlreadyExists(dedup_processor_name)); + } + + let dedup_processor = DedupProcessorFactory::new( + dedup_processor_name.clone(), + pipeline + .flags() + .enable_probabilistic_optimizations + .in_sets + .unwrap_or(false), + ); + pipeline.add_processor(Box::new(dedup_processor), dedup_processor_name.clone()); + pipeline.connect_nodes( + output_node.node, + output_node.port, + dedup_processor_name.clone(), + DEFAULT_PORT_HANDLE, + ); + + query_ctx.pipeline_map.insert( + (pipeline_idx, table_name.to_string()), + OutputNodeInfo { + node: dedup_processor_name, + port: DEFAULT_PORT_HANDLE, + }, + ); + + Ok(()) +} + #[derive(Clone, Debug)] struct ConnectionInfo { input_nodes: Vec<(String, String, PortHandle)>, diff --git a/dozer-sql/src/expression/tests/in_list.rs b/dozer-sql/src/expression/tests/in_list.rs index 0296f2210c..ff76dadef8 100644 --- a/dozer-sql/src/expression/tests/in_list.rs +++ b/dozer-sql/src/expression/tests/in_list.rs @@ -84,3 +84,21 @@ fn test_not_in_list() { ); assert_eq!(f, Field::Boolean(false)); } + +#[test] +fn test_in_list_with_nulls() { + let f = run_fct("SELECT 42 IN (NULL)", Schema::default(), vec![]); + assert_eq!(f, Field::Null); + + let f = run_fct("SELECT 42 IN (1, NULL)", Schema::default(), vec![]); + assert_eq!(f, Field::Null); + + let f = run_fct("SELECT 42 IN (1, NULL, 42)", Schema::default(), vec![]); + assert_eq!(f, Field::Boolean(true)); + + let f = run_fct("SELECT 42 NOT IN (1, NULL)", Schema::default(), vec![]); + assert_eq!(f, Field::Null); + + let f = run_fct("SELECT NULL IN (1, 2, 3)", Schema::default(), vec![]); + assert_eq!(f, Field::Null); +} diff --git a/dozer-sql/src/product/set/set_factory.rs b/dozer-sql/src/product/set/set_factory.rs index 6f4fd442f4..285f650568 100644 --- a/dozer-sql/src/product/set/set_factory.rs +++ b/dozer-sql/src/product/set/set_factory.rs @@ -23,6 +23,12 @@ pub struct SetProcessorFactory { enable_probabilistic_optimizations: bool, } +#[derive(Debug)] +pub struct DedupProcessorFactory { + id: String, + enable_probabilistic_optimizations: bool, +} + impl SetProcessorFactory { /// Creates a new [`FromProcessorFactory`]. pub fn new( @@ -38,6 +44,15 @@ impl SetProcessorFactory { } } +impl DedupProcessorFactory { + pub fn new(id: String, enable_probabilistic_optimizations: bool) -> Self { + Self { + id, + enable_probabilistic_optimizations, + } + } +} + #[async_trait] impl ProcessorFactory for SetProcessorFactory { fn id(&self) -> String { @@ -87,6 +102,52 @@ impl ProcessorFactory for SetProcessorFactory { } } +#[async_trait] +impl ProcessorFactory for DedupProcessorFactory { + fn id(&self) -> String { + self.id.clone() + } + + fn type_name(&self) -> String { + "Dedup".to_string() + } + + fn get_input_ports(&self) -> Vec { + vec![DEFAULT_PORT_HANDLE] + } + + fn get_output_ports(&self) -> Vec { + vec![DEFAULT_PORT_HANDLE] + } + + async fn get_output_schema( + &self, + _output_port: &PortHandle, + input_schemas: &HashMap, + ) -> Result { + input_schemas + .get(&DEFAULT_PORT_HANDLE) + .cloned() + .ok_or_else(|| PipelineError::InvalidPortHandle(DEFAULT_PORT_HANDLE).into()) + } + + async fn build( + &self, + _input_schemas: HashMap, + _output_schemas: HashMap, + _event_hub: EventHub, + ) -> Result, BoxedError> { + Ok(Box::new(SetProcessor::new( + self.id.clone(), + SetOperation { + op: SetOperator::Union, + quantifier: SetQuantifier::None, + }, + self.enable_probabilistic_optimizations, + )?)) + } +} + fn validate_set_operation_input_schemas( input_schemas: &HashMap, ) -> Result, PipelineError> { diff --git a/dozer-sql/src/tests/builder_test.rs b/dozer-sql/src/tests/builder_test.rs index 4d2b398255..540f6eace0 100644 --- a/dozer-sql/src/tests/builder_test.rs +++ b/dozer-sql/src/tests/builder_test.rs @@ -21,6 +21,7 @@ use tokio::sync::mpsc::Sender; use std::collections::HashMap; use std::future::pending; +use std::sync::{Arc, Mutex}; use crate::builder::statement_to_pipeline; use crate::tests::utils::create_test_runtime; @@ -219,6 +220,369 @@ impl Sink for TestSink { } } +#[derive(Debug)] +pub struct ScriptedSourceFactory { + output_ports: Vec, + operations: Vec<(PortHandle, Operation)>, +} + +impl ScriptedSourceFactory { + pub fn new(output_ports: Vec, operations: Vec<(PortHandle, Operation)>) -> Self { + Self { + output_ports, + operations, + } + } +} + +impl SourceFactory for ScriptedSourceFactory { + fn get_output_ports(&self) -> Vec { + self.output_ports + .iter() + .map(|e| OutputPortDef::new(*e, OutputPortType::Stateless)) + .collect() + } + + fn get_output_schema(&self, port: &PortHandle) -> Result { + let table_name = if *port == 1 { "allowed" } else { "users" }; + Ok(Schema::default() + .field( + FieldDefinition::new( + String::from("CustomerID"), + FieldType::Int, + false, + SourceDefinition::Table { + connection: "mem".to_string(), + name: table_name.to_string(), + }, + ), + false, + ) + .field( + FieldDefinition::new( + String::from("Country"), + FieldType::String, + false, + SourceDefinition::Table { + connection: "mem".to_string(), + name: table_name.to_string(), + }, + ), + false, + ) + .field( + FieldDefinition::new( + String::from("Spending"), + FieldType::Float, + false, + SourceDefinition::Table { + connection: "mem".to_string(), + name: table_name.to_string(), + }, + ), + false, + ) + .field( + FieldDefinition::new( + String::from("timestamp"), + FieldType::Timestamp, + false, + SourceDefinition::Table { + connection: "mem".to_string(), + name: table_name.to_string(), + }, + ), + false, + ) + .clone()) + } + + fn get_output_port_name(&self, port: &PortHandle) -> String { + format!("port_{}", port) + } + + fn build( + &self, + _output_schemas: HashMap, + _event_hub: EventHub, + _state: Option>, + ) -> Result, BoxedError> { + Ok(Box::new(ScriptedSource { + operations: self.operations.clone(), + })) + } +} + +#[derive(Debug)] +pub struct ScriptedSource { + operations: Vec<(PortHandle, Operation)>, +} + +#[async_trait] +impl Source for ScriptedSource { + async fn serialize_state(&self) -> Result, BoxedError> { + Ok(vec![]) + } + + async fn start( + &mut self, + sender: Sender<(PortHandle, IngestionMessage)>, + _last_checkpoint: Option, + ) -> Result<(), BoxedError> { + for (port, op) in self.operations.clone() { + sender + .send(( + port, + IngestionMessage::OperationEvent { + table_index: port as usize, + op, + id: None, + }, + )) + .await + .unwrap(); + } + Ok(()) + } +} + +#[derive(Debug)] +pub struct CollectingSinkFactory { + operations: Arc>>, +} + +impl CollectingSinkFactory { + pub fn new(operations: Arc>>) -> Self { + Self { operations } + } +} + +#[async_trait] +impl SinkFactory for CollectingSinkFactory { + fn get_input_ports(&self) -> Vec { + vec![DEFAULT_PORT_HANDLE] + } + + fn get_input_port_name(&self, _port: &PortHandle) -> String { + "test".to_string() + } + + async fn build( + &self, + _input_schemas: HashMap, + _event_hub: EventHub, + ) -> Result, BoxedError> { + Ok(Box::new(CollectingSink { + operations: self.operations.clone(), + })) + } + + fn prepare(&self, _input_schemas: HashMap) -> Result<(), BoxedError> { + Ok(()) + } + + fn type_name(&self) -> String { + "test".to_string() + } +} + +#[derive(Debug)] +pub struct CollectingSink { + operations: Arc>>, +} + +impl Sink for CollectingSink { + fn process(&mut self, op: TableOperation) -> Result<(), BoxedError> { + self.operations.lock().unwrap().push(op); + Ok(()) + } + + fn commit(&mut self, _epoch_details: &Epoch) -> Result<(), BoxedError> { + Ok(()) + } + + fn on_source_snapshotting_started( + &mut self, + _connection_name: String, + ) -> Result<(), BoxedError> { + Ok(()) + } + + fn on_source_snapshotting_done( + &mut self, + _connection_name: String, + _id: Option, + ) -> Result<(), BoxedError> { + Ok(()) + } + + fn set_source_state(&mut self, _source_state: &[u8]) -> Result<(), BoxedError> { + Ok(()) + } + + fn get_source_state(&mut self) -> Result>, BoxedError> { + Ok(None) + } + + fn get_latest_op_id(&mut self) -> Result, BoxedError> { + Ok(None) + } +} + +fn scripted_record(customer_id: i64, country: &str, spending: f64, timestamp: &str) -> Record { + Record::new(vec![ + Field::Int(customer_id), + Field::String(country.to_string()), + Field::Float(OrderedFloat(spending)), + Field::Timestamp(DateTime::parse_from_rfc3339(timestamp).unwrap()), + ]) +} + +fn scripted_insert(record: Record) -> Operation { + Operation::Insert { new: record } +} + +fn scripted_update(old: Record, new: Record) -> Operation { + Operation::Update { old, new } +} + +fn scripted_delete(record: Record) -> Operation { + Operation::Delete { old: record } +} + +fn execute_scripted_query_operations( + sql: &str, + operations: Vec<(PortHandle, Operation)>, +) -> Vec { + let mut pipeline = AppPipeline::new_with_default_flags(); + let runtime = create_test_runtime(); + let context = statement_to_pipeline(sql, &mut pipeline, None, vec![], runtime.clone()).unwrap(); + + let table_info = context.output_tables_map.get("results").unwrap(); + let output_operations = Arc::new(Mutex::new(vec![])); + + let mut asm = AppSourceManager::new(); + asm.add( + Box::new(ScriptedSourceFactory::new( + vec![DEFAULT_PORT_HANDLE, 1], + operations, + )), + AppSourceMappings::new( + "mem".to_string(), + vec![ + ("users".to_string(), DEFAULT_PORT_HANDLE), + ("allowed".to_string(), 1), + ] + .into_iter() + .collect(), + ), + ) + .unwrap(); + + pipeline.add_sink( + Box::new(CollectingSinkFactory::new(output_operations.clone())), + "sink".to_string(), + ); + pipeline.connect_nodes( + table_info.node.clone(), + table_info.port, + "sink".to_string(), + DEFAULT_PORT_HANDLE, + ); + + let mut app = App::new(asm); + app.add_pipeline(pipeline); + + let dag = app.into_dag().unwrap(); + let runtime_clone = runtime.clone(); + let handle = runtime.block_on(async move { + DagExecutor::new(dag, Default::default()) + .await + .unwrap() + .start(pending::<()>(), Default::default(), runtime_clone) + .await + .unwrap() + }); + handle.join().unwrap(); + + output_operations + .lock() + .unwrap() + .iter() + .map(|op| op.op.clone()) + .collect() +} + +fn execute_scripted_query(sql: &str, operations: Vec<(PortHandle, Operation)>) -> Vec> { + execute_scripted_query_operations(sql, operations) + .into_iter() + .filter_map(|op| match op { + Operation::Insert { new } => Some(new.values), + _ => None, + }) + .collect() +} + +#[test] +fn test_static_in_list_applies_streaming_update_delete_semantics() { + let operations = execute_scripted_query_operations( + "SELECT users.CustomerID \ + INTO results \ + FROM users \ + WHERE users.CustomerID IN (7, 9)", + vec![ + ( + DEFAULT_PORT_HANDLE, + scripted_insert(scripted_record(8, "France", 7.0, "2020-01-01T00:10:00Z")), + ), + ( + DEFAULT_PORT_HANDLE, + scripted_update( + scripted_record(8, "France", 7.0, "2020-01-01T00:10:00Z"), + scripted_record(7, "Italy", 5.5, "2020-01-01T00:13:00Z"), + ), + ), + ( + DEFAULT_PORT_HANDLE, + scripted_update( + scripted_record(7, "Italy", 5.5, "2020-01-01T00:13:00Z"), + scripted_record(9, "Spain", 9.0, "2020-01-01T00:14:00Z"), + ), + ), + ( + DEFAULT_PORT_HANDLE, + scripted_update( + scripted_record(9, "Spain", 9.0, "2020-01-01T00:14:00Z"), + scripted_record(8, "France", 7.0, "2020-01-01T00:15:00Z"), + ), + ), + ( + DEFAULT_PORT_HANDLE, + scripted_delete(scripted_record(7, "Italy", 5.5, "2020-01-01T00:13:00Z")), + ), + ], + ); + + assert_eq!( + operations, + vec![ + Operation::Insert { + new: Record::new(vec![Field::Int(7)]), + }, + Operation::Update { + old: Record::new(vec![Field::Int(7)]), + new: Record::new(vec![Field::Int(9)]), + }, + Operation::Delete { + old: Record::new(vec![Field::Int(9)]), + }, + Operation::Delete { + old: Record::new(vec![Field::Int(7)]), + }, + ] + ); +} + #[test] fn test_pipeline_builder() { let mut pipeline = AppPipeline::new_with_default_flags(); @@ -280,3 +644,235 @@ fn test_pipeline_builder() { let elapsed = now.elapsed(); debug!("Elapsed: {:.2?}", elapsed); } + +#[test] +fn test_in_subquery_where_clause_builds_pipeline() { + let mut pipeline = AppPipeline::new_with_default_flags(); + let runtime = create_test_runtime(); + let context = statement_to_pipeline( + "SELECT users.CustomerID \ + INTO results \ + FROM users \ + WHERE users.CustomerID IN (SELECT allowed.CustomerID FROM allowed)", + &mut pipeline, + None, + vec![], + runtime, + ) + .unwrap(); + + assert!(context.output_tables_map.contains_key("results")); + assert!(context.used_sources.contains(&"users".to_string())); + assert!(context.used_sources.contains(&"allowed".to_string())); +} + +#[test] +fn test_in_subquery_keeps_additional_where_predicates() { + let mut pipeline = AppPipeline::new_with_default_flags(); + let runtime = create_test_runtime(); + let context = statement_to_pipeline( + "SELECT users.CustomerID \ + INTO results \ + FROM users \ + WHERE users.Spending > 10 \ + AND users.CustomerID IN (SELECT allowed.CustomerID FROM allowed)", + &mut pipeline, + None, + vec![], + runtime, + ) + .unwrap(); + + assert!(context.output_tables_map.contains_key("results")); + assert!(context.used_sources.contains(&"users".to_string())); + assert!(context.used_sources.contains(&"allowed".to_string())); +} + +#[test] +fn test_in_subquery_rejects_multi_column_projection() { + let mut pipeline = AppPipeline::new_with_default_flags(); + let runtime = create_test_runtime(); + let result = statement_to_pipeline( + "SELECT users.CustomerID \ + INTO results \ + FROM users \ + WHERE users.CustomerID IN (SELECT allowed.CustomerID, allowed.Country FROM allowed)", + &mut pipeline, + None, + vec![], + runtime, + ); + + assert!(result.is_err()); +} + +#[test] +fn test_in_subquery_filters_stream_with_inner_select_membership() { + let rows = execute_scripted_query( + "SELECT users.CustomerID \ + INTO results \ + FROM users \ + WHERE users.CustomerID IN (SELECT allowed.CustomerID FROM allowed)", + vec![ + ( + 1, + scripted_insert(scripted_record(7, "Allowed", 0.0, "2020-01-01T00:00:00Z")), + ), + ( + DEFAULT_PORT_HANDLE, + scripted_insert(scripted_record(7, "Italy", 5.5, "2020-01-01T00:13:00Z")), + ), + ( + DEFAULT_PORT_HANDLE, + scripted_insert(scripted_record(8, "France", 7.0, "2020-01-01T00:14:00Z")), + ), + ], + ); + + assert_eq!(rows, vec![vec![Field::Int(7)]]); +} + +#[test] +fn test_in_subquery_emits_when_inner_membership_arrives_later() { + let rows = execute_scripted_query( + "SELECT users.CustomerID \ + INTO results \ + FROM users \ + WHERE users.CustomerID IN (SELECT allowed.CustomerID FROM allowed)", + vec![ + ( + DEFAULT_PORT_HANDLE, + scripted_insert(scripted_record(7, "Italy", 5.5, "2020-01-01T00:13:00Z")), + ), + ( + 1, + scripted_insert(scripted_record(7, "Allowed", 0.0, "2020-01-01T00:14:00Z")), + ), + ], + ); + + assert_eq!(rows, vec![vec![Field::Int(7)]]); +} + +#[test] +fn test_in_subquery_qualifies_unqualified_outer_identifier() { + let rows = execute_scripted_query( + "SELECT users.CustomerID \ + INTO results \ + FROM users \ + WHERE CustomerID IN (SELECT allowed.CustomerID FROM allowed)", + vec![ + ( + 1, + scripted_insert(scripted_record(7, "Allowed", 0.0, "2020-01-01T00:00:00Z")), + ), + ( + DEFAULT_PORT_HANDLE, + scripted_insert(scripted_record(7, "Italy", 5.5, "2020-01-01T00:13:00Z")), + ), + ], + ); + + assert_eq!(rows, vec![vec![Field::Int(7)]]); +} + +#[test] +fn test_in_subquery_retains_remaining_where_predicates() { + let rows = execute_scripted_query( + "SELECT users.CustomerID \ + INTO results \ + FROM users \ + WHERE users.Spending > 6 \ + AND users.CustomerID IN (SELECT allowed.CustomerID FROM allowed)", + vec![ + ( + 1, + scripted_insert(scripted_record(7, "Allowed", 0.0, "2020-01-01T00:00:00Z")), + ), + ( + 1, + scripted_insert(scripted_record(8, "Allowed", 0.0, "2020-01-01T00:00:01Z")), + ), + ( + DEFAULT_PORT_HANDLE, + scripted_insert(scripted_record(7, "Italy", 5.5, "2020-01-01T00:13:00Z")), + ), + ( + DEFAULT_PORT_HANDLE, + scripted_insert(scripted_record(8, "France", 7.0, "2020-01-01T00:14:00Z")), + ), + ], + ); + + assert_eq!(rows, vec![vec![Field::Int(8)]]); +} + +#[test] +fn test_in_subquery_uses_membership_semantics_for_duplicate_inner_rows() { + let rows = execute_scripted_query( + "SELECT users.CustomerID \ + INTO results \ + FROM users \ + WHERE users.CustomerID IN (SELECT allowed.CustomerID FROM allowed)", + vec![ + ( + 1, + scripted_insert(scripted_record(7, "Allowed", 0.0, "2020-01-01T00:00:00Z")), + ), + ( + 1, + scripted_insert(scripted_record(7, "Allowed", 0.0, "2020-01-01T00:00:01Z")), + ), + ( + DEFAULT_PORT_HANDLE, + scripted_insert(scripted_record(7, "Italy", 5.5, "2020-01-01T00:13:00Z")), + ), + ], + ); + + assert_eq!(rows, vec![vec![Field::Int(7)]]); +} + +#[test] +fn test_in_subquery_keeps_membership_until_last_duplicate_inner_row_is_deleted() { + let operations = execute_scripted_query_operations( + "SELECT users.CustomerID \ + INTO results \ + FROM users \ + WHERE users.CustomerID IN (SELECT allowed.CustomerID FROM allowed)", + vec![ + ( + 1, + scripted_insert(scripted_record(7, "Allowed", 0.0, "2020-01-01T00:00:00Z")), + ), + ( + 1, + scripted_insert(scripted_record(7, "Allowed", 0.0, "2020-01-01T00:00:00Z")), + ), + ( + DEFAULT_PORT_HANDLE, + scripted_insert(scripted_record(7, "Italy", 5.5, "2020-01-01T00:13:00Z")), + ), + ( + 1, + scripted_delete(scripted_record(7, "Allowed", 0.0, "2020-01-01T00:00:00Z")), + ), + ( + 1, + scripted_delete(scripted_record(7, "Allowed", 0.0, "2020-01-01T00:00:00Z")), + ), + ], + ); + + assert_eq!( + operations, + vec![ + Operation::Insert { + new: Record::new(vec![Field::Int(7)]), + }, + Operation::Delete { + old: Record::new(vec![Field::Int(7)]), + }, + ] + ); +} diff --git a/dozer-tests/src/sql_tests/full/in_clause.test b/dozer-tests/src/sql_tests/full/in_clause.test new file mode 100644 index 0000000000..fb333c42ce --- /dev/null +++ b/dozer-tests/src/sql_tests/full/in_clause.test @@ -0,0 +1,91 @@ +control sortmode rowsort + + +statement ok +CREATE TABLE actor( + actor_id integer NOT NULL, + first_name text NOT NULL +) + +statement ok +CREATE TABLE allowed_actor( + actor_id integer NOT NULL +) + +statement ok +INSERT INTO actor(actor_id, first_name) VALUES (1, 'penelope'); + +statement ok +INSERT INTO actor(actor_id, first_name) VALUES (2, 'jack'); + +statement ok +INSERT INTO actor(actor_id, first_name) VALUES (3, 'angelina'); + +statement ok +INSERT INTO actor(actor_id, first_name) VALUES (4, 'tom'); + +statement ok +INSERT INTO allowed_actor(actor_id) VALUES (1); + +statement ok +INSERT INTO allowed_actor(actor_id) VALUES (3); + +statement ok +INSERT INTO allowed_actor(actor_id) VALUES (3); + +query I +select actor_id from actor where actor_id in (1, 3) +---- +1 +3 + +query IT +select actor_id, first_name from actor where actor_id in (select actor_id from allowed_actor) +---- +1 penelope +3 angelina + +statement ok +INSERT INTO allowed_actor(actor_id) VALUES (2); + +query IT +select actor_id, first_name from actor where actor_id in (select actor_id from allowed_actor) and first_name != 'jack' +---- +1 penelope +3 angelina + +statement ok +DELETE FROM allowed_actor WHERE rowid = (SELECT min(rowid) FROM allowed_actor WHERE actor_id = 3); + +query IT +select actor_id, first_name from actor where actor_id in (select actor_id from allowed_actor) +---- +1 penelope +2 jack +3 angelina + +statement ok +DELETE FROM allowed_actor WHERE actor_id = 3; + +query IT +select actor_id, first_name from actor where actor_id in (select actor_id from allowed_actor) +---- +1 penelope +2 jack + +statement ok +UPDATE actor SET actor_id = 30 WHERE actor_id = 2; + +query IT +select actor_id, first_name from actor where actor_id in (select actor_id from allowed_actor) +---- +1 penelope + +statement ok +INSERT INTO allowed_actor(actor_id) VALUES (30); + +query IT +select actor_id, first_name from actor where actor_id in (select actor_id from allowed_actor) +---- +1 penelope +30 jack diff --git a/dozer-tests/src/sql_tests/prototype/in_clause.test b/dozer-tests/src/sql_tests/prototype/in_clause.test new file mode 100644 index 0000000000..1acdf92527 --- /dev/null +++ b/dozer-tests/src/sql_tests/prototype/in_clause.test @@ -0,0 +1,70 @@ +control sortmode rowsort + + +statement ok +CREATE TABLE actor( + actor_id integer NOT NULL, + first_name text NOT NULL +) + +statement ok +CREATE TABLE allowed_actor( + actor_id integer NOT NULL +) + +statement ok +INSERT INTO actor(actor_id, first_name) VALUES (1, 'penelope'); + +statement ok +INSERT INTO actor(actor_id, first_name) VALUES (2, 'jack'); + +statement ok +INSERT INTO actor(actor_id, first_name) VALUES (3, 'angelina'); + +statement ok +INSERT INTO actor(actor_id, first_name) VALUES (4, 'tom'); + +statement ok +INSERT INTO allowed_actor(actor_id) VALUES (1); + +statement ok +INSERT INTO allowed_actor(actor_id) VALUES (3); + +statement ok +INSERT INTO allowed_actor(actor_id) VALUES (3); + +query I +select actor_id from actor where actor_id in (1, 3) + +query IT +select actor_id, first_name from actor where actor_id in (select actor_id from allowed_actor) + +statement ok +INSERT INTO allowed_actor(actor_id) VALUES (2); + +query IT +select actor_id, first_name from actor where actor_id in (select actor_id from allowed_actor) and first_name != 'jack' + +statement ok +DELETE FROM allowed_actor WHERE rowid = (SELECT min(rowid) FROM allowed_actor WHERE actor_id = 3); + +query IT +select actor_id, first_name from actor where actor_id in (select actor_id from allowed_actor) + +statement ok +DELETE FROM allowed_actor WHERE actor_id = 3; + +query IT +select actor_id, first_name from actor where actor_id in (select actor_id from allowed_actor) + +statement ok +UPDATE actor SET actor_id = 30 WHERE actor_id = 2; + +query IT +select actor_id, first_name from actor where actor_id in (select actor_id from allowed_actor) + +statement ok +INSERT INTO allowed_actor(actor_id) VALUES (30); + +query IT +select actor_id, first_name from actor where actor_id in (select actor_id from allowed_actor)