From a454360378ba872f76969ca2800165baec211e76 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 5 May 2026 18:06:09 -0600 Subject: [PATCH 01/16] feat: add optimized PyArrow UDF execution (CometPythonMapInArrowExec) When Comet operators produce Arrow columnar data and the next operator is a Python UDF (mapInArrow/mapInPandas), Spark currently inserts an unnecessary ColumnarToRow transition. The Python runner then converts those rows back to Arrow to send to Python, creating a wasteful Arrow->Row->Arrow round-trip. This adds CometPythonMapInArrowExec which: - Accepts columnar input directly from Comet operators - Uses lightweight batch.rowIterator() instead of UnsafeProjection - Keeps the Python output as ColumnarBatch (no output row conversion) The optimization is detected in EliminateRedundantTransitions and controlled by spark.comet.exec.pythonMapInArrow.enabled (default: true). --- .../scala/org/apache/comet/CometConf.scala | 10 + .../rules/EliminateRedundantTransitions.scala | 42 +++- .../sql/comet/CometPythonMapInArrowExec.scala | 143 ++++++++++++++ .../resources/pyspark/test_pyarrow_udf.py | 183 ++++++++++++++++++ .../exec/CometPythonMapInArrowSuite.scala | 68 +++++++ 5 files changed, 445 insertions(+), 1 deletion(-) create mode 100644 spark/src/main/scala/org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala create mode 100644 spark/src/test/resources/pyspark/test_pyarrow_udf.py create mode 100644 spark/src/test/scala/org/apache/comet/exec/CometPythonMapInArrowSuite.scala diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index d3f51dfbe2..a06cd896ec 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -314,6 +314,16 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) + val COMET_PYTHON_MAP_IN_ARROW_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.exec.pythonMapInArrow.enabled") + .category(CATEGORY_EXEC) + .doc( + "Whether to enable optimized execution of PyArrow UDFs (mapInArrow/mapInPandas). " + + "When enabled, Comet passes Arrow columnar data directly to Python UDFs without " + + "the intermediate Arrow-to-Row-to-Arrow conversion that Spark normally performs.") + .booleanConf + .createWithDefault(true) + val COMET_TRACING_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.tracing.enabled") .category(CATEGORY_TUNING) .doc(s"Enable fine-grained tracing of events and memory usage. $TRACING_GUIDE.") diff --git a/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala b/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala index 7402a83248..272ef76484 100644 --- a/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala +++ b/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala @@ -20,13 +20,15 @@ package org.apache.comet.rules import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.PythonUDF import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.sideBySide -import org.apache.spark.sql.comet.{CometCollectLimitExec, CometColumnarToRowExec, CometNativeColumnarToRowExec, CometNativeWriteExec, CometPlan, CometSparkToColumnarExec} +import org.apache.spark.sql.comet.{CometCollectLimitExec, CometColumnarToRowExec, CometNativeColumnarToRowExec, CometNativeWriteExec, CometPlan, CometPythonMapInArrowExec, CometSparkToColumnarExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.QueryStageExec import org.apache.spark.sql.execution.exchange.ReusedExchangeExec +import org.apache.spark.sql.execution.python.{MapInPandasExec, PythonMapInArrowExec} import org.apache.comet.CometConf @@ -98,6 +100,32 @@ case class EliminateRedundantTransitions(session: SparkSession) extends Rule[Spa case CometNativeColumnarToRowExec(sparkToColumnar: CometSparkToColumnarExec) => sparkToColumnar.child case CometSparkToColumnarExec(child: CometSparkToColumnarExec) => child + // Replace MapInBatchExec (PythonMapInArrowExec / MapInPandasExec) that has a + // ColumnarToRow child with CometPythonMapInArrowExec to avoid the unnecessary + // Arrow->Row->Arrow round-trip. + case p: PythonMapInArrowExec if CometConf.COMET_PYTHON_MAP_IN_ARROW_ENABLED.get() => + extractColumnarChild(p.child) + .map { columnarChild => + CometPythonMapInArrowExec( + p.func, + p.output, + columnarChild, + p.isBarrier, + p.func.asInstanceOf[PythonUDF].evalType) + } + .getOrElse(p) + case p: MapInPandasExec if CometConf.COMET_PYTHON_MAP_IN_ARROW_ENABLED.get() => + extractColumnarChild(p.child) + .map { columnarChild => + CometPythonMapInArrowExec( + p.func, + p.output, + columnarChild, + p.isBarrier, + p.func.asInstanceOf[PythonUDF].evalType) + } + .getOrElse(p) + // Spark adds `RowToColumnar` under Comet columnar shuffle. But it's redundant as the // shuffle takes row-based input. case s @ CometShuffleExchangeExec( @@ -130,6 +158,18 @@ case class EliminateRedundantTransitions(session: SparkSession) extends Rule[Spa } } + /** + * If the given plan is a ColumnarToRow transition wrapping a columnar child, returns that + * columnar child. Used to detect and eliminate unnecessary transitions before Python UDF + * operators. + */ + private def extractColumnarChild(plan: SparkPlan): Option[SparkPlan] = plan match { + case ColumnarToRowExec(child) if child.supportsColumnar => Some(child) + case CometColumnarToRowExec(child) => Some(child) + case CometNativeColumnarToRowExec(child) => Some(child) + case _ => None + } + /** * Creates an appropriate columnar to row transition operator. * diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala new file mode 100644 index 0000000000..84b3c31113 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import scala.collection.JavaConverters._ + +import org.apache.spark.{ContextAwareIterator, TaskContext} +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.python.{ArrowPythonRunner, BatchIterator, PythonSQLMetrics} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} + +/** + * An optimized version of Spark's MapInBatchExec (PythonMapInArrowExec / MapInPandasExec) that + * accepts columnar input directly from Comet operators, avoiding unnecessary Arrow -> Row -> + * Arrow conversions. + * + * Normal Spark flow: CometNativeExec (Arrow) -> ColumnarToRow -> PythonMapInArrowExec + * (internally: rows -> Arrow -> Python -> Arrow -> rows) + * + * Optimized flow: CometNativeExec (Arrow) -> CometPythonMapInArrowExec (batch.rowIterator() -> + * Arrow -> Python -> Arrow columnar output) + * + * This eliminates: + * 1. The UnsafeProjection in ColumnarToRow (expensive copy) 2. The output Arrow->Row conversion + * (keeps Python output as ColumnarBatch) + */ +case class CometPythonMapInArrowExec( + func: Expression, + output: Seq[Attribute], + child: SparkPlan, + isBarrier: Boolean, + pythonEvalType: Int) + extends UnaryExecNode + with PythonSQLMetrics { + + override def supportsColumnar: Boolean = true + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override lazy val metrics: Map[String, SQLMetric] = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "number of output batches"), + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows")) ++ + pythonMetrics + + override def doExecute(): RDD[InternalRow] = { + ColumnarToRowExec(this).doExecute() + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val numOutputBatches = longMetric("numOutputBatches") + val numInputRows = longMetric("numInputRows") + + val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) + val pythonFunction = func.asInstanceOf[PythonUDF].func + val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) + val localOutput = output + val localChildSchema = child.schema + val batchSize = conf.arrowMaxRecordsPerBatch + val sessionLocalTimeZone = conf.sessionLocalTimeZone + val largeVarTypes = conf.arrowUseLargeVarTypes + val localPythonEvalType = pythonEvalType + val localPythonMetrics = pythonMetrics + val jobArtifactUUID = + org.apache.spark.JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + + val inputRDD = child.executeColumnar() + + inputRDD.mapPartitionsInternal { batches => + val context = TaskContext.get() + val argOffsets = Array(Array(0)) + + // Convert columnar batches to rows using lightweight rowIterator + // (avoids UnsafeProjection copy that ColumnarToRow would do) + val rowIter = batches.flatMap { batch => + numInputRows += batch.numRows() + batch.rowIterator().asScala + } + + val contextAwareIterator = new ContextAwareIterator(context, rowIter) + + // Wrap rows as a struct, matching MapInBatchEvaluatorFactory behavior + val wrappedIter = contextAwareIterator.map(InternalRow(_)) + + val batchIter = + if (batchSize > 0) new BatchIterator(wrappedIter, batchSize) else Iterator(wrappedIter) + + val columnarBatchIter = new ArrowPythonRunner( + chainedFunc, + localPythonEvalType, + argOffsets, + org.apache.spark.sql.types + .StructType(Array(org.apache.spark.sql.types.StructField("struct", localChildSchema))), + sessionLocalTimeZone, + largeVarTypes, + pythonRunnerConf, + localPythonMetrics, + jobArtifactUUID).compute(batchIter, context.partitionId(), context) + + columnarBatchIter.map { batch => + // Python returns a StructType column; flatten to individual columns + val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + val outputVectors = localOutput.indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + numOutputRows += flattenedBatch.numRows() + numOutputBatches += 1 + flattenedBatch + } + } + } + + override protected def withNewChildInternal(newChild: SparkPlan): CometPythonMapInArrowExec = + copy(child = newChild) +} diff --git a/spark/src/test/resources/pyspark/test_pyarrow_udf.py b/spark/src/test/resources/pyspark/test_pyarrow_udf.py new file mode 100644 index 0000000000..04b83fe66b --- /dev/null +++ b/spark/src/test/resources/pyspark/test_pyarrow_udf.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Integration test for CometPythonMapInArrowExec. + +This test verifies that Comet's optimized PyArrow UDF execution works correctly +by checking: +1. The plan uses CometPythonMapInArrowExec instead of PythonMapInArrow + ColumnarToRow +2. The UDF produces correct results +3. Performance improvement by eliminating unnecessary Arrow->Row->Arrow conversions + +Usage: + # Build Comet first: make release + # Then run with PySpark: + spark-submit --jars spark/target/comet-spark-spark3.5_2.12-*.jar \ + --conf spark.plugins=org.apache.comet.CometSparkSessionExtensions \ + --conf spark.comet.enabled=true \ + --conf spark.comet.exec.enabled=true \ + --conf spark.comet.exec.pythonMapInArrow.enabled=true \ + --conf spark.shuffle.manager=org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager \ + --conf spark.memory.offHeap.enabled=true \ + --conf spark.memory.offHeap.size=2g \ + spark/src/test/resources/pyspark/test_pyarrow_udf.py +""" + +import sys +import pyarrow as pa +from pyspark.sql import SparkSession +from pyspark.sql import types as T + + +def test_map_in_arrow_basic(): + """Test basic mapInArrow with Comet optimization.""" + spark = SparkSession.builder.getOrCreate() + + # Create test data + data = [(i, float(i * 1.5), f"name_{i}") for i in range(100)] + df = spark.createDataFrame(data, ["id", "value", "name"]) + + # Write to parquet so CometScan can read it + df.write.mode("overwrite").parquet("/tmp/comet_pyarrow_test_data") + test_df = spark.read.parquet("/tmp/comet_pyarrow_test_data") + + # Define a PyArrow UDF that doubles the value column + def double_value(batch: pa.RecordBatch) -> pa.RecordBatch: + pdf = batch.to_pandas() + pdf["value"] = pdf["value"] * 2 + return pa.RecordBatch.from_pandas(pdf) + + output_schema = T.StructType([ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + T.StructField("name", T.StringType()), + ]) + + # Apply mapInArrow + result_df = test_df.mapInArrow(double_value, output_schema) + + # Check the explain plan + print("=" * 60) + print("PHYSICAL PLAN:") + print("=" * 60) + result_df.explain(mode="extended") + print("=" * 60) + + plan_str = result_df.queryExecution.executedPlan.toString() + print(f"\nPlan string:\n{plan_str}\n") + + # Verify CometPythonMapInArrowExec is in the plan (if Comet is active) + if "CometPythonMapInArrowExec" in plan_str: + print("SUCCESS: CometPythonMapInArrowExec is in the plan!") + elif "CometScan" in plan_str and "ColumnarToRow" in plan_str: + print("WARNING: CometScan present but still using ColumnarToRow before Python UDF") + elif "CometScan" not in plan_str: + print("INFO: Comet is not active for this query (CometScan not found)") + else: + print("INFO: Plan does not contain CometPythonMapInArrowExec") + + # Verify correctness + result = result_df.orderBy("id").collect() + expected_first = data[0] + actual_first = result[0] + + assert actual_first["id"] == expected_first[0], \ + f"ID mismatch: {actual_first['id']} != {expected_first[0]}" + assert abs(actual_first["value"] - expected_first[1] * 2) < 0.001, \ + f"Value mismatch: {actual_first['value']} != {expected_first[1] * 2}" + assert actual_first["name"] == expected_first[2], \ + f"Name mismatch: {actual_first['name']} != {expected_first[2]}" + + print(f"\nFirst row: {actual_first}") + print(f"Expected value (doubled): {expected_first[1] * 2}") + print("CORRECTNESS: PASSED") + + # Verify all rows + for i, row in enumerate(result): + expected_val = data[i][1] * 2 + assert abs(row["value"] - expected_val) < 0.001, \ + f"Row {i}: expected value {expected_val}, got {row['value']}" + + print(f"All {len(result)} rows verified correctly.") + return True + + +def test_map_in_arrow_type_change(): + """Test mapInArrow that changes the schema.""" + spark = SparkSession.builder.getOrCreate() + + data = [(i, float(i)) for i in range(50)] + df = spark.createDataFrame(data, ["id", "value"]) + df.write.mode("overwrite").parquet("/tmp/comet_pyarrow_test_data2") + test_df = spark.read.parquet("/tmp/comet_pyarrow_test_data2") + + def add_computed_column(batch: pa.RecordBatch) -> pa.RecordBatch: + pdf = batch.to_pandas() + pdf["squared"] = pdf["value"] ** 2 + pdf["label"] = pdf["id"].apply(lambda x: f"item_{x}") + return pa.RecordBatch.from_pandas(pdf) + + output_schema = T.StructType([ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + T.StructField("squared", T.DoubleType()), + T.StructField("label", T.StringType()), + ]) + + result_df = test_df.mapInArrow(add_computed_column, output_schema) + result = result_df.orderBy("id").collect() + + assert len(result) == 50 + for i, row in enumerate(result): + assert abs(row["squared"] - float(i) ** 2) < 0.001 + assert row["label"] == f"item_{i}" + + print("test_map_in_arrow_type_change: PASSED") + return True + + +if __name__ == "__main__": + print("Running PyArrow UDF integration tests for Comet...") + print() + + tests = [ + ("test_map_in_arrow_basic", test_map_in_arrow_basic), + ("test_map_in_arrow_type_change", test_map_in_arrow_type_change), + ] + + passed = 0 + failed = 0 + for name, test_fn in tests: + print(f"\n{'=' * 60}") + print(f"Running: {name}") + print(f"{'=' * 60}") + try: + test_fn() + passed += 1 + except Exception as e: + print(f"FAILED: {e}") + import traceback + traceback.print_exc() + failed += 1 + + print(f"\n{'=' * 60}") + print(f"Results: {passed} passed, {failed} failed") + print(f"{'=' * 60}") + + sys.exit(0 if failed == 0 else 1) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometPythonMapInArrowSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometPythonMapInArrowSuite.scala new file mode 100644 index 0000000000..94145cea2b --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/exec/CometPythonMapInArrowSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.exec + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.comet.CometPythonMapInArrowExec +import org.apache.spark.sql.execution.ColumnarToRowExec +import org.apache.spark.sql.execution.python.PythonMapInArrowExec + +import org.apache.comet.CometConf + +class CometPythonMapInArrowSuite extends CometTestBase { + + test("plan with CometScan has columnar support for Python UDF optimization") { + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_PYTHON_MAP_IN_ARROW_ENABLED.key -> "true") { + withParquetTable( + (1 to 10).map(i => (i.toDouble, s"str_$i")), + "testTable", + withDictionary = false) { + val df = spark.sql("SELECT * FROM testTable") + val plan = df.queryExecution.executedPlan + val cometScans = plan.collect { case s if s.supportsColumnar => s } + assert(cometScans.nonEmpty, "Expected columnar operators that can feed Python UDFs") + } + } + } + + test("config disables Python map in arrow optimization") { + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_PYTHON_MAP_IN_ARROW_ENABLED.key -> "false") { + withParquetTable( + (1 to 10).map(i => (i.toDouble, s"str_$i")), + "testTable", + withDictionary = false) { + val df = spark.sql("SELECT * FROM testTable") + val plan = df.queryExecution.executedPlan + // With the feature disabled, no CometPythonMapInArrowExec should appear + val cometPythonExecs = + plan.collect { case e: CometPythonMapInArrowExec => e } + assert( + cometPythonExecs.isEmpty, + "CometPythonMapInArrowExec should not appear when disabled") + } + } + } +} From 84aec8406f093abff96ca916ca9c4602065f9019 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 5 May 2026 18:11:52 -0600 Subject: [PATCH 02/16] docs: add PyArrow UDF acceleration user guide page Documents the CometPythonMapInArrowExec optimization, including supported APIs, configuration, usage example, and how to verify the optimization is active in query plans. --- docs/source/user-guide/latest/index.rst | 1 + docs/source/user-guide/latest/pyarrow-udfs.md | 132 ++++++++++++++++++ .../resources/pyspark/test_pyarrow_udf.py | 3 +- 3 files changed, 135 insertions(+), 1 deletion(-) create mode 100644 docs/source/user-guide/latest/pyarrow-udfs.md diff --git a/docs/source/user-guide/latest/index.rst b/docs/source/user-guide/latest/index.rst index 480ec4f702..c96dea7750 100644 --- a/docs/source/user-guide/latest/index.rst +++ b/docs/source/user-guide/latest/index.rst @@ -38,5 +38,6 @@ Comet $COMET_VERSION User Guide Understanding Comet Plans Tuning Guide Metrics Guide + PyArrow UDF Acceleration Iceberg Guide Kubernetes Guide diff --git a/docs/source/user-guide/latest/pyarrow-udfs.md b/docs/source/user-guide/latest/pyarrow-udfs.md new file mode 100644 index 0000000000..71701960cd --- /dev/null +++ b/docs/source/user-guide/latest/pyarrow-udfs.md @@ -0,0 +1,132 @@ + + +# PyArrow UDF Acceleration + +Comet can accelerate Python UDFs that use PyArrow-backed batch processing, such as `mapInArrow` and `mapInPandas`. +These APIs are commonly used for ML inference, feature engineering, and data transformation workloads. + +## Background + +Spark's `mapInArrow` and `mapInPandas` APIs allow users to apply Python functions that operate on Arrow +RecordBatches or Pandas DataFrames. Under the hood, Spark communicates with the Python worker process +using the Arrow IPC format. + +Without Comet, the execution path for these UDFs involves unnecessary data conversions: + +1. Comet reads data in Arrow columnar format (via CometScan) +2. Spark inserts a ColumnarToRow transition (converts Arrow to UnsafeRow) +3. The Python runner converts those rows back to Arrow to send to Python +4. Python executes the UDF on Arrow batches +5. Results are returned as Arrow and then converted back to rows + +Steps 2 and 3 are redundant since the data starts and ends in Arrow format. + +## How Comet Optimizes This + +When enabled, Comet detects `PythonMapInArrowExec` and `MapInPandasExec` operators in the physical plan +and replaces them with `CometPythonMapInArrowExec`, which: + +- Reads Arrow columnar batches directly from the upstream Comet operator +- Feeds them to the Python runner without the expensive UnsafeProjection copy +- Keeps the Python output in columnar format for downstream operators + +This eliminates the ColumnarToRow transition and the output row conversion, reducing CPU overhead +and memory allocations. + +## Configuration + +The optimization is controlled by: + +``` +spark.comet.exec.pythonMapInArrow.enabled=true (default) +``` + +It is enabled by default when Comet execution is active. + +## Supported APIs + +| PySpark API | Spark Plan Node | Supported | +|-------------|-----------------|-----------| +| `df.mapInArrow(func, schema)` | `PythonMapInArrowExec` | Yes | +| `df.mapInPandas(func, schema)` | `MapInPandasExec` | Yes | +| `@pandas_udf` (scalar) | `ArrowEvalPythonExec` | Not yet | +| `df.applyInPandas(func, schema)` | `FlatMapGroupsInPandasExec` | Not yet | + +## Example + +```python +import pyarrow as pa +from pyspark.sql import SparkSession, types as T + +spark = SparkSession.builder \ + .config("spark.plugins", "org.apache.spark.CometPlugin") \ + .config("spark.comet.enabled", "true") \ + .config("spark.comet.exec.enabled", "true") \ + .config("spark.comet.exec.pythonMapInArrow.enabled", "true") \ + .config("spark.memory.offHeap.enabled", "true") \ + .config("spark.memory.offHeap.size", "2g") \ + .getOrCreate() + +df = spark.read.parquet("data.parquet") + +def transform(batch: pa.RecordBatch) -> pa.RecordBatch: + # Your transformation logic here + table = batch.to_pandas() + table["new_col"] = table["value"] * 2 + return pa.RecordBatch.from_pandas(table) + +output_schema = T.StructType([ + T.StructField("value", T.DoubleType()), + T.StructField("new_col", T.DoubleType()), +]) + +result = df.mapInArrow(transform, output_schema) +``` + +## Verifying the Optimization + +Use `explain()` to verify that `CometPythonMapInArrowExec` appears in your plan: + +```python +result.explain(mode="extended") +``` + +You should see: +``` +CometPythonMapInArrowExec ... ++- CometNativeExec ... + +- CometScan ... +``` + +Instead of the unoptimized plan: +``` +PythonMapInArrow ... ++- ColumnarToRow + +- CometNativeExec ... + +- CometScan ... +``` + +## Limitations + +- The optimization currently applies only to `mapInArrow` and `mapInPandas`. Scalar pandas UDFs + (`@pandas_udf`) and grouped operations (`applyInPandas`) are not yet supported. +- The internal row-to-Arrow conversion inside the Python runner is still present in this version. + A future optimization will write Arrow batches directly to the Python IPC stream, achieving + near zero-copy data transfer. diff --git a/spark/src/test/resources/pyspark/test_pyarrow_udf.py b/spark/src/test/resources/pyspark/test_pyarrow_udf.py index 04b83fe66b..1993f29f9f 100644 --- a/spark/src/test/resources/pyspark/test_pyarrow_udf.py +++ b/spark/src/test/resources/pyspark/test_pyarrow_udf.py @@ -26,10 +26,11 @@ 3. Performance improvement by eliminating unnecessary Arrow->Row->Arrow conversions Usage: + # Requires Python 3.11 or 3.12 (PySpark 3.5 does not support 3.13+) # Build Comet first: make release # Then run with PySpark: spark-submit --jars spark/target/comet-spark-spark3.5_2.12-*.jar \ - --conf spark.plugins=org.apache.comet.CometSparkSessionExtensions \ + --conf spark.plugins=org.apache.spark.CometPlugin \ --conf spark.comet.enabled=true \ --conf spark.comet.exec.enabled=true \ --conf spark.comet.exec.pythonMapInArrow.enabled=true \ From af98fbba92faed24484ae32504218821b4eb59d7 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 5 May 2026 18:38:38 -0600 Subject: [PATCH 03/16] fix(test): correct PyArrow UDF integration test signatures and assertions Fix three issues that prevented test_pyarrow_udf.py from running: 1. mapInArrow callbacks must accept Iterator[pa.RecordBatch] and yield batches. The previous single-batch signatures crashed with "'map' object has no attribute 'to_pandas'". 2. PySpark DataFrame has no `queryExecution` attribute. Use `_jdf.queryExecution().executedPlan().toString()` instead. 3. Replace soft plan-string heuristics with assertions that fail loudly if the optimization regresses. Match on `CometPythonMapInArrow` (no `Exec` suffix in the plan toString) and assert no `ColumnarToRow` transition is present. --- .../resources/pyspark/test_pyarrow_udf.py | 42 ++++++++++--------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/spark/src/test/resources/pyspark/test_pyarrow_udf.py b/spark/src/test/resources/pyspark/test_pyarrow_udf.py index 1993f29f9f..6acac6a912 100644 --- a/spark/src/test/resources/pyspark/test_pyarrow_udf.py +++ b/spark/src/test/resources/pyspark/test_pyarrow_udf.py @@ -58,11 +58,13 @@ def test_map_in_arrow_basic(): df.write.mode("overwrite").parquet("/tmp/comet_pyarrow_test_data") test_df = spark.read.parquet("/tmp/comet_pyarrow_test_data") - # Define a PyArrow UDF that doubles the value column - def double_value(batch: pa.RecordBatch) -> pa.RecordBatch: - pdf = batch.to_pandas() - pdf["value"] = pdf["value"] * 2 - return pa.RecordBatch.from_pandas(pdf) + # Define a PyArrow UDF that doubles the value column. + # mapInArrow callbacks receive an iterator of RecordBatches and must yield batches. + def double_value(iterator): + for batch in iterator: + pdf = batch.to_pandas() + pdf["value"] = pdf["value"] * 2 + yield pa.RecordBatch.from_pandas(pdf) output_schema = T.StructType([ T.StructField("id", T.LongType()), @@ -80,18 +82,17 @@ def double_value(batch: pa.RecordBatch) -> pa.RecordBatch: result_df.explain(mode="extended") print("=" * 60) - plan_str = result_df.queryExecution.executedPlan.toString() + plan_str = result_df._jdf.queryExecution().executedPlan().toString() print(f"\nPlan string:\n{plan_str}\n") - # Verify CometPythonMapInArrowExec is in the plan (if Comet is active) - if "CometPythonMapInArrowExec" in plan_str: - print("SUCCESS: CometPythonMapInArrowExec is in the plan!") - elif "CometScan" in plan_str and "ColumnarToRow" in plan_str: - print("WARNING: CometScan present but still using ColumnarToRow before Python UDF") - elif "CometScan" not in plan_str: - print("INFO: Comet is not active for this query (CometScan not found)") - else: - print("INFO: Plan does not contain CometPythonMapInArrowExec") + # Verify the optimized Comet operator is in the plan. The toString form is + # "CometPythonMapInArrow" (no Exec suffix) and the upstream scan prints as + # "CometNativeScan". + assert "CometPythonMapInArrow" in plan_str, \ + f"CometPythonMapInArrow missing from plan:\n{plan_str}" + assert "ColumnarToRow" not in plan_str, \ + f"Unexpected ColumnarToRow in optimized plan:\n{plan_str}" + print("SUCCESS: CometPythonMapInArrow is in the plan with no ColumnarToRow transition.") # Verify correctness result = result_df.orderBy("id").collect() @@ -128,11 +129,12 @@ def test_map_in_arrow_type_change(): df.write.mode("overwrite").parquet("/tmp/comet_pyarrow_test_data2") test_df = spark.read.parquet("/tmp/comet_pyarrow_test_data2") - def add_computed_column(batch: pa.RecordBatch) -> pa.RecordBatch: - pdf = batch.to_pandas() - pdf["squared"] = pdf["value"] ** 2 - pdf["label"] = pdf["id"].apply(lambda x: f"item_{x}") - return pa.RecordBatch.from_pandas(pdf) + def add_computed_column(iterator): + for batch in iterator: + pdf = batch.to_pandas() + pdf["squared"] = pdf["value"] ** 2 + pdf["label"] = pdf["id"].apply(lambda x: f"item_{x}") + yield pa.RecordBatch.from_pandas(pdf) output_schema = T.StructType([ T.StructField("id", T.LongType()), From f29cb2f53f5437edcfc906129a8ca3253fb0b0ea Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 5 May 2026 18:43:55 -0600 Subject: [PATCH 04/16] test: convert PyArrow UDF script to pytest and add CI coverage - Rewrite test_pyarrow_udf.py as a pytest module. A session-scoped SparkSession fixture builds the Comet-enabled session once and a parametrized `accelerated` fixture toggles spark.comet.exec.pythonMapInArrow.enabled per test, so each case runs under both the optimized and fallback paths and asserts the expected plan operator (`CometPythonMapInArrow` vs vanilla `PythonMapInArrow`). The jar is auto-discovered from spark/target by matching the installed pyspark version, or taken from the COMET_JAR env var. - Add a dedicated `PyArrow UDF Tests` workflow that builds Comet against Spark 3.5 / Scala 2.12, installs pyspark/pyarrow/pandas/pytest, and runs the new pytest module. - Add CometPythonMapInArrowSuite to the `exec` suite list in both pr_build_linux.yml and pr_build_macos.yml so the JVM-side suite is exercised on every PR. --- .github/workflows/pr_build_linux.yml | 1 + .github/workflows/pr_build_macos.yml | 1 + .github/workflows/pyarrow_udf_test.yml | 96 ++++++ .../resources/pyspark/test_pyarrow_udf.py | 299 +++++++++--------- 4 files changed, 256 insertions(+), 141 deletions(-) create mode 100644 .github/workflows/pyarrow_udf_test.yml diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index b0f09bc43b..b62a000f6c 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -354,6 +354,7 @@ jobs: org.apache.comet.exec.CometGenerateExecSuite org.apache.comet.exec.CometWindowExecSuite org.apache.comet.exec.CometJoinSuite + org.apache.comet.exec.CometPythonMapInArrowSuite org.apache.comet.CometNativeSuite org.apache.comet.CometSparkSessionExtensionsSuite org.apache.spark.CometPluginsSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index c743d1888a..fe972818e6 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -193,6 +193,7 @@ jobs: org.apache.comet.exec.CometGenerateExecSuite org.apache.comet.exec.CometWindowExecSuite org.apache.comet.exec.CometJoinSuite + org.apache.comet.exec.CometPythonMapInArrowSuite org.apache.comet.CometNativeSuite org.apache.comet.CometSparkSessionExtensionsSuite org.apache.spark.CometPluginsSuite diff --git a/.github/workflows/pyarrow_udf_test.yml b/.github/workflows/pyarrow_udf_test.yml new file mode 100644 index 0000000000..0779f092a4 --- /dev/null +++ b/.github/workflows/pyarrow_udf_test.yml @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: PyArrow UDF Tests + +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + +on: + push: + branches: + - main + paths: + - "spark/src/main/scala/org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala" + - "spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala" + - "spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala" + - "spark/src/test/resources/pyspark/test_pyarrow_udf.py" + - ".github/workflows/pyarrow_udf_test.yml" + - "native/**" + pull_request: + paths: + - "spark/src/main/scala/org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala" + - "spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala" + - "spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala" + - "spark/src/test/resources/pyspark/test_pyarrow_udf.py" + - ".github/workflows/pyarrow_udf_test.yml" + - "native/**" + workflow_dispatch: + +env: + RUST_VERSION: stable + RUST_BACKTRACE: 1 + RUSTFLAGS: "-Clink-arg=-fuse-ld=bfd" + +jobs: + pyarrow-udf: + name: PyArrow UDF (Spark 3.5, JDK 17, Python 3.11) + runs-on: ubuntu-latest + container: + image: amd64/rust + env: + JAVA_TOOL_OPTIONS: "--add-exports=java.base/sun.nio.ch=ALL-UNNAMED --add-exports=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.lang=ALL-UNNAMED" + steps: + - uses: actions/checkout@v6 + + - name: Setup Rust & Java toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ env.RUST_VERSION }} + jdk-version: 17 + + - name: Cache Maven dependencies + uses: actions/cache@v5 + with: + path: | + ~/.m2/repository + /root/.m2/repository + key: ${{ runner.os }}-java-maven-${{ hashFiles('**/pom.xml') }}-pyarrow-udf + restore-keys: | + ${{ runner.os }}-java-maven- + + - name: Build Comet (release, Spark 3.5 / Scala 2.12) + run: | + cd native && cargo build --release + cd .. && ./mvnw -B -Prelease install -DskipTests -Pspark-3.5 -Pscala-2.12 + + - name: Install Python 3.11 and pip + run: | + apt-get update + apt-get install -y --no-install-recommends python3.11 python3.11-venv python3-pip + python3.11 -m venv /tmp/venv + /tmp/venv/bin/pip install --upgrade pip + /tmp/venv/bin/pip install "pyspark==3.5.8" "pyarrow>=14" pandas pytest + + - name: Run PyArrow UDF pytest + run: | + jar=$(ls "$PWD"/spark/target/comet-spark-spark3.5_2.12-*-SNAPSHOT.jar \ + | grep -v sources | grep -v tests | head -n1) + echo "Using $jar" + COMET_JAR="$jar" /tmp/venv/bin/python -m pytest -v \ + spark/src/test/resources/pyspark/test_pyarrow_udf.py diff --git a/spark/src/test/resources/pyspark/test_pyarrow_udf.py b/spark/src/test/resources/pyspark/test_pyarrow_udf.py index 6acac6a912..462f4efdc6 100644 --- a/spark/src/test/resources/pyspark/test_pyarrow_udf.py +++ b/spark/src/test/resources/pyspark/test_pyarrow_udf.py @@ -17,117 +17,165 @@ # under the License. """ -Integration test for CometPythonMapInArrowExec. +Pytest-driven integration tests for Comet's PyArrow UDF acceleration. -This test verifies that Comet's optimized PyArrow UDF execution works correctly -by checking: -1. The plan uses CometPythonMapInArrowExec instead of PythonMapInArrow + ColumnarToRow -2. The UDF produces correct results -3. Performance improvement by eliminating unnecessary Arrow->Row->Arrow conversions +Each test runs against two execution paths: + - "accelerated": spark.comet.exec.pythonMapInArrow.enabled=true + (plan should contain CometPythonMapInArrow and no ColumnarToRow) + - "fallback": spark.comet.exec.pythonMapInArrow.enabled=false + (plan should contain vanilla PythonMapInArrow) Usage: - # Requires Python 3.11 or 3.12 (PySpark 3.5 does not support 3.13+) - # Build Comet first: make release - # Then run with PySpark: - spark-submit --jars spark/target/comet-spark-spark3.5_2.12-*.jar \ - --conf spark.plugins=org.apache.spark.CometPlugin \ - --conf spark.comet.enabled=true \ - --conf spark.comet.exec.enabled=true \ - --conf spark.comet.exec.pythonMapInArrow.enabled=true \ - --conf spark.shuffle.manager=org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager \ - --conf spark.memory.offHeap.enabled=true \ - --conf spark.memory.offHeap.size=2g \ - spark/src/test/resources/pyspark/test_pyarrow_udf.py -""" + # Build Comet first: + make release -import sys -import pyarrow as pa -from pyspark.sql import SparkSession -from pyspark.sql import types as T + # Then either let the test discover the jar from spark/target, or pass it + # explicitly via COMET_JAR: + export COMET_JAR=$PWD/spark/target/comet-spark-spark3.5_2.12-0.16.0-SNAPSHOT.jar + pip install pyspark==3.5.8 pyarrow pandas pytest + pytest -v spark/src/test/resources/pyspark/test_pyarrow_udf.py +""" -def test_map_in_arrow_basic(): - """Test basic mapInArrow with Comet optimization.""" - spark = SparkSession.builder.getOrCreate() +import glob +import os - # Create test data +import pyarrow as pa +import pytest +from pyspark.sql import SparkSession, types as T + + +REPO_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..") +) + + +def _resolve_comet_jar() -> str: + explicit = os.environ.get("COMET_JAR") + if explicit: + if any(ch in explicit for ch in "*?["): + matches = sorted(glob.glob(explicit)) + if not matches: + raise FileNotFoundError( + f"COMET_JAR pattern matched nothing: {explicit}" + ) + return matches[-1] + return explicit + + # Pick the jar that matches the installed pyspark major.minor version. The + # Comet jars are published per Spark version (e.g., comet-spark-spark3.5_2.12-*.jar); + # using the wrong one yields ClassNotFoundException on Scala stdlib classes. + import pyspark + + major_minor = ".".join(pyspark.__version__.split(".")[:2]) + spark_tag = f"spark{major_minor}" + scala_tag = "_2.12" if major_minor.startswith("3.") else "_2.13" + pattern = os.path.join( + REPO_ROOT, + f"spark/target/comet-spark-{spark_tag}{scala_tag}-*-SNAPSHOT.jar", + ) + candidates = [ + m + for m in sorted(glob.glob(pattern)) + if "sources" not in os.path.basename(m) and "tests" not in os.path.basename(m) + ] + if not candidates: + raise FileNotFoundError( + "Comet jar not found. Set COMET_JAR or run `make release`. " + f"Looked under {pattern}." + ) + return candidates[-1] + + +@pytest.fixture(scope="session") +def spark(): + jar = _resolve_comet_jar() + # PYSPARK_SUBMIT_ARGS is consumed when pyspark launches its JVM. Setting + # --jars puts the Comet jar on both driver and executor classpaths so the + # CometPlugin can be loaded. + os.environ["PYSPARK_SUBMIT_ARGS"] = ( + f"--jars {jar} --driver-class-path {jar} pyspark-shell" + ) + session = ( + SparkSession.builder.master("local[2]") + .appName("comet-pyarrow-udf-tests") + .config("spark.plugins", "org.apache.spark.CometPlugin") + .config("spark.comet.enabled", "true") + .config("spark.comet.exec.enabled", "true") + .config("spark.memory.offHeap.enabled", "true") + .config("spark.memory.offHeap.size", "2g") + .getOrCreate() + ) + try: + yield session + finally: + session.stop() + + +@pytest.fixture(params=[True, False], ids=["accelerated", "fallback"]) +def accelerated(request, spark) -> bool: + spark.conf.set( + "spark.comet.exec.pythonMapInArrow.enabled", + "true" if request.param else "false", + ) + return request.param + + +def _executed_plan(df) -> str: + return df._jdf.queryExecution().executedPlan().toString() + + +def _assert_plan_matches_mode(plan: str, accelerated: bool) -> None: + if accelerated: + assert "CometPythonMapInArrow" in plan, ( + f"expected CometPythonMapInArrow in accelerated plan, got:\n{plan}" + ) + assert "ColumnarToRow" not in plan, ( + f"unexpected ColumnarToRow in accelerated plan:\n{plan}" + ) + else: + assert "CometPythonMapInArrow" not in plan, ( + f"unexpected CometPythonMapInArrow in fallback plan:\n{plan}" + ) + assert "PythonMapInArrow" in plan, ( + f"expected PythonMapInArrow in fallback plan, got:\n{plan}" + ) + + +def test_map_in_arrow_doubles_value(spark, tmp_path, accelerated): data = [(i, float(i * 1.5), f"name_{i}") for i in range(100)] - df = spark.createDataFrame(data, ["id", "value", "name"]) + src = str(tmp_path / "src.parquet") + spark.createDataFrame(data, ["id", "value", "name"]).write.parquet(src) - # Write to parquet so CometScan can read it - df.write.mode("overwrite").parquet("/tmp/comet_pyarrow_test_data") - test_df = spark.read.parquet("/tmp/comet_pyarrow_test_data") - - # Define a PyArrow UDF that doubles the value column. - # mapInArrow callbacks receive an iterator of RecordBatches and must yield batches. def double_value(iterator): for batch in iterator: pdf = batch.to_pandas() pdf["value"] = pdf["value"] * 2 yield pa.RecordBatch.from_pandas(pdf) - output_schema = T.StructType([ - T.StructField("id", T.LongType()), - T.StructField("value", T.DoubleType()), - T.StructField("name", T.StringType()), - ]) - - # Apply mapInArrow - result_df = test_df.mapInArrow(double_value, output_schema) - - # Check the explain plan - print("=" * 60) - print("PHYSICAL PLAN:") - print("=" * 60) - result_df.explain(mode="extended") - print("=" * 60) - - plan_str = result_df._jdf.queryExecution().executedPlan().toString() - print(f"\nPlan string:\n{plan_str}\n") - - # Verify the optimized Comet operator is in the plan. The toString form is - # "CometPythonMapInArrow" (no Exec suffix) and the upstream scan prints as - # "CometNativeScan". - assert "CometPythonMapInArrow" in plan_str, \ - f"CometPythonMapInArrow missing from plan:\n{plan_str}" - assert "ColumnarToRow" not in plan_str, \ - f"Unexpected ColumnarToRow in optimized plan:\n{plan_str}" - print("SUCCESS: CometPythonMapInArrow is in the plan with no ColumnarToRow transition.") - - # Verify correctness - result = result_df.orderBy("id").collect() - expected_first = data[0] - actual_first = result[0] - - assert actual_first["id"] == expected_first[0], \ - f"ID mismatch: {actual_first['id']} != {expected_first[0]}" - assert abs(actual_first["value"] - expected_first[1] * 2) < 0.001, \ - f"Value mismatch: {actual_first['value']} != {expected_first[1] * 2}" - assert actual_first["name"] == expected_first[2], \ - f"Name mismatch: {actual_first['name']} != {expected_first[2]}" - - print(f"\nFirst row: {actual_first}") - print(f"Expected value (doubled): {expected_first[1] * 2}") - print("CORRECTNESS: PASSED") - - # Verify all rows - for i, row in enumerate(result): - expected_val = data[i][1] * 2 - assert abs(row["value"] - expected_val) < 0.001, \ - f"Row {i}: expected value {expected_val}, got {row['value']}" - - print(f"All {len(result)} rows verified correctly.") - return True - - -def test_map_in_arrow_type_change(): - """Test mapInArrow that changes the schema.""" - spark = SparkSession.builder.getOrCreate() + schema = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + T.StructField("name", T.StringType()), + ] + ) + result_df = spark.read.parquet(src).mapInArrow(double_value, schema) + + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + rows = result_df.orderBy("id").collect() + assert len(rows) == len(data) + for row, original in zip(rows, data): + assert row["id"] == original[0] + assert abs(row["value"] - original[1] * 2) < 1e-6 + assert row["name"] == original[2] + +def test_map_in_arrow_changes_schema(spark, tmp_path, accelerated): data = [(i, float(i)) for i in range(50)] - df = spark.createDataFrame(data, ["id", "value"]) - df.write.mode("overwrite").parquet("/tmp/comet_pyarrow_test_data2") - test_df = spark.read.parquet("/tmp/comet_pyarrow_test_data2") + src = str(tmp_path / "src.parquet") + spark.createDataFrame(data, ["id", "value"]).write.parquet(src) def add_computed_column(iterator): for batch in iterator: @@ -136,51 +184,20 @@ def add_computed_column(iterator): pdf["label"] = pdf["id"].apply(lambda x: f"item_{x}") yield pa.RecordBatch.from_pandas(pdf) - output_schema = T.StructType([ - T.StructField("id", T.LongType()), - T.StructField("value", T.DoubleType()), - T.StructField("squared", T.DoubleType()), - T.StructField("label", T.StringType()), - ]) - - result_df = test_df.mapInArrow(add_computed_column, output_schema) - result = result_df.orderBy("id").collect() - - assert len(result) == 50 - for i, row in enumerate(result): - assert abs(row["squared"] - float(i) ** 2) < 0.001 + schema = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + T.StructField("squared", T.DoubleType()), + T.StructField("label", T.StringType()), + ] + ) + result_df = spark.read.parquet(src).mapInArrow(add_computed_column, schema) + + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + rows = result_df.orderBy("id").collect() + assert len(rows) == 50 + for i, row in enumerate(rows): + assert abs(row["squared"] - float(i) ** 2) < 1e-6 assert row["label"] == f"item_{i}" - - print("test_map_in_arrow_type_change: PASSED") - return True - - -if __name__ == "__main__": - print("Running PyArrow UDF integration tests for Comet...") - print() - - tests = [ - ("test_map_in_arrow_basic", test_map_in_arrow_basic), - ("test_map_in_arrow_type_change", test_map_in_arrow_type_change), - ] - - passed = 0 - failed = 0 - for name, test_fn in tests: - print(f"\n{'=' * 60}") - print(f"Running: {name}") - print(f"{'=' * 60}") - try: - test_fn() - passed += 1 - except Exception as e: - print(f"FAILED: {e}") - import traceback - traceback.print_exc() - failed += 1 - - print(f"\n{'=' * 60}") - print(f"Results: {passed} passed, {failed} failed") - print(f"{'=' * 60}") - - sys.exit(0 if failed == 0 else 1) From f7515397e4aada8fc956552b9042d3ce00ceb039 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 5 May 2026 18:43:57 -0600 Subject: [PATCH 05/16] docs: run prettier on pyarrow-udfs user guide page --- docs/source/user-guide/latest/pyarrow-udfs.md | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/docs/source/user-guide/latest/pyarrow-udfs.md b/docs/source/user-guide/latest/pyarrow-udfs.md index 71701960cd..2d555cedc4 100644 --- a/docs/source/user-guide/latest/pyarrow-udfs.md +++ b/docs/source/user-guide/latest/pyarrow-udfs.md @@ -62,12 +62,12 @@ It is enabled by default when Comet execution is active. ## Supported APIs -| PySpark API | Spark Plan Node | Supported | -|-------------|-----------------|-----------| -| `df.mapInArrow(func, schema)` | `PythonMapInArrowExec` | Yes | -| `df.mapInPandas(func, schema)` | `MapInPandasExec` | Yes | -| `@pandas_udf` (scalar) | `ArrowEvalPythonExec` | Not yet | -| `df.applyInPandas(func, schema)` | `FlatMapGroupsInPandasExec` | Not yet | +| PySpark API | Spark Plan Node | Supported | +| -------------------------------- | --------------------------- | --------- | +| `df.mapInArrow(func, schema)` | `PythonMapInArrowExec` | Yes | +| `df.mapInPandas(func, schema)` | `MapInPandasExec` | Yes | +| `@pandas_udf` (scalar) | `ArrowEvalPythonExec` | Not yet | +| `df.applyInPandas(func, schema)` | `FlatMapGroupsInPandasExec` | Not yet | ## Example @@ -109,6 +109,7 @@ result.explain(mode="extended") ``` You should see: + ``` CometPythonMapInArrowExec ... +- CometNativeExec ... @@ -116,6 +117,7 @@ CometPythonMapInArrowExec ... ``` Instead of the unoptimized plan: + ``` PythonMapInArrow ... +- ColumnarToRow From b14fbfb58adaf3b9219e8f06171f450ef7fd1deb Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 5 May 2026 18:46:57 -0600 Subject: [PATCH 06/16] style: apply spotless formatting --- .../org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala | 1 - .../org/apache/comet/exec/CometPythonMapInArrowSuite.scala | 2 -- 2 files changed, 3 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala index 84b3c31113..223153d7d8 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.python.{ArrowPythonRunner, BatchIterator, PythonSQLMetrics} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} /** diff --git a/spark/src/test/scala/org/apache/comet/exec/CometPythonMapInArrowSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometPythonMapInArrowSuite.scala index 94145cea2b..7b1e17c4ed 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometPythonMapInArrowSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometPythonMapInArrowSuite.scala @@ -21,8 +21,6 @@ package org.apache.comet.exec import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.comet.CometPythonMapInArrowExec -import org.apache.spark.sql.execution.ColumnarToRowExec -import org.apache.spark.sql.execution.python.PythonMapInArrowExec import org.apache.comet.CometConf From ca0bbbf50892860e7e103af8c016163d9d4310ef Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 5 May 2026 18:46:57 -0600 Subject: [PATCH 07/16] ci: broaden pyarrow_udf_test triggers to match pr_build_linux Replace the narrow paths allowlist with the same paths-ignore list used by pr_build_linux.yml so the workflow runs on any source change that could affect Comet's PyArrow UDF execution path, not just the few files explicitly named. --- .github/workflows/pyarrow_udf_test.yml | 34 +++++++++++++++----------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/.github/workflows/pyarrow_udf_test.yml b/.github/workflows/pyarrow_udf_test.yml index 0779f092a4..46c5fbe079 100644 --- a/.github/workflows/pyarrow_udf_test.yml +++ b/.github/workflows/pyarrow_udf_test.yml @@ -25,21 +25,27 @@ on: push: branches: - main - paths: - - "spark/src/main/scala/org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala" - - "spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala" - - "spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala" - - "spark/src/test/resources/pyspark/test_pyarrow_udf.py" - - ".github/workflows/pyarrow_udf_test.yml" - - "native/**" + paths-ignore: + - "benchmarks/**" + - "doc/**" + - "docs/**" + - "**.md" + - "dev/changelog/*.md" + - "native/core/benches/**" + - "native/spark-expr/benches/**" + - "spark/src/test/scala/org/apache/spark/sql/benchmark/**" + - "spark/src/main/scala/org/apache/comet/GenerateDocs.scala" pull_request: - paths: - - "spark/src/main/scala/org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala" - - "spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala" - - "spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala" - - "spark/src/test/resources/pyspark/test_pyarrow_udf.py" - - ".github/workflows/pyarrow_udf_test.yml" - - "native/**" + paths-ignore: + - "benchmarks/**" + - "doc/**" + - "docs/**" + - "**.md" + - "dev/changelog/*.md" + - "native/core/benches/**" + - "native/spark-expr/benches/**" + - "spark/src/test/scala/org/apache/spark/sql/benchmark/**" + - "spark/src/main/scala/org/apache/comet/GenerateDocs.scala" workflow_dispatch: env: From 55c28c32a187cce9bdf6b49a2b4113e845ed1d44 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 5 May 2026 18:47:48 -0600 Subject: [PATCH 08/16] ci: restrict GITHUB_TOKEN to contents:read in pyarrow_udf_test --- .github/workflows/pyarrow_udf_test.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/pyarrow_udf_test.yml b/.github/workflows/pyarrow_udf_test.yml index 46c5fbe079..0740842413 100644 --- a/.github/workflows/pyarrow_udf_test.yml +++ b/.github/workflows/pyarrow_udf_test.yml @@ -48,6 +48,9 @@ on: - "spark/src/main/scala/org/apache/comet/GenerateDocs.scala" workflow_dispatch: +permissions: + contents: read + env: RUST_VERSION: stable RUST_BACKTRACE: 1 From 05b1e7afd38437c9eb72309ac2f4f5f764a97adc Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 5 May 2026 22:19:01 -0600 Subject: [PATCH 09/16] fix: shim CometPythonMapInArrowExec for cross-version Spark builds The PR's `CometPythonMapInArrowExec` and `EliminateRedundantTransitions` rule directly reference Spark 3.5 APIs that differ across supported Spark versions: the `ArrowPythonRunner` constructor (4 distinct signatures across 3.4/3.5/4.0/4.1+/4.2), `arrowUseLargeVarTypes`, `JobArtifactSet`, `MapInBatchExec.isBarrier`, and the `PythonMapInArrowExec` type itself (renamed to `MapInArrowExec` in 4.0+). This breaks compile on every profile other than 3.5. Introduce a per-version `ShimCometPythonMapInArrow` trait under `org.apache.spark.sql.comet.shims` (placed in the spark namespace so it can reach `private[spark]` members) that: * matches the Spark-version-specific MapInArrow / MapInPandas exec types and exposes their `(func, output, child, isBarrier, evalType)` tuple, * constructs the right `ArrowPythonRunner` for the version, * hides `arrowUseLargeVarTypes` / `JobArtifactSet` / `getPythonRunnerConfMap` behind helper methods. Spark 3.4 lacks the prerequisite APIs (no `isBarrier`, no `JobArtifactSet`, no `arrowUseLargeVarTypes`), so its shim returns `None` from the matchers and the optimization is a no-op there. --- .../rules/EliminateRedundantTransitions.scala | 41 ++++----- .../sql/comet/CometPythonMapInArrowExec.scala | 32 +++---- .../shims/ShimCometPythonMapInArrow.scala | 68 +++++++++++++++ .../shims/ShimCometPythonMapInArrow.scala | 84 ++++++++++++++++++ .../shims/ShimCometPythonMapInArrow.scala | 86 ++++++++++++++++++ .../shims/ShimCometPythonMapInArrow.scala | 87 +++++++++++++++++++ .../shims/ShimCometPythonMapInArrow.scala | 86 ++++++++++++++++++ 7 files changed, 446 insertions(+), 38 deletions(-) create mode 100644 spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimCometPythonMapInArrow.scala create mode 100644 spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometPythonMapInArrow.scala create mode 100644 spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometPythonMapInArrow.scala create mode 100644 spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometPythonMapInArrow.scala create mode 100644 spark/src/main/spark-4.2/org/apache/spark/sql/comet/shims/ShimCometPythonMapInArrow.scala diff --git a/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala b/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala index 272ef76484..e7218ab935 100644 --- a/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala +++ b/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala @@ -20,15 +20,14 @@ package org.apache.comet.rules import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.PythonUDF import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.comet.{CometCollectLimitExec, CometColumnarToRowExec, CometNativeColumnarToRowExec, CometNativeWriteExec, CometPlan, CometPythonMapInArrowExec, CometSparkToColumnarExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} +import org.apache.spark.sql.comet.shims.ShimCometPythonMapInArrow import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.QueryStageExec import org.apache.spark.sql.execution.exchange.ReusedExchangeExec -import org.apache.spark.sql.execution.python.{MapInPandasExec, PythonMapInArrowExec} import org.apache.comet.CometConf @@ -53,7 +52,9 @@ import org.apache.comet.CometConf // various reasons) or Spark requests row-based output such as a `collect` call. Spark will adds // another `ColumnarToRowExec` on top of `CometSparkToColumnarExec`. In this case, the pair could // be removed. -case class EliminateRedundantTransitions(session: SparkSession) extends Rule[SparkPlan] { +case class EliminateRedundantTransitions(session: SparkSession) + extends Rule[SparkPlan] + with ShimCometPythonMapInArrow { private lazy val showTransformations = CometConf.COMET_EXPLAIN_TRANSFORMATIONS.get() @@ -100,29 +101,23 @@ case class EliminateRedundantTransitions(session: SparkSession) extends Rule[Spa case CometNativeColumnarToRowExec(sparkToColumnar: CometSparkToColumnarExec) => sparkToColumnar.child case CometSparkToColumnarExec(child: CometSparkToColumnarExec) => child - // Replace MapInBatchExec (PythonMapInArrowExec / MapInPandasExec) that has a - // ColumnarToRow child with CometPythonMapInArrowExec to avoid the unnecessary - // Arrow->Row->Arrow round-trip. - case p: PythonMapInArrowExec if CometConf.COMET_PYTHON_MAP_IN_ARROW_ENABLED.get() => - extractColumnarChild(p.child) + // Replace MapInBatchExec (PythonMapInArrowExec / MapInArrowExec / MapInPandasExec) that has + // a ColumnarToRow child with CometPythonMapInArrowExec to avoid the unnecessary + // Arrow->Row->Arrow round-trip. The matchers are version-shimmed: Spark 3.4 returns None + // (it lacks the required APIs) and Spark 4.1+ matches the renamed `MapInArrowExec`. + case p: SparkPlan + if CometConf.COMET_PYTHON_MAP_IN_ARROW_ENABLED.get() && + matchMapInArrow(p).orElse(matchMapInPandas(p)).isDefined => + val (mapFunc, mapOutput, mapChild, mapIsBarrier, mapEvalType) = + matchMapInArrow(p).orElse(matchMapInPandas(p)).get + extractColumnarChild(mapChild) .map { columnarChild => CometPythonMapInArrowExec( - p.func, - p.output, + mapFunc, + mapOutput, columnarChild, - p.isBarrier, - p.func.asInstanceOf[PythonUDF].evalType) - } - .getOrElse(p) - case p: MapInPandasExec if CometConf.COMET_PYTHON_MAP_IN_ARROW_ENABLED.get() => - extractColumnarChild(p.child) - .map { columnarChild => - CometPythonMapInArrowExec( - p.func, - p.output, - columnarChild, - p.isBarrier, - p.func.asInstanceOf[PythonUDF].evalType) + mapIsBarrier, + mapEvalType) } .getOrElse(p) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala index 223153d7d8..9b3e820023 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala @@ -22,15 +22,16 @@ package org.apache.spark.sql.comet import scala.collection.JavaConverters._ import org.apache.spark.{ContextAwareIterator, TaskContext} -import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.PythonUDF import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.comet.shims.ShimCometPythonMapInArrow import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.python.{ArrowPythonRunner, BatchIterator, PythonSQLMetrics} +import org.apache.spark.sql.execution.python.{BatchIterator, PythonSQLMetrics} +import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} /** @@ -55,7 +56,8 @@ case class CometPythonMapInArrowExec( isBarrier: Boolean, pythonEvalType: Int) extends UnaryExecNode - with PythonSQLMetrics { + with PythonSQLMetrics + with ShimCometPythonMapInArrow { override def supportsColumnar: Boolean = true @@ -78,18 +80,16 @@ case class CometPythonMapInArrowExec( val numOutputBatches = longMetric("numOutputBatches") val numInputRows = longMetric("numInputRows") - val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) - val pythonFunction = func.asInstanceOf[PythonUDF].func - val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) + val pythonUDF = func.asInstanceOf[PythonUDF] val localOutput = output val localChildSchema = child.schema val batchSize = conf.arrowMaxRecordsPerBatch val sessionLocalTimeZone = conf.sessionLocalTimeZone - val largeVarTypes = conf.arrowUseLargeVarTypes + val useLargeVarTypes = largeVarTypes(conf) + val pythonRunnerConf = getPythonRunnerConfMap(conf) val localPythonEvalType = pythonEvalType val localPythonMetrics = pythonMetrics - val jobArtifactUUID = - org.apache.spark.JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + val jobArtifactUUID = currentJobArtifactUUID() val inputRDD = child.executeColumnar() @@ -112,17 +112,19 @@ case class CometPythonMapInArrowExec( val batchIter = if (batchSize > 0) new BatchIterator(wrappedIter, batchSize) else Iterator(wrappedIter) - val columnarBatchIter = new ArrowPythonRunner( - chainedFunc, + val columnarBatchIter = computeArrowPython( + pythonUDF, localPythonEvalType, argOffsets, - org.apache.spark.sql.types - .StructType(Array(org.apache.spark.sql.types.StructField("struct", localChildSchema))), + StructType(Array(StructField("struct", localChildSchema))), sessionLocalTimeZone, - largeVarTypes, + useLargeVarTypes, pythonRunnerConf, localPythonMetrics, - jobArtifactUUID).compute(batchIter, context.partitionId(), context) + jobArtifactUUID, + batchIter, + context.partitionId(), + context) columnarBatchIter.map { batch => // Python returns a StructType column; flatten to individual columns diff --git a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimCometPythonMapInArrow.scala b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimCometPythonMapInArrow.scala new file mode 100644 index 0000000000..30736d99b3 --- /dev/null +++ b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimCometPythonMapInArrow.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PythonUDF} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Spark 3.4 shim for the PyArrow UDF acceleration support. + * + * Spark 3.4 lacks several APIs that the optimization relies on (`isBarrier` on `MapInBatchExec`, + * `arrowUseLargeVarTypes`, `JobArtifactSet`, the modern `ArrowPythonRunner` constructor), so the + * matchers return `None` and the runner factory throws. The optimization is effectively a no-op + * on Spark 3.4. + */ +trait ShimCometPythonMapInArrow { + + protected def matchMapInArrow( + plan: SparkPlan): Option[(Expression, Seq[Attribute], SparkPlan, Boolean, Int)] = None + + protected def matchMapInPandas( + plan: SparkPlan): Option[(Expression, Seq[Attribute], SparkPlan, Boolean, Int)] = None + + protected def currentJobArtifactUUID(): Option[String] = None + + protected def largeVarTypes(conf: SQLConf): Boolean = false + + protected def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = Map.empty + + protected def computeArrowPython( + pythonUDF: PythonUDF, + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType, + timeZoneId: String, + largeVarTypes: Boolean, + pythonRunnerConf: Map[String, String], + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String], + batchIter: Iterator[Iterator[InternalRow]], + partitionId: Int, + context: TaskContext): Iterator[ColumnarBatch] = + throw new UnsupportedOperationException( + "CometPythonMapInArrowExec is not supported on Spark 3.4") +} diff --git a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometPythonMapInArrow.scala b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometPythonMapInArrow.scala new file mode 100644 index 0000000000..f7c8221d9e --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometPythonMapInArrow.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.{JobArtifactSet, TaskContext} +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PythonUDF} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.{ArrowPythonRunner, MapInPandasExec, PythonMapInArrowExec} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +trait ShimCometPythonMapInArrow { + + protected def matchMapInArrow( + plan: SparkPlan): Option[(Expression, Seq[Attribute], SparkPlan, Boolean, Int)] = + plan match { + case p: PythonMapInArrowExec => + Some((p.func, p.output, p.child, p.isBarrier, p.func.asInstanceOf[PythonUDF].evalType)) + case _ => None + } + + protected def matchMapInPandas( + plan: SparkPlan): Option[(Expression, Seq[Attribute], SparkPlan, Boolean, Int)] = + plan match { + case p: MapInPandasExec => + Some((p.func, p.output, p.child, p.isBarrier, p.func.asInstanceOf[PythonUDF].evalType)) + case _ => None + } + + protected def currentJobArtifactUUID(): Option[String] = + JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + + protected def largeVarTypes(conf: SQLConf): Boolean = conf.arrowUseLargeVarTypes + + protected def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = + ArrowPythonRunner.getPythonRunnerConfMap(conf) + + protected def computeArrowPython( + pythonUDF: PythonUDF, + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType, + timeZoneId: String, + largeVarTypes: Boolean, + pythonRunnerConf: Map[String, String], + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String], + batchIter: Iterator[Iterator[InternalRow]], + partitionId: Int, + context: TaskContext): Iterator[ColumnarBatch] = { + val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonUDF.func))) + new ArrowPythonRunner( + chainedFunc, + evalType, + argOffsets, + schema, + timeZoneId, + largeVarTypes, + pythonRunnerConf, + pythonMetrics, + jobArtifactUUID).compute(batchIter, partitionId, context) + } +} diff --git a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometPythonMapInArrow.scala b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometPythonMapInArrow.scala new file mode 100644 index 0000000000..78935f54c5 --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometPythonMapInArrow.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.{JobArtifactSet, TaskContext} +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PythonUDF} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.{ArrowPythonRunner, MapInArrowExec, MapInPandasExec} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +trait ShimCometPythonMapInArrow { + + protected def matchMapInArrow( + plan: SparkPlan): Option[(Expression, Seq[Attribute], SparkPlan, Boolean, Int)] = + plan match { + case p: MapInArrowExec => + Some((p.func, p.output, p.child, p.isBarrier, p.func.asInstanceOf[PythonUDF].evalType)) + case _ => None + } + + protected def matchMapInPandas( + plan: SparkPlan): Option[(Expression, Seq[Attribute], SparkPlan, Boolean, Int)] = + plan match { + case p: MapInPandasExec => + Some((p.func, p.output, p.child, p.isBarrier, p.func.asInstanceOf[PythonUDF].evalType)) + case _ => None + } + + protected def currentJobArtifactUUID(): Option[String] = + JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + + protected def largeVarTypes(conf: SQLConf): Boolean = conf.arrowUseLargeVarTypes + + protected def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = + ArrowPythonRunner.getPythonRunnerConfMap(conf) + + protected def computeArrowPython( + pythonUDF: PythonUDF, + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType, + timeZoneId: String, + largeVarTypes: Boolean, + pythonRunnerConf: Map[String, String], + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String], + batchIter: Iterator[Iterator[InternalRow]], + partitionId: Int, + context: TaskContext): Iterator[ColumnarBatch] = { + val chainedFunc = + Seq((ChainedPythonFunctions(Seq(pythonUDF.func)), pythonUDF.resultId.id)) + new ArrowPythonRunner( + chainedFunc, + evalType, + argOffsets, + schema, + timeZoneId, + largeVarTypes, + pythonRunnerConf, + pythonMetrics, + jobArtifactUUID, + None).compute(batchIter, partitionId, context) + } +} diff --git a/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometPythonMapInArrow.scala b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometPythonMapInArrow.scala new file mode 100644 index 0000000000..f7f775b1fa --- /dev/null +++ b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometPythonMapInArrow.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.{JobArtifactSet, TaskContext} +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PythonUDF} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.{ArrowPythonRunner, MapInArrowExec, MapInPandasExec} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +trait ShimCometPythonMapInArrow { + + protected def matchMapInArrow( + plan: SparkPlan): Option[(Expression, Seq[Attribute], SparkPlan, Boolean, Int)] = + plan match { + case p: MapInArrowExec => + Some((p.func, p.output, p.child, p.isBarrier, p.func.asInstanceOf[PythonUDF].evalType)) + case _ => None + } + + protected def matchMapInPandas( + plan: SparkPlan): Option[(Expression, Seq[Attribute], SparkPlan, Boolean, Int)] = + plan match { + case p: MapInPandasExec => + Some((p.func, p.output, p.child, p.isBarrier, p.func.asInstanceOf[PythonUDF].evalType)) + case _ => None + } + + protected def currentJobArtifactUUID(): Option[String] = + JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + + protected def largeVarTypes(conf: SQLConf): Boolean = conf.arrowUseLargeVarTypes + + protected def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = + ArrowPythonRunner.getPythonRunnerConfMap(conf) + + protected def computeArrowPython( + pythonUDF: PythonUDF, + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType, + timeZoneId: String, + largeVarTypes: Boolean, + pythonRunnerConf: Map[String, String], + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String], + batchIter: Iterator[Iterator[InternalRow]], + partitionId: Int, + context: TaskContext): Iterator[ColumnarBatch] = { + val chainedFunc = + Seq((ChainedPythonFunctions(Seq(pythonUDF.func)), pythonUDF.resultId.id)) + new ArrowPythonRunner( + chainedFunc, + evalType, + argOffsets, + schema, + timeZoneId, + largeVarTypes, + pythonRunnerConf, + pythonMetrics, + jobArtifactUUID, + None, + None).compute(batchIter, partitionId, context) + } +} diff --git a/spark/src/main/spark-4.2/org/apache/spark/sql/comet/shims/ShimCometPythonMapInArrow.scala b/spark/src/main/spark-4.2/org/apache/spark/sql/comet/shims/ShimCometPythonMapInArrow.scala new file mode 100644 index 0000000000..78935f54c5 --- /dev/null +++ b/spark/src/main/spark-4.2/org/apache/spark/sql/comet/shims/ShimCometPythonMapInArrow.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.{JobArtifactSet, TaskContext} +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PythonUDF} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.{ArrowPythonRunner, MapInArrowExec, MapInPandasExec} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +trait ShimCometPythonMapInArrow { + + protected def matchMapInArrow( + plan: SparkPlan): Option[(Expression, Seq[Attribute], SparkPlan, Boolean, Int)] = + plan match { + case p: MapInArrowExec => + Some((p.func, p.output, p.child, p.isBarrier, p.func.asInstanceOf[PythonUDF].evalType)) + case _ => None + } + + protected def matchMapInPandas( + plan: SparkPlan): Option[(Expression, Seq[Attribute], SparkPlan, Boolean, Int)] = + plan match { + case p: MapInPandasExec => + Some((p.func, p.output, p.child, p.isBarrier, p.func.asInstanceOf[PythonUDF].evalType)) + case _ => None + } + + protected def currentJobArtifactUUID(): Option[String] = + JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + + protected def largeVarTypes(conf: SQLConf): Boolean = conf.arrowUseLargeVarTypes + + protected def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = + ArrowPythonRunner.getPythonRunnerConfMap(conf) + + protected def computeArrowPython( + pythonUDF: PythonUDF, + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType, + timeZoneId: String, + largeVarTypes: Boolean, + pythonRunnerConf: Map[String, String], + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String], + batchIter: Iterator[Iterator[InternalRow]], + partitionId: Int, + context: TaskContext): Iterator[ColumnarBatch] = { + val chainedFunc = + Seq((ChainedPythonFunctions(Seq(pythonUDF.func)), pythonUDF.resultId.id)) + new ArrowPythonRunner( + chainedFunc, + evalType, + argOffsets, + schema, + timeZoneId, + largeVarTypes, + pythonRunnerConf, + pythonMetrics, + jobArtifactUUID, + None).compute(batchIter, partitionId, context) + } +} From 66eb246d3cb9f04b6b25878a02a32f6a2007b669 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 5 May 2026 22:19:07 -0600 Subject: [PATCH 10/16] ci: switch pyarrow_udf_test container to rust:bookworm The default `amd64/rust` image is Debian 13 (trixie), where the system `python3` is 3.13 and there is no `python3.11` apt package. The workflow installed `python3.11` explicitly, which fails on trixie with `Unable to locate package python3.11`. Switching to `rust:bookworm` gives a Debian 12 base where `python3` is 3.11, matching the job name and pyspark 3.5.x's supported runtime. --- .github/workflows/pyarrow_udf_test.yml | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pyarrow_udf_test.yml b/.github/workflows/pyarrow_udf_test.yml index 0740842413..622ee59fd0 100644 --- a/.github/workflows/pyarrow_udf_test.yml +++ b/.github/workflows/pyarrow_udf_test.yml @@ -61,7 +61,10 @@ jobs: name: PyArrow UDF (Spark 3.5, JDK 17, Python 3.11) runs-on: ubuntu-latest container: - image: amd64/rust + # Pinned to the Debian 12 (bookworm) base so the system `python3` is 3.11. The default + # `amd64/rust` image is Debian 13 (trixie) which ships Python 3.13 and no python3.11 apt + # package, breaking `apt-get install python3.11`. + image: rust:bookworm env: JAVA_TOOL_OPTIONS: "--add-exports=java.base/sun.nio.ch=ALL-UNNAMED --add-exports=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.lang=ALL-UNNAMED" steps: @@ -91,8 +94,8 @@ jobs: - name: Install Python 3.11 and pip run: | apt-get update - apt-get install -y --no-install-recommends python3.11 python3.11-venv python3-pip - python3.11 -m venv /tmp/venv + apt-get install -y --no-install-recommends python3 python3-venv python3-pip + python3 -m venv /tmp/venv /tmp/venv/bin/pip install --upgrade pip /tmp/venv/bin/pip install "pyspark==3.5.8" "pyarrow>=14" pandas pytest From ec6fa783ed9bb9495dfd709159f5c10cdf37a60b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 6 May 2026 06:23:52 -0600 Subject: [PATCH 11/16] ci: set PYSPARK_PYTHON to venv python for pyarrow_udf_test Spark launches Python workers in fresh subprocesses that look up python3 on PATH. Without PYSPARK_PYTHON, workers use the system python (no pyarrow installed) and UDF execution fails with ModuleNotFoundError. Point both PYSPARK_PYTHON and PYSPARK_DRIVER_PYTHON at /tmp/venv/bin/python so workers inherit the same interpreter that pytest uses. --- .github/workflows/pyarrow_udf_test.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/pyarrow_udf_test.yml b/.github/workflows/pyarrow_udf_test.yml index 622ee59fd0..e8018889cc 100644 --- a/.github/workflows/pyarrow_udf_test.yml +++ b/.github/workflows/pyarrow_udf_test.yml @@ -100,6 +100,13 @@ jobs: /tmp/venv/bin/pip install "pyspark==3.5.8" "pyarrow>=14" pandas pytest - name: Run PyArrow UDF pytest + env: + # Spark launches Python workers in a fresh subprocess and looks up `python3` + # on PATH unless PYSPARK_PYTHON is set. Without this, workers use the system + # python which has no pyarrow installed and UDF execution fails with + # ModuleNotFoundError. + PYSPARK_PYTHON: /tmp/venv/bin/python + PYSPARK_DRIVER_PYTHON: /tmp/venv/bin/python run: | jar=$(ls "$PWD"/spark/target/comet-spark-spark3.5_2.12-*-SNAPSHOT.jar \ | grep -v sources | grep -v tests | head -n1) From 1de2c2f815607115c03cc3075200ec4bc28d8223 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 6 May 2026 07:44:54 -0600 Subject: [PATCH 12/16] feat: default-disable PyArrow UDF optimization while experimental Flip spark.comet.exec.pythonMapInArrow.enabled default from true to false and prefix the config doc with "Experimental:" so the default matches the "[experimental]" label on the feature. Update the user guide to instruct users to opt in explicitly. --- common/src/main/scala/org/apache/comet/CometConf.scala | 10 ++++++---- docs/source/user-guide/latest/pyarrow-udfs.md | 6 +++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index a06cd896ec..675e872b6e 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -318,11 +318,13 @@ object CometConf extends ShimCometConf { conf("spark.comet.exec.pythonMapInArrow.enabled") .category(CATEGORY_EXEC) .doc( - "Whether to enable optimized execution of PyArrow UDFs (mapInArrow/mapInPandas). " + - "When enabled, Comet passes Arrow columnar data directly to Python UDFs without " + - "the intermediate Arrow-to-Row-to-Arrow conversion that Spark normally performs.") + "Experimental: whether to enable optimized execution of PyArrow UDFs " + + "(mapInArrow/mapInPandas). When enabled, Comet passes Arrow columnar data " + + "directly to Python UDFs without the intermediate Arrow-to-Row-to-Arrow " + + "conversion that Spark normally performs. Disabled by default while the " + + "feature stabilizes.") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val COMET_TRACING_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.tracing.enabled") .category(CATEGORY_TUNING) diff --git a/docs/source/user-guide/latest/pyarrow-udfs.md b/docs/source/user-guide/latest/pyarrow-udfs.md index 2d555cedc4..374948c039 100644 --- a/docs/source/user-guide/latest/pyarrow-udfs.md +++ b/docs/source/user-guide/latest/pyarrow-udfs.md @@ -52,13 +52,13 @@ and memory allocations. ## Configuration -The optimization is controlled by: +The optimization is experimental and disabled by default. Enable it with: ``` -spark.comet.exec.pythonMapInArrow.enabled=true (default) +spark.comet.exec.pythonMapInArrow.enabled=true ``` -It is enabled by default when Comet execution is active. +The default is `false` while the feature stabilizes. ## Supported APIs From 3f68cbeb56f6be4a3235b73113630d9b9a928249 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 6 May 2026 07:45:02 -0600 Subject: [PATCH 13/16] test: expand PyArrow UDF pytest coverage Add coverage for cases that the original pytest module did not exercise: - mapInPandas (claimed supported, previously zero coverage) - Null preservation across long and string columns via Arrow passthrough - Empty input from a CometScan via filter pushdown - Python exception propagation (sentinel must surface in driver-side error) - DecimalType(18,6), DateType, TimestampType round-trip with nulls - ArrayType and nested StructType, including null arrays/structs and arrays containing null elements - repartition between scan and UDF (correctness only; the optimization itself does not fire across a vanilla Exchange and is documented as such in the test) Generalize _assert_plan_matches_mode to take the vanilla node name so the fallback assertion can match either PythonMapInArrow or MapInPandas. --- .../resources/pyspark/test_pyarrow_udf.py | 280 +++++++++++++++++- 1 file changed, 277 insertions(+), 3 deletions(-) diff --git a/spark/src/test/resources/pyspark/test_pyarrow_udf.py b/spark/src/test/resources/pyspark/test_pyarrow_udf.py index 462f4efdc6..b62db73be1 100644 --- a/spark/src/test/resources/pyspark/test_pyarrow_udf.py +++ b/spark/src/test/resources/pyspark/test_pyarrow_udf.py @@ -37,8 +37,10 @@ pytest -v spark/src/test/resources/pyspark/test_pyarrow_udf.py """ +import datetime as dt import glob import os +from decimal import Decimal import pyarrow as pa import pytest @@ -125,7 +127,9 @@ def _executed_plan(df) -> str: return df._jdf.queryExecution().executedPlan().toString() -def _assert_plan_matches_mode(plan: str, accelerated: bool) -> None: +def _assert_plan_matches_mode( + plan: str, accelerated: bool, vanilla_node: str = "PythonMapInArrow" +) -> None: if accelerated: assert "CometPythonMapInArrow" in plan, ( f"expected CometPythonMapInArrow in accelerated plan, got:\n{plan}" @@ -137,8 +141,8 @@ def _assert_plan_matches_mode(plan: str, accelerated: bool) -> None: assert "CometPythonMapInArrow" not in plan, ( f"unexpected CometPythonMapInArrow in fallback plan:\n{plan}" ) - assert "PythonMapInArrow" in plan, ( - f"expected PythonMapInArrow in fallback plan, got:\n{plan}" + assert vanilla_node in plan, ( + f"expected {vanilla_node} in fallback plan, got:\n{plan}" ) @@ -201,3 +205,273 @@ def add_computed_column(iterator): for i, row in enumerate(rows): assert abs(row["squared"] - float(i) ** 2) < 1e-6 assert row["label"] == f"item_{i}" + + +def test_map_in_pandas_doubles_value(spark, tmp_path, accelerated): + data = [(i, float(i * 1.5)) for i in range(100)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(data, ["id", "value"]).write.parquet(src) + + def double_value(iterator): + for pdf in iterator: + pdf = pdf.copy() + pdf["value"] = pdf["value"] * 2 + yield pdf + + schema = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + ] + ) + result_df = spark.read.parquet(src).mapInPandas(double_value, schema) + + _assert_plan_matches_mode( + _executed_plan(result_df), accelerated, vanilla_node="MapInPandas" + ) + + rows = result_df.orderBy("id").collect() + assert len(rows) == len(data) + for row, original in zip(rows, data): + assert row["id"] == original[0] + assert abs(row["value"] - original[1] * 2) < 1e-6 + + +def test_map_in_pandas_changes_schema(spark, tmp_path, accelerated): + data = [(i, float(i)) for i in range(50)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(data, ["id", "value"]).write.parquet(src) + + def add_squared(iterator): + for pdf in iterator: + pdf = pdf.copy() + pdf["squared"] = pdf["value"] ** 2 + yield pdf + + schema = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + T.StructField("squared", T.DoubleType()), + ] + ) + result_df = spark.read.parquet(src).mapInPandas(add_squared, schema) + + _assert_plan_matches_mode( + _executed_plan(result_df), accelerated, vanilla_node="MapInPandas" + ) + + rows = result_df.orderBy("id").collect() + assert len(rows) == 50 + for i, row in enumerate(rows): + assert abs(row["squared"] - float(i) ** 2) < 1e-6 + + +def test_map_in_arrow_preserves_nulls(spark, tmp_path, accelerated): + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("name", T.StringType()), + ] + ) + rows = [ + (1, "a"), + (2, None), + (None, "c"), + (None, None), + (5, "e"), + ] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + # Pure Arrow passthrough so nulls survive without a pandas roundtrip + # (pandas would coerce null longs to NaN floats). + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + out = {(r["id"], r["name"]) for r in result_df.collect()} + assert out == set(rows) + + +def test_map_in_arrow_empty_input(spark, tmp_path, accelerated): + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + ] + ) + src = str(tmp_path / "src.parquet") + spark.createDataFrame([(1, 1.0), (2, 2.0)], schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + # Filter all rows out so the operator sees an empty stream from CometScan. + result_df = ( + spark.read.parquet(src).where("id < 0").mapInArrow(passthrough, schema_in) + ) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + assert result_df.count() == 0 + + +def test_map_in_arrow_python_exception_propagates(spark, tmp_path, accelerated): + schema_in = T.StructType([T.StructField("id", T.LongType())]) + data = [(i,) for i in range(10)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(data, schema_in).write.parquet(src) + + sentinel = "boom-from-pyarrow-udf" + + def boom(iterator): + for _batch in iterator: + raise ValueError(sentinel) + # Unreachable, but mapInArrow requires the callable to be a generator. + yield # pragma: no cover + + result_df = spark.read.parquet(src).mapInArrow(boom, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + with pytest.raises(Exception) as exc_info: + result_df.collect() + assert sentinel in str(exc_info.value), ( + f"expected sentinel {sentinel!r} in exception, got: {exc_info.value}" + ) + + +def test_map_in_arrow_decimal_type(spark, tmp_path, accelerated): + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("amount", T.DecimalType(18, 6)), + ] + ) + rows = [ + (1, Decimal("123.456789")), + (2, Decimal("0.000001")), + (3, Decimal("-99999999.999999")), + (4, None), + ] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + out = {(r["id"], r["amount"]) for r in result_df.collect()} + assert out == set(rows) + + +def test_map_in_arrow_date_and_timestamp(spark, tmp_path, accelerated): + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("d", T.DateType()), + T.StructField("ts", T.TimestampType()), + ] + ) + rows = [ + (1, dt.date(2024, 1, 1), dt.datetime(2024, 1, 1, 12, 30, 45)), + (2, dt.date(1999, 12, 31), dt.datetime(2000, 6, 15, 0, 0, 0)), + (3, None, None), + ] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + out = {(r["id"], r["d"], r["ts"]) for r in result_df.collect()} + assert out == set(rows) + + +def test_map_in_arrow_array_and_struct(spark, tmp_path, accelerated): + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("nums", T.ArrayType(T.IntegerType())), + T.StructField( + "addr", + T.StructType( + [ + T.StructField("city", T.StringType()), + T.StructField("zip", T.IntegerType()), + ] + ), + ), + ] + ) + rows = [ + (1, [1, 2, 3], ("Berlin", 10115)), + (2, [], ("NYC", 10001)), + (3, None, None), + (4, [None, 5], ("Tokyo", None)), + ] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + def _normalize(row): + nums = tuple(row["nums"]) if row["nums"] is not None else None + addr = row["addr"] + addr_tuple = (addr["city"], addr["zip"]) if addr is not None else None + return (row["id"], nums, addr_tuple) + + out = {_normalize(r) for r in result_df.collect()} + expected = { + (r[0], tuple(r[1]) if r[1] is not None else None, r[2]) for r in rows + } + assert out == expected + + +def test_map_in_arrow_after_shuffle(spark, tmp_path, accelerated): + """ + Verifies correctness when a shuffle sits between the Comet scan and the + Python UDF. Without `spark.shuffle.manager` configured at session startup + the shuffle stays a vanilla `Exchange`, which is not columnar, so the + optimization does not fire across it today. This test does not assert on + the plan; it only ensures the path produces correct results in both modes + so a future change that wires Comet shuffle into the optimization does + not silently break correctness. + """ + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + ] + ) + rows = [(i, float(i)) for i in range(50)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = ( + spark.read.parquet(src) + .repartition(4, "id") + .mapInArrow(passthrough, schema_in) + ) + + out = sorted((r["id"], r["value"]) for r in result_df.collect()) + assert out == sorted(rows) From e2ca2d2d91e5a10a829bd3793cac31727c27f6d4 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 6 May 2026 07:45:52 -0600 Subject: [PATCH 14/16] docs: document PyArrow UDF limitations and AQE explain quirk Expand the user guide with the limitations a user should know before enabling the experimental optimization: - The remaining row-to-Arrow round-trip inside the Python runner is documented more precisely (the input goes through ColumnarBatch.rowIterator to feed ArrowPythonRunner, which re-encodes to Arrow IPC). - A vanilla Spark Exchange between the Comet scan and the UDF prevents the optimization from firing. Users must configure Comet's native shuffle manager at session startup to keep the data columnar. - Spark 3.4 lacks the prerequisite APIs and the feature is a no-op there. - isBarrier is captured by the operator constructor but not yet propagated to the Python runner. Also explain the AQE display quirk: with AQE on and a shuffle present, the pre-execution plan shows the unoptimized form because the rule only sees the materialized subplan after stage execution. Running an action and re-inspecting explain() reveals the optimized plan. --- docs/source/user-guide/latest/pyarrow-udfs.md | 37 +++++++++++++++++-- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/docs/source/user-guide/latest/pyarrow-udfs.md b/docs/source/user-guide/latest/pyarrow-udfs.md index 374948c039..08a731e5de 100644 --- a/docs/source/user-guide/latest/pyarrow-udfs.md +++ b/docs/source/user-guide/latest/pyarrow-udfs.md @@ -102,7 +102,7 @@ result = df.mapInArrow(transform, output_schema) ## Verifying the Optimization -Use `explain()` to verify that `CometPythonMapInArrowExec` appears in your plan: +Use `explain()` to verify that `CometPythonMapInArrow` appears in your plan: ```python result.explain(mode="extended") @@ -111,7 +111,7 @@ result.explain(mode="extended") You should see: ``` -CometPythonMapInArrowExec ... +CometPythonMapInArrow ... +- CometNativeExec ... +- CometScan ... ``` @@ -125,10 +125,39 @@ PythonMapInArrow ... +- CometScan ... ``` +When AQE is enabled (the Spark default) and the query contains a shuffle, the +optimization is applied during stage materialization. Calling `explain()` before +running an action will show the unoptimized plan: + +``` +AdaptiveSparkPlan isFinalPlan=false ++- PythonMapInArrow ... + +- CometExchange ... +``` + +To see the optimized plan, run an action first (for example `result.collect()` or +`result.cache(); result.count()`) and then call `explain()`. The post-execution +plan shows the materialized stages and includes `CometPythonMapInArrow` if the +optimization fired. + ## Limitations - The optimization currently applies only to `mapInArrow` and `mapInPandas`. Scalar pandas UDFs (`@pandas_udf`) and grouped operations (`applyInPandas`) are not yet supported. - The internal row-to-Arrow conversion inside the Python runner is still present in this version. - A future optimization will write Arrow batches directly to the Python IPC stream, achieving - near zero-copy data transfer. + Comet currently routes columnar input through `ColumnarBatch.rowIterator()` so that the existing + `ArrowPythonRunner` can re-encode the rows back to Arrow IPC. A future optimization will write + Arrow batches directly to the Python IPC stream, eliminating the remaining round-trip and + achieving near zero-copy data transfer. +- The optimization requires Arrow data on the input side. If a shuffle sits between the upstream + Comet operator and the Python UDF, you need Comet's native shuffle for the optimization to + apply. Set `spark.shuffle.manager` to + `org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager` and enable + `spark.comet.exec.shuffle.enabled=true` at session startup. With a vanilla Spark `Exchange` + in the plan the data leaves the shuffle as rows and the optimization cannot fire. +- Spark 3.4 lacks several APIs the optimization depends on (`MapInBatchExec.isBarrier`, + `arrowUseLargeVarTypes`, `JobArtifactSet`, the modern `ArrowPythonRunner` constructor). On + Spark 3.4 the feature is a no-op even when enabled. Spark 3.5+ is required. +- The `isBarrier` flag on `mapInArrow` / `mapInPandas` is currently captured but not propagated + through to the Python runner. If your job depends on barrier-execution semantics, leave the + optimization disabled until this is fixed. From f4b5c3274cc45400fe5e7102b72e2fea96ed8496 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 6 May 2026 07:57:26 -0600 Subject: [PATCH 15/16] bench: add Python end-to-end benchmark for PyArrow UDF acceleration Standalone Python script that times df.mapInArrow(passthrough).count() and the equivalent mapInPandas query with the optimization toggled on and off. Numbers are wall-clock seconds, so they include the Python worker, Arrow IPC, and downstream count() costs. That is the right unit for a feature whose user surface is Python: it shows what fraction of end-to-end time the optimization shaves off, not just the JVM-side delta in isolation. Three workloads exercise the dimension where the optimization helps most: - narrow primitives (long, int, double) - mixed with strings (variable-length encoding) - wide rows (50 columns, projection cost scales with column count) Local smoke run with 200k rows shows 1.17x to 1.45x speedup across mapInArrow and mapInPandas, narrow/wide schemas. The script is configurable via BENCHMARK_ROWS / BENCHMARK_WARMUP / BENCHMARK_ITERS env vars for users who want longer or shorter runs. --- .../pyspark/benchmark_pyarrow_udf.py | 235 ++++++++++++++++++ 1 file changed, 235 insertions(+) create mode 100644 spark/src/test/resources/pyspark/benchmark_pyarrow_udf.py diff --git a/spark/src/test/resources/pyspark/benchmark_pyarrow_udf.py b/spark/src/test/resources/pyspark/benchmark_pyarrow_udf.py new file mode 100644 index 0000000000..8a3b4333c4 --- /dev/null +++ b/spark/src/test/resources/pyspark/benchmark_pyarrow_udf.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +End-to-end wall-clock benchmark for Comet's PyArrow UDF acceleration. + +Times `df.mapInArrow(passthrough, schema).count()` and the equivalent +`mapInPandas` query with `spark.comet.exec.pythonMapInArrow.enabled` set +to false (vanilla Spark path) and true (Comet's optimized path). Both +modes run the same Python worker, so the measured delta covers what the +optimization actually changes for users: + + * vanilla: CometScan -> ColumnarToRow + UnsafeProjection -> ArrowPythonRunner + * optimized: CometScan -> rowIterator -> ArrowPythonRunner (same runner; + no UnsafeProjection, output kept as ColumnarBatch) + +Results are wall-clock seconds, so they include Python interpreter, +Arrow IPC, and downstream count() costs. That's intentional: the +optimization's user-visible value is what fraction of end-to-end time +it shaves off, not the JVM-side delta in isolation. + +Usage: + # Build Comet (release for representative numbers): + make release + + pip install pyspark==3.5.8 pyarrow pandas + + python3 spark/src/test/resources/pyspark/benchmark_pyarrow_udf.py + +Override defaults via environment variables: + COMET_JAR=/path/to/comet.jar path to the Comet jar + BENCHMARK_ROWS=2000000 rows per run + BENCHMARK_WARMUP=2 warmup iterations per case + BENCHMARK_ITERS=5 measured iterations per case +""" + +import contextlib +import glob +import os +import statistics +import tempfile +import time + +from pyspark.sql import SparkSession + + +REPO_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..") +) + + +def _resolve_comet_jar() -> str: + explicit = os.environ.get("COMET_JAR") + if explicit: + return explicit + import pyspark + + major_minor = ".".join(pyspark.__version__.split(".")[:2]) + spark_tag = f"spark{major_minor}" + scala_tag = "_2.12" if major_minor.startswith("3.") else "_2.13" + pattern = os.path.join( + REPO_ROOT, + f"spark/target/comet-spark-{spark_tag}{scala_tag}-*-SNAPSHOT.jar", + ) + candidates = [ + m + for m in sorted(glob.glob(pattern)) + if "sources" not in os.path.basename(m) and "tests" not in os.path.basename(m) + ] + if not candidates: + raise FileNotFoundError( + "Comet jar not found. Set COMET_JAR or run `make release`. " + f"Looked under {pattern}." + ) + return candidates[-1] + + +def _build_spark() -> SparkSession: + jar = _resolve_comet_jar() + os.environ["PYSPARK_SUBMIT_ARGS"] = ( + f"--jars {jar} --driver-class-path {jar} pyspark-shell" + ) + return ( + SparkSession.builder.master("local[2]") + .appName("comet-pyarrow-udf-benchmark") + .config("spark.plugins", "org.apache.spark.CometPlugin") + .config("spark.comet.enabled", "true") + .config("spark.comet.exec.enabled", "true") + .config("spark.memory.offHeap.enabled", "true") + .config("spark.memory.offHeap.size", "4g") + .config("spark.driver.memory", "4g") + # Pin AQE off so the explain output and plan structure are stable + # across iterations. AQE doesn't change the optimization's behavior; + # it just makes plan inspection harder. + .config("spark.sql.adaptive.enabled", "false") + .getOrCreate() + ) + + +def _passthrough_arrow(iterator): + for batch in iterator: + yield batch + + +def _passthrough_pandas(iterator): + for pdf in iterator: + yield pdf + + +def _narrow_primitives(spark: SparkSession, n: int): + return spark.range(n).selectExpr( + "id as id_long", + "cast(id as int) as id_int", + "cast(id as double) as id_double", + ) + + +def _mixed_with_strings(spark: SparkSession, n: int): + return spark.range(n).selectExpr( + "id as id_long", + "cast(id as int) as id_int", + "cast(id as double) as id_double", + "concat('row_', cast(id as string)) as id_str", + "cast(id % 2 as boolean) as id_bool", + ) + + +def _wide_rows(spark: SparkSession, n: int): + types = ["int", "long", "double"] + cols = [ + f"cast(id + {i} as {types[i % len(types)]}) as col_{i}" for i in range(50) + ] + return spark.range(n).selectExpr(*cols) + + +WORKLOADS = [ + ("narrow primitives", _narrow_primitives), + ("mixed with strings", _mixed_with_strings), + ("wide rows (50 cols)", _wide_rows), +] + + +@contextlib.contextmanager +def _temp_parquet(spark: SparkSession, build_df, n: int): + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "src.parquet") + build_df(spark, n).write.parquet(path) + yield path + + +def _time_run(spark: SparkSession, parquet_path: str, accelerate: bool, api: str) -> float: + spark.conf.set( + "spark.comet.exec.pythonMapInArrow.enabled", + "true" if accelerate else "false", + ) + df = spark.read.parquet(parquet_path) + schema = df.schema + if api == "mapInArrow": + df = df.mapInArrow(_passthrough_arrow, schema) + else: + df = df.mapInPandas(_passthrough_pandas, schema) + t0 = time.perf_counter() + df.count() + return time.perf_counter() - t0 + + +def main() -> None: + rows = int(os.environ.get("BENCHMARK_ROWS", 1024 * 1024)) + warmup = int(os.environ.get("BENCHMARK_WARMUP", 2)) + iters = int(os.environ.get("BENCHMARK_ITERS", 5)) + + spark = _build_spark() + spark.sparkContext.setLogLevel("WARN") + + print(f"\nrows per run: {rows:,}") + print(f"warmup iters: {warmup}, measured iters: {iters}") + print(f"jar: {_resolve_comet_jar()}\n") + + header = " {:<14} {:<10} {:>10} {:>10} {:>10} {:>13} {:>9}".format( + "api", "mode", "min (s)", "median (s)", "max (s)", "rows/s", "speedup" + ) + print(header) + print(" " + "-" * (len(header) - 2)) + + for name, build_df in WORKLOADS: + print(f"\n=== {name} ===") + with _temp_parquet(spark, build_df, rows) as parquet_path: + for api in ("mapInArrow", "mapInPandas"): + samples_by_mode = {} + for mode, accelerate in (("vanilla", False), ("optimized", True)): + for _ in range(warmup): + _time_run(spark, parquet_path, accelerate, api) + samples = [ + _time_run(spark, parquet_path, accelerate, api) + for _ in range(iters) + ] + samples_by_mode[mode] = samples + median = statistics.median(samples) + speedup = "" + if mode == "optimized": + speedup = "{:.2f}x".format( + statistics.median(samples_by_mode["vanilla"]) / median + ) + print( + " {:<14} {:<10} {:>10} {:>10} {:>10} {:>13} {:>9}".format( + api, + mode, + "{:.3f}".format(min(samples)), + "{:.3f}".format(median), + "{:.3f}".format(max(samples)), + "{:,.0f}".format(rows / median), + speedup, + ) + ) + + spark.stop() + + +if __name__ == "__main__": + main() From 3822ed7d90368a0746a577a681d2bf56efccb087 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 6 May 2026 08:17:49 -0600 Subject: [PATCH 16/16] fix: propagate isBarrier through CometPythonMapInArrowExec The operator captured isBarrier in its constructor but always called inputRDD.mapPartitionsInternal, dropping the barrier execution mode semantics that mapInArrow(..., barrier=True) requests. Stages running under the optimization lost gang scheduling and the BarrierTaskContext APIs the UDF expects. Branch on isBarrier and route through inputRDD.barrier().mapPartitions in the barrier case, matching what Spark's MapInBatchExec.doExecute does. Add a pytest case that calls BarrierTaskContext.get() inside the UDF, which raises if the task is not running in a barrier stage; runs in both vanilla and optimized modes. Drop the isBarrier limitation note from the user guide. --- docs/source/user-guide/latest/pyarrow-udfs.md | 3 -- .../sql/comet/CometPythonMapInArrowExec.scala | 15 ++++++-- .../resources/pyspark/test_pyarrow_udf.py | 38 +++++++++++++++++++ 3 files changed, 50 insertions(+), 6 deletions(-) diff --git a/docs/source/user-guide/latest/pyarrow-udfs.md b/docs/source/user-guide/latest/pyarrow-udfs.md index 08a731e5de..6a95fbac0d 100644 --- a/docs/source/user-guide/latest/pyarrow-udfs.md +++ b/docs/source/user-guide/latest/pyarrow-udfs.md @@ -158,6 +158,3 @@ optimization fired. - Spark 3.4 lacks several APIs the optimization depends on (`MapInBatchExec.isBarrier`, `arrowUseLargeVarTypes`, `JobArtifactSet`, the modern `ArrowPythonRunner` constructor). On Spark 3.4 the feature is a no-op even when enabled. Spark 3.5+ is required. -- The `isBarrier` flag on `mapInArrow` / `mapInPandas` is currently captured but not propagated - through to the Python runner. If your job depends on barrier-execution semantics, leave the - optimization disabled until this is fixed. diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala index 9b3e820023..68e27b9355 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometPythonMapInArrowExec.scala @@ -93,12 +93,13 @@ case class CometPythonMapInArrowExec( val inputRDD = child.executeColumnar() - inputRDD.mapPartitionsInternal { batches => + // Run on every partition. Identical to what MapInBatchExec does, except the input + // is columnar; we intentionally avoid the UnsafeProjection copy that ColumnarToRow + // would do. + def processPartition(batches: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = { val context = TaskContext.get() val argOffsets = Array(Array(0)) - // Convert columnar batches to rows using lightweight rowIterator - // (avoids UnsafeProjection copy that ColumnarToRow would do) val rowIter = batches.flatMap { batch => numInputRows += batch.numRows() batch.rowIterator().asScala @@ -137,6 +138,14 @@ case class CometPythonMapInArrowExec( flattenedBatch } } + + // Preserve isBarrier semantics: when set, run inside a barrier stage so all tasks + // are gang-scheduled and BarrierTaskContext.barrier() works inside the UDF. + if (isBarrier) { + inputRDD.barrier().mapPartitions(processPartition) + } else { + inputRDD.mapPartitionsInternal(processPartition) + } } override protected def withNewChildInternal(newChild: SparkPlan): CometPythonMapInArrowExec = diff --git a/spark/src/test/resources/pyspark/test_pyarrow_udf.py b/spark/src/test/resources/pyspark/test_pyarrow_udf.py index b62db73be1..ea72436841 100644 --- a/spark/src/test/resources/pyspark/test_pyarrow_udf.py +++ b/spark/src/test/resources/pyspark/test_pyarrow_udf.py @@ -475,3 +475,41 @@ def passthrough(iterator): out = sorted((r["id"], r["value"]) for r in result_df.collect()) assert out == sorted(rows) + + +def test_map_in_arrow_barrier_mode(spark, tmp_path, accelerated): + """ + `mapInArrow(..., barrier=True)` runs the stage in barrier execution mode + (gang scheduling, all-or-nothing failure semantics, BarrierTaskContext + available inside the UDF). The optimization captures isBarrier in the + operator constructor and must propagate it through to RDD.barrier(); + otherwise the runtime context the UDF sees changes when the optimization + fires and any code calling BarrierTaskContext APIs breaks. + """ + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + ] + ) + rows = [(i, float(i)) for i in range(20)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def assert_barrier_context(iterator): + from pyspark import BarrierTaskContext + + # Will raise if the task is not running inside a barrier stage. + BarrierTaskContext.get() + for batch in iterator: + yield batch + + result_df = ( + spark.read.parquet(src).mapInArrow( + assert_barrier_context, schema_in, barrier=True + ) + ) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + out = sorted((r["id"], r["value"]) for r in result_df.collect()) + assert out == sorted(rows)