From 8e835e1912e95a9410b48172ee60d9f4157b643b Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Fri, 26 Jul 2024 10:17:12 -0500 Subject: [PATCH] Fix ArrayIndexOutOfBoundsException on join counts with constant join keys (#11244) * Fix ArrayIndexOutOfBoundsException on join counts with constant join keys Signed-off-by: Jason Lowe * Handle GpuAlias --------- Signed-off-by: Jason Lowe --- integration_tests/src/main/python/join_test.py | 14 +++++++++++--- .../GpuBroadcastNestedLoopJoinExecBase.scala | 14 +++++++++++++- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/integration_tests/src/main/python/join_test.py b/integration_tests/src/main/python/join_test.py index 35e837cb436..250159e1bb3 100644 --- a/integration_tests/src/main/python/join_test.py +++ b/integration_tests/src/main/python/join_test.py @@ -14,10 +14,10 @@ import pytest from _pytest.mark.structures import ParameterSet -from pyspark.sql.functions import array_contains, broadcast, col +from pyspark.sql.functions import array_contains, broadcast, col, lit from pyspark.sql.types import * -from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_collect_with_capture -from conftest import is_databricks_runtime, is_emr_runtime, is_not_utc +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_row_counts_equal, assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_collect_with_capture +from conftest import is_emr_runtime from data_gen import * from marks import ignore_order, allow_non_gpu, incompat, validate_execs_in_gpu_plan from spark_session import with_cpu_session, is_before_spark_330, is_databricks_runtime @@ -164,6 +164,14 @@ def do_join(spark): return left.join(right.hint("broadcast"), left.a == right.r_a, join_type) assert_gpu_and_cpu_are_equal_collect(do_join, conf={'spark.sql.adaptive.enabled': 'true'}) +@pytest.mark.parametrize('join_type', ['Left', 'Inner', 'LeftSemi', 'LeftAnti'], ids=idfn) +def test_broadcast_hash_join_constant_keys(join_type): + def do_join(spark): + left = spark.range(10).withColumn("s", lit(1)) + right = spark.range(10000).withColumn("r_s", lit(1)) + return left.join(right.hint("broadcast"), left.s == right.r_s, join_type) + assert_gpu_and_cpu_row_counts_equal(do_join, conf={'spark.sql.adaptive.enabled': 'true'}) + # local sort because of https://github.com/NVIDIA/spark-rapids/issues/84 # After 3.1.0 is the min spark version we can drop this diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExecBase.scala index 47bcec60674..7ddc969d919 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExecBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExecBase.scala @@ -558,6 +558,16 @@ abstract class GpuBroadcastNestedLoopJoinExecBase( broadcastExchange.executeColumnarBroadcast[Any]() } + private def isUnconditionalJoin(condition: Option[GpuExpression]): Boolean = { + condition.forall { + case GpuLiteral(true, BooleanType) => + // Spark can generate a degenerate conditional join when the join keys are constants + output.isEmpty + case GpuAlias(e: GpuExpression, _) => isUnconditionalJoin(Some(e)) + case _ => false + } + } + override def internalDoExecuteColumnar(): RDD[ColumnarBatch] = { // Determine which table will be first in the join and bind the references accordingly // so the AST column references match the appropriate table. @@ -583,7 +593,9 @@ abstract class GpuBroadcastNestedLoopJoinExecBase( if (useTrueCondition) Some(GpuLiteral(true)) else None } - if (joinCondition.isEmpty) { + // Sometimes Spark specifies a true condition for a row-count-only join. + // This can happen when the join keys are detected to be constant. + if (isUnconditionalJoin(joinCondition)) { doUnconditionalJoin(broadcastRelation) } else { doConditionalJoin(broadcastRelation, joinCondition, numFirstTableColumns)