Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicLong

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.{CacheTableAsSelect, CTERelationRef, LogicalPlan, UnionLoop, UnionLoopRef, WithCTE}
import org.apache.spark.sql.catalyst.rules.Rule

Expand Down Expand Up @@ -53,7 +54,7 @@ object NormalizeCTEIds extends Rule[LogicalPlan] {
private def canonicalizeCTE(
plan: LogicalPlan,
defIdToNewId: mutable.Map[Long, Long]): LogicalPlan = {
plan.transformDownWithSubqueries {
val normalizedPlan = plan match {
// For nested WithCTE, if defIndex didn't contain the cteId,
// means it's not current WithCTE's ref.
case ref: CTERelationRef if defIdToNewId.contains(ref.cteId) =>
Expand All @@ -62,6 +63,17 @@ object NormalizeCTEIds extends Rule[LogicalPlan] {
unionLoop.copy(id = defIdToNewId(unionLoop.id))
case unionLoopRef: UnionLoopRef if defIdToNewId.contains(unionLoopRef.loopId) =>
unionLoopRef.copy(loopId = defIdToNewId(unionLoopRef.loopId))
case other => other
}

normalizedPlan
.withNewChildren(normalizedPlan.children.map {
case withCTE: WithCTE => withCTE
case child => canonicalizeCTE(child, defIdToNewId)
})
.transformExpressionsDown {
case subqueryExpression: SubqueryExpression =>
subqueryExpression.withNewPlan(canonicalizeCTE(subqueryExpression.plan, defIdToNewId))
}
}
}
39 changes: 39 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,45 @@ abstract class CTEInlineSuiteBase
}
}

test("SPARK-56921: plan normalization handles nested CTEs under union") {
withTempView("input", "common") {
Seq((1, 1, 10), (1, 2, 20), (2, 1, 30))
.toDF("a", "b", "value")
.createOrReplaceTempView("input")

sql(
s"""with cte_common as (
| select a, b, sum(value) as value
| from input
| group by a, b
|)
|select * from cte_common
""".stripMargin).createOrReplaceTempView("common")

val left = sql(
s"""with cte_a as (
| select a, sum(value) as value
| from common
| group by a
|)
|select a as id, value from cte_a
""".stripMargin)

val right = sql(
s"""with cte_b as (
| select b, sum(value) as value
| from common
| group by b
|)
|select b as id, value from cte_b
""".stripMargin)

val df = left.union(right)
df.queryExecution.normalized
checkAnswer(df, Row(1, 30) :: Row(2, 30) :: Row(1, 40) :: Row(2, 20) :: Nil)
}
}

test("SPARK-36447: invalid nested CTEs") {
withTempView("t") {
Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t")
Expand Down