Skip to content

Commit

Permalink
add json test
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jan 10, 2022
1 parent 3d55627 commit 588fb67
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 15 deletions.
35 changes: 35 additions & 0 deletions integration_tests/src/main/python/get_json_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect
from data_gen import *
from pyspark.sql.types import *

def mk_json_str_gen(pattern):
return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')

@pytest.mark.parametrize('json_str_pattern', [r'\{"store": \{"fruit": \[\{"weight":\d,"type":"[a-z]{1,9}"\}\], ' \
r'"bicycle":\{"price":\d\d\.\d\d,"color":"[a-z]{0,4}"\}\},' \
r'"email":"[a-z]{1,5}\@[a-z]{3,10}\.com","owner":"[a-z]{3,8}"\}',
r'\{"a": "[a-z]{1,3}"\}'], ids=idfn)
def test_get_json_object(json_str_pattern):
gen = mk_json_str_gen(json_str_pattern)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen, length=10).selectExpr(
'get_json_object(a,"$.a")',
'get_json_object(a, "$.owner")',
'get_json_object(a, "$.store.fruit[0]")'),
conf={'spark.sql.parser.escapedStringLiterals': 'true'})
40 changes: 25 additions & 15 deletions integration_tests/src/main/python/json_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,20 +16,30 @@

from asserts import assert_gpu_and_cpu_are_equal_collect
from data_gen import *
from pyspark.sql.types import *
from src.main.python.marks import approximate_float

def mk_json_str_gen(pattern):
return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')
from src.main.python.spark_session import with_cpu_session

@pytest.mark.parametrize('json_str_pattern', [r'\{"store": \{"fruit": \[\{"weight":\d,"type":"[a-z]{1,9}"\}\], ' \
r'"bicycle":\{"price":\d\d\.\d\d,"color":"[a-z]{0,4}"\}\},' \
r'"email":"[a-z]{1,5}\@[a-z]{3,10}\.com","owner":"[a-z]{3,8}"\}',
r'\{"a": "[a-z]{1,3}"\}'], ids=idfn)
def test_get_json_object(json_str_pattern):
gen = mk_json_str_gen(json_str_pattern)
json_supported_gens = [
byte_gen, short_gen, int_gen, long_gen, boolean_gen,
# FloatGen(no_nans=True), # Test will fail
DoubleGen(no_nans=True)
]

_enable_all_types_conf = {
'spark.rapids.sql.format.json.enabled': 'true',
'spark.rapids.sql.format.json.read.enabled': 'true'}

@approximate_float
@pytest.mark.parametrize('data_gen', json_supported_gens, ids=idfn)
@pytest.mark.parametrize('v1_enabled_list', ["", "json"])
def test_round_trip(spark_tmp_path, data_gen, v1_enabled_list):
gen = StructGen([('a', data_gen)], nullable=False)
data_path = spark_tmp_path + '/JSON_DATA'
schema = gen.data_type
updated_conf = copy_and_update(_enable_all_types_conf, {'spark.sql.sources.useV1SourceList': v1_enabled_list})
with_cpu_session(
lambda spark : gen_df(spark, gen).write.json(data_path))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen, length=10).selectExpr(
'get_json_object(a,"$.a")',
'get_json_object(a, "$.owner")',
'get_json_object(a, "$.store.fruit[0]")'),
conf={'spark.sql.parser.escapedStringLiterals': 'true'})
lambda spark : spark.read.schema(schema).json(data_path),
conf=updated_conf)

0 comments on commit 588fb67

Please sign in to comment.