From a7093b5cbffe673a3992842527e01e547c2e5feb Mon Sep 17 00:00:00 2001 From: Deepak Majeti Date: Wed, 7 Sep 2022 11:28:19 -0700 Subject: [PATCH] Add decimal types support to VELOX_DYNAMIC_TYPE_DISPATCH_IMPL (#2307) Summary: A lack of VELOX_DYNAMIC_TYPE_DISPATCH_IMPL support for decimal types requires specialization for these types at various code blocks. This is now supported. Pull Request resolved: https://github.com/facebookincubator/velox/pull/2307 Reviewed By: Yuhta Differential Revision: D39020495 Pulled By: kevinwilfong fbshipit-source-id: 9ddc34056c01fda0fc0769bcf17e44b033f2504b --- .../prestosql/aggregates/PrestoHasher.cpp | 20 ++++++++++ .../aggregates/tests/PrestoHasherTest.cpp | 19 +++++++++- velox/type/Type.h | 16 ++++---- velox/type/Variant.cpp | 38 +++++++++++++++++++ velox/type/tests/VariantTest.cpp | 12 ++++++ 5 files changed, 96 insertions(+), 9 deletions(-) diff --git a/velox/functions/prestosql/aggregates/PrestoHasher.cpp b/velox/functions/prestosql/aggregates/PrestoHasher.cpp index b7a88f76e6b5..b89808bf3e1a 100644 --- a/velox/functions/prestosql/aggregates/PrestoHasher.cpp +++ b/velox/functions/prestosql/aggregates/PrestoHasher.cpp @@ -108,6 +108,26 @@ FOLLY_ALWAYS_INLINE void PrestoHasher::hash( }); } +template <> +FOLLY_ALWAYS_INLINE void PrestoHasher::hash( + const SelectivityVector& rows, + BufferPtr& hashes) { + applyHashFunction(rows, *vector_.get(), hashes, [&](auto row) { + return hashInteger( + vector_->valueAt(row).unscaledValue()); + }); +} + +template <> +FOLLY_ALWAYS_INLINE void PrestoHasher::hash( + const SelectivityVector& rows, + BufferPtr& hashes) { + applyHashFunction(rows, *vector_.get(), hashes, [&](auto row) { + return hashInteger( + vector_->valueAt(row).unscaledValue()); + }); +} + template <> FOLLY_ALWAYS_INLINE void PrestoHasher::hash( const SelectivityVector& rows, diff --git a/velox/functions/prestosql/aggregates/tests/PrestoHasherTest.cpp b/velox/functions/prestosql/aggregates/tests/PrestoHasherTest.cpp index f0a4b2061851..10188b07bb46 100644 --- a/velox/functions/prestosql/aggregates/tests/PrestoHasherTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/PrestoHasherTest.cpp @@ -31,8 +31,9 @@ class PrestoHasherTest : public testing::Test, template void assertHash( const std::vector>& data, - const std::vector& expected) { - auto vector = makeNullableFlatVector(data); + const std::vector& expected, + const TypePtr& type = CppToType::create()) { + auto vector = makeNullableFlatVector(data, type); assertHash(vector, expected); } @@ -174,6 +175,20 @@ TEST_F(PrestoHasherTest, date) { {Date(0), Date(1000), std::nullopt}, {0, 2343331593029422743, 0}); } +TEST_F(PrestoHasherTest, unscaledShortDecimal) { + assertHash( + {UnscaledShortDecimal(0), UnscaledShortDecimal(1000), std::nullopt}, + {0, 2343331593029422743, 0}, + DECIMAL(10, 5)); +} + +TEST_F(PrestoHasherTest, unscaledLongDecimal) { + assertHash( + {UnscaledLongDecimal(0), UnscaledLongDecimal(1000), std::nullopt}, + {0, 2343331593029422743, 0}, + DECIMAL(20, 5)); +} + TEST_F(PrestoHasherTest, doubles) { assertHash( {1.0, diff --git a/velox/type/Type.h b/velox/type/Type.h index ecb709584c0a..b1cc034b1ba6 100644 --- a/velox/type/Type.h +++ b/velox/type/Type.h @@ -1299,6 +1299,14 @@ std::shared_ptr OPAQUE() { case ::facebook::velox::TypeKind::ROW: { \ return PREFIX<::facebook::velox::TypeKind::ROW> SUFFIX(__VA_ARGS__); \ } \ + case ::facebook::velox::TypeKind::SHORT_DECIMAL: { \ + return PREFIX<::facebook::velox::TypeKind::SHORT_DECIMAL> SUFFIX( \ + __VA_ARGS__); \ + } \ + case ::facebook::velox::TypeKind::LONG_DECIMAL: { \ + return PREFIX<::facebook::velox::TypeKind::LONG_DECIMAL> SUFFIX( \ + __VA_ARGS__); \ + } \ default: \ VELOX_FAIL("not a known type kind: {}", mapTypeKindToName(typeKind)); \ } \ @@ -1313,12 +1321,6 @@ std::shared_ptr OPAQUE() { return TEMPLATE_FUNC<::facebook::velox::TypeKind::UNKNOWN>(__VA_ARGS__); \ } else if ((typeKind) == ::facebook::velox::TypeKind::OPAQUE) { \ return TEMPLATE_FUNC<::facebook::velox::TypeKind::OPAQUE>(__VA_ARGS__); \ - } else if (((typeKind) == ::facebook::velox::TypeKind::SHORT_DECIMAL)) { \ - return TEMPLATE_FUNC<::facebook::velox::TypeKind::SHORT_DECIMAL>( \ - __VA_ARGS__); \ - } else if (((typeKind) == ::facebook::velox::TypeKind::LONG_DECIMAL)) { \ - return TEMPLATE_FUNC<::facebook::velox::TypeKind::LONG_DECIMAL>( \ - __VA_ARGS__); \ } else { \ return VELOX_DYNAMIC_TYPE_DISPATCH_IMPL( \ TEMPLATE_FUNC, , typeKind, __VA_ARGS__); \ @@ -1455,7 +1457,7 @@ std::shared_ptr createType( if (children.size() != 0) { throw std::invalid_argument{ std::string(TypeTraits::name) + - " primitive type takes no childern"}; + " primitive type takes no children"}; } static_assert(TypeTraits::isPrimitiveType); return ScalarType::create(); diff --git a/velox/type/Variant.cpp b/velox/type/Variant.cpp index 88014a19b972..924affc7bc9e 100644 --- a/velox/type/Variant.cpp +++ b/velox/type/Variant.cpp @@ -77,6 +77,44 @@ struct VariantEquality { } }; +template <> +struct VariantEquality { + template + static bool equals(const variant& a, const variant& b) { + const auto lhs = a.value(); + const auto rhs = b.value(); + const auto lType = DECIMAL(lhs.precision, lhs.scale); + const auto rType = DECIMAL(rhs.precision, rhs.scale); + if (!lType->equivalent(*rType)) { + return false; + } + if (a.isNull() || b.isNull()) { + return evaluateNullEquality(a, b); + } else { + return lhs.value() == rhs.value(); + } + } +}; + +template <> +struct VariantEquality { + template + static bool equals(const variant& a, const variant& b) { + const auto lhs = a.value(); + const auto rhs = b.value(); + const auto lType = DECIMAL(lhs.precision, lhs.scale); + const auto rType = DECIMAL(rhs.precision, rhs.scale); + if (!lType->equivalent(*rType)) { + return false; + } + if (a.isNull() || b.isNull()) { + return evaluateNullEquality(a, b); + } else { + return lhs.value() == rhs.value(); + } + } +}; + // interval day time template <> struct VariantEquality { diff --git a/velox/type/tests/VariantTest.cpp b/velox/type/tests/VariantTest.cpp index 3e9e4f21af44..d16e64b0b3b7 100644 --- a/velox/type/tests/VariantTest.cpp +++ b/velox/type/tests/VariantTest.cpp @@ -111,6 +111,9 @@ TEST(VariantTest, shortDecimal) { // 0.1234 < 1.234 EXPECT_LT( u2.value(), v.value()); + EXPECT_TRUE(dispatchDynamicVariantEquality(v, v, false)); + EXPECT_TRUE(dispatchDynamicVariantEquality(v, v, true)); + EXPECT_FALSE(dispatchDynamicVariantEquality(v, u2, true)); } TEST(VariantTest, shortDecimalHash) { @@ -157,6 +160,9 @@ TEST(VariantTest, shortDecimalNull) { // n and n3 have same precision and scale. auto n3 = n; EXPECT_EQ(nHash, n3.hash()); + + EXPECT_TRUE(dispatchDynamicVariantEquality(n, n, true)); + EXPECT_FALSE(dispatchDynamicVariantEquality(n, n, false)); } TEST(VariantTest, longDecimal) { @@ -179,6 +185,9 @@ TEST(VariantTest, longDecimal) { // 12.3456 > 12.345 EXPECT_LT( v.value(), u2.value()); + EXPECT_TRUE(dispatchDynamicVariantEquality(v, v, false)); + EXPECT_TRUE(dispatchDynamicVariantEquality(v, v, true)); + EXPECT_FALSE(dispatchDynamicVariantEquality(v, u2, true)); } TEST(VariantTest, longDecimalHash) { @@ -225,6 +234,9 @@ TEST(VariantTest, longDecimalNull) { // n and n3 have same precision and scale. auto n3 = n; EXPECT_EQ(nHash, n3.hash()); + + EXPECT_TRUE(dispatchDynamicVariantEquality(n, n, true)); + EXPECT_FALSE(dispatchDynamicVariantEquality(n, n, false)); } /// Test variant::equalsWithEpsilon by summing up large 64-bit integers (> 15