Skip to content

Commit

Permalink
Move GpuParquetScan/GpuOrcScan into Shim (#590)
Browse files Browse the repository at this point in the history
* Move GpuParquetScan to shim

Signed-off-by: Thomas Graves <tgraves@nvidia.com>

* Move scan overrides into shim

Signed-off-by: Thomas Graves <tgraves@apache.org>

* Rename GpuParquetScan object to match

Signed-off-by: Thomas Graves <tgraves@apache.org>

* Add tests for v2 datasources

Signed-off-by: Thomas Graves <tgraves@nvidia.com>

* Move OrcScan into shims

Signed-off-by: Thomas Graves <tgraves@apache.org>

* Fixes

Signed-off-by: Thomas Graves <tgraves@nvidia.com>

* Fix imports

Signed-off-by: Thomas Graves <tgraves@apache.org>

Co-authored-by: Thomas Graves <tgraves@nvidia.com>
  • Loading branch information
tgravescs and tgravescs authored Aug 19, 2020
1 parent 74ddc06 commit ce1f9b8
Show file tree
Hide file tree
Showing 15 changed files with 479 additions and 149 deletions.
33 changes: 23 additions & 10 deletions integration_tests/src/main/python/csv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,12 @@ def read_impl(spark):
('str.csv', _good_str_schema, ',', True)
])
@pytest.mark.parametrize('read_func', [read_csv_df, read_csv_sql])
def test_basic_read(std_input_path, name, schema, sep, header, read_func):
@pytest.mark.parametrize('v1_enabled_list', ["", "csv"])
def test_basic_read(std_input_path, name, schema, sep, header, read_func, v1_enabled_list):
updated_conf=_enable_ts_conf
updated_conf['spark.sql.sources.useV1SourceList']=v1_enabled_list
assert_gpu_and_cpu_are_equal_collect(read_func(std_input_path + '/' + name, schema, header, sep),
conf=_enable_ts_conf)
conf=updated_conf)

csv_supported_gens = [
# Spark does not escape '\r' or '\n' even though it uses it to mark end of record
Expand All @@ -141,15 +144,18 @@ def test_basic_read(std_input_path, name, schema, sep, header, read_func):

@approximate_float
@pytest.mark.parametrize('data_gen', csv_supported_gens, ids=idfn)
def test_round_trip(spark_tmp_path, data_gen):
@pytest.mark.parametrize('v1_enabled_list', ["", "csv"])
def test_round_trip(spark_tmp_path, data_gen, v1_enabled_list):
gen = StructGen([('a', data_gen)], nullable=False)
data_path = spark_tmp_path + '/CSV_DATA'
schema = gen.data_type
updated_conf=_enable_ts_conf
updated_conf['spark.sql.sources.useV1SourceList']=v1_enabled_list
with_cpu_session(
lambda spark : gen_df(spark, gen).write.csv(data_path))
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.read.schema(schema).csv(data_path),
conf=_enable_ts_conf)
conf=updated_conf)

@allow_non_gpu('FileSourceScanExec')
@pytest.mark.parametrize('read_func', [read_csv_df, read_csv_sql])
Expand All @@ -174,7 +180,8 @@ def test_csv_fallback(spark_tmp_path, read_func, disable_conf):
csv_supported_date_formats = ['yyyy-MM-dd', 'yyyy/MM/dd', 'yyyy-MM', 'yyyy/MM',
'MM-yyyy', 'MM/yyyy', 'MM-dd-yyyy', 'MM/dd/yyyy']
@pytest.mark.parametrize('date_format', csv_supported_date_formats, ids=idfn)
def test_date_formats_round_trip(spark_tmp_path, date_format):
@pytest.mark.parametrize('v1_enabled_list', ["", "csv"])
def test_date_formats_round_trip(spark_tmp_path, date_format, v1_enabled_list):
gen = StructGen([('a', DateGen())], nullable=False)
data_path = spark_tmp_path + '/CSV_DATA'
schema = gen.data_type
Expand All @@ -186,7 +193,8 @@ def test_date_formats_round_trip(spark_tmp_path, date_format):
lambda spark : spark.read\
.schema(schema)\
.option('dateFormat', date_format)\
.csv(data_path))
.csv(data_path),
conf={'spark.sql.sources.useV1SourceList': v1_enabled_list})

csv_supported_ts_parts = ['', # Just the date
"'T'HH:mm:ss.SSSXXX",
Expand All @@ -199,7 +207,8 @@ def test_date_formats_round_trip(spark_tmp_path, date_format):

@pytest.mark.parametrize('ts_part', csv_supported_ts_parts)
@pytest.mark.parametrize('date_format', csv_supported_date_formats)
def test_ts_formats_round_trip(spark_tmp_path, date_format, ts_part):
@pytest.mark.parametrize('v1_enabled_list', ["", "csv"])
def test_ts_formats_round_trip(spark_tmp_path, date_format, ts_part, v1_enabled_list):
full_format = date_format + ts_part
# Once https://github.com/NVIDIA/spark-rapids/issues/122 is fixed the full range should be used
data_gen = TimestampGen(start=datetime(1902, 1, 1, tzinfo=timezone.utc),
Expand All @@ -211,14 +220,17 @@ def test_ts_formats_round_trip(spark_tmp_path, date_format, ts_part):
lambda spark : gen_df(spark, gen).write\
.option('timestampFormat', full_format)\
.csv(data_path))
updated_conf=_enable_ts_conf
updated_conf['spark.sql.sources.useV1SourceList']=v1_enabled_list
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.read\
.schema(schema)\
.option('timestampFormat', full_format)\
.csv(data_path),
conf=_enable_ts_conf)
conf=updated_conf)

def test_input_meta(spark_tmp_path):
@pytest.mark.parametrize('v1_enabled_list', ["", "csv"])
def test_input_meta(spark_tmp_path, v1_enabled_list):
gen = StructGen([('a', long_gen), ('b', long_gen)], nullable=False)
first_data_path = spark_tmp_path + '/CSV_DATA/key=0'
with_cpu_session(
Expand All @@ -234,4 +246,5 @@ def test_input_meta(spark_tmp_path):
.selectExpr('a',
'input_file_name()',
'input_file_block_start()',
'input_file_block_length()'))
'input_file_block_length()'),
conf={'spark.sql.sources.useV1SourceList': v1_enabled_list})
36 changes: 24 additions & 12 deletions integration_tests/src/main/python/orc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ def read_orc_sql(data_path):

@pytest.mark.parametrize('name', ['timestamp-date-test.orc'])
@pytest.mark.parametrize('read_func', [read_orc_df, read_orc_sql])
def test_basic_read(std_input_path, name, read_func):
@pytest.mark.parametrize('v1_enabled_list', ["", "orc"])
def test_basic_read(std_input_path, name, read_func, v1_enabled_list):
assert_gpu_and_cpu_are_equal_collect(
read_func(std_input_path + '/' + name))
read_func(std_input_path + '/' + name),
conf={'spark.sql.sources.useV1SourceList': v1_enabled_list})

orc_gens_list = [[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)),
Expand Down Expand Up @@ -59,13 +61,15 @@ def test_orc_fallback(spark_tmp_path, read_func, disable_conf):

@pytest.mark.parametrize('orc_gens', orc_gens_list, ids=idfn)
@pytest.mark.parametrize('read_func', [read_orc_df, read_orc_sql])
def test_read_round_trip(spark_tmp_path, orc_gens, read_func):
@pytest.mark.parametrize('v1_enabled_list', ["", "orc"])
def test_read_round_trip(spark_tmp_path, orc_gens, read_func, v1_enabled_list):
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)]
data_path = spark_tmp_path + '/ORC_DATA'
with_cpu_session(
lambda spark : gen_df(spark, gen_list).write.orc(data_path))
assert_gpu_and_cpu_are_equal_collect(
read_func(data_path))
read_func(data_path),
conf={'spark.sql.sources.useV1SourceList': v1_enabled_list})

orc_pred_push_gens = [
byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, boolean_gen,
Expand All @@ -79,30 +83,35 @@ def test_read_round_trip(spark_tmp_path, orc_gens, read_func):

@pytest.mark.parametrize('orc_gen', orc_pred_push_gens, ids=idfn)
@pytest.mark.parametrize('read_func', [read_orc_df, read_orc_sql])
def test_pred_push_round_trip(spark_tmp_path, orc_gen, read_func):
@pytest.mark.parametrize('v1_enabled_list', ["", "orc"])
def test_pred_push_round_trip(spark_tmp_path, orc_gen, read_func, v1_enabled_list):
data_path = spark_tmp_path + '/ORC_DATA'
gen_list = [('a', RepeatSeqGen(orc_gen, 100)), ('b', orc_gen)]
s0 = gen_scalar(orc_gen, force_no_nulls=True)
with_cpu_session(
lambda spark : gen_df(spark, gen_list).orderBy('a').write.orc(data_path))
rf = read_func(data_path)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: rf(spark).select(f.col('a') >= s0))
lambda spark: rf(spark).select(f.col('a') >= s0),
conf={'spark.sql.sources.useV1SourceList': v1_enabled_list})

orc_compress_options = ['none', 'uncompressed', 'snappy', 'zlib']
# The following need extra jars 'lzo'
# https://github.com/NVIDIA/spark-rapids/issues/143

@pytest.mark.parametrize('compress', orc_compress_options)
def test_compress_read_round_trip(spark_tmp_path, compress):
@pytest.mark.parametrize('v1_enabled_list', ["", "orc"])
def test_compress_read_round_trip(spark_tmp_path, compress, v1_enabled_list):
data_path = spark_tmp_path + '/ORC_DATA'
with_cpu_session(
lambda spark : binary_op_df(spark, long_gen).write.orc(data_path),
conf={'spark.sql.orc.compression.codec': compress})
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.read.orc(data_path))
lambda spark : spark.read.orc(data_path),
conf={'spark.sql.sources.useV1SourceList': v1_enabled_list})

def test_simple_partitioned_read(spark_tmp_path):
@pytest.mark.parametrize('v1_enabled_list', ["", "orc"])
def test_simple_partitioned_read(spark_tmp_path, v1_enabled_list):
# Once https://github.com/NVIDIA/spark-rapids/issues/131 is fixed
# we should go with a more standard set of generators
orc_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
Expand All @@ -117,10 +126,12 @@ def test_simple_partitioned_read(spark_tmp_path):
lambda spark : gen_df(spark, gen_list).write.orc(second_data_path))
data_path = spark_tmp_path + '/ORC_DATA'
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.read.orc(data_path))
lambda spark : spark.read.orc(data_path),
conf={'spark.sql.sources.useV1SourceList': v1_enabled_list})

@pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/135')
def test_merge_schema_read(spark_tmp_path):
@pytest.mark.parametrize('v1_enabled_list', ["", "orc"])
def test_merge_schema_read(spark_tmp_path, v1_enabled_list):
# Once https://github.com/NVIDIA/spark-rapids/issues/131 is fixed
# we should go with a more standard set of generators
orc_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
Expand All @@ -136,7 +147,8 @@ def test_merge_schema_read(spark_tmp_path):
lambda spark : gen_df(spark, second_gen_list).write.orc(second_data_path))
data_path = spark_tmp_path + '/ORC_DATA'
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.read.option('mergeSchema', 'true').orc(data_path))
lambda spark : spark.read.option('mergeSchema', 'true').orc(data_path),
conf={'spark.sql.sources.useV1SourceList': v1_enabled_list})

orc_write_gens_list = [
[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
Expand Down
Loading

0 comments on commit ce1f9b8

Please sign in to comment.