Skip to content

Commit

Permalink
Compare partitioned files while preserving order (NVIDIA#859)
Browse files Browse the repository at this point in the history
Signed-off-by: Andy Grove <andygrove@nvidia.com>
  • Loading branch information
andygrove authored Oct 9, 2020
1 parent 9d2639a commit b01d3d0
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION}
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import org.apache.spark.sql.execution.{InputAdapter, QueryExecution, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.execution.datasources.FilePartition
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.sql.util.QueryExecutionListener
Expand Down Expand Up @@ -380,6 +381,8 @@ object BenchUtils {
*
* @param df1 DataFrame to compare.
* @param df2 DataFrame to compare.
* @param readPathAction Function to create DataFrame from a path when reading individual
* partitions from a partitioned data source.
* @param ignoreOrdering Sort the data collected from the DataFrames before comparing them.
* @param useIterator When set to true, use `toLocalIterator` to load one partition at a time
* into driver memory, reducing memory usage at the cost of performance
Expand All @@ -390,6 +393,7 @@ object BenchUtils {
def compareResults(
df1: DataFrame,
df2: DataFrame,
readPathAction: String => DataFrame,
ignoreOrdering: Boolean,
useIterator: Boolean = false,
maxErrors: Int = 10,
Expand All @@ -401,14 +405,15 @@ object BenchUtils {
if (count1 == count2) {
println(s"Both DataFrames contain $count1 rows")

if (!ignoreOrdering && (df1.rdd.getNumPartitions > 1 || df2.rdd.getNumPartitions > 1)) {
throw new IllegalStateException("Cannot run with ignoreOrdering=false because one or " +
"more inputs have multiple partitions")
val (result1, result2) = if (!ignoreOrdering &&
(df1.rdd.getNumPartitions > 1 || df2.rdd.getNumPartitions > 1)) {
(collectPartitioned(df1, readPathAction),
collectPartitioned(df2, readPathAction))
} else {
(collectResults(df1, ignoreOrdering, useIterator),
collectResults(df2, ignoreOrdering, useIterator))
}

val result1 = collectResults(df1, ignoreOrdering, useIterator)
val result2 = collectResults(df2, ignoreOrdering, useIterator)

var errors = 0
var i = 0
while (i < count1 && errors < maxErrors) {
Expand Down Expand Up @@ -441,12 +446,16 @@ object BenchUtils {

// apply sorting if specified
val resultDf = if (ignoreOrdering) {
// let Spark do the sorting
// let Spark do the sorting, sorting by non-float columns first, then float columns
val nonFloatCols = df.schema.fields
.filter(field => !(field.dataType == DataTypes.FloatType ||
field.dataType == DataTypes.DoubleType))
.map(field => col(field.name))
df.sort(nonFloatCols: _*)
val floatCols = df.schema.fields
.filter(field => field.dataType == DataTypes.FloatType ||
field.dataType == DataTypes.DoubleType)
.map(field => col(field.name))
df.sort((nonFloatCols ++ floatCols): _*)
} else {
df
}
Expand All @@ -466,6 +475,23 @@ object BenchUtils {
it.map(_.toSeq)
}

/**
* Collect data from a partitioned data source, preserving order by reading files in
* alphabetical order.
*/
private def collectPartitioned(
df: DataFrame,
readPathAction: String => DataFrame): Iterator[Seq[Any]] = {
val files = df.rdd.partitions.flatMap {
case p: FilePartition => p.files
case other =>
throw new RuntimeException(s"Expected FilePartition, found ${other.getClass}")
}
files.map(_.filePath).sorted.flatMap(path => {
readPathAction(path).collect()
}).toIterator.map(_.toSeq)
}

private def rowEqual(row1: Seq[Any], row2: Seq[Any], epsilon: Double): Boolean = {
row1.zip(row2).forall {
case (l, r) => compare(l, r, epsilon)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ object CompareResults {
def main(arg: Array[String]): Unit = {
val conf = new Conf(arg)

val spark = SparkSession.builder.appName("CompareResults").getOrCreate()
val spark = SparkSession.builder
.appName("CompareResults")
// disable plugin so that we can see FilePartition rather than DatasourceRDDPartition and
// can retrieve individual partition filenames
.config("spark.rapids.sql.enabled", "false")
.getOrCreate()

val (df1, df2) = conf.inputFormat() match {
case "csv" =>
Expand All @@ -50,9 +55,17 @@ object CompareResults {
(spark.read.parquet(conf.input1()), spark.read.parquet(conf.input2()))
}

val readPathAction = conf.inputFormat() match {
case "csv" =>
path: String => spark.read.csv(path)
case "parquet" =>
path: String => spark.read.parquet(path)
}

BenchUtils.compareResults(
df1,
df2,
readPathAction,
conf.ignoreOrdering(),
conf.useIterator(),
conf.maxErrors(),
Expand Down

0 comments on commit b01d3d0

Please sign in to comment.