Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-46946][SQL] Supporting broadcast of multiple filtering keys in DynamicPruning #44988

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ trait DynamicPruning extends Predicate
* beneficial and so it should be executed even if it cannot reuse the results of the
* broadcast through ReuseExchange; otherwise, it will use the filter only if it
* can reuse the results of the broadcast through ReuseExchange
* @param broadcastKeyIndex the index of the filtering key collected from the broadcast
* @param broadcastKeyIndices the indices of the filtering keys collected from the broadcast
*/
case class DynamicPruningSubquery(
pruningKey: Expression,
buildQuery: LogicalPlan,
buildKeys: Seq[Expression],
broadcastKeyIndex: Int,
broadcastKeyIndices: Seq[Int],
onlyInBroadcast: Boolean,
exprId: ExprId = NamedExpression.newExprId,
hint: Option[HintInfo] = None)
Expand All @@ -67,10 +67,12 @@ case class DynamicPruningSubquery(
buildQuery.resolved &&
buildKeys.nonEmpty &&
buildKeys.forall(_.resolved) &&
broadcastKeyIndex >= 0 &&
broadcastKeyIndex < buildKeys.size &&
broadcastKeyIndices.forall(idx => idx >= 0 && idx < buildKeys.size) &&
buildKeys.forall(_.references.subsetOf(buildQuery.outputSet)) &&
pruningKey.dataType == buildKeys(broadcastKeyIndex).dataType
// DynamicPruningSubquery should only have a single broadcasting key since
// there are no usage for multiple broadcasting keys at the moment.
broadcastKeyIndices.size == 1 &&
child.dataType == buildKeys(broadcastKeyIndices.head).dataType
longvu-db marked this conversation as resolved.
Show resolved Hide resolved
}

final override def nodePatternsInternal(): Seq[TreePattern] = Seq(DYNAMIC_PRUNING_SUBQUERY)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.types.IntegerType

class DynamicPruningSubquerySuite extends SparkFunSuite {
private val pruningKeyExpression = Literal(1)

private val validDynamicPruningSubquery = DynamicPruningSubquery(
pruningKey = pruningKeyExpression,
buildQuery = Project(Seq(AttributeReference("id", IntegerType)()),
LocalRelation(AttributeReference("id", IntegerType)())),
buildKeys = Seq(pruningKeyExpression),
broadcastKeyIndices = Seq(0),
onlyInBroadcast = false
)

test("pruningKey data type matches single buildKey") {
val dynamicPruningSubquery = validDynamicPruningSubquery
.copy(buildKeys = Seq(Literal(2023)))
assert(dynamicPruningSubquery.resolved == true)
}

test("pruningKey data type is a Struct and matches with Struct buildKey") {
val dynamicPruningSubquery = validDynamicPruningSubquery
.copy(pruningKey = CreateStruct(Seq(Literal(1), Literal.FalseLiteral)),
buildKeys = Seq(CreateStruct(Seq(Literal(2), Literal.TrueLiteral))))
assert(dynamicPruningSubquery.resolved == true)
}

test("multiple buildKeys but only one broadcastKeyIndex") {
val dynamicPruningSubquery = validDynamicPruningSubquery
.copy(buildKeys = Seq(Literal(0), Literal(2), Literal(0), Literal(9)),
broadcastKeyIndices = Seq(1))
assert(dynamicPruningSubquery.resolved == true)
}

test("pruningKey data type does not match the single buildKey") {
val dynamicPruningSubquery = validDynamicPruningSubquery.copy(
pruningKey = Literal.TrueLiteral,
buildKeys = Seq(Literal(2013)))
assert(dynamicPruningSubquery.resolved == false)
}

test("pruningKey data type is a Struct but mismatch with Struct buildKey") {
val dynamicPruningSubquery = validDynamicPruningSubquery
.copy(pruningKey = CreateStruct(Seq(Literal(1), Literal.FalseLiteral)),
buildKeys = Seq(CreateStruct(Seq(Literal.TrueLiteral, Literal(2)))))
assert(dynamicPruningSubquery.resolved == false)
}

test("DynamicPruningSubquery should only have a single broadcasting key") {
val dynamicPruningSubquery = validDynamicPruningSubquery
.copy(buildKeys = Seq(Literal(2025), Literal(2), Literal(1809)),
broadcastKeyIndices = Seq(0, 2))
assert(dynamicPruningSubquery.resolved == false)
}

test("duplicates in broadcastKeyIndices, and also should not be allowed") {
val dynamicPruningSubquery = validDynamicPruningSubquery
.copy(buildKeys = Seq(Literal(2)),
broadcastKeyIndices = Seq(0, 0))
assert(dynamicPruningSubquery.resolved == false)
}

test("broadcastKeyIndex out of bounds") {
val dynamicPruningSubquery = validDynamicPruningSubquery
.copy(broadcastKeyIndices = Seq(1))
assert(dynamicPruningSubquery.resolved == false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
*/
case class SubqueryAdaptiveBroadcastExec(
name: String,
index: Int,
indices: Seq[Int],
onlyInBroadcast: Boolean,
@transient buildPlan: LogicalPlan,
buildKeys: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,34 @@ import org.apache.spark.util.ThreadUtils

/**
* Physical plan for a custom subquery that collects and transforms the broadcast key values.
* This subquery retrieves the partition key from the broadcast results based on the type of
* [[HashedRelation]] returned. If the key is packed inside a Long, we extract it through
* This subquery retrieves the partition keys from the broadcast results based on the type of
* [[HashedRelation]] returned. If a key is packed inside a Long, we extract it through
* bitwise operations, otherwise we return it from the appropriate index of the [[UnsafeRow]].
*
* @param index the index of the join key in the list of keys from the build side
* @param indices the indices of the join keys in the list of keys from the build side
* @param buildKeys the join keys from the build side of the join used
* @param child the BroadcastExchange or the AdaptiveSparkPlan with BroadcastQueryStageExec
* from the build side of the join
*/
case class SubqueryBroadcastExec(
name: String,
index: Int,
indices: Seq[Int],
buildKeys: Seq[Expression],
child: SparkPlan) extends BaseSubqueryExec with UnaryExecNode {

// `SubqueryBroadcastExec` is only used with `InSubqueryExec`. No one would reference this output,
// so the exprId doesn't matter here. But it's important to correctly report the output length, so
// that `InSubqueryExec` can know it's the single-column execution mode, not multi-column.
// that `InSubqueryExec` can know whether it's the single-column or multi-column execution mode.
override def output: Seq[Attribute] = {
val key = buildKeys(index)
val name = key match {
case n: NamedExpression => n.name
case Cast(n: NamedExpression, _, _, _) => n.name
case _ => "key"
indices.map { idx =>
val key = buildKeys(idx)
val name = key match {
case n: NamedExpression => n.name
case Cast(n: NamedExpression, _, _, _) => n.name
case _ => s"key_$idx"
}
AttributeReference(name, key.dataType, key.nullable)()
}
Seq(AttributeReference(name, key.dataType, key.nullable)())
}

override lazy val metrics = Map(
Expand All @@ -69,7 +71,7 @@ case class SubqueryBroadcastExec(

override def doCanonicalize(): SparkPlan = {
val keys = buildKeys.map(k => QueryPlan.normalizeExpressions(k, child.output))
SubqueryBroadcastExec("dpp", index, keys, child.canonicalized)
SubqueryBroadcastExec("dpp", indices, keys, child.canonicalized)
}

@transient
Expand All @@ -84,14 +86,15 @@ case class SubqueryBroadcastExec(
val beforeCollect = System.nanoTime()

val broadcastRelation = child.executeBroadcast[HashedRelation]().value
val (iter, expr) = if (broadcastRelation.isInstanceOf[LongHashedRelation]) {
(broadcastRelation.keys(), HashJoin.extractKeyExprAt(buildKeys, index))
val exprs = if (broadcastRelation.isInstanceOf[LongHashedRelation]) {
indices.map { idx => HashJoin.extractKeyExprAt(buildKeys, idx) }
} else {
(broadcastRelation.keys(),
BoundReference(index, buildKeys(index).dataType, buildKeys(index).nullable))
indices.map { idx =>
BoundReference(idx, buildKeys(idx).dataType, buildKeys(idx).nullable) }
}

val proj = UnsafeProjection.create(expr)
val proj = UnsafeProjection.create(exprs)
val iter = broadcastRelation.keys()
val keyIter = iter.map(proj).map(_.copy())

val rows = if (broadcastRelation.keyIsUnique) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ case class PlanAdaptiveDynamicPruningFilters(
plan.transformAllExpressionsWithPruning(
_.containsAllPatterns(DYNAMIC_PRUNING_EXPRESSION, IN_SUBQUERY_EXEC)) {
case DynamicPruningExpression(InSubqueryExec(
value, SubqueryAdaptiveBroadcastExec(name, index, onlyInBroadcast, buildPlan, buildKeys,
value, SubqueryAdaptiveBroadcastExec(name, indices, onlyInBroadcast, buildPlan, buildKeys,
adaptivePlan: AdaptiveSparkPlanExec), exprId, _, _, _)) =>
val packedKeys = BindReferences.bindReferences(
HashJoin.rewriteKeyExpr(buildKeys), adaptivePlan.executedPlan.output)
Expand All @@ -61,14 +61,14 @@ case class PlanAdaptiveDynamicPruningFilters(
val newAdaptivePlan = adaptivePlan.copy(inputPlan = exchange)

val broadcastValues = SubqueryBroadcastExec(
name, index, buildKeys, newAdaptivePlan)
name, indices, buildKeys, newAdaptivePlan)
DynamicPruningExpression(InSubqueryExec(value, broadcastValues, exprId))
} else if (onlyInBroadcast) {
DynamicPruningExpression(Literal.TrueLiteral)
} else {
// we need to apply an aggregate on the buildPlan in order to be column pruned
val alias = Alias(buildKeys(index), buildKeys(index).toString)()
val aggregate = Aggregate(Seq(alias), Seq(alias), buildPlan)
val aliases = indices.map(idx => Alias(buildKeys(idx), buildKeys(idx).toString)())
val aggregate = Aggregate(aliases, aliases, buildPlan)

val session = adaptivePlan.context.session
val sparkPlan = QueryExecution.prepareExecutedPlan(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ case class PlanAdaptiveSubqueries(
val subquery = SubqueryExec(s"subquery#${exprId.id}", subqueryMap(exprId.id))
InSubqueryExec(expr, subquery, exprId, isDynamicPruning = false)
case expressions.DynamicPruningSubquery(value, buildPlan,
buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId, _) =>
buildKeys, broadcastKeyIndices, onlyInBroadcast, exprId, _) =>
val name = s"dynamicpruning#${exprId.id}"
val subquery = SubqueryAdaptiveBroadcastExec(name, broadcastKeyIndex, onlyInBroadcast,
val subquery = SubqueryAdaptiveBroadcastExec(name, broadcastKeyIndices, onlyInBroadcast,
buildPlan, buildKeys, subqueryMap(exprId.id))
DynamicPruningExpression(InSubqueryExec(value, subquery, exprId))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,21 +103,24 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join
private def insertPredicate(
pruningKey: Expression,
pruningPlan: LogicalPlan,
filteringKey: Expression,
filteringKeys: Seq[Expression],
filteringPlan: LogicalPlan,
joinKeys: Seq[Expression],
partScan: LogicalPlan): LogicalPlan = {
val reuseEnabled = conf.exchangeReuseEnabled
val index = joinKeys.indexOf(filteringKey)
lazy val hasBenefit = pruningHasBenefit(pruningKey, partScan, filteringKey, filteringPlan)
require(filteringKeys.size == 1, "DPP Filters should only have a single broadcasting key " +
"since there are no usage for multiple broadcasting keys at the moment.")
val indices = Seq(joinKeys.indexOf(filteringKeys.head))
lazy val hasBenefit = pruningHasBenefit(
pruningKey, partScan, filteringKeys.head, filteringPlan)
if (reuseEnabled || hasBenefit) {
// insert a DynamicPruning wrapper to identify the subquery during query planning
Filter(
DynamicPruningSubquery(
pruningKey,
filteringPlan,
joinKeys,
index,
indices,
conf.dynamicPartitionPruningReuseBroadcastOnly || !hasBenefit),
pruningPlan)
} else {
Expand Down Expand Up @@ -255,12 +258,12 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join
var filterableScan = getFilterableTableScan(l, left)
if (filterableScan.isDefined && canPruneLeft(joinType) &&
hasPartitionPruningFilter(right)) {
newLeft = insertPredicate(l, newLeft, r, right, rightKeys, filterableScan.get)
newLeft = insertPredicate(l, newLeft, Seq(r), right, rightKeys, filterableScan.get)
} else {
filterableScan = getFilterableTableScan(r, right)
if (filterableScan.isDefined && canPruneRight(joinType) &&
hasPartitionPruningFilter(left) ) {
newRight = insertPredicate(r, newRight, l, left, leftKeys, filterableScan.get)
newRight = insertPredicate(r, newRight, Seq(l), left, leftKeys, filterableScan.get)
}
}
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) extends Rule[Sp

plan.transformAllExpressionsWithPruning(_.containsPattern(DYNAMIC_PRUNING_SUBQUERY)) {
case DynamicPruningSubquery(
value, buildPlan, buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId, _) =>
value, buildPlan, buildKeys, broadcastKeyIndices, onlyInBroadcast, exprId, _) =>
val sparkPlan = QueryExecution.createSparkPlan(
sparkSession, sparkSession.sessionState.planner, buildPlan)
// Using `sparkPlan` is a little hacky as it is based on the assumption that this rule is
Expand All @@ -73,15 +73,16 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) extends Rule[Sp
val name = s"dynamicpruning#${exprId.id}"
// place the broadcast adaptor for reusing the broadcast results on the probe side
val broadcastValues =
SubqueryBroadcastExec(name, broadcastKeyIndex, buildKeys, exchange)
SubqueryBroadcastExec(name, broadcastKeyIndices, buildKeys, exchange)
DynamicPruningExpression(InSubqueryExec(value, broadcastValues, exprId))
} else if (onlyInBroadcast) {
// it is not worthwhile to execute the query, so we fall-back to a true literal
DynamicPruningExpression(Literal.TrueLiteral)
} else {
// we need to apply an aggregate on the buildPlan in order to be column pruned
val alias = Alias(buildKeys(broadcastKeyIndex), buildKeys(broadcastKeyIndex).toString)()
val aggregate = Aggregate(Seq(alias), Seq(alias), buildPlan)
val aliases = broadcastKeyIndices.map(idx =>
Alias(buildKeys(idx), buildKeys(idx).toString)())
val aggregate = Aggregate(aliases, aliases, buildPlan)
DynamicPruningExpression(expressions.InSubquery(
Seq(value), ListQuery(aggregate, numCols = aggregate.output.length)))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ abstract class DynamicPartitionPruningSuiteBase

val buf = collectDynamicPruningExpressions(df.queryExecution.executedPlan).collect {
case InSubqueryExec(_, b: SubqueryBroadcastExec, _, _, _, _) =>
b.index
b.indices.map(idx => b.buildKeys(idx))
}
assert(buf.distinct.size == n)
}
Expand Down