Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support unix_timestamp on GPU for subset of formats #1113

Merged
merged 18 commits into from
Nov 18, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.UnaryPositive"></a>spark.rapids.sql.expression.UnaryPositive|`positive`|A numeric value with a + in front of it|true|None|
<a name="sql.expression.UnboundedFollowing$"></a>spark.rapids.sql.expression.UnboundedFollowing$| |Special boundary for a window frame, indicating all rows preceding the current row|true|None|
<a name="sql.expression.UnboundedPreceding$"></a>spark.rapids.sql.expression.UnboundedPreceding$| |Special boundary for a window frame, indicating all rows preceding the current row|true|None|
<a name="sql.expression.UnixTimestamp"></a>spark.rapids.sql.expression.UnixTimestamp|`unix_timestamp`|Returns the UNIX timestamp of current or specified time|false|This is not 100% compatible with the Spark version because Incorrectly formatted strings and bogus dates produce garbage data instead of null|
<a name="sql.expression.UnixTimestamp"></a>spark.rapids.sql.expression.UnixTimestamp|`unix_timestamp`|Returns the UNIX timestamp of current or specified time|true|None|
<a name="sql.expression.Upper"></a>spark.rapids.sql.expression.Upper|`upper`, `ucase`|String uppercase operator|false|This is not 100% compatible with the Spark version because in some cases unicode characters change byte width when changing the case. The GPU string conversion does not support these characters. For a full list of unsupported characters see https://github.com/rapidsai/cudf/issues/3132|
<a name="sql.expression.WeekDay"></a>spark.rapids.sql.expression.WeekDay|`weekday`|Returns the day of the week (0 = Monday...6=Sunday)|true|None|
<a name="sql.expression.WindowExpression"></a>spark.rapids.sql.expression.WindowExpression| |Calculates a return value for every input row of a table based on a group (or "window") of rows|true|None|
Expand Down
29 changes: 19 additions & 10 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ object GpuCast {

val INVALID_FLOAT_CAST_MSG = "At least one value is either null or is an invalid number"

val EPOCH = "epoch"
val NOW = "now"
val TODAY = "today"
val YESTERDAY = "yesterday"
val TOMORROW = "tomorrow"

/**
* Returns true iff we can cast `from` to `to` using the GPU.
*/
Expand Down Expand Up @@ -182,6 +188,17 @@ object GpuCast {
case _ => false
}
}

def calculateSpecialDates: Map[String, Int] = {
val now = DateTimeUtils.currentDate(ZoneId.of("UTC"))
Map(
EPOCH -> 0,
NOW -> now,
TODAY -> now,
YESTERDAY -> (now - 1),
TOMORROW -> (now + 1)
)
}
}

/**
Expand Down Expand Up @@ -655,16 +672,6 @@ case class GpuCast(
}
}

// special dates
val now = DateTimeUtils.currentDate(ZoneId.of("UTC"))
val specialDates: Map[String, Int] = Map(
"epoch" -> 0,
"now" -> now,
"today" -> now,
"yesterday" -> (now - 1),
"tomorrow" -> (now + 1)
)

var sanitizedInput = input.incRefCount()

// replace partial months
Expand All @@ -677,6 +684,8 @@ case class GpuCast(
cv.stringReplaceWithBackrefs("-([0-9])([ T](:?[\\r\\n]|.)*)?\\Z", "-0\\1")
}

val specialDates = calculateSpecialDates

withResource(sanitizedInput) { sanitizedInput =>

// convert dates that are in valid formats yyyy, yyyy-mm, yyyy-mm-dd
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.python._
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids._
import org.apache.spark.sql.rapids.catalyst.expressions.GpuRand
import org.apache.spark.sql.rapids.execution.{GpuBroadcastMeta, GpuBroadcastNestedLoopJoinMeta, GpuCustomShuffleReaderExec, GpuShuffleMeta}
Expand Down Expand Up @@ -1069,9 +1070,9 @@ object GpuOverrides {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
if (conf.isImprovedTimestampOpsEnabled) {
// passing the already converted strf string for a little optimization
GpuToUnixTimestampImproved(lhs, rhs, strfFormat)
GpuToUnixTimestampImproved(lhs, rhs, sparkFormat, strfFormat)
} else {
GpuToUnixTimestamp(lhs, rhs, strfFormat)
GpuToUnixTimestamp(lhs, rhs, sparkFormat, strfFormat)
}
}
})
Expand All @@ -1083,14 +1084,12 @@ object GpuOverrides {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
if (conf.isImprovedTimestampOpsEnabled) {
// passing the already converted strf string for a little optimization
GpuUnixTimestampImproved(lhs, rhs, strfFormat)
GpuUnixTimestampImproved(lhs, rhs, sparkFormat, strfFormat)
} else {
GpuUnixTimestamp(lhs, rhs, strfFormat)
GpuUnixTimestamp(lhs, rhs, sparkFormat, strfFormat)
}
}
})
.incompat("Incorrectly formatted strings and bogus dates produce garbage data" +
" instead of null"),
}),
expr[Hour](
"Returns the hour component of the string/timestamp",
(a, conf, p, r) => new UnaryExprMeta[Hour](a, conf, p, r) {
Expand Down Expand Up @@ -2012,6 +2011,16 @@ object GpuOverrides {
).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap
val execs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] =
commonExecs ++ ShimLoader.getSparkShims.getExecs

def getTimeParserPolicy: TimeParserPolicy = {
val policy = SQLConf.get.getConfString(SQLConf.LEGACY_TIME_PARSER_POLICY.key, "EXCEPTION")
policy match {
case "LEGACY" => LegacyTimeParserPolicy
case "EXCEPTION" => ExceptionTimeParserPolicy
case "CORRECTED" => CorrectedTimeParserPolicy
}
}

}
/** Tag the initial plan when AQE is enabled */
case class GpuQueryStagePrepOverrides() extends Rule[SparkPlan] with Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@

package org.apache.spark.sql.rapids

import java.sql.SQLException
import java.time.ZoneId

import ai.rapids.cudf.{BinaryOp, ColumnVector, DType, Scalar}
import com.nvidia.spark.rapids.{BinaryExprMeta, ConfKeysAndIncompat, DateUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta}
import com.nvidia.spark.rapids.{BinaryExprMeta, ConfKeysAndIncompat, DateUtils, GpuBinaryExpression, GpuCast, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta}
import com.nvidia.spark.rapids.DateUtils.TimestampFormatConversionException
import com.nvidia.spark.rapids.GpuOverrides.extractStringLit
import com.nvidia.spark.rapids.GpuOverrides.{extractStringLit, getTimeParserPolicy}
import com.nvidia.spark.rapids.RapidsPluginImplicits._

import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, ImplicitCastInputTypes, NullIntolerant, TimeZoneAwareExpression}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.CalendarInterval
Expand Down Expand Up @@ -285,6 +287,7 @@ abstract class UnixTimeExprMeta[A <: BinaryExpression with TimeZoneAwareExpressi
(expr: A, conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: ConfKeysAndIncompat) extends BinaryExprMeta[A](expr, conf, parent, rule) {
var sparkFormat: String = _
var strfFormat: String = _
override def tagExprForGpu(): Unit = {
if (ZoneId.of(expr.timeZoneId.get).normalized() != GpuOverrides.UTC_TIMEZONE_ID) {
Expand All @@ -293,11 +296,24 @@ abstract class UnixTimeExprMeta[A <: BinaryExpression with TimeZoneAwareExpressi
// Date and Timestamp work too
if (expr.right.dataType == StringType) {
try {
val rightLit = extractStringLit(expr.right)
if (rightLit.isDefined) {
strfFormat = DateUtils.toStrf(rightLit.get)
} else {
willNotWorkOnGpu("format has to be a string literal")
extractStringLit(expr.right) match {
case Some(rightLit) =>
if (GpuOverrides.getTimeParserPolicy == LegacyTimeParserPolicy) {
willNotWorkOnGpu("legacyTimeParserPolicy is LEGACY")
revans2 marked this conversation as resolved.
Show resolved Hide resolved
} else {
val gpuSupportedFormats = Seq(
revans2 marked this conversation as resolved.
Show resolved Hide resolved
"yyyy-MM-dd",
"yyyy-MM-dd HH:mm:ss"
revans2 marked this conversation as resolved.
Show resolved Hide resolved
)
sparkFormat = rightLit
if (gpuSupportedFormats.contains(sparkFormat)) {
strfFormat = DateUtils.toStrf(sparkFormat)
} else {
willNotWorkOnGpu(s"Unsupported GpuUnixTimestamp format: $sparkFormat")
}
}
case None =>
willNotWorkOnGpu("format has to be a string literal")
}
} catch {
case x: TimestampFormatConversionException =>
Expand All @@ -307,6 +323,11 @@ abstract class UnixTimeExprMeta[A <: BinaryExpression with TimeZoneAwareExpressi
}
}

sealed trait TimeParserPolicy extends Serializable
object LegacyTimeParserPolicy extends TimeParserPolicy
object ExceptionTimeParserPolicy extends TimeParserPolicy
object CorrectedTimeParserPolicy extends TimeParserPolicy

/**
* A direct conversion of Spark's ToTimestamp class which converts time to UNIX timestamp by
* first converting to microseconds and then dividing by the downScaleFactor
Expand All @@ -316,6 +337,7 @@ abstract class GpuToTimestamp

def downScaleFactor = 1000000 // MICROS IN SECOND

def sparkFormat: String
def strfFormat: String

override def inputTypes: Seq[AbstractDataType] =
Expand All @@ -326,6 +348,8 @@ abstract class GpuToTimestamp

override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess

private val timeParserPolicy = getTimeParserPolicy

override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = {
throw new IllegalArgumentException("rhs has to be a scalar for the unixtimestamp to work")
}
Expand All @@ -338,7 +362,93 @@ abstract class GpuToTimestamp
override def doColumnar(lhs: GpuColumnVector, rhs: Scalar): ColumnVector = {
val tmp = if (lhs.dataType == StringType) {
// rhs is ignored we already parsed the format
lhs.getBase.asTimestampMicroseconds(strfFormat)

val DAY_MICROS = 24 * 60 * 60 * 1000000L
val specialDates = GpuCast.calculateSpecialDates
revans2 marked this conversation as resolved.
Show resolved Hide resolved
.map {
case (name, days) => (name, days * DAY_MICROS)
}

def daysScalar(name: String): Scalar = {
Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, specialDates(name))
}

def daysEqual(name: String): ColumnVector = {
lhs.getBase.equalTo(Scalar.fromString(name))
revans2 marked this conversation as resolved.
Show resolved Hide resolved
}

// the cuDF `is_timestamp` function is less restrictive than Spark's behavior for UnixTime
// and ToUnixTime and will support parsing a subset of a string so we check the length of
// the string as well which works well for fixed-length formats but if/when we want to
// support variable-length formats (such as timestamps with milliseconds) then we will need
// to use regex instead.
val isTimestamp = withResource(lhs.getBase.getCharLengths) { actualLen =>
withResource(Scalar.fromInt(sparkFormat.length)) { expectedLen =>
withResource(actualLen.equalTo(expectedLen)) { lengthOk =>
withResource(lhs.getBase.isTimestamp(strfFormat)) { isTimestamp =>
isTimestamp.and(lengthOk)
}
}
}
}

// in addition to date/timestamp strings, we also need to check for special dates and null
// values, since anything else is invalid and should throw an error or be converted to null
// depending on the policy
withResource(isTimestamp) { isTimestamp =>
withResource(daysEqual(GpuCast.EPOCH)) { isEpoch =>
withResource(daysEqual(GpuCast.NOW)) { isNow =>
withResource(daysEqual(GpuCast.TODAY)) { isToday =>
withResource(daysEqual(GpuCast.YESTERDAY)) { isYesterday =>
withResource(daysEqual(GpuCast.TOMORROW)) { isTomorrow =>
withResource(lhs.getBase.isNull) { isNull =>
val canBeConverted = isTimestamp.or(isEpoch.or(isNow.or(isToday.or(
isYesterday.or(isTomorrow.or(isNull))))))

// throw error if legacyTimeParserPolicy is EXCEPTION
if (timeParserPolicy == ExceptionTimeParserPolicy) {
revans2 marked this conversation as resolved.
Show resolved Hide resolved
withResource(Scalar.fromBool(false)) { falseScalar =>
if (canBeConverted.hasNulls || canBeConverted.contains(falseScalar)) {
throw new RuntimeException(
revans2 marked this conversation as resolved.
Show resolved Hide resolved
s"Expression ${this.getClass.getSimpleName} failed to parse one or " +
"more values because they did not match the specified format. Set " +
"spark.sql.legacy.timeParserPolicy to CORRECTED to return null " +
"for invalid values, or to LEGACY for pre-Spark 3.0.0 behavior (" +
"LEGACY will force this expression to run on CPU though)")
}
}
}

// do the conversion
withResource(Scalar.fromNull(DType.TIMESTAMP_MICROSECONDS)) { nullValue =>
withResource(lhs.getBase.asTimestampMicroseconds(strfFormat)) { converted =>
withResource(daysScalar(GpuCast.EPOCH)) { epoch =>
withResource(daysScalar(GpuCast.NOW)) { now =>
withResource(daysScalar(GpuCast.TODAY)) { today =>
withResource(daysScalar(GpuCast.YESTERDAY)) { yesterday =>
withResource(daysScalar(GpuCast.TOMORROW)) { tomorrow =>
isTimestamp.ifElse(converted,
isEpoch.ifElse(epoch,
isNow.ifElse(now,
isToday.ifElse(today,
isYesterday.ifElse(yesterday,
isTomorrow.ifElse(tomorrow,
nullValue))))))
}
}
}
}
}
}
}
}
}
}
}
}
}
}

} else { // Timestamp or DateType
lhs.getBase.asTimestampMicroseconds()
}
Expand Down Expand Up @@ -397,9 +507,10 @@ abstract class GpuToTimestampImproved extends GpuToTimestamp {
}

case class GpuUnixTimestamp(strTs: Expression,
format: Expression,
strf: String,
timeZoneId: Option[String] = None) extends GpuToTimestamp {
format: Expression,
sparkFormat: String,
strf: String,
timeZoneId: Option[String] = None) extends GpuToTimestamp {
override def strfFormat = strf
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
copy(timeZoneId = Option(timeZoneId))
Expand All @@ -411,9 +522,10 @@ case class GpuUnixTimestamp(strTs: Expression,
}

case class GpuToUnixTimestamp(strTs: Expression,
format: Expression,
strf: String,
timeZoneId: Option[String] = None) extends GpuToTimestamp {
format: Expression,
sparkFormat: String,
strf: String,
timeZoneId: Option[String] = None) extends GpuToTimestamp {
override def strfFormat = strf
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
copy(timeZoneId = Option(timeZoneId))
Expand All @@ -425,9 +537,10 @@ case class GpuToUnixTimestamp(strTs: Expression,
}

case class GpuUnixTimestampImproved(strTs: Expression,
format: Expression,
strf: String,
timeZoneId: Option[String] = None) extends GpuToTimestampImproved {
format: Expression,
sparkFormat: String,
strf: String,
timeZoneId: Option[String] = None) extends GpuToTimestampImproved {
override def strfFormat = strf
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
copy(timeZoneId = Option(timeZoneId))
Expand All @@ -439,9 +552,10 @@ case class GpuUnixTimestampImproved(strTs: Expression,
}

case class GpuToUnixTimestampImproved(strTs: Expression,
format: Expression,
strf: String,
timeZoneId: Option[String] = None) extends GpuToTimestampImproved {
format: Expression,
sparkFormat: String,
strf: String,
timeZoneId: Option[String] = None) extends GpuToTimestampImproved {
override def strfFormat = strf
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
copy(timeZoneId = Option(timeZoneId))
Expand Down
Loading