From 8c3d24fc0386c4f2e11af968dfab2ed9316a92b5 Mon Sep 17 00:00:00 2001
From: remzi <13716567376yh@gmail.com>
Date: Thu, 23 Dec 2021 13:52:48 +0800
Subject: [PATCH] Add DateType support for AST expressions
Signed-off-by: remzi <13716567376yh@gmail.com>
---
docs/supported_ops.md | 28 +++++++++----------
integration_tests/src/main/python/ast_test.py | 5 ++--
.../com/nvidia/spark/rapids/TypeChecks.scala | 2 +-
.../com/nvidia/spark/rapids/literals.scala | 1 +
4 files changed, 19 insertions(+), 17 deletions(-)
diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index dc2d68ef73f..80506c0e161 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -1774,7 +1774,7 @@ are limited.
S |
S |
S |
-NS |
+S |
PS UTC is only supported TZ for TIMESTAMP |
NS |
NS |
@@ -1795,7 +1795,7 @@ are limited.
S |
S |
S |
-NS |
+S |
PS UTC is only supported TZ for TIMESTAMP |
NS |
NS |
@@ -2664,7 +2664,7 @@ are limited.
S |
S |
S |
-NS |
+S |
PS UTC is only supported TZ for TIMESTAMP |
NS |
NS |
@@ -5183,7 +5183,7 @@ are limited.
S |
NS |
NS |
-NS |
+S |
PS UTC is only supported TZ for TIMESTAMP |
NS |
NS |
@@ -5204,7 +5204,7 @@ are limited.
S |
NS |
NS |
-NS |
+S |
PS UTC is only supported TZ for TIMESTAMP |
NS |
NS |
@@ -6028,7 +6028,7 @@ are limited.
S |
NS |
NS |
-NS |
+S |
PS UTC is only supported TZ for TIMESTAMP |
NS |
NS |
@@ -6049,7 +6049,7 @@ are limited.
S |
NS |
NS |
-NS |
+S |
PS UTC is only supported TZ for TIMESTAMP |
NS |
NS |
@@ -6186,7 +6186,7 @@ are limited.
S |
NS |
NS |
-NS |
+S |
PS UTC is only supported TZ for TIMESTAMP |
NS |
NS |
@@ -6207,7 +6207,7 @@ are limited.
S |
NS |
NS |
-NS |
+S |
PS UTC is only supported TZ for TIMESTAMP |
NS |
NS |
@@ -7509,7 +7509,7 @@ are limited.
S |
NS |
NS |
-NS |
+S |
PS UTC is only supported TZ for TIMESTAMP |
NS |
NS |
@@ -7530,7 +7530,7 @@ are limited.
S |
NS |
NS |
-NS |
+S |
PS UTC is only supported TZ for TIMESTAMP |
NS |
NS |
@@ -7641,7 +7641,7 @@ are limited.
S |
NS |
NS |
-NS |
+S |
PS UTC is only supported TZ for TIMESTAMP |
NS |
NS |
@@ -7662,7 +7662,7 @@ are limited.
S |
NS |
NS |
-NS |
+S |
PS UTC is only supported TZ for TIMESTAMP |
NS |
NS |
@@ -7825,7 +7825,7 @@ are limited.
S |
S |
S |
-NS |
+S |
PS UTC is only supported TZ for TIMESTAMP |
NS |
NS |
diff --git a/integration_tests/src/main/python/ast_test.py b/integration_tests/src/main/python/ast_test.py
index e0419b2ce8c..392e8c5d6b6 100644
--- a/integration_tests/src/main/python/ast_test.py
+++ b/integration_tests/src/main/python/ast_test.py
@@ -42,6 +42,7 @@
(float_gen, False),
(double_gen, False),
(timestamp_gen, True),
+ (date_gen, True),
(string_gen, False)
]
@@ -69,7 +70,7 @@ def assert_binary_ast(data_descr, func, conf={}):
(data_gen, is_supported) = data_descr
assert_gpu_ast(is_supported, lambda spark: func(binary_op_df(spark, data_gen)), conf=conf)
-@pytest.mark.parametrize('data_gen', [boolean_gen, byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, timestamp_gen], ids=idfn)
+@pytest.mark.parametrize('data_gen', [boolean_gen, byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, timestamp_gen, date_gen], ids=idfn)
def test_literal(spark_tmp_path, data_gen):
# Write data to Parquet so Spark generates a plan using just the count of the data.
data_path = spark_tmp_path + '/AST_TEST_DATA'
@@ -78,7 +79,7 @@ def test_literal(spark_tmp_path, data_gen):
assert_gpu_ast(is_supported=True,
func=lambda spark: spark.read.parquet(data_path).select(scalar))
-@pytest.mark.parametrize('data_gen', [boolean_gen, byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, timestamp_gen], ids=idfn)
+@pytest.mark.parametrize('data_gen', [boolean_gen, byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, timestamp_gen, date_gen], ids=idfn)
def test_null_literal(spark_tmp_path, data_gen):
# Write data to Parquet so Spark generates a plan using just the count of the data.
data_path = spark_tmp_path + '/AST_TEST_DATA'
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
index e04d442d448..5687dfcc6fd 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
@@ -682,7 +682,7 @@ object TypeSig {
(commonCudfTypes + BINARY + DECIMAL_64 + NULL + ARRAY + MAP).nested() + STRUCT
/** All types that can appear in AST expressions */
- val astTypes: TypeSig = BOOLEAN + integral + fp + TIMESTAMP
+ val astTypes: TypeSig = BOOLEAN + integral + fp + TIMESTAMP + DATE
/** All AST types that work for comparisons */
val comparisonAstTypes: TypeSig = astTypes - fp
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala
index 25b3bb73292..0984ec2580f 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala
@@ -673,6 +673,7 @@ case class GpuLiteral (value: Any, dataType: DataType) extends GpuLeafExpression
case TimestampType =>
ast.Literal.ofTimestampFromLong(DType.TIMESTAMP_MICROSECONDS,
value.asInstanceOf[java.lang.Long])
+ case DateType => ast.Literal.ofTimestampDaysFromInt(value.asInstanceOf[java.lang.Integer])
case _ => throw new IllegalStateException(s"$dataType is an unsupported literal type")
}
}