Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API for getting all registered Spark scalar function names #11196

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions velox/expression/SpecialFormRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ SpecialFormRegistry::getSpecialForm(const std::string& name) const {
return specialForm;
}

std::vector<std::string> SpecialFormRegistry::getSpecialFormNames() const {
std::vector<std::string> names;
registry_.withRLock([&](const auto& map) {
names.reserve(map.size());
for (const auto& [name, _] : map) {
names.push_back(name);
}
});
return names;
}

const SpecialFormRegistry& specialFormRegistry() {
return specialFormRegistryInternal();
}
Expand Down
2 changes: 2 additions & 0 deletions velox/expression/SpecialFormRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class SpecialFormRegistry {
FunctionCallToSpecialForm* FOLLY_NULLABLE
getSpecialForm(const std::string& name) const;

std::vector<std::string> getSpecialFormNames() const;

private:
folly::Synchronized<RegistryType> registry_;
};
Expand Down
17 changes: 17 additions & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,5 +520,22 @@ void registerFunctions(const std::string& prefix) {
Varchar>({prefix + "mask"});
}

std::vector<std::string> listFunctionNames() {
std::vector<std::string> names =
exec::specialFormRegistry().getSpecialFormNames();

const auto& simpleFunctions = exec::simpleFunctions().getFunctionNames();
names.insert(names.end(), simpleFunctions.begin(), simpleFunctions.end());

exec::vectorFunctionFactories().withRLock([&](const auto& map) {
names.reserve(names.size() + map.size());
for (const auto& [name, _] : map) {
names.push_back(name);
}
});

return names;
}

} // namespace sparksql
} // namespace facebook::velox::functions
5 changes: 5 additions & 0 deletions velox/functions/sparksql/Register.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@
#pragma once

#include <string>
#include <vector>

namespace facebook::velox::functions::sparksql {

void registerFunctions(const std::string& prefix);

/// Return all the registered scalar function names include simple functions,
/// vector functions and special forms.
std::vector<std::string> listFunctionNames();
Copy link
Contributor

@xiaoxmeng xiaoxmeng Oct 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/// Return all the registered function names include simple functions, vector functions and special forms.


} // namespace facebook::velox::functions::sparksql
11 changes: 11 additions & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,14 @@ target_link_libraries(
GTest::gtest_main
GTest::gmock
gflags::gflags)

add_executable(velox_spark_function_registry_test RegisterTest.cpp)

add_test(velox_spark_function_registry_test velox_spark_function_registry_test)

target_link_libraries(
velox_spark_function_registry_test
velox_expression
velox_functions_spark
GTest::gtest
GTest::gtest_main)
62 changes: 62 additions & 0 deletions velox/functions/sparksql/tests/RegisterTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/functions/sparksql/Register.h"

#include <gtest/gtest.h>

#include "velox/expression/CastExpr.h"
#include "velox/expression/SpecialFormRegistry.h"
#include "velox/functions/Registerer.h"
#include "velox/functions/tests/RegistryTestUtil.h"

namespace facebook::velox::functions::sparksql::test {

VELOX_DECLARE_VECTOR_FUNCTION(
udf_vector_func_one,
VectorFuncOne::signatures(),
std::make_unique<VectorFuncOne>());

class RegisterTest : public testing::Test {
public:
RegisterTest() {
registerFunction<FuncTwo, int64_t, double, double>(
{"func_two_double", "Func_Two_Double_Alias"});
registerFunction<FuncTwo, int64_t, int64_t, int64_t>({"func_two_bigint"});

VELOX_REGISTER_VECTOR_FUNCTION(udf_vector_func_one, "vector_func_one");
VELOX_REGISTER_VECTOR_FUNCTION(
udf_vector_func_one, "Vector_Func_One_Alias");

exec::registerFunctionCallToSpecialForm(
"cast", std::make_unique<exec::CastCallToSpecialForm>());
}
};

TEST_F(RegisterTest, listFunctionNames) {
auto names = listFunctionNames();
EXPECT_EQ(names.size(), 6);
std::sort(names.begin(), names.end());

EXPECT_EQ(names[0], "cast");
EXPECT_EQ(names[1], "func_two_bigint");
EXPECT_EQ(names[2], "func_two_double");
EXPECT_EQ(names[3], "func_two_double_alias");
EXPECT_EQ(names[4], "vector_func_one");
EXPECT_EQ(names[5], "vector_func_one_alias");
}

} // namespace facebook::velox::functions::sparksql::test
155 changes: 1 addition & 154 deletions velox/functions/tests/FunctionRegistryTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,166 +26,13 @@
#include "velox/functions/Registerer.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/functions/tests/RegistryTestUtil.h"
#include "velox/type/Type.h"

namespace facebook::velox {

namespace {

template <typename T>
struct FuncOne {
VELOX_DEFINE_FUNCTION_TYPES(T);

// Set func_one as non-deterministic.
static constexpr bool is_deterministic = false;

FOLLY_ALWAYS_INLINE bool call(
out_type<velox::Varchar>& /* result */,
const arg_type<velox::Varchar>& /* arg1 */) {
return true;
}
};

template <typename T>
struct FuncTwo {
template <typename T1, typename T2>
FOLLY_ALWAYS_INLINE bool callNullable(
int64_t& /* result */,
const T1* /* arg1 */,
const T2* /* arg2 */) {
return true;
}
};

template <typename T>
struct FuncThree {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE bool call(
ArrayWriter<int64_t>& /* result */,
const ArrayVal<int64_t>& /* arg1 */) {
return true;
}
};

template <typename T>
struct FuncFour {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE bool call(
out_type<velox::Varchar>& /* result */,
const arg_type<velox::Varchar>& /* arg1 */) {
return true;
}
};

template <typename T>
struct FuncFive {
FOLLY_ALWAYS_INLINE bool call(int64_t& result, const int64_t& /* arg1 */) {
result = 5;
return true;
}
};

// FuncSix has the same signature as FuncFive. It's used to test overwrite
// during registration.
template <typename T>
struct FuncSix {
FOLLY_ALWAYS_INLINE bool call(int64_t& result, const int64_t& /* arg1 */) {
result = 6;
return true;
}
};

template <typename T>
struct VariadicFunc {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE bool call(
out_type<velox::Varchar>& /* result */,
const arg_type<Variadic<velox::Varchar>>& /* arg1 */) {
return true;
}
};

class VectorFuncOne : public velox::exec::VectorFunction {
public:
void apply(
const velox::SelectivityVector& /* rows */,
std::vector<velox::VectorPtr>& /* args */,
const TypePtr& /* outputType */,
velox::exec::EvalCtx& /* context */,
velox::VectorPtr& /* result */) const override {}

static std::vector<std::shared_ptr<velox::exec::FunctionSignature>>
signatures() {
// varchar -> bigint
return {velox::exec::FunctionSignatureBuilder()
.returnType("bigint")
.argumentType("varchar")
.build()};
}
};

class VectorFuncTwo : public velox::exec::VectorFunction {
public:
void apply(
const velox::SelectivityVector& /* rows */,
std::vector<velox::VectorPtr>& /* args */,
const TypePtr& /* outputType */,
velox::exec::EvalCtx& /* context */,
velox::VectorPtr& /* result */) const override {}

static std::vector<std::shared_ptr<velox::exec::FunctionSignature>>
signatures() {
// array(varchar) -> array(bigint)
return {velox::exec::FunctionSignatureBuilder()
.returnType("array(bigint)")
.argumentType("array(varchar)")
.build()};
}
};

class VectorFuncThree : public velox::exec::VectorFunction {
public:
void apply(
const velox::SelectivityVector& /* rows */,
std::vector<velox::VectorPtr>& /* args */,
const TypePtr& /* outputType */,
velox::exec::EvalCtx& /* context */,
velox::VectorPtr& /* result */) const override {}

static std::vector<std::shared_ptr<velox::exec::FunctionSignature>>
signatures() {
// ... -> opaque
return {velox::exec::FunctionSignatureBuilder()
.returnType("opaque")
.argumentType("any")
.build()};
}
};

class VectorFuncFour : public velox::exec::VectorFunction {
public:
void apply(
const velox::SelectivityVector& /* rows */,
std::vector<velox::VectorPtr>& /* args */,
const TypePtr& /* outputType */,
velox::exec::EvalCtx& /* context */,
velox::VectorPtr& /* result */) const override {}

static std::vector<std::shared_ptr<velox::exec::FunctionSignature>>
signatures() {
// map(K,V) -> array(K)
return {velox::exec::FunctionSignatureBuilder()
.knownTypeVariable("K")
.typeVariable("V")
.returnType("array(K)")
.argumentType("map(K,V)")
.build()};
}
};

VELOX_DECLARE_VECTOR_FUNCTION(
udf_vector_func_one,
VectorFuncOne::signatures(),
Expand Down
Loading
Loading