Skip to content

Commit

Permalink
Fixed Decimal 128 bug in ParquetCachedBatchSerializer (#4899)
Browse files Browse the repository at this point in the history
* Fixed Parquet schema

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* addressed review comments

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* removed unused import

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

Co-authored-by: Raza Jafri <rjafri@nvidia.com>
  • Loading branch information
razajafri and razajafri authored Mar 8, 2022
1 parent fbb2f07 commit a0aeaba
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 14 deletions.
3 changes: 1 addition & 2 deletions integration_tests/src/main/python/cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
enable_vectorized_confs = [{"spark.sql.inMemoryColumnarStorage.enableVectorizedReader": "true"},
{"spark.sql.inMemoryColumnarStorage.enableVectorizedReader": "false"}]

# cache does not work with 128-bit decimals, see https://github.com/NVIDIA/spark-rapids/issues/4826
_cache_decimal_gens = [decimal_gen_32bit, decimal_gen_64bit]
_cache_decimal_gens = [decimal_gen_32bit, decimal_gen_64bit, decimal_gen_128bit]
_cache_single_array_gens_no_null = [ArrayGen(gen) for gen in all_basic_gens_no_null + _cache_decimal_gens]

decimal_struct_gen= StructGen([['child0', sub_gen] for ind, sub_gen in enumerate(_cache_decimal_gens)])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference,
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.columnar.CachedBatch
import org.apache.spark.sql.execution.datasources.parquet.{ParquetReadSupport, ParquetToSparkSchemaConverter, ParquetWriteSupport, SparkToParquetSchemaConverter, VectorizedColumnReader}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetReadSupport, ParquetToSparkSchemaConverter, ParquetWriteSupport, VectorizedColumnReader}
import org.apache.spark.sql.execution.datasources.parquet.rapids.shims.v2.{ParquetRecordMaterializer, ShimVectorizedColumnReader}
import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, WritableColumnVector}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -889,13 +889,9 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi
val inMemCacheSparkSchema = parquetToSparkSchemaConverter.convert(inMemCacheParquetSchema)

val totalRowCount = parquetFileReader.getRowGroups.asScala.map(_.getRowCount).sum
val sparkToParquetSchemaConverter = new SparkToParquetSchemaConverter(hadoopConf)
val inMemReqSparkSchema = StructType(selectedAttributes.toStructType.map { field =>
inMemCacheSparkSchema.fields(inMemCacheSparkSchema.fieldIndex(field.name))
})
val inMemReqParquetSchema = sparkToParquetSchemaConverter.convert(inMemReqSparkSchema)
val columnsRequested: util.List[ColumnDescriptor] = inMemReqParquetSchema.getColumns

val reqSparkSchemaInCacheOrder = StructType(inMemCacheSparkSchema.filter(f =>
inMemReqSparkSchema.fields.exists(f0 => f0.name.equals(f.name))))

Expand All @@ -907,23 +903,26 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi
index -> inMemReqSparkSchema.fields.indexOf(reqSparkSchemaInCacheOrder.fields(index))
}.toMap

val reqParquetSchemaInCacheOrder =
sparkToParquetSchemaConverter.convert(reqSparkSchemaInCacheOrder)
val reqParquetSchemaInCacheOrder = new org.apache.parquet.schema.MessageType(
inMemCacheParquetSchema.getName(), reqSparkSchemaInCacheOrder.fields.map { f =>
inMemCacheParquetSchema.getFields().get(inMemCacheParquetSchema.getFieldIndex(f.name))
}:_*)

val columnsRequested: util.List[ColumnDescriptor] = reqParquetSchemaInCacheOrder.getColumns
// reset spark schema calculated from parquet schema
hadoopConf.set(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, inMemReqSparkSchema.json)
hadoopConf.set(ParquetWriteSupport.SPARK_ROW_SCHEMA, inMemReqSparkSchema.json)

val columnsInCache: util.List[ColumnDescriptor] = reqParquetSchemaInCacheOrder.getColumns
val typesInCache: util.List[Type] = reqParquetSchemaInCacheOrder.asGroupType.getFields
val missingColumns = new Array[Boolean](inMemReqParquetSchema.getFieldCount)
val missingColumns = new Array[Boolean](reqParquetSchemaInCacheOrder.getFieldCount)

// initialize missingColumns to cover the case where requested column isn't present in the
// cache, which should never happen but just in case it does
val paths: util.List[Array[String]] = inMemReqParquetSchema.getPaths
val paths: util.List[Array[String]] = reqParquetSchemaInCacheOrder.getPaths

for (i <- 0 until inMemReqParquetSchema.getFieldCount) {
val t = inMemReqParquetSchema.getFields.get(i)
for (i <- 0 until reqParquetSchemaInCacheOrder.getFieldCount) {
val t = reqParquetSchemaInCacheOrder.getFields.get(i)
if (!t.isPrimitive || t.isRepetition(Type.Repetition.REPEATED)) {
throw new UnsupportedOperationException("Complex types not supported.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
}),
GpuOverrides.exec[InMemoryTableScanExec](
"Implementation of InMemoryTableScanExec to use GPU accelerated Caching",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.STRUCT
ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT
+ TypeSig.ARRAY + TypeSig.MAP).nested(), TypeSig.all),
(scan, conf, p, r) => new InMemoryTableScanMeta(scan, conf, p, r)),
GpuOverrides.exec[ArrowEvalPythonExec](
Expand Down

0 comments on commit a0aeaba

Please sign in to comment.