From 76876bc45bc4cb03f539bc774860c7274b85ce3c Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Tue, 5 Jan 2021 10:23:31 -0800 Subject: [PATCH 1/2] Support DecimalType for CaseWhen Signed-off-by: Raza Jafri --- docs/supported_ops.md | 4 ++-- integration_tests/src/main/python/conditionals_test.py | 4 +++- integration_tests/src/main/python/data_gen.py | 7 +++++++ integration_tests/src/main/python/generate_expr_test.py | 5 ----- .../main/scala/com/nvidia/spark/rapids/TypeChecks.scala | 2 +- 5 files changed, 13 insertions(+), 9 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 8c26e9e97c1..5ecc0f73153 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -2426,7 +2426,7 @@ Accelerator support is described below. S S* S -NS +S* S NS NS @@ -2447,7 +2447,7 @@ Accelerator support is described below. S S* S -NS +S* S NS NS diff --git a/integration_tests/src/main/python/conditionals_test.py b/integration_tests/src/main/python/conditionals_test.py index 66126c6279f..1ec590e1f11 100644 --- a/integration_tests/src/main/python/conditionals_test.py +++ b/integration_tests/src/main/python/conditionals_test.py @@ -20,6 +20,8 @@ from pyspark.sql.types import * import pyspark.sql.functions as f +all_gens = all_gen + [NullGen()] + @pytest.mark.parametrize('data_gen', all_basic_gens, ids=idfn) def test_if_else(data_gen): (s1, s2) = gen_scalars_for_sql(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen)) @@ -34,7 +36,7 @@ def test_if_else(data_gen): 'IF(a, b, {})'.format(null_lit), 'IF(a, {}, c)'.format(null_lit))) -@pytest.mark.parametrize('data_gen', all_basic_gens, ids=idfn) +@pytest.mark.parametrize('data_gen', all_gens, ids=idfn) def test_case_when(data_gen): num_cmps = 20 s1 = gen_scalar(data_gen, force_no_nulls=not isinstance(data_gen, NullGen)) diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py index 655ede78500..3c3c64d1a55 100644 --- a/integration_tests/src/main/python/data_gen.py +++ b/integration_tests/src/main/python/data_gen.py @@ -684,6 +684,8 @@ def to_cast_string(spark_type): return 'TIMESTAMP' elif isinstance(spark_type, StringType): return 'STRING' + elif isinstance(spark_type, DecimalType): + return 'DECIMAL' else: raise RuntimeError('CAST TO TYPE {} NOT SUPPORTED YET'.format(spark_type)) @@ -796,3 +798,8 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False): MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen)] allow_negative_scale_of_decimal_conf = {'spark.sql.legacy.allowNegativeScaleOfDecimal': 'true'} + +all_gen = [StringGen(), ByteGen(), ShortGen(), IntegerGen(), LongGen(), + FloatGen(), DoubleGen(), BooleanGen(), DateGen(), TimestampGen(), + decimal_gen_default, decimal_gen_scale_precision, decimal_gen_same_scale_precision, + decimal_gen_64bit] diff --git a/integration_tests/src/main/python/generate_expr_test.py b/integration_tests/src/main/python/generate_expr_test.py index d83f16937cb..2abc75688a6 100644 --- a/integration_tests/src/main/python/generate_expr_test.py +++ b/integration_tests/src/main/python/generate_expr_test.py @@ -28,11 +28,6 @@ def four_op_df(spark, gen, length=2048, seed=0): ('c', gen), ('d', gen)], nullable=False), length=length, seed=seed) -all_gen = [StringGen(), ByteGen(), ShortGen(), IntegerGen(), LongGen(), - FloatGen(), DoubleGen(), BooleanGen(), DateGen(), TimestampGen(), - decimal_gen_default, decimal_gen_scale_precision, decimal_gen_same_scale_precision, - decimal_gen_64bit] - #sort locally because of https://github.com/NVIDIA/spark-rapids/issues/84 # After 3.1.0 is the min spark version we can drop this @ignore_order(local=True) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 2fd76aef3b9..0cef1ddeb21 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -601,7 +601,7 @@ case class ExprChecksImpl(contexts: Map[ExpressionContext, ContextChecks]) * This is specific to CaseWhen, because it does not follow the typical parameter convention. */ object CaseWhenCheck extends ExprChecks { - val check: TypeSig = TypeSig.commonCudfTypes + TypeSig.NULL + val check: TypeSig = TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL val sparkSig: TypeSig = TypeSig.all override def tag(meta: RapidsMeta[_, _, _]): Unit = { From 550d5a07737e1dd42513e5529130bd5bdccb6534 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Tue, 5 Jan 2021 13:05:31 -0800 Subject: [PATCH 2/2] addressed review comments Signed-off-by: Raza Jafri --- integration_tests/src/main/python/data_gen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py index 3c3c64d1a55..ec0dd982019 100644 --- a/integration_tests/src/main/python/data_gen.py +++ b/integration_tests/src/main/python/data_gen.py @@ -685,7 +685,7 @@ def to_cast_string(spark_type): elif isinstance(spark_type, StringType): return 'STRING' elif isinstance(spark_type, DecimalType): - return 'DECIMAL' + return 'DECIMAL({}, {})'.format(spark_type.precision, spark_type.scale) else: raise RuntimeError('CAST TO TYPE {} NOT SUPPORTED YET'.format(spark_type))