diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 571d51acfb6..150d9e56f39 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -86,8 +86,8 @@ There is a build script `build/buildall` that automates the local build process. `./buid/buildall --help` for up-to-date use information. By default, it builds everything that is needed to create a distribution jar for all released (noSnapshots) Spark versions except for Databricks. Other profiles that you can pass using `--profile=` include -- `snapshots` -- `minimumFeatureVersionMix` that currently includes 321cdh, 312, 320 is recommended for catching incompatibilities already in the local development cycle +- `snapshots` that includes all released (noSnapshots) and snapshots Spark versions except for Databricks +- `minimumFeatureVersionMix` that currently includes 321cdh, 312, 320, 330 is recommended for catching incompatibilities already in the local development cycle For initial quick iterations we can use `--profile=` to build a single-shim version. e.g., `--profile=311` for Spark 3.1.1. diff --git a/build/buildall b/build/buildall index d0b90a4ff9b..95cd0b59a72 100755 --- a/build/buildall +++ b/build/buildall @@ -159,6 +159,7 @@ case $DIST_PROFILE in 320 321 322 + 330 331 ) ;; @@ -171,6 +172,7 @@ case $DIST_PROFILE in 313 320 321 + 322 330 ) ;; diff --git a/dist/pom.xml b/dist/pom.xml index d8e364d76db..597c4c82d5b 100644 --- a/dist/pom.xml +++ b/dist/pom.xml @@ -47,11 +47,11 @@ 320, 321, 321cdh, + 322, 330 314, - 322, 331 diff --git a/docs/additional-functionality/rapids-shuffle.md b/docs/additional-functionality/rapids-shuffle.md index d667f251ef8..1c06d4c991d 100644 --- a/docs/additional-functionality/rapids-shuffle.md +++ b/docs/additional-functionality/rapids-shuffle.md @@ -286,6 +286,7 @@ In this section, we are using a docker container built using the sample dockerfi | 3.2.0 | com.nvidia.spark.rapids.spark320.RapidsShuffleManager | | 3.2.1 | com.nvidia.spark.rapids.spark321.RapidsShuffleManager | | 3.2.1 CDH | com.nvidia.spark.rapids.spark321cdh.RapidsShuffleManager | + | 3.2.2 | com.nvidia.spark.rapids.spark322.RapidsShuffleManager | | 3.3.0 | com.nvidia.spark.rapids.spark330.RapidsShuffleManager | | Databricks 9.1 | com.nvidia.spark.rapids.spark312db.RapidsShuffleManager | | Databricks 10.4 | com.nvidia.spark.rapids.spark321db.RapidsShuffleManager | diff --git a/docs/configs.md b/docs/configs.md index 8276c1a3ef4..e9f202beac7 100644 --- a/docs/configs.md +++ b/docs/configs.md @@ -269,6 +269,7 @@ Name | SQL Function(s) | Description | Default Value | Notes spark.rapids.sql.expression.NaNvl|`nanvl`|Evaluates to `left` iff left is not NaN, `right` otherwise|true|None| spark.rapids.sql.expression.NamedLambdaVariable| |A parameter to a higher order SQL function|true|None| spark.rapids.sql.expression.Not|`!`, `not`|Boolean not operator|true|None| +spark.rapids.sql.expression.NthValue|`nth_value`|nth window operator|true|None| spark.rapids.sql.expression.OctetLength|`octet_length`|The byte length of string data|true|None| spark.rapids.sql.expression.Or|`or`|Logical OR|true|None| spark.rapids.sql.expression.PercentRank|`percent_rank`|Window function that returns the percent rank value within the aggregation window|true|None| diff --git a/docs/supported_ops.md b/docs/supported_ops.md index f3b20edf50c..645162b647d 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -9445,6 +9445,74 @@ are limited. +NthValue +`nth_value` +nth window operator +None +window +input +S +S +S +S +S +S +S +S +PS
UTC is only supported TZ for TIMESTAMP
+S +S +S +NS +NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
+NS + + +offset + + + +S + + + + + + + + + + + + + + + + +result +S +S +S +S +S +S +S +S +PS
UTC is only supported TZ for TIMESTAMP
+S +S +S +NS +NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
+NS + + OctetLength `octet_length` The byte length of string data @@ -9492,6 +9560,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Or `or` Logical OR @@ -9624,32 +9718,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - PercentRank `percent_rank` Window function that returns the percent rank value within the aggregation window @@ -9944,6 +10012,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + PreciseTimestampConversion Expression used internally to convert the TimestampType to Long and back without losing precision, i.e. in microseconds. Used in time windowing @@ -9991,32 +10085,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - PromotePrecision PromotePrecision before arithmetic operations between DecimalType data @@ -15729,44 +15797,44 @@ are limited. window input +S +S +S +S +S +S +S +S +PS
UTC is only supported TZ for TIMESTAMP
+S +S +S NS NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
NS result +S +S +S +S +S +S +S +S +PS
UTC is only supported TZ for TIMESTAMP
+S +S +S NS NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
NS @@ -15888,44 +15956,44 @@ are limited. window input +S +S +S +S +S +S +S +S +PS
UTC is only supported TZ for TIMESTAMP
+S +S +S NS NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
NS result +S +S +S +S +S +S +S +S +PS
UTC is only supported TZ for TIMESTAMP
+S +S +S NS NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
NS diff --git a/integration_tests/src/main/python/conditionals_test.py b/integration_tests/src/main/python/conditionals_test.py index e3042433474..33c2d8ca31b 100644 --- a/integration_tests/src/main/python/conditionals_test.py +++ b/integration_tests/src/main/python/conditionals_test.py @@ -16,7 +16,7 @@ from asserts import assert_gpu_and_cpu_are_equal_collect from data_gen import * -from spark_session import is_before_spark_320 +from spark_session import is_before_spark_320, is_jvm_charset_utf8 from pyspark.sql.types import * import pyspark.sql.functions as f @@ -199,6 +199,7 @@ def test_conditional_with_side_effects_col_scalar(data_gen): conf = ansi_enabled_conf) @pytest.mark.parametrize('data_gen', [mk_str_gen('[0-9]{1,20}')], ids=idfn) +@pytest.mark.skipif(not is_jvm_charset_utf8(), reason="regular expressions require UTF-8") def test_conditional_with_side_effects_cast(data_gen): test_conf=copy_and_update( ansi_enabled_conf, {'spark.rapids.sql.regexp.enabled': True}) @@ -208,6 +209,7 @@ def test_conditional_with_side_effects_cast(data_gen): conf = test_conf) @pytest.mark.parametrize('data_gen', [mk_str_gen('[0-9]{1,9}')], ids=idfn) +@pytest.mark.skipif(not is_jvm_charset_utf8(), reason="regular expressions require UTF-8") def test_conditional_with_side_effects_case_when(data_gen): test_conf=copy_and_update( ansi_enabled_conf, {'spark.rapids.sql.regexp.enabled': True}) diff --git a/integration_tests/src/main/python/qa_nightly_select_test.py b/integration_tests/src/main/python/qa_nightly_select_test.py index f8515a09bb9..286eb0bfc07 100644 --- a/integration_tests/src/main/python/qa_nightly_select_test.py +++ b/integration_tests/src/main/python/qa_nightly_select_test.py @@ -22,7 +22,7 @@ from asserts import assert_gpu_and_cpu_are_equal_collect from qa_nightly_sql import * import pytest -from spark_session import with_cpu_session +from spark_session import with_cpu_session, is_jvm_charset_utf8 from marks import approximate_float, ignore_order, incompat, qarun from data_gen import copy_and_update @@ -218,3 +218,16 @@ def test_select_float_order_local(sql_query_line, pytestconfig): with_cpu_session(num_stringDf) assert_gpu_and_cpu_are_equal_collect(lambda spark: spark.sql(sql_query), conf=_qa_conf) + +@approximate_float(abs=1e-6) +@incompat +@ignore_order(local=True) +@qarun +@pytest.mark.parametrize('sql_query_line', SELECT_REGEXP_SQL, ids=idfn) +@pytest.mark.skipif(not is_jvm_charset_utf8(), reason="Regular expressions require UTF-8") +def test_select_regexp(sql_query_line, pytestconfig): + sql_query = sql_query_line[0] + if sql_query: + print(sql_query) + with_cpu_session(num_stringDf) + assert_gpu_and_cpu_are_equal_collect(lambda spark: spark.sql(sql_query), conf=_qa_conf) diff --git a/integration_tests/src/main/python/qa_nightly_sql.py b/integration_tests/src/main/python/qa_nightly_sql.py index 4d750bf9413..c5432091c78 100644 --- a/integration_tests/src/main/python/qa_nightly_sql.py +++ b/integration_tests/src/main/python/qa_nightly_sql.py @@ -194,7 +194,6 @@ ("SELECT * FROM test_table WHERE strF LIKE 'Y%'", "* WHERE strF LIKE 'Y%'"), ("SELECT * FROM test_table WHERE strF LIKE '%an' ", "* WHERE strF LIKE '%an'"), ("SELECT REPLACE(strF, 'Yuan', 'Eric') FROM test_table", "REPLACE(strF, 'Yuan', 'Eric')"), -("SELECT REGEXP_REPLACE(strF, 'Yu', 'Eric') FROM test_table", "REGEXP_REPLACE(strF, 'Yu', 'Eric')"), #("SELECT REGEXP_REPLACE(strF, 'Y*', 'Eric') FROM test_table", "REGEXP_REPLACE(strF, 'Y*', 'Eric')"), ("SELECT CONCAT(strF, strF) FROM test_table", "CONCAT(strF, strF)"), # (" DATETIME", "DATETIME"), @@ -816,3 +815,7 @@ ("SELECT IFNULL(floatF, 0) as if_null FROM test_table", "IFNULL(floatF, 0)"), ("SELECT floatF, COALESCE(floatF, 0) FROM test_table", "floatF, COALESCE(floatF,0)"), ] + +SELECT_REGEXP_SQL=[ +("SELECT REGEXP_REPLACE(strF, 'Yu', 'Eric') FROM test_table", "REGEXP_REPLACE(strF, 'Yu', 'Eric')"), +] \ No newline at end of file diff --git a/integration_tests/src/main/python/regexp_no_unicode_test.py b/integration_tests/src/main/python/regexp_no_unicode_test.py index 230c06f4d3f..960a9349bb0 100644 --- a/integration_tests/src/main/python/regexp_no_unicode_test.py +++ b/integration_tests/src/main/python/regexp_no_unicode_test.py @@ -20,7 +20,9 @@ from marks import * from pyspark.sql.types import * -if locale.nl_langinfo(locale.CODESET) == 'UTF-8': +from spark_session import is_jvm_charset_utf8 + +if is_jvm_charset_utf8(): pytestmark = pytest.mark.skip(reason=str("Current locale uses UTF-8, fallback will not occur")) _regexp_conf = { 'spark.rapids.sql.regexp.enabled': 'true' } diff --git a/integration_tests/src/main/python/regexp_test.py b/integration_tests/src/main/python/regexp_test.py index d8988840b32..b770c5d974e 100644 --- a/integration_tests/src/main/python/regexp_test.py +++ b/integration_tests/src/main/python/regexp_test.py @@ -21,9 +21,9 @@ from data_gen import * from marks import * from pyspark.sql.types import * -from spark_session import is_before_spark_320 +from spark_session import is_before_spark_320, is_jvm_charset_utf8 -if locale.nl_langinfo(locale.CODESET) != 'UTF-8': +if not is_jvm_charset_utf8(): pytestmark = [pytest.mark.regexp, pytest.mark.skip(reason=str("Current locale doesn't support UTF-8, regexp support is disabled"))] else: pytestmark = pytest.mark.regexp @@ -761,4 +761,4 @@ def test_regexp_split_unicode_support(): 'split(a, "[o]{1,2}", -1)', 'split(a, "[bf]", -1)', 'split(a, "[o]", -2)'), - conf=_regexp_conf) \ No newline at end of file + conf=_regexp_conf) diff --git a/integration_tests/src/main/python/spark_session.py b/integration_tests/src/main/python/spark_session.py index d6feb691222..c497ff91f13 100644 --- a/integration_tests/src/main/python/spark_session.py +++ b/integration_tests/src/main/python/spark_session.py @@ -189,3 +189,10 @@ def get_java_major_version(): elif dash_pos != -1: ver = ver[0:dash_pos] return int(ver) + +def get_jvm_charset(): + sc = _spark.sparkContext + return str(sc._jvm.java.nio.charset.Charset.defaultCharset()) + +def is_jvm_charset_utf8(): + return get_jvm_charset() == 'UTF-8' diff --git a/integration_tests/src/main/python/window_function_test.py b/integration_tests/src/main/python/window_function_test.py index df8ab18bde8..313e4f57b66 100644 --- a/integration_tests/src/main/python/window_function_test.py +++ b/integration_tests/src/main/python/window_function_test.py @@ -275,7 +275,8 @@ def test_window_aggs_for_ranges_numeric_long_overflow(data_gen): ' range between 9223372036854775807 preceding and 9223372036854775807 following) as sum_c_asc, ' 'from window_agg_table') -# In a distributed setup the order of the partitions returend might be different, so we must ignore the order + +# In a distributed setup the order of the partitions returned might be different, so we must ignore the order # but small batch sizes can make sort very slow, so do the final order by locally @ignore_order(local=True) @pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches @@ -1066,3 +1067,45 @@ def test_unbounded_to_unbounded_window(): assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.range(1024).selectExpr( 'SUM(id) OVER ()', 'COUNT(1) OVER ()')) + + +_nested_gens = array_gens_sample + struct_gens_sample + map_gens_sample +exprs_for_nth_first_last = \ + 'first(a) OVER (PARTITION BY b ORDER BY c RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), ' \ + 'first(a) OVER (PARTITION BY b ORDER BY c RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), ' \ + 'first(a) OVER (PARTITION BY b ORDER BY c ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), ' \ + 'first(a) OVER (PARTITION BY b ORDER BY c ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), ' \ + 'last (a) OVER (PARTITION BY b ORDER BY c RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), ' \ + 'last (a) OVER (PARTITION BY b ORDER BY c RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), ' \ + 'last (a) OVER (PARTITION BY b ORDER BY c ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), ' \ + 'last (a) OVER (PARTITION BY b ORDER BY c ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), ' \ + 'NTH_VALUE(a, 1) OVER (PARTITION BY b ORDER BY c RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), ' \ + 'NTH_VALUE(a, 2) OVER (PARTITION BY b ORDER BY c RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), ' \ + 'NTH_VALUE(a, 3) OVER (PARTITION BY b ORDER BY c ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), ' \ + 'NTH_VALUE(a, 3) OVER (PARTITION BY b ORDER BY c ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), '\ + 'first(a, true) OVER (PARTITION BY b ORDER BY c), ' \ + 'last (a, true) OVER (PARTITION BY b ORDER BY c), ' \ + 'last (a, true) OVER (PARTITION BY b ORDER BY c) ' +exprs_for_nth_first_last_ignore_nulls = \ + 'NTH_VALUE(a, 1) IGNORE NULLS OVER (PARTITION BY b ORDER BY c RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), ' \ + 'first(a) IGNORE NULLS OVER (PARTITION BY b ORDER BY c ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), ' \ + 'last(a) IGNORE NULLS OVER (PARTITION BY b ORDER BY c) ' + +@pytest.mark.parametrize('data_gen', all_basic_gens_no_null + decimal_gens + _nested_gens, ids=idfn) +def test_window_first_last_nth(data_gen): + assert_gpu_and_cpu_are_equal_sql( + # Coalesce is to make sure that first and last, which are non-deterministic become deterministic + lambda spark: three_col_df(spark, data_gen, string_gen, int_gen, num_slices=1).coalesce(1), + "window_agg_table", + 'SELECT a, b, c, ' + exprs_for_nth_first_last + + 'FROM window_agg_table') + +@pytest.mark.skipif(is_before_spark_320(), reason='IGNORE NULLS clause is not supported for FIRST(), LAST() and NTH_VALUE in Spark 3.1.x') +@pytest.mark.parametrize('data_gen', all_basic_gens_no_null + decimal_gens + _nested_gens, ids=idfn) +def test_window_first_last_nth_ignore_nulls(data_gen): + assert_gpu_and_cpu_are_equal_sql( + # Coalesce is to make sure that first and last, which are non-deterministic become deterministic + lambda spark: three_col_df(spark, data_gen, string_gen, int_gen, num_slices=1).coalesce(1), + "window_agg_table", + 'SELECT a, b, c, ' + exprs_for_nth_first_last_ignore_nulls + + 'FROM window_agg_table') \ No newline at end of file diff --git a/jenkins/spark-premerge-build.sh b/jenkins/spark-premerge-build.sh index 40ae460e1b9..775f5385263 100755 --- a/jenkins/spark-premerge-build.sh +++ b/jenkins/spark-premerge-build.sh @@ -50,7 +50,7 @@ mvn_verify() { # enable UTF-8 for regular expression tests env -u SPARK_HOME LC_ALL="en_US.UTF-8" mvn $MVN_URM_MIRROR -Dbuildver=320 test -Drat.skip=true -Dmaven.javadoc.skip=true -Dskip -Dmaven.scalastyle.skip=true -Dcuda.version=$CUDA_CLASSIFIER -Dpytest.TEST_TAGS='' -pl '!tools' -DwildcardSuites=com.nvidia.spark.rapids.ConditionalsSuite,com.nvidia.spark.rapids.RegularExpressionSuite,com.nvidia.spark.rapids.RegularExpressionTranspilerSuite env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Dbuildver=321 clean install -Drat.skip=true -DskipTests -Dmaven.javadoc.skip=true -Dskip -Dmaven.scalastyle.skip=true -Dcuda.version=$CUDA_CLASSIFIER -pl aggregator -am - [[ $BUILD_MAINTENANCE_VERSION_SNAPSHOTS == "true" ]] && env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Dbuildver=322 clean install -Drat.skip=true -DskipTests -Dmaven.javadoc.skip=true -Dskip -Dmaven.scalastyle.skip=true -Dcuda.version=$CUDA_CLASSIFIER -pl aggregator -am + env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Dbuildver=322 clean install -Drat.skip=true -DskipTests -Dmaven.javadoc.skip=true -Dskip -Dmaven.scalastyle.skip=true -Dcuda.version=$CUDA_CLASSIFIER -pl aggregator -am env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Dbuildver=330 clean install -Drat.skip=true -DskipTests -Dmaven.javadoc.skip=true -Dskip -Dmaven.scalastyle.skip=true -Dcuda.version=$CUDA_CLASSIFIER -pl aggregator -am [[ $BUILD_MAINTENANCE_VERSION_SNAPSHOTS == "true" ]] && env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Dbuildver=331 clean install -Drat.skip=true -DskipTests -Dmaven.javadoc.skip=true -Dskip -Dmaven.scalastyle.skip=true -Dcuda.version=$CUDA_CLASSIFIER -pl aggregator -am [[ $BUILD_FEATURE_VERSION_SNAPSHOTS == "true" ]] && env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Dbuildver=340 clean install -Drat.skip=true -DskipTests -Dmaven.javadoc.skip=true -Dskip -Dmaven.scalastyle.skip=true -Dcuda.version=$CUDA_CLASSIFIER -pl aggregator -am diff --git a/pom.xml b/pom.xml index 75be178baa8..0494b723bdf 100644 --- a/pom.xml +++ b/pom.xml @@ -1008,7 +1008,7 @@ 3.2.1 3.2.1.3.2.7171000.0-3 3.2.1-databricks - 3.2.2-SNAPSHOT + 3.2.2 3.3.0 3.3.1-SNAPSHOT 3.4.0-SNAPSHOT diff --git a/sql-plugin/src/main/322/scala/com/nvidia/spark/rapids/shims/spark322/SparkShimServiceProvider.scala b/sql-plugin/src/main/322/scala/com/nvidia/spark/rapids/shims/spark322/SparkShimServiceProvider.scala index 66d962b0fc6..1df0b81e6b7 100644 --- a/sql-plugin/src/main/322/scala/com/nvidia/spark/rapids/shims/spark322/SparkShimServiceProvider.scala +++ b/sql-plugin/src/main/322/scala/com/nvidia/spark/rapids/shims/spark322/SparkShimServiceProvider.scala @@ -20,7 +20,7 @@ import com.nvidia.spark.rapids.SparkShimVersion object SparkShimServiceProvider { val VERSION = SparkShimVersion(3, 2, 2) - val VERSIONNAMES = Seq(s"$VERSION", s"$VERSION-SNAPSHOT") + val VERSIONNAMES = Seq(s"$VERSION") } class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index f4aeaed360c..a13e3f69cd2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2346,18 +2346,36 @@ object GpuOverrides extends Logging { override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = GpuSum(childExprs.head, a.dataType) }), + expr[NthValue]( + "nth window operator", + ExprChecks.windowOnly( + (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), + TypeSig.all, + Seq(ParamCheck("input", + (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), + TypeSig.all), + ParamCheck("offset", TypeSig.lit(TypeEnum.INT), TypeSig.lit(TypeEnum.INT))) + ), + (a, conf, p, r) => new AggExprMeta[NthValue](a, conf, p, r) { + override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = + GpuNthValue(childExprs.head, a.offset, a.ignoreNulls) + + // nth does not overflow, so it doesn't need the ANSI check + override val needsAnsiCheck: Boolean = false + }), expr[First]( - "first aggregate operator", { - ExprChecks.aggNotWindow( + "first aggregate operator", + ExprChecks.fullAgg( + (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), + TypeSig.all, + Seq(ParamCheck("input", (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), - TypeSig.all, - Seq(ParamCheck("input", - (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + - TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), - TypeSig.all)) - ) - }, + TypeSig.all)) + ), (a, conf, p, r) => new AggExprMeta[First](a, conf, p, r) { override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = GpuFirst(childExprs.head, a.ignoreNulls) @@ -2366,17 +2384,16 @@ object GpuOverrides extends Logging { override val needsAnsiCheck: Boolean = false }), expr[Last]( - "last aggregate operator", { - ExprChecks.aggNotWindow( + "last aggregate operator", + ExprChecks.fullAgg( + (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), + TypeSig.all, + Seq(ParamCheck("input", (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), - TypeSig.all, - Seq(ParamCheck("input", - (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + - TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), - TypeSig.all)) - ) - }, + TypeSig.all)) + ), (a, conf, p, r) => new AggExprMeta[Last](a, conf, p, r) { override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = GpuLast(childExprs.head, a.ignoreNulls) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index ed8feccda3a..01ebd6c71e5 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -734,34 +734,48 @@ class CudfRegexTranspiler(mode: RegexMode) { case _ => false } } - - private def isSupportedRepetitionBase(e: RegexAST): (Boolean, Option[RegexAST]) = { + + private def getUnsupportedRepetitionBaseOption(e: RegexAST): Option[RegexAST] = { e match { case RegexEscaped(ch) => ch match { - case 'd' | 'w' | 's' | 'S' | 'h' | 'H' | 'v' | 'V' => (true, None) - case _ => (false, Some(e)) + case 'd' | 'w' | 's' | 'S' | 'h' | 'H' | 'v' | 'V' => None + case _ => Some(e) } - case RegexChar(a) if "$^".contains(a) => // example: "$*" - (false, Some(e)) + Some(e) case RegexRepetition(_, _) => // example: "a*+" - (false, Some(e)) + Some(e) case RegexSequence(parts) => - parts.foreach { part => isSupportedRepetitionBase(part) match { - case (false, unsupportedPart) => return (false, unsupportedPart) - case _ => + parts.foreach { part => getUnsupportedRepetitionBaseOption(part) match { + case r @ Some(_) => return r + case None => } } - (true, None) - + None + case RegexGroup(_, term) => - isSupportedRepetitionBase(term) + getUnsupportedRepetitionBaseOption(term) - case _ => (true, None) + case _ => None + } + } + + private def getUnsupportedRepetitionBase(e: RegexAST): RegexAST = { + getUnsupportedRepetitionBaseOption(e) match { + case None => throw new NoSuchElementException( + s"Expected repetition base ${e.toRegexString} to be unsupported but was actully supported") + case Some(unsupportedTerm) => unsupportedTerm + } + } + + private def isSupportedRepetitionBase(e: RegexAST): Boolean = { + getUnsupportedRepetitionBaseOption(e) match { + case None => true + case _ => false } } @@ -1315,7 +1329,7 @@ class CudfRegexTranspiler(mode: RegexMode) { quantifier.position) case (RegexGroup(capture, term), SimpleQuantifier(ch)) - if "+*".contains(ch) && !(isSupportedRepetitionBase(term)._1) => + if "+*".contains(ch) && !isSupportedRepetitionBase(term) => (term, ch) match { // \Z is not supported in groups case (RegexEscaped('A'), '+') | @@ -1327,13 +1341,13 @@ class CudfRegexTranspiler(mode: RegexMode) { // NOTE: (\A)* can be transpiled to (\A)? // however, (\A)? is not supported in libcudf yet case _ => - val unsupportedTerm = isSupportedRepetitionBase(term)._2.get + val unsupportedTerm = getUnsupportedRepetitionBase(term) throw new RegexUnsupportedException( s"cuDF does not support repetition of group containing: " + s"${unsupportedTerm.toRegexString}", term.position) } case (RegexGroup(capture, term), QuantifierVariableLength(n, _)) - if !(isSupportedRepetitionBase(term)._1) => + if !isSupportedRepetitionBase(term) => term match { // \Z is not supported in groups case RegexEscaped('A') | @@ -1345,13 +1359,13 @@ class CudfRegexTranspiler(mode: RegexMode) { // NOTE: (\A)* can be transpiled to (\A)? // however, (\A)? is not supported in libcudf yet case _ => - val unsupportedTerm = isSupportedRepetitionBase(term)._2.get + val unsupportedTerm = getUnsupportedRepetitionBase(term) throw new RegexUnsupportedException( s"cuDF does not support repetition of group containing: " + s"${unsupportedTerm.toRegexString}", term.position) } case (RegexGroup(capture, term), QuantifierFixedLength(n)) - if !(isSupportedRepetitionBase(term)._1) => + if !isSupportedRepetitionBase(term) => term match { // \Z is not supported in groups case RegexEscaped('A') | @@ -1363,7 +1377,7 @@ class CudfRegexTranspiler(mode: RegexMode) { // NOTE: (\A)* can be transpiled to (\A)? // however, (\A)? is not supported in libcudf yet case _ => - val unsupportedTerm = isSupportedRepetitionBase(term)._2.get + val unsupportedTerm = getUnsupportedRepetitionBase(term) throw new RegexUnsupportedException( s"cuDF does not support repetition of group containing: " + s"${unsupportedTerm.toRegexString}", term.position) @@ -1386,7 +1400,7 @@ class CudfRegexTranspiler(mode: RegexMode) { // \A{1,5} can be transpiled to \A (dropping the repetition) // \Z{1,} can be transpiled to \Z (dropping the repetition) rewrite(base, replacement, previous) - case _ if isSupportedRepetitionBase(base)._1 => + case _ if isSupportedRepetitionBase(base) => RegexRepetition(rewrite(base, replacement, None), quantifier) case (RegexRepetition(_, SimpleQuantifier('*')), SimpleQuantifier('+')) => throw new RegexUnsupportedException("Possessive quantifier *+ not supported", diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala index aea4380e3f4..5c7701f4833 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala @@ -390,38 +390,34 @@ class CudfMergeSets(override val dataType: DataType) extends CudfAggregate { override val name: String = "CudfMergeSets" } -abstract class CudfFirstLastBase extends CudfAggregate { - val includeNulls: NullPolicy - val offset: Int +class CudfNthLikeAggregate(opName: String, override val dataType: DataType, offset: Int, + includeNulls: NullPolicy) extends CudfAggregate { + + override val name = includeNulls match { + case NullPolicy.INCLUDE => opName + "IncludeNulls" + case NullPolicy.EXCLUDE => opName + "ExcludeNulls" + } + override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar = (col: cudf.ColumnVector) => col.reduce(ReductionAggregation.nth(offset, includeNulls)) + override lazy val groupByAggregate: GroupByAggregation = { GroupByAggregation.nth(offset, includeNulls) } } -class CudfFirstIncludeNulls(override val dataType: DataType) extends CudfFirstLastBase { - override val includeNulls: NullPolicy = NullPolicy.INCLUDE - override val offset: Int = 0 - override lazy val name: String = "CudfFirstIncludeNulls" -} +object CudfNthLikeAggregate { + def newFirstExcludeNulls(dataType: DataType): CudfAggregate = + new CudfNthLikeAggregate("CudfFirst", dataType, 0, NullPolicy.EXCLUDE) -class CudfFirstExcludeNulls(override val dataType: DataType) extends CudfFirstLastBase { - override val includeNulls: NullPolicy = NullPolicy.EXCLUDE - override val offset: Int = 0 - override lazy val name: String = "CudfFirstExcludeNulls" -} + def newFirstIncludeNulls(dataType: DataType): CudfAggregate = + new CudfNthLikeAggregate("CudfFirst", dataType, 0, NullPolicy.INCLUDE) -class CudfLastIncludeNulls(override val dataType: DataType) extends CudfFirstLastBase { - override val includeNulls: NullPolicy = NullPolicy.INCLUDE - override val offset: Int = -1 - override lazy val name: String = "CudfLastIncludeNulls" -} + def newLastExcludeNulls(dataType: DataType): CudfAggregate = + new CudfNthLikeAggregate("CudfLast", dataType, -1, NullPolicy.EXCLUDE) -class CudfLastExcludeNulls(override val dataType: DataType) extends CudfFirstLastBase { - override val includeNulls: NullPolicy = NullPolicy.EXCLUDE - override val offset: Int = -1 - override lazy val name: String = "CudfLastExcludeNulls" + def newLastIncludeNulls(dataType: DataType): CudfAggregate = + new CudfNthLikeAggregate("CudfLast", dataType, -1, NullPolicy.INCLUDE) } /** @@ -1193,10 +1189,10 @@ case class GpuPivotFirst( }) override lazy val updateAggregates: Seq[CudfAggregate] = - pivotColAttr.map(c => new CudfLastExcludeNulls(c.dataType)) + pivotColAttr.map(c => CudfNthLikeAggregate.newLastExcludeNulls(c.dataType)) override lazy val mergeAggregates: Seq[CudfAggregate] = - pivotColAttr.map(c => new CudfLastExcludeNulls(c.dataType)) + pivotColAttr.map(c => CudfNthLikeAggregate.newLastExcludeNulls(c.dataType)) override lazy val evaluateExpression: Expression = GpuCreateArray(pivotColAttr, false) @@ -1482,6 +1478,7 @@ case class GpuDecimal128Average(child: Expression, dt: DecimalType) */ case class GpuFirst(child: Expression, ignoreNulls: Boolean) extends GpuAggregateFunction + with GpuAggregateWindowFunction with GpuDeterministicFirstLastCollectShim with ImplicitCastInputTypes with Serializable { @@ -1493,11 +1490,11 @@ case class GpuFirst(child: Expression, ignoreNulls: Boolean) Seq(child, GpuLiteral(ignoreNulls, BooleanType)) private lazy val commonExpressions: Seq[CudfAggregate] = if (ignoreNulls) { - Seq(new CudfFirstExcludeNulls(cudfFirst.dataType), - new CudfFirstExcludeNulls(valueSet.dataType)) + Seq(CudfNthLikeAggregate.newFirstExcludeNulls(cudfFirst.dataType), + CudfNthLikeAggregate.newFirstExcludeNulls(valueSet.dataType)) } else { - Seq(new CudfFirstIncludeNulls(cudfFirst.dataType), - new CudfFirstIncludeNulls(valueSet.dataType)) + Seq(CudfNthLikeAggregate.newFirstIncludeNulls(cudfFirst.dataType), + CudfNthLikeAggregate.newFirstIncludeNulls(valueSet.dataType)) } // Expected input data type. @@ -1524,10 +1521,19 @@ case class GpuFirst(child: Expression, ignoreNulls: Boolean) TypeCheckSuccess } } + + // GENERAL WINDOW FUNCTION + override lazy val windowInputProjection: Seq[Expression] = inputProjection + override def windowAggregation( + inputs: Seq[(ColumnVector, Int)]): RollingAggregationOnColumn = + RollingAggregation.nth(0, if (ignoreNulls) NullPolicy.EXCLUDE else NullPolicy.INCLUDE) + .onColumn(inputs.head._2) + } case class GpuLast(child: Expression, ignoreNulls: Boolean) extends GpuAggregateFunction + with GpuAggregateWindowFunction with GpuDeterministicFirstLastCollectShim with ImplicitCastInputTypes with Serializable { @@ -1539,11 +1545,11 @@ case class GpuLast(child: Expression, ignoreNulls: Boolean) Seq(child, GpuLiteral(!ignoreNulls, BooleanType)) private lazy val commonExpressions: Seq[CudfAggregate] = if (ignoreNulls) { - Seq(new CudfLastExcludeNulls(cudfLast.dataType), - new CudfLastExcludeNulls(valueSet.dataType)) + Seq(CudfNthLikeAggregate.newLastExcludeNulls(cudfLast.dataType), + CudfNthLikeAggregate.newLastExcludeNulls(valueSet.dataType)) } else { - Seq(new CudfLastIncludeNulls(cudfLast.dataType), - new CudfLastIncludeNulls(valueSet.dataType)) + Seq(CudfNthLikeAggregate.newLastIncludeNulls(cudfLast.dataType), + CudfNthLikeAggregate.newLastIncludeNulls(valueSet.dataType)) } override lazy val initialValues: Seq[GpuLiteral] = Seq( @@ -1569,6 +1575,49 @@ case class GpuLast(child: Expression, ignoreNulls: Boolean) TypeCheckSuccess } } + + // GENERAL WINDOW FUNCTION + override lazy val windowInputProjection: Seq[Expression] = inputProjection + override def windowAggregation( + inputs: Seq[(ColumnVector, Int)]): RollingAggregationOnColumn = + RollingAggregation.nth(-1, if (ignoreNulls) NullPolicy.EXCLUDE else NullPolicy.INCLUDE) + .onColumn(inputs.head._2) +} + +case class GpuNthValue(child: Expression, offset: Expression, ignoreNulls: Boolean) + extends GpuAggregateWindowFunction + with ImplicitCastInputTypes + with Serializable { + + // offset is foldable, get value as Spark does + private lazy val offsetVal = offset.eval().asInstanceOf[Int] + + // Copied from First + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, BooleanType) + override def children: Seq[Expression] = child :: Nil + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"gpu_nth_value($child, $offset)" + + s"${if (ignoreNulls) " ignore nulls"}" + + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else { + TypeCheckSuccess + } + } + + // GENERAL WINDOW FUNCTION + override lazy val windowInputProjection: Seq[Expression] = + Seq(child) + + override def windowAggregation( + inputs: Seq[(ColumnVector, Int)]): RollingAggregationOnColumn = + RollingAggregation.nth(offsetVal - 1, + if (ignoreNulls) NullPolicy.EXCLUDE else NullPolicy.INCLUDE) + .onColumn(inputs.head._2) } trait GpuCollectBase diff --git a/tools/src/main/resources/operatorsScore.csv b/tools/src/main/resources/operatorsScore.csv index a8353a9e396..07e0f857236 100644 --- a/tools/src/main/resources/operatorsScore.csv +++ b/tools/src/main/resources/operatorsScore.csv @@ -161,6 +161,7 @@ NaNvl,4 NamedLambdaVariable,4 NormalizeNaNAndZero,4 Not,4 +NthValue,4 OctetLength,4 Or,4 PercentRank,4 diff --git a/tools/src/main/resources/supportedExprs.csv b/tools/src/main/resources/supportedExprs.csv index 955f8ce9af3..1ae14023b5d 100644 --- a/tools/src/main/resources/supportedExprs.csv +++ b/tools/src/main/resources/supportedExprs.csv @@ -333,6 +333,9 @@ Not,S,`!`; `not`,None,project,input,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA, Not,S,`!`; `not`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA Not,S,`!`; `not`,None,AST,input,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA Not,S,`!`; `not`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +NthValue,S,`nth_value`,None,window,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS +NthValue,S,`nth_value`,None,window,offset,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +NthValue,S,`nth_value`,None,window,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS OctetLength,S,`octet_length`,None,project,input,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,NA OctetLength,S,`octet_length`,None,project,result,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA Or,S,`or`,None,project,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA @@ -594,14 +597,14 @@ First,S,`first_value`; `first`,None,aggregation,input,S,S,S,S,S,S,S,S,PS,S,S,S,N First,S,`first_value`; `first`,None,aggregation,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS First,S,`first_value`; `first`,None,reduction,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS First,S,`first_value`; `first`,None,reduction,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS -First,NS,`first_value`; `first`,None,window,input,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS -First,NS,`first_value`; `first`,None,window,result,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS +First,S,`first_value`; `first`,None,window,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS +First,S,`first_value`; `first`,None,window,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS Last,S,`last`; `last_value`,None,aggregation,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS Last,S,`last`; `last_value`,None,aggregation,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS Last,S,`last`; `last_value`,None,reduction,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS Last,S,`last`; `last_value`,None,reduction,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS -Last,NS,`last`; `last_value`,None,window,input,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS -Last,NS,`last`; `last_value`,None,window,result,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS +Last,S,`last`; `last_value`,None,window,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS +Last,S,`last`; `last_value`,None,window,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS Max,S,`max`,None,aggregation,input,S,S,S,S,S,PS,PS,S,PS,S,S,S,NS,NS,NS,NA,PS,NS Max,S,`max`,None,aggregation,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NA,PS,NS Max,S,`max`,None,reduction,input,S,S,S,S,S,PS,PS,S,PS,S,S,S,NS,NS,NS,NA,PS,NS diff --git a/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/HashAggregateExecParser.scala b/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/HashAggregateExecParser.scala index 2c5e7714d18..37ca01c93e4 100644 --- a/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/HashAggregateExecParser.scala +++ b/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/HashAggregateExecParser.scala @@ -35,7 +35,11 @@ case class HashAggregateExecParser( val accumId = node.metrics.find( _.name == "time in aggregation build total").map(_.accumulatorId) val maxDuration = SQLPlanParser.getTotalDuration(accumId, app) - val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName)) { + val exprString = node.desc.replaceFirst("HashAggregate", "") + val expressions = SQLPlanParser.parseAggregateExpressions(exprString) + val isAllExprsSupported = expressions.forall(expr => checker.isExprSupported(expr)) + val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName) && + isAllExprsSupported) { (checker.getSpeedupFactor(fullExecName), true) } else { (1.0, false) diff --git a/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ObjectHashAggregateExecParser.scala b/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ObjectHashAggregateExecParser.scala index 995c51ee411..c77761aed2b 100644 --- a/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ObjectHashAggregateExecParser.scala +++ b/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ObjectHashAggregateExecParser.scala @@ -35,7 +35,11 @@ case class ObjectHashAggregateExecParser( val accumId = node.metrics.find( _.name == "time in aggregation build total").map(_.accumulatorId) val maxDuration = SQLPlanParser.getTotalDuration(accumId, app) - val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName)) { + val exprString = node.desc.replaceFirst("ObjectHashAggregate", "") + val expressions = SQLPlanParser.parseAggregateExpressions(exprString) + val isAllExprsSupported = expressions.forall(expr => checker.isExprSupported(expr)) + val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName) && + isAllExprsSupported) { (checker.getSpeedupFactor(fullExecName), true) } else { (1.0, false) diff --git a/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala b/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala index 84e6e1c68e6..659fd18d682 100644 --- a/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala +++ b/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala @@ -285,7 +285,62 @@ object SQLPlanParser extends Logging { case _ => // NO OP } } - parsedExpressions.toArray + parsedExpressions.distinct.toArray + } + + // This parser is used for SortAggregateExec, HashAggregateExec and ObjectHashAggregateExec + def parseAggregateExpressions(exprStr: String): Array[String] = { + val parsedExpressions = ArrayBuffer[String]() + // (key=[num#83], functions=[partial_collect_list(letter#84, 0, 0), partial_count(letter#84)]) + val pattern = """functions=\[([\w#, +*\\\-\.<>=\`\(\)]+\])""".r + val aggregatesString = pattern.findFirstMatchIn(exprStr) + // This is to split multiple column names in AggregateExec. Each column will be aggregating + // based on the aggregate function. Here "partial_" is removed and only function name is + // preserved. Below regex will first remove the "functions=" from the string followed by + // removing "partial_". That string is split which produces an array containing + // column names. Finally we remove the parentheses from the beginning and end to get only + // the expressions. Result will be as below. + // paranRemoved = Array(collect_list(letter#84, 0, 0),, count(letter#84)) + if (aggregatesString.isDefined) { + val paranRemoved = aggregatesString.get.toString.replaceAll("functions=", ""). + replaceAll("partial_", "").split("(?<=\\),)").map(_.trim). + map(_.replaceAll("""^\[+""", "").replaceAll("""\]+$""", "")) + val functionPattern = """(\w+)\(.*\)""".r + paranRemoved.foreach { case expr => + val functionName = getFunctionName(functionPattern, expr) + functionName match { + case Some(func) => parsedExpressions += func + case _ => // NO OP + } + } + } + parsedExpressions.distinct.toArray + } + + def parseSortExpressions(exprStr: String): Array[String] = { + val parsedExpressions = ArrayBuffer[String]() + // Sort [round(num#126, 0) ASC NULLS FIRST, letter#127 DESC NULLS LAST], true, 0 + val pattern = """\[([\w#, \(\)]+\])""".r + val sortString = pattern.findFirstMatchIn(exprStr) + // This is to split multiple column names in SortExec. Project may have a function on a column. + // The string is split on delimiter containing FIRST, OR LAST, which is the last string + // of each column in SortExec that produces an array containing + // column names. Finally we remove the parentheses from the beginning and end to get only + // the expressions. Result will be as below. + // paranRemoved = Array(round(num#7, 0) ASC NULLS FIRST,, letter#8 DESC NULLS LAST) + if (sortString.isDefined) { + val paranRemoved = sortString.get.toString.split("(?<=FIRST,)|(?<=LAST,)"). + map(_.trim).map(_.replaceAll("""^\[+""", "").replaceAll("""\]+$""", "")) + val functionPattern = """(\w+)\(.*\)""".r + paranRemoved.foreach { case expr => + val functionName = getFunctionName(functionPattern, expr) + functionName match { + case Some(func) => parsedExpressions += func + case _ => // NO OP + } + } + } + parsedExpressions.distinct.toArray } def parseFilterExpressions(exprStr: String): Array[String] = { @@ -362,6 +417,6 @@ object SQLPlanParser extends Logging { } } } - parsedExpressions.toArray + parsedExpressions.distinct.toArray } } diff --git a/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SortAggregateExecParser.scala b/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SortAggregateExecParser.scala index 425ffe9193b..5109bbd7f58 100644 --- a/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SortAggregateExecParser.scala +++ b/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SortAggregateExecParser.scala @@ -30,7 +30,11 @@ case class SortAggregateExecParser( override def parse: ExecInfo = { // SortAggregate doesn't have duration val duration = None - val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName)) { + val exprString = node.desc.replaceFirst("SortAggregate", "") + val expressions = SQLPlanParser.parseAggregateExpressions(exprString) + val isAllExprsSupported = expressions.forall(expr => checker.isExprSupported(expr)) + val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName) && + isAllExprsSupported) { (checker.getSpeedupFactor(fullExecName), true) } else { (1.0, false) diff --git a/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SortExecParser.scala b/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SortExecParser.scala index 288c484965a..e043a8e828a 100644 --- a/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SortExecParser.scala +++ b/tools/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SortExecParser.scala @@ -30,7 +30,11 @@ case class SortExecParser( override def parse: ExecInfo = { // Sort doesn't have duration val duration = None - val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName)) { + val exprString = node.desc.replaceFirst("Sort ", "") + val expressions = SQLPlanParser.parseSortExpressions(exprString) + val isAllExprsSupported = expressions.forall(expr => checker.isExprSupported(expr)) + val (speedupFactor, isSupported) = if (checker.isExecSupported(fullExecName) && + isAllExprsSupported) { (checker.getSpeedupFactor(fullExecName), true) } else { (1.0, false) diff --git a/tools/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala b/tools/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala index 2cfeffac1cf..4f14af8fca3 100644 --- a/tools/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala +++ b/tools/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala @@ -252,7 +252,9 @@ class PluginTypeChecker extends Logging { } def isExprSupported(expr: String): Boolean = { - val exprLowercase = expr.toLowerCase + // Remove _ from the string. Example: collect_list => collectlist. + // collect_list is alias for CollectList aggregate function + val exprLowercase = expr.toLowerCase.replace("_","") if (supportedExprs.contains(exprLowercase)) { val exprSupported = supportedExprs.getOrElse(exprLowercase, "NS") if (exprSupported == "S") { diff --git a/tools/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala b/tools/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala index 4186ae30463..ff37207246a 100644 --- a/tools/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala +++ b/tools/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.{BeforeAndAfterEach, FunSuite} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, TrampolineUtil} import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.functions.{broadcast, ceil, col, collect_list, explode, hex, sum} +import org.apache.spark.sql.functions.{broadcast, ceil, col, collect_list, count, explode, hex, round, sum} import org.apache.spark.sql.rapids.tool.ToolUtils import org.apache.spark.sql.rapids.tool.qualification.QualificationAppInfo import org.apache.spark.sql.types.StringType @@ -558,6 +558,56 @@ class SQLPlanParserSuite extends FunSuite with BeforeAndAfterEach with Logging { } } + test("Expressions supported in SortAggregateExec") { + TrampolineUtil.withTempDir { eventLogDir => + val (eventLog, _) = ToolTestUtils.generateEventLog(eventLogDir, "sqlmetric") { spark => + import spark.implicits._ + spark.conf.set("spark.sql.execution.useObjectHashAggregateExec", "false") + val df1 = Seq((1, "a"), (1, "aa"), (1, "a"), (2, "b"), + (2, "b"), (3, "c"), (3, "c")).toDF("num", "letter") + df1.groupBy("num").agg(collect_list("letter").as("collected_letters"), + count("letter").as("letter_count")) + } + val pluginTypeChecker = new PluginTypeChecker() + val app = createAppFromEventlog(eventLog) + assert(app.sqlPlans.size == 1) + val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => + SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, pluginTypeChecker, app) + } + val execInfo = getAllExecsFromPlan(parsedPlans.toSeq) + val sortAggregate = execInfo.filter(_.exec == "SortAggregate") + assertSizeAndSupported(2, sortAggregate) + } + } + + test("Expressions supported in SortExec") { + TrampolineUtil.withTempDir { parquetoutputLoc => + TrampolineUtil.withTempDir { eventLogDir => + val (eventLog, _) = ToolTestUtils.generateEventLog(eventLogDir, + "ProjectExprsSupported") { spark => + import spark.implicits._ + val df1 = Seq((1.7, "a"), (1.6, "aa"), (1.1, "b"), (2.5, "a"), (2.2, "b"), + (3.2, "a"), (10.6, "c")).toDF("num", "letter") + df1.write.parquet(s"$parquetoutputLoc/testsortExec") + val df2 = spark.read.parquet(s"$parquetoutputLoc/testsortExec") + df2.sort("num").collect + df2.orderBy("num").collect + df2.select(round(col("num")), col("letter")).sort(round(col("num")), col("letter").desc) + } + val pluginTypeChecker = new PluginTypeChecker() + val app = createAppFromEventlog(eventLog) + assert(app.sqlPlans.size == 4) + val parsedPlans = app.sqlPlans.map { case (sqlID, plan) => + SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, pluginTypeChecker, app) + } + val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq) + val sortExec = allExecInfo.filter(_.exec.contains("Sort")) + assert(sortExec.size == 3) + assertSizeAndSupported(3, sortExec, 5.2) + } + } + } + test("Expressions supported in ProjectExec") { TrampolineUtil.withTempDir { parquetoutputLoc => TrampolineUtil.withTempDir { eventLogDir =>