diff --git a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala index 5048ee29f29..c79bd573260 100644 --- a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala +++ b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala @@ -27,7 +27,8 @@ import org.json4s.DefaultFormats import org.json4s.jackson.JsonMethods.parse import org.json4s.jackson.Serialization.writePretty -import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION} +import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION, Success, TaskEndReason} +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession} import org.apache.spark.sql.execution.{InputAdapter, QueryExecution, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} @@ -38,6 +39,10 @@ import org.apache.spark.sql.util.QueryExecutionListener object BenchUtils { + val STATUS_COMPLETED = "Completed" + val STATUS_COMPLETED_WITH_TASK_FAILURES = "CompletedWithTaskFailures" + val STATUS_FAILED = "Failed" + /** Perform benchmark of calling collect */ def collect( spark: SparkSession, @@ -166,6 +171,7 @@ object BenchUtils { val exceptions = new ListBuffer[String]() var df: DataFrame = null + val queryStatus = new ListBuffer[String]() val queryTimes = new ListBuffer[Long]() for (i <- 0 until iterations) { spark.sparkContext.setJobDescription(s"Benchmark Run: query=$queryDescription; iteration=$i") @@ -186,7 +192,9 @@ object BenchUtils { println(s"$logPrefix Start iteration $i:") val start = System.nanoTime() + val taskFailureListener = new TaskFailureListener try { + spark.sparkContext.addSparkListener(taskFailureListener) df = createDataFrame(spark) resultsAction match { @@ -202,25 +210,35 @@ object BenchUtils { val end = System.nanoTime() val elapsed = NANOSECONDS.toMillis(end - start) queryTimes.append(elapsed) - println(s"$logPrefix Iteration $i took $elapsed msec.") + + val failureOpt = taskFailureListener.taskFailures.headOption + val status = failureOpt.map(_ => STATUS_COMPLETED_WITH_TASK_FAILURES) + .getOrElse(STATUS_COMPLETED) + failureOpt.foreach(failure => exceptions.append(failure.toString)) + + queryStatus.append(status) + println(s"$logPrefix Iteration $i took $elapsed msec. Status: $status.") } catch { case e: Exception => val end = System.nanoTime() val elapsed = NANOSECONDS.toMillis(end - start) println(s"$logPrefix Iteration $i failed after $elapsed msec.") - queryTimes.append(-1) + queryStatus.append(STATUS_FAILED) + queryTimes.append(elapsed) exceptions.append(BenchUtils.stackTraceAsString(e)) e.printStackTrace() + } finally { + spark.sparkContext.removeSparkListener(taskFailureListener) } } // only show query times if there were no failed queries - if (!queryTimes.contains(-1)) { + if (!queryStatus.contains(STATUS_FAILED)) { // summarize all query times for (i <- 0 until iterations) { - println(s"$logPrefix Iteration $i took ${queryTimes(i)} msec.") + println(s"$logPrefix Iteration $i took ${queryTimes(i)} msec. Status: ${queryStatus(i)}") } // for multiple runs, summarize cold/hot timings @@ -236,8 +254,7 @@ object BenchUtils { } // write results to file - val suffix = if (exceptions.isEmpty) "" else "-failed" - val filename = s"$filenameStub-${queryStartTime.toEpochMilli}$suffix.json" + val filename = s"$filenameStub-${queryStartTime.toEpochMilli}.json" println(s"$logPrefix Saving benchmark report to $filename") // try not to leak secrets @@ -269,58 +286,35 @@ object BenchUtils { executedPlanStr ) - val report = resultsAction match { - case Collect() => BenchmarkReport( - filename, - queryStartTime.toEpochMilli, - environment, - testConfiguration, - "collect", - Map.empty, - queryDescription, - queryPlan, - queryPlansWithMetrics, - queryTimes, - exceptions) - - case w: WriteCsv => BenchmarkReport( - filename, - queryStartTime.toEpochMilli, - environment, - testConfiguration, - "csv", - w.writeOptions, - queryDescription, - queryPlan, - queryPlansWithMetrics, - queryTimes, - exceptions) - - case w: WriteOrc => BenchmarkReport( - filename, - queryStartTime.toEpochMilli, - environment, - testConfiguration, - "orc", - w.writeOptions, - queryDescription, - queryPlan, - queryPlansWithMetrics, - queryTimes, - exceptions) - - case w: WriteParquet => BenchmarkReport( - filename, - queryStartTime.toEpochMilli, - environment, - testConfiguration, - "parquet", - w.writeOptions, - queryDescription, - queryPlan, - queryPlansWithMetrics, - queryTimes, - exceptions) + var report = BenchmarkReport( + filename, + queryStartTime.toEpochMilli, + environment, + testConfiguration, + "", + Map.empty, + queryDescription, + queryPlan, + queryPlansWithMetrics, + queryTimes, + queryStatus, + exceptions) + + report = resultsAction match { + case Collect() => report.copy( + action = "collect") + + case w: WriteCsv => report.copy( + action = "csv", + writeOptions = w.writeOptions) + + case w: WriteOrc => report.copy( + action = "orc", + writeOptions = w.writeOptions) + + case w: WriteParquet => report.copy( + action = "parquet", + writeOptions = w.writeOptions) } writeReport(report, filename) @@ -328,7 +322,7 @@ object BenchUtils { if (generateDotGraph) { queryPlansWithMetrics.headOption match { case Some(plan) => - val filename = s"$filenameStub-${queryStartTime.toEpochMilli}$suffix.dot" + val filename = s"$filenameStub-${queryStartTime.toEpochMilli}.dot" println(s"$logPrefix Saving query plan diagram to $filename") BenchUtils.generateDotGraph(plan, None, filename) case _ => @@ -723,6 +717,20 @@ object BenchUtils { } } +class TaskFailureListener extends SparkListener { + + val taskFailures = new ListBuffer[TaskEndReason]() + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + taskEnd.reason match { + case Success => + case reason => taskFailures += reason + } + super.onTaskEnd(taskEnd) + } + +} + class BenchmarkListener( queryPlans: ListBuffer[SparkPlanNode], exceptions: ListBuffer[String]) extends QueryExecutionListener { @@ -789,6 +797,7 @@ case class BenchmarkReport( queryPlan: QueryPlan, queryPlans: Seq[SparkPlanNode], queryTimes: Seq[Long], + queryStatus: Seq[String], exceptions: Seq[String]) /** Configuration options that affect how the tests are run */ diff --git a/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/common/BenchUtilsSuite.scala b/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/common/BenchUtilsSuite.scala index a0a33de21d1..866d1213ef5 100644 --- a/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/common/BenchUtilsSuite.scala +++ b/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/common/BenchUtilsSuite.scala @@ -53,6 +53,7 @@ class BenchUtilsSuite extends FunSuite with BeforeAndAfterEach { queryPlan = QueryPlan("logical", "physical"), Seq.empty, queryTimes = Seq(99, 88, 77), + queryStatus = Seq("Completed", "Completed", "Completed"), exceptions = Seq.empty) val filename = s"$TEST_FILES_ROOT/BenchUtilsSuite-${System.currentTimeMillis()}.json"