Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.rules._
* +- Aggregate [__qid],
* [first(left.col0) AS left.col0, ..., first(left.colN-1) AS left.colN-1,
* max_by(struct(right.*), expr, k) AS _matches]
* +- Join LeftOuter
* +- Join Inner // or LeftOuter for `LEFT OUTER NEAREST BY`
* :- Project [left.*, uuid() AS __qid]
* : +- left
* +- right
Expand Down Expand Up @@ -79,18 +79,18 @@ object RewriteNearestByJoin extends Rule[LogicalPlan] {
val taggedLeft = Project(left.output :+ qidAlias, left)
val qidAttr = qidAlias.toAttribute

// 2. LEFT OUTER-join the tagged left with right (no join condition). LEFT OUTER
// (rather than INNER) preserves left rows even when `right` is empty, so that a
// `LEFT OUTER NEAREST BY` query still returns those rows with `NULL` right-side
// columns after the aggregate + inline below. When `right` is non-empty every left
// row already has right-row pairings, so LEFT OUTER and INNER are equivalent.
// 2. Join the tagged left with right (no join condition), using the user's join type.
// For `LEFT OUTER`, left rows with no right-side match are preserved with `NULL`
// right-side columns through the aggregate + inline below; for `INNER`, such rows
// are dropped. When `right` is non-empty every left row already has right-row
// pairings, so `LEFT OUTER` and `INNER` are equivalent in that case.
//
// This synthetic join is an unconditioned cross-product, so `NEAREST BY` queries
// are subject to `CheckCartesianProducts` and will be rejected when the user has
// set `spark.sql.crossJoin.enabled = false`. That is intentional: if the user has
// opted out of cross-products, the NEAREST BY rewrite -- which is itself a bounded
// cross-product today -- should not silently bypass that choice.
val join = Join(taggedLeft, right, LeftOuter, None, JoinHint.NONE)
val join = Join(taggedLeft, right, joinType, None, JoinHint.NONE)

val (aggInput, rankingForAgg) = if (!rankingExpression.deterministic) {
val rankingAlias = Alias(rankingExpression, "__ranking__")()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, CreateStruct, Inline, Literal, Rand, Uuid}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, First, MaxMinByK}
import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, NearestByDistance, NearestBySimilarity, PlanTest}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, NearestByDistance, NearestBySimilarity, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, JoinHint, LocalRelation, NearestByJoin, Project}
import org.apache.spark.sql.types.IntegerType

Expand All @@ -41,10 +41,10 @@ class RewriteNearestByJoinSuite extends PlanTest {
numResults: Int,
ranking: org.apache.spark.sql.catalyst.expressions.Expression,
reverse: Boolean,
outer: Boolean) = {
joinType: JoinType) = {
val qidAlias = Alias(Uuid(Some(0L)), "__qid")()
val taggedLeft = Project(left.output :+ qidAlias, left)
val join = Join(taggedLeft, right, LeftOuter, None, JoinHint.NONE)
val join = Join(taggedLeft, right, joinType, None, JoinHint.NONE)

val rightStruct = CreateStruct(right.output)
val topKAgg = MaxMinByK(
Expand All @@ -66,7 +66,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
val generate = Generate(
Inline(matchesAlias.toAttribute),
unrequiredChildIndex = Seq(aggregate.output.indexOf(matchesAlias.toAttribute)),
outer = outer,
outer = joinType == LeftOuter,
qualifier = None,
generatorOutput = generatorOutput,
child = aggregate)
Expand All @@ -89,7 +89,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
val expected = expectedRewrite(
left, right, 5,
ranking = left.output(0) + right.output(0),
reverse = false, outer = false)
reverse = false, joinType = Inner)

comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
}
Expand All @@ -106,7 +106,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
val expected = expectedRewrite(
left, right, 3,
ranking = left.output(0) - right.output(0),
reverse = true, outer = false)
reverse = true, joinType = Inner)

comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
}
Expand All @@ -123,7 +123,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
val expected = expectedRewrite(
left, right, 1,
ranking = left.output(0) + right.output(0),
reverse = false, outer = true)
reverse = false, joinType = LeftOuter)

comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
}
Expand All @@ -140,11 +140,38 @@ class RewriteNearestByJoinSuite extends PlanTest {
val expected = expectedRewrite(
left, right, 2,
ranking = left.output(0) - right.output(0),
reverse = true, outer = true)
reverse = true, joinType = LeftOuter)

comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
}

test("synthetic Join uses the user's joinType") {
// Locks in that the rewrite's synthetic Join carries the user's `joinType`
// (Inner or LeftOuter).
val left = LocalRelation($"a".int, $"b".int)
val right = LocalRelation($"x".int, $"y".int)
Seq(Inner, LeftOuter).foreach { joinType =>
val query = NearestByJoin(
left, right, joinType, approx = true, numResults = 1,
rankingExpression = left.output(0) + right.output(0),
direction = NearestBySimilarity)

val rewritten = RewriteNearestByJoin(query.analyze)
val syntheticJoin = rewritten.collect { case j: Join => j }
assert(syntheticJoin.size == 1,
s"expected exactly one synthetic Join in the rewritten plan, got ${syntheticJoin.size}")
assert(syntheticJoin.head.joinType == joinType,
s"expected synthetic Join to use $joinType, got ${syntheticJoin.head.joinType}")

val generate = rewritten.collect { case g: Generate => g }
assert(generate.size == 1,
s"expected exactly one Generate in the rewritten plan, got ${generate.size}")
val expectedOuter = joinType == LeftOuter
assert(generate.head.outer == expectedOuter,
s"expected Generate.outer == $expectedOuter for $joinType, got ${generate.head.outer}")
}
}

test("EXACT (approx = false) produces the same rewrite as APPROX") {
// Locks in the current invariant that APPROX and EXACT lower through the same
// brute-force rewrite. If a future change diverges them (e.g. an APPROX-only
Expand All @@ -160,7 +187,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
val expected = expectedRewrite(
left, right, 5,
ranking = left.output(0) + right.output(0),
reverse = false, outer = false)
reverse = false, joinType = Inner)

comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
}
Expand All @@ -177,7 +204,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
val expected = expectedRewrite(
left, right, 1,
ranking = left.output(0) + right.output(0),
reverse = false, outer = false)
reverse = false, joinType = Inner)

comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
}
Expand All @@ -194,7 +221,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
val expected = expectedRewrite(
left, right, NearestByJoin.MaxNumResults,
ranking = left.output(0) + right.output(0),
reverse = false, outer = false)
reverse = false, joinType = Inner)

comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
}
Expand All @@ -214,7 +241,7 @@ class RewriteNearestByJoinSuite extends PlanTest {
val expected = expectedRewrite(
t, tDup, 1,
ranking = t.output(0) + tDup.output(0),
reverse = false, outer = false)
reverse = false, joinType = Inner)

comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,27 @@ Project [user_id#x, product#x]
+- LocalRelation [col1#x, col2#x]


-- !query
SELECT u.user_id, p.product
FROM users u INNER JOIN (SELECT * FROM products WHERE false) p
APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore)
-- !query analysis
Project [user_id#x, product#x]
+- NearestByJoin Inner, true, 1, -abs((score#x - pscore#x)), NearestBySimilarity
:- SubqueryAlias u
: +- SubqueryAlias spark_catalog.default.users
: +- View (`spark_catalog`.`default`.`users`, [user_id#x, score#x])
: +- Project [cast(col1#x as int) AS user_id#x, cast(col2#x as decimal(3,1)) AS score#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias p
+- Project [product#x, pscore#x]
+- Filter false
+- SubqueryAlias spark_catalog.default.products
+- View (`spark_catalog`.`default`.`products`, [product#x, pscore#x])
+- Project [cast(col1#x as string) AS product#x, cast(col2#x as decimal(3,1)) AS pscore#x]
+- LocalRelation [col1#x, col2#x]


-- !query
SELECT u.user_id, p.product
FROM users u INNER JOIN products p
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ SELECT u.user_id, p.product
FROM users u LEFT OUTER JOIN (SELECT * FROM products WHERE false) p
APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore);

-- INNER JOIN with NEAREST BY, empty right side
SELECT u.user_id, p.product
FROM users u INNER JOIN (SELECT * FROM products WHERE false) p
APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore);

-- Explicit INNER keyword
SELECT u.user_id, p.product
FROM users u INNER JOIN products p
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ struct<user_id:int,product:string>
3 NULL


-- !query
SELECT u.user_id, p.product
FROM users u INNER JOIN (SELECT * FROM products WHERE false) p
APPROX NEAREST 1 BY SIMILARITY -abs(u.score - p.pscore)
-- !query schema
struct<user_id:int,product:string>
-- !query output



-- !query
SELECT u.user_id, p.product
FROM users u INNER JOIN products p
Expand Down Expand Up @@ -286,12 +296,12 @@ AdaptiveSparkPlan isFinalPlan=false
+- SortAggregate(key=[__qid#x], functions=[partial_first(user_id#x, false), partial_max_by(named_struct(product, product#x, pscore, pscore#x), __ranking__#x, 1, false, 0, 0)])
+- Sort [__qid#x ASC NULLS FIRST], false, 0
+- Project [user_id#x, __qid#x, product#x, pscore#x, (rand(0) + cast(pscore#x as double)) AS __ranking__#x]
+- BroadcastNestedLoopJoin BuildRight, LeftOuter
:- Project [col1#x AS user_id#x, uuid(Some(x)) AS __qid#x]
: +- LocalTableScan [col1#x, col2#x]
+- BroadcastExchange IdentityBroadcastMode, [plan_id=x]
+- Project [col1#x AS product#x, col2#x AS pscore#x]
+- LocalTableScan [col1#x, col2#x]
+- BroadcastNestedLoopJoin BuildLeft, Inner
:- BroadcastExchange IdentityBroadcastMode, [plan_id=x]
: +- Project [col1#x AS user_id#x, uuid(Some(x)) AS __qid#x]
: +- LocalTableScan [col1#x, col2#x]
+- Project [col1#x AS product#x, col2#x AS pscore#x]
+- LocalTableScan [col1#x, col2#x]


-- !query
Expand All @@ -313,7 +323,7 @@ AdaptiveSparkPlan isFinalPlan=false
+- Exchange hashpartitioning(__qid#x, 4), ENSURE_REQUIREMENTS, [plan_id=x]
+- SortAggregate(key=[__qid#x], functions=[partial_first(user_id#x, false), partial_min_by(named_struct(product, product#x, pscore, pscore#x), abs((score#x - pscore#x)), 2, true, 0, 0)])
+- Sort [__qid#x ASC NULLS FIRST], false, 0
+- BroadcastNestedLoopJoin BuildRight, LeftOuter
+- BroadcastNestedLoopJoin BuildRight, Inner
:- Filter (user_id#x > 1)
: +- Project [col1#x AS user_id#x, col2#x AS score#x, uuid(Some(x)) AS __qid#x]
: +- LocalTableScan [col1#x, col2#x]
Expand Down Expand Up @@ -342,7 +352,7 @@ AdaptiveSparkPlan isFinalPlan=false
+- Exchange hashpartitioning(__qid#x, 4), ENSURE_REQUIREMENTS, [plan_id=x]
+- SortAggregate(key=[__qid#x], functions=[partial_first(user_id#x, false), partial_min_by(named_struct(product, product#x, pscore, pscore#x), abs((score#x - pscore#x)), 2, true, 0, 0)])
+- Sort [__qid#x ASC NULLS FIRST], false, 0
+- BroadcastNestedLoopJoin BuildRight, LeftOuter
+- BroadcastNestedLoopJoin BuildRight, Inner
:- Project [col1#x AS user_id#x, col2#x AS score#x, uuid(Some(x)) AS __qid#x]
: +- LocalTableScan [col1#x, col2#x]
+- BroadcastExchange IdentityBroadcastMode, [plan_id=x]
Expand Down