diff --git a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/GpuHashJoin.scala b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/GpuHashJoin.scala index 30bfa47353f..27d10c5c0b3 100644 --- a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/GpuHashJoin.scala +++ b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/GpuHashJoin.scala @@ -19,13 +19,10 @@ import ai.rapids.cudf.{NvtxColor, Table} import com.nvidia.spark.rapids._ import org.apache.spark.TaskContext -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi, RightOuter} -import org.apache.spark.sql.execution.joins.HashJoin +import org.apache.spark.sql.execution.joins.HashJoinWithoutCodegen import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -45,9 +42,7 @@ object GpuHashJoin { } } -trait GpuHashJoin extends GpuExec with HashJoin { - - override def supportCodegen: Boolean = false +trait GpuHashJoin extends GpuExec with HashJoinWithoutCodegen { override def output: Seq[Attribute] = { joinType match { @@ -247,13 +242,4 @@ trait GpuHashJoin extends GpuExec with HashJoin { joinedTable.close() } } - - override def inputRDDs(): Seq[RDD[InternalRow]] = { - throw new UnsupportedOperationException("inputRDDs is used by codegen which we don't support") - } - - protected override def prepareRelation(ctx: CodegenContext): (String, Boolean) = { - throw new UnsupportedOperationException( - "prepareRelation is used by codegen which we don't support") - } } diff --git a/shims/spark310/src/main/scala/org/apache/spark/sql/execution/joins/HashJoinWithoutCodegen.scala b/shims/spark310/src/main/scala/org/apache/spark/sql/execution/joins/HashJoinWithoutCodegen.scala new file mode 100644 index 00000000000..7a3b97755e6 --- /dev/null +++ b/shims/spark310/src/main/scala/org/apache/spark/sql/execution/joins/HashJoinWithoutCodegen.scala @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.joins + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext + +/** + * This trait is used to implement a hash join that does not support codegen. + * It is in the org.apache.spark.sql.execution.joins package to access the + * `HashedRelationInfo` class which is private to that package but needed by + * any class implementing `HashJoin`. + */ +trait HashJoinWithoutCodegen extends HashJoin { + override def supportCodegen: Boolean = false + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + throw new UnsupportedOperationException("inputRDDs is used by codegen which we don't support") + } + + protected override def prepareRelation(ctx: CodegenContext): HashedRelationInfo = { + throw new UnsupportedOperationException( + "prepareRelation is used by codegen which is not supported for this join") + } +}