Skip to content

Commit

Permalink
Qualification tool support recognizing decimal operations (#2928)
Browse files Browse the repository at this point in the history
* add decimal check to potential issues

Signed-off-by: Thomas Graves <tgraves@apache.org>

* add another test

* remove unneeded test and updated readme

Signed-off-by: Thomas Graves <tgraves@nvidia.com>

* Update tests

Signed-off-by: Thomas Graves <tgraves@nvidia.com>
  • Loading branch information
tgravescs authored Jul 15, 2021
1 parent bb8baab commit c8cc36a
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 62 deletions.
5 changes: 3 additions & 2 deletions tools/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ outputs this same report to STDOUT.
The other file is a CSV file that contains more information and can be used for further post processing.

Note, potential problems are reported in the CSV file in a separate column, which is not included in the score. This
currently only includes some UDFs. The tool won't catch all UDFs, and some of the UDFs can be handled with additional steps.
Please refer to [supported_ops.md](../docs/supported_ops.md) for more details on UDF.
currently includes some UDFs and some decimal operations. The tool won't catch all UDFs, and some of the UDFs can be
handled with additional steps. Please refer to [supported_ops.md](../docs/supported_ops.md) for more details on UDF.
For decimals, it tries to recognize decimal operations but it may not catch them all.

The CSV output also contains a `Executor CPU Time Percent` column that is not included in the score. This is an estimate
at how much time the tasks spent doing processing on the CPU vs waiting on IO. This is not always a good indicator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,19 @@ abstract class AppBase(
}
}

protected def findPotentialIssues(desc: String): Option[String] = {
desc match {
case u if u.matches(".*UDF.*") => Some("UDF")
case _ => None
}
// Decimal support on the GPU is limited to less than 18 digits and decimals
// are configured off by default for now. It would be nice to have this
// based off of what plugin supports at some point.
private val decimalKeyWords = Map(".*promote_precision\\(.*" -> "DECIMAL",
".*decimal\\([0-9]+,[0-9]+\\).*" -> "DECIMAL",
".*DecimalType\\([0-9]+,[0-9]+\\).*" -> "DECIMAL")

private val UDFKeywords = Map(".*UDF.*" -> "UDF")

protected def findPotentialIssues(desc: String): Set[String] = {
val potentialIssuesRegexs = UDFKeywords ++ decimalKeyWords
val issues = potentialIssuesRegexs.filterKeys(desc.matches(_))
issues.values.toSet
}

def getPlanMetaWithSchema(planInfo: SparkPlanInfo): Seq[SparkPlanInfo] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -953,52 +953,6 @@ class ApplicationInfo(
|""".stripMargin
}

def qualificationDurationNoMetricsSQL: String = {
s"""select
|first(appName) as `App Name`,
|'$appId' as `App ID`,
|ROUND((sum(sqlQualDuration) * 100) / first(app.duration), 2) as Score,
|concat_ws(",", collect_set(problematic)) as `Potential Problems`,
|sum(sqlQualDuration) as `SQL Dataframe Duration`,
|first(app.duration) as `App Duration`,
|first(app.endDurationEstimated) as `App Duration Estimated`
|from sqlDF_$index sq, appdf_$index app
|where sq.sqlID not in ($sqlIdsForUnsuccessfulJobs)
|""".stripMargin
}

// only include jobs that are marked as succeeded
def qualificationDurationSQL: String = {
s"""select
|$index as appIndex,
|'$appId' as appID,
|app.appName,
|sq.sqlID, sq.description,
|sq.sqlQualDuration as dfDuration,
|app.duration as appDuration,
|app.endDurationEstimated as appEndDurationEstimated,
|problematic as potentialProblems,
|m.executorCPUTime,
|m.executorRunTime
|from sqlDF_$index sq, appdf_$index app
|left join sqlAggMetricsDF m on $index = m.appIndex and sq.sqlID = m.sqlID
|where sq.sqlID not in ($sqlIdsForUnsuccessfulJobs)
|""".stripMargin
}

def qualificationDurationSumSQL: String = {
s"""select first(appName) as `App Name`,
|'$appId' as `App ID`,
|ROUND((sum(dfDuration) * 100) / first(appDuration), 2) as Score,
|concat_ws(",", collect_set(potentialProblems)) as `Potential Problems`,
|sum(dfDuration) as `SQL Dataframe Duration`,
|first(appDuration) as `App Duration`,
|round(sum(executorCPUTime)/sum(executorRunTime)*100,2) as `Executor CPU Time Percent`,
|first(appEndDurationEstimated) as `App Duration Estimated`
|from (${qualificationDurationSQL.stripLineEnd})
|""".stripMargin
}

def profilingDurationSQL: String = {
s"""select
|$index as appIndex,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class QualAppInfo(
val jobIdToSqlID: HashMap[Int, Long] = HashMap.empty[Int, Long]
val sqlIDtoJobFailures: HashMap[Long, ArrayBuffer[Int]] = HashMap.empty[Long, ArrayBuffer[Int]]

val problematicSQL: ArrayBuffer[ProblematicSQLCase] = ArrayBuffer[ProblematicSQLCase]()
val sqlIDtoProblematic: HashMap[Long, Set[String]] = HashMap[Long, Set[String]]()

// SQL containing any Dataset operation
val sqlIDToDataSetCase: HashSet[Long] = HashSet[Long]()
Expand Down Expand Up @@ -118,6 +118,10 @@ class QualAppInfo(
}.values.sum
}

private def probNotDataset: HashMap[Long, Set[String]] = {
sqlIDtoProblematic.filterNot { case (sqlID, _) => sqlIDToDataSetCase.contains(sqlID) }
}

// The total task time for all tasks that ran during SQL dataframe
// operations. if the SQL contains a dataset, it isn't counted.
private def calculateTaskDataframeDuration: Long = {
Expand All @@ -128,12 +132,12 @@ class QualAppInfo(
}

private def getPotentialProblems: String = {
problematicSQL.map(_.reason).toSet.mkString(",")
probNotDataset.values.flatten.toSet.mkString(":")
}

private def getSQLDurationProblematic: Long = {
problematicSQL.map { prob =>
sqlDurationTime.getOrElse(prob.sqlID, 0L)
probNotDataset.keys.map { sqlId =>
sqlDurationTime.getOrElse(sqlId, 0L)
}.sum
}

Expand Down Expand Up @@ -219,8 +223,10 @@ class QualAppInfo(
if (isDataSetPlan(node.desc)) {
sqlIDToDataSetCase += sqlID
}
findPotentialIssues(node.desc).foreach { issues =>
problematicSQL += ProblematicSQLCase(sqlID, issues)
val issues = findPotentialIssues(node.desc)
if (issues.nonEmpty) {
val existingIssues = sqlIDtoProblematic.getOrElse(sqlID, Set.empty[String])
sqlIDtoProblematic(sqlID) = existingIssues ++ issues
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ class QualEventProcessor() extends EventProcessorBase {
}
app.sqlDurationTime += (event.executionId -> 0)
} else {
// if start time not there, use 0 for duration
val startTime = sqlInfo.map(_.startTime).getOrElse(0L)
// if start time not there, use event end time so duration is 0
val startTime = sqlInfo.map(_.startTime).getOrElse(event.time)
val sqlDuration = event.time - startTime
app.sqlDurationTime += (event.executionId -> sqlDuration)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
App Name,App ID,Score,Potential Problems,SQL DF Duration,SQL Dataframe Task Duration,App Duration,Executor CPU Time Percent,App Duration Estimated,SQL Duration with Potential Problems,SQL Ids with Failures,Read Score Percent,Read File Format Score,Unsupported Read File Formats and Types
Spark shell,local-1626104300434,1211.93,"",2429,1469,131104,88.35,false,0,"",20,12.5,Parquet[decimal];ORC[map:array:struct:decimal]
Spark shell,local-1626104300434,1211.93,"DECIMAL",2429,1469,131104,88.35,false,160,"",20,12.5,Parquet[decimal];ORC[map:array:struct:decimal]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
App Name,App ID,Score,Potential Problems,SQL DF Duration,SQL Dataframe Task Duration,App Duration,Executor CPU Time Percent,App Duration Estimated,SQL Duration with Potential Problems,SQL Ids with Failures,Read Score Percent,Read File Format Score,Unsupported Read File Formats and Types
Spark shell,local-1626189209260,1052.3,DECIMAL,1314,1238,106033,57.21,false,1023,"",20,25.0,Parquet[decimal]
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.scalatest.{BeforeAndAfterEach, FunSuite}
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted, SparkListenerTaskEnd}
import org.apache.spark.sql.{DataFrame, SparkSession, TrampolineUtil}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.rapids.tool.{AppFilterImpl, ToolUtils}
import org.apache.spark.sql.rapids.tool.qualification.QualificationSummaryInfo
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -316,6 +317,92 @@ class QualificationSuite extends FunSuite with BeforeAndAfterEach with Logging {
runQualificationTest(logFiles, "nds_q86_fail_test_expectation.csv")
}

// this event log has both decimal and non-decimal so comes out partial
// it has both reading decimal, multiplication and join on decimal
test("test decimal problematic") {
val logFiles = Array(s"$logDir/decimal_part_eventlog.zstd")
runQualificationTest(logFiles, "decimal_part_expectation.csv")
}

private def createDecFile(spark: SparkSession, dir: String): Unit = {
import spark.implicits._
val dfGen = Seq("1.32").toDF("value")
.selectExpr("CAST(value AS DECIMAL(4, 2)) AS value")
dfGen.write.parquet(dir)
}

test("test decimal generate udf same") {
TrampolineUtil.withTempDir { outpath =>

TrampolineUtil.withTempDir { eventLogDir =>
val tmpParquet = s"$outpath/decparquet"
createDecFile(sparkSession, tmpParquet)

val eventLog = ToolTestUtils.generateEventLog(eventLogDir, "dot") { spark =>
val plusOne = udf((x: Int) => x + 1)
import spark.implicits._
spark.udf.register("plusOne", plusOne)
val df = spark.read.parquet(tmpParquet)
val df2 = df.withColumn("mult", $"value" * $"value")
val df4 = df2.withColumn("udfcol", plusOne($"value"))
df4
}

val allArgs = Array(
"--output-directory",
outpath.getAbsolutePath())
val appArgs = new QualificationArgs(allArgs ++ Array(eventLog))
val (exit, appSum) = QualificationMain.mainInternal(appArgs)
assert(exit == 0)
assert(appSum.size == 1)
val probApp = appSum.head
assert(probApp.potentialProblems.contains("UDF") &&
probApp.potentialProblems.contains("DECIMAL"))
assert(probApp.sqlDataFrameDuration == probApp.sqlDurationForProblematic)
}
}
}

test("test decimal generate udf different sql ops") {
TrampolineUtil.withTempDir { outpath =>

TrampolineUtil.withTempDir { eventLogDir =>
val tmpParquet = s"$outpath/decparquet"
createDecFile(sparkSession, tmpParquet)

val eventLog = ToolTestUtils.generateEventLog(eventLogDir, "dot") { spark =>
val plusOne = udf((x: Int) => x + 1)
import spark.implicits._
spark.udf.register("plusOne", plusOne)
val df = spark.read.parquet(tmpParquet)
val df2 = df.withColumn("mult", $"value" * $"value")
// first run sql op with decimal only
df2.collect()
// run a separate sql op using just udf
spark.sql("SELECT plusOne(5)").collect()
// Then run another sql op that doesn't use with decimal or udf
import spark.implicits._
val t1 = Seq((1, 2), (3, 4)).toDF("a", "b")
t1.createOrReplaceTempView("t1")
spark.sql("SELECT a, MAX(b) FROM t1 GROUP BY a ORDER BY a")
}

val allArgs = Array(
"--output-directory",
outpath.getAbsolutePath())
val appArgs = new QualificationArgs(allArgs ++ Array(eventLog))
val (exit, appSum) = QualificationMain.mainInternal(appArgs)
assert(exit == 0)
assert(appSum.size == 1)
val probApp = appSum.head
assert(probApp.potentialProblems.contains("UDF") &&
probApp.potentialProblems.contains("DECIMAL"))
assert(probApp.sqlDurationForProblematic > 0)
assert(probApp.sqlDataFrameDuration > probApp.sqlDurationForProblematic)
}
}
}

test("test read datasource v1") {
val profileLogDir = ToolTestUtils.getTestResourcePath("spark-events-profiling")
val logFiles = Array(s"$profileLogDir/eventlog_dsv1.zstd")
Expand Down

0 comments on commit c8cc36a

Please sign in to comment.