Skip to content

Commit

Permalink
Add json reader support (#4485)
Browse files Browse the repository at this point in the history
Signed-off-by: Bobby Wang <wbo4958@gmail.com>
  • Loading branch information
wbo4958 authored Jan 20, 2022
1 parent f89b19e commit 7eb6097
Show file tree
Hide file tree
Showing 22 changed files with 839 additions and 53 deletions.
51 changes: 50 additions & 1 deletion docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,54 @@ The plugin supports reading `uncompressed`, `snappy` and `gzip` Parquet files an
fall back to the CPU when reading an unsupported compression format, and will error out in that
case.

## JSON

The JSON format read is a very experimental feature which is expected to have some issues, so we disable
it by default. If you would like to test it, you need to enable `spark.rapids.sql.format.json.enabled` and
`spark.rapids.sql.format.json.read.enabled`.

Currently, the GPU accelerated JSON reader doesn't support column pruning, which will likely make
this difficult to use or even test. The user must specify the full schema or just let Spark infer
the schema from the JSON file. eg,

We have a `people.json` file with below content

``` console
{"name":"Michael"}
{"name":"Andy", "age":30}
{"name":"Justin", "age":19}
```

Both below ways will work

- Inferring the schema

``` scala
val df = spark.read.json("people.json")
```

- Specifying the full schema

``` scala
val schema = StructType(Seq(StructField("name", StringType), StructField("age", IntegerType)))
val df = spark.read.schema(schema).json("people.json")
```

While the below code will not work in the current version,

``` scala
val schema = StructType(Seq(StructField("name", StringType)))
val df = spark.read.schema(schema).json("people.json")
```

### JSON supporting types

The nested types(array, map and struct) are not supported yet in current version.

### JSON Floating Point

Like the CSV reader, the JSON reader has the same floating point issue. Please refer to [CSV Floating Point](#csv-floating-point) section.

## LIKE

If a null char '\0' is in a string that is being matched by a regular expression, `LIKE` sees it as
Expand Down Expand Up @@ -840,4 +888,5 @@ Seq(0L, Long.MaxValue).toDF("val")
```

But this is not something that can be done generically and requires inner knowledge about
what can trigger a side effect.
what can trigger a side effect.

3 changes: 3 additions & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ Name | Description | Default Value
<a name="sql.fast.sample"></a>spark.rapids.sql.fast.sample|Option to turn on fast sample. If enable it is inconsistent with CPU sample because of GPU sample algorithm is inconsistent with CPU.|false
<a name="sql.format.csv.enabled"></a>spark.rapids.sql.format.csv.enabled|When set to false disables all csv input and output acceleration. (only input is currently supported anyways)|true
<a name="sql.format.csv.read.enabled"></a>spark.rapids.sql.format.csv.read.enabled|When set to false disables csv input acceleration|true
<a name="sql.format.json.enabled"></a>spark.rapids.sql.format.json.enabled|When set to true enables all json input and output acceleration. (only input is currently supported anyways)|false
<a name="sql.format.json.read.enabled"></a>spark.rapids.sql.format.json.read.enabled|When set to true enables json input acceleration|false
<a name="sql.format.orc.enabled"></a>spark.rapids.sql.format.orc.enabled|When set to false disables all orc input and output acceleration|true
<a name="sql.format.orc.multiThreadedRead.maxNumFilesParallel"></a>spark.rapids.sql.format.orc.multiThreadedRead.maxNumFilesParallel|A limit on the maximum number of files per task processed in parallel on the CPU side before the file is sent to the GPU. This affects the amount of host memory used when reading the files in parallel. Used with MULTITHREADED reader, see spark.rapids.sql.format.orc.reader.type|2147483647
<a name="sql.format.orc.multiThreadedRead.numThreads"></a>spark.rapids.sql.format.orc.multiThreadedRead.numThreads|The maximum number of threads, on the executor, to use for reading small orc files in parallel. This can not be changed at runtime after the executor has started. Used with MULTITHREADED reader, see spark.rapids.sql.format.orc.reader.type.|20
Expand Down Expand Up @@ -384,6 +386,7 @@ Name | Description | Default Value | Notes
Name | Description | Default Value | Notes
-----|-------------|---------------|------------------
<a name="sql.input.CSVScan"></a>spark.rapids.sql.input.CSVScan|CSV parsing|true|None|
<a name="sql.input.JsonScan"></a>spark.rapids.sql.input.JsonScan|Json parsing|true|None|
<a name="sql.input.OrcScan"></a>spark.rapids.sql.input.OrcScan|ORC parsing|true|None|
<a name="sql.input.ParquetScan"></a>spark.rapids.sql.input.ParquetScan|Parquet parsing|true|None|

Expand Down
43 changes: 43 additions & 0 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -17543,6 +17543,49 @@ dates or timestamps, or for a lack of type coercion support.
<td> </td>
</tr>
<tr>
<th rowSpan="2">JSON</th>
<th>Read</th>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td><b>NS</b></td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
</tr>
<tr>
<th>Write</th>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
</tr>
<tr>
<th rowSpan="2">ORC</th>
<th>Read</th>
<td>S</td>
Expand Down
35 changes: 35 additions & 0 deletions integration_tests/src/main/python/get_json_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect
from data_gen import *
from pyspark.sql.types import *

def mk_json_str_gen(pattern):
return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')

@pytest.mark.parametrize('json_str_pattern', [r'\{"store": \{"fruit": \[\{"weight":\d,"type":"[a-z]{1,9}"\}\], ' \
r'"bicycle":\{"price":\d\d\.\d\d,"color":"[a-z]{0,4}"\}\},' \
r'"email":"[a-z]{1,5}\@[a-z]{3,10}\.com","owner":"[a-z]{3,8}"\}',
r'\{"a": "[a-z]{1,3}"\}'], ids=idfn)
def test_get_json_object(json_str_pattern):
gen = mk_json_str_gen(json_str_pattern)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen, length=10).selectExpr(
'get_json_object(a,"$.a")',
'get_json_object(a, "$.owner")',
'get_json_object(a, "$.store.fruit[0]")'),
conf={'spark.sql.parser.escapedStringLiterals': 'true'})
136 changes: 121 additions & 15 deletions integration_tests/src/main/python/json_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,20 +16,126 @@

from asserts import assert_gpu_and_cpu_are_equal_collect
from data_gen import *
from pyspark.sql.types import *
from src.main.python.marks import approximate_float, allow_non_gpu

def mk_json_str_gen(pattern):
return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')
from src.main.python.spark_session import with_cpu_session

@pytest.mark.parametrize('json_str_pattern', [r'\{"store": \{"fruit": \[\{"weight":\d,"type":"[a-z]{1,9}"\}\], ' \
r'"bicycle":\{"price":\d\d\.\d\d,"color":"[a-z]{0,4}"\}\},' \
r'"email":"[a-z]{1,5}\@[a-z]{3,10}\.com","owner":"[a-z]{3,8}"\}',
r'\{"a": "[a-z]{1,3}"\}'], ids=idfn)
def test_get_json_object(json_str_pattern):
gen = mk_json_str_gen(json_str_pattern)
json_supported_gens = [
# Spark does not escape '\r' or '\n' even though it uses it to mark end of record
# This would require multiLine reads to work correctly, so we avoid these chars
StringGen('(\\w| |\t|\ud720){0,10}', nullable=False),
StringGen('[aAbB ]{0,10}'),
byte_gen, short_gen, int_gen, long_gen, boolean_gen,
# Once https://github.com/NVIDIA/spark-rapids/issues/125 and https://github.com/NVIDIA/spark-rapids/issues/124
# are fixed we should not have to special case float values any more.
pytest.param(double_gen, marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/125')),
pytest.param(FloatGen(no_nans=True), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/124')),
pytest.param(float_gen, marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/125')),
DoubleGen(no_nans=True)
]

_enable_all_types_conf = {
'spark.rapids.sql.format.json.enabled': 'true',
'spark.rapids.sql.format.json.read.enabled': 'true'}

@approximate_float
@pytest.mark.parametrize('data_gen', [
StringGen('(\\w| |\t|\ud720){0,10}', nullable=False),
StringGen('[aAbB ]{0,10}'),
byte_gen, short_gen, int_gen, long_gen, boolean_gen,], ids=idfn)
@pytest.mark.parametrize('v1_enabled_list', ["", "json"])
@allow_non_gpu('FileSourceScanExec')
def test_json_infer_schema_round_trip(spark_tmp_path, data_gen, v1_enabled_list):
gen = StructGen([('a', data_gen)], nullable=False)
data_path = spark_tmp_path + '/JSON_DATA'
updated_conf = copy_and_update(_enable_all_types_conf, {'spark.sql.sources.useV1SourceList': v1_enabled_list})
with_cpu_session(
lambda spark : gen_df(spark, gen).write.json(data_path))
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.read.json(data_path),
conf=updated_conf)

@approximate_float
@pytest.mark.parametrize('data_gen', json_supported_gens, ids=idfn)
@pytest.mark.parametrize('v1_enabled_list', ["", "json"])
def test_json_round_trip(spark_tmp_path, data_gen, v1_enabled_list):
gen = StructGen([('a', data_gen)], nullable=False)
data_path = spark_tmp_path + '/JSON_DATA'
schema = gen.data_type
updated_conf = copy_and_update(_enable_all_types_conf, {'spark.sql.sources.useV1SourceList': v1_enabled_list})
with_cpu_session(
lambda spark : gen_df(spark, gen).write.json(data_path))
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.read.schema(schema).json(data_path),
conf=updated_conf)

@pytest.mark.parametrize('v1_enabled_list', ["", "json"])
def test_json_input_meta(spark_tmp_path, v1_enabled_list):
gen = StructGen([('a', long_gen), ('b', long_gen), ('c', long_gen)], nullable=False)
first_data_path = spark_tmp_path + '/JSON_DATA/key=0'
with_cpu_session(
lambda spark : gen_df(spark, gen).write.json(first_data_path))
second_data_path = spark_tmp_path + '/JSON_DATA/key=1'
with_cpu_session(
lambda spark : gen_df(spark, gen).write.json(second_data_path))
data_path = spark_tmp_path + '/JSON_DATA'
updated_conf = copy_and_update(_enable_all_types_conf, {'spark.sql.sources.useV1SourceList': v1_enabled_list})
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.read.schema(gen.data_type)
.json(data_path)
.filter(f.col('b') > 0)
.selectExpr('b',
'input_file_name()',
'input_file_block_start()',
'input_file_block_length()'),
conf=updated_conf)

json_supported_date_formats = ['yyyy-MM-dd', 'yyyy/MM/dd', 'yyyy-MM', 'yyyy/MM',
'MM-yyyy', 'MM/yyyy', 'MM-dd-yyyy', 'MM/dd/yyyy']
@pytest.mark.parametrize('date_format', json_supported_date_formats, ids=idfn)
@pytest.mark.parametrize('v1_enabled_list', ["", "json"])
def test_json_date_formats_round_trip(spark_tmp_path, date_format, v1_enabled_list):
gen = StructGen([('a', DateGen())], nullable=False)
data_path = spark_tmp_path + '/JSON_DATA'
schema = gen.data_type
updated_conf = copy_and_update(_enable_all_types_conf, {'spark.sql.sources.useV1SourceList': v1_enabled_list})
with_cpu_session(
lambda spark : gen_df(spark, gen).write\
.option('dateFormat', date_format)\
.json(data_path))
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.read\
.schema(schema)\
.option('dateFormat', date_format)\
.json(data_path),
conf=updated_conf)

json_supported_ts_parts = ['', # Just the date
"'T'HH:mm:ss.SSSXXX",
"'T'HH:mm:ss[.SSS][XXX]",
"'T'HH:mm:ss.SSS",
"'T'HH:mm:ss[.SSS]",
"'T'HH:mm:ss",
"'T'HH:mm[:ss]",
"'T'HH:mm"]

@pytest.mark.parametrize('ts_part', json_supported_ts_parts)
@pytest.mark.parametrize('date_format', json_supported_date_formats)
@pytest.mark.parametrize('v1_enabled_list', ["", "json"])
def test_json_ts_formats_round_trip(spark_tmp_path, date_format, ts_part, v1_enabled_list):
full_format = date_format + ts_part
data_gen = TimestampGen()
gen = StructGen([('a', data_gen)], nullable=False)
data_path = spark_tmp_path + '/JSON_DATA'
schema = gen.data_type
with_cpu_session(
lambda spark : gen_df(spark, gen).write\
.option('timestampFormat', full_format)\
.json(data_path))
updated_conf = copy_and_update(_enable_all_types_conf, {'spark.sql.sources.useV1SourceList': v1_enabled_list})
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen, length=10).selectExpr(
'get_json_object(a,"$.a")',
'get_json_object(a, "$.owner")',
'get_json_object(a, "$.store.fruit[0]")'),
conf={'spark.sql.parser.escapedStringLiterals': 'true'})
lambda spark : spark.read\
.schema(schema)\
.option('timestampFormat', full_format)\
.json(data_path),
conf=updated_conf)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,20 +18,13 @@ package com.nvidia.spark.rapids.shims.v2

import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.csv._
import org.apache.spark.sql.catalyst.json.rapids.shims.v2.Spark30Xuntil33XFileOptionsShims
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.v2._

trait Spark30Xuntil33XShims extends SparkShims {
def dateFormatInRead(csvOpts: CSVOptions): Option[String] = {
Option(csvOpts.dateFormat)
}

def timestampFormatInRead(csvOpts: CSVOptions): Option[String] = {
Option(csvOpts.timestampFormat)
}
trait Spark30Xuntil33XShims extends Spark30Xuntil33XFileOptionsShims {

def neverReplaceShowCurrentNamespaceCommand: ExecRule[_ <: SparkPlan] = {
GpuOverrides.neverReplaceExec[ShowCurrentNamespaceExec]("Namespace metadata operation")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.json.rapids.shims.v2

import com.nvidia.spark.rapids.SparkShims

import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.catalyst.json.JSONOptions

trait Spark30Xuntil33XFileOptionsShims extends SparkShims {

def dateFormatInRead(fileOptions: Serializable): Option[String] = {
fileOptions match {
case csvOpts: CSVOptions => Option(csvOpts.dateFormat)
case jsonOpts: JSONOptions => Option(jsonOpts.dateFormat)
case _ => throw new RuntimeException("Wrong file options.")
}
}

def timestampFormatInRead(fileOptions: Serializable): Option[String] = {
fileOptions match {
case csvOpts: CSVOptions => Option(csvOpts.timestampFormat)
case jsonOpts: JSONOptions => Option(jsonOpts.timestampFormat)
case _ => throw new RuntimeException("Wrong file options.")
}
}

}
Loading

0 comments on commit 7eb6097

Please sign in to comment.