diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 8c26e9e97c1..5ecc0f73153 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -2426,7 +2426,7 @@ Accelerator support is described below. S S* S -NS +S* S NS NS @@ -2447,7 +2447,7 @@ Accelerator support is described below. S S* S -NS +S* S NS NS diff --git a/integration_tests/src/main/python/conditionals_test.py b/integration_tests/src/main/python/conditionals_test.py index 66126c6279f..1ec590e1f11 100644 --- a/integration_tests/src/main/python/conditionals_test.py +++ b/integration_tests/src/main/python/conditionals_test.py @@ -20,6 +20,8 @@ from pyspark.sql.types import * import pyspark.sql.functions as f +all_gens = all_gen + [NullGen()] + @pytest.mark.parametrize('data_gen', all_basic_gens, ids=idfn) def test_if_else(data_gen): (s1, s2) = gen_scalars_for_sql(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen)) @@ -34,7 +36,7 @@ def test_if_else(data_gen): 'IF(a, b, {})'.format(null_lit), 'IF(a, {}, c)'.format(null_lit))) -@pytest.mark.parametrize('data_gen', all_basic_gens, ids=idfn) +@pytest.mark.parametrize('data_gen', all_gens, ids=idfn) def test_case_when(data_gen): num_cmps = 20 s1 = gen_scalar(data_gen, force_no_nulls=not isinstance(data_gen, NullGen)) diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py index 655ede78500..ec0dd982019 100644 --- a/integration_tests/src/main/python/data_gen.py +++ b/integration_tests/src/main/python/data_gen.py @@ -684,6 +684,8 @@ def to_cast_string(spark_type): return 'TIMESTAMP' elif isinstance(spark_type, StringType): return 'STRING' + elif isinstance(spark_type, DecimalType): + return 'DECIMAL({}, {})'.format(spark_type.precision, spark_type.scale) else: raise RuntimeError('CAST TO TYPE {} NOT SUPPORTED YET'.format(spark_type)) @@ -796,3 +798,8 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False): MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen)] allow_negative_scale_of_decimal_conf = {'spark.sql.legacy.allowNegativeScaleOfDecimal': 'true'} + +all_gen = [StringGen(), ByteGen(), ShortGen(), IntegerGen(), LongGen(), + FloatGen(), DoubleGen(), BooleanGen(), DateGen(), TimestampGen(), + decimal_gen_default, decimal_gen_scale_precision, decimal_gen_same_scale_precision, + decimal_gen_64bit] diff --git a/integration_tests/src/main/python/generate_expr_test.py b/integration_tests/src/main/python/generate_expr_test.py index d83f16937cb..2abc75688a6 100644 --- a/integration_tests/src/main/python/generate_expr_test.py +++ b/integration_tests/src/main/python/generate_expr_test.py @@ -28,11 +28,6 @@ def four_op_df(spark, gen, length=2048, seed=0): ('c', gen), ('d', gen)], nullable=False), length=length, seed=seed) -all_gen = [StringGen(), ByteGen(), ShortGen(), IntegerGen(), LongGen(), - FloatGen(), DoubleGen(), BooleanGen(), DateGen(), TimestampGen(), - decimal_gen_default, decimal_gen_scale_precision, decimal_gen_same_scale_precision, - decimal_gen_64bit] - #sort locally because of https://github.com/NVIDIA/spark-rapids/issues/84 # After 3.1.0 is the min spark version we can drop this @ignore_order(local=True) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 2fd76aef3b9..0cef1ddeb21 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -601,7 +601,7 @@ case class ExprChecksImpl(contexts: Map[ExpressionContext, ContextChecks]) * This is specific to CaseWhen, because it does not follow the typical parameter convention. */ object CaseWhenCheck extends ExprChecks { - val check: TypeSig = TypeSig.commonCudfTypes + TypeSig.NULL + val check: TypeSig = TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL val sparkSig: TypeSig = TypeSig.all override def tag(meta: RapidsMeta[_, _, _]): Unit = {