Skip to content

Commit

Permalink
[SPARK-46946][SQL] Supporting broadcast of multiple filtering keys in…
Browse files Browse the repository at this point in the history
… DynamicPruning

### What changes were proposed in this pull request?

This PR extends `DynamicPruningSubquery` to support broadcasting of multiple filtering keys (instead of one as before). The majority of the PR is to simply generalise singularity to plurality.

**Note:** We actually do not use the multiple filtering keys `DynamicPruningSubquery` in this PR, we are doing this to make supporting DPP Null Safe Equality or multiple Equality predicates easier in the future.

In Null Safe Equality JOIN, the JOIN condition `a <=> b` is transformed to `Coalesce(key1, Literal(key1.dataType)) = Coalesce(key2, Literal(key2.dataType)) AND IsNull(key1) = IsNull(key2)`. In order to have the highest pruning efficiency, we broadcast the 2 keys `Coalesce(key, Literal(key.dataType))` and `IsNull(key)` and use them to prune the other side at the same time.

Before, the `DynamicPruningSubquery` only has one broadcasting key and we only supports DPP for one `EqualTo` JOIN predicate, now we are extending the subquery to multiple broadcasting keys. Please note that DPP has not been supported for multiple JOIN predicates.

Put it in another way, at the moment, we don't insert a DPP Filter for multiple JOIN predicates at the same time, only potentially insert a DPP Filter for a given Equality JOIN predicate.

### Why are the changes needed?

To make supporting DPP Null Safe Equality or DPP multiple Equality predicates easier in the future.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Added unit tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #44988 from longvu-db/multiple-broadcast-filtering-keys.

Authored-by: Thang Long VU <long.vu@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
longvu-db authored and cloud-fan committed Feb 2, 2024
1 parent 362a4d4 commit 22e5938
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ trait DynamicPruning extends Predicate
* beneficial and so it should be executed even if it cannot reuse the results of the
* broadcast through ReuseExchange; otherwise, it will use the filter only if it
* can reuse the results of the broadcast through ReuseExchange
* @param broadcastKeyIndex the index of the filtering key collected from the broadcast
* @param broadcastKeyIndices the indices of the filtering keys collected from the broadcast
*/
case class DynamicPruningSubquery(
pruningKey: Expression,
buildQuery: LogicalPlan,
buildKeys: Seq[Expression],
broadcastKeyIndex: Int,
broadcastKeyIndices: Seq[Int],
onlyInBroadcast: Boolean,
exprId: ExprId = NamedExpression.newExprId,
hint: Option[HintInfo] = None)
Expand All @@ -67,10 +67,12 @@ case class DynamicPruningSubquery(
buildQuery.resolved &&
buildKeys.nonEmpty &&
buildKeys.forall(_.resolved) &&
broadcastKeyIndex >= 0 &&
broadcastKeyIndex < buildKeys.size &&
broadcastKeyIndices.forall(idx => idx >= 0 && idx < buildKeys.size) &&
buildKeys.forall(_.references.subsetOf(buildQuery.outputSet)) &&
pruningKey.dataType == buildKeys(broadcastKeyIndex).dataType
// DynamicPruningSubquery should only have a single broadcasting key since
// there are no usage for multiple broadcasting keys at the moment.
broadcastKeyIndices.size == 1 &&
child.dataType == buildKeys(broadcastKeyIndices.head).dataType
}

final override def nodePatternsInternal(): Seq[TreePattern] = Seq(DYNAMIC_PRUNING_SUBQUERY)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.types.IntegerType

class DynamicPruningSubquerySuite extends SparkFunSuite {
private val pruningKeyExpression = Literal(1)

private val validDynamicPruningSubquery = DynamicPruningSubquery(
pruningKey = pruningKeyExpression,
buildQuery = Project(Seq(AttributeReference("id", IntegerType)()),
LocalRelation(AttributeReference("id", IntegerType)())),
buildKeys = Seq(pruningKeyExpression),
broadcastKeyIndices = Seq(0),
onlyInBroadcast = false
)

test("pruningKey data type matches single buildKey") {
val dynamicPruningSubquery = validDynamicPruningSubquery
.copy(buildKeys = Seq(Literal(2023)))
assert(dynamicPruningSubquery.resolved == true)
}

test("pruningKey data type is a Struct and matches with Struct buildKey") {
val dynamicPruningSubquery = validDynamicPruningSubquery
.copy(pruningKey = CreateStruct(Seq(Literal(1), Literal.FalseLiteral)),
buildKeys = Seq(CreateStruct(Seq(Literal(2), Literal.TrueLiteral))))
assert(dynamicPruningSubquery.resolved == true)
}

test("multiple buildKeys but only one broadcastKeyIndex") {
val dynamicPruningSubquery = validDynamicPruningSubquery
.copy(buildKeys = Seq(Literal(0), Literal(2), Literal(0), Literal(9)),
broadcastKeyIndices = Seq(1))
assert(dynamicPruningSubquery.resolved == true)
}

test("pruningKey data type does not match the single buildKey") {
val dynamicPruningSubquery = validDynamicPruningSubquery.copy(
pruningKey = Literal.TrueLiteral,
buildKeys = Seq(Literal(2013)))
assert(dynamicPruningSubquery.resolved == false)
}

test("pruningKey data type is a Struct but mismatch with Struct buildKey") {
val dynamicPruningSubquery = validDynamicPruningSubquery
.copy(pruningKey = CreateStruct(Seq(Literal(1), Literal.FalseLiteral)),
buildKeys = Seq(CreateStruct(Seq(Literal.TrueLiteral, Literal(2)))))
assert(dynamicPruningSubquery.resolved == false)
}

test("DynamicPruningSubquery should only have a single broadcasting key") {
val dynamicPruningSubquery = validDynamicPruningSubquery
.copy(buildKeys = Seq(Literal(2025), Literal(2), Literal(1809)),
broadcastKeyIndices = Seq(0, 2))
assert(dynamicPruningSubquery.resolved == false)
}

test("duplicates in broadcastKeyIndices, and also should not be allowed") {
val dynamicPruningSubquery = validDynamicPruningSubquery
.copy(buildKeys = Seq(Literal(2)),
broadcastKeyIndices = Seq(0, 0))
assert(dynamicPruningSubquery.resolved == false)
}

test("broadcastKeyIndex out of bounds") {
val dynamicPruningSubquery = validDynamicPruningSubquery
.copy(broadcastKeyIndices = Seq(1))
assert(dynamicPruningSubquery.resolved == false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
*/
case class SubqueryAdaptiveBroadcastExec(
name: String,
index: Int,
indices: Seq[Int],
onlyInBroadcast: Boolean,
@transient buildPlan: LogicalPlan,
buildKeys: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,34 @@ import org.apache.spark.util.ThreadUtils

/**
* Physical plan for a custom subquery that collects and transforms the broadcast key values.
* This subquery retrieves the partition key from the broadcast results based on the type of
* [[HashedRelation]] returned. If the key is packed inside a Long, we extract it through
* This subquery retrieves the partition keys from the broadcast results based on the type of
* [[HashedRelation]] returned. If a key is packed inside a Long, we extract it through
* bitwise operations, otherwise we return it from the appropriate index of the [[UnsafeRow]].
*
* @param index the index of the join key in the list of keys from the build side
* @param indices the indices of the join keys in the list of keys from the build side
* @param buildKeys the join keys from the build side of the join used
* @param child the BroadcastExchange or the AdaptiveSparkPlan with BroadcastQueryStageExec
* from the build side of the join
*/
case class SubqueryBroadcastExec(
name: String,
index: Int,
indices: Seq[Int],
buildKeys: Seq[Expression],
child: SparkPlan) extends BaseSubqueryExec with UnaryExecNode {

// `SubqueryBroadcastExec` is only used with `InSubqueryExec`. No one would reference this output,
// so the exprId doesn't matter here. But it's important to correctly report the output length, so
// that `InSubqueryExec` can know it's the single-column execution mode, not multi-column.
// that `InSubqueryExec` can know whether it's the single-column or multi-column execution mode.
override def output: Seq[Attribute] = {
val key = buildKeys(index)
val name = key match {
case n: NamedExpression => n.name
case Cast(n: NamedExpression, _, _, _) => n.name
case _ => "key"
indices.map { idx =>
val key = buildKeys(idx)
val name = key match {
case n: NamedExpression => n.name
case Cast(n: NamedExpression, _, _, _) => n.name
case _ => s"key_$idx"
}
AttributeReference(name, key.dataType, key.nullable)()
}
Seq(AttributeReference(name, key.dataType, key.nullable)())
}

override lazy val metrics = Map(
Expand All @@ -69,7 +71,7 @@ case class SubqueryBroadcastExec(

override def doCanonicalize(): SparkPlan = {
val keys = buildKeys.map(k => QueryPlan.normalizeExpressions(k, child.output))
SubqueryBroadcastExec("dpp", index, keys, child.canonicalized)
SubqueryBroadcastExec("dpp", indices, keys, child.canonicalized)
}

@transient
Expand All @@ -84,14 +86,15 @@ case class SubqueryBroadcastExec(
val beforeCollect = System.nanoTime()

val broadcastRelation = child.executeBroadcast[HashedRelation]().value
val (iter, expr) = if (broadcastRelation.isInstanceOf[LongHashedRelation]) {
(broadcastRelation.keys(), HashJoin.extractKeyExprAt(buildKeys, index))
val exprs = if (broadcastRelation.isInstanceOf[LongHashedRelation]) {
indices.map { idx => HashJoin.extractKeyExprAt(buildKeys, idx) }
} else {
(broadcastRelation.keys(),
BoundReference(index, buildKeys(index).dataType, buildKeys(index).nullable))
indices.map { idx =>
BoundReference(idx, buildKeys(idx).dataType, buildKeys(idx).nullable) }
}

val proj = UnsafeProjection.create(expr)
val proj = UnsafeProjection.create(exprs)
val iter = broadcastRelation.keys()
val keyIter = iter.map(proj).map(_.copy())

val rows = if (broadcastRelation.keyIsUnique) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ case class PlanAdaptiveDynamicPruningFilters(
plan.transformAllExpressionsWithPruning(
_.containsAllPatterns(DYNAMIC_PRUNING_EXPRESSION, IN_SUBQUERY_EXEC)) {
case DynamicPruningExpression(InSubqueryExec(
value, SubqueryAdaptiveBroadcastExec(name, index, onlyInBroadcast, buildPlan, buildKeys,
value, SubqueryAdaptiveBroadcastExec(name, indices, onlyInBroadcast, buildPlan, buildKeys,
adaptivePlan: AdaptiveSparkPlanExec), exprId, _, _, _)) =>
val packedKeys = BindReferences.bindReferences(
HashJoin.rewriteKeyExpr(buildKeys), adaptivePlan.executedPlan.output)
Expand All @@ -61,14 +61,14 @@ case class PlanAdaptiveDynamicPruningFilters(
val newAdaptivePlan = adaptivePlan.copy(inputPlan = exchange)

val broadcastValues = SubqueryBroadcastExec(
name, index, buildKeys, newAdaptivePlan)
name, indices, buildKeys, newAdaptivePlan)
DynamicPruningExpression(InSubqueryExec(value, broadcastValues, exprId))
} else if (onlyInBroadcast) {
DynamicPruningExpression(Literal.TrueLiteral)
} else {
// we need to apply an aggregate on the buildPlan in order to be column pruned
val alias = Alias(buildKeys(index), buildKeys(index).toString)()
val aggregate = Aggregate(Seq(alias), Seq(alias), buildPlan)
val aliases = indices.map(idx => Alias(buildKeys(idx), buildKeys(idx).toString)())
val aggregate = Aggregate(aliases, aliases, buildPlan)

val session = adaptivePlan.context.session
val sparkPlan = QueryExecution.prepareExecutedPlan(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ case class PlanAdaptiveSubqueries(
val subquery = SubqueryExec(s"subquery#${exprId.id}", subqueryMap(exprId.id))
InSubqueryExec(expr, subquery, exprId, isDynamicPruning = false)
case expressions.DynamicPruningSubquery(value, buildPlan,
buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId, _) =>
buildKeys, broadcastKeyIndices, onlyInBroadcast, exprId, _) =>
val name = s"dynamicpruning#${exprId.id}"
val subquery = SubqueryAdaptiveBroadcastExec(name, broadcastKeyIndex, onlyInBroadcast,
val subquery = SubqueryAdaptiveBroadcastExec(name, broadcastKeyIndices, onlyInBroadcast,
buildPlan, buildKeys, subqueryMap(exprId.id))
DynamicPruningExpression(InSubqueryExec(value, subquery, exprId))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,21 +103,24 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join
private def insertPredicate(
pruningKey: Expression,
pruningPlan: LogicalPlan,
filteringKey: Expression,
filteringKeys: Seq[Expression],
filteringPlan: LogicalPlan,
joinKeys: Seq[Expression],
partScan: LogicalPlan): LogicalPlan = {
val reuseEnabled = conf.exchangeReuseEnabled
val index = joinKeys.indexOf(filteringKey)
lazy val hasBenefit = pruningHasBenefit(pruningKey, partScan, filteringKey, filteringPlan)
require(filteringKeys.size == 1, "DPP Filters should only have a single broadcasting key " +
"since there are no usage for multiple broadcasting keys at the moment.")
val indices = Seq(joinKeys.indexOf(filteringKeys.head))
lazy val hasBenefit = pruningHasBenefit(
pruningKey, partScan, filteringKeys.head, filteringPlan)
if (reuseEnabled || hasBenefit) {
// insert a DynamicPruning wrapper to identify the subquery during query planning
Filter(
DynamicPruningSubquery(
pruningKey,
filteringPlan,
joinKeys,
index,
indices,
conf.dynamicPartitionPruningReuseBroadcastOnly || !hasBenefit),
pruningPlan)
} else {
Expand Down Expand Up @@ -255,12 +258,12 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join
var filterableScan = getFilterableTableScan(l, left)
if (filterableScan.isDefined && canPruneLeft(joinType) &&
hasPartitionPruningFilter(right)) {
newLeft = insertPredicate(l, newLeft, r, right, rightKeys, filterableScan.get)
newLeft = insertPredicate(l, newLeft, Seq(r), right, rightKeys, filterableScan.get)
} else {
filterableScan = getFilterableTableScan(r, right)
if (filterableScan.isDefined && canPruneRight(joinType) &&
hasPartitionPruningFilter(left) ) {
newRight = insertPredicate(r, newRight, l, left, leftKeys, filterableScan.get)
newRight = insertPredicate(r, newRight, Seq(l), left, leftKeys, filterableScan.get)
}
}
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) extends Rule[Sp

plan.transformAllExpressionsWithPruning(_.containsPattern(DYNAMIC_PRUNING_SUBQUERY)) {
case DynamicPruningSubquery(
value, buildPlan, buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId, _) =>
value, buildPlan, buildKeys, broadcastKeyIndices, onlyInBroadcast, exprId, _) =>
val sparkPlan = QueryExecution.createSparkPlan(
sparkSession, sparkSession.sessionState.planner, buildPlan)
// Using `sparkPlan` is a little hacky as it is based on the assumption that this rule is
Expand All @@ -73,15 +73,16 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) extends Rule[Sp
val name = s"dynamicpruning#${exprId.id}"
// place the broadcast adaptor for reusing the broadcast results on the probe side
val broadcastValues =
SubqueryBroadcastExec(name, broadcastKeyIndex, buildKeys, exchange)
SubqueryBroadcastExec(name, broadcastKeyIndices, buildKeys, exchange)
DynamicPruningExpression(InSubqueryExec(value, broadcastValues, exprId))
} else if (onlyInBroadcast) {
// it is not worthwhile to execute the query, so we fall-back to a true literal
DynamicPruningExpression(Literal.TrueLiteral)
} else {
// we need to apply an aggregate on the buildPlan in order to be column pruned
val alias = Alias(buildKeys(broadcastKeyIndex), buildKeys(broadcastKeyIndex).toString)()
val aggregate = Aggregate(Seq(alias), Seq(alias), buildPlan)
val aliases = broadcastKeyIndices.map(idx =>
Alias(buildKeys(idx), buildKeys(idx).toString)())
val aggregate = Aggregate(aliases, aliases, buildPlan)
DynamicPruningExpression(expressions.InSubquery(
Seq(value), ListQuery(aggregate, numCols = aggregate.output.length)))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ abstract class DynamicPartitionPruningSuiteBase

val buf = collectDynamicPruningExpressions(df.queryExecution.executedPlan).collect {
case InSubqueryExec(_, b: SubqueryBroadcastExec, _, _, _, _) =>
b.index
b.indices.map(idx => b.buildKeys(idx))
}
assert(buf.distinct.size == n)
}
Expand Down

0 comments on commit 22e5938

Please sign in to comment.