Skip to content

Commit

Permalink
Implement test for qualification tool sql metric aggregates (#2591)
Browse files Browse the repository at this point in the history
* Fix package name

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* Rename to ToolTestUtils

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* Implement basic structure for test

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* Test checks that number of tasks is correct

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* Add more tests

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* Add cpuTime test, commented out

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* Change executorCpuTime logic to match qualification tool

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* use dedicated spark session in test

* create new SparkSession before each test

Signed-off-by: Andy Grove <andygrove@nvidia.com>
  • Loading branch information
andygrove authored Jun 5, 2021
1 parent e65e826 commit 35f910d
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,13 @@

package com.nvidia.spark.rapids.tool.qualification

import java.io.FileWriter

import scala.collection.mutable.ArrayBuffer

import com.nvidia.spark.rapids.tool.profiling._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.fs.Path

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.rapids.tool.profiling._

/**
* A tool to analyze Spark event logs and determine if
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@

package com.nvidia.spark.rapids.tool

import java.io.File

import scala.collection.mutable.ArrayBuffer
import java.io.{File, FilenameFilter, FileNotFoundException}

import com.nvidia.spark.rapids.tool.profiling.{ProfileArgs, ProfileUtils}
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.{DataFrame, SparkSession, TrampolineUtil}
import org.apache.spark.sql.rapids.tool.profiling.ApplicationInfo

object ToolTestUtils extends Logging {
Expand All @@ -36,6 +35,42 @@ object ToolTestUtils extends Logging {
getTestResourceFile(file).getCanonicalPath
}

def generateEventLog(eventLogDir: File, appName: String)
(fun: SparkSession => DataFrame): String = {

// we need to close any existing sessions to ensure that we can
// create a session with a new event log dir
TrampolineUtil.cleanupAnyExistingSession()

lazy val spark = SparkSession
.builder()
.master("local[*]")
.appName(appName)
.config("spark.eventLog.enabled", "true")
.config("spark.eventLog.dir", eventLogDir.getAbsolutePath)
.getOrCreate()

// execute the query and generate events
val df = fun(spark)
df.collect()

// close the event log
spark.close()

// find the event log
val files = listFilesMatching(eventLogDir, !_.startsWith("."))
if (files.length != 1) {
throw new FileNotFoundException(s"Could not find event log in ${eventLogDir.getAbsolutePath}")
}
files.head.getAbsolutePath
}

def listFilesMatching(dir: File, matcher: String => Boolean): Array[File] = {
dir.listFiles(new FilenameFilter {
override def accept(file: File, s: String): Boolean = matcher(s)
})
}

def compareDataFrames(df: DataFrame, expectedDf: DataFrame): Unit = {
val diffCount = df.except(expectedDf).union(expectedDf.except(df)).count
if (diffCount != 0) {
Expand Down Expand Up @@ -66,3 +101,4 @@ object ToolTestUtils extends Logging {
apps
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
*/
package com.nvidia.spark.rapids.tool.profiling

import java.io.{File, FilenameFilter}

import scala.io.Source

import com.google.common.io.Files
import com.nvidia.spark.rapids.tool.ToolTestUtils
import org.scalatest.{BeforeAndAfterAll, FunSuite}

import org.apache.spark.internal.Logging
Expand All @@ -32,76 +30,50 @@ class GenerateDotSuite extends FunSuite with BeforeAndAfterAll with Logging {
}

test("Generate DOT") {
val eventLogDir = Files.createTempDir()
eventLogDir.deleteOnExit()

lazy val spark = SparkSession
.builder()
.master("local[*]")
.appName("Rapids Spark Profiling Tool Unit Tests")
.config("spark.eventLog.enabled", "true")
.config("spark.eventLog.dir", eventLogDir.getAbsolutePath)
.getOrCreate()

// generate some events
import spark.implicits._
val t1 = Seq((1, 2), (3, 4)).toDF("a", "b")
t1.createOrReplaceTempView("t1")
val df = spark.sql("SELECT a, MAX(b) FROM t1 GROUP BY a ORDER BY a")
df.collect()

// close the event log
spark.close()

// find the event log
val files = listFilesMatching(eventLogDir, !_.startsWith("."))
assert(files.length === 1)
val eventLog = files.head.getAbsolutePath

// create new session for tool to use
val spark2 = SparkSession
.builder()
.master("local[*]")
.appName("Rapids Spark Profiling Tool Unit Tests")
.getOrCreate()

val dotFileDir = Files.createTempDir()
dotFileDir.deleteOnExit()

val appArgs = new ProfileArgs(Array(
"--output-directory",
dotFileDir.getAbsolutePath,
"--generate-dot",
eventLog
))
ProfileMain.mainInternal(spark2, appArgs)

// assert that a file was generated
val dotDirs = listFilesMatching(dotFileDir, _.startsWith("local"))
assert(dotDirs.length === 2)
TrampolineUtil.withTempDir { eventLogDir =>
val eventLog = ToolTestUtils.generateEventLog(eventLogDir, "dot") { spark =>
import spark.implicits._
val t1 = Seq((1, 2), (3, 4)).toDF("a", "b")
t1.createOrReplaceTempView("t1")
spark.sql("SELECT a, MAX(b) FROM t1 GROUP BY a ORDER BY a")
}

// assert that the generated files looks something like what we expect
var hashAggCount = 0
for (dir <- dotDirs) {
val dotFiles = listFilesMatching(dir, _.endsWith(".dot"))
assert(dotFiles.length === 1)
val source = Source.fromFile(dotFiles.head)
try {
val lines = source.getLines().toArray
assert(lines.head === "digraph G {")
assert(lines.last === "}")
hashAggCount += lines.count(_.contains("HashAggregate"))
} finally {
source.close()
// create new session for tool to use
val spark2 = SparkSession
.builder()
.master("local[*]")
.appName("Rapids Spark Profiling Tool Unit Tests")
.getOrCreate()

TrampolineUtil.withTempDir { dotFileDir =>
val appArgs = new ProfileArgs(Array(
"--output-directory",
dotFileDir.getAbsolutePath,
"--generate-dot",
eventLog))
ProfileMain.mainInternal(spark2, appArgs)

// assert that a file was generated
val dotDirs = ToolTestUtils.listFilesMatching(dotFileDir, _.startsWith("local"))
assert(dotDirs.length === 2)

// assert that the generated files looks something like what we expect
var hashAggCount = 0
for (dir <- dotDirs) {
val dotFiles = ToolTestUtils.listFilesMatching(dir, _.endsWith(".dot"))
assert(dotFiles.length === 1)
val source = Source.fromFile(dotFiles.head)
try {
val lines = source.getLines().toArray
assert(lines.head === "digraph G {")
assert(lines.last === "}")
hashAggCount += lines.count(_.contains("HashAggregate"))
} finally {
source.close()
}
}
assert(hashAggCount === 2)
}
}
assert(hashAggCount === 2)
}

private def listFilesMatching(dir: File, matcher: String => Boolean): Array[File] = {
dir.listFiles(new FilenameFilter {
override def accept(file: File, s: String): Boolean = matcher(s)
})
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,29 @@ package com.nvidia.spark.rapids.tool.qualification
import java.io.File

import com.nvidia.spark.rapids.tool.ToolTestUtils
import org.scalatest.FunSuite
import org.scalatest.{BeforeAndAfterEach, FunSuite}
import scala.collection.mutable.ListBuffer

import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted, SparkListenerTaskEnd}
import org.apache.spark.sql.{SparkSession, TrampolineUtil}

class QualificationSuite extends FunSuite with Logging {
class QualificationSuite extends FunSuite with BeforeAndAfterEach with Logging {

lazy val sparkSession = {
SparkSession
private var sparkSession: SparkSession = _

private val expRoot = ToolTestUtils.getTestResourceFile("QualificationExpectations")
private val logDir = ToolTestUtils.getTestResourcePath("spark-events-qualification")

override protected def beforeEach(): Unit = {
TrampolineUtil.cleanupAnyExistingSession()
sparkSession = SparkSession
.builder()
.master("local[*]")
.appName("Rapids Spark Profiling Tool Unit Tests")
.getOrCreate()
}

private val expRoot = ToolTestUtils.getTestResourceFile("QualificationExpectations")
private val logDir = ToolTestUtils.getTestResourcePath("spark-events-qualification")

private def runQualificationTest(eventLogs: Array[String], expectFileName: String) = {
Seq(true, false).foreach { hasExecCpu =>
TrampolineUtil.withTempDir { outpath =>
Expand Down Expand Up @@ -87,4 +92,82 @@ class QualificationSuite extends FunSuite with Logging {
val logFiles = Array(s"$logDir/nds_q86_test")
runQualificationTest(logFiles, "nds_q86_test_expectation.csv")
}

test("sql metric agg") {
TrampolineUtil.withTempDir { eventLogDir =>
val listener = new ToolTestListener
val eventLog = ToolTestUtils.generateEventLog(eventLogDir, "sqlmetric") { spark =>
spark.sparkContext.addSparkListener(listener)
import spark.implicits._
val testData = Seq((1, 2), (3, 4)).toDF("a", "b")
testData.createOrReplaceTempView("t1")
testData.createOrReplaceTempView("t2")
spark.sql("SELECT a, MAX(b) FROM (SELECT t1.a, t2.b " +
"FROM t1 JOIN t2 ON t1.a = t2.a) AS t " +
"GROUP BY a ORDER BY a")
}
assert(listener.completedStages.length == 5)

// run the qualification tool
TrampolineUtil.withTempDir { outpath =>

// create new session for tool to use
val spark2 = SparkSession
.builder()
.master("local[*]")
.appName("Rapids Spark Profiling Tool Unit Tests")
.getOrCreate()

val appArgs = new QualificationArgs(Array(
"--include-exec-cpu-percent",
"--output-directory",
outpath.getAbsolutePath,
eventLog))

val (exit, _) =
QualificationMain.mainInternal(spark2, appArgs, writeOutput = false,
dropTempViews = false)
assert(exit == 0)

val df = spark2.table("sqlAggMetricsDF")

def fieldIndex(name: String) = df.schema.fieldIndex(name)

val rows = df.collect()
assert(rows.length === 1)
val collect = rows.head
assert(collect.getString(fieldIndex("description")).startsWith("collect"))

// parse results from listener
val numTasks = listener.completedStages.map(_.stageInfo.numTasks).sum
val executorCpuTime = listener.executorCpuTime
val executorRunTime = listener.completedStages
.map(_.stageInfo.taskMetrics.executorRunTime).sum
val shuffleBytesRead = listener.completedStages
.map(_.stageInfo.taskMetrics.shuffleReadMetrics.localBytesRead).sum
val shuffleBytesWritten = listener.completedStages
.map(_.stageInfo.taskMetrics.shuffleWriteMetrics.bytesWritten).sum

// compare metrics from event log with metrics from listener
assert(collect.getLong(fieldIndex("numTasks")) === numTasks)
assert(collect.getLong(fieldIndex("executorCPUTime")) === executorCpuTime)
assert(collect.getLong(fieldIndex("executorRunTime")) === executorRunTime)
assert(collect.getLong(fieldIndex("sr_localBytesRead_sum")) === shuffleBytesRead)
assert(collect.getLong(fieldIndex("sw_bytesWritten_sum")) === shuffleBytesWritten)
}
}
}
}

class ToolTestListener extends SparkListener {
val completedStages = new ListBuffer[SparkListenerStageCompleted]()
var executorCpuTime = 0L

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
executorCpuTime += taskEnd.taskMetrics.executorCpuTime / 1000000
}

override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
completedStages.append(stageCompleted)
}
}

0 comments on commit 35f910d

Please sign in to comment.