diff --git a/integration_tests/src/main/python/delta_lake_write_test.py b/integration_tests/src/main/python/delta_lake_write_test.py index 3997ae3eba3..41f782a0005 100644 --- a/integration_tests/src/main/python/delta_lake_write_test.py +++ b/integration_tests/src/main/python/delta_lake_write_test.py @@ -398,24 +398,25 @@ def setup_tables(spark): conf=confs) with_cpu_session(lambda spark: assert_gpu_and_cpu_delta_logs_equivalent(spark, data_path)) -@allow_non_gpu(*delta_meta_allow, delta_write_fallback_allow) +@allow_non_gpu(*delta_meta_allow) @delta_lake @ignore_order @pytest.mark.parametrize("ts_write", ["INT96", "TIMESTAMP_MICROS", "TIMESTAMP_MILLIS"], ids=idfn) @pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x") -def test_delta_write_legacy_timestamp_fallback(spark_tmp_path, ts_write): - gen = TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc)) +def test_delta_write_legacy_timestamp(spark_tmp_path, ts_write): + gen = TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), + end=datetime(2000, 1, 1, tzinfo=timezone.utc)).with_special_case( + datetime(1000, 1, 1, tzinfo=timezone.utc), weight=10.0) data_path = spark_tmp_path + "/DELTA_DATA" all_confs = copy_and_update(delta_writes_enabled_conf, { "spark.sql.legacy.parquet.datetimeRebaseModeInWrite": "LEGACY", "spark.sql.legacy.parquet.int96RebaseModeInWrite": "LEGACY", "spark.sql.legacy.parquet.outputTimestampType": ts_write }) - assert_gpu_fallback_write( + assert_gpu_and_cpu_writes_are_equal_collect( lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.format("delta").save(path), lambda spark, path: spark.read.format("delta").load(path), data_path, - delta_write_fallback_check, conf=all_confs) @allow_non_gpu(*delta_meta_allow, delta_write_fallback_allow) diff --git a/integration_tests/src/main/python/hive_write_test.py b/integration_tests/src/main/python/hive_write_test.py index cf79b996514..d7de6f1084e 100644 --- a/integration_tests/src/main/python/hive_write_test.py +++ b/integration_tests/src/main/python/hive_write_test.py @@ -85,8 +85,6 @@ def do_write(spark, table_name): @pytest.mark.skipif(not is_hive_available(), reason="Hive is missing") @pytest.mark.parametrize("gens", [_basic_gens], ids=idfn) @pytest.mark.parametrize("storage_with_confs", [ - ("PARQUET", {"spark.sql.legacy.parquet.datetimeRebaseModeInWrite": "LEGACY", - "spark.sql.legacy.parquet.int96RebaseModeInWrite": "LEGACY"}), ("PARQUET", {"parquet.encryption.footer.key": "k1", "parquet.encryption.column.keys": "k2:a"}), ("PARQUET", {"spark.sql.parquet.compression.codec": "gzip"}), @@ -183,4 +181,4 @@ def do_test(spark): jvm = spark_jvm() jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.assertContainsAnsiCast(cpu_df._jdf) jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.assertContainsAnsiCast(gpu_df._jdf) - assert_equal(from_cpu, from_gpu) \ No newline at end of file + assert_equal(from_cpu, from_gpu) diff --git a/integration_tests/src/main/python/parquet_write_test.py b/integration_tests/src/main/python/parquet_write_test.py index 6c8b8e4dca5..aaba55c0bc5 100644 --- a/integration_tests/src/main/python/parquet_write_test.py +++ b/integration_tests/src/main/python/parquet_write_test.py @@ -70,8 +70,11 @@ MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen)] -parquet_datetime_gen_simple = [DateGen(end=date(3000, 1, 1)), - TimestampGen(end=datetime(3000, 1, 1, tzinfo=timezone.utc))] +parquet_datetime_gen_simple = [DateGen(start=date(1, 1, 1), end=date(2000, 1, 1)) + .with_special_case(date(1000, 1, 1), weight=10.0), + TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), + end=datetime(2000, 1, 1, tzinfo=timezone.utc)) + .with_special_case(datetime(1000, 1, 1, tzinfo=timezone.utc), weight=10.0)] parquet_datetime_in_struct_gen = [StructGen([['child' + str(ind), sub_gen] for ind, sub_gen in enumerate(parquet_datetime_gen_simple)]), StructGen([['child0', StructGen([['child' + str(ind), sub_gen] for ind, sub_gen in enumerate(parquet_datetime_gen_simple)])]])] parquet_datetime_in_array_gen = [ArrayGen(sub_gen, max_length=10) for sub_gen in parquet_datetime_gen_simple + parquet_datetime_in_struct_gen] + [ @@ -309,22 +312,6 @@ def test_ts_write_twice_fails_exception(spark_tmp_path, spark_tmp_table_factory) with_gpu_session( lambda spark : writeParquetNoOverwriteCatchException(spark, unary_op_df(spark, gen), data_path, table_name)) -@allow_non_gpu('DataWritingCommandExec,ExecutedCommandExec,WriteFilesExec') -@pytest.mark.parametrize('ts_write', parquet_ts_write_options) -@pytest.mark.parametrize('ts_rebase', ['LEGACY']) -def test_parquet_write_legacy_fallback(spark_tmp_path, ts_write, ts_rebase, spark_tmp_table_factory): - gen = TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc)) - data_path = spark_tmp_path + '/PARQUET_DATA' - all_confs={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase, - 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase, - 'spark.sql.parquet.outputTimestampType': ts_write} - assert_gpu_fallback_write( - lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.format("parquet").mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()), - lambda spark, path: spark.read.parquet(path), - data_path, - 'DataWritingCommandExec', - conf=all_confs) - @allow_non_gpu('DataWritingCommandExec,ExecutedCommandExec,WriteFilesExec') @pytest.mark.parametrize('write_options', [{"parquet.encryption.footer.key": "k1"}, {"parquet.encryption.column.keys": "k2:a"}, @@ -470,41 +457,17 @@ def generate_map_with_empty_validity(spark, path): lambda spark, path: spark.read.parquet(path), data_path) -@pytest.mark.parametrize('ts_write_data_gen', [('INT96', TimestampGen()), - ('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc))), - ('TIMESTAMP_MILLIS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc)))]) -@pytest.mark.parametrize('date_time_rebase_write', ["CORRECTED"]) -@pytest.mark.parametrize('date_time_rebase_read', ["EXCEPTION", "CORRECTED"]) -@pytest.mark.parametrize('int96_rebase_write', ["CORRECTED"]) -@pytest.mark.parametrize('int96_rebase_read', ["EXCEPTION", "CORRECTED"]) -def test_timestamp_roundtrip_no_legacy_rebase(spark_tmp_path, ts_write_data_gen, - date_time_rebase_read, date_time_rebase_write, - int96_rebase_read, int96_rebase_write): - ts_write, gen = ts_write_data_gen - data_path = spark_tmp_path + '/PARQUET_DATA' - all_confs = {'spark.sql.parquet.outputTimestampType': ts_write} - all_confs.update({'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': date_time_rebase_write, - 'spark.sql.legacy.parquet.int96RebaseModeInWrite': int96_rebase_write}) - all_confs.update({'spark.sql.legacy.parquet.datetimeRebaseModeInRead': date_time_rebase_read, - 'spark.sql.legacy.parquet.int96RebaseModeInRead': int96_rebase_read}) - assert_gpu_and_cpu_writes_are_equal_collect( - lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.parquet(path), - lambda spark, path: spark.read.parquet(path), - data_path, - conf=all_confs) - -# This should be merged to `test_timestamp_roundtrip_no_legacy_rebase` above when -# we have rebase for int96 supported. -@pytest.mark.parametrize('ts_write', ['TIMESTAMP_MICROS', 'TIMESTAMP_MILLIS']) @pytest.mark.parametrize('data_gen', parquet_nested_datetime_gen, ids=idfn) -def test_datetime_roundtrip_with_legacy_rebase(spark_tmp_path, ts_write, data_gen): +@pytest.mark.parametrize('ts_write', parquet_ts_write_options) +@pytest.mark.parametrize('ts_rebase_write', ['CORRECTED', 'LEGACY']) +@pytest.mark.parametrize('ts_rebase_read', ['CORRECTED', 'LEGACY']) +def test_datetime_roundtrip_with_legacy_rebase(spark_tmp_path, data_gen, ts_write, ts_rebase_write, ts_rebase_read): data_path = spark_tmp_path + '/PARQUET_DATA' all_confs = {'spark.sql.parquet.outputTimestampType': ts_write, - 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'LEGACY', - 'spark.sql.legacy.parquet.datetimeRebaseModeInRead': 'CORRECTED', - # set the int96 rebase mode values because its LEGACY in databricks which will preclude this op from running on GPU - 'spark.sql.legacy.parquet.int96RebaseModeInWrite' : 'CORRECTED', - 'spark.sql.legacy.parquet.int96RebaseModeInRead' : 'CORRECTED'} + 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase_write, + 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase_write, + 'spark.sql.legacy.parquet.datetimeRebaseModeInRead': ts_rebase_read, + 'spark.sql.legacy.parquet.int96RebaseModeInRead': ts_rebase_read} assert_gpu_and_cpu_writes_are_equal_collect( lambda spark, path: unary_op_df(spark, data_gen).coalesce(1).write.parquet(path), lambda spark, path: spark.read.parquet(path), @@ -776,27 +739,12 @@ def read_table(spark, path): func(create_table, read_table, data_path, conf) # Test to avoid regression on a known bug in Spark. For details please visit https://github.com/NVIDIA/spark-rapids/issues/8693 -def test_hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path): - +@pytest.mark.parametrize('ts_rebase', ['LEGACY', 'CORRECTED']) +def test_hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path, ts_rebase): def func_test(create_table, read_table, data_path, conf): assert_gpu_and_cpu_writes_are_equal_collect(create_table, read_table, data_path, conf=conf) assert_gpu_and_cpu_are_equal_collect(lambda spark: spark.read.parquet(data_path + '/CPU')) - - hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path, 'CORRECTED', func_test) - -# Test to avoid regression on a known bug in Spark. For details please visit https://github.com/NVIDIA/spark-rapids/issues/8693 -@allow_non_gpu('DataWritingCommandExec', 'WriteFilesExec') -def test_hive_timestamp_value_fallback(spark_tmp_table_factory, spark_tmp_path): - - def func_test(create_table, read_table, data_path, conf): - assert_gpu_fallback_write( - create_table, - read_table, - data_path, - ['DataWritingCommandExec'], - conf) - - hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path, 'LEGACY', func_test) + hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path, ts_rebase, func_test) @ignore_order @pytest.mark.skipif(is_before_spark_340(), reason="`spark.sql.optimizer.plannedWrite.enabled` is only supported in Spark 340+") diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala index a15cbab3b89..7e845491ec0 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala @@ -113,38 +113,28 @@ object GpuParquetFileFormat { val schemaHasTimestamps = schema.exists { field => TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType]) } - if (schemaHasTimestamps) { - if(!isOutputTimestampTypeSupported(sqlConf.parquetOutputTimestampType)) { - meta.willNotWorkOnGpu(s"Output timestamp type " + - s"${sqlConf.parquetOutputTimestampType} is not supported") - } + if (schemaHasTimestamps && + !isOutputTimestampTypeSupported(sqlConf.parquetOutputTimestampType)) { + meta.willNotWorkOnGpu(s"Output timestamp type " + + s"${sqlConf.parquetOutputTimestampType} is not supported") } - DateTimeRebaseMode.fromName(SparkShimImpl.int96ParquetRebaseWrite(sqlConf)) match { - case DateTimeRebaseException | DateTimeRebaseCorrected => // Good - case DateTimeRebaseLegacy => - if (schemaHasTimestamps) { - meta.willNotWorkOnGpu("LEGACY rebase mode for int96 timestamps is not supported") - } - // This should never be reached out, since invalid mode is handled in - // `DateTimeRebaseMode.fromName`. - case other => meta.willNotWorkOnGpu( - DateTimeRebaseUtils.invalidRebaseModeMessage(other.getClass.getName)) + val schemaHasDates = schema.exists { field => + TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[DateType]) } - - DateTimeRebaseMode.fromName(SparkShimImpl.parquetRebaseWrite(sqlConf)) match { - case DateTimeRebaseException | DateTimeRebaseCorrected => // Good - case DateTimeRebaseLegacy => - if (!TypeChecks.areTimestampsSupported()) { - meta.willNotWorkOnGpu("Only UTC timezone is supported in LEGACY rebase mode. " + - s"Current timezone settings: (JVM : ${ZoneId.systemDefault()}, " + - s"session: ${SQLConf.get.sessionLocalTimeZone}). " + - " Set both of the timezones to UTC to enable LEGACY rebase support.") - } - // This should never be reached out, since invalid mode is handled in - // `DateTimeRebaseMode.fromName`. - case other => meta.willNotWorkOnGpu( - DateTimeRebaseUtils.invalidRebaseModeMessage(other.getClass.getName)) + if (schemaHasDates || schemaHasTimestamps) { + val int96RebaseMode = DateTimeRebaseMode.fromName( + SparkShimImpl.int96ParquetRebaseWrite(sqlConf)) + val dateTimeRebaseMode = DateTimeRebaseMode.fromName( + SparkShimImpl.parquetRebaseWrite(sqlConf)) + + if ((int96RebaseMode == DateTimeRebaseLegacy || dateTimeRebaseMode == DateTimeRebaseLegacy) + && !TypeChecks.areTimestampsSupported()) { + meta.willNotWorkOnGpu("Only UTC timezone is supported in LEGACY rebase mode. " + + s"Current timezone settings: (JVM : ${ZoneId.systemDefault()}, " + + s"session: ${SQLConf.get.sessionLocalTimeZone}). " + + " Set both of the timezones to UTC to enable LEGACY rebase support.") + } } if (meta.canThisBeReplaced) {