From 8c3d24fc0386c4f2e11af968dfab2ed9316a92b5 Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Thu, 23 Dec 2021 13:52:48 +0800 Subject: [PATCH] Add DateType support for AST expressions Signed-off-by: remzi <13716567376yh@gmail.com> --- docs/supported_ops.md | 28 +++++++++---------- integration_tests/src/main/python/ast_test.py | 5 ++-- .../com/nvidia/spark/rapids/TypeChecks.scala | 2 +- .../com/nvidia/spark/rapids/literals.scala | 1 + 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index dc2d68ef73f..80506c0e161 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -1774,7 +1774,7 @@ are limited. S S S -NS +S PS
UTC is only supported TZ for TIMESTAMP
NS NS @@ -1795,7 +1795,7 @@ are limited. S S S -NS +S PS
UTC is only supported TZ for TIMESTAMP
NS NS @@ -2664,7 +2664,7 @@ are limited. S S S -NS +S PS
UTC is only supported TZ for TIMESTAMP
NS NS @@ -5183,7 +5183,7 @@ are limited. S NS NS -NS +S PS
UTC is only supported TZ for TIMESTAMP
NS NS @@ -5204,7 +5204,7 @@ are limited. S NS NS -NS +S PS
UTC is only supported TZ for TIMESTAMP
NS NS @@ -6028,7 +6028,7 @@ are limited. S NS NS -NS +S PS
UTC is only supported TZ for TIMESTAMP
NS NS @@ -6049,7 +6049,7 @@ are limited. S NS NS -NS +S PS
UTC is only supported TZ for TIMESTAMP
NS NS @@ -6186,7 +6186,7 @@ are limited. S NS NS -NS +S PS
UTC is only supported TZ for TIMESTAMP
NS NS @@ -6207,7 +6207,7 @@ are limited. S NS NS -NS +S PS
UTC is only supported TZ for TIMESTAMP
NS NS @@ -7509,7 +7509,7 @@ are limited. S NS NS -NS +S PS
UTC is only supported TZ for TIMESTAMP
NS NS @@ -7530,7 +7530,7 @@ are limited. S NS NS -NS +S PS
UTC is only supported TZ for TIMESTAMP
NS NS @@ -7641,7 +7641,7 @@ are limited. S NS NS -NS +S PS
UTC is only supported TZ for TIMESTAMP
NS NS @@ -7662,7 +7662,7 @@ are limited. S NS NS -NS +S PS
UTC is only supported TZ for TIMESTAMP
NS NS @@ -7825,7 +7825,7 @@ are limited. S S S -NS +S PS
UTC is only supported TZ for TIMESTAMP
NS NS diff --git a/integration_tests/src/main/python/ast_test.py b/integration_tests/src/main/python/ast_test.py index e0419b2ce8c..392e8c5d6b6 100644 --- a/integration_tests/src/main/python/ast_test.py +++ b/integration_tests/src/main/python/ast_test.py @@ -42,6 +42,7 @@ (float_gen, False), (double_gen, False), (timestamp_gen, True), + (date_gen, True), (string_gen, False) ] @@ -69,7 +70,7 @@ def assert_binary_ast(data_descr, func, conf={}): (data_gen, is_supported) = data_descr assert_gpu_ast(is_supported, lambda spark: func(binary_op_df(spark, data_gen)), conf=conf) -@pytest.mark.parametrize('data_gen', [boolean_gen, byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, timestamp_gen], ids=idfn) +@pytest.mark.parametrize('data_gen', [boolean_gen, byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, timestamp_gen, date_gen], ids=idfn) def test_literal(spark_tmp_path, data_gen): # Write data to Parquet so Spark generates a plan using just the count of the data. data_path = spark_tmp_path + '/AST_TEST_DATA' @@ -78,7 +79,7 @@ def test_literal(spark_tmp_path, data_gen): assert_gpu_ast(is_supported=True, func=lambda spark: spark.read.parquet(data_path).select(scalar)) -@pytest.mark.parametrize('data_gen', [boolean_gen, byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, timestamp_gen], ids=idfn) +@pytest.mark.parametrize('data_gen', [boolean_gen, byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, timestamp_gen, date_gen], ids=idfn) def test_null_literal(spark_tmp_path, data_gen): # Write data to Parquet so Spark generates a plan using just the count of the data. data_path = spark_tmp_path + '/AST_TEST_DATA' diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index e04d442d448..5687dfcc6fd 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -682,7 +682,7 @@ object TypeSig { (commonCudfTypes + BINARY + DECIMAL_64 + NULL + ARRAY + MAP).nested() + STRUCT /** All types that can appear in AST expressions */ - val astTypes: TypeSig = BOOLEAN + integral + fp + TIMESTAMP + val astTypes: TypeSig = BOOLEAN + integral + fp + TIMESTAMP + DATE /** All AST types that work for comparisons */ val comparisonAstTypes: TypeSig = astTypes - fp diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala index 25b3bb73292..0984ec2580f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala @@ -673,6 +673,7 @@ case class GpuLiteral (value: Any, dataType: DataType) extends GpuLeafExpression case TimestampType => ast.Literal.ofTimestampFromLong(DType.TIMESTAMP_MICROSECONDS, value.asInstanceOf[java.lang.Long]) + case DateType => ast.Literal.ofTimestampDaysFromInt(value.asInstanceOf[java.lang.Integer]) case _ => throw new IllegalStateException(s"$dataType is an unsupported literal type") } }