From ff5339575fde0fc65fe607ef7ae8d5f73cacdb1a Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Wed, 4 Nov 2020 20:24:17 +0800 Subject: [PATCH] introduce DecimalType and decimal scalar Signed-off-by: sperlingxx --- .../ParquetCachedBatchSerializer.scala | 2 +- .../nvidia/spark/rapids/GpuColumnVector.java | 13 ++++ .../com/nvidia/spark/rapids/GpuOrcScan.scala | 13 +++- .../nvidia/spark/rapids/GpuOverrides.scala | 1 + .../nvidia/spark/rapids/GpuParquetScan.scala | 12 +++- .../com/nvidia/spark/rapids/literals.scala | 19 +++++ .../spark/rapids/unit/DecimalUnitTest.scala | 71 +++++++++++++++++++ 7 files changed, 127 insertions(+), 4 deletions(-) create mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/unit/DecimalUnitTest.scala diff --git a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/ParquetCachedBatchSerializer.scala b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/ParquetCachedBatchSerializer.scala index c607a440bb45..7e2abb7a744a 100644 --- a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/ParquetCachedBatchSerializer.scala +++ b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/ParquetCachedBatchSerializer.scala @@ -267,7 +267,7 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm { } def isSupportedByCudf(schema: Seq[Attribute]): Boolean = { - schema.forall(a => GpuColumnVector.isSupportedType(a.dataType)) + schema.forall(a => GpuParquetScanBase.isSupportedType(a.dataType)) } /** diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java index 49ed06f848a9..d08aadb5f85e 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java @@ -158,6 +158,15 @@ private static final DType toRapidsOrNull(DataType type) { return DType.TIMESTAMP_MICROSECONDS; } else if (type instanceof StringType) { return DType.STRING; + } else if (type instanceof DecimalType) { + DecimalType decType = (DecimalType) type; + if (decType.precision() <= DType.DECIMAL32_MAX_PRECISION) { + return DType.create(DType.DTypeEnum.DECIMAL32, -decType.scale()); + } else if (decType.precision() <= DType.DECIMAL64_MAX_PRECISION) { + return DType.create(DType.DTypeEnum.DECIMAL64, -decType.scale()); + } else { + return null; + } } return null; } @@ -201,6 +210,10 @@ static final DataType getSparkType(DType type) { return DataTypes.TimestampType; case STRING: return DataTypes.StringType; + case DECIMAL32: + return new DecimalType(DType.DECIMAL32_MAX_PRECISION, -type.getScale()); + case DECIMAL64: + return new DecimalType(DType.DECIMAL64_MAX_PRECISION, -type.getScale()); default: throw new IllegalArgumentException(type + " is not supported by spark yet."); } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala index 4a6fc01310e0..42a37fa35564 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala @@ -56,7 +56,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.OrcFilters import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, DecimalType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration @@ -113,11 +113,20 @@ object GpuOrcScanBase { meta.willNotWorkOnGpu("mergeSchema and schema evolution is not supported yet") } schema.foreach { field => - if (!GpuColumnVector.isSupportedType(field.dataType)) { + if (!isSupportedType(field.dataType)) { meta.willNotWorkOnGpu(s"GpuOrcScan does not support fields of type ${field.dataType}") } } } + // We need this specialized type check method because + // R/W ORC data with decimal columns has not supported by cuDF yet. + def isSupportedType(dataType: DataType): Boolean = { + GpuColumnVector.isSupportedType(dataType) match { + case false => false + case true if dataType.isInstanceOf[DecimalType] => false + case _ => true + } + } } case class GpuOrcPartitionReaderFactory( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 6410b2d297f4..bd34739f6b4b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -450,6 +450,7 @@ object GpuOverrides { case DateType => true case TimestampType => ZoneId.systemDefault().normalized() == GpuOverrides.UTC_TIMEZONE_ID case StringType => true + case dt: DecimalType if dt.precision <= ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION => true case _ => false } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala index 9e024b4633c9..b23afa34924e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala @@ -62,7 +62,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.InputFileUtils import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.{MapType, StringType, StructType, TimestampType} +import org.apache.spark.sql.types.{DataType, DecimalType, MapType, StringType, StructType, TimestampType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration @@ -175,6 +175,16 @@ object GpuParquetScanBase { meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode") } } + + // We need this specialized type check method because + // R/W parquet data with decimal columns has not supported by cuDF yet. + def isSupportedType(dataType: DataType): Boolean = { + GpuColumnVector.isSupportedType(dataType) match { + case false => false + case true if dataType.isInstanceOf[DecimalType] => false + case _ => true + } + } } /** diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala index 15eaf7a6bcf9..61ae70eef892 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala @@ -65,6 +65,7 @@ object GpuScalar { case DType.TIMESTAMP_DAYS => v.getInt case DType.TIMESTAMP_MICROSECONDS => v.getLong case DType.STRING => v.getJavaString + case dt: DType if dt.isDecimalType => v.getBigDecimal case t => throw new IllegalStateException(s"$t is not a supported rapids scalar type yet") } @@ -88,12 +89,30 @@ object GpuScalar { case b: Boolean => Scalar.fromBool(b) case s: String => Scalar.fromString(s) case s: UTF8String => Scalar.fromString(s.toString) + case dec: BigDecimal => Scalar.fromBigDecimal(dec.bigDecimal) case _ => throw new IllegalStateException(s"${v.getClass} '${v}' is not supported as a scalar yet") } def from(v: Any, t: DataType): Scalar = v match { case _ if v == null => Scalar.fromNull(GpuColumnVector.getRapidsType(t)) + case _ if t.isInstanceOf[DecimalType] => + var bigDec = v match { + case vv: BigDecimal => vv.bigDecimal + case vv: Double => BigDecimal(vv).bigDecimal + case vv: Float => BigDecimal(vv).bigDecimal + case vv: String => BigDecimal(vv).bigDecimal + case vv: Double => BigDecimal(vv).bigDecimal + case vv: Long => BigDecimal(vv).bigDecimal + case vv: Int => BigDecimal(vv).bigDecimal + case vv => throw new IllegalStateException( + s"${vv.getClass} '${vv}' is not supported as a scalar yet") + } + bigDec = bigDec.setScale(t.asInstanceOf[DecimalType].scale) + if (bigDec.precision() > t.asInstanceOf[DecimalType].precision) { + throw new IllegalArgumentException(s"BigDecimal $bigDec exceeds precision constraint of $t") + } + Scalar.fromBigDecimal(bigDec) case l: Long => t match { case LongType => Scalar.fromLong(l) case TimestampType => Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, l) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/unit/DecimalUnitTest.scala b/tests/src/test/scala/com/nvidia/spark/rapids/unit/DecimalUnitTest.scala new file mode 100644 index 000000000000..22e32e993c9d --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/unit/DecimalUnitTest.scala @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.unit + +import java.math.{BigDecimal => BigDec} + +import scala.util.Random + +import ai.rapids.cudf.DType +import com.nvidia.spark.rapids.{GpuScalar, GpuUnitTests} +import org.scalatest.Matchers + +import org.apache.spark.sql.types.DecimalType + +class DecimalUnitTest extends GpuUnitTests with Matchers { + Random.setSeed(1234L) + + private val dec32Data = Array.fill[BigDecimal](10)( + BigDecimal(Random.nextInt() / 10, Random.nextInt(5))) + private val dec64Data = Array.fill[BigDecimal](10)( + BigDecimal(Random.nextLong() / 1000, Random.nextInt(10))) + + test("test decimal as scalar") { + Array(dec32Data, dec64Data).flatten.foreach { dec => + // test GpuScalar.from(v: Any) + withResource(GpuScalar.from(dec)) { s => + s.getType.getScale shouldEqual -dec.scale + GpuScalar.extract(s).asInstanceOf[BigDec] shouldEqual dec.bigDecimal + } + // test GpuScalar.from(v: Any, t: DataType) + val dt = DecimalType(DType.DECIMAL64_MAX_PRECISION, dec.scale) + val dbl = dec.doubleValue() + withResource(GpuScalar.from(dbl, dt)) { s => + s.getType.getScale shouldEqual -dt.scale + GpuScalar.extract(s).asInstanceOf[BigDec].doubleValue() shouldEqual dbl + } + val str = dec.toString() + withResource(GpuScalar.from(str, dt)) { s => + s.getType.getScale shouldEqual -dt.scale + GpuScalar.extract(s).asInstanceOf[BigDec].toString shouldEqual str + } + val long = dec.longValue() + withResource(GpuScalar.from(long, DecimalType(DType.DECIMAL64_MAX_PRECISION, 0))) { s => + s.getType.getScale shouldEqual 0 + GpuScalar.extract(s).asInstanceOf[BigDec].longValue() shouldEqual long + } + } + // test exception throwing + assertThrows[IllegalStateException] { + withResource(GpuScalar.from(true, DecimalType(10, 1))) { _ => } + } + assertThrows[IllegalArgumentException] { + val bigDec = BigDecimal(Long.MaxValue / 100, 0) + withResource(GpuScalar.from(bigDec, DecimalType(15, 1))) { _ => } + } + } +}