Skip to content

Commit

Permalink
Fix Spark UT issues in RapidsDataFrameAggregateSuite (#10943)
Browse files Browse the repository at this point in the history
* Fix Spark UT issues in RapidsDataFrameAggregateSuite

Signed-off-by: Haoyang Li <haoyangl@nvidia.com>

* Added SPARK-24788 back

Signed-off-by: Haoyang Li <haoyangl@nvidia.com>

---------

Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
  • Loading branch information
thirtiseven authored Jun 8, 2024
1 parent c7129f5 commit 18c2579
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,67 @@
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.suites

import org.apache.spark.sql.DataFrameAggregateSuite
import org.apache.spark.sql.{DataFrameAggregateSuite, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.rapids.utils.RapidsSQLTestsTrait
import org.apache.spark.sql.types._

class RapidsDataFrameAggregateSuite extends DataFrameAggregateSuite with RapidsSQLTestsTrait {
// example to show how to replace the logic of an excluded test case in Vanilla Spark
testRapids("collect functions" ) { // "collect functions" was excluded at RapidsTestSettings
// println("...")
import testImplicits._

testRapids("collect functions") {
val df = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b")
checkAnswer(
df.select(sort_array(collect_list($"a")), sort_array(collect_list($"b"))),
Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4)))
)
checkAnswer(
df.select(sort_array(collect_set($"a")), sort_array(collect_set($"b"))),
Seq(Row(Seq(1, 2, 3), Seq(2, 4)))
)

checkDataset(
df.select(sort_array(collect_set($"a")).as("aSet")).as[Set[Int]],
Set(1, 2, 3))
checkDataset(
df.select(sort_array(collect_set($"b")).as("bSet")).as[Set[Int]],
Set(2, 4))
checkDataset(
df.select(sort_array(collect_set($"a")), sort_array(collect_set($"b")))
.as[(Set[Int], Set[Int])], Seq(Set(1, 2, 3) -> Set(2, 4)): _*)
}

testRapids("collect functions structs") {
val df = Seq((1, 2, 2), (2, 2, 2), (3, 4, 1))
.toDF("a", "x", "y")
.select($"a", struct($"x", $"y").as("b"))
checkAnswer(
df.select(sort_array(collect_list($"a")), sort_array(collect_list($"b"))),
Seq(Row(Seq(1, 2, 3), Seq(Row(2, 2), Row(2, 2), Row(4, 1))))
)
checkAnswer(
df.select(sort_array(collect_set($"a")), sort_array(collect_set($"b"))),
Seq(Row(Seq(1, 2, 3), Seq(Row(2, 2), Row(4, 1))))
)
}

testRapids("SPARK-17641: collect functions should not collect null values") {
val df = Seq(("1", 2), (null, 2), ("1", 4)).toDF("a", "b")
checkAnswer(
df.select(sort_array(collect_list($"a")), sort_array(collect_list($"b"))),
Seq(Row(Seq("1", "1"), Seq(2, 2, 4)))
)
checkAnswer(
df.select(sort_array(collect_set($"a")), sort_array(collect_set($"b"))),
Seq(Row(Seq("1"), Seq(2, 4)))
)
}

testRapids("collect functions should be able to cast to array type with no null values") {
val df = Seq(1, 2).toDF("a")
checkAnswer(df.select(sort_array(collect_list("a")) cast ArrayType(IntegerType, false)),
Seq(Row(Seq(1, 2))))
checkAnswer(df.select(sort_array(collect_set("a")) cast ArrayType(FloatType, false)),
Seq(Row(Seq(1.0, 2.0))))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ abstract class BackendTestSettings {
// or a description like "This simply can't work on GPU".
// It should never be "unknown" or "need investigation"
case class KNOWN_ISSUE(reason: String) extends ExcludeReason
case class ADJUST_UT(reason: String) extends ExcludeReason
case class WONT_FIX_ISSUE(reason: String) extends ExcludeReason


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ class RapidsTestSettings extends BackendTestSettings {
.exclude("cast string to timestamp", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10771"))
.exclude("cast string to date", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10771"))
enableSuite[RapidsDataFrameAggregateSuite]
.exclude("collect functions", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10772"))
.exclude("collect functions structs", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10772"))
.exclude("collect functions should be able to cast to array type with no null values", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10772"))
.exclude("SPARK-17641: collect functions should not collect null values", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10772"))
.exclude("SPARK-19471: AggregationIterator does not initialize the generated result projection before using it", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10772"))
.exclude("collect functions", ADJUST_UT("order of elements in the array is non-deterministic in collect"))
.exclude("collect functions structs", ADJUST_UT("order of elements in the array is non-deterministic in collect"))
.exclude("collect functions should be able to cast to array type with no null values", ADJUST_UT("order of elements in the array is non-deterministic in collect"))
.exclude("SPARK-17641: collect functions should not collect null values", ADJUST_UT("order of elements in the array is non-deterministic in collect"))
.exclude("SPARK-19471: AggregationIterator does not initialize the generated result projection before using it", WONT_FIX_ISSUE("Codegen related UT, not applicable for GPU"))
.exclude("SPARK-24788: RelationalGroupedDataset.toString with unresolved exprs should not fail", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10801"))
enableSuite[RapidsJsonExpressionsSuite]
.exclude("from_json - invalid data", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10849"))
Expand Down

0 comments on commit 18c2579

Please sign in to comment.