diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala index 7fc89ecc88ba3..05513cddccb86 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec @@ -35,23 +34,26 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") val funcId_might_contain = new FunctionIdentifier("might_contain") - // Register 'bloom_filter_agg' to builtin. - FunctionRegistry.builtin.registerFunction(funcId_bloom_filter_agg, - new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), - (children: Seq[Expression]) => children.size match { - case 1 => new BloomFilterAggregate(children.head) - case 2 => new BloomFilterAggregate(children.head, children(1)) - case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) - }) - - // Register 'might_contain' to builtin. - FunctionRegistry.builtin.registerFunction(funcId_might_contain, - new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"), - (children: Seq[Expression]) => BloomFilterMightContain(children.head, children(1))) + override def beforeAll(): Unit = { + super.beforeAll() + // Register 'bloom_filter_agg' to builtin. + spark.sessionState.functionRegistry.registerFunction(funcId_bloom_filter_agg, + new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), + (children: Seq[Expression]) => children.size match { + case 1 => new BloomFilterAggregate(children.head) + case 2 => new BloomFilterAggregate(children.head, children(1)) + case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) + }) + + // Register 'might_contain' to builtin. + spark.sessionState.functionRegistry.registerFunction(funcId_might_contain, + new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"), + (children: Seq[Expression]) => BloomFilterMightContain(children.head, children(1))) + } override def afterAll(): Unit = { - FunctionRegistry.builtin.dropFunction(funcId_bloom_filter_agg) - FunctionRegistry.builtin.dropFunction(funcId_might_contain) + spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg) + spark.sessionState.functionRegistry.dropFunction(funcId_might_contain) super.afterAll() }