From 22e5938aefc784f50218a86e013e4c2247271072 Mon Sep 17 00:00:00 2001 From: Thang Long VU Date: Fri, 2 Feb 2024 19:55:34 +0800 Subject: [PATCH] [SPARK-46946][SQL] Supporting broadcast of multiple filtering keys in DynamicPruning ### What changes were proposed in this pull request? This PR extends `DynamicPruningSubquery` to support broadcasting of multiple filtering keys (instead of one as before). The majority of the PR is to simply generalise singularity to plurality. **Note:** We actually do not use the multiple filtering keys `DynamicPruningSubquery` in this PR, we are doing this to make supporting DPP Null Safe Equality or multiple Equality predicates easier in the future. In Null Safe Equality JOIN, the JOIN condition `a <=> b` is transformed to `Coalesce(key1, Literal(key1.dataType)) = Coalesce(key2, Literal(key2.dataType)) AND IsNull(key1) = IsNull(key2)`. In order to have the highest pruning efficiency, we broadcast the 2 keys `Coalesce(key, Literal(key.dataType))` and `IsNull(key)` and use them to prune the other side at the same time. Before, the `DynamicPruningSubquery` only has one broadcasting key and we only supports DPP for one `EqualTo` JOIN predicate, now we are extending the subquery to multiple broadcasting keys. Please note that DPP has not been supported for multiple JOIN predicates. Put it in another way, at the moment, we don't insert a DPP Filter for multiple JOIN predicates at the same time, only potentially insert a DPP Filter for a given Equality JOIN predicate. ### Why are the changes needed? To make supporting DPP Null Safe Equality or DPP multiple Equality predicates easier in the future. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #44988 from longvu-db/multiple-broadcast-filtering-keys. Authored-by: Thang Long VU Signed-off-by: Wenchen Fan --- .../catalyst/expressions/DynamicPruning.scala | 12 +-- .../DynamicPruningSubquerySuite.scala | 89 +++++++++++++++++++ .../SubqueryAdaptiveBroadcastExec.scala | 2 +- .../sql/execution/SubqueryBroadcastExec.scala | 37 ++++---- .../PlanAdaptiveDynamicPruningFilters.scala | 8 +- .../adaptive/PlanAdaptiveSubqueries.scala | 4 +- .../dynamicpruning/PartitionPruning.scala | 15 ++-- .../PlanDynamicPruningFilters.scala | 9 +- .../sql/DynamicPartitionPruningSuite.scala | 2 +- 9 files changed, 138 insertions(+), 40 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruningSubquerySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala index ec6925eaa9842..cc24a982d5d85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala @@ -37,13 +37,13 @@ trait DynamicPruning extends Predicate * beneficial and so it should be executed even if it cannot reuse the results of the * broadcast through ReuseExchange; otherwise, it will use the filter only if it * can reuse the results of the broadcast through ReuseExchange - * @param broadcastKeyIndex the index of the filtering key collected from the broadcast + * @param broadcastKeyIndices the indices of the filtering keys collected from the broadcast */ case class DynamicPruningSubquery( pruningKey: Expression, buildQuery: LogicalPlan, buildKeys: Seq[Expression], - broadcastKeyIndex: Int, + broadcastKeyIndices: Seq[Int], onlyInBroadcast: Boolean, exprId: ExprId = NamedExpression.newExprId, hint: Option[HintInfo] = None) @@ -67,10 +67,12 @@ case class DynamicPruningSubquery( buildQuery.resolved && buildKeys.nonEmpty && buildKeys.forall(_.resolved) && - broadcastKeyIndex >= 0 && - broadcastKeyIndex < buildKeys.size && + broadcastKeyIndices.forall(idx => idx >= 0 && idx < buildKeys.size) && buildKeys.forall(_.references.subsetOf(buildQuery.outputSet)) && - pruningKey.dataType == buildKeys(broadcastKeyIndex).dataType + // DynamicPruningSubquery should only have a single broadcasting key since + // there are no usage for multiple broadcasting keys at the moment. + broadcastKeyIndices.size == 1 && + child.dataType == buildKeys(broadcastKeyIndices.head).dataType } final override def nodePatternsInternal(): Seq[TreePattern] = Seq(DYNAMIC_PRUNING_SUBQUERY) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruningSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruningSubquerySuite.scala new file mode 100644 index 0000000000000..9d7d756019bdb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruningSubquerySuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.types.IntegerType + +class DynamicPruningSubquerySuite extends SparkFunSuite { + private val pruningKeyExpression = Literal(1) + + private val validDynamicPruningSubquery = DynamicPruningSubquery( + pruningKey = pruningKeyExpression, + buildQuery = Project(Seq(AttributeReference("id", IntegerType)()), + LocalRelation(AttributeReference("id", IntegerType)())), + buildKeys = Seq(pruningKeyExpression), + broadcastKeyIndices = Seq(0), + onlyInBroadcast = false + ) + + test("pruningKey data type matches single buildKey") { + val dynamicPruningSubquery = validDynamicPruningSubquery + .copy(buildKeys = Seq(Literal(2023))) + assert(dynamicPruningSubquery.resolved == true) + } + + test("pruningKey data type is a Struct and matches with Struct buildKey") { + val dynamicPruningSubquery = validDynamicPruningSubquery + .copy(pruningKey = CreateStruct(Seq(Literal(1), Literal.FalseLiteral)), + buildKeys = Seq(CreateStruct(Seq(Literal(2), Literal.TrueLiteral)))) + assert(dynamicPruningSubquery.resolved == true) + } + + test("multiple buildKeys but only one broadcastKeyIndex") { + val dynamicPruningSubquery = validDynamicPruningSubquery + .copy(buildKeys = Seq(Literal(0), Literal(2), Literal(0), Literal(9)), + broadcastKeyIndices = Seq(1)) + assert(dynamicPruningSubquery.resolved == true) + } + + test("pruningKey data type does not match the single buildKey") { + val dynamicPruningSubquery = validDynamicPruningSubquery.copy( + pruningKey = Literal.TrueLiteral, + buildKeys = Seq(Literal(2013))) + assert(dynamicPruningSubquery.resolved == false) + } + + test("pruningKey data type is a Struct but mismatch with Struct buildKey") { + val dynamicPruningSubquery = validDynamicPruningSubquery + .copy(pruningKey = CreateStruct(Seq(Literal(1), Literal.FalseLiteral)), + buildKeys = Seq(CreateStruct(Seq(Literal.TrueLiteral, Literal(2))))) + assert(dynamicPruningSubquery.resolved == false) + } + + test("DynamicPruningSubquery should only have a single broadcasting key") { + val dynamicPruningSubquery = validDynamicPruningSubquery + .copy(buildKeys = Seq(Literal(2025), Literal(2), Literal(1809)), + broadcastKeyIndices = Seq(0, 2)) + assert(dynamicPruningSubquery.resolved == false) + } + + test("duplicates in broadcastKeyIndices, and also should not be allowed") { + val dynamicPruningSubquery = validDynamicPruningSubquery + .copy(buildKeys = Seq(Literal(2)), + broadcastKeyIndices = Seq(0, 0)) + assert(dynamicPruningSubquery.resolved == false) + } + + test("broadcastKeyIndex out of bounds") { + val dynamicPruningSubquery = validDynamicPruningSubquery + .copy(broadcastKeyIndices = Seq(1)) + assert(dynamicPruningSubquery.resolved == false) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryAdaptiveBroadcastExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryAdaptiveBroadcastExec.scala index e7092ee91d766..555f4f41d3cd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryAdaptiveBroadcastExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryAdaptiveBroadcastExec.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors */ case class SubqueryAdaptiveBroadcastExec( name: String, - index: Int, + indices: Seq[Int], onlyInBroadcast: Boolean, @transient buildPlan: LogicalPlan, buildKeys: Seq[Expression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala index 05657fe62e8e5..9e7c1193c8ae0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala @@ -34,32 +34,34 @@ import org.apache.spark.util.ThreadUtils /** * Physical plan for a custom subquery that collects and transforms the broadcast key values. - * This subquery retrieves the partition key from the broadcast results based on the type of - * [[HashedRelation]] returned. If the key is packed inside a Long, we extract it through + * This subquery retrieves the partition keys from the broadcast results based on the type of + * [[HashedRelation]] returned. If a key is packed inside a Long, we extract it through * bitwise operations, otherwise we return it from the appropriate index of the [[UnsafeRow]]. * - * @param index the index of the join key in the list of keys from the build side + * @param indices the indices of the join keys in the list of keys from the build side * @param buildKeys the join keys from the build side of the join used * @param child the BroadcastExchange or the AdaptiveSparkPlan with BroadcastQueryStageExec * from the build side of the join */ case class SubqueryBroadcastExec( name: String, - index: Int, + indices: Seq[Int], buildKeys: Seq[Expression], child: SparkPlan) extends BaseSubqueryExec with UnaryExecNode { // `SubqueryBroadcastExec` is only used with `InSubqueryExec`. No one would reference this output, // so the exprId doesn't matter here. But it's important to correctly report the output length, so - // that `InSubqueryExec` can know it's the single-column execution mode, not multi-column. + // that `InSubqueryExec` can know whether it's the single-column or multi-column execution mode. override def output: Seq[Attribute] = { - val key = buildKeys(index) - val name = key match { - case n: NamedExpression => n.name - case Cast(n: NamedExpression, _, _, _) => n.name - case _ => "key" + indices.map { idx => + val key = buildKeys(idx) + val name = key match { + case n: NamedExpression => n.name + case Cast(n: NamedExpression, _, _, _) => n.name + case _ => s"key_$idx" + } + AttributeReference(name, key.dataType, key.nullable)() } - Seq(AttributeReference(name, key.dataType, key.nullable)()) } override lazy val metrics = Map( @@ -69,7 +71,7 @@ case class SubqueryBroadcastExec( override def doCanonicalize(): SparkPlan = { val keys = buildKeys.map(k => QueryPlan.normalizeExpressions(k, child.output)) - SubqueryBroadcastExec("dpp", index, keys, child.canonicalized) + SubqueryBroadcastExec("dpp", indices, keys, child.canonicalized) } @transient @@ -84,14 +86,15 @@ case class SubqueryBroadcastExec( val beforeCollect = System.nanoTime() val broadcastRelation = child.executeBroadcast[HashedRelation]().value - val (iter, expr) = if (broadcastRelation.isInstanceOf[LongHashedRelation]) { - (broadcastRelation.keys(), HashJoin.extractKeyExprAt(buildKeys, index)) + val exprs = if (broadcastRelation.isInstanceOf[LongHashedRelation]) { + indices.map { idx => HashJoin.extractKeyExprAt(buildKeys, idx) } } else { - (broadcastRelation.keys(), - BoundReference(index, buildKeys(index).dataType, buildKeys(index).nullable)) + indices.map { idx => + BoundReference(idx, buildKeys(idx).dataType, buildKeys(idx).nullable) } } - val proj = UnsafeProjection.create(expr) + val proj = UnsafeProjection.create(exprs) + val iter = broadcastRelation.keys() val keyIter = iter.map(proj).map(_.copy()) val rows = if (broadcastRelation.keyIsUnique) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala index 9a780c11eefab..3d35abff3c538 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala @@ -39,7 +39,7 @@ case class PlanAdaptiveDynamicPruningFilters( plan.transformAllExpressionsWithPruning( _.containsAllPatterns(DYNAMIC_PRUNING_EXPRESSION, IN_SUBQUERY_EXEC)) { case DynamicPruningExpression(InSubqueryExec( - value, SubqueryAdaptiveBroadcastExec(name, index, onlyInBroadcast, buildPlan, buildKeys, + value, SubqueryAdaptiveBroadcastExec(name, indices, onlyInBroadcast, buildPlan, buildKeys, adaptivePlan: AdaptiveSparkPlanExec), exprId, _, _, _)) => val packedKeys = BindReferences.bindReferences( HashJoin.rewriteKeyExpr(buildKeys), adaptivePlan.executedPlan.output) @@ -61,14 +61,14 @@ case class PlanAdaptiveDynamicPruningFilters( val newAdaptivePlan = adaptivePlan.copy(inputPlan = exchange) val broadcastValues = SubqueryBroadcastExec( - name, index, buildKeys, newAdaptivePlan) + name, indices, buildKeys, newAdaptivePlan) DynamicPruningExpression(InSubqueryExec(value, broadcastValues, exprId)) } else if (onlyInBroadcast) { DynamicPruningExpression(Literal.TrueLiteral) } else { // we need to apply an aggregate on the buildPlan in order to be column pruned - val alias = Alias(buildKeys(index), buildKeys(index).toString)() - val aggregate = Aggregate(Seq(alias), Seq(alias), buildPlan) + val aliases = indices.map(idx => Alias(buildKeys(idx), buildKeys(idx).toString)()) + val aggregate = Aggregate(aliases, aliases, buildPlan) val session = adaptivePlan.context.session val sparkPlan = QueryExecution.prepareExecutedPlan( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala index 7816fbd52c0a2..df4d895867586 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -47,9 +47,9 @@ case class PlanAdaptiveSubqueries( val subquery = SubqueryExec(s"subquery#${exprId.id}", subqueryMap(exprId.id)) InSubqueryExec(expr, subquery, exprId, isDynamicPruning = false) case expressions.DynamicPruningSubquery(value, buildPlan, - buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId, _) => + buildKeys, broadcastKeyIndices, onlyInBroadcast, exprId, _) => val name = s"dynamicpruning#${exprId.id}" - val subquery = SubqueryAdaptiveBroadcastExec(name, broadcastKeyIndex, onlyInBroadcast, + val subquery = SubqueryAdaptiveBroadcastExec(name, broadcastKeyIndices, onlyInBroadcast, buildPlan, buildKeys, subqueryMap(exprId.id)) DynamicPruningExpression(InSubqueryExec(value, subquery, exprId)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala index 4e52137b74271..ef22c0ab44e4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala @@ -103,13 +103,16 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join private def insertPredicate( pruningKey: Expression, pruningPlan: LogicalPlan, - filteringKey: Expression, + filteringKeys: Seq[Expression], filteringPlan: LogicalPlan, joinKeys: Seq[Expression], partScan: LogicalPlan): LogicalPlan = { val reuseEnabled = conf.exchangeReuseEnabled - val index = joinKeys.indexOf(filteringKey) - lazy val hasBenefit = pruningHasBenefit(pruningKey, partScan, filteringKey, filteringPlan) + require(filteringKeys.size == 1, "DPP Filters should only have a single broadcasting key " + + "since there are no usage for multiple broadcasting keys at the moment.") + val indices = Seq(joinKeys.indexOf(filteringKeys.head)) + lazy val hasBenefit = pruningHasBenefit( + pruningKey, partScan, filteringKeys.head, filteringPlan) if (reuseEnabled || hasBenefit) { // insert a DynamicPruning wrapper to identify the subquery during query planning Filter( @@ -117,7 +120,7 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join pruningKey, filteringPlan, joinKeys, - index, + indices, conf.dynamicPartitionPruningReuseBroadcastOnly || !hasBenefit), pruningPlan) } else { @@ -255,12 +258,12 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join var filterableScan = getFilterableTableScan(l, left) if (filterableScan.isDefined && canPruneLeft(joinType) && hasPartitionPruningFilter(right)) { - newLeft = insertPredicate(l, newLeft, r, right, rightKeys, filterableScan.get) + newLeft = insertPredicate(l, newLeft, Seq(r), right, rightKeys, filterableScan.get) } else { filterableScan = getFilterableTableScan(r, right) if (filterableScan.isDefined && canPruneRight(joinType) && hasPartitionPruningFilter(left) ) { - newRight = insertPredicate(r, newRight, l, left, leftKeys, filterableScan.get) + newRight = insertPredicate(r, newRight, Seq(l), left, leftKeys, filterableScan.get) } } case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala index fef92edbce649..3a08b13be0134 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala @@ -51,7 +51,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) extends Rule[Sp plan.transformAllExpressionsWithPruning(_.containsPattern(DYNAMIC_PRUNING_SUBQUERY)) { case DynamicPruningSubquery( - value, buildPlan, buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId, _) => + value, buildPlan, buildKeys, broadcastKeyIndices, onlyInBroadcast, exprId, _) => val sparkPlan = QueryExecution.createSparkPlan( sparkSession, sparkSession.sessionState.planner, buildPlan) // Using `sparkPlan` is a little hacky as it is based on the assumption that this rule is @@ -73,15 +73,16 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) extends Rule[Sp val name = s"dynamicpruning#${exprId.id}" // place the broadcast adaptor for reusing the broadcast results on the probe side val broadcastValues = - SubqueryBroadcastExec(name, broadcastKeyIndex, buildKeys, exchange) + SubqueryBroadcastExec(name, broadcastKeyIndices, buildKeys, exchange) DynamicPruningExpression(InSubqueryExec(value, broadcastValues, exprId)) } else if (onlyInBroadcast) { // it is not worthwhile to execute the query, so we fall-back to a true literal DynamicPruningExpression(Literal.TrueLiteral) } else { // we need to apply an aggregate on the buildPlan in order to be column pruned - val alias = Alias(buildKeys(broadcastKeyIndex), buildKeys(broadcastKeyIndex).toString)() - val aggregate = Aggregate(Seq(alias), Seq(alias), buildPlan) + val aliases = broadcastKeyIndices.map(idx => + Alias(buildKeys(idx), buildKeys(idx).toString)()) + val aggregate = Aggregate(aliases, aliases, buildPlan) DynamicPruningExpression(expressions.InSubquery( Seq(value), ListQuery(aggregate, numCols = aggregate.output.length))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala index 50dcb9d718978..2c24cc7d570ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -246,7 +246,7 @@ abstract class DynamicPartitionPruningSuiteBase val buf = collectDynamicPruningExpressions(df.queryExecution.executedPlan).collect { case InSubqueryExec(_, b: SubqueryBroadcastExec, _, _, _, _) => - b.index + b.indices.map(idx => b.buildKeys(idx)) } assert(buf.distinct.size == n) }