Skip to content

Commit

Permalink
Support semi projection join type in smj
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf committed Jul 12, 2024
1 parent 69538a6 commit 245606e
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 16 deletions.
109 changes: 94 additions & 15 deletions velox/exec/MergeJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "velox/exec/MergeJoin.h"
#include <iostream>
#include "velox/exec/OperatorUtils.h"
#include "velox/exec/Task.h"
#include "velox/expression/FieldReference.h"
Expand Down Expand Up @@ -47,7 +48,9 @@ bool MergeJoin::isSupported(core::JoinType joinType) {
case core::JoinType::kLeft:
case core::JoinType::kRight:
case core::JoinType::kLeftSemiFilter:
case core::JoinType::kLeftSemiProject:
case core::JoinType::kRightSemiFilter:
case core::JoinType::kRightSemiProject:
case core::JoinType::kAnti:
return true;

Expand Down Expand Up @@ -81,7 +84,8 @@ void MergeJoin::initialize() {
}
}

if (joinNode_->isRightSemiFilterJoin()) {
if (joinNode_->isRightSemiFilterJoin() ||
joinNode_->isRightSemiProjectJoin()) {
VELOX_USER_CHECK(
leftProjections_.empty(),
"The left side projections should be empty for right semi join");
Expand All @@ -95,7 +99,7 @@ void MergeJoin::initialize() {
}
}

if (joinNode_->isLeftSemiFilterJoin()) {
if (joinNode_->isLeftSemiFilterJoin() || joinNode_->isLeftSemiProjectJoin()) {
VELOX_USER_CHECK(
rightProjections_.empty(),
"The right side projections should be empty for left semi join");
Expand All @@ -105,10 +109,13 @@ void MergeJoin::initialize() {
initializeFilter(joinNode_->filter(), leftType, rightType);

if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin() ||
joinNode_->isRightJoin()) {
joinNode_->isRightJoin() || joinNode_->isLeftSemiProjectJoin() ||
joinNode_->isRightSemiProjectJoin()) {
joinTracker_ = JoinTracker(outputBatchSize_, pool());
}
} else if (joinNode_->isAntiJoin()) {
} else if (
joinNode_->isAntiJoin() || joinNode_->isLeftSemiProjectJoin() ||
joinNode_->isRightSemiProjectJoin()) {
// Anti join needs to track the left side rows that have no match on the
// right.
joinTracker_ = JoinTracker(outputBatchSize_, pool());
Expand Down Expand Up @@ -288,7 +295,9 @@ void copyRow(
void MergeJoin::addOutputRowForLeftJoin(
const RowVectorPtr& left,
vector_size_t leftIndex) {
VELOX_USER_CHECK(isLeftJoin(joinType_) || isAntiJoin(joinType_));
VELOX_USER_CHECK(
isLeftJoin(joinType_) || isAntiJoin(joinType_) ||
isLeftSemiProjectJoin(joinType_));
rawLeftIndices_[outputSize_] = leftIndex;

for (const auto& projection : rightProjections_) {
Expand All @@ -307,7 +316,7 @@ void MergeJoin::addOutputRowForLeftJoin(
void MergeJoin::addOutputRowForRightJoin(
const RowVectorPtr& right,
vector_size_t rightIndex) {
VELOX_USER_CHECK(isRightJoin(joinType_));
VELOX_USER_CHECK(isRightJoin(joinType_) || isRightSemiProjectJoin(joinType_));
rawRightIndices_[outputSize_] = rightIndex;

for (const auto& projection : leftProjections_) {
Expand Down Expand Up @@ -360,7 +369,7 @@ void MergeJoin::addOutputRow(
copyRow(right, rightIndex, filterInput_, outputSize_, filterRightInputs_);

if (joinTracker_) {
if (isRightJoin(joinType_)) {
if (isRightJoin(joinType_) || isRightSemiProjectJoin(joinType_)) {
// Record right-side row with a match on the left-side.
joinTracker_->addMatch(right, rightIndex, outputSize_);
} else {
Expand All @@ -372,12 +381,18 @@ void MergeJoin::addOutputRow(

// Anti join needs to track the left side rows that have no match on the
// right.
if (isAntiJoin(joinType_)) {
if (isAntiJoin(joinType_) || isLeftSemiProjectJoin(joinType_)) {
VELOX_CHECK(joinTracker_);
// Record left-side row with a match on the right-side.
joinTracker_->addMatch(left, leftIndex, outputSize_);
}

if (isRightSemiProjectJoin(joinType_)) {
VELOX_CHECK(joinTracker_);
// Record right-side row with a match on the left-side.
joinTracker_->addMatch(right, rightIndex, outputSize_);
}

++outputSize_;
}

Expand Down Expand Up @@ -451,13 +466,20 @@ bool MergeJoin::prepareOutput(
isRightFlattened_ = false;
}
currentRight_ = right;
if (isRightSemiProjectJoin(joinType_) || isLeftSemiProjectJoin(joinType_)) {
localColumns[outputType_->size() - 1] = BaseVector::create(
outputType_->childAt(outputType_->size() - 1),
outputBatchSize_,
operatorCtx_->pool());
}

output_ = std::make_shared<RowVector>(
operatorCtx_->pool(),
outputType_,
nullptr,
outputBatchSize_,
std::move(localColumns));

outputSize_ = 0;

if (filterInput_ != nullptr) {
Expand Down Expand Up @@ -555,7 +577,9 @@ bool MergeJoin::addToOutput() {
// or data structures that short-circuit the join process once a match
// is found.
if (isLeftSemiFilterJoin(joinType_) ||
isRightSemiFilterJoin(joinType_)) {
isRightSemiFilterJoin(joinType_) ||
isLeftSemiProjectJoin(joinType_) ||
isRightSemiProjectJoin(joinType_)) {
// LeftSemiFilter produce each row from the left at most once.
// RightSemiFilter produce each row from the right at most once.
rightEnd = rightStart + 1;
Expand Down Expand Up @@ -638,6 +662,26 @@ RowVectorPtr MergeJoin::filterOutputForAntiJoin(const RowVectorPtr& output) {
return wrap(numPassed, indices, output);
}

RowVectorPtr MergeJoin::filterOutputForSemiProject(const RowVectorPtr& output) {
auto numRows = output->size();
const auto& filterRows = joinTracker_->matchingRows(numRows);

auto lastChildren = output->children().back();
auto flatMatch = lastChildren->as<FlatVector<bool>>();
flatMatch->resize(numRows);
auto rawValues = flatMatch->mutableRawValues<uint64_t>();

for (auto i = 0; i < numRows; i++) {
if (filterRows.isValid(i)) {
bits::setBit(rawValues, i, true);
} else {
bits::setBit(rawValues, i, false);
}
}

return output;
}

RowVectorPtr MergeJoin::getOutput() {
// Make sure to have is-blocked or needs-input as true if returning null
// output. Otherwise, Driver assumes the operator is finished.
Expand All @@ -649,6 +693,7 @@ RowVectorPtr MergeJoin::getOutput() {

for (;;) {
auto output = doGetOutput();

if (output != nullptr && output->size() > 0) {
if (filter_) {
output = applyFilter(output);
Expand All @@ -663,6 +708,10 @@ RowVectorPtr MergeJoin::getOutput() {
continue;
} else if (isAntiJoin(joinType_)) {
return filterOutputForAntiJoin(output);
} else if (
isLeftSemiProjectJoin(joinType_) ||
isRightSemiProjectJoin(joinType_)) {
return filterOutputForSemiProject(output);
} else {
return output;
}
Expand Down Expand Up @@ -768,7 +817,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
}

if (!input_ || !rightInput_) {
if (isLeftJoin(joinType_) || isAntiJoin(joinType_)) {
if (isLeftJoin(joinType_) || isAntiJoin(joinType_) ||
isLeftSemiProjectJoin(joinType_)) {
if (input_ && noMoreRightInput_) {
// If output_ is currently wrapping a different buffer, return it
// first.
Expand All @@ -795,7 +845,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
output_->resize(outputSize_);
return std::move(output_);
}
} else if (isRightJoin(joinType_)) {
} else if (isRightJoin(joinType_) || isRightSemiProjectJoin(joinType_)) {
if (rightInput_ && noMoreInput_) {
// If output_ is currently wrapping a different buffer, return it
// first.
Expand Down Expand Up @@ -844,7 +894,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
for (;;) {
// Catch up input_ with rightInput_.
while (compareResult < 0) {
if (isLeftJoin(joinType_) || isAntiJoin(joinType_)) {
if (isLeftJoin(joinType_) || isAntiJoin(joinType_) ||
isLeftSemiProjectJoin(joinType_)) {
// If output_ is currently wrapping a different buffer, return it
// first.
if (prepareOutput(input_, nullptr)) {
Expand All @@ -871,7 +922,7 @@ RowVectorPtr MergeJoin::doGetOutput() {

// Catch up rightInput_ with input_.
while (compareResult > 0) {
if (isRightJoin(joinType_)) {
if (isRightJoin(joinType_) || isRightSemiProjectJoin(joinType_)) {
// If output_ is currently wrapping a different buffer, return it
// first.
if (prepareOutput(nullptr, rightInput_)) {
Expand Down Expand Up @@ -972,18 +1023,37 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {

if (!filterRows.hasSelections()) {
// No matches in the output, no need to evaluate the filter.
return output;
if (isLeftSemiProjectJoin(joinType_) ||
isRightSemiProjectJoin(joinType_)) {
return filterOutputForSemiProject(output);
} else {
return output;
}
}

evaluateFilter(filterRows);

FlatVector<bool>* flatMatch{nullptr};
uint64_t* rawValues;

if (isLeftSemiProjectJoin(joinType_) || isRightSemiProjectJoin(joinType_)) {
flatMatch = output->children().back()->as<FlatVector<bool>>();
flatMatch->resize(numRows);
rawValues = flatMatch->mutableRawValues<uint64_t>();
}

// If all matches for a given left-side row fail the filter, add a row to
// the output with nulls for the right-side columns.
auto onMiss = [&](auto row) {
if (!isAntiJoin(joinType_)) {
rawIndices[numPassed++] = row;

if (!isRightJoin(joinType_)) {
if (isLeftSemiProjectJoin(joinType_) ||
isRightSemiProjectJoin(joinType_)) {
bits::setBit(rawValues, row, false);
}

if (!isRightJoin(joinType_) && !isRightSemiProjectJoin(joinType_)) {
for (auto& projection : rightProjections_) {
auto target = output->childAt(projection.outputChannel);
target->setNull(row, true);
Expand All @@ -1010,13 +1080,22 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
}
} else {
if (passed) {
if (isLeftSemiProjectJoin(joinType_) ||
isRightSemiProjectJoin(joinType_)) {
bits::setBit(rawValues, i, true);
}
rawIndices[numPassed++] = i;
}
}
} else {
// This row doesn't have a match on the right side. Keep it
// unconditionally.
rawIndices[numPassed++] = i;

if (isLeftSemiProjectJoin(joinType_) ||
isRightSemiProjectJoin(joinType_)) {
bits::setBit(rawValues, i, false);
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions velox/exec/MergeJoin.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ class MergeJoin : public Operator {
/// rows from the left side that have a match on the right.
RowVectorPtr filterOutputForAntiJoin(const RowVectorPtr& output);

RowVectorPtr filterOutputForSemiProject(const RowVectorPtr& output);

/// As we populate the results of the join, we track whether a given
/// output row is a result of a match between left and right sides or a miss.
/// We use JoinTracker::addMatch and addMiss methods for that.
Expand Down
54 changes: 54 additions & 0 deletions velox/exec/tests/MergeJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,60 @@ TEST_F(MergeJoinTest, semiJoin) {
core::JoinType::kRightSemiFilter);
}

TEST_F(MergeJoinTest, semiJoinProjection) {
auto left = makeRowVector(
{"t0"}, {makeNullableFlatVector<int64_t>({1, 2, 2, 6, std::nullopt})});

auto right = makeRowVector(
{"u0"},
{makeNullableFlatVector<int64_t>(
{1, 2, 2, 7, std::nullopt, std::nullopt})});

createDuckDbTable("t", {left});
createDuckDbTable("u", {right});

auto testSemiJoin = [&](const std::string& filter,
const std::string& sql,
const std::vector<std::string>& outputLayout,
core::JoinType joinType) {
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
auto plan =
PlanBuilder(planNodeIdGenerator)
.values({left})
.mergeJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator).values({right}).planNode(),
filter,
outputLayout,
joinType)
.planNode();
AssertQueryBuilder(plan, duckDbQueryRunner_).assertResults(sql);
};

testSemiJoin(
"",
"SELECT t.t0, EXISTS (SELECT * FROM u WHERE t.t0 = u.u0) FROM t",
{"t0", "match"},
core::JoinType::kLeftSemiProject);
testSemiJoin(
"",
"SELECT u0, u0 IN (SELECT * FROM t where t.t0 = u.u0) FROM u",
{"u0", "match"},
core::JoinType::kRightSemiProject);

testSemiJoin(
"t0 > 1",
"SELECT t.t0, EXISTS (SELECT * FROM u WHERE t0 = u0 and t.t0 > 1) FROM t",
{"t0", "match"},
core::JoinType::kLeftSemiProject);
testSemiJoin(
"u0 > 1",
"SELECT u0, u0 IN (SELECT * FROM t where t0 = u0 and u0 > 1) FROM u",
{"u0", "match"},
core::JoinType::kRightSemiProject);
}

TEST_F(MergeJoinTest, rightJoin) {
auto left = makeRowVector(
{"t0"},
Expand Down
18 changes: 17 additions & 1 deletion velox/exec/tests/utils/PlanBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1388,7 +1388,23 @@ PlanBuilder& PlanBuilder::mergeJoin(
if (!filter.empty()) {
filterExpr = parseExpr(filter, resultType, options_, pool_);
}
auto outputType = extract(resultType, outputLayout);
RowTypePtr outputType;
if (isLeftSemiProjectJoin(joinType) || isRightSemiProjectJoin(joinType)) {
std::vector<std::string> names = outputLayout;

// Last column in 'outputLayout' must be a boolean 'match'.
std::vector<TypePtr> types;
types.reserve(outputLayout.size());
for (auto i = 0; i < outputLayout.size() - 1; ++i) {
types.emplace_back(resultType->findChild(outputLayout[i]));
}
types.emplace_back(BOOLEAN());

outputType = ROW(std::move(names), std::move(types));
} else {
outputType = extract(resultType, outputLayout);
}

auto leftKeyFields = fields(leftType, leftKeys);
auto rightKeyFields = fields(rightType, rightKeys);

Expand Down

0 comments on commit 245606e

Please sign in to comment.