Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support int96RebaseModeInWrite and int96RebaseModeInRead #3330

Merged
merged 26 commits into from
Sep 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1fc3b86
support int96 rebase mode
razajafri Aug 26, 2021
c25aab8
addressed review comments
razajafri Aug 30, 2021
5df6921
Merge remote-tracking branch 'origin/branch-21.10' into HEAD
razajafri Aug 30, 2021
2c31105
removed unnecessary changes to the SparkBaseShim
razajafri Aug 30, 2021
1798391
removed unnecessary extra line
razajafri Aug 30, 2021
1d06300
split the date and time exception checks
razajafri Aug 31, 2021
489e342
Merge remote-tracking branch 'origin/branch-21.10' into HEAD
razajafri Aug 31, 2021
45a3773
addressed review comments
razajafri Sep 1, 2021
b25422c
skip test if Spark version <3.1.1
razajafri Sep 1, 2021
a62a9ef
added more tests and addressed review comments
razajafri Sep 3, 2021
b8f1c1a
addressed the failing test
razajafri Sep 3, 2021
ac48707
Merge remote-tracking branch 'origin/branch-21.10' into HEAD
razajafri Sep 20, 2021
12a8e65
added method to return if running before 311
razajafri Sep 20, 2021
10409c2
Merge remote-tracking branch 'origin/branch-21.10' into HEAD
razajafri Sep 20, 2021
a0d2728
changed the API to return existence of a separate INT96 rebase conf
razajafri Sep 21, 2021
01cfc0b
Merge remote-tracking branch 'origin/branch-21.10' into HEAD
razajafri Sep 21, 2021
46002b9
updated the 32x shim to override the correct method
razajafri Sep 21, 2021
464e9ad
added DB support
razajafri Sep 21, 2021
85cbce2
Merge branch 'branch-21.10' into int96_rebase_mode
razajafri Sep 21, 2021
7e77a23
The default value for int96 rebase in databricks is legacy, explicitl…
razajafri Sep 22, 2021
5c49856
Merge remote-tracking branch 'origin/branch-21.10' into HEAD
razajafri Sep 22, 2021
abd2f5d
skip ANSI test for hash aggregate until we root cause the failure
razajafri Sep 22, 2021
14658b4
Adding resolution to the failed test
razajafri Sep 22, 2021
0692169
Revert "Adding resolution to the failed test"
razajafri Sep 22, 2021
b3c1c1e
Adding resolution to the failed test
razajafri Sep 22, 2021
7166a1c
Update integration_tests/src/main/python/hash_aggregate_test.py
razajafri Sep 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions integration_tests/src/main/python/cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ def write_read_parquet_cached(spark):
# rapids-spark doesn't support LEGACY read for parquet
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED',
'spark.sql.legacy.parquet.datetimeRebaseModeInRead' : 'CORRECTED',
# set the int96 rebase mode values because its LEGACY in databricks which will preclude this op from running on GPU
'spark.sql.legacy.parquet.int96RebaseModeInWrite' : 'CORRECTED',
'spark.sql.legacy.parquet.int96RebaseModeInRead' : 'CORRECTED',
'spark.sql.inMemoryColumnarStorage.enableVectorizedReader' : enable_vectorized,
'spark.sql.parquet.outputTimestampType': ts_write}

Expand Down
1 change: 1 addition & 0 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,7 @@ def do_it(spark):


@pytest.mark.parametrize('data_gen', _no_overflow_ansi_gens, ids=idfn)
@ignore_order(local=True)
def test_no_fallback_when_ansi_enabled(data_gen):
def do_it(spark):
df = gen_df(spark, [('a', data_gen), ('b', data_gen)], length=100)
Expand Down
2 changes: 2 additions & 0 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def test_read_round_trip(spark_tmp_path, parquet_gens, read_func, reader_confs,
conf=rebase_write_corrected_conf)
all_confs = copy_and_update(reader_confs, {
'spark.sql.sources.useV1SourceList': v1_enabled_list,
# set the int96 rebase mode values because its LEGACY in databricks which will preclude this op from running on GPU
'spark.sql.legacy.parquet.int96RebaseModeInRead' : 'CORRECTED',
'spark.sql.legacy.parquet.datetimeRebaseModeInRead': 'CORRECTED'})
# once https://github.com/NVIDIA/spark-rapids/issues/1126 is in we can remove spark.sql.legacy.parquet.datetimeRebaseModeInRead config which is a workaround
# for nested timestamp/date support
Expand Down
65 changes: 56 additions & 9 deletions integration_tests/src/main/python/parquet_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pyspark.sql.functions as f
import pyspark.sql.utils
import random
from spark_session import is_before_spark_311

# test with original parquet file reader, the multi-file parallel reader for cloud, and coalesce file reader for
# non-cloud
Expand All @@ -41,6 +42,11 @@ def limited_timestamp(nullable=True):
return TimestampGen(start=datetime(1677, 9, 22, tzinfo=timezone.utc), end=datetime(2262, 4, 11, tzinfo=timezone.utc),
nullable=nullable)

# TODO - https://github.com/NVIDIA/spark-rapids/issues/1130 to handle TIMESTAMP_MILLIS
# TODO - we are limiting the INT96 values, see https://github.com/rapidsai/cudf/issues/8070
def limited_int96():
return TimestampGen(start=datetime(1677, 9, 22, tzinfo=timezone.utc), end=datetime(2262, 4, 11, tzinfo=timezone.utc))

parquet_basic_gen =[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, date_gen,
# we are limiting TimestampGen to avoid overflowing the INT96 value
Expand Down Expand Up @@ -214,24 +220,44 @@ def test_write_sql_save_table(spark_tmp_path, parquet_gens, ts_type, spark_tmp_t
data_path,
conf=all_confs)

def writeParquetUpgradeCatchException(spark, df, data_path, spark_tmp_table_factory, ts_rebase, ts_write):
def writeParquetUpgradeCatchException(spark, df, data_path, spark_tmp_table_factory, int96_rebase, datetime_rebase, ts_write):
spark.conf.set('spark.sql.parquet.outputTimestampType', ts_write)
spark.conf.set('spark.sql.legacy.parquet.datetimeRebaseModeInWrite', ts_rebase)
spark.conf.set('spark.sql.legacy.parquet.int96RebaseModeInWrite', ts_rebase) # for spark 310
spark.conf.set('spark.sql.legacy.parquet.datetimeRebaseModeInWrite', datetime_rebase)
spark.conf.set('spark.sql.legacy.parquet.int96RebaseModeInWrite', int96_rebase) # for spark 310
with pytest.raises(Exception) as e_info:
df.coalesce(1).write.format("parquet").mode('overwrite').option("path", data_path).saveAsTable(spark_tmp_table_factory.get())
assert e_info.match(r".*SparkUpgradeException.*")

# TODO - https://github.com/NVIDIA/spark-rapids/issues/1130 to handle TIMESTAMP_MILLIS
# TODO - we are limiting the INT96 values, see https://github.com/rapidsai/cudf/issues/8070
@pytest.mark.parametrize('ts_write_data_gen', [('INT96', TimestampGen(start=datetime(1677, 9, 22, tzinfo=timezone.utc), end=datetime(2262, 4, 11, tzinfo=timezone.utc))),
('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc)))])
@pytest.mark.parametrize('ts_rebase', ['EXCEPTION'])
def test_ts_write_fails_datetime_exception(spark_tmp_path, ts_write_data_gen, ts_rebase, spark_tmp_table_factory):
@pytest.mark.parametrize('ts_write_data_gen', [('INT96', limited_int96()), ('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc)))])
@pytest.mark.parametrize('rebase', ["CORRECTED","EXCEPTION"])
def test_ts_write_fails_datetime_exception(spark_tmp_path, ts_write_data_gen, spark_tmp_table_factory, rebase):
ts_write, gen = ts_write_data_gen
data_path = spark_tmp_path + '/PARQUET_DATA'
with_gpu_session(
lambda spark : writeParquetUpgradeCatchException(spark, unary_op_df(spark, gen), data_path, spark_tmp_table_factory, ts_rebase, ts_write))
int96_rebase = "EXCEPTION" if (ts_write == "INT96") else rebase
date_time_rebase = "EXCEPTION" if (ts_write == "TIMESTAMP_MICROS") else rebase
if is_before_spark_311() and ts_write == 'INT96':
all_confs = {'spark.sql.parquet.outputTimestampType': ts_write}
all_confs.update({'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': date_time_rebase,
'spark.sql.legacy.parquet.int96RebaseModeInWrite': int96_rebase})
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
data_path,
conf=all_confs)
else:
with_gpu_session(
lambda spark : writeParquetUpgradeCatchException(spark,
unary_op_df(spark, gen),
data_path,
spark_tmp_table_factory,
int96_rebase, date_time_rebase, ts_write))
with_cpu_session(
lambda spark: writeParquetUpgradeCatchException(spark,
unary_op_df(spark, gen), data_path,
spark_tmp_table_factory,
int96_rebase, date_time_rebase, ts_write))

def writeParquetNoOverwriteCatchException(spark, df, data_path, table_name):
with pytest.raises(Exception) as e_info:
Expand Down Expand Up @@ -319,6 +345,27 @@ def generate_map_with_empty_validity(spark, path):
lambda spark, path: spark.read.parquet(path),
data_path)

@pytest.mark.parametrize('ts_write_data_gen', [('INT96', limited_int96()), ('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc)))])
@pytest.mark.parametrize('date_time_rebase_write', ["CORRECTED"])
@pytest.mark.parametrize('date_time_rebase_read', ["EXCEPTION", "CORRECTED"])
@pytest.mark.parametrize('int96_rebase_write', ["CORRECTED"])
@pytest.mark.parametrize('int96_rebase_read', ["EXCEPTION", "CORRECTED"])
def test_roundtrip_with_rebase_values(spark_tmp_path, ts_write_data_gen, date_time_rebase_read,
date_time_rebase_write, int96_rebase_read, int96_rebase_write):
ts_write, gen = ts_write_data_gen
data_path = spark_tmp_path + '/PARQUET_DATA'
all_confs = {'spark.sql.parquet.outputTimestampType': ts_write}
all_confs.update({'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': date_time_rebase_write,
'spark.sql.legacy.parquet.int96RebaseModeInWrite': int96_rebase_write})
all_confs.update({'spark.sql.legacy.parquet.datetimeRebaseModeInRead': date_time_rebase_read,
'spark.sql.legacy.parquet.int96RebaseModeInRead': int96_rebase_read})

assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
data_path,
conf=all_confs)

@pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/3476')
@pytest.mark.allow_non_gpu("DataWritingCommandExec", "HiveTableScanExec")
@pytest.mark.parametrize('allow_non_empty', [True, False])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ class Spark311Shims extends SparkBaseShims {
classOf[RapidsShuffleManager].getCanonicalName
}

override def int96ParquetRebaseRead(conf: SQLConf): String = {
conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ)
}

override def int96ParquetRebaseWrite(conf: SQLConf): String = {
conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE)
}

override def int96ParquetRebaseReadKey: String = {
SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ.key
}

override def int96ParquetRebaseWriteKey: String = {
SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE.key
}

override def hasCastFloatTimestampUpcast: Boolean = false

override def getParquetFilters(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,22 @@ class Spark311CDHShims extends SparkBaseShims {
sessionCatalog.createTable(newTable, ignoreIfExists = false, validateLocation = false)
}

override def int96ParquetRebaseRead(conf: SQLConf): String = {
conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ)
}

override def int96ParquetRebaseWrite(conf: SQLConf): String = {
conf.getConf(SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE)
}

override def int96ParquetRebaseReadKey: String = {
SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_READ.key
}

override def int96ParquetRebaseWriteKey: String = {
SQLConf.LEGACY_PARQUET_INT96_REBASE_MODE_IN_WRITE.key
}

override def hasCastFloatTimestampUpcast: Boolean = false

override def getParquetFilters(
Expand Down
Loading