diff --git a/velox/expression/ExprCompiler.cpp b/velox/expression/ExprCompiler.cpp index 94ce09fe3f20..66409506b894 100644 --- a/velox/expression/ExprCompiler.cpp +++ b/velox/expression/ExprCompiler.cpp @@ -37,6 +37,7 @@ using core::TypedExprPtr; const char* const kAnd = "and"; const char* const kOr = "or"; const char* const kRowConstructor = "row_constructor"; +const char* const kRowConstructorWithNull = "row_constructor_with_null"; struct ITypedExprHasher { size_t operator()(const ITypedExpr* expr) const { @@ -237,6 +238,26 @@ ExprPtr getRowConstructorExpr( trackCpuUsage); } +ExprPtr getRowConstructorWithNullExpr( + const core::QueryConfig& config, + const TypePtr& type, + std::vector&& compiledChildren, + bool trackCpuUsage) { + static auto rowConstructorVectorFunction = + vectorFunctionFactories().withRLock([&config](auto& functionMap) { + auto functionIterator = functionMap.find(exec::kRowConstructorWithNull); + return functionIterator->second.factory( + exec::kRowConstructorWithNull, {}, config); + }); + + return std::make_shared( + type, + std::move(compiledChildren), + rowConstructorVectorFunction, + "row_constructor_with_null", + trackCpuUsage); +} + ExprPtr getSpecialForm( const core::QueryConfig& config, const std::string& name, @@ -247,6 +268,10 @@ ExprPtr getSpecialForm( return getRowConstructorExpr( config, type, std::move(compiledChildren), trackCpuUsage); } + if (name == kRowConstructorWithNull) { + return getRowConstructorWithNullExpr( + config, type, std::move(compiledChildren), trackCpuUsage); + } // If we just check the output of constructSpecialForm we'll have moved // compiledChildren, and if the function isn't a special form we'll still need diff --git a/velox/functions/FunctionRegistry.cpp b/velox/functions/FunctionRegistry.cpp index 22a8ec94ca82..3087041d65f1 100644 --- a/velox/functions/FunctionRegistry.cpp +++ b/velox/functions/FunctionRegistry.cpp @@ -109,7 +109,8 @@ std::shared_ptr resolveCallableSpecialForm( const std::string& functionName, const std::vector& argTypes) { // TODO Replace with struct_pack - if (functionName == "row_constructor") { + if (functionName == "row_constructor" || + functionName == "row_constructor_with_null") { auto numInput = argTypes.size(); std::vector types(numInput); std::vector names(numInput); diff --git a/velox/functions/prestosql/CMakeLists.txt b/velox/functions/prestosql/CMakeLists.txt index e2e820452936..dade3728cf2a 100644 --- a/velox/functions/prestosql/CMakeLists.txt +++ b/velox/functions/prestosql/CMakeLists.txt @@ -46,6 +46,7 @@ add_library( Repeat.cpp Reverse.cpp RowFunction.cpp + RowFunctionWithNull.cpp Sequence.cpp SimpleComparisonMatcher.cpp Slice.cpp diff --git a/velox/functions/prestosql/RowFunctionWithNull.cpp b/velox/functions/prestosql/RowFunctionWithNull.cpp new file mode 100644 index 000000000000..fb650de6d3ca --- /dev/null +++ b/velox/functions/prestosql/RowFunctionWithNull.cpp @@ -0,0 +1,72 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/expression/Expr.h" +#include "velox/expression/VectorFunction.h" + +namespace facebook::velox::functions { +namespace { + +class RowFunctionWithNull : public exec::VectorFunction { + public: + void apply( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& outputType, + exec::EvalCtx& context, + VectorPtr& result) const override { + auto argsCopy = args; + + BufferPtr nulls = AlignedBuffer::allocate( + bits::nbytes(rows.size()), context.pool(), 1); + auto* nullsPtr = nulls->asMutable(); + auto cntNull = 0; + rows.applyToSelected([&](vector_size_t i) { + bits::clearNull(nullsPtr, i); + if (!bits::isBitNull(nullsPtr, i)) { + for (size_t c = 0; c < argsCopy.size(); c++) { + auto arg = argsCopy[c].get(); + if (arg->mayHaveNulls() && arg->isNullAt(i)) { + // If any argument of the struct is null, set the struct as null. + bits::setNull(nullsPtr, i, true); + cntNull++; + break; + } + } + } + }); + + RowVectorPtr row = std::make_shared( + context.pool(), + outputType, + nulls, + rows.size(), + std::move(argsCopy), + cntNull /*nullCount*/); + context.moveOrCopyResult(row, rows, result); + } + + bool isDefaultNullBehavior() const override { + return false; + } +}; +} // namespace + +VELOX_DECLARE_VECTOR_FUNCTION( + udf_concat_row_with_null, + std::vector>{}, + std::make_unique()); + +} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp b/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp index 61df9efbd2bb..fc114b5ddeab 100644 --- a/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp @@ -23,6 +23,8 @@ namespace facebook::velox::functions { void registerAllSpecialFormGeneralFunctions() { VELOX_REGISTER_VECTOR_FUNCTION(udf_in, "in"); VELOX_REGISTER_VECTOR_FUNCTION(udf_concat_row, "row_constructor"); + VELOX_REGISTER_VECTOR_FUNCTION( + udf_concat_row_with_null, "row_constructor_with_null"); registerIsNullFunction("is_null"); } diff --git a/velox/functions/prestosql/tests/ScalarFunctionRegTest.cpp b/velox/functions/prestosql/tests/ScalarFunctionRegTest.cpp index 507fde7e8f0a..1797d43e8691 100644 --- a/velox/functions/prestosql/tests/ScalarFunctionRegTest.cpp +++ b/velox/functions/prestosql/tests/ScalarFunctionRegTest.cpp @@ -56,6 +56,7 @@ TEST_F(ScalarFunctionRegTest, prefix) { scalarVectorFuncMap.erase("in"); scalarVectorFuncMap.erase("row_constructor"); scalarVectorFuncMap.erase("is_null"); + scalarVectorFuncMap.erase("row_constructor_with_null"); for (const auto& entry : scalarVectorFuncMap) { EXPECT_EQ(prefix, entry.first.substr(0, prefix.size()));