Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support split non-AST-able join condition for BroadcastNestedLoopJoin #9635

Merged
merged 10 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest
from _pytest.mark.structures import ParameterSet
from pyspark.sql.functions import broadcast, col
from pyspark.sql.functions import array_contains, broadcast, col
from pyspark.sql.types import *
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_collect_with_capture
from conftest import is_databricks_runtime, is_emr_runtime
Expand Down Expand Up @@ -397,17 +397,44 @@ def do_join(spark):
return left.join(broadcast(right), left.a > f.log(right.r_a), join_type)
assert_gpu_and_cpu_are_equal_collect(do_join)

@allow_non_gpu('BroadcastExchangeExec')
revans2 marked this conversation as resolved.
Show resolved Hide resolved
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen(), pytest.param(FloatGen(), marks=[incompat]), pytest.param(DoubleGen(), marks=[incompat])], ids=idfn)
@pytest.mark.parametrize('join_type', ['Cross', 'Right'], ids=idfn)
revans2 marked this conversation as resolved.
Show resolved Hide resolved
def test_broadcast_nested_loop_join_with_non_ast_condition_no_fallback(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# AST does not support cast or logarithm yet which is supposed to be extracted into child
# nodes
return broadcast(left).join(right, f.round(left.a).cast('integer') > f.round(f.log(right.r_a).cast('integer')), join_type)
revans2 marked this conversation as resolved.
Show resolved Hide resolved
assert_gpu_and_cpu_are_equal_collect(do_join, conf={"spark.rapids.sql.castFloatToIntegralTypes.enabled": True})

@allow_non_gpu('BroadcastExchangeExec', 'BroadcastNestedLoopJoinExec', 'Cast', 'GreaterThan', 'Log')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen(), pytest.param(FloatGen(), marks=[incompat]), pytest.param(DoubleGen(), marks=[incompat])], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn)
revans2 marked this conversation as resolved.
Show resolved Hide resolved
def test_broadcast_nested_loop_join_with_condition_fallback(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# AST does not support cast or logarithm yet
# AST does not support cast or logarithm yet for double type
return broadcast(left).join(right, left.a > f.log(right.r_a), join_type)
assert_gpu_fallback_collect(do_join, 'BroadcastNestedLoopJoinExec')

@allow_non_gpu('BroadcastExchangeExec')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [byte_gen, short_gen, int_gen, long_gen,
float_gen, double_gen,
string_gen, boolean_gen, date_gen, timestamp_gen], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_nested_loop_join_with_array_contains_no_fallback(data_gen, join_type):
revans2 marked this conversation as resolved.
Show resolved Hide resolved
arr_gen = ArrayGen(data_gen)
literal = with_cpu_session(lambda spark: gen_scalar(data_gen))
def do_join(spark):
left, right = create_df(spark, arr_gen, 50, 25)
# Array_contains should be pushed down since ast doesn't support it.
return broadcast(left).join(right, array_contains(col('a'), literal.cast(data_gen.data_type)))
assert_gpu_and_cpu_are_equal_collect(do_join)

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'LeftSemi', 'LeftAnti'], ids=idfn)
Expand Down
131 changes: 131 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import scala.collection.mutable
import scala.collection.mutable.ListBuffer

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSeq, Expression, ExprId, NamedExpression}
import org.apache.spark.sql.rapids.catalyst.expressions.{GpuEquivalentExpressions, GpuExpressionEquals}


object AstUtil {

/**
* Check whether it can be split into non-ast sub-expression if needed
*
* @return true when: 1) If all ast-able in expr; 2) all non-ast-able tree nodes don't contain
* attributes from both join sides. In such case, it's not able
* to push down into single child.
*/
def canExtractNonAstConditionIfNeed(expr: BaseExprMeta[_], left: Seq[Attribute],
revans2 marked this conversation as resolved.
Show resolved Hide resolved
right: Seq[Attribute]): Boolean = {
if (!expr.canSelfBeAst) {
// Returns false directly since can't split non-ast-able root node into child
false
} else {
// Check whether any child contains the case not able to split
expr.childExprs.isEmpty || expr.childExprs.forall(canExtractInternal(_, left, right))
}
}

private[this] def canExtractInternal(expr: BaseExprMeta[_], left: Seq[Attribute],
right: Seq[Attribute]): Boolean = {
if (!expr.canSelfBeAst) {
// It needs to be split since not ast-able. Check itself and childerns to ensure
// pushing-down can be made, which doesn't need attributions from both sides.
val exprRef = expr.wrapped.asInstanceOf[Expression]
val leftTree = exprRef.references.exists(left.contains(_))
val rightTree = exprRef.references.exists(right.contains(_))
// Can't extract a condition involving columns from both sides
!(rightTree && leftTree)
} else {
// Check whether any child contains the case not able to split
expr.childExprs.isEmpty || expr.childExprs.forall(canExtractInternal(_, left, right))
}
}

/**
* Extract non-AST functions from join conditions and update the original join condition. Based
* on the attributes, it decides which side the split condition belongs to. The replaced
* condition is wrapped with GpuAlias with new intermediate attributes.
*
winningsix marked this conversation as resolved.
Show resolved Hide resolved
* @param condition to be split if needed
* @param left attributions from left child
* @param right attributions from right child
* @param skipCheck whether skip split-able check
* @return a tuple of [[Expression]] for remained expressions, List of [[NamedExpression]] for
* left child if any, List of [[NamedExpression]] for right child if any
*/
def extractNonAstFromJoinCond(condition: Option[BaseExprMeta[_]],
left: AttributeSeq, right: AttributeSeq, skipCheck: Boolean):
(Option[Expression], List[NamedExpression], List[NamedExpression]) = {
// Choose side with smaller key size. Use expr ID to check the side which project expr
// belonging to.
val (exprIds, isLeft) = if (left.attrs.size < right.attrs.size) {
(left.attrs.map(_.exprId), true)
} else {
(right.attrs.map(_.exprId), false)
}
// List of expression pushing down to left side child
val leftExprs: ListBuffer[NamedExpression] = ListBuffer.empty
// List of expression pushing down to right side child
val rightExprs: ListBuffer[NamedExpression] = ListBuffer.empty
// Substitution map used to replace targeted expressions based on semantic equality
val substitutionMap = mutable.HashMap.empty[GpuExpressionEquals, Expression]

// 1st step to construct 1) left expr list; 2) right expr list; 3) substitutionMap
// No need to consider common sub-expressions here since project node will use tiered execution
condition.foreach(c =>
if (skipCheck || canExtractNonAstConditionIfNeed(c, left.attrs, right.attrs)) {
splitNonAstInternal(c, exprIds, leftExprs, rightExprs, substitutionMap, isLeft)
})

// 2nd step to replace expression pushing down to child plans in depth first fashion
(condition.map(
_.convertToGpu().mapChildren(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be fine, but I am not 100% sure that convertToGpu will be idempotent. It might be better to just convert it to the GPU up front and cache it instead of recalling it inside of splitNonAstInternal. To make that work we probably would have to zip the GPU expression and along with condition.childExprs. I don't think this is a blocker for the patch to go in. Just being overly cautious about it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, it seems not very straight forward to eliminate non-idempotent risk here. convertToGpu only happens when we need to replace it with a new GpuAlias node. We can probably cache those converted expression but in the end (L100), the root BaseExprMeta will do a real convertToGpu which converts the entire tree. It seems we can hardly reuse the cache here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine for now, so lets file a follow on issue to look into it. I see a few ways to fix it, but none of them are that simple, and may have their own hidden pitfalls.

GpuEquivalentExpressions.replaceWithSemanticCommonRef(_,
substitutionMap))), leftExprs.toList, rightExprs.toList)
}

private[this] def splitNonAstInternal(condition: BaseExprMeta[_], childAtt: Seq[ExprId],
left: ListBuffer[NamedExpression], right: ListBuffer[NamedExpression],
substitutionMap: mutable.HashMap[GpuExpressionEquals, Expression], isLeft: Boolean): Unit = {
for (child <- condition.childExprs) {
if (!child.canSelfBeAst) {
val exprRef = child.wrapped.asInstanceOf[Expression]
val gpuProj = child.convertToGpu()
val alias = substitutionMap.get(GpuExpressionEquals(gpuProj)) match {
case Some(_) => None
case None =>
if (exprRef.references.exists(r => childAtt.contains(r.exprId)) ^ isLeft) {
val alias = GpuAlias(gpuProj, s"_agpu_non_ast_r_${left.size}")()
right += alias
Some(alias)
} else {
val alias = GpuAlias(gpuProj, s"_agpu_non_ast_l_${left.size}")()
left += alias
Some(alias)
}
}
alias.foreach(a => substitutionMap.put(GpuExpressionEquals(gpuProj), a.toAttribute))
} else {
splitNonAstInternal(child, childAtt, left, right, substitutionMap, isLeft)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,15 @@ abstract class BaseExprMeta[INPUT <: Expression](
childExprs.forall(_.canThisBeAst) && cannotBeAstReasons.isEmpty
}

/**
* Check whether this node itself can be converted to AST. It will not recursively check its
* children. It's used to check join condition AST-ability in top-down fashion.
*/
lazy val canSelfBeAst = {
tagForAst()
cannotBeAstReasons.isEmpty
}

final def requireAstForGpu(): Unit = {
tagForAst()
cannotBeAstReasons.foreach { reason =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,22 @@ class GpuEquivalentExpressions {
}

object GpuEquivalentExpressions {
/**
* Recursively replaces semantic equal expression with its proxy expression in `substitutionMap`.
*/
def replaceWithSemanticCommonRef(
expr: Expression,
substitutionMap: mutable.HashMap[GpuExpressionEquals, Expression]): Expression = {
expr match {
case e: AttributeReference => e
case _ =>
substitutionMap.get(GpuExpressionEquals(expr)) match {
case Some(attr) => attr
case None => expr.mapChildren(replaceWithSemanticCommonRef(_, substitutionMap))
}
}
}

/**
* Recursively replaces expression with its proxy expression in `substitutionMap`.
*/
Expand Down
Loading