Skip to content

Commit

Permalink
Add ColumnPruning in injectBloomFilter
Browse files Browse the repository at this point in the history
  • Loading branch information
Yang Liu committed Apr 3, 2022
1 parent deac8f9 commit 35eb5d2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
}
val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None)
val alias = Alias(aggExp, "bloomFilter")()
val aggregate = ConstantFolding(Aggregate(Nil, Seq(alias), filterCreationSidePlan))
val aggregate =
ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias), filterCreationSidePlan)))
val bloomFilterSubquery = ScalarSubquery(aggregate, Nil)
val filter = BloomFilterMightContain(bloomFilterSubquery,
new XxHash64(Seq(filterApplicationSideExp)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp
planEnabled = sql(query).queryExecution.optimizedPlan
checkAnswer(sql(query), expectedAnswer)
if (shouldReplace) {
assert(!columnPruningTakesEffect(planEnabled))
assert(getNumBloomFilters(planEnabled) > getNumBloomFilters(planDisabled))
} else {
assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled))
Expand Down Expand Up @@ -288,6 +289,20 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp
numMightContains
}

def columnPruningTakesEffect(plan: LogicalPlan): Boolean = {
def takesEffect(plan: LogicalPlan): Boolean = {
val result = org.apache.spark.sql.catalyst.optimizer.ColumnPruning.apply(plan)
!result.fastEquals(plan)
}

plan.collectFirst {
case Filter(condition, _) if condition.collectFirst {
case subquery: org.apache.spark.sql.catalyst.expressions.ScalarSubquery
if takesEffect(subquery.plan) => true
}.nonEmpty => true
}.nonEmpty
}

def assertRewroteSemiJoin(query: String): Unit = {
checkWithAndWithoutFeatureEnabled(query, testSemiJoin = true, shouldReplace = true)
}
Expand Down

0 comments on commit 35eb5d2

Please sign in to comment.