diff --git a/api_validation/pom.xml b/api_validation/pom.xml
index 0062e481c4e..47f9604ff76 100644
--- a/api_validation/pom.xml
+++ b/api_validation/pom.xml
@@ -46,6 +46,12 @@
${spark311.version}
+
+ spark320
+
+ ${spark320.version}
+
+
diff --git a/docs/additional-functionality/rapids-shuffle.md b/docs/additional-functionality/rapids-shuffle.md
index 41c8cdbd746..85cf8bc02e2 100644
--- a/docs/additional-functionality/rapids-shuffle.md
+++ b/docs/additional-functionality/rapids-shuffle.md
@@ -258,6 +258,7 @@ In this section, we are using a docker container built using the sample dockerfi
| 3.0.1 EMR | com.nvidia.spark.rapids.spark301emr.RapidsShuffleManager |
| 3.0.2 | com.nvidia.spark.rapids.spark302.RapidsShuffleManager |
| 3.1.1 | com.nvidia.spark.rapids.spark311.RapidsShuffleManager |
+ | 3.2.0 | com.nvidia.spark.rapids.spark320.RapidsShuffleManager |
2. Recommended settings for UCX 1.9.0+
```shell
diff --git a/integration_tests/pom.xml b/integration_tests/pom.xml
index 8550e38cd59..b7925a7b582 100644
--- a/integration_tests/pom.xml
+++ b/integration_tests/pom.xml
@@ -28,36 +28,6 @@
rapids-4-spark-integration-tests_2.12
0.5.0-SNAPSHOT
-
- ${spark300.version}
-
-
-
- spark301dbtests
-
- ${spark301db.version}
-
-
-
- spark301tests
-
- ${spark301.version}
-
-
-
- spark302tests
-
- ${spark302.version}
-
-
-
- spark311tests
-
- ${spark311.version}
-
-
-
-
org.slf4j
diff --git a/pom.xml b/pom.xml
index bf84eba3ebb..53118bd5dc8 100644
--- a/pom.xml
+++ b/pom.xml
@@ -131,11 +131,39 @@
true
+
+
+ spark301dbtests
+
+ ${spark301db.version}
+
+
spark301tests
+
+ ${spark301.version}
+
+
+
+ spark302tests
+
+ ${spark302.version}
+
spark311tests
+
+ ${spark311.version}
+
+
+ tests-spark310+
+
+
+
+ spark320tests
+
+ ${spark320.version}
+
tests-spark310+
@@ -146,6 +174,7 @@
1.8
1.8
${spark300.version}
+ ${spark300.version}
cuda10-1
0.19-SNAPSHOT
2.12
@@ -177,6 +206,7 @@
3.0.1-databricks
3.0.2-SNAPSHOT
3.1.1-SNAPSHOT
+ 3.2.0-SNAPSHOT
3.6.0
4.3.0
3.2.0
diff --git a/shims/aggregator/pom.xml b/shims/aggregator/pom.xml
index ac020ae4e28..d365ea89d08 100644
--- a/shims/aggregator/pom.xml
+++ b/shims/aggregator/pom.xml
@@ -62,6 +62,12 @@
true
+
+ com.nvidia
+ rapids-4-spark-shims-spark320_${scala.binary.version}
+ ${project.version}
+ compile
+
com.nvidia
rapids-4-spark-shims-spark311_${scala.binary.version}
diff --git a/shims/pom.xml b/shims/pom.xml
index f14361abc55..d32c1e610f3 100644
--- a/shims/pom.xml
+++ b/shims/pom.xml
@@ -47,6 +47,7 @@
spark302
spark311
+ spark320
@@ -71,6 +72,11 @@
${cuda.version}
provided
+
+ org.scalatest
+ scalatest_${scala.binary.version}
+ test
+
@@ -78,6 +84,10 @@
net.alchim31.maven
scala-maven-plugin
+
+ org.scalatest
+ scalatest-maven-plugin
+
diff --git a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala
index 180accf98ae..96cbbbebda4 100644
--- a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala
+++ b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala
@@ -32,11 +32,13 @@ import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
@@ -44,7 +46,7 @@ import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, Fil
import org.apache.spark.sql.execution.datasources.rapids.GpuPartitioningUtils
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
import org.apache.spark.sql.execution.python.WindowInPandasExec
@@ -584,4 +586,18 @@ class Spark300Shims extends SparkShims {
}
recurse(plan, predicate, new ListBuffer[SparkPlan]())
}
+
+ override def reusedExchangeExecPfn: PartialFunction[SparkPlan, ReusedExchangeExec] = {
+ case ShuffleQueryStageExec(_, e: ReusedExchangeExec) => e
+ case BroadcastQueryStageExec(_, e: ReusedExchangeExec) => e
+ }
+
+ /** dropped by SPARK-34234 */
+ override def attachTreeIfSupported[TreeType <: TreeNode[_], A](
+ tree: TreeType,
+ msg: String)(
+ f: => A
+ ): A = {
+ attachTree(tree, msg)(f)
+ }
}
diff --git a/shims/spark301db/pom.xml b/shims/spark301db/pom.xml
index ef41d7e6f1d..2364123e5b7 100644
--- a/shims/spark301db/pom.xml
+++ b/shims/spark301db/pom.xml
@@ -73,7 +73,6 @@
1.10.1
- 3.0.1-databricks
0.15.1
diff --git a/shims/spark320/pom.xml b/shims/spark320/pom.xml
new file mode 100644
index 00000000000..fd893e82315
--- /dev/null
+++ b/shims/spark320/pom.xml
@@ -0,0 +1,92 @@
+
+
+
+ 4.0.0
+
+
+ com.nvidia
+ rapids-4-spark-shims_2.12
+ 0.5.0-SNAPSHOT
+ ../pom.xml
+
+ com.nvidia
+ rapids-4-spark-shims-spark320_2.12
+ RAPIDS Accelerator for Apache Spark SQL Plugin Spark 3.2.0 Shim
+ The RAPIDS SQL plugin for Apache Spark 3.2.0 Shim
+ 0.5.0-SNAPSHOT
+
+
+
+
+
+
+ maven-antrun-plugin
+
+
+ dependency
+ generate-resources
+
+
+
+
+
+
+
+
+
+
+
+
+ run
+
+
+
+
+
+ org.scalastyle
+ scalastyle-maven-plugin
+
+
+
+
+
+
+ ${project.build.directory}/extra-resources
+
+
+ src/main/resources
+
+
+
+
+
+
+ com.nvidia
+ rapids-4-spark-shims-spark311_${scala.binary.version}
+ ${project.version}
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark320.version}
+ provided
+
+
+
diff --git a/shims/spark320/src/main/resources/META-INF/services/com.nvidia.spark.rapids.SparkShimServiceProvider b/shims/spark320/src/main/resources/META-INF/services/com.nvidia.spark.rapids.SparkShimServiceProvider
new file mode 100644
index 00000000000..f6e343b6bfe
--- /dev/null
+++ b/shims/spark320/src/main/resources/META-INF/services/com.nvidia.spark.rapids.SparkShimServiceProvider
@@ -0,0 +1 @@
+com.nvidia.spark.rapids.shims.spark320.SparkShimServiceProvider
diff --git a/shims/spark320/src/main/scala/com/nvidia/spark/rapids/shims/spark320/Spark320Shims.scala b/shims/spark320/src/main/scala/com/nvidia/spark/rapids/shims/spark320/Spark320Shims.scala
new file mode 100644
index 00000000000..ceddf82f741
--- /dev/null
+++ b/shims/spark320/src/main/scala/com/nvidia/spark/rapids/shims/spark320/Spark320Shims.scala
@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) 2021, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.shims.spark320
+
+import com.nvidia.spark.rapids.ShimVersion
+import com.nvidia.spark.rapids.shims.spark311.Spark311Shims
+import com.nvidia.spark.rapids.spark320.RapidsShuffleManager
+
+import org.apache.spark.sql.catalyst.trees.TreeNode
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
+import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
+
+class Spark320Shims extends Spark311Shims {
+
+ override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION320
+
+ override def getRapidsShuffleManagerClass: String = {
+ classOf[RapidsShuffleManager].getCanonicalName
+ }
+
+ /**
+ * Case class ShuffleQueryStageExec holds an additional field shuffleOrigin
+ * affecting the unapply method signature
+ */
+ override def reusedExchangeExecPfn: PartialFunction[SparkPlan, ReusedExchangeExec] = {
+ case ShuffleQueryStageExec(_, e: ReusedExchangeExec, _) => e
+ case BroadcastQueryStageExec(_, e: ReusedExchangeExec, _) => e
+ }
+
+ /** dropped by SPARK-34234 */
+ override def attachTreeIfSupported[TreeType <: TreeNode[_], A](
+ tree: TreeType,
+ msg: String)(
+ f: => A
+ ): A = {
+ identity(f)
+ }
+}
diff --git a/shims/spark320/src/main/scala/com/nvidia/spark/rapids/shims/spark320/SparkShimServiceProvider.scala b/shims/spark320/src/main/scala/com/nvidia/spark/rapids/shims/spark320/SparkShimServiceProvider.scala
new file mode 100644
index 00000000000..f451f0e8679
--- /dev/null
+++ b/shims/spark320/src/main/scala/com/nvidia/spark/rapids/shims/spark320/SparkShimServiceProvider.scala
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2021, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.shims.spark320
+
+import com.nvidia.spark.rapids.{SparkShims, SparkShimVersion}
+
+object SparkShimServiceProvider {
+ val VERSION320 = SparkShimVersion(3, 2, 0)
+ val VERSIONNAMES: Seq[String] = Seq(VERSION320)
+ .flatMap(v => Seq(s"$v", s"$v-SNAPSHOT"))
+}
+
+class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider {
+
+ def matchesVersion(version: String): Boolean = {
+ SparkShimServiceProvider.VERSIONNAMES.contains(version)
+ }
+
+ def buildShim: SparkShims = {
+ new Spark320Shims()
+ }
+}
diff --git a/shims/spark320/src/main/scala/com/nvidia/spark/rapids/spark320/RapidsShuffleManager.scala b/shims/spark320/src/main/scala/com/nvidia/spark/rapids/spark320/RapidsShuffleManager.scala
new file mode 100644
index 00000000000..4c6b0551db0
--- /dev/null
+++ b/shims/spark320/src/main/scala/com/nvidia/spark/rapids/spark320/RapidsShuffleManager.scala
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2021, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.spark320
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.rapids.shims.spark311.RapidsShuffleInternalManager
+
+/** A shuffle manager optimized for the RAPIDS Plugin for Apache Spark. */
+sealed class RapidsShuffleManager(
+ conf: SparkConf,
+ isDriver: Boolean) extends RapidsShuffleInternalManager(conf, isDriver) {
+}
diff --git a/shims/spark320/src/test/scala/com/nvidia/spark/rapids/shims/spark320/Spark320ShimsSuite.scala b/shims/spark320/src/test/scala/com/nvidia/spark/rapids/shims/spark320/Spark320ShimsSuite.scala
new file mode 100644
index 00000000000..bdc363c986d
--- /dev/null
+++ b/shims/spark320/src/test/scala/com/nvidia/spark/rapids/shims/spark320/Spark320ShimsSuite.scala
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.shims.spark320;
+
+import com.nvidia.spark.rapids.{SparkShims, SparkShimVersion}
+
+import org.scalatest.FunSuite;
+
+class Spark320ShimsSuite extends FunSuite {
+ val sparkShims: SparkShims = new SparkShimServiceProvider().buildShim
+ test("spark shims version") {
+ assert(sparkShims.getSparkShimVersion === SparkShimVersion(3, 2, 0))
+ }
+
+ test("shuffle manager class") {
+ assert(sparkShims.getRapidsShuffleManagerClass ===
+ classOf[com.nvidia.spark.rapids.spark320.RapidsShuffleManager].getCanonicalName)
+ }
+}
\ No newline at end of file
diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java
index 46fc9c794c4..ac48d7740a4 100644
--- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java
+++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java
@@ -107,23 +107,17 @@ public static synchronized void debug(String name, HostColumnVector hostCol) {
hexString(hostCol.getUTF8(i)));
}
}
- } else if (DType.INT32.equals(type)) {
- for (int i = 0; i < hostCol.getRowCount(); i++) {
- if (hostCol.isNull(i)) {
- System.err.println(i + " NULL");
- } else {
- System.err.println(i + " " + hostCol.getInt(i));
- }
- }
- } else if (DType.INT8.equals(type)) {
- for (int i = 0; i < hostCol.getRowCount(); i++) {
- if (hostCol.isNull(i)) {
- System.err.println(i + " NULL");
- } else {
- System.err.println(i + " " + hostCol.getByte(i));
- }
- }
- } else if (DType.BOOL8.equals(type)) {
+ } else if (DType.INT32.equals(type)
+ || DType.INT8.equals(type)
+ || DType.INT16.equals(type)
+ || DType.INT64.equals(type)
+ || DType.TIMESTAMP_DAYS.equals(type)
+ || DType.TIMESTAMP_SECONDS.equals(type)
+ || DType.TIMESTAMP_MICROSECONDS.equals(type)
+ || DType.TIMESTAMP_MILLISECONDS.equals(type)
+ || DType.TIMESTAMP_NANOSECONDS.equals(type)) {
+ debugInteger(hostCol, type);
+ } else if (DType.BOOL8.equals(type)) {
for (int i = 0; i < hostCol.getRowCount(); i++) {
if (hostCol.isNull(i)) {
System.err.println(i + " NULL");
@@ -131,20 +125,39 @@ public static synchronized void debug(String name, HostColumnVector hostCol) {
System.err.println(i + " " + hostCol.getBoolean(i));
}
}
- } else if (DType.TIMESTAMP_MICROSECONDS.equals(type) ||
- DType.INT64.equals(type)) {
- for (int i = 0; i < hostCol.getRowCount(); i++) {
- if (hostCol.isNull(i)) {
- System.err.println(i + " NULL");
- } else {
- System.err.println(i + " " + hostCol.getLong(i));
- }
- }
} else {
System.err.println("TYPE " + type + " NOT SUPPORTED FOR DEBUG PRINT");
}
}
+ private static void debugInteger(HostColumnVector hostCol, DType intType) {
+ for (int i = 0; i < hostCol.getRowCount(); i++) {
+ if (hostCol.isNull(i)) {
+ System.err.println(i + " NULL");
+ } else {
+ final int sizeInBytes = intType.getSizeInBytes();
+ final Object value;
+ switch (sizeInBytes) {
+ case Byte.BYTES:
+ value = hostCol.getByte(i);
+ break;
+ case Short.BYTES:
+ value = hostCol.getShort(i);
+ break;
+ case Integer.BYTES:
+ value = hostCol.getInt(i);
+ break;
+ case Long.BYTES:
+ value = hostCol.getLong(i);
+ break;
+ default:
+ throw new IllegalArgumentException("INFEASIBLE: Unsupported integer-like type " + intType);
+ }
+ System.err.println(i + " " + value);
+ }
+ }
+ }
+
private static HostColumnVector.DataType convertFrom(DataType spark, boolean nullable) {
if (spark instanceof ArrayType) {
ArrayType arrayType = (ArrayType) spark;
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index 2e3501462e9..37a093c746b 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -1857,7 +1857,7 @@ object GpuOverrides {
}
}
- override def convertToGpu(child: Expression): GpuExpression = GpuSum(child)
+ override def convertToGpu(child: Expression): GpuExpression = GpuSum(child, a.dataType)
}),
expr[Average](
"Average aggregate operator",
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala
index 07ee8a42786..1ba4a055956 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala
@@ -30,11 +30,12 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, ExprId, Nul
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile}
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, ShuffleManagerShimBase}
@@ -191,4 +192,13 @@ trait SparkShims {
def shouldFailDivByZero(): Boolean
def findOperators(plan: SparkPlan, predicate: SparkPlan => Boolean): Seq[SparkPlan]
+
+ def reusedExchangeExecPfn: PartialFunction[SparkPlan, ReusedExchangeExec]
+
+ /** dropped by SPARK-34234 */
+ def attachTreeIfSupported[TreeType <: TreeNode[_], A](
+ tree: TreeType,
+ msg: String = "")(
+ f: => A
+ ): A
}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala
index 32da9614133..0fcd3b4bfce 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.rapids
import ai.rapids.cudf
-import ai.rapids.cudf.{Aggregation, AggregationOnColumn, ColumnVector}
+import ai.rapids.cudf.{Aggregation, AggregationOnColumn, ColumnVector, DType}
import com.nvidia.spark.rapids._
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
@@ -192,10 +192,29 @@ class CudfCount(ref: Expression) extends CudfAggregate(ref) {
}
class CudfSum(ref: Expression) extends CudfAggregate(ref) {
+ // Up to 3.1.1, analyzed plan widened the input column type before applying
+ // aggregation. Thus even though we did not explicitly pass the output column type
+ // we did not run into integer overflow issues:
+ //
+ // == Analyzed Logical Plan ==
+ // sum(shorts): bigint
+ // Aggregate [sum(cast(shorts#77 as bigint)) AS sum(shorts)#94L]
+ //
+ // In Spark's main branch (3.2.0-SNAPSHOT as of this comment), analyzed logical plan
+ // no longer applies the cast to the input column such that the output column type has to
+ // be passed explicitly into aggregation
+ //
+ // == Analyzed Logical Plan ==
+ // sum(shorts): bigint
+ // Aggregate [sum(shorts#33) AS sum(shorts)#50L]
+ //
+ @transient val rapidsSumType: DType = GpuColumnVector.getNonNestedRapidsType(ref.dataType)
+
override val updateReductionAggregate: cudf.ColumnVector => cudf.Scalar =
- (col: cudf.ColumnVector) => col.sum
- override val mergeReductionAggregate: cudf.ColumnVector => cudf.Scalar =
- (col: cudf.ColumnVector) => col.sum
+ (col: cudf.ColumnVector) => col.sum(rapidsSumType)
+
+ override val mergeReductionAggregate: cudf.ColumnVector => cudf.Scalar = updateReductionAggregate
+
override lazy val updateAggregate: Aggregation = Aggregation.sum()
override lazy val mergeAggregate: Aggregation = Aggregation.sum()
override def toString(): String = "CudfSum"
@@ -329,12 +348,8 @@ case class GpuMax(child: Expression) extends GpuDeclarativeAggregate
Aggregation.max().onColumn(inputs.head._2)
}
-case class GpuSum(child: Expression)
+case class GpuSum(child: Expression, resultType: DataType)
extends GpuDeclarativeAggregate with ImplicitCastInputTypes with GpuAggregateWindowFunction {
- private lazy val resultType = child.dataType match {
- case _: DoubleType => DoubleType
- case _ => LongType
- }
private lazy val cudfSum = AttributeReference("sum", resultType)()
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala
index d43c433c577..d865d52d3b8 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala
@@ -26,8 +26,7 @@ import org.apache.spark.{MapOutputStatistics, ShuffleDependency}
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, SortOrder}
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
@@ -142,13 +141,14 @@ abstract class GpuShuffleExchangeExecBase(
protected override def doExecute(): RDD[InternalRow] =
throw new IllegalStateException(s"Row-based execution should not occur for $this")
- override def doExecuteColumnar(): RDD[ColumnarBatch] = attachTree(this, "execute") {
- // Returns the same ShuffleRowRDD if this plan is used by multiple plans.
- if (cachedShuffleRDD == null) {
- cachedShuffleRDD = new ShuffledBatchRDD(shuffleDependencyColumnar, metrics ++ readMetrics)
+ override def doExecuteColumnar(): RDD[ColumnarBatch] = ShimLoader.getSparkShims
+ .attachTreeIfSupported(this, "execute") {
+ // Returns the same ShuffleRowRDD if this plan is used by multiple plans.
+ if (cachedShuffleRDD == null) {
+ cachedShuffleRDD = new ShuffledBatchRDD(shuffleDependencyColumnar, metrics ++ readMetrics)
+ }
+ cachedShuffleRDD
}
- cachedShuffleRDD
- }
}
object GpuShuffleExchangeExec {
diff --git a/tests-spark310+/pom.xml b/tests-spark310+/pom.xml
index 82e651bb044..e05a532377f 100644
--- a/tests-spark310+/pom.xml
+++ b/tests-spark310+/pom.xml
@@ -28,10 +28,6 @@
rapids-4-spark-tests-next-spark_2.12
0.5.0-SNAPSHOT
-
- ${spark311.version}
-
-
org.apache.spark
diff --git a/tests/pom.xml b/tests/pom.xml
index 43a836a0905..0bd1229e09a 100644
--- a/tests/pom.xml
+++ b/tests/pom.xml
@@ -30,36 +30,6 @@
RAPIDS plugin for Apache Spark integration tests
0.5.0-SNAPSHOT
-
- ${spark300.version}
-
-
-
- spark301dbtests
-
- ${spark301db.version}
-
-
-
- spark301tests
-
- ${spark301.version}
-
-
-
- spark302tests
-
- ${spark302.version}
-
-
-
- spark311tests
-
- ${spark311.version}
-
-
-
-
org.slf4j
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala
index 4e0a50b5849..5de8e891450 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala
@@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfterEach
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, SparkPlan}
-import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, BroadcastQueryStageExec, ShuffleQueryStageExec}
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
@@ -94,10 +94,7 @@ class AdaptiveQueryExecSuite
}
private def findReusedExchange(plan: SparkPlan): Seq[ReusedExchangeExec] = {
- collectWithSubqueries(plan) {
- case ShuffleQueryStageExec(_, e: ReusedExchangeExec) => e
- case BroadcastQueryStageExec(_, e: ReusedExchangeExec) => e
- }
+ collectWithSubqueries(plan)(ShimLoader.getSparkShims.reusedExchangeExecPfn)
}
test("skewed inner join optimization") {