diff --git a/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/ParquetCachedBatchSerializer.scala b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/ParquetCachedBatchSerializer.scala index db8a28ef3e8..1f3d51e030a 100644 --- a/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/ParquetCachedBatchSerializer.scala +++ b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/ParquetCachedBatchSerializer.scala @@ -327,6 +327,7 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi conf: SQLConf): RDD[CachedBatch] = { val rapidsConf = new RapidsConf(conf) + val useCompression = conf.useCompression val bytesAllowedPerBatch = getBytesAllowedPerBatch(conf) val (schemaWithUnambiguousNames, _) = getSupportedSchemaFromUnsupported(schema) val structSchema = schemaWithUnambiguousNames.toStructType @@ -349,7 +350,7 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi } else { withResource(putOnGpuIfNeeded(batch)) { gpuCB => compressColumnarBatchWithParquet(gpuCB, structSchema, schema.toStructType, - bytesAllowedPerBatch) + bytesAllowedPerBatch, useCompression) } } }) @@ -367,7 +368,8 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi oldGpuCB: ColumnarBatch, schema: StructType, origSchema: StructType, - bytesAllowedPerBatch: Long): List[ParquetCachedBatch] = { + bytesAllowedPerBatch: Long, + useCompression: Boolean): List[ParquetCachedBatch] = { val estimatedRowSize = scala.Range(0, oldGpuCB.numCols()).map { idx => oldGpuCB.column(idx).asInstanceOf[GpuColumnVector] .getBase.getDeviceMemorySize / oldGpuCB.numRows() @@ -418,7 +420,7 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi for (i <- splitVectors.head.indices) { withResource(makeTableForIndex(i)) { table => - val buffer = writeTableToCachedBatch(table, schema) + val buffer = writeTableToCachedBatch(table, schema, useCompression) buffers += ParquetCachedBatch(buffer) } } @@ -427,7 +429,7 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi } } else { withResource(GpuColumnVector.from(gpuCB)) { table => - val buffer = writeTableToCachedBatch(table, schema) + val buffer = writeTableToCachedBatch(table, schema, useCompression) buffers += ParquetCachedBatch(buffer) } } @@ -435,13 +437,22 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi } } + def getParquetWriterOptions( + useCompression: Boolean, + schema: StructType): ParquetWriterOptions = { + val compressionType = if (useCompression) CompressionType.SNAPPY else CompressionType.NONE + SchemaUtils + .writerOptionsFromSchema(ParquetWriterOptions.builder(), schema, writeInt96 = false) + .withCompressionType(compressionType) + .withStatisticsFrequency(StatisticsFrequency.ROWGROUP).build() + } + private def writeTableToCachedBatch( table: Table, - schema: StructType): ParquetBufferConsumer = { + schema: StructType, + useCompression: Boolean): ParquetBufferConsumer = { val buffer = new ParquetBufferConsumer(table.getRowCount.toInt) - val opts = SchemaUtils - .writerOptionsFromSchema(ParquetWriterOptions.builder(), schema, writeInt96 = false) - .withStatisticsFrequency(StatisticsFrequency.ROWGROUP).build() + val opts = getParquetWriterOptions(useCompression, schema) withResource(Table.writeParquetChunked(opts, buffer)) { writer => writer.write(table) } @@ -1454,6 +1465,7 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi conf: SQLConf): RDD[CachedBatch] = { val rapidsConf = new RapidsConf(conf) + val useCompression = conf.useCompression val bytesAllowedPerBatch = getBytesAllowedPerBatch(conf) val (schemaWithUnambiguousNames, _) = getSupportedSchemaFromUnsupported(schema) if (rapidsConf.isSqlEnabled && rapidsConf.isSqlExecuteOnGPU && @@ -1465,7 +1477,7 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi }) columnarBatchRdd.flatMap(cb => { withResource(cb)(cb => compressColumnarBatchWithParquet(cb, structSchema, - schema.toStructType, bytesAllowedPerBatch)) + schema.toStructType, bytesAllowedPerBatch, useCompression)) }) } else { val broadcastedConf = SparkSession.active.sparkContext.broadcast(conf.getAllConfs) diff --git a/tests-spark310+/src/test/scala/com/nvidia/spark/rapids/shims/Spark310ParquetWriterSuite.scala b/tests-spark310+/src/test/scala/com/nvidia/spark/rapids/shims/Spark310ParquetWriterSuite.scala index 636d4e0086c..a875263c482 100644 --- a/tests-spark310+/src/test/scala/com/nvidia/spark/rapids/shims/Spark310ParquetWriterSuite.scala +++ b/tests-spark310+/src/test/scala/com/nvidia/spark/rapids/shims/Spark310ParquetWriterSuite.scala @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.shims import scala.collection.mutable -import ai.rapids.cudf.{ColumnVector, DType, Table, TableWriter} +import ai.rapids.cudf.{ColumnVector, CompressionType, DType, Table, TableWriter} import com.nvidia.spark.rapids._ import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} import org.mockito.ArgumentMatchers._ @@ -82,6 +82,20 @@ class Spark310ParquetWriterSuite extends SparkQueryCompareTestSuite { testColumnarBatchToCachedBatchIterator(cb, schema) } + test("test useCompression conf is honored") { + val ser = new ParquetCachedBatchSerializer() + val schema = new StructType().add("value", "string") + List(false, true).foreach { comp => + val opts = ser.getParquetWriterOptions(comp, schema) + assert( + (if (comp) { + CompressionType.SNAPPY + } else { + CompressionType.NONE + }) == opts.getCompressionType) + } + } + val ROWS = 3 * 1024 * 1024 private def getCudfAndGpuVectors(onHost: Boolean = false): (ColumnVector, GpuColumnVector)= { @@ -149,7 +163,8 @@ class Spark310ParquetWriterSuite extends SparkQueryCompareTestSuite { Array(StructField("empty", ByteType, false), StructField("empty", ByteType, false), StructField("empty", ByteType, false))) - ser.compressColumnarBatchWithParquet(cb, dummySchema, dummySchema, BYTES_ALLOWED_PER_BATCH) + ser.compressColumnarBatchWithParquet(cb, dummySchema, dummySchema, + BYTES_ALLOWED_PER_BATCH, false) theTableMock.close() } }