Skip to content

Commit

Permalink
Add cast between data and timestamp cases
Browse files Browse the repository at this point in the history
  • Loading branch information
Chong Gao committed Oct 26, 2023
1 parent 397f7ca commit 3652632
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ object CpuTimeZoneDB {
*
* @return
*/
def convertToUTC(inputVector: ColumnVector, currentTimeZone: ZoneId): ColumnVector = {
def fromTimestampToUtcTimestamp(inputVector: ColumnVector, currentTimeZone: ZoneId): ColumnVector = {
assert(inputVector.getType == DType.TIMESTAMP_MICROSECONDS)
val zoneStr = currentTimeZone.getId
val rowCount = inputVector.getRowCount.toInt
Expand Down Expand Up @@ -60,7 +60,7 @@ object CpuTimeZoneDB {
* @param desiredTimeZone desired time zone
* @return timestamp in the `desiredTimeZone`
*/
def convertFromUTC(inputVector: ColumnVector, desiredTimeZone: ZoneId): ColumnVector = {
def fromUtcTimestampToTimestamp(inputVector: ColumnVector, desiredTimeZone: ZoneId): ColumnVector = {
assert(inputVector.getType == DType.TIMESTAMP_MICROSECONDS)
val zoneStr = desiredTimeZone.getId
val rowCount = inputVector.getRowCount.toInt
Expand All @@ -80,4 +80,56 @@ object CpuTimeZoneDB {
}
}
}

/**
* Converts timesamp to date since 1970-01-01 at the given zone ID.
*
* @return
*/
def fromTimestampToDate(inputVector: ColumnVector, currentTimeZone: ZoneId): ColumnVector = {
assert(inputVector.getType == DType.TIMESTAMP_MICROSECONDS)
val rowCount = inputVector.getRowCount.toInt
withResource(inputVector.copyToHost()) { input =>
withResource(HostColumnVector.builder(DType.TIMESTAMP_DAYS, rowCount)) { builder =>
var currRow = 0
while (currRow < rowCount) {
val origin = input.getLong(currRow)
// Spark implementation
val dist = DateTimeUtils.microsToDays(origin, currentTimeZone)
builder.append(dist)
currRow += 1
}
withResource(builder.build()) { b =>
b.copyToDevice()
}
}
}
}

/**
* Converts date at the given zone ID to microseconds since 1970-01-01 00:00:00Z.
*
* @param inputVector UTC timestamp
* @param desiredTimeZone desired time zone
* @return timestamp in the `desiredTimeZone`
*/
def fromDateToTimestap(inputVector: ColumnVector, desiredTimeZone: ZoneId): ColumnVector = {
assert(inputVector.getType == DType.TIMESTAMP_DAYS)
val rowCount = inputVector.getRowCount.toInt
withResource(inputVector.copyToHost()) { input =>
withResource(HostColumnVector.builder(DType.INT64, rowCount)) { builder =>
var currRow = 0
while (currRow < rowCount) {
val origin = input.getInt(currRow)
// Spark implementation
val dist = DateTimeUtils.daysToMicros(origin, desiredTimeZone)
builder.append(dist)
currRow += 1
}
withResource(builder.build()) { b =>
b.copyToDevice()
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids.timezone

import java.time.{Instant, ZoneId}
import java.time._
import java.util.concurrent.TimeUnit

import scala.collection.mutable
Expand All @@ -27,7 +27,7 @@ import com.nvidia.spark.rapids.SparkQueryCompareTestSuite

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.types.{LongType, StructField, StructType}
import org.apache.spark.sql.types._

class TimeZoneSuite extends SparkQueryCompareTestSuite {
/**
Expand All @@ -48,20 +48,47 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
}
}

/**
* create date column vector
*/
def createDateColumnVector(epochDays: Array[LocalDate]): ColumnVector = {
val rows = epochDays.length
withResource(HostColumnVector.builder(DType.TIMESTAMP_DAYS, rows)) { builder =>
var idx = 0
while (idx < rows) {
builder.append(epochDays(idx).toEpochDay.toInt)
idx += 1
}
withResource(builder.build()) { b =>
b.copyToDevice()
}
}
}

/**
* create Spark data frame, schema is [(ts_long: long)]
* @return
*/
def createDF(spark: SparkSession, epochSeconds: Array[Long]): DataFrame = {
val data = epochSeconds.indices.map(i => Row(epochSeconds(i)))
def createLongDF(spark: SparkSession, epochSeconds: Array[Long]): DataFrame = {
val data = epochSeconds.map(l => Row(l))
val schema = StructType(Array(StructField("ts_long", LongType)))
spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
}

/**
* assert result with Spark result
* create Spark data frame, schema is [(date_int: Int)]
* @return
*/
def assertRet(actualRet: ColumnVector, sparkRet: Seq[Row]): Unit = {
def createDateDF(spark: SparkSession, epochDays: Array[LocalDate]): DataFrame = {
val data = epochDays.map(d => Row(d))
val schema = StructType(Array(StructField("date_col", DateType)))
spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
}

/**
* assert timestamp result with Spark result
*/
def assertTimestampRet(actualRet: ColumnVector, sparkRet: Seq[Row]): Unit = {
withResource(actualRet.copyToHost()) { host =>
assert(actualRet.getRowCount == sparkRet.length)
for (i <- 0 until actualRet.getRowCount.toInt) {
Expand All @@ -72,6 +99,19 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
}
}

/**
* assert date result with Spark result
*/
def assertDateRet(actualRet: ColumnVector, sparkRet: Seq[Row]): Unit = {
withResource(actualRet.copyToHost()) { host =>
assert(actualRet.getRowCount == sparkRet.length)
for (i <- 0 until actualRet.getRowCount.toInt) {
val epochDay = sparkRet(i).getLocalDate(0).toEpochDay
assert(host.getInt(i) == epochDay)
}
}
}

/**
* Get all epoch seconds every 15 minutes in [startYear, endYear].
*/
Expand All @@ -86,63 +126,128 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
epochSeconds.toArray
}

def testFromUTC(epochSeconds: Array[Long], zoneStr: String): Unit = {
/**
* Get all LocalDate in [startYear, endYear].
*/
def getEpochDays(startYear: Int, endYear: Int): Array[LocalDate] = {
val s = LocalDate.of(startYear, 1, 1).toEpochDay.toInt
val e = LocalDate.of(endYear, 1, 1).toEpochDay.toInt
val epochDays = mutable.ArrayBuffer[LocalDate]()
for (epoch <- s until e) {
epochDays += LocalDate.ofEpochDay(epoch)
}
epochDays.toArray
}

def testFromUtcTimeStampToTimestamp(epochSeconds: Array[Long], zoneStr: String): Unit = {
// get result from Spark
val sparkRet = withCpuSparkSession(
spark => {
createDF(spark, epochSeconds).createOrReplaceTempView("tab")
createLongDF(spark, epochSeconds).createOrReplaceTempView("tab")
// refer to https://spark.apache.org/docs/latest/api/sql/#from_utc_timestamp
// convert from UTC timestamp
// first cast long value as timestamp
// first cast(long as timestamp), it does not timezone awared.
spark.sql(s"select from_utc_timestamp(cast(ts_long as timestamp), '$zoneStr') from tab").collect()
},
new SparkConf()
// use UTC time zone to create date frame
// from_utc_timestamp's 2nd parameter is timezone, so here use UTC is OK.
.set("spark.sql.session.timeZone", "UTC")
// by setting this, the Spark output for datetime type is java.time.Instant instead
// of java.sql.Timestamp, it's convenient to compare result via java.time.Instant
.set("spark.sql.datetime.java8API.enabled", "true"))

// get result from TimeZoneDB
val actualRet = withResource(createColumnVector(epochSeconds)) { inputCv =>
// convert time zone from UTC to specific timezone
CpuTimeZoneDB.convertFromUTC(
CpuTimeZoneDB.fromUtcTimestampToTimestamp(
inputCv,
ZoneId.of(zoneStr))
}

withResource(actualRet) { _ =>
assertRet(actualRet, sparkRet)
assertTimestampRet(actualRet, sparkRet)
}
}

def testToUTC(epochSeconds: Array[Long], zoneStr: String): Unit = {
def testFromTimestampToUTCTimestamp(epochSeconds: Array[Long], zoneStr: String): Unit = {
// get result from Spark
val sparkRet = withCpuSparkSession(
spark => {
createDF(spark, epochSeconds).createOrReplaceTempView("tab")
createLongDF(spark, epochSeconds).createOrReplaceTempView("tab")
// refer to https://spark.apache.org/docs/latest/api/sql/#to_utc_timestamp
// convert to UTC timezone
// first cast long value as timestamp
// first cast(long as timestamp), it does not timezone awared.
spark.sql(s"select to_utc_timestamp(cast(ts_long as timestamp), '$zoneStr') from tab").collect()
},
new SparkConf()
// use UTC time zone to create date frame
// to_utc_timestamp's 2nd parameter is timezone, so here use UTC is OK.
.set("spark.sql.session.timeZone", "UTC")
// by setting this, the Spark output for datetime type is java.time.Instant instead
// of java.sql.Timestamp, it's convenient to compare result via java.time.Instant
.set("spark.sql.datetime.java8API.enabled", "true"))

// get result from TimeZoneDB
val actualRet = withResource(createColumnVector(epochSeconds)) { inputCv =>
// convert time zone from UTC to specific timezone
CpuTimeZoneDB.convertToUTC(
CpuTimeZoneDB.fromTimestampToUtcTimestamp(
inputCv,
ZoneId.of(zoneStr))
}

withResource(actualRet) { _ =>
assertTimestampRet(actualRet, sparkRet)
}
}

def testFromTimestampToDate(epochSeconds: Array[Long], zoneStr: String): Unit = {
// get result from Spark
val sparkRet = withCpuSparkSession(
spark => {
createLongDF(spark, epochSeconds).createOrReplaceTempView("tab")
// cast timestamp to date
// first cast(long as timestamp), it's not timezone awared.
spark.sql(s"select cast(cast(ts_long as timestamp) as date) from tab").collect()
},
new SparkConf()
// cast(timestamp as date) will use this timezone
.set("spark.sql.session.timeZone", zoneStr)
// by setting this, the Spark output for date type is java.time.LocalDate instead
// of java.sql.Date, it's convenient to compare result
.set("spark.sql.datetime.java8API.enabled", "true"))

// get result from TimeZoneDB
val actualRet = withResource(createColumnVector(epochSeconds)) { inputCv =>
CpuTimeZoneDB.fromTimestampToDate(
inputCv,
ZoneId.of(zoneStr))
}

withResource(actualRet) { _ =>
assertRet(actualRet, sparkRet)
assertDateRet(actualRet, sparkRet)
}
}

def testFromDateToTimestamp(epochDays: Array[LocalDate], zoneStr: String): Unit = {
// get result from Spark
val sparkRet = withCpuSparkSession(
spark => {
createDateDF(spark, epochDays).createOrReplaceTempView("tab")
// cast timestamp to date
// first cast(long as timestamp), it's not timezone awared.
spark.sql(s"select cast(date_col as Timestamp) from tab").collect()
},
new SparkConf()
// cast(date as timestamp) will use this timezone
.set("spark.sql.session.timeZone", zoneStr)
// by setting this, the Spark output for date type is java.time.LocalDate instead
// of java.sql.Date, it's convenient to compare result
.set("spark.sql.datetime.java8API.enabled", "true"))

// get result from TimeZoneDB
val actualRet = withResource(createDateColumnVector(epochDays)) { inputCv =>
CpuTimeZoneDB.fromDateToTimestap(
inputCv,
ZoneId.of(zoneStr))
}

withResource(actualRet) { _ =>
assertTimestampRet(actualRet, sparkRet)
}
}

Expand All @@ -160,11 +265,17 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite {
// iterate zones
for (zoneStr <- zones) {
// iterate years
for (year <- 1 until 9999) {
val startYear = 1
val endYear = 9999
for (year <- startYear until endYear) {
val epochSeconds = getEpochSeconds(year, year + 1)
testFromUTC(epochSeconds, zoneStr)
testToUTC(epochSeconds, zoneStr)
testFromUtcTimeStampToTimestamp(epochSeconds, zoneStr)
testFromTimestampToUTCTimestamp(epochSeconds, zoneStr)
testFromTimestampToDate(epochSeconds, zoneStr)
}

val epochDays = getEpochDays(startYear, endYear)
testFromDateToTimestamp(epochDays, zoneStr)
}
}
}

0 comments on commit 3652632

Please sign in to comment.