Skip to content

Commit

Permalink
Fix IntervalSuite cases failure (#5539)
Browse files Browse the repository at this point in the history
* Update checkSchemaCompat for year-month and day-time interval types
Signed-off-by: Chong Gao <res_life@163.com>
  • Loading branch information
res-life authored May 20, 2022
1 parent aabe71a commit 29940af
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import com.nvidia.spark.rapids.ParquetPartitionReader.CopyRange
import com.nvidia.spark.rapids.RapidsConf.ParquetFooterReaderType
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.jni.ParquetFooter
import com.nvidia.spark.rapids.shims.{ParquetFieldIdShims, SparkShimImpl}
import com.nvidia.spark.rapids.shims.{GpuTypeShims, ParquetFieldIdShims, SparkShimImpl}
import java.util
import org.apache.commons.io.output.{CountingOutputStream, NullOutputStream}
import org.apache.hadoop.conf.Configuration
Expand Down Expand Up @@ -784,9 +784,10 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte
case PrimitiveTypeName.BOOLEAN if dt == DataTypes.BooleanType =>
return

// TODO: add YearMonthIntervalType
case PrimitiveTypeName.INT32 =>
if (dt == DataTypes.IntegerType || canReadAsIntDecimal(pt, dt)) {
if (dt == DataTypes.IntegerType || GpuTypeShims.isSupportedYearMonthType(dt)
|| canReadAsIntDecimal(pt, dt)) {
// Year-month interval type is stored as int32 in parquet
return
}
// TODO: After we deprecate Spark 3.1, replace OriginalType with LogicalTypeAnnotation
Expand All @@ -797,9 +798,10 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte
return
}

// TODO: add DayTimeIntervalType
case PrimitiveTypeName.INT64 =>
if (dt == DataTypes.LongType || canReadAsLongDecimal(pt, dt)) {
if (dt == DataTypes.LongType || GpuTypeShims.isSupportedDayTimeType(dt) ||
// Day-time interval type is stored as int64 in parquet
canReadAsLongDecimal(pt, dt)) {
return
}
// TODO: After we deprecate Spark 3.1, replace OriginalType with LogicalTypeAnnotation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ class IntervalSuite extends SparkQueryCompareTestSuite {
}

test("test GPU cache interval when reading/writing parquet") {
// temporarily skip the test on Spark 3.3.0-https://github.com/NVIDIA/spark-rapids/issues/5497
assumePriorToSpark330
setPCBS()

val tempFile = File.createTempFile("pcbs", ".parquet")
Expand All @@ -91,8 +89,6 @@ class IntervalSuite extends SparkQueryCompareTestSuite {

// CPU write a parquet, then test the reading between CPU and GPU
test("test ANSI interval read") {
// temporarily skip the test on Spark 3.3.0-https://github.com/NVIDIA/spark-rapids/issues/5497
assumePriorToSpark330
val tmpFile = File.createTempFile("interval", ".parquet")
try {
withCpuSparkSession(spark => getDF(spark).coalesce(1)
Expand All @@ -109,8 +105,6 @@ class IntervalSuite extends SparkQueryCompareTestSuite {

// GPU write a parquet, then test the reading between CPU and GPU
test("test ANSI interval write") {
// temporarily skip the test on Spark 3.3.0-https://github.com/NVIDIA/spark-rapids/issues/5497
assumePriorToSpark330
val tmpFile = File.createTempFile("interval", ".parquet")
try {
withGpuSparkSession(spark => getDF(spark).coalesce(1)
Expand Down

0 comments on commit 29940af

Please sign in to comment.