Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 182 additions & 4 deletions dozer-sql/src/builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use dozer_core::app::AppPipeline;
use dozer_core::node::PortHandle;
use dozer_core::DEFAULT_PORT_HANDLE;
use dozer_sql_expression::builder::{ExpressionBuilder, NameOrAlias};
use dozer_sql_expression::sqlparser::ast::{SetOperator, SetQuantifier, TableFactor};
use dozer_sql_expression::sqlparser::ast::{
BinaryOperator, Expr as SqlExpr, Ident, Join, JoinConstraint, JoinOperator, SelectItem,
SetOperator, SetQuantifier, TableAlias, TableFactor, TableWithJoins,
};
use dozer_types::models::udf_config::UdfConfig;

use dozer_sql_expression::sqlparser::{
Expand All @@ -21,7 +24,9 @@ use tokio::runtime::Runtime;

use super::errors::UnsupportedSqlError;

use super::product::set::set_factory::SetProcessorFactory;
use super::product::set::set_factory::{DedupProcessorFactory, SetProcessorFactory};

const IN_SUBQUERY_ALIAS_PREFIX: &str = "__dozer_in_subquery_";

#[derive(Debug, Clone)]
pub struct OutputNodeInfo {
Expand Down Expand Up @@ -236,19 +241,23 @@ fn query_to_pipeline(

fn select_to_pipeline(
table_info: TableInfo,
select: Select,
mut select: Select,
pipeline: &mut AppPipeline,
query_ctx: &mut QueryContext,
pipeline_idx: usize,
is_top_select: bool,
) -> Result<String, PipelineError> {
// FROM clause
let Some(from) = select.from.into_iter().next() else {
let Some(mut from) = select.from.into_iter().next() else {
return Err(PipelineError::UnsupportedSqlError(
UnsupportedSqlError::FromCommaSyntax,
));
};

if let Some(selection) = select.selection.take() {
select.selection = rewrite_in_subquery_selection_as_join(selection, &mut from, query_ctx)?;
}

let connection_info = from::insert_from_to_pipeline(from, pipeline, pipeline_idx, query_ctx)?;

let input_nodes = connection_info.input_nodes;
Expand Down Expand Up @@ -360,6 +369,117 @@ fn select_to_pipeline(
Ok(gen_agg_name)
}

fn rewrite_in_subquery_selection_as_join(
selection: SqlExpr,
from: &mut TableWithJoins,
query_ctx: &mut QueryContext,
) -> Result<Option<SqlExpr>, PipelineError> {
match selection {
SqlExpr::BinaryOp {
left,
op: BinaryOperator::And,
right,
} => {
let left = rewrite_in_subquery_selection_as_join(*left, from, query_ctx)?;
let right = rewrite_in_subquery_selection_as_join(*right, from, query_ctx)?;

match (left, right) {
(Some(left), Some(right)) => Ok(Some(SqlExpr::BinaryOp {
left: Box::new(left),
op: BinaryOperator::And,
right: Box::new(right),
})),
(Some(selection), None) | (None, Some(selection)) => Ok(Some(selection)),
(None, None) => Ok(None),
}
}
SqlExpr::InSubquery {
expr,
subquery,
negated: false,
} => {
let subquery_field = prepare_in_subquery_for_membership(&subquery)?;
let expr = qualify_unqualified_outer_identifier(expr, from)?;
let alias = format!(
"{}{}",
IN_SUBQUERY_ALIAS_PREFIX,
query_ctx.get_next_processor_id()
);
let join_constraint = SqlExpr::BinaryOp {
left: expr,
op: BinaryOperator::Eq,
right: Box::new(SqlExpr::CompoundIdentifier(vec![
Ident::new(alias.clone()),
subquery_field,
])),
};

from.joins.push(Join {
relation: TableFactor::Derived {
lateral: false,
subquery,
alias: Some(TableAlias {
name: Ident::new(alias),
columns: vec![],
}),
},
join_operator: JoinOperator::Inner(JoinConstraint::On(join_constraint)),
});

Ok(None)
}
other => Ok(Some(other)),
}
}

fn qualify_unqualified_outer_identifier(
expr: Box<SqlExpr>,
from: &TableWithJoins,
) -> Result<Box<SqlExpr>, PipelineError> {
let SqlExpr::Identifier(ident) = *expr else {
return Ok(expr);
};

let source_name_or_alias = common::get_name_or_alias(&from.relation)?;
let qualifier = source_name_or_alias.1.unwrap_or(source_name_or_alias.0);

Ok(Box::new(SqlExpr::CompoundIdentifier(vec![
Ident::new(qualifier),
ident,
])))
}

fn prepare_in_subquery_for_membership(query: &Query) -> Result<Ident, PipelineError> {
let select = match query.body.as_ref() {
SetExpr::Select(select) => select,
SetExpr::Query(query) => return prepare_in_subquery_for_membership(query),
_ => {
return Err(PipelineError::InvalidQuery(
"IN subquery must project exactly one column".to_string(),
))
}
};

let [projection] = select.projection.as_slice() else {
return Err(PipelineError::InvalidQuery(
"IN subquery must project exactly one column".to_string(),
));
};

match projection {
SelectItem::UnnamedExpr(SqlExpr::Identifier(ident)) => Ok(ident.clone()),
SelectItem::UnnamedExpr(SqlExpr::CompoundIdentifier(idents)) => {
idents.last().cloned().ok_or_else(|| {
PipelineError::InvalidQuery("IN subquery projection must name a column".to_string())
})
}
SelectItem::ExprWithAlias { alias, .. } => Ok(alias.clone()),
_ => Err(PipelineError::InvalidQuery(
"IN subquery projection must name a column".to_string(),
)),
}
}

#[allow(clippy::too_many_arguments)]
fn set_to_pipeline(
table_info: TableInfo,
Expand Down Expand Up @@ -530,6 +650,10 @@ fn get_from_source(
let alias_name = alias.as_ref().map(|alias_ident| {
ExpressionBuilder::fullname_from_ident(&[alias_ident.name.clone()])
});
let should_dedup_output = alias_name
.as_ref()
.map(|name| name.starts_with(IN_SUBQUERY_ALIAS_PREFIX))
.unwrap_or(false);
let is_top_select = false; //inside FROM clause, so not top select
let name_or = NameOrAlias(name, alias_name);
query_to_pipeline(
Expand All @@ -543,6 +667,9 @@ fn get_from_source(
pipeline_idx,
is_top_select,
)?;
if should_dedup_output {
insert_dedup_processor_to_pipeline(&name_or.0, pipeline, query_ctx, pipeline_idx)?;
}

Ok(name_or)
}
Expand All @@ -552,6 +679,57 @@ fn get_from_source(
}
}

fn insert_dedup_processor_to_pipeline(
table_name: &str,
pipeline: &mut AppPipeline,
query_ctx: &mut QueryContext,
pipeline_idx: usize,
) -> Result<(), PipelineError> {
let output_node = query_ctx
.pipeline_map
.get(&(pipeline_idx, table_name.to_string()))
.cloned()
.ok_or_else(|| {
PipelineError::InvalidQuery(format!(
"Unable to deduplicate IN subquery source {table_name}"
))
})?;

let dedup_processor_name = format!("dedup--{}", query_ctx.get_next_processor_id());
if !query_ctx
.processors_list
.insert(dedup_processor_name.clone())
{
return Err(PipelineError::ProcessorAlreadyExists(dedup_processor_name));
}

let dedup_processor = DedupProcessorFactory::new(
dedup_processor_name.clone(),
pipeline
.flags()
.enable_probabilistic_optimizations
.in_sets
.unwrap_or(false),
);
pipeline.add_processor(Box::new(dedup_processor), dedup_processor_name.clone());
pipeline.connect_nodes(
output_node.node,
output_node.port,
dedup_processor_name.clone(),
DEFAULT_PORT_HANDLE,
);

query_ctx.pipeline_map.insert(
(pipeline_idx, table_name.to_string()),
OutputNodeInfo {
node: dedup_processor_name,
port: DEFAULT_PORT_HANDLE,
},
);

Ok(())
}

#[derive(Clone, Debug)]
struct ConnectionInfo {
input_nodes: Vec<(String, String, PortHandle)>,
Expand Down
61 changes: 61 additions & 0 deletions dozer-sql/src/product/set/set_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ pub struct SetProcessorFactory {
enable_probabilistic_optimizations: bool,
}

#[derive(Debug)]
pub struct DedupProcessorFactory {
id: String,
enable_probabilistic_optimizations: bool,
}

impl SetProcessorFactory {
/// Creates a new [`FromProcessorFactory`].
pub fn new(
Expand All @@ -38,6 +44,15 @@ impl SetProcessorFactory {
}
}

impl DedupProcessorFactory {
pub fn new(id: String, enable_probabilistic_optimizations: bool) -> Self {
Self {
id,
enable_probabilistic_optimizations,
}
}
}

#[async_trait]
impl ProcessorFactory for SetProcessorFactory {
fn id(&self) -> String {
Expand Down Expand Up @@ -87,6 +102,52 @@ impl ProcessorFactory for SetProcessorFactory {
}
}

#[async_trait]
impl ProcessorFactory for DedupProcessorFactory {
fn id(&self) -> String {
self.id.clone()
}

fn type_name(&self) -> String {
"Dedup".to_string()
}

fn get_input_ports(&self) -> Vec<PortHandle> {
vec![DEFAULT_PORT_HANDLE]
}

fn get_output_ports(&self) -> Vec<PortHandle> {
vec![DEFAULT_PORT_HANDLE]
}

async fn get_output_schema(
&self,
_output_port: &PortHandle,
input_schemas: &HashMap<PortHandle, Schema>,
) -> Result<Schema, BoxedError> {
input_schemas
.get(&DEFAULT_PORT_HANDLE)
.cloned()
.ok_or_else(|| PipelineError::InvalidPortHandle(DEFAULT_PORT_HANDLE).into())
}

async fn build(
&self,
_input_schemas: HashMap<PortHandle, Schema>,
_output_schemas: HashMap<PortHandle, Schema>,
_event_hub: EventHub,
) -> Result<Box<dyn Processor>, BoxedError> {
Ok(Box::new(SetProcessor::new(
self.id.clone(),
SetOperation {
op: SetOperator::Union,
quantifier: SetQuantifier::None,
},
self.enable_probabilistic_optimizations,
)?))
}
}

fn validate_set_operation_input_schemas(
input_schemas: &HashMap<PortHandle, Schema>,
) -> Result<Vec<FieldDefinition>, PipelineError> {
Expand Down
Loading