Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qualification tool: Parse expressions in Aggregates and Sort execs. #6042

Merged
merged 3 commits into from
Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ case class HashAggregateExecParser(
val accumId = node.metrics.find(
_.name == "time in aggregation build total").map(_.accumulatorId)
val maxDuration = SQLPlanParser.getTotalDuration(accumId, app)
val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName)) {
val exprString = node.desc.replaceFirst("HashAggregate", "")
val expressions = SQLPlanParser.parseAggregateExpressions(exprString)
val isAllExprsSupported = expressions.forall(expr => checker.isExprSupported(expr))
val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName) &&
isAllExprsSupported) {
(checker.getSpeedupFactor(fullExecName), true)
} else {
(1.0, false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ case class ObjectHashAggregateExecParser(
val accumId = node.metrics.find(
_.name == "time in aggregation build total").map(_.accumulatorId)
val maxDuration = SQLPlanParser.getTotalDuration(accumId, app)
val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName)) {
val exprString = node.desc.replaceFirst("ObjectHashAggregate", "")
val expressions = SQLPlanParser.parseAggregateExpressions(exprString)
val isAllExprsSupported = expressions.forall(expr => checker.isExprSupported(expr))
val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName) &&
isAllExprsSupported) {
(checker.getSpeedupFactor(fullExecName), true)
} else {
(1.0, false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,62 @@ object SQLPlanParser extends Logging {
case _ => // NO OP
}
}
parsedExpressions.toArray
parsedExpressions.distinct.toArray
}

// This parser is used for SortAggregateExec, HashAggregateExec and ObjectHashAggregateExec
def parseAggregateExpressions(exprStr: String): Array[String] = {
val parsedExpressions = ArrayBuffer[String]()
// (key=[num#83], functions=[partial_collect_list(letter#84, 0, 0), partial_count(letter#84)])
val pattern = """functions=\[([\w#, +*\\\-\.<>=\`\(\)]+\])""".r
val aggregatesString = pattern.findFirstMatchIn(exprStr)
// This is to split multiple column names in AggregateExec. Each column will be aggregating
// based on the aggregate function. Here "partial_" is removed and only function name is
// preserved. Below regex will first remove the "functions=" from the string followed by
// removing "partial_". That string is split which produces an array containing
// column names. Finally we remove the parentheses from the beginning and end to get only
// the expressions. Result will be as below.
// paranRemoved = Array(collect_list(letter#84, 0, 0),, count(letter#84))
tgravescs marked this conversation as resolved.
Show resolved Hide resolved
if (aggregatesString.isDefined) {
val paranRemoved = aggregatesString.get.toString.replaceAll("functions=", "").
replaceAll("partial_", "").split("(?<=\\),)").map(_.trim).
tgravescs marked this conversation as resolved.
Show resolved Hide resolved
map(_.replaceAll("""^\[+""", "").replaceAll("""\]+$""", ""))
val functionPattern = """(\w+)\(.*\)""".r
paranRemoved.foreach { case expr =>
val functionName = getFunctionName(functionPattern, expr)
tgravescs marked this conversation as resolved.
Show resolved Hide resolved
functionName match {
case Some(func) => parsedExpressions += func
case _ => // NO OP
}
}
}
parsedExpressions.distinct.toArray
}

def parseSortExpressions(exprStr: String): Array[String] = {
val parsedExpressions = ArrayBuffer[String]()
// Sort [round(num#126, 0) ASC NULLS FIRST, letter#127 DESC NULLS LAST], true, 0
tgravescs marked this conversation as resolved.
Show resolved Hide resolved
val pattern = """\[([\w#, \(\)]+\])""".r
val sortString = pattern.findFirstMatchIn(exprStr)
// This is to split multiple column names in SortExec. Project may have a function on a column.
// The string is split on delimiter containing FIRST, OR LAST, which is the last string
// of each column in SortExec that produces an array containing
// column names. Finally we remove the parentheses from the beginning and end to get only
// the expressions. Result will be as below.
// paranRemoved = Array(round(num#7, 0) ASC NULLS FIRST,, letter#8 DESC NULLS LAST)
if (sortString.isDefined) {
val paranRemoved = sortString.get.toString.split("(?<=FIRST,)|(?<=LAST,)").
map(_.trim).map(_.replaceAll("""^\[+""", "").replaceAll("""\]+$""", ""))
val functionPattern = """(\w+)\(.*\)""".r
paranRemoved.foreach { case expr =>
val functionName = getFunctionName(functionPattern, expr)
functionName match {
case Some(func) => parsedExpressions += func
case _ => // NO OP
}
}
}
parsedExpressions.distinct.toArray
}

def parseFilterExpressions(exprStr: String): Array[String] = {
Expand Down Expand Up @@ -362,6 +417,6 @@ object SQLPlanParser extends Logging {
}
}
}
parsedExpressions.toArray
parsedExpressions.distinct.toArray
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ case class SortAggregateExecParser(
override def parse: ExecInfo = {
// SortAggregate doesn't have duration
val duration = None
val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName)) {
val exprString = node.desc.replaceFirst("SortAggregate", "")
val expressions = SQLPlanParser.parseAggregateExpressions(exprString)
val isAllExprsSupported = expressions.forall(expr => checker.isExprSupported(expr))
val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName) &&
isAllExprsSupported) {
(checker.getSpeedupFactor(fullExecName), true)
} else {
(1.0, false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ case class SortExecParser(
override def parse: ExecInfo = {
// Sort doesn't have duration
val duration = None
val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName)) {
val exprString = node.desc.replaceFirst("Sort ", "")
val expressions = SQLPlanParser.parseSortExpressions(exprString)
val isAllExprsSupported = expressions.forall(expr => checker.isExprSupported(expr))
val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName) &&
isAllExprsSupported) {
(checker.getSpeedupFactor(fullExecName), true)
} else {
(1.0, false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ class PluginTypeChecker extends Logging {
}

def isExprSupported(expr: String): Boolean = {
val exprLowercase = expr.toLowerCase
// Remove _ from the string. Example: collect_list => collectlist.
// collect_list is alias for CollectList aggregate function
val exprLowercase = expr.toLowerCase.replace("_","")
if (supportedExprs.contains(exprLowercase)) {
val exprSupported = supportedExprs.getOrElse(exprLowercase, "NS")
if (exprSupported == "S") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.scalatest.{BeforeAndAfterEach, FunSuite}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, TrampolineUtil}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{broadcast, ceil, col, collect_list, explode, hex, sum}
import org.apache.spark.sql.functions.{broadcast, ceil, col, collect_list, count, explode, hex, round, sum}
import org.apache.spark.sql.rapids.tool.ToolUtils
import org.apache.spark.sql.rapids.tool.qualification.QualificationAppInfo
import org.apache.spark.sql.types.StringType
Expand Down Expand Up @@ -558,6 +558,56 @@ class SQLPlanParserSuite extends FunSuite with BeforeAndAfterEach with Logging {
}
}

test("Expressions supported in SortAggregateExec") {
TrampolineUtil.withTempDir { eventLogDir =>
val (eventLog, _) = ToolTestUtils.generateEventLog(eventLogDir, "sqlmetric") { spark =>
import spark.implicits._
spark.conf.set("spark.sql.execution.useObjectHashAggregateExec", "false")
val df1 = Seq((1, "a"), (1, "aa"), (1, "a"), (2, "b"),
(2, "b"), (3, "c"), (3, "c")).toDF("num", "letter")
df1.groupBy("num").agg(collect_list("letter").as("collected_letters"),
count("letter").as("letter_count"))
}
val pluginTypeChecker = new PluginTypeChecker()
val app = createAppFromEventlog(eventLog)
assert(app.sqlPlans.size == 1)
val parsedPlans = app.sqlPlans.map { case (sqlID, plan) =>
SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, pluginTypeChecker, app)
}
val execInfo = getAllExecsFromPlan(parsedPlans.toSeq)
val sortAggregate = execInfo.filter(_.exec == "SortAggregate")
assertSizeAndSupported(2, sortAggregate)
}
}

test("Expressions supported in SortExec") {
TrampolineUtil.withTempDir { parquetoutputLoc =>
TrampolineUtil.withTempDir { eventLogDir =>
val (eventLog, _) = ToolTestUtils.generateEventLog(eventLogDir,
"ProjectExprsSupported") { spark =>
import spark.implicits._
val df1 = Seq((1.7, "a"), (1.6, "aa"), (1.1, "b"), (2.5, "a"), (2.2, "b"),
(3.2, "a"), (10.6, "c")).toDF("num", "letter")
df1.write.parquet(s"$parquetoutputLoc/testsortExec")
val df2 = spark.read.parquet(s"$parquetoutputLoc/testsortExec")
df2.sort("num").collect
df2.orderBy("num").collect
df2.select(round(col("num")), col("letter")).sort(round(col("num")), col("letter").desc)
tgravescs marked this conversation as resolved.
Show resolved Hide resolved
}
val pluginTypeChecker = new PluginTypeChecker()
val app = createAppFromEventlog(eventLog)
assert(app.sqlPlans.size == 4)
val parsedPlans = app.sqlPlans.map { case (sqlID, plan) =>
SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, pluginTypeChecker, app)
}
val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq)
val sortExec = allExecInfo.filter(_.exec.contains("Sort"))
assert(sortExec.size == 3)
assertSizeAndSupported(3, sortExec, 5.2)
}
}
}

test("Expressions supported in ProjectExec") {
TrampolineUtil.withTempDir { parquetoutputLoc =>
TrampolineUtil.withTempDir { eventLogDir =>
Expand Down