From a669168332a904a5b6d66da98d9b104e99088344 Mon Sep 17 00:00:00 2001 From: Wei He Date: Tue, 8 Oct 2024 14:20:27 -0700 Subject: [PATCH] Add API for getting all registered Spark scalar function names (#11196) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/11196 Differential Revision: D64052727 --- velox/expression/SpecialFormRegistry.cpp | 10 + velox/expression/SpecialFormRegistry.h | 2 + velox/functions/sparksql/Register.cpp | 16 ++ velox/functions/sparksql/Register.h | 3 + velox/functions/sparksql/tests/CMakeLists.txt | 13 ++ .../functions/sparksql/tests/RegisterTest.cpp | 63 +++++++ .../functions/tests/FunctionRegistryTest.cpp | 155 +-------------- velox/functions/tests/RegistryTestUtil.h | 178 ++++++++++++++++++ 8 files changed, 286 insertions(+), 154 deletions(-) create mode 100644 velox/functions/sparksql/tests/RegisterTest.cpp create mode 100644 velox/functions/tests/RegistryTestUtil.h diff --git a/velox/expression/SpecialFormRegistry.cpp b/velox/expression/SpecialFormRegistry.cpp index 3957981905b3..e4eae457c6f7 100644 --- a/velox/expression/SpecialFormRegistry.cpp +++ b/velox/expression/SpecialFormRegistry.cpp @@ -52,6 +52,16 @@ SpecialFormRegistry::getSpecialForm(const std::string& name) const { return specialForm; } +std::vector SpecialFormRegistry::getSpecialFormNames() const { + std::vector names; + registry_.withRLock([&](const auto& map) { + for (const auto& [name, _] : map) { + names.push_back(name); + } + }); + return names; +} + const SpecialFormRegistry& specialFormRegistry() { return specialFormRegistryInternal(); } diff --git a/velox/expression/SpecialFormRegistry.h b/velox/expression/SpecialFormRegistry.h index 6f92a48ba32d..167d37981750 100644 --- a/velox/expression/SpecialFormRegistry.h +++ b/velox/expression/SpecialFormRegistry.h @@ -41,6 +41,8 @@ class SpecialFormRegistry { FunctionCallToSpecialForm* FOLLY_NULLABLE getSpecialForm(const std::string& name) const; + std::vector getSpecialFormNames() const; + private: folly::Synchronized registry_; }; diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index 3172f7dc572e..5481964debf2 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -520,5 +520,21 @@ void registerFunctions(const std::string& prefix) { Varchar>({prefix + "mask"}); } +std::vector listFunctionNames() { + std::vector names = + exec::specialFormRegistry().getSpecialFormNames(); + + const auto& simpleFunctions = exec::simpleFunctions().getFunctionNames(); + names.insert(names.end(), simpleFunctions.begin(), simpleFunctions.end()); + + exec::vectorFunctionFactories().withRLock([&](const auto& map) { + for (const auto& [name, _] : map) { + names.push_back(name); + } + }); + + return names; +} + } // namespace sparksql } // namespace facebook::velox::functions diff --git a/velox/functions/sparksql/Register.h b/velox/functions/sparksql/Register.h index 6f25e7d49796..98db3b3b2af7 100644 --- a/velox/functions/sparksql/Register.h +++ b/velox/functions/sparksql/Register.h @@ -16,9 +16,12 @@ #pragma once #include +#include namespace facebook::velox::functions::sparksql { void registerFunctions(const std::string& prefix); +std::vector listFunctionNames(); + } // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index d0858bdf2842..8f33424558ce 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -71,3 +71,16 @@ target_link_libraries( GTest::gtest_main GTest::gmock gflags::gflags) + +add_executable(velox_spark_function_registry_test RegisterTest.cpp) + +add_test(velox_spark_function_registry_test velox_spark_function_registry_test) + +target_link_libraries( + velox_spark_function_registry_test + velox_expression + velox_functions_spark + GTest::gtest + GTest::gtest_main + GTest::gmock + gflags::gflags) diff --git a/velox/functions/sparksql/tests/RegisterTest.cpp b/velox/functions/sparksql/tests/RegisterTest.cpp new file mode 100644 index 000000000000..2174979735cc --- /dev/null +++ b/velox/functions/sparksql/tests/RegisterTest.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "velox/functions/sparksql/Register.h" + +#include "velox/expression/CastExpr.h" +#include "velox/expression/SpecialFormRegistry.h" +#include "velox/functions/Registerer.h" +#include "velox/functions/tests/RegistryTestUtil.h" + +namespace facebook::velox::functions::sparksql::test { + +VELOX_DECLARE_VECTOR_FUNCTION( + udf_vector_func_one, + VectorFuncOne::signatures(), + std::make_unique()); + +class RegisterTest : public testing::Test { + public: + RegisterTest() { + registerFunction( + {"func_two_double", "Func_Two_Double_Alias"}); + registerFunction({"func_two_bigint"}); + + VELOX_REGISTER_VECTOR_FUNCTION(udf_vector_func_one, "vector_func_one"); + VELOX_REGISTER_VECTOR_FUNCTION( + udf_vector_func_one, "Vector_Func_One_Alias"); + + exec::registerFunctionCallToSpecialForm( + "cast", std::make_unique()); + } +}; + +TEST_F(RegisterTest, listFunctionNames) { + auto names = listFunctionNames(); + EXPECT_EQ(names.size(), 6); + std::sort(names.begin(), names.end()); + + EXPECT_EQ(names[0], "cast"); + EXPECT_EQ(names[1], "func_two_bigint"); + EXPECT_EQ(names[2], "func_two_double"); + EXPECT_EQ(names[3], "func_two_double_alias"); + EXPECT_EQ(names[4], "vector_func_one"); + EXPECT_EQ(names[5], "vector_func_one_alias"); +} + +} // namespace facebook::velox::functions::sparksql::test diff --git a/velox/functions/tests/FunctionRegistryTest.cpp b/velox/functions/tests/FunctionRegistryTest.cpp index 6c10e18dc92d..0456c77c6fd7 100644 --- a/velox/functions/tests/FunctionRegistryTest.cpp +++ b/velox/functions/tests/FunctionRegistryTest.cpp @@ -26,166 +26,13 @@ #include "velox/functions/Registerer.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/functions/tests/RegistryTestUtil.h" #include "velox/type/Type.h" namespace facebook::velox { namespace { -template -struct FuncOne { - VELOX_DEFINE_FUNCTION_TYPES(T); - - // Set func_one as non-deterministic. - static constexpr bool is_deterministic = false; - - FOLLY_ALWAYS_INLINE bool call( - out_type& /* result */, - const arg_type& /* arg1 */) { - return true; - } -}; - -template -struct FuncTwo { - template - FOLLY_ALWAYS_INLINE bool callNullable( - int64_t& /* result */, - const T1* /* arg1 */, - const T2* /* arg2 */) { - return true; - } -}; - -template -struct FuncThree { - VELOX_DEFINE_FUNCTION_TYPES(T); - - FOLLY_ALWAYS_INLINE bool call( - ArrayWriter& /* result */, - const ArrayVal& /* arg1 */) { - return true; - } -}; - -template -struct FuncFour { - VELOX_DEFINE_FUNCTION_TYPES(T); - - FOLLY_ALWAYS_INLINE bool call( - out_type& /* result */, - const arg_type& /* arg1 */) { - return true; - } -}; - -template -struct FuncFive { - FOLLY_ALWAYS_INLINE bool call(int64_t& result, const int64_t& /* arg1 */) { - result = 5; - return true; - } -}; - -// FuncSix has the same signature as FuncFive. It's used to test overwrite -// during registration. -template -struct FuncSix { - FOLLY_ALWAYS_INLINE bool call(int64_t& result, const int64_t& /* arg1 */) { - result = 6; - return true; - } -}; - -template -struct VariadicFunc { - VELOX_DEFINE_FUNCTION_TYPES(T); - - FOLLY_ALWAYS_INLINE bool call( - out_type& /* result */, - const arg_type>& /* arg1 */) { - return true; - } -}; - -class VectorFuncOne : public velox::exec::VectorFunction { - public: - void apply( - const velox::SelectivityVector& /* rows */, - std::vector& /* args */, - const TypePtr& /* outputType */, - velox::exec::EvalCtx& /* context */, - velox::VectorPtr& /* result */) const override {} - - static std::vector> - signatures() { - // varchar -> bigint - return {velox::exec::FunctionSignatureBuilder() - .returnType("bigint") - .argumentType("varchar") - .build()}; - } -}; - -class VectorFuncTwo : public velox::exec::VectorFunction { - public: - void apply( - const velox::SelectivityVector& /* rows */, - std::vector& /* args */, - const TypePtr& /* outputType */, - velox::exec::EvalCtx& /* context */, - velox::VectorPtr& /* result */) const override {} - - static std::vector> - signatures() { - // array(varchar) -> array(bigint) - return {velox::exec::FunctionSignatureBuilder() - .returnType("array(bigint)") - .argumentType("array(varchar)") - .build()}; - } -}; - -class VectorFuncThree : public velox::exec::VectorFunction { - public: - void apply( - const velox::SelectivityVector& /* rows */, - std::vector& /* args */, - const TypePtr& /* outputType */, - velox::exec::EvalCtx& /* context */, - velox::VectorPtr& /* result */) const override {} - - static std::vector> - signatures() { - // ... -> opaque - return {velox::exec::FunctionSignatureBuilder() - .returnType("opaque") - .argumentType("any") - .build()}; - } -}; - -class VectorFuncFour : public velox::exec::VectorFunction { - public: - void apply( - const velox::SelectivityVector& /* rows */, - std::vector& /* args */, - const TypePtr& /* outputType */, - velox::exec::EvalCtx& /* context */, - velox::VectorPtr& /* result */) const override {} - - static std::vector> - signatures() { - // map(K,V) -> array(K) - return {velox::exec::FunctionSignatureBuilder() - .knownTypeVariable("K") - .typeVariable("V") - .returnType("array(K)") - .argumentType("map(K,V)") - .build()}; - } -}; - VELOX_DECLARE_VECTOR_FUNCTION( udf_vector_func_one, VectorFuncOne::signatures(), diff --git a/velox/functions/tests/RegistryTestUtil.h b/velox/functions/tests/RegistryTestUtil.h new file mode 100644 index 000000000000..60e1bf5929d9 --- /dev/null +++ b/velox/functions/tests/RegistryTestUtil.h @@ -0,0 +1,178 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/expression/FunctionSignature.h" +#include "velox/expression/VectorFunction.h" +#include "velox/functions/Macros.h" + +namespace facebook::velox { + +template +struct FuncOne { + VELOX_DEFINE_FUNCTION_TYPES(T); + + // Set func_one as non-deterministic. + static constexpr bool is_deterministic = false; + + FOLLY_ALWAYS_INLINE bool call( + out_type& /* result */, + const arg_type& /* arg1 */) { + return true; + } +}; + +template +struct FuncTwo { + template + FOLLY_ALWAYS_INLINE bool callNullable( + int64_t& /* result */, + const T1* /* arg1 */, + const T2* /* arg2 */) { + return true; + } +}; + +template +struct FuncThree { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE bool call( + ArrayWriter& /* result */, + const ArrayVal& /* arg1 */) { + return true; + } +}; + +template +struct FuncFour { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE bool call( + out_type& /* result */, + const arg_type& /* arg1 */) { + return true; + } +}; + +template +struct FuncFive { + FOLLY_ALWAYS_INLINE bool call(int64_t& result, const int64_t& /* arg1 */) { + result = 5; + return true; + } +}; + +// FuncSix has the same signature as FuncFive. It's used to test overwrite +// during registration. +template +struct FuncSix { + FOLLY_ALWAYS_INLINE bool call(int64_t& result, const int64_t& /* arg1 */) { + result = 6; + return true; + } +}; + +template +struct VariadicFunc { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE bool call( + out_type& /* result */, + const arg_type>& /* arg1 */) { + return true; + } +}; + +class VectorFuncOne : public velox::exec::VectorFunction { + public: + void apply( + const velox::SelectivityVector& /* rows */, + std::vector& /* args */, + const TypePtr& /* outputType */, + velox::exec::EvalCtx& /* context */, + velox::VectorPtr& /* result */) const override {} + + static std::vector> + signatures() { + // varchar -> bigint + return {velox::exec::FunctionSignatureBuilder() + .returnType("bigint") + .argumentType("varchar") + .build()}; + } +}; + +class VectorFuncTwo : public velox::exec::VectorFunction { + public: + void apply( + const velox::SelectivityVector& /* rows */, + std::vector& /* args */, + const TypePtr& /* outputType */, + velox::exec::EvalCtx& /* context */, + velox::VectorPtr& /* result */) const override {} + + static std::vector> + signatures() { + // array(varchar) -> array(bigint) + return {velox::exec::FunctionSignatureBuilder() + .returnType("array(bigint)") + .argumentType("array(varchar)") + .build()}; + } +}; + +class VectorFuncThree : public velox::exec::VectorFunction { + public: + void apply( + const velox::SelectivityVector& /* rows */, + std::vector& /* args */, + const TypePtr& /* outputType */, + velox::exec::EvalCtx& /* context */, + velox::VectorPtr& /* result */) const override {} + + static std::vector> + signatures() { + // ... -> opaque + return {velox::exec::FunctionSignatureBuilder() + .returnType("opaque") + .argumentType("any") + .build()}; + } +}; + +class VectorFuncFour : public velox::exec::VectorFunction { + public: + void apply( + const velox::SelectivityVector& /* rows */, + std::vector& /* args */, + const TypePtr& /* outputType */, + velox::exec::EvalCtx& /* context */, + velox::VectorPtr& /* result */) const override {} + + static std::vector> + signatures() { + // map(K,V) -> array(K) + return {velox::exec::FunctionSignatureBuilder() + .knownTypeVariable("K") + .typeVariable("V") + .returnType("array(K)") + .argumentType("map(K,V)") + .build()}; + } +}; + +} // namespace facebook::velox