diff --git a/api_validation/pom.xml b/api_validation/pom.xml index 0062e481c4e..47f9604ff76 100644 --- a/api_validation/pom.xml +++ b/api_validation/pom.xml @@ -46,6 +46,12 @@ ${spark311.version} + + spark320 + + ${spark320.version} + + diff --git a/docs/additional-functionality/rapids-shuffle.md b/docs/additional-functionality/rapids-shuffle.md index 41c8cdbd746..85cf8bc02e2 100644 --- a/docs/additional-functionality/rapids-shuffle.md +++ b/docs/additional-functionality/rapids-shuffle.md @@ -258,6 +258,7 @@ In this section, we are using a docker container built using the sample dockerfi | 3.0.1 EMR | com.nvidia.spark.rapids.spark301emr.RapidsShuffleManager | | 3.0.2 | com.nvidia.spark.rapids.spark302.RapidsShuffleManager | | 3.1.1 | com.nvidia.spark.rapids.spark311.RapidsShuffleManager | + | 3.2.0 | com.nvidia.spark.rapids.spark320.RapidsShuffleManager | 2. Recommended settings for UCX 1.9.0+ ```shell diff --git a/integration_tests/pom.xml b/integration_tests/pom.xml index 8550e38cd59..b7925a7b582 100644 --- a/integration_tests/pom.xml +++ b/integration_tests/pom.xml @@ -28,36 +28,6 @@ rapids-4-spark-integration-tests_2.12 0.5.0-SNAPSHOT - - ${spark300.version} - - - - spark301dbtests - - ${spark301db.version} - - - - spark301tests - - ${spark301.version} - - - - spark302tests - - ${spark302.version} - - - - spark311tests - - ${spark311.version} - - - - org.slf4j diff --git a/pom.xml b/pom.xml index bf84eba3ebb..53118bd5dc8 100644 --- a/pom.xml +++ b/pom.xml @@ -131,11 +131,39 @@ true + + + spark301dbtests + + ${spark301db.version} + + spark301tests + + ${spark301.version} + + + + spark302tests + + ${spark302.version} + spark311tests + + ${spark311.version} + + + tests-spark310+ + + + + spark320tests + + ${spark320.version} + tests-spark310+ @@ -146,6 +174,7 @@ 1.8 1.8 ${spark300.version} + ${spark300.version} cuda10-1 0.19-SNAPSHOT 2.12 @@ -177,6 +206,7 @@ 3.0.1-databricks 3.0.2-SNAPSHOT 3.1.1-SNAPSHOT + 3.2.0-SNAPSHOT 3.6.0 4.3.0 3.2.0 diff --git a/shims/aggregator/pom.xml b/shims/aggregator/pom.xml index ac020ae4e28..d365ea89d08 100644 --- a/shims/aggregator/pom.xml +++ b/shims/aggregator/pom.xml @@ -62,6 +62,12 @@ true + + com.nvidia + rapids-4-spark-shims-spark320_${scala.binary.version} + ${project.version} + compile + com.nvidia rapids-4-spark-shims-spark311_${scala.binary.version} diff --git a/shims/pom.xml b/shims/pom.xml index f14361abc55..d32c1e610f3 100644 --- a/shims/pom.xml +++ b/shims/pom.xml @@ -47,6 +47,7 @@ spark302 spark311 + spark320 @@ -71,6 +72,11 @@ ${cuda.version} provided + + org.scalatest + scalatest_${scala.binary.version} + test + @@ -78,6 +84,10 @@ net.alchim31.maven scala-maven-plugin + + org.scalatest + scalatest-maven-plugin + diff --git a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala index 180accf98ae..96cbbbebda4 100644 --- a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala +++ b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala @@ -32,11 +32,13 @@ import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, ShuffleQueryStageExec} @@ -44,7 +46,7 @@ import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, Fil import org.apache.spark.sql.execution.datasources.rapids.GpuPartitioningUtils import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec} import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec import org.apache.spark.sql.execution.python.WindowInPandasExec @@ -584,4 +586,18 @@ class Spark300Shims extends SparkShims { } recurse(plan, predicate, new ListBuffer[SparkPlan]()) } + + override def reusedExchangeExecPfn: PartialFunction[SparkPlan, ReusedExchangeExec] = { + case ShuffleQueryStageExec(_, e: ReusedExchangeExec) => e + case BroadcastQueryStageExec(_, e: ReusedExchangeExec) => e + } + + /** dropped by SPARK-34234 */ + override def attachTreeIfSupported[TreeType <: TreeNode[_], A]( + tree: TreeType, + msg: String)( + f: => A + ): A = { + attachTree(tree, msg)(f) + } } diff --git a/shims/spark301db/pom.xml b/shims/spark301db/pom.xml index ef41d7e6f1d..2364123e5b7 100644 --- a/shims/spark301db/pom.xml +++ b/shims/spark301db/pom.xml @@ -73,7 +73,6 @@ 1.10.1 - 3.0.1-databricks 0.15.1 diff --git a/shims/spark320/pom.xml b/shims/spark320/pom.xml new file mode 100644 index 00000000000..fd893e82315 --- /dev/null +++ b/shims/spark320/pom.xml @@ -0,0 +1,92 @@ + + + + 4.0.0 + + + com.nvidia + rapids-4-spark-shims_2.12 + 0.5.0-SNAPSHOT + ../pom.xml + + com.nvidia + rapids-4-spark-shims-spark320_2.12 + RAPIDS Accelerator for Apache Spark SQL Plugin Spark 3.2.0 Shim + The RAPIDS SQL plugin for Apache Spark 3.2.0 Shim + 0.5.0-SNAPSHOT + + + + + + + maven-antrun-plugin + + + dependency + generate-resources + + + + + + + + + + + + + run + + + + + + org.scalastyle + scalastyle-maven-plugin + + + + + + + ${project.build.directory}/extra-resources + + + src/main/resources + + + + + + + com.nvidia + rapids-4-spark-shims-spark311_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark320.version} + provided + + + diff --git a/shims/spark320/src/main/resources/META-INF/services/com.nvidia.spark.rapids.SparkShimServiceProvider b/shims/spark320/src/main/resources/META-INF/services/com.nvidia.spark.rapids.SparkShimServiceProvider new file mode 100644 index 00000000000..f6e343b6bfe --- /dev/null +++ b/shims/spark320/src/main/resources/META-INF/services/com.nvidia.spark.rapids.SparkShimServiceProvider @@ -0,0 +1 @@ +com.nvidia.spark.rapids.shims.spark320.SparkShimServiceProvider diff --git a/shims/spark320/src/main/scala/com/nvidia/spark/rapids/shims/spark320/Spark320Shims.scala b/shims/spark320/src/main/scala/com/nvidia/spark/rapids/shims/spark320/Spark320Shims.scala new file mode 100644 index 00000000000..ceddf82f741 --- /dev/null +++ b/shims/spark320/src/main/scala/com/nvidia/spark/rapids/shims/spark320/Spark320Shims.scala @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.spark320 + +import com.nvidia.spark.rapids.ShimVersion +import com.nvidia.spark.rapids.shims.spark311.Spark311Shims +import com.nvidia.spark.rapids.spark320.RapidsShuffleManager + +import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec + +class Spark320Shims extends Spark311Shims { + + override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION320 + + override def getRapidsShuffleManagerClass: String = { + classOf[RapidsShuffleManager].getCanonicalName + } + + /** + * Case class ShuffleQueryStageExec holds an additional field shuffleOrigin + * affecting the unapply method signature + */ + override def reusedExchangeExecPfn: PartialFunction[SparkPlan, ReusedExchangeExec] = { + case ShuffleQueryStageExec(_, e: ReusedExchangeExec, _) => e + case BroadcastQueryStageExec(_, e: ReusedExchangeExec, _) => e + } + + /** dropped by SPARK-34234 */ + override def attachTreeIfSupported[TreeType <: TreeNode[_], A]( + tree: TreeType, + msg: String)( + f: => A + ): A = { + identity(f) + } +} diff --git a/shims/spark320/src/main/scala/com/nvidia/spark/rapids/shims/spark320/SparkShimServiceProvider.scala b/shims/spark320/src/main/scala/com/nvidia/spark/rapids/shims/spark320/SparkShimServiceProvider.scala new file mode 100644 index 00000000000..f451f0e8679 --- /dev/null +++ b/shims/spark320/src/main/scala/com/nvidia/spark/rapids/shims/spark320/SparkShimServiceProvider.scala @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.spark320 + +import com.nvidia.spark.rapids.{SparkShims, SparkShimVersion} + +object SparkShimServiceProvider { + val VERSION320 = SparkShimVersion(3, 2, 0) + val VERSIONNAMES: Seq[String] = Seq(VERSION320) + .flatMap(v => Seq(s"$v", s"$v-SNAPSHOT")) +} + +class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { + + def matchesVersion(version: String): Boolean = { + SparkShimServiceProvider.VERSIONNAMES.contains(version) + } + + def buildShim: SparkShims = { + new Spark320Shims() + } +} diff --git a/shims/spark320/src/main/scala/com/nvidia/spark/rapids/spark320/RapidsShuffleManager.scala b/shims/spark320/src/main/scala/com/nvidia/spark/rapids/spark320/RapidsShuffleManager.scala new file mode 100644 index 00000000000..4c6b0551db0 --- /dev/null +++ b/shims/spark320/src/main/scala/com/nvidia/spark/rapids/spark320/RapidsShuffleManager.scala @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.spark320 + +import org.apache.spark.SparkConf +import org.apache.spark.sql.rapids.shims.spark311.RapidsShuffleInternalManager + +/** A shuffle manager optimized for the RAPIDS Plugin for Apache Spark. */ +sealed class RapidsShuffleManager( + conf: SparkConf, + isDriver: Boolean) extends RapidsShuffleInternalManager(conf, isDriver) { +} diff --git a/shims/spark320/src/test/scala/com/nvidia/spark/rapids/shims/spark320/Spark320ShimsSuite.scala b/shims/spark320/src/test/scala/com/nvidia/spark/rapids/shims/spark320/Spark320ShimsSuite.scala new file mode 100644 index 00000000000..bdc363c986d --- /dev/null +++ b/shims/spark320/src/test/scala/com/nvidia/spark/rapids/shims/spark320/Spark320ShimsSuite.scala @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.spark320; + +import com.nvidia.spark.rapids.{SparkShims, SparkShimVersion} + +import org.scalatest.FunSuite; + +class Spark320ShimsSuite extends FunSuite { + val sparkShims: SparkShims = new SparkShimServiceProvider().buildShim + test("spark shims version") { + assert(sparkShims.getSparkShimVersion === SparkShimVersion(3, 2, 0)) + } + + test("shuffle manager class") { + assert(sparkShims.getRapidsShuffleManagerClass === + classOf[com.nvidia.spark.rapids.spark320.RapidsShuffleManager].getCanonicalName) + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java index 46fc9c794c4..ac48d7740a4 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java @@ -107,23 +107,17 @@ public static synchronized void debug(String name, HostColumnVector hostCol) { hexString(hostCol.getUTF8(i))); } } - } else if (DType.INT32.equals(type)) { - for (int i = 0; i < hostCol.getRowCount(); i++) { - if (hostCol.isNull(i)) { - System.err.println(i + " NULL"); - } else { - System.err.println(i + " " + hostCol.getInt(i)); - } - } - } else if (DType.INT8.equals(type)) { - for (int i = 0; i < hostCol.getRowCount(); i++) { - if (hostCol.isNull(i)) { - System.err.println(i + " NULL"); - } else { - System.err.println(i + " " + hostCol.getByte(i)); - } - } - } else if (DType.BOOL8.equals(type)) { + } else if (DType.INT32.equals(type) + || DType.INT8.equals(type) + || DType.INT16.equals(type) + || DType.INT64.equals(type) + || DType.TIMESTAMP_DAYS.equals(type) + || DType.TIMESTAMP_SECONDS.equals(type) + || DType.TIMESTAMP_MICROSECONDS.equals(type) + || DType.TIMESTAMP_MILLISECONDS.equals(type) + || DType.TIMESTAMP_NANOSECONDS.equals(type)) { + debugInteger(hostCol, type); + } else if (DType.BOOL8.equals(type)) { for (int i = 0; i < hostCol.getRowCount(); i++) { if (hostCol.isNull(i)) { System.err.println(i + " NULL"); @@ -131,20 +125,39 @@ public static synchronized void debug(String name, HostColumnVector hostCol) { System.err.println(i + " " + hostCol.getBoolean(i)); } } - } else if (DType.TIMESTAMP_MICROSECONDS.equals(type) || - DType.INT64.equals(type)) { - for (int i = 0; i < hostCol.getRowCount(); i++) { - if (hostCol.isNull(i)) { - System.err.println(i + " NULL"); - } else { - System.err.println(i + " " + hostCol.getLong(i)); - } - } } else { System.err.println("TYPE " + type + " NOT SUPPORTED FOR DEBUG PRINT"); } } + private static void debugInteger(HostColumnVector hostCol, DType intType) { + for (int i = 0; i < hostCol.getRowCount(); i++) { + if (hostCol.isNull(i)) { + System.err.println(i + " NULL"); + } else { + final int sizeInBytes = intType.getSizeInBytes(); + final Object value; + switch (sizeInBytes) { + case Byte.BYTES: + value = hostCol.getByte(i); + break; + case Short.BYTES: + value = hostCol.getShort(i); + break; + case Integer.BYTES: + value = hostCol.getInt(i); + break; + case Long.BYTES: + value = hostCol.getLong(i); + break; + default: + throw new IllegalArgumentException("INFEASIBLE: Unsupported integer-like type " + intType); + } + System.err.println(i + " " + value); + } + } + } + private static HostColumnVector.DataType convertFrom(DataType spark, boolean nullable) { if (spark instanceof ArrayType) { ArrayType arrayType = (ArrayType) spark; diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 2e3501462e9..37a093c746b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1857,7 +1857,7 @@ object GpuOverrides { } } - override def convertToGpu(child: Expression): GpuExpression = GpuSum(child) + override def convertToGpu(child: Expression): GpuExpression = GpuSum(child, a.dataType) }), expr[Average]( "Average aggregate operator", diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala index 07ee8a42786..1ba4a055956 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -30,11 +30,12 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, ExprId, Nul import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile} -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, ShuffleManagerShimBase} @@ -191,4 +192,13 @@ trait SparkShims { def shouldFailDivByZero(): Boolean def findOperators(plan: SparkPlan, predicate: SparkPlan => Boolean): Seq[SparkPlan] + + def reusedExchangeExecPfn: PartialFunction[SparkPlan, ReusedExchangeExec] + + /** dropped by SPARK-34234 */ + def attachTreeIfSupported[TreeType <: TreeNode[_], A]( + tree: TreeType, + msg: String = "")( + f: => A + ): A } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala index 32da9614133..0fcd3b4bfce 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf -import ai.rapids.cudf.{Aggregation, AggregationOnColumn, ColumnVector} +import ai.rapids.cudf.{Aggregation, AggregationOnColumn, ColumnVector, DType} import com.nvidia.spark.rapids._ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult @@ -192,10 +192,29 @@ class CudfCount(ref: Expression) extends CudfAggregate(ref) { } class CudfSum(ref: Expression) extends CudfAggregate(ref) { + // Up to 3.1.1, analyzed plan widened the input column type before applying + // aggregation. Thus even though we did not explicitly pass the output column type + // we did not run into integer overflow issues: + // + // == Analyzed Logical Plan == + // sum(shorts): bigint + // Aggregate [sum(cast(shorts#77 as bigint)) AS sum(shorts)#94L] + // + // In Spark's main branch (3.2.0-SNAPSHOT as of this comment), analyzed logical plan + // no longer applies the cast to the input column such that the output column type has to + // be passed explicitly into aggregation + // + // == Analyzed Logical Plan == + // sum(shorts): bigint + // Aggregate [sum(shorts#33) AS sum(shorts)#50L] + // + @transient val rapidsSumType: DType = GpuColumnVector.getNonNestedRapidsType(ref.dataType) + override val updateReductionAggregate: cudf.ColumnVector => cudf.Scalar = - (col: cudf.ColumnVector) => col.sum - override val mergeReductionAggregate: cudf.ColumnVector => cudf.Scalar = - (col: cudf.ColumnVector) => col.sum + (col: cudf.ColumnVector) => col.sum(rapidsSumType) + + override val mergeReductionAggregate: cudf.ColumnVector => cudf.Scalar = updateReductionAggregate + override lazy val updateAggregate: Aggregation = Aggregation.sum() override lazy val mergeAggregate: Aggregation = Aggregation.sum() override def toString(): String = "CudfSum" @@ -329,12 +348,8 @@ case class GpuMax(child: Expression) extends GpuDeclarativeAggregate Aggregation.max().onColumn(inputs.head._2) } -case class GpuSum(child: Expression) +case class GpuSum(child: Expression, resultType: DataType) extends GpuDeclarativeAggregate with ImplicitCastInputTypes with GpuAggregateWindowFunction { - private lazy val resultType = child.dataType match { - case _: DoubleType => DoubleType - case _ => LongType - } private lazy val cudfSum = AttributeReference("sum", resultType)() diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala index d43c433c577..d865d52d3b8 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala @@ -26,8 +26,7 @@ import org.apache.spark.{MapOutputStatistics, ShuffleDependency} import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} @@ -142,13 +141,14 @@ abstract class GpuShuffleExchangeExecBase( protected override def doExecute(): RDD[InternalRow] = throw new IllegalStateException(s"Row-based execution should not occur for $this") - override def doExecuteColumnar(): RDD[ColumnarBatch] = attachTree(this, "execute") { - // Returns the same ShuffleRowRDD if this plan is used by multiple plans. - if (cachedShuffleRDD == null) { - cachedShuffleRDD = new ShuffledBatchRDD(shuffleDependencyColumnar, metrics ++ readMetrics) + override def doExecuteColumnar(): RDD[ColumnarBatch] = ShimLoader.getSparkShims + .attachTreeIfSupported(this, "execute") { + // Returns the same ShuffleRowRDD if this plan is used by multiple plans. + if (cachedShuffleRDD == null) { + cachedShuffleRDD = new ShuffledBatchRDD(shuffleDependencyColumnar, metrics ++ readMetrics) + } + cachedShuffleRDD } - cachedShuffleRDD - } } object GpuShuffleExchangeExec { diff --git a/tests-spark310+/pom.xml b/tests-spark310+/pom.xml index 82e651bb044..e05a532377f 100644 --- a/tests-spark310+/pom.xml +++ b/tests-spark310+/pom.xml @@ -28,10 +28,6 @@ rapids-4-spark-tests-next-spark_2.12 0.5.0-SNAPSHOT - - ${spark311.version} - - org.apache.spark diff --git a/tests/pom.xml b/tests/pom.xml index 43a836a0905..0bd1229e09a 100644 --- a/tests/pom.xml +++ b/tests/pom.xml @@ -30,36 +30,6 @@ RAPIDS plugin for Apache Spark integration tests 0.5.0-SNAPSHOT - - ${spark300.version} - - - - spark301dbtests - - ${spark301db.version} - - - - spark301tests - - ${spark301.version} - - - - spark302tests - - ${spark302.version} - - - - spark311tests - - ${spark311.version} - - - - org.slf4j diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala index 4e0a50b5849..5de8e891450 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkConf import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, SparkPlan} -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, BroadcastQueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, ShuffleQueryStageExec} import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec} import org.apache.spark.sql.execution.joins.SortMergeJoinExec @@ -94,10 +94,7 @@ class AdaptiveQueryExecSuite } private def findReusedExchange(plan: SparkPlan): Seq[ReusedExchangeExec] = { - collectWithSubqueries(plan) { - case ShuffleQueryStageExec(_, e: ReusedExchangeExec) => e - case BroadcastQueryStageExec(_, e: ReusedExchangeExec) => e - } + collectWithSubqueries(plan)(ShimLoader.getSparkShims.reusedExchangeExecPfn) } test("skewed inner join optimization") {