Skip to content

Commit

Permalink
Added support for complex type (Array, Map) in NeqFunction (#5186)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #5186

Similar to eqFunction,registered function of type <NeqFunction, bool, Generic<T1>, Generic<T1>>

Reviewed By: laithsakka

Differential Revision: D46540217

fbshipit-source-id: 94f05e42153a7312c83fecd483e2ebf1756107b6
  • Loading branch information
Arpit Porwal authored and facebook-github-bot committed Jun 13, 2023
1 parent 754c732 commit 0782dd9
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 9 deletions.
25 changes: 24 additions & 1 deletion velox/functions/prestosql/Comparisons.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ namespace facebook::velox::functions {
} \
};

VELOX_GEN_BINARY_EXPR(NeqFunction, lhs != rhs, bool);
VELOX_GEN_BINARY_EXPR(LtFunction, lhs < rhs, bool);
VELOX_GEN_BINARY_EXPR(GtFunction, lhs > rhs, bool);
VELOX_GEN_BINARY_EXPR(LteFunction, lhs <= rhs, bool);
Expand Down Expand Up @@ -86,6 +85,30 @@ struct EqFunction {
}
};

template <typename T>
struct NeqFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

// Used for primitive inputs.
template <typename TInput>
void call(bool& out, const TInput& lhs, const TInput& rhs) {
out = (lhs != rhs);
}

// For arbitrary nested complex types. Can return null.
bool call(
bool& out,
const arg_type<Generic<T1>>& lhs,
const arg_type<Generic<T1>>& rhs) {
if (EqFunction<T>().call(out, lhs, rhs)) {
out = !out;
return true;
} else {
return false;
}
}
};

template <typename T>
struct BetweenFunction {
template <typename TInput>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ void registerComparisonFunctions(const std::string& prefix) {

registerNonSimdizableScalar<NeqFunction, bool>({prefix + "neq"});
VELOX_REGISTER_VECTOR_FUNCTION(udf_simd_comparison_neq, prefix + "neq");
registerFunction<NeqFunction, bool, Generic<T1>, Generic<T1>>(
{prefix + "neq"});

registerNonSimdizableScalar<LtFunction, bool>({prefix + "lt"});
VELOX_REGISTER_VECTOR_FUNCTION(udf_simd_comparison_lt, prefix + "lt");
Expand Down
30 changes: 22 additions & 8 deletions velox/functions/prestosql/tests/ComparisonsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/Udf.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
Expand Down Expand Up @@ -369,22 +370,30 @@ TEST_F(ComparisonsTest, gtLtDecimal) {
runAndCompare("c1 <= c0", longDecimalsInputs, expectedGteLte);
};

TEST_F(ComparisonsTest, eqArray) {
TEST_F(ComparisonsTest, eqNeqArray) {
auto test =
[&](const std::optional<std::vector<std::optional<int64_t>>>& array1,
const std::optional<std::vector<std::optional<int64_t>>>& array2,
std::optional<bool> expected) {
auto vector1 = vectorMaker_.arrayVectorNullable<int64_t>({array1});
auto vector2 = vectorMaker_.arrayVectorNullable<int64_t>({array2});
auto result = evaluate<SimpleVector<bool>>(
auto eqResult = evaluate<SimpleVector<bool>>(
"c0 == c1", makeRowVector({vector1, vector2}));

ASSERT_EQ(expected.has_value(), !result->isNullAt(0));
auto neqResult = evaluate<SimpleVector<bool>>(
"c0 != c1", makeRowVector({vector1, vector2}));

ASSERT_EQ(expected.has_value(), !eqResult->isNullAt(0));
ASSERT_EQ(expected.has_value(), !neqResult->isNullAt(0));
if (expected.has_value()) {
ASSERT_EQ(expected.value(), result->valueAt(0));
// equals check
ASSERT_EQ(expected.value(), eqResult->valueAt(0));
// not equal check
ASSERT_EQ(!expected.value(), neqResult->valueAt(0));
}
};

// eq and neq function test
test(std::nullopt, std::nullopt, std::nullopt);
test(std::nullopt, {{1}}, std::nullopt);
test({{1}}, std::nullopt, std::nullopt);
Expand Down Expand Up @@ -414,24 +423,29 @@ TEST_F(ComparisonsTest, eqArray) {
std::nullopt);
}

TEST_F(ComparisonsTest, eqMap) {
TEST_F(ComparisonsTest, eqNeqMap) {
using map_t =
std::optional<std::vector<std::pair<int64_t, std::optional<int64_t>>>>;
auto test =
[&](const map_t& map1, const map_t& map2, std::optional<bool> expected) {
auto vector1 = makeNullableMapVector<int64_t, int64_t>({map1});
auto vector2 = makeNullableMapVector<int64_t, int64_t>({map2});

auto result = evaluate<SimpleVector<bool>>(
auto eqResult = evaluate<SimpleVector<bool>>(
"c0 == c1", makeRowVector({vector1, vector2}));

ASSERT_EQ(expected.has_value(), !result->isNullAt(0));
auto neqResult = evaluate<SimpleVector<bool>>(
"c0 != c1", makeRowVector({vector1, vector2}));

ASSERT_EQ(expected.has_value(), !eqResult->isNullAt(0));
ASSERT_EQ(expected.has_value(), !neqResult->isNullAt(0));
if (expected.has_value()) {
ASSERT_EQ(expected.value(), result->valueAt(0));
ASSERT_EQ(expected.value(), eqResult->valueAt(0));
ASSERT_EQ(!expected.value(), neqResult->valueAt(0));
}
};

// eq and neq function test
test({{{1, 2}, {3, 4}}}, {{{1, 2}, {3, 4}}}, true);

// Elements checked in sorted order.
Expand Down

0 comments on commit 0782dd9

Please sign in to comment.