diff --git a/integration_tests/src/main/python/parquet_test.py b/integration_tests/src/main/python/parquet_test.py index 844040d4ffd..3137a2dfabd 100644 --- a/integration_tests/src/main/python/parquet_test.py +++ b/integration_tests/src/main/python/parquet_test.py @@ -291,7 +291,8 @@ def test_read_merge_schema_from_conf(spark_tmp_path, v1_enabled_list, mt_opt): @pytest.mark.parametrize('parquet_gens', parquet_write_gens_list, ids=idfn) @pytest.mark.parametrize('mt_opt', ["true", "false"]) @pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) -def test_write_round_trip(spark_tmp_path, parquet_gens, mt_opt, v1_enabled_list): +@pytest.mark.parametrize('ts_type', ["TIMESTAMP_MICROS", "TIMESTAMP_MILLIS"]) +def test_write_round_trip(spark_tmp_path, parquet_gens, mt_opt, v1_enabled_list, ts_type): gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] data_path = spark_tmp_path + '/PARQUET_DATA' assert_gpu_and_cpu_writes_are_equal_collect( @@ -299,10 +300,23 @@ def test_write_round_trip(spark_tmp_path, parquet_gens, mt_opt, v1_enabled_list) lambda spark, path: spark.read.parquet(path), data_path, conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED', - 'spark.sql.parquet.outputTimestampType': 'TIMESTAMP_MICROS', + 'spark.sql.parquet.outputTimestampType': ts_type, 'spark.rapids.sql.format.parquet.multiThreadedRead.enabled': mt_opt, 'spark.sql.sources.useV1SourceList': v1_enabled_list}) +@pytest.mark.parametrize('ts_type', ['TIMESTAMP_MILLIS', 'TIMESTAMP_MICROS']) +@pytest.mark.parametrize('ts_rebase', ['CORRECTED']) +@ignore_order +def test_write_ts_millis(spark_tmp_path, ts_type, ts_rebase): + gen = TimestampGen() + data_path = spark_tmp_path + '/PARQUET_DATA' + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: unary_op_df(spark, gen).write.parquet(path), + lambda spark, path: spark.read.parquet(path), + data_path, + conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase, + 'spark.sql.parquet.outputTimestampType': ts_type}) + parquet_part_write_gens = [ byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, # Some file systems have issues with UTF8 strings so to help the test pass even there @@ -314,7 +328,8 @@ def test_write_round_trip(spark_tmp_path, parquet_gens, mt_opt, v1_enabled_list) @pytest.mark.parametrize('parquet_gen', parquet_part_write_gens, ids=idfn) @pytest.mark.parametrize('mt_opt', ["true", "false"]) @pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) -def test_part_write_round_trip(spark_tmp_path, parquet_gen, mt_opt, v1_enabled_list): +@pytest.mark.parametrize('ts_type', ['TIMESTAMP_MILLIS', 'TIMESTAMP_MICROS']) +def test_part_write_round_trip(spark_tmp_path, parquet_gen, mt_opt, v1_enabled_list, ts_type): gen_list = [('a', RepeatSeqGen(parquet_gen, 10)), ('b', parquet_gen)] data_path = spark_tmp_path + '/PARQUET_DATA' @@ -323,7 +338,7 @@ def test_part_write_round_trip(spark_tmp_path, parquet_gen, mt_opt, v1_enabled_l lambda spark, path: spark.read.parquet(path), data_path, conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED', - 'spark.sql.parquet.outputTimestampType': 'TIMESTAMP_MICROS', + 'spark.sql.parquet.outputTimestampType': ts_type, 'spark.rapids.sql.format.parquet.multiThreadedRead.enabled': mt_opt, 'spark.sql.sources.useV1SourceList': v1_enabled_list}) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala index 6a18ec8c26a..8b7482ed72d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala @@ -30,8 +30,10 @@ import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetWriteSupport} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType +import org.apache.spark.sql.rapids.ColumnarWriteTaskStatsTracker import org.apache.spark.sql.rapids.execution.TrampolineUtil -import org.apache.spark.sql.types.{DateType, StructType, TimestampType} +import org.apache.spark.sql.types.{DataTypes, DateType, StructType, TimestampType} +import org.apache.spark.sql.vectorized.ColumnarBatch object GpuParquetFileFormat { def tagGpuSupport( @@ -64,10 +66,9 @@ object GpuParquetFileFormat { TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType]) } if (schemaHasTimestamps) { - // TODO: Could support TIMESTAMP_MILLIS by performing cast on all timestamp input columns - sqlConf.parquetOutputTimestampType match { - case ParquetOutputTimestampType.TIMESTAMP_MICROS => - case t => meta.willNotWorkOnGpu(s"Output timestamp type $t is not supported") + if(!isOutputTimestampTypeSupported(sqlConf.parquetOutputTimestampType)) { + meta.willNotWorkOnGpu(s"Output timestamp type " + + s"${sqlConf.parquetOutputTimestampType} is not supported") } } @@ -100,6 +101,15 @@ object GpuParquetFileFormat { case _ => None } } + + def isOutputTimestampTypeSupported( + outputTimestampType: ParquetOutputTimestampType.Value): Boolean = { + outputTimestampType match { + case ParquetOutputTimestampType.TIMESTAMP_MICROS | + ParquetOutputTimestampType.TIMESTAMP_MILLIS => true + case _ => false + } + } } class GpuParquetFileFormat extends ColumnarFileFormat with Logging { @@ -160,7 +170,7 @@ class GpuParquetFileFormat extends ColumnarFileFormat with Logging { sparkSession.sessionState.conf.writeLegacyParquetFormat.toString) val outputTimestampType = sparkSession.sessionState.conf.parquetOutputTimestampType - if (outputTimestampType != ParquetOutputTimestampType.TIMESTAMP_MICROS) { + if(!GpuParquetFileFormat.isOutputTimestampTypeSupported(outputTimestampType)) { val hasTimestamps = dataSchema.exists { field => TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType]) } @@ -217,6 +227,8 @@ class GpuParquetWriter( context: TaskAttemptContext) extends ColumnarOutputWriter(path, context, dataSchema, "Parquet") { + val outputTimestampType = conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key) + override def scanTableBeforeWrite(table: Table): Unit = { if (dateTimeRebaseException) { (0 until table.getNumberOfColumns).foreach { i => @@ -227,6 +239,32 @@ class GpuParquetWriter( } } + /** + * Persists a columnar batch. Invoked on the executor side. When writing to dynamically + * partitioned tables, dynamic partition columns are not included in columns to be written. + * NOTE: It is the writer's responsibility to close the batch. + */ + override def write(batch: ColumnarBatch, + statsTrackers: Seq[ColumnarWriteTaskStatsTracker]): Unit = { + val newBatch = + if (outputTimestampType == ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString) { + new ColumnarBatch(GpuColumnVector.extractColumns(batch).map { + cv => { + cv.dataType() match { + case DataTypes.TimestampType => new GpuColumnVector(DataTypes.TimestampType, + withResource(cv.getBase()) { v => + v.castTo(DType.TIMESTAMP_MILLISECONDS) + }) + case _ => cv + } + } + }) + } else { + batch + } + + super.write(newBatch, statsTrackers) + } override val tableWriter: TableWriter = { val writeContext = new ParquetWriteSupport().init(conf) val builder = ParquetWriterOptions.builder() diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ParquetWriterSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ParquetWriterSuite.scala index 55a2cb3596c..2efdab41ba9 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ParquetWriterSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ParquetWriterSuite.scala @@ -103,6 +103,21 @@ class ParquetWriterSuite extends SparkQueryCompareTestSuite { } } + testExpectedGpuException( + "Old timestamps millis in EXCEPTION mode", + classOf[SparkException], + oldTsDf, + new SparkConf() + .set("spark.sql.legacy.parquet.datetimeRebaseModeInWrite", "EXCEPTION") + .set("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MILLIS")) { + val tempFile = File.createTempFile("oldTimeStamp", "parquet") + tempFile.delete() + frame => { + frame.write.mode("overwrite").parquet(tempFile.getAbsolutePath) + frame + } + } + testExpectedGpuException( "Old timestamps in EXCEPTION mode", classOf[SparkException],