From a0aeaba8d6e3a71602c32e8dc383fe8ebe34ef52 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Tue, 8 Mar 2022 07:57:21 -0800 Subject: [PATCH] Fixed Decimal 128 bug in ParquetCachedBatchSerializer (#4899) * Fixed Parquet schema Signed-off-by: Raza Jafri * addressed review comments Signed-off-by: Raza Jafri * removed unused import Signed-off-by: Raza Jafri Co-authored-by: Raza Jafri --- .../src/main/python/cache_test.py | 3 +-- .../v2/ParquetCachedBatchSerializer.scala | 21 +++++++++---------- .../spark/rapids/shims/v2/Spark31XShims.scala | 2 +- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/integration_tests/src/main/python/cache_test.py b/integration_tests/src/main/python/cache_test.py index 0849ba0b5d0..66dd85bf75c 100644 --- a/integration_tests/src/main/python/cache_test.py +++ b/integration_tests/src/main/python/cache_test.py @@ -24,8 +24,7 @@ enable_vectorized_confs = [{"spark.sql.inMemoryColumnarStorage.enableVectorizedReader": "true"}, {"spark.sql.inMemoryColumnarStorage.enableVectorizedReader": "false"}] -# cache does not work with 128-bit decimals, see https://github.com/NVIDIA/spark-rapids/issues/4826 -_cache_decimal_gens = [decimal_gen_32bit, decimal_gen_64bit] +_cache_decimal_gens = [decimal_gen_32bit, decimal_gen_64bit, decimal_gen_128bit] _cache_single_array_gens_no_null = [ArrayGen(gen) for gen in all_basic_gens_no_null + _cache_decimal_gens] decimal_struct_gen= StructGen([['child0', sub_gen] for ind, sub_gen in enumerate(_cache_decimal_gens)]) diff --git a/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/ParquetCachedBatchSerializer.scala b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/ParquetCachedBatchSerializer.scala index 383fe646cfc..e4fceba818c 100644 --- a/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/ParquetCachedBatchSerializer.scala +++ b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/ParquetCachedBatchSerializer.scala @@ -53,7 +53,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.columnar.CachedBatch -import org.apache.spark.sql.execution.datasources.parquet.{ParquetReadSupport, ParquetToSparkSchemaConverter, ParquetWriteSupport, SparkToParquetSchemaConverter, VectorizedColumnReader} +import org.apache.spark.sql.execution.datasources.parquet.{ParquetReadSupport, ParquetToSparkSchemaConverter, ParquetWriteSupport, VectorizedColumnReader} import org.apache.spark.sql.execution.datasources.parquet.rapids.shims.v2.{ParquetRecordMaterializer, ShimVectorizedColumnReader} import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, WritableColumnVector} import org.apache.spark.sql.internal.SQLConf @@ -889,13 +889,9 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi val inMemCacheSparkSchema = parquetToSparkSchemaConverter.convert(inMemCacheParquetSchema) val totalRowCount = parquetFileReader.getRowGroups.asScala.map(_.getRowCount).sum - val sparkToParquetSchemaConverter = new SparkToParquetSchemaConverter(hadoopConf) val inMemReqSparkSchema = StructType(selectedAttributes.toStructType.map { field => inMemCacheSparkSchema.fields(inMemCacheSparkSchema.fieldIndex(field.name)) }) - val inMemReqParquetSchema = sparkToParquetSchemaConverter.convert(inMemReqSparkSchema) - val columnsRequested: util.List[ColumnDescriptor] = inMemReqParquetSchema.getColumns - val reqSparkSchemaInCacheOrder = StructType(inMemCacheSparkSchema.filter(f => inMemReqSparkSchema.fields.exists(f0 => f0.name.equals(f.name)))) @@ -907,23 +903,26 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi index -> inMemReqSparkSchema.fields.indexOf(reqSparkSchemaInCacheOrder.fields(index)) }.toMap - val reqParquetSchemaInCacheOrder = - sparkToParquetSchemaConverter.convert(reqSparkSchemaInCacheOrder) + val reqParquetSchemaInCacheOrder = new org.apache.parquet.schema.MessageType( + inMemCacheParquetSchema.getName(), reqSparkSchemaInCacheOrder.fields.map { f => + inMemCacheParquetSchema.getFields().get(inMemCacheParquetSchema.getFieldIndex(f.name)) + }:_*) + val columnsRequested: util.List[ColumnDescriptor] = reqParquetSchemaInCacheOrder.getColumns // reset spark schema calculated from parquet schema hadoopConf.set(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, inMemReqSparkSchema.json) hadoopConf.set(ParquetWriteSupport.SPARK_ROW_SCHEMA, inMemReqSparkSchema.json) val columnsInCache: util.List[ColumnDescriptor] = reqParquetSchemaInCacheOrder.getColumns val typesInCache: util.List[Type] = reqParquetSchemaInCacheOrder.asGroupType.getFields - val missingColumns = new Array[Boolean](inMemReqParquetSchema.getFieldCount) + val missingColumns = new Array[Boolean](reqParquetSchemaInCacheOrder.getFieldCount) // initialize missingColumns to cover the case where requested column isn't present in the // cache, which should never happen but just in case it does - val paths: util.List[Array[String]] = inMemReqParquetSchema.getPaths + val paths: util.List[Array[String]] = reqParquetSchemaInCacheOrder.getPaths - for (i <- 0 until inMemReqParquetSchema.getFieldCount) { - val t = inMemReqParquetSchema.getFields.get(i) + for (i <- 0 until reqParquetSchemaInCacheOrder.getFieldCount) { + val t = reqParquetSchemaInCacheOrder.getFields.get(i) if (!t.isPrimitive || t.isRepetition(Type.Repetition.REPEATED)) { throw new UnsupportedOperationException("Complex types not supported.") } diff --git a/sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XShims.scala b/sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XShims.scala index b4b735c11b8..5f02f8eac8a 100644 --- a/sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XShims.scala +++ b/sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XShims.scala @@ -430,7 +430,7 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging { }), GpuOverrides.exec[InMemoryTableScanExec]( "Implementation of InMemoryTableScanExec to use GPU accelerated Caching", - ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.STRUCT + ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP).nested(), TypeSig.all), (scan, conf, p, r) => new InMemoryTableScanMeta(scan, conf, p, r)), GpuOverrides.exec[ArrowEvalPythonExec](