Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues with canonicalization #623

Merged
merged 1 commit into from
Sep 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,14 @@ case class Table(
partitionColumns: Seq[String],
schema: StructType) {

private[this] def path(basePath: String) =
basePath + "/" + name + ".dat"
private[this] def path(basePath: String, appendDat: Boolean = true) = {
val rest = if (appendDat) {
".dat"
} else {
""
}
basePath + "/" + name + rest
}

def readCSV(spark: SparkSession, basePath: String): DataFrame =
spark.read.option("delimiter", "|")
Expand All @@ -37,12 +43,20 @@ case class Table(
def setupCSV(spark: SparkSession, basePath: String): Unit =
readCSV(spark, basePath).createOrReplaceTempView(name)

def setupParquet(spark: SparkSession, basePath: String): Unit =
spark.read.parquet(path(basePath)).createOrReplaceTempView(name)
def setupParquet(spark: SparkSession, basePath: String, appendDat: Boolean = true): Unit =
spark.read.parquet(path(basePath, appendDat)).createOrReplaceTempView(name)

def setupOrc(spark: SparkSession, basePath: String): Unit =
spark.read.orc(path(basePath)).createOrReplaceTempView(name)

def setup(
spark: SparkSession,
basePath: String,
format: String,
appendDat: Boolean = true): Unit = {
spark.read.format(format).load(path(basePath, appendDat)).createOrReplaceTempView(name)
}

private def setupWrite(
spark: SparkSession,
inputBase: String,
Expand Down Expand Up @@ -127,14 +141,22 @@ object TpcdsLikeSpark {
tables.foreach(_.setupCSV(spark, basePath))
}

def setupAllParquet(spark: SparkSession, basePath: String): Unit = {
tables.foreach(_.setupParquet(spark, basePath))
def setupAllParquet(spark: SparkSession, basePath: String, appendDat: Boolean = true): Unit = {
tables.foreach(_.setupParquet(spark, basePath, appendDat))
}

def setupAllOrc(spark: SparkSession, basePath: String): Unit = {
tables.foreach(_.setupOrc(spark, basePath))
}

def setupAll(
spark: SparkSession,
basePath: String,
format: String,
appendDat: Boolean = true): Unit = {
tables.foreach(_.setup(spark, basePath, format, appendDat))
}

private val tables = Array(
Table(
"catalog_sales",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.rapids._
import org.apache.spark.sql.rapids.execution.TrampolineUtil

/**
* Rewrites an expression using rules that are guaranteed preserve the result while attempting
* to remove cosmetic variations. Deterministic expressions that are `equal` after canonicalization
* will always return the same answer given the same input (i.e. false positives should not be
* possible). However, it is possible that two canonical expressions that are not equal will in fact
* return the same answer given any input (i.e. false negatives are possible).
*
* The following rules are applied:
* - Names and nullability hints for [[org.apache.spark.sql.types.DataType]]s are stripped.
* - Names for [[GetStructField]] are stripped.
* - TimeZoneId for [[Cast]] and [[AnsiCast]] are stripped if `needsTimeZone` is false.
* - Commutative and associative operations ([[Add]] and [[Multiply]]) have their children ordered
* by `hashCode`.
* - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`.
* - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`.
* - Elements in [[In]] are reordered by `hashCode`.
*
* This is essentially a copy of the Spark `Canonicalize` class but updated for GPU operators
*/
object GpuCanonicalize {
def execute(e: Expression): Expression = {
expressionReorder(ignoreTimeZone(ignoreNamesTypes(e)))
}

/** Remove names and nullability from types, and names from `GetStructField`. */
def ignoreNamesTypes(e: Expression): Expression = e match {
case a: AttributeReference =>
AttributeReference("none", TrampolineUtil.asNullable(a.dataType))(exprId = a.exprId)
case GetStructField(child, ordinal, Some(_)) => GetStructField(child, ordinal, None)
case _ => e
}

/** Remove TimeZoneId for Cast if needsTimeZone return false. */
def ignoreTimeZone(e: Expression): Expression = e match {
case c: CastBase if c.timeZoneId.nonEmpty && !c.needsTimeZone =>
c.withTimeZone(null)
case c: GpuCast if c.timeZoneId.nonEmpty =>
// TODO when we start to support time zones check for `&& !c.needsTimeZone`
abellina marked this conversation as resolved.
Show resolved Hide resolved
c.withTimeZone(null)
case _ => e
}

/** Collects adjacent commutative operations. */
private def gatherCommutative(
e: Expression,
f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = e match {
case c if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f))
case other => other :: Nil
}

/** Orders a set of commutative operations by their hash code. */
private def orderCommutative(
e: Expression,
f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] =
gatherCommutative(e, f).sortBy(_.hashCode())

/** Rearrange expressions that are commutative or associative. */
private def expressionReorder(e: Expression): Expression = e match {
case a: GpuAdd => orderCommutative(a, { case GpuAdd(l, r) => Seq(l, r) }).reduce(GpuAdd)
case m: GpuMultiply =>
orderCommutative(m, { case GpuMultiply(l, r) => Seq(l, r) }).reduce(GpuMultiply)
case o: GpuOr =>
orderCommutative(o, { case GpuOr(l, r) if l.deterministic && r.deterministic => Seq(l, r) })
.reduce(GpuOr)
case a: GpuAnd =>
orderCommutative(a, { case GpuAnd(l, r) if l.deterministic && r.deterministic => Seq(l, r)})
.reduce(GpuAnd)

case GpuEqualTo(l, r) if l.hashCode() > r.hashCode() => GpuEqualTo(r, l)
case GpuEqualNullSafe(l, r) if l.hashCode() > r.hashCode() => GpuEqualNullSafe(r, l)

case GpuGreaterThan(l, r) if l.hashCode() > r.hashCode() => GpuLessThan(r, l)
case GpuLessThan(l, r) if l.hashCode() > r.hashCode() => GpuGreaterThan(r, l)

case GpuGreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GpuLessThanOrEqual(r, l)
case GpuLessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GpuGreaterThanOrEqual(r, l)

// Note in the following `NOT` cases, `l.hashCode() <= r.hashCode()` holds. The reason is that
// canonicalization is conducted bottom-up -- see [[Expression.canonicalized]].
case GpuNot(GpuGreaterThan(l, r)) => GpuLessThanOrEqual(l, r)
case GpuNot(GpuLessThan(l, r)) => GpuGreaterThanOrEqual(l, r)
case GpuNot(GpuGreaterThanOrEqual(l, r)) => GpuLessThan(l, r)
case GpuNot(GpuLessThanOrEqual(l, r)) => GpuGreaterThan(l, r)

// order the list in the In operator
case GpuInSet(value, list) if list.length > 1 => GpuInSet(value, list.sortBy(_.hashCode()))

case _ => e
}
}
32 changes: 32 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package com.nvidia.spark.rapids
import com.nvidia.spark.rapids.GpuMetricNames._

import org.apache.spark.SparkContext
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, ExprId}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}

Expand Down Expand Up @@ -88,4 +90,34 @@ trait GpuExec extends SparkPlan with Arm {
case c: GpuExpression => c.disableCoalesceUntilInput()
case _ => false
}

/**
* Defines how the canonicalization should work for the current plan.
*/
override protected def doCanonicalize(): SparkPlan = {
val canonicalizedChildren = children.map(_.canonicalized)
var id = -1
mapExpressions {
case a: Alias =>
id += 1
// As the root of the expression, Alias will always take an arbitrary exprId, we need to
// normalize that for equality testing, by assigning expr id from 0 incrementally. The
// alias name doesn't matter and should be erased.
val normalizedChild = QueryPlan.normalizeExpressions(a.child, allAttributes)
Alias(normalizedChild, "")(ExprId(id), a.qualifier)
case a: GpuAlias =>
id += 1
// As the root of the expression, Alias will always take an arbitrary exprId, we need to
// normalize that for equality testing, by assigning expr id from 0 incrementally. The
// alias name doesn't matter and should be erased.
val normalizedChild = QueryPlan.normalizeExpressions(a.child, allAttributes)
GpuAlias(normalizedChild, "")(ExprId(id), a.qualifier)
case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 =>
// Top level `AttributeReference` may also be used for output like `Alias`, we should
// normalize the exprId too.
id += 1
ar.withExprId(ExprId(id)).canonicalized
case other => QueryPlan.normalizeExpressions(other, allAttributes)
}.withNewChildren(canonicalizedChildren)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ trait GpuExpression extends Expression with Unevaluable with Arm {
* temporary value.
*/
def columnarEval(batch: ColumnarBatch): Any

override lazy val canonicalized: Expression = {
val canonicalizedChildren = children.map(_.canonicalized)
GpuCanonicalize.execute(withNewChildren(canonicalizedChildren))
}
}

abstract class GpuLeafExpression extends GpuExpression {
Expand Down
47 changes: 43 additions & 4 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSeq, AttributeSet, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSeq, AttributeSet, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, HashPartitioning, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.{ExplainUtils, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.rapids.{CudfAggregate, GpuAggregateExpression, GpuDeclarativeAggregate}
Expand Down Expand Up @@ -233,6 +233,17 @@ case class GpuHashAggregateExec(
resultExpressions: Seq[NamedExpression],
child: SparkPlan) extends UnaryExecNode with GpuExec with Arm {

override def verboseStringWithOperatorId(): String = {
s"""
|$formattedNodeName
|${ExplainUtils.generateFieldString("Input", child.output)}
|${ExplainUtils.generateFieldString("Keys", groupingExpressions)}
|${ExplainUtils.generateFieldString("Functions", aggregateExpressions)}
|${ExplainUtils.generateFieldString("Aggregate Attributes", aggregateAttributes)}
|${ExplainUtils.generateFieldString("Results", resultExpressions)}
|""".stripMargin
}

case class BoundExpressionsModeAggregates(boundInputReferences: Seq[GpuExpression] ,
boundFinalProjections: Option[scala.Seq[GpuExpression]],
boundResultReferences: scala.Seq[Expression] ,
Expand Down Expand Up @@ -834,14 +845,42 @@ case class GpuHashAggregateExec(
"concatTime"-> SQLMetrics.createNanoTimingMetric(sparkContext, "time in batch concat")
)

protected def outputExpressions: Seq[NamedExpression] = resultExpressions

//
// This section is derived (copied in most cases) from HashAggregateExec
//
private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
}

override def outputPartitioning: Partitioning = child.outputPartitioning
final override def outputPartitioning: Partitioning = {
if (hasAlias) {
child.outputPartitioning match {
case h: GpuHashPartitioning => h.copy(expressions = replaceAliases(h.expressions))
abellina marked this conversation as resolved.
Show resolved Hide resolved
case h: HashPartitioning => h.copy(expressions = replaceAliases(h.expressions))
case other => other
abellina marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
child.outputPartitioning
}
}

protected def hasAlias: Boolean = outputExpressions.collectFirst { case _: Alias => }.isDefined

protected def replaceAliases(exprs: Seq[Expression]): Seq[Expression] = {
exprs.map {
case a: AttributeReference => replaceAlias(a).getOrElse(a)
case other => other
}
}

protected def replaceAlias(attr: AttributeReference): Option[Attribute] = {
outputExpressions.collectFirst {
case a @ Alias(child: AttributeReference, _) if child.semanticEquals(attr) =>
a.toAttribute
}
}

// Used in de-duping and optimizer rules
override def producedAttributes: AttributeSet =
Expand Down
Loading