diff --git a/docs/configs.md b/docs/configs.md
index 44a89c4374f..1d6c3695805 100644
--- a/docs/configs.md
+++ b/docs/configs.md
@@ -198,6 +198,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
spark.rapids.sql.expression.Subtract|`-`|Subtraction|true|None|
spark.rapids.sql.expression.Tan|`tan`|Tangent|true|None|
spark.rapids.sql.expression.Tanh|`tanh`|Hyperbolic tangent|true|None|
+spark.rapids.sql.expression.TimeAdd| |Adds interval to timestamp|true|None|
spark.rapids.sql.expression.TimeSub| |Subtracts interval from timestamp|true|None|
spark.rapids.sql.expression.ToDegrees|`degrees`|Converts radians to degrees|true|None|
spark.rapids.sql.expression.ToRadians|`radians`|Converts degrees to radians|true|None|
diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py
index 06caf94aba2..80ecc92a0c7 100644
--- a/integration_tests/src/main/python/date_time_test.py
+++ b/integration_tests/src/main/python/date_time_test.py
@@ -24,9 +24,6 @@
# We only support literal intervals for TimeSub
vals = [(-584, 1563), (1943, 1101), (2693, 2167), (2729, 0), (44, 1534), (2635, 3319),
(1885, -2828), (0, 2463), (932, 2286), (0, 0)]
-@pytest.mark.xfail(
- condition=not(is_before_spark_310()),
- reason='https://issues.apache.org/jira/browse/SPARK-32640')
@pytest.mark.parametrize('data_gen', vals, ids=idfn)
def test_timesub(data_gen):
days, seconds = data_gen
@@ -35,6 +32,15 @@ def test_timesub(data_gen):
lambda spark: unary_op_df(spark, TimestampGen(start=datetime(15, 1, 1, tzinfo=timezone.utc)), seed=1)
.selectExpr("a - (interval {} days {} seconds)".format(days, seconds)))
+@pytest.mark.parametrize('data_gen', vals, ids=idfn)
+def test_timeadd(data_gen):
+ days, seconds = data_gen
+ assert_gpu_and_cpu_are_equal_collect(
+ # We are starting at year 0005 to make sure we don't go before year 0001
+ # and beyond year 10000 while doing TimeAdd
+ lambda spark: unary_op_df(spark, TimestampGen(start=datetime(5, 1, 1, tzinfo=timezone.utc), end=datetime(15, 1, 1, tzinfo=timezone.utc)), seed=1)
+ .selectExpr("a + (interval {} days {} seconds)".format(days, seconds)))
+
@pytest.mark.parametrize('data_gen', date_gens, ids=idfn)
def test_datediff(data_gen):
assert_gpu_and_cpu_are_equal_collect(
diff --git a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala
index a14b4fded4e..86bdf57c832 100644
--- a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala
+++ b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala
@@ -36,12 +36,11 @@ import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
-import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, GpuTimeSub, ShuffleManagerShimBase}
+import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, ShuffleManagerShimBase}
import org.apache.spark.sql.rapids.execution.GpuBroadcastNestedLoopJoinExecBase
import org.apache.spark.sql.rapids.shims.spark310._
import org.apache.spark.sql.types._
import org.apache.spark.storage.{BlockId, BlockManagerId}
-import org.apache.spark.unsafe.types.CalendarInterval
class Spark310Shims extends Spark301Shims {
@@ -102,30 +101,7 @@ class Spark310Shims extends Spark301Shims {
}
override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
- val exprs310: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq(
- GpuOverrides.expr[TimeAdd](
- "Subtracts interval from timestamp",
- (a, conf, p, r) => new BinaryExprMeta[TimeAdd](a, conf, p, r) {
- override def tagExprForGpu(): Unit = {
- a.interval match {
- case Literal(intvl: CalendarInterval, DataTypes.CalendarIntervalType) =>
- if (intvl.months != 0) {
- willNotWorkOnGpu("interval months isn't supported")
- }
- case _ =>
- willNotWorkOnGpu("only literals are supported for intervals")
- }
- if (ZoneId.of(a.timeZoneId.get).normalized() != GpuOverrides.UTC_TIMEZONE_ID) {
- willNotWorkOnGpu("Only UTC zone id is supported")
- }
- }
-
- override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
- GpuTimeSub(lhs, rhs)
- }
- )
- ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
- exprs310 ++ super.exprs301
+ super.exprs301
}
override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = {
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index 63a9febc64f..9b0ae460766 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -47,7 +47,7 @@ import org.apache.spark.sql.rapids._
import org.apache.spark.sql.rapids.catalyst.expressions.GpuRand
import org.apache.spark.sql.rapids.execution.{GpuBroadcastMeta, GpuBroadcastNestedLoopJoinMeta, GpuCustomShuffleReaderExec, GpuShuffleMeta}
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/**
* Base class for all ReplacementRules
@@ -907,6 +907,28 @@ object GpuOverrides {
GpuDateDiff(lhs, rhs)
}
}),
+ expr[TimeAdd](
+ "Adds interval to timestamp",
+ (a, conf, p, r) => new BinaryExprMeta[TimeAdd](a, conf, p, r) {
+ override def tagExprForGpu(): Unit = {
+ a.interval match {
+ case Literal(intvl: CalendarInterval, DataTypes.CalendarIntervalType) =>
+ if (intvl.months != 0) {
+ willNotWorkOnGpu("interval months isn't supported")
+ }
+ case _ =>
+ willNotWorkOnGpu("only literals are supported for intervals")
+ }
+ if (ZoneId.of(a.timeZoneId.get).normalized() != GpuOverrides.UTC_TIMEZONE_ID) {
+ willNotWorkOnGpu("Only UTC zone id is supported")
+ }
+ }
+
+ override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
+ GpuTimeAdd(lhs, rhs)
+ }
+ }
+ ),
expr[ToUnixTimestamp](
"Returns the UNIX timestamp of the given time",
(a, conf, p, r) => new UnixTimeExprMeta[ToUnixTimestamp](a, conf, p, r){
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala
index 22e82953c44..532942bae12 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala
@@ -117,11 +117,12 @@ case class GpuYear(child: Expression) extends GpuDateUnaryExpression {
GpuColumnVector.from(input.getBase.year())
}
-case class GpuTimeSub(
+abstract class GpuTimeMath(
start: Expression,
interval: Expression,
timeZoneId: Option[String] = None)
- extends BinaryExpression with GpuExpression with TimeZoneAwareExpression with ExpectsInputTypes {
+ extends BinaryExpression with GpuExpression with TimeZoneAwareExpression with ExpectsInputTypes
+ with Serializable {
def this(start: Expression, interval: Expression) = this(start, interval, None)
@@ -136,10 +137,6 @@ case class GpuTimeSub(
override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess
- override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
- copy(timeZoneId = Option(timeZoneId))
- }
-
override def columnarEval(batch: ColumnarBatch): Any = {
var lhs: Any = null
var rhs: Any = null
@@ -156,7 +153,7 @@ case class GpuTimeSub(
if (usToSub != 0) {
withResource(Scalar.fromLong(usToSub)) { us_s =>
withResource(l.getBase.castTo(DType.INT64)) { us =>
- withResource(us.sub(us_s)) {longResult =>
+ withResource(intervalMath(us_s, us)) { longResult =>
GpuColumnVector.from(longResult.castTo(DType.TIMESTAMP_MICROSECONDS))
}
}
@@ -177,6 +174,36 @@ case class GpuTimeSub(
}
}
}
+
+ def intervalMath(us_s: Scalar, us: ColumnVector): ColumnVector
+}
+
+case class GpuTimeAdd(start: Expression,
+ interval: Expression,
+ timeZoneId: Option[String] = None)
+ extends GpuTimeMath(start, interval, timeZoneId) {
+
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
+ copy(timeZoneId = Option(timeZoneId))
+ }
+
+ override def intervalMath(us_s: Scalar, us: ColumnVector): ColumnVector = {
+ us.add(us_s)
+ }
+}
+
+case class GpuTimeSub(start: Expression,
+ interval: Expression,
+ timeZoneId: Option[String] = None)
+ extends GpuTimeMath(start, interval, timeZoneId) {
+
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
+ copy(timeZoneId = Option(timeZoneId))
+ }
+
+ def intervalMath(us_s: Scalar, us: ColumnVector): ColumnVector = {
+ us.sub(us_s)
+ }
}
case class GpuDateDiff(endDate: Expression, startDate: Expression)