From eb1efdfa28d55d7f4b85d1b96bbb8433e23a676e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 2 Nov 2020 12:16:32 -0700 Subject: [PATCH] Remove isGpuBroadcastNestedLoopJoin from shims (#1053) Signed-off-by: Andy Grove --- .../nvidia/spark/rapids/shims/spark300/Spark300Shims.scala | 7 ------- .../main/scala/com/nvidia/spark/rapids/SparkShims.scala | 1 - .../nvidia/spark/rapids/BroadcastNestedLoopJoinSuite.scala | 5 +++-- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala index cd90309c93d..d032f53e3b3 100644 --- a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala +++ b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala @@ -118,13 +118,6 @@ class Spark300Shims extends SparkShims { } } - override def isGpuBroadcastNestedLoopJoin(plan: SparkPlan): Boolean = { - plan match { - case _: GpuBroadcastNestedLoopJoinExecBase => true - case _ => false - } - } - override def isGpuShuffledHashJoin(plan: SparkPlan): Boolean = { plan match { case _: GpuShuffledHashJoinExec => true diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala index 142253df6f3..846b4fde362 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -62,7 +62,6 @@ trait SparkShims { def getSparkShimVersion: ShimVersion def isGpuHashJoin(plan: SparkPlan): Boolean def isGpuBroadcastHashJoin(plan: SparkPlan): Boolean - def isGpuBroadcastNestedLoopJoin(plan: SparkPlan): Boolean def isGpuShuffledHashJoin(plan: SparkPlan): Boolean def isBroadcastExchangeLike(plan: SparkPlan): Boolean def isShuffleExchangeLike(plan: SparkPlan): Boolean diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastNestedLoopJoinSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastNestedLoopJoinSuite.scala index 5fc4ee0888d..010083ec677 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastNestedLoopJoinSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastNestedLoopJoinSuite.scala @@ -21,6 +21,7 @@ import com.nvidia.spark.rapids.TestUtils.findOperators import org.apache.spark.SparkConf import org.apache.spark.sql.functions.broadcast import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.rapids.execution.GpuBroadcastNestedLoopJoinExecBase class BroadcastNestedLoopJoinSuite extends SparkQueryCompareTestSuite { @@ -36,7 +37,7 @@ class BroadcastNestedLoopJoinSuite extends SparkQueryCompareTestSuite { df3.collect() val plan = df3.queryExecution.executedPlan - val nljCount = findOperators(plan, ShimLoader.getSparkShims.isGpuBroadcastNestedLoopJoin) + val nljCount = findOperators(plan, _.isInstanceOf[GpuBroadcastNestedLoopJoinExecBase]) assert(nljCount.size === 1) }, conf) } @@ -53,7 +54,7 @@ class BroadcastNestedLoopJoinSuite extends SparkQueryCompareTestSuite { df3.collect() val plan = df3.queryExecution.executedPlan - val nljCount = findOperators(plan, ShimLoader.getSparkShims.isGpuBroadcastNestedLoopJoin) + val nljCount = findOperators(plan, _.isInstanceOf[GpuBroadcastNestedLoopJoinExecBase]) ShimLoader.getSparkShims.getSparkShimVersion match { case SparkShimVersion(3, 0, 0) =>