diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala index bca35f180c8..86da3ce5fd6 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala @@ -214,7 +214,6 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { case _: DayTimeIntervalType => // Supported } } - checkTimeZoneId(timeAdd.timeZoneId) } override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark330PlusShims.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark330PlusShims.scala index f639ef144d0..369f6ceb540 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark330PlusShims.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark330PlusShims.scala @@ -162,7 +162,6 @@ trait Spark330PlusShims extends Spark321PlusShims with Spark320PlusNonDBShims { case _: DayTimeIntervalType => // Supported } } - checkTimeZoneId(timeAdd.timeZoneId) } override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCSVScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCSVScan.scala index b035b95646b..907291cc34b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCSVScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCSVScan.scala @@ -167,9 +167,7 @@ object GpuCSVScan { } if (types.contains(TimestampType)) { - if (!TypeChecks.areTimestampsSupported(parsedOptions.zoneId)) { - meta.willNotWorkOnGpu("Only UTC zone id is supported") - } + meta.checkTimeZoneId(parsedOptions.zoneId) GpuTextBasedDateUtils.tagCudfFormat(meta, GpuCsvUtils.timestampFormatInRead(parsedOptions), parseString = true) } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 5423e3eb9c4..fe46477d9b3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -649,7 +649,9 @@ object GpuOverrides extends Logging { case FloatType => true case DoubleType => true case DateType => true - case TimestampType => TypeChecks.areTimestampsSupported(ZoneId.systemDefault()) + case TimestampType => + TypeChecks.areTimestampsSupported(ZoneId.systemDefault()) && + TypeChecks.areTimestampsSupported(SQLConf.get.sessionLocalTimeZone) case StringType => true case dt: DecimalType if allowDecimal => dt.precision <= DType.DECIMAL64_MAX_PRECISION case NullType => allowNull @@ -1702,7 +1704,6 @@ object GpuOverrides extends Logging { willNotWorkOnGpu("interval months isn't supported") } } - checkTimeZoneId(timeAdd.timeZoneId) } override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = @@ -1724,7 +1725,6 @@ object GpuOverrides extends Logging { willNotWorkOnGpu("interval months isn't supported") } } - checkTimeZoneId(dateAddInterval.timeZoneId) } override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = @@ -1785,9 +1785,6 @@ object GpuOverrides extends Logging { ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, TypeSig.TIMESTAMP, TypeSig.TIMESTAMP), (hour, conf, p, r) => new UnaryExprMeta[Hour](hour, conf, p, r) { - override def tagExprForGpu(): Unit = { - checkTimeZoneId(hour.timeZoneId) - } override def convertToGpu(expr: Expression): GpuExpression = GpuHour(expr) }), @@ -1796,9 +1793,6 @@ object GpuOverrides extends Logging { ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, TypeSig.TIMESTAMP, TypeSig.TIMESTAMP), (minute, conf, p, r) => new UnaryExprMeta[Minute](minute, conf, p, r) { - override def tagExprForGpu(): Unit = { - checkTimeZoneId(minute.timeZoneId) - } override def convertToGpu(expr: Expression): GpuExpression = GpuMinute(expr) @@ -1808,9 +1802,6 @@ object GpuOverrides extends Logging { ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, TypeSig.TIMESTAMP, TypeSig.TIMESTAMP), (second, conf, p, r) => new UnaryExprMeta[Second](second, conf, p, r) { - override def tagExprForGpu(): Unit = { - checkTimeZoneId(second.timeZoneId) - } override def convertToGpu(expr: Expression): GpuExpression = GpuSecond(expr) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala index 7b23d6a5b78..0f3bb65e406 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import com.nvidia.spark.rapids.shims.SparkShimImpl -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, ComplexTypeMergingExpression, Expression, QuaternaryExpression, String2TrimExpression, TernaryExpression, UnaryExpression, WindowExpression, WindowFunction} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, ComplexTypeMergingExpression, Expression, QuaternaryExpression, String2TrimExpression, TernaryExpression, TimeZoneAwareExpression, UnaryExpression, WindowExpression, WindowFunction} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ImperativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.trees.TreeNodeTag @@ -367,11 +367,16 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE]( } } - protected def checkTimeZoneId(timeZoneId: Option[String]): Unit = { - timeZoneId.foreach { zoneId => - if (!TypeChecks.areTimestampsSupported(ZoneId.systemDefault())) { - willNotWorkOnGpu(s"Only UTC zone id is supported. Actual zone id: $zoneId") - } + def checkTimeZoneId(sessionZoneId: ZoneId): Unit = { + // Both of the Spark session time zone and JVM's default time zone should be UTC. + if (!TypeChecks.areTimestampsSupported(sessionZoneId)) { + willNotWorkOnGpu("Only UTC zone id is supported. " + + s"Actual session local zone id: $sessionZoneId") + } + + val defaultZoneId = ZoneId.systemDefault() + if (!TypeChecks.areTimestampsSupported(defaultZoneId)) { + willNotWorkOnGpu(s"Only UTC zone id is supported. Actual default zone id: $defaultZoneId") } } @@ -987,6 +992,10 @@ abstract class BaseExprMeta[INPUT <: Expression]( s"$wrapped is foldable and operates on non literals") } rule.getChecks.foreach(_.tag(this)) + wrapped match { + case tzAware: TimeZoneAwareExpression => checkTimeZoneId(tzAware.zoneId) + case _ => // do nothing + } tagExprForGpu() } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 0e58e6405ec..f408559a6f9 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -23,6 +23,8 @@ import ai.rapids.cudf.DType import com.nvidia.spark.rapids.shims.{GpuTypeShims, TypeSigUtil} import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnaryExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** Trait of TypeSigUtil for different spark versions */ @@ -358,7 +360,8 @@ final class TypeSig private( case DoubleType => check.contains(TypeEnum.DOUBLE) case DateType => check.contains(TypeEnum.DATE) case TimestampType if check.contains(TypeEnum.TIMESTAMP) => - TypeChecks.areTimestampsSupported(ZoneId.systemDefault()) + TypeChecks.areTimestampsSupported(ZoneId.systemDefault()) && + TypeChecks.areTimestampsSupported(SQLConf.get.sessionLocalTimeZone) case StringType => check.contains(TypeEnum.STRING) case dt: DecimalType => check.contains(TypeEnum.DECIMAL) && @@ -419,10 +422,11 @@ final class TypeSig private( basicNotSupportedMessage(dataType, TypeEnum.DATE, check, isChild) case TimestampType => if (check.contains(TypeEnum.TIMESTAMP) && - (!TypeChecks.areTimestampsSupported(ZoneId.systemDefault()))) { - Seq(withChild(isChild, s"$dataType is not supported when the JVM system " + - s"timezone is set to ${ZoneId.systemDefault()}. Set the timezone to UTC to enable " + - s"$dataType support")) + (!TypeChecks.areTimestampsSupported(ZoneId.systemDefault()) || + !TypeChecks.areTimestampsSupported(SQLConf.get.sessionLocalTimeZone))) { + Seq(withChild(isChild, s"$dataType is not supported with timezone settings: (JVM:" + + s" ${ZoneId.systemDefault()}, session: ${SQLConf.get.sessionLocalTimeZone})." + + s" Set both of the timezones to UTC to enable $dataType support")) } else { basicNotSupportedMessage(dataType, TypeEnum.TIMESTAMP, check, isChild) } @@ -796,6 +800,11 @@ object TypeChecks { def areTimestampsSupported(timezoneId: ZoneId): Boolean = { timezoneId.normalized() == GpuOverrides.UTC_TIMEZONE_ID } + + def areTimestampsSupported(zoneIdString: String): Boolean = { + val zoneId = DateTimeUtils.getZoneId(zoneIdString) + areTimestampsSupported(zoneId) + } } /** diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala index 75a95b41823..deab6ca94ec 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala @@ -129,9 +129,7 @@ object GpuJsonScan { } if (types.contains(TimestampType)) { - if (!TypeChecks.areTimestampsSupported(parsedOptions.zoneId)) { - meta.willNotWorkOnGpu("Only UTC zone id is supported") - } + meta.checkTimeZoneId(parsedOptions.zoneId) GpuTextBasedDateUtils.tagCudfFormat(meta, GpuJsonUtils.timestampFormatInRead(parsedOptions), parseString = true) } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index 431f1d8bc70..606c8a91b50 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -377,8 +377,6 @@ abstract class UnixTimeExprMeta[A <: BinaryExpression with TimeZoneAwareExpressi var sparkFormat: String = _ var strfFormat: String = _ override def tagExprForGpu(): Unit = { - checkTimeZoneId(expr.timeZoneId) - // Date and Timestamp work too if (expr.right.dataType == StringType) { extractStringLit(expr.right) match {