Skip to content

Commit

Permalink
Qualification tool: Parse expressions in WindowExec (#6106)
Browse files Browse the repository at this point in the history
Signed-off-by: Niranjan Artal <nartal@nvidia.com>
  • Loading branch information
nartal1 authored Jul 27, 2022
1 parent 3aceabc commit 9ebff9b
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,41 @@ object SQLPlanParser extends Logging {
parsedExpressions.distinct.toArray
}

def parseWindowExpressions(exprStr:String): Array[String] = {
val parsedExpressions = ArrayBuffer[String]()
// [sum(cast(level#30 as bigint)) windowspecdefinition(device#29, id#28 ASC NULLS FIRST,
// specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS sum#35L,
// row_number() windowspecdefinition(device#29, id#28 ASC NULLS FIRST, specifiedwindowframe
// (RowFrame, unboundedpreceding$(), currentrow$())) AS row_number#41], [device#29],
// [id#28 ASC NULLS FIRST]

// This splits the string to get only the expressions in WindowExec. So we first split the
// string on closing bracket ] and get the first element from the array. This is followed
// by removing the first and last parenthesis and removing the cast as it is not an expr.
// Lastly we split the string by keyword windowsspecdefinition so that each array element
// except the last element contains one window aggregate function.
// sum(level#30 as bigint))
// (device#29, id#28 ASC NULLS FIRST, ..... AS sum#35L, row_number()
// (device#29, id#28 ASC NULLS FIRST, ...... AS row_number#41
val windowExprs = exprStr.split("(?<=\\])")(0).
trim.replaceAll("""^\[+""", "").replaceAll("""\]+$""", "").
replaceAll("cast\\(", "").split("windowspecdefinition").map(_.trim)
val functionPattern = """(\w+)\(""".r

// Get functionname from each array element except the last one as it doesn't contain
// any window function
for ( i <- 0 to windowExprs.size - 1 ) {
val windowFunc = functionPattern.findAllIn(windowExprs(i)).toList
val expr = windowFunc(windowFunc.size -1)
val functionName = getFunctionName(functionPattern, expr)
functionName match {
case Some(func) => parsedExpressions += func
case _ => // NO OP
}
}
parsedExpressions.distinct.toArray
}

def parseSortExpressions(exprStr: String): Array[String] = {
val parsedExpressions = ArrayBuffer[String]()
// Sort [round(num#126, 0) ASC NULLS FIRST, letter#127 DESC NULLS LAST], true, 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ case class WindowExecParser(
override def parse: ExecInfo = {
// Window doesn't have duration
val duration = None
val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName)) {
val exprString = node.desc.replaceFirst("Window ", "")
val expressions = SQLPlanParser.parseWindowExpressions(exprString)
val isAllExprsSupported = expressions.forall(expr => checker.isExprSupported(expr))
val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName) &&
isAllExprsSupported) {
(checker.getSpeedupFactor(fullExecName), true)
} else {
(1.0, false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.scalatest.{BeforeAndAfterEach, FunSuite}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, TrampolineUtil}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{broadcast, ceil, col, collect_list, count, explode, hex, round, sum}
import org.apache.spark.sql.functions.{broadcast, ceil, col, collect_list, count, explode, hex, round, row_number, sum}
import org.apache.spark.sql.rapids.tool.ToolUtils
import org.apache.spark.sql.rapids.tool.qualification.QualificationAppInfo
import org.apache.spark.sql.types.StringType
Expand Down Expand Up @@ -471,16 +471,16 @@ class SQLPlanParserSuite extends FunSuite with BeforeAndAfterEach with Logging {
}
}

test("WindowExec") {
test("WindowExec and expressions within WIndowExec") {
TrampolineUtil.withTempDir { eventLogDir =>
val (eventLog, _) = ToolTestUtils.generateEventLog(eventLogDir, "sqlmetric") { spark =>
import spark.implicits._
val metrics = Seq(
(0, 0, 0), (1, 0, 1), (2, 5, 2), (3, 0, 3), (4, 0, 1), (5, 5, 3), (6, 5, 0)
).toDF("id", "device", "level")
val rangeWithTwoDevicesById = Window.partitionBy('device).orderBy('id).
rangeBetween(start = -1, end = Window.currentRow)
val rangeWithTwoDevicesById = Window.partitionBy('device).orderBy('id)
metrics.withColumn("sum", sum('level) over rangeWithTwoDevicesById)
.withColumn("row_number", row_number.over(rangeWithTwoDevicesById))
}
val pluginTypeChecker = new PluginTypeChecker()
val app = createAppFromEventlog(eventLog)
Expand Down

0 comments on commit 9ebff9b

Please sign in to comment.