Skip to content

Commit

Permalink
Add in tests for Maps and extend map support where possible (NVIDIA#1148
Browse files Browse the repository at this point in the history
)

Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
revans2 authored Nov 18, 2020
1 parent 71e780b commit e872bf1
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 22 deletions.
4 changes: 4 additions & 0 deletions integration_tests/src/main/python/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def _assert_equal(cpu, gpu, float_check, path):
_assert_equal(sub_cpu, sub_gpu, float_check, path + [index])

index = index + 1
elif (t is dict):
# TODO eventually we need to split this up so we can do the right thing for float/double
# values stored under the map some where, especially for NaNs
assert cpu == gpu, "GPU and CPU map values are different at {}".format(path)
elif (t is int):
assert cpu == gpu, "GPU and CPU int values are different at {}".format(path)
elif (t is float):
Expand Down
8 changes: 4 additions & 4 deletions integration_tests/src/main/python/cmp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,24 +117,24 @@ def test_isnan(data_gen):
lambda spark : unary_op_df(spark, data_gen).select(
f.isnan(f.col('a'))))

@pytest.mark.parametrize('data_gen', eq_gens + array_gens_sample + struct_gens_sample, ids=idfn)
@pytest.mark.parametrize('data_gen', eq_gens + array_gens_sample + struct_gens_sample + map_gens_sample, ids=idfn)
def test_dropna_any(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).dropna())

@pytest.mark.parametrize('data_gen', eq_gens + array_gens_sample + struct_gens_sample, ids=idfn)
@pytest.mark.parametrize('data_gen', eq_gens + array_gens_sample + struct_gens_sample + map_gens_sample, ids=idfn)
def test_dropna_all(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).dropna(how='all'))

#dropna is really a filter along with a test for null, but lets do an explicit filter test too
@pytest.mark.parametrize('data_gen', eq_gens + array_gens_sample + struct_gens_sample, ids=idfn)
@pytest.mark.parametrize('data_gen', eq_gens + array_gens_sample + struct_gens_sample + map_gens_sample, ids=idfn)
def test_filter(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : three_col_df(spark, BooleanGen(), data_gen, data_gen).filter(f.col('a')))

# coalesce batch happens after a filter, but only if something else happens on the GPU after that
@pytest.mark.parametrize('data_gen', eq_gens + array_gens_sample + struct_gens_sample, ids=idfn)
@pytest.mark.parametrize('data_gen', eq_gens + array_gens_sample + struct_gens_sample + map_gens_sample, ids=idfn)
def test_filter_with_project(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : two_col_df(spark, BooleanGen(), data_gen).filter(f.col('a')).selectExpr('*', 'a as a2'))
Expand Down
38 changes: 38 additions & 0 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,35 @@ def gen_array():
return [self._child_gen.gen() for _ in range(0, length)]
self._start(rand, gen_array)

def contains_ts(self):
return self._child_gen.contains_ts()

class MapGen(DataGen):
"""Generate a Map"""
def __init__(self, key_gen, value_gen, min_length=0, max_length=20, nullable=True, special_cases=[]):
# keys cannot be nullable
assert not key_gen.nullable
self._min_length = min_length
self._max_length = max_length
self._key_gen = key_gen
self._value_gen = value_gen
super().__init__(MapType(key_gen.data_type, value_gen.data_type, valueContainsNull=value_gen.nullable), nullable=nullable, special_cases=special_cases)

def __repr__(self):
return super().__repr__() + '(' + str(self._key_gen) + ',' + str(self._value_gen) + ')'

def start(self, rand):
self._key_gen.start(rand)
self._value_gen.start(rand)
def make_dict():
length = rand.randint(self._min_length, self._max_length)
return {self._key_gen.gen(): self._value_gen.gen() for idx in range(0, length)}
self._start(rand, make_dict)

def contains_ts(self):
return self._key_gen.contains_ts() or self._value_gen.contains_ts()


def skip_if_not_utc():
if (not is_tz_utc()):
pytest.skip('The java system time zone is not set to UTC')
Expand Down Expand Up @@ -686,3 +715,12 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
StructGen([['child0', byte_gen]]),
StructGen([['child0', ArrayGen(short_gen)], ['child1', double_gen]])]

simple_string_to_string_map_gen = MapGen(StringGen(pattern='key_[0-9]', nullable=False),
StringGen(), max_length=10)

# Some map gens, but not all because of nesting
map_gens_sample = [simple_string_to_string_map_gen,
MapGen(StringGen(pattern='key_[0-9]', nullable=False), ArrayGen(string_gen), max_length=10),
MapGen(RepeatSeqGen(IntegerGen(nullable=False), 10), long_gen, max_length=10),
MapGen(BooleanGen(nullable=False), boolean_gen, max_length=2),
MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen)]
32 changes: 32 additions & 0 deletions integration_tests/src/main/python/map_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect
from data_gen import *
from marks import incompat
from pyspark.sql.types import *
import pyspark.sql.functions as f

@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
def test_simple_get_map_value(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'a["key_0"]',
'a["key_1"]',
'a[null]',
'a["key_9"]',
'a["NOT_FOUND"]',
'a["key_5"]'))
2 changes: 1 addition & 1 deletion integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def read_parquet_sql(data_path):
ArrayGen(TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))),
ArrayGen(ArrayGen(byte_gen)),
StructGen([['child0', ArrayGen(byte_gen)], ['child1', byte_gen], ['child2', float_gen]]),
ArrayGen(StructGen([['child0', string_gen], ['child1', double_gen], ['child2', int_gen]]))],
ArrayGen(StructGen([['child0', string_gen], ['child1', double_gen], ['child2', int_gen]]))] + map_gens_sample,
pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/132'))]

# test with original parquet file reader, the multi-file parallel reader for cloud, and coalesce file reader for
Expand Down
4 changes: 3 additions & 1 deletion integration_tests/src/main/python/row_conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def test_row_conversions():
["i", timestamp_gen], ["j", date_gen], ["k", ArrayGen(byte_gen)],
["l", ArrayGen(string_gen)], ["m", ArrayGen(float_gen)],
["n", ArrayGen(boolean_gen)], ["o", ArrayGen(ArrayGen(short_gen))],
["p", StructGen([["c0", byte_gen], ["c1", ArrayGen(byte_gen)]])]]
["p", StructGen([["c0", byte_gen], ["c1", ArrayGen(byte_gen)]])],
["q", simple_string_to_string_map_gen],
["r", MapGen(BooleanGen(nullable=False), ArrayGen(boolean_gen), max_length=2)]]
assert_gpu_and_cpu_are_equal_collect(
lambda spark : gen_df(spark, gens).selectExpr("*", "a as a_again"))
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class Spark300Shims extends SparkShims {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowArray = true,
allowStringMaps = true,
allowMaps = true,
allowStruct = true,
allowNesting = true)

Expand Down Expand Up @@ -285,8 +285,6 @@ class Spark300Shims extends SparkShims {
a.dataFilters,
conf)
}
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t, allowStringMaps = true)
}),
GpuOverrides.scan[OrcScan](
"ORC parsing",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class Spark301dbShims extends Spark301Shims {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowArray = true,
allowStringMaps = true,
allowMaps = true,
allowStruct = true,
allowNesting = true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class Spark310Shims extends Spark301Shims {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowArray = true,
allowStringMaps = true,
allowMaps = true,
allowStruct = true,
allowNesting = true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ object GpuOverrides {
(a, conf, p, r) => new UnaryExprMeta[Alias](a, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowStringMaps = true,
allowMaps = true,
allowArray = true,
allowStruct = true,
allowNesting = true)
Expand All @@ -673,7 +673,7 @@ object GpuOverrides {
(att, conf, p, r) => new BaseExprMeta[AttributeReference](att, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowStringMaps = true,
allowMaps = true,
allowArray = true,
allowStruct = true,
allowNesting = true)
Expand Down Expand Up @@ -870,7 +870,7 @@ object GpuOverrides {
(a, conf, p, r) => new UnaryExprMeta[IsNull](a, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowStringMaps = true,
allowMaps = true,
allowArray = true,
allowStruct = true,
allowNesting = true)
Expand All @@ -882,7 +882,7 @@ object GpuOverrides {
(a, conf, p, r) => new UnaryExprMeta[IsNotNull](a, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowStringMaps = true,
allowMaps = true,
allowArray = true,
allowStruct = true,
allowNesting = true)
Expand Down Expand Up @@ -914,6 +914,7 @@ object GpuOverrides {

override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowMaps = true,
allowArray = true,
allowStruct = true,
allowNesting = true)
Expand Down Expand Up @@ -1267,9 +1268,6 @@ object GpuOverrides {
expr[EqualTo](
"Check if the values are equal",
(a, conf, p, r) => new BinaryExprMeta[EqualTo](a, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t, allowStringMaps = true)

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuEqualTo(lhs, rhs)
}),
Expand Down Expand Up @@ -1879,7 +1877,7 @@ object GpuOverrides {
new SparkPlanMeta[ProjectExec](proj, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowStringMaps = true,
allowMaps = true,
allowArray = true,
allowStruct = true,
allowNesting = true)
Expand All @@ -1904,7 +1902,7 @@ object GpuOverrides {

override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowStringMaps = true,
allowMaps = true,
allowArray = true,
allowStruct = true,
allowNesting = true)
Expand Down Expand Up @@ -1975,7 +1973,7 @@ object GpuOverrides {
(filter, conf, p, r) => new SparkPlanMeta[FilterExec](filter, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowStringMaps = true,
allowMaps = true,
allowArray = true,
allowStruct = true,
allowNesting = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ object GpuParquetScanBase {
for (field <- readSchema) {
if (!GpuOverrides.isSupportedType(
field.dataType,
allowStringMaps = true,
allowMaps = true,
allowArray = true,
allowStruct = true,
allowNesting = true)) {
Expand Down

0 comments on commit e872bf1

Please sign in to comment.