Skip to content

Commit

Permalink
Detect task failures in benchmarks (NVIDIA#1750)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove authored Feb 19, 2021
1 parent 4b55183 commit 5758be5
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ import org.json4s.DefaultFormats
import org.json4s.jackson.JsonMethods.parse
import org.json4s.jackson.Serialization.writePretty

import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION}
import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION, Success, TaskEndReason}
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import org.apache.spark.sql.execution.{InputAdapter, QueryExecution, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
Expand All @@ -38,6 +39,10 @@ import org.apache.spark.sql.util.QueryExecutionListener

object BenchUtils {

val STATUS_COMPLETED = "Completed"
val STATUS_COMPLETED_WITH_TASK_FAILURES = "CompletedWithTaskFailures"
val STATUS_FAILED = "Failed"

/** Perform benchmark of calling collect */
def collect(
spark: SparkSession,
Expand Down Expand Up @@ -166,6 +171,7 @@ object BenchUtils {
val exceptions = new ListBuffer[String]()

var df: DataFrame = null
val queryStatus = new ListBuffer[String]()
val queryTimes = new ListBuffer[Long]()
for (i <- 0 until iterations) {
spark.sparkContext.setJobDescription(s"Benchmark Run: query=$queryDescription; iteration=$i")
Expand All @@ -186,7 +192,9 @@ object BenchUtils {

println(s"$logPrefix Start iteration $i:")
val start = System.nanoTime()
val taskFailureListener = new TaskFailureListener
try {
spark.sparkContext.addSparkListener(taskFailureListener)
df = createDataFrame(spark)

resultsAction match {
Expand All @@ -202,25 +210,35 @@ object BenchUtils {
val end = System.nanoTime()
val elapsed = NANOSECONDS.toMillis(end - start)
queryTimes.append(elapsed)
println(s"$logPrefix Iteration $i took $elapsed msec.")

val failureOpt = taskFailureListener.taskFailures.headOption
val status = failureOpt.map(_ => STATUS_COMPLETED_WITH_TASK_FAILURES)
.getOrElse(STATUS_COMPLETED)
failureOpt.foreach(failure => exceptions.append(failure.toString))

queryStatus.append(status)
println(s"$logPrefix Iteration $i took $elapsed msec. Status: $status.")

} catch {
case e: Exception =>
val end = System.nanoTime()
val elapsed = NANOSECONDS.toMillis(end - start)
println(s"$logPrefix Iteration $i failed after $elapsed msec.")
queryTimes.append(-1)
queryStatus.append(STATUS_FAILED)
queryTimes.append(elapsed)
exceptions.append(BenchUtils.stackTraceAsString(e))
e.printStackTrace()
} finally {
spark.sparkContext.removeSparkListener(taskFailureListener)
}
}

// only show query times if there were no failed queries
if (!queryTimes.contains(-1)) {
if (!queryStatus.contains(STATUS_FAILED)) {

// summarize all query times
for (i <- 0 until iterations) {
println(s"$logPrefix Iteration $i took ${queryTimes(i)} msec.")
println(s"$logPrefix Iteration $i took ${queryTimes(i)} msec. Status: ${queryStatus(i)}")
}

// for multiple runs, summarize cold/hot timings
Expand All @@ -236,8 +254,7 @@ object BenchUtils {
}

// write results to file
val suffix = if (exceptions.isEmpty) "" else "-failed"
val filename = s"$filenameStub-${queryStartTime.toEpochMilli}$suffix.json"
val filename = s"$filenameStub-${queryStartTime.toEpochMilli}.json"
println(s"$logPrefix Saving benchmark report to $filename")

// try not to leak secrets
Expand Down Expand Up @@ -269,66 +286,43 @@ object BenchUtils {
executedPlanStr
)

val report = resultsAction match {
case Collect() => BenchmarkReport(
filename,
queryStartTime.toEpochMilli,
environment,
testConfiguration,
"collect",
Map.empty,
queryDescription,
queryPlan,
queryPlansWithMetrics,
queryTimes,
exceptions)

case w: WriteCsv => BenchmarkReport(
filename,
queryStartTime.toEpochMilli,
environment,
testConfiguration,
"csv",
w.writeOptions,
queryDescription,
queryPlan,
queryPlansWithMetrics,
queryTimes,
exceptions)

case w: WriteOrc => BenchmarkReport(
filename,
queryStartTime.toEpochMilli,
environment,
testConfiguration,
"orc",
w.writeOptions,
queryDescription,
queryPlan,
queryPlansWithMetrics,
queryTimes,
exceptions)

case w: WriteParquet => BenchmarkReport(
filename,
queryStartTime.toEpochMilli,
environment,
testConfiguration,
"parquet",
w.writeOptions,
queryDescription,
queryPlan,
queryPlansWithMetrics,
queryTimes,
exceptions)
var report = BenchmarkReport(
filename,
queryStartTime.toEpochMilli,
environment,
testConfiguration,
"",
Map.empty,
queryDescription,
queryPlan,
queryPlansWithMetrics,
queryTimes,
queryStatus,
exceptions)

report = resultsAction match {
case Collect() => report.copy(
action = "collect")

case w: WriteCsv => report.copy(
action = "csv",
writeOptions = w.writeOptions)

case w: WriteOrc => report.copy(
action = "orc",
writeOptions = w.writeOptions)

case w: WriteParquet => report.copy(
action = "parquet",
writeOptions = w.writeOptions)
}

writeReport(report, filename)

if (generateDotGraph) {
queryPlansWithMetrics.headOption match {
case Some(plan) =>
val filename = s"$filenameStub-${queryStartTime.toEpochMilli}$suffix.dot"
val filename = s"$filenameStub-${queryStartTime.toEpochMilli}.dot"
println(s"$logPrefix Saving query plan diagram to $filename")
BenchUtils.generateDotGraph(plan, None, filename)
case _ =>
Expand Down Expand Up @@ -723,6 +717,20 @@ object BenchUtils {
}
}

class TaskFailureListener extends SparkListener {

val taskFailures = new ListBuffer[TaskEndReason]()

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
taskEnd.reason match {
case Success =>
case reason => taskFailures += reason
}
super.onTaskEnd(taskEnd)
}

}

class BenchmarkListener(
queryPlans: ListBuffer[SparkPlanNode],
exceptions: ListBuffer[String]) extends QueryExecutionListener {
Expand Down Expand Up @@ -789,6 +797,7 @@ case class BenchmarkReport(
queryPlan: QueryPlan,
queryPlans: Seq[SparkPlanNode],
queryTimes: Seq[Long],
queryStatus: Seq[String],
exceptions: Seq[String])

/** Configuration options that affect how the tests are run */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class BenchUtilsSuite extends FunSuite with BeforeAndAfterEach {
queryPlan = QueryPlan("logical", "physical"),
Seq.empty,
queryTimes = Seq(99, 88, 77),
queryStatus = Seq("Completed", "Completed", "Completed"),
exceptions = Seq.empty)

val filename = s"$TEST_FILES_ROOT/BenchUtilsSuite-${System.currentTimeMillis()}.json"
Expand Down

0 comments on commit 5758be5

Please sign in to comment.