diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 966102ea9cc..91e4c9ae3a6 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -187,13 +187,13 @@ Accelerator supports are described below.
S |
S* |
S |
-NS |
+S* |
S |
NS |
NS |
-PS* (missing nested DECIMAL, BINARY, CALENDAR, UDT) |
-PS* (missing nested DECIMAL, BINARY, CALENDAR, UDT) |
-PS* (missing nested DECIMAL, BINARY, CALENDAR, UDT) |
+PS* (missing nested BINARY, CALENDAR, UDT) |
+PS* (missing nested BINARY, CALENDAR, UDT) |
+PS* (missing nested BINARY, CALENDAR, UDT) |
NS |
@@ -486,13 +486,13 @@ Accelerator supports are described below.
S |
S* |
S |
+S* |
NS |
NS |
NS |
-NS |
-PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, UDT) |
-PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, UDT) |
-PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, UDT) |
+PS* (missing nested NULL, BINARY, CALENDAR, UDT) |
+PS* (missing nested NULL, BINARY, CALENDAR, UDT) |
+PS* (missing nested NULL, BINARY, CALENDAR, UDT) |
NS |
@@ -16592,13 +16592,13 @@ dates or timestamps, or for a lack of type coercion support.
S |
S |
S |
-NS |
+S |
|
NS |
|
-PS (missing nested DECIMAL, BINARY) |
-PS (missing nested DECIMAL, BINARY) |
-PS (missing nested DECIMAL, BINARY) |
+PS (missing nested BINARY) |
+PS (missing nested BINARY) |
+PS (missing nested BINARY) |
Output |
diff --git a/integration_tests/src/main/python/parquet_test.py b/integration_tests/src/main/python/parquet_test.py
index e6553597b33..04b6fcc19c6 100644
--- a/integration_tests/src/main/python/parquet_test.py
+++ b/integration_tests/src/main/python/parquet_test.py
@@ -27,23 +27,30 @@ def read_parquet_df(data_path):
def read_parquet_sql(data_path):
return lambda spark : spark.sql('select * from parquet.`{}`'.format(data_path))
+
+# Override decimal_gens because decimal with negative scale is unsupported in parquet reading
+decimal_gens = [DecimalGen(), DecimalGen(precision=7, scale=3), DecimalGen(precision=10, scale=10),
+ DecimalGen(precision=9, scale=0), DecimalGen(precision=18, scale=15)]
+
parquet_gens_list = [[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, date_gen,
TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc)), ArrayGen(byte_gen),
ArrayGen(long_gen), ArrayGen(string_gen), ArrayGen(date_gen),
ArrayGen(TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))),
+ ArrayGen(DecimalGen()),
ArrayGen(ArrayGen(byte_gen)),
- StructGen([['child0', ArrayGen(byte_gen)], ['child1', byte_gen], ['child2', float_gen]]),
- ArrayGen(StructGen([['child0', string_gen], ['child1', double_gen], ['child2', int_gen]]))] + map_gens_sample,
- pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/132'))]
+ StructGen([['child0', ArrayGen(byte_gen)], ['child1', byte_gen], ['child2', float_gen], ['child3', DecimalGen()]]),
+ ArrayGen(StructGen([['child0', string_gen], ['child1', double_gen], ['child2', int_gen]]))] +
+ map_gens_sample + decimal_gens,
+ pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/132'))]
# test with original parquet file reader, the multi-file parallel reader for cloud, and coalesce file reader for
# non-cloud
-original_parquet_file_reader_conf={'spark.rapids.sql.format.parquet.reader.type': 'PERFILE'}
-multithreaded_parquet_file_reader_conf={'spark.rapids.sql.format.parquet.reader.type': 'MULTITHREADED'}
-coalesce_parquet_file_reader_conf={'spark.rapids.sql.format.parquet.reader.type': 'COALESCING'}
+original_parquet_file_reader_conf = {'spark.rapids.sql.format.parquet.reader.type': 'PERFILE'}
+multithreaded_parquet_file_reader_conf = {'spark.rapids.sql.format.parquet.reader.type': 'MULTITHREADED'}
+coalesce_parquet_file_reader_conf = {'spark.rapids.sql.format.parquet.reader.type': 'COALESCING'}
reader_opt_confs = [original_parquet_file_reader_conf, multithreaded_parquet_file_reader_conf,
- coalesce_parquet_file_reader_conf]
+ coalesce_parquet_file_reader_conf]
@pytest.mark.parametrize('parquet_gens', parquet_gens_list, ids=idfn)
@pytest.mark.parametrize('read_func', [read_parquet_df, read_parquet_sql])
@@ -66,9 +73,9 @@ def test_read_round_trip(spark_tmp_path, parquet_gens, read_func, reader_confs,
@pytest.mark.parametrize('read_func', [read_parquet_df, read_parquet_sql])
@pytest.mark.parametrize('disable_conf', ['spark.rapids.sql.format.parquet.enabled', 'spark.rapids.sql.format.parquet.read.enabled'])
def test_parquet_fallback(spark_tmp_path, read_func, disable_conf):
- data_gens =[string_gen,
- byte_gen, short_gen, int_gen, long_gen, boolean_gen]
-
+ data_gens = [string_gen,
+ byte_gen, short_gen, int_gen, long_gen, boolean_gen] + decimal_gens
+
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(data_gens)]
gen = StructGen(gen_list, nullable=False)
data_path = spark_tmp_path + '/PARQUET_DATA'
@@ -103,8 +110,8 @@ def test_compress_read_round_trip(spark_tmp_path, compress, v1_enabled_list, rea
byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, boolean_gen,
string_gen, date_gen,
# Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with
- # timestamp_gen
- TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))]
+ # timestamp_gen
+ TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))] + decimal_gens
@pytest.mark.parametrize('parquet_gen', parquet_pred_push_gens, ids=idfn)
@pytest.mark.parametrize('read_func', [read_parquet_df, read_parquet_sql])
@@ -193,11 +200,27 @@ def test_ts_read_fails_datetime_legacy(gen, spark_tmp_path, ts_write, ts_rebase,
lambda spark : readParquetCatchException(spark, data_path),
conf=all_confs)
+
+@pytest.mark.parametrize('parquet_gens', [decimal_gens], ids=idfn)
+@pytest.mark.parametrize('read_func', [read_parquet_df, read_parquet_sql])
+@pytest.mark.parametrize('reader_confs', reader_opt_confs)
+@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"])
+def test_decimal_read_legacy(spark_tmp_path, parquet_gens, read_func, reader_confs, v1_enabled_list):
+ gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)]
+ data_path = spark_tmp_path + '/PARQUET_DATA'
+ with_cpu_session(
+ lambda spark : gen_df(spark, gen_list).write.parquet(data_path),
+ conf={'spark.sql.parquet.writeLegacyFormat': 'true'})
+ all_confs = reader_confs.copy()
+ all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list})
+ assert_gpu_and_cpu_are_equal_collect(read_func(data_path), conf=all_confs)
+
+
parquet_gens_legacy_list = [[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
- string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)),
- TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))],
- pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/133')),
- pytest.param([date_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/133'))]
+ string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)),
+ TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))] + decimal_gens,
+ pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/133')),
+ pytest.param([date_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/133'))]
@pytest.mark.parametrize('parquet_gens', parquet_gens_legacy_list, ids=idfn)
@pytest.mark.parametrize('reader_confs', reader_opt_confs)
@@ -221,7 +244,7 @@ def test_simple_partitioned_read(spark_tmp_path, v1_enabled_list, reader_confs):
# we should go with a more standard set of generators
parquet_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)),
- TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))]
+ TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))] + decimal_gens
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)]
first_data_path = spark_tmp_path + '/PARQUET_DATA/key=0'
with_cpu_session(
@@ -291,7 +314,7 @@ def test_read_merge_schema(spark_tmp_path, v1_enabled_list, reader_confs):
# we should go with a more standard set of generators
parquet_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)),
- TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))]
+ TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))] + decimal_gens
first_gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)]
first_data_path = spark_tmp_path + '/PARQUET_DATA/key=0'
with_cpu_session(
@@ -316,7 +339,7 @@ def test_read_merge_schema_from_conf(spark_tmp_path, v1_enabled_list, reader_con
# we should go with a more standard set of generators
parquet_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)),
- TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))]
+ TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))] + decimal_gens
first_gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)]
first_data_path = spark_tmp_path + '/PARQUET_DATA/key=0'
with_cpu_session(
@@ -399,15 +422,15 @@ def test_small_file_memory(spark_tmp_path, v1_enabled_list):
_nested_pruning_schemas = [
- ([["a", StructGen([["c_1", StringGen()], ["c_2", LongGen()], ["c_3", ShortGen()]])]],
+ ([["a", StructGen([["c_1", StringGen()], ["c_2", LongGen()], ["c_3", ShortGen()]])]],
[["a", StructGen([["c_1", StringGen()]])]]),
- ([["a", StructGen([["c_1", StringGen()], ["c_2", LongGen()], ["c_3", ShortGen()]])]],
+ ([["a", StructGen([["c_1", StringGen()], ["c_2", LongGen()], ["c_3", ShortGen()]])]],
[["a", StructGen([["c_2", LongGen()]])]]),
- ([["a", StructGen([["c_1", StringGen()], ["c_2", LongGen()], ["c_3", ShortGen()]])]],
+ ([["a", StructGen([["c_1", StringGen()], ["c_2", LongGen()], ["c_3", ShortGen()]])]],
[["a", StructGen([["c_3", ShortGen()]])]]),
- ([["a", StructGen([["c_1", StringGen()], ["c_2", LongGen()], ["c_3", ShortGen()]])]],
+ ([["a", StructGen([["c_1", StringGen()], ["c_2", LongGen()], ["c_3", ShortGen()]])]],
[["a", StructGen([["c_1", StringGen()], ["c_3", ShortGen()]])]]),
- ([["a", StructGen([["c_1", StringGen()], ["c_2", LongGen()], ["c_3", ShortGen()]])]],
+ ([["a", StructGen([["c_1", StringGen()], ["c_2", LongGen()], ["c_3", ShortGen()]])]],
[["a", StructGen([["c_3", ShortGen()], ["c_2", LongGen()], ["c_1", StringGen()]])]]),
([["ar", ArrayGen(StructGen([["str_1", StringGen()],["str_2", StringGen()]]))]],
[["ar", ArrayGen(StructGen([["str_2", StringGen()]]))]])
diff --git a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala
index d5c7a718ec0..902f76d7ea2 100644
--- a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala
+++ b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala
@@ -147,8 +147,9 @@ class Spark300Shims extends SparkShims {
GpuOverrides.exec[FileSourceScanExec](
"Reading data from files, often from Hive tables",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP +
- TypeSig.ARRAY).nested(), TypeSig.all),
+ TypeSig.ARRAY + TypeSig.DECIMAL).nested(), TypeSig.all),
(fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) {
+
// partition filters and data filters are not run on the GPU
override val childExprs: Seq[ExprMeta[_]] = Seq.empty
diff --git a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala
index 7cd54d0d55a..1164319b4f5 100644
--- a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala
+++ b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala
@@ -134,7 +134,7 @@ class Spark310Shims extends Spark301Shims {
GpuOverrides.exec[FileSourceScanExec](
"Reading data from files, often from Hive tables",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP +
- TypeSig.ARRAY).nested(), TypeSig.all),
+ TypeSig.ARRAY + TypeSig.DECIMAL).nested(), TypeSig.all),
(fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) {
// partition filters and data filters are not run on the GPU
override val childExprs: Seq[ExprMeta[_]] = Seq.empty
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchScanExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchScanExec.scala
index 01687c4a5cf..0b096ddf781 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchScanExec.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchScanExec.scala
@@ -45,7 +45,7 @@ import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{DateType, StructField, StructType, TimestampType}
+import org.apache.spark.sql.types.{DateType, DecimalType, StructField, StructType, TimestampType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.SerializableConfiguration
@@ -229,6 +229,10 @@ object GpuCSVScan {
}
}
// TODO parsedOptions.emptyValueInRead
+
+ if (readSchema.exists(_.dataType.isInstanceOf[DecimalType])) {
+ meta.willNotWorkOnGpu("DecimalType is not supported")
+ }
}
}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala
index e4a0f10359a..cdb14c06aa4 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala
@@ -115,7 +115,7 @@ object GpuOrcScanBase {
meta.willNotWorkOnGpu("mergeSchema and schema evolution is not supported yet")
}
schema.foreach { field =>
- if (!GpuColumnVector.isNonNestedSupportedType(field.dataType)) {
+ if (!GpuOverrides.isSupportedType(field.dataType)) {
meta.willNotWorkOnGpu(s"GpuOrcScan does not support fields of type ${field.dataType}")
}
}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index 3ac15e43902..88a2cb513b8 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.execution.datasources.v2.{AlterNamespaceSetPropertiesExec, AlterTableExec, AtomicReplaceTableExec, BatchScanExec, CreateNamespaceExec, CreateTableExec, DeleteFromTableExec, DescribeNamespaceExec, DescribeTableExec, DropNamespaceExec, DropTableExec, RefreshTableExec, RenameTableExec, ReplaceTableExec, SetCatalogAndNamespaceExec, ShowCurrentNamespaceExec, ShowNamespacesExec, ShowTablePropertiesExec, ShowTablesExec}
import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan
+import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.python._
@@ -2201,7 +2202,8 @@ object GpuOverrides {
exec[BatchScanExec](
"The backend for most file input",
ExecChecks(
- (TypeSig.commonCudfTypes + TypeSig.STRUCT + TypeSig.MAP + TypeSig.ARRAY).nested(),
+ (TypeSig.commonCudfTypes + TypeSig.STRUCT + TypeSig.MAP + TypeSig.ARRAY +
+ TypeSig.DECIMAL).nested(),
TypeSig.all),
(p, conf, parent, r) => new SparkPlanMeta[BatchScanExec](p, conf, parent, r) {
override val childScans: scala.Seq[ScanMeta[_]] =
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala
index 746f49cb4d2..9737ff70d2e 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala
@@ -139,7 +139,8 @@ object GpuParquetScanBase {
allowMaps = true,
allowArray = true,
allowStruct = true,
- allowNesting = true)) {
+ allowNesting = true,
+ allowDecimal = meta.conf.decimalTypeEnabled)) {
meta.willNotWorkOnGpu(s"GpuParquetScan does not support fields of type ${field.dataType}")
}
}
@@ -197,6 +198,33 @@ object GpuParquetScanBase {
meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode")
}
}
+
+ private[rapids] def convertDecimal32Columns(t: Table): Table = {
+ val containDecimal32Column = (0 until t.getNumberOfColumns).exists { i =>
+ t.getColumn(i).getType.getTypeId == DType.DTypeEnum.DECIMAL32
+ }
+ // return input table if there exists no DECIMAL32 columns
+ if (!containDecimal32Column) return t
+
+ val columns = new Array[ColumnVector](t.getNumberOfColumns)
+ try {
+ RebaseHelper.withResource(t) { _ =>
+ (0 until t.getNumberOfColumns).foreach { i =>
+ t.getColumn(i).getType match {
+ case tpe if tpe.getTypeId == DType.DTypeEnum.DECIMAL32 =>
+ columns(i) = t.getColumn(i).castTo(
+ DType.create(DType.DTypeEnum.DECIMAL64, tpe.getScale))
+ case _ =>
+ columns(i) = t.getColumn(i).incRefCount()
+ }
+ }
+ }
+ new Table(columns: _*)
+ } finally {
+ // clean temporary column vectors
+ columns.safeClose()
+ }
+ }
}
/**
@@ -657,13 +685,16 @@ abstract class FileParquetPartitionReaderBase(
inputTable: Table,
filePath: String,
clippedSchema: MessageType): Table = {
- if (readDataSchema.length > inputTable.getNumberOfColumns) {
+ // Convert Decimal32 columns to Decimal64, because spark-rapids only supports Decimal64.
+ val inTable = GpuParquetScanBase.convertDecimal32Columns(inputTable)
+
+ if (readDataSchema.length > inTable.getNumberOfColumns) {
// Spark+Parquet schema evolution is relatively simple with only adding/removing columns
// To type casting or anyting like that
val clippedGroups = clippedSchema.asGroupType()
val newColumns = new Array[ColumnVector](readDataSchema.length)
try {
- withResource(inputTable) { table =>
+ withResource(inTable) { table =>
var readAt = 0
(0 until readDataSchema.length).foreach(writeAt => {
val readField = readDataSchema(writeAt)
@@ -686,7 +717,7 @@ abstract class FileParquetPartitionReaderBase(
newColumns.safeClose()
}
} else {
- inputTable
+ inTable
}
}
@@ -1040,6 +1071,7 @@ class MultiFileParquetPartitionReader(
}
val parseOpts = ParquetOptions.builder()
.withTimeUnit(DType.TIMESTAMP_MICROSECONDS)
+ .enableStrictDecimalType(true)
.includeColumn(readDataSchema.fieldNames:_*).build()
// about to start using the GPU
@@ -1429,6 +1461,7 @@ class MultiFileCloudParquetPartitionReader(
}
val parseOpts = ParquetOptions.builder()
.withTimeUnit(DType.TIMESTAMP_MICROSECONDS)
+ .enableStrictDecimalType(true)
.includeColumn(readDataSchema.fieldNames: _*).build()
// about to start using the GPU
@@ -1564,6 +1597,7 @@ class ParquetPartitionReader(
}
val parseOpts = ParquetOptions.builder()
.withTimeUnit(DType.TIMESTAMP_MICROSECONDS)
+ .enableStrictDecimalType(true)
.includeColumn(readDataSchema.fieldNames:_*).build()
// about to start using the GPU
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala
index e9ad8324307..295655972f6 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala
@@ -127,6 +127,7 @@ object HostColumnarToGpu {
if (cv.isNullAt(i)) {
b.appendNull()
} else {
+ // The precision here matters for cpu column vectors (such as OnHeapColumnVector).
b.append(cv.getDecimal(i, dt.precision, dt.scale).toUnscaledLong)
}
}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
index 231f4a51021..c0badbed029 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
@@ -1409,13 +1409,13 @@ object SupportedOpsDocs {
println("S | ") // DATE
println("S | ") // TIMESTAMP
println("S | ") // STRING
- println("NS | ") // DECIMAL
+ println("S | ") // DECIMAL
println(" | ") // NULL
println("NS | ") // BINARY
println(" | ") // CALENDAR
- println("PS (missing nested DECIMAL, BINARY) | ") // ARRAY
- println("PS (missing nested DECIMAL, BINARY) | ") // MAP
- println("PS (missing nested DECIMAL, BINARY) | ") // STRUCT
+ println("PS (missing nested BINARY) | ") // ARRAY
+ println("PS (missing nested BINARY) | ") // MAP
+ println("PS (missing nested BINARY) | ") // STRUCT
println("
")
println("")
println("Output | ")
diff --git a/tests/src/test/resources/decimal-test-legacy.parquet b/tests/src/test/resources/decimal-test-legacy.parquet
new file mode 100644
index 00000000000..969f049848f
Binary files /dev/null and b/tests/src/test/resources/decimal-test-legacy.parquet differ
diff --git a/tests/src/test/resources/decimal-test.csv b/tests/src/test/resources/decimal-test.csv
new file mode 100644
index 00000000000..1fd4646904c
--- /dev/null
+++ b/tests/src/test/resources/decimal-test.csv
@@ -0,0 +1,100 @@
+915270249210239718,3232.792,"",771.371173049837,-431710025170174585,4.7239953E24
+50004804273312941,9263.400,0.3815050595,900.890874730220,-2697073954890740236,7.463101E-31
+189216077028719828,8094.536,0.7879423094,817.584127883600,-5133656973475552689,2.4919543E-32
+257886722013221592,7097.760,0.5551530993,263.532647425188,-3917032101531217289,5.176538E7
+223616015091255874,2021.691,0.4300789101,741.876603938996,7810001276519378488,6.673943E9
+672269487046710336,5925.644,0.5119009979,677.273069670187,"",-3.037228E-26
+565739933980371374,2082.917,0.6325578476,307.790881040462,-7858370133784586516,-1.7402571E-20
+30846051239772526,2153.584,"",428.457572476468,2224676899470531349,-3.720298E-28
+292750338676616377,913.012,0.7647953732,329.427928877243,1287002480498025462,963.7641
+863597087074282135,7771.993,0.9378547775,655.942008799151,-4594481394522420980,-5.143567E-37
+740916466628308680,3601.327,0.5253287765,293.314491660678,3663204060287939255,-7.328485E30
+543922875310806589,418.135,0.0089806121,678.725898059834,7100991336799471692,""
+907927561275971206,4357.321,0.0894175775,985.715625304095,8980271347953950594,2.036806E-11
+5365651368778670,28.443,0E-10,98.984414373371,2753461187700946189,""
+98975573402945273,5544.347,0.3277681904,780.987096259467,-2821373815087770381,""
+37654707856660412,3245.419,0.7249053714,"","",5.489915E-15
+417692898406404444,5137.551,0.0654631387,8.906002611085,-3073072854586574157,2.8559989E-5
+30008375870039084,1638.013,0.5689759921,506.427830699217,8001350543913869500,0.0010491271
+267647485487795321,711.758,0.7086510671,787.214975034435,0,-4.294404E35
+32497433138887044,1500.500,0.1632851254,111.862479713755,-2124845351861490171,-2911376.8
+563219648467491123,2533.799,0.0470860444,-999.999999999999,-6150492885085058765,-3.1956E-21
+698244577193394337,7710.613,0.1796242319,177.658429265869,-7293294899695747856,1.55815667E12
+59984767260993922,9598.853,0.9606575721,827.362102857851,7956328026836954420,-2.4386406E11
+185053572335775931,1504.547,0.5687599560,418.547722108074,-8496278226268642122,-1.1115532E-32
+406076905698949206,1584.821,0.9073638746,454.572717017033,9048728756392488113,-3.6374905E-7
+370626360548153088,6979.091,0.7359568576,847.384858615533,4440968842303227453,3.9915457E-28
+880266977576630903,3969.112,0.0536779482,32.241735406630,6086331466042394437,-2.1541503E-33
+914944620826734702,258.240,0.7421725875,413.873958949421,-7351387047220061849,-8.1657964E-36
+615190575294467984,8101.993,0.4021261498,440.941554482483,2251193521986836658,1.7728261E25
+748719415383076949,928.200,0.0701185195,307.250677745230,-255208803392541585,""
+"",8482.048,0.0875250620,889.559939012387,131095898465527538,-0.0
+541457418990604771,298.713,"",356.319195755119,8385038430064196353,-1.1713938E26
+998278945156788331,3126.578,0.0190951652,277.824073520913,5884635015085662155,2.7177546E22
+648685475062397368,955.643,0.4443410992,569.890024734869,-5773983929953240007,3.5561367E32
+128835062166543528,6830.727,"",967.668187303342,-7732708842585057603,-1.69050099E9
+979181842441473467,2627.146,0.0931873144,308.242155826852,6658806232357759343,-3.0718065E22
+373383053238313709,5880.783,0.8089926358,288.091654931223,2449340209240679641,1.8817666E-34
+369917108588895763,"",0.0784246739,419.063742180292,5469773751425577039,6.3262706E-10
+959109271119159898,5963.914,0.5131396358,"",-3786594002235441489,-1.0057181E33
+775063852272099258,0.000,0.1505213451,419.677059264876,-8824564747028895090,5.758003E-13
+495443091472185622,5536.609,0.6390268097,739.197371445563,1760002372390556825,6816057.5
+25146710634437912,3067.206,0.2161656375,215.056698028213,803782761142075607,-54089.32
+613858510238074126,957.490,0.4822075927,518.956895534893,2437026550498167082,-2.4735995E-5
+255794834605228757,899.249,0.8092193887,588.233634837773,-525199667559164571,""
+680825620982324925,5517.827,0.0065746462,193.251744967577,-5041734557262022470,-1485983.4
+714779301951357809,1711.747,0.6136524686,686.957950228997,-6839627708332633308,""
+381520952215527508,5424.226,0.2305893771,549.262462812814,-2059623267661777620,-3.7695377E-10
+669984683600664919,9136.075,0.7850895020,209.124878112725,-2710360188048040119,4.9402314E28
+182598193716752431,2503.031,0.0070287326,3.249783083502,-6633320123827272048,4.1270598E-11
+"",2596.200,0.8936980935,85.288926615679,662259464231680319,""
+40023869813864036,233.012,0.5436629243,171.799250634543,-3664974252843436884,-1.1522237E21
+710322854240541412,4019.609,"",446.549422285354,8834022966795744609,-1.00241495E-20
+841507640518767629,5834.872,0.4100687936,369.576552043052,6075221653337964625,-3.9562905E-11
+857872153237373925,4796.207,0.0756528306,575.779939808894,3893516458893827324,3.1058907E-18
+25775526817610369,548.169,0.0669838937,717.645262020503,-1,-2.9235664E-31
+25215292382715618,"",0.8464787764,632.667903196574,3221645619906578280,3.8987961E-19
+47171113226028385,8168.750,0.9055021741,35.058703401277,8776159873495597953,-3.7015653E-27
+943951645776822973,8443.481,0.3052030122,807.819184192606,-199240735297494389,2.9509407E32
+650796227970751779,6957.306,0.3522180815,680.869699320152,5264124301174436230,-8.20616E-30
+750204382327921920,"",0.9373742004,143.248267456298,-6228871387240662005,7.737559E-19
+612727546233536520,2940.034,0.5992235971,277.240099823954,9223372036854775807,-1.4193973E-5
+248544869063573270,4405.459,0.4938911635,185.838164288601,3780377771042673290,-4.794553E-20
+"",7286.797,0.0525265125,821.424386490011,8226475382227724456,-1.0
+797023704927729407,8307.766,0.3520962355,"",2370991850373366052,-9.2752095E-33
+733656328652598115,9387.470,"",292.371386489063,-772608278816658373,6.0993237
+456487310339748192,"",0.8716752095,759.109542616645,-8516846031062369157,-1.9623711E-20
+12735651342573944,463.904,0.7174409373,999.999999999999,8704401706019656847,1.6802378E27
+503142435912090290,4817.213,0.3841225840,9.127318172586,1196628311801793151,1.0
+227184013575342512,5661.152,0.3130044658,529.075843748348,-5761755906933392982,-1.2703339E27
+"",1738.648,0.1635048458,456.258113062836,1937750570385130232,-547.00464
+786474267467256790,9999.999,0.7247985984,246.846730055841,1,221027.05
+5722907039468293,1494.945,0.4833744532,521.189835443092,"",79.40845
+498893766826535453,-9999.999,0.8314235080,448.448825006248,-4757790292222279178,-3.1900585E26
+746704316885770352,7426.798,0.0697542832,933.708683277599,8123776869365306368,8367.281
+242857837293930678,72.927,0.4892817062,593.300614078262,1525346850676520527,2.4633369E-28
+267180693163965584,5567.777,0.1129542164,862.736499156329,-7943660520670333047,-2.5013383E34
+863179009301888915,"",0.5913615439,559.544347644078,-6218179828993870307,""
+628941572681954168,672.616,0.1120588881,726.724012853591,8295107456982566271,-3.601133E-26
+195085765272922078,3267.147,0.6227582353,942.184595963346,-7901787385806075056,5.208774E-15
+54795121483994580,7573.962,0.0232430822,530.957948073184,2519241821679825052,2.3091825E27
+11025911078984278,9080.678,0.9256236807,785.967159932085,3898714205093691132,-5.299004E-27
+585416995369880382,3180.974,0.6636002892,399.741020816154,8287847947794918753,-4.328721E-24
+93056155170236349,9851.699,0.4570285677,870.296415087233,7848530628017602134,1.6895455E-36
+273518414773911153,541.209,0.0030913568,794.949286358568,-9209970347163780353,9.984763E-9
+950579123902836569,2958.730,"","",-8408605963020785199,-152976.69
+294812000018063416,5947.866,0.3776552591,71.442236681325,"",1.89517168E8
+999999999999999999,4313.357,0.0052879447,534.027817103663,-1499033162028359415,3.3824432E-10
+47797026331225179,2523.889,0.0202628702,283.620687243088,9223372036854775807,-9.904269E28
+945353911554496947,9714.635,0.5445525623,577.494999825168,-4851177952341996882,5.4106984E7
+"",706.236,0.0574360283,868.681442920555,0,-9.979879E32
+362850502113699515,3648.749,0.4765617370,751.474476720116,4674971002257188808,1.8771697E32
+263209284245630819,9532.742,0.5552251822,191.831875430596,-2824608313884610115,1.6279153E-21
+553774330947428625,166.791,0.9306022422,79.427117038552,-5012912027038200187,9.0646643E10
+607563227695541003,4512.145,0.6783432188,901.316884058369,7603145021478615614,NaN
+159291016019700529,5240.183,0.7258544713,350.623491123040,-1,-3.2958715E-15
+619900339634591041,7564.722,0.1887079611,994.209902879431,-6112563531377075787,-2.2229757E-33
+831245367052238890,4229.978,0.5127881028,639.349614940198,7280508167761829381,1.8915695E21
+868372001868812448,7565.051,0.1726353517,344.897092976059,1197973862304902753,-3.4028235E38
+"",8757.507,0.8479961242,429.256072226970,-503988022213926193,1.22784955E-32
+590855156246510004,1298.748,0.3022440975,961.920179785774,-4325271223339769315,-7.483853E-22
diff --git a/tests/src/test/resources/decimal-test.orc b/tests/src/test/resources/decimal-test.orc
new file mode 100644
index 00000000000..8396738735f
Binary files /dev/null and b/tests/src/test/resources/decimal-test.orc differ
diff --git a/tests/src/test/resources/decimal-test.parquet b/tests/src/test/resources/decimal-test.parquet
new file mode 100644
index 00000000000..2d029c704be
Binary files /dev/null and b/tests/src/test/resources/decimal-test.parquet differ
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ParquetScanSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ParquetScanSuite.scala
index 14cf270bc93..0942b2c0346 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/ParquetScanSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/ParquetScanSuite.scala
@@ -16,7 +16,11 @@
package com.nvidia.spark.rapids
+import java.io.File
+import java.nio.file.Files
+
import org.apache.spark.SparkConf
+import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions.col
class ParquetScanSuite extends SparkQueryCompareTestSuite {
@@ -38,4 +42,18 @@ class ParquetScanSuite extends SparkQueryCompareTestSuite {
frameFromParquet("timestamp-date-test.parquet")) {
frame => frame.select(col("*"))
}
+
+ // Column schema of decimal-test.parquet is: [c_0: decimal(18, 0), c_1: decimal(7, 3),
+ // c_2: decimal(10, 10), c_3: decimal(15, 12), c_4: int64, c_5: float]
+ testSparkResultsAreEqual("Test Parquet decimal stored as INT32/64",
+ frameFromParquet("decimal-test.parquet")) {
+ frame => frame.select(col("*"))
+ }
+
+ // Column schema of decimal-test-legacy.parquet is: [c_0: decimal(18, 0), c_1: decimal(7, 3),
+ // c_2: decimal(10, 10), c_3: decimal(15, 12), c_4: int64, c_5: float]
+ testSparkResultsAreEqual("Test Parquet decimal stored as FIXED_LEN_BYTE_ARRAY",
+ frameFromParquet("decimal-test-legacy.parquet")) {
+ frame => frame.select(col("*"))
+ }
}
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/unit/DecimalUnitTest.scala b/tests/src/test/scala/com/nvidia/spark/rapids/unit/DecimalUnitTest.scala
index 7bd4467a6ae..f5886b58d40 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/unit/DecimalUnitTest.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/unit/DecimalUnitTest.scala
@@ -21,10 +21,16 @@ import java.math.RoundingMode
import scala.util.Random
import ai.rapids.cudf.{ColumnVector, DType, HostColumnVector}
-import com.nvidia.spark.rapids.{GpuAlias, GpuColumnVector, GpuIsNotNull, GpuIsNull, GpuLiteral, GpuOverrides, GpuScalar, GpuUnitTests, HostColumnarToGpu, RapidsConf}
+import com.nvidia.spark.rapids.{GpuAlias, GpuBatchScanExec, GpuColumnVector, GpuIsNotNull, GpuIsNull, GpuLiteral, GpuOverrides, GpuScalar, GpuUnitTests, HostColumnarToGpu, RapidsConf}
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Literal}
-import org.apache.spark.sql.types.{Decimal, DecimalType}
+import org.apache.spark.sql.execution.FileSourceScanExec
+import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.rapids.GpuFileSourceScanExec
+import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, LongType, StructField, StructType}
class DecimalUnitTest extends GpuUnitTests {
Random.setSeed(1234L)
@@ -260,4 +266,34 @@ class DecimalUnitTest extends GpuUnitTests {
}
}
}
+
+ test("test type checking of Scans") {
+ val conf = new SparkConf().set(RapidsConf.DECIMAL_TYPE_ENABLED.key, "true")
+ .set(RapidsConf.TEST_ALLOWED_NONGPU.key, "BatchScanExec,ColumnarToRowExec,FileSourceScanExec")
+ val decimalCsvStruct = StructType(Array(
+ StructField("c_0", DecimalType(18, 0), true),
+ StructField("c_1", DecimalType(7, 3), true),
+ StructField("c_2", DecimalType(10, 10), true),
+ StructField("c_3", DecimalType(15, 12), true),
+ StructField("c_4", LongType, true),
+ StructField("c_5", IntegerType, true)))
+
+ withGpuSparkSession((ss: SparkSession) => {
+ var rootPlan = frameFromOrc("decimal-test.orc")(ss).queryExecution.executedPlan
+ assert(rootPlan.map(p => p).exists(_.isInstanceOf[FileSourceScanExec]))
+ rootPlan = fromCsvDf("decimal-test.csv", decimalCsvStruct)(ss).queryExecution.executedPlan
+ assert(rootPlan.map(p => p).exists(_.isInstanceOf[FileSourceScanExec]))
+ rootPlan = frameFromParquet("decimal-test.parquet")(ss).queryExecution.executedPlan
+ assert(rootPlan.map(p => p).exists(_.isInstanceOf[GpuFileSourceScanExec]))
+ }, conf)
+
+ withGpuSparkSession((ss: SparkSession) => {
+ var rootPlan = frameFromOrc("decimal-test.orc")(ss).queryExecution.executedPlan
+ assert(rootPlan.map(p => p).exists(_.isInstanceOf[BatchScanExec]))
+ rootPlan = fromCsvDf("decimal-test.csv", decimalCsvStruct)(ss).queryExecution.executedPlan
+ assert(rootPlan.map(p => p).exists(_.isInstanceOf[BatchScanExec]))
+ rootPlan = frameFromParquet("decimal-test.parquet")(ss).queryExecution.executedPlan
+ assert(rootPlan.map(p => p).exists(_.isInstanceOf[GpuBatchScanExec]))
+ }, conf.set(SQLConf.USE_V1_SOURCE_LIST.key, ""))
+ }
}