diff --git a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala index e9140b81a5a..924b8f57cc2 100644 --- a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala +++ b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala @@ -31,6 +31,7 @@ import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION} import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession} import org.apache.spark.sql.execution.{InputAdapter, QueryExecution, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} +import org.apache.spark.sql.execution.datasources.FilePartition import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.DataTypes import org.apache.spark.sql.util.QueryExecutionListener @@ -380,6 +381,8 @@ object BenchUtils { * * @param df1 DataFrame to compare. * @param df2 DataFrame to compare. + * @param readPathAction Function to create DataFrame from a path when reading individual + * partitions from a partitioned data source. * @param ignoreOrdering Sort the data collected from the DataFrames before comparing them. * @param useIterator When set to true, use `toLocalIterator` to load one partition at a time * into driver memory, reducing memory usage at the cost of performance @@ -390,6 +393,7 @@ object BenchUtils { def compareResults( df1: DataFrame, df2: DataFrame, + readPathAction: String => DataFrame, ignoreOrdering: Boolean, useIterator: Boolean = false, maxErrors: Int = 10, @@ -401,14 +405,15 @@ object BenchUtils { if (count1 == count2) { println(s"Both DataFrames contain $count1 rows") - if (!ignoreOrdering && (df1.rdd.getNumPartitions > 1 || df2.rdd.getNumPartitions > 1)) { - throw new IllegalStateException("Cannot run with ignoreOrdering=false because one or " + - "more inputs have multiple partitions") + val (result1, result2) = if (!ignoreOrdering && + (df1.rdd.getNumPartitions > 1 || df2.rdd.getNumPartitions > 1)) { + (collectPartitioned(df1, readPathAction), + collectPartitioned(df2, readPathAction)) + } else { + (collectResults(df1, ignoreOrdering, useIterator), + collectResults(df2, ignoreOrdering, useIterator)) } - val result1 = collectResults(df1, ignoreOrdering, useIterator) - val result2 = collectResults(df2, ignoreOrdering, useIterator) - var errors = 0 var i = 0 while (i < count1 && errors < maxErrors) { @@ -441,12 +446,16 @@ object BenchUtils { // apply sorting if specified val resultDf = if (ignoreOrdering) { - // let Spark do the sorting + // let Spark do the sorting, sorting by non-float columns first, then float columns val nonFloatCols = df.schema.fields .filter(field => !(field.dataType == DataTypes.FloatType || field.dataType == DataTypes.DoubleType)) .map(field => col(field.name)) - df.sort(nonFloatCols: _*) + val floatCols = df.schema.fields + .filter(field => field.dataType == DataTypes.FloatType || + field.dataType == DataTypes.DoubleType) + .map(field => col(field.name)) + df.sort((nonFloatCols ++ floatCols): _*) } else { df } @@ -466,6 +475,23 @@ object BenchUtils { it.map(_.toSeq) } + /** + * Collect data from a partitioned data source, preserving order by reading files in + * alphabetical order. + */ + private def collectPartitioned( + df: DataFrame, + readPathAction: String => DataFrame): Iterator[Seq[Any]] = { + val files = df.rdd.partitions.flatMap { + case p: FilePartition => p.files + case other => + throw new RuntimeException(s"Expected FilePartition, found ${other.getClass}") + } + files.map(_.filePath).sorted.flatMap(path => { + readPathAction(path).collect() + }).toIterator.map(_.toSeq) + } + private def rowEqual(row1: Seq[Any], row2: Seq[Any], epsilon: Double): Boolean = { row1.zip(row2).forall { case (l, r) => compare(l, r, epsilon) diff --git a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/CompareResults.scala b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/CompareResults.scala index 7b54ca4f8c8..42047d533a6 100644 --- a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/CompareResults.scala +++ b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/CompareResults.scala @@ -41,7 +41,12 @@ object CompareResults { def main(arg: Array[String]): Unit = { val conf = new Conf(arg) - val spark = SparkSession.builder.appName("CompareResults").getOrCreate() + val spark = SparkSession.builder + .appName("CompareResults") + // disable plugin so that we can see FilePartition rather than DatasourceRDDPartition and + // can retrieve individual partition filenames + .config("spark.rapids.sql.enabled", "false") + .getOrCreate() val (df1, df2) = conf.inputFormat() match { case "csv" => @@ -50,9 +55,17 @@ object CompareResults { (spark.read.parquet(conf.input1()), spark.read.parquet(conf.input2())) } + val readPathAction = conf.inputFormat() match { + case "csv" => + path: String => spark.read.csv(path) + case "parquet" => + path: String => spark.read.parquet(path) + } + BenchUtils.compareResults( df1, df2, + readPathAction, conf.ignoreOrdering(), conf.useIterator(), conf.maxErrors(),