Skip to content

Commit

Permalink
Add API for getting all registered Spark scalar function names (#11196)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #11196

Differential Revision: D64052727
  • Loading branch information
kagamiori authored and facebook-github-bot committed Oct 8, 2024
1 parent 3ce3fb1 commit 280154f
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 154 deletions.
10 changes: 10 additions & 0 deletions velox/expression/SpecialFormRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ SpecialFormRegistry::getSpecialForm(const std::string& name) const {
return specialForm;
}

std::vector<std::string> SpecialFormRegistry::getSpecialFormNames() const {
std::vector<std::string> names;
registry_.withRLock([&](const auto& map) {
for (const auto& [name, _] : map) {
names.push_back(name);
}
});
return names;
}

const SpecialFormRegistry& specialFormRegistry() {
return specialFormRegistryInternal();
}
Expand Down
2 changes: 2 additions & 0 deletions velox/expression/SpecialFormRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class SpecialFormRegistry {
FunctionCallToSpecialForm* FOLLY_NULLABLE
getSpecialForm(const std::string& name) const;

std::vector<std::string> getSpecialFormNames() const;

private:
folly::Synchronized<RegistryType> registry_;
};
Expand Down
16 changes: 16 additions & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,5 +520,21 @@ void registerFunctions(const std::string& prefix) {
Varchar>({prefix + "mask"});
}

std::vector<std::string> listFunctionNames() {
std::vector<std::string> names =
exec::specialFormRegistry().getSpecialFormNames();

const auto& simpleFunctions = exec::simpleFunctions().getFunctionNames();
names.insert(names.end(), simpleFunctions.begin(), simpleFunctions.end());

exec::vectorFunctionFactories().withRLock([&](const auto& map) {
for (const auto& [name, _] : map) {
names.push_back(name);
}
});

return names;
}

} // namespace sparksql
} // namespace facebook::velox::functions
3 changes: 3 additions & 0 deletions velox/functions/sparksql/Register.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
#pragma once

#include <string>
#include <vector>

namespace facebook::velox::functions::sparksql {

void registerFunctions(const std::string& prefix);

std::vector<std::string> listFunctionNames();

} // namespace facebook::velox::functions::sparksql
11 changes: 11 additions & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,14 @@ target_link_libraries(
GTest::gtest_main
GTest::gmock
gflags::gflags)

add_executable(velox_spark_function_registry_test RegisterTest.cpp)

add_test(velox_spark_function_registry_test velox_spark_function_registry_test)

target_link_libraries(
velox_spark_function_registry_test
velox_expression
velox_functions_spark
GTest::gtest
GTest::gtest_main)
62 changes: 62 additions & 0 deletions velox/functions/sparksql/tests/RegisterTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/functions/sparksql/Register.h"

#include <gtest/gtest.h>

#include "velox/expression/CastExpr.h"
#include "velox/expression/SpecialFormRegistry.h"
#include "velox/functions/Registerer.h"
#include "velox/functions/tests/RegistryTestUtil.h"

namespace facebook::velox::functions::sparksql::test {

VELOX_DECLARE_VECTOR_FUNCTION(
udf_vector_func_one,
VectorFuncOne::signatures(),
std::make_unique<VectorFuncOne>());

class RegisterTest : public testing::Test {
public:
RegisterTest() {
registerFunction<FuncTwo, int64_t, double, double>(
{"func_two_double", "Func_Two_Double_Alias"});
registerFunction<FuncTwo, int64_t, int64_t, int64_t>({"func_two_bigint"});

VELOX_REGISTER_VECTOR_FUNCTION(udf_vector_func_one, "vector_func_one");
VELOX_REGISTER_VECTOR_FUNCTION(
udf_vector_func_one, "Vector_Func_One_Alias");

exec::registerFunctionCallToSpecialForm(
"cast", std::make_unique<exec::CastCallToSpecialForm>());
}
};

TEST_F(RegisterTest, listFunctionNames) {
auto names = listFunctionNames();
EXPECT_EQ(names.size(), 6);
std::sort(names.begin(), names.end());

EXPECT_EQ(names[0], "cast");
EXPECT_EQ(names[1], "func_two_bigint");
EXPECT_EQ(names[2], "func_two_double");
EXPECT_EQ(names[3], "func_two_double_alias");
EXPECT_EQ(names[4], "vector_func_one");
EXPECT_EQ(names[5], "vector_func_one_alias");
}

} // namespace facebook::velox::functions::sparksql::test
155 changes: 1 addition & 154 deletions velox/functions/tests/FunctionRegistryTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,166 +26,13 @@
#include "velox/functions/Registerer.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/functions/tests/RegistryTestUtil.h"
#include "velox/type/Type.h"

namespace facebook::velox {

namespace {

template <typename T>
struct FuncOne {
VELOX_DEFINE_FUNCTION_TYPES(T);

// Set func_one as non-deterministic.
static constexpr bool is_deterministic = false;

FOLLY_ALWAYS_INLINE bool call(
out_type<velox::Varchar>& /* result */,
const arg_type<velox::Varchar>& /* arg1 */) {
return true;
}
};

template <typename T>
struct FuncTwo {
template <typename T1, typename T2>
FOLLY_ALWAYS_INLINE bool callNullable(
int64_t& /* result */,
const T1* /* arg1 */,
const T2* /* arg2 */) {
return true;
}
};

template <typename T>
struct FuncThree {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE bool call(
ArrayWriter<int64_t>& /* result */,
const ArrayVal<int64_t>& /* arg1 */) {
return true;
}
};

template <typename T>
struct FuncFour {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE bool call(
out_type<velox::Varchar>& /* result */,
const arg_type<velox::Varchar>& /* arg1 */) {
return true;
}
};

template <typename T>
struct FuncFive {
FOLLY_ALWAYS_INLINE bool call(int64_t& result, const int64_t& /* arg1 */) {
result = 5;
return true;
}
};

// FuncSix has the same signature as FuncFive. It's used to test overwrite
// during registration.
template <typename T>
struct FuncSix {
FOLLY_ALWAYS_INLINE bool call(int64_t& result, const int64_t& /* arg1 */) {
result = 6;
return true;
}
};

template <typename T>
struct VariadicFunc {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE bool call(
out_type<velox::Varchar>& /* result */,
const arg_type<Variadic<velox::Varchar>>& /* arg1 */) {
return true;
}
};

class VectorFuncOne : public velox::exec::VectorFunction {
public:
void apply(
const velox::SelectivityVector& /* rows */,
std::vector<velox::VectorPtr>& /* args */,
const TypePtr& /* outputType */,
velox::exec::EvalCtx& /* context */,
velox::VectorPtr& /* result */) const override {}

static std::vector<std::shared_ptr<velox::exec::FunctionSignature>>
signatures() {
// varchar -> bigint
return {velox::exec::FunctionSignatureBuilder()
.returnType("bigint")
.argumentType("varchar")
.build()};
}
};

class VectorFuncTwo : public velox::exec::VectorFunction {
public:
void apply(
const velox::SelectivityVector& /* rows */,
std::vector<velox::VectorPtr>& /* args */,
const TypePtr& /* outputType */,
velox::exec::EvalCtx& /* context */,
velox::VectorPtr& /* result */) const override {}

static std::vector<std::shared_ptr<velox::exec::FunctionSignature>>
signatures() {
// array(varchar) -> array(bigint)
return {velox::exec::FunctionSignatureBuilder()
.returnType("array(bigint)")
.argumentType("array(varchar)")
.build()};
}
};

class VectorFuncThree : public velox::exec::VectorFunction {
public:
void apply(
const velox::SelectivityVector& /* rows */,
std::vector<velox::VectorPtr>& /* args */,
const TypePtr& /* outputType */,
velox::exec::EvalCtx& /* context */,
velox::VectorPtr& /* result */) const override {}

static std::vector<std::shared_ptr<velox::exec::FunctionSignature>>
signatures() {
// ... -> opaque
return {velox::exec::FunctionSignatureBuilder()
.returnType("opaque")
.argumentType("any")
.build()};
}
};

class VectorFuncFour : public velox::exec::VectorFunction {
public:
void apply(
const velox::SelectivityVector& /* rows */,
std::vector<velox::VectorPtr>& /* args */,
const TypePtr& /* outputType */,
velox::exec::EvalCtx& /* context */,
velox::VectorPtr& /* result */) const override {}

static std::vector<std::shared_ptr<velox::exec::FunctionSignature>>
signatures() {
// map(K,V) -> array(K)
return {velox::exec::FunctionSignatureBuilder()
.knownTypeVariable("K")
.typeVariable("V")
.returnType("array(K)")
.argumentType("map(K,V)")
.build()};
}
};

VELOX_DECLARE_VECTOR_FUNCTION(
udf_vector_func_one,
VectorFuncOne::signatures(),
Expand Down
Loading

0 comments on commit 280154f

Please sign in to comment.