diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 53d67ec7f6..7e7aa23d1e 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -557,16 +557,29 @@ object CometConf extends ShimCometConf { .doubleConf .createWithDefault(10.0) - val COMET_EXCHANGE_SIZE_MULTIPLIER: ConfigEntry[Double] = conf( - "spark.comet.shuffle.sizeInBytesMultiplier") + val COMET_EXCHANGE_SIZE_MULTIPLIER: ConfigEntry[Double] = + conf("spark.comet.shuffle.sizeInBytesMultiplier") + .category(CATEGORY_SHUFFLE) + .doc( + "Comet shuffle uses Arrow columnar format which is more compact than Spark's UnsafeRow " + + "format. This causes Spark's AQE to underestimate shuffle data sizes, potentially " + + "choosing suboptimal join strategies (e.g. broadcast instead of sort-merge). " + + "This multiplier is applied to the reported shuffle data size to compensate. " + + "Only used when spark.comet.shuffle.sizeInBytesMultiplier.dynamic is false.") + .doubleConf + .createWithDefault(2.0) + + val COMET_EXCHANGE_SIZE_MULTIPLIER_DYNAMIC: ConfigEntry[Boolean] = conf( + "spark.comet.shuffle.sizeInBytesMultiplier.dynamic") .category(CATEGORY_SHUFFLE) .doc( - "Comet reports smaller sizes for shuffle due to using Arrow's columnar memory format " + - "and this can result in Spark choosing a different join strategy due to the estimated " + - "size of the exchange being smaller. Comet will multiple sizeInBytes by this amount to " + - "avoid regressions in join strategy.") - .doubleConf - .createWithDefault(1.0) + "When true, Comet estimates the size multiplier dynamically based on the shuffle " + + "output schema rather than using the static spark.comet.shuffle.sizeInBytesMultiplier " + + "value. The dynamic estimate accounts for per-column type widths to approximate " + + "how much larger the data would be in Spark's UnsafeRow format compared to Arrow " + + "columnar format.") + .booleanConf + .createWithDefault(false) val COMET_DEBUG_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.debug.enabled") diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index cc7a63f0cb..6a66714723 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -120,8 +120,12 @@ case class CometShuffleExchangeExec( new CometShuffledBatchRDD(shuffleDependency, readMetrics, partitionSpecs) override def runtimeStatistics: Statistics = { - val dataSize = - metrics("dataSize").value * Math.max(CometConf.COMET_EXCHANGE_SIZE_MULTIPLIER.get(conf), 1) + val multiplier = if (CometConf.COMET_EXCHANGE_SIZE_MULTIPLIER_DYNAMIC.get(conf)) { + CometShuffleExchangeExec.estimateUnsafeRowMultiplier(child.output) + } else { + Math.max(CometConf.COMET_EXCHANGE_SIZE_MULTIPLIER.get(conf), 1) + } + val dataSize = metrics("dataSize").value * multiplier val rowCount = metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN).value Statistics(dataSize.toLong, Some(rowCount)) } @@ -222,6 +226,53 @@ object CometShuffleExchangeExec with CometTypeShim with SQLConfHelper { + /** + * Estimates the ratio of Spark UnsafeRow size to Arrow columnar size for a given schema. + * + * UnsafeRow uses 8 bytes per field (fixed-width region) regardless of actual type width, plus + * per-row overhead for null bitset and header. Arrow columnar uses the actual type width and + * amortizes per-column overhead across the batch. This method returns the estimated ratio so + * that Arrow-reported shuffle sizes can be scaled to approximate what Spark would report. + */ + def estimateUnsafeRowMultiplier(fields: Seq[Attribute]): Double = { + val numFields = fields.size + if (numFields == 0) return 1.0 + + // UnsafeRow per-row: 4 bytes header + null bitset (8-byte aligned) + 8 bytes per field + val unsafeRowBytesPerRow = + 4.0 + math.ceil(numFields / 64.0) * 8.0 + 8.0 * numFields + + // Arrow per-value: actual type width (null bitmap overhead is negligible at batch sizes) + val arrowBytesPerRow = fields.map { attr => + arrowTypeWidth(attr.dataType) + }.sum + + val ratio = if (arrowBytesPerRow > 0) { + unsafeRowBytesPerRow / arrowBytesPerRow + } else { + 2.0 + } + + Math.max(ratio, 1.0) + } + + private def arrowTypeWidth(dataType: DataType): Double = dataType match { + case BooleanType => 0.125 // 1 bit + case ByteType => 1.0 + case ShortType => 2.0 + case IntegerType | FloatType | DateType => 4.0 + case LongType | DoubleType | TimestampType | TimestampNTZType => 8.0 + case _: DecimalType => 16.0 + // Strings/binary: Arrow uses 4-byte offset + data; UnsafeRow uses 8-byte pointer + data. + // Use 8.0 as estimated average Arrow cost (4 offset + ~4 avg data bytes). + case StringType | BinaryType => 8.0 + case ArrayType(elementType, _) => arrowTypeWidth(elementType) + 4.0 // offset array + case MapType(keyType, valueType, _) => + arrowTypeWidth(keyType) + arrowTypeWidth(valueType) + 4.0 + case s: StructType => s.fields.map(f => arrowTypeWidth(f.dataType)).sum + case _ => 8.0 + } + override def getSupportLevel(op: ShuffleExchangeExec): SupportLevel = { if (shuffleSupported(op).isDefined) Compatible() else Unsupported() } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala index 751bf4c1f5..2741a13e41 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala @@ -28,9 +28,11 @@ import org.scalatest.Tag import org.apache.hadoop.fs.Path import org.apache.spark.SparkEnv import org.apache.spark.sql.{CometTestBase, DataFrame, Dataset, Row} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions.{col, count, sum} +import org.apache.spark.sql.types._ import org.apache.comet.CometConf @@ -499,4 +501,50 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper } } } + + test("estimateUnsafeRowMultiplier: narrow fixed-width schema has high multiplier") { + // Schema of boolean + byte + short columns: Arrow uses ~3.125 bytes/row, + // UnsafeRow uses 4 + 8 + 24 = 36 bytes/row + val fields = Seq( + AttributeReference("a", BooleanType)(), + AttributeReference("b", ByteType)(), + AttributeReference("c", ShortType)()) + val multiplier = CometShuffleExchangeExec.estimateUnsafeRowMultiplier(fields) + assert(multiplier > 5.0, s"Expected > 5.0 for narrow types, got $multiplier") + } + + test("estimateUnsafeRowMultiplier: wide fixed-width schema (longs/doubles)") { + // Schema of 4 longs: Arrow uses 32 bytes/row, UnsafeRow uses 4 + 8 + 32 = 44 bytes/row + val fields = (1 to 4).map(i => AttributeReference(s"c$i", LongType)()) + val multiplier = CometShuffleExchangeExec.estimateUnsafeRowMultiplier(fields) + assert( + multiplier >= 1.0 && multiplier < 2.0, + s"Expected ~1.3 for all-long schema, got $multiplier") + } + + test("estimateUnsafeRowMultiplier: string-heavy schema") { + val fields = (1 to 5).map(i => AttributeReference(s"s$i", StringType)()) + val multiplier = CometShuffleExchangeExec.estimateUnsafeRowMultiplier(fields) + assert(multiplier >= 1.0, s"Expected >= 1.0 for string schema, got $multiplier") + } + + test("estimateUnsafeRowMultiplier: empty schema returns 1.0") { + val multiplier = CometShuffleExchangeExec.estimateUnsafeRowMultiplier(Seq.empty) + assert(multiplier == 1.0) + } + + test("estimateUnsafeRowMultiplier: mixed TPC-DS-like schema") { + // Typical dimension table: ints, strings, decimals + val fields = Seq( + AttributeReference("id", IntegerType)(), + AttributeReference("name", StringType)(), + AttributeReference("price", DecimalType(10, 2))(), + AttributeReference("qty", IntegerType)(), + AttributeReference("date", DateType)(), + AttributeReference("flag", BooleanType)()) + val multiplier = CometShuffleExchangeExec.estimateUnsafeRowMultiplier(fields) + assert( + multiplier >= 1.0 && multiplier <= 4.0, + s"Expected between 1.0 and 4.0 for mixed schema, got $multiplier") + } }