Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support split non-AST-able join condition for BroadcastNestedLoopJoin #9635

Merged
merged 10 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import scala.collection.mutable
import scala.collection.mutable.ListBuffer

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSeq, Expression, NamedExpression}
import org.apache.spark.sql.rapids.catalyst.expressions.{GpuEquivalentExpressions, GpuExpressionEquals}


object AstUtil {

/**
* Check whether it can be split into non-ast sub-expression if needed
*
* @return true when: 1) If all ast-able in expr; 2) all non-ast-able tree nodes don't contain
* attributes from both join sides. In such case, it's not able
* to push down into single child.
*/
def canExtractNonAstConditionIfNeed(expr: BaseExprMeta[_], left: Seq[Attribute],
revans2 marked this conversation as resolved.
Show resolved Hide resolved
right: Seq[Attribute]): Boolean = {
if (!expr.canSelfBeAst) {
// It needs to be split since not ast-able. Check itself and childerns to ensure
// pushing-down can be made, which doesn't need attributions from both sides.
val exprRef = expr.wrapped.asInstanceOf[Expression]
val leftTree = exprRef.references.exists(left.contains(_))
val rightTree = exprRef.references.exists(right.contains(_))
// Can't extra a condition involving columns from both sides
winningsix marked this conversation as resolved.
Show resolved Hide resolved
!(rightTree && leftTree)
} else {
// Check whether any child contains the case not able to split
expr.childExprs.isEmpty || expr.childExprs.forall(
canExtractNonAstConditionIfNeed(_, left, right))
}
}

/**
*
winningsix marked this conversation as resolved.
Show resolved Hide resolved
* @param condition to be split if needed
* @param left attributions from left child
* @param right attributions from right child
* @param skipCheck whether skip split-able check
* @return a tuple of [[Expression]] for remained expressions, List of [[NamedExpression]] for
* left child if any, List of [[NamedExpression]] for right child if any
*/
def extractNonAstFromJoinCond(condition: Option[BaseExprMeta[_]],
left: AttributeSeq, right: AttributeSeq, skipCheck: Boolean):
(Option[Expression], List[NamedExpression], List[NamedExpression]) = {
// Choose side with smaller key size
val (childAtt, isLeft) =
if (left.attrs.size < right.attrs.size) (left, true) else (right, false)
// List of expression pushing down to left side child
val leftExprs: ListBuffer[NamedExpression] = ListBuffer.empty
// List of expression pushing down to right side child
val rightExprs: ListBuffer[NamedExpression] = ListBuffer.empty
// Substitution map used to replace targeted expressions based on semantic equality
val substitutionMap = mutable.HashMap.empty[GpuExpressionEquals, Expression]

// 1st step to construct 1) left expr list; 2) right expr list; 3) substitutionMap
// No need to consider common sub-expressions here since project node will use tiered execution
condition.foreach(c =>
if (skipCheck || canExtractNonAstConditionIfNeed(c, left.attrs, right.attrs)) {
splitNonAstInternal(c, childAtt.attrs, leftExprs, rightExprs, substitutionMap, isLeft)
})

// 2nd step to replace expression pushing down to child plans in depth first fashion
(condition.map(
_.convertToGpu().mapChildren(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be fine, but I am not 100% sure that convertToGpu will be idempotent. It might be better to just convert it to the GPU up front and cache it instead of recalling it inside of splitNonAstInternal. To make that work we probably would have to zip the GPU expression and along with condition.childExprs. I don't think this is a blocker for the patch to go in. Just being overly cautious about it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, it seems not very straight forward to eliminate non-idempotent risk here. convertToGpu only happens when we need to replace it with a new GpuAlias node. We can probably cache those converted expression but in the end (L100), the root BaseExprMeta will do a real convertToGpu which converts the entire tree. It seems we can hardly reuse the cache here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine for now, so lets file a follow on issue to look into it. I see a few ways to fix it, but none of them are that simple, and may have their own hidden pitfalls.

GpuEquivalentExpressions.replaceWithSemanticCommonRef(_,
substitutionMap))), leftExprs.toList, rightExprs.toList)
}

private[this] def splitNonAstInternal(condition: BaseExprMeta[_], childAtt: Seq[Attribute],
left: ListBuffer[NamedExpression], right: ListBuffer[NamedExpression],
substitutionMap: mutable.HashMap[GpuExpressionEquals, Expression], isLeft: Boolean): Unit = {
for (child <- condition.childExprs) {
if (!child.canSelfBeAst) {
val exprRef = child.wrapped.asInstanceOf[Expression]
val gpuProj = child.convertToGpu()
val alias = substitutionMap.get(GpuExpressionEquals(gpuProj)) match {
case Some(_) => None
case None =>
if (exprRef.references.exists(childAtt.contains(_)) ^ isLeft) {
val alias = GpuAlias(gpuProj, s"_agpu_non_ast_r_${left.size}")()
right += alias
Some(alias)
} else {
val alias = GpuAlias(gpuProj, s"_agpu_non_ast_l_${left.size}")()
left += alias
Some(alias)
}
}
alias.foreach(a => substitutionMap.put(GpuExpressionEquals(gpuProj), a.toAttribute))
} else {
splitNonAstInternal(child, childAtt, left, right, substitutionMap, isLeft)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,15 @@ abstract class BaseExprMeta[INPUT <: Expression](
childExprs.forall(_.canThisBeAst) && cannotBeAstReasons.isEmpty
}

/**
* Check whether this node itself can be converted to AST. It will not recursively check its
* children. It's used to check join condition AST-ability in top-down fashion.
*/
final def canSelfBeAst: Boolean = {
winningsix marked this conversation as resolved.
Show resolved Hide resolved
tagForAst()
cannotBeAstReasons.isEmpty
}

final def requireAstForGpu(): Unit = {
tagForAst()
cannotBeAstReasons.foreach { reason =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,22 @@ class GpuEquivalentExpressions {
}

object GpuEquivalentExpressions {
/**
* Recursively replaces semantic equal expression with its proxy expression in `substitutionMap`.
*/
def replaceWithSemanticCommonRef(
expr: Expression,
substitutionMap: mutable.HashMap[GpuExpressionEquals, Expression]): Expression = {
expr match {
case e: AttributeReference => e
case _ =>
substitutionMap.get(GpuExpressionEquals(expr)) match {
case Some(attr) => attr
case None => expr.mapChildren(replaceWithSemanticCommonRef(_, substitutionMap))
}
}
}

/**
* Recursively replaces expression with its proxy expression in `substitutionMap`.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, IdentityBroadcastMode, UnspecifiedDistribution}
import org.apache.spark.sql.execution.SparkPlan
Expand All @@ -50,6 +50,28 @@ abstract class GpuBroadcastNestedLoopJoinMetaBase(

val gpuBuildSide: GpuBuildSide = GpuJoinUtils.getGpuBuildSide(join.buildSide)

private var taggedForAstCheck = false

// Avoid checking multiple times
private var isAstCond = false

/**
* Check whether condition can be ast-able. It includes two cases: 1) all join conditions are
* ast-able; 2) join conditions are ast-able after split and push down to child plans.
*/
protected def canJoinCondAstAble(): Boolean = {
if (!taggedForAstCheck) {
val Seq(leftPlan, rightPlan) = childPlans
conditionMeta match {
case Some(e) => isAstCond = AstUtil.canExtractNonAstConditionIfNeed(
e, leftPlan.outputAttributes, rightPlan.outputAttributes)
case None => isAstCond = true
}
taggedForAstCheck = true
}
isAstCond
}

override def namedChildExprs: Map[String, Seq[BaseExprMeta[_]]] =
JoinTypeChecks.nonEquiJoinMeta(conditionMeta)

Expand All @@ -60,7 +82,9 @@ abstract class GpuBroadcastNestedLoopJoinMetaBase(
join.joinType match {
case _: InnerLike =>
case LeftOuter | RightOuter | LeftSemi | LeftAnti | ExistenceJoin(_) =>
conditionMeta.foreach(requireAstForGpuOn)
// First to check whether can be split if not ast-able. If false, then check requireAst to
// send not-work-on-GPU reason if not replace-able.
conditionMeta.foreach(cond => if (!canJoinCondAstAble()) requireAstForGpuOn(cond))
case _ => willNotWorkOnGpu(s"${join.joinType} currently is not supported")
}
join.joinType match {
Expand Down Expand Up @@ -383,12 +407,16 @@ object GpuBroadcastNestedLoopJoinExecBase {
}
}

// postBuildCondition is the post-broadcast project condition. It's used to re-construct a tiered
// project to handle pre-built batch. It will be removed after code refactor to decouple
// broadcast and nested loop join.
abstract class GpuBroadcastNestedLoopJoinExecBase(
left: SparkPlan,
right: SparkPlan,
joinType: JoinType,
gpuBuildSide: GpuBuildSide,
condition: Option[Expression],
postBuildCondition: List[NamedExpression],
targetSizeBytes: Long) extends ShimBinaryExecNode with GpuExec {

import GpuMetric._
Expand All @@ -411,7 +439,7 @@ abstract class GpuBroadcastNestedLoopJoinExecBase(
case GpuBuildLeft => (right, left)
}

def broadcastExchange: GpuBroadcastExchangeExecBase = buildPlan match {
def broadcastExchange: GpuBroadcastExchangeExecBase = getBroadcastPlan(buildPlan) match {
case bqse: BroadcastQueryStageExec if bqse.plan.isInstanceOf[GpuBroadcastExchangeExecBase] =>
bqse.plan.asInstanceOf[GpuBroadcastExchangeExecBase]
case bqse: BroadcastQueryStageExec if bqse.plan.isInstanceOf[ReusedExchangeExec] =>
Expand All @@ -420,6 +448,15 @@ abstract class GpuBroadcastNestedLoopJoinExecBase(
case reused: ReusedExchangeExec => reused.child.asInstanceOf[GpuBroadcastExchangeExecBase]
}

private[this] def getBroadcastPlan(plan: SparkPlan): SparkPlan = {
plan match {
// In case has post broadcast project. It happens when join condition contains non-AST
// expression which results in a project right after broadcast.
case plan: GpuProjectExec => plan.child
case _ => plan
}
}

override def requiredChildDistribution: Seq[Distribution] = gpuBuildSide match {
case GpuBuildLeft =>
BroadcastDistribution(IdentityBroadcastMode) :: UnspecifiedDistribution :: Nil
Expand Down Expand Up @@ -468,7 +505,7 @@ abstract class GpuBroadcastNestedLoopJoinExecBase(
}
}

protected def makeBuiltBatch(
protected def makeBuiltBatchInternal(
relation: Any,
buildTime: GpuMetric,
buildDataSize: GpuMetric): ColumnarBatch = {
Expand All @@ -477,6 +514,24 @@ abstract class GpuBroadcastNestedLoopJoinExecBase(
makeBroadcastBuiltBatch(broadcastRelation, buildTime, buildDataSize)
}

final def makeBuiltBatch(
relation: Any,
buildTime: GpuMetric,
buildDataSize: GpuMetric): ColumnarBatch = {
buildPlan match {
case p: GpuProjectExec =>
// Need to manually do project columnar execution other than calling child's
// internalDoExecuteColumnar. This is to workaround especial handle to build broadcast
// batch.
val proj = GpuBindReferences.bindGpuReferencesTiered(
postBuildCondition, p.child.output, true)
withResource(makeBuiltBatchInternal(relation, buildTime, buildDataSize)) {
cb => proj.project(cb)
}
case _ => makeBuiltBatchInternal(relation, buildTime, buildDataSize)
}
}

protected def computeBuildRowCount(
relation: Any,
buildTime: GpuMetric,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ package org.apache.spark.sql.rapids.execution

import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec
Expand All @@ -58,28 +58,59 @@ class GpuBroadcastNestedLoopJoinMeta(
}
verifyBuildSideWasReplaced(buildSide)

val condition = conditionMeta.map(_.convertToGpu())
val isAstCondition = conditionMeta.forall(_.canThisBeAst)
join.joinType match {
case _: InnerLike =>
case LeftOuter | LeftSemi | LeftAnti if gpuBuildSide == GpuBuildLeft =>
throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}")
case RightOuter if gpuBuildSide == GpuBuildRight =>
throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}")
case LeftOuter | RightOuter | LeftSemi | LeftAnti | ExistenceJoin(_) =>
// Cannot post-filter these types of joins
assert(isAstCondition, s"Non-AST condition in ${join.joinType}")
case _ => throw new IllegalStateException(s"Unsupported join type ${join.joinType}")
}
// If ast-able, try to split if needed. Otherwise, do post-filter
val isAstCondition = canJoinCondAstAble()

if(isAstCondition){
// Try to extract non-ast-able conditions from join conditions
val (remains, leftExpr, rightExpr) = AstUtil.extractNonAstFromJoinCond(
conditionMeta, left.output, right.output, true)

// Reconstruct the childern with wrapped project node if needed.
val leftChild =
if (!leftExpr.isEmpty) GpuProjectExec(leftExpr ++ left.output, left)(true) else left
val rightChild =
if (!rightExpr.isEmpty) GpuProjectExec(rightExpr ++ right.output, right)(true) else right
val postBoardcastCondition =
if (gpuBuildSide == GpuBuildLeft) leftExpr ++ left.output else rightExpr ++ right.output

val joinExec = GpuBroadcastNestedLoopJoinExec(
left, right,
join.joinType, gpuBuildSide,
if (isAstCondition) condition else None,
conf.gpuTargetBatchSizeBytes)
if (isAstCondition) {
joinExec
// TODO: a code refactor is needed to skip passing in postBoardcastCondition as a parameter to
// instantiate GpuBroadcastNestedLoopJoinExec. This is because currently output columnar batch
// of broadcast side is handled inside GpuBroadcastNestedLoopJoinExec. Have to manually build
// a project node to build side batch.
val joinExec = GpuBroadcastNestedLoopJoinExec(
leftChild, rightChild,
join.joinType, gpuBuildSide,
remains,
postBoardcastCondition,
conf.gpuTargetBatchSizeBytes)
if (leftExpr.isEmpty && rightExpr.isEmpty) {
joinExec
} else {
// Remove the intermediate attributes from left and right side project nodes
GpuProjectExec((left.output ++ right.output).toList, joinExec)(false)
}
} else {
join.joinType match {
case _: InnerLike =>
case LeftOuter | LeftSemi | LeftAnti if gpuBuildSide == GpuBuildLeft =>
throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}")
case RightOuter if gpuBuildSide == GpuBuildRight =>
throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}")
case LeftOuter | RightOuter | LeftSemi | LeftAnti | ExistenceJoin(_) =>
// Cannot post-filter these types of joins
assert(isAstCondition, s"Non-AST condition in ${join.joinType}")
case _ => throw new IllegalStateException(s"Unsupported join type ${join.joinType}")
}
val condition = conditionMeta.map(_.convertToGpu())

val joinExec = GpuBroadcastNestedLoopJoinExec(
left, right,
join.joinType, gpuBuildSide,
None,
List.empty,
conf.gpuTargetBatchSizeBytes)

// condition cannot be implemented via AST so fallback to a post-filter if necessary
condition.map {
// TODO: Restore batch coalescing logic here.
Expand All @@ -94,13 +125,13 @@ class GpuBroadcastNestedLoopJoinMeta(
}
}


case class GpuBroadcastNestedLoopJoinExec(
left: SparkPlan,
right: SparkPlan,
joinType: JoinType,
gpuBuildSide: GpuBuildSide,
condition: Option[Expression],
postBroadcastCondition: List[NamedExpression],
targetSizeBytes: Long) extends GpuBroadcastNestedLoopJoinExecBase(
left, right, joinType, gpuBuildSide, condition, targetSizeBytes
left, right, joinType, gpuBuildSide, condition, postBroadcastCondition, targetSizeBytes
)
Loading