Skip to content

Commit

Permalink
format, update
Browse files Browse the repository at this point in the history
  • Loading branch information
Chong Gao committed Oct 27, 2023
1 parent 3652632 commit 32ffb5e
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@ import com.nvidia.spark.rapids.Arm.withResource

import org.apache.spark.sql.catalyst.util.DateTimeUtils

object CpuTimeZoneDB {
object TimeZoneDB {

def cacheDatabase(): Unit = {}

/**
* Interprate a timestamp as a time in the given time zone, and renders that time as a timestamp in UTC
* Interpret a timestamp as a time in the given time zone,
* and renders that time as a timestamp in UTC
*
* @return
*/
def fromTimestampToUtcTimestamp(inputVector: ColumnVector, currentTimeZone: ZoneId): ColumnVector = {
def fromTimestampToUtcTimestamp(
inputVector: ColumnVector,
currentTimeZone: ZoneId): ColumnVector = {
assert(inputVector.getType == DType.TIMESTAMP_MICROSECONDS)
val zoneStr = currentTimeZone.getId
val rowCount = inputVector.getRowCount.toInt
Expand Down Expand Up @@ -60,7 +62,9 @@ object CpuTimeZoneDB {
* @param desiredTimeZone desired time zone
* @return timestamp in the `desiredTimeZone`
*/
def fromUtcTimestampToTimestamp(inputVector: ColumnVector, desiredTimeZone: ZoneId): ColumnVector = {
def fromUtcTimestampToTimestamp(
inputVector: ColumnVector,
desiredTimeZone: ZoneId): ColumnVector = {
assert(inputVector.getType == DType.TIMESTAMP_MICROSECONDS)
val zoneStr = desiredTimeZone.getId
val rowCount = inputVector.getRowCount.toInt
Expand All @@ -82,7 +86,7 @@ object CpuTimeZoneDB {
}

/**
* Converts timesamp to date since 1970-01-01 at the given zone ID.
* Converts timestamp to date since 1970-01-01 at the given zone ID.
*
* @return
*/
Expand Down Expand Up @@ -113,7 +117,7 @@ object CpuTimeZoneDB {
* @param desiredTimeZone desired time zone
* @return timestamp in the `desiredTimeZone`
*/
def fromDateToTimestap(inputVector: ColumnVector, desiredTimeZone: ZoneId): ColumnVector = {
def fromDateToTimestamp(inputVector: ColumnVector, desiredTimeZone: ZoneId): ColumnVector = {
assert(inputVector.getType == DType.TIMESTAMP_DAYS)
val rowCount = inputVector.getRowCount.toInt
withResource(inputVector.copyToHost()) { input =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.nvidia.spark.rapids.timezone
import java.time._
import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._
import scala.collection.mutable

import ai.rapids.cudf.{ColumnVector, DType, HostColumnVector}
Expand Down Expand Up @@ -67,6 +68,7 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {

/**
* create Spark data frame, schema is [(ts_long: long)]
*
* @return
*/
def createLongDF(spark: SparkSession, epochSeconds: Array[Long]): DataFrame = {
Expand All @@ -77,6 +79,7 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {

/**
* create Spark data frame, schema is [(date_int: Int)]
*
* @return
*/
def createDateDF(spark: SparkSession, epochDays: Array[LocalDate]): DataFrame = {
Expand Down Expand Up @@ -113,8 +116,8 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
}

/**
* Get all epoch seconds every 15 minutes in [startYear, endYear].
*/
* Get all epoch seconds every 15 minutes in [startYear, endYear].
*/
def getEpochSeconds(startYear: Int, endYear: Int): Array[Long] = {
val s = Instant.parse("%04d-01-01T00:00:00z".format(startYear)).getEpochSecond
val e = Instant.parse("%04d-01-01T00:00:00z".format(endYear)).getEpochSecond
Expand All @@ -127,8 +130,8 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
}

/**
* Get all LocalDate in [startYear, endYear].
*/
* Get all LocalDate in [startYear, endYear].
*/
def getEpochDays(startYear: Int, endYear: Int): Array[LocalDate] = {
val s = LocalDate.of(startYear, 1, 1).toEpochDay.toInt
val e = LocalDate.of(endYear, 1, 1).toEpochDay.toInt
Expand All @@ -145,19 +148,20 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
spark => {
createLongDF(spark, epochSeconds).createOrReplaceTempView("tab")
// refer to https://spark.apache.org/docs/latest/api/sql/#from_utc_timestamp
// first cast(long as timestamp), it does not timezone awared.
spark.sql(s"select from_utc_timestamp(cast(ts_long as timestamp), '$zoneStr') from tab").collect()
// first cast(long as timestamp), it does not timezone aware.
val query = s"select from_utc_timestamp(cast(ts_long as timestamp), '$zoneStr') from tab"
spark.sql(query).collect()
},
new SparkConf()
// from_utc_timestamp's 2nd parameter is timezone, so here use UTC is OK.
.set("spark.sql.session.timeZone", "UTC")
// by setting this, the Spark output for datetime type is java.time.Instant instead
// of java.sql.Timestamp, it's convenient to compare result via java.time.Instant
.set("spark.sql.datetime.java8API.enabled", "true"))
// from_utc_timestamp's 2nd parameter is timezone, so here use UTC is OK.
.set("spark.sql.session.timeZone", "UTC")
// by setting this, the Spark output for datetime type is java.time.Instant instead
// of java.sql.Timestamp, it's convenient to compare result via java.time.Instant
.set("spark.sql.datetime.java8API.enabled", "true"))

// get result from TimeZoneDB
val actualRet = withResource(createColumnVector(epochSeconds)) { inputCv =>
CpuTimeZoneDB.fromUtcTimestampToTimestamp(
TimeZoneDB.fromUtcTimestampToTimestamp(
inputCv,
ZoneId.of(zoneStr))
}
Expand All @@ -173,19 +177,20 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
spark => {
createLongDF(spark, epochSeconds).createOrReplaceTempView("tab")
// refer to https://spark.apache.org/docs/latest/api/sql/#to_utc_timestamp
// first cast(long as timestamp), it does not timezone awared.
spark.sql(s"select to_utc_timestamp(cast(ts_long as timestamp), '$zoneStr') from tab").collect()
// first cast(long as timestamp), it does not timezone aware.
val query = s"select to_utc_timestamp(cast(ts_long as timestamp), '$zoneStr') from tab"
spark.sql(query).collect()
},
new SparkConf()
// to_utc_timestamp's 2nd parameter is timezone, so here use UTC is OK.
.set("spark.sql.session.timeZone", "UTC")
// by setting this, the Spark output for datetime type is java.time.Instant instead
// of java.sql.Timestamp, it's convenient to compare result via java.time.Instant
.set("spark.sql.datetime.java8API.enabled", "true"))
// to_utc_timestamp's 2nd parameter is timezone, so here use UTC is OK.
.set("spark.sql.session.timeZone", "UTC")
// by setting this, the Spark output for datetime type is java.time.Instant instead
// of java.sql.Timestamp, it's convenient to compare result via java.time.Instant
.set("spark.sql.datetime.java8API.enabled", "true"))

// get result from TimeZoneDB
val actualRet = withResource(createColumnVector(epochSeconds)) { inputCv =>
CpuTimeZoneDB.fromTimestampToUtcTimestamp(
TimeZoneDB.fromTimestampToUtcTimestamp(
inputCv,
ZoneId.of(zoneStr))
}
Expand All @@ -201,19 +206,19 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
spark => {
createLongDF(spark, epochSeconds).createOrReplaceTempView("tab")
// cast timestamp to date
// first cast(long as timestamp), it's not timezone awared.
// first cast(long as timestamp), it's not timezone aware.
spark.sql(s"select cast(cast(ts_long as timestamp) as date) from tab").collect()
},
new SparkConf()
// cast(timestamp as date) will use this timezone
.set("spark.sql.session.timeZone", zoneStr)
// by setting this, the Spark output for date type is java.time.LocalDate instead
// of java.sql.Date, it's convenient to compare result
.set("spark.sql.datetime.java8API.enabled", "true"))
// cast(timestamp as date) will use this timezone
.set("spark.sql.session.timeZone", zoneStr)
// by setting this, the Spark output for date type is java.time.LocalDate instead
// of java.sql.Date, it's convenient to compare result
.set("spark.sql.datetime.java8API.enabled", "true"))

// get result from TimeZoneDB
val actualRet = withResource(createColumnVector(epochSeconds)) { inputCv =>
CpuTimeZoneDB.fromTimestampToDate(
TimeZoneDB.fromTimestampToDate(
inputCv,
ZoneId.of(zoneStr))
}
Expand All @@ -229,19 +234,19 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
spark => {
createDateDF(spark, epochDays).createOrReplaceTempView("tab")
// cast timestamp to date
// first cast(long as timestamp), it's not timezone awared.
// first cast(long as timestamp), it's not timezone aware.
spark.sql(s"select cast(date_col as Timestamp) from tab").collect()
},
new SparkConf()
// cast(date as timestamp) will use this timezone
.set("spark.sql.session.timeZone", zoneStr)
// by setting this, the Spark output for date type is java.time.LocalDate instead
// of java.sql.Date, it's convenient to compare result
.set("spark.sql.datetime.java8API.enabled", "true"))
// cast(date as timestamp) will use this timezone
.set("spark.sql.session.timeZone", zoneStr)
// by setting this, the Spark output for date type is java.time.LocalDate instead
// of java.sql.Date, it's convenient to compare result
.set("spark.sql.datetime.java8API.enabled", "true"))

// get result from TimeZoneDB
val actualRet = withResource(createDateColumnVector(epochDays)) { inputCv =>
CpuTimeZoneDB.fromDateToTimestap(
TimeZoneDB.fromDateToTimestamp(
inputCv,
ZoneId.of(zoneStr))
}
Expand All @@ -251,23 +256,38 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
}
}

def selectWithRepeatZones: Seq[String] = {
val mustZones = Array[String]("Asia/Shanghai", "America/Los_Angeles")
val repeatZones = ZoneId.getAvailableZoneIds.asScala.toList.filter { z =>
val rules = ZoneId.of(z).getRules
!(rules.isFixedOffset || rules.getTransitionRules.isEmpty) && !mustZones.contains(z)
}
scala.util.Random.shuffle(repeatZones)
repeatZones.slice(0, 2) ++ mustZones
}

def selectNonRepeatZones: Seq[String] = {
val mustZones = Array[String]("Asia/Shanghai", "America/Sao_Paulo")
val nonRepeatZones = ZoneId.getAvailableZoneIds.asScala.toList.filter { z =>
val rules = ZoneId.of(z).getRules
// remove this line when we support repeat rules
(rules.isFixedOffset || rules.getTransitionRules.isEmpty) && !mustZones.contains(z)
}
scala.util.Random.shuffle(nonRepeatZones)
nonRepeatZones.slice(0, 2) ++ mustZones
}

test("test all time zones") {
assume(false,
"It's time consuming for test all time zones, by default it's disabled")
// val zones = ZoneId.getAvailableZoneIds.asScala.toList.map(z => ZoneId.of(z)).filter { z =>
// val rules = z.getRules
// // remove this line when we support repeat rules
// rules.isFixedOffset || rules.getTransitionRules.isEmpty
// }

// Currently only test one zone
val zones = Array[String]("Asia/Shanghai")
val zones = selectNonRepeatZones
// iterate zones
for (zoneStr <- zones) {
// iterate years
val startYear = 1
val endYear = 9999
for (year <- startYear until endYear) {
for (year <- startYear until endYear by 7) {
val epochSeconds = getEpochSeconds(year, year + 1)
testFromUtcTimeStampToTimestamp(epochSeconds, zoneStr)
testFromTimestampToUTCTimestamp(epochSeconds, zoneStr)
Expand Down

0 comments on commit 32ffb5e

Please sign in to comment.