Skip to content

Commit

Permalink
Fix nulls ordering for Range frames (facebookincubator#9271)
Browse files Browse the repository at this point in the history
Summary:
The nullsFirst flag weren't correctly passed in column comparisons for range frames. This caused range values for nulls to match those of adjacent rows for desc nulls last and asc nulls first order clauses.

Setting the CompareFlags also made me realize an improvement to simplify the code.

Fixes prestodb/presto#21889

Pull Request resolved: facebookincubator#9271

Reviewed By: Yuhta

Differential Revision: D55889369

Pulled By: kagamiori

fbshipit-source-id: 9ce7e97f49aa77f89b2e4cb66f2b4d11a2a6b59f
  • Loading branch information
aditi-pandit authored and facebook-github-bot committed Apr 16, 2024
1 parent beacf30 commit da6a3d3
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 53 deletions.
77 changes: 30 additions & 47 deletions velox/exec/WindowPartition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,53 +184,47 @@ std::pair<vector_size_t, vector_size_t> WindowPartition::computePeerBuffers(
// largest value < frame bound or (smallest value > frame bound for
// descending). After finding that point it does a sequential search for the
// required value.
template <bool isAscending>
vector_size_t WindowPartition::searchFrameValue(
bool firstMatch,
vector_size_t start,
vector_size_t end,
vector_size_t currentRow,
column_index_t orderByColumn,
column_index_t frameColumn) const {
column_index_t frameColumn,
const CompareFlags& flags) const {
auto current = partition_[currentRow];
bool crossedBound = false;
vector_size_t begin = start;
vector_size_t finish = end;
int compareResult;
while (finish - begin >= 2) {
auto mid = (begin + finish) / 2;
compareResult =
data_->compare(partition_[mid], current, orderByColumn, frameColumn);
if constexpr (isAscending) {
crossedBound = compareResult >= 0;
} else {
crossedBound = compareResult <= 0;
}
auto compareResult = data_->compare(
partition_[mid], current, orderByColumn, frameColumn, flags);

if (crossedBound) {
if (compareResult >= 0) {
// Search in the first half of the column.
finish = mid;
} else {
// Search in the second half of the column.
begin = mid;
}
}

return linearSearchFrameValue<isAscending>(
firstMatch, begin, end, currentRow, orderByColumn, frameColumn);
return linearSearchFrameValue(
firstMatch, begin, end, currentRow, orderByColumn, frameColumn, flags);
}

template <bool isAscending>
vector_size_t WindowPartition::linearSearchFrameValue(
bool firstMatch,
vector_size_t start,
vector_size_t end,
vector_size_t currentRow,
column_index_t orderByColumn,
column_index_t frameColumn) const {
column_index_t frameColumn,
const CompareFlags& flags) const {
auto current = partition_[currentRow];
bool crossedBound = false;
for (vector_size_t i = start; i < end; ++i) {
auto compareResult =
data_->compare(partition_[i], current, orderByColumn, frameColumn);
auto compareResult = data_->compare(
partition_[i], current, orderByColumn, frameColumn, flags);

// The bound value was found. Return if firstMatch required.
// If the last match is required, then we need to find the first row that
Expand All @@ -241,16 +235,11 @@ vector_size_t WindowPartition::linearSearchFrameValue(
}
}

if constexpr (isAscending) {
crossedBound = compareResult > 0;
} else {
crossedBound = compareResult < 0;
}
// Bound is crossed. Last match needs the previous row.
// But for first row matches, this is the first
// row that has crossed, but not equals boundary (The equal boundary case
// is covered by the condition above). So the bound matches this row itself.
if (crossedBound) {
if (compareResult > 0) {
if (firstMatch) {
return i;
} else {
Expand All @@ -264,10 +253,10 @@ vector_size_t WindowPartition::linearSearchFrameValue(
return end == numRows() ? numRows() + 1 : -1;
}

template <bool isAscending>
void WindowPartition::updateKRangeFrameBounds(
bool firstMatch,
bool isPreceding,
const CompareFlags& flags,
vector_size_t startRow,
vector_size_t numRows,
column_index_t frameColumn,
Expand Down Expand Up @@ -300,13 +289,14 @@ void WindowPartition::updateKRangeFrameBounds(
start = currentRow;
end = partition_.size();
}
rawFrameBounds[i] = searchFrameValue<isAscending>(
rawFrameBounds[i] = searchFrameValue(
firstMatch,
start,
end,
currentRow,
orderByColumn,
inputMapping_[frameColumn]);
inputMapping_[frameColumn],
flags);
}
}
}
Expand All @@ -319,26 +309,19 @@ void WindowPartition::computeKRangeFrameBounds(
vector_size_t numRows,
const vector_size_t* rawPeerBuffer,
vector_size_t* rawFrameBounds) const {
CompareFlags flags;
flags.ascending = sortKeyInfo_[0].second.isAscending();
flags.nullsFirst = sortKeyInfo_[0].second.isNullsFirst();
// Start bounds require first match. End bounds require last match.
if (sortKeyInfo_[0].second.isAscending()) {
updateKRangeFrameBounds<true>(
isStartBound,
isPreceding,
startRow,
numRows,
frameColumn,
rawPeerBuffer,
rawFrameBounds);
} else {
updateKRangeFrameBounds<false>(
isStartBound,
isPreceding,
startRow,
numRows,
frameColumn,
rawPeerBuffer,
rawFrameBounds);
}
updateKRangeFrameBounds(
isStartBound,
isPreceding,
flags,
startRow,
numRows,
frameColumn,
rawPeerBuffer,
rawFrameBounds);
}

} // namespace facebook::velox::exec
10 changes: 5 additions & 5 deletions velox/exec/WindowPartition.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,29 +133,29 @@ class WindowPartition {
// Searches for 'currentRow[frameColumn]' in 'orderByColumn' of rows between
// 'start' and 'end' in the partition. 'firstMatch' specifies if first or last
// row is matched.
template <bool isAscending>
vector_size_t searchFrameValue(
bool firstMatch,
vector_size_t start,
vector_size_t end,
vector_size_t currentRow,
column_index_t orderByColumn,
column_index_t frameColumn) const;
column_index_t frameColumn,
const CompareFlags& flags) const;

template <bool isAscending>
vector_size_t linearSearchFrameValue(
bool firstMatch,
vector_size_t start,
vector_size_t end,
vector_size_t currentRow,
column_index_t orderByColumn,
column_index_t frameColumn) const;
column_index_t frameColumn,
const CompareFlags& flags) const;

// Iterates over 'numBlockRows' and searches frame value for each row.
template <bool isAscending>
void updateKRangeFrameBounds(
bool firstMatch,
bool isPreceding,
const CompareFlags& flags,
vector_size_t startRow,
vector_size_t numRows,
column_index_t frameColumn,
Expand Down
36 changes: 35 additions & 1 deletion velox/functions/prestosql/window/tests/AggregateWindowTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ TEST_F(AggregateWindowTest, variableWidthAggregate) {

// Tests function with k RANGE PRECEDING (FOLLOWING) frames.
TEST_F(AggregateWindowTest, rangeFrames) {
for (const auto& function : kAggregateFunctions) {
auto aggregateFunctions = kAggregateFunctions;
aggregateFunctions.push_back("array_agg(c2)");
for (const auto& function : aggregateFunctions) {
// count function is skipped as DuckDB returns inconsistent results
// with Velox for rows with empty frames. Velox expects empty frames to
// return 0, but DuckDB returns null.
Expand All @@ -141,6 +143,38 @@ TEST_F(AggregateWindowTest, rangeFrames) {
}
}

TEST_F(AggregateWindowTest, rangeNullsOrder) {
auto c0 = makeNullableFlatVector<int64_t>({1, 2, 1, std::nullopt});
auto input = makeRowVector({c0});

std::string overClause = "order by c0 asc nulls last";
// This frame corresponds to range between 0 preceding and 0 following
// (since c0 is used for both the ORDER BY and range frame column).
// So for each row the frame corresponds to all rows with the same value.

// This test validates that the null value doesn't mix in the range of
// adjacent rows.
std::string frameClause = "range between c0 preceding and c0 following";
auto arr =
makeNullableArrayVector<int64_t>({{1, 1}, {2}, {1, 1}, {std::nullopt}});
auto expected = makeRowVector({c0, arr});

WindowTestBase::testWindowFunction(
{input}, "array_agg(c0)", overClause, frameClause, expected);

overClause = "order by c0 asc nulls first";
WindowTestBase::testWindowFunction(
{input}, "array_agg(c0)", overClause, frameClause, expected);

overClause = "order by c0 desc nulls last";
WindowTestBase::testWindowFunction(
{input}, "array_agg(c0)", overClause, frameClause, expected);

overClause = "order by c0 desc nulls first";
WindowTestBase::testWindowFunction(
{input}, "array_agg(c0)", overClause, frameClause, expected);
}

// Test for aggregates that return NULL as the default value for empty frames
// against DuckDb.
TEST_F(AggregateWindowTest, nullEmptyResult) {
Expand Down

0 comments on commit da6a3d3

Please sign in to comment.