From be9f42603c256d744c036139cb778d6e748d99e3 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Mon, 27 Jul 2020 13:01:50 -0500 Subject: [PATCH] Add in checks for Parquet LEGACY date/time rebase (#435) --- .../src/main/python/parquet_test.py | 12 +-- .../scala/com/nvidia/spark/RebaseHelper.scala | 82 +++++++++++++++++++ .../spark/rapids/ColumnarOutputWriter.scala | 19 ++--- .../spark/rapids/GpuParquetFileFormat.scala | 35 +++++++- .../nvidia/spark/rapids/GpuParquetScan.scala | 61 ++++++++++++-- .../sql/rapids/execution/TrampolineUtil.scala | 9 +- .../spark/rapids/ParquetWriterSuite.scala | 32 +++++++- .../rapids/SparkQueryCompareTestSuite.scala | 55 +++++++++++++ 8 files changed, 278 insertions(+), 27 deletions(-) create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala diff --git a/integration_tests/src/main/python/parquet_test.py b/integration_tests/src/main/python/parquet_test.py index 7e6509ffd09..bac8cd11d29 100644 --- a/integration_tests/src/main/python/parquet_test.py +++ b/integration_tests/src/main/python/parquet_test.py @@ -30,7 +30,7 @@ def read_parquet_sql(data_path): parquet_gens_list = [[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, boolean_gen, date_gen, - TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))], + TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))], pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/132'))] @pytest.mark.parametrize('parquet_gens', parquet_gens_list, ids=idfn) @@ -80,7 +80,7 @@ def test_compress_read_round_trip(spark_tmp_path, compress): string_gen, date_gen, # Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with # timestamp_gen - TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))] + TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))] @pytest.mark.parametrize('parquet_gen', parquet_pred_push_gens, ids=idfn) @pytest.mark.parametrize('read_func', [read_parquet_df, read_parquet_sql]) @@ -102,7 +102,7 @@ def test_pred_push_round_trip(spark_tmp_path, parquet_gen, read_func): def test_ts_read_round_trip(spark_tmp_path, ts_write, ts_rebase): # Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with # timestamp_gen - gen = TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc)) + gen = TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc)) data_path = spark_tmp_path + '/PARQUET_DATA' with_cpu_session( lambda spark : unary_op_df(spark, gen).write.parquet(data_path), @@ -113,7 +113,7 @@ def test_ts_read_round_trip(spark_tmp_path, ts_write, ts_rebase): parquet_gens_legacy_list = [[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), - TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))], + TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))], pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/133')), pytest.param([date_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/133'))] @@ -132,7 +132,7 @@ def test_simple_partitioned_read(spark_tmp_path): # we should go with a more standard set of generators parquet_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), - TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))] + TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))] gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] first_data_path = spark_tmp_path + '/PARQUET_DATA/key=0' with_cpu_session( @@ -153,7 +153,7 @@ def test_read_merge_schema(spark_tmp_path): # we should go with a more standard set of generators parquet_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), - TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))] + TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))] first_gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] first_data_path = spark_tmp_path + '/PARQUET_DATA/key=0' with_cpu_session( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala b/sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala new file mode 100644 index 00000000000..95a869c4e8c --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark + +import ai.rapids.cudf.{ColumnVector, DType, Scalar} +import com.nvidia.spark.rapids.Arm + +import org.apache.spark.sql.catalyst.util.RebaseDateTime +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.rapids.execution.TrampolineUtil + +object RebaseHelper extends Arm { + private[this] def isDateTimeRebaseNeeded(column: ColumnVector, + startDay: Int, + startTs: Long): Boolean = { + val dtype = column.getType + if (dtype == DType.TIMESTAMP_DAYS) { + withResource(Scalar.timestampDaysFromInt(startDay)) { minGood => + withResource(column.lessThan(minGood)) { hasBad => + withResource(hasBad.any()) { a => + a.getBoolean + } + } + } + } else if (dtype.isTimestamp) { + assert(dtype == DType.TIMESTAMP_MICROSECONDS) + withResource( + Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, startTs)) { minGood => + withResource(column.lessThan(minGood)) { hasBad => + withResource(hasBad.any()) { a => + a.getBoolean + } + } + } + } else { + false + } + } + + def isDateTimeRebaseNeededWrite(column: ColumnVector): Boolean = + isDateTimeRebaseNeeded(column, + RebaseDateTime.lastSwitchGregorianDay, + RebaseDateTime.lastSwitchGregorianTs) + + def isDateTimeRebaseNeededRead(column: ColumnVector): Boolean = + isDateTimeRebaseNeeded(column, + RebaseDateTime.lastSwitchJulianDay, + RebaseDateTime.lastSwitchJulianTs) + + def newRebaseExceptionInRead(format: String): Exception = { + val config = if (format == "Parquet") { + SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ.key + } else if (format == "Avro") { + SQLConf.LEGACY_AVRO_REBASE_MODE_IN_READ.key + } else { + throw new IllegalStateException("unrecognized format " + format) + } + TrampolineUtil.makeSparkUpgradeException("3.0", + "reading dates before 1582-10-15 or timestamps before " + + s"1900-01-01T00:00:00Z from $format files can be ambiguous, as the files may be written by " + + "Spark 2.x or legacy versions of Hive, which uses a legacy hybrid calendar that is " + + "different from Spark 3.0+'s Proleptic Gregorian calendar. See more details in " + + s"SPARK-31404. The RAPIDS Accelerator does not support reading these 'LEGACY' files. To do " + + s"so you should disable $format support in the RAPIDS Accelerator " + + s"or set $config to 'CORRECTED' to read the datetime values as it is.", + null) + } +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala index 9dcb0dcbe28..e3b73932fb3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids import scala.collection.mutable -import ai.rapids.cudf.{HostBufferConsumer, HostMemoryBuffer, NvtxColor, NvtxRange, TableWriter} +import ai.rapids.cudf.{HostBufferConsumer, HostMemoryBuffer, NvtxColor, NvtxRange, Table, TableWriter} import com.nvidia.spark.rapids.RapidsPluginImplicits._ import org.apache.hadoop.fs.{FSDataOutputStream, Path} import org.apache.hadoop.mapreduce.TaskAttemptContext @@ -60,7 +60,7 @@ abstract class ColumnarOutputWriterFactory extends Serializable { * `org.apache.spark.sql.execution.datasources.OutputWriter`. */ abstract class ColumnarOutputWriter(path: String, context: TaskAttemptContext, - dataSchema: StructType, rangeName: String) extends HostBufferConsumer { + dataSchema: StructType, rangeName: String) extends HostBufferConsumer with Arm { val tableWriter: TableWriter val conf = context.getConfiguration @@ -130,6 +130,10 @@ abstract class ColumnarOutputWriter(path: String, context: TaskAttemptContext, } } + protected def scanTableBeforeWrite(table: Table): Unit = { + // NOOP for now, but allows a child to override this + } + /** * Writes the columnar batch and returns the time in ns taken to write * @@ -140,17 +144,12 @@ abstract class ColumnarOutputWriter(path: String, context: TaskAttemptContext, var needToCloseBatch = true try { val startTimestamp = System.nanoTime - val nvtxRange = new NvtxRange(s"GPU $rangeName write", NvtxColor.BLUE) - try { - val table = GpuColumnVector.from(batch) - try { + withResource(new NvtxRange(s"GPU $rangeName write", NvtxColor.BLUE)) { _ => + withResource(GpuColumnVector.from(batch)) { table => + scanTableBeforeWrite(table) anythingWritten = true tableWriter.write(table) - } finally { - table.close() } - } finally { - nvtxRange.close() } // Batch is no longer needed, write process from here does not use GPU. diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala index 2761a6f5404..6a18ec8c26a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala @@ -17,6 +17,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf._ +import com.nvidia.spark.RebaseHelper import org.apache.hadoop.mapreduce.{Job, OutputCommitter, TaskAttemptContext} import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} import org.apache.parquet.hadoop.ParquetOutputFormat.JobSummaryLevel @@ -25,11 +26,12 @@ import org.apache.parquet.hadoop.util.ContextUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetWriteSupport} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.rapids.execution.TrampolineUtil -import org.apache.spark.sql.types.{StructType, TimestampType} +import org.apache.spark.sql.types.{DateType, StructType, TimestampType} object GpuParquetFileFormat { def tagGpuSupport( @@ -69,6 +71,21 @@ object GpuParquetFileFormat { } } + val schemaHasDates = schema.exists { field => + TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[DateType]) + } + + sqlConf.getConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE) match { + case "EXCEPTION" => //Good + case "CORRECTED" => //Good + case "LEGACY" => + if (schemaHasDates || schemaHasTimestamps) { + meta.willNotWorkOnGpu("LEGACY rebase mode for dates and timestamps is not supported") + } + case other => + meta.willNotWorkOnGpu(s"$other is not a supported rebase mode") + } + if (meta.canThisBeReplaced) { Some(new GpuParquetFileFormat) } else { @@ -101,6 +118,9 @@ class GpuParquetFileFormat extends ColumnarFileFormat with Logging { val conf = ContextUtil.getConfiguration(job) + val dateTimeRebaseException = + "EXCEPTION".equals(conf.get(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key)) + val committerClass = conf.getClass( SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, @@ -179,7 +199,7 @@ class GpuParquetFileFormat extends ColumnarFileFormat with Logging { path: String, dataSchema: StructType, context: TaskAttemptContext): ColumnarOutputWriter = { - new GpuParquetWriter(path, dataSchema, compressionType, context) + new GpuParquetWriter(path, dataSchema, compressionType, dateTimeRebaseException, context) } override def getFileExtension(context: TaskAttemptContext): String = { @@ -193,9 +213,20 @@ class GpuParquetWriter( path: String, dataSchema: StructType, compressionType: CompressionType, + dateTimeRebaseException: Boolean, context: TaskAttemptContext) extends ColumnarOutputWriter(path, context, dataSchema, "Parquet") { + override def scanTableBeforeWrite(table: Table): Unit = { + if (dateTimeRebaseException) { + (0 until table.getNumberOfColumns).foreach { i => + if (RebaseHelper.isDateTimeRebaseNeededWrite(table.getColumn(i))) { + throw DataSourceUtils.newRebaseExceptionInWrite("Parquet") + } + } + } + } + override val tableWriter: TableWriter = { val writeContext = new ParquetWriteSupport().init(conf) val builder = ParquetWriterOptions.builder() diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala index d19840cd208..d780fefcf4d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala @@ -27,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer import scala.math.max import ai.rapids.cudf.{ColumnVector, DType, HostMemoryBuffer, NvtxColor, ParquetOptions, Table} +import com.nvidia.spark.RebaseHelper import com.nvidia.spark.rapids.GpuMetricNames._ import com.nvidia.spark.rapids.ParquetPartitionReader.CopyRange import com.nvidia.spark.rapids.RapidsPluginImplicits._ @@ -56,8 +57,9 @@ import org.apache.spark.sql.execution.datasources.v2.{FilePartitionReaderFactory import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration @@ -114,6 +116,8 @@ object GpuParquetScan { sparkSession: SparkSession, readSchema: StructType, meta: RapidsMeta[_, _, _]): Unit = { + val sqlConf = sparkSession.conf + if (!meta.conf.isParquetEnabled) { meta.willNotWorkOnGpu("Parquet input and output has been disabled. To enable set" + s"${RapidsConf.ENABLE_PARQUET} to true") @@ -130,6 +134,10 @@ object GpuParquetScan { } } + val schemaHasTimestamps = readSchema.exists { field => + TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType]) + } + // Currently timestamp conversion is not supported. // If support needs to be added then we need to follow the logic in Spark's // ParquetPartitionReaderFactory and VectorizedColumnReader which essentially @@ -139,9 +147,17 @@ object GpuParquetScan { // were written in that timezone and convert them to UTC timestamps. // Essentially this should boil down to a vector subtract of the scalar delta // between the configured timezone's delta from UTC on the timestamp data. - if (sparkSession.sessionState.conf.isParquetINT96TimestampConversion) { + if (schemaHasTimestamps && sparkSession.sessionState.conf.isParquetINT96TimestampConversion) { meta.willNotWorkOnGpu("GpuParquetScan does not support int96 timestamp conversion") } + + sqlConf.get(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ.key) match { + case "EXCEPTION" => // Good + case "CORRECTED" => // Good + case "LEGACY" => // Good, but it really is EXCEPTION for us... + case other => + meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode") + } } } @@ -164,6 +180,8 @@ case class GpuParquetPartitionReaderFactory( private val debugDumpPrefix = rapidsConf.parquetDebugDumpPrefix private val maxReadBatchSizeRows = rapidsConf.maxReadBatchSizeRows private val maxReadBatchSizeBytes = rapidsConf.maxReadBatchSizeBytes + private val isCorrectedRebase = + "CORRECTED" == sqlConf.getConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ) override def supportColumnarReads(partition: InputPartition): Boolean = true @@ -177,8 +195,10 @@ case class GpuParquetPartitionReaderFactory( ColumnarPartitionReaderWithPartitionValues.newReader(partitionedFile, reader, partitionSchema) } - private def filterClippedSchema(clippedSchema: MessageType, - fileSchema: MessageType, isCaseSensitive: Boolean): MessageType = { + private def filterClippedSchema( + clippedSchema: MessageType, + fileSchema: MessageType, + isCaseSensitive: Boolean): MessageType = { val fs = fileSchema.asGroupType() val types = if (isCaseSensitive) { val inFile = fs.getFields.asScala.map(_.getName).toSet @@ -201,6 +221,24 @@ case class GpuParquetPartitionReaderFactory( } } + // Copied from Spark + private val SPARK_VERSION_METADATA_KEY = "org.apache.spark.version" + // Copied from Spark + private val SPARK_LEGACY_DATETIME = "org.apache.spark.legacyDateTime" + + def isCorrectedRebaseMode( + lookupFileMeta: String => String, + isCorrectedModeConfig: Boolean): Boolean = { + // If there is no version, we return the mode specified by the config. + Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)).map { version => + // Files written by Spark 2.4 and earlier follow the legacy hybrid calendar and we need to + // rebase the datetime values. + // Files written by Spark 3.0 and later may also need the rebase if they were written with + // the "LEGACY" rebase mode. + version >= "3.0.0" && lookupFileMeta(SPARK_LEGACY_DATETIME) == null + }.getOrElse(isCorrectedModeConfig) + } + private def buildBaseColumnarParquetReader( file: PartitionedFile): PartitionReader[ColumnarBatch] = { val conf = broadcastedConf.value.value @@ -217,6 +255,9 @@ case class GpuParquetPartitionReaderFactory( None } + val isCorrectedRebaseForThis = + isCorrectedRebaseMode(footer.getFileMetaData.getKeyValueMetaData.get, isCorrectedRebase) + val blocks = if (pushedFilters.isDefined) { // Use the ParquetFileReader to perform dictionary-level filtering ParquetInputFormat.setFilterPredicate(conf, pushedFilters.get) @@ -242,7 +283,7 @@ case class GpuParquetPartitionReaderFactory( val clippedBlocks = ParquetPartitionReader.clipBlocks(columnPaths, blocks.asScala) new ParquetPartitionReader(conf, file, filePath, clippedBlocks, clippedSchema, isCaseSensitive, readDataSchema, debugDumpPrefix, maxReadBatchSizeRows, - maxReadBatchSizeBytes, metrics) + maxReadBatchSizeBytes, metrics, isCorrectedRebaseForThis) } } @@ -274,7 +315,8 @@ class ParquetPartitionReader( debugDumpPrefix: String, maxReadBatchSizeRows: Integer, maxReadBatchSizeBytes: Long, - execMetrics: Map[String, SQLMetric]) extends PartitionReader[ColumnarBatch] with Logging + execMetrics: Map[String, SQLMetric], + isCorrectedRebaseMode: Boolean) extends PartitionReader[ColumnarBatch] with Logging with ScanWithMetrics with Arm { private var isExhausted: Boolean = false private var maxDeviceMemory: Long = 0 @@ -554,6 +596,13 @@ class ParquetPartitionReader( GpuSemaphore.acquireIfNecessary(TaskContext.get()) val table = Table.readParquet(parseOpts, dataBuffer, 0, dataSize) + if (!isCorrectedRebaseMode) { + (0 until table.getNumberOfColumns).foreach { i => + if (RebaseHelper.isDateTimeRebaseNeededRead(table.getColumn(i))) { + throw RebaseHelper.newRebaseExceptionInRead("Parquet") + } + } + } maxDeviceMemory = max(GpuColumnVector.getTotalDeviceMemoryUsed(table), maxDeviceMemory) if (readDataSchema.length < table.getNumberOfColumns) { table.close() diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala index f4192725a31..03f238ba124 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.execution import org.json4s.JsonAST -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.{SparkContext, SparkEnv, SparkUpgradeException} import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.InputMetrics import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, IdentityBroadcastMode} @@ -62,4 +62,11 @@ object TrampolineUtil { def incInputRecordsRows(inputMetrics: InputMetrics, rows: Long): Unit = inputMetrics.incRecordsRead(rows) + + def makeSparkUpgradeException( + version: String, + message: String, + cause: Throwable): SparkUpgradeException = { + new SparkUpgradeException(version, message, cause) + } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ParquetWriterSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ParquetWriterSuite.scala index 2c68abdcf8d..55a2cb3596c 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ParquetWriterSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ParquetWriterSuite.scala @@ -22,7 +22,7 @@ import java.nio.charset.StandardCharsets import org.apache.hadoop.fs.Path import org.apache.parquet.hadoop.ParquetFileReader -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} /** * Tests for writing Parquet files with the GPU. @@ -85,7 +85,35 @@ class ParquetWriterSuite extends SparkQueryCompareTestSuite { val tempFile = File.createTempFile("int96", "parquet") tempFile.delete() frame => { - frame.write.parquet(tempFile.getAbsolutePath) + frame.write.mode("overwrite").parquet(tempFile.getAbsolutePath) + frame + } + } + + testExpectedGpuException( + "Old dates in EXCEPTION mode", + classOf[SparkException], + oldDatesDf, + new SparkConf().set("spark.sql.legacy.parquet.datetimeRebaseModeInWrite", "EXCEPTION")) { + val tempFile = File.createTempFile("oldDates", "parquet") + tempFile.delete() + frame => { + frame.write.mode("overwrite").parquet(tempFile.getAbsolutePath) + frame + } + } + + testExpectedGpuException( + "Old timestamps in EXCEPTION mode", + classOf[SparkException], + oldTsDf, + new SparkConf() + .set("spark.sql.legacy.parquet.datetimeRebaseModeInWrite", "EXCEPTION") + .set("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MICROS")) { + val tempFile = File.createTempFile("oldTimeStamp", "parquet") + tempFile.delete() + frame => { + frame.write.mode("overwrite").parquet(tempFile.getAbsolutePath) frame } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala index 7a7798ec2a0..bc921edd3c7 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala @@ -754,6 +754,41 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm { compareResults(sort, maxFloatDiff, fromCpu, fromGpu) } + def testExpectedGpuException[T <: Throwable]( + testName: String, + exceptionClass: Class[T], + df: SparkSession => DataFrame, + conf: SparkConf = new SparkConf(), + repart: Integer = 1, + sort: Boolean = false, + maxFloatDiff: Double = 0.0, + incompat: Boolean = false, + execsAllowedNonGpu: Seq[String] = Seq.empty, + sortBeforeRepart: Boolean = false)(fun: DataFrame => DataFrame): Unit = { + val (testConf, qualifiedTestName) = + setupTestConfAndQualifierName(testName, incompat, sort, conf, execsAllowedNonGpu, + maxFloatDiff, sortBeforeRepart) + + test(qualifiedTestName) { + val t = Try({ + val fromGpu = withGpuSparkSession( session => { + var data = df(session) + if (repart > 0) { + // repartition the data so it is turned into a projection, + // not folded into the table scan exec + data = data.repartition(repart) + } + fun(data).collect() + }, testConf) + }) + t match { + case Failure(e) if e.getClass == exceptionClass => // Good + case Failure(e) => throw e + case _ => fail("Expected an exception") + } + } + } + def testExpectedExceptionStartsWith[T <: Throwable]( testName: String, exceptionClass: Class[T], @@ -1569,6 +1604,26 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm { ).toDF("strings", "ints") } + def oldDatesDf(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq( + new Date(-141427L * 24 * 60 * 60 * 1000), + new Date(-150000L * 24 * 60 * 60 * 1000), + Date.valueOf("1582-10-15"), + Date.valueOf("1582-10-13") + ).toDF("dates") + } + + def oldTsDf(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq( + new Timestamp(-141427L * 24 * 60 * 60 * 1000), + new Timestamp(-150000L * 24 * 60 * 60 * 1000), + Timestamp.valueOf("1582-10-15 00:01:01"), + Timestamp.valueOf("1582-10-13 12:03:12") + ).toDF("times") + } + def utf8RepeatedDf(session: SparkSession): DataFrame = { import session.sqlContext.implicits._ var utf8Chars = (0 until 256 /*65536*/).map(i => (i.toChar.toString, i))