diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index b0f09bc43b..86d63a75c2 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -355,6 +355,7 @@ jobs: org.apache.comet.exec.CometWindowExecSuite org.apache.comet.exec.CometJoinSuite org.apache.comet.CometNativeSuite + org.apache.comet.CometSetOpWithGroupBySuite org.apache.comet.CometSparkSessionExtensionsSuite org.apache.spark.CometPluginsSuite org.apache.spark.CometPluginsDefaultSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index c743d1888a..263317cd15 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -194,6 +194,7 @@ jobs: org.apache.comet.exec.CometWindowExecSuite org.apache.comet.exec.CometJoinSuite org.apache.comet.CometNativeSuite + org.apache.comet.CometSetOpWithGroupBySuite org.apache.comet.CometSparkSessionExtensionsSuite org.apache.spark.CometPluginsSuite org.apache.spark.CometPluginsDefaultSuite diff --git a/dev/diffs/4.1.1.diff b/dev/diffs/4.1.1.diff index bc662dec7d..eb0a9d574a 100644 --- a/dev/diffs/4.1.1.diff +++ b/dev/diffs/4.1.1.diff @@ -150,50 +150,6 @@ index 4410fe50912..43bcce2a038 100644 case _ => Map[String, String]() } val childrenInfo = children.flatMap { -diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/intersect-all.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/intersect-all.sql.out -index 69b4001ff34..6fda691652d 100644 ---- a/sql/core/src/test/resources/sql-tests/analyzer-results/intersect-all.sql.out -+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/intersect-all.sql.out -@@ -1,7 +1,7 @@ - -- Automatically generated by SQLQueryTestSuite - -- !query - CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES -- (1, 2), -+ (1, 2), - (1, 2), - (1, 3), - (1, 3), -@@ -11,7 +11,7 @@ CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES - AS tab1(k, v) - -- !query analysis - CreateViewCommand `tab1`, SELECT * FROM VALUES -- (1, 2), -+ (1, 2), - (1, 2), - (1, 3), - (1, 3), -@@ -26,8 +26,8 @@ CreateViewCommand `tab1`, SELECT * FROM VALUES - - -- !query - CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES -- (1, 2), -- (1, 2), -+ (1, 2), -+ (1, 2), - (2, 3), - (3, 4), - (null, null), -@@ -35,8 +35,8 @@ CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES - AS tab2(k, v) - -- !query analysis - CreateViewCommand `tab2`, SELECT * FROM VALUES -- (1, 2), -- (1, 2), -+ (1, 2), -+ (1, 2), - (2, 3), - (3, 4), - (null, null), diff --git a/sql/core/src/test/resources/sql-tests/inputs/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/decimalArithmeticOperations.sql index 13bbd9d81b7..541cdfb1e04 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/decimalArithmeticOperations.sql @@ -211,18 +167,6 @@ index 13bbd9d81b7..541cdfb1e04 100644 CREATE TEMPORARY VIEW t AS SELECT 1.0 as a, 0.0 as b; -- division, remainder and pmod by 0 return NULL -diff --git a/sql/core/src/test/resources/sql-tests/inputs/except-all.sql b/sql/core/src/test/resources/sql-tests/inputs/except-all.sql -index e28f0721a64..788b43c242a 100644 ---- a/sql/core/src/test/resources/sql-tests/inputs/except-all.sql -+++ b/sql/core/src/test/resources/sql-tests/inputs/except-all.sql -@@ -1,3 +1,7 @@ -+-- TODO(https://github.com/apache/datafusion-comet/issues/4122) -+-- EXCEPT ALL with GROUP BY returns incorrect results on Spark 4.1 -+--SET spark.comet.enabled = false -+ - CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES - (0), (1), (2), (2), (2), (2), (3), (null), (null) AS tab1(c1); - CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql b/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql index 7aef901da4f..f3d6e18926d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql @@ -280,32 +224,6 @@ index 35128da97fd..25b873ae859 100644 -- Positive test cases -- Create a table with some testing data. DROP TABLE IF EXISTS t1; -diff --git a/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql -index 077caa5dd44..697457d4251 100644 ---- a/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql -+++ b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql -@@ -1,5 +1,9 @@ -+-- TODO(https://github.com/apache/datafusion-comet/issues/4122) -+-- INTERSECT ALL with GROUP BY returns incorrect results on Spark 4.1 -+--SET spark.comet.enabled = false -+ - CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES -- (1, 2), -+ (1, 2), - (1, 2), - (1, 3), - (1, 3), -@@ -8,8 +12,8 @@ CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES - (null, null) - AS tab1(k, v); - CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES -- (1, 2), -- (1, 2), -+ (1, 2), -+ (1, 2), - (2, 3), - (3, 4), - (null, null), diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql index 41fd4de2a09..162d5a817b6 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql @@ -428,30 +346,6 @@ index 21a3ce1e122..f4762ab98f0 100644 SET spark.sql.ansi.enabled = false; -- In COMPENSATION views get invalidated if the type can't cast -diff --git a/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out -index 44f95f225ab..361866fc298 100644 ---- a/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out -+++ b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out -@@ -1,7 +1,7 @@ - -- Automatically generated by SQLQueryTestSuite - -- !query - CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES -- (1, 2), -+ (1, 2), - (1, 2), - (1, 3), - (1, 3), -@@ -17,8 +17,8 @@ struct<> - - -- !query - CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES -- (1, 2), -- (1, 2), -+ (1, 2), -+ (1, 2), - (2, 3), - (3, 4), - (null, null), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 0d807aeae4d..6d7744e771b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 888df13bac..02178f9da3 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1317,8 +1317,37 @@ case class CometUnionExec( children: Seq[SparkPlan]) extends CometExec { + // CometExec's default outputPartitioning delegates to `originalPlan`, which captures the + // children that were live at CometExecRule conversion time. AQE post-stage rewrites + // (coalesce, skew join, etc.) later re-parent our `children` field but do not update + // `originalPlan`, so the partitioning read from the frozen snapshot can describe a + // pre-coalesce layout with more partitions than the RDDs will actually produce. Recompute + // from current children so SPARK-52921's union-output-partitioning inference is based on + // the live plan. Safe on older Spark too: UnionExec.outputPartitioning returns + // UnknownPartitioning when UNION_OUTPUT_PARTITIONING is off (the pre-4.1 default). + // + // Only advertise SinglePartition or HashPartitioningLike — the same whitelist that Spark's + // UnionExec.comparePartitioning uses and that ShimCometUnionExec.unionRDDs honors via + // SQLPartitioningAwareUnionRDD. For anything else, report UnknownPartitioning so that the + // declared partitioning and the RDD layer always agree. + override lazy val outputPartitioning: Partitioning = { + originalPlan.withNewChildren(children).outputPartitioning match { + case p @ (SinglePartition | _: HashPartitioningLike) => p + case p => UnknownPartitioning(p.numPartitions) + } + } + override def doExecuteColumnar(): RDD[ColumnarBatch] = { - sparkContext.union(children.map(_.executeColumnar())) + // Spark 4.1's UnionExec (SPARK-52921) can report a non-trivial output partitioning when all + // children share the same hash/single partitioning, and downstream plans may skip an + // otherwise-required shuffle in response. Plain `sparkContext.union` concatenates partitions + // (so partition i of the result holds only one child's partition i), which violates that + // partitioning claim and silently corrupts aggregates layered above the union. The shim + // routes through SQLPartitioningAwareUnionRDD on 4.1+ when a known partitioning is declared. + shims.ShimCometUnionExec.unionRDDs( + sparkContext, + children.map(_.executeColumnar()), + outputPartitioning) } override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan = diff --git a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimCometUnionExec.scala b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimCometUnionExec.scala new file mode 100644 index 0000000000..15c480b057 --- /dev/null +++ b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimCometUnionExec.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import scala.reflect.ClassTag + +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.plans.physical.Partitioning + +object ShimCometUnionExec { + + /** + * Unions a sequence of RDDs while preserving the declared output partitioning. Before Spark + * 4.1, [[org.apache.spark.sql.execution.UnionExec]] always reports + * [[org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning]], so this shim simply + * concatenates partitions via `SparkContext.union`. The partitioning-aware path is only needed + * on Spark 4.1+ (see SPARK-52921). + */ + def unionRDDs[T: ClassTag]( + sc: SparkContext, + rdds: Seq[RDD[T]], + @annotation.nowarn("cat=unused") outputPartitioning: Partitioning): RDD[T] = { + sc.union(rdds) + } +} diff --git a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometUnionExec.scala b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometUnionExec.scala new file mode 100644 index 0000000000..15c480b057 --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometUnionExec.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import scala.reflect.ClassTag + +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.plans.physical.Partitioning + +object ShimCometUnionExec { + + /** + * Unions a sequence of RDDs while preserving the declared output partitioning. Before Spark + * 4.1, [[org.apache.spark.sql.execution.UnionExec]] always reports + * [[org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning]], so this shim simply + * concatenates partitions via `SparkContext.union`. The partitioning-aware path is only needed + * on Spark 4.1+ (see SPARK-52921). + */ + def unionRDDs[T: ClassTag]( + sc: SparkContext, + rdds: Seq[RDD[T]], + @annotation.nowarn("cat=unused") outputPartitioning: Partitioning): RDD[T] = { + sc.union(rdds) + } +} diff --git a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometUnionExec.scala b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometUnionExec.scala new file mode 100644 index 0000000000..15c480b057 --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometUnionExec.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import scala.reflect.ClassTag + +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.plans.physical.Partitioning + +object ShimCometUnionExec { + + /** + * Unions a sequence of RDDs while preserving the declared output partitioning. Before Spark + * 4.1, [[org.apache.spark.sql.execution.UnionExec]] always reports + * [[org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning]], so this shim simply + * concatenates partitions via `SparkContext.union`. The partitioning-aware path is only needed + * on Spark 4.1+ (see SPARK-52921). + */ + def unionRDDs[T: ClassTag]( + sc: SparkContext, + rdds: Seq[RDD[T]], + @annotation.nowarn("cat=unused") outputPartitioning: Partitioning): RDD[T] = { + sc.union(rdds) + } +} diff --git a/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometUnionExec.scala b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometUnionExec.scala new file mode 100644 index 0000000000..902ded377b --- /dev/null +++ b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometUnionExec.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import scala.reflect.ClassTag + +import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.{RDD, SQLPartitioningAwareUnionRDD} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioningLike, Partitioning, SinglePartition} + +object ShimCometUnionExec extends Logging { + + /** + * Unions a sequence of RDDs while preserving the declared output partitioning. Spark 4.1 + * introduced [[org.apache.spark.sql.internal.SQLConf.UNION_OUTPUT_PARTITIONING]] (SPARK-52921), + * which lets [[org.apache.spark.sql.execution.UnionExec]] report a non-trivial output + * partitioning when all children share the same partitioning. Downstream operators may then + * skip an otherwise-required shuffle, so the columnar Union path must honor that contract by + * routing through [[SQLPartitioningAwareUnionRDD]] rather than plain `SparkContext.union`, + * which concatenates partitions and breaks the partitioning invariant. + */ + def unionRDDs[T: ClassTag]( + sc: SparkContext, + rdds: Seq[RDD[T]], + outputPartitioning: Partitioning): RDD[T] = { + outputPartitioning match { + case SinglePartition | _: HashPartitioningLike => + val numPartitions = outputPartitioning.numPartitions + val nonEmpty = rdds.filter(_.partitions.nonEmpty) + // SQLPartitioningAwareUnionRDD indexes every child at every output partition, so any + // child whose partition count diverges from the declared numPartitions would raise + // ArrayIndexOutOfBoundsException. That would only happen if the declared partitioning + // is stale relative to the RDDs (e.g. children were coalesced by AQE but the reported + // partitioning was not). Fall back to plain concat in that case. + if (nonEmpty.isEmpty || nonEmpty.exists(_.partitions.length != numPartitions)) { + val childCounts = rdds.map(_.partitions.length).mkString(", ") + logWarning( + s"CometUnionExec: child partition counts ($childCounts) do not match " + + s"declared output partitioning numPartitions=$numPartitions; " + + "falling back to SparkContext.union concat.") + sc.union(rdds) + } else { + new SQLPartitioningAwareUnionRDD(sc, nonEmpty, numPartitions) + } + case _ => sc.union(rdds) + } + } +} diff --git a/spark/src/main/spark-4.2/org/apache/spark/sql/comet/shims/ShimCometUnionExec.scala b/spark/src/main/spark-4.2/org/apache/spark/sql/comet/shims/ShimCometUnionExec.scala new file mode 100644 index 0000000000..902ded377b --- /dev/null +++ b/spark/src/main/spark-4.2/org/apache/spark/sql/comet/shims/ShimCometUnionExec.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import scala.reflect.ClassTag + +import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.{RDD, SQLPartitioningAwareUnionRDD} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioningLike, Partitioning, SinglePartition} + +object ShimCometUnionExec extends Logging { + + /** + * Unions a sequence of RDDs while preserving the declared output partitioning. Spark 4.1 + * introduced [[org.apache.spark.sql.internal.SQLConf.UNION_OUTPUT_PARTITIONING]] (SPARK-52921), + * which lets [[org.apache.spark.sql.execution.UnionExec]] report a non-trivial output + * partitioning when all children share the same partitioning. Downstream operators may then + * skip an otherwise-required shuffle, so the columnar Union path must honor that contract by + * routing through [[SQLPartitioningAwareUnionRDD]] rather than plain `SparkContext.union`, + * which concatenates partitions and breaks the partitioning invariant. + */ + def unionRDDs[T: ClassTag]( + sc: SparkContext, + rdds: Seq[RDD[T]], + outputPartitioning: Partitioning): RDD[T] = { + outputPartitioning match { + case SinglePartition | _: HashPartitioningLike => + val numPartitions = outputPartitioning.numPartitions + val nonEmpty = rdds.filter(_.partitions.nonEmpty) + // SQLPartitioningAwareUnionRDD indexes every child at every output partition, so any + // child whose partition count diverges from the declared numPartitions would raise + // ArrayIndexOutOfBoundsException. That would only happen if the declared partitioning + // is stale relative to the RDDs (e.g. children were coalesced by AQE but the reported + // partitioning was not). Fall back to plain concat in that case. + if (nonEmpty.isEmpty || nonEmpty.exists(_.partitions.length != numPartitions)) { + val childCounts = rdds.map(_.partitions.length).mkString(", ") + logWarning( + s"CometUnionExec: child partition counts ($childCounts) do not match " + + s"declared output partitioning numPartitions=$numPartitions; " + + "falling back to SparkContext.union concat.") + sc.union(rdds) + } else { + new SQLPartitioningAwareUnionRDD(sc, nonEmpty, numPartitions) + } + case _ => sc.union(rdds) + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometSetOpWithGroupBySuite.scala b/spark/src/test/scala/org/apache/comet/CometSetOpWithGroupBySuite.scala new file mode 100644 index 0000000000..c6595ba42e --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometSetOpWithGroupBySuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.sql.{CometTestBase, DataFrame} +import org.apache.spark.sql.comet.CometUnionExec +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.internal.SQLConf + +/** + * Regression test for issue #4122: on Spark 4.1 (SPARK-52921), EXCEPT ALL / INTERSECT ALL whose + * sides are themselves GROUP BY aggregates are lowered to a plan where the union inherits a hash + * partitioning from its shuffled children, so the downstream final aggregate skips its shuffle. + * If Comet's columnar Union concatenates partitions it breaks that partitioning invariant and the + * resulting sums/counts collapse two sides into the wrong partitions. + */ +class CometSetOpWithGroupBySuite extends CometTestBase with AdaptiveSparkPlanHelper { + + test("issue #4122: EXCEPT ALL with GROUP BY under both sides") { + withTempView("tab3", "tab4") { + sql("""CREATE TEMPORARY VIEW tab3 AS SELECT * FROM VALUES + | (1, 2), (1, 2), (1, 3), (2, 3), (2, 2) AS tab3(k, v)""".stripMargin) + sql("""CREATE TEMPORARY VIEW tab4 AS SELECT * FROM VALUES + | (1, 2), (2, 3), (2, 2), (2, 2), (2, 20) AS tab4(k, v)""".stripMargin) + + val df = sql("""SELECT v FROM tab3 GROUP BY v + |EXCEPT ALL + |SELECT k FROM tab4 GROUP BY k""".stripMargin) + checkSparkAnswer(df) + assertContainsCometUnion(df) + } + } + + test("issue #4122: INTERSECT ALL with GROUP BY under both sides") { + withTempView("tab1", "tab2") { + sql("""CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + | (1, 2), (1, 2), (1, 3), (1, 3), (2, 3), + | (CAST(null AS INT), CAST(null AS INT)), + | (CAST(null AS INT), CAST(null AS INT)) AS tab1(k, v)""".stripMargin) + sql("""CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES + | (1, 2), (1, 2), (2, 3), (3, 4), + | (CAST(null AS INT), CAST(null AS INT)), + | (CAST(null AS INT), CAST(null AS INT)) AS tab2(k, v)""".stripMargin) + + val df = sql("""SELECT v FROM tab1 GROUP BY v + |INTERSECT ALL + |SELECT k FROM tab2 GROUP BY k""".stripMargin) + checkSparkAnswer(df) + assertContainsCometUnion(df) + } + } + + private def assertContainsCometUnion(df: DataFrame): Unit = { + val plan = df.queryExecution.executedPlan + val found = collectFirst(plan) { case u: CometUnionExec => u } + assert(found.isDefined, s"Expected CometUnionExec in plan but found none:\n$plan") + } + + test("UNION ALL with checkSparkAnswerAndOperator") { + withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + withTempView("u1", "u2") { + sql("""CREATE TEMPORARY VIEW u1 AS SELECT * FROM VALUES + | (1, 'a'), (2, 'b'), (3, 'c') AS u1(id, name)""".stripMargin) + sql("""CREATE TEMPORARY VIEW u2 AS SELECT * FROM VALUES + | (4, 'd'), (5, 'e'), (6, 'f') AS u2(id, name)""".stripMargin) + + val df = sql("SELECT id, name FROM u1 UNION ALL SELECT id, name FROM u2") + checkSparkAnswerAndOperator(df, includeClasses = Seq(classOf[CometUnionExec])) + } + } + } + + test("UNION ALL with SinglePartition (coalesce(1) children)") { + withSQLConf( + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false", + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + withTempView("s1", "s2") { + sql("""CREATE TEMPORARY VIEW s1 AS SELECT * FROM VALUES + | (10), (20), (30) AS s1(x)""".stripMargin) + sql("""CREATE TEMPORARY VIEW s2 AS SELECT * FROM VALUES + | (40), (50) AS s2(x)""".stripMargin) + + val df = sql("""SELECT * FROM (SELECT sum(x) as total FROM s1) + |UNION ALL + |SELECT * FROM (SELECT sum(x) as total FROM s2)""".stripMargin) + checkSparkAnswer(df) + assertContainsCometUnion(df) + } + } + } +}