Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support unix_timestamp and to_unix_timestamp with non-UTC timezones (non-DST) #9816

Merged
merged 18 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,13 @@ def test_unsupported_fallback_to_unix_timestamp(data_gen):
"to_unix_timestamp(a, b)"),
"ToUnixTimestamp")

# TODO: has another test for this called test_unix_timestamp
@pytest.mark.parametrize('data_gen', date_n_time_gens, ids=idfn)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also test string gen.
Please refer to the Product plan doc, it requires String data type.

def test_unix_timestamp_non_UTC(data_gen):
assert_gpu_and_cpu_are_equal_collect(lambda spark: gen_df(
spark, [("a", data_gen)], length=10).selectExpr(
"unix_timestamp(a, 'yyyy-MM-dd HH:mm:ss')"))


@pytest.mark.parametrize('time_zone', ["UTC", "UTC+0", "UTC-0", "GMT", "GMT+0", "GMT-0"], ids=idfn)
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,6 @@ abstract class CastExprMetaBase[INPUT <: UnaryExpression with TimeZoneAwareExpre
def buildTagMessage(entry: ConfEntry[_]): String = {
s"${entry.doc}. To enable this operation on the GPU, set ${entry.key} to true."
}

// timezone tagging in type checks is good enough, so always false
override protected val needTimezoneTagging: Boolean = false
}

object CastOptions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -669,9 +669,7 @@ object GpuOverrides extends Logging {
case FloatType => true
case DoubleType => true
case DateType => true
case TimestampType =>
TypeChecks.areTimestampsSupported(ZoneId.systemDefault()) &&
TypeChecks.areTimestampsSupported(SQLConf.get.sessionLocalTimeZone)
case TimestampType => true
case StringType => true
case dt: DecimalType if allowDecimal => dt.precision <= DType.DECIMAL64_MAX_PRECISION
case NullType => allowNull
Expand Down Expand Up @@ -1682,6 +1680,8 @@ object GpuOverrides extends Logging {
.withPsNote(TypeEnum.STRING, "A limited number of formats are supported"),
TypeSig.STRING)),
(a, conf, p, r) => new UnixTimeExprMeta[ToUnixTimestamp](a, conf, p, r) {
override val isTimezoneSupported: Boolean = true

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
if (conf.isImprovedTimestampOpsEnabled) {
// passing the already converted strf string for a little optimization
Expand All @@ -1701,6 +1701,8 @@ object GpuOverrides extends Logging {
.withPsNote(TypeEnum.STRING, "A limited number of formats are supported"),
TypeSig.STRING)),
(a, conf, p, r) => new UnixTimeExprMeta[UnixTimestamp](a, conf, p, r) {
override val isTimezoneSupported: Boolean = true

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
if (conf.isImprovedTimestampOpsEnabled) {
// passing the already converted strf string for a little optimization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ object GpuParquetFileFormat {
SparkShimImpl.parquetRebaseWrite(sqlConf))

if ((int96RebaseMode == DateTimeRebaseLegacy || dateTimeRebaseMode == DateTimeRebaseLegacy)
&& !TypeChecks.areTimestampsSupported()) {
&& !TypeChecks.isUTCTimezone()) {
meta.willNotWorkOnGpu("Only UTC timezone is supported in LEGACY rebase mode. " +
s"Current timezone settings: (JVM : ${ZoneId.systemDefault()}, " +
s"session: ${SQLConf.get.sessionLocalTimeZone}). " +
Expand Down
12 changes: 5 additions & 7 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -374,13 +374,11 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging {
case Some(value) => ZoneId.of(value)
case None => throw new RuntimeException(s"Driver time zone cannot be determined.")
}
if (TypeChecks.areTimestampsSupported(driverTimezone)) {
val executorTimezone = ZoneId.systemDefault()
if (executorTimezone.normalized() != driverTimezone.normalized()) {
throw new RuntimeException(s" Driver and executor timezone mismatch. " +
s"Driver timezone is $driverTimezone and executor timezone is " +
s"$executorTimezone. Set executor timezone to $driverTimezone.")
}
val executorTimezone = ZoneId.systemDefault()
if (executorTimezone.normalized() != driverTimezone.normalized()) {
throw new RuntimeException(s" Driver and executor timezone mismatch. " +
s"Driver timezone is $driverTimezone and executor timezone is " +
s"$executorTimezone. Set executor timezone to $driverTimezone.")
}

GpuCoreDumpHandler.executorInit(conf, pluginContext)
Expand Down
53 changes: 31 additions & 22 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.collection.mutable

import com.nvidia.spark.rapids.shims.{DistributionUtil, SparkShimImpl}

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, ComplexTypeMergingExpression, Expression, QuaternaryExpression, String2TrimExpression, TernaryExpression, TimeZoneAwareExpression, UnaryExpression, WindowExpression, WindowFunction}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, Cast, ComplexTypeMergingExpression, Expression, QuaternaryExpression, String2TrimExpression, TernaryExpression, TimeZoneAwareExpression, UnaryExpression, UTCTimestamp, WindowExpression, WindowFunction}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
Expand Down Expand Up @@ -382,13 +382,13 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE](

def checkTimeZoneId(sessionZoneId: ZoneId): Unit = {
// Both of the Spark session time zone and JVM's default time zone should be UTC.
if (!TypeChecks.areTimestampsSupported(sessionZoneId)) {
if (!TypeChecks.isTimestampsSupported(sessionZoneId)) {
willNotWorkOnGpu("Only UTC zone id is supported. " +
s"Actual session local zone id: $sessionZoneId")
}

val defaultZoneId = ZoneId.systemDefault()
if (!TypeChecks.areTimestampsSupported(defaultZoneId)) {
if (!TypeChecks.isTimestampsSupported(defaultZoneId)) {
willNotWorkOnGpu(s"Only UTC zone id is supported. Actual default zone id: $defaultZoneId")
}
}
Expand Down Expand Up @@ -1082,21 +1082,27 @@ abstract class BaseExprMeta[INPUT <: Expression](

val isFoldableNonLitAllowed: Boolean = false

/**
* Whether to tag a TimeZoneAwareExpression for timezone after all the other tagging
* is done.
* By default a TimeZoneAwareExpression always requires the timezone tagging, but
* there are some exceptions, e.g. 'Cast', who requires timezone tagging only when it
* has timezone sensitive type as input or output.
*
* Override this to match special cases.
*/
protected def needTimezoneTagging: Boolean = {
// A TimeZoneAwareExpression with no timezone sensitive types as input/output will
// escape from the timezone tagging in the prior type checks. So ask for tagging here.
// e.g. 'UnixTimestamp' with 'DateType' as the input, timezone will be taken into
// account when converting a Date to a Long.
!(dataType +: childExprs.map(_.dataType)).exists(TypeChecks.isTimezoneSensitiveType)
// Whether timezone is supported for those expressions needs to be check.
// TODO: use TimezoneDB Utils to tell whether timezone is supported
val isTimezoneSupported: Boolean = false

// Both [[isTimezoneSupported]] and [[needTimezoneCheck]] are needed to check whether timezone
// check needed. For cast expression, only some cases are needed pending on its data type and
// its child's data type.
//+------------------------+-------------------+-----------------------------------------+
//| Value | needTimezoneCheck | isTimezoneSupported |
//+------------------------+-------------------+-----------------------------------------+
//| TimezoneAwareExpression| True | False by default, True when implemented |
//| UTCTimestamp | True | False by default, True when implemented |
//| Others | False | N/A (will not be checked) |
//+------------------------+-------------------+-----------------------------------------+
lazy val needTimezoneCheck: Boolean = {
wrapped match {
case _: TimeZoneAwareExpression =>
if (wrapped.isInstanceOf[Cast]) wrapped.asInstanceOf[Cast].needsTimeZone else true
case _: UTCTimestamp => true
case _ => false
}
}

final override def tagSelfForGpu(): Unit = {
Expand All @@ -1105,11 +1111,14 @@ abstract class BaseExprMeta[INPUT <: Expression](
s"$wrapped is foldable and operates on non literals")
}
rule.getChecks.foreach(_.tag(this))
if (needTimezoneCheck && !isTimezoneSupported) checkTimestampType(this)
tagExprForGpu()
wrapped match {
case tzAware: TimeZoneAwareExpression if needTimezoneTagging =>
checkTimeZoneId(tzAware.zoneId)
case _ => // do nothing
}

def checkTimestampType(meta: RapidsMeta[_, _, _]): Unit = {
// FIXME: use new API from TimezoneDB utils to check whether it's supported
if (!TypeChecks.isUTCTimezone()) {
meta.willNotWorkOnGpu(TypeChecks.timezoneNotSupportedString(meta.wrapped.getClass.toString))
}
}

Expand Down
69 changes: 12 additions & 57 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,7 @@ final class TypeSig private(
case FloatType => check.contains(TypeEnum.FLOAT)
case DoubleType => check.contains(TypeEnum.DOUBLE)
case DateType => check.contains(TypeEnum.DATE)
case TimestampType if check.contains(TypeEnum.TIMESTAMP) =>
TypeChecks.areTimestampsSupported()
case TimestampType => check.contains(TypeEnum.TIMESTAMP)
case StringType => check.contains(TypeEnum.STRING)
case dt: DecimalType =>
check.contains(TypeEnum.DECIMAL) &&
Expand Down Expand Up @@ -402,15 +401,6 @@ final class TypeSig private(
}
}

private[this] def timezoneNotSupportedMessage(dataType: DataType,
te: TypeEnum.Value, check: TypeEnum.ValueSet, isChild: Boolean): Seq[String] = {
if (check.contains(te) && !TypeChecks.areTimestampsSupported()) {
Seq(withChild(isChild, TypeChecks.timezoneNotSupportedString(dataType)))
} else {
basicNotSupportedMessage(dataType, te, check, isChild)
}
}

private[this] def reasonNotSupported(
check: TypeEnum.ValueSet,
dataType: DataType,
Expand All @@ -433,7 +423,7 @@ final class TypeSig private(
case DateType =>
basicNotSupportedMessage(dataType, TypeEnum.DATE, check, isChild)
case TimestampType =>
timezoneNotSupportedMessage(dataType, TypeEnum.TIMESTAMP, check, isChild)
basicNotSupportedMessage(dataType, TypeEnum.TIMESTAMP, check, isChild)
case StringType =>
basicNotSupportedMessage(dataType, TypeEnum.STRING, check, isChild)
case dt: DecimalType =>
Expand Down Expand Up @@ -780,30 +770,6 @@ abstract class TypeChecks[RET] {
}.mkString(", ")
}

/**
* Original log does not print enough info when timezone is not UTC,
* here check again to add UTC info.
*/
private def tagTimezoneInfoIfHasTimestampType(
unsupportedTypes: Map[DataType, Set[String]],
meta: RapidsMeta[_, _, _]): Unit = {
def checkTimestampType(dataType: DataType): Unit = dataType match {
case TimestampType if !TypeChecks.areTimestampsSupported() =>
meta.willNotWorkOnGpu(TypeChecks.timezoneNotSupportedString(dataType))
case ArrayType(elementType, _) =>
checkTimestampType(elementType)
case MapType(keyType, valueType, _) =>
checkTimestampType(keyType)
checkTimestampType(valueType)
case StructType(fields) =>
fields.foreach(field => checkTimestampType(field.dataType))
case _ => // do nothing
}
unsupportedTypes.foreach { case (dataType, _) =>
checkTimestampType(dataType)
}
}

protected def tagUnsupportedTypes(
meta: RapidsMeta[_, _, _],
sig: TypeSig,
Expand All @@ -815,40 +781,29 @@ abstract class TypeChecks[RET] {
.groupBy(_.dataType)
.mapValues(_.map(_.name).toSet).toMap

tagTimezoneInfoIfHasTimestampType(unsupportedTypes, meta)

if (unsupportedTypes.nonEmpty) {
meta.willNotWorkOnGpu(msgFormat.format(stringifyTypeAttributeMap(unsupportedTypes)))
}
}
}

object TypeChecks {
/**
* Check if the time zone passed is supported by plugin.
*/
def areTimestampsSupported(timezoneId: ZoneId): Boolean = {
timezoneId.normalized() == GpuOverrides.UTC_TIMEZONE_ID
}

def areTimestampsSupported(zoneIdString: String): Boolean = {
val zoneId = DateTimeUtils.getZoneId(zoneIdString)
areTimestampsSupported(zoneId)
}

def areTimestampsSupported(): Boolean = {
areTimestampsSupported(ZoneId.systemDefault()) &&
areTimestampsSupported(SQLConf.get.sessionLocalTimeZone)
// TODO: move this to Timezone DB
def isTimestampsSupported(timezoneId: ZoneId): Boolean = {
timezoneId.normalized() == GpuOverrides.UTC_TIMEZONE_ID
}

def isTimezoneSensitiveType(dataType: DataType): Boolean = {
dataType == TimestampType
def isUTCTimezone(): Boolean = {
val zoneId = DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone)
zoneId.normalized() == GpuOverrides.UTC_TIMEZONE_ID
}

def timezoneNotSupportedString(dataType: DataType): String = {
s"$dataType is not supported with timezone settings: (JVM:" +
// TODO: change the string about supported timezones
def timezoneNotSupportedString(exprName: String): String = {
s"$exprName is not supported with timezone settings: (JVM:" +
s" ${ZoneId.systemDefault()}, session: ${SQLConf.get.sessionLocalTimeZone})." +
s" Set both of the timezones to UTC to enable $dataType support"
s" Set both of the timezones to UTC to enable $exprName support"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.GpuOverrides.{extractStringLit, getTimeParserPolicy}
import com.nvidia.spark.rapids.RapidsConf.TEST_USE_TIMEZONE_CPU_BACKEND
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.jni.GpuTimeZoneDB
import com.nvidia.spark.rapids.shims.ShimBinaryExpression

import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUTCTimestamp, ImplicitCastInputTypes, NullIntolerant, TimeZoneAwareExpression}
Expand Down Expand Up @@ -370,6 +371,11 @@ abstract class UnixTimeExprMeta[A <: BinaryExpression with TimeZoneAwareExpressi
var sparkFormat: String = _
var strfFormat: String = _
override def tagExprForGpu(): Unit = {
val zoneIdStr = SQLConf.get.sessionLocalTimeZone
if(!GpuTimeZoneDB.isSupportedTimeZone(zoneIdStr)) {
return willNotWorkOnGpu(s"Not supported timezone ID ${zoneIdStr}")
}

// Date and Timestamp work too
if (expr.right.dataType == StringType) {
extractStringLit(expr.right) match {
Expand Down Expand Up @@ -842,7 +848,21 @@ abstract class GpuToTimestamp
failOnError)
}
} else { // Timestamp or DateType
lhs.getBase.asTimestampMicroseconds()
timeZoneId match {
case Some(idStr) => {
val zoneId = GpuTimeZoneDB.getZoneId(idStr)
if (TimeZoneDB.isUTCTimezone(zoneId)) {
lhs.getBase.asTimestampMicroseconds()
} else {
assert(GpuTimeZoneDB.isSupportedTimeZone(zoneId))
withResource(lhs) { gcv =>
GpuTimeZoneDB.fromTimestampToUtcTimestamp(gcv.getBase, zoneId)
.asTimestampMicroseconds()
}
}
}
case None => lhs.getBase.asTimestampMicroseconds()
}
}
// Return Timestamp value if dataType it is expecting is of TimestampType
if (dataType.equals(TimestampType)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@

package com.nvidia.spark.rapids

import java.sql.{Date, Timestamp}
import java.time.{ZonedDateTime, ZoneId}
import java.util.TimeZone

import scala.collection.mutable.ListBuffer

import ai.rapids.cudf.{ColumnVector, RegexProgram}
import com.nvidia.spark.rapids.Arm.withResource
import java.sql.{Date, Timestamp}
import org.scalatest.BeforeAndAfterEach
import scala.collection.mutable.ListBuffer

import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
Expand Down Expand Up @@ -280,6 +284,48 @@ class ParseDateTimeSuite extends SparkQueryCompareTestSuite with BeforeAndAfterE
assert(res)
}

test("literals: ensure time literals are correct with different timezones") {
testTimeWithDiffTimezones("Asia/Shanghai", "America/New_York")
testTimeWithDiffTimezones("Asia/Shanghai", "UTC")
testTimeWithDiffTimezones("UTC", "Asia/Shanghai")
}

private[this] def testTimeWithDiffTimezones(sessionTZStr: String, systemTZStr: String) = {
withTimeZones(sessionTimeZone = sessionTZStr, systemTimeZone = systemTZStr) { conf =>
val df = withGpuSparkSession(spark => {
spark.sql("SELECT current_date(), current_timestamp(), now() FROM RANGE(1, 10)")
}, conf)

val times = df.collect()
val zonedDateTime = ZonedDateTime.now(ZoneId.of("America/New_York"))
val res = times.forall(time => {
val diffDate = zonedDateTime.toLocalDate.toEpochDay - time.getLocalDate(0).toEpochDay
val diffTimestamp =
zonedDateTime.toInstant.getNano - time.getInstant(1).getNano
val diffNow =
zonedDateTime.toInstant.getNano - time.getInstant(2).getNano
// For date, at most 1 day difference when execution is crossing two days
// For timestamp or now, it should be less than 1 second allowing Spark's execution
diffDate.abs <= 1 & diffTimestamp.abs <= 1E9 & diffNow.abs <= 1E9
})
assert(res)
}
}

private def withTimeZones(sessionTimeZone: String,
systemTimeZone: String)(f: SparkConf => Unit): Unit = {
val conf = new SparkConf()
conf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, sessionTimeZone)
conf.set(SQLConf.DATETIME_JAVA8API_ENABLED.key, "true")
val originTimeZone = TimeZone.getDefault
try {
TimeZone.setDefault(TimeZone.getTimeZone(systemTimeZone))
f(conf)
} finally {
TimeZone.setDefault(originTimeZone)
}
}

private def testRegex(rule: RegexReplace, values: Seq[String], expected: Seq[String]): Unit = {
withResource(ColumnVector.fromStrings(values: _*)) { v =>
withResource(ColumnVector.fromStrings(expected: _*)) { expected =>
Expand Down
Loading