Skip to content

Commit

Permalink
Qualification tool: Parse expressions in Aggregates and Sort execs. (#…
Browse files Browse the repository at this point in the history
…6042)

* Add parsers for Aggregate and Sort execs

Signed-off-by: Niranjan Artal <nartal@nvidia.com>

* addressed review comments

Signed-off-by: Niranjan Artal <nartal@nvidia.com>
  • Loading branch information
nartal1 authored Jul 22, 2022
1 parent 1b27375 commit f2d6157
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ case class HashAggregateExecParser(
val accumId = node.metrics.find(
_.name == "time in aggregation build total").map(_.accumulatorId)
val maxDuration = SQLPlanParser.getTotalDuration(accumId, app)
val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName)) {
val exprString = node.desc.replaceFirst("HashAggregate", "")
val expressions = SQLPlanParser.parseAggregateExpressions(exprString)
val isAllExprsSupported = expressions.forall(expr => checker.isExprSupported(expr))
val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName) &&
isAllExprsSupported) {
(checker.getSpeedupFactor(fullExecName), true)
} else {
(1.0, false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ case class ObjectHashAggregateExecParser(
val accumId = node.metrics.find(
_.name == "time in aggregation build total").map(_.accumulatorId)
val maxDuration = SQLPlanParser.getTotalDuration(accumId, app)
val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName)) {
val exprString = node.desc.replaceFirst("ObjectHashAggregate", "")
val expressions = SQLPlanParser.parseAggregateExpressions(exprString)
val isAllExprsSupported = expressions.forall(expr => checker.isExprSupported(expr))
val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName) &&
isAllExprsSupported) {
(checker.getSpeedupFactor(fullExecName), true)
} else {
(1.0, false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,62 @@ object SQLPlanParser extends Logging {
case _ => // NO OP
}
}
parsedExpressions.toArray
parsedExpressions.distinct.toArray
}

// This parser is used for SortAggregateExec, HashAggregateExec and ObjectHashAggregateExec
def parseAggregateExpressions(exprStr: String): Array[String] = {
val parsedExpressions = ArrayBuffer[String]()
// (key=[num#83], functions=[partial_collect_list(letter#84, 0, 0), partial_count(letter#84)])
val pattern = """functions=\[([\w#, +*\\\-\.<>=\`\(\)]+\])""".r
val aggregatesString = pattern.findFirstMatchIn(exprStr)
// This is to split multiple column names in AggregateExec. Each column will be aggregating
// based on the aggregate function. Here "partial_" is removed and only function name is
// preserved. Below regex will first remove the "functions=" from the string followed by
// removing "partial_". That string is split which produces an array containing
// column names. Finally we remove the parentheses from the beginning and end to get only
// the expressions. Result will be as below.
// paranRemoved = Array(collect_list(letter#84, 0, 0),, count(letter#84))
if (aggregatesString.isDefined) {
val paranRemoved = aggregatesString.get.toString.replaceAll("functions=", "").
replaceAll("partial_", "").split("(?<=\\),)").map(_.trim).
map(_.replaceAll("""^\[+""", "").replaceAll("""\]+$""", ""))
val functionPattern = """(\w+)\(.*\)""".r
paranRemoved.foreach { case expr =>
val functionName = getFunctionName(functionPattern, expr)
functionName match {
case Some(func) => parsedExpressions += func
case _ => // NO OP
}
}
}
parsedExpressions.distinct.toArray
}

def parseSortExpressions(exprStr: String): Array[String] = {
val parsedExpressions = ArrayBuffer[String]()
// Sort [round(num#126, 0) ASC NULLS FIRST, letter#127 DESC NULLS LAST], true, 0
val pattern = """\[([\w#, \(\)]+\])""".r
val sortString = pattern.findFirstMatchIn(exprStr)
// This is to split multiple column names in SortExec. Project may have a function on a column.
// The string is split on delimiter containing FIRST, OR LAST, which is the last string
// of each column in SortExec that produces an array containing
// column names. Finally we remove the parentheses from the beginning and end to get only
// the expressions. Result will be as below.
// paranRemoved = Array(round(num#7, 0) ASC NULLS FIRST,, letter#8 DESC NULLS LAST)
if (sortString.isDefined) {
val paranRemoved = sortString.get.toString.split("(?<=FIRST,)|(?<=LAST,)").
map(_.trim).map(_.replaceAll("""^\[+""", "").replaceAll("""\]+$""", ""))
val functionPattern = """(\w+)\(.*\)""".r
paranRemoved.foreach { case expr =>
val functionName = getFunctionName(functionPattern, expr)
functionName match {
case Some(func) => parsedExpressions += func
case _ => // NO OP
}
}
}
parsedExpressions.distinct.toArray
}

def parseFilterExpressions(exprStr: String): Array[String] = {
Expand Down Expand Up @@ -362,6 +417,6 @@ object SQLPlanParser extends Logging {
}
}
}
parsedExpressions.toArray
parsedExpressions.distinct.toArray
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ case class SortAggregateExecParser(
override def parse: ExecInfo = {
// SortAggregate doesn't have duration
val duration = None
val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName)) {
val exprString = node.desc.replaceFirst("SortAggregate", "")
val expressions = SQLPlanParser.parseAggregateExpressions(exprString)
val isAllExprsSupported = expressions.forall(expr => checker.isExprSupported(expr))
val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName) &&
isAllExprsSupported) {
(checker.getSpeedupFactor(fullExecName), true)
} else {
(1.0, false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ case class SortExecParser(
override def parse: ExecInfo = {
// Sort doesn't have duration
val duration = None
val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName)) {
val exprString = node.desc.replaceFirst("Sort ", "")
val expressions = SQLPlanParser.parseSortExpressions(exprString)
val isAllExprsSupported = expressions.forall(expr => checker.isExprSupported(expr))
val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName) &&
isAllExprsSupported) {
(checker.getSpeedupFactor(fullExecName), true)
} else {
(1.0, false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ class PluginTypeChecker extends Logging {
}

def isExprSupported(expr: String): Boolean = {
val exprLowercase = expr.toLowerCase
// Remove _ from the string. Example: collect_list => collectlist.
// collect_list is alias for CollectList aggregate function
val exprLowercase = expr.toLowerCase.replace("_","")
if (supportedExprs.contains(exprLowercase)) {
val exprSupported = supportedExprs.getOrElse(exprLowercase, "NS")
if (exprSupported == "S") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.scalatest.{BeforeAndAfterEach, FunSuite}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, TrampolineUtil}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{broadcast, ceil, col, collect_list, explode, hex, sum}
import org.apache.spark.sql.functions.{broadcast, ceil, col, collect_list, count, explode, hex, round, sum}
import org.apache.spark.sql.rapids.tool.ToolUtils
import org.apache.spark.sql.rapids.tool.qualification.QualificationAppInfo
import org.apache.spark.sql.types.StringType
Expand Down Expand Up @@ -558,6 +558,56 @@ class SQLPlanParserSuite extends FunSuite with BeforeAndAfterEach with Logging {
}
}

test("Expressions supported in SortAggregateExec") {
TrampolineUtil.withTempDir { eventLogDir =>
val (eventLog, _) = ToolTestUtils.generateEventLog(eventLogDir, "sqlmetric") { spark =>
import spark.implicits._
spark.conf.set("spark.sql.execution.useObjectHashAggregateExec", "false")
val df1 = Seq((1, "a"), (1, "aa"), (1, "a"), (2, "b"),
(2, "b"), (3, "c"), (3, "c")).toDF("num", "letter")
df1.groupBy("num").agg(collect_list("letter").as("collected_letters"),
count("letter").as("letter_count"))
}
val pluginTypeChecker = new PluginTypeChecker()
val app = createAppFromEventlog(eventLog)
assert(app.sqlPlans.size == 1)
val parsedPlans = app.sqlPlans.map { case (sqlID, plan) =>
SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, pluginTypeChecker, app)
}
val execInfo = getAllExecsFromPlan(parsedPlans.toSeq)
val sortAggregate = execInfo.filter(_.exec == "SortAggregate")
assertSizeAndSupported(2, sortAggregate)
}
}

test("Expressions supported in SortExec") {
TrampolineUtil.withTempDir { parquetoutputLoc =>
TrampolineUtil.withTempDir { eventLogDir =>
val (eventLog, _) = ToolTestUtils.generateEventLog(eventLogDir,
"ProjectExprsSupported") { spark =>
import spark.implicits._
val df1 = Seq((1.7, "a"), (1.6, "aa"), (1.1, "b"), (2.5, "a"), (2.2, "b"),
(3.2, "a"), (10.6, "c")).toDF("num", "letter")
df1.write.parquet(s"$parquetoutputLoc/testsortExec")
val df2 = spark.read.parquet(s"$parquetoutputLoc/testsortExec")
df2.sort("num").collect
df2.orderBy("num").collect
df2.select(round(col("num")), col("letter")).sort(round(col("num")), col("letter").desc)
}
val pluginTypeChecker = new PluginTypeChecker()
val app = createAppFromEventlog(eventLog)
assert(app.sqlPlans.size == 4)
val parsedPlans = app.sqlPlans.map { case (sqlID, plan) =>
SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, pluginTypeChecker, app)
}
val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq)
val sortExec = allExecInfo.filter(_.exec.contains("Sort"))
assert(sortExec.size == 3)
assertSizeAndSupported(3, sortExec, 5.2)
}
}
}

test("Expressions supported in ProjectExec") {
TrampolineUtil.withTempDir { parquetoutputLoc =>
TrampolineUtil.withTempDir { eventLogDir =>
Expand Down

0 comments on commit f2d6157

Please sign in to comment.