From f2d6157251dae10874e868337213c150d080048c Mon Sep 17 00:00:00 2001 From: Niranjan Artal <50492963+nartal1@users.noreply.github.com> Date: Fri, 22 Jul 2022 06:15:08 -0700 Subject: [PATCH] Qualification tool: Parse expressions in Aggregates and Sort execs. (#6042) * Add parsers for Aggregate and Sort execs Signed-off-by: Niranjan Artal * addressed review comments Signed-off-by: Niranjan Artal --- .../planparser/HashAggregateExecParser.scala | 6 +- .../ObjectHashAggregateExecParser.scala | 6 +- .../tool/planparser/SQLPlanParser.scala | 59 ++++++++++++++++++- .../planparser/SortAggregateExecParser.scala | 6 +- .../tool/planparser/SortExecParser.scala | 6 +- .../qualification/PluginTypeChecker.scala | 4 +- .../tool/planparser/SqlPlanParserSuite.scala | 52 +++++++++++++++- 7 files changed, 131 insertions(+), 8 deletions(-) diff --git a/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/HashAggregateExecParser.scala b/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/HashAggregateExecParser.scala index 2c5e7714d18..37ca01c93e4 100644 --- a/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/HashAggregateExecParser.scala +++ b/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/HashAggregateExecParser.scala @@ -35,7 +35,11 @@ case class HashAggregateExecParser( val accumId = node.metrics.find( _.name == "time in aggregation build total").map(_.accumulatorId) val maxDuration = SQLPlanParser.getTotalDuration(accumId, app) - val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName)) { + val exprString = node.desc.replaceFirst("HashAggregate", "") + val expressions = SQLPlanParser.parseAggregateExpressions(exprString) + val isAllExprsSupported = expressions.forall(expr => checker.isExprSupported(expr)) + val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName) && + isAllExprsSupported) { (checker.getSpeedupFactor(fullExecName), true) } else { (1.0, false) diff --git a/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ObjectHashAggregateExecParser.scala b/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ObjectHashAggregateExecParser.scala index 995c51ee411..c77761aed2b 100644 --- a/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ObjectHashAggregateExecParser.scala +++ b/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ObjectHashAggregateExecParser.scala @@ -35,7 +35,11 @@ case class ObjectHashAggregateExecParser( val accumId = node.metrics.find( _.name == "time in aggregation build total").map(_.accumulatorId) val maxDuration = SQLPlanParser.getTotalDuration(accumId, app) - val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName)) { + val exprString = node.desc.replaceFirst("ObjectHashAggregate", "") + val expressions = SQLPlanParser.parseAggregateExpressions(exprString) + val isAllExprsSupported = expressions.forall(expr => checker.isExprSupported(expr)) + val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName) && + isAllExprsSupported) { (checker.getSpeedupFactor(fullExecName), true) } else { (1.0, false) diff --git a/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala b/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala index 84e6e1c68e6..659fd18d682 100644 --- a/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala +++ b/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala @@ -285,7 +285,62 @@ object SQLPlanParser extends Logging { case _ => // NO OP } } - parsedExpressions.toArray + parsedExpressions.distinct.toArray + } + + // This parser is used for SortAggregateExec, HashAggregateExec and ObjectHashAggregateExec + def parseAggregateExpressions(exprStr: String): Array[String] = { + val parsedExpressions = ArrayBuffer[String]() + // (key=[num#83], functions=[partial_collect_list(letter#84, 0, 0), partial_count(letter#84)]) + val pattern = """functions=\[([\w#, +*\\\-\.<>=\`\(\)]+\])""".r + val aggregatesString = pattern.findFirstMatchIn(exprStr) + // This is to split multiple column names in AggregateExec. Each column will be aggregating + // based on the aggregate function. Here "partial_" is removed and only function name is + // preserved. Below regex will first remove the "functions=" from the string followed by + // removing "partial_". That string is split which produces an array containing + // column names. Finally we remove the parentheses from the beginning and end to get only + // the expressions. Result will be as below. + // paranRemoved = Array(collect_list(letter#84, 0, 0),, count(letter#84)) + if (aggregatesString.isDefined) { + val paranRemoved = aggregatesString.get.toString.replaceAll("functions=", ""). + replaceAll("partial_", "").split("(?<=\\),)").map(_.trim). + map(_.replaceAll("""^\[+""", "").replaceAll("""\]+$""", "")) + val functionPattern = """(\w+)\(.*\)""".r + paranRemoved.foreach { case expr => + val functionName = getFunctionName(functionPattern, expr) + functionName match { + case Some(func) => parsedExpressions += func + case _ => // NO OP + } + } + } + parsedExpressions.distinct.toArray + } + + def parseSortExpressions(exprStr: String): Array[String] = { + val parsedExpressions = ArrayBuffer[String]() + // Sort [round(num#126, 0) ASC NULLS FIRST, letter#127 DESC NULLS LAST], true, 0 + val pattern = """\[([\w#, \(\)]+\])""".r + val sortString = pattern.findFirstMatchIn(exprStr) + // This is to split multiple column names in SortExec. Project may have a function on a column. + // The string is split on delimiter containing FIRST, OR LAST, which is the last string + // of each column in SortExec that produces an array containing + // column names. Finally we remove the parentheses from the beginning and end to get only + // the expressions. Result will be as below. + // paranRemoved = Array(round(num#7, 0) ASC NULLS FIRST,, letter#8 DESC NULLS LAST) + if (sortString.isDefined) { + val paranRemoved = sortString.get.toString.split("(?<=FIRST,)|(?<=LAST,)"). + map(_.trim).map(_.replaceAll("""^\[+""", "").replaceAll("""\]+$""", "")) + val functionPattern = """(\w+)\(.*\)""".r + paranRemoved.foreach { case expr => + val functionName = getFunctionName(functionPattern, expr) + functionName match { + case Some(func) => parsedExpressions += func + case _ => // NO OP + } + } + } + parsedExpressions.distinct.toArray } def parseFilterExpressions(exprStr: String): Array[String] = { @@ -362,6 +417,6 @@ object SQLPlanParser extends Logging { } } } - parsedExpressions.toArray + parsedExpressions.distinct.toArray } } diff --git a/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SortAggregateExecParser.scala b/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SortAggregateExecParser.scala index 425ffe9193b..5109bbd7f58 100644 --- a/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SortAggregateExecParser.scala +++ b/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SortAggregateExecParser.scala @@ -30,7 +30,11 @@ case class SortAggregateExecParser( override def parse: ExecInfo = { // SortAggregate doesn't have duration val duration = None - val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName)) { + val exprString = node.desc.replaceFirst("SortAggregate", "") + val expressions = SQLPlanParser.parseAggregateExpressions(exprString) + val isAllExprsSupported = expressions.forall(expr => checker.isExprSupported(expr)) + val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName) && + isAllExprsSupported) { (checker.getSpeedupFactor(fullExecName), true) } else { (1.0, false) diff --git a/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SortExecParser.scala b/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SortExecParser.scala index 288c484965a..e043a8e828a 100644 --- a/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SortExecParser.scala +++ b/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SortExecParser.scala @@ -30,7 +30,11 @@ case class SortExecParser( override def parse: ExecInfo = { // Sort doesn't have duration val duration = None - val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName)) { + val exprString = node.desc.replaceFirst("Sort ", "") + val expressions = SQLPlanParser.parseSortExpressions(exprString) + val isAllExprsSupported = expressions.forall(expr => checker.isExprSupported(expr)) + val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName) && + isAllExprsSupported) { (checker.getSpeedupFactor(fullExecName), true) } else { (1.0, false) diff --git a/tools/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala b/tools/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala index 2cfeffac1cf..4f14af8fca3 100644 --- a/tools/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala +++ b/tools/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala @@ -252,7 +252,9 @@ class PluginTypeChecker extends Logging { } def isExprSupported(expr: String): Boolean = { - val exprLowercase = expr.toLowerCase + // Remove _ from the string. Example: collect_list => collectlist. + // collect_list is alias for CollectList aggregate function + val exprLowercase = expr.toLowerCase.replace("_","") if (supportedExprs.contains(exprLowercase)) { val exprSupported = supportedExprs.getOrElse(exprLowercase, "NS") if (exprSupported == "S") { diff --git a/tools/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala b/tools/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala index 4186ae30463..ff37207246a 100644 --- a/tools/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala +++ b/tools/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.{BeforeAndAfterEach, FunSuite} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, TrampolineUtil} import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.functions.{broadcast, ceil, col, collect_list, explode, hex, sum} +import org.apache.spark.sql.functions.{broadcast, ceil, col, collect_list, count, explode, hex, round, sum} import org.apache.spark.sql.rapids.tool.ToolUtils import org.apache.spark.sql.rapids.tool.qualification.QualificationAppInfo import org.apache.spark.sql.types.StringType @@ -558,6 +558,56 @@ class SQLPlanParserSuite extends FunSuite with BeforeAndAfterEach with Logging { } } + test("Expressions supported in SortAggregateExec") { + TrampolineUtil.withTempDir { eventLogDir => + val (eventLog, _) = ToolTestUtils.generateEventLog(eventLogDir, "sqlmetric") { spark => + import spark.implicits._ + spark.conf.set("spark.sql.execution.useObjectHashAggregateExec", "false") + val df1 = Seq((1, "a"), (1, "aa"), (1, "a"), (2, "b"), + (2, "b"), (3, "c"), (3, "c")).toDF("num", "letter") + df1.groupBy("num").agg(collect_list("letter").as("collected_letters"), + count("letter").as("letter_count")) + } + val pluginTypeChecker = new PluginTypeChecker() + val app = createAppFromEventlog(eventLog) + assert(app.sqlPlans.size == 1) + val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => + SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, pluginTypeChecker, app) + } + val execInfo = getAllExecsFromPlan(parsedPlans.toSeq) + val sortAggregate = execInfo.filter(_.exec == "SortAggregate") + assertSizeAndSupported(2, sortAggregate) + } + } + + test("Expressions supported in SortExec") { + TrampolineUtil.withTempDir { parquetoutputLoc => + TrampolineUtil.withTempDir { eventLogDir => + val (eventLog, _) = ToolTestUtils.generateEventLog(eventLogDir, + "ProjectExprsSupported") { spark => + import spark.implicits._ + val df1 = Seq((1.7, "a"), (1.6, "aa"), (1.1, "b"), (2.5, "a"), (2.2, "b"), + (3.2, "a"), (10.6, "c")).toDF("num", "letter") + df1.write.parquet(s"$parquetoutputLoc/testsortExec") + val df2 = spark.read.parquet(s"$parquetoutputLoc/testsortExec") + df2.sort("num").collect + df2.orderBy("num").collect + df2.select(round(col("num")), col("letter")).sort(round(col("num")), col("letter").desc) + } + val pluginTypeChecker = new PluginTypeChecker() + val app = createAppFromEventlog(eventLog) + assert(app.sqlPlans.size == 4) + val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => + SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, pluginTypeChecker, app) + } + val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) + val sortExec = allExecInfo.filter(_.exec.contains("Sort")) + assert(sortExec.size == 3) + assertSizeAndSupported(3, sortExec, 5.2) + } + } + } + test("Expressions supported in ProjectExec") { TrampolineUtil.withTempDir { parquetoutputLoc => TrampolineUtil.withTempDir { eventLogDir =>