Skip to content

Commit

Permalink
Add support for custom comparison in Presto's array_position UDF
Browse files Browse the repository at this point in the history
Summary:
Update Presto's array_position UDF to work with types that provide custom
comparison (both the 2 and 3 argument flavors).  We can reuse the implementations
for complex types, since that just uses the compare function provided by the Vector. 
With facebookincubator#11022 this just
invokes the Type's custom implementation.

Differential Revision: D64209619
  • Loading branch information
Kevin Wilfong authored and facebook-github-bot committed Oct 10, 2024
1 parent e1b8a41 commit 30fb119
Show file tree
Hide file tree
Showing 2 changed files with 325 additions and 55 deletions.
99 changes: 63 additions & 36 deletions velox/functions/prestosql/ArrayPosition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ inline bool isPrimitiveEqual(const T& lhs, const T& rhs) {

// Find the index of the first match for primitive types.
template <
bool useCustomComparison,
TypeKind kind,
typename std::enable_if_t<TypeTraits<kind>::isPrimitiveType, int> = 0>
typename std::enable_if_t<
!useCustomComparison && TypeTraits<kind>::isPrimitiveType,
int> = 0>
void applyTypedFirstMatch(
const SelectivityVector& rows,
DecodedVector& arrayDecoded,
Expand Down Expand Up @@ -122,8 +125,11 @@ void applyTypedFirstMatch(

// Find the index of the first match for complex types.
template <
bool useCustomComparison,
TypeKind kind,
typename std::enable_if_t<!TypeTraits<kind>::isPrimitiveType, int> = 0>
typename std::enable_if_t<
useCustomComparison || !TypeTraits<kind>::isPrimitiveType,
int> = 0>
void applyTypedFirstMatch(
const SelectivityVector& rows,
DecodedVector& arrayDecoded,
Expand Down Expand Up @@ -179,8 +185,11 @@ FOLLY_ALWAYS_INLINE void getLoopBoundary(

// Find the index of the instance-th match for primitive types.
template <
bool useCustomComparison,
TypeKind kind,
typename std::enable_if_t<TypeTraits<kind>::isPrimitiveType, int> = 0>
typename std::enable_if_t<
!useCustomComparison && TypeTraits<kind>::isPrimitiveType,
int> = 0>
void applyTypedWithInstance(
const SelectivityVector& rows,
exec::EvalCtx& context,
Expand Down Expand Up @@ -304,8 +313,11 @@ void applyTypedWithInstance(

// Find the index of the instance-th match for complex types.
template <
bool useCustomComparison,
TypeKind kind,
typename std::enable_if_t<!TypeTraits<kind>::isPrimitiveType, int> = 0>
typename std::enable_if_t<
useCustomComparison || !TypeTraits<kind>::isPrimitiveType,
int> = 0>
void applyTypedWithInstance(
const SelectivityVector& rows,
exec::EvalCtx& context,
Expand Down Expand Up @@ -359,6 +371,50 @@ void applyTypedWithInstance(
});
}

template <bool useCustomComparison>
void applyInternal(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
exec::EvalCtx& context,
VectorPtr& result) {
context.ensureWritable(rows, BIGINT(), result);
auto flatResult = result->asFlatVector<int64_t>();

exec::DecodedArgs decodedArgs(rows, args, context);
auto elements = decodedArgs.at(0)->base()->as<ArrayVector>()->elements();
exec::LocalSelectivityVector nestedRows(context, elements->size());
nestedRows.get()->setAll();
exec::LocalDecodedVector elementsHolder(
context, *elements, *nestedRows.get());

if (args.size() == 2) {
VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH(
applyTypedFirstMatch,
useCustomComparison,
args[1]->typeKind(),
rows,
*decodedArgs.at(0),
*elementsHolder.get(),
*decodedArgs.at(1),
*flatResult);
} else {
const auto& instanceVector = args[2];
VELOX_CHECK(instanceVector->type()->isBigint());

VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH(
applyTypedWithInstance,
useCustomComparison,
args[1]->typeKind(),
rows,
context,
*decodedArgs.at(0),
*elementsHolder.get(),
*decodedArgs.at(1),
*decodedArgs.at(2),
*flatResult);
}
}

class ArrayPositionFunction : public exec::VectorFunction {
public:
void apply(
Expand All @@ -373,39 +429,10 @@ class ArrayPositionFunction : public exec::VectorFunction {
VELOX_CHECK(arrayVector->type()->asArray().elementType()->kindEquals(
searchVector->type()));

context.ensureWritable(rows, BIGINT(), result);
auto flatResult = result->asFlatVector<int64_t>();

exec::DecodedArgs decodedArgs(rows, args, context);
auto elements = decodedArgs.at(0)->base()->as<ArrayVector>()->elements();
exec::LocalSelectivityVector nestedRows(context, elements->size());
nestedRows.get()->setAll();
exec::LocalDecodedVector elementsHolder(
context, *elements, *nestedRows.get());

if (args.size() == 2) {
VELOX_DYNAMIC_TYPE_DISPATCH(
applyTypedFirstMatch,
searchVector->typeKind(),
rows,
*decodedArgs.at(0),
*elementsHolder.get(),
*decodedArgs.at(1),
*flatResult);
if (searchVector->type()->providesCustomComparison()) {
applyInternal<true>(rows, args, context, result);
} else {
const auto& instanceVector = args[2];
VELOX_CHECK(instanceVector->type()->isBigint());

VELOX_DYNAMIC_TYPE_DISPATCH(
applyTypedWithInstance,
searchVector->typeKind(),
rows,
context,
*decodedArgs.at(0),
*elementsHolder.get(),
*decodedArgs.at(1),
*decodedArgs.at(2),
*flatResult);
applyInternal<false>(rows, args, context, result);
}
}

Expand Down
Loading

0 comments on commit 30fb119

Please sign in to comment.