Skip to content

Commit

Permalink
Fix regression supporting to_date with Spark-3.1 (NVIDIA#1742)
Browse files Browse the repository at this point in the history
Signed-off-by: Niranjan Artal <nartal@nvidia.com>
  • Loading branch information
nartal1 authored Feb 19, 2021
1 parent 3e8ade4 commit c50cae0
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 26 deletions.
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.GetArrayItem"></a>spark.rapids.sql.expression.GetArrayItem| |Gets the field at `ordinal` in the Array|true|None|
<a name="sql.expression.GetMapValue"></a>spark.rapids.sql.expression.GetMapValue| |Gets Value from a Map based on a key|true|None|
<a name="sql.expression.GetStructField"></a>spark.rapids.sql.expression.GetStructField| |Gets the named field of the struct|true|None|
<a name="sql.expression.GetTimestamp"></a>spark.rapids.sql.expression.GetTimestamp| |Gets timestamps from strings using given pattern.|true|None|
<a name="sql.expression.GreaterThan"></a>spark.rapids.sql.expression.GreaterThan|`>`|> operator|true|None|
<a name="sql.expression.GreaterThanOrEqual"></a>spark.rapids.sql.expression.GreaterThanOrEqual|`>=`|>= operator|true|None|
<a name="sql.expression.Greatest"></a>spark.rapids.sql.expression.Greatest|`greatest`|Returns the greatest value of all parameters, skipping null values|true|None|
Expand Down
132 changes: 132 additions & 0 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -5860,6 +5860,138 @@ Accelerator support is described below.
<td><b>NS</b></td>
</tr>
<tr>
<td rowSpan="6">GetTimestamp</td>
<td rowSpan="6"> </td>
<td rowSpan="6">Gets timestamps from strings using given pattern.</td>
<td rowSpan="6">None</td>
<td rowSpan="3">project</td>
<td>timeExp</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>format</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS (A limited number of formats are supported; Literal value only)</em></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S*</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="3">lambda</td>
<td>timeExp</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>format</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="6">GreaterThan</td>
<td rowSpan="6">`>`</td>
<td rowSpan="6">> operator</td>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.rapids.TimeStamp
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -2329,8 +2330,8 @@ object GpuOverrides {

// Shim expressions should be last to allow overrides with shim-specific versions
val expressions: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] =
commonExpressions ++ GpuHiveOverrides.exprs ++ ShimLoader.getSparkShims.getExprs

commonExpressions ++ GpuHiveOverrides.exprs ++ ShimLoader.getSparkShims.getExprs ++
TimeStamp.getExprs

def wrapScan[INPUT <: Scan](
scan: INPUT,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


package org.apache.spark.sql.catalyst.expressions.rapids

import com.nvidia.spark.rapids.{ExprChecks, ExprRule, GpuExpression, GpuOverrides, TypeEnum, TypeSig}

import org.apache.spark.sql.catalyst.expressions.{Expression, GetTimestamp}
import org.apache.spark.sql.rapids.{GpuGetTimestamp, UnixTimeExprMeta}

object TimeStamp {

def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq(
GpuOverrides.expr[GetTimestamp](
"Gets timestamps from strings using given pattern.",
ExprChecks.binaryProjectNotLambda(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP,
("timeExp", TypeSig.STRING, TypeSig.STRING),
("format", TypeSig.lit(TypeEnum.STRING)
.withPsNote(TypeEnum.STRING, "A limited number of formats are supported"),
TypeSig.STRING)),
(a, conf, p, r) => new UnixTimeExprMeta[GetTimestamp](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
GpuGetTimestamp(lhs, rhs, sparkFormat, strfFormat)
}
})
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -477,11 +477,16 @@ abstract class GpuToTimestamp
} else { // Timestamp or DateType
lhs.getBase.asTimestampMicroseconds()
}
withResource(tmp) { r =>
// The type we are returning is a long not an actual timestamp
withResource(Scalar.fromInt(downScaleFactor)) { downScaleFactor =>
withResource(tmp.asLongs()) { longMicroSecs =>
longMicroSecs.div(downScaleFactor)
// Return Timestamp value if dataType it is expecting is of TimestampType
if (dataType.equals(TimestampType)) {
tmp
} else {
withResource(tmp) { tmp =>
// The type we are returning is a long not an actual timestamp
withResource(Scalar.fromInt(downScaleFactor)) { downScaleFactor =>
withResource(tmp.asLongs()) { longMicroSecs =>
longMicroSecs.div(downScaleFactor)
}
}
}
}
Expand Down Expand Up @@ -600,6 +605,24 @@ case class GpuToUnixTimestampImproved(strTs: Expression,

}

case class GpuGetTimestamp(
strTs: Expression,
format: Expression,
sparkFormat: String,
strf: String,
timeZoneId: Option[String] = None) extends GpuToTimestamp {

override def strfFormat = strf
override val downScaleFactor = 1
override def dataType: DataType = TimestampType

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

override def left: Expression = strTs
override def right: Expression = format
}

case class GpuFromUnixTime(
sec: Expression,
format: Expression,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,36 +18,35 @@ package com.nvidia.spark.rapids

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, to_date, unix_timestamp}
import org.apache.spark.sql.functions.{col, to_date, to_timestamp, unix_timestamp}
import org.apache.spark.sql.internal.SQLConf

class ParseDateTimeSuite extends SparkQueryCompareTestSuite {

val execsAllowedNonGpu = ShimLoader.getSparkShims.getSparkShimVersion match {
case SparkShimVersion(3, 1, _) =>
// The behavior has changed in Spark 3.1.0 and `to_date` gets translated to
// `cast(gettimestamp(c0#20108, yyyy-MM-dd, Some(UTC)) as date)` and we do
// not currently support `gettimestamp` on GPU
// https://github.com/NVIDIA/spark-rapids/issues/1157
Seq("ProjectExec,Alias,Cast,GetTimestamp,Literal")
case _ =>
Seq.empty
}

testSparkResultsAreEqual("to_date yyyy-MM-dd",
datesAsStrings,
conf = new SparkConf().set(SQLConf.LEGACY_TIME_PARSER_POLICY.key, "CORRECTED"),
execsAllowedNonGpu = execsAllowedNonGpu) {
conf = new SparkConf().set(SQLConf.LEGACY_TIME_PARSER_POLICY.key, "CORRECTED")) {
df => df.withColumn("c1", to_date(col("c0"), "yyyy-MM-dd"))
}

testSparkResultsAreEqual("to_date dd/MM/yyyy",
datesAsStrings,
conf = new SparkConf().set(SQLConf.LEGACY_TIME_PARSER_POLICY.key, "CORRECTED"),
execsAllowedNonGpu = execsAllowedNonGpu) {
datesAsStrings,
conf = new SparkConf().set(SQLConf.LEGACY_TIME_PARSER_POLICY.key, "CORRECTED")) {
df => df.withColumn("c1", to_date(col("c0"), "dd/MM/yyyy"))
}

testSparkResultsAreEqual("to_timestamp yyyy-MM-dd",
timestampsAsStrings,
conf = new SparkConf().set(SQLConf.LEGACY_TIME_PARSER_POLICY.key, "CORRECTED")) {
df => df.withColumn("c1", to_timestamp(col("c0"), "yyyy-MM-dd"))
}

testSparkResultsAreEqual("to_timestamp dd/MM/yyyy",
timestampsAsStrings,
conf = new SparkConf().set(SQLConf.LEGACY_TIME_PARSER_POLICY.key, "CORRECTED")) {
df => df.withColumn("c1", to_timestamp(col("c0"), "dd/MM/yyyy"))
}

testSparkResultsAreEqual("to_date default pattern",
datesAsStrings,
new SparkConf().set(SQLConf.LEGACY_TIME_PARSER_POLICY.key, "CORRECTED")) {
Expand Down

0 comments on commit c50cae0

Please sign in to comment.