Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -557,16 +557,29 @@ object CometConf extends ShimCometConf {
.doubleConf
.createWithDefault(10.0)

val COMET_EXCHANGE_SIZE_MULTIPLIER: ConfigEntry[Double] = conf(
"spark.comet.shuffle.sizeInBytesMultiplier")
val COMET_EXCHANGE_SIZE_MULTIPLIER: ConfigEntry[Double] =
conf("spark.comet.shuffle.sizeInBytesMultiplier")
.category(CATEGORY_SHUFFLE)
.doc(
"Comet shuffle uses Arrow columnar format which is more compact than Spark's UnsafeRow " +
"format. This causes Spark's AQE to underestimate shuffle data sizes, potentially " +
"choosing suboptimal join strategies (e.g. broadcast instead of sort-merge). " +
"This multiplier is applied to the reported shuffle data size to compensate. " +
"Only used when spark.comet.shuffle.sizeInBytesMultiplier.dynamic is false.")
.doubleConf
.createWithDefault(2.0)

val COMET_EXCHANGE_SIZE_MULTIPLIER_DYNAMIC: ConfigEntry[Boolean] = conf(
"spark.comet.shuffle.sizeInBytesMultiplier.dynamic")
.category(CATEGORY_SHUFFLE)
.doc(
"Comet reports smaller sizes for shuffle due to using Arrow's columnar memory format " +
"and this can result in Spark choosing a different join strategy due to the estimated " +
"size of the exchange being smaller. Comet will multiple sizeInBytes by this amount to " +
"avoid regressions in join strategy.")
.doubleConf
.createWithDefault(1.0)
"When true, Comet estimates the size multiplier dynamically based on the shuffle " +
"output schema rather than using the static spark.comet.shuffle.sizeInBytesMultiplier " +
"value. The dynamic estimate accounts for per-column type widths to approximate " +
"how much larger the data would be in Spark's UnsafeRow format compared to Arrow " +
"columnar format.")
.booleanConf
.createWithDefault(false)

val COMET_DEBUG_ENABLED: ConfigEntry[Boolean] =
conf("spark.comet.debug.enabled")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,12 @@ case class CometShuffleExchangeExec(
new CometShuffledBatchRDD(shuffleDependency, readMetrics, partitionSpecs)

override def runtimeStatistics: Statistics = {
val dataSize =
metrics("dataSize").value * Math.max(CometConf.COMET_EXCHANGE_SIZE_MULTIPLIER.get(conf), 1)
val multiplier = if (CometConf.COMET_EXCHANGE_SIZE_MULTIPLIER_DYNAMIC.get(conf)) {
CometShuffleExchangeExec.estimateUnsafeRowMultiplier(child.output)
} else {
Math.max(CometConf.COMET_EXCHANGE_SIZE_MULTIPLIER.get(conf), 1)
}
val dataSize = metrics("dataSize").value * multiplier
val rowCount = metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN).value
Statistics(dataSize.toLong, Some(rowCount))
}
Expand Down Expand Up @@ -222,6 +226,53 @@ object CometShuffleExchangeExec
with CometTypeShim
with SQLConfHelper {

/**
* Estimates the ratio of Spark UnsafeRow size to Arrow columnar size for a given schema.
*
* UnsafeRow uses 8 bytes per field (fixed-width region) regardless of actual type width, plus
* per-row overhead for null bitset and header. Arrow columnar uses the actual type width and
* amortizes per-column overhead across the batch. This method returns the estimated ratio so
* that Arrow-reported shuffle sizes can be scaled to approximate what Spark would report.
*/
def estimateUnsafeRowMultiplier(fields: Seq[Attribute]): Double = {
val numFields = fields.size
if (numFields == 0) return 1.0

// UnsafeRow per-row: 4 bytes header + null bitset (8-byte aligned) + 8 bytes per field
val unsafeRowBytesPerRow =
4.0 + math.ceil(numFields / 64.0) * 8.0 + 8.0 * numFields

// Arrow per-value: actual type width (null bitmap overhead is negligible at batch sizes)
val arrowBytesPerRow = fields.map { attr =>
arrowTypeWidth(attr.dataType)
}.sum

val ratio = if (arrowBytesPerRow > 0) {
unsafeRowBytesPerRow / arrowBytesPerRow
} else {
2.0
}

Math.max(ratio, 1.0)
}

private def arrowTypeWidth(dataType: DataType): Double = dataType match {
case BooleanType => 0.125 // 1 bit
case ByteType => 1.0
case ShortType => 2.0
case IntegerType | FloatType | DateType => 4.0
case LongType | DoubleType | TimestampType | TimestampNTZType => 8.0
case _: DecimalType => 16.0
// Strings/binary: Arrow uses 4-byte offset + data; UnsafeRow uses 8-byte pointer + data.
// Use 8.0 as estimated average Arrow cost (4 offset + ~4 avg data bytes).
case StringType | BinaryType => 8.0
case ArrayType(elementType, _) => arrowTypeWidth(elementType) + 4.0 // offset array
case MapType(keyType, valueType, _) =>
arrowTypeWidth(keyType) + arrowTypeWidth(valueType) + 4.0
case s: StructType => s.fields.map(f => arrowTypeWidth(f.dataType)).sum
case _ => 8.0
}

override def getSupportLevel(op: ShuffleExchangeExec): SupportLevel = {
if (shuffleSupported(op).isDefined) Compatible() else Unsupported()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ import org.scalatest.Tag
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkEnv
import org.apache.spark.sql.{CometTestBase, DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions.{col, count, sum}
import org.apache.spark.sql.types._

import org.apache.comet.CometConf

Expand Down Expand Up @@ -499,4 +501,50 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper
}
}
}

test("estimateUnsafeRowMultiplier: narrow fixed-width schema has high multiplier") {
// Schema of boolean + byte + short columns: Arrow uses ~3.125 bytes/row,
// UnsafeRow uses 4 + 8 + 24 = 36 bytes/row
val fields = Seq(
AttributeReference("a", BooleanType)(),
AttributeReference("b", ByteType)(),
AttributeReference("c", ShortType)())
val multiplier = CometShuffleExchangeExec.estimateUnsafeRowMultiplier(fields)
assert(multiplier > 5.0, s"Expected > 5.0 for narrow types, got $multiplier")
}

test("estimateUnsafeRowMultiplier: wide fixed-width schema (longs/doubles)") {
// Schema of 4 longs: Arrow uses 32 bytes/row, UnsafeRow uses 4 + 8 + 32 = 44 bytes/row
val fields = (1 to 4).map(i => AttributeReference(s"c$i", LongType)())
val multiplier = CometShuffleExchangeExec.estimateUnsafeRowMultiplier(fields)
assert(
multiplier >= 1.0 && multiplier < 2.0,
s"Expected ~1.3 for all-long schema, got $multiplier")
}

test("estimateUnsafeRowMultiplier: string-heavy schema") {
val fields = (1 to 5).map(i => AttributeReference(s"s$i", StringType)())
val multiplier = CometShuffleExchangeExec.estimateUnsafeRowMultiplier(fields)
assert(multiplier >= 1.0, s"Expected >= 1.0 for string schema, got $multiplier")
}

test("estimateUnsafeRowMultiplier: empty schema returns 1.0") {
val multiplier = CometShuffleExchangeExec.estimateUnsafeRowMultiplier(Seq.empty)
assert(multiplier == 1.0)
}

test("estimateUnsafeRowMultiplier: mixed TPC-DS-like schema") {
// Typical dimension table: ints, strings, decimals
val fields = Seq(
AttributeReference("id", IntegerType)(),
AttributeReference("name", StringType)(),
AttributeReference("price", DecimalType(10, 2))(),
AttributeReference("qty", IntegerType)(),
AttributeReference("date", DateType)(),
AttributeReference("flag", BooleanType)())
val multiplier = CometShuffleExchangeExec.estimateUnsafeRowMultiplier(fields)
assert(
multiplier >= 1.0 && multiplier <= 4.0,
s"Expected between 1.0 and 4.0 for mixed schema, got $multiplier")
}
}
Loading