Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,47 +17,201 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.{BinaryComparison, DoubleLiteral, Expression, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Rand}
import org.apache.spark.sql.catalyst.expressions.{Add, BinaryComparison, Divide,
DoubleLiteral, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, LessThan,
LessThanOrEqual, Literal, Multiply, Rand, Subtract}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_COMPARISON, EXPRESSION_WITH_RANDOM_SEED, LITERAL}
import org.apache.spark.sql.catalyst.trees.TreePattern.EXPRESSION_WITH_RANDOM_SEED

/**
* Rand() generates a random column with i.i.d. uniformly distributed values in [0, 1), so
* compare double literal value with 1.0 or 0.0 could eliminate Rand() in binary comparison.
*
* 1. Converts the binary comparison to true literal when the comparison value must be true.
* 2. Converts the binary comparison to false literal when the comparison value must be false.
*/
object OptimizeRand extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan =
plan.transformAllExpressionsWithPruning(_.containsAllPatterns(
EXPRESSION_WITH_RANDOM_SEED, LITERAL, BINARY_COMPARISON), ruleId) {
case op @ BinaryComparison(DoubleLiteral(_), _: Rand) => eliminateRand(swapComparison(op))
case op @ BinaryComparison(_: Rand, DoubleLiteral(_)) => eliminateRand(op)
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(
EXPRESSION_WITH_RANDOM_SEED), ruleId) {
case op @ EqualTo(DoubleLiteral(_), _: Rand) =>
eliminateRand(EqualTo(op.right, op.left))
case op @ BinaryComparison(DoubleLiteral(_), _: Rand)
if !op.isInstanceOf[EqualTo] =>
eliminateRand(swapComparison(op))
case op @ BinaryComparison(_: Rand, DoubleLiteral(_)) =>
eliminateRand(op)
case op: BinaryComparison
if isDirectRandChild(op.left) || isDirectRandChild(op.right) =>
optimizeArithmetic(op)
}

private def isDirectRandChild(expr: Expression): Boolean = expr match {
case _: Rand => true
case Add(l, r, _) => l.isInstanceOf[Rand] || r.isInstanceOf[Rand]
case Subtract(l, r, _) => l.isInstanceOf[Rand] || r.isInstanceOf[Rand]
case Multiply(l, r, _) => l.isInstanceOf[Rand] || r.isInstanceOf[Rand]
case Divide(l, r, _) => l.isInstanceOf[Rand] || r.isInstanceOf[Rand]
case _ => false
}

/**
* Swaps the left and right sides of some binary comparisons. e.g., transform "a < b" to "b > a"
*/
private def swapComparison(comparison: BinaryComparison): BinaryComparison = comparison match {
case a LessThan b => GreaterThan(b, a)
case a LessThanOrEqual b => GreaterThanOrEqual(b, a)
case a GreaterThan b => LessThan(b, a)
case a GreaterThanOrEqual b => LessThanOrEqual(b, a)
case o => o
private def hasRand(expr: Expression): Boolean = expr match {
case _: Rand => true
case a: Add => hasRand(a.left) || hasRand(a.right)
case s: Subtract => hasRand(s.left) || hasRand(s.right)
case m: Multiply => hasRand(m.left) || hasRand(m.right)
case d: Divide => hasRand(d.left) || hasRand(d.right)
case _ => false
}

private def swapComparison(comparison: BinaryComparison): BinaryComparison =
comparison match {
case GreaterThan(l, r) => LessThan(r, l)
case GreaterThanOrEqual(l, r) => LessThanOrEqual(r, l)
case LessThan(l, r) => GreaterThan(r, l)
case LessThanOrEqual(l, r) => GreaterThanOrEqual(r, l)
case o => o
}

private def eliminateRand(op: BinaryComparison): Expression = op match {
case GreaterThan(_: Rand, DoubleLiteral(value)) =>
if (value < 0.0) TrueLiteral else if (value >= 1.0) FalseLiteral else op
case GreaterThanOrEqual(_: Rand, DoubleLiteral(value)) =>
if (value <= 0.0) TrueLiteral else if (value >= 1.0) FalseLiteral else op
case LessThan(_: Rand, DoubleLiteral(value)) =>
if (value >= 1.0) TrueLiteral else if (value <= 0.0) FalseLiteral else op
case LessThanOrEqual(_: Rand, DoubleLiteral(value)) =>
if (value >= 1.0) TrueLiteral else if (value < 0.0) FalseLiteral else op
case GreaterThan(_: Rand, DoubleLiteral(v)) =>
if (v < 0.0) TrueLiteral else if (v >= 1.0) FalseLiteral else op
case GreaterThanOrEqual(_: Rand, DoubleLiteral(v)) =>
if (v <= 0.0) TrueLiteral else if (v >= 1.0) FalseLiteral else op
case LessThan(_: Rand, DoubleLiteral(v)) =>
if (v >= 1.0) TrueLiteral else if (v <= 0.0) FalseLiteral else op
case LessThanOrEqual(_: Rand, DoubleLiteral(v)) =>
if (v >= 1.0) TrueLiteral else if (v < 0.0) FalseLiteral else op
case EqualTo(_: Rand, DoubleLiteral(v)) =>
if (v < 0.0 || v >= 1.0) FalseLiteral else op
case other => other
}

private def extractDouble(lit: Expression): Option[Double] = lit match {
case DoubleLiteral(v) => Some(v)
case Literal(v: Double, _) => Some(v)
case Literal(v: java.lang.Double, _) => Some(v.doubleValue())
case Literal(v: java.lang.Number, _) => Some(v.doubleValue())
case _ => None
}

case class RandExpr(coeff: Double, offset: Double)

private def extractRandCoeffOffset(expr: Expression): Option[RandExpr] = {
// This function extracts coefficient and offset from expressions containing rand().
// It normalizes expressions into the form: coeff * rand() + offset
// Note: Only supports patterns where rand() is a direct child of arithmetic
// operations. Deeply nested expressions like (rand() + 1) * 2 are not supported.
expr match {
case _: Rand => Some(RandExpr(1.0, 0.0))
case m: Multiply =>
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hasRand (L40) recurses through Add/Subtract/Multiply/Divide, but extractRandCoeffOffset doesn't — Multiply (L83) needs a direct Rand child, Subtract (L99) only handles <rand> - <lit>. Is the gate intended to be wider than the extractor?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice question! I had narrow the check range to match the extractor's capabilities by replacing hasRand() with a new isDirectRandChild() function that only checks for direct children.

if (m.left.isInstanceOf[Rand]) {
extractDouble(m.right).map(coeff => RandExpr(coeff, 0.0))
} else if (m.right.isInstanceOf[Rand]) {
extractDouble(m.left).map(coeff => RandExpr(coeff, 0.0))
} else {
None
}
case a: Add =>
extractRandCoeffOffset(a.left).flatMap { left =>
extractDouble(a.right).map(right => RandExpr(left.coeff, left.offset + right))
}.orElse {
extractRandCoeffOffset(a.right).flatMap { right =>
extractDouble(a.left).map(left => RandExpr(right.coeff, right.offset + left))
}
}
case s: Subtract =>
for {
left <- extractRandCoeffOffset(s.left)
right <- extractDouble(s.right)
} yield RandExpr(left.coeff, left.offset - right)
case d: Divide =>
for {
left <- extractRandCoeffOffset(d.left)
denom <- extractDouble(d.right) if denom != 0.0
} yield RandExpr(left.coeff / denom, left.offset / denom)
case _ => None
}
}

private def optimizeArithmetic(op: BinaryComparison): Expression = {
extractDouble(op.right).flatMap { litVal =>
if (hasRand(op.left)) {
val opName = op match {
case _: LessThan => Some("LT")
case _: GreaterThan => Some("GT")
case _: LessThanOrEqual => Some("LTE")
case _: GreaterThanOrEqual => Some("GTE")
case _: EqualTo => Some("EQ")
case _ => None
}
opName.flatMap { name =>
extractRandCoeffOffset(op.left).map { randExpr =>
optimizeWithCoeffOffset(randExpr.coeff, randExpr.offset, litVal, name, op)
}
}
} else {
None
}
}.orElse {
extractDouble(op.left).flatMap { litVal =>
if (hasRand(op.right)) {
val swapped = swapComparison(op)
val opName = swapped match {
case _: LessThan => Some("LT")
case _: GreaterThan => Some("GT")
case _: LessThanOrEqual => Some("LTE")
case _: GreaterThanOrEqual => Some("GTE")
case _: EqualTo => Some("EQ")
case _ => None
}
opName.flatMap { name =>
extractRandCoeffOffset(op.right).map { randExpr =>
optimizeWithCoeffOffset(randExpr.coeff, randExpr.offset, litVal, name, op)
}
}
} else {
None
}
}
}.getOrElse(op)
}

private def optimizeWithCoeffOffset(coeff: Double, offset: Double,
value: Double, op: String, original: Expression): Expression = {
if (coeff == 0.0) {
val compVal = offset
op match {
case "GT" => if (compVal > value) TrueLiteral else FalseLiteral
case "GTE" => if (compVal >= value) TrueLiteral else FalseLiteral
case "LT" => if (compVal < value) TrueLiteral else FalseLiteral
case "LTE" => if (compVal <= value) TrueLiteral else FalseLiteral
case "EQ" => if (compVal == value) TrueLiteral else FalseLiteral
case _ => original
}
} else if (coeff > 0.0) {
val t = (value - offset) / coeff
op match {
case "GT" => if (t < 0.0) TrueLiteral
else if (t >= 1.0) FalseLiteral else original
case "GTE" => if (t <= 0.0) TrueLiteral
else if (t > 1.0) FalseLiteral else original
case "LT" => if (t >= 1.0) TrueLiteral
else if (t <= 0.0) FalseLiteral else original
case "LTE" => if (t >= 1.0) TrueLiteral
else if (t < 0.0) FalseLiteral else original
case "EQ" => if (t < 0.0 || t >= 1.0) FalseLiteral else original
case _ => original
}
} else {
val t = (value - offset) / coeff
op match {
case "GT" => if (t <= 0.0) FalseLiteral
else if (t > 1.0) TrueLiteral else original
case "GTE" => if (t < 0.0) FalseLiteral
else if (t >= 1.0) TrueLiteral else original
case "LT" => if (t < 0.0) TrueLiteral
else if (t >= 1.0) FalseLiteral else original
case "LTE" => if (t <= 0.0) TrueLiteral
else if (t > 1.0) FalseLiteral else original
case "EQ" => if (t < 0.0 || t >= 1.0) FalseLiteral else original
case _ => original
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class OptimizeRandSuite extends PlanTest {
val x = testRelation.where($"a".attr.in(1, 3, 5)).subquery("x")
val literal0d = Literal(0d)
val literal1d = Literal(1d)
val literal2d = Literal(2d)
val literal3d = Literal(3d)
val literal6d = Literal(6d)
val literalHalf = Literal(0.5)
val negativeLiteral1d = Literal(-1d)
val rand5 = rand(5)
Expand Down Expand Up @@ -173,4 +176,95 @@ class OptimizeRandSuite extends PlanTest {
}
}

test("Optimize arithmetic expressions with rand") {
// rand() * 3 < 3 should be optimized to true
val plan1 = testRelation.select((rand5 * literal3d < literal3d).as("flag")).analyze
val actual1 = Optimize.execute(plan1)
val correctAnswer1 = testRelation.select(Alias(TrueLiteral, "flag")()).analyze
comparePlans(actual1, correctAnswer1)

// rand() + 1 < 2 should be optimized to true
val plan2 = testRelation.select((rand5 + literal1d < literal2d).as("flag")).analyze
val actual2 = Optimize.execute(plan2)
val correctAnswer2 = testRelation.select(Alias(TrueLiteral, "flag")()).analyze
comparePlans(actual2, correctAnswer2)

// rand() - 1 < 0 should be optimized to true
val plan3 = testRelation.select((rand5 - literal1d < literal0d).as("flag")).analyze
val actual3 = Optimize.execute(plan3)
val correctAnswer3 = testRelation.select(Alias(TrueLiteral, "flag")()).analyze
comparePlans(actual3, correctAnswer3)

// rand() / 2 < 1 should be optimized to true
val plan4 = testRelation.select((rand5 / literal2d < literal1d).as("flag")).analyze
val actual4 = Optimize.execute(plan4)
val correctAnswer4 = testRelation.select(Alias(TrueLiteral, "flag")()).analyze
comparePlans(actual4, correctAnswer4)

// rand() * 2 > 3 should be optimized to false
val plan5 = testRelation.select((rand5 * literal2d > literal3d).as("flag")).analyze
val actual5 = Optimize.execute(plan5)
val correctAnswer5 = testRelation.select(Alias(FalseLiteral, "flag")()).analyze
comparePlans(actual5, correctAnswer5)
}

test("Optimize equality comparison with rand") {
// rand() == 0.5 cannot be optimized (value is in [0, 1) range)
val plan1 = testRelation.select((rand5 === literalHalf).as("flag")).analyze
val actual1 = Optimize.execute(plan1)
comparePlans(actual1, plan1)

// rand() == 2 should be optimized to false (value outside [0, 1) range)
val plan2 = testRelation.select((rand5 === literal2d).as("flag")).analyze
val actual2 = Optimize.execute(plan2)
val correctAnswer2 = testRelation.select(Alias(FalseLiteral, "flag")()).analyze
comparePlans(actual2, correctAnswer2)

// rand() == -1 should be optimized to false (value outside [0, 1) range)
val plan3 = testRelation.select((rand5 === negativeLiteral1d).as("flag")).analyze
val actual3 = Optimize.execute(plan3)
val correctAnswer3 = testRelation.select(Alias(FalseLiteral, "flag")()).analyze
comparePlans(actual3, correctAnswer3)

// 2 == rand() should be optimized to false (literal on left side)
val plan4 = testRelation.select((literal2d === rand5).as("flag")).analyze
val actual4 = Optimize.execute(plan4)
val correctAnswer4 = testRelation.select(Alias(FalseLiteral, "flag")()).analyze
comparePlans(actual4, correctAnswer4)

// -1 == rand() should be optimized to false (literal on left side)
val plan5 = testRelation.select((negativeLiteral1d === rand5).as("flag")).analyze
val actual5 = Optimize.execute(plan5)
val correctAnswer5 = testRelation.select(Alias(FalseLiteral, "flag")()).analyze
comparePlans(actual5, correctAnswer5)
}

test("Benchmark: rand optimization performance benefit") {
val iterations = 1000

// Benchmark unoptimized plan (with rand)
val unoptimizedStartTime = System.nanoTime()
(0 until iterations).foreach { _ =>
val plan = testRelation.select((rand5 * literal3d < literal3d).as("flag")).analyze
plan
}
val unoptimizedTime = System.nanoTime() - unoptimizedStartTime

// Benchmark optimized plan (constant folded to true)
val optimizedStartTime = System.nanoTime()
(0 until iterations).foreach { _ =>
val plan = testRelation.select((rand5 * literal3d < literal3d).as("flag")).analyze
Optimize.execute(plan)
}
val optimizedTime = System.nanoTime() - optimizedStartTime

// Log the performance improvement (for documentation purposes)
val improvement = ((unoptimizedTime - optimizedTime).toDouble / unoptimizedTime * 100).toLong
val msg = s"Planning time improved by ~${improvement}% after optimization " +
s"(unoptimized: ${unoptimizedTime / 1000000}ms, optimized: ${optimizedTime / 1000000}ms)"
// scalastyle:off println
println(msg)
// scalastyle:on println
}

}