diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index 35d0189f64651..a69cda25ef4f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -85,7 +85,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J } val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None) val alias = Alias(aggExp, "bloomFilter")() - val aggregate = ConstantFolding(Aggregate(Nil, Seq(alias), filterCreationSidePlan)) + val aggregate = + ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias), filterCreationSidePlan))) val bloomFilterSubquery = ScalarSubquery(aggregate, Nil) val filter = BloomFilterMightContain(bloomFilterSubquery, new XxHash64(Seq(filterApplicationSideExp))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala index 0da3667382c16..097a18cabd58c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala @@ -255,6 +255,7 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp planEnabled = sql(query).queryExecution.optimizedPlan checkAnswer(sql(query), expectedAnswer) if (shouldReplace) { + assert(!columnPruningTakesEffect(planEnabled)) assert(getNumBloomFilters(planEnabled) > getNumBloomFilters(planDisabled)) } else { assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled)) @@ -288,6 +289,20 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp numMightContains } + def columnPruningTakesEffect(plan: LogicalPlan): Boolean = { + def takesEffect(plan: LogicalPlan): Boolean = { + val result = org.apache.spark.sql.catalyst.optimizer.ColumnPruning.apply(plan) + !result.fastEquals(plan) + } + + plan.collectFirst { + case Filter(condition, _) if condition.collectFirst { + case subquery: org.apache.spark.sql.catalyst.expressions.ScalarSubquery + if takesEffect(subquery.plan) => true + }.nonEmpty => true + }.nonEmpty + } + def assertRewroteSemiJoin(query: String): Unit = { checkWithAndWithoutFeatureEnabled(query, testSemiJoin = true, shouldReplace = true) }