Skip to content

Commit

Permalink
KE-29673 add segment prune function for bloom runtime filter
Browse files Browse the repository at this point in the history
fix min/max for UTF8String collection

valid the runtime filter if need when broadcast join is valid
  • Loading branch information
zgzzbws authored and hellozepp committed Aug 10, 2023
1 parent d476a8c commit e37ed33
Show file tree
Hide file tree
Showing 6 changed files with 401 additions and 59 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE
import org.apache.spark.sql.types.{BooleanType, DataType, StructType}

/**
* An internal function that returns aggregate operations(min, max and bloom filter) result
* for `structTypeExpression`, min and max results are employed to prune KE segments.
* So this design will only be available for KE, and the related issue is KE-29673.
* Same with the `BloomFilterMightContain` expression, this expression requires that
* `structTypeExpression` is either a constant value or an uncorrelated scalar sub-query.
*
* @param structTypeExpression the struct type including aggregate operations.
* @param valueExpression the application side target column expression.
* @param applicationSideAttrRef the attribute reference for `valueExpression`, this parameter will
* be used to construct `rangeRow` iff `valueExpression` is transformed
* to non AttributeReference type.
*/
case class BloomAndRangeFilterExpression(
structTypeExpression: Expression,
valueExpression: Expression,
applicationSideAttrRef: AttributeReference)
extends BinaryExpression with BloomRuntimeFilterHelper {

val MIN_INDEX = 0
val MAX_INDEX = 1
val BINARY_INDEX = 2

override def nullable: Boolean = true
override def left: Expression = structTypeExpression
override def right: Expression = valueExpression
override def prettyName: String = "bloom_and_range_filter"
override def dataType: DataType = BooleanType
def decoratedRight: Expression = new XxHash64(Seq(right))

override def checkInputDataTypes(): TypeCheckResult = {
left.dataType match {
case StructType(_) =>
structTypeExpression match {
case e : Expression if e.foldable => TypeCheckResult.TypeCheckSuccess
case subquery : PlanExpression[_] if !subquery.containsPattern(OUTER_REFERENCE) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
s"The bloom and range filter binary input to $prettyName " +
"should be either a constant value or a scalar sub-query expression")
}
case _ => TypeCheckResult.TypeCheckFailure(
s"Input to function $prettyName should be a StructType, " +
s"which includes aggregate operations for min, max and bloom filter, " +
s"but it's a [${left.dataType.catalogString}]")
}
}

override protected def withNewChildrenInternal(
newStructTypeExpression: Expression,
newValueExpression: Expression): BloomAndRangeFilterExpression =
copy(structTypeExpression = newStructTypeExpression, valueExpression = newValueExpression)

@transient private lazy val subQueryRowResult = {
structTypeExpression.eval().asInstanceOf[UnsafeRow]
}

@transient lazy val rangeRow: Seq[Expression] = {
val structFields = left.dataType.asInstanceOf[StructType].fields
val minDataType = structFields(MIN_INDEX).dataType
val min = subQueryRowResult.get(MIN_INDEX, minDataType)
val maxDataType = structFields(MAX_INDEX).dataType
val max = subQueryRowResult.get(MAX_INDEX, maxDataType)
if(min != null && max != null) {
val attrRef = valueExpression match {
case reference: AttributeReference =>
reference
case _ =>
applicationSideAttrRef
}
val gteExpress = GreaterThanOrEqual(attrRef, Literal(convertToScala(min, minDataType)))
val lteExpress = LessThanOrEqual(attrRef, Literal(convertToScala(max, maxDataType)))
Seq(gteExpress, lteExpress)
} else {
Seq()
}
}

@transient private lazy val bloomFilter = {
val bytes = subQueryRowResult.getBinary(BINARY_INDEX)
if(bytes == null) null else deserialize(bytes)
}

override def eval(input: InternalRow): Any = {
internalEval(input, bloomFilter, decoratedRight)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
internalDoGenCode(ctx, ev, bloomFilter, decoratedRight, dataType)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,11 @@

package org.apache.spark.sql.catalyst.expressions

import java.io.ByteArrayInputStream

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, JavaCode, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE
import org.apache.spark.sql.types._
import org.apache.spark.util.sketch.BloomFilter

/**
* An internal scalar function that returns the membership check result (either true or false)
Expand All @@ -40,7 +36,7 @@ import org.apache.spark.util.sketch.BloomFilter
*/
case class BloomFilterMightContain(
bloomFilterExpression: Expression,
valueExpression: Expression) extends BinaryExpression {
valueExpression: Expression) extends BinaryExpression with BloomRuntimeFilterHelper {

override def nullable: Boolean = true
override def left: Expression = bloomFilterExpression
Expand Down Expand Up @@ -82,35 +78,11 @@ case class BloomFilterMightContain(
}

override def eval(input: InternalRow): Any = {
if (bloomFilter == null) {
null
} else {
val value = valueExpression.eval(input)
if (value == null) null else bloomFilter.mightContainLong(value.asInstanceOf[Long])
}
internalEval(input, bloomFilter, valueExpression)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
if (bloomFilter == null) {
ev.copy(isNull = TrueLiteral, value = JavaCode.defaultLiteral(dataType))
} else {
val bf = ctx.addReferenceObj("bloomFilter", bloomFilter, classOf[BloomFilter].getName)
val valueEval = valueExpression.genCode(ctx)
ev.copy(code = code"""
${valueEval.code}
boolean ${ev.isNull} = ${valueEval.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $bf.mightContainLong((Long)${valueEval.value});
}""")
}
}

final def deserialize(bytes: Array[Byte]): BloomFilter = {
val in = new ByteArrayInputStream(bytes)
val bloomFilter = BloomFilter.readFrom(in)
in.close()
bloomFilter
internalDoGenCode(ctx, ev, bloomFilter, valueExpression, dataType)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import java.io.ByteArrayInputStream

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, JavaCode, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.sketch.BloomFilter

trait BloomRuntimeFilterHelper {

def internalEval(input: InternalRow, bloomFilter: BloomFilter,
evalExpression: Expression): Any = {
if (bloomFilter == null) {
null
} else {
val value = evalExpression.eval(input)
if (value == null) null else bloomFilter.mightContainLong(value.asInstanceOf[Long])
}
}

def internalDoGenCode(ctx: CodegenContext, ev: ExprCode,
bloomFilter: BloomFilter, evalExpression: Expression, dataType: DataType): ExprCode = {
if (bloomFilter == null) {
ev.copy(isNull = TrueLiteral, value = JavaCode.defaultLiteral(dataType))
} else {
val bf = ctx.addReferenceObj("bloomFilter", bloomFilter, classOf[BloomFilter].getName)
val valueEval = evalExpression.genCode(ctx)
ev.copy(code = code"""
${valueEval.code}
boolean ${ev.isNull} = ${valueEval.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $bf.mightContainLong((Long)${valueEval.value});
}""")
}
}

def deserialize(bytes: Array[Byte]): BloomFilter = {
val in = new ByteArrayInputStream(bytes)
val bloomFilter = BloomFilter.readFrom(in)
in.close()
bloomFilter
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions
import org.apache.spark.sql.catalyst.dsl.expressions.DslExpression
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate, Complete}
import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, PhysicalOperation}
Expand All @@ -28,8 +30,8 @@ import org.apache.spark.sql.types._

/**
* Insert a filter on one side of the join if the other side has a selective predicate.
* The filter could be an IN subquery (converted to a semi join), a bloom filter, or something
* else in the future.
* The filter could be an IN subquery (converted to a semi join), a bloom filter,
* a bloom and range filter, or something else in the future.
*/
object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with JoinSelectionHelper {

Expand All @@ -47,8 +49,10 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
filterApplicationSidePlan: LogicalPlan,
filterCreationSideExp: Expression,
filterCreationSidePlan: LogicalPlan): LogicalPlan = {
require(conf.runtimeFilterBloomFilterEnabled || conf.runtimeFilterSemiJoinReductionEnabled)
if (conf.runtimeFilterBloomFilterEnabled) {
require(conf.runtimeFilterBloomFilterEnabled || conf.runtimeFilterSemiJoinReductionEnabled
|| conf.runtimeFilterBloomFilterWithSegmentPruneEnabled)
if (conf.runtimeFilterBloomFilterEnabled
|| conf.runtimeFilterBloomFilterWithSegmentPruneEnabled) {
injectBloomFilter(
filterApplicationSideExp,
filterApplicationSidePlan,
Expand Down Expand Up @@ -82,14 +86,30 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
} else {
new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)))
}
val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None)
val alias = Alias(aggExp, "bloomFilter")()
val aggregate =
ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias), filterCreationSidePlan)))
val bloomFilterSubquery = ScalarSubquery(aggregate, Nil)
val filter = BloomFilterMightContain(bloomFilterSubquery,
new XxHash64(Seq(filterApplicationSideExp)))
Filter(filter, filterApplicationSidePlan)
val bloomAggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None)
if(conf.runtimeFilterBloomFilterWithSegmentPruneEnabled) {
val minAggExp = expressions.min(filterCreationSideExp)
val maxAggExp = expressions.max(filterCreationSideExp)
val aggStruct = CreateStruct(Seq(
minAggExp.as("min"),
maxAggExp.as("max"),
bloomAggExp.as("bloomFilter")))
val aggStructAlias = Alias(aggStruct, "columnAgg")()
val aggregate = Aggregate(Nil, Seq(aggStructAlias), filterCreationSidePlan)
val aggregatePlan = ConstantFolding(ColumnPruning(aggregate))
val bloomFilterSubquery = ScalarSubquery(aggregatePlan, Nil)
val filter = BloomAndRangeFilterExpression(bloomFilterSubquery, filterApplicationSideExp,
filterApplicationSideExp.asInstanceOf[AttributeReference])
Filter(filter, filterApplicationSidePlan)
} else {
val alias = Alias(bloomAggExp, "bloomFilter")()
val aggregate =
ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias), filterCreationSidePlan)))
val bloomFilterSubquery = ScalarSubquery(aggregate, Nil)
val filter = BloomFilterMightContain(bloomFilterSubquery,
new XxHash64(Seq(filterApplicationSideExp)))
Filter(filter, filterApplicationSidePlan)
}
}

private def injectInSubqueryFilter(
Expand Down Expand Up @@ -185,14 +205,23 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
hint: JoinHint): Boolean = {
findExpressionAndTrackLineageDown(filterApplicationSideExp,
filterApplicationSide).isDefined && isSelectiveFilterOverScan(filterCreationSide) &&
(isProbablyShuffleJoin(filterApplicationSide, filterCreationSide, hint) ||
probablyHasShuffle(filterApplicationSide)) &&
satisfyJoinConditionRequirement(filterApplicationSide, filterCreationSide, hint) &&
satisfyByteSizeRequirement(filterApplicationSide)
}

private def satisfyJoinConditionRequirement(
filterApplicationSide: LogicalPlan,
filterCreationSide: LogicalPlan,
hint: JoinHint): Boolean = {
conf.runtimeFilterBroadcastJoinConditionIgnored ||
(isProbablyShuffleJoin(filterApplicationSide, filterCreationSide, hint) ||
probablyHasShuffle(filterApplicationSide))
}

def hasRuntimeFilter(left: LogicalPlan, right: LogicalPlan, leftKey: Expression,
rightKey: Expression): Boolean = {
if (conf.runtimeFilterBloomFilterEnabled) {
if (conf.runtimeFilterBloomFilterEnabled
|| conf.runtimeFilterBloomFilterWithSegmentPruneEnabled) {
hasBloomFilter(left, right, leftKey, rightKey)
} else {
hasInSubquery(left, right, leftKey, rightKey)
Expand Down Expand Up @@ -229,6 +258,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
splitConjunctivePredicates(condition).exists {
case BloomFilterMightContain(_, XxHash64(Seq(valueExpression), _))
if valueExpression.fastEquals(key) => true
case BloomAndRangeFilterExpression(_, valueExpression, _)
if valueExpression.fastEquals(key) => true
case _ => false
}
case _ => false
Expand Down Expand Up @@ -286,7 +317,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
override def apply(plan: LogicalPlan): LogicalPlan = plan match {
case s: Subquery if s.correlated => plan
case _ if !conf.runtimeFilterSemiJoinReductionEnabled &&
!conf.runtimeFilterBloomFilterEnabled => plan
!conf.runtimeFilterBloomFilterEnabled &&
!conf.runtimeFilterBloomFilterWithSegmentPruneEnabled => plan
case _ => tryInjectRuntimeFilter(plan)
}

Expand Down
Loading

0 comments on commit e37ed33

Please sign in to comment.