diff --git a/CHANGELOG.md b/CHANGELOG.md index 7cf8994f744..790eb254516 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,5 @@ # Change log -Generated on 2021-03-02 +Generated on 2021-03-16 ## Release 0.4 @@ -104,6 +104,12 @@ Generated on 2021-03-02 ### PRs ||| |:---|:---| +|[#1910](https://github.com/NVIDIA/spark-rapids/pull/1910)|Make hash partitioning match CPU| +|[#1927](https://github.com/NVIDIA/spark-rapids/pull/1927)|Change cuDF dependency to 0.18.1| +|[#1934](https://github.com/NVIDIA/spark-rapids/pull/1934)|Update documentation to use cudf version 0.18.1| +|[#1871](https://github.com/NVIDIA/spark-rapids/pull/1871)|Disable coalesce batch spilling to avoid cudf contiguous_split bug| +|[#1849](https://github.com/NVIDIA/spark-rapids/pull/1849)|Update changelog for 0.4| +|[#1744](https://github.com/NVIDIA/spark-rapids/pull/1744)|Fix NullPointerException on null partition insert| |[#1842](https://github.com/NVIDIA/spark-rapids/pull/1842)|Update to note support for 3.0.2| |[#1832](https://github.com/NVIDIA/spark-rapids/pull/1832)|Spark 3.1.1 shim no longer a snapshot shim| |[#1831](https://github.com/NVIDIA/spark-rapids/pull/1831)|Spark 3.0.2 shim no longer a snapshot shim| diff --git a/api_validation/README.md b/api_validation/README.md index 25bee6d650c..c6667ef0dca 100644 --- a/api_validation/README.md +++ b/api_validation/README.md @@ -17,7 +17,7 @@ It requires cudf, rapids-4-spark and spark jars. ``` cd api_validation -// To run validation script on all version of Spark(3.0.0, 3.0.1 and 3.1.0-SNAPSHOT) +// To run validation script on all version of Spark(3.0.0, 3.0.1 and 3.1.1) sh auditAllVersions.sh // To run script on particular version we can use profile(spark300, spark301 and spark311) diff --git a/docs/additional-functionality/cache-serializer.md b/docs/additional-functionality/cache-serializer.md index 4a91597ef6f..08f4c695a91 100644 --- a/docs/additional-functionality/cache-serializer.md +++ b/docs/additional-functionality/cache-serializer.md @@ -15,7 +15,7 @@ nav_order: 2 utilize disk space to spill over. To read more about what storage levels are available look at `StorageLevel.scala` in Spark. - Starting in Spark 3.1.0 users can add their own cache serializer, if they desire, by + Starting in Spark 3.1.1 users can add their own cache serializer, if they desire, by setting the `spark.sql.cache.serializer` configuration. This is a static configuration that is set once for the duration of a Spark application which means that you can only set the conf before starting a Spark application and cannot be changed for that application's Spark diff --git a/docs/demo/Databricks/generate-init-script.ipynb b/docs/demo/Databricks/generate-init-script.ipynb index ea79022702f..5672f3cdb54 100644 --- a/docs/demo/Databricks/generate-init-script.ipynb +++ b/docs/demo/Databricks/generate-init-script.ipynb @@ -1 +1 @@ -{"cells":[{"cell_type":"code","source":["dbutils.fs.mkdirs(\"dbfs:/databricks/init_scripts/\")\n \ndbutils.fs.put(\"/databricks/init_scripts/init.sh\",\"\"\"\n#!/bin/bash\nsudo wget -O /databricks/jars/rapids-4-spark_2.12-0.4.0.jar https://oss.sonatype.org/content/repositories/staging/com/nvidia/rapids-4-spark_2.12/0.4.0/rapids-4-spark_2.12-0.4.0.jar\nsudo wget -O /databricks/jars/cudf-0.18-cuda10-1.jar https://oss.sonatype.org/content/repositories/staging/ai/rapids/cudf/0.18/cudf-0.18-cuda10-1.jar\"\"\", True)"],"metadata":{},"outputs":[],"execution_count":1},{"cell_type":"code","source":["%sh\ncd ../../dbfs/databricks/init_scripts\npwd\nls -ltr\ncat init.sh"],"metadata":{},"outputs":[],"execution_count":2},{"cell_type":"code","source":[""],"metadata":{},"outputs":[],"execution_count":3}],"metadata":{"name":"generate-init-script","notebookId":2645746662301564},"nbformat":4,"nbformat_minor":0} +{"cells":[{"cell_type":"code","source":["dbutils.fs.mkdirs(\"dbfs:/databricks/init_scripts/\")\n \ndbutils.fs.put(\"/databricks/init_scripts/init.sh\",\"\"\"\n#!/bin/bash\nsudo wget -O /databricks/jars/rapids-4-spark_2.12-0.4.0.jar https://oss.sonatype.org/content/repositories/staging/com/nvidia/rapids-4-spark_2.12/0.4.0/rapids-4-spark_2.12-0.4.0.jar\nsudo wget -O /databricks/jars/cudf-0.18.1-cuda10-1.jar https://oss.sonatype.org/content/repositories/staging/ai/rapids/cudf/0.18.1/cudf-0.18.1-cuda10-1.jar\"\"\", True)"],"metadata":{},"outputs":[],"execution_count":1},{"cell_type":"code","source":["%sh\ncd ../../dbfs/databricks/init_scripts\npwd\nls -ltr\ncat init.sh"],"metadata":{},"outputs":[],"execution_count":2},{"cell_type":"code","source":[""],"metadata":{},"outputs":[],"execution_count":3}],"metadata":{"name":"generate-init-script","notebookId":2645746662301564},"nbformat":4,"nbformat_minor":0} diff --git a/docs/download.md b/docs/download.md index a0b55f76aa2..9170a8ba0e5 100644 --- a/docs/download.md +++ b/docs/download.md @@ -17,10 +17,10 @@ RAPIDS Accelerator For Apache Spark. See the [getting-started guide](https://nvi ## Release v0.4.0 ### Download v0.4.0 * Download [RAPIDS Accelerator For Apache Spark v0.4.0](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/0.4.0/rapids-4-spark_2.12-0.4.0.jar) -* Download RAPIDS cuDF 0.18 for your system: - * [For CUDA 11.0 & NVIDIA driver 450.36+](https://repo1.maven.org/maven2/ai/rapids/cudf/0.18/cudf-0.18-cuda11.jar) - * [For CUDA 10.2 & NVIDIA driver 440.33+](https://repo1.maven.org/maven2/ai/rapids/cudf/0.18/cudf-0.18-cuda10-2.jar) - * [For CUDA 10.1 & NVIDIA driver 418.87+](https://repo1.maven.org/maven2/ai/rapids/cudf/0.18/cudf-0.18-cuda10-1.jar) +* Download RAPIDS cuDF 0.18.1 for your system: + * [For CUDA 11.0 & NVIDIA driver 450.36+](https://repo1.maven.org/maven2/ai/rapids/cudf/0.18.1/cudf-0.18.1-cuda11.jar) + * [For CUDA 10.2 & NVIDIA driver 440.33+](https://repo1.maven.org/maven2/ai/rapids/cudf/0.18.1/cudf-0.18.1-cuda10-2.jar) + * [For CUDA 10.1 & NVIDIA driver 418.87+](https://repo1.maven.org/maven2/ai/rapids/cudf/0.18.1/cudf-0.18.1-cuda10-1.jar) ### Requirements Hardware Requirements: diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 72e245727ce..841fe7dd616 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -9719,19 +9719,19 @@ Accelerator support is described below. S S S +NS +NS +NS +NS S +S* S - - -S - -S - - - - - - +NS +NS +NS +NS +NS +NS result @@ -9764,17 +9764,17 @@ Accelerator support is described below. NS NS NS - - NS - NS - - - - - - +NS +NS +NS +NS +NS +NS +NS +NS +NS result diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py index ac1bc35d033..0abb3b72597 100644 --- a/integration_tests/src/main/python/data_gen.py +++ b/integration_tests/src/main/python/data_gen.py @@ -276,6 +276,21 @@ def start(self, rand): self._start(rand, self._loop_values) self._vals = [self._child.gen() for _ in range(0, self._length)] +class SetValuesGen(DataGen): + """A set of values that are randomly selected""" + def __init__(self, data_type, data): + super().__init__(data_type, nullable=False) + self.nullable = any(x is None for x in data) + self._vals = data + + def __repr__(self): + return super().__repr__() + '(' + str(self._child) + ')' + + def start(self, rand): + data = self._vals + length = len(data) + self._start(rand, lambda : data[rand.randrange(0, length)]) + FLOAT_MIN = -3.4028235E38 FLOAT_MAX = 3.4028235E38 NEG_FLOAT_NAN_MIN_VALUE = struct.unpack('f', struct.pack('I', 0xffffffff))[0] diff --git a/integration_tests/src/main/python/join_test.py b/integration_tests/src/main/python/join_test.py index a904c489be4..d68539a231e 100644 --- a/integration_tests/src/main/python/join_test.py +++ b/integration_tests/src/main/python/join_test.py @@ -201,3 +201,40 @@ def do_join(spark): return testurls.join(resolved, "Url", "inner") assert_gpu_and_cpu_are_equal_collect(do_join, conf={'spark.sql.autoBroadcastJoinThreshold': '-1'}) +@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti'], ids=idfn) +@pytest.mark.parametrize('cache_side', ['cache_left', 'cache_right'], ids=idfn) +@pytest.mark.parametrize('cpu_side', ['cache', 'not_cache'], ids=idfn) +@ignore_order +def test_half_cache_join(join_type, cache_side, cpu_side): + left_gen = [('a', SetValuesGen(LongType(), range(500))), ('b', IntegerGen())] + right_gen = [('r_a', SetValuesGen(LongType(), range(500))), ('c', LongGen())] + def do_join(spark): + # Try to force the shuffle to be split between CPU and GPU for the join + # so don't let the shuffle be on the GPU/CPU depending on how the test is configured + # when we repartition and cache the data + spark.conf.set('spark.rapids.sql.exec.ShuffleExchangeExec', cpu_side != 'cache') + left = gen_df(spark, left_gen, length=500) + right = gen_df(spark, right_gen, length=500) + + if (cache_side == 'cache_left'): + # Try to force the shuffle to be split between CPU and GPU for the join + # by default if the operation after the shuffle is not on the GPU then + # don't do a GPU shuffle, so do something simple after the repartition + # to make sure that the GPU shuffle is used. + left = left.repartition('a').selectExpr('b + 1 as b', 'a').cache() + left.count() # populate the cache + else: + #cache_right + # Try to force the shuffle to be split between CPU and GPU for the join + # by default if the operation after the shuffle is not on the GPU then + # don't do a GPU shuffle, so do something simple after the repartition + # to make sure that the GPU shuffle is used. + right = right.repartition('r_a').selectExpr('c + 1 as c', 'r_a').cache() + right.count() # populate the cache + # Now turn it back so the other half of the shuffle will be on the oposite side + spark.conf.set('spark.rapids.sql.exec.ShuffleExchangeExec', cpu_side == 'cache') + return left.join(right, left.a == right.r_a, join_type) + + # Even though Spark does not know the size of an RDD input so it will not do a broadcast join unless + # we tell it to, this is just to be safe + assert_gpu_and_cpu_are_equal_collect(do_join, {'spark.sql.autoBroadcastJoinThreshold': '1'}) diff --git a/integration_tests/src/main/python/repart_test.py b/integration_tests/src/main/python/repart_test.py index d3c89684f45..902ec7e9a75 100644 --- a/integration_tests/src/main/python/repart_test.py +++ b/integration_tests/src/main/python/repart_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2021, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ from asserts import assert_gpu_and_cpu_are_equal_collect from data_gen import * from marks import ignore_order +import pyspark.sql.functions as f @pytest.mark.parametrize('data_gen', all_gen, ids=idfn) def test_union(data_gen): @@ -45,3 +46,39 @@ def test_repartion_df(num_parts, length): assert_gpu_and_cpu_are_equal_collect( lambda spark : gen_df(spark, gen_list, length=length).repartition(num_parts), conf = allow_negative_scale_of_decimal_conf) + +@ignore_order(local=True) # To avoid extra data shuffle by 'sort on Spark' for this repartition test. +@pytest.mark.parametrize('num_parts', [1, 2, 10, 17, 19, 32], ids=idfn) +@pytest.mark.parametrize('gen', [ + ([('a', boolean_gen)], ['a']), + ([('a', byte_gen)], ['a']), + ([('a', short_gen)], ['a']), + ([('a', int_gen)], ['a']), + ([('a', long_gen)], ['a']), + pytest.param(([('a', float_gen)], ['a']), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/1914')), + pytest.param(([('a', double_gen)], ['a']), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/1914')), + ([('a', decimal_gen_default)], ['a']), + ([('a', decimal_gen_neg_scale)], ['a']), + ([('a', decimal_gen_scale_precision)], ['a']), + ([('a', decimal_gen_same_scale_precision)], ['a']), + ([('a', decimal_gen_64bit)], ['a']), + ([('a', string_gen)], ['a']), + ([('a', null_gen)], ['a']), + ([('a', byte_gen)], [f.col('a') - 5]), + ([('a', long_gen)], [f.col('a') + 15]), + ([('a', byte_gen), ('b', boolean_gen)], ['a', 'b']), + ([('a', short_gen), ('b', string_gen)], ['a', 'b']), + ([('a', int_gen), ('b', byte_gen)], ['a', 'b']), + ([('a', long_gen), ('b', null_gen)], ['a', 'b']), + ([('a', byte_gen), ('b', boolean_gen), ('c', short_gen)], ['a', 'b', 'c']), + ([('a', short_gen), ('b', string_gen), ('c', int_gen)], ['a', 'b', 'c']), + ([('a', decimal_gen_default), ('b', decimal_gen_64bit), ('c', decimal_gen_scale_precision)], ['a', 'b', 'c']), + ], ids=idfn) +def test_hash_repartition_exact(gen, num_parts): + data_gen = gen[0] + part_on = gen[1] + assert_gpu_and_cpu_are_equal_collect( + lambda spark : gen_df(spark, data_gen)\ + .repartition(num_parts, *part_on)\ + .selectExpr('spark_partition_id() as id', '*', 'hash(*)', 'pmod(hash(*),{})'.format(num_parts)), + conf = allow_negative_scale_of_decimal_conf) diff --git a/integration_tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala b/integration_tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala index 4656e8c2a81..13816c19b2c 100644 --- a/integration_tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala +++ b/integration_tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -101,65 +101,4 @@ class JoinsSuite extends SparkQueryCompareTestSuite { mixedDfWithNulls, mixedDfWithNulls, sortBeforeRepart = true) { (A, B) => A.join(B, A("longs") === B("longs"), "LeftAnti") } - - test("fixUpJoinConsistencyIfNeeded AQE on") { - // this test is only valid in Spark 3.0.1 and later due to AQE supporting the plugin - val isValidTestForSparkVersion = ShimLoader.getSparkShims.getSparkShimVersion match { - case SparkShimVersion(3, 0, 0) => false - case DatabricksShimVersion(3, 0, 0) => false - case _ => true - } - assume(isValidTestForSparkVersion) - testFixUpJoinConsistencyIfNeeded(true) - } - - test("fixUpJoinConsistencyIfNeeded AQE off") { - testFixUpJoinConsistencyIfNeeded(false) - } - - private def testFixUpJoinConsistencyIfNeeded(aqe: Boolean) { - - val conf = shuffledJoinConf.clone() - .set("spark.sql.adaptive.enabled", String.valueOf(aqe)) - .set("spark.rapids.sql.test.allowedNonGpu", - "BroadcastHashJoinExec,SortMergeJoinExec,SortExec,Upper") - .set("spark.rapids.sql.incompatibleOps.enabled", "false") // force UPPER onto CPU - - withGpuSparkSession(spark => { - import spark.implicits._ - - def createStringDF(name: String, upper: Boolean = false): DataFrame = { - val countryNames = (0 until 1000).map(i => s"country_$i") - if (upper) { - countryNames.map(_.toUpperCase).toDF(name) - } else { - countryNames.toDF(name) - } - } - - val left = createStringDF("c1") - .join(createStringDF("c2"), col("c1") === col("c2")) - - val right = createStringDF("c3") - .join(createStringDF("c4"), col("c3") === col("c4")) - - val join = left.join(right, upper(col("c1")) === col("c4")) - - // call collect so that we get the final executed plan when AQE is on - join.collect() - - val shuffleExec = TestUtils - .findOperator(join.queryExecution.executedPlan, _.isInstanceOf[ShuffleExchangeExec]) - .get - - val gpuSupportedTag = TreeNodeTag[Set[String]]("rapids.gpu.supported") - val reasons = shuffleExec.getTagValue(gpuSupportedTag).getOrElse(Set.empty) - assert(reasons.contains( - "other exchanges that feed the same join are on the CPU, and GPU " + - "hashing is not consistent with the CPU version")) - - }, conf) - - } - } diff --git a/integration_tests/src/test/scala/com/nvidia/spark/rapids/UnaryOperatorsSuite.scala b/integration_tests/src/test/scala/com/nvidia/spark/rapids/UnaryOperatorsSuite.scala index a352fa91297..40a85b4574b 100644 --- a/integration_tests/src/test/scala/com/nvidia/spark/rapids/UnaryOperatorsSuite.scala +++ b/integration_tests/src/test/scala/com/nvidia/spark/rapids/UnaryOperatorsSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,6 +37,6 @@ class UnaryOperatorsSuite extends SparkQueryCompareTestSuite { } testSparkResultsAreEqual("Test murmur3", mixedDfWithNulls) { - frame => frame.selectExpr("hash(longs, doubles, 1, null, 'stock string', ints, strings)") + frame => frame.selectExpr("hash(longs, 1, null, 'stock string', ints, strings)") } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashPartitioning.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashPartitioning.scala index 225edcefd7c..ebb764270bb 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashPartitioning.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashPartitioning.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,11 +18,11 @@ package com.nvidia.spark.rapids import scala.collection.mutable.ArrayBuffer -import ai.rapids.cudf.{ColumnVector, NvtxColor, NvtxRange, Table} -import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import ai.rapids.cudf.{ColumnVector, DType, NvtxColor, NvtxRange, Table} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, HashClusteredDistribution} +import org.apache.spark.sql.rapids.GpuMurmur3Hash import org.apache.spark.sql.types.{DataType, IntegerType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -47,96 +47,67 @@ case class GpuHashPartitioning(expressions: Seq[Expression], numPartitions: Int) } } - def getGpuKeyColumns(batch: ColumnarBatch) : Array[GpuColumnVector] = { - expressions.map(_.columnarEval(batch) - .asInstanceOf[GpuColumnVector]).toArray - } - - def getGpuDataColumns(batch: ColumnarBatch) : Array[GpuColumnVector] = - GpuColumnVector.extractColumns(batch) - - def insertDedupe( - indexesOut: Array[Int], - colsIn: Array[GpuColumnVector], - dedupedData: ArrayBuffer[ColumnVector]): Unit = { - indexesOut.indices.foreach { i => - val b = colsIn(i).getBase - val idx = dedupedData.indexOf(b) - if (idx < 0) { - indexesOut(i) = dedupedData.size - dedupedData += b - } else { - indexesOut(i) = idx - } - } - } - - def dedupe(keyCols: Array[GpuColumnVector], dataCols: Array[GpuColumnVector]): - (Array[Int], Array[Int], Table) = { - val base = new ArrayBuffer[ColumnVector](keyCols.length + dataCols.length) - val keys = new Array[Int](keyCols.length) - val data = new Array[Int](dataCols.length) - - insertDedupe(keys, keyCols, base) - insertDedupe(data, dataCols, base) - - (keys, data, new Table(base: _*)) - } - - def partitionInternal(batch: ColumnarBatch): (Array[Int], Array[GpuColumnVector]) = { - var gpuKeyColumns : Array[GpuColumnVector] = null - var gpuDataColumns : Array[GpuColumnVector] = null - try { - gpuKeyColumns = getGpuKeyColumns(batch) - gpuDataColumns = getGpuDataColumns(batch) - val sparkTypes = gpuDataColumns.map(_.dataType()) - - val (keys, dataIndexes, table) = dedupe(gpuKeyColumns, gpuDataColumns) - // Don't need the batch any more table has all we need in it. - // gpuDataColumns did not increment the reference count when we got them, - // so don't close them. - gpuDataColumns = null - gpuKeyColumns.foreach(_.close()) - gpuKeyColumns = null - batch.close() - - val partedTable = table.onColumns(keys: _*).hashPartition(numPartitions) - table.close() - val parts = partedTable.getPartitions - val columns = dataIndexes.zip(sparkTypes).map { case (idx, sparkType) => - GpuColumnVector.from(partedTable.getColumn(idx).incRefCount(), sparkType) - } - partedTable.close() - (parts, columns) - } finally { - if (gpuDataColumns != null) { - gpuDataColumns.safeClose() - } - if (gpuKeyColumns != null) { - gpuKeyColumns.safeClose() - } - } - } - override def columnarEval(batch: ColumnarBatch): Any = { // We are doing this here because the cudf partition command is at this level - val totalRange = new NvtxRange("Hash partition", NvtxColor.PURPLE) - try { - val numRows = batch.numRows - val (partitionIndexes, partitionColumns) = { - val partitionRange = new NvtxRange("partition", NvtxColor.BLUE) - try { - partitionInternal(batch) - } finally { - partitionRange.close() + val numRows = batch.numRows + withResource(new NvtxRange("Hash partition", NvtxColor.PURPLE)) { _ => + val sortedTable = withResource(batch) { batch => + val parts = withResource(new NvtxRange("Calculate part", NvtxColor.CYAN)) { _ => + withResource(GpuMurmur3Hash.compute(batch, expressions)) { hash => + withResource(GpuScalar.from(numPartitions, IntegerType)) { partsLit => + hash.pmod(partsLit, DType.INT32) + } + } + } + withResource(new NvtxRange("sort by part", NvtxColor.DARK_GREEN)) { _ => + withResource(parts) { parts => + val allColumns = new ArrayBuffer[ColumnVector](batch.numCols() + 1) + allColumns += parts + allColumns ++= GpuColumnVector.extractBases(batch) + withResource(new Table(allColumns: _*)) { fullTable => + fullTable.orderBy(Table.asc(0)) + } + } + } + } + val (partitionIndexes, partitionColumns) = withResource(sortedTable) { sortedTable => + val cutoffs = withResource(new Table(sortedTable.getColumn(0))) { justPartitions => + val partsTable = withResource(GpuScalar.from(0, IntegerType)) { zeroLit => + withResource(ColumnVector.sequence(zeroLit, numPartitions)) { partsColumn => + new Table(partsColumn) + } + } + withResource(partsTable) { partsTable => + justPartitions.upperBound(Array(false), partsTable, Array(false)) + } + } + val partitionIndexes = withResource(cutoffs) { cutoffs => + val buffer = new ArrayBuffer[Int](numPartitions) + // The first index is always 0 + buffer += 0 + withResource(cutoffs.copyToHost()) { hostCutoffs => + (0 until numPartitions).foreach { i => + buffer += hostCutoffs.getInt(i) + } + } + buffer.toArray } + val dataTypes = GpuColumnVector.extractTypes(batch) + closeOnExcept(new ArrayBuffer[GpuColumnVector]()) { partitionColumns => + (1 until sortedTable.getNumberOfColumns).foreach { index => + partitionColumns += + GpuColumnVector.from(sortedTable.getColumn(index).incRefCount(), + dataTypes(index - 1)) + } + + (partitionIndexes, partitionColumns.toArray) + } + } + val ret = withResource(partitionColumns) { partitionColumns => + sliceInternalGpuOrCpu(numRows, partitionIndexes, partitionColumns) } - val ret = sliceInternalGpuOrCpu(numRows, partitionIndexes, partitionColumns) - partitionColumns.safeClose() // Close the partition columns we copied them as a part of the slice ret.zipWithIndex.filter(_._1 != null) - } finally { - totalRange.close() } } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 906aaf48d46..691aa93ac8a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2281,14 +2281,15 @@ object GpuOverrides { "Murmur3 hash operator", ExprChecks.projectNotLambda(TypeSig.INT, TypeSig.INT, repeatingParamCheck = Some(RepeatingParamCheck("input", + // Floating point values don't work because of -0.0 is not hashed properly TypeSig.BOOLEAN + TypeSig.BYTE + TypeSig.SHORT + TypeSig.INT + TypeSig.LONG + - TypeSig.FLOAT + TypeSig.DOUBLE + TypeSig.STRING + TypeSig.NULL, - TypeSig.BOOLEAN + TypeSig.BYTE + TypeSig.SHORT + TypeSig.INT + TypeSig.LONG + - TypeSig.FLOAT + TypeSig.DOUBLE + TypeSig.STRING + TypeSig.NULL))), + TypeSig.STRING + TypeSig.NULL + TypeSig.DECIMAL, + TypeSig.all))), (a, conf, p, r) => new ExprMeta[Murmur3Hash](a, conf, p, r) { override val childExprs: Seq[BaseExprMeta[_]] = a.children .map(GpuOverrides.wrapExpr(_, conf, Some(this))) - def convertToGpu(): GpuExpression = GpuMurmur3Hash(childExprs.map(_.convertToGpu())) + def convertToGpu(): GpuExpression = + GpuMurmur3Hash(childExprs.map(_.convertToGpu()), a.seed) }), expr[Contains]( "Contains", @@ -2440,6 +2441,18 @@ object GpuOverrides { override val childExprs: Seq[BaseExprMeta[_]] = hp.expressions.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + override def tagPartForGpu(): Unit = { + // This needs to match what murmur3 supports. + // TODO In 0.5 we should make the checks self documenting, and look more like what + // SparkPlan and Expression support + // https://github.com/NVIDIA/spark-rapids/issues/1915 + val sig = TypeSig.BOOLEAN + TypeSig.BYTE + TypeSig.SHORT + TypeSig.INT + TypeSig.LONG + + TypeSig.STRING + TypeSig.NULL + TypeSig.DECIMAL + hp.children.foreach { child => + sig.tagExprParam(this, child, "hash_key") + } + } + override def convertToGpu(): GpuPartitioning = GpuHashPartitioning(childExprs.map(_.convertToGpu()), hp.numPartitions) }), @@ -2454,7 +2467,7 @@ object GpuOverrides { val gpuOrdering = childExprs.map(_.convertToGpu()).asInstanceOf[Seq[SortOrder]] GpuRangePartitioning(gpuOrdering, rp.numPartitions) } else { - GpuSinglePartitioning(childExprs.map(_.convertToGpu())) + GpuSinglePartitioning } } }), @@ -2468,11 +2481,7 @@ object GpuOverrides { part[SinglePartition.type]( "Single partitioning", (sp, conf, p, r) => new PartMeta[SinglePartition.type](sp, conf, p, r) { - override val childExprs: Seq[ExprMeta[_]] = Seq.empty[ExprMeta[_]] - - override def convertToGpu(): GpuPartitioning = { - GpuSinglePartitioning(childExprs.map(_.convertToGpu())) - } + override def convertToGpu(): GpuPartitioning = GpuSinglePartitioning }) ).map(r => (r.getClassFor.asSubclass(classOf[Partitioning]), r)).toMap @@ -2581,7 +2590,7 @@ object GpuOverrides { GpuTopN(takeExec.limit, so, projectList.map(_.convertToGpu().asInstanceOf[NamedExpression]), - ShimLoader.getSparkShims.getGpuShuffleExchangeExec(GpuSinglePartitioning(Seq.empty), + ShimLoader.getSparkShims.getGpuShuffleExchangeExec(GpuSinglePartitioning, GpuTopN(takeExec.limit, so, takeExec.child.output, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSinglePartitioning.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSinglePartitioning.scala index cea112c3c5b..c204769de5c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSinglePartitioning.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSinglePartitioning.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,11 +17,11 @@ package com.nvidia.spark.rapids import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution} import org.apache.spark.sql.types.{DataType, IntegerType} import org.apache.spark.sql.vectorized.ColumnarBatch -case class GpuSinglePartitioning(expressions: Seq[Expression]) - extends GpuExpression with GpuPartitioning { +case object GpuSinglePartitioning extends GpuExpression with GpuPartitioning { /** * Returns the result of evaluating this expression on the entire `ColumnarBatch`. * The result of calling this may be a single [[GpuColumnVector]] or a scalar value. @@ -37,7 +37,7 @@ case class GpuSinglePartitioning(expressions: Seq[Expression]) if (batch.numCols == 0) { Array(batch).zipWithIndex } else { - try { + withResource(batch) { batch => // Nothing needs to be sliced but a contiguous table is needed for GPU shuffle which // slice will produce. val sliced = sliceInternalGpuOrCpu( @@ -45,8 +45,6 @@ case class GpuSinglePartitioning(expressions: Seq[Expression]) Array(0), GpuColumnVector.extractColumns(batch)) sliced.zipWithIndex - } finally { - batch.close() } } } @@ -57,5 +55,10 @@ case class GpuSinglePartitioning(expressions: Seq[Expression]) override val numPartitions: Int = 1 - override def children: Seq[Expression] = expressions + override def children: Seq[Expression] = Seq.empty + + override def satisfies0(required: Distribution): Boolean = required match { + case _: BroadcastDistribution => false + case _ => true + } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala index c7ae798fbaf..fda3f2bd72c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -25,14 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.connector.read.Scan -import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} -import org.apache.spark.sql.execution.adaptive.{QueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.command.DataWritingCommand -import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.window.WindowExecBase -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType trait DataFromReplacementRule { @@ -536,93 +533,11 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT, wrapped.withNewChildren(childPlans.map(_.convertIfNeeded())) } - private def findShuffleExchanges(): Seq[Either[ - SparkPlanMeta[QueryStageExec], - SparkPlanMeta[ShuffleExchangeExec]]] = wrapped match { - case _: ShuffleQueryStageExec => - Left(this.asInstanceOf[SparkPlanMeta[QueryStageExec]]) :: Nil - case _: ShuffleExchangeExec => - Right(this.asInstanceOf[SparkPlanMeta[ShuffleExchangeExec]]) :: Nil - case bkj: BroadcastHashJoinExec => ShimLoader.getSparkShims.getBuildSide(bkj) match { - case GpuBuildLeft => childPlans(1).findShuffleExchanges() - case GpuBuildRight => childPlans(0).findShuffleExchanges() - } - case _ => childPlans.flatMap(_.findShuffleExchanges()) - } - - private def findBucketedReads(): Seq[Boolean] = wrapped match { - case f: FileSourceScanExec => - if (f.bucketedScan) { - true :: Nil - } else { - false :: Nil - } - case _: ShuffleExchangeExec => - // if we find a shuffle before a scan then it doesn't matter if its - // a bucketed read - false :: Nil - case _ => - childPlans.flatMap(_.findBucketedReads()) - } - - private def makeShuffleConsistent(): Unit = { - // during query execution when AQE is enabled, the plan could consist of a mixture of - // ShuffleExchangeExec nodes for exchanges that have not started executing yet, and - // ShuffleQueryStageExec nodes for exchanges that have already started executing. This code - // attempts to tag ShuffleExchangeExec nodes for CPU if other exchanges (either - // ShuffleExchangeExec or ShuffleQueryStageExec nodes) were also tagged for CPU. - val shuffleExchanges = findShuffleExchanges() - // if any of the table reads are bucketed then we can't do the shuffle on the - // GPU because the hashing is different between the CPU and GPU - val bucketedReads = findBucketedReads().exists(_ == true) - - def canThisBeReplaced(plan: Either[ - SparkPlanMeta[QueryStageExec], - SparkPlanMeta[ShuffleExchangeExec]]): Boolean = { - plan match { - case Left(qs) => qs.wrapped.plan match { - case _: GpuExec => true - case ReusedExchangeExec(_, _: GpuExec) => true - case _ => false - } - case Right(e) => e.canThisBeReplaced - } - } - - // if we are reading from a bucketed table or if we can't convert all exchanges to GPU - // then we need to make sure that all of them run on the CPU instead - if (bucketedReads || !shuffleExchanges.forall(canThisBeReplaced)) { - val errMsg = if (bucketedReads) { - "can't support shuffle on the GPU when doing a join that reads directly from a " + - "bucketed table!" - } else { - "other exchanges that feed the same join are on the CPU, and GPU hashing is " + - "not consistent with the CPU version" - } - // tag any exchanges that have not been converted to query stages yet - shuffleExchanges.filter(_.isRight).foreach(_.right.get.willNotWorkOnGpu(errMsg)) - // verify that no query stages already got converted to GPU - if (shuffleExchanges.filter(_.isLeft).exists(canThisBeReplaced)) { - throw new IllegalStateException("Join needs to run on CPU but at least one input " + - "query stage ran on GPU") - } - } - } - def getReasonsNotToReplaceEntirePlan: Seq[String] = { val childReasons = childPlans.flatMap(_.getReasonsNotToReplaceEntirePlan) entirePlanExcludedReasons ++ childReasons } - private def fixUpJoinConsistencyIfNeeded(): Unit = { - childPlans.foreach(_.fixUpJoinConsistencyIfNeeded()) - wrapped match { - case _: ShuffledHashJoinExec => makeShuffleConsistent() - case _: SortMergeJoinExec => makeShuffleConsistent() - case _ => () - } - } - // For adaptive execution we have to ensure we mark everything properly // the first time through and that has to match what happens when AQE // splits things up and does the subquery analysis at the shuffle boundaries. @@ -659,19 +574,12 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT, // have to be very careful to avoid loops in the rules. // RULES: - // 1) BroadcastHashJoin can disable the Broadcast directly feeding it, if the join itself cannot - // be translated for some reason. This is okay because it is the joins immediate parent, so - // it can keep everything consistent. - // 2) For ShuffledHashJoin and SortMergeJoin we need to verify that all of the exchanges + // 1) For ShuffledHashJoin and SortMergeJoin we need to verify that all of the exchanges // feeding them are either all on the GPU or all on the CPU, because the hashing is not // consistent between the two implementations. This is okay because it is only impacting // shuffled exchanges. So broadcast exchanges are not impacted which could have an impact on // BroadcastHashJoin, and shuffled exchanges are not used to disable anything downstream. - // 3) If a shuffled exchange is not columnar on at least one side don't do it. This must happen - // before the join consistency or we risk running into issues with disabling one exchange that - // would make a join inconsistent fixUpExchangeOverhead() - fixUpJoinConsistencyIfNeeded() } override final def tagSelfForGpu(): Unit = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/limit.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/limit.scala index 93ae20c8ab0..4e1da31570a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/limit.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/limit.scala @@ -135,7 +135,7 @@ class GpuCollectLimitMeta( override def convertToGpu(): GpuExec = GpuGlobalLimitExec(collectLimit.limit, - ShimLoader.getSparkShims.getGpuShuffleExchangeExec(GpuSinglePartitioning(Seq.empty), + ShimLoader.getSparkShims.getGpuShuffleExchangeExec(GpuSinglePartitioning, GpuLocalLimitExec(collectLimit.limit, childPlans.head.convertIfNeeded()))) } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/HashFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/HashFunctions.scala index c99bc8681fe..16b7026b441 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/HashFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/HashFunctions.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,16 +16,12 @@ package org.apache.spark.sql.rapids -import scala.collection.mutable.ArrayBuffer - -import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, Scalar} -import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar, GpuUnaryExpression} -import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType} +import com.nvidia.spark.rapids.{Arm, GpuCast, GpuColumnVector, GpuExpression, GpuIf, GpuIsNan, GpuLiteral, GpuProjectExec, GpuUnaryExpression, GpuUnscaledValue} import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, NullIntolerant} import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.unsafe.types.UTF8String case class GpuMd5(child: Expression) extends GpuUnaryExpression with ImplicitCastInputTypes with NullIntolerant { @@ -40,31 +36,42 @@ case class GpuMd5(child: Expression) } } -case class GpuMurmur3Hash(child: Seq[Expression]) extends GpuExpression { +object GpuMurmur3Hash extends Arm { + def compute(batch: ColumnarBatch, boundExpr: Seq[Expression], seed: Int = 42): ColumnVector = { + val newExprs = boundExpr.map { expr => + expr.dataType match { + case ByteType | ShortType => + GpuCast(expr, IntegerType) + case DoubleType => + // We have to normalize the NaNs, but not zeros + // however the current cudf code does the wrong thing for -0.0 + // https://github.com/NVIDIA/spark-rapids/issues/1914 + GpuIf(GpuIsNan(expr), GpuLiteral(Double.NaN, DoubleType), expr) + case FloatType => + // We have to normalize the NaNs, but not zeros + // however the current cudf code does the wrong thing for -0.0 + // https://github.com/NVIDIA/spark-rapids/issues/1914 + GpuIf(GpuIsNan(expr), GpuLiteral(Float.NaN, FloatType), expr) + case dt: DecimalType if dt.precision <= DType.DECIMAL64_MAX_PRECISION => + // For these values it is just hashing it as a long + GpuUnscaledValue(expr) + case _ => + expr + } + } + withResource(GpuProjectExec.project(batch, newExprs)) { args => + val bases = GpuColumnVector.extractBases(args) + ColumnVector.spark32BitMurmurHash3(seed, bases.toArray[ColumnView]) + } + } +} + +case class GpuMurmur3Hash(children: Seq[Expression], seed: Int) extends GpuExpression { override def dataType: DataType = IntegerType - override def toString: String = s"hash($child)" + override def toString: String = s"hash($children)" def nullable: Boolean = children.exists(_.nullable) - def children: Seq[Expression] = child - def columnarEval(batch: ColumnarBatch): Any = { - val rows = batch.numRows() - val columns: ArrayBuffer[ColumnVector] = new ArrayBuffer[ColumnVector]() - try { - children.foreach { child => child.columnarEval(batch) match { - case vector: GpuColumnVector => - columns += vector.getBase - case col => if (col != null) { - withResource(GpuScalar.from(col)) { scalarValue => - columns += ai.rapids.cudf.ColumnVector.fromScalar(scalarValue, rows) - } - } - } - } - GpuColumnVector.from( - ColumnVector.spark32BitMurmurHash3(42, columns.toArray[ColumnView]), dataType) - } finally { - columns.safeClose() - } - } + def columnarEval(batch: ColumnarBatch): Any = + GpuColumnVector.from(GpuMurmur3Hash.compute(batch, children, seed), dataType) } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala index 30f40f2b4cf..ee9a639b683 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala @@ -274,8 +274,8 @@ object GpuShuffleExchangeExec { rdd, SQLConf.get.rangeExchangeSampleSizePerPartition) // No need to bind arguments for the GpuRangePartitioner. The Sorter has already done it new GpuRangePartitioner(bounds, sorter) - case s: GpuSinglePartitioning => - GpuBindReferences.bindReference(s, outputAttributes) + case GpuSinglePartitioning => + GpuSinglePartitioning case rrp: GpuRoundRobinPartitioning => GpuBindReferences.bindReference(rrp, outputAttributes) case _ => sys.error(s"Exchange not implemented for $newPartitioning") diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala index fd74438ec77..eb589dd277c 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala @@ -121,7 +121,13 @@ class AdaptiveQueryExecSuite spark, "SELECT * FROM skewData1 join skewData2 ON key1 = key2") val innerSmj = findTopLevelGpuShuffleHashJoin(innerAdaptivePlan) - checkSkewJoin(innerSmj, 2, 1) + // Spark changed how skewed joins work and now the numbers are different + // depending on the version being used + if (cmpSparkVersion(3,1,1) >= 0) { + checkSkewJoin(innerSmj, 2, 1) + } else { + checkSkewJoin(innerSmj, 1, 1) + } } } @@ -131,7 +137,13 @@ class AdaptiveQueryExecSuite spark, "SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2") val leftSmj = findTopLevelGpuShuffleHashJoin(leftAdaptivePlan) - checkSkewJoin(leftSmj, 2, 0) + // Spark changed how skewed joins work and now the numbers are different + // depending on the version being used + if (cmpSparkVersion(3,1,1) >= 0) { + checkSkewJoin(leftSmj, 2, 0) + } else { + checkSkewJoin(leftSmj, 1, 0) + } } } @@ -455,26 +467,22 @@ class AdaptiveQueryExecSuite } /** most of the AQE tests requires Spark 3.0.1 or later */ - private def assumeSpark301orLater = { - val sparkShimVersion = ShimLoader.getSparkShims.getSparkShimVersion - val isValidTestForSparkVersion = sparkShimVersion match { - case SparkShimVersion(3, 0, 0) => false - case DatabricksShimVersion(3, 0, 0) => false - case _ => true - } - assume(isValidTestForSparkVersion, "SPARK 3.1.0 or later required") - } + private def assumeSpark301orLater = + assume(cmpSparkVersion(3, 0, 1) >= 0) + + private def assumePriorToSpark320 = + assume(cmpSparkVersion(3, 2, 0) < 0) - private def assumePriorToSpark320 = { + private def cmpSparkVersion(major: Int, minor: Int, bugfix: Int): Int = { val sparkShimVersion = ShimLoader.getSparkShims.getSparkShimVersion - val isValidTestForSparkVersion = sparkShimVersion match { - case ver: SparkShimVersion => - (ver.major == 3 && ver.minor < 2) || ver.major < 3 - case ver: DatabricksShimVersion => - (ver.major == 3 && ver.minor < 2) || ver.major < 3 - case _ => true + val (sparkMajor, sparkMinor, sparkBugfix) = sparkShimVersion match { + case SparkShimVersion(a, b, c) => (a, b, c) + case DatabricksShimVersion(a, b, c) => (a, b, c) + case EMRShimVersion(a, b, c) => (a, b, c) } - assume(isValidTestForSparkVersion, "Prior to SPARK 3.2.0 required") + val fullVersion = ((major.toLong * 1000) + minor) * 1000 + bugfix + val sparkFullVersion = ((sparkMajor.toLong * 1000) + sparkMinor) * 1000 + sparkBugfix + sparkFullVersion.compareTo(fullVersion) } def checkSkewJoin( diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuSinglePartitioningSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuSinglePartitioningSuite.scala index 2f652c9b46a..73889400f9a 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuSinglePartitioningSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuSinglePartitioningSuite.scala @@ -45,7 +45,7 @@ class GpuSinglePartitioningSuite extends FunSuite with Arm { .set(RapidsConf.SHUFFLE_COMPRESSION_CODEC.key, "none") TestUtils.withGpuSparkSession(conf) { _ => GpuShuffleEnv.init(new RapidsConf(conf)) - val partitioner = GpuSinglePartitioning(Nil) + val partitioner = GpuSinglePartitioning withResource(buildBatch()) { batch => withResource(GpuColumnVector.from(batch)) { table => withResource(table.contiguousSplit()) { contigTables =>