Skip to content

Commit

Permalink
verify shuffle of decimal type data (#1193)
Browse files Browse the repository at this point in the history
Signed-off-by: sperlingxx <lovedreamf@gmail.com>
  • Loading branch information
sperlingxx authored Dec 1, 2020
1 parent d3d3456 commit 84b5540
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
package com.nvidia.spark.rapids

import java.io.File
import java.math.RoundingMode

import ai.rapids.cudf.{DType, Table}
import org.scalatest.FunSuite

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.rapids.{GpuShuffleEnv, RapidsDiskBlockManager}
import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType}
import org.apache.spark.sql.types.{DecimalType, DoubleType, IntegerType, StringType}
import org.apache.spark.sql.vectorized.ColumnarBatch

class GpuPartitioningSuite extends FunSuite with Arm {
Expand All @@ -33,8 +34,11 @@ class GpuPartitioningSuite extends FunSuite with Arm {
.column(5, null.asInstanceOf[java.lang.Integer], 3, 1, 1, 1, 1, 1, 1, 1)
.column("five", "two", null, null, "one", "one", "one", "one", "one", "one")
.column(5.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)
.decimal64Column(-3, RoundingMode.UNNECESSARY ,
5.1, null, 3.3, 4.4e2, 0, -2.1e-1, 1.111, 2.345, null, 1.23e3)
.build()) { table =>
GpuColumnVector.from(table, Array(IntegerType, StringType, DoubleType))
GpuColumnVector.from(table, Array(IntegerType, StringType, DoubleType,
DecimalType(DType.DECIMAL64_MAX_PRECISION, 3)))
}
}

Expand All @@ -53,8 +57,14 @@ class GpuPartitioningSuite extends FunSuite with Arm {
val expectedColumns = GpuColumnVector.extractBases(expected)
val actualColumns = GpuColumnVector.extractBases(expected)
expectedColumns.zip(actualColumns).foreach { case (expected, actual) =>
withResource(expected.equalToNullAware(actual)) { compareVector =>
withResource(compareVector.all(DType.BOOL8)) { compareResult =>
// FIXME: For decimal types, NULL_EQUALS has not been supported in cuDF yet
val cpVec = if (expected.getType.isDecimalType) {
expected.equalTo(actual)
} else {
expected.equalToNullAware(actual)
}
withResource(cpVec) { compareVector =>
withResource(compareVector.all()) { compareResult =>
assert(compareResult.getBoolean)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@

package com.nvidia.spark.rapids

import java.math.RoundingMode

import ai.rapids.cudf.Table
import org.scalatest.FunSuite

import org.apache.spark.SparkConf
import org.apache.spark.sql.rapids.GpuShuffleEnv
import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType}
import org.apache.spark.sql.types.{DecimalType, DoubleType, IntegerType, StringType}
import org.apache.spark.sql.vectorized.ColumnarBatch

class GpuSinglePartitioningSuite extends FunSuite with Arm {
Expand All @@ -30,8 +32,11 @@ class GpuSinglePartitioningSuite extends FunSuite with Arm {
.column(5, null.asInstanceOf[java.lang.Integer], 3, 1, 1, 1, 1, 1, 1, 1)
.column("five", "two", null, null, "one", "one", "one", "one", "one", "one")
.column(5.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)
.decimal64Column(-3, RoundingMode.UNNECESSARY ,
5.1, null, 3.3, 4.4e2, 0, -2.1e-1, 1.111, 2.345, null, 1.23e3)
.build()) { table =>
GpuColumnVector.from(table, Array(IntegerType, StringType, DoubleType))
GpuColumnVector.from(table, Array(IntegerType, StringType, DoubleType,
DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 3)))
}
}

Expand Down
18 changes: 11 additions & 7 deletions tests/src/test/scala/com/nvidia/spark/rapids/MetaUtilsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

package com.nvidia.spark.rapids

import java.math.RoundingMode

import ai.rapids.cudf.{BufferType, ContiguousTable, DeviceMemoryBuffer, Table}
import com.nvidia.spark.rapids.format.{CodecType, ColumnMeta}
import org.scalatest.FunSuite

import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, StringType, StructType}
import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, IntegerType, StringType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch

class MetaUtilsSuite extends FunSuite with Arm {
Expand All @@ -29,6 +31,7 @@ class MetaUtilsSuite extends FunSuite with Arm {
.column(5, null.asInstanceOf[java.lang.Integer], 3, 1)
.column("five", "two", null, null)
.column(5.0, 2.0, 3.0, 1.0)
.decimal64Column(-5, RoundingMode.UNNECESSARY, 0, null, -1.4, 10.123)
.build()) { table =>
table.contiguousSplit()(0)
}
Expand Down Expand Up @@ -109,12 +112,12 @@ class MetaUtilsSuite extends FunSuite with Arm {
}

test("buildDegenerateTableMeta no rows") {
val schema = StructType.fromDDL("a INT, b STRING, c DOUBLE")
val schema = StructType.fromDDL("a INT, b STRING, c DOUBLE, d DECIMAL(15, 5)")
withResource(GpuColumnVector.emptyBatch(schema)) { batch =>
val meta = MetaUtils.buildDegenerateTableMeta(batch)
assertResult(null)(meta.bufferMeta)
assertResult(0)(meta.rowCount)
assertResult(3)(meta.columnMetasLength)
assertResult(4)(meta.columnMetasLength)
(0 until meta.columnMetasLength).foreach { i =>
val columnMeta = meta.columnMetas(i)
assertResult(0)(columnMeta.nullCount)
Expand All @@ -130,7 +133,7 @@ class MetaUtilsSuite extends FunSuite with Arm {
}

test("buildDegenerateTableMeta no rows compressed table") {
val schema = StructType.fromDDL("a INT, b STRING, c DOUBLE")
val schema = StructType.fromDDL("a INT, b STRING, c DOUBLE, d DECIMAL(15, 5)")
withResource(GpuColumnVector.emptyBatch(schema)) { uncompressedBatch =>
val uncompressedMeta = MetaUtils.buildDegenerateTableMeta(uncompressedBatch)
withResource(DeviceMemoryBuffer.allocate(0)) { buffer =>
Expand All @@ -140,7 +143,7 @@ class MetaUtilsSuite extends FunSuite with Arm {
val meta = MetaUtils.buildDegenerateTableMeta(batch)
assertResult(null)(meta.bufferMeta)
assertResult(0)(meta.rowCount)
assertResult(3)(meta.columnMetasLength)
assertResult(4)(meta.columnMetasLength)
(0 until meta.columnMetasLength).foreach { i =>
val columnMeta = meta.columnMetas(i)
assertResult(0)(columnMeta.nullCount)
Expand All @@ -163,9 +166,10 @@ class MetaUtilsSuite extends FunSuite with Arm {
val table = contigTable.getTable
val origBuffer = contigTable.getBuffer
val meta = MetaUtils.buildTableMeta(10, table, origBuffer)
val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType,
DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5))
withResource(origBuffer.sliceWithCopy(0, origBuffer.getLength)) { buffer =>
withResource(MetaUtils.getBatchFromMeta(buffer, meta,
Array[DataType](IntegerType, StringType, DoubleType))) { batch =>
withResource(MetaUtils.getBatchFromMeta(buffer, meta, sparkTypes)) { batch =>
assertResult(table.getRowCount)(batch.numRows)
assertResult(table.getNumberOfColumns)(batch.numCols)
(0 until table.getNumberOfColumns).foreach { i =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.nvidia.spark.rapids

import java.io.File
import java.math.RoundingMode

import scala.collection.mutable.ArrayBuffer

Expand All @@ -29,14 +30,15 @@ import org.scalatest.FunSuite
import org.scalatest.mockito.MockitoSugar

import org.apache.spark.sql.rapids.RapidsDiskBlockManager
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, StringType}
import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, IntegerType, StringType}

class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar {
private def buildContiguousTable(): ContiguousTable = {
withResource(new Table.TestBuilder()
.column(5, null.asInstanceOf[java.lang.Integer], 3, 1)
.column("five", "two", null, null)
.column(5.0, 2.0, 3.0, 1.0)
.decimal64Column(-5, RoundingMode.UNNECESSARY, 0, null, -1.4, 10.123)
.build()) { table =>
table.contiguousSplit()(0)
}
Expand Down Expand Up @@ -106,7 +108,8 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar {

test("get column batch") {
val catalog = new RapidsBufferCatalog
val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType)
val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType,
DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5))
withResource(new RapidsDeviceMemoryStore(catalog)) { store =>
val bufferId = MockRapidsBufferId(7)
withResource(buildContiguousTable()) { ct =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.nvidia.spark.rapids

import java.io.File
import java.math.RoundingMode

import ai.rapids.cudf.{ContiguousTable, DeviceMemoryBuffer, HostMemoryBuffer, Table}
import org.mockito.ArgumentMatchers
Expand All @@ -25,7 +26,7 @@ import org.scalatest.{BeforeAndAfterEach, FunSuite}
import org.scalatest.mockito.MockitoSugar

import org.apache.spark.sql.rapids.RapidsDiskBlockManager
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, StringType}
import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, IntegerType, StringType}

class RapidsDiskStoreSuite extends FunSuite with BeforeAndAfterEach with Arm with MockitoSugar {
val TEST_FILES_ROOT: File = TestUtils.getTempDir(this.getClass.getSimpleName)
Expand All @@ -43,6 +44,7 @@ class RapidsDiskStoreSuite extends FunSuite with BeforeAndAfterEach with Arm wit
.column(5, null.asInstanceOf[java.lang.Integer], 3, 1)
.column("five", "two", null, null)
.column(5.0, 2.0, 3.0, 1.0)
.decimal64Column(-5, RoundingMode.UNNECESSARY, 0, null, -1.4, 10.123)
.build()) { table =>
table.contiguousSplit()(0)
}
Expand Down Expand Up @@ -83,7 +85,8 @@ class RapidsDiskStoreSuite extends FunSuite with BeforeAndAfterEach with Arm wit
}

test("get columnar batch") {
val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType)
val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType,
DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5))
val bufferId = MockRapidsBufferId(1, canShareDiskPaths = false)
val bufferPath = bufferId.getDiskPath(null)
assert(!bufferPath.exists)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.nvidia.spark.rapids

import java.io.File
import java.math.RoundingMode

import ai.rapids.cudf.{ContiguousTable, Cuda, HostColumnVector, HostMemoryBuffer, Table}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
Expand All @@ -26,14 +27,15 @@ import org.scalatest.FunSuite
import org.scalatest.mockito.MockitoSugar

import org.apache.spark.sql.rapids.RapidsDiskBlockManager
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType}
import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, IntegerType, LongType, StringType}

class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar {
private def buildContiguousTable(): ContiguousTable = {
withResource(new Table.TestBuilder()
.column(5, null.asInstanceOf[java.lang.Integer], 3, 1)
.column("five", "two", null, null)
.column(5.0, 2.0, 3.0, 1.0)
.decimal64Column(-5, RoundingMode.UNNECESSARY, 0, null, -1.4, 10.123)
.build()) { table =>
table.contiguousSplit()(0)
}
Expand Down Expand Up @@ -119,7 +121,8 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar {
}

test("get memory buffer") {
val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType)
val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType,
DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5))
val bufferId = MockRapidsBufferId(7)
val spillPriority = -10
val hostStoreMaxSize = 1L * 1024 * 1024
Expand Down
2 changes: 2 additions & 0 deletions tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ object TestUtils extends Assertions with Arm {
case DType.FLOAT32 => assertResult(e.getFloat(i))(a.getFloat(i))
case DType.FLOAT64 => assertResult(e.getDouble(i))(a.getDouble(i))
case DType.STRING => assertResult(e.getJavaString(i))(a.getJavaString(i))
case dt if dt.isDecimalType && dt.isBackedByLong =>
assertResult(e.getBigDecimal(i))(a.getBigDecimal(i))
case _ => throw new UnsupportedOperationException("not implemented yet")
}
}
Expand Down

0 comments on commit 84b5540

Please sign in to comment.