Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simdjson based get_json_object Spark function (5179) #447

Merged
merged 1 commit into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions velox/docs/functions/spark/json.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ JSON Functions

.. spark:function:: get_json_object(json, path) -> varchar

Extracts a json object from path::
Extracts a json object from ``path``. Returns NULL if it finds json string
is malformed. ::

SELECT get_json_object('{"a":"b"}', '$.a'); -- b
SELECT get_json_object('{"a":"b"}', '$.a'); -- 'b'
SELECT get_json_object('{"a":{"b":"c"}}', '$.a'); -- '{"b":"c"}'
3 changes: 2 additions & 1 deletion velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ target_link_libraries(
velox_functions_spark_specialforms
velox_is_null_functions
velox_functions_util
Folly::folly)
Folly::folly
simdjson)

set_property(TARGET velox_functions_spark PROPERTY JOB_POOL_COMPILE
high_memory_pool)
Expand Down
4 changes: 2 additions & 2 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include "velox/functions/lib/Re2Functions.h"
#include "velox/functions/lib/RegistrationHelpers.h"
#include "velox/functions/prestosql/DateTimeFunctions.h"
#include "velox/functions/prestosql/JsonFunctions.h"
#include "velox/functions/prestosql/StringFunctions.h"
#include "velox/functions/sparksql/ArrayMinMaxFunction.h"
#include "velox/functions/sparksql/ArraySort.h"
Expand All @@ -35,6 +34,7 @@
#include "velox/functions/sparksql/RegexFunctions.h"
#include "velox/functions/sparksql/RegisterArithmetic.h"
#include "velox/functions/sparksql/RegisterCompare.h"
#include "velox/functions/sparksql/SIMDJsonFunctions.h"
#include "velox/functions/sparksql/Size.h"
#include "velox/functions/sparksql/String.h"
#include "velox/functions/sparksql/UnscaledValueFunction.h"
Expand Down Expand Up @@ -124,7 +124,7 @@ void registerFunctions(const std::string& prefix) {
// Register size functions
registerSize(prefix + "size");

registerFunction<JsonExtractScalarFunction, Varchar, Varchar, Varchar>(
registerFunction<SIMDGetJsonObjectFunction, Varchar, Varchar, Varchar>(
{prefix + "get_json_object"});

// Register string functions.
Expand Down
182 changes: 182 additions & 0 deletions velox/functions/sparksql/SIMDJsonFunctions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/functions/prestosql/SIMDJsonFunctions.h"

using namespace simdjson;

namespace facebook::velox::functions::sparksql {

template <typename T>
struct SIMDGetJsonObjectFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);
std::optional<std::string> formattedJsonPath_;

// ASCII input always produces ASCII result.
static constexpr bool is_default_ascii_behavior = true;

// Makes a conversion from spark's json path, e.g., converts
// "$.a.b" to "/a/b".
FOLLY_ALWAYS_INLINE std::string getFormattedJsonPath(
const arg_type<Varchar>& jsonPath) {
char formattedJsonPath[jsonPath.size() + 1];
int j = 0;
for (int i = 0; i < jsonPath.size(); i++) {
if (jsonPath.data()[i] == '$' || jsonPath.data()[i] == ']' ||
jsonPath.data()[i] == '\'') {
continue;
} else if (jsonPath.data()[i] == '[' || jsonPath.data()[i] == '.') {
formattedJsonPath[j] = '/';
j++;
} else {
formattedJsonPath[j] = jsonPath.data()[i];
j++;
}
}
formattedJsonPath[j] = '\0';
return std::string(formattedJsonPath, j + 1);
}

FOLLY_ALWAYS_INLINE void initialize(
const core::QueryConfig& config,
const arg_type<Varchar>* /*json*/,
const arg_type<Varchar>* jsonPath) {
if (jsonPath != nullptr) {
formattedJsonPath_ = getFormattedJsonPath(*jsonPath);
}
}

FOLLY_ALWAYS_INLINE simdjson::error_code extractStringResult(
simdjson_result<ondemand::value> rawResult,
out_type<Varchar>& result) {
simdjson::error_code error;
std::stringstream ss;
switch (rawResult.type()) {
// For number and bool types, we need to explicitly get the value
// for specific types instead of using `ss << rawResult`. Thus, we
// can make simdjson's internal parsing position moved and then we
// can check the validity of ending character.
case ondemand::json_type::number: {
switch (rawResult.get_number_type()) {
case ondemand::number_type::unsigned_integer: {
uint64_t numberResult;
error = rawResult.get_uint64().get(numberResult);
if (!error) {
ss << numberResult;
result.append(ss.str());
}
return error;
}
case ondemand::number_type::signed_integer: {
int64_t numberResult;
error = rawResult.get_int64().get(numberResult);
if (!error) {
ss << numberResult;
result.append(ss.str());
}
return error;
}
case ondemand::number_type::floating_point_number: {
double numberResult;
error = rawResult.get_double().get(numberResult);
if (!error) {
ss << numberResult;
result.append(ss.str());
}
return error;
}
default:
VELOX_UNREACHABLE();
}
}
case ondemand::json_type::boolean: {
bool boolResult;
error = rawResult.get_bool().get(boolResult);
if (!error) {
result.append(boolResult ? "true" : "false");
}
return error;
}
case ondemand::json_type::string: {
std::string_view stringResult;
error = rawResult.get_string().get(stringResult);
result.append(stringResult);
return error;
}
case ondemand::json_type::object: {
// For nested case, e.g., for "{"my": {"hello": 10}}", "$.my" will
// return an object type.
ss << rawResult;
result.append(ss.str());
return SUCCESS;
}
case ondemand::json_type::array: {
ss << rawResult;
result.append(ss.str());
return SUCCESS;
}
default: {
return UNSUPPORTED_ARCHITECTURE;
}
}
}

// This is a simple validation by checking whether the obtained result is
// followed by valid char. Because ondemand parsing we are using ignores json
// format validation for characters following the current parsing position.
bool isValidEndingCharacter(const char* currentPos) {
char endingChar = *currentPos;
if (endingChar == ',' || endingChar == '}' || endingChar == ']') {
return true;
}
if (endingChar == ' ' || endingChar == '\r' || endingChar == '\n' ||
endingChar == '\t') {
// These chars can be prior to a valid ending char.
return isValidEndingCharacter(currentPos++);
}
return false;
}

FOLLY_ALWAYS_INLINE bool call(
out_type<Varchar>& result,
const arg_type<Varchar>& json,
const arg_type<Varchar>& jsonPath) {
ParserContext ctx(json.data(), json.size());
try {
ctx.parseDocument();
simdjson_result<ondemand::value> rawResult =
formattedJsonPath_.has_value()
? ctx.jsonDoc.at_pointer(formattedJsonPath_.value().data())
: ctx.jsonDoc.at_pointer(getFormattedJsonPath(jsonPath).data());
// Field not found.
if (rawResult.error() == NO_SUCH_FIELD) {
return false;
}
auto error = extractStringResult(rawResult, result);
if (error) {
return false;
}
} catch (simdjson_error& e) {
return false;
}

const char* currentPos;
ctx.jsonDoc.current_location().get(currentPos);
return isValidEndingCharacter(currentPos);
}
};

} // namespace facebook::velox::functions::sparksql
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ add_executable(
ElementAtTest.cpp
HashTest.cpp
InTest.cpp
JsonFunctionsTest.cpp
LeastGreatestTest.cpp
MapTest.cpp
MightContainTest.cpp
Expand Down
105 changes: 105 additions & 0 deletions velox/functions/sparksql/tests/JsonFunctionsTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/prestosql/types/JsonType.h"
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"
#include "velox/type/Type.h"

#include <stdint.h>

namespace facebook::velox::functions::sparksql::test {
namespace {

class JsonFunctionTest : public SparkFunctionBaseTest {
protected:
std::optional<std::string> getJsonObject(
const std::optional<std::string>& json,
const std::optional<std::string>& jsonPath) {
return evaluateOnce<std::string>("get_json_object(c0, c1)", json, jsonPath);
}
};

TEST_F(JsonFunctionTest, getJsonObject) {
EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$.hello"), "3.5");
EXPECT_EQ(getJsonObject(R"({"hello": 3.5})", "$.hello"), "3.5");
EXPECT_EQ(getJsonObject(R"({"hello": 292222730})", "$.hello"), "292222730");
EXPECT_EQ(getJsonObject(R"({"hello": -292222730})", "$.hello"), "-292222730");
EXPECT_EQ(getJsonObject(R"({"my": {"hello": 3.5}})", "$.my.hello"), "3.5");
EXPECT_EQ(getJsonObject(R"({"my": {"hello": true}})", "$.my.hello"), "true");
EXPECT_EQ(getJsonObject(R"({"hello": ""})", "$.hello"), "");
EXPECT_EQ(
getJsonObject(R"({"name": "Alice", "age": 5, "id": "001"})", "$.age"),
"5");
EXPECT_EQ(
getJsonObject(R"({"name": "Alice", "age": 5, "id": "001"})", "$.id"),
"001");
EXPECT_EQ(
getJsonObject(
R"([{"my": {"param": {"name": "Alice", "age": "5", "id": "001"}}}, {"other": "v1"}])",
"$[0]['my']['param']['age']"),
"5");
EXPECT_EQ(
getJsonObject(
R"([{"my": {"param": {"name": "Alice", "age": "5", "id": "001"}}}, {"other": "v1"}])",
"$[0].my.param.age"),
"5");

// Json object as result.
EXPECT_EQ(
getJsonObject(
R"({"my": {"param": {"name": "Alice", "age": "5", "id": "001"}}})",
"$.my.param"),
R"({"name": "Alice", "age": "5", "id": "001"})");
EXPECT_EQ(
getJsonObject(
R"({"my": {"param": {"name": "Alice", "age": "5", "id": "001"}}})",
"$['my']['param']"),
R"({"name": "Alice", "age": "5", "id": "001"})");

// Array as result.
EXPECT_EQ(
getJsonObject(
R"([{"my": {"param": {"name": "Alice"}}}, {"other": ["v1", "v2"]}])",
"$[1].other"),
R"(["v1", "v2"])");
// Array element as result.
EXPECT_EQ(
getJsonObject(
R"([{"my": {"param": {"name": "Alice"}}}, {"other": ["v1", "v2"]}])",
"$[1].other[0]"),
"v1");
EXPECT_EQ(
getJsonObject(
R"([{"my": {"param": {"name": "Alice"}}}, {"other": ["v1", "v2"]}])",
"$[1].other[1]"),
"v2");

// Field not found.
EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$.hi"), std::nullopt);
// Illegal json.
EXPECT_EQ(getJsonObject(R"({"hello"-3.5})", "$.hello"), std::nullopt);
// Illegal json path.
EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$hello"), std::nullopt);
EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$."), std::nullopt);
// Invalid ending character.
EXPECT_EQ(
getJsonObject(
R"([{"my": {"param": {"name": "Alice"quoted""}}}, {"other": ["v1", "v2"]}])",
"$[0].my.param.name"),
std::nullopt);
}

} // namespace
} // namespace facebook::velox::functions::sparksql::test
Loading