Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 153 additions & 6 deletions dozer-sql/src/builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@ use crate::aggregation::factory::AggregationProcessorFactory;
use crate::builder::PipelineError::InvalidQuery;
use crate::errors::PipelineError;
use crate::selection::factory::SelectionProcessorFactory;
use crate::selection::in_subquery::{
InSubqueryProcessorFactory, LEFT_IN_SUBQUERY_PORT, RIGHT_IN_SUBQUERY_PORT,
};
use dozer_core::app::AppPipeline;
use dozer_core::node::PortHandle;
use dozer_core::DEFAULT_PORT_HANDLE;
use dozer_sql_expression::builder::{ExpressionBuilder, NameOrAlias};
use dozer_sql_expression::sqlparser::ast::{SetOperator, SetQuantifier, TableFactor};
use dozer_sql_expression::sqlparser::ast::{
BinaryOperator, Expr as SqlExpr, SelectItem, SetOperator, SetQuantifier, TableFactor,
};
use dozer_types::models::udf_config::UdfConfig;

use dozer_sql_expression::sqlparser::{
Expand Down Expand Up @@ -121,6 +126,11 @@ struct TableInfo {
override_name: Option<String>,
}

struct InSubquerySelection {
expr: Box<SqlExpr>,
subquery: Box<Query>,
}

fn query_to_pipeline(
table_info: TableInfo,
query: Query,
Expand Down Expand Up @@ -236,12 +246,16 @@ fn query_to_pipeline(

fn select_to_pipeline(
table_info: TableInfo,
select: Select,
mut select: Select,
pipeline: &mut AppPipeline,
query_ctx: &mut QueryContext,
pipeline_idx: usize,
is_top_select: bool,
) -> Result<String, PipelineError> {
let (in_subquery_selections, residual_selection) =
split_in_subquery_selection(select.selection.take())?;
select.selection = residual_selection;

// FROM clause
let Some(from) = select.from.into_iter().next() else {
return Err(PipelineError::UnsupportedSqlError(
Expand All @@ -258,6 +272,8 @@ fn select_to_pipeline(

let gen_selection_name = format!("select--{}", query_ctx.get_next_processor_id());
let (gen_product_name, product_output_port) = output_node;
let mut upstream_node = gen_product_name;
let mut upstream_port = product_output_port;

for (source_name, processor_name, processor_port) in input_nodes {
if let Some(table_info) = query_ctx
Expand All @@ -276,6 +292,57 @@ fn select_to_pipeline(
}
}

for in_subquery_selection in in_subquery_selections {
let subquery_name = format!("in_subquery_{}", query_ctx.get_next_processor_id());
query_to_pipeline(
TableInfo {
name: NameOrAlias(subquery_name.clone(), None),
override_name: None,
},
*in_subquery_selection.subquery,
pipeline,
query_ctx,
pipeline_idx,
false,
)?;

let subquery_output = query_ctx
.pipeline_map
.get(&(pipeline_idx, subquery_name.clone()))
.cloned()
.ok_or_else(|| {
PipelineError::InvalidQuery(format!("Invalid IN subquery {subquery_name}"))
})?;

let in_subquery_node = format!("in_subquery--{}", query_ctx.get_next_processor_id());
if !query_ctx.processors_list.insert(in_subquery_node.clone()) {
return Err(PipelineError::ProcessorAlreadyExists(in_subquery_node));
}

let in_subquery = InSubqueryProcessorFactory::new(
in_subquery_node.clone(),
*in_subquery_selection.expr,
query_ctx.udfs.clone(),
query_ctx.runtime.clone(),
);
pipeline.add_processor(Box::new(in_subquery), in_subquery_node.clone());
pipeline.connect_nodes(
upstream_node,
upstream_port,
in_subquery_node.clone(),
LEFT_IN_SUBQUERY_PORT,
);
pipeline.connect_nodes(
subquery_output.node,
subquery_output.port,
in_subquery_node.clone(),
RIGHT_IN_SUBQUERY_PORT,
);

upstream_node = in_subquery_node;
upstream_port = DEFAULT_PORT_HANDLE;
}

let aggregation = AggregationProcessorFactory::new(
gen_agg_name.clone(),
select.projection,
Expand Down Expand Up @@ -304,8 +371,8 @@ fn select_to_pipeline(
pipeline.add_processor(Box::new(selection), gen_selection_name.clone());

pipeline.connect_nodes(
gen_product_name,
product_output_port,
upstream_node,
upstream_port,
gen_selection_name.clone(),
DEFAULT_PORT_HANDLE,
);
Expand All @@ -318,8 +385,8 @@ fn select_to_pipeline(
);
} else {
pipeline.connect_nodes(
gen_product_name,
product_output_port,
upstream_node,
upstream_port,
gen_agg_name.clone(),
DEFAULT_PORT_HANDLE,
);
Expand Down Expand Up @@ -360,6 +427,86 @@ fn select_to_pipeline(
Ok(gen_agg_name)
}

fn split_in_subquery_selection(
selection: Option<SqlExpr>,
) -> Result<(Vec<InSubquerySelection>, Option<SqlExpr>), PipelineError> {
match selection {
Some(selection) => extract_in_subquery_selection(selection),
None => Ok((vec![], None)),
}
}

fn extract_in_subquery_selection(
selection: SqlExpr,
) -> Result<(Vec<InSubquerySelection>, Option<SqlExpr>), PipelineError> {
match selection {
SqlExpr::InSubquery {
expr,
subquery,
negated,
} => {
if negated {
return Err(PipelineError::InvalidQuery(
"NOT IN subqueries are not supported".to_string(),
));
}
validate_in_subquery_projection(&subquery)?;
Ok((vec![InSubquerySelection { expr, subquery }], None))
}
SqlExpr::BinaryOp {
left,
op: BinaryOperator::And,
right,
} => {
let (mut left_in_subqueries, left_residual) = extract_in_subquery_selection(*left)?;
let (mut right_in_subqueries, right_residual) = extract_in_subquery_selection(*right)?;
left_in_subqueries.append(&mut right_in_subqueries);
Ok((
left_in_subqueries,
combine_residual_and(left_residual, right_residual),
))
}
SqlExpr::Nested(expr) => {
let (in_subqueries, residual) = extract_in_subquery_selection(*expr)?;
Ok((
in_subqueries,
residual.map(|expr| SqlExpr::Nested(Box::new(expr))),
))
}
other => Ok((vec![], Some(other))),
}
}

fn combine_residual_and(left: Option<SqlExpr>, right: Option<SqlExpr>) -> Option<SqlExpr> {
match (left, right) {
(Some(left), Some(right)) => Some(SqlExpr::BinaryOp {
left: Box::new(left),
op: BinaryOperator::And,
right: Box::new(right),
}),
(Some(expr), None) | (None, Some(expr)) => Some(expr),
(None, None) => None,
}
}

fn validate_in_subquery_projection(query: &Query) -> Result<(), PipelineError> {
match query.body.as_ref() {
SetExpr::Select(select) if select.projection.len() != 1 => Err(
PipelineError::InvalidQuery("IN subquery must return exactly one column".to_string()),
),
SetExpr::Select(select) => match select.projection.first() {
Some(SelectItem::Wildcard(_)) | Some(SelectItem::QualifiedWildcard(_, _)) => {
Err(PipelineError::InvalidQuery(
"IN subquery must project a single expression or column".to_string(),
))
}
_ => Ok(()),
},
SetExpr::Query(query) => validate_in_subquery_projection(query),
_ => Ok(()),
}
}

#[allow(clippy::too_many_arguments)]
fn set_to_pipeline(
table_info: TableInfo,
Expand Down
94 changes: 94 additions & 0 deletions dozer-sql/src/builder/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
use super::statement_to_pipeline;
use crate::{errors::PipelineError, tests::utils::create_test_runtime};
use dozer_core::app::AppPipeline;

fn build_pipeline(sql: &str) -> Result<super::QueryContext, PipelineError> {
let runtime = create_test_runtime();
statement_to_pipeline(
sql,
&mut AppPipeline::new_with_default_flags(),
None,
vec![],
runtime,
)
}

#[test]
#[should_panic]
fn disallow_zero_outgoing_ndes() {
Expand Down Expand Up @@ -145,6 +157,88 @@ fn test_correct_into_clause() {
assert!(result.is_ok());
}

#[test]
fn test_in_subquery_where_clause_keeps_residual_predicates() {
let sql = r#"
SELECT users.CustomerID
INTO matched_customers
FROM users
WHERE users.Spending > 10
AND users.CustomerID IN (
SELECT allowed.CustomerID FROM allowed
)
"#;

let context = build_pipeline(sql).unwrap();

assert!(context.output_tables_map.contains_key("matched_customers"));
assert!(context.used_sources.contains(&"users".to_string()));
assert!(context.used_sources.contains(&"allowed".to_string()));
}

#[test]
fn test_multiple_in_subqueries_can_be_chained() {
let sql = r#"
SELECT users.CustomerID
INTO matched_customers
FROM users
WHERE users.CustomerID IN (
SELECT allowed.CustomerID FROM allowed
)
AND users.Country IN (
SELECT allowed_countries.Country FROM allowed_countries
)
"#;

let context = build_pipeline(sql).unwrap();

assert!(context.output_tables_map.contains_key("matched_customers"));
assert!(context.used_sources.contains(&"users".to_string()));
assert!(context.used_sources.contains(&"allowed".to_string()));
assert!(context
.used_sources
.contains(&"allowed_countries".to_string()));
}

#[test]
fn test_not_in_subquery_is_rejected() {
let sql = r#"
SELECT users.CustomerID
INTO matched_customers
FROM users
WHERE users.CustomerID NOT IN (
SELECT blocked.CustomerID FROM blocked
)
"#;

let result = build_pipeline(sql);

assert!(matches!(
result,
Err(PipelineError::InvalidQuery(message)) if message.contains("NOT IN subqueries")
));
}

#[test]
fn test_in_subquery_rejects_multi_column_projection() {
let sql = r#"
SELECT users.CustomerID
INTO matched_customers
FROM users
WHERE users.CustomerID IN (
SELECT allowed.CustomerID, allowed.Country FROM allowed
)
"#;

let result = build_pipeline(sql);

assert!(matches!(
result,
Err(PipelineError::InvalidQuery(message))
if message.contains("exactly one column")
));
}

#[test]
fn test_missing_into_in_nested_from_clause() {
let sql = r#"SELECT a FROM (SELECT a from b)"#;
Expand Down
Loading