From 63b5f357cbf5a9ad6594d64ed593dac5bf995d7e Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Mon, 23 Nov 2020 21:18:31 -0800 Subject: [PATCH] Support for CalendarIntervalType and NullType (#1183) * Support for CalendarIntervalType and NullType Handle nested CalendarIntervalType and NullType Signed-off-by: Raza Jafri * fixed license Co-authored-by: Raza Jafri --- .../spark/rapids/tests/CacheTestSuite.scala | 142 ++++++ .../ParquetCachedBatchSerializer.scala | 404 +++++++++++++----- .../com/nvidia/spark/rapids/FuzzerUtils.scala | 83 +++- .../rapids/SparkQueryCompareTestSuite.scala | 15 +- 4 files changed, 511 insertions(+), 133 deletions(-) create mode 100644 integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/CacheTestSuite.scala diff --git a/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/CacheTestSuite.scala b/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/CacheTestSuite.scala new file mode 100644 index 00000000000..963cd8fb441 --- /dev/null +++ b/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/CacheTestSuite.scala @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.tests + +import com.nvidia.spark.rapids.FuzzerUtils._ +import com.nvidia.spark.rapids.SparkQueryCompareTestSuite + +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.types.{ArrayType, ByteType, CalendarIntervalType, DataType, IntegerType, LongType, MapType, NullType, StringType} + +class CacheTestSuite extends SparkQueryCompareTestSuite { + type DfGenerator = SparkSession => DataFrame + + test("interval") { + testCache { spark: SparkSession => + val schema = createSchema(CalendarIntervalType) + generateDataFrame(spark, schema) + } + } + + test("map(interval)") { + testCache(getMapWithDataDF(CalendarIntervalType)) + } + + test("struct(interval)") { + testCache(getIntervalStructDF(CalendarIntervalType)) + testCache(getIntervalStructDF1(CalendarIntervalType)) + } + + test("array(interval)") { + testCache(getArrayDF(CalendarIntervalType)) + } + + test("array(map(integer, struct(string, byte, interval)))") { + testCache(getDF(CalendarIntervalType)) + } + + test("array(array(map(array(long), struct(interval))))") { + testCache(getMultiNestedDF(CalendarIntervalType)) + } + + test("null") { + testCache { spark: SparkSession => + val schema = createSchema(NullType) + generateDataFrame(spark, schema) + } + } + + test("array(null)") { + testCache(getArrayDF(NullType)) + } + + test("map(null)") { + testCache(getMapWithDataDF(NullType)) + } + + test("struct(null)") { + testCache(getIntervalStructDF(NullType)) + testCache(getIntervalStructDF1(NullType)) + } + + test("array(map(integer, struct(string, byte, null)))") { + testCache(getDF(NullType)) + } + + test("array(array(map(array(long), struct(null))))") { + testCache(getMultiNestedDF(NullType)) + } + +/** Helper functions */ + + def testCache(f: SparkSession => DataFrame): Unit = { + val df = withCpuSparkSession(f) + val regularValues = df.selectExpr("*").collect() + val cachedValues = df.selectExpr("*").cache().collect() + compare(regularValues, cachedValues) + } + + def getArrayDF(dataType: DataType): DfGenerator = { + spark: SparkSession => + val schema = createSchema(ArrayType(dataType)) + generateDataFrame(spark, schema) + } + + def getMapWithDataDF(dataType: DataType): DfGenerator = { + spark: SparkSession => + val schema = + createSchema(StringType, ArrayType( + createSchema(StringType, StringType)), + MapType(StringType, StringType), + MapType(IntegerType, dataType)) + generateDataFrame(spark, schema) + } + + def getIntervalStructDF(dataType: DataType): DfGenerator = { + spark: SparkSession => + val schema = + createSchema( + createSchema(dataType, StringType, dataType)) + generateDataFrame(spark, schema) + } + + def getIntervalStructDF1(dataType: DataType): DfGenerator = { + spark: SparkSession => + val schema = + createSchema(createSchema(IntegerType, IntegerType), dataType) + generateDataFrame(spark, schema) + } + + def getMultiNestedDF(dataType: DataType): DfGenerator = { + spark: SparkSession => + val schema = + createSchema(ArrayType( + createSchema(ArrayType( + createSchema(MapType(ArrayType(LongType), + createSchema(dataType))))))) + generateDataFrame(spark, schema) + } + + def getDF(dataType: DataType): DfGenerator = { + spark: SparkSession => + val schema = + createSchema(ArrayType( + createSchema(MapType(IntegerType, + createSchema(StringType, ByteType, dataType))))) + generateDataFrame(spark, schema) + } +} diff --git a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/ParquetCachedBatchSerializer.scala b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/ParquetCachedBatchSerializer.scala index 0f648147acb..f4ef68d9407 100644 --- a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/ParquetCachedBatchSerializer.scala +++ b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/ParquetCachedBatchSerializer.scala @@ -17,11 +17,12 @@ package com.nvidia.spark.rapids.shims.spark310 import java.io.{InputStream, IOException} +import java.lang.reflect.Method import java.nio.ByteBuffer import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, ListBuffer} import ai.rapids.cudf._ import com.nvidia.spark.rapids._ @@ -30,7 +31,7 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import org.apache.commons.io.output.ByteArrayOutputStream import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.RecordWriter -import org.apache.parquet.HadoopReadOptions +import org.apache.parquet.{HadoopReadOptions, ParquetReadOptions} import org.apache.parquet.column.ParquetProperties import org.apache.parquet.hadoop.{CodecFactory, MemoryManager, ParquetFileReader, ParquetFileWriter, ParquetInputFormat, ParquetOutputFormat, ParquetRecordWriter, ParquetWriter} import org.apache.parquet.hadoop.ParquetFileWriter.Mode @@ -43,12 +44,13 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, ExprId, GenericInternalRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, GenericInternalRow, SpecializedGetters} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.columnar.{CachedBatch, CachedBatchSerializer} import org.apache.spark.sql.execution.datasources.parquet.{ParquetReadSupport, ParquetToSparkSchemaConverter, ParquetWriteSupport, SparkToParquetSchemaConverter, VectorizedColumnReader} import org.apache.spark.sql.execution.datasources.parquet.rapids.ParquetRecordMaterializer -import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector, WritableColumnVector} +import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, WritableColumnVector} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.types._ @@ -130,7 +132,7 @@ class ByteArrayInputFile(buff: Array[Byte]) extends InputFile { } private object ByteArrayOutputFile { - val BLOCK_SIZE = 32 * 1024 * 1024 // 32M + val BLOCK_SIZE: Int = 32 * 1024 * 1024 // 32M } private class ByteArrayOutputFile(stream: ByteArrayOutputStream) extends OutputFile { @@ -167,13 +169,13 @@ private class ByteArrayOutputFile(stream: ByteArrayOutputStream) extends OutputF private class ParquetBufferConsumer(val numRows: Int) extends HostBufferConsumer with AutoCloseable { @transient private[this] val offHeapBuffers = mutable.Queue[(HostMemoryBuffer, Long)]() - private var buffer: Array[Byte] = null + private var buffer: Array[Byte] = _ override def handleBuffer(buffer: HostMemoryBuffer, len: Long): Unit = { offHeapBuffers += Tuple2(buffer, len) } - def getBuffer(): Array[Byte] = { + def getBuffer: Array[Byte] = { if (buffer == null) { writeBuffers() } @@ -190,7 +192,7 @@ private class ParquetBufferConsumer(val numRows: Int) extends HostBufferConsumer val toProcess = offHeapBuffers.dequeueAll(_ => true) // this could be problematic if the buffers are big as their cumulative length could be more // than Int.MAX_SIZE. We could just have a list of buffers in that case and iterate over them - val bytes = toProcess.unzip._2.sum + val bytes = toProcess.map(_._2).sum // for now assert bytes are less than Int.MaxValue assert(bytes <= Int.MaxValue) @@ -211,7 +213,7 @@ private class ParquetBufferConsumer(val numRows: Int) extends HostBufferConsumer private object ParquetCachedBatch { def apply(parquetBuff: ParquetBufferConsumer): ParquetCachedBatch = { - new ParquetCachedBatch(parquetBuff.numRows, parquetBuff.getBuffer()) + new ParquetCachedBatch(parquetBuff.numRows, parquetBuff.getBuffer) } } @@ -225,16 +227,16 @@ case class ParquetCachedBatch(numRows: Int, buffer: Array[Byte]) extends CachedB */ private case class CloseableColumnBatchIterator(iter: Iterator[ColumnarBatch]) extends Iterator[ColumnarBatch] { - var cb: ColumnarBatch = null + var cb: ColumnarBatch = _ private def closeCurrentBatch(): Unit = { if (cb != null) { - cb.close + cb.close() cb = null } } - TaskContext.get().addTaskCompletionListener[Unit]((tc: TaskContext) => { + TaskContext.get().addTaskCompletionListener[Unit]((_: TaskContext) => { closeCurrentBatch() }) @@ -276,6 +278,10 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm { def isTypeSupportedByParquet(dataType: DataType): Boolean = { dataType match { case CalendarIntervalType | NullType => false + case s: StructType => s.forall(field => isTypeSupportedByParquet(field.dataType)) + case ArrayType(elementType, _) => isTypeSupportedByParquet(elementType) + case MapType(keyType, valueType, _) => isTypeSupportedByParquet(keyType) && + isTypeSupportedByParquet(valueType) case _ => true } } @@ -372,19 +378,17 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm { val requestedColumnNames = getColumnNames(selectedAttributes, cacheAttributes) - val cbRdd: RDD[ColumnarBatch] = input.map(batch => { - if (batch.isInstanceOf[ParquetCachedBatch]) { - val parquetCB = batch.asInstanceOf[ParquetCachedBatch] + val cbRdd: RDD[ColumnarBatch] = input.map { + case parquetCB: ParquetCachedBatch => val parquetOptions = ParquetOptions.builder() .includeColumn(requestedColumnNames.asJavaCollection).build() withResource(Table.readParquet(parquetOptions, parquetCB.buffer, 0, parquetCB.sizeInBytes)) { table => GpuColumnVector.from(table, selectedAttributes.map(_.dataType).toArray) } - } else { + case _ => throw new IllegalStateException("I don't know how to convert this batch") - } - }) + } cbRdd } @@ -468,6 +472,77 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm { } } + private abstract class UnsupportedDataHandlerIterator extends Iterator[InternalRow] { + + def handleInternalRow(schema: Seq[DataType], row: InternalRow, newRow: InternalRow): Unit + + def handleInterval(data: SpecializedGetters, index: Int): Any + + def handleStruct( + data: InternalRow, + origSchema: StructType, + supportedSchema: StructType): InternalRow = { + val structRow = InternalRow.fromSeq(supportedSchema) + handleInternalRow(origSchema.map(field => field.dataType), data, structRow) + structRow + } + + def handleMap( + keyType: DataType, + valueType: DataType, + mapData: MapData): MapData = { + val keyData = mapData.keyArray() + val newKeyData = handleArray(keyType, keyData) + val valueData = mapData.valueArray() + val newValueData = handleArray(valueType, valueData) + new ArrayBasedMapData(newKeyData, newValueData) + } + + def handleArray( + dataType: DataType, + arrayData: ArrayData): ArrayData = { + dataType match { + case s@StructType(_) => + val listBuffer = new ListBuffer[InternalRow]() + val supportedSchema = mapping(dataType).asInstanceOf[StructType] + arrayData.foreach(supportedSchema, (_, data) => { + val structRow = + handleStruct(data.asInstanceOf[InternalRow], s, s) + listBuffer += structRow.copy() + }) + new GenericArrayData(listBuffer) + + case ArrayType(elementType, _) => + val arrayList = new ListBuffer[Any]() + scala.Range(0, arrayData.numElements()).foreach { i => + val subArrayData = arrayData.getArray(i) + arrayList.append(handleArray(elementType, subArrayData)) + } + new GenericArrayData(arrayList) + + case m@MapType(_, _, _) => + val mapList = + new ListBuffer[Any]() + scala.Range(0, arrayData.numElements()).foreach { i => + val mapData = arrayData.getMap(i) + mapList.append(handleMap(m.keyType, m.valueType, mapData)) + } + new GenericArrayData(mapList) + + case CalendarIntervalType => + val citList = new ListBuffer[Any]() + scala.Range(0, arrayData.numElements()).foreach { i => + val citRow = handleInterval(arrayData, i) + citList += citRow + } + new GenericArrayData(citList) + + case _ => + arrayData + } + } + } + /** * Consumes the Iterator[CachedBatch] to return either Iterator[ColumnarBatch] or * Iterator[InternalRow] @@ -479,13 +554,13 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm { sharedHadoopConf: Configuration, sharedConf: SQLConf) { - val origRequestedSchema = getCatalystSchema(selectedAttributes, cacheAttributes) - val origCacheSchema = getCatalystSchema(cacheAttributes, cacheAttributes) - val options = HadoopReadOptions.builder(sharedHadoopConf).build() + val origRequestedSchema: Seq[Attribute] = getCatalystSchema(selectedAttributes, cacheAttributes) + val origCacheSchema: Seq[Attribute] = getCatalystSchema(cacheAttributes, cacheAttributes) + val options: ParquetReadOptions = HadoopReadOptions.builder(sharedHadoopConf).build() /** * We are getting this method using reflection because its a package-private */ - val readBatchMethod = + val readBatchMethod: Method = classOf[VectorizedColumnReader].getDeclaredMethod("readBatch", Integer.TYPE, classOf[WritableColumnVector]) readBatchMethod.setAccessible(true) @@ -504,7 +579,7 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm { */ new Iterator[InternalRow]() { - var iter: Iterator[InternalRow] = null + var iter: Iterator[InternalRow] = _ override def hasNext: Boolean = { // go over the batch and get the next non-degenerate iterator @@ -553,7 +628,7 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm { cacheSchema.toStructType, new ParquetToSparkSchemaConverter(sharedHadoopConf), None /*convertTz*/, LegacyBehaviorPolicy.CORRECTED)) - for (i <- 0 until rows.toInt) { + for (_ <- 0 until rows.toInt) { val row = recordReader.read unsafeRows += row.copy() } @@ -564,8 +639,8 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm { val unsafeProjection = GenerateUnsafeProjection.generate(requestedSchema, cacheSchema) if (hasUnsupportedType) { - new Iterator[InternalRow]() { - val wrappedIter = iter + new UnsupportedDataHandlerIterator() { + val wrappedIter: Iterator[InternalRow] = iter val newRow = new GenericInternalRow(cacheSchema.length) override def hasNext: Boolean = wrappedIter.hasNext @@ -573,35 +648,69 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm { override def next(): InternalRow = { //read a row and convert it to what the caller is expecting val row = wrappedIter.next() - var newIndex = 0 - origCacheSchema.indices.map { index => - val attribute = origCacheSchema(index) - attribute.dataType match { - case CalendarIntervalType => { - // create a CalendarInterval based on the next three values - if (row.isNullAt(newIndex) || row.isNullAt(newIndex + 1) || - row.isNullAt(newIndex + 2)) { + handleInternalRow(origCacheSchema.map(attr => attr.dataType), row, newRow) + val unsafeProjection = + GenerateUnsafeProjection.generate(origRequestedSchema, origCacheSchema) + unsafeProjection.apply(newRow) + } + + override def handleInterval( + data: SpecializedGetters, + index: Int): CalendarInterval = { + if (data.isNullAt(index)) { + null + } else { + val structData = data.getStruct(index, 3) + new CalendarInterval(structData.getInt(0), + structData.getInt(1), structData.getLong(2)) + } + } + + override def handleInternalRow( + schema: Seq[DataType], + row: InternalRow, + newRow: InternalRow): Unit = { + schema.indices.foreach { index => + val dataType = schema(index) + if (mapping.contains(dataType) || dataType == CalendarIntervalType || + dataType == NullType) { + if (row.isNullAt(index)) { + newRow.setNullAt(index) + } else { + dataType match { + case s: StructType => + val supportedSchema = mapping(dataType) + .asInstanceOf[StructType] + val structRow = + handleStruct(row.getStruct(index, supportedSchema.size), s, s) + newRow.update(index, structRow) + + case a@ArrayType(_, _) => + val arrayData = row.getArray(index) + newRow.update(index, handleArray(a.elementType, arrayData)) + + case MapType(keyType, valueType, _) => + val mapData = row.getMap(index) + newRow.update(index, handleMap(keyType, valueType, mapData)) + + case CalendarIntervalType => + val interval = handleInterval(row, index) + if (interval == null) { + newRow.setNullAt(index) + } else { + newRow.setInterval(index, interval) + } + + case NullType => newRow.setNullAt(index) - } else { - val interval = new CalendarInterval(row.getInt(newIndex), - row.getInt(newIndex + 1), row.getLong(newIndex + 2)) - newRow.setInterval(index, interval) + case _ => + newRow.update(index, row.get(index, dataType)) } - newIndex += 3 - } - case NullType => { - newRow.setNullAt(index) - newIndex += 1 - } - case _ => { - newRow.update(index, row.get(newIndex, attribute.dataType)) - newIndex += 1 } + } else { + newRow.update(index, row.get(index, dataType)) } } - val unsafeProjection = - GenerateUnsafeProjection.generate(origRequestedSchema, origCacheSchema) - unsafeProjection.apply(newRow) } } } else { @@ -634,7 +743,7 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm { val sparkSchema = parquetToSparkSchemaConverter.convert(parquetSchema) val sparkToParquetSchemaConverter = new SparkToParquetSchemaConverter(sharedHadoopConf) - val (cacheSchema, requestedSchema) = if (hasUnsupportedType) { + val (_, requestedSchema) = if (hasUnsupportedType) { getSupportedSchemaFromUnsupported(origCacheSchema, origRequestedSchema) } else { (origCacheSchema, origRequestedSchema) @@ -652,12 +761,12 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm { * Read the next RowGroup and read each column and return the columnarBatch */ new Iterator[ColumnarBatch] { - val columnVectors = OffHeapColumnVector.allocateColumns(capacity, - requestedSchema.toStructType) + val columnVectors: Array[OffHeapColumnVector] = + OffHeapColumnVector.allocateColumns(capacity, requestedSchema.toStructType) val columnarBatch = new ColumnarBatch(columnVectors .asInstanceOf[Array[org.apache.spark.sql.vectorized.ColumnVector]]) val missingColumns = new Array[Boolean](reqParquetSchema.getFieldCount) - var columnReaders: Array[VectorizedColumnReader] = null + var columnReaders: Array[VectorizedColumnReader] = _ var rowsReturned: Long = 0L var totalRowCount: Long = 0L @@ -746,6 +855,11 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm { } } + private val intervalStructType = new StructType() + .add("_days", IntegerType) + .add("_months", IntegerType) + .add("_ms", LongType) + /** * This is a private helper class to return Iterator to convert InternalRow or ColumnarBatch to * CachedBatch. There is no type checking so if the type of T is anything besides InternalRow @@ -778,11 +892,11 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm { */ private class InternalRowToCachedBatchIterator extends Iterator[CachedBatch]() { // is there a type that spark doesn't support by default in the schema? - val hasUnsupportedType = cachedAttributes.exists { attribute => + val hasUnsupportedType: Boolean = cachedAttributes.exists { attribute => !isTypeSupportedByParquet(attribute.dataType) } - val newCachedAttributes = + val newCachedAttributes: Seq[Attribute] = if (hasUnsupportedType) { val newCachedAttributes = getSupportedSchemaFromUnsupported(getCatalystSchema(cachedAttributes, @@ -792,54 +906,88 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm { ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, newCachedAttributes.toStructType.json) sharedHadoopConf.set( ParquetWriteSupport.SPARK_ROW_SCHEMA, newCachedAttributes.toStructType.json) - Option(newCachedAttributes) + newCachedAttributes } else { - Option.empty + cachedAttributes } def getIterator: Iterator[InternalRow] = { if (!hasUnsupportedType) { iter.asInstanceOf[Iterator[InternalRow]] } else { - new Iterator[InternalRow] { + new UnsupportedDataHandlerIterator { - val wrappedIter = iter.asInstanceOf[Iterator[InternalRow]] - // This row has CalendarIntervalType exploded - val newRow = InternalRow.fromSeq(newCachedAttributes.get) + val wrappedIter: Iterator[InternalRow] = iter.asInstanceOf[Iterator[InternalRow]] - override def hasNext: Boolean = iter.hasNext + val newRow: InternalRow = InternalRow.fromSeq(newCachedAttributes) + + override def hasNext: Boolean = wrappedIter.hasNext override def next(): InternalRow = { val row = wrappedIter.next() - val rowValueAndType = - scala.Range(0, cachedAttributes.size).zip(cachedAttributes).map { - in => (row.get(in._1, in._2.dataType), in._2.dataType) + handleInternalRow(cachedAttributes.map(attr => attr.dataType), row, newRow) + newRow + } + + override def handleInterval( + data: SpecializedGetters, + index: Int): InternalRow = { + val citRow = InternalRow(IntegerType, IntegerType, LongType) + if (data.isNullAt(index)) { + null + } else { + val cit = data.getInterval(index) + citRow.setInt(0, cit.months) + citRow.setInt(1, cit.days) + citRow.setLong(2, cit.microseconds) + citRow } + } - var newIndex = 0 - rowValueAndType.foreach { value => - value match { - case (_, CalendarIntervalType) => { - // write exploded CalendarInterval - val interval = value._1.asInstanceOf[CalendarInterval] - if (interval == null) { - newRow.setNullAt(newIndex) - newRow.setNullAt(newIndex + 1) - newRow.setNullAt(newIndex + 2) - } else { - newRow.update(newIndex, interval.months) - newRow.update(newIndex + 1, interval.days) - newRow.update(newIndex + 2, interval.microseconds) + override def handleInternalRow( + schema: Seq[DataType], + row: InternalRow, + newRow: InternalRow): Unit = { + schema.indices.foreach { index => + val dataType = schema(index) + if (mapping.contains(dataType) || dataType == CalendarIntervalType || + dataType == NullType) { + if (row.isNullAt(index)) { + newRow.setNullAt(index) + } else { + dataType match { + case s: StructType => + val newSchema = mapping(dataType).asInstanceOf[StructType] + val structRow = + handleStruct(row.getStruct(index, s.fields.length), s, newSchema) + newRow.update(index, structRow) + + case ArrayType(arrayDataType, _) => + val arrayData = row.getArray(index) + val newArrayData = handleArray(arrayDataType, arrayData) + newRow.update(index, newArrayData) + + case MapType(keyType, valueType, _) => + val mapData = row.getMap(index) + val map = handleMap(keyType, valueType, mapData) + newRow.update(index, map) + + case CalendarIntervalType => + val structData: InternalRow = handleInterval(row, index) + if (structData == null) { + newRow.setNullAt(index) + } else { + newRow.update(index, structData) + } + + case _ => + newRow.update(index, row.get(index, dataType)) } - newIndex += 3 - } - case (_, _) => { - newRow.update(newIndex, value._1) - newIndex += 1 } + } else { + newRow.update(index, row.get(index, dataType)) } } - newRow } } } @@ -888,39 +1036,71 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm { } } + val mapping = new mutable.HashMap[DataType, DataType]() + private def getSupportedSchemaFromUnsupported( cachedAttributes: Seq[Attribute], requestedAttributes: Seq[Attribute] = Seq.empty): (Seq[Attribute], Seq[Attribute]) = { + def getSupportedDataType(dataType: DataType, nullable: Boolean = true): DataType = { + dataType match { + case CalendarIntervalType => + intervalStructType + case NullType => + ByteType + case s: StructType => + val newStructType = StructType( + s.map { field => + StructField(field.name, + getSupportedDataType(field.dataType, field.nullable), field.nullable, + field.metadata) + }) + mapping.put(s, newStructType) + newStructType + case a@ArrayType(elementType, nullable) => + val newArrayType = + ArrayType(getSupportedDataType(elementType, nullable), nullable) + mapping.put(a, newArrayType) + newArrayType + case m@MapType(keyType, valueType, nullable) => + val newKeyType = getSupportedDataType(keyType, nullable) + val newValueType = getSupportedDataType(valueType, nullable) + val mapType = MapType(newKeyType, newValueType, nullable) + mapping.put(m, mapType) + mapType + case _ => + dataType + } + } // we only handle CalendarIntervalType and NullType ATM // convert it to a supported type - val mapping = new mutable.HashMap[ExprId, List[Attribute]]() - val newCachedAttributes = cachedAttributes.flatMap { + val newCachedAttributes = cachedAttributes.map { attribute => - if (attribute.dataType == DataTypes.CalendarIntervalType) { - val list = List(AttributeReference(attribute.name + "_cit_months", - DataTypes.IntegerType, attribute.nullable, - metadata = attribute.metadata)().asInstanceOf[Attribute], - AttributeReference(attribute.name + "_cit_days", - DataTypes.IntegerType, attribute.nullable, - metadata = attribute.metadata)().asInstanceOf[Attribute], - AttributeReference(attribute.name + "_cit_ms", - DataTypes.LongType, attribute.nullable, - metadata = attribute.metadata)().asInstanceOf[Attribute]) - mapping.put(attribute.exprId, list) - list - } else if (attribute.dataType == DataTypes.NullType) { - val list = List(AttributeReference(attribute.name + "_nulltype", DataTypes.IntegerType, - attribute.nullable, metadata = attribute.metadata)().asInstanceOf[Attribute]) - mapping.put(attribute.exprId, list) - list - } - else { - List(attribute) + attribute.dataType match { + case CalendarIntervalType => + AttributeReference(attribute.name, intervalStructType, attribute.nullable, + metadata = attribute.metadata)(attribute.exprId) + .asInstanceOf[Attribute] + case NullType => + AttributeReference(attribute.name, DataTypes.ByteType, nullable = true, + metadata = attribute.metadata)(attribute.exprId).asInstanceOf[Attribute] + case s: StructType => + AttributeReference(attribute.name, getSupportedDataType(s, attribute.nullable), + attribute.nullable, attribute.metadata)(attribute.exprId) + case a: ArrayType => + AttributeReference(attribute.name, getSupportedDataType(a, attribute.nullable), + attribute.nullable, attribute.metadata)(attribute.exprId) + case m: MapType => + AttributeReference(attribute.name, getSupportedDataType(m, attribute.nullable), + attribute.nullable, attribute.metadata)(attribute.exprId) + case _ => + attribute } } - val newRequestedAttributes = requestedAttributes.flatMap { attribute => - mapping.getOrElse(attribute.exprId, List(attribute)) + + val newRequestedAttributes = newCachedAttributes.filter { attribute => + requestedAttributes.map(_.exprId).contains(attribute.exprId) } + (newCachedAttributes, newRequestedAttributes) } @@ -1013,7 +1193,7 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm { predicates: Seq[Expression], cachedAttributes: Seq[Attribute]): (Int, Iterator[CachedBatch]) => Iterator[CachedBatch] = { //essentially a noop - (partId: Int, b: Iterator[CachedBatch]) => b + (_: Int, b: Iterator[CachedBatch]) => b } } @@ -1052,7 +1232,7 @@ private class ParquetOutputFileFormat { } private object ParquetOutputFileFormat { - var memoryManager: MemoryManager = null + var memoryManager: MemoryManager = _ val DEFAULT_MEMORY_POOL_RATIO: Float = 0.95f val DEFAULT_MIN_MEMORY_ALLOCATION: Long = 1 * 1024 * 1024 // 1MB diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/FuzzerUtils.scala b/tests/src/test/scala/com/nvidia/spark/rapids/FuzzerUtils.scala index 659234bd86b..b60f3fd510b 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/FuzzerUtils.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/FuzzerUtils.scala @@ -19,13 +19,16 @@ package com.nvidia.spark.rapids import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ListBuffer import scala.util.Random import com.nvidia.spark.rapids.GpuColumnVector.GpuColumnarBatchBuilder import org.apache.spark.sql.{DataFrame, Row, SparkSession} -import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, MapType, StructField, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.unsafe.types.CalendarInterval /** * Utilities for creating random inputs for unit tests. @@ -40,6 +43,14 @@ object FuzzerUtils { asciiStringsOnly = false, maxStringLen = 64) + /** + * Create a schema with the specified data types. + */ + def createSchema(dataTypes: DataType*): StructType = { + new StructType(dataTypes.zipWithIndex + .map(pair => StructField(s"c${pair._2}", pair._1, true)).toArray) + } + /** * Create a schema with the specified data types. */ @@ -155,7 +166,7 @@ object FuzzerUtils { def generateDataFrame( spark: SparkSession, schema: StructType, - rowCount: Int, + rowCount: Int = 1024, options: FuzzerOptions = DEFAULT_OPTIONS, seed: Long = 0): DataFrame = { val r = new Random(seed) @@ -166,27 +177,58 @@ object FuzzerUtils { /** * Creates a Row with random data based on the given field definitions. */ - def generateRow(fields: Array[StructField], rand: Random, options: FuzzerOptions) = { + def generateRow(fields: Array[StructField], rand: Random, options: FuzzerOptions): Row = { val r = new EnhancedRandom(rand, options) + + def handleDataType(dataType: DataType): Any = { + dataType match { + case DataTypes.BooleanType => r.nextBoolean() + case DataTypes.ByteType => r.nextByte() + case DataTypes.ShortType => r.nextShort() + case DataTypes.IntegerType => r.nextInt() + case DataTypes.LongType => r.nextLong() + case DataTypes.FloatType => r.nextFloat() + case DataTypes.DoubleType => r.nextDouble() + case DataTypes.StringType => r.nextString() + case DataTypes.TimestampType => r.nextTimestamp() + case DataTypes.DateType => r.nextDate() + case DataTypes.CalendarIntervalType => r.nextInterval() + case DataTypes.NullType => null + case ArrayType(elementType, _) => + val list = new ListBuffer[Any] + //only generating 5 items in the array this can later be made configurable + scala.Range(0, 5).foreach { _ => + list.append(handleDataType(elementType)) + } + list.toList + case MapType(keyType, valueType, _) => + val keyList = new ListBuffer[Any] + //only generating 5 items in the array this can later be made configurable + scala.Range(0, 5).foreach { _ => + keyList.append(handleDataType(keyType)) + } + val valueList = new ListBuffer[Any] + //only generating 5 items in the array this can later be made configurable + scala.Range(0, 5).foreach { _ => + valueList.append(handleDataType(valueType)) + } + val map = new mutable.HashMap[Any, Any]() + keyList.zip(valueList).map { values => + map.put(values._1, values._2) + } + map.toMap + case s: StructType => + generateRow(s.fields, rand, options) + case _ => throw new IllegalStateException( + s"fuzzer does not support data type $dataType") + } + } + Row.fromSeq(fields.map { field => if (field.nullable && r.nextFloat() < 0.2) { null } else { - field.dataType match { - case DataTypes.BooleanType => r.nextBoolean() - case DataTypes.ByteType => r.nextByte() - case DataTypes.ShortType => r.nextShort() - case DataTypes.IntegerType => r.nextInt() - case DataTypes.LongType => r.nextLong() - case DataTypes.FloatType => r.nextFloat() - case DataTypes.DoubleType => r.nextDouble() - case DataTypes.StringType => r.nextString() - case DataTypes.TimestampType => r.nextTimestamp() - case DataTypes.DateType => r.nextDate() - case DataTypes.NullType => null - case _ => throw new IllegalStateException( - s"fuzzer does not support data type ${field.dataType}") - } + handleDataType(field.dataType) } }) } @@ -201,6 +243,10 @@ object FuzzerUtils { */ class EnhancedRandom(r: Random, options: FuzzerOptions) { + def nextInterval(): CalendarInterval = { + new CalendarInterval(nextInt(), nextInt(), nextLong()) + } + def nextBoolean(): Boolean = r.nextBoolean() def nextByte(): Byte = { @@ -305,7 +351,6 @@ class EnhancedRandom(r: Random, options: FuzzerOptions) { } else { r.nextString(r.nextInt(options.maxStringLen)) } - } private val ASCII_CHARS = "abcdefghijklmnopqrstuvwxyz" diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala index a88fefe3554..4126e6d681c 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala @@ -70,7 +70,8 @@ object SparkSessionHolder extends Logging { TimeZone.setDefault(TimeZone.getTimeZone("UTC")) // Add Locale setting Locale.setDefault(Locale.US) - SparkSession.builder() + + val builder = SparkSession.builder() .master("local[1]") .config("spark.sql.adaptive.enabled", "false") .config("spark.rapids.sql.enabled", "false") @@ -80,7 +81,17 @@ object SparkSessionHolder extends Logging { "com.nvidia.spark.rapids.ExecutionPlanCaptureCallback") .config("spark.sql.warehouse.dir", sparkWarehouseDir.getAbsolutePath) .appName("rapids spark plugin integration tests (scala)") - .getOrCreate() + + // comma separated config from command line + val commandLineVariables = System.getenv("SPARK_CONF") + if (commandLineVariables != null) { + commandLineVariables.split(",").foreach { s => + val a = s.split("=") + builder.config(a(0), a(1)) + } + } + + builder.getOrCreate() } private def reinitSession(): Unit = {