diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala index c7ab100ecc9..cdecc8050d6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala @@ -36,7 +36,7 @@ import com.nvidia.spark.rapids.ParquetPartitionReader.CopyRange import com.nvidia.spark.rapids.RapidsConf.ParquetFooterReaderType import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.jni.ParquetFooter -import com.nvidia.spark.rapids.shims.{ParquetFieldIdShims, SparkShimImpl} +import com.nvidia.spark.rapids.shims.{GpuTypeShims, ParquetFieldIdShims, SparkShimImpl} import java.util import org.apache.commons.io.output.{CountingOutputStream, NullOutputStream} import org.apache.hadoop.conf.Configuration @@ -784,9 +784,10 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte case PrimitiveTypeName.BOOLEAN if dt == DataTypes.BooleanType => return - // TODO: add YearMonthIntervalType case PrimitiveTypeName.INT32 => - if (dt == DataTypes.IntegerType || canReadAsIntDecimal(pt, dt)) { + if (dt == DataTypes.IntegerType || GpuTypeShims.isSupportedYearMonthType(dt) + || canReadAsIntDecimal(pt, dt)) { + // Year-month interval type is stored as int32 in parquet return } // TODO: After we deprecate Spark 3.1, replace OriginalType with LogicalTypeAnnotation @@ -797,9 +798,10 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte return } - // TODO: add DayTimeIntervalType case PrimitiveTypeName.INT64 => - if (dt == DataTypes.LongType || canReadAsLongDecimal(pt, dt)) { + if (dt == DataTypes.LongType || GpuTypeShims.isSupportedDayTimeType(dt) || + // Day-time interval type is stored as int64 in parquet + canReadAsLongDecimal(pt, dt)) { return } // TODO: After we deprecate Spark 3.1, replace OriginalType with LogicalTypeAnnotation diff --git a/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalSuite.scala b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalSuite.scala index cddc26f1ce1..7d5d3e787ed 100644 --- a/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalSuite.scala +++ b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalSuite.scala @@ -68,8 +68,6 @@ class IntervalSuite extends SparkQueryCompareTestSuite { } test("test GPU cache interval when reading/writing parquet") { - // temporarily skip the test on Spark 3.3.0-https://github.com/NVIDIA/spark-rapids/issues/5497 - assumePriorToSpark330 setPCBS() val tempFile = File.createTempFile("pcbs", ".parquet") @@ -91,8 +89,6 @@ class IntervalSuite extends SparkQueryCompareTestSuite { // CPU write a parquet, then test the reading between CPU and GPU test("test ANSI interval read") { - // temporarily skip the test on Spark 3.3.0-https://github.com/NVIDIA/spark-rapids/issues/5497 - assumePriorToSpark330 val tmpFile = File.createTempFile("interval", ".parquet") try { withCpuSparkSession(spark => getDF(spark).coalesce(1) @@ -109,8 +105,6 @@ class IntervalSuite extends SparkQueryCompareTestSuite { // GPU write a parquet, then test the reading between CPU and GPU test("test ANSI interval write") { - // temporarily skip the test on Spark 3.3.0-https://github.com/NVIDIA/spark-rapids/issues/5497 - assumePriorToSpark330 val tmpFile = File.createTempFile("interval", ".parquet") try { withGpuSparkSession(spark => getDF(spark).coalesce(1)