diff --git a/integration_tests/src/main/python/join_test.py b/integration_tests/src/main/python/join_test.py index 94289193485..e8414230b1b 100644 --- a/integration_tests/src/main/python/join_test.py +++ b/integration_tests/src/main/python/join_test.py @@ -15,7 +15,7 @@ import pytest from pyspark.sql.functions import broadcast from pyspark.sql.types import * -from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_collect_with_capture from conftest import is_databricks_runtime, is_emr_runtime from data_gen import * from marks import ignore_order, allow_non_gpu, incompat, validate_execs_in_gpu_plan @@ -361,7 +361,7 @@ def do_join(spark): def test_right_broadcast_nested_loop_join_with_ast_condition(data_gen, join_type, batch_size): def do_join(spark): left, right = create_df(spark, data_gen, 50, 25) - # This test is impacted by https://github.com/NVIDIA/spark-rapids/issues/294 + # This test is impacted by https://github.com/NVIDIA/spark-rapids/issues/294 # if the sizes are large enough to have both 0.0 and -0.0 show up 500 and 250 # but these take a long time to verify so we run with smaller numbers by default # that do not expose the error @@ -651,7 +651,7 @@ def do_join(spark): if (cache_side == 'cache_left'): # Try to force the shuffle to be split between CPU and GPU for the join - # by default if the operation after the shuffle is not on the GPU then + # by default if the operation after the shuffle is not on the GPU then # don't do a GPU shuffle, so do something simple after the repartition # to make sure that the GPU shuffle is used. left = left.repartition('a').selectExpr('b + 1 as b', 'a').cache() @@ -659,7 +659,7 @@ def do_join(spark): else: #cache_right # Try to force the shuffle to be split between CPU and GPU for the join - # by default if the operation after the shuffle is not on the GPU then + # by default if the operation after the shuffle is not on the GPU then # don't do a GPU shuffle, so do something simple after the repartition # to make sure that the GPU shuffle is used. right = right.repartition('r_a').selectExpr('c + 1 as c', 'r_a').cache() @@ -785,3 +785,37 @@ def do_join(spark): return spark.sql("select a.* from {} a, {} b where a.name=b.name".format( resultdf_name, resultdf_name)) assert_gpu_and_cpu_are_equal_collect(do_join) + +# ExistenceJoin occurs in the context of existential subqueries (which is rewritten to SemiJoin) if +# there is an additional condition that may qualify left records even though they don't have +# join partner records from the right. +# +# Thus a query is rewritten roughly as a LeftOuter with an additional Boolean column "exists" added. +# which feeds into a filter "exists OR someOtherPredicate" +# If the condition is something like an AND, it makes the result a subset of a SemiJoin, and +# the optimizer won't use ExistenceJoin. +@ignore_order(local=True) +@pytest.mark.parametrize( + "allowFallback", [ + pytest.param('true', + marks=pytest.mark.allow_non_gpu('SortMergeJoinExec')), + pytest.param('false', + marks=pytest.mark.xfail(reason="https://github.com/NVIDIA/spark-rapids/issues/589")) + ], ids=idfn +) +def test_existence_join(allowFallback, spark_tmp_table_factory): + leftTable = spark_tmp_table_factory.get() + rightTable = spark_tmp_table_factory.get() + def do_join(spark): + # create non-overlapping ranges to have a mix of exists=true and exists=false + spark.createDataFrame([v] for v in range(2, 10)).createOrReplaceTempView(leftTable) + spark.createDataFrame([v] for v in range(0, 8)).createOrReplaceTempView(rightTable) + res = spark.sql(( + "select * " + "from {} as l " + "where l._1 < 0 " + " OR l._1 in (select * from {} as r)" + ).format(leftTable, rightTable)) + return res + assert_cpu_and_gpu_are_equal_collect_with_capture(do_join, r".+Join ExistenceJoin\(exists#[0-9]+\).+") + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala index 93ce85ce3e1..16ff8eb02f2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala @@ -21,7 +21,9 @@ import java.util.Properties import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} import scala.collection.JavaConverters._ +import scala.collection.mutable.{Map => MutableMap} import scala.util.Try +import scala.util.matching.Regex import com.nvidia.spark.rapids.python.PythonWorkerSemaphore @@ -391,25 +393,33 @@ object ExecutionPlanCaptureCallback { executedPlan.expressions.exists(didFallBack(_, fallbackCpuClass)) } - private def containsExpression(exp: Expression, className: String): Boolean = exp.find { + private def containsExpression(exp: Expression, className: String, + regexMap: MutableMap[String, Regex] // regex memoization + ): Boolean = exp.find { case e if PlanUtils.getBaseNameFromClass(e.getClass.getName) == className => true - case e: ExecSubqueryExpression => containsPlan(e.plan, className) + case e: ExecSubqueryExpression => containsPlan(e.plan, className, regexMap) case _ => false }.nonEmpty - private def containsPlan(plan: SparkPlan, className: String): Boolean = plan.find { + private def containsPlan(plan: SparkPlan, className: String, + regexMap: MutableMap[String, Regex] = MutableMap.empty // regex memoization + ): Boolean = plan.find { case p if PlanUtils.sameClass(p, className) => true case p: AdaptiveSparkPlanExec => - containsPlan(p.executedPlan, className) + containsPlan(p.executedPlan, className, regexMap) case p: QueryStageExec => - containsPlan(p.plan, className) + containsPlan(p.plan, className, regexMap) case p: ReusedSubqueryExec => - containsPlan(p.child, className) + containsPlan(p.child, className, regexMap) case p: ReusedExchangeExec => - containsPlan(p.child, className) - case p => - p.expressions.exists(containsExpression(_, className)) + containsPlan(p.child, className, regexMap) + case p if p.expressions.exists(containsExpression(_, className, regexMap)) => + true + case p: SparkPlan => + regexMap.getOrElseUpdate(className, className.r) + .findFirstIn(p.simpleStringWithNodeId()) + .nonEmpty }.nonEmpty }