diff --git a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala index a0ca80149b2..dfe80e071a6 100644 --- a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala +++ b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala @@ -197,7 +197,7 @@ object BenchUtils { val elapsed = NANOSECONDS.toMillis(end - start) println(s"*** Iteration $i failed after $elapsed msec.") queryTimes.append(-1) - exceptions.append(BenchUtils.toString(e)) + exceptions.append(BenchUtils.stackTraceAsString(e)) e.printStackTrace() } } @@ -241,9 +241,19 @@ object BenchUtils { sparkConf = df.sparkSession.conf.getAll, getSparkVersion) + // if the query plan is invalid, referencing the `executedPlan` lazy val + // can throw an exception + val executedPlanStr = try { + df.queryExecution.executedPlan.toString() + } catch { + case e: Exception => + exceptions.append(stackTraceAsString(e)) + "Failed to capture executedPlan - see exceptions in report" + } + val queryPlan = QueryPlan( df.queryExecution.logical.toString(), - df.queryExecution.executedPlan.toString() + executedPlanStr ) val report = resultsAction match { @@ -662,7 +672,7 @@ object BenchUtils { } } - def toString(e: Exception): String = { + def stackTraceAsString(e: Throwable): String = { val sw = new StringWriter() val w = new PrintWriter(sw) e.printStackTrace(w) @@ -676,12 +686,21 @@ class BenchmarkListener( exceptions: ListBuffer[String]) extends QueryExecutionListener { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - queryPlans += toJson(qe.executedPlan) + addQueryPlan(qe) } override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { - queryPlans += toJson(qe.executedPlan) - exceptions += BenchUtils.toString(exception) + addQueryPlan(qe) + exceptions += BenchUtils.stackTraceAsString(exception) + } + + private def addQueryPlan(qe: QueryExecution) = { + try { + queryPlans += toJson(qe.executedPlan) + } catch { + case e: Exception => + exceptions.append(BenchUtils.stackTraceAsString(e)) + } } private def toJson(plan: SparkPlan): SparkPlanNode = {