diff --git a/dozer-sql/src/builder/mod.rs b/dozer-sql/src/builder/mod.rs index 3ccf3ed356..d3ff485e85 100644 --- a/dozer-sql/src/builder/mod.rs +++ b/dozer-sql/src/builder/mod.rs @@ -6,7 +6,10 @@ use dozer_core::app::AppPipeline; use dozer_core::node::PortHandle; use dozer_core::DEFAULT_PORT_HANDLE; use dozer_sql_expression::builder::{ExpressionBuilder, NameOrAlias}; -use dozer_sql_expression::sqlparser::ast::{SetOperator, SetQuantifier, TableFactor}; +use dozer_sql_expression::sqlparser::ast::{ + BinaryOperator, Expr as SqlExpr, Ident, Join, JoinConstraint, JoinOperator, SelectItem, + SetOperator, SetQuantifier, TableAlias, TableFactor, TableWithJoins, +}; use dozer_types::models::udf_config::UdfConfig; use dozer_sql_expression::sqlparser::{ @@ -21,7 +24,9 @@ use tokio::runtime::Runtime; use super::errors::UnsupportedSqlError; -use super::product::set::set_factory::SetProcessorFactory; +use super::product::set::set_factory::{DedupProcessorFactory, SetProcessorFactory}; + +const IN_SUBQUERY_ALIAS_PREFIX: &str = "__dozer_in_subquery_"; #[derive(Debug, Clone)] pub struct OutputNodeInfo { @@ -236,19 +241,23 @@ fn query_to_pipeline( fn select_to_pipeline( table_info: TableInfo, - select: Select, + mut select: Select, pipeline: &mut AppPipeline, query_ctx: &mut QueryContext, pipeline_idx: usize, is_top_select: bool, ) -> Result { // FROM clause - let Some(from) = select.from.into_iter().next() else { + let Some(mut from) = select.from.into_iter().next() else { return Err(PipelineError::UnsupportedSqlError( UnsupportedSqlError::FromCommaSyntax, )); }; + if let Some(selection) = select.selection.take() { + select.selection = rewrite_in_subquery_selection_as_join(selection, &mut from, query_ctx)?; + } + let connection_info = from::insert_from_to_pipeline(from, pipeline, pipeline_idx, query_ctx)?; let input_nodes = connection_info.input_nodes; @@ -360,6 +369,117 @@ fn select_to_pipeline( Ok(gen_agg_name) } +fn rewrite_in_subquery_selection_as_join( + selection: SqlExpr, + from: &mut TableWithJoins, + query_ctx: &mut QueryContext, +) -> Result, PipelineError> { + match selection { + SqlExpr::BinaryOp { + left, + op: BinaryOperator::And, + right, + } => { + let left = rewrite_in_subquery_selection_as_join(*left, from, query_ctx)?; + let right = rewrite_in_subquery_selection_as_join(*right, from, query_ctx)?; + + match (left, right) { + (Some(left), Some(right)) => Ok(Some(SqlExpr::BinaryOp { + left: Box::new(left), + op: BinaryOperator::And, + right: Box::new(right), + })), + (Some(selection), None) | (None, Some(selection)) => Ok(Some(selection)), + (None, None) => Ok(None), + } + } + SqlExpr::InSubquery { + expr, + subquery, + negated: false, + } => { + let subquery_field = prepare_in_subquery_for_membership(&subquery)?; + let expr = qualify_unqualified_outer_identifier(expr, from)?; + let alias = format!( + "{}{}", + IN_SUBQUERY_ALIAS_PREFIX, + query_ctx.get_next_processor_id() + ); + let join_constraint = SqlExpr::BinaryOp { + left: expr, + op: BinaryOperator::Eq, + right: Box::new(SqlExpr::CompoundIdentifier(vec![ + Ident::new(alias.clone()), + subquery_field, + ])), + }; + + from.joins.push(Join { + relation: TableFactor::Derived { + lateral: false, + subquery, + alias: Some(TableAlias { + name: Ident::new(alias), + columns: vec![], + }), + }, + join_operator: JoinOperator::Inner(JoinConstraint::On(join_constraint)), + }); + + Ok(None) + } + other => Ok(Some(other)), + } +} + +fn qualify_unqualified_outer_identifier( + expr: Box, + from: &TableWithJoins, +) -> Result, PipelineError> { + let SqlExpr::Identifier(ident) = *expr else { + return Ok(expr); + }; + + let source_name_or_alias = common::get_name_or_alias(&from.relation)?; + let qualifier = source_name_or_alias.1.unwrap_or(source_name_or_alias.0); + + Ok(Box::new(SqlExpr::CompoundIdentifier(vec![ + Ident::new(qualifier), + ident, + ]))) +} + +fn prepare_in_subquery_for_membership(query: &Query) -> Result { + let select = match query.body.as_ref() { + SetExpr::Select(select) => select, + SetExpr::Query(query) => return prepare_in_subquery_for_membership(query), + _ => { + return Err(PipelineError::InvalidQuery( + "IN subquery must project exactly one column".to_string(), + )) + } + }; + + let [projection] = select.projection.as_slice() else { + return Err(PipelineError::InvalidQuery( + "IN subquery must project exactly one column".to_string(), + )); + }; + + match projection { + SelectItem::UnnamedExpr(SqlExpr::Identifier(ident)) => Ok(ident.clone()), + SelectItem::UnnamedExpr(SqlExpr::CompoundIdentifier(idents)) => { + idents.last().cloned().ok_or_else(|| { + PipelineError::InvalidQuery("IN subquery projection must name a column".to_string()) + }) + } + SelectItem::ExprWithAlias { alias, .. } => Ok(alias.clone()), + _ => Err(PipelineError::InvalidQuery( + "IN subquery projection must name a column".to_string(), + )), + } +} + #[allow(clippy::too_many_arguments)] fn set_to_pipeline( table_info: TableInfo, @@ -530,6 +650,10 @@ fn get_from_source( let alias_name = alias.as_ref().map(|alias_ident| { ExpressionBuilder::fullname_from_ident(&[alias_ident.name.clone()]) }); + let should_dedup_output = alias_name + .as_ref() + .map(|name| name.starts_with(IN_SUBQUERY_ALIAS_PREFIX)) + .unwrap_or(false); let is_top_select = false; //inside FROM clause, so not top select let name_or = NameOrAlias(name, alias_name); query_to_pipeline( @@ -543,6 +667,9 @@ fn get_from_source( pipeline_idx, is_top_select, )?; + if should_dedup_output { + insert_dedup_processor_to_pipeline(&name_or.0, pipeline, query_ctx, pipeline_idx)?; + } Ok(name_or) } @@ -552,6 +679,57 @@ fn get_from_source( } } +fn insert_dedup_processor_to_pipeline( + table_name: &str, + pipeline: &mut AppPipeline, + query_ctx: &mut QueryContext, + pipeline_idx: usize, +) -> Result<(), PipelineError> { + let output_node = query_ctx + .pipeline_map + .get(&(pipeline_idx, table_name.to_string())) + .cloned() + .ok_or_else(|| { + PipelineError::InvalidQuery(format!( + "Unable to deduplicate IN subquery source {table_name}" + )) + })?; + + let dedup_processor_name = format!("dedup--{}", query_ctx.get_next_processor_id()); + if !query_ctx + .processors_list + .insert(dedup_processor_name.clone()) + { + return Err(PipelineError::ProcessorAlreadyExists(dedup_processor_name)); + } + + let dedup_processor = DedupProcessorFactory::new( + dedup_processor_name.clone(), + pipeline + .flags() + .enable_probabilistic_optimizations + .in_sets + .unwrap_or(false), + ); + pipeline.add_processor(Box::new(dedup_processor), dedup_processor_name.clone()); + pipeline.connect_nodes( + output_node.node, + output_node.port, + dedup_processor_name.clone(), + DEFAULT_PORT_HANDLE, + ); + + query_ctx.pipeline_map.insert( + (pipeline_idx, table_name.to_string()), + OutputNodeInfo { + node: dedup_processor_name, + port: DEFAULT_PORT_HANDLE, + }, + ); + + Ok(()) +} + #[derive(Clone, Debug)] struct ConnectionInfo { input_nodes: Vec<(String, String, PortHandle)>, diff --git a/dozer-sql/src/product/set/set_factory.rs b/dozer-sql/src/product/set/set_factory.rs index 6f4fd442f4..285f650568 100644 --- a/dozer-sql/src/product/set/set_factory.rs +++ b/dozer-sql/src/product/set/set_factory.rs @@ -23,6 +23,12 @@ pub struct SetProcessorFactory { enable_probabilistic_optimizations: bool, } +#[derive(Debug)] +pub struct DedupProcessorFactory { + id: String, + enable_probabilistic_optimizations: bool, +} + impl SetProcessorFactory { /// Creates a new [`FromProcessorFactory`]. pub fn new( @@ -38,6 +44,15 @@ impl SetProcessorFactory { } } +impl DedupProcessorFactory { + pub fn new(id: String, enable_probabilistic_optimizations: bool) -> Self { + Self { + id, + enable_probabilistic_optimizations, + } + } +} + #[async_trait] impl ProcessorFactory for SetProcessorFactory { fn id(&self) -> String { @@ -87,6 +102,52 @@ impl ProcessorFactory for SetProcessorFactory { } } +#[async_trait] +impl ProcessorFactory for DedupProcessorFactory { + fn id(&self) -> String { + self.id.clone() + } + + fn type_name(&self) -> String { + "Dedup".to_string() + } + + fn get_input_ports(&self) -> Vec { + vec![DEFAULT_PORT_HANDLE] + } + + fn get_output_ports(&self) -> Vec { + vec![DEFAULT_PORT_HANDLE] + } + + async fn get_output_schema( + &self, + _output_port: &PortHandle, + input_schemas: &HashMap, + ) -> Result { + input_schemas + .get(&DEFAULT_PORT_HANDLE) + .cloned() + .ok_or_else(|| PipelineError::InvalidPortHandle(DEFAULT_PORT_HANDLE).into()) + } + + async fn build( + &self, + _input_schemas: HashMap, + _output_schemas: HashMap, + _event_hub: EventHub, + ) -> Result, BoxedError> { + Ok(Box::new(SetProcessor::new( + self.id.clone(), + SetOperation { + op: SetOperator::Union, + quantifier: SetQuantifier::None, + }, + self.enable_probabilistic_optimizations, + )?)) + } +} + fn validate_set_operation_input_schemas( input_schemas: &HashMap, ) -> Result, PipelineError> { diff --git a/dozer-sql/src/tests/builder_test.rs b/dozer-sql/src/tests/builder_test.rs index 4d2b398255..2ec5416aba 100644 --- a/dozer-sql/src/tests/builder_test.rs +++ b/dozer-sql/src/tests/builder_test.rs @@ -21,6 +21,7 @@ use tokio::sync::mpsc::Sender; use std::collections::HashMap; use std::future::pending; +use std::sync::{Arc, Mutex}; use crate::builder::statement_to_pipeline; use crate::tests::utils::create_test_runtime; @@ -219,6 +220,293 @@ impl Sink for TestSink { } } +#[derive(Debug)] +pub struct ScriptedSourceFactory { + output_ports: Vec, + operations: Vec<(PortHandle, Operation)>, +} + +impl ScriptedSourceFactory { + pub fn new(output_ports: Vec, operations: Vec<(PortHandle, Operation)>) -> Self { + Self { + output_ports, + operations, + } + } +} + +impl SourceFactory for ScriptedSourceFactory { + fn get_output_ports(&self) -> Vec { + self.output_ports + .iter() + .map(|e| OutputPortDef::new(*e, OutputPortType::Stateless)) + .collect() + } + + fn get_output_schema(&self, port: &PortHandle) -> Result { + let table_name = if *port == 1 { "allowed" } else { "users" }; + Ok(Schema::default() + .field( + FieldDefinition::new( + String::from("CustomerID"), + FieldType::Int, + false, + SourceDefinition::Table { + connection: "mem".to_string(), + name: table_name.to_string(), + }, + ), + false, + ) + .field( + FieldDefinition::new( + String::from("Country"), + FieldType::String, + false, + SourceDefinition::Table { + connection: "mem".to_string(), + name: table_name.to_string(), + }, + ), + false, + ) + .field( + FieldDefinition::new( + String::from("Spending"), + FieldType::Float, + false, + SourceDefinition::Table { + connection: "mem".to_string(), + name: table_name.to_string(), + }, + ), + false, + ) + .field( + FieldDefinition::new( + String::from("timestamp"), + FieldType::Timestamp, + false, + SourceDefinition::Table { + connection: "mem".to_string(), + name: table_name.to_string(), + }, + ), + false, + ) + .clone()) + } + + fn get_output_port_name(&self, port: &PortHandle) -> String { + format!("port_{}", port) + } + + fn build( + &self, + _output_schemas: HashMap, + _event_hub: EventHub, + _state: Option>, + ) -> Result, BoxedError> { + Ok(Box::new(ScriptedSource { + operations: self.operations.clone(), + })) + } +} + +#[derive(Debug)] +pub struct ScriptedSource { + operations: Vec<(PortHandle, Operation)>, +} + +#[async_trait] +impl Source for ScriptedSource { + async fn serialize_state(&self) -> Result, BoxedError> { + Ok(vec![]) + } + + async fn start( + &mut self, + sender: Sender<(PortHandle, IngestionMessage)>, + _last_checkpoint: Option, + ) -> Result<(), BoxedError> { + for (port, op) in self.operations.clone() { + sender + .send(( + port, + IngestionMessage::OperationEvent { + table_index: port as usize, + op, + id: None, + }, + )) + .await + .unwrap(); + } + Ok(()) + } +} + +#[derive(Debug)] +pub struct CollectingSinkFactory { + operations: Arc>>, +} + +impl CollectingSinkFactory { + pub fn new(operations: Arc>>) -> Self { + Self { operations } + } +} + +#[async_trait] +impl SinkFactory for CollectingSinkFactory { + fn get_input_ports(&self) -> Vec { + vec![DEFAULT_PORT_HANDLE] + } + + fn get_input_port_name(&self, _port: &PortHandle) -> String { + "test".to_string() + } + + async fn build( + &self, + _input_schemas: HashMap, + _event_hub: EventHub, + ) -> Result, BoxedError> { + Ok(Box::new(CollectingSink { + operations: self.operations.clone(), + })) + } + + fn prepare(&self, _input_schemas: HashMap) -> Result<(), BoxedError> { + Ok(()) + } + + fn type_name(&self) -> String { + "test".to_string() + } +} + +#[derive(Debug)] +pub struct CollectingSink { + operations: Arc>>, +} + +impl Sink for CollectingSink { + fn process(&mut self, op: TableOperation) -> Result<(), BoxedError> { + self.operations.lock().unwrap().push(op); + Ok(()) + } + + fn commit(&mut self, _epoch_details: &Epoch) -> Result<(), BoxedError> { + Ok(()) + } + + fn on_source_snapshotting_started( + &mut self, + _connection_name: String, + ) -> Result<(), BoxedError> { + Ok(()) + } + + fn on_source_snapshotting_done( + &mut self, + _connection_name: String, + _id: Option, + ) -> Result<(), BoxedError> { + Ok(()) + } + + fn set_source_state(&mut self, _source_state: &[u8]) -> Result<(), BoxedError> { + Ok(()) + } + + fn get_source_state(&mut self) -> Result>, BoxedError> { + Ok(None) + } + + fn get_latest_op_id(&mut self) -> Result, BoxedError> { + Ok(None) + } +} + +fn scripted_record(customer_id: i64, country: &str, spending: f64, timestamp: &str) -> Record { + Record::new(vec![ + Field::Int(customer_id), + Field::String(country.to_string()), + Field::Float(OrderedFloat(spending)), + Field::Timestamp(DateTime::parse_from_rfc3339(timestamp).unwrap()), + ]) +} + +fn scripted_insert(record: Record) -> Operation { + Operation::Insert { new: record } +} + +fn execute_scripted_query(sql: &str, operations: Vec<(PortHandle, Operation)>) -> Vec> { + let mut pipeline = AppPipeline::new_with_default_flags(); + let runtime = create_test_runtime(); + let context = statement_to_pipeline(sql, &mut pipeline, None, vec![], runtime.clone()).unwrap(); + + let table_info = context.output_tables_map.get("results").unwrap(); + let output_operations = Arc::new(Mutex::new(vec![])); + + let mut asm = AppSourceManager::new(); + asm.add( + Box::new(ScriptedSourceFactory::new( + vec![DEFAULT_PORT_HANDLE, 1], + operations, + )), + AppSourceMappings::new( + "mem".to_string(), + vec![ + ("users".to_string(), DEFAULT_PORT_HANDLE), + ("allowed".to_string(), 1), + ] + .into_iter() + .collect(), + ), + ) + .unwrap(); + + pipeline.add_sink( + Box::new(CollectingSinkFactory::new(output_operations.clone())), + "sink".to_string(), + ); + pipeline.connect_nodes( + table_info.node.clone(), + table_info.port, + "sink".to_string(), + DEFAULT_PORT_HANDLE, + ); + + let mut app = App::new(asm); + app.add_pipeline(pipeline); + + let dag = app.into_dag().unwrap(); + let runtime_clone = runtime.clone(); + let handle = runtime.block_on(async move { + DagExecutor::new(dag, Default::default()) + .await + .unwrap() + .start(pending::<()>(), Default::default(), runtime_clone) + .await + .unwrap() + }); + handle.join().unwrap(); + + let rows = output_operations + .lock() + .unwrap() + .iter() + .filter_map(|op| match &op.op { + Operation::Insert { new } => Some(new.values.clone()), + _ => None, + }) + .collect::>(); + + rows +} + #[test] fn test_pipeline_builder() { let mut pipeline = AppPipeline::new_with_default_flags(); @@ -280,3 +568,135 @@ fn test_pipeline_builder() { let elapsed = now.elapsed(); debug!("Elapsed: {:.2?}", elapsed); } + +#[test] +fn test_in_subquery_filters_stream_with_inner_select_membership() { + let rows = execute_scripted_query( + "SELECT users.CustomerID \ + INTO results \ + FROM users \ + WHERE users.CustomerID IN (SELECT allowed.CustomerID FROM allowed)", + vec![ + ( + 1, + scripted_insert(scripted_record(7, "Allowed", 0.0, "2020-01-01T00:00:00Z")), + ), + ( + DEFAULT_PORT_HANDLE, + scripted_insert(scripted_record(7, "Italy", 5.5, "2020-01-01T00:13:00Z")), + ), + ( + DEFAULT_PORT_HANDLE, + scripted_insert(scripted_record(8, "France", 7.0, "2020-01-01T00:14:00Z")), + ), + ], + ); + + assert_eq!(rows, vec![vec![Field::Int(7)]]); +} + +#[test] +fn test_in_subquery_emits_when_inner_membership_arrives_later() { + let rows = execute_scripted_query( + "SELECT users.CustomerID \ + INTO results \ + FROM users \ + WHERE users.CustomerID IN (SELECT allowed.CustomerID FROM allowed)", + vec![ + ( + DEFAULT_PORT_HANDLE, + scripted_insert(scripted_record(7, "Italy", 5.5, "2020-01-01T00:13:00Z")), + ), + ( + 1, + scripted_insert(scripted_record(7, "Allowed", 0.0, "2020-01-01T00:14:00Z")), + ), + ], + ); + + assert_eq!(rows, vec![vec![Field::Int(7)]]); +} + +#[test] +fn test_in_subquery_qualifies_unqualified_outer_identifier() { + let rows = execute_scripted_query( + "SELECT users.CustomerID \ + INTO results \ + FROM users \ + WHERE CustomerID IN (SELECT allowed.CustomerID FROM allowed)", + vec![ + ( + 1, + scripted_insert(scripted_record(7, "Allowed", 0.0, "2020-01-01T00:00:00Z")), + ), + ( + DEFAULT_PORT_HANDLE, + scripted_insert(scripted_record(7, "Italy", 5.5, "2020-01-01T00:13:00Z")), + ), + ], + ); + + assert_eq!(rows, vec![vec![Field::Int(7)]]); +} + +#[test] +fn test_in_subquery_keeps_additional_where_predicates() { + let rows = execute_scripted_query( + "SELECT users.CustomerID \ + INTO results \ + FROM users \ + WHERE users.Spending > 6 \ + AND users.CustomerID IN (SELECT allowed.CustomerID FROM allowed)", + vec![ + ( + 1, + scripted_insert(scripted_record(7, "Allowed", 0.0, "2020-01-01T00:00:00Z")), + ), + ( + 1, + scripted_insert(scripted_record(8, "Allowed", 0.0, "2020-01-01T00:00:01Z")), + ), + ( + DEFAULT_PORT_HANDLE, + scripted_insert(scripted_record(7, "Italy", 5.5, "2020-01-01T00:13:00Z")), + ), + ( + DEFAULT_PORT_HANDLE, + scripted_insert(scripted_record(8, "France", 7.0, "2020-01-01T00:14:00Z")), + ), + ], + ); + + assert_eq!(rows, vec![vec![Field::Int(8)]]); +} + +#[test] +fn test_in_subquery_uses_membership_semantics_for_duplicate_inner_rows() { + let rows = execute_scripted_query( + "SELECT users.CustomerID \ + INTO results \ + FROM users \ + WHERE users.CustomerID IN (SELECT allowed.CustomerID FROM allowed)", + vec![ + ( + 1, + scripted_insert(scripted_record(7, "Allowed", 0.0, "2020-01-01T00:00:00Z")), + ), + ( + 1, + scripted_insert(scripted_record( + 7, + "Allowed duplicate", + 1.0, + "2020-01-01T00:00:01Z", + )), + ), + ( + DEFAULT_PORT_HANDLE, + scripted_insert(scripted_record(7, "Italy", 5.5, "2020-01-01T00:13:00Z")), + ), + ], + ); + + assert_eq!(rows, vec![vec![Field::Int(7)]]); +}