Skip to content

Commit

Permalink
Add decimal types support to VELOX_DYNAMIC_TYPE_DISPATCH_IMPL (#2307)
Browse files Browse the repository at this point in the history
Summary:
A lack of VELOX_DYNAMIC_TYPE_DISPATCH_IMPL support for decimal types
requires specialization for these types at various code blocks.
This is now supported.

Pull Request resolved: #2307

Reviewed By: Yuhta

Differential Revision: D39020495

Pulled By: kevinwilfong

fbshipit-source-id: 9ddc34056c01fda0fc0769bcf17e44b033f2504b
  • Loading branch information
majetideepak authored and facebook-github-bot committed Sep 7, 2022
1 parent 0ae3c3c commit a7093b5
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 9 deletions.
20 changes: 20 additions & 0 deletions velox/functions/prestosql/aggregates/PrestoHasher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,26 @@ FOLLY_ALWAYS_INLINE void PrestoHasher::hash<TypeKind::DATE>(
});
}

template <>
FOLLY_ALWAYS_INLINE void PrestoHasher::hash<TypeKind::SHORT_DECIMAL>(
const SelectivityVector& rows,
BufferPtr& hashes) {
applyHashFunction(rows, *vector_.get(), hashes, [&](auto row) {
return hashInteger(
vector_->valueAt<UnscaledShortDecimal>(row).unscaledValue());
});
}

template <>
FOLLY_ALWAYS_INLINE void PrestoHasher::hash<TypeKind::LONG_DECIMAL>(
const SelectivityVector& rows,
BufferPtr& hashes) {
applyHashFunction(rows, *vector_.get(), hashes, [&](auto row) {
return hashInteger(
vector_->valueAt<UnscaledLongDecimal>(row).unscaledValue());
});
}

template <>
FOLLY_ALWAYS_INLINE void PrestoHasher::hash<TypeKind::INTERVAL_DAY_TIME>(
const SelectivityVector& rows,
Expand Down
19 changes: 17 additions & 2 deletions velox/functions/prestosql/aggregates/tests/PrestoHasherTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ class PrestoHasherTest : public testing::Test,
template <typename T>
void assertHash(
const std::vector<std::optional<T>>& data,
const std::vector<int64_t>& expected) {
auto vector = makeNullableFlatVector<T>(data);
const std::vector<int64_t>& expected,
const TypePtr& type = CppToType<T>::create()) {
auto vector = makeNullableFlatVector<T>(data, type);
assertHash(vector, expected);
}

Expand Down Expand Up @@ -174,6 +175,20 @@ TEST_F(PrestoHasherTest, date) {
{Date(0), Date(1000), std::nullopt}, {0, 2343331593029422743, 0});
}

TEST_F(PrestoHasherTest, unscaledShortDecimal) {
assertHash<UnscaledShortDecimal>(
{UnscaledShortDecimal(0), UnscaledShortDecimal(1000), std::nullopt},
{0, 2343331593029422743, 0},
DECIMAL(10, 5));
}

TEST_F(PrestoHasherTest, unscaledLongDecimal) {
assertHash<UnscaledLongDecimal>(
{UnscaledLongDecimal(0), UnscaledLongDecimal(1000), std::nullopt},
{0, 2343331593029422743, 0},
DECIMAL(20, 5));
}

TEST_F(PrestoHasherTest, doubles) {
assertHash<double>(
{1.0,
Expand Down
16 changes: 9 additions & 7 deletions velox/type/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -1299,6 +1299,14 @@ std::shared_ptr<const OpaqueType> OPAQUE() {
case ::facebook::velox::TypeKind::ROW: { \
return PREFIX<::facebook::velox::TypeKind::ROW> SUFFIX(__VA_ARGS__); \
} \
case ::facebook::velox::TypeKind::SHORT_DECIMAL: { \
return PREFIX<::facebook::velox::TypeKind::SHORT_DECIMAL> SUFFIX( \
__VA_ARGS__); \
} \
case ::facebook::velox::TypeKind::LONG_DECIMAL: { \
return PREFIX<::facebook::velox::TypeKind::LONG_DECIMAL> SUFFIX( \
__VA_ARGS__); \
} \
default: \
VELOX_FAIL("not a known type kind: {}", mapTypeKindToName(typeKind)); \
} \
Expand All @@ -1313,12 +1321,6 @@ std::shared_ptr<const OpaqueType> OPAQUE() {
return TEMPLATE_FUNC<::facebook::velox::TypeKind::UNKNOWN>(__VA_ARGS__); \
} else if ((typeKind) == ::facebook::velox::TypeKind::OPAQUE) { \
return TEMPLATE_FUNC<::facebook::velox::TypeKind::OPAQUE>(__VA_ARGS__); \
} else if (((typeKind) == ::facebook::velox::TypeKind::SHORT_DECIMAL)) { \
return TEMPLATE_FUNC<::facebook::velox::TypeKind::SHORT_DECIMAL>( \
__VA_ARGS__); \
} else if (((typeKind) == ::facebook::velox::TypeKind::LONG_DECIMAL)) { \
return TEMPLATE_FUNC<::facebook::velox::TypeKind::LONG_DECIMAL>( \
__VA_ARGS__); \
} else { \
return VELOX_DYNAMIC_TYPE_DISPATCH_IMPL( \
TEMPLATE_FUNC, , typeKind, __VA_ARGS__); \
Expand Down Expand Up @@ -1455,7 +1457,7 @@ std::shared_ptr<const Type> createType(
if (children.size() != 0) {
throw std::invalid_argument{
std::string(TypeTraits<KIND>::name) +
" primitive type takes no childern"};
" primitive type takes no children"};
}
static_assert(TypeTraits<KIND>::isPrimitiveType);
return ScalarType<KIND>::create();
Expand Down
38 changes: 38 additions & 0 deletions velox/type/Variant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,44 @@ struct VariantEquality<TypeKind::DATE> {
}
};

template <>
struct VariantEquality<TypeKind::SHORT_DECIMAL> {
template <bool NullEqualsNull>
static bool equals(const variant& a, const variant& b) {
const auto lhs = a.value<TypeKind::SHORT_DECIMAL>();
const auto rhs = b.value<TypeKind::SHORT_DECIMAL>();
const auto lType = DECIMAL(lhs.precision, lhs.scale);
const auto rType = DECIMAL(rhs.precision, rhs.scale);
if (!lType->equivalent(*rType)) {
return false;
}
if (a.isNull() || b.isNull()) {
return evaluateNullEquality<NullEqualsNull>(a, b);
} else {
return lhs.value() == rhs.value();
}
}
};

template <>
struct VariantEquality<TypeKind::LONG_DECIMAL> {
template <bool NullEqualsNull>
static bool equals(const variant& a, const variant& b) {
const auto lhs = a.value<TypeKind::LONG_DECIMAL>();
const auto rhs = b.value<TypeKind::LONG_DECIMAL>();
const auto lType = DECIMAL(lhs.precision, lhs.scale);
const auto rType = DECIMAL(rhs.precision, rhs.scale);
if (!lType->equivalent(*rType)) {
return false;
}
if (a.isNull() || b.isNull()) {
return evaluateNullEquality<NullEqualsNull>(a, b);
} else {
return lhs.value() == rhs.value();
}
}
};

// interval day time
template <>
struct VariantEquality<TypeKind::INTERVAL_DAY_TIME> {
Expand Down
12 changes: 12 additions & 0 deletions velox/type/tests/VariantTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ TEST(VariantTest, shortDecimal) {
// 0.1234 < 1.234
EXPECT_LT(
u2.value<TypeKind::SHORT_DECIMAL>(), v.value<TypeKind::SHORT_DECIMAL>());
EXPECT_TRUE(dispatchDynamicVariantEquality(v, v, false));
EXPECT_TRUE(dispatchDynamicVariantEquality(v, v, true));
EXPECT_FALSE(dispatchDynamicVariantEquality(v, u2, true));
}

TEST(VariantTest, shortDecimalHash) {
Expand Down Expand Up @@ -157,6 +160,9 @@ TEST(VariantTest, shortDecimalNull) {
// n and n3 have same precision and scale.
auto n3 = n;
EXPECT_EQ(nHash, n3.hash());

EXPECT_TRUE(dispatchDynamicVariantEquality(n, n, true));
EXPECT_FALSE(dispatchDynamicVariantEquality(n, n, false));
}

TEST(VariantTest, longDecimal) {
Expand All @@ -179,6 +185,9 @@ TEST(VariantTest, longDecimal) {
// 12.3456 > 12.345
EXPECT_LT(
v.value<TypeKind::LONG_DECIMAL>(), u2.value<TypeKind::LONG_DECIMAL>());
EXPECT_TRUE(dispatchDynamicVariantEquality(v, v, false));
EXPECT_TRUE(dispatchDynamicVariantEquality(v, v, true));
EXPECT_FALSE(dispatchDynamicVariantEquality(v, u2, true));
}

TEST(VariantTest, longDecimalHash) {
Expand Down Expand Up @@ -225,6 +234,9 @@ TEST(VariantTest, longDecimalNull) {
// n and n3 have same precision and scale.
auto n3 = n;
EXPECT_EQ(nHash, n3.hash());

EXPECT_TRUE(dispatchDynamicVariantEquality(n, n, true));
EXPECT_FALSE(dispatchDynamicVariantEquality(n, n, false));
}

/// Test variant::equalsWithEpsilon by summing up large 64-bit integers (> 15
Expand Down

0 comments on commit a7093b5

Please sign in to comment.