Skip to content

Commit

Permalink
[SPARK-33036][SQL] Refactor RewriteCorrelatedScalarSubquery code to r…
Browse files Browse the repository at this point in the history
…eplace exprIds in a bottom-up manner

### What changes were proposed in this pull request?

This PR intends to refactor code in `RewriteCorrelatedScalarSubquery` for replacing `ExprId`s in a bottom-up manner instead of doing in a top-down one.

This PR comes from the talk with cloud-fan in #29585 (comment).

### Why are the changes needed?

To improve code.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing tests.

Closes #29913 from maropu/RefactorRewriteCorrelatedScalarSubquery.

Authored-by: Takeshi Yamamuro <yamamuro@apache.org>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
  • Loading branch information
maropu committed Oct 7, 2020
1 parent 72da6f8 commit 94d648d
Showing 1 changed file with 51 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -338,20 +338,15 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
/**
* Extract all correlated scalar subqueries from an expression. The subqueries are collected using
* the given collector. To avoid the reuse of `exprId`s, this method generates new `exprId`
* for the subqueries and rewrite references in the given `expression`.
* This method returns extracted subqueries and the corresponding `exprId`s and these values
* will be used later in `constructLeftJoins` for building the child plan that
* returns subquery output with the `exprId`s.
* the given collector. The expression is rewritten and returned.
*/
private def extractCorrelatedScalarSubqueries[E <: Expression](
expression: E,
subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): E = {
subqueries: ArrayBuffer[ScalarSubquery]): E = {
val newExpression = expression transform {
case s: ScalarSubquery if s.children.nonEmpty =>
val newExprId = NamedExpression.newExprId
subqueries += s -> newExprId
s.plan.output.head.withExprId(newExprId)
subqueries += s
s.plan.output.head
}
newExpression.asInstanceOf[E]
}
Expand Down Expand Up @@ -512,19 +507,23 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {

/**
* Construct a new child plan by left joining the given subqueries to a base plan.
* This method returns the child plan and an attribute mapping
* for the updated `ExprId`s of subqueries. If the non-empty mapping returned,
* this rule will rewrite subquery references in a parent plan based on it.
*/
private def constructLeftJoins(
child: LogicalPlan,
subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): LogicalPlan = {
subqueries.foldLeft(child) {
case (currentChild, (ScalarSubquery(query, conditions, _), newExprId)) =>
subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, AttributeMap[Attribute]) = {
val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]()
val newChild = subqueries.foldLeft(child) {
case (currentChild, ScalarSubquery(query, conditions, _)) =>
val origOutput = query.output.head

val resultWithZeroTups = evalSubqueryOnZeroTups(query)
if (resultWithZeroTups.isEmpty) {
// CASE 1: Subquery guaranteed not to have the COUNT bug
Project(
currentChild.output :+ Alias(origOutput, origOutput.name)(exprId = newExprId),
currentChild.output :+ origOutput,
Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
} else {
// Subquery might have the COUNT bug. Add appropriate corrections.
Expand All @@ -544,12 +543,13 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {

if (havingNode.isEmpty) {
// CASE 2: Subquery with no HAVING clause
val subqueryResultExpr =
Alias(If(IsNull(alwaysTrueRef),
resultWithZeroTups.get,
aggValRef), origOutput.name)()
subqueryAttrMapping += ((origOutput, subqueryResultExpr.toAttribute))
Project(
currentChild.output :+
Alias(
If(IsNull(alwaysTrueRef),
resultWithZeroTups.get,
aggValRef), origOutput.name)(exprId = newExprId),
currentChild.output :+ subqueryResultExpr,
Join(currentChild,
Project(query.output :+ alwaysTrueExpr, query),
LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
Expand All @@ -576,7 +576,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
(IsNull(alwaysTrueRef), resultWithZeroTups.get),
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
aggValRef),
origOutput.name)(exprId = newExprId)
origOutput.name)()

subqueryAttrMapping += ((origOutput, caseExpr.toAttribute))

Project(
currentChild.output :+ caseExpr,
Expand All @@ -587,6 +589,20 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
}
}
}
(newChild, AttributeMap(subqueryAttrMapping.toSeq))
}

private def updateAttrs[E <: Expression](
exprs: Seq[E],
attrMap: AttributeMap[Attribute]): Seq[E] = {
if (attrMap.nonEmpty) {
val newExprs = exprs.map { _.transform {
case a: AttributeReference => attrMap.getOrElse(a, a)
}}
newExprs.asInstanceOf[Seq[E]]
} else {
exprs
}
}

/**
Expand All @@ -595,36 +611,42 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
*/
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
case a @ Aggregate(grouping, expressions, child) =>
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
// We currently only allow correlated subqueries in an aggregate if they are part of the
// grouping expressions. As a result we need to replace all the scalar subqueries in the
// grouping expressions by their result.
val newGrouping = grouping.map { e =>
subqueries.find(_._1.semanticEquals(e)).map(_._1.plan.output.head).getOrElse(e)
subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e)
}
val newAgg = Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries))
val (newChild, subqueryAttrMapping) = constructLeftJoins(child, subqueries)
val newExprs = updateAttrs(rewriteExprs, subqueryAttrMapping)
val newAgg = Aggregate(newGrouping, newExprs, newChild)
val attrMapping = a.output.zip(newAgg.output)
newAgg -> attrMapping
} else {
a -> Nil
}
case p @ Project(expressions, child) =>
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
val newProj = Project(newExpressions, constructLeftJoins(child, subqueries))
val (newChild, subqueryAttrMapping) = constructLeftJoins(child, subqueries)
val newExprs = updateAttrs(rewriteExprs, subqueryAttrMapping)
val newProj = Project(newExprs, newChild)
val attrMapping = p.output.zip(newProj.output)
newProj -> attrMapping
} else {
p -> Nil
}
case f @ Filter(condition, child) =>
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val rewriteCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
if (subqueries.nonEmpty) {
val newProj = Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries)))
val (newChild, subqueryAttrMapping) = constructLeftJoins(child, subqueries)
val newCondition = updateAttrs(Seq(rewriteCondition), subqueryAttrMapping).head
val newProj = Project(f.output, Filter(newCondition, newChild))
val attrMapping = f.output.zip(newProj.output)
newProj -> attrMapping
} else {
Expand Down

0 comments on commit 94d648d

Please sign in to comment.