diff --git a/velox/exec/MergeJoin.cpp b/velox/exec/MergeJoin.cpp index ddd6acdb41ba..7cd9c757eac6 100644 --- a/velox/exec/MergeJoin.cpp +++ b/velox/exec/MergeJoin.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ #include "velox/exec/MergeJoin.h" +#include #include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" #include "velox/expression/FieldReference.h" @@ -47,7 +48,9 @@ bool MergeJoin::isSupported(core::JoinType joinType) { case core::JoinType::kLeft: case core::JoinType::kRight: case core::JoinType::kLeftSemiFilter: + case core::JoinType::kLeftSemiProject: case core::JoinType::kRightSemiFilter: + case core::JoinType::kRightSemiProject: case core::JoinType::kAnti: return true; @@ -81,7 +84,8 @@ void MergeJoin::initialize() { } } - if (joinNode_->isRightSemiFilterJoin()) { + if (joinNode_->isRightSemiFilterJoin() || + joinNode_->isRightSemiProjectJoin()) { VELOX_USER_CHECK( leftProjections_.empty(), "The left side projections should be empty for right semi join"); @@ -95,7 +99,7 @@ void MergeJoin::initialize() { } } - if (joinNode_->isLeftSemiFilterJoin()) { + if (joinNode_->isLeftSemiFilterJoin() || joinNode_->isLeftSemiProjectJoin()) { VELOX_USER_CHECK( rightProjections_.empty(), "The right side projections should be empty for left semi join"); @@ -105,10 +109,13 @@ void MergeJoin::initialize() { initializeFilter(joinNode_->filter(), leftType, rightType); if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin() || - joinNode_->isRightJoin()) { + joinNode_->isRightJoin() || joinNode_->isLeftSemiProjectJoin() || + joinNode_->isRightSemiProjectJoin()) { joinTracker_ = JoinTracker(outputBatchSize_, pool()); } - } else if (joinNode_->isAntiJoin()) { + } else if ( + joinNode_->isAntiJoin() || joinNode_->isLeftSemiProjectJoin() || + joinNode_->isRightSemiProjectJoin()) { // Anti join needs to track the left side rows that have no match on the // right. joinTracker_ = JoinTracker(outputBatchSize_, pool()); @@ -288,7 +295,9 @@ void copyRow( void MergeJoin::addOutputRowForLeftJoin( const RowVectorPtr& left, vector_size_t leftIndex) { - VELOX_USER_CHECK(isLeftJoin(joinType_) || isAntiJoin(joinType_)); + VELOX_USER_CHECK( + isLeftJoin(joinType_) || isAntiJoin(joinType_) || + isLeftSemiProjectJoin(joinType_)); rawLeftIndices_[outputSize_] = leftIndex; for (const auto& projection : rightProjections_) { @@ -307,7 +316,7 @@ void MergeJoin::addOutputRowForLeftJoin( void MergeJoin::addOutputRowForRightJoin( const RowVectorPtr& right, vector_size_t rightIndex) { - VELOX_USER_CHECK(isRightJoin(joinType_)); + VELOX_USER_CHECK(isRightJoin(joinType_) || isRightSemiProjectJoin(joinType_)); rawRightIndices_[outputSize_] = rightIndex; for (const auto& projection : leftProjections_) { @@ -360,7 +369,7 @@ void MergeJoin::addOutputRow( copyRow(right, rightIndex, filterInput_, outputSize_, filterRightInputs_); if (joinTracker_) { - if (isRightJoin(joinType_)) { + if (isRightJoin(joinType_) || isRightSemiProjectJoin(joinType_)) { // Record right-side row with a match on the left-side. joinTracker_->addMatch(right, rightIndex, outputSize_); } else { @@ -372,12 +381,18 @@ void MergeJoin::addOutputRow( // Anti join needs to track the left side rows that have no match on the // right. - if (isAntiJoin(joinType_)) { + if (isAntiJoin(joinType_) || isLeftSemiProjectJoin(joinType_)) { VELOX_CHECK(joinTracker_); // Record left-side row with a match on the right-side. joinTracker_->addMatch(left, leftIndex, outputSize_); } + if (isRightSemiProjectJoin(joinType_)) { + VELOX_CHECK(joinTracker_); + // Record right-side row with a match on the left-side. + joinTracker_->addMatch(right, rightIndex, outputSize_); + } + ++outputSize_; } @@ -451,6 +466,12 @@ bool MergeJoin::prepareOutput( isRightFlattened_ = false; } currentRight_ = right; + if (isRightSemiProjectJoin(joinType_) || isLeftSemiProjectJoin(joinType_)) { + localColumns[outputType_->size() - 1] = BaseVector::create( + outputType_->childAt(outputType_->size() - 1), + outputBatchSize_, + operatorCtx_->pool()); + } output_ = std::make_shared( operatorCtx_->pool(), @@ -458,6 +479,7 @@ bool MergeJoin::prepareOutput( nullptr, outputBatchSize_, std::move(localColumns)); + outputSize_ = 0; if (filterInput_ != nullptr) { @@ -555,7 +577,9 @@ bool MergeJoin::addToOutput() { // or data structures that short-circuit the join process once a match // is found. if (isLeftSemiFilterJoin(joinType_) || - isRightSemiFilterJoin(joinType_)) { + isRightSemiFilterJoin(joinType_) || + isLeftSemiProjectJoin(joinType_) || + isRightSemiProjectJoin(joinType_)) { // LeftSemiFilter produce each row from the left at most once. // RightSemiFilter produce each row from the right at most once. rightEnd = rightStart + 1; @@ -638,6 +662,26 @@ RowVectorPtr MergeJoin::filterOutputForAntiJoin(const RowVectorPtr& output) { return wrap(numPassed, indices, output); } +RowVectorPtr MergeJoin::filterOutputForSemiProject(const RowVectorPtr& output) { + auto numRows = output->size(); + const auto& filterRows = joinTracker_->matchingRows(numRows); + + auto lastChildren = output->children().back(); + auto flatMatch = lastChildren->as>(); + flatMatch->resize(numRows); + auto rawValues = flatMatch->mutableRawValues(); + + for (auto i = 0; i < numRows; i++) { + if (filterRows.isValid(i)) { + bits::setBit(rawValues, i, true); + } else { + bits::setBit(rawValues, i, false); + } + } + + return output; +} + RowVectorPtr MergeJoin::getOutput() { // Make sure to have is-blocked or needs-input as true if returning null // output. Otherwise, Driver assumes the operator is finished. @@ -649,6 +693,7 @@ RowVectorPtr MergeJoin::getOutput() { for (;;) { auto output = doGetOutput(); + if (output != nullptr && output->size() > 0) { if (filter_) { output = applyFilter(output); @@ -663,6 +708,10 @@ RowVectorPtr MergeJoin::getOutput() { continue; } else if (isAntiJoin(joinType_)) { return filterOutputForAntiJoin(output); + } else if ( + isLeftSemiProjectJoin(joinType_) || + isRightSemiProjectJoin(joinType_)) { + return filterOutputForSemiProject(output); } else { return output; } @@ -768,7 +817,8 @@ RowVectorPtr MergeJoin::doGetOutput() { } if (!input_ || !rightInput_) { - if (isLeftJoin(joinType_) || isAntiJoin(joinType_)) { + if (isLeftJoin(joinType_) || isAntiJoin(joinType_) || + isLeftSemiProjectJoin(joinType_)) { if (input_ && noMoreRightInput_) { // If output_ is currently wrapping a different buffer, return it // first. @@ -795,7 +845,7 @@ RowVectorPtr MergeJoin::doGetOutput() { output_->resize(outputSize_); return std::move(output_); } - } else if (isRightJoin(joinType_)) { + } else if (isRightJoin(joinType_) || isRightSemiProjectJoin(joinType_)) { if (rightInput_ && noMoreInput_) { // If output_ is currently wrapping a different buffer, return it // first. @@ -844,7 +894,8 @@ RowVectorPtr MergeJoin::doGetOutput() { for (;;) { // Catch up input_ with rightInput_. while (compareResult < 0) { - if (isLeftJoin(joinType_) || isAntiJoin(joinType_)) { + if (isLeftJoin(joinType_) || isAntiJoin(joinType_) || + isLeftSemiProjectJoin(joinType_)) { // If output_ is currently wrapping a different buffer, return it // first. if (prepareOutput(input_, nullptr)) { @@ -871,7 +922,7 @@ RowVectorPtr MergeJoin::doGetOutput() { // Catch up rightInput_ with input_. while (compareResult > 0) { - if (isRightJoin(joinType_)) { + if (isRightJoin(joinType_) || isRightSemiProjectJoin(joinType_)) { // If output_ is currently wrapping a different buffer, return it // first. if (prepareOutput(nullptr, rightInput_)) { @@ -972,18 +1023,37 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { if (!filterRows.hasSelections()) { // No matches in the output, no need to evaluate the filter. - return output; + if (isLeftSemiProjectJoin(joinType_) || + isRightSemiProjectJoin(joinType_)) { + return filterOutputForSemiProject(output); + } else { + return output; + } } evaluateFilter(filterRows); + FlatVector* flatMatch{nullptr}; + uint64_t* rawValues; + + if (isLeftSemiProjectJoin(joinType_) || isRightSemiProjectJoin(joinType_)) { + flatMatch = output->children().back()->as>(); + flatMatch->resize(numRows); + rawValues = flatMatch->mutableRawValues(); + } + // If all matches for a given left-side row fail the filter, add a row to // the output with nulls for the right-side columns. auto onMiss = [&](auto row) { if (!isAntiJoin(joinType_)) { rawIndices[numPassed++] = row; - if (!isRightJoin(joinType_)) { + if (isLeftSemiProjectJoin(joinType_) || + isRightSemiProjectJoin(joinType_)) { + bits::setBit(rawValues, row, false); + } + + if (!isRightJoin(joinType_) && !isRightSemiProjectJoin(joinType_)) { for (auto& projection : rightProjections_) { auto target = output->childAt(projection.outputChannel); target->setNull(row, true); @@ -1010,6 +1080,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { } } else { if (passed) { + if (isLeftSemiProjectJoin(joinType_) || + isRightSemiProjectJoin(joinType_)) { + bits::setBit(rawValues, i, true); + } rawIndices[numPassed++] = i; } } @@ -1017,6 +1091,11 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { // This row doesn't have a match on the right side. Keep it // unconditionally. rawIndices[numPassed++] = i; + + if (isLeftSemiProjectJoin(joinType_) || + isRightSemiProjectJoin(joinType_)) { + bits::setBit(rawValues, i, false); + } } } diff --git a/velox/exec/MergeJoin.h b/velox/exec/MergeJoin.h index 08021c70f8a0..490516c4d672 100644 --- a/velox/exec/MergeJoin.h +++ b/velox/exec/MergeJoin.h @@ -241,6 +241,8 @@ class MergeJoin : public Operator { /// rows from the left side that have a match on the right. RowVectorPtr filterOutputForAntiJoin(const RowVectorPtr& output); + RowVectorPtr filterOutputForSemiProject(const RowVectorPtr& output); + /// As we populate the results of the join, we track whether a given /// output row is a result of a match between left and right sides or a miss. /// We use JoinTracker::addMatch and addMiss methods for that. diff --git a/velox/exec/tests/MergeJoinTest.cpp b/velox/exec/tests/MergeJoinTest.cpp index a91e62ca7b17..c19193106dbf 100644 --- a/velox/exec/tests/MergeJoinTest.cpp +++ b/velox/exec/tests/MergeJoinTest.cpp @@ -649,6 +649,60 @@ TEST_F(MergeJoinTest, semiJoin) { core::JoinType::kRightSemiFilter); } +TEST_F(MergeJoinTest, semiJoinProjection) { + auto left = makeRowVector( + {"t0"}, {makeNullableFlatVector({1, 2, 2, 6, std::nullopt})}); + + auto right = makeRowVector( + {"u0"}, + {makeNullableFlatVector( + {1, 2, 2, 7, std::nullopt, std::nullopt})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto testSemiJoin = [&](const std::string& filter, + const std::string& sql, + const std::vector& outputLayout, + core::JoinType joinType) { + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + filter, + outputLayout, + joinType) + .planNode(); + AssertQueryBuilder(plan, duckDbQueryRunner_).assertResults(sql); + }; + + testSemiJoin( + "", + "SELECT t.t0, EXISTS (SELECT * FROM u WHERE t.t0 = u.u0) FROM t", + {"t0", "match"}, + core::JoinType::kLeftSemiProject); + testSemiJoin( + "", + "SELECT u0, u0 IN (SELECT * FROM t where t.t0 = u.u0) FROM u", + {"u0", "match"}, + core::JoinType::kRightSemiProject); + + testSemiJoin( + "t0 > 1", + "SELECT t.t0, EXISTS (SELECT * FROM u WHERE t0 = u0 and t.t0 > 1) FROM t", + {"t0", "match"}, + core::JoinType::kLeftSemiProject); + testSemiJoin( + "u0 > 1", + "SELECT u0, u0 IN (SELECT * FROM t where t0 = u0 and u0 > 1) FROM u", + {"u0", "match"}, + core::JoinType::kRightSemiProject); +} + TEST_F(MergeJoinTest, rightJoin) { auto left = makeRowVector( {"t0"}, diff --git a/velox/exec/tests/utils/PlanBuilder.cpp b/velox/exec/tests/utils/PlanBuilder.cpp index a1db6fb23544..5dcc569a74b2 100644 --- a/velox/exec/tests/utils/PlanBuilder.cpp +++ b/velox/exec/tests/utils/PlanBuilder.cpp @@ -1388,7 +1388,23 @@ PlanBuilder& PlanBuilder::mergeJoin( if (!filter.empty()) { filterExpr = parseExpr(filter, resultType, options_, pool_); } - auto outputType = extract(resultType, outputLayout); + RowTypePtr outputType; + if (isLeftSemiProjectJoin(joinType) || isRightSemiProjectJoin(joinType)) { + std::vector names = outputLayout; + + // Last column in 'outputLayout' must be a boolean 'match'. + std::vector types; + types.reserve(outputLayout.size()); + for (auto i = 0; i < outputLayout.size() - 1; ++i) { + types.emplace_back(resultType->findChild(outputLayout[i])); + } + types.emplace_back(BOOLEAN()); + + outputType = ROW(std::move(names), std::move(types)); + } else { + outputType = extract(resultType, outputLayout); + } + auto leftKeyFields = fields(leftType, leftKeys); auto rightKeyFields = fields(rightType, rightKeys);