diff --git a/docs/get-started/getting-started-workload-qualification.md b/docs/get-started/getting-started-workload-qualification.md index ce2441e509a..7272f36b0af 100644 --- a/docs/get-started/getting-started-workload-qualification.md +++ b/docs/get-started/getting-started-workload-qualification.md @@ -63,13 +63,8 @@ the other is to modify your existing Spark application code to call a function d Please note that if using adaptive execution in Spark the explain output may not be perfect as the plan could have changed along the way in a way that we wouldn't see by looking at just -the CPU plan. - -### Requirements - -- A Spark 3.x CPU cluster -- The `rapids-4-spark` and `cudf` [jars](../download.md) -- Ability to modify the existing Spark application code if using the function call directly +the CPU plan. The same applies if you are using an older version of Spark. Spark planning +may be slightly different if you go up to a newer version of Spark. ### Using the Configuration Flag for Explain Only Mode @@ -78,6 +73,13 @@ This mode allows you to run on a CPU cluster and can help us understand the pote if there are any unsupported features. Basically it will log the output which is the same as the driver logs with `spark.rapids.sql.explain=all`. +#### Requirements + +- A Spark 3.x CPU cluster +- The `rapids-4-spark` and `cudf` [jars](../download.md) + +#### Usage + 1. In `spark-shell`, add the `rapids-4-spark` and `cudf` jars into --jars option or put them in the Spark classpath and enable the configs `spark.rapids.sql.mode=explainOnly` and `spark.plugins=com.nvidia.spark.SQLPlugin`. @@ -125,20 +127,41 @@ pretty accurate. ### How to use the Function Call -Starting with version 21.12 of the RAPIDS Accelerator, a new function named -`explainPotentialGpuPlan` is added which can help us understand the potential GPU plan and if there -are any unsupported features on a CPU cluster. Basically it can return output which is the same as -the driver logs with `spark.rapids.sql.explain=all`. +A function named `explainPotentialGpuPlan` is available which can help us understand the potential +GPU plan and if there are any unsupported features on a CPU cluster. Basically it can return output +which is the same as the driver logs with `spark.rapids.sql.explain=all`. -1. In `spark-shell`, add the `rapids-4-spark` and `cudf` jars into --jars option or put them in the +#### Requirements with Spark 3.X + +- A Spark 3.X CPU cluster +- The `rapids-4-spark` and `cudf` [jars](../download.md) +- Ability to modify the existing Spark application code +- RAPIDS Accelerator for Apache Spark version 21.12 or newer + +#### Requirements with Spark 2.4.X + +- A Spark 2.4.X CPU cluster +- The `rapids-4-spark-sql-meta` [jar](../download.md) +- Ability to modify the existing Spark application code +- RAPIDS Accelerator for Apache Spark version 22.02 or newer + +#### Usage + +1. In `spark-shell`, add the necessary jars into --jars option or put them in the Spark classpath. - For example: + For example, on Spark 3.X: ```bash spark-shell --jars /PathTo/cudf-.jar,/PathTo/rapids-4-spark_.jar ``` + For example, on Spark 2.4.X: + + ```bash + spark-shell --jars /PathTo/rapids-4-spark-sql-meta-.jar + ``` + 2. Test if the class can be successfully loaded or not. ```scala @@ -148,8 +171,8 @@ the driver logs with `spark.rapids.sql.explain=all`. 3. Enable optional RAPIDS Accelerator related parameters based on your setup. Enabling optional parameters may allow more operations to run on the GPU but please understand - the meaning and risk of above parameters before enabling it. Please refer to [configs - doc](../configs.md) for details of RAPIDS Accelerator parameters. + the meaning and risk of above parameters before enabling it. Please refer to the + [configuration documentation](../configs.md) for details of RAPIDS Accelerator parameters. For example, if your jobs have `double`, `float` and `decimal` operators together with some Scala UDFs, you can set the following parameters: diff --git a/jenkins/spark-nightly-build.sh b/jenkins/spark-nightly-build.sh index 04de910e37b..0cb42f719fb 100755 --- a/jenkins/spark-nightly-build.sh +++ b/jenkins/spark-nightly-build.sh @@ -1,6 +1,6 @@ #!/bin/bash # -# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -74,6 +74,14 @@ function distWithReducedPom { $mvnExtaFlags } +# build the Spark 2.x explain jar +mvn -B $MVN_URM_MIRROR -Dmaven.repo.local=$M2DIR -Dbuildver=24X clean install -DskipTests +[[ $SKIP_DEPLOY != 'true' ]] && \ + mvn -B deploy $MVN_URM_MIRROR \ + -Dmaven.repo.local=$M2DIR \ + -DskipTests \ + -Dbuildver=24X + # build, install, and deploy all the versions we support, but skip deploy of individual dist module since we # only want the combined jar to be pushed. # Note this does not run any integration tests diff --git a/jenkins/spark-premerge-build.sh b/jenkins/spark-premerge-build.sh index 767f7747eee..4bff6d663c6 100755 --- a/jenkins/spark-premerge-build.sh +++ b/jenkins/spark-premerge-build.sh @@ -35,6 +35,9 @@ mvn_verify() { # file size check for pull request. The size of a committed file should be less than 1.5MiB pre-commit run check-added-large-files --from-ref $BASE_REF --to-ref HEAD + # build the Spark 2.x explain jar + env -u SPARK_HOME mvn -B $MVN_URM_MIRROR -Dbuildver=24X clean install -DskipTests + # build all the versions but only run unit tests on one 3.0.X version (base version covers this), one 3.1.X version, and one 3.2.X version. # All others shims test should be covered in nightly pipelines env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Dbuildver=302 clean install -Drat.skip=true -DskipTests -Dmaven.javadoc.skip=true -Dskip -Dmaven.scalastyle.skip=true -Dcuda.version=$CUDA_CLASSIFIER -pl aggregator -am diff --git a/pom.xml b/pom.xml index d469582be3a..2951ac22cad 100644 --- a/pom.xml +++ b/pom.xml @@ -69,18 +69,25 @@ -6 - - - dist - integration_tests - shims - shuffle-plugin - sql-plugin - tests - udf-compiler - udf-examples - + + release24X + + + buildver + 24X + + + + 2.11 + 2.11.12 + ${spark248.version} + ${spark248.version} + + + spark2-sql-plugin + + release301 @@ -121,6 +128,14 @@ + dist + integration_tests + shims + shuffle-plugin + sql-plugin + tests + udf-compiler + udf-examples api_validation tools aggregator @@ -166,6 +181,14 @@ spark302 + dist + integration_tests + shims + shuffle-plugin + sql-plugin + tests + udf-compiler + udf-examples aggregator tools api_validation @@ -210,6 +233,14 @@ + dist + integration_tests + shims + shuffle-plugin + sql-plugin + tests + udf-compiler + udf-examples api_validation tools aggregator @@ -254,6 +285,14 @@ + dist + integration_tests + shims + shuffle-plugin + sql-plugin + tests + udf-compiler + udf-examples api_validation tools aggregator @@ -302,6 +341,14 @@ + dist + integration_tests + shims + shuffle-plugin + sql-plugin + tests + udf-compiler + udf-examples api_validation tools aggregator @@ -359,6 +406,14 @@ + dist + integration_tests + shims + shuffle-plugin + sql-plugin + tests + udf-compiler + udf-examples aggregator @@ -416,6 +471,14 @@ + dist + integration_tests + shims + shuffle-plugin + sql-plugin + tests + udf-compiler + udf-examples aggregator @@ -462,6 +525,14 @@ + dist + integration_tests + shims + shuffle-plugin + sql-plugin + tests + udf-compiler + udf-examples tools aggregator api_validation @@ -511,6 +582,14 @@ + dist + integration_tests + shims + shuffle-plugin + sql-plugin + tests + udf-compiler + udf-examples tools aggregator api_validation @@ -557,6 +636,14 @@ + dist + integration_tests + shims + shuffle-plugin + sql-plugin + tests + udf-compiler + udf-examples tools aggregator tests-spark310+ @@ -602,6 +689,14 @@ + dist + integration_tests + shims + shuffle-plugin + sql-plugin + tests + udf-compiler + udf-examples tools aggregator tests-spark310+ @@ -647,6 +742,14 @@ + dist + integration_tests + shims + shuffle-plugin + sql-plugin + tests + udf-compiler + udf-examples tools aggregator tests-spark310+ @@ -691,6 +794,14 @@ + dist + integration_tests + shims + shuffle-plugin + sql-plugin + tests + udf-compiler + udf-examples tools aggregator tests-spark310+ @@ -745,6 +856,14 @@ + dist + integration_tests + shims + shuffle-plugin + sql-plugin + tests + udf-compiler + udf-examples tools aggregator tests-spark310+ @@ -826,6 +945,7 @@ + 2.4.8 3.0.1 3.0.1-databricks 3.0.2 @@ -1208,6 +1328,7 @@ ${spark.version.classifier}tests + true @@ -1240,7 +1361,7 @@ rmm_log.txt dependency-reduced-pom*.xml **/.*/** - src/main/java/com/nvidia/spark/rapids/format/* + **/src/main/java/com/nvidia/spark/rapids/format/* **/*.csv dist/*.txt **/META-INF/com.nvidia.spark.rapids.SparkShimServiceProvider @@ -1248,6 +1369,7 @@ default, but there are some projects that are conditionally included. --> **/target/**/* **/cufile.log + scripts/spark2diffs/*.diff diff --git a/scripts/rundiffspark2.sh b/scripts/rundiffspark2.sh new file mode 100755 index 00000000000..4dae595d630 --- /dev/null +++ b/scripts/rundiffspark2.sh @@ -0,0 +1,472 @@ +#!/bin/bash +# +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This scripts diffs the code in spark2-sql-module with the corresponding files and functions in +# sql-plugin to look for anything that has changed + +# Generally speaking this assumes the convertToGpu is the last function in the meta classes, +# if someone adds something after it we may not catch it. +# This also doesn't catch if someone adds an override in a shim or someplace we don't diff. + +# just using interface, and we don't really expect them to use it on 2.x so just skip diffing +# ../spark2-sql-plugin/src/main/java/com/nvidia/spark/RapidsUDF.java + +echo "Done running Diffs of spark2 files" + +tmp_dir=$(mktemp -d -t spark2diff-XXXXXXXXXX) +echo "Using temporary directory: $tmp_dir" + +diff ../sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveOverrides.scala ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveOverrides.scala > $tmp_dir/GpuHiveOverrides.newdiff +if [[ $(diff spark2diffs/GpuHiveOverrides.diff $tmp_dir/GpuHiveOverrides.newdiff) ]]; then + echo "check diff for ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveOverrides.scala" +fi + +sed -n '/class GpuBroadcastNestedLoopJoinMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinMeta.scala > $tmp_dir/GpuBroadcastNestedLoopJoinMeta_new.out +sed -n '/class GpuBroadcastNestedLoopJoinMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala > $tmp_dir/GpuBroadcastNestedLoopJoinMeta_old.out +diff $tmp_dir/GpuBroadcastNestedLoopJoinMeta_new.out $tmp_dir/GpuBroadcastNestedLoopJoinMeta_old.out > $tmp_dir/GpuBroadcastNestedLoopJoinMeta.newdiff +diff -c spark2diffs/GpuBroadcastNestedLoopJoinMeta.diff $tmp_dir/GpuBroadcastNestedLoopJoinMeta.newdiff + +sed -n '/object JoinTypeChecks/,/def extractTopLevelAttributes/{/def extractTopLevelAttributes/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala > $tmp_dir/GpuHashJoin_new.out +sed -n '/object JoinTypeChecks/,/def extractTopLevelAttributes/{/def extractTopLevelAttributes/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala > $tmp_dir/GpuHashJoin_old.out +diff $tmp_dir/GpuHashJoin_new.out $tmp_dir/GpuHashJoin_old.out > $tmp_dir/GpuHashJoin.newdiff +diff -c spark2diffs/GpuHashJoin.diff $tmp_dir/GpuHashJoin.newdiff + +sed -n '/class GpuShuffleMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleMeta.scala > $tmp_dir/GpuShuffleMeta_new.out +sed -n '/class GpuShuffleMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala > $tmp_dir/GpuShuffleMeta_old.out +diff $tmp_dir/GpuShuffleMeta_new.out $tmp_dir/GpuShuffleMeta_old.out > $tmp_dir/GpuShuffleMeta.newdiff +diff -c spark2diffs/GpuShuffleMeta.diff $tmp_dir/GpuShuffleMeta.newdiff + +sed -n '/class GpuBroadcastMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExecMeta.scala > $tmp_dir/GpuBroadcastExchangeExecMeta_new.out +sed -n '/class GpuBroadcastMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala > $tmp_dir/GpuBroadcastExchangeExecMeta_old.out +diff $tmp_dir/GpuBroadcastExchangeExecMeta_new.out $tmp_dir/GpuBroadcastExchangeExecMeta_old.out > $tmp_dir/GpuBroadcastExchangeExecMeta.newdiff +diff -c spark2diffs/GpuBroadcastExchangeExecMeta.diff $tmp_dir/GpuBroadcastExchangeExecMeta.newdiff + +sed -n '/abstract class UnixTimeExprMeta/,/sealed trait TimeParserPolicy/{/sealed trait TimeParserPolicy/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsMeta.scala > $tmp_dir/UnixTimeExprMeta_new.out +sed -n '/abstract class UnixTimeExprMeta/,/sealed trait TimeParserPolicy/{/sealed trait TimeParserPolicy/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala > $tmp_dir/UnixTimeExprMeta_old.out +diff $tmp_dir/UnixTimeExprMeta_new.out $tmp_dir/UnixTimeExprMeta_old.out > $tmp_dir/UnixTimeExprMeta.newdiff +diff -c spark2diffs/UnixTimeExprMeta.diff $tmp_dir/UnixTimeExprMeta.newdiff + +sed -n '/object GpuToTimestamp/,/abstract class UnixTimeExprMeta/{/abstract class UnixTimeExprMeta/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsMeta.scala > $tmp_dir/GpuToTimestamp_new.out +sed -n '/object GpuToTimestamp/,/val REMOVE_WHITESPACE_FROM_MONTH_DAY/{/val REMOVE_WHITESPACE_FROM_MONTH_DAY/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala > $tmp_dir/GpuToTimestamp_old.out +diff $tmp_dir/GpuToTimestamp_new.out $tmp_dir/GpuToTimestamp_old.out > $tmp_dir/GpuToTimestamp.newdiff +diff -c spark2diffs/GpuToTimestamp.diff $tmp_dir/GpuToTimestamp.newdiff + +sed -n '/case class ParseFormatMeta/,/case class RegexReplace/p' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsMeta.scala > $tmp_dir/datemisc_new.out +sed -n '/case class ParseFormatMeta/,/case class RegexReplace/p' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala > $tmp_dir/datemisc_old.out +diff -c $tmp_dir/datemisc_new.out $tmp_dir/datemisc_old.out + +sed -n '/class GpuRLikeMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringMeta.scala > $tmp_dir/GpuRLikeMeta_new.out +sed -n '/class GpuRLikeMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala > $tmp_dir/GpuRLikeMeta_old.out +diff $tmp_dir/GpuRLikeMeta_new.out $tmp_dir/GpuRLikeMeta_old.out > $tmp_dir/GpuRLikeMeta.newdiff +diff -c spark2diffs/GpuRLikeMeta.diff $tmp_dir/GpuRLikeMeta.newdiff + +sed -n '/class GpuRegExpExtractMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringMeta.scala > $tmp_dir/GpuRegExpExtractMeta_new.out +sed -n '/class GpuRegExpExtractMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala > $tmp_dir/GpuRegExpExtractMeta_old.out +diff $tmp_dir/GpuRegExpExtractMeta_new.out $tmp_dir/GpuRegExpExtractMeta_old.out > $tmp_dir/GpuRegExpExtractMeta.newdiff +diff -c spark2diffs/GpuRegExpExtractMeta.diff $tmp_dir/GpuRegExpExtractMeta.newdiff + +sed -n '/class SubstringIndexMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringMeta.scala > $tmp_dir/SubstringIndexMeta_new.out +sed -n '/class SubstringIndexMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala > $tmp_dir/SubstringIndexMeta_old.out +diff $tmp_dir/SubstringIndexMeta_new.out $tmp_dir/SubstringIndexMeta_old.out > $tmp_dir/SubstringIndexMeta.newdiff +diff -c spark2diffs/SubstringIndexMeta.diff $tmp_dir/SubstringIndexMeta.newdiff + +sed -n '/object CudfRegexp/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringMeta.scala > $tmp_dir/CudfRegexp_new.out +sed -n '/object CudfRegexp/,/^}/{/^}/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala > $tmp_dir/CudfRegexp_old.out +diff -c $tmp_dir/CudfRegexp_new.out $tmp_dir/CudfRegexp_old.out > $tmp_dir/CudfRegexp.newdiff + +sed -n '/object GpuSubstringIndex/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringMeta.scala > $tmp_dir/GpuSubstringIndex_new.out +sed -n '/object GpuSubstringIndex/,/^}/{/^}/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala > $tmp_dir/GpuSubstringIndex_old.out +diff -c $tmp_dir/GpuSubstringIndex_new.out $tmp_dir/GpuSubstringIndex_old.out > $tmp_dir/GpuSubstringIndex.newdiff + +sed -n '/class GpuStringSplitMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringMeta.scala > $tmp_dir/GpuStringSplitMeta_new.out +sed -n '/class GpuStringSplitMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala > $tmp_dir/GpuStringSplitMeta_old.out +diff $tmp_dir/GpuStringSplitMeta_new.out $tmp_dir/GpuStringSplitMeta_old.out > $tmp_dir/GpuStringSplitMeta.newdiff +diff -c spark2diffs/GpuStringSplitMeta.diff $tmp_dir/GpuStringSplitMeta.newdiff + +sed -n '/object GpuOrcFileFormat/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala > $tmp_dir/GpuOrcFileFormat_new.out +sed -n '/object GpuOrcFileFormat/,/^}/{/^}/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala > $tmp_dir/GpuOrcFileFormat_old.out +diff $tmp_dir/GpuOrcFileFormat_new.out $tmp_dir/GpuOrcFileFormat_old.out > $tmp_dir/GpuOrcFileFormat.newdiff +diff -c spark2diffs/GpuOrcFileFormat.diff $tmp_dir/GpuOrcFileFormat.newdiff + +sed -n '/class GpuSequenceMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala > $tmp_dir/GpuSequenceMeta_new.out +sed -n '/class GpuSequenceMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala > $tmp_dir/GpuSequenceMeta_old.out +diff $tmp_dir/GpuSequenceMeta_new.out $tmp_dir/GpuSequenceMeta_old.out > $tmp_dir/GpuSequenceMeta.newdiff +diff -c spark2diffs/GpuSequenceMeta.diff $tmp_dir/GpuSequenceMeta.newdiff + +sed -n '/object GpuDataSource/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSource.scala > $tmp_dir/GpuDataSource_new.out +sed -n '/object GpuDataSource/,/val GLOB_PATHS_KEY/{/val GLOB_PATHS_KEY/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSource.scala > $tmp_dir/GpuDataSource_old.out +diff $tmp_dir/GpuDataSource_new.out $tmp_dir/GpuDataSource_old.out > $tmp_dir/GpuDataSource.newdiff +diff -c spark2diffs/GpuDataSource.diff $tmp_dir/GpuDataSource.newdiff + +sed -n '/class GpuGetArrayItemMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala > $tmp_dir/GpuGetArrayItemMeta_new.out +sed -n '/class GpuGetArrayItemMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala > $tmp_dir/GpuGetArrayItemMeta_old.out +diff $tmp_dir/GpuGetArrayItemMeta_new.out $tmp_dir/GpuGetArrayItemMeta_old.out > $tmp_dir/GpuGetArrayItemMeta.newdiff +diff -c spark2diffs/GpuGetArrayItemMeta.diff $tmp_dir/GpuGetArrayItemMeta.newdiff + +sed -n '/class GpuGetMapValueMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala > $tmp_dir/GpuGetMapValueMeta_new.out +sed -n '/class GpuGetMapValueMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala > $tmp_dir/GpuGetMapValueMeta_old.out +diff $tmp_dir/GpuGetMapValueMeta_new.out $tmp_dir/GpuGetMapValueMeta_old.out > $tmp_dir/GpuGetMapValueMeta.newdiff +diff -c spark2diffs/GpuGetMapValueMeta.diff $tmp_dir/GpuGetMapValueMeta.newdiff + +sed -n '/abstract class ScalaUDFMetaBase/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuScalaUDFMeta.scala> $tmp_dir/ScalaUDFMetaBase_new.out +sed -n '/abstract class ScalaUDFMetaBase/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuScalaUDF.scala > $tmp_dir/ScalaUDFMetaBase_old.out +diff $tmp_dir/ScalaUDFMetaBase_new.out $tmp_dir/ScalaUDFMetaBase_old.out > $tmp_dir/ScalaUDFMetaBase.newdiff +diff -c spark2diffs/ScalaUDFMetaBase.diff $tmp_dir/ScalaUDFMetaBase.newdiff + +sed -n '/object GpuScalaUDF/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuScalaUDFMeta.scala > $tmp_dir/GpuScalaUDF_new.out +sed -n '/object GpuScalaUDF/,/^}/{/^}/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuScalaUDF.scala > $tmp_dir/GpuScalaUDF_old.out +diff -c $tmp_dir/GpuScalaUDF_new.out $tmp_dir/GpuScalaUDF_old.out > $tmp_dir/GpuScalaUDF.newdiff + +sed -n '/object GpuDecimalMultiply/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala > $tmp_dir/GpuDecimalMultiply_new.out +sed -n '/object GpuDecimalMultiply/,/def checkForOverflow/{/def checkForOverflow/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala > $tmp_dir/GpuDecimalMultiply_old.out +diff $tmp_dir/GpuDecimalMultiply_new.out $tmp_dir/GpuDecimalMultiply_old.out > $tmp_dir/GpuDecimalMultiply.newdiff +diff -c spark2diffs/GpuDecimalMultiply.diff $tmp_dir/GpuDecimalMultiply.newdiff + +sed -n '/object GpuDecimalDivide/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala > $tmp_dir/GpuDecimalDivide_new.out +sed -n '/object GpuDecimalDivide/,/^}/{/^}/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala > $tmp_dir/GpuDecimalDivide_old.out +diff $tmp_dir/GpuDecimalDivide_new.out $tmp_dir/GpuDecimalDivide_old.out > $tmp_dir/GpuDecimalDivide.newdiff +diff -c spark2diffs/GpuDecimalDivide.diff $tmp_dir/GpuDecimalDivide.newdiff + +sed -n '/def isSupportedRelation/,/^ }/{/^ }/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/TrampolineUtil.scala > $tmp_dir/isSupportedRelation_new.out +sed -n '/def isSupportedRelation/,/^ }/{/^ }/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala > $tmp_dir/isSupportedRelation_old.out +diff -c $tmp_dir/isSupportedRelation_new.out $tmp_dir/isSupportedRelation_old.out + +sed -n '/def dataTypeExistsRecursively/,/^ }/{/^ }/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/TrampolineUtil.scala > $tmp_dir/dataTypeExistsRecursively_new.out +sed -n '/def dataTypeExistsRecursively/,/^ }/{/^ }/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala > $tmp_dir/dataTypeExistsRecursively_old.out +diff -c $tmp_dir/dataTypeExistsRecursively_new.out $tmp_dir/dataTypeExistsRecursively_old.out + +sed -n '/def getSimpleName/,/^ }/{/^ }/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/TrampolineUtil.scala > $tmp_dir/getSimpleName_new.out +sed -n '/def getSimpleName/,/^ }/{/^ }/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala > $tmp_dir/getSimpleName_old.out +diff -c $tmp_dir/getSimpleName_new.out $tmp_dir/getSimpleName_old.out + +diff ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala > $tmp_dir/RegexParser.newdiff +diff -c spark2diffs/RegexParser.diff $tmp_dir/RegexParser.newdiff + +diff ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala > $tmp_dir/TypeChecks.newdiff +diff -c spark2diffs/TypeChecks.diff $tmp_dir/TypeChecks.newdiff + +diff ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/DataTypeUtils.scala ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/DataTypeUtils.scala > $tmp_dir/DataTypeUtils.newdiff +diff -c spark2diffs/DataTypeUtils.diff $tmp_dir/DataTypeUtils.newdiff + +diff ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala > $tmp_dir/GpuOverrides.newdiff +diff -c spark2diffs/GpuOverrides.diff $tmp_dir/GpuOverrides.newdiff + +sed -n '/GpuOverrides.expr\[Cast\]/,/doFloatToIntCheck/p' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimGpuOverrides.scala > $tmp_dir/cast_new.out +sed -n '/GpuOverrides.expr\[Cast\]/,/doFloatToIntCheck/p' ../sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala > $tmp_dir/cast_old.out +diff $tmp_dir/cast_new.out $tmp_dir/cast_old.out > $tmp_dir/cast.newdiff +diff -c spark2diffs/cast.diff $tmp_dir/cast.newdiff + +sed -n '/GpuOverrides.expr\[Average\]/,/GpuOverrides.expr\[Abs/p' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimGpuOverrides.scala > $tmp_dir/average_new.out +sed -n '/GpuOverrides.expr\[Average\]/,/GpuOverrides.expr\[Abs/p' ../sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala > $tmp_dir/average_old.out +diff $tmp_dir/average_new.out $tmp_dir/average_old.out > $tmp_dir/average.newdiff +diff -c spark2diffs/average.diff $tmp_dir/average.newdiff + +sed -n '/GpuOverrides.expr\[Abs\]/,/GpuOverrides.expr\[RegExpReplace/p' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimGpuOverrides.scala > $tmp_dir/abs_new.out +sed -n '/GpuOverrides.expr\[Abs\]/,/GpuOverrides.expr\[RegExpReplace/p' ../sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala > $tmp_dir/abs_old.out +diff $tmp_dir/abs_new.out $tmp_dir/abs_old.out > $tmp_dir/abs.newdiff +diff -c spark2diffs/abs.diff $tmp_dir/abs.newdiff + +sed -n '/GpuOverrides.expr\[RegExpReplace\]/,/GpuOverrides.expr\[TimeSub/{/GpuOverrides.expr\[TimeSub/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimGpuOverrides.scala > $tmp_dir/regexreplace_new.out +sed -n '/GpuOverrides.expr\[RegExpReplace\]/,/GpuScalaUDFMeta.exprMeta/{/GpuScalaUDFMeta.exprMeta/!p}' ../sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala > $tmp_dir/regexreplace_old.out +diff -c $tmp_dir/regexreplace_new.out $tmp_dir/regexreplace_old.out + +sed -n '/GpuOverrides.expr\[TimeSub\]/,/GpuOverrides.expr\[ScalaUDF/{/GpuOverrides.expr\[ScalaUDF/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimGpuOverrides.scala > $tmp_dir/TimeSub_new.out +sed -n '/GpuOverrides.expr\[TimeSub\]/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala > $tmp_dir/TimeSub_old.out +diff -w $tmp_dir/TimeSub_new.out $tmp_dir/TimeSub_old.out > $tmp_dir/TimeSub.newdiff +diff -c spark2diffs/TimeSub.diff $tmp_dir/TimeSub.newdiff + +sed -n '/GpuOverrides.expr\[ScalaUDF\]/,/})/{/})/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimGpuOverrides.scala > $tmp_dir/ScalaUDF_new.out +sed -n '/GpuOverrides.expr\[ScalaUDF\]/,/})/{/})/!p}' ../sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/v2/GpuRowBasedScalaUDF.scala > $tmp_dir/ScalaUDF_old.out +diff -w $tmp_dir/ScalaUDF_new.out $tmp_dir/ScalaUDF_old.out > $tmp_dir/ScalaUDF.newdiff +diff -c spark2diffs/ScalaUDF.diff $tmp_dir/ScalaUDF.newdiff + +sed -n '/GpuOverrides.exec\[FileSourceScanExec\]/,/})/{/})/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimGpuOverrides.scala > $tmp_dir/FileSourceScanExec_new.out +sed -n '/GpuOverrides.exec\[FileSourceScanExec\]/,/override def convertToCpu/{/override def convertToCpu/!p}' ../sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala > $tmp_dir/FileSourceScanExec_old.out +diff -w $tmp_dir/FileSourceScanExec_new.out $tmp_dir/FileSourceScanExec_old.out > $tmp_dir/FileSourceScanExec.newdiff +diff -c spark2diffs/FileSourceScanExec.diff $tmp_dir/FileSourceScanExec.newdiff + +sed -n '/GpuOverrides.exec\[ArrowEvalPythonExec\]/,/})/{/})/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimGpuOverrides.scala > $tmp_dir/ArrowEvalPythonExec_new.out +sed -n '/GpuOverrides.exec\[ArrowEvalPythonExec\]/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala > $tmp_dir/ArrowEvalPythonExec_old.out +diff -w $tmp_dir/ArrowEvalPythonExec_new.out $tmp_dir/ArrowEvalPythonExec_old.out > $tmp_dir/ArrowEvalPythonExec.newdiff +diff -c spark2diffs/ArrowEvalPythonExec.diff $tmp_dir/ArrowEvalPythonExec.newdiff + +sed -n '/GpuOverrides.exec\[FlatMapGroupsInPandasExec\]/,/GpuOverrides.exec\[WindowInPandasExec/{/GpuOverrides.exec\[WindowInPandasExec/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimGpuOverrides.scala > $tmp_dir/FlatMapGroupsInPandasExec_new.out +sed -n '/GpuOverrides.exec\[FlatMapGroupsInPandasExec\]/,/GpuOverrides.exec\[AggregateInPandasExec/{/GpuOverrides.exec\[AggregateInPandasExec/!p}' ../sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala > $tmp_dir/FlatMapGroupsInPandasExec_old.out +diff -c -w $tmp_dir/FlatMapGroupsInPandasExec_new.out $tmp_dir/FlatMapGroupsInPandasExec_old.out + +sed -n '/GpuOverrides.exec\[WindowInPandasExec\]/,/})/{/})/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimGpuOverrides.scala > $tmp_dir/WindowInPandasExec_new.out +sed -n '/GpuOverrides.exec\[WindowInPandasExec\]/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala > $tmp_dir/WindowInPandasExec_old.out +diff -c -w --ignore-blank-lines $tmp_dir/WindowInPandasExec_new.out $tmp_dir/WindowInPandasExec_old.out + +sed -n '/GpuOverrides.exec\[AggregateInPandasExec\]/,/)\.collect/{/)\.collect/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimGpuOverrides.scala > $tmp_dir/AggregateInPandasExec_new.out +sed -n '/GpuOverrides.exec\[AggregateInPandasExec\]/,/)\.map/{/)\.map/!p}' ../sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala > $tmp_dir/AggregateInPandasExec_old.out +diff -c -w $tmp_dir/AggregateInPandasExec_new.out $tmp_dir/AggregateInPandasExec_old.out + +sed -n '/object GpuOrcScanBase/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScanBase.scala > $tmp_dir/GpuOrcScanBase_new.out +sed -n '/object GpuOrcScanBase/,/^}/{/^}/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScanBase.scala > $tmp_dir/GpuOrcScanBase_old.out +diff $tmp_dir/GpuOrcScanBase_new.out $tmp_dir/GpuOrcScanBase_old.out > $tmp_dir/GpuOrcScanBase.newdiff +diff -c spark2diffs/GpuOrcScanBase.diff $tmp_dir/GpuOrcScanBase.newdiff + +sed -n '/class LiteralExprMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/literalsMeta.scala > $tmp_dir/LiteralExprMeta_new.out +sed -n '/class LiteralExprMeta/,/^}/{/^}/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala > $tmp_dir/LiteralExprMeta_old.out +diff $tmp_dir/LiteralExprMeta_new.out $tmp_dir/LiteralExprMeta_old.out > $tmp_dir/LiteralExprMeta.newdiff +diff -c spark2diffs/LiteralExprMeta.diff $tmp_dir/LiteralExprMeta.newdiff + +# 2.x doesn't have a base aggregate class so this is much different, check the revision for now +CUR_COMMIT=`git log -1 ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala | grep commit | cut -d ' ' -f 2` +if [ "$CUR_COMMIT" != "b17c685788c0a62763aa8101709e241877f02025" ]; then + echo "sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala has different commit - check manually" +fi + +sed -n '/class GpuGenerateExecSparkPlanMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGenerateExecMeta.scala > $tmp_dir/GpuGenerateExecSparkPlanMeta_new.out +sed -n '/class GpuGenerateExecSparkPlanMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGenerateExec.scala > $tmp_dir/GpuGenerateExecSparkPlanMeta_old.out +diff $tmp_dir/GpuGenerateExecSparkPlanMeta_new.out $tmp_dir/GpuGenerateExecSparkPlanMeta_old.out > $tmp_dir/GpuGenerateExecSparkPlanMeta.newdiff +diff -c spark2diffs/GpuGenerateExecSparkPlanMeta.diff $tmp_dir/GpuGenerateExecSparkPlanMeta.newdiff + +sed -n '/abstract class GeneratorExprMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGenerateExecMeta.scala > $tmp_dir/GeneratorExprMeta_new.out +sed -n '/abstract class GeneratorExprMeta/,/^}/{/^}/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGenerateExec.scala > $tmp_dir/GeneratorExprMeta_old.out +diff $tmp_dir/GeneratorExprMeta_new.out $tmp_dir/GeneratorExprMeta_old.out > $tmp_dir/GeneratorExprMeta.newdiff +diff -c spark2diffs/GeneratorExprMeta.diff $tmp_dir/GeneratorExprMeta.newdiff + +diff ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ExplainPlan.scala ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/ExplainPlan.scala > $tmp_dir/ExplainPlan.newdiff +diff spark2diffs/ExplainPlan.diff $tmp_dir/ExplainPlan.newdiff + +sed -n '/class GpuProjectExecMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperatorsMeta.scala > $tmp_dir/GpuProjectExecMeta_new.out +sed -n '/class GpuProjectExecMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala > $tmp_dir/GpuProjectExecMeta_old.out +diff $tmp_dir/GpuProjectExecMeta_new.out $tmp_dir/GpuProjectExecMeta_old.out > $tmp_dir/GpuProjectExecMeta.newdiff +diff -c spark2diffs/GpuProjectExecMeta.diff $tmp_dir/GpuProjectExecMeta.newdiff + +sed -n '/class GpuSampleExecMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperatorsMeta.scala > $tmp_dir/GpuSampleExecMeta_new.out +sed -n '/class GpuSampleExecMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala > $tmp_dir/GpuSampleExecMeta_old.out +diff $tmp_dir/GpuSampleExecMeta_new.out $tmp_dir/GpuSampleExecMeta_old.out > $tmp_dir/GpuSampleExecMeta.newdiff +diff -c spark2diffs/GpuSampleExecMeta.diff $tmp_dir/GpuSampleExecMeta.newdiff + +sed -n '/class GpuSortMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExecMeta.scala > $tmp_dir/GpuSortMeta_new.out +sed -n '/class GpuSortMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala > $tmp_dir/GpuSortMeta_old.out +diff $tmp_dir/GpuSortMeta_new.out $tmp_dir/GpuSortMeta_old.out > $tmp_dir/GpuSortMeta.newdiff +diff -c spark2diffs/GpuSortMeta.diff $tmp_dir/GpuSortMeta.newdiff + +diff ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/InputFileBlockRule.scala ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/InputFileBlockRule.scala > $tmp_dir/InputFileBlockRule.newdiff +diff -c spark2diffs/InputFileBlockRule.diff $tmp_dir/InputFileBlockRule.newdiff + +sed -n '/abstract class GpuWindowInPandasExecMetaBase/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuPandasMeta.scala > $tmp_dir/GpuWindowInPandasExecMetaBase_new.out +sed -n '/abstract class GpuWindowInPandasExecMetaBase/,/^}/{/^}/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala > $tmp_dir/GpuWindowInPandasExecMetaBase_old.out +diff $tmp_dir/GpuWindowInPandasExecMetaBase_new.out $tmp_dir/GpuWindowInPandasExecMetaBase_old.out > $tmp_dir/GpuWindowInPandasExecMetaBase.newdiff +diff -c spark2diffs/GpuWindowInPandasExecMetaBase.diff $tmp_dir/GpuWindowInPandasExecMetaBase.newdiff + +sed -n '/class GpuAggregateInPandasExecMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuPandasMeta.scala > $tmp_dir/GpuAggregateInPandasExecMeta_new.out +sed -n '/class GpuAggregateInPandasExecMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala > $tmp_dir/GpuAggregateInPandasExecMeta_old.out +diff $tmp_dir/GpuAggregateInPandasExecMeta_new.out $tmp_dir/GpuAggregateInPandasExecMeta_old.out > $tmp_dir/GpuAggregateInPandasExecMeta.newdiff +diff -c spark2diffs/GpuAggregateInPandasExecMeta.diff $tmp_dir/GpuAggregateInPandasExecMeta.newdiff + +sed -n '/class GpuFlatMapGroupsInPandasExecMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuPandasMeta.scala > $tmp_dir/GpuFlatMapGroupsInPandasExecMeta_new.out +sed -n '/class GpuFlatMapGroupsInPandasExecMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/301+-nondb/scala/org/apache/spark/sql/rapids/execution/python/shims/v2/GpuFlatMapGroupsInPandasExec.scala > $tmp_dir/GpuFlatMapGroupsInPandasExecMeta_old.out +diff $tmp_dir/GpuFlatMapGroupsInPandasExecMeta_new.out $tmp_dir/GpuFlatMapGroupsInPandasExecMeta_old.out > $tmp_dir/GpuFlatMapGroupsInPandasExecMeta.newdiff +diff -c spark2diffs/GpuFlatMapGroupsInPandasExecMeta.diff $tmp_dir/GpuFlatMapGroupsInPandasExecMeta.newdiff + +sed -n '/class GpuShuffledHashJoinMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuShuffledHashJoinExecMeta.scala > $tmp_dir/GpuShuffledHashJoinMeta_new.out +sed -n '/class GpuShuffledHashJoinMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala > $tmp_dir/GpuShuffledHashJoinMeta_old.out +diff $tmp_dir/GpuShuffledHashJoinMeta_new.out $tmp_dir/GpuShuffledHashJoinMeta_old.out > $tmp_dir/GpuShuffledHashJoinMeta.newdiff +diff -c spark2diffs/GpuShuffledHashJoinMeta.diff $tmp_dir/GpuShuffledHashJoinMeta.newdiff + +diff ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala ../sql-plugin/src/main/pre320-treenode/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala > $tmp_dir/TreeNode.newdiff +diff -c spark2diffs/TreeNode.diff $tmp_dir/TreeNode.newdiff + +diff ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuSortMergeJoinMeta.scala ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortMergeJoinMeta.scala > $tmp_dir/GpuSortMergeJoinMeta.newdiff +diff -c spark2diffs/GpuSortMergeJoinMeta.diff $tmp_dir/GpuSortMergeJoinMeta.newdiff + +diff ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuJoinUtils.scala ../sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuJoinUtils.scala > $tmp_dir/GpuJoinUtils.newdiff +diff -c spark2diffs/GpuJoinUtils.diff $tmp_dir/GpuJoinUtils.newdiff + +diff ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/TypeSigUtil.scala ../sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/TypeSigUtil.scala > $tmp_dir/TypeSigUtil.newdiff +diff -c spark2diffs/TypeSigUtil.diff $tmp_dir/TypeSigUtil.newdiff + +sed -n '/class GpuBroadcastHashJoinMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuBroadcastHashJoinExecMeta.scala > $tmp_dir/GpuBroadcastHashJoinMeta_new.out +sed -n '/class GpuBroadcastHashJoinMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBroadcastHashJoinExec.scala > $tmp_dir/GpuBroadcastHashJoinMeta_old.out +diff $tmp_dir/GpuBroadcastHashJoinMeta_new.out $tmp_dir/GpuBroadcastHashJoinMeta_old.out > $tmp_dir/GpuBroadcastHashJoinMeta.newdiff +diff -c spark2diffs/GpuBroadcastHashJoinMeta.diff $tmp_dir/GpuBroadcastHashJoinMeta.newdiff + +sed -n '/object GpuCSVScan/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuCSVScan.scala > $tmp_dir/GpuCSVScan_new.out +sed -n '/object GpuCSVScan/,/^}/{/^}/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchScanExec.scala > $tmp_dir/GpuCSVScan_old.out +diff $tmp_dir/GpuCSVScan_new.out $tmp_dir/GpuCSVScan_old.out > $tmp_dir/GpuCSVScan.newdiff +diff -c spark2diffs/GpuCSVScan.diff $tmp_dir/GpuCSVScan.newdiff + +diff ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/OffsetWindowFunctionMeta.scala ../sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/v2/OffsetWindowFunctionMeta.scala > $tmp_dir/OffsetWindowFunctionMeta.newdiff +diff -c spark2diffs/OffsetWindowFunctionMeta.diff $tmp_dir/OffsetWindowFunctionMeta.newdiff + +sed -n '/class GpuRegExpReplaceMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala > $tmp_dir/GpuRegExpReplaceMeta_new.out +sed -n '/class GpuRegExpReplaceMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala > $tmp_dir/GpuRegExpReplaceMeta_old.out +diff $tmp_dir/GpuRegExpReplaceMeta_new.out $tmp_dir/GpuRegExpReplaceMeta_old.out > $tmp_dir/GpuRegExpReplaceMeta.newdiff +diff -c spark2diffs/GpuRegExpReplaceMeta.diff $tmp_dir/GpuRegExpReplaceMeta.newdiff + +sed -n '/class GpuWindowExpressionMetaBase/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala > $tmp_dir/GpuWindowExpressionMetaBase_new.out +sed -n '/class GpuWindowExpressionMetaBase/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala > $tmp_dir/GpuWindowExpressionMetaBase_old.out +diff $tmp_dir/GpuWindowExpressionMetaBase_new.out $tmp_dir/GpuWindowExpressionMetaBase_old.out > $tmp_dir/GpuWindowExpressionMetaBase.newdiff +diff -c spark2diffs/GpuWindowExpressionMetaBase.diff $tmp_dir/GpuWindowExpressionMetaBase.newdiff + +sed -n '/abstract class GpuSpecifiedWindowFrameMetaBase/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala > $tmp_dir/GpuSpecifiedWindowFrameMetaBase_new.out +sed -n '/abstract class GpuSpecifiedWindowFrameMetaBase/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala > $tmp_dir/GpuSpecifiedWindowFrameMetaBase_old.out +diff $tmp_dir/GpuSpecifiedWindowFrameMetaBase_new.out $tmp_dir/GpuSpecifiedWindowFrameMetaBase_old.out > $tmp_dir/GpuSpecifiedWindowFrameMetaBase.newdiff +diff -c spark2diffs/GpuSpecifiedWindowFrameMetaBase.diff $tmp_dir/GpuSpecifiedWindowFrameMetaBase.newdiff + +sed -n '/class GpuSpecifiedWindowFrameMeta(/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala > $tmp_dir/GpuSpecifiedWindowFrameMeta_new.out +sed -n '/class GpuSpecifiedWindowFrameMeta(/,/^}/{/^}/!p}' ../sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala > $tmp_dir/GpuSpecifiedWindowFrameMeta_old.out +diff $tmp_dir/GpuSpecifiedWindowFrameMeta_new.out $tmp_dir/GpuSpecifiedWindowFrameMeta_old.out > $tmp_dir/GpuSpecifiedWindowFrameMeta.newdiff +diff -c spark2diffs/GpuSpecifiedWindowFrameMeta.diff $tmp_dir/GpuSpecifiedWindowFrameMeta.newdiff + +sed -n '/class GpuWindowExpressionMeta(/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala > $tmp_dir/GpuWindowExpressionMeta_new.out +sed -n '/class GpuWindowExpressionMeta(/,/^}/{/^}/!p}' ../sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala > $tmp_dir/GpuWindowExpressionMeta_old.out +diff $tmp_dir/GpuWindowExpressionMeta_new.out $tmp_dir/GpuWindowExpressionMeta_old.out > $tmp_dir/GpuWindowExpressionMeta.newdiff +diff -c spark2diffs/GpuWindowExpressionMeta.diff $tmp_dir/GpuWindowExpressionMeta.newdiff + +sed -n '/object GpuWindowUtil/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala > $tmp_dir/GpuWindowUtil_new.out +sed -n '/object GpuWindowUtil/,/^}/{/^}/!p}' ../sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala > $tmp_dir/GpuWindowUtil_old.out +diff -c $tmp_dir/GpuWindowUtil_new.out $tmp_dir/GpuWindowUtil_old.out + +sed -n '/case class ParsedBoundary/p' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala > $tmp_dir/ParsedBoundary_new.out +sed -n '/case class ParsedBoundary/p' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala > $tmp_dir/ParsedBoundary_old.out +diff -c $tmp_dir/ParsedBoundary_new.out $tmp_dir/ParsedBoundary_old.out + +sed -n '/class GpuWindowSpecDefinitionMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala > $tmp_dir/GpuWindowSpecDefinitionMeta_new.out +sed -n '/class GpuWindowSpecDefinitionMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala > $tmp_dir/GpuWindowSpecDefinitionMeta_old.out +diff $tmp_dir/GpuWindowSpecDefinitionMeta_new.out $tmp_dir/GpuWindowSpecDefinitionMeta_old.out > $tmp_dir/GpuWindowSpecDefinitionMeta.newdiff +diff -c spark2diffs/GpuWindowSpecDefinitionMeta.diff $tmp_dir/GpuWindowSpecDefinitionMeta.newdiff + +sed -n '/abstract class GpuBaseWindowExecMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowMeta.scala > $tmp_dir/GpuBaseWindowExecMeta_new.out +sed -n '/abstract class GpuBaseWindowExecMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala > $tmp_dir/GpuBaseWindowExecMeta_old.out +diff $tmp_dir/GpuBaseWindowExecMeta_new.out $tmp_dir/GpuBaseWindowExecMeta_old.out > $tmp_dir/GpuBaseWindowExecMeta.newdiff +diff -c spark2diffs/GpuBaseWindowExecMeta.diff $tmp_dir/GpuBaseWindowExecMeta.newdiff + +sed -n '/class GpuWindowExecMeta(/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowMeta.scala > $tmp_dir/GpuWindowExecMeta_new.out +sed -n '/class GpuWindowExecMeta(/,/^}/{/^}/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala > $tmp_dir/GpuWindowExecMeta_old.out +diff $tmp_dir/GpuWindowExecMeta_new.out $tmp_dir/GpuWindowExecMeta_old.out > $tmp_dir/GpuWindowExecMeta.newdiff +diff -c spark2diffs/GpuWindowExecMeta.diff $tmp_dir/GpuWindowExecMeta.newdiff + +sed -n '/class GpuExpandExecMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpandMeta.scala > $tmp_dir/GpuExpandExecMeta_new.out +sed -n '/class GpuExpandExecMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpandExec.scala > $tmp_dir/GpuExpandExecMeta_old.out +diff $tmp_dir/GpuExpandExecMeta_new.out $tmp_dir/GpuExpandExecMeta_old.out > $tmp_dir/GpuExpandExecMeta.newdiff +diff -c spark2diffs/GpuExpandExecMeta.diff $tmp_dir/GpuExpandExecMeta.newdiff + +diff ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala > $tmp_dir/RapidsMeta.newdiff +diff -c spark2diffs/RapidsMeta.diff $tmp_dir/RapidsMeta.newdiff + +sed -n '/object GpuReadCSVFileFormat/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadCSVFileFormat.scala > $tmp_dir/GpuReadCSVFileFormat_new.out +sed -n '/object GpuReadCSVFileFormat/,/^}/{/^}/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadCSVFileFormat.scala > $tmp_dir/GpuReadCSVFileFormat_old.out +diff $tmp_dir/GpuReadCSVFileFormat_new.out $tmp_dir/GpuReadCSVFileFormat_old.out > $tmp_dir/GpuReadCSVFileFormat.newdiff +diff -c spark2diffs/GpuReadCSVFileFormat.diff $tmp_dir/GpuReadCSVFileFormat.newdiff + +diff ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala > $tmp_dir/RapidsConf.newdiff +diff -c spark2diffs/RapidsConf.diff $tmp_dir/RapidsConf.newdiff + +sed -n '/object GpuReadParquetFileFormat/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScanBase.scala > $tmp_dir/GpuReadParquetFileFormat_new.out +sed -n '/object GpuReadParquetFileFormat/,/^}/{/^}/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadParquetFileFormat.scala > $tmp_dir/GpuReadParquetFileFormat_old.out +diff $tmp_dir/GpuReadParquetFileFormat_new.out $tmp_dir/GpuReadParquetFileFormat_old.out > $tmp_dir/GpuReadParquetFileFormat.newdiff +diff -c spark2diffs/GpuReadParquetFileFormat.diff $tmp_dir/GpuReadParquetFileFormat.newdiff + +sed -n '/object GpuParquetScanBase/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScanBase.scala > $tmp_dir/GpuParquetScanBase_new.out +sed -n '/object GpuParquetScanBase/,/^}/{/^}/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScanBase.scala > $tmp_dir/GpuParquetScanBase_old.out +diff $tmp_dir/GpuParquetScanBase_new.out $tmp_dir/GpuParquetScanBase_old.out > $tmp_dir/GpuParquetScanBase.newdiff +diff -c spark2diffs/GpuParquetScanBase.diff $tmp_dir/GpuParquetScanBase.newdiff + +diff ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBroadcastJoinMeta.scala ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBroadcastJoinMeta.scala > $tmp_dir/GpuBroadcastJoinMeta.newdiff +diff -c spark2diffs/GpuBroadcastJoinMeta.diff $tmp_dir/GpuBroadcastJoinMeta.newdiff + +sed -n '/object AggregateUtils/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/AggregateUtils.scala > $tmp_dir/AggregateUtils_new.out +sed -n '/object AggregateUtils/,/^}/{/^}/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala > $tmp_dir/AggregateUtils_old.out +diff $tmp_dir/AggregateUtils_new.out $tmp_dir/AggregateUtils_old.out > $tmp_dir/AggregateUtils.newdiff +diff -c spark2diffs/AggregateUtils.diff $tmp_dir/AggregateUtils.newdiff + +sed -n '/final class CastExprMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCastMeta.scala > $tmp_dir/CastExprMeta_new.out +sed -n '/final class CastExprMeta/,/override def convertToGpu/{/override def convertToGpu/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala > $tmp_dir/CastExprMeta_old.out +diff $tmp_dir/CastExprMeta_new.out $tmp_dir/CastExprMeta_old.out > $tmp_dir/CastExprMeta.newdiff +diff -c spark2diffs/CastExprMeta.diff $tmp_dir/CastExprMeta.newdiff + +sed -n '/object GpuReadOrcFileFormat/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadOrcFileFormat.scala > $tmp_dir/GpuReadOrcFileFormat_new.out +sed -n '/object GpuReadOrcFileFormat/,/^}/{/^}/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadOrcFileFormat.scala > $tmp_dir/GpuReadOrcFileFormat_old.out +diff $tmp_dir/GpuReadOrcFileFormat_new.out $tmp_dir/GpuReadOrcFileFormat_old.out > $tmp_dir/GpuReadOrcFileFormat.newdiff +diff -c spark2diffs/GpuReadOrcFileFormat.diff $tmp_dir/GpuReadOrcFileFormat.newdiff + +sed -n '/object GpuParquetFileFormat/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala > $tmp_dir/GpuParquetFileFormat_new.out +sed -n '/object GpuParquetFileFormat/,/^}/{/^}/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala > $tmp_dir/GpuParquetFileFormat_old.out +diff $tmp_dir/GpuParquetFileFormat_new.out $tmp_dir/GpuParquetFileFormat_old.out > $tmp_dir/GpuParquetFileFormat.newdiff +diff -c spark2diffs/GpuParquetFileFormat.diff $tmp_dir/GpuParquetFileFormat.newdiff + +sed -n '/def asDecimalType/,/^ }/{/^ }/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/DecimalUtil.scala > $tmp_dir/asDecimalType_new.out +sed -n '/def asDecimalType/,/^ }/{/^ }/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/DecimalUtil.scala > $tmp_dir/asDecimalType_old.out +diff -c $tmp_dir/asDecimalType_new.out $tmp_dir/asDecimalType_old.out + +sed -n '/def optionallyAsDecimalType/,/^ }/{/^ }/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/DecimalUtil.scala > $tmp_dir/optionallyAsDecimalType_new.out +sed -n '/def optionallyAsDecimalType/,/^ }/{/^ }/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/DecimalUtil.scala > $tmp_dir/optionallyAsDecimalType_old.out +diff $tmp_dir/optionallyAsDecimalType_new.out $tmp_dir/optionallyAsDecimalType_old.out > $tmp_dir/optionallyAsDecimalType.newdiff +diff -c spark2diffs/optionallyAsDecimalType.diff $tmp_dir/optionallyAsDecimalType.newdiff + +sed -n '/def getPrecisionForIntegralType/,/^ }/{/^ }/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/DecimalUtil.scala > $tmp_dir/getPrecisionForIntegralType_new.out +sed -n '/def getPrecisionForIntegralType/,/^ }/{/^ }/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/DecimalUtil.scala > $tmp_dir/getPrecisionForIntegralType_old.out +diff $tmp_dir/getPrecisionForIntegralType_new.out $tmp_dir/getPrecisionForIntegralType_old.out > $tmp_dir/getPrecisionForIntegralType.newdiff +diff -c spark2diffs/getPrecisionForIntegralType.diff $tmp_dir/getPrecisionForIntegralType.newdiff + +# not sure this diff works very well due to java vs scala and quite a bit different but should find any changes in those functions +sed -n '/def toRapidsStringOrNull/,/^ }/{/^ }/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/DecimalUtil.scala > $tmp_dir/toRapidsStringOrNull_new.out +sed -n '/private static DType/,/^ }/{/^ }/!p}' ../sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java > $tmp_dir/toRapidsStringOrNull_old.out +diff $tmp_dir/toRapidsStringOrNull_new.out $tmp_dir/toRapidsStringOrNull_old.out > $tmp_dir/toRapidsStringOrNull.newdiff +diff -c spark2diffs/toRapidsStringOrNull.diff $tmp_dir/toRapidsStringOrNull.newdiff + +sed -n '/def createCudfDecimal/,/^ }/{/^ }/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/DecimalUtil.scala > $tmp_dir/createCudfDecimal_new.out +sed -n '/def createCudfDecimal/,/^ }/{/^ }/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/DecimalUtil.scala > $tmp_dir/createCudfDecimal_old.out +diff $tmp_dir/createCudfDecimal_new.out $tmp_dir/createCudfDecimal_old.out > $tmp_dir/createCudfDecimal.newdiff +diff -c spark2diffs/createCudfDecimal.diff $tmp_dir/createCudfDecimal.newdiff + +sed -n '/abstract class ReplicateRowsExprMeta/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReplicateRowsMeta.scala > $tmp_dir/ReplicateRowsExprMeta_new.out +sed -n '/abstract class ReplicateRowsExprMeta/,/override final def convertToGpu/{/override final def convertToGpu/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGenerateExec.scala > $tmp_dir/ReplicateRowsExprMeta_old.out +diff $tmp_dir/ReplicateRowsExprMeta_new.out $tmp_dir/ReplicateRowsExprMeta_old.out > $tmp_dir/ReplicateRowsExprMeta.newdiff +diff -c spark2diffs/ReplicateRowsExprMeta.diff $tmp_dir/ReplicateRowsExprMeta.newdiff + +diff ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala > $tmp_dir/DateUtils.newdiff +diff -c spark2diffs/DateUtils.diff $tmp_dir/DateUtils.newdiff + +sed -n '/object CudfTDigest/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala > $tmp_dir/CudfTDigest_new.out +sed -n '/object CudfTDigest/,/^}/{/^}/!p}' ../sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala > $tmp_dir/CudfTDigest_old.out +diff $tmp_dir/CudfTDigest_new.out $tmp_dir/CudfTDigest_old.out > $tmp_dir/CudfTDigest.newdiff +diff -c spark2diffs/CudfTDigest.diff $tmp_dir/CudfTDigest.newdiff + +sed -n '/sealed trait TimeParserPolicy/p' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala > $tmp_dir/TimeParserPolicy_new.out +sed -n '/sealed trait TimeParserPolicy/p' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala > $tmp_dir/TimeParserPolicy_old.out +diff -c $tmp_dir/TimeParserPolicy_new.out $tmp_dir/TimeParserPolicy_old.out + +sed -n '/object LegacyTimeParserPolicy/p' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala > $tmp_dir/LegacyTimeParserPolicy_new.out +sed -n '/object LegacyTimeParserPolicy/p' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala > $tmp_dir/LegacyTimeParserPolicy_old.out +diff -c $tmp_dir/LegacyTimeParserPolicy_new.out $tmp_dir/LegacyTimeParserPolicy_old.out + +sed -n '/object ExceptionTimeParserPolicy/p' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala > $tmp_dir/ExceptionTimeParserPolicy_new.out +sed -n '/object ExceptionTimeParserPolicy/p' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala > $tmp_dir/ExceptionTimeParserPolicy_old.out +diff -c $tmp_dir/ExceptionTimeParserPolicy_new.out $tmp_dir/ExceptionTimeParserPolicy_old.out + +sed -n '/object CorrectedTimeParserPolicy/p' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala > $tmp_dir/CorrectedTimeParserPolicy_new.out +sed -n '/object CorrectedTimeParserPolicy/p' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala > $tmp_dir/CorrectedTimeParserPolicy_old.out +diff -c $tmp_dir/CorrectedTimeParserPolicy_new.out $tmp_dir/CorrectedTimeParserPolicy_old.out + +sed -n '/object GpuFloorCeil/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala > $tmp_dir/GpuFloorCeil_new.out +sed -n '/object GpuFloorCeil/,/^}/{/^}/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala > $tmp_dir/GpuFloorCeil_old.out +diff -c $tmp_dir/GpuFloorCeil_new.out $tmp_dir/GpuFloorCeil_old.out + +sed -n '/object GpuFileSourceScanExec/,/^}/{/^}/!p}' ../spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExec.scala > $tmp_dir/GpuFileSourceScanExec_new.out +sed -n '/object GpuFileSourceScanExec/,/^}/{/^}/!p}' ../sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExec.scala > $tmp_dir/GpuFileSourceScanExec_old.out +diff $tmp_dir/GpuFileSourceScanExec_new.out $tmp_dir/GpuFileSourceScanExec_old.out > $tmp_dir/GpuFileSourceScanExec.newdiff +diff -c spark2diffs/GpuFileSourceScanExec.diff $tmp_dir/GpuFileSourceScanExec.newdiff + +echo "Done running Diffs of spark2.x files" +rm -r $tmp_dir diff --git a/scripts/spark2diffs/AggregateUtils.diff b/scripts/spark2diffs/AggregateUtils.diff new file mode 100644 index 00000000000..c8c643be803 --- /dev/null +++ b/scripts/spark2diffs/AggregateUtils.diff @@ -0,0 +1,43 @@ +32a33,74 +> +> /** +> * Computes a target input batch size based on the assumption that computation can consume up to +> * 4X the configured batch size. +> * @param confTargetSize user-configured maximum desired batch size +> * @param inputTypes input batch schema +> * @param outputTypes output batch schema +> * @param isReductionOnly true if this is a reduction-only aggregation without grouping +> * @return maximum target batch size to keep computation under the 4X configured batch limit +> */ +> def computeTargetBatchSize( +> confTargetSize: Long, +> inputTypes: Seq[DataType], +> outputTypes: Seq[DataType], +> isReductionOnly: Boolean): Long = { +> def typesToSize(types: Seq[DataType]): Long = +> types.map(GpuBatchUtils.estimateGpuMemory(_, nullable = false, rowCount = 1)).sum +> val inputRowSize = typesToSize(inputTypes) +> val outputRowSize = typesToSize(outputTypes) +> // The cudf hash table implementation allocates four 32-bit integers per input row. +> val hashTableRowSize = 4 * 4 +> +> // Using the memory management for joins as a reference, target 4X batch size as a budget. +> var totalBudget = 4 * confTargetSize +> +> // Compute the amount of memory being consumed per-row in the computation +> var computationBytesPerRow = inputRowSize + hashTableRowSize +> if (isReductionOnly) { +> // Remove the lone output row size from the budget rather than track per-row in computation +> totalBudget -= outputRowSize +> } else { +> // The worst-case memory consumption during a grouping aggregation is the case where the +> // grouping does not combine any input rows, so just as many rows appear in the output. +> computationBytesPerRow += outputRowSize +> } +> +> // Calculate the max rows that can be processed during computation within the budget +> val maxRows = totalBudget / computationBytesPerRow +> +> // Finally compute the input target batching size taking into account the cudf row limits +> Math.min(inputRowSize * maxRows, Int.MaxValue) +> } diff --git a/scripts/spark2diffs/ArrowEvalPythonExec.diff b/scripts/spark2diffs/ArrowEvalPythonExec.diff new file mode 100644 index 00000000000..18ab1fa79a7 --- /dev/null +++ b/scripts/spark2diffs/ArrowEvalPythonExec.diff @@ -0,0 +1,8 @@ +13c13 +< e.output.map(GpuOverrides.wrapExpr(_, conf, Some(this))) +--- +> e.resultAttrs.map(GpuOverrides.wrapExpr(_, conf, Some(this))) +14a15 +> +17a19 +> diff --git a/scripts/spark2diffs/CastExprMeta.diff b/scripts/spark2diffs/CastExprMeta.diff new file mode 100644 index 00000000000..97fb9b3dcac --- /dev/null +++ b/scripts/spark2diffs/CastExprMeta.diff @@ -0,0 +1,39 @@ +1c1 +< final class CastExprMeta[INPUT <: Cast]( +--- +> final class CastExprMeta[INPUT <: CastBase]( +5c5 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +21,22c21 +< // 2.x doesn't have the SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING config, so set it to true +< val legacyCastToString: Boolean = true +--- +> val legacyCastToString: Boolean = ShimLoader.getSparkShims.getLegacyComplexTypeToString() +46c45 +< if (dt.precision > GpuOverrides.DECIMAL128_MAX_PRECISION) { +--- +> if (dt.precision > DType.DECIMAL128_MAX_PRECISION) { +48c47 +< s"precision > ${GpuOverrides.DECIMAL128_MAX_PRECISION} is not supported yet") +--- +> s"precision > ${DType.DECIMAL128_MAX_PRECISION} is not supported yet") +81a81 +> YearParseUtil.tagParseStringAsDate(conf, this) +83c83 +< // NOOP for anything prior to 3.2.0 +--- +> YearParseUtil.tagParseStringAsDate(conf, this) +85,91c85 +< // Spark 2.x: removed check for +< // !ShimLoader.getSparkShims.isCastingStringToNegDecimalScaleSupported +< // this dealt with handling a bug fix that is only in newer versions of Spark +< // (https://issues.apache.org/jira/browse/SPARK-37451) +< // Since we don't know what version of Spark 3 they will be using +< // just always say it won't work and they can hopefully figure it out from warning. +< if (dt.scale < 0) { +--- +> if (dt.scale < 0 && !ShimLoader.getSparkShims.isCastingStringToNegDecimalScaleSupported) { +120a115 +> diff --git a/scripts/spark2diffs/CudfTDigest.diff b/scripts/spark2diffs/CudfTDigest.diff new file mode 100644 index 00000000000..aabb3beba06 --- /dev/null +++ b/scripts/spark2diffs/CudfTDigest.diff @@ -0,0 +1,7 @@ +9a10,15 +> +> // Map Spark delta to cuDF delta +> def accuracy(accuracyExpression: GpuLiteral): Int = accuracyExpression.value match { +> case delta: Int => delta.max(1000) +> case _ => 1000 +> } diff --git a/scripts/spark2diffs/DataTypeUtils.diff b/scripts/spark2diffs/DataTypeUtils.diff new file mode 100644 index 00000000000..9c7e90dc82f --- /dev/null +++ b/scripts/spark2diffs/DataTypeUtils.diff @@ -0,0 +1,4 @@ +2c2 +< * Copyright (c) 2022, NVIDIA CORPORATION. +--- +> * Copyright (c) 2021, NVIDIA CORPORATION. diff --git a/scripts/spark2diffs/DateUtils.diff b/scripts/spark2diffs/DateUtils.diff new file mode 100644 index 00000000000..38fed34ab05 --- /dev/null +++ b/scripts/spark2diffs/DateUtils.diff @@ -0,0 +1,63 @@ +2c2 +< * Copyright (c) 2022, NVIDIA CORPORATION. +--- +> * Copyright (c) 2020-2021, NVIDIA CORPORATION. +19c19 +< import java.time._ +--- +> import java.time.LocalDate +22a23,25 +> import ai.rapids.cudf.{DType, Scalar} +> import com.nvidia.spark.rapids.VersionUtils.isSpark320OrLater +> +23a27 +> import org.apache.spark.sql.catalyst.util.DateTimeUtils.localDateToDays +59,60c63,65 +< // Spark 2.x - removed isSpark320orlater checks +< def specialDatesDays: Map[String, Int] = { +--- +> def specialDatesDays: Map[String, Int] = if (isSpark320OrLater) { +> Map.empty +> } else { +71c76,78 +< def specialDatesSeconds: Map[String, Long] = { +--- +> def specialDatesSeconds: Map[String, Long] = if (isSpark320OrLater) { +> Map.empty +> } else { +73,74c80 +< // spark 2.4 Date utils are different +< val now = DateTimeUtils.instantToMicros(Instant.now()) +--- +> val now = DateTimeUtils.currentTimestamp() +84c90,92 +< def specialDatesMicros: Map[String, Long] = { +--- +> def specialDatesMicros: Map[String, Long] = if (isSpark320OrLater) { +> Map.empty +> } else { +86c94 +< val now = DateTimeUtils.instantToMicros(Instant.now()) +--- +> val now = DateTimeUtils.currentTimestamp() +96c104,121 +< def currentDate(): Int = Math.toIntExact(LocalDate.now().toEpochDay) +--- +> def fetchSpecialDates(unit: DType): Map[String, () => Scalar] = unit match { +> case DType.TIMESTAMP_DAYS => +> DateUtils.specialDatesDays.map { case (k, v) => +> k -> (() => Scalar.timestampDaysFromInt(v)) +> } +> case DType.TIMESTAMP_SECONDS => +> DateUtils.specialDatesSeconds.map { case (k, v) => +> k -> (() => Scalar.timestampFromLong(unit, v)) +> } +> case DType.TIMESTAMP_MICROSECONDS => +> DateUtils.specialDatesMicros.map { case (k, v) => +> k -> (() => Scalar.timestampFromLong(unit, v)) +> } +> case _ => +> throw new IllegalArgumentException(s"unsupported DType: $unit") +> } +> +> def currentDate(): Int = localDateToDays(LocalDate.now()) diff --git a/scripts/spark2diffs/ExplainPlan.diff b/scripts/spark2diffs/ExplainPlan.diff new file mode 100644 index 00000000000..423f7c311c9 --- /dev/null +++ b/scripts/spark2diffs/ExplainPlan.diff @@ -0,0 +1,8 @@ +2c2 +< * Copyright (c) 2022, NVIDIA CORPORATION. +--- +> * Copyright (c) 2021, NVIDIA CORPORATION. +65c65 +< GpuOverrides.explainPotentialGpuPlan(df, explain) +--- +> ShimLoader.newExplainPlan.explainPotentialGpuPlan(df, explain) diff --git a/scripts/spark2diffs/FileSourceScanExec.diff b/scripts/spark2diffs/FileSourceScanExec.diff new file mode 100644 index 00000000000..7633b9cd6b1 --- /dev/null +++ b/scripts/spark2diffs/FileSourceScanExec.diff @@ -0,0 +1,21 @@ +6a7,24 +> // Replaces SubqueryBroadcastExec inside dynamic pruning filters with GPU counterpart +> // if possible. Instead regarding filters as childExprs of current Meta, we create +> // a new meta for SubqueryBroadcastExec. The reason is that the GPU replacement of +> // FileSourceScan is independent from the replacement of the partitionFilters. It is +> // possible that the FileSourceScan is on the CPU, while the dynamic partitionFilters +> // are on the GPU. And vice versa. +> private lazy val partitionFilters = wrapped.partitionFilters.map { filter => +> filter.transformDown { +> case dpe @ DynamicPruningExpression(inSub: InSubqueryExec) +> if inSub.plan.isInstanceOf[SubqueryBroadcastExec] => +> +> val subBcMeta = GpuOverrides.wrapAndTagPlan(inSub.plan, conf) +> subBcMeta.tagForExplain() +> val gpuSubBroadcast = subBcMeta.convertIfNeeded().asInstanceOf[BaseSubqueryExec] +> dpe.copy(inSub.copy(plan = gpuSubBroadcast)) +> } +> } +> +10a29 +> diff --git a/scripts/spark2diffs/GeneratorExprMeta.diff b/scripts/spark2diffs/GeneratorExprMeta.diff new file mode 100644 index 00000000000..fd5db326022 --- /dev/null +++ b/scripts/spark2diffs/GeneratorExprMeta.diff @@ -0,0 +1,4 @@ +4c4 +< p: Option[RapidsMeta[_, _]], +--- +> p: Option[RapidsMeta[_, _, _]], diff --git a/scripts/spark2diffs/GpuAggregateInPandasExecMeta.diff b/scripts/spark2diffs/GpuAggregateInPandasExecMeta.diff new file mode 100644 index 00000000000..a2c23932573 --- /dev/null +++ b/scripts/spark2diffs/GpuAggregateInPandasExecMeta.diff @@ -0,0 +1,6 @@ +4c4 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +21a22 +> diff --git a/scripts/spark2diffs/GpuBaseWindowExecMeta.diff b/scripts/spark2diffs/GpuBaseWindowExecMeta.diff new file mode 100644 index 00000000000..74cfcdeb8c1 --- /dev/null +++ b/scripts/spark2diffs/GpuBaseWindowExecMeta.diff @@ -0,0 +1,4 @@ +3c3 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], diff --git a/scripts/spark2diffs/GpuBroadcastExchangeExecMeta.diff b/scripts/spark2diffs/GpuBroadcastExchangeExecMeta.diff new file mode 100644 index 00000000000..a2e66d7281c --- /dev/null +++ b/scripts/spark2diffs/GpuBroadcastExchangeExecMeta.diff @@ -0,0 +1,12 @@ +4c4 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +13c13 +< def isSupported(rm: RapidsMeta[_, _]): Boolean = rm.wrapped match { +--- +> def isSupported(rm: RapidsMeta[_, _, _]): Boolean = rm.wrapped match { +25c25 +< } +--- +> diff --git a/scripts/spark2diffs/GpuBroadcastHashJoinMeta.diff b/scripts/spark2diffs/GpuBroadcastHashJoinMeta.diff new file mode 100644 index 00000000000..fc56a52a9d2 --- /dev/null +++ b/scripts/spark2diffs/GpuBroadcastHashJoinMeta.diff @@ -0,0 +1,15 @@ +4c4 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +31c31,36 +< willNotWorkOnGpu("the broadcast for this join must be on the GPU too") +--- +> if (conf.isSqlExplainOnlyEnabled && wrapped.conf.adaptiveExecutionEnabled) { +> willNotWorkOnGpu("explain only mode with AQE, we cannot determine " + +> "if the broadcast for this join is on the GPU too") +> } else { +> willNotWorkOnGpu("the broadcast for this join must be on the GPU too") +> } +37a43 +> diff --git a/scripts/spark2diffs/GpuBroadcastJoinMeta.diff b/scripts/spark2diffs/GpuBroadcastJoinMeta.diff new file mode 100644 index 00000000000..e9c8db15154 --- /dev/null +++ b/scripts/spark2diffs/GpuBroadcastJoinMeta.diff @@ -0,0 +1,37 @@ +2c2 +< * Copyright (c) 2022, NVIDIA CORPORATION. +--- +> * Copyright (c) 2020, NVIDIA CORPORATION. +18a19,21 +> import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec +> import org.apache.spark.sql.execution.exchange.ReusedExchangeExec +> import org.apache.spark.sql.rapids.execution.GpuBroadcastExchangeExec +22c25 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +27d29 +< // Spark 2.x - removed some checks only applicable to AQE +28a31,36 +> case bqse: BroadcastQueryStageExec => bqse.plan.isInstanceOf[GpuBroadcastExchangeExec] || +> bqse.plan.isInstanceOf[ReusedExchangeExec] && +> bqse.plan.asInstanceOf[ReusedExchangeExec] +> .child.isInstanceOf[GpuBroadcastExchangeExec] +> case reused: ReusedExchangeExec => reused.child.isInstanceOf[GpuBroadcastExchangeExec] +> case _: GpuBroadcastExchangeExec => true +29a38,52 +> } +> } +> +> def verifyBuildSideWasReplaced(buildSide: SparkPlan): Unit = { +> val buildSideOnGpu = buildSide match { +> case bqse: BroadcastQueryStageExec => bqse.plan.isInstanceOf[GpuBroadcastExchangeExec] || +> bqse.plan.isInstanceOf[ReusedExchangeExec] && +> bqse.plan.asInstanceOf[ReusedExchangeExec] +> .child.isInstanceOf[GpuBroadcastExchangeExec] +> case reused: ReusedExchangeExec => reused.child.isInstanceOf[GpuBroadcastExchangeExec] +> case _: GpuBroadcastExchangeExec => true +> case _ => false +> } +> if (!buildSideOnGpu) { +> throw new IllegalStateException(s"the broadcast must be on the GPU too") diff --git a/scripts/spark2diffs/GpuBroadcastNestedLoopJoinMeta.diff b/scripts/spark2diffs/GpuBroadcastNestedLoopJoinMeta.diff new file mode 100644 index 00000000000..ceeec421192 --- /dev/null +++ b/scripts/spark2diffs/GpuBroadcastNestedLoopJoinMeta.diff @@ -0,0 +1,8 @@ +4c4 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +49c49 +< } +--- +> diff --git a/scripts/spark2diffs/GpuCSVScan.diff b/scripts/spark2diffs/GpuCSVScan.diff new file mode 100644 index 00000000000..4ae82fa5ac5 --- /dev/null +++ b/scripts/spark2diffs/GpuCSVScan.diff @@ -0,0 +1,54 @@ +27,34c27,34 +< def dateFormatInRead(csvOpts: CSVOptions): Option[String] = { +< // spark 2.x uses FastDateFormat, use getPattern +< Option(csvOpts.dateFormat.getPattern) +< } +< +< def timestampFormatInRead(csvOpts: CSVOptions): Option[String] = { +< // spark 2.x uses FastDateFormat, use getPattern +< Option(csvOpts.timestampFormat.getPattern) +--- +> def tagSupport(scanMeta: ScanMeta[CSVScan]) : Unit = { +> val scan = scanMeta.wrapped +> tagSupport( +> scan.sparkSession, +> scan.dataSchema, +> scan.readDataSchema, +> scan.options.asScala.toMap, +> scanMeta) +42c42 +< meta: RapidsMeta[_, _]): Unit = { +--- +> meta: RapidsMeta[_, _, _]): Unit = { +67,68d66 +< // 2.x only supports delimiter as char +< /* +72d69 +< */ +74,75c71 +< // delimiter is char in 2.x +< if (parsedOptions.delimiter > 127) { +--- +> if (parsedOptions.delimiter.codePointAt(0) > 127) { +105,109d100 +< // 2.x doesn't have linSeparator config +< // CSV text with '\n', '\r' and '\r\n' as line separators. +< // Since I have no way to check in 2.x we will just assume it works for explain until +< // they move to 3.x +< /* +113d103 +< */ +143c133 +< dateFormatInRead(parsedOptions).foreach { dateFormat => +--- +> ShimLoader.getSparkShims.dateFormatInRead(parsedOptions).foreach { dateFormat => +190,192c180 +< +< // Spark 2.x doesn't have zoneId, so use timeZone and then to id +< if (!TypeChecks.areTimestampsSupported(parsedOptions.timeZone.toZoneId)) { +--- +> if (!TypeChecks.areTimestampsSupported(parsedOptions.zoneId)) { +195c183 +< timestampFormatInRead(parsedOptions).foreach { tsFormat => +--- +> ShimLoader.getSparkShims.timestampFormatInRead(parsedOptions).foreach { tsFormat => diff --git a/scripts/spark2diffs/GpuDataSource.diff b/scripts/spark2diffs/GpuDataSource.diff new file mode 100644 index 00000000000..8b3e80de905 --- /dev/null +++ b/scripts/spark2diffs/GpuDataSource.diff @@ -0,0 +1,29 @@ +7c7 +< val parquet = classOf[ParquetFileFormat].getCanonicalName +--- +> val parquet = classOf[GpuParquetFileFormat].getCanonicalName +47a48,62 +> def lookupDataSourceWithFallback(className: String, conf: SQLConf): Class[_] = { +> val cls = GpuDataSource.lookupDataSource(className, conf) +> // `providingClass` is used for resolving data source relation for catalog tables. +> // As now catalog for data source V2 is under development, here we fall back all the +> // [[FileDataSourceV2]] to [[FileFormat]] to guarantee the current catalog works. +> // [[FileDataSourceV2]] will still be used if we call the load()/save() method in +> // [[DataFrameReader]]/[[DataFrameWriter]], since they use method `lookupDataSource` +> // instead of `providingClass`. +> val fallbackCls = ConstructorUtils.invokeConstructor(cls) match { +> case f: FileDataSourceV2 => f.fallbackFileFormat +> case _ => cls +> } +> // convert to GPU version +> fallbackCls +> } +54c69 +< classOf[OrcFileFormat].getCanonicalName +--- +> classOf[OrcDataSourceV2].getCanonicalName +141a157,160 +> +> /** +> * The key in the "options" map for deciding whether or not to glob paths before use. +> */ diff --git a/scripts/spark2diffs/GpuDecimalDivide.diff b/scripts/spark2diffs/GpuDecimalDivide.diff new file mode 100644 index 00000000000..0a29774f57d --- /dev/null +++ b/scripts/spark2diffs/GpuDecimalDivide.diff @@ -0,0 +1,10 @@ +27c27 +< GpuOverrides.DECIMAL128_MAX_PRECISION) +--- +> DType.DECIMAL128_MAX_PRECISION) +51,52c51,52 +< math.min(outputType.precision + 1, GpuOverrides.DECIMAL128_MAX_PRECISION), +< math.min(outputType.scale + 1, GpuOverrides.DECIMAL128_MAX_PRECISION)) +--- +> math.min(outputType.precision + 1, DType.DECIMAL128_MAX_PRECISION), +> math.min(outputType.scale + 1, DType.DECIMAL128_MAX_PRECISION)) diff --git a/scripts/spark2diffs/GpuDecimalMultiply.diff b/scripts/spark2diffs/GpuDecimalMultiply.diff new file mode 100644 index 00000000000..aebfb75661b --- /dev/null +++ b/scripts/spark2diffs/GpuDecimalMultiply.diff @@ -0,0 +1,12 @@ +1c1 +< object GpuDecimalMultiply { +--- +> object GpuDecimalMultiply extends Arm { +68c68 +< GpuOverrides.DECIMAL128_MAX_PRECISION) +--- +> DType.DECIMAL128_MAX_PRECISION) +85c85 +< math.min(outputType.scale + 1, GpuOverrides.DECIMAL128_MAX_PRECISION)) +--- +> math.min(outputType.scale + 1, DType.DECIMAL128_MAX_PRECISION)) diff --git a/scripts/spark2diffs/GpuExpandExecMeta.diff b/scripts/spark2diffs/GpuExpandExecMeta.diff new file mode 100644 index 00000000000..f3230aabc61 --- /dev/null +++ b/scripts/spark2diffs/GpuExpandExecMeta.diff @@ -0,0 +1,9 @@ +4c4 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +14a15,18 +> +> /** +> * Convert what this wraps to a GPU enabled version. +> */ diff --git a/scripts/spark2diffs/GpuFileSourceScanExec.diff b/scripts/spark2diffs/GpuFileSourceScanExec.diff new file mode 100644 index 00000000000..b5543699ac0 --- /dev/null +++ b/scripts/spark2diffs/GpuFileSourceScanExec.diff @@ -0,0 +1,11 @@ +10a11,20 +> +> def convertFileFormat(format: FileFormat): FileFormat = { +> format match { +> case _: CSVFileFormat => new GpuReadCSVFileFormat +> case f if GpuOrcFileFormat.isSparkOrcFormat(f) => new GpuReadOrcFileFormat +> case _: ParquetFileFormat => new GpuReadParquetFileFormat +> case f => +> throw new IllegalArgumentException(s"${f.getClass.getCanonicalName} is not supported") +> } +> } diff --git a/scripts/spark2diffs/GpuFlatMapGroupsInPandasExecMeta.diff b/scripts/spark2diffs/GpuFlatMapGroupsInPandasExecMeta.diff new file mode 100644 index 00000000000..a2c23932573 --- /dev/null +++ b/scripts/spark2diffs/GpuFlatMapGroupsInPandasExecMeta.diff @@ -0,0 +1,6 @@ +4c4 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +21a22 +> diff --git a/scripts/spark2diffs/GpuGenerateExecSparkPlanMeta.diff b/scripts/spark2diffs/GpuGenerateExecSparkPlanMeta.diff new file mode 100644 index 00000000000..0406c09957b --- /dev/null +++ b/scripts/spark2diffs/GpuGenerateExecSparkPlanMeta.diff @@ -0,0 +1,6 @@ +4c4 +< p: Option[RapidsMeta[_, _]], +--- +> p: Option[RapidsMeta[_, _, _]], +17a18 +> diff --git a/scripts/spark2diffs/GpuGetArrayItemMeta.diff b/scripts/spark2diffs/GpuGetArrayItemMeta.diff new file mode 100644 index 00000000000..4fbbd1d5d44 --- /dev/null +++ b/scripts/spark2diffs/GpuGetArrayItemMeta.diff @@ -0,0 +1,4 @@ +4c4 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], diff --git a/scripts/spark2diffs/GpuGetMapValueMeta.diff b/scripts/spark2diffs/GpuGetMapValueMeta.diff new file mode 100644 index 00000000000..ad0fd3c812d --- /dev/null +++ b/scripts/spark2diffs/GpuGetMapValueMeta.diff @@ -0,0 +1,6 @@ +4c4 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +6a7 +> diff --git a/scripts/spark2diffs/GpuHashJoin.diff b/scripts/spark2diffs/GpuHashJoin.diff new file mode 100644 index 00000000000..331be606a48 --- /dev/null +++ b/scripts/spark2diffs/GpuHashJoin.diff @@ -0,0 +1,18 @@ +2c2 +< def tagForGpu(joinType: JoinType, meta: RapidsMeta[_, _]): Unit = { +--- +> def tagForGpu(joinType: JoinType, meta: RapidsMeta[_, _, _]): Unit = { +69c69 +< object GpuHashJoin { +--- +> object GpuHashJoin extends Arm { +72c72 +< meta: RapidsMeta[_, _], +--- +> meta: RapidsMeta[_, _, _], +99a100 +> +120c121 +< } +--- +> diff --git a/scripts/spark2diffs/GpuHiveOverrides.diff b/scripts/spark2diffs/GpuHiveOverrides.diff new file mode 100644 index 00000000000..cd73ea94947 --- /dev/null +++ b/scripts/spark2diffs/GpuHiveOverrides.diff @@ -0,0 +1,57 @@ +2c2 +< * Copyright (c) 2020-2021, NVIDIA CORPORATION. +--- +> * Copyright (c) 2022, NVIDIA CORPORATION. +20c20 +< import com.nvidia.spark.rapids.{ExprChecks, ExprMeta, ExprRule, GpuExpression, GpuOverrides, RapidsConf, RepeatingParamCheck, ShimLoader, TypeSig} +--- +> import com.nvidia.spark.rapids.{ExprChecks, ExprMeta, ExprRule, GpuOverrides, RapidsConf, RepeatingParamCheck, TypeSig} +29,30c29,30 +< ShimLoader.loadClass("org.apache.spark.sql.hive.HiveSessionStateBuilder") +< ShimLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") +--- +> getClass().getClassLoader.loadClass("org.apache.spark.sql.hive.HiveSessionStateBuilder") +> getClass().getClassLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") +67,86d66 +< +< override def convertToGpu(): GpuExpression = { +< opRapidsFunc.map { _ => +< // We use the original HiveGenericUDF `deterministic` method as a proxy +< // for simplicity. +< GpuHiveSimpleUDF( +< a.name, +< a.funcWrapper, +< childExprs.map(_.convertToGpu()), +< a.dataType, +< a.deterministic) +< }.getOrElse { +< // This `require` is just for double check. +< require(conf.isCpuBasedUDFEnabled) +< GpuRowBasedHiveSimpleUDF( +< a.name, +< a.funcWrapper, +< childExprs.map(_.convertToGpu())) +< } +< } +106,126d85 +< } +< } +< +< override def convertToGpu(): GpuExpression = { +< opRapidsFunc.map { _ => +< // We use the original HiveGenericUDF `deterministic` method as a proxy +< // for simplicity. +< GpuHiveGenericUDF( +< a.name, +< a.funcWrapper, +< childExprs.map(_.convertToGpu()), +< a.dataType, +< a.deterministic, +< a.foldable) +< }.getOrElse { +< // This `require` is just for double check. +< require(conf.isCpuBasedUDFEnabled) +< GpuRowBasedHiveGenericUDF( +< a.name, +< a.funcWrapper, +< childExprs.map(_.convertToGpu())) diff --git a/scripts/spark2diffs/GpuJoinUtils.diff b/scripts/spark2diffs/GpuJoinUtils.diff new file mode 100644 index 00000000000..3b7ea872635 --- /dev/null +++ b/scripts/spark2diffs/GpuJoinUtils.diff @@ -0,0 +1,22 @@ +16,18d15 +< package com.nvidia.spark.rapids.shims.v2 +< +< import com.nvidia.spark.rapids.shims.v2._ +20,26c17 +< import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} +< +< /** +< * Spark BuildSide, BuildRight, BuildLeft moved packages in Spark 3.1 +< * so create GPU versions of these that can be agnostic to Spark version. +< */ +< sealed abstract class GpuBuildSide +--- +> package com.nvidia.spark.rapids.shims.v2 +28c19 +< case object GpuBuildRight extends GpuBuildSide +--- +> import com.nvidia.spark.rapids.{GpuBuildLeft, GpuBuildRight, GpuBuildSide} +30c21 +< case object GpuBuildLeft extends GpuBuildSide +--- +> import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} diff --git a/scripts/spark2diffs/GpuOrcFileFormat.diff b/scripts/spark2diffs/GpuOrcFileFormat.diff new file mode 100644 index 00000000000..fe09c17e591 --- /dev/null +++ b/scripts/spark2diffs/GpuOrcFileFormat.diff @@ -0,0 +1,14 @@ +12c12 +< def tagGpuSupport(meta: RapidsMeta[_, _], +--- +> def tagGpuSupport(meta: RapidsMeta[_, _, _], +15c15 +< schema: StructType): Unit = { +--- +> schema: StructType): Option[GpuOrcFileFormat] = { +79a80,84 +> if (meta.canThisBeReplaced) { +> Some(new GpuOrcFileFormat) +> } else { +> None +> } diff --git a/scripts/spark2diffs/GpuOrcScanBase.diff b/scripts/spark2diffs/GpuOrcScanBase.diff new file mode 100644 index 00000000000..3472a7ff290 --- /dev/null +++ b/scripts/spark2diffs/GpuOrcScanBase.diff @@ -0,0 +1,14 @@ +1a2,10 +> def tagSupport(scanMeta: ScanMeta[OrcScan]): Unit = { +> val scan = scanMeta.wrapped +> val schema = StructType(scan.readDataSchema ++ scan.readPartitionSchema) +> if (scan.options.getBoolean("mergeSchema", false)) { +> scanMeta.willNotWorkOnGpu("mergeSchema and schema evolution is not supported yet") +> } +> tagSupport(scan.sparkSession, schema, scanMeta) +> } +> +5c14 +< meta: RapidsMeta[_, _]): Unit = { +--- +> meta: RapidsMeta[_, _, _]): Unit = { diff --git a/scripts/spark2diffs/GpuOverrides.diff b/scripts/spark2diffs/GpuOverrides.diff new file mode 100644 index 00000000000..3d3cc65ee96 --- /dev/null +++ b/scripts/spark2diffs/GpuOverrides.diff @@ -0,0 +1,1681 @@ +2c2 +< * Copyright (c) 2022, NVIDIA CORPORATION. +--- +> * Copyright (c) 2019-2022, NVIDIA CORPORATION. +24a25 +> import ai.rapids.cudf.DType +26c27 +< import com.nvidia.spark.rapids.shims.v2._ +--- +> import com.nvidia.spark.rapids.shims.v2.{AQEUtils, GpuHashPartitioning, GpuSpecifiedWindowFrameMeta, GpuWindowExpressionMeta, OffsetWindowFunctionMeta} +31a33,34 +> import org.apache.spark.sql.catalyst.expressions.rapids.TimeStamp +> import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero +33a37 +> import org.apache.spark.sql.catalyst.trees.TreeNodeTag +34a39 +> import org.apache.spark.sql.connector.read.Scan +36,37c41,42 +< import org.apache.spark.sql.execution.ScalarSubquery +< import org.apache.spark.sql.execution.aggregate._ +--- +> import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, ShuffleQueryStageExec} +> import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} +43a49,50 +> import org.apache.spark.sql.execution.datasources.v2._ +> import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan +45a53 +> import org.apache.spark.sql.execution.python._ +49a58 +> import org.apache.spark.sql.rapids.catalyst.expressions.GpuRand +50a60,61 +> import org.apache.spark.sql.rapids.execution.python._ +> import org.apache.spark.sql.rapids.shims.v2.GpuTimeAdd +63c74 +< abstract class ReplacementRule[INPUT <: BASE, BASE, WRAP_TYPE <: RapidsMeta[INPUT, BASE]]( +--- +> abstract class ReplacementRule[INPUT <: BASE, BASE, WRAP_TYPE <: RapidsMeta[INPUT, BASE, _]]( +67c78 +< Option[RapidsMeta[_, _]], +--- +> Option[RapidsMeta[_, _, _]], +118c129 +< Option[RapidsMeta[_, _]], +--- +> Option[RapidsMeta[_, _, _]], +187c198 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +202c213 +< Option[RapidsMeta[_, _]], +--- +> Option[RapidsMeta[_, _, _]], +217d227 +< /* +232c242 +< */ +--- +> +240c250 +< Option[RapidsMeta[_, _]], +--- +> Option[RapidsMeta[_, _, _]], +259c269 +< Option[RapidsMeta[_, _]], +--- +> Option[RapidsMeta[_, _, _]], +281c291 +< Option[RapidsMeta[_, _]], +--- +> Option[RapidsMeta[_, _, _]], +292d301 +< +296c305 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +300,304c309 +< // spark 2.3 doesn't have this so just code it here +< def sparkSessionActive: SparkSession = { +< SparkSession.getActiveSession.getOrElse(SparkSession.getDefaultSession.getOrElse( +< throw new IllegalStateException("No active or default Spark session found"))) +< } +--- +> private var fileFormat: Option[ColumnarFileFormat] = None +311c316 +< val spark = sparkSessionActive +--- +> val spark = SparkSession.active +313c318 +< cmd.fileFormat match { +--- +> fileFormat = cmd.fileFormat match { +315a321 +> None +317a324 +> None +323a331 +> None +325a334 +> None +328d336 +< } +329a338,356 +> override def convertToGpu(): GpuDataWritingCommand = { +> val format = fileFormat.getOrElse( +> throw new IllegalStateException("fileFormat missing, tagSelfForGpu not called?")) +> +> GpuInsertIntoHadoopFsRelationCommand( +> cmd.outputPath, +> cmd.staticPartitions, +> cmd.ifPartitionNotExists, +> cmd.partitionColumns, +> cmd.bucketSpec, +> format, +> cmd.options, +> cmd.query, +> cmd.mode, +> cmd.catalogTable, +> cmd.fileIndex, +> cmd.outputColumnNames) +> } +> } +334c361 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +339,344c366 +< +< // spark 2.3 doesn't have this so just code it here +< def sparkSessionActive: SparkSession = { +< SparkSession.getActiveSession.getOrElse(SparkSession.getDefaultSession.getOrElse( +< throw new IllegalStateException("No active or default Spark session found"))) +< } +--- +> private var gpuProvider: Option[ColumnarFileFormat] = None +354c376 +< val spark = sparkSessionActive +--- +> val spark = SparkSession.active +356c378 +< GpuDataSource.lookupDataSource(cmd.table.provider.get, spark.sessionState.conf) +--- +> GpuDataSource.lookupDataSourceWithFallback(cmd.table.provider.get, spark.sessionState.conf) +359c381 +< origProvider.getConstructor().newInstance() match { +--- +> gpuProvider = origProvider.getConstructor().newInstance() match { +362d383 +< None +366d386 +< None +372a393,404 +> override def convertToGpu(): GpuDataWritingCommand = { +> val newProvider = gpuProvider.getOrElse( +> throw new IllegalStateException("fileFormat unexpected, tagSelfForGpu not called?")) +> +> GpuCreateDataSourceTableAsSelectCommand( +> cmd.table, +> cmd.mode, +> cmd.query, +> cmd.outputColumnNames, +> origProvider, +> newProvider) +> } +374a407,409 +> /** +> * Listener trait so that tests can confirm that the expected optimizations are being applied +> */ +401,403d435 +< // copy here for 2.x +< sealed abstract class Optimization +< +405,409d436 +< // Spark 2.x - don't pull in cudf so hardcode here +< val DECIMAL32_MAX_PRECISION = 9 +< val DECIMAL64_MAX_PRECISION = 18 +< val DECIMAL128_MAX_PRECISION = 38 +< +467a495,572 +> private def convertExprToGpuIfPossible(expr: Expression, conf: RapidsConf): Expression = { +> if (expr.find(_.isInstanceOf[GpuExpression]).isDefined) { +> // already been converted +> expr +> } else { +> val wrapped = wrapExpr(expr, conf, None) +> wrapped.tagForGpu() +> if (wrapped.canExprTreeBeReplaced) { +> wrapped.convertToGpu() +> } else { +> expr +> } +> } +> } +> +> private def convertPartToGpuIfPossible(part: Partitioning, conf: RapidsConf): Partitioning = { +> part match { +> case _: GpuPartitioning => part +> case _ => +> val wrapped = wrapPart(part, conf, None) +> wrapped.tagForGpu() +> if (wrapped.canThisBeReplaced) { +> wrapped.convertToGpu() +> } else { +> part +> } +> } +> } +> +> /** +> * Removes unnecessary CPU shuffles that Spark can add to the plan when it does not realize +> * a GPU partitioning satisfies a CPU distribution because CPU and GPU expressions are not +> * semantically equal. +> */ +> def removeExtraneousShuffles(plan: SparkPlan, conf: RapidsConf): SparkPlan = { +> plan.transformUp { +> case cpuShuffle: ShuffleExchangeExec => +> cpuShuffle.child match { +> case sqse: ShuffleQueryStageExec => +> GpuTransitionOverrides.getNonQueryStagePlan(sqse) match { +> case gpuShuffle: GpuShuffleExchangeExecBase => +> val converted = convertPartToGpuIfPossible(cpuShuffle.outputPartitioning, conf) +> if (converted == gpuShuffle.outputPartitioning) { +> sqse +> } else { +> cpuShuffle +> } +> case _ => cpuShuffle +> } +> case _ => cpuShuffle +> } +> } +> } +> +> /** +> * Searches the plan for ReusedExchangeExec instances containing a GPU shuffle where the +> * output types between the two plan nodes do not match. In such a case the ReusedExchangeExec +> * will be updated to match the GPU shuffle output types. +> */ +> def fixupReusedExchangeExecs(plan: SparkPlan): SparkPlan = { +> def outputTypesMatch(a: Seq[Attribute], b: Seq[Attribute]): Boolean = +> a.corresponds(b)((x, y) => x.dataType == y.dataType) +> plan.transformUp { +> case sqse: ShuffleQueryStageExec => +> sqse.plan match { +> case ReusedExchangeExec(output, gsee: GpuShuffleExchangeExecBase) if ( +> !outputTypesMatch(output, gsee.output)) => +> val newOutput = sqse.plan.output.zip(gsee.output).map { case (c, g) => +> assert(c.isInstanceOf[AttributeReference] && g.isInstanceOf[AttributeReference], +> s"Expected AttributeReference but found $c and $g") +> AttributeReference(c.name, g.dataType, c.nullable, c.metadata)(c.exprId, c.qualifier) +> } +> AQEUtils.newReuseInstance(sqse, newOutput) +> case _ => sqse +> } +> } +> } +> +558c663 +< case dt: DecimalType if allowDecimal => dt.precision <= GpuOverrides.DECIMAL64_MAX_PRECISION +--- +> case dt: DecimalType if allowDecimal => dt.precision <= DType.DECIMAL64_MAX_PRECISION +581c686 +< def checkAndTagFloatAgg(dataType: DataType, conf: RapidsConf, meta: RapidsMeta[_,_]): Unit = { +--- +> def checkAndTagFloatAgg(dataType: DataType, conf: RapidsConf, meta: RapidsMeta[_,_,_]): Unit = { +595c700 +< meta: RapidsMeta[_,_]): Unit = { +--- +> meta: RapidsMeta[_,_,_]): Unit = { +605a711,742 +> /** +> * Helper function specific to ANSI mode for the aggregate functions that should +> * fallback, since we don't have the same overflow checks that Spark provides in +> * the CPU +> * @param checkType Something other than `None` triggers logic to detect whether +> * the agg should fallback in ANSI mode. Otherwise (None), it's +> * an automatic fallback. +> * @param meta agg expression meta +> */ +> def checkAndTagAnsiAgg(checkType: Option[DataType], meta: AggExprMeta[_]): Unit = { +> val failOnError = SQLConf.get.ansiEnabled +> if (failOnError) { +> if (checkType.isDefined) { +> val typeToCheck = checkType.get +> val failedType = typeToCheck match { +> case _: DecimalType | LongType | IntegerType | ShortType | ByteType => true +> case _ => false +> } +> if (failedType) { +> meta.willNotWorkOnGpu( +> s"ANSI mode not supported for ${meta.expr} with $typeToCheck result type") +> } +> } else { +> // Average falls into this category, where it produces Doubles, but +> // internally it uses Double and Long, and Long could overflow (technically) +> // and failOnError given that it is based on catalyst Add. +> meta.willNotWorkOnGpu( +> s"ANSI mode not supported for ${meta.expr}") +> } +> } +> } +> +609c746 +< doWrap: (INPUT, RapidsConf, Option[RapidsMeta[_, _]], DataFromReplacementRule) +--- +> doWrap: (INPUT, RapidsConf, Option[RapidsMeta[_, _, _]], DataFromReplacementRule) +616a754,763 +> def scan[INPUT <: Scan]( +> desc: String, +> doWrap: (INPUT, RapidsConf, Option[RapidsMeta[_, _, _]], DataFromReplacementRule) +> => ScanMeta[INPUT]) +> (implicit tag: ClassTag[INPUT]): ScanRule[INPUT] = { +> assert(desc != null) +> assert(doWrap != null) +> new ScanRule[INPUT](doWrap, desc, tag) +> } +> +620c767 +< doWrap: (INPUT, RapidsConf, Option[RapidsMeta[_, _]], DataFromReplacementRule) +--- +> doWrap: (INPUT, RapidsConf, Option[RapidsMeta[_, _, _]], DataFromReplacementRule) +638c785 +< p: Option[RapidsMeta[_, _]], +--- +> p: Option[RapidsMeta[_, _, _]], +647c794 +< doWrap: (INPUT, RapidsConf, Option[RapidsMeta[_, _]], DataFromReplacementRule) +--- +> doWrap: (INPUT, RapidsConf, Option[RapidsMeta[_, _, _]], DataFromReplacementRule) +657c804 +< doWrap: (INPUT, RapidsConf, Option[RapidsMeta[_, _]], DataFromReplacementRule) +--- +> doWrap: (INPUT, RapidsConf, Option[RapidsMeta[_, _, _]], DataFromReplacementRule) +668c815 +< parent: Option[RapidsMeta[_, _]]): BaseExprMeta[INPUT] = +--- +> parent: Option[RapidsMeta[_, _, _]]): BaseExprMeta[INPUT] = +694d840 +< +709a856 +> override def convertToGpu(child: Expression): GpuExpression = GpuSignum(child) +718a866,867 +> override def convertToGpu(child: Expression): GpuExpression = +> GpuAlias(child, a.name)(a.exprId, a.qualifier, a.explicitMetadata) +728a878 +> override def convertToGpu(): Expression = att +742a893 +> override def convertToGpu(child: Expression): GpuExpression = GpuPromotePrecision(child) +761,762c912 +< // allowNegativeScaleOfDecimalEnabled is not in 2.x assume its default false +< val t = if (s < 0 && !false) { +--- +> val t = if (s < 0 && !SQLConf.get.allowNegativeScaleOfDecimalEnabled) { +774c924 +< case PromotePrecision(c: Cast) if c.dataType.isInstanceOf[DecimalType] => +--- +> case PromotePrecision(c: CastBase) if c.dataType.isInstanceOf[DecimalType] => +830c980 +< if (intermediatePrecision > GpuOverrides.DECIMAL128_MAX_PRECISION) { +--- +> if (intermediatePrecision > DType.DECIMAL128_MAX_PRECISION) { +845c995 +< if (intermediatePrecision > GpuOverrides.DECIMAL128_MAX_PRECISION) { +--- +> if (intermediatePrecision > DType.DECIMAL128_MAX_PRECISION) { +858a1009,1026 +> +> override def convertToGpu(): GpuExpression = { +> a.child match { +> case _: Divide => +> // GpuDecimalDivide includes the overflow check in it. +> GpuDecimalDivide(lhs.convertToGpu(), rhs.convertToGpu(), wrapped.dataType) +> case _: Multiply => +> // GpuDecimal*Multiply includes the overflow check in it +> val intermediatePrecision = +> GpuDecimalMultiply.nonRoundedIntermediatePrecision(lhsDecimalType, +> rhsDecimalType, a.dataType) +> GpuDecimalMultiply(lhs.convertToGpu(), rhs.convertToGpu(), wrapped.dataType, +> needsExtraOverflowChecks = intermediatePrecision > DType.DECIMAL128_MAX_PRECISION) +> case _ => +> GpuCheckOverflow(childExprs.head.convertToGpu(), +> wrapped.dataType, wrapped.nullOnOverflow) +> } +> } +863a1032 +> override def convertToGpu(child: Expression): GpuToDegrees = GpuToDegrees(child) +868a1038 +> override def convertToGpu(child: Expression): GpuToRadians = GpuToRadians(child) +907a1078 +> override def convertToGpu(): GpuExpression = GpuSpecialFrameBoundary(currentRow) +913a1085 +> override def convertToGpu(): GpuExpression = GpuSpecialFrameBoundary(unboundedPreceding) +919a1092 +> override def convertToGpu(): GpuExpression = GpuSpecialFrameBoundary(unboundedFollowing) +924a1098 +> override def convertToGpu(): GpuExpression = GpuRowNumber +933a1108 +> override def convertToGpu(): GpuExpression = GpuRank(childExprs.map(_.convertToGpu())) +942a1118 +> override def convertToGpu(): GpuExpression = GpuDenseRank(childExprs.map(_.convertToGpu())) +962a1139,1140 +> override def convertToGpu(): GpuExpression = +> GpuLead(input.convertToGpu(), offset.convertToGpu(), default.convertToGpu()) +982a1161,1162 +> override def convertToGpu(): GpuExpression = +> GpuLag(input.convertToGpu(), offset.convertToGpu(), default.convertToGpu()) +992a1173,1174 +> override def convertToGpu(child: Expression): GpuExpression = +> GpuPreciseTimestampConversion(child, a.fromType, a.toType) +1001,1002c1183 +< // val ansiEnabled = SQLConf.get.ansiEnabled +< val ansiEnabled = false +--- +> val ansiEnabled = SQLConf.get.ansiEnabled +1005,1006d1185 +< // Spark 2.x - ansi in not in 2.x +< /* +1010,1011d1188 +< +< */ +1012a1190,1192 +> +> override def convertToGpu(child: Expression): GpuExpression = +> GpuUnaryMinus(child, ansiEnabled) +1020a1201 +> override def convertToGpu(child: Expression): GpuExpression = GpuUnaryPositive(child) +1025a1207 +> override def convertToGpu(child: Expression): GpuExpression = GpuYear(child) +1030a1213 +> override def convertToGpu(child: Expression): GpuExpression = GpuMonth(child) +1035a1219 +> override def convertToGpu(child: Expression): GpuExpression = GpuQuarter(child) +1040a1225 +> override def convertToGpu(child: Expression): GpuExpression = GpuDayOfMonth(child) +1045a1231 +> override def convertToGpu(child: Expression): GpuExpression = GpuDayOfYear(child) +1050a1237,1248 +> override def convertToGpu(child: Expression): GpuExpression = GpuAcos(child) +> }), +> expr[Acosh]( +> "Inverse hyperbolic cosine", +> ExprChecks.mathUnaryWithAst, +> (a, conf, p, r) => new UnaryAstExprMeta[Acosh](a, conf, p, r) { +> override def convertToGpu(child: Expression): GpuExpression = +> if (conf.includeImprovedFloat) { +> GpuAcoshImproved(child) +> } else { +> GpuAcoshCompat(child) +> } +1055a1254,1274 +> override def convertToGpu(child: Expression): GpuExpression = GpuAsin(child) +> }), +> expr[Asinh]( +> "Inverse hyperbolic sine", +> ExprChecks.mathUnaryWithAst, +> (a, conf, p, r) => new UnaryAstExprMeta[Asinh](a, conf, p, r) { +> override def convertToGpu(child: Expression): GpuExpression = +> if (conf.includeImprovedFloat) { +> GpuAsinhImproved(child) +> } else { +> GpuAsinhCompat(child) +> } +> +> override def tagSelfForAst(): Unit = { +> if (!conf.includeImprovedFloat) { +> // AST is not expressive enough yet to implement the conditional expression needed +> // to emulate Spark's behavior +> willNotWorkInAst("asinh is not AST compatible unless " + +> s"${RapidsConf.IMPROVED_FLOAT_OPS.key} is enabled") +> } +> } +1060a1280 +> override def convertToGpu(child: Expression): GpuExpression = GpuSqrt(child) +1065a1286 +> override def convertToGpu(child: Expression): GpuExpression = GpuCbrt(child) +1077c1298 +< if (precision > GpuOverrides.DECIMAL128_MAX_PRECISION) { +--- +> if (precision > DType.DECIMAL128_MAX_PRECISION) { +1084a1306 +> override def convertToGpu(child: Expression): GpuExpression = GpuFloor(child) +1096c1318 +< if (precision > GpuOverrides.DECIMAL128_MAX_PRECISION) { +--- +> if (precision > DType.DECIMAL128_MAX_PRECISION) { +1103a1326 +> override def convertToGpu(child: Expression): GpuExpression = GpuCeil(child) +1109a1333 +> override def convertToGpu(child: Expression): GpuExpression = GpuNot(child) +1117a1342 +> override def convertToGpu(child: Expression): GpuExpression = GpuIsNull(child) +1125a1351 +> override def convertToGpu(child: Expression): GpuExpression = GpuIsNotNull(child) +1131a1358 +> override def convertToGpu(child: Expression): GpuExpression = GpuIsNan(child) +1136a1364 +> override def convertToGpu(child: Expression): GpuExpression = GpuRint(child) +1142a1371 +> override def convertToGpu(child: Expression): GpuExpression = GpuBitwiseNot(child) +1151a1381,1382 +> def convertToGpu(): GpuExpression = +> GpuAtLeastNNonNulls(a.n, childExprs.map(_.convertToGpu())) +1160a1392,1393 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuDateAdd(lhs, rhs) +1169a1403,1404 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuDateSub(lhs, rhs) +1176a1412,1413 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuNaNvl(lhs, rhs) +1183a1421,1422 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuShiftLeft(lhs, rhs) +1190a1430,1431 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuShiftRight(lhs, rhs) +1197a1439,1440 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuShiftRightUnsigned(lhs, rhs) +1205a1449,1450 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuBitwiseAnd(lhs, rhs) +1213a1459,1460 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuBitwiseOr(lhs, rhs) +1221a1469,1470 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuBitwiseXor(lhs, rhs) +1231a1481 +> override def convertToGpu(): GpuExpression = GpuCoalesce(childExprs.map(_.convertToGpu())) +1240a1491 +> override def convertToGpu(): GpuExpression = GpuLeast(childExprs.map(_.convertToGpu())) +1249a1501 +> override def convertToGpu(): GpuExpression = GpuGreatest(childExprs.map(_.convertToGpu())) +1254a1507,1513 +> override def convertToGpu(child: Expression): GpuExpression = GpuAtan(child) +> }), +> expr[Atanh]( +> "Inverse hyperbolic tangent", +> ExprChecks.mathUnaryWithAst, +> (a, conf, p, r) => new UnaryAstExprMeta[Atanh](a, conf, p, r) { +> override def convertToGpu(child: Expression): GpuExpression = GpuAtanh(child) +1259a1519 +> override def convertToGpu(child: Expression): GpuExpression = GpuCos(child) +1264a1525 +> override def convertToGpu(child: Expression): GpuExpression = GpuExp(child) +1269a1531 +> override def convertToGpu(child: Expression): GpuExpression = GpuExpm1(child) +1275a1538 +> override def convertToGpu(child: Expression): GpuExpression = GpuInitCap(child) +1280a1544 +> override def convertToGpu(child: Expression): GpuExpression = GpuLog(child) +1285a1550,1554 +> override def convertToGpu(child: Expression): GpuExpression = { +> // No need for overflow checking on the GpuAdd in Double as Double handles overflow +> // the same in all modes. +> GpuLog(GpuAdd(child, GpuLiteral(1d, DataTypes.DoubleType), false)) +> } +1290a1560,1561 +> override def convertToGpu(child: Expression): GpuExpression = +> GpuLogarithm(child, GpuLiteral(2d, DataTypes.DoubleType)) +1295a1567,1568 +> override def convertToGpu(child: Expression): GpuExpression = +> GpuLogarithm(child, GpuLiteral(10d, DataTypes.DoubleType)) +1302a1576,1578 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> // the order of the parameters is transposed intentionally +> GpuLogarithm(rhs, lhs) +1307a1584 +> override def convertToGpu(child: Expression): GpuExpression = GpuSin(child) +1312a1590 +> override def convertToGpu(child: Expression): GpuExpression = GpuSinh(child) +1317a1596 +> override def convertToGpu(child: Expression): GpuExpression = GpuCosh(child) +1322a1602 +> override def convertToGpu(child: Expression): GpuExpression = GpuCot(child) +1327a1608 +> override def convertToGpu(child: Expression): GpuExpression = GpuTanh(child) +1332a1614,1632 +> override def convertToGpu(child: Expression): GpuExpression = GpuTan(child) +> }), +> expr[NormalizeNaNAndZero]( +> "Normalize NaN and zero", +> ExprChecks.unaryProjectInputMatchesOutput( +> TypeSig.DOUBLE + TypeSig.FLOAT, +> TypeSig.DOUBLE + TypeSig.FLOAT), +> (a, conf, p, r) => new UnaryExprMeta[NormalizeNaNAndZero](a, conf, p, r) { +> override def convertToGpu(child: Expression): GpuExpression = +> GpuNormalizeNaNAndZero(child) +> }), +> expr[KnownFloatingPointNormalized]( +> "Tag to prevent redundant normalization", +> ExprChecks.unaryProjectInputMatchesOutput( +> TypeSig.DOUBLE + TypeSig.FLOAT, +> TypeSig.DOUBLE + TypeSig.FLOAT), +> (a, conf, p, r) => new UnaryExprMeta[KnownFloatingPointNormalized](a, conf, p, r) { +> override def convertToGpu(child: Expression): GpuExpression = +> GpuKnownFloatingPointNormalized(child) +1339a1640,1641 +> override def convertToGpu(child: Expression): GpuExpression = +> GpuKnownNotNull(child) +1346a1649,1651 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { +> GpuDateDiff(lhs, rhs) +> } +1365a1671,1672 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuTimeAdd(lhs, rhs) +1366a1674,1695 +> expr[DateAddInterval]( +> "Adds interval to date", +> ExprChecks.binaryProject(TypeSig.DATE, TypeSig.DATE, +> ("start", TypeSig.DATE, TypeSig.DATE), +> ("interval", TypeSig.lit(TypeEnum.CALENDAR) +> .withPsNote(TypeEnum.CALENDAR, "month intervals are not supported"), +> TypeSig.CALENDAR)), +> (dateAddInterval, conf, p, r) => +> new BinaryExprMeta[DateAddInterval](dateAddInterval, conf, p, r) { +> override def tagExprForGpu(): Unit = { +> GpuOverrides.extractLit(dateAddInterval.interval).foreach { lit => +> val intvl = lit.value.asInstanceOf[CalendarInterval] +> if (intvl.months != 0) { +> willNotWorkOnGpu("interval months isn't supported") +> } +> } +> checkTimeZoneId(dateAddInterval.timeZoneId) +> } +> +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuDateAddInterval(lhs, rhs) +> }), +1376a1706,1707 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuDateFormatClass(lhs, rhs, strfFormat) +1389,1390c1720,1730 +< override def shouldFallbackOnAnsiTimestamp: Boolean = false +< // ShimLoader.getSparkShims.shouldFallbackOnAnsiTimestamp +--- +> override def shouldFallbackOnAnsiTimestamp: Boolean = +> ShimLoader.getSparkShims.shouldFallbackOnAnsiTimestamp +> +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { +> if (conf.isImprovedTimestampOpsEnabled) { +> // passing the already converted strf string for a little optimization +> GpuToUnixTimestampImproved(lhs, rhs, sparkFormat, strfFormat) +> } else { +> GpuToUnixTimestamp(lhs, rhs, sparkFormat, strfFormat) +> } +> } +1402,1403c1742,1743 +< override def shouldFallbackOnAnsiTimestamp: Boolean = false +< // ShimLoader.getSparkShims.shouldFallbackOnAnsiTimestamp +--- +> override def shouldFallbackOnAnsiTimestamp: Boolean = +> ShimLoader.getSparkShims.shouldFallbackOnAnsiTimestamp +1404a1745,1752 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { +> if (conf.isImprovedTimestampOpsEnabled) { +> // passing the already converted strf string for a little optimization +> GpuUnixTimestampImproved(lhs, rhs, sparkFormat, strfFormat) +> } else { +> GpuUnixTimestamp(lhs, rhs, sparkFormat, strfFormat) +> } +> } +1414a1763 +> override def convertToGpu(expr: Expression): GpuExpression = GpuHour(expr) +1424a1774,1775 +> override def convertToGpu(expr: Expression): GpuExpression = +> GpuMinute(expr) +1433a1785,1787 +> +> override def convertToGpu(expr: Expression): GpuExpression = +> GpuSecond(expr) +1439a1794,1795 +> override def convertToGpu(expr: Expression): GpuExpression = +> GpuWeekDay(expr) +1445a1802,1803 +> override def convertToGpu(expr: Expression): GpuExpression = +> GpuDayOfWeek(expr) +1450a1809,1810 +> override def convertToGpu(expr: Expression): GpuExpression = +> GpuLastDay(expr) +1461a1822,1824 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> // passing the already converted strf string for a little optimization +> GpuFromUnixTime(lhs, rhs, strfFormat) +1468a1832,1833 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuPmod(lhs, rhs) +1480c1845 +< private val ansiEnabled = false +--- +> private val ansiEnabled = SQLConf.get.ansiEnabled +1482a1848,1850 +> if (ansiEnabled && GpuAnsi.needBasicOpOverflowCheck(a.dataType)) { +> willNotWorkInAst("AST Addition does not support ANSI mode.") +> } +1484a1853,1854 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuAdd(lhs, rhs, failOnError = ansiEnabled) +1496c1866 +< private val ansiEnabled = false +--- +> private val ansiEnabled = SQLConf.get.ansiEnabled +1498a1869,1871 +> if (ansiEnabled && GpuAnsi.needBasicOpOverflowCheck(a.dataType)) { +> willNotWorkInAst("AST Subtraction does not support ANSI mode.") +> } +1500a1874,1875 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuSubtract(lhs, rhs, ansiEnabled) +1513a1889,1891 +> if (SQLConf.get.ansiEnabled && GpuAnsi.needBasicOpOverflowCheck(a.dataType)) { +> willNotWorkOnGpu("GPU Multiplication does not support ANSI mode") +> } +1515a1894,1901 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { +> a.dataType match { +> case _: DecimalType => throw new IllegalStateException( +> "Decimal Multiply should be converted in CheckOverflow") +> case _ => +> GpuMultiply(lhs, rhs) +> } +> } +1522a1909,1910 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuAnd(lhs, rhs) +1529a1918,1919 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuOr(lhs, rhs) +1539a1930,1931 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuEqualNullSafe(lhs, rhs) +1550a1943,1944 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuEqualTo(lhs, rhs) +1561a1956,1957 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuGreaterThan(lhs, rhs) +1572a1969,1970 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuGreaterThanOrEqual(lhs, rhs) +1592a1991,1992 +> override def convertToGpu(): GpuExpression = +> GpuInSet(childExprs.head.convertToGpu(), in.list.asInstanceOf[Seq[Literal]].map(_.value)) +1603a2004,2005 +> override def convertToGpu(): GpuExpression = +> GpuInSet(childExprs.head.convertToGpu(), in.hset.toSeq) +1614a2017,2018 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuLessThan(lhs, rhs) +1625a2030,2031 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuLessThanOrEqual(lhs, rhs) +1630a2037,2048 +> override def convertToGpu(): GpuExpression = { +> val branches = childExprs.grouped(2).flatMap { +> case Seq(cond, value) => Some((cond.convertToGpu(), value.convertToGpu())) +> case Seq(_) => None +> }.toArray.toSeq // force materialization to make the seq serializable +> val elseValue = if (childExprs.size % 2 != 0) { +> Some(childExprs.last.convertToGpu()) +> } else { +> None +> } +> GpuCaseWhen(branches, elseValue) +> } +1647a2066,2069 +> override def convertToGpu(): GpuExpression = { +> val Seq(boolExpr, trueExpr, falseExpr) = childExprs.map(_.convertToGpu()) +> GpuIf(boolExpr, trueExpr, falseExpr) +> } +1655a2078,2079 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuPow(lhs, rhs) +1669a2094,2113 +> // Division of Decimal types is a little odd. To work around some issues with +> // what Spark does the tagging/checks are in CheckOverflow instead of here. +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> a.dataType match { +> case _: DecimalType => +> throw new IllegalStateException("Internal Error: Decimal Divide operations " + +> "should be converted to the GPU in the CheckOverflow rule") +> case _ => +> GpuDivide(lhs, rhs) +> } +> }), +> expr[IntegralDivide]( +> "Division with a integer result", +> ExprChecks.binaryProject( +> TypeSig.LONG, TypeSig.LONG, +> ("lhs", TypeSig.LONG + TypeSig.DECIMAL_128, TypeSig.LONG + TypeSig.DECIMAL_128), +> ("rhs", TypeSig.LONG + TypeSig.DECIMAL_128, TypeSig.LONG + TypeSig.DECIMAL_128)), +> (a, conf, p, r) => new BinaryExprMeta[IntegralDivide](a, conf, p, r) { +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuIntegralDivide(lhs, rhs) +1677a2122,2123 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuRemainder(lhs, rhs) +1692c2138,2139 +< // No filter parameter in 2.x +--- +> private val filter: Option[BaseExprMeta[_]] = +> a.filter.map(GpuOverrides.wrapExpr(_, conf, Some(this))) +1696c2143,2158 +< childrenExprMeta +--- +> childrenExprMeta ++ filter.toSeq +> +> override def convertToGpu(): GpuExpression = { +> // handle the case AggregateExpression has the resultIds parameter where its +> // Seq[ExprIds] instead of single ExprId. +> val resultId = try { +> val resultMethod = a.getClass.getMethod("resultId") +> resultMethod.invoke(a).asInstanceOf[ExprId] +> } catch { +> case _: Exception => +> val resultMethod = a.getClass.getMethod("resultIds") +> resultMethod.invoke(a).asInstanceOf[Seq[ExprId]].head +> } +> GpuAggregateExpression(childExprs.head.convertToGpu().asInstanceOf[GpuAggregateFunction], +> a.mode, a.isDistinct, filter.map(_.convertToGpu()), resultId) +> } +1719a2182,2184 +> // One of the few expressions that are not replaced with a GPU version +> override def convertToGpu(): Expression = +> sortOrder.withNewChildren(childExprs.map(_.convertToGpu())) +1744a2210,2216 +> override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = { +> val Seq(pivotColumn, valueColumn) = childExprs +> GpuPivotFirst(pivotColumn, valueColumn, pivot.pivotColumnValues) +> } +> +> // Pivot does not overflow, so it doesn't need the ANSI check +> override val needsAnsiCheck: Boolean = false +1759a2232,2233 +> override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = +> GpuCount(childExprs) +1788a2263,2268 +> +> override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = +> GpuMax(childExprs.head) +> +> // Max does not overflow, so it doesn't need the ANSI check +> override val needsAnsiCheck: Boolean = false +1817a2298,2303 +> +> override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = +> GpuMin(childExprs.head) +> +> // Min does not overflow, so it doesn't need the ANSI check +> override val needsAnsiCheck: Boolean = false +1830a2317,2318 +> override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = +> GpuSum(childExprs.head, a.dataType) +1844a2333,2337 +> override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = +> GpuFirst(childExprs.head, a.ignoreNulls) +> +> // First does not overflow, so it doesn't need the ANSI check +> override val needsAnsiCheck: Boolean = false +1858a2352,2356 +> override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = +> GpuLast(childExprs.head, a.ignoreNulls) +> +> // Last does not overflow, so it doesn't need the ANSI check +> override val needsAnsiCheck: Boolean = false +1877a2376,2377 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuBRound(lhs, rhs) +1896a2397,2398 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuRound(lhs, rhs) +1918a2421,2425 +> +> override def convertToGpu(): GpuExpression = +> GpuPythonUDF(a.name, a.func, a.dataType, +> childExprs.map(_.convertToGpu()), +> a.evalType, a.udfDeterministic, a.resultId) +1926a2434 +> override def convertToGpu(child: Expression): GpuExpression = GpuRand(child) +1931a2440 +> override def convertToGpu(): GpuExpression = GpuSparkPartitionID() +1936a2446 +> override def convertToGpu(): GpuExpression = GpuMonotonicallyIncreasingID() +1941a2452 +> override def convertToGpu(): GpuExpression = GpuInputFileName() +1946a2458 +> override def convertToGpu(): GpuExpression = GpuInputFileBlockStart() +1951a2464 +> override def convertToGpu(): GpuExpression = GpuInputFileBlockLength() +1957a2471 +> override def convertToGpu(child: Expression): GpuExpression = GpuMd5(child) +1962a2477 +> override def convertToGpu(child: Expression): GpuExpression = GpuUpper(child) +1968a2484 +> override def convertToGpu(child: Expression): GpuExpression = GpuLower(child) +1985a2502,2506 +> override def convertToGpu( +> str: Expression, +> width: Expression, +> pad: Expression): GpuExpression = +> GpuStringLPad(str, width, pad) +2001a2523,2527 +> override def convertToGpu( +> str: Expression, +> width: Expression, +> pad: Expression): GpuExpression = +> GpuStringRPad(str, width, pad) +2022a2549,2550 +> override def convertToGpu(arr: Expression): GpuExpression = +> GpuGetStructField(arr, expr.ordinal, expr.name) +2080a2609,2613 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { +> // This will be called under 3.0.x version, so set failOnError to false to match CPU +> // behavior +> GpuElementAt(lhs, rhs, failOnError = false) +> } +2091a2625,2626 +> override def convertToGpu(child: Expression): GpuExpression = +> GpuMapKeys(child) +2102a2638,2653 +> override def convertToGpu(child: Expression): GpuExpression = +> GpuMapValues(child) +> }), +> expr[MapEntries]( +> "Returns an unordered array of all entries in the given map", +> ExprChecks.unaryProject( +> // Technically the return type is an array of struct, but we cannot really express that +> TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + +> TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), +> TypeSig.ARRAY.nested(TypeSig.all), +> TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + +> TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), +> TypeSig.MAP.nested(TypeSig.all)), +> (in, conf, p, r) => new UnaryExprMeta[MapEntries](in, conf, p, r) { +> override def convertToGpu(child: Expression): GpuExpression = +> GpuMapEntries(child) +2110,2111c2661,2662 +< .withPsNote(TypeEnum.DOUBLE, GpuOverrides.nanAggPsNote) +< .withPsNote(TypeEnum.FLOAT, GpuOverrides.nanAggPsNote), +--- +> .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) +> .withPsNote(TypeEnum.FLOAT, nanAggPsNote), +2115c2666 +< GpuOverrides.checkAndTagFloatNanAgg("Min", in.dataType, conf, this) +--- +> checkAndTagFloatNanAgg("Min", in.dataType, conf, this) +2116a2668,2670 +> +> override def convertToGpu(child: Expression): GpuExpression = +> GpuArrayMin(child) +2124,2125c2678,2679 +< .withPsNote(TypeEnum.DOUBLE, GpuOverrides.nanAggPsNote) +< .withPsNote(TypeEnum.FLOAT, GpuOverrides.nanAggPsNote), +--- +> .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) +> .withPsNote(TypeEnum.FLOAT, nanAggPsNote), +2129c2683 +< GpuOverrides.checkAndTagFloatNanAgg("Max", in.dataType, conf, this) +--- +> checkAndTagFloatNanAgg("Max", in.dataType, conf, this) +2131a2686,2687 +> override def convertToGpu(child: Expression): GpuExpression = +> GpuArrayMax(child) +2136a2693,2694 +> override def convertToGpu(): GpuExpression = +> GpuCreateNamedStruct(childExprs.map(_.convertToGpu())) +2170a2729,2730 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuArrayContains(lhs, rhs) +2182c2742,2746 +< }), +--- +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { +> GpuSortArray(lhs, rhs) +> } +> } +> ), +2207a2772,2773 +> override def convertToGpu(): GpuExpression = +> GpuCreateArray(childExprs.map(_.convertToGpu()), wrapped.useStringTypeWhenEmpty) +2223a2790,2796 +> override def convertToGpu(): GpuExpression = { +> val func = childExprs.head +> val args = childExprs.tail +> GpuLambdaFunction(func.convertToGpu(), +> args.map(_.convertToGpu().asInstanceOf[NamedExpression]), +> in.hidden) +> } +2231a2805,2807 +> override def convertToGpu(): GpuExpression = { +> GpuNamedLambdaVariable(in.name, in.dataType, in.nullable, in.exprId) +> } +2248a2825,2873 +> override def convertToGpu(): GpuExpression = { +> GpuArrayTransform(childExprs.head.convertToGpu(), childExprs(1).convertToGpu()) +> } +> }), +> expr[TransformKeys]( +> "Transform keys in a map using a transform function", +> ExprChecks.projectOnly(TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + +> TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), +> TypeSig.MAP.nested(TypeSig.all), +> Seq( +> ParamCheck("argument", +> TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + +> TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), +> TypeSig.MAP.nested(TypeSig.all)), +> ParamCheck("function", +> // We need to be able to check for duplicate keys (equality) +> TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL, +> TypeSig.all - TypeSig.MAP.nested()))), +> (in, conf, p, r) => new ExprMeta[TransformKeys](in, conf, p, r) { +> override def tagExprForGpu(): Unit = { +> SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY).toUpperCase match { +> case "EXCEPTION" => // Good we can support this +> case other => +> willNotWorkOnGpu(s"$other is not supported for config setting" + +> s" ${SQLConf.MAP_KEY_DEDUP_POLICY.key}") +> } +> } +> override def convertToGpu(): GpuExpression = { +> GpuTransformKeys(childExprs.head.convertToGpu(), childExprs(1).convertToGpu()) +> } +> }), +> expr[TransformValues]( +> "Transform values in a map using a transform function", +> ExprChecks.projectOnly(TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + +> TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), +> TypeSig.MAP.nested(TypeSig.all), +> Seq( +> ParamCheck("argument", +> TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + +> TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), +> TypeSig.MAP.nested(TypeSig.all)), +> ParamCheck("function", +> (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + +> TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP).nested(), +> TypeSig.all))), +> (in, conf, p, r) => new ExprMeta[TransformValues](in, conf, p, r) { +> override def convertToGpu(): GpuExpression = { +> GpuTransformValues(childExprs.head.convertToGpu(), childExprs(1).convertToGpu()) +> } +2256a2882,2886 +> override def convertToGpu( +> val0: Expression, +> val1: Expression, +> val2: Expression): GpuExpression = +> GpuStringLocate(val0, val1, val2) +2264a2895,2899 +> override def convertToGpu( +> column: Expression, +> position: Expression, +> length: Expression): GpuExpression = +> GpuSubstring(column, position, length) +2280a2916,2918 +> override def convertToGpu( +> input: Expression, +> repeatTimes: Expression): GpuExpression = GpuStringRepeat(input, repeatTimes) +2288a2927,2931 +> override def convertToGpu( +> column: Expression, +> target: Expression, +> replace: Expression): GpuExpression = +> GpuStringReplace(column, target, replace) +2296a2940,2943 +> override def convertToGpu( +> column: Expression, +> target: Option[Expression] = None): GpuExpression = +> GpuStringTrim(column, target) +2305a2953,2956 +> override def convertToGpu( +> column: Expression, +> target: Option[Expression] = None): GpuExpression = +> GpuStringTrimLeft(column, target) +2314a2966,2969 +> override def convertToGpu( +> column: Expression, +> target: Option[Expression] = None): GpuExpression = +> GpuStringTrimRight(column, target) +2321a2977,2978 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuStartsWith(lhs, rhs) +2328a2986,2987 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuEndsWith(lhs, rhs) +2339a2999 +> override def convertToGpu(child: Seq[Expression]): GpuExpression = GpuConcat(child) +2356a3017,3018 +> override final def convertToGpu(): GpuExpression = +> GpuConcatWs(childExprs.map(_.convertToGpu())) +2366a3029,3030 +> def convertToGpu(): GpuExpression = +> GpuMurmur3Hash(childExprs.map(_.convertToGpu()), a.seed) +2373a3038,3039 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuContains(lhs, rhs) +2380a3047,3048 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuLike(lhs, rhs, a.escapeChar) +2405a3074 +> override def convertToGpu(child: Expression): GpuExpression = GpuLength(child) +2413a3083,3084 +> override def convertToGpu(child: Expression): GpuExpression = +> GpuSize(child, a.legacySizeOfNull) +2420a3092 +> override def convertToGpu(child: Expression): GpuExpression = GpuUnscaledValue(child) +2426a3099,3100 +> override def convertToGpu(child: Expression): GpuExpression = +> GpuMakeDecimal(child, a.precision, a.scale, a.nullOnOverflow) +2440a3115 +> override def convertToGpu(): GpuExpression = GpuExplode(childExprs.head.convertToGpu()) +2454a3130 +> override def convertToGpu(): GpuExpression = GpuPosExplode(childExprs.head.convertToGpu()) +2470c3146,3212 +< }), +--- +> override def convertToGpu(childExpr: Seq[Expression]): GpuExpression = +> GpuReplicateRows(childExpr) +> }), +> expr[CollectList]( +> "Collect a list of non-unique elements, not supported in reduction", +> // GpuCollectList is not yet supported in Reduction context. +> ExprChecks.aggNotReduction( +> TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + +> TypeSig.NULL + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP), +> TypeSig.ARRAY.nested(TypeSig.all), +> Seq(ParamCheck("input", +> (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + +> TypeSig.NULL + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP).nested(), +> TypeSig.all))), +> (c, conf, p, r) => new TypedImperativeAggExprMeta[CollectList](c, conf, p, r) { +> override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = +> GpuCollectList(childExprs.head, c.mutableAggBufferOffset, c.inputAggBufferOffset) +> +> override def aggBufferAttribute: AttributeReference = { +> val aggBuffer = c.aggBufferAttributes.head +> aggBuffer.copy(dataType = c.dataType)(aggBuffer.exprId, aggBuffer.qualifier) +> } +> +> override def createCpuToGpuBufferConverter(): CpuToGpuAggregateBufferConverter = +> new CpuToGpuCollectBufferConverter(c.child.dataType) +> +> override def createGpuToCpuBufferConverter(): GpuToCpuAggregateBufferConverter = +> new GpuToCpuCollectBufferConverter() +> +> override val supportBufferConversion: Boolean = true +> +> // Last does not overflow, so it doesn't need the ANSI check +> override val needsAnsiCheck: Boolean = false +> }), +> expr[CollectSet]( +> "Collect a set of unique elements, not supported in reduction", +> // GpuCollectSet is not yet supported in Reduction context. +> // Compared to CollectList, StructType is NOT in GpuCollectSet because underlying +> // method drop_list_duplicates doesn't support nested types. +> ExprChecks.aggNotReduction( +> TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + +> TypeSig.NULL + TypeSig.STRUCT), +> TypeSig.ARRAY.nested(TypeSig.all), +> Seq(ParamCheck("input", +> (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + +> TypeSig.NULL + TypeSig.STRUCT).nested(), +> TypeSig.all))), +> (c, conf, p, r) => new TypedImperativeAggExprMeta[CollectSet](c, conf, p, r) { +> override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = +> GpuCollectSet(childExprs.head, c.mutableAggBufferOffset, c.inputAggBufferOffset) +> +> override def aggBufferAttribute: AttributeReference = { +> val aggBuffer = c.aggBufferAttributes.head +> aggBuffer.copy(dataType = c.dataType)(aggBuffer.exprId, aggBuffer.qualifier) +> } +> +> override def createCpuToGpuBufferConverter(): CpuToGpuAggregateBufferConverter = +> new CpuToGpuCollectBufferConverter(c.child.dataType) +> +> override def createGpuToCpuBufferConverter(): GpuToCpuAggregateBufferConverter = +> new GpuToCpuCollectBufferConverter() +> +> override val supportBufferConversion: Boolean = true +> +> // Last does not overflow, so it doesn't need the ANSI check +> override val needsAnsiCheck: Boolean = false +> }), +2476a3219,3222 +> override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = { +> val legacyStatisticalAggregate = ShimLoader.getSparkShims.getLegacyStatisticalAggregate +> GpuStddevPop(childExprs.head, !legacyStatisticalAggregate) +> } +2484a3231,3234 +> override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = { +> val legacyStatisticalAggregate = ShimLoader.getSparkShims.getLegacyStatisticalAggregate +> GpuStddevSamp(childExprs.head, !legacyStatisticalAggregate) +> } +2491a3242,3245 +> override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = { +> val legacyStatisticalAggregate = ShimLoader.getSparkShims.getLegacyStatisticalAggregate +> GpuVariancePop(childExprs.head, !legacyStatisticalAggregate) +> } +2498a3253,3256 +> override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = { +> val legacyStatisticalAggregate = ShimLoader.getSparkShims.getLegacyStatisticalAggregate +> GpuVarianceSamp(childExprs.head, !legacyStatisticalAggregate) +> } +2538a3297,3302 +> +> override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = +> GpuApproximatePercentile(childExprs.head, +> childExprs(1).asInstanceOf[GpuLiteral], +> childExprs(2).asInstanceOf[GpuLiteral]) +> +2553a3318,3319 +> override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = +> GpuGetJsonObject(lhs, rhs) +2556c3322 +< expr[ScalarSubquery]( +--- +> expr[org.apache.spark.sql.execution.ScalarSubquery]( +2562,2563c3328,3332 +< (a, conf, p, r) => new ExprMeta[ScalarSubquery](a, conf, p, r) { +< }), +--- +> (a, conf, p, r) => +> new ExprMeta[org.apache.spark.sql.execution.ScalarSubquery](a, conf, p, r) { +> override def convertToGpu(): GpuExpression = GpuScalarSubquery(a.plan, a.exprId) +> } +> ), +2568c3337,3339 +< }), +--- +> override def convertToGpu(): GpuExpression = GpuCreateMap(childExprs.map(_.convertToGpu())) +> } +> ), +2585c3356,3387 +< commonExpressions ++ GpuHiveOverrides.exprs ++ ShimGpuOverrides.shimExpressions +--- +> commonExpressions ++ TimeStamp.getExprs ++ GpuHiveOverrides.exprs ++ +> ShimLoader.getSparkShims.getExprs +> +> def wrapScan[INPUT <: Scan]( +> scan: INPUT, +> conf: RapidsConf, +> parent: Option[RapidsMeta[_, _, _]]): ScanMeta[INPUT] = +> scans.get(scan.getClass) +> .map(r => r.wrap(scan, conf, parent, r).asInstanceOf[ScanMeta[INPUT]]) +> .getOrElse(new RuleNotFoundScanMeta(scan, conf, parent)) +> +> val commonScans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]] = Seq( +> GpuOverrides.scan[CSVScan]( +> "CSV parsing", +> (a, conf, p, r) => new ScanMeta[CSVScan](a, conf, p, r) { +> override def tagSelfForGpu(): Unit = GpuCSVScan.tagSupport(this) +> +> override def convertToGpu(): Scan = +> GpuCSVScan(a.sparkSession, +> a.fileIndex, +> a.dataSchema, +> a.readDataSchema, +> a.readPartitionSchema, +> a.options, +> a.partitionFilters, +> a.dataFilters, +> conf.maxReadBatchSizeRows, +> conf.maxReadBatchSizeBytes) +> })).map(r => (r.getClassFor.asSubclass(classOf[Scan]), r)).toMap +> +> val scans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]] = +> commonScans ++ ShimLoader.getSparkShims.getScans +2590c3392 +< parent: Option[RapidsMeta[_, _]]): PartMeta[INPUT] = +--- +> parent: Option[RapidsMeta[_, _, _]]): PartMeta[INPUT] = +2605a3408,3409 +> override def convertToGpu(): GpuPartitioning = +> GpuHashPartitioning(childExprs.map(_.convertToGpu()), hp.numPartitions) +2615a3420,3427 +> override def convertToGpu(): GpuPartitioning = { +> if (rp.numPartitions > 1) { +> val gpuOrdering = childExprs.map(_.convertToGpu()).asInstanceOf[Seq[SortOrder]] +> GpuRangePartitioning(gpuOrdering, rp.numPartitions) +> } else { +> GpuSinglePartitioning +> } +> } +2620a3433,3435 +> override def convertToGpu(): GpuPartitioning = { +> GpuRoundRobinPartitioning(rrp.numPartitions) +> } +2625a3441 +> override def convertToGpu(): GpuPartitioning = GpuSinglePartitioning +2632c3448 +< parent: Option[RapidsMeta[_, _]]): DataWritingCommandMeta[INPUT] = +--- +> parent: Option[RapidsMeta[_, _, _]]): DataWritingCommandMeta[INPUT] = +2650c3466 +< parent: Option[RapidsMeta[_, _]]): SparkPlanMeta[INPUT] = +--- +> parent: Option[RapidsMeta[_, _, _]]): SparkPlanMeta[INPUT] = +2673c3489,3507 +< (range, conf, p, r) => new SparkPlanMeta[RangeExec](range, conf, p, r) { +--- +> (range, conf, p, r) => { +> new SparkPlanMeta[RangeExec](range, conf, p, r) { +> override def convertToGpu(): GpuExec = +> GpuRangeExec(range.start, range.end, range.step, range.numSlices, range.output, +> conf.gpuTargetBatchSizeBytes) +> } +> }), +> exec[BatchScanExec]( +> "The backend for most file input", +> ExecChecks( +> (TypeSig.commonCudfTypes + TypeSig.STRUCT + TypeSig.MAP + TypeSig.ARRAY + +> TypeSig.DECIMAL_128).nested(), +> TypeSig.all), +> (p, conf, parent, r) => new SparkPlanMeta[BatchScanExec](p, conf, parent, r) { +> override val childScans: scala.Seq[ScanMeta[_]] = +> Seq(GpuOverrides.wrapScan(p.scan, conf, Some(this))) +> +> override def convertToGpu(): GpuExec = +> GpuBatchScanExec(p.output, childScans.head.convertToGpu()) +2680a3515,3516 +> override def convertToGpu(): GpuExec = +> GpuCoalesceExec(coalesce.numPartitions, childPlans.head.convertIfNeeded()) +2693a3530,3532 +> override def convertToGpu(): GpuExec = +> GpuDataWritingCommandExec(childDataWriteCmds.head.convertToGpu(), +> childPlans.head.convertIfNeeded()) +2708a3548,3570 +> override def convertToGpu(): GpuExec = { +> // To avoid metrics confusion we split a single stage up into multiple parts but only +> // if there are multiple partitions to make it worth doing. +> val so = sortOrder.map(_.convertToGpu().asInstanceOf[SortOrder]) +> if (takeExec.child.outputPartitioning.numPartitions == 1) { +> GpuTopN(takeExec.limit, so, +> projectList.map(_.convertToGpu().asInstanceOf[NamedExpression]), +> childPlans.head.convertIfNeeded())(takeExec.sortOrder) +> } else { +> GpuTopN( +> takeExec.limit, +> so, +> projectList.map(_.convertToGpu().asInstanceOf[NamedExpression]), +> ShimLoader.getSparkShims.getGpuShuffleExchangeExec( +> GpuSinglePartitioning, +> GpuTopN( +> takeExec.limit, +> so, +> takeExec.child.output, +> childPlans.head.convertIfNeeded())(takeExec.sortOrder), +> SinglePartition))(takeExec.sortOrder) +> } +> } +2716a3579,3580 +> override def convertToGpu(): GpuExec = +> GpuLocalLimitExec(localLimitExec.limit, childPlans.head.convertIfNeeded()) +2724a3589,3590 +> override def convertToGpu(): GpuExec = +> GpuGlobalLimitExec(globalLimitExec.limit, childPlans.head.convertIfNeeded()) +2731,2734c3597 +< (collectLimitExec, conf, p, r) => +< new SparkPlanMeta[CollectLimitExec](collectLimitExec, conf, p, r) { +< override val childParts: scala.Seq[PartMeta[_]] = +< Seq(GpuOverrides.wrapPart(collectLimitExec.outputPartitioning, conf, Some(this)))}) +--- +> (collectLimitExec, conf, p, r) => new GpuCollectLimitMeta(collectLimitExec, conf, p, r)) +2742a3606,3607 +> override def convertToGpu(): GpuExec = +> GpuFilterExec(childExprs.head.convertToGpu(), childPlans.head.convertIfNeeded()) +2763a3629,3630 +> override def convertToGpu(): GpuExec = +> GpuUnionExec(childPlans.map(_.convertIfNeeded())) +2794a3662,3672 +> override def convertToGpu(): GpuExec = { +> val Seq(left, right) = childPlans.map(_.convertIfNeeded()) +> val joinExec = GpuCartesianProductExec( +> left, +> right, +> None, +> conf.gpuTargetBatchSizeBytes) +> // The GPU does not yet support conditional joins, so conditions are implemented +> // as a filter after the join when possible. +> condition.map(c => GpuFilterExec(c.convertToGpu(), joinExec)).getOrElse(joinExec) +> } +2807a3686,3701 +> exec[ObjectHashAggregateExec]( +> "The backend for hash based aggregations supporting TypedImperativeAggregate functions", +> ExecChecks( +> // note that binary input is allowed here but there are additional checks later on to +> // check that we have can support binary in the context of aggregate buffer conversions +> (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + +> TypeSig.MAP + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.BINARY) +> .nested() +> .withPsNote(TypeEnum.BINARY, "only allowed when aggregate buffers can be " + +> "converted between CPU and GPU") +> .withPsNote(TypeEnum.ARRAY, "not allowed for grouping expressions") +> .withPsNote(TypeEnum.MAP, "not allowed for grouping expressions") +> .withPsNote(TypeEnum.STRUCT, +> "not allowed for grouping expressions if containing Array or Map as child"), +> TypeSig.all), +> (agg, conf, p, r) => new GpuObjectHashAggregateExecMeta(agg, conf, p, r)), +2814,2815d3707 +< // SPARK 2.x we can't check for the TypedImperativeAggregate properly so +< // map/arrya/struct left off +2818c3710 +< TypeSig.MAP + TypeSig.BINARY) +--- +> TypeSig.MAP + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.BINARY) +2822c3714,3717 +< .withPsNote(TypeEnum.MAP, "not allowed for grouping expressions"), +--- +> .withPsNote(TypeEnum.ARRAY, "not allowed for grouping expressions") +> .withPsNote(TypeEnum.MAP, "not allowed for grouping expressions") +> .withPsNote(TypeEnum.STRUCT, +> "not allowed for grouping expressions if containing Array or Map as child"), +2825,2826d3719 +< // SPARK 2.x we can't check for the TypedImperativeAggregate properly so don't say we do the +< // ObjectHashAggregate +2859c3752 +< (sample, conf, p, r) => new GpuSampleExecMeta(sample, conf, p, r) {} +--- +> (sample, conf, p, r) => new GpuSampleExecMeta(sample, conf, p, r) +2861,2863c3754,3787 +< // ShimLoader.getSparkShims.aqeShuffleReaderExec, +< // ShimLoader.getSparkShims.neverReplaceShowCurrentNamespaceCommand, +< neverReplaceExec[ExecutedCommandExec]("Table metadata operation") +--- +> exec[SubqueryBroadcastExec]( +> "Plan to collect and transform the broadcast key values", +> ExecChecks(TypeSig.all, TypeSig.all), +> (s, conf, p, r) => new GpuSubqueryBroadcastMeta(s, conf, p, r) +> ), +> ShimLoader.getSparkShims.aqeShuffleReaderExec, +> exec[FlatMapCoGroupsInPandasExec]( +> "The backend for CoGrouped Aggregation Pandas UDF, it runs on CPU itself now but supports" + +> " scheduling GPU resources for the Python process when enabled", +> ExecChecks.hiddenHack(), +> (flatCoPy, conf, p, r) => new GpuFlatMapCoGroupsInPandasExecMeta(flatCoPy, conf, p, r)) +> .disabledByDefault("Performance is not ideal now"), +> neverReplaceExec[AlterNamespaceSetPropertiesExec]("Namespace metadata operation"), +> neverReplaceExec[CreateNamespaceExec]("Namespace metadata operation"), +> neverReplaceExec[DescribeNamespaceExec]("Namespace metadata operation"), +> neverReplaceExec[DropNamespaceExec]("Namespace metadata operation"), +> neverReplaceExec[SetCatalogAndNamespaceExec]("Namespace metadata operation"), +> ShimLoader.getSparkShims.neverReplaceShowCurrentNamespaceCommand, +> neverReplaceExec[ShowNamespacesExec]("Namespace metadata operation"), +> neverReplaceExec[ExecutedCommandExec]("Table metadata operation"), +> neverReplaceExec[AlterTableExec]("Table metadata operation"), +> neverReplaceExec[CreateTableExec]("Table metadata operation"), +> neverReplaceExec[DeleteFromTableExec]("Table metadata operation"), +> neverReplaceExec[DescribeTableExec]("Table metadata operation"), +> neverReplaceExec[DropTableExec]("Table metadata operation"), +> neverReplaceExec[AtomicReplaceTableExec]("Table metadata operation"), +> neverReplaceExec[RefreshTableExec]("Table metadata operation"), +> neverReplaceExec[RenameTableExec]("Table metadata operation"), +> neverReplaceExec[ReplaceTableExec]("Table metadata operation"), +> neverReplaceExec[ShowTablePropertiesExec]("Table metadata operation"), +> neverReplaceExec[ShowTablesExec]("Table metadata operation"), +> neverReplaceExec[AdaptiveSparkPlanExec]("Wrapper for adaptive query plan"), +> neverReplaceExec[BroadcastQueryStageExec]("Broadcast query stage"), +> neverReplaceExec[ShuffleQueryStageExec]("Shuffle query stage") +2867c3791 +< commonExecs ++ ShimGpuOverrides.shimExecs +--- +> commonExecs ++ ShimLoader.getSparkShims.getExecs +2870,2872c3794 +< // val key = SQLConf.LEGACY_TIME_PARSER_POLICY.key +< val key = "2xgone" +< val policy = SQLConf.get.getConfString(key, "EXCEPTION") +--- +> val policy = SQLConf.get.getConfString(SQLConf.LEGACY_TIME_PARSER_POLICY.key, "EXCEPTION") +2879a3802,3806 +> val preRowToColProjection = TreeNodeTag[Seq[NamedExpression]]("rapids.gpu.preRowToColProcessing") +> +> val postColToRowProjection = TreeNodeTag[Seq[NamedExpression]]( +> "rapids.gpu.postColToRowProcessing") +> +2885a3813,3820 +> private def doConvertPlan(wrap: SparkPlanMeta[SparkPlan], conf: RapidsConf, +> optimizations: Seq[Optimization]): SparkPlan = { +> val convertedPlan = wrap.convertIfNeeded() +> val sparkPlan = addSortsIfNeeded(convertedPlan, conf) +> GpuOverrides.listeners.foreach(_.optimizedPlan(wrap, sparkPlan, optimizations)) +> sparkPlan +> } +> +2888c3823,3871 +< Seq.empty +--- +> if (conf.optimizerEnabled) { +> // we need to run these rules both before and after CBO because the cost +> // is impacted by forcing operators onto CPU due to other rules that we have +> wrap.runAfterTagRules() +> val optimizer = try { +> ShimLoader.newInstanceOf[Optimizer](conf.optimizerClassName) +> } catch { +> case e: Exception => +> throw new RuntimeException(s"Failed to create optimizer ${conf.optimizerClassName}", e) +> } +> optimizer.optimize(conf, wrap) +> } else { +> Seq.empty +> } +> } +> +> private def addSortsIfNeeded(plan: SparkPlan, conf: RapidsConf): SparkPlan = { +> plan.transformUp { +> case operator: SparkPlan => +> ensureOrdering(operator, conf) +> } +> } +> +> // copied from Spark EnsureRequirements but only does the ordering checks and +> // check to convert any SortExec added to GpuSortExec +> private def ensureOrdering(operator: SparkPlan, conf: RapidsConf): SparkPlan = { +> val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering +> var children: Seq[SparkPlan] = operator.children +> assert(requiredChildOrderings.length == children.length) +> +> // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings: +> children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => +> // If child.outputOrdering already satisfies the requiredOrdering, we do not need to sort. +> if (SortOrder.orderingSatisfies(child.outputOrdering, requiredOrdering)) { +> child +> } else { +> val sort = SortExec(requiredOrdering, global = false, child = child) +> // just specifically check Sort to see if we can change Sort to GPUSort +> val sortMeta = new GpuSortMeta(sort, conf, None, new SortDataFromReplacementRule) +> sortMeta.initReasons() +> sortMeta.tagPlanForGpu() +> if (sortMeta.canThisBeReplaced) { +> sortMeta.convertToGpu() +> } else { +> sort +> } +> } +> } +> operator.withNewChildren(children) +2898,2899c3881,3887 +< // Only run the explain and don't actually convert or run on GPU. +< def explainPotentialGpuPlan(df: DataFrame, explain: String = "ALL"): String = { +--- +> /** +> * Only run the explain and don't actually convert or run on GPU. +> * This gets the plan from the dataframe so it's after catalyst has run through all the +> * rules to modify the plan. This means we have to try to undo some of the last rules +> * to make it close to when the columnar rules would normally run on the plan. +> */ +> def explainPotentialGpuPlan(df: DataFrame, explain: String): String = { +2925a3914,3930 +> /** +> * Use explain mode on an active SQL plan as its processed through catalyst. +> * This path is the same as being run through the plugin running on hosts with +> * GPUs. +> */ +> private def explainCatalystSQLPlan(updatedPlan: SparkPlan, conf: RapidsConf): Unit = { +> val explainSetting = if (conf.shouldExplain) { +> conf.explain +> } else { +> "ALL" +> } +> val explainOutput = explainSinglePlan(updatedPlan, conf, explainSetting) +> if (explainOutput.nonEmpty) { +> logWarning(s"\n$explainOutput") +> } +> } +> +2948c3953 +< // case c2r: ColumnarToRowExec => prepareExplainOnly(c2r.child) +--- +> case c2r: ColumnarToRowExec => prepareExplainOnly(c2r.child) +2950,2951c3955,3956 +< // case aqe: AdaptiveSparkPlanExec => +< // prepareExplainOnly(ShimLoader.getSparkShims.getAdaptiveInputPlan(aqe)) +--- +> case aqe: AdaptiveSparkPlanExec => +> prepareExplainOnly(ShimLoader.getSparkShims.getAdaptiveInputPlan(aqe)) +2958,2962c3963,3985 +< // Spark 2.x +< object GpuUserDefinedFunction { +< // UDFs can support all types except UDT which does not have a clear columnar representation. +< val udfTypeSig: TypeSig = (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + +< TypeSig.BINARY + TypeSig.CALENDAR + TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT).nested() +--- +> class ExplainPlanImpl extends ExplainPlanBase { +> override def explainPotentialGpuPlan(df: DataFrame, explain: String): String = { +> GpuOverrides.explainPotentialGpuPlan(df, explain) +> } +> } +> +> // work around any GpuOverride failures +> object GpuOverrideUtil extends Logging { +> def tryOverride(fn: SparkPlan => SparkPlan): SparkPlan => SparkPlan = { plan => +> val planOriginal = plan.clone() +> val failOnError = TEST_CONF.get(plan.conf) || !SUPPRESS_PLANNING_FAILURE.get(plan.conf) +> try { +> fn(plan) +> } catch { +> case NonFatal(t) if !failOnError => +> logWarning("Failed to apply GPU overrides, falling back on the original plan: " + t, t) +> planOriginal +> case fatal: Throwable => +> logError("Encountered an exception applying GPU overrides " + fatal, fatal) +> throw fatal +> } +> } +> } +2963a3987,4058 +> /** Tag the initial plan when AQE is enabled */ +> case class GpuQueryStagePrepOverrides() extends Rule[SparkPlan] with Logging { +> override def apply(sparkPlan: SparkPlan): SparkPlan = GpuOverrideUtil.tryOverride { plan => +> // Note that we disregard the GPU plan returned here and instead rely on side effects of +> // tagging the underlying SparkPlan. +> GpuOverrides().apply(plan) +> // return the original plan which is now modified as a side-effect of invoking GpuOverrides +> plan +> }(sparkPlan) +> } +> +> case class GpuOverrides() extends Rule[SparkPlan] with Logging { +> +> // Spark calls this method once for the whole plan when AQE is off. When AQE is on, it +> // gets called once for each query stage (where a query stage is an `Exchange`). +> override def apply(sparkPlan: SparkPlan): SparkPlan = GpuOverrideUtil.tryOverride { plan => +> val conf = new RapidsConf(plan.conf) +> if (conf.isSqlEnabled && conf.isSqlExecuteOnGPU) { +> GpuOverrides.logDuration(conf.shouldExplain, +> t => f"Plan conversion to the GPU took $t%.2f ms") { +> val updatedPlan = updateForAdaptivePlan(plan, conf) +> applyOverrides(updatedPlan, conf) +> } +> } else if (conf.isSqlEnabled && conf.isSqlExplainOnlyEnabled) { +> // this mode logs the explain output and returns the original CPU plan +> val updatedPlan = updateForAdaptivePlan(plan, conf) +> GpuOverrides.explainCatalystSQLPlan(updatedPlan, conf) +> plan +> } else { +> plan +> } +> }(sparkPlan) +> +> private def updateForAdaptivePlan(plan: SparkPlan, conf: RapidsConf): SparkPlan = { +> if (plan.conf.adaptiveExecutionEnabled) { +> // AQE can cause Spark to inject undesired CPU shuffles into the plan because GPU and CPU +> // distribution expressions are not semantically equal. +> val newPlan = GpuOverrides.removeExtraneousShuffles(plan, conf) +> +> // AQE can cause ReusedExchangeExec instance to cache the wrong aggregation buffer type +> // compared to the desired buffer type from a reused GPU shuffle. +> GpuOverrides.fixupReusedExchangeExecs(newPlan) +> } else { +> plan +> } +> } +> +> private def applyOverrides(plan: SparkPlan, conf: RapidsConf): SparkPlan = { +> val wrap = GpuOverrides.wrapAndTagPlan(plan, conf) +> val reasonsToNotReplaceEntirePlan = wrap.getReasonsNotToReplaceEntirePlan +> if (conf.allowDisableEntirePlan && reasonsToNotReplaceEntirePlan.nonEmpty) { +> if (conf.shouldExplain) { +> logWarning("Can't replace any part of this plan due to: " + +> s"${reasonsToNotReplaceEntirePlan.mkString(",")}") +> } +> plan +> } else { +> val optimizations = GpuOverrides.getOptimizations(wrap, conf) +> wrap.runAfterTagRules() +> if (conf.shouldExplain) { +> wrap.tagForExplain() +> val explain = wrap.explain(conf.shouldExplainAll) +> if (explain.nonEmpty) { +> logWarning(s"\n$explain") +> if (conf.optimizerShouldExplainAll && optimizations.nonEmpty) { +> logWarning(s"Cost-based optimizations applied:\n${optimizations.mkString("\n")}") +> } +> } +> } +> GpuOverrides.doConvertPlan(wrap, conf, optimizations) +> } +> } diff --git a/scripts/spark2diffs/GpuParquetFileFormat.diff b/scripts/spark2diffs/GpuParquetFileFormat.diff new file mode 100644 index 00000000000..980d6308a00 --- /dev/null +++ b/scripts/spark2diffs/GpuParquetFileFormat.diff @@ -0,0 +1,36 @@ +3c3 +< meta: RapidsMeta[_, _], +--- +> meta: RapidsMeta[_, _, _], +6c6 +< schema: StructType): Unit = { +--- +> schema: StructType): Option[GpuParquetFileFormat] = { +52,56c52 +< // Spark 2.x doesn't have the rebase mode because the changes of calendar type weren't made +< // so just skip the checks, since this is just explain only it would depend on how +< // they set when they get to 3.x. The default in 3.x is EXCEPTION which would be good +< // for us. +< /* +--- +> +78c74,79 +< */ +--- +> +> if (meta.canThisBeReplaced) { +> Some(new GpuParquetFileFormat) +> } else { +> None +> } +81,82c82 +< // SPARK 2.X - just return String rather then CompressionType +< def parseCompressionType(compressionType: String): Option[String] = { +--- +> def parseCompressionType(compressionType: String): Option[CompressionType] = { +84,85c84,85 +< case "NONE" | "UNCOMPRESSED" => Some("NONE") +< case "SNAPPY" => Some("SNAPPY") +--- +> case "NONE" | "UNCOMPRESSED" => Some(CompressionType.NONE) +> case "SNAPPY" => Some(CompressionType.SNAPPY) diff --git a/scripts/spark2diffs/GpuParquetScanBase.diff b/scripts/spark2diffs/GpuParquetScanBase.diff new file mode 100644 index 00000000000..64a998e8382 --- /dev/null +++ b/scripts/spark2diffs/GpuParquetScanBase.diff @@ -0,0 +1,41 @@ +1a2,27 +> def tagSupport(scanMeta: ScanMeta[ParquetScan]): Unit = { +> val scan = scanMeta.wrapped +> val schema = StructType(scan.readDataSchema ++ scan.readPartitionSchema) +> tagSupport(scan.sparkSession, schema, scanMeta) +> } +> +> def throwIfNeeded( +> table: Table, +> isCorrectedInt96Rebase: Boolean, +> isCorrectedDateTimeRebase: Boolean, +> hasInt96Timestamps: Boolean): Unit = { +> (0 until table.getNumberOfColumns).foreach { i => +> val col = table.getColumn(i) +> // if col is a day +> if (!isCorrectedDateTimeRebase && RebaseHelper.isDateRebaseNeededInRead(col)) { +> throw DataSourceUtils.newRebaseExceptionInRead("Parquet") +> } +> // if col is a time +> else if (hasInt96Timestamps && !isCorrectedInt96Rebase || +> !hasInt96Timestamps && !isCorrectedDateTimeRebase) { +> if (RebaseHelper.isTimeRebaseNeededInRead(col)) { +> throw DataSourceUtils.newRebaseExceptionInRead("Parquet") +> } +> } +> } +> } +6c32 +< meta: RapidsMeta[_, _]): Unit = { +--- +> meta: RapidsMeta[_, _, _]): Unit = { +59,65d84 +< // Spark 2.x doesn't have the rebase mode because the changes of calendar type weren't made +< // so just skip the checks, since this is just explain only it would depend on how +< // they set when they get to 3.x. The default in 3.x is EXCEPTION which would be good +< // for us. +< +< // Spark 2.x doesn't support the rebase mode +< /* +95d113 +< */ diff --git a/scripts/spark2diffs/GpuProjectExecMeta.diff b/scripts/spark2diffs/GpuProjectExecMeta.diff new file mode 100644 index 00000000000..fd5db326022 --- /dev/null +++ b/scripts/spark2diffs/GpuProjectExecMeta.diff @@ -0,0 +1,4 @@ +4c4 +< p: Option[RapidsMeta[_, _]], +--- +> p: Option[RapidsMeta[_, _, _]], diff --git a/scripts/spark2diffs/GpuRLikeMeta.diff b/scripts/spark2diffs/GpuRLikeMeta.diff new file mode 100644 index 00000000000..4f459f3ea36 --- /dev/null +++ b/scripts/spark2diffs/GpuRLikeMeta.diff @@ -0,0 +1,6 @@ +4c4 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +22a23 +> diff --git a/scripts/spark2diffs/GpuReadCSVFileFormat.diff b/scripts/spark2diffs/GpuReadCSVFileFormat.diff new file mode 100644 index 00000000000..b381da92f26 --- /dev/null +++ b/scripts/spark2diffs/GpuReadCSVFileFormat.diff @@ -0,0 +1,4 @@ +5c5 +< fsse.sqlContext.sparkSession, +--- +> ShimLoader.getSparkShims.sessionFromPlan(fsse), diff --git a/scripts/spark2diffs/GpuReadOrcFileFormat.diff b/scripts/spark2diffs/GpuReadOrcFileFormat.diff new file mode 100644 index 00000000000..3e510390759 --- /dev/null +++ b/scripts/spark2diffs/GpuReadOrcFileFormat.diff @@ -0,0 +1,4 @@ +8c8 +< fsse.sqlContext.sparkSession, +--- +> ShimLoader.getSparkShims.sessionFromPlan(fsse), diff --git a/scripts/spark2diffs/GpuReadParquetFileFormat.diff b/scripts/spark2diffs/GpuReadParquetFileFormat.diff new file mode 100644 index 00000000000..b381da92f26 --- /dev/null +++ b/scripts/spark2diffs/GpuReadParquetFileFormat.diff @@ -0,0 +1,4 @@ +5c5 +< fsse.sqlContext.sparkSession, +--- +> ShimLoader.getSparkShims.sessionFromPlan(fsse), diff --git a/scripts/spark2diffs/GpuRegExpExtractMeta.diff b/scripts/spark2diffs/GpuRegExpExtractMeta.diff new file mode 100644 index 00000000000..2d63b1c8ccf --- /dev/null +++ b/scripts/spark2diffs/GpuRegExpExtractMeta.diff @@ -0,0 +1,6 @@ +4c4 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +50a51 +> diff --git a/scripts/spark2diffs/GpuRegExpReplaceMeta.diff b/scripts/spark2diffs/GpuRegExpReplaceMeta.diff new file mode 100644 index 00000000000..be38e9d5104 --- /dev/null +++ b/scripts/spark2diffs/GpuRegExpReplaceMeta.diff @@ -0,0 +1,6 @@ +4c4 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +27a28 +> diff --git a/scripts/spark2diffs/GpuSampleExecMeta.diff b/scripts/spark2diffs/GpuSampleExecMeta.diff new file mode 100644 index 00000000000..fd5db326022 --- /dev/null +++ b/scripts/spark2diffs/GpuSampleExecMeta.diff @@ -0,0 +1,4 @@ +4c4 +< p: Option[RapidsMeta[_, _]], +--- +> p: Option[RapidsMeta[_, _, _]], diff --git a/scripts/spark2diffs/GpuSequenceMeta.diff b/scripts/spark2diffs/GpuSequenceMeta.diff new file mode 100644 index 00000000000..45044eab5f6 --- /dev/null +++ b/scripts/spark2diffs/GpuSequenceMeta.diff @@ -0,0 +1,6 @@ +4c4 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +12a13 +> diff --git a/scripts/spark2diffs/GpuShuffleMeta.diff b/scripts/spark2diffs/GpuShuffleMeta.diff new file mode 100644 index 00000000000..8498313b618 --- /dev/null +++ b/scripts/spark2diffs/GpuShuffleMeta.diff @@ -0,0 +1,23 @@ +4c4 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +44c44 +< if shuffle.sqlContext.sparkSession.sessionState.conf +--- +> if ShimLoader.getSparkShims.sessionFromPlan(shuffle).sessionState.conf +55a56,65 +> // When AQE is enabled, we need to preserve meta data as outputAttributes and +> // availableRuntimeDataTransition to the spark plan for the subsequent query stages. +> // These meta data will be fetched in the SparkPlanMeta of CustomShuffleReaderExec. +> if (wrapped.getTagValue(GpuShuffleMeta.shuffleExOutputAttributes).isEmpty) { +> wrapped.setTagValue(GpuShuffleMeta.shuffleExOutputAttributes, outputAttributes) +> } +> if (wrapped.getTagValue(GpuShuffleMeta.availableRuntimeDataTransition).isEmpty) { +> wrapped.setTagValue(GpuShuffleMeta.availableRuntimeDataTransition, +> availableRuntimeDataTransition) +> } +57c67 +< } +--- +> diff --git a/scripts/spark2diffs/GpuShuffledHashJoinMeta.diff b/scripts/spark2diffs/GpuShuffledHashJoinMeta.diff new file mode 100644 index 00000000000..523abb697b6 --- /dev/null +++ b/scripts/spark2diffs/GpuShuffledHashJoinMeta.diff @@ -0,0 +1,6 @@ +4c4 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +23a24 +> diff --git a/scripts/spark2diffs/GpuSortMergeJoinMeta.diff b/scripts/spark2diffs/GpuSortMergeJoinMeta.diff new file mode 100644 index 00000000000..a93408146fa --- /dev/null +++ b/scripts/spark2diffs/GpuSortMergeJoinMeta.diff @@ -0,0 +1,34 @@ +2c2 +< * Copyright (c) 2022, NVIDIA CORPORATION. +--- +> * Copyright (c) 2019-2022, NVIDIA CORPORATION. +17,20c17 +< package com.nvidia.spark.rapids.shims.v2 +< +< import com.nvidia.spark.rapids._ +< import com.nvidia.spark.rapids.shims.v2._ +--- +> package com.nvidia.spark.rapids +29c26 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +76a74,91 +> } +> +> override def convertToGpu(): GpuExec = { +> val Seq(left, right) = childPlans.map(_.convertIfNeeded()) +> val joinExec = GpuShuffledHashJoinExec( +> leftKeys.map(_.convertToGpu()), +> rightKeys.map(_.convertToGpu()), +> join.joinType, +> buildSide, +> None, +> left, +> right, +> join.isSkewJoin)( +> join.leftKeys, +> join.rightKeys) +> // The GPU does not yet support conditional joins, so conditions are implemented +> // as a filter after the join when possible. +> condition.map(c => GpuFilterExec(c.convertToGpu(), joinExec)).getOrElse(joinExec) diff --git a/scripts/spark2diffs/GpuSortMeta.diff b/scripts/spark2diffs/GpuSortMeta.diff new file mode 100644 index 00000000000..2459d3d0ebb --- /dev/null +++ b/scripts/spark2diffs/GpuSortMeta.diff @@ -0,0 +1,6 @@ +4c4 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +16a17 +> diff --git a/scripts/spark2diffs/GpuSpecifiedWindowFrameMeta.diff b/scripts/spark2diffs/GpuSpecifiedWindowFrameMeta.diff new file mode 100644 index 00000000000..abe83b0a3f5 --- /dev/null +++ b/scripts/spark2diffs/GpuSpecifiedWindowFrameMeta.diff @@ -0,0 +1,8 @@ +4c4 +< parent: Option[RapidsMeta[_,_]], +--- +> parent: Option[RapidsMeta[_,_,_]], +11c11 +< parent: Option[RapidsMeta[_,_]], +--- +> parent: Option[RapidsMeta[_,_,_]], diff --git a/scripts/spark2diffs/GpuSpecifiedWindowFrameMetaBase.diff b/scripts/spark2diffs/GpuSpecifiedWindowFrameMetaBase.diff new file mode 100644 index 00000000000..c1f6b464417 --- /dev/null +++ b/scripts/spark2diffs/GpuSpecifiedWindowFrameMetaBase.diff @@ -0,0 +1,13 @@ +4c4 +< parent: Option[RapidsMeta[_,_]], +--- +> parent: Option[RapidsMeta[_,_,_]], +47,49d46 +< // Spark 2.x different - no days, just months and microseconds +< // could remove this catch but leaving for now +< /* +53,54d49 +< */ +< ci.microseconds +119a115 +> diff --git a/scripts/spark2diffs/GpuStringSplitMeta.diff b/scripts/spark2diffs/GpuStringSplitMeta.diff new file mode 100644 index 00000000000..a349654ab10 --- /dev/null +++ b/scripts/spark2diffs/GpuStringSplitMeta.diff @@ -0,0 +1,18 @@ +4c4 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +6c6 +< extends BinaryExprMeta[StringSplit](expr, conf, parent, rule) { +--- +> extends TernaryExprMeta[StringSplit](expr, conf, parent, rule) { +10,11c10 +< // 2.x uses expr.pattern not expr.regex +< val regexp = extractLit(expr.pattern) +--- +> val regexp = extractLit(expr.regex) +27,28d25 +< // 2.x has no limit parameter +< /* +32d28 +< */ diff --git a/scripts/spark2diffs/GpuToTimestamp.diff b/scripts/spark2diffs/GpuToTimestamp.diff new file mode 100644 index 00000000000..182eec2a41e --- /dev/null +++ b/scripts/spark2diffs/GpuToTimestamp.diff @@ -0,0 +1,10 @@ +1c1 +< object GpuToTimestamp { +--- +> object GpuToTimestamp extends Arm { +26a27 +> +44d44 +< } +45a46 +> /** remove whitespace before month and day */ diff --git a/scripts/spark2diffs/GpuWindowExecMeta.diff b/scripts/spark2diffs/GpuWindowExecMeta.diff new file mode 100644 index 00000000000..88f19fb1f69 --- /dev/null +++ b/scripts/spark2diffs/GpuWindowExecMeta.diff @@ -0,0 +1,4 @@ +3c3 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], diff --git a/scripts/spark2diffs/GpuWindowExpressionMeta.diff b/scripts/spark2diffs/GpuWindowExpressionMeta.diff new file mode 100644 index 00000000000..bf38a9d528a --- /dev/null +++ b/scripts/spark2diffs/GpuWindowExpressionMeta.diff @@ -0,0 +1,4 @@ +4c4 +< parent: Option[RapidsMeta[_,_]], +--- +> parent: Option[RapidsMeta[_,_,_]], diff --git a/scripts/spark2diffs/GpuWindowExpressionMetaBase.diff b/scripts/spark2diffs/GpuWindowExpressionMetaBase.diff new file mode 100644 index 00000000000..5fb307e1a07 --- /dev/null +++ b/scripts/spark2diffs/GpuWindowExpressionMetaBase.diff @@ -0,0 +1,9 @@ +4c4 +< parent: Option[RapidsMeta[_,_]], +--- +> parent: Option[RapidsMeta[_,_,_]], +124a125,128 +> +> /** +> * Convert what this wraps to a GPU enabled version. +> */ diff --git a/scripts/spark2diffs/GpuWindowInPandasExecMetaBase.diff b/scripts/spark2diffs/GpuWindowInPandasExecMetaBase.diff new file mode 100644 index 00000000000..4fbbd1d5d44 --- /dev/null +++ b/scripts/spark2diffs/GpuWindowInPandasExecMetaBase.diff @@ -0,0 +1,4 @@ +4c4 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], diff --git a/scripts/spark2diffs/GpuWindowSpecDefinitionMeta.diff b/scripts/spark2diffs/GpuWindowSpecDefinitionMeta.diff new file mode 100644 index 00000000000..d9274ace354 --- /dev/null +++ b/scripts/spark2diffs/GpuWindowSpecDefinitionMeta.diff @@ -0,0 +1,21 @@ +4c4 +< parent: Option[RapidsMeta[_,_]], +--- +> parent: Option[RapidsMeta[_,_,_]], +9c9 +< windowSpec.partitionSpec.map(GpuOverrides.wrapExpr(_, conf, Some(this))) +--- +> windowSpec.partitionSpec.map(wrapExpr(_, conf, Some(this))) +11c11 +< windowSpec.orderSpec.map(GpuOverrides.wrapExpr(_, conf, Some(this))) +--- +> windowSpec.orderSpec.map(wrapExpr(_, conf, Some(this))) +13c13 +< GpuOverrides.wrapExpr(windowSpec.frameSpecification, conf, Some(this)) +--- +> wrapExpr(windowSpec.frameSpecification, conf, Some(this)) +21a22,25 +> +> /** +> * Convert what this wraps to a GPU enabled version. +> */ diff --git a/scripts/spark2diffs/GpuhashJoin.diff b/scripts/spark2diffs/GpuhashJoin.diff new file mode 100644 index 00000000000..331be606a48 --- /dev/null +++ b/scripts/spark2diffs/GpuhashJoin.diff @@ -0,0 +1,18 @@ +2c2 +< def tagForGpu(joinType: JoinType, meta: RapidsMeta[_, _]): Unit = { +--- +> def tagForGpu(joinType: JoinType, meta: RapidsMeta[_, _, _]): Unit = { +69c69 +< object GpuHashJoin { +--- +> object GpuHashJoin extends Arm { +72c72 +< meta: RapidsMeta[_, _], +--- +> meta: RapidsMeta[_, _, _], +99a100 +> +120c121 +< } +--- +> diff --git a/scripts/spark2diffs/InputFileBlockRule.diff b/scripts/spark2diffs/InputFileBlockRule.diff new file mode 100644 index 00000000000..fbe426c78a5 --- /dev/null +++ b/scripts/spark2diffs/InputFileBlockRule.diff @@ -0,0 +1,31 @@ +2c2 +< * Copyright (c) 2022, NVIDIA CORPORATION. +--- +> * Copyright (c) 2021, NVIDIA CORPORATION. +20d19 +< import org.apache.spark.sql.catalyst.expressions._ +21a21 +> import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +32,43d31 +< /** +< * Check the Expression is or has Input File expressions. +< * @param exec expression to check +< * @return true or false +< */ +< def checkHasInputFileExpressions(exec: Expression): Boolean = exec match { +< case _: InputFileName => true +< case _: InputFileBlockStart => true +< case _: InputFileBlockLength => true +< case e => e.children.exists(checkHasInputFileExpressions) +< } +< +45c33 +< plan.expressions.exists(checkHasInputFileExpressions) +--- +> plan.expressions.exists(GpuTransitionOverrides.checkHasInputFileExpressions) +79d66 +< /* +84d70 +< */ +104a91 +> diff --git a/scripts/spark2diffs/LiteralExprMeta.diff b/scripts/spark2diffs/LiteralExprMeta.diff new file mode 100644 index 00000000000..d9d5951aba1 --- /dev/null +++ b/scripts/spark2diffs/LiteralExprMeta.diff @@ -0,0 +1,7 @@ +4c4 +< p: Option[RapidsMeta[_, _]], +--- +> p: Option[RapidsMeta[_, _, _]], +8a9,10 +> +> override def convertToGpu(): GpuExpression = GpuLiteral(lit.value, lit.dataType) diff --git a/scripts/spark2diffs/OffsetWindowFunctionMeta.diff b/scripts/spark2diffs/OffsetWindowFunctionMeta.diff new file mode 100644 index 00000000000..93b5d531340 --- /dev/null +++ b/scripts/spark2diffs/OffsetWindowFunctionMeta.diff @@ -0,0 +1,8 @@ +2c2 +< * Copyright (c) 2022, NVIDIA CORPORATION. +--- +> * Copyright (c) 2021, NVIDIA CORPORATION. +27c27 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], diff --git a/scripts/spark2diffs/RapidsConf.diff b/scripts/spark2diffs/RapidsConf.diff new file mode 100644 index 00000000000..8f80b098de0 --- /dev/null +++ b/scripts/spark2diffs/RapidsConf.diff @@ -0,0 +1,29 @@ +2c2 +< * Copyright (c) 2022, NVIDIA CORPORATION. +--- +> * Copyright (c) 2019-2022, NVIDIA CORPORATION. +311c311 +< .createWithDefault(ByteUnit.GiB.toBytes(1).toLong) +--- +> .createWithDefault(ByteUnit.GiB.toBytes(1)) +361c361 +< .createWithDefault(ByteUnit.MiB.toBytes(640).toLong) +--- +> .createWithDefault(ByteUnit.MiB.toBytes(640)) +393c393 +< .createWithDefault(ByteUnit.MiB.toBytes(8).toLong) +--- +> .createWithDefault(ByteUnit.MiB.toBytes(8)) +1369c1369 +< |$SPARK_HOME/bin/spark --jars 'rapids-4-spark_2.12-22.02.0-SNAPSHOT.jar,cudf-22.02.0-SNAPSHOT-cuda11.jar' \ +--- +> |${SPARK_HOME}/bin/spark --jars 'rapids-4-spark_2.12-22.02.0-SNAPSHOT.jar,cudf-22.02.0-SNAPSHOT-cuda11.jar' \ +1424a1425,1428 +> printToggleHeader("Scans\n") +> } +> GpuOverrides.scans.values.toSeq.sortBy(_.tag.toString).foreach(_.confHelp(asTable)) +> if (asTable) { +1431c1435 +< // com.nvidia.spark.rapids.python.PythonConfEntries.init() +--- +> com.nvidia.spark.rapids.python.PythonConfEntries.init() diff --git a/scripts/spark2diffs/RapidsMeta.diff b/scripts/spark2diffs/RapidsMeta.diff new file mode 100644 index 00000000000..93986450c2f --- /dev/null +++ b/scripts/spark2diffs/RapidsMeta.diff @@ -0,0 +1,374 @@ +2c2 +< * Copyright (c) 2022, NVIDIA CORPORATION. +--- +> * Copyright (c) 2019-2021, NVIDIA CORPORATION. +23c23 +< import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, ComplexTypeMergingExpression, Expression, String2TrimExpression, TernaryExpression, UnaryExpression, WindowExpression, WindowFunction} +--- +> import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, ComplexTypeMergingExpression, Expression, QuaternaryExpression, String2TrimExpression, TernaryExpression, UnaryExpression, WindowExpression, WindowFunction} +25a26,27 +> import org.apache.spark.sql.catalyst.trees.TreeNodeTag +> import org.apache.spark.sql.connector.read.Scan +27c29 +< import org.apache.spark.sql.execution.aggregate._ +--- +> import org.apache.spark.sql.execution.aggregate.BaseAggregateExec +30c32 +< import org.apache.spark.sql.execution.window.WindowExec +--- +> import org.apache.spark.sql.rapids.{CpuToGpuAggregateBufferConverter, GpuToCpuAggregateBufferConverter} +54a57 +> val gpuSupportedTag = TreeNodeTag[Set[String]]("rapids.gpu.supported") +67a71,72 +> * @tparam OUTPUT when converting to a GPU enabled version of the plan, the generic base +> * type for all GPU enabled versions. +69c74 +< abstract class RapidsMeta[INPUT <: BASE, BASE]( +--- +> abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE]( +72c77 +< val parent: Option[RapidsMeta[_, _]], +--- +> val parent: Option[RapidsMeta[_, _, _]], +85a91,95 +> * The wrapped scans that should be examined +> */ +> val childScans: Seq[ScanMeta[_]] +> +> /** +95a106,110 +> * Convert what this wraps to a GPU enabled version. +> */ +> def convertToGpu(): OUTPUT +> +> /** +110a126 +> import RapidsMeta.gpuSupportedTag +127a144 +> childScans.foreach(_.recursiveCostPreventsRunningOnGpu()) +133a151 +> childScans.foreach(_.recursiveSparkPlanPreventsRunningOnGpu()) +140a159 +> childScans.foreach(_.recursiveSparkPlanRemoved()) +158a178,183 +> wrapped match { +> case p: SparkPlan => +> p.setTagValue(gpuSupportedTag, +> p.getTagValue(gpuSupportedTag).getOrElse(Set.empty) + because) +> case _ => +> } +214a240,244 +> * Returns true iff all of the scans can be replaced. +> */ +> def canScansBeReplaced: Boolean = childScans.forall(_.canThisBeReplaced) +> +> /** +244a275 +> childScans.foreach(_.tagForGpu()) +380a412 +> childScans.foreach(_.print(append, depth + 1, all)) +403c435 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +405,407c437 +< extends RapidsMeta[INPUT, Partitioning](part, conf, parent, rule) { +< // 2.x - replaced GpuPartitioning with Partitioning, should be fine +< // since BASE only used for convert +--- +> extends RapidsMeta[INPUT, Partitioning, GpuPartitioning](part, conf, parent, rule) { +410a441 +> override val childScans: Seq[ScanMeta[_]] = Seq.empty +431c462 +< parent: Option[RapidsMeta[_, _]]) +--- +> parent: Option[RapidsMeta[_, _, _]]) +437a469,505 +> override def convertToGpu(): GpuPartitioning = +> throw new IllegalStateException("Cannot be converted to GPU") +> } +> +> /** +> * Base class for metadata around `Scan`. +> */ +> abstract class ScanMeta[INPUT <: Scan](scan: INPUT, +> conf: RapidsConf, +> parent: Option[RapidsMeta[_, _, _]], +> rule: DataFromReplacementRule) +> extends RapidsMeta[INPUT, Scan, Scan](scan, conf, parent, rule) { +> +> override val childPlans: Seq[SparkPlanMeta[_]] = Seq.empty +> override val childExprs: Seq[BaseExprMeta[_]] = Seq.empty +> override val childScans: Seq[ScanMeta[_]] = Seq.empty +> override val childParts: Seq[PartMeta[_]] = Seq.empty +> override val childDataWriteCmds: Seq[DataWritingCommandMeta[_]] = Seq.empty +> +> override def tagSelfForGpu(): Unit = {} +> } +> +> /** +> * Metadata for `Scan` with no rule found +> */ +> final class RuleNotFoundScanMeta[INPUT <: Scan]( +> scan: INPUT, +> conf: RapidsConf, +> parent: Option[RapidsMeta[_, _, _]]) +> extends ScanMeta[INPUT](scan, conf, parent, new NoRuleDataFromReplacementRule) { +> +> override def tagSelfForGpu(): Unit = { +> willNotWorkOnGpu(s"GPU does not currently support the operator ${scan.getClass}") +> } +> +> override def convertToGpu(): Scan = +> throw new IllegalStateException("Cannot be converted to GPU") +446c514 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +448c516 +< extends RapidsMeta[INPUT, DataWritingCommand](cmd, conf, parent, rule) { +--- +> extends RapidsMeta[INPUT, DataWritingCommand, GpuDataWritingCommand](cmd, conf, parent, rule) { +451a520 +> override val childScans: Seq[ScanMeta[_]] = Seq.empty +464c533 +< parent: Option[RapidsMeta[_, _]]) +--- +> parent: Option[RapidsMeta[_, _, _]]) +469a539,541 +> +> override def convertToGpu(): GpuDataWritingCommand = +> throw new IllegalStateException("Cannot be converted to GPU") +477c549 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +479c551 +< extends RapidsMeta[INPUT, SparkPlan](plan, conf, parent, rule) { +--- +> extends RapidsMeta[INPUT, SparkPlan, GpuExec](plan, conf, parent, rule) { +484a557 +> childScans.foreach(_.recursiveSparkPlanPreventsRunningOnGpu()) +489a563 +> childScans.foreach(_.recursiveSparkPlanRemoved()) +514a589 +> override val childScans: Seq[ScanMeta[_]] = Seq.empty +550a626,630 +> +> childPlans.head.wrapped +> .getTagValue(GpuOverrides.preRowToColProjection).foreach { r2c => +> wrapped.setTagValue(GpuOverrides.preRowToColProjection, r2c) +> } +592c672 +< /*if (!canScansBeReplaced) { +--- +> if (!canScansBeReplaced) { +594c674 +< } */ +--- +> } +613a694,696 +> wrapped.getTagValue(RapidsMeta.gpuSupportedTag) +> .foreach(_.diff(cannotBeReplacedReasons.get) +> .foreach(willNotWorkOnGpu)) +637c720,724 +< convertToCpu +--- +> if (canThisBeReplaced) { +> convertToGpu() +> } else { +> convertToCpu() +> } +707c794 +< parent: Option[RapidsMeta[_, _]]) +--- +> parent: Option[RapidsMeta[_, _, _]]) +711a799,801 +> +> override def convertToGpu(): GpuExec = +> throw new IllegalStateException("Cannot be converted to GPU") +720c810 +< parent: Option[RapidsMeta[_, _]]) +--- +> parent: Option[RapidsMeta[_, _, _]]) +727a818,820 +> +> override def convertToGpu(): GpuExec = +> throw new IllegalStateException("Cannot be converted to GPU") +768c861 +< case agg: SparkPlan if agg.isInstanceOf[WindowExec] => +--- +> case agg: SparkPlan if ShimLoader.getSparkShims.isWindowFunctionExec(agg) => +770,777c863 +< case agg: HashAggregateExec => +< // Spark 2.x doesn't have the BaseAggregateExec class +< if (agg.groupingExpressions.isEmpty) { +< ReductionAggExprContext +< } else { +< GroupByAggExprContext +< } +< case agg: SortAggregateExec => +--- +> case agg: BaseAggregateExec => +788c874 +< def getRegularOperatorContext(meta: RapidsMeta[_, _]): ExpressionContext = meta.wrapped match { +--- +> def getRegularOperatorContext(meta: RapidsMeta[_, _, _]): ExpressionContext = meta.wrapped match { +844c930 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +846c932 +< extends RapidsMeta[INPUT, Expression](expr, conf, parent, rule) { +--- +> extends RapidsMeta[INPUT, Expression, Expression](expr, conf, parent, rule) { +852a939 +> override val childScans: Seq[ScanMeta[_]] = Seq.empty +991c1078 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +993a1081,1082 +> +> override def convertToGpu(): GpuExpression +1002c1091 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +1005a1095,1099 +> override final def convertToGpu(): GpuExpression = +> convertToGpu(childExprs.head.convertToGpu()) +> +> def convertToGpu(child: Expression): GpuExpression +> +1021c1115 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +1032c1126 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +1037a1132,1134 +> if (needsAnsiCheck) { +> GpuOverrides.checkAndTagAnsiAgg(ansiTypeToCheck, this) +> } +1041a1139,1151 +> +> override final def convertToGpu(): GpuExpression = +> convertToGpu(childExprs.map(_.convertToGpu())) +> +> def convertToGpu(childExprs: Seq[Expression]): GpuExpression +> +> // Set to false if the aggregate doesn't overflow and therefore +> // shouldn't error +> val needsAnsiCheck: Boolean = true +> +> // The type to use to determine whether the aggregate could overflow. +> // Set to None, if we should fallback for all types +> val ansiTypeToCheck: Option[DataType] = Some(expr.dataType) +1050c1160 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +1052a1163,1164 +> +> def convertToGpu(childExprs: Seq[Expression]): GpuExpression +1061c1173 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +1072a1185,1200 +> * Returns a buffer converter who can generate a Expression to transform the aggregation buffer +> * of wrapped function from CPU format to GPU format. The conversion occurs on the CPU, so the +> * generated expression should be a CPU Expression executed by row. +> */ +> def createCpuToGpuBufferConverter(): CpuToGpuAggregateBufferConverter = +> throw new NotImplementedError("The method should be implemented by specific functions") +> +> /** +> * Returns a buffer converter who can generate a Expression to transform the aggregation buffer +> * of wrapped function from GPU format to CPU format. The conversion occurs on the CPU, so the +> * generated expression should be a CPU Expression executed by row. +> */ +> def createGpuToCpuBufferConverter(): GpuToCpuAggregateBufferConverter = +> throw new NotImplementedError("The method should be implemented by specific functions") +> +> /** +1086c1214 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +1088a1217,1223 +> +> override final def convertToGpu(): GpuExpression = { +> val Seq(lhs, rhs) = childExprs.map(_.convertToGpu()) +> convertToGpu(lhs, rhs) +> } +> +> def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression +1095c1230 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +1113c1248,1267 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +> rule: DataFromReplacementRule) +> extends ExprMeta[INPUT](expr, conf, parent, rule) { +> +> override final def convertToGpu(): GpuExpression = { +> val Seq(child0, child1, child2) = childExprs.map(_.convertToGpu()) +> convertToGpu(child0, child1, child2) +> } +> +> def convertToGpu(val0: Expression, val1: Expression, +> val2: Expression): GpuExpression +> } +> +> /** +> * Base class for metadata around `QuaternaryExpression`. +> */ +> abstract class QuaternaryExprMeta[INPUT <: QuaternaryExpression]( +> expr: INPUT, +> conf: RapidsConf, +> parent: Option[RapidsMeta[_, _, _]], +1115a1270,1277 +> +> override final def convertToGpu(): GpuExpression = { +> val Seq(child0, child1, child2, child3) = childExprs.map(_.convertToGpu()) +> convertToGpu(child0, child1, child2, child3) +> } +> +> def convertToGpu(val0: Expression, val1: Expression, +> val2: Expression, val3: Expression): GpuExpression +1121c1283 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +1123a1286,1292 +> +> override final def convertToGpu(): GpuExpression = { +> val gpuCol :: gpuTrimParam = childExprs.map(_.convertToGpu()) +> convertToGpu(gpuCol, gpuTrimParam.headOption) +> } +> +> def convertToGpu(column: Expression, target: Option[Expression] = None): GpuExpression +1132c1301 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +1134a1304,1307 +> override final def convertToGpu(): GpuExpression = +> convertToGpu(childExprs.map(_.convertToGpu())) +> +> def convertToGpu(childExprs: Seq[Expression]): GpuExpression +1143c1316 +< parent: Option[RapidsMeta[_, _]]) +--- +> parent: Option[RapidsMeta[_, _, _]]) +1147a1321,1323 +> +> override def convertToGpu(): GpuExpression = +> throw new IllegalStateException("Cannot be converted to GPU") diff --git a/scripts/spark2diffs/RegexParser.diff b/scripts/spark2diffs/RegexParser.diff new file mode 100644 index 00000000000..00a37973fc6 --- /dev/null +++ b/scripts/spark2diffs/RegexParser.diff @@ -0,0 +1,9 @@ +2c2 +< * Copyright (c) 2022, NVIDIA CORPORATION. +--- +> * Copyright (c) 2021-2022, NVIDIA CORPORATION. +794c794 +< } +--- +> } +\ No newline at end of file diff --git a/scripts/spark2diffs/ReplicateRowsExprMeta.diff b/scripts/spark2diffs/ReplicateRowsExprMeta.diff new file mode 100644 index 00000000000..5779c3dda25 --- /dev/null +++ b/scripts/spark2diffs/ReplicateRowsExprMeta.diff @@ -0,0 +1,6 @@ +4c4 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +6a7 +> diff --git a/scripts/spark2diffs/ScalaUDF.diff b/scripts/spark2diffs/ScalaUDF.diff new file mode 100644 index 00000000000..22fc16800b5 --- /dev/null +++ b/scripts/spark2diffs/ScalaUDF.diff @@ -0,0 +1,14 @@ +1c1 +< GpuOverrides.expr[ScalaUDF]( +--- +> def exprMeta: ExprRule[ScalaUDF] = GpuOverrides.expr[ScalaUDF]( +9a10,18 +> override protected def rowBasedScalaUDF: GpuRowBasedScalaUDFBase = +> GpuRowBasedScalaUDF( +> expr.function, +> expr.dataType, +> childExprs.map(_.convertToGpu()), +> expr.inputEncoders, +> expr.udfName, +> expr.nullable, +> expr.udfDeterministic) diff --git a/scripts/spark2diffs/ScalaUDFMetaBase.diff b/scripts/spark2diffs/ScalaUDFMetaBase.diff new file mode 100644 index 00000000000..003edf94153 --- /dev/null +++ b/scripts/spark2diffs/ScalaUDFMetaBase.diff @@ -0,0 +1,6 @@ +4c4 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +17a18 +> diff --git a/scripts/spark2diffs/SubstringIndexMeta.diff b/scripts/spark2diffs/SubstringIndexMeta.diff new file mode 100644 index 00000000000..33e5885ee4a --- /dev/null +++ b/scripts/spark2diffs/SubstringIndexMeta.diff @@ -0,0 +1,6 @@ +4c4 +< override val parent: Option[RapidsMeta[_, _]], +--- +> override val parent: Option[RapidsMeta[_, _, _]], +20a21 +> diff --git a/scripts/spark2diffs/TimeSub.diff b/scripts/spark2diffs/TimeSub.diff new file mode 100644 index 00000000000..222d72f4684 --- /dev/null +++ b/scripts/spark2diffs/TimeSub.diff @@ -0,0 +1,8 @@ +1c1 +< GpuOverrides.expr[TimeSub]( +--- +> getExprsSansTimeSub + (classOf[TimeSub] -> GpuOverrides.expr[TimeSub]( +18c18 +< }), +--- +> diff --git a/scripts/spark2diffs/TreeNode.diff b/scripts/spark2diffs/TreeNode.diff new file mode 100644 index 00000000000..9c7e90dc82f --- /dev/null +++ b/scripts/spark2diffs/TreeNode.diff @@ -0,0 +1,4 @@ +2c2 +< * Copyright (c) 2022, NVIDIA CORPORATION. +--- +> * Copyright (c) 2021, NVIDIA CORPORATION. diff --git a/scripts/spark2diffs/TypeChecks.diff b/scripts/spark2diffs/TypeChecks.diff new file mode 100644 index 00000000000..ae1f49a9d61 --- /dev/null +++ b/scripts/spark2diffs/TypeChecks.diff @@ -0,0 +1,118 @@ +2c2 +< * Copyright (c) 2022, NVIDIA CORPORATION. +--- +> * Copyright (c) 2020-2021, NVIDIA CORPORATION. +21a22 +> import ai.rapids.cudf.DType +24d24 +< import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION} +168c168 +< private val maxAllowedDecimalPrecision: Int = GpuOverrides.DECIMAL64_MAX_PRECISION, +--- +> private val maxAllowedDecimalPrecision: Int = DType.DECIMAL64_MAX_PRECISION, +270c270 +< meta: RapidsMeta[_, _], +--- +> meta: RapidsMeta[_, _, _], +562c562 +< val DECIMAL_64: TypeSig = decimal(GpuOverrides.DECIMAL64_MAX_PRECISION) +--- +> val DECIMAL_64: TypeSig = decimal(DType.DECIMAL64_MAX_PRECISION) +569c569 +< val DECIMAL_128: TypeSig = decimal(GpuOverrides.DECIMAL128_MAX_PRECISION) +--- +> val DECIMAL_128: TypeSig = decimal(DType.DECIMAL128_MAX_PRECISION) +703c703 +< def tag(meta: RapidsMeta[_, _]): Unit +--- +> def tag(meta: RapidsMeta[_, _, _]): Unit +716c716 +< meta: RapidsMeta[_, _], +--- +> meta: RapidsMeta[_, _, _], +773c773 +< override def tag(rapidsMeta: RapidsMeta[_, _]): Unit = { +--- +> override def tag(rapidsMeta: RapidsMeta[_, _, _]): Unit = { +777c777 +< private[this] def tagBase(rapidsMeta: RapidsMeta[_, _], willNotWork: String => Unit): Unit = { +--- +> private[this] def tagBase(rapidsMeta: RapidsMeta[_, _, _], willNotWork: String => Unit): Unit = { +835c835 +< def tag(meta: RapidsMeta[_, _], +--- +> def tag(meta: RapidsMeta[_, _, _], +846c846 +< override def tag(meta: RapidsMeta[_, _]): Unit = +--- +> override def tag(meta: RapidsMeta[_, _, _]): Unit = +872c872 +< def tag(meta: RapidsMeta[_, _], +--- +> def tag(meta: RapidsMeta[_, _, _], +893c893 +< override def tag(rapidsMeta: RapidsMeta[_, _]): Unit = { +--- +> override def tag(rapidsMeta: RapidsMeta[_, _, _]): Unit = { +968c968 +< override def tag(meta: RapidsMeta[_, _]): Unit = { +--- +> override def tag(meta: RapidsMeta[_, _, _]): Unit = { +1020c1020 +< override def tag(meta: RapidsMeta[_, _]): Unit = { +--- +> override def tag(meta: RapidsMeta[_, _, _]): Unit = { +1063c1063 +< override def tag(meta: RapidsMeta[_, _]): Unit = { +--- +> override def tag(meta: RapidsMeta[_, _, _]): Unit = { +1113c1113 +< override def tag(meta: RapidsMeta[_, _]): Unit = { +--- +> override def tag(meta: RapidsMeta[_, _, _]): Unit = { +1159c1159 +< override def tag(meta: RapidsMeta[_, _]): Unit = { +--- +> override def tag(meta: RapidsMeta[_, _, _]): Unit = { +1195c1195 +< override def tag(meta: RapidsMeta[_, _]): Unit = { +--- +> override def tag(meta: RapidsMeta[_, _, _]): Unit = { +1345c1345 +< override def tag(meta: RapidsMeta[_, _]): Unit = { +--- +> override def tag(meta: RapidsMeta[_, _, _]): Unit = { +1355c1355 +< private[this] def tagBase(meta: RapidsMeta[_, _], willNotWork: String => Unit): Unit = { +--- +> private[this] def tagBase(meta: RapidsMeta[_, _, _], willNotWork: String => Unit): Unit = { +1690,1698d1689 +< def getSparkVersion: String = { +< // hack for databricks, try to find something more reliable? +< if (SPARK_BUILD_USER.equals("Databricks")) { +< SPARK_VERSION + "-databricks" +< } else { +< SPARK_VERSION +< } +< } +< +1713c1704 +< println(s"against version ${getSparkVersion} of Spark. Most of this should still") +--- +> println(s"against version ${ShimLoader.getSparkVersion} of Spark. Most of this should still") +1721c1712 +< println(s"supports a precision up to ${GpuOverrides.DECIMAL64_MAX_PRECISION} digits. Note that") +--- +> println(s"supports a precision up to ${DType.DECIMAL64_MAX_PRECISION} digits. Note that") +1823c1814 +< val allData = allSupportedTypes.toList.map { t => +--- +> val allData = allSupportedTypes.map { t => +1906c1897 +< val allData = allSupportedTypes.toList.map { t => +--- +> val allData = allSupportedTypes.map { t => +2010c2001 +< val allData = allSupportedTypes.toList.map { t => +--- +> val allData = allSupportedTypes.map { t => diff --git a/scripts/spark2diffs/TypeSigUtil.diff b/scripts/spark2diffs/TypeSigUtil.diff new file mode 100644 index 00000000000..3911cb60122 --- /dev/null +++ b/scripts/spark2diffs/TypeSigUtil.diff @@ -0,0 +1,12 @@ +2c2 +< * Copyright (c) 2022, NVIDIA CORPORATION. +--- +> * Copyright (c) 2021, NVIDIA CORPORATION. +19c19 +< import com.nvidia.spark.rapids.{TypeEnum, TypeSig} +--- +> import com.nvidia.spark.rapids.{TypeEnum, TypeSig, TypeSigUtilBase} +24c24 +< object TypeSigUtil extends com.nvidia.spark.rapids.TypeSigUtilBase { +--- +> object TypeSigUtil extends TypeSigUtilBase { diff --git a/scripts/spark2diffs/UnixTimeExprMeta.diff b/scripts/spark2diffs/UnixTimeExprMeta.diff new file mode 100644 index 00000000000..6925f619975 --- /dev/null +++ b/scripts/spark2diffs/UnixTimeExprMeta.diff @@ -0,0 +1,14 @@ +3c3 +< parent: Option[RapidsMeta[_, _]], +--- +> parent: Option[RapidsMeta[_, _, _]], +36,37d35 +< // Spark 2.x - ansi not available +< /* +40,41c38 +< } else */ +< if (!conf.incompatDateFormats) { +--- +> } else if (!conf.incompatDateFormats) { +76a74 +> diff --git a/scripts/spark2diffs/abs.diff b/scripts/spark2diffs/abs.diff new file mode 100644 index 00000000000..c5bd73367e2 --- /dev/null +++ b/scripts/spark2diffs/abs.diff @@ -0,0 +1,3 @@ +6a7,8 +> // ANSI support for ABS was added in 3.2.0 SPARK-33275 +> override def convertToGpu(child: Expression): GpuExpression = GpuAbs(child, false) diff --git a/scripts/spark2diffs/average.diff b/scripts/spark2diffs/average.diff new file mode 100644 index 00000000000..243dcaf2728 --- /dev/null +++ b/scripts/spark2diffs/average.diff @@ -0,0 +1,9 @@ +32c32,37 +< +--- +> +> override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = +> GpuAverage(childExprs.head) +> +> // Average is not supported in ANSI mode right now, no matter the type +> override val ansiTypeToCheck: Option[DataType] = None diff --git a/scripts/spark2diffs/cast.diff b/scripts/spark2diffs/cast.diff new file mode 100644 index 00000000000..0b4eac8a499 --- /dev/null +++ b/scripts/spark2diffs/cast.diff @@ -0,0 +1,5 @@ +4c4,5 +< (cast, conf, p, r) => new CastExprMeta[Cast](cast, false, conf, p, r, +--- +> (cast, conf, p, r) => new CastExprMeta[Cast](cast, +> SparkSession.active.sessionState.conf.ansiEnabled, conf, p, r, diff --git a/scripts/spark2diffs/createCudfDecimal.diff b/scripts/spark2diffs/createCudfDecimal.diff new file mode 100644 index 00000000000..152836ed88c --- /dev/null +++ b/scripts/spark2diffs/createCudfDecimal.diff @@ -0,0 +1,20 @@ +1,7c1,9 +< def createCudfDecimal(precision: Int, scale: Int): Option[String] = { +< if (precision <= GpuOverrides.DECIMAL32_MAX_PRECISION) { +< Some("DECIMAL32") +< } else if (precision <= GpuOverrides.DECIMAL64_MAX_PRECISION) { +< Some("DECIMAL64") +< } else if (precision <= GpuOverrides.DECIMAL128_MAX_PRECISION) { +< Some("DECIMAL128") +--- +> def createCudfDecimal(dt: DecimalType): DType = { +> createCudfDecimal(dt.precision, dt.scale) +> def createCudfDecimal(precision: Int, scale: Int): DType = { +> if (precision <= DType.DECIMAL32_MAX_PRECISION) { +> DType.create(DType.DTypeEnum.DECIMAL32, -scale) +> } else if (precision <= DType.DECIMAL64_MAX_PRECISION) { +> DType.create(DType.DTypeEnum.DECIMAL64, -scale) +> } else if (precision <= DType.DECIMAL128_MAX_PRECISION) { +> DType.create(DType.DTypeEnum.DECIMAL128, -scale) +10d11 +< None diff --git a/scripts/spark2diffs/getPrecisionForIntegralType.diff b/scripts/spark2diffs/getPrecisionForIntegralType.diff new file mode 100644 index 00000000000..ec4b160039e --- /dev/null +++ b/scripts/spark2diffs/getPrecisionForIntegralType.diff @@ -0,0 +1,12 @@ +1,5c1,5 +< def getPrecisionForIntegralType(input: String): Int = input match { +< case "INT8" => 3 // -128 to 127 +< case "INT16" => 5 // -32768 to 32767 +< case "INT32" => 10 // -2147483648 to 2147483647 +< case "INT64" => 19 // -9223372036854775808 to 9223372036854775807 +--- +> def getPrecisionForIntegralType(input: DType): Int = input match { +> case DType.INT8 => 3 // -128 to 127 +> case DType.INT16 => 5 // -32768 to 32767 +> case DType.INT32 => 10 // -2147483648 to 2147483647 +> case DType.INT64 => 19 // -9223372036854775808 to 9223372036854775807 diff --git a/scripts/spark2diffs/optionallyAsDecimalType.diff b/scripts/spark2diffs/optionallyAsDecimalType.diff new file mode 100644 index 00000000000..682fc15f583 --- /dev/null +++ b/scripts/spark2diffs/optionallyAsDecimalType.diff @@ -0,0 +1,4 @@ +4c4 +< val prec = DecimalUtil.getPrecisionForIntegralType(getNonNestedRapidsType(t)) +--- +> val prec = DecimalUtil.getPrecisionForIntegralType(GpuColumnVector.getNonNestedRapidsType(t)) diff --git a/scripts/spark2diffs/toRapidsStringOrNull.diff b/scripts/spark2diffs/toRapidsStringOrNull.diff new file mode 100644 index 00000000000..485a19004ed --- /dev/null +++ b/scripts/spark2diffs/toRapidsStringOrNull.diff @@ -0,0 +1,55 @@ +1,20c1,31 +< def toRapidsStringOrNull(dtype: DataType): Option[String] = { +< dtype match { +< case _: LongType => Some("INT64") +< case _: DoubleType => Some("FLOAT64") +< case _: ByteType => Some("INT8") +< case _: BooleanType => Some("BOOL8") +< case _: ShortType => Some("INT16") +< case _: IntegerType => Some("INT32") +< case _: FloatType => Some("FLOAT32") +< case _: DateType => Some("TIMESTAMP_DAYS") +< case _: TimestampType => Some("TIMESTAMP_MICROSECONDS") +< case _: StringType => Some("STRING") +< case _: BinaryType => Some("LIST") +< case _: NullType => Some("INT8") +< case _: DecimalType => +< // Decimal supportable check has been conducted in the GPU plan overriding stage. +< // So, we don't have to handle decimal-supportable problem at here. +< val dt = dtype.asInstanceOf[DecimalType] +< createCudfDecimal(dt.precision, dt.scale) +< case _ => None +--- +> private static DType toRapidsOrNull(DataType type) { +> if (type instanceof LongType) { +> return DType.INT64; +> } else if (type instanceof DoubleType) { +> return DType.FLOAT64; +> } else if (type instanceof ByteType) { +> return DType.INT8; +> } else if (type instanceof BooleanType) { +> return DType.BOOL8; +> } else if (type instanceof ShortType) { +> return DType.INT16; +> } else if (type instanceof IntegerType) { +> return DType.INT32; +> } else if (type instanceof FloatType) { +> return DType.FLOAT32; +> } else if (type instanceof DateType) { +> return DType.TIMESTAMP_DAYS; +> } else if (type instanceof TimestampType) { +> return DType.TIMESTAMP_MICROSECONDS; +> } else if (type instanceof StringType) { +> return DType.STRING; +> } else if (type instanceof BinaryType) { +> return DType.LIST; +> } else if (type instanceof NullType) { +> // INT8 is used for both in this case +> return DType.INT8; +> } else if (type instanceof DecimalType) { +> // Decimal supportable check has been conducted in the GPU plan overriding stage. +> // So, we don't have to handle decimal-supportable problem at here. +> DecimalType dt = (DecimalType) type; +> return DecimalUtil.createCudfDecimal(dt.precision(), dt.scale()); +21a33 +> return null; diff --git a/spark2-sql-plugin/pom.xml b/spark2-sql-plugin/pom.xml new file mode 100644 index 00000000000..37b06295f14 --- /dev/null +++ b/spark2-sql-plugin/pom.xml @@ -0,0 +1,183 @@ + + + + 4.0.0 + + + com.nvidia + rapids-4-spark-parent + 22.02.0-SNAPSHOT + + rapids-4-spark-sql-meta_2.11 + RAPIDS Accelerator for Apache Spark SQL Plugin Base Meta + The RAPIDS SQL plugin for Apache Spark Base Meta Information + 22.02.0-SNAPSHOT + + + 2.11 + 2.11.12 + ${spark248.version} + ${spark.version} + spark24 + spark24 + + + + org.scala-lang + scala-library + + + org.scalatest + scalatest_${scala.binary.version} + test + + + + + + with-classifier + + true + + + + org.apache.spark + spark-hive_${scala.binary.version} + + + org.apache.spark + spark-sql_${scala.binary.version} + + + + + + + + + + ${project.build.directory}/extra-resources + true + + + ${project.basedir}/.. + META-INF + + + LICENSE + NOTICE + + + + + + org.apache.maven.plugins + maven-jar-plugin + + ${spark.version.classifier} + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-profile-src-20+ + generate-sources + + + + + net.alchim31.maven + scala-maven-plugin + ${scala.plugin.version} + + + eclipse-add-source + + add-source + + + + scala-compile-first + process-resources + + compile + + + + scala-test-compile-first + process-test-resources + + testCompile + + + + attach-scaladocs + verify + + doc-jar + + + + -doc-external-doc:${java.home}/lib/rt.jar#https://docs.oracle.com/javase/${java.major.version}/docs/api/index.html + -doc-external-doc:${settings.localRepository}/${scala.local-lib.path}#https://scala-lang.org/api/${scala.version}/ + -doc-external-doc:${settings.localRepository}/org/apache/spark/spark-sql_${scala.binary.version}/${spark.version}/spark-sql_${scala.binary.version}-${spark.version}.jar#https://spark.apache.org/docs/${spark.version}/api/scala/index.html + + + + + + ${scala.version} + true + true + incremental + + -unchecked + -deprecation + -feature + -explaintypes + -Yno-adapted-args + -Xlint:missing-interpolator + -Xfatal-warnings + + + -Xms1024m + -Xmx1024m + + + -source + ${maven.compiler.source} + -target + ${maven.compiler.target} + -Xlint:all,-serial,-path,-try + + + + + org.apache.rat + apache-rat-plugin + + + org.scalatest + scalatest-maven-plugin + + + + diff --git a/spark2-sql-plugin/src/main/java/com/nvidia/spark/RapidsUDF.java b/spark2-sql-plugin/src/main/java/com/nvidia/spark/RapidsUDF.java new file mode 100644 index 00000000000..bc240b66360 --- /dev/null +++ b/spark2-sql-plugin/src/main/java/com/nvidia/spark/RapidsUDF.java @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark; + +/** A RAPIDS accelerated version of a user-defined function (UDF). */ +public interface RapidsUDF { + /** + * Evaluate a user-defined function with RAPIDS cuDF columnar inputs + * producing a cuDF column as output. The method must return a column of + * the appropriate type that corresponds to the type returned by the CPU + * implementation of the UDF (e.g.: INT32 for int, FLOAT64 for double, + * STRING for String, etc) or a runtime exception will occur when the + * results are marshalled into the expected Spark result type for the UDF. + *

+ * Note that the inputs should NOT be closed by this method, as they will + * be closed by the caller. This method must close any intermediate cuDF + * results produced during the computation (e.g.: `Table`, `ColumnVector` + * or `Scalar` instances). + * @param args columnar inputs to the UDF that will be closed by the caller + * and should not be closed within this method. + * @return columnar output from the user-defined function + */ + // Spark 2.x -just keep the interface, we probably could remove all those check + // but leave for printing the will not work info + // ColumnVector evaluateColumnar(ColumnVector... args); +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/AggregateUtils.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/AggregateUtils.scala new file mode 100644 index 00000000000..aab90a21184 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/AggregateUtils.scala @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.catalyst.expressions.{AttributeSet, If} +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.types.DataType + +object AggregateUtils { + + private val aggs = List("min", "max", "avg", "sum", "count", "first", "last") + + /** + * Return true if the Attribute passed is one of aggregates in the aggs list. + * Use it with caution. We are comparing the name of a column looking for anything that matches + * with the values in aggs. + */ + def validateAggregate(attributes: AttributeSet): Boolean = { + attributes.toSeq.exists(attr => aggs.exists(agg => attr.name.contains(agg))) + } + + /** + * Return true if there are multiple distinct functions along with non-distinct functions. + */ + def shouldFallbackMultiDistinct(aggExprs: Seq[AggregateExpression]): Boolean = { + // Check if there is an `If` within `First`. This is included in the plan for non-distinct + // functions only when multiple distincts along with non-distinct functions are present in the + // query. We fall back to CPU in this case when references of `If` are an aggregate. We cannot + // call `isDistinct` here on aggregateExpressions to get the total number of distinct functions. + // If there are multiple distincts, the plan is rewritten by `RewriteDistinctAggregates` where + // regular aggregations and every distinct aggregation is calculated in a separate group. + aggExprs.map(e => e.aggregateFunction).exists { + func => { + func match { + case First(If(_, _, _), _) if validateAggregate(func.references) => true + case _ => false + } + } + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/DataTypeUtils.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/DataTypeUtils.scala new file mode 100644 index 00000000000..c18fa2baea6 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/DataTypeUtils.scala @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} + +object DataTypeUtils { + def isNestedType(dataType: DataType): Boolean = dataType match { + case _: ArrayType | _: MapType | _: StructType => true + case _ => false + } + + def hasNestedTypes(schema: StructType): Boolean = + schema.exists(f => isNestedType(f.dataType)) +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala new file mode 100644 index 00000000000..e856e306596 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala @@ -0,0 +1,189 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import java.time._ + +import scala.collection.mutable.ListBuffer + +import org.apache.spark.sql.catalyst.util.DateTimeUtils + +/** + * Class for helper functions for Date + */ +object DateUtils { + val unsupportedCharacter = Set( + 'k', 'K','z', 'V', 'c', 'F', 'W', 'Q', 'q', 'G', 'A', 'n', 'N', + 'O', 'X', 'p', '\'', '[', ']', '#', '{', '}', 'Z', 'w', 'e', 'E', 'x', 'Z', 'Y') + + val unsupportedWord = Set( + "u", "uu", "uuu", "uuuu", "uuuuu", "uuuuuu", "uuuuuuu", "uuuuuuuu", "uuuuuuuuu", "uuuuuuuuuu", + "y", "yyy", "yyyyy", "yyyyyy", "yyyyyyy", "yyyyyyyy", "yyyyyyyyy", "yyyyyyyyyy", + "D", "DD", "DDD", "s", "m", "H", "h", "M", "MMM", "MMMM", "MMMMM", "L", "LLL", "LLLL", "LLLLL", + "d", "S", "SS", "SSS", "SSSS", "SSSSS", "SSSSSSSSS", "SSSSSSS", "SSSSSSSS") + + // we support "yy" in some cases, but not when parsing strings + // https://github.com/NVIDIA/spark-rapids/issues/2118 + val unsupportedWordParseFromString = unsupportedWord ++ Set("yy") + + val conversionMap = Map( + "MM" -> "%m", "LL" -> "%m", "dd" -> "%d", "mm" -> "%M", "ss" -> "%S", "HH" -> "%H", + "yy" -> "%y", "yyyy" -> "%Y", "SSSSSS" -> "%f") + + val ONE_SECOND_MICROSECONDS = 1000000 + + val ONE_DAY_SECONDS = 86400L + + val ONE_DAY_MICROSECONDS = 86400000000L + + val EPOCH = "epoch" + val NOW = "now" + val TODAY = "today" + val YESTERDAY = "yesterday" + val TOMORROW = "tomorrow" + + // Spark 2.x - removed isSpark320orlater checks + def specialDatesDays: Map[String, Int] = { + val today = currentDate() + Map( + EPOCH -> 0, + NOW -> today, + TODAY -> today, + YESTERDAY -> (today - 1), + TOMORROW -> (today + 1) + ) + } + + def specialDatesSeconds: Map[String, Long] = { + val today = currentDate() + // spark 2.4 Date utils are different + val now = DateTimeUtils.instantToMicros(Instant.now()) + Map( + EPOCH -> 0, + NOW -> now / 1000000L, + TODAY -> today * ONE_DAY_SECONDS, + YESTERDAY -> (today - 1) * ONE_DAY_SECONDS, + TOMORROW -> (today + 1) * ONE_DAY_SECONDS + ) + } + + def specialDatesMicros: Map[String, Long] = { + val today = currentDate() + val now = DateTimeUtils.instantToMicros(Instant.now()) + Map( + EPOCH -> 0, + NOW -> now, + TODAY -> today * ONE_DAY_MICROSECONDS, + YESTERDAY -> (today - 1) * ONE_DAY_MICROSECONDS, + TOMORROW -> (today + 1) * ONE_DAY_MICROSECONDS + ) + } + + def currentDate(): Int = Math.toIntExact(LocalDate.now().toEpochDay) + + case class FormatKeywordToReplace(word: String, startIndex: Int, endIndex: Int) + + /** + * This function converts a java time format string to a strf format string + * Supports %m,%p,%j,%d,%I,%H,%M,%S,%y,%Y,%f format specifiers. + * %d Day of the month: 01-31 + * %m Month of the year: 01-12 + * %y Year without century: 00-99c + * %Y Year with century: 0001-9999 + * %H 24-hour of the day: 00-23 + * %M Minute of the hour: 00-59 + * %S Second of the minute: 00-59 + * %f 6-digit microsecond: 000000-999999 + * + * reported bugs + * https://github.com/rapidsai/cudf/issues/4160 after the bug is fixed this method + * should also support + * "hh" -> "%I" (12 hour clock) + * "a" -> "%p" ('AM', 'PM') + * "DDD" -> "%j" (Day of the year) + * + * @param format Java time format string + * @param parseString True if we're parsing a string + */ + def toStrf(format: String, parseString: Boolean): String = { + val javaPatternsToReplace = identifySupportedFormatsToReplaceElseThrow( + format, parseString) + replaceFormats(format, javaPatternsToReplace) + } + + def replaceFormats( + format: String, + javaPatternsToReplace: ListBuffer[FormatKeywordToReplace]): String = { + val strf = new StringBuilder(format.length).append(format) + for (pattern <- javaPatternsToReplace.reverse) { + if (conversionMap.contains(pattern.word)) { + strf.replace(pattern.startIndex, pattern.endIndex, conversionMap(pattern.word)) + } + } + strf.toString + } + + def identifySupportedFormatsToReplaceElseThrow( + format: String, + parseString: Boolean): ListBuffer[FormatKeywordToReplace] = { + + val unsupportedWordContextAware = if (parseString) { + unsupportedWordParseFromString + } else { + unsupportedWord + } + var sb = new StringBuilder() + var index = 0; + val patterns = new ListBuffer[FormatKeywordToReplace] + format.foreach(character => { + // We are checking to see if this char is a part of a previously read pattern + // or start of a new one. + if (sb.isEmpty || sb.last == character) { + if (unsupportedCharacter(character)) { + throw TimestampFormatConversionException(s"Unsupported character: $character") + } + sb.append(character) + } else { + // its a new pattern, check if the previous pattern was supported. If its supported, + // add it to the groups and add this char as a start of a new pattern else throw exception + val word = sb.toString + if (unsupportedWordContextAware(word)) { + throw TimestampFormatConversionException(s"Unsupported word: $word") + } + val startIndex = index - word.length + patterns += FormatKeywordToReplace(word, startIndex, startIndex + word.length) + sb = new StringBuilder(format.length) + if (unsupportedCharacter(character)) { + throw TimestampFormatConversionException(s"Unsupported character: $character") + } + sb.append(character) + } + index = index + 1 + }) + if (sb.nonEmpty) { + val word = sb.toString() + if (unsupportedWordContextAware(word)) { + throw TimestampFormatConversionException(s"Unsupported word: $word") + } + val startIndex = format.length - word.length + patterns += FormatKeywordToReplace(word, startIndex, startIndex + word.length) + } + patterns + } + + case class TimestampFormatConversionException(reason: String) extends Exception +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/DecimalUtil.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/DecimalUtil.scala new file mode 100644 index 00000000000..90c4cdc67d1 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/DecimalUtil.scala @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.types._ + +object DecimalUtil { + + def getNonNestedRapidsType(dtype: DataType): String = { + val res = toRapidsStringOrNull(dtype) + res.getOrElse(throw new + IllegalArgumentException(dtype + " is not supported for GPU processing yet.")) + } + + def createCudfDecimal(precision: Int, scale: Int): Option[String] = { + if (precision <= GpuOverrides.DECIMAL32_MAX_PRECISION) { + Some("DECIMAL32") + } else if (precision <= GpuOverrides.DECIMAL64_MAX_PRECISION) { + Some("DECIMAL64") + } else if (precision <= GpuOverrides.DECIMAL128_MAX_PRECISION) { + Some("DECIMAL128") + } else { + throw new IllegalArgumentException(s"precision overflow: $precision") + None + } + } + + // don't want to pull in cudf for explain only so use strings + // instead of DType + def toRapidsStringOrNull(dtype: DataType): Option[String] = { + dtype match { + case _: LongType => Some("INT64") + case _: DoubleType => Some("FLOAT64") + case _: ByteType => Some("INT8") + case _: BooleanType => Some("BOOL8") + case _: ShortType => Some("INT16") + case _: IntegerType => Some("INT32") + case _: FloatType => Some("FLOAT32") + case _: DateType => Some("TIMESTAMP_DAYS") + case _: TimestampType => Some("TIMESTAMP_MICROSECONDS") + case _: StringType => Some("STRING") + case _: BinaryType => Some("LIST") + case _: NullType => Some("INT8") + case _: DecimalType => + // Decimal supportable check has been conducted in the GPU plan overriding stage. + // So, we don't have to handle decimal-supportable problem at here. + val dt = dtype.asInstanceOf[DecimalType] + createCudfDecimal(dt.precision, dt.scale) + case _ => None + } + } + + /** + * Get the number of decimal places needed to hold the integral type held by this column + */ + def getPrecisionForIntegralType(input: String): Int = input match { + case "INT8" => 3 // -128 to 127 + case "INT16" => 5 // -32768 to 32767 + case "INT32" => 10 // -2147483648 to 2147483647 + case "INT64" => 19 // -9223372036854775808 to 9223372036854775807 + case t => throw new IllegalArgumentException(s"Unsupported type $t") + } + // The following types were copied from Spark's DecimalType class + private val BooleanDecimal = DecimalType(1, 0) + + def optionallyAsDecimalType(t: DataType): Option[DecimalType] = t match { + case dt: DecimalType => Some(dt) + case ByteType | ShortType | IntegerType | LongType => + val prec = DecimalUtil.getPrecisionForIntegralType(getNonNestedRapidsType(t)) + Some(DecimalType(prec, 0)) + case BooleanType => Some(BooleanDecimal) + case _ => None + } + + def asDecimalType(t: DataType): DecimalType = optionallyAsDecimalType(t) match { + case Some(dt) => dt + case _ => + throw new IllegalArgumentException( + s"Internal Error: type $t cannot automatically be cast to a supported DecimalType") + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ExplainPlan.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ExplainPlan.scala new file mode 100644 index 00000000000..af973043a39 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ExplainPlan.scala @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids + +import scala.util.control.NonFatal + +import org.apache.spark.sql.DataFrame + +// Base trait visible publicly outside of parallel world packaging. +// It can't be named the same as ExplainPlan object to allow calling from PySpark. +trait ExplainPlanBase { + def explainPotentialGpuPlan(df: DataFrame, explain: String = "ALL"): String +} + +object ExplainPlan { + /** + * Looks at the CPU plan associated with the dataframe and outputs information + * about which parts of the query the RAPIDS Accelerator for Apache Spark + * could place on the GPU. This only applies to the initial plan, so if running + * with adaptive query execution enable, it will not be able to show any changes + * in the plan due to that. + * + * This is very similar output you would get by running the query with the + * Rapids Accelerator enabled and with the config `spark.rapids.sql.enabled` enabled. + * + * Requires the RAPIDS Accelerator for Apache Spark jar and RAPIDS cudf jar be included + * in the classpath but the RAPIDS Accelerator for Apache Spark should be disabled. + * + * {{{ + * val output = com.nvidia.spark.rapids.ExplainPlan.explainPotentialGpuPlan(df) + * }}} + * + * Calling from PySpark: + * + * {{{ + * output = sc._jvm.com.nvidia.spark.rapids.ExplainPlan.explainPotentialGpuPlan(df._jdf, "ALL") + * }}} + * + * @param df The Spark DataFrame to get the query plan from + * @param explain If ALL returns all the explain data, otherwise just returns what does not + * work on the GPU. Default is ALL. + * @return String containing the explained plan. + * @throws java.lang.IllegalArgumentException if an argument is invalid or it is unable to + * determine the Spark version + * @throws java.lang.IllegalStateException if the plugin gets into an invalid state while trying + * to process the plan or there is an unexepected exception. + */ + @throws[IllegalArgumentException] + @throws[IllegalStateException] + def explainPotentialGpuPlan(df: DataFrame, explain: String = "ALL"): String = { + try { + GpuOverrides.explainPotentialGpuPlan(df, explain) + } catch { + case ia: IllegalArgumentException => throw ia + case is: IllegalStateException => throw is + case NonFatal(e) => + val msg = "Unexpected exception trying to run explain on the plan!" + throw new IllegalStateException(msg, e) + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala new file mode 100644 index 00000000000..9feec11050d --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.types._ + +object CudfTDigest { + val dataType: DataType = StructType(Array( + StructField("centroids", ArrayType(StructType(Array( + StructField("mean", DataTypes.DoubleType, nullable = false), + StructField("weight", DataTypes.DoubleType, nullable = false) + )), containsNull = false)), + StructField("min", DataTypes.DoubleType, nullable = false), + StructField("max", DataTypes.DoubleType, nullable = false) + )) +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBroadcastJoinMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBroadcastJoinMeta.scala new file mode 100644 index 00000000000..a4499404606 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBroadcastJoinMeta.scala @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids + +import org.apache.spark.sql.execution.SparkPlan + +abstract class GpuBroadcastJoinMeta[INPUT <: SparkPlan](plan: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends SparkPlanMeta[INPUT](plan, conf, parent, rule) { + + def canBuildSideBeReplaced(buildSide: SparkPlanMeta[_]): Boolean = { + // Spark 2.x - removed some checks only applicable to AQE + buildSide.wrapped match { + case _ => buildSide.canThisBeReplaced + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCastMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCastMeta.scala new file mode 100644 index 00000000000..5f9b226385e --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCastMeta.scala @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.types._ + +/** Meta-data for cast and ansi_cast. */ +final class CastExprMeta[INPUT <: Cast]( + cast: INPUT, + val ansiEnabled: Boolean, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule, + doFloatToIntCheck: Boolean, + // stringToDate supports ANSI mode from Spark v3.2.0. Here is the details. + // https://github.com/apache/spark/commit/6e862792fb + // We do not want to create a shim class for this small change + stringToAnsiDate: Boolean, + toTypeOverride: Option[DataType] = None) + extends UnaryExprMeta[INPUT](cast, conf, parent, rule) { + + def withToTypeOverride(newToType: DecimalType): CastExprMeta[INPUT] = + new CastExprMeta[INPUT](cast, ansiEnabled, conf, parent, rule, + doFloatToIntCheck, stringToAnsiDate, Some(newToType)) + + val fromType: DataType = cast.child.dataType + val toType: DataType = toTypeOverride.getOrElse(cast.dataType) + // 2.x doesn't have the SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING config, so set it to true + val legacyCastToString: Boolean = true + + override def tagExprForGpu(): Unit = recursiveTagExprForGpuCheck() + + private def recursiveTagExprForGpuCheck( + fromDataType: DataType = fromType, + toDataType: DataType = toType, + depth: Int = 0): Unit = { + val checks = rule.getChecks.get.asInstanceOf[CastChecks] + if (depth > 0 && + !checks.gpuCanCast(fromDataType, toDataType)) { + willNotWorkOnGpu(s"Casting child type $fromDataType to $toDataType is not supported") + } + + (fromDataType, toDataType) match { + case (FloatType | DoubleType, ByteType | ShortType | IntegerType | LongType) if + doFloatToIntCheck && !conf.isCastFloatToIntegralTypesEnabled => + willNotWorkOnGpu(buildTagMessage(RapidsConf.ENABLE_CAST_FLOAT_TO_INTEGRAL_TYPES)) + case (dt: DecimalType, _: StringType) => + if (!conf.isCastDecimalToStringEnabled) { + willNotWorkOnGpu("the GPU does not produce the exact same string as Spark produces, " + + s"set ${RapidsConf.ENABLE_CAST_DECIMAL_TO_STRING} to true if semantically " + + s"equivalent decimal strings are sufficient for your application.") + } + if (dt.precision > GpuOverrides.DECIMAL128_MAX_PRECISION) { + willNotWorkOnGpu(s"decimal to string with a " + + s"precision > ${GpuOverrides.DECIMAL128_MAX_PRECISION} is not supported yet") + } + case ( _: DecimalType, _: FloatType | _: DoubleType) if !conf.isCastDecimalToFloatEnabled => + willNotWorkOnGpu("the GPU will use a different strategy from Java's BigDecimal " + + "to convert decimal data types to floating point and this can produce results that " + + "slightly differ from the default behavior in Spark. To enable this operation on " + + s"the GPU, set ${RapidsConf.ENABLE_CAST_DECIMAL_TO_FLOAT} to true.") + case (_: FloatType | _: DoubleType, _: DecimalType) if !conf.isCastFloatToDecimalEnabled => + willNotWorkOnGpu("the GPU will use a different strategy from Java's BigDecimal " + + "to convert floating point data types to decimals and this can produce results that " + + "slightly differ from the default behavior in Spark. To enable this operation on " + + s"the GPU, set ${RapidsConf.ENABLE_CAST_FLOAT_TO_DECIMAL} to true.") + case (_: FloatType | _: DoubleType, _: StringType) if !conf.isCastFloatToStringEnabled => + willNotWorkOnGpu("the GPU will use different precision than Java's toString method when " + + "converting floating point data types to strings and this can produce results that " + + "differ from the default behavior in Spark. To enable this operation on the GPU, set" + + s" ${RapidsConf.ENABLE_CAST_FLOAT_TO_STRING} to true.") + case (_: StringType, dt: DecimalType) if dt.precision + 1 > DecimalType.MAX_PRECISION => + willNotWorkOnGpu(s"Because of rounding requirements we cannot support $dt on the GPU") + case (_: StringType, _: FloatType | _: DoubleType) if !conf.isCastStringToFloatEnabled => + willNotWorkOnGpu("Currently hex values aren't supported on the GPU. Also note " + + "that casting from string to float types on the GPU returns incorrect results when " + + "the string represents any number \"1.7976931348623158E308\" <= x < " + + "\"1.7976931348623159E308\" and \"-1.7976931348623159E308\" < x <= " + + "\"-1.7976931348623158E308\" in both these cases the GPU returns Double.MaxValue " + + "while CPU returns \"+Infinity\" and \"-Infinity\" respectively. To enable this " + + s"operation on the GPU, set ${RapidsConf.ENABLE_CAST_STRING_TO_FLOAT} to true.") + case (_: StringType, _: TimestampType) => + if (!conf.isCastStringToTimestampEnabled) { + willNotWorkOnGpu("the GPU only supports a subset of formats " + + "when casting strings to timestamps. Refer to the CAST documentation " + + "for more details. To enable this operation on the GPU, set" + + s" ${RapidsConf.ENABLE_CAST_STRING_TO_TIMESTAMP} to true.") + } + case (_: StringType, _: DateType) => + // NOOP for anything prior to 3.2.0 + case (_: StringType, dt:DecimalType) => + // Spark 2.x: removed check for + // !ShimLoader.getSparkShims.isCastingStringToNegDecimalScaleSupported + // this dealt with handling a bug fix that is only in newer versions of Spark + // (https://issues.apache.org/jira/browse/SPARK-37451) + // Since we don't know what version of Spark 3 they will be using + // just always say it won't work and they can hopefully figure it out from warning. + if (dt.scale < 0) { + willNotWorkOnGpu("RAPIDS doesn't support casting string to decimal for " + + "negative scale decimal in this version of Spark because of SPARK-37451") + } + case (structType: StructType, StringType) => + structType.foreach { field => + recursiveTagExprForGpuCheck(field.dataType, StringType, depth + 1) + } + case (fromStructType: StructType, toStructType: StructType) => + fromStructType.zip(toStructType).foreach { + case (fromChild, toChild) => + recursiveTagExprForGpuCheck(fromChild.dataType, toChild.dataType, depth + 1) + } + case (ArrayType(elementType, _), StringType) => + recursiveTagExprForGpuCheck(elementType, StringType, depth + 1) + + case (ArrayType(nestedFrom, _), ArrayType(nestedTo, _)) => + recursiveTagExprForGpuCheck(nestedFrom, nestedTo, depth + 1) + + case (MapType(keyFrom, valueFrom, _), MapType(keyTo, valueTo, _)) => + recursiveTagExprForGpuCheck(keyFrom, keyTo, depth + 1) + recursiveTagExprForGpuCheck(valueFrom, valueTo, depth + 1) + + case _ => + } + } + + def buildTagMessage(entry: ConfEntry[_]): String = { + s"${entry.doc}. To enable this operation on the GPU, set ${entry.key} to true." + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpandMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpandMeta.scala new file mode 100644 index 00000000000..c11c4aa0ec0 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpandMeta.scala @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.{ExpandExec, SparkPlan} + +class GpuExpandExecMeta( + expand: ExpandExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends SparkPlanMeta[ExpandExec](expand, conf, parent, rule) { + + private val gpuProjections: Seq[Seq[BaseExprMeta[_]]] = + expand.projections.map(_.map(GpuOverrides.wrapExpr(_, conf, Some(this)))) + + private val outputs: Seq[BaseExprMeta[_]] = + expand.output.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + + override val childExprs: Seq[BaseExprMeta[_]] = gpuProjections.flatten ++ outputs +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGenerateExecMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGenerateExecMeta.scala new file mode 100644 index 00000000000..f3b0b492e59 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGenerateExecMeta.scala @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.catalyst.expressions.Generator +import org.apache.spark.sql.execution.GenerateExec + +class GpuGenerateExecSparkPlanMeta( + gen: GenerateExec, + conf: RapidsConf, + p: Option[RapidsMeta[_, _]], + r: DataFromReplacementRule) extends SparkPlanMeta[GenerateExec](gen, conf, p, r) { + + override val childExprs: Seq[BaseExprMeta[_]] = { + (Seq(gen.generator) ++ gen.requiredChildOutput).map( + GpuOverrides.wrapExpr(_, conf, Some(this))) + } + + override def tagPlanForGpu(): Unit = { + if (gen.outer && + !childExprs.head.asInstanceOf[GeneratorExprMeta[Generator]].supportOuter) { + willNotWorkOnGpu(s"outer is not currently supported with ${gen.generator.nodeName}") + } + } +} + +abstract class GeneratorExprMeta[INPUT <: Generator]( + gen: INPUT, + conf: RapidsConf, + p: Option[RapidsMeta[_, _]], + r: DataFromReplacementRule) extends ExprMeta[INPUT](gen, conf, p, r) { + /* whether supporting outer generate or not */ + val supportOuter: Boolean = false +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScanBase.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScanBase.scala new file mode 100644 index 00000000000..fb9e1d953a2 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScanBase.scala @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types.StructType + +object GpuOrcScanBase { + def tagSupport( + sparkSession: SparkSession, + schema: StructType, + meta: RapidsMeta[_, _]): Unit = { + if (!meta.conf.isOrcEnabled) { + meta.willNotWorkOnGpu("ORC input and output has been disabled. To enable set" + + s"${RapidsConf.ENABLE_ORC} to true") + } + + if (!meta.conf.isOrcReadEnabled) { + meta.willNotWorkOnGpu("ORC input has been disabled. To enable set" + + s"${RapidsConf.ENABLE_ORC_READ} to true") + } + + FileFormatChecks.tag(meta, schema, OrcFormatType, ReadFileOp) + + if (sparkSession.conf + .getOption("spark.sql.orc.mergeSchema").exists(_.toBoolean)) { + meta.willNotWorkOnGpu("mergeSchema and schema evolution is not supported yet") + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala new file mode 100644 index 00000000000..0545a3c87da --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -0,0 +1,2964 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import java.time.ZoneId + +import scala.collection.mutable.ListBuffer +import scala.reflect.ClassTag +import scala.util.control.NonFatal + +import com.nvidia.spark.rapids.RapidsConf.{SUPPRESS_PLANNING_FAILURE, TEST_CONF} +import com.nvidia.spark.rapids.shims.v2._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.ScalarSubquery +import org.apache.spark.sql.execution.aggregate._ +import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand, DataWritingCommandExec, ExecutedCommandExec} +import org.apache.spark.sql.execution.datasources.{FileFormat, InsertIntoHadoopFsRelationCommand} +import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +import org.apache.spark.sql.execution.datasources.json.JsonFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.execution.datasources.text.TextFileFormat +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.execution.window.WindowExec +import org.apache.spark.sql.hive.rapids.GpuHiveOverrides +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.rapids._ +import org.apache.spark.sql.rapids.execution._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +/** + * Base class for all ReplacementRules + * @param doWrap wraps a part of the plan in a [[RapidsMeta]] for further processing. + * @param desc a description of what this part of the plan does. + * @param tag metadata used to determine what INPUT is at runtime. + * @tparam INPUT the exact type of the class we are wrapping. + * @tparam BASE the generic base class for this type of stage, i.e. SparkPlan, Expression, etc. + * @tparam WRAP_TYPE base class that should be returned by doWrap. + */ +abstract class ReplacementRule[INPUT <: BASE, BASE, WRAP_TYPE <: RapidsMeta[INPUT, BASE]]( + protected var doWrap: ( + INPUT, + RapidsConf, + Option[RapidsMeta[_, _]], + DataFromReplacementRule) => WRAP_TYPE, + protected var desc: String, + protected val checks: Option[TypeChecks[_]], + final val tag: ClassTag[INPUT]) extends DataFromReplacementRule { + + private var _incompatDoc: Option[String] = None + private var _disabledDoc: Option[String] = None + private var _visible: Boolean = true + + def isVisible: Boolean = _visible + def description: String = desc + + override def incompatDoc: Option[String] = _incompatDoc + override def disabledMsg: Option[String] = _disabledDoc + override def getChecks: Option[TypeChecks[_]] = checks + + /** + * Mark this expression as incompatible with the original Spark version + * @param str a description of how it is incompatible. + * @return this for chaining. + */ + final def incompat(str: String) : this.type = { + _incompatDoc = Some(str) + this + } + + /** + * Mark this expression as disabled by default. + * @param str a description of why it is disabled by default. + * @return this for chaining. + */ + final def disabledByDefault(str: String) : this.type = { + _disabledDoc = Some(str) + this + } + + final def invisible(): this.type = { + _visible = false + this + } + + /** + * Provide a function that will wrap a spark type in a [[RapidsMeta]] instance that is used for + * conversion to a GPU version. + * @param func the function + * @return this for chaining. + */ + final def wrap(func: ( + INPUT, + RapidsConf, + Option[RapidsMeta[_, _]], + DataFromReplacementRule) => WRAP_TYPE): this.type = { + doWrap = func + this + } + + /** + * Set a description of what the operation does. + * @param str the description. + * @return this for chaining + */ + final def desc(str: String): this.type = { + this.desc = str + this + } + + private var confKeyCache: String = null + protected val confKeyPart: String + + override def confKey: String = { + if (confKeyCache == null) { + confKeyCache = "spark.rapids.sql." + confKeyPart + "." + tag.runtimeClass.getSimpleName + } + confKeyCache + } + + def notes(): Option[String] = if (incompatDoc.isDefined) { + Some(s"This is not 100% compatible with the Spark version because ${incompatDoc.get}") + } else if (disabledMsg.isDefined) { + Some(s"This is disabled by default because ${disabledMsg.get}") + } else { + None + } + + def confHelp(asTable: Boolean = false, sparkSQLFunctions: Option[String] = None): Unit = { + if (_visible) { + val notesMsg = notes() + if (asTable) { + import ConfHelper.makeConfAnchor + print(s"${makeConfAnchor(confKey)}") + if (sparkSQLFunctions.isDefined) { + print(s"|${sparkSQLFunctions.get}") + } + print(s"|$desc|${notesMsg.isEmpty}|") + if (notesMsg.isDefined) { + print(s"${notesMsg.get}") + } else { + print("None") + } + println("|") + } else { + println(s"$confKey:") + println(s"\tEnable (true) or disable (false) the $tag $operationName.") + if (sparkSQLFunctions.isDefined) { + println(s"\tsql function: ${sparkSQLFunctions.get}") + } + println(s"\t$desc") + if (notesMsg.isDefined) { + println(s"\t${notesMsg.get}") + } + println(s"\tdefault: ${notesMsg.isEmpty}") + println() + } + } + } + + final def wrap( + op: BASE, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + r: DataFromReplacementRule): WRAP_TYPE = { + doWrap(op.asInstanceOf[INPUT], conf, parent, r) + } + + def getClassFor: Class[_] = tag.runtimeClass +} + +/** + * Holds everything that is needed to replace an Expression with a GPU enabled version. + */ +class ExprRule[INPUT <: Expression]( + doWrap: ( + INPUT, + RapidsConf, + Option[RapidsMeta[_, _]], + DataFromReplacementRule) => BaseExprMeta[INPUT], + desc: String, + checks: Option[ExprChecks], + tag: ClassTag[INPUT]) + extends ReplacementRule[INPUT, Expression, BaseExprMeta[INPUT]]( + doWrap, desc, checks, tag) { + + override val confKeyPart = "expression" + override val operationName = "Expression" +} + +/** + * Holds everything that is needed to replace a `Scan` with a GPU enabled version. + */ +/* +class ScanRule[INPUT <: Scan]( + doWrap: ( + INPUT, + RapidsConf, + Option[RapidsMeta[_, _, _]], + DataFromReplacementRule) => ScanMeta[INPUT], + desc: String, + tag: ClassTag[INPUT]) + extends ReplacementRule[INPUT, Scan, ScanMeta[INPUT]]( + doWrap, desc, None, tag) { + + override val confKeyPart: String = "input" + override val operationName: String = "Input" +} +*/ +/** + * Holds everything that is needed to replace a `Partitioning` with a GPU enabled version. + */ +class PartRule[INPUT <: Partitioning]( + doWrap: ( + INPUT, + RapidsConf, + Option[RapidsMeta[_, _]], + DataFromReplacementRule) => PartMeta[INPUT], + desc: String, + checks: Option[PartChecks], + tag: ClassTag[INPUT]) + extends ReplacementRule[INPUT, Partitioning, PartMeta[INPUT]]( + doWrap, desc, checks, tag) { + + override val confKeyPart: String = "partitioning" + override val operationName: String = "Partitioning" +} + +/** + * Holds everything that is needed to replace a `SparkPlan` with a GPU enabled version. + */ +class ExecRule[INPUT <: SparkPlan]( + doWrap: ( + INPUT, + RapidsConf, + Option[RapidsMeta[_, _]], + DataFromReplacementRule) => SparkPlanMeta[INPUT], + desc: String, + checks: Option[ExecChecks], + tag: ClassTag[INPUT]) + extends ReplacementRule[INPUT, SparkPlan, SparkPlanMeta[INPUT]]( + doWrap, desc, checks, tag) { + + // TODO finish this... + + override val confKeyPart: String = "exec" + override val operationName: String = "Exec" +} + +/** + * Holds everything that is needed to replace a `DataWritingCommand` with a + * GPU enabled version. + */ +class DataWritingCommandRule[INPUT <: DataWritingCommand]( + doWrap: ( + INPUT, + RapidsConf, + Option[RapidsMeta[_, _]], + DataFromReplacementRule) => DataWritingCommandMeta[INPUT], + desc: String, + tag: ClassTag[INPUT]) + extends ReplacementRule[INPUT, DataWritingCommand, DataWritingCommandMeta[INPUT]]( + doWrap, desc, None, tag) { + + override val confKeyPart: String = "output" + override val operationName: String = "Output" +} + + +final class InsertIntoHadoopFsRelationCommandMeta( + cmd: InsertIntoHadoopFsRelationCommand, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends DataWritingCommandMeta[InsertIntoHadoopFsRelationCommand](cmd, conf, parent, rule) { + + // spark 2.3 doesn't have this so just code it here + def sparkSessionActive: SparkSession = { + SparkSession.getActiveSession.getOrElse(SparkSession.getDefaultSession.getOrElse( + throw new IllegalStateException("No active or default Spark session found"))) + } + + override def tagSelfForGpu(): Unit = { + if (cmd.bucketSpec.isDefined) { + willNotWorkOnGpu("bucketing is not supported") + } + + val spark = sparkSessionActive + + cmd.fileFormat match { + case _: CSVFileFormat => + willNotWorkOnGpu("CSV output is not supported") + case _: JsonFileFormat => + willNotWorkOnGpu("JSON output is not supported") + case f if GpuOrcFileFormat.isSparkOrcFormat(f) => + GpuOrcFileFormat.tagGpuSupport(this, spark, cmd.options, cmd.query.schema) + case _: ParquetFileFormat => + GpuParquetFileFormat.tagGpuSupport(this, spark, cmd.options, cmd.query.schema) + case _: TextFileFormat => + willNotWorkOnGpu("text output is not supported") + case f => + willNotWorkOnGpu(s"unknown file format: ${f.getClass.getCanonicalName}") + } + } +} + + +final class CreateDataSourceTableAsSelectCommandMeta( + cmd: CreateDataSourceTableAsSelectCommand, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends DataWritingCommandMeta[CreateDataSourceTableAsSelectCommand](cmd, conf, parent, rule) { + + private var origProvider: Class[_] = _ + + // spark 2.3 doesn't have this so just code it here + def sparkSessionActive: SparkSession = { + SparkSession.getActiveSession.getOrElse(SparkSession.getDefaultSession.getOrElse( + throw new IllegalStateException("No active or default Spark session found"))) + } + + override def tagSelfForGpu(): Unit = { + if (cmd.table.bucketSpec.isDefined) { + willNotWorkOnGpu("bucketing is not supported") + } + if (cmd.table.provider.isEmpty) { + willNotWorkOnGpu("provider must be defined") + } + + val spark = sparkSessionActive + origProvider = + GpuDataSource.lookupDataSource(cmd.table.provider.get, spark.sessionState.conf) + // Note that the data source V2 always fallsback to the V1 currently. + // If that changes then this will start failing because we don't have a mapping. + origProvider.getConstructor().newInstance() match { + case f: FileFormat if GpuOrcFileFormat.isSparkOrcFormat(f) => + GpuOrcFileFormat.tagGpuSupport(this, spark, cmd.table.storage.properties, cmd.query.schema) + None + case _: ParquetFileFormat => + GpuParquetFileFormat.tagGpuSupport(this, spark, + cmd.table.storage.properties, cmd.query.schema) + None + case ds => + willNotWorkOnGpu(s"Data source class not supported: ${ds}") + None + } + } + +} + +trait GpuOverridesListener { + def optimizedPlan( + plan: SparkPlanMeta[SparkPlan], + sparkPlan: SparkPlan, + costOptimizations: Seq[Optimization]) +} + +sealed trait FileFormatType +object CsvFormatType extends FileFormatType { + override def toString = "CSV" +} +object ParquetFormatType extends FileFormatType { + override def toString = "Parquet" +} +object OrcFormatType extends FileFormatType { + override def toString = "ORC" +} + +sealed trait FileFormatOp +object ReadFileOp extends FileFormatOp { + override def toString = "read" +} +object WriteFileOp extends FileFormatOp { + override def toString = "write" +} + +// copy here for 2.x +sealed abstract class Optimization + +object GpuOverrides extends Logging { + // Spark 2.x - don't pull in cudf so hardcode here + val DECIMAL32_MAX_PRECISION = 9 + val DECIMAL64_MAX_PRECISION = 18 + val DECIMAL128_MAX_PRECISION = 38 + + val FLOAT_DIFFERS_GROUP_INCOMPAT = + "when enabling these, there may be extra groups produced for floating point grouping " + + "keys (e.g. -0.0, and 0.0)" + val CASE_MODIFICATION_INCOMPAT = + "the Unicode version used by cuDF and the JVM may differ, resulting in some " + + "corner-case characters not changing case correctly." + val UTC_TIMEZONE_ID = ZoneId.of("UTC").normalized() + // Based on https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html + private[this] lazy val regexList: Seq[String] = Seq("\\", "\u0000", "\\x", "\t", "\n", "\r", + "\f", "\\a", "\\e", "\\cx", "[", "]", "^", "&", ".", "*", "\\d", "\\D", "\\h", "\\H", "\\s", + "\\S", "\\v", "\\V", "\\w", "\\w", "\\p", "$", "\\b", "\\B", "\\A", "\\G", "\\Z", "\\z", "\\R", + "?", "|", "(", ")", "{", "}", "\\k", "\\Q", "\\E", ":", "!", "<=", ">") + + /** + * Provides a way to log an info message about how long an operation took in milliseconds. + */ + def logDuration[T](shouldLog: Boolean, msg: Double => String)(block: => T): T = { + val start = System.nanoTime() + val ret = block + val end = System.nanoTime() + if (shouldLog) { + val timeTaken = (end - start).toDouble / java.util.concurrent.TimeUnit.MILLISECONDS.toNanos(1) + logInfo(msg(timeTaken)) + } + ret + } + + private[this] val _gpuCommonTypes = TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_64 + + val pluginSupportedOrderableSig: TypeSig = + _gpuCommonTypes + TypeSig.STRUCT.nested(_gpuCommonTypes) + + private[this] def isStructType(dataType: DataType) = dataType match { + case StructType(_) => true + case _ => false + } + + // this listener mechanism is global and is intended for use by unit tests only + private lazy val listeners: ListBuffer[GpuOverridesListener] = + new ListBuffer[GpuOverridesListener]() + + def addListener(listener: GpuOverridesListener): Unit = { + listeners += listener + } + + def removeListener(listener: GpuOverridesListener): Unit = { + listeners -= listener + } + + def removeAllListeners(): Unit = { + listeners.clear() + } + + def canRegexpBeTreatedLikeARegularString(strLit: UTF8String): Boolean = { + val s = strLit.toString + !regexList.exists(pattern => s.contains(pattern)) + } + + @scala.annotation.tailrec + def extractLit(exp: Expression): Option[Literal] = exp match { + case l: Literal => Some(l) + case a: Alias => extractLit(a.child) + case _ => None + } + + def isOfType(l: Option[Literal], t: DataType): Boolean = l.exists(_.dataType == t) + + def isStringLit(exp: Expression): Boolean = + isOfType(extractLit(exp), StringType) + + def extractStringLit(exp: Expression): Option[String] = extractLit(exp) match { + case Some(Literal(v: UTF8String, StringType)) => + val s = if (v == null) null else v.toString + Some(s) + case _ => None + } + + def isLit(exp: Expression): Boolean = extractLit(exp).isDefined + + def isNullLit(lit: Literal): Boolean = { + lit.value == null + } + + def isSupportedStringReplacePattern(exp: Expression): Boolean = { + extractLit(exp) match { + case Some(Literal(null, _)) => false + case Some(Literal(value: UTF8String, DataTypes.StringType)) => + val strLit = value.toString + if (strLit.isEmpty) { + false + } else { + // check for regex special characters, except for \u0000 which we can support + !regexList.filterNot(_ == "\u0000").exists(pattern => strLit.contains(pattern)) + } + case _ => false + } + } + + def areAllSupportedTypes(types: DataType*): Boolean = types.forall(isSupportedType(_)) + + /** + * Is this particular type supported or not. + * @param dataType the type to check + * @param allowNull should NullType be allowed + * @param allowDecimal should DecimalType be allowed + * @param allowBinary should BinaryType be allowed + * @param allowCalendarInterval should CalendarIntervalType be allowed + * @param allowArray should ArrayType be allowed + * @param allowStruct should StructType be allowed + * @param allowStringMaps should a Map[String, String] specifically be allowed + * @param allowMaps should MapType be allowed generically + * @param allowNesting should nested types like array struct and map allow nested types + * within them, or only primitive types. + * @return true if it is allowed else false + */ + def isSupportedType(dataType: DataType, + allowNull: Boolean = false, + allowDecimal: Boolean = false, + allowBinary: Boolean = false, + allowCalendarInterval: Boolean = false, + allowArray: Boolean = false, + allowStruct: Boolean = false, + allowStringMaps: Boolean = false, + allowMaps: Boolean = false, + allowNesting: Boolean = false): Boolean = { + def checkNested(dataType: DataType): Boolean = { + isSupportedType(dataType, + allowNull = allowNull, + allowDecimal = allowDecimal, + allowBinary = allowBinary && allowNesting, + allowCalendarInterval = allowCalendarInterval && allowNesting, + allowArray = allowArray && allowNesting, + allowStruct = allowStruct && allowNesting, + allowStringMaps = allowStringMaps && allowNesting, + allowMaps = allowMaps && allowNesting, + allowNesting = allowNesting) + } + dataType match { + case BooleanType => true + case ByteType => true + case ShortType => true + case IntegerType => true + case LongType => true + case FloatType => true + case DoubleType => true + case DateType => true + case TimestampType => TypeChecks.areTimestampsSupported(ZoneId.systemDefault()) + case StringType => true + case dt: DecimalType if allowDecimal => dt.precision <= GpuOverrides.DECIMAL64_MAX_PRECISION + case NullType => allowNull + case BinaryType => allowBinary + case CalendarIntervalType => allowCalendarInterval + case ArrayType(elementType, _) if allowArray => checkNested(elementType) + case MapType(StringType, StringType, _) if allowStringMaps => true + case MapType(keyType, valueType, _) if allowMaps => + checkNested(keyType) && checkNested(valueType) + case StructType(fields) if allowStruct => + fields.map(_.dataType).forall(checkNested) + case _ => false + } + } + + /** + * Checks to see if any expressions are a String Literal + */ + def isAnyStringLit(expressions: Seq[Expression]): Boolean = + expressions.exists(isStringLit) + + def isOrContainsFloatingPoint(dataType: DataType): Boolean = + TrampolineUtil.dataTypeExistsRecursively(dataType, dt => dt == FloatType || dt == DoubleType) + + def checkAndTagFloatAgg(dataType: DataType, conf: RapidsConf, meta: RapidsMeta[_,_]): Unit = { + if (!conf.isFloatAggEnabled && isOrContainsFloatingPoint(dataType)) { + meta.willNotWorkOnGpu("the GPU will aggregate floating point values in" + + " parallel and the result is not always identical each time. This can cause" + + " some Spark queries to produce an incorrect answer if the value is computed" + + " more than once as part of the same query. To enable this anyways set" + + s" ${RapidsConf.ENABLE_FLOAT_AGG} to true.") + } + } + + def checkAndTagFloatNanAgg( + op: String, + dataType: DataType, + conf: RapidsConf, + meta: RapidsMeta[_,_]): Unit = { + if (conf.hasNans && isOrContainsFloatingPoint(dataType)) { + meta.willNotWorkOnGpu(s"$op aggregation on floating point columns that can contain NaNs " + + "will compute incorrect results. If it is known that there are no NaNs, set " + + s" ${RapidsConf.HAS_NANS} to false.") + } + } + + private val nanAggPsNote = "Input must not contain NaNs and" + + s" ${RapidsConf.HAS_NANS} must be false." + + def expr[INPUT <: Expression]( + desc: String, + pluginChecks: ExprChecks, + doWrap: (INPUT, RapidsConf, Option[RapidsMeta[_, _]], DataFromReplacementRule) + => BaseExprMeta[INPUT]) + (implicit tag: ClassTag[INPUT]): ExprRule[INPUT] = { + assert(desc != null) + assert(doWrap != null) + new ExprRule[INPUT](doWrap, desc, Some(pluginChecks), tag) + } + + def part[INPUT <: Partitioning]( + desc: String, + checks: PartChecks, + doWrap: (INPUT, RapidsConf, Option[RapidsMeta[_, _]], DataFromReplacementRule) + => PartMeta[INPUT]) + (implicit tag: ClassTag[INPUT]): PartRule[INPUT] = { + assert(desc != null) + assert(doWrap != null) + new PartRule[INPUT](doWrap, desc, Some(checks), tag) + } + + /** + * Create an exec rule that should never be replaced, because it is something that should always + * run on the CPU, or should just be ignored totally for what ever reason. + */ + def neverReplaceExec[INPUT <: SparkPlan](desc: String) + (implicit tag: ClassTag[INPUT]): ExecRule[INPUT] = { + assert(desc != null) + def doWrap( + exec: INPUT, + conf: RapidsConf, + p: Option[RapidsMeta[_, _]], + cc: DataFromReplacementRule) = + new DoNotReplaceOrWarnSparkPlanMeta[INPUT](exec, conf, p) + new ExecRule[INPUT](doWrap, desc, None, tag).invisible() + } + + def exec[INPUT <: SparkPlan]( + desc: String, + pluginChecks: ExecChecks, + doWrap: (INPUT, RapidsConf, Option[RapidsMeta[_, _]], DataFromReplacementRule) + => SparkPlanMeta[INPUT]) + (implicit tag: ClassTag[INPUT]): ExecRule[INPUT] = { + assert(desc != null) + assert(doWrap != null) + new ExecRule[INPUT](doWrap, desc, Some(pluginChecks), tag) + } + + def dataWriteCmd[INPUT <: DataWritingCommand]( + desc: String, + doWrap: (INPUT, RapidsConf, Option[RapidsMeta[_, _]], DataFromReplacementRule) + => DataWritingCommandMeta[INPUT]) + (implicit tag: ClassTag[INPUT]): DataWritingCommandRule[INPUT] = { + assert(desc != null) + assert(doWrap != null) + new DataWritingCommandRule[INPUT](doWrap, desc, tag) + } + + def wrapExpr[INPUT <: Expression]( + expr: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]]): BaseExprMeta[INPUT] = + expressions.get(expr.getClass) + .map(r => r.wrap(expr, conf, parent, r).asInstanceOf[BaseExprMeta[INPUT]]) + .getOrElse(new RuleNotFoundExprMeta(expr, conf, parent)) + + lazy val fileFormats: Map[FileFormatType, Map[FileFormatOp, FileFormatChecks]] = Map( + (CsvFormatType, FileFormatChecks( + cudfRead = TypeSig.commonCudfTypes, + cudfWrite = TypeSig.none, + sparkSig = TypeSig.cpuAtomics)), + (ParquetFormatType, FileFormatChecks( + cudfRead = (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT + + TypeSig.ARRAY + TypeSig.MAP).nested(), + cudfWrite = (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT + + TypeSig.ARRAY + TypeSig.MAP).nested(), + sparkSig = (TypeSig.cpuAtomics + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + + TypeSig.UDT).nested())), + (OrcFormatType, FileFormatChecks( + cudfRead = (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.DECIMAL_128 + + TypeSig.STRUCT + TypeSig.MAP).nested(), + cudfWrite = (TypeSig.commonCudfTypes + TypeSig.ARRAY + + // Note Map is not put into nested, now CUDF only support single level map + TypeSig.STRUCT + TypeSig.DECIMAL_128).nested() + TypeSig.MAP, + sparkSig = (TypeSig.cpuAtomics + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + + TypeSig.UDT).nested()))) + + + val commonExpressions: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq( + expr[Literal]( + "Holds a static value from the query", + ExprChecks.projectAndAst( + TypeSig.astTypes, + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.CALENDAR + + TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT) + .nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT), + TypeSig.all), + (lit, conf, p, r) => new LiteralExprMeta(lit, conf, p, r)), + expr[Signum]( + "Returns -1.0, 0.0 or 1.0 as expr is negative, 0 or positive", + ExprChecks.mathUnary, + (a, conf, p, r) => new UnaryExprMeta[Signum](a, conf, p, r) { + }), + expr[Alias]( + "Gives a column a name", + ExprChecks.unaryProjectAndAstInputMatchesOutput( + TypeSig.astTypes, + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.MAP + TypeSig.ARRAY + TypeSig.STRUCT + + TypeSig.DECIMAL_128).nested(), + TypeSig.all), + (a, conf, p, r) => new UnaryAstExprMeta[Alias](a, conf, p, r) { + }), + expr[AttributeReference]( + "References an input column", + ExprChecks.projectAndAst( + TypeSig.astTypes, + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.MAP + TypeSig.ARRAY + + TypeSig.STRUCT + TypeSig.DECIMAL_128).nested(), + TypeSig.all), + (att, conf, p, r) => new BaseExprMeta[AttributeReference](att, conf, p, r) { + // This is the only NOOP operator. It goes away when things are bound + + // There are so many of these that we don't need to print them out, unless it + // will not work on the GPU + override def print(append: StringBuilder, depth: Int, all: Boolean): Unit = { + if (!this.canThisBeReplaced || cannotRunOnGpuBecauseOfSparkPlan) { + super.print(append, depth, all) + } + } + }), + expr[PromotePrecision]( + "PromotePrecision before arithmetic operations between DecimalType data", + ExprChecks.unaryProjectInputMatchesOutput(TypeSig.DECIMAL_128, + TypeSig.DECIMAL_128), + (a, conf, p, r) => new UnaryExprMeta[PromotePrecision](a, conf, p, r) { + }), + expr[CheckOverflow]( + "CheckOverflow after arithmetic operations between DecimalType data", + ExprChecks.unaryProjectInputMatchesOutput(TypeSig.DECIMAL_128, + TypeSig.DECIMAL_128), + (a, conf, p, r) => new ExprMeta[CheckOverflow](a, conf, p, r) { + private[this] def extractOrigParam(expr: BaseExprMeta[_]): BaseExprMeta[_] = + expr.wrapped match { + case lit: Literal if lit.dataType.isInstanceOf[DecimalType] => + val dt = lit.dataType.asInstanceOf[DecimalType] + // Lets figure out if we can make the Literal value smaller + val (newType, value) = lit.value match { + case null => + (DecimalType(0, 0), null) + case dec: Decimal => + val stripped = Decimal(dec.toJavaBigDecimal.stripTrailingZeros()) + val p = stripped.precision + val s = stripped.scale + // allowNegativeScaleOfDecimalEnabled is not in 2.x assume its default false + val t = if (s < 0 && !false) { + // need to adjust to avoid errors about negative scale + DecimalType(p - s, 0) + } else { + DecimalType(p, s) + } + (t, stripped) + case other => + throw new IllegalArgumentException(s"Unexpected decimal literal value $other") + } + expr.asInstanceOf[LiteralExprMeta].withNewLiteral(Literal(value, newType)) + // We avoid unapply for Cast because it changes between versions of Spark + case PromotePrecision(c: Cast) if c.dataType.isInstanceOf[DecimalType] => + val to = c.dataType.asInstanceOf[DecimalType] + val fromType = DecimalUtil.optionallyAsDecimalType(c.child.dataType) + fromType match { + case Some(from) => + val minScale = math.min(from.scale, to.scale) + val fromWhole = from.precision - from.scale + val toWhole = to.precision - to.scale + val minWhole = if (to.scale < from.scale) { + // If the scale is getting smaller in the worst case we need an + // extra whole part to handle rounding up. + math.min(fromWhole + 1, toWhole) + } else { + math.min(fromWhole, toWhole) + } + val newToType = DecimalType(minWhole + minScale, minScale) + if (newToType == from) { + // We can remove the cast totally + val castExpr = expr.childExprs.head + castExpr.childExprs.head + } else if (newToType == to) { + // The cast is already ideal + expr + } else { + val castExpr = expr.childExprs.head.asInstanceOf[CastExprMeta[_]] + castExpr.withToTypeOverride(newToType) + } + case _ => + expr + } + case _ => expr + } + private[this] lazy val binExpr = childExprs.head + private[this] lazy val lhs = extractOrigParam(binExpr.childExprs.head) + private[this] lazy val rhs = extractOrigParam(binExpr.childExprs(1)) + private[this] lazy val lhsDecimalType = + DecimalUtil.asDecimalType(lhs.wrapped.asInstanceOf[Expression].dataType) + private[this] lazy val rhsDecimalType = + DecimalUtil.asDecimalType(rhs.wrapped.asInstanceOf[Expression].dataType) + + override def tagExprForGpu(): Unit = { + a.child match { + // Division and Multiplication of Decimal types is a little odd. Spark will cast the + // inputs to a common wider value where the scale is the max of the two input scales, + // and the precision is max of the two input non-scale portions + the new scale. Then it + // will do the divide or multiply as a BigDecimal value but lie about the return type. + // Finally here in CheckOverflow it will reset the scale and check the precision so that + // Spark knows it fits in the final desired result. + // Here we try to strip out the extra casts, etc to get to as close to the original + // query as possible. This lets us then calculate what CUDF needs to get the correct + // answer, which in some cases is a lot smaller. + case _: Divide => + val intermediatePrecision = + GpuDecimalDivide.nonRoundedIntermediateArgPrecision(lhsDecimalType, + rhsDecimalType, a.dataType) + + if (intermediatePrecision > GpuOverrides.DECIMAL128_MAX_PRECISION) { + if (conf.needDecimalGuarantees) { + binExpr.willNotWorkOnGpu(s"the intermediate precision of " + + s"$intermediatePrecision that is required to guarantee no overflow issues " + + s"for this divide is too large to be supported on the GPU") + } else { + logWarning("Decimal overflow guarantees disabled for " + + s"${lhs.dataType} / ${rhs.dataType} produces ${a.dataType} with an " + + s"intermediate precision of $intermediatePrecision") + } + } + case _: Multiply => + val intermediatePrecision = + GpuDecimalMultiply.nonRoundedIntermediatePrecision(lhsDecimalType, + rhsDecimalType, a.dataType) + if (intermediatePrecision > GpuOverrides.DECIMAL128_MAX_PRECISION) { + if (conf.needDecimalGuarantees) { + binExpr.willNotWorkOnGpu(s"the intermediate precision of " + + s"$intermediatePrecision that is required to guarantee no overflow issues " + + s"for this multiply is too large to be supported on the GPU") + } else { + logWarning("Decimal overflow guarantees disabled for " + + s"${lhs.dataType} * ${rhs.dataType} produces ${a.dataType} with an " + + s"intermediate precision of $intermediatePrecision") + } + } + case _ => // NOOP + } + } + }), + expr[ToDegrees]( + "Converts radians to degrees", + ExprChecks.mathUnary, + (a, conf, p, r) => new UnaryExprMeta[ToDegrees](a, conf, p, r) { + }), + expr[ToRadians]( + "Converts degrees to radians", + ExprChecks.mathUnary, + (a, conf, p, r) => new UnaryExprMeta[ToRadians](a, conf, p, r) { + }), + expr[WindowExpression]( + "Calculates a return value for every input row of a table based on a group (or " + + "\"window\") of rows", + ExprChecks.windowOnly( + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP).nested(), + TypeSig.all, + Seq(ParamCheck("windowFunction", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP).nested(), + TypeSig.all), + ParamCheck("windowSpec", + TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL_64, + TypeSig.numericAndInterval))), + (windowExpression, conf, p, r) => new GpuWindowExpressionMeta(windowExpression, conf, p, r)), + expr[SpecifiedWindowFrame]( + "Specification of the width of the group (or \"frame\") of input rows " + + "around which a window function is evaluated", + ExprChecks.projectOnly( + TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral, + TypeSig.numericAndInterval, + Seq( + ParamCheck("lower", + TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral, + TypeSig.numericAndInterval), + ParamCheck("upper", + TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral, + TypeSig.numericAndInterval))), + (windowFrame, conf, p, r) => new GpuSpecifiedWindowFrameMeta(windowFrame, conf, p, r) ), + expr[WindowSpecDefinition]( + "Specification of a window function, indicating the partitioning-expression, the row " + + "ordering, and the width of the window", + WindowSpecCheck, + (windowSpec, conf, p, r) => new GpuWindowSpecDefinitionMeta(windowSpec, conf, p, r)), + expr[CurrentRow.type]( + "Special boundary for a window frame, indicating stopping at the current row", + ExprChecks.projectOnly(TypeSig.NULL, TypeSig.NULL), + (currentRow, conf, p, r) => new ExprMeta[CurrentRow.type](currentRow, conf, p, r) { + }), + expr[UnboundedPreceding.type]( + "Special boundary for a window frame, indicating all rows preceding the current row", + ExprChecks.projectOnly(TypeSig.NULL, TypeSig.NULL), + (unboundedPreceding, conf, p, r) => + new ExprMeta[UnboundedPreceding.type](unboundedPreceding, conf, p, r) { + }), + expr[UnboundedFollowing.type]( + "Special boundary for a window frame, indicating all rows preceding the current row", + ExprChecks.projectOnly(TypeSig.NULL, TypeSig.NULL), + (unboundedFollowing, conf, p, r) => + new ExprMeta[UnboundedFollowing.type](unboundedFollowing, conf, p, r) { + }), + expr[RowNumber]( + "Window function that returns the index for the row within the aggregation window", + ExprChecks.windowOnly(TypeSig.INT, TypeSig.INT), + (rowNumber, conf, p, r) => new ExprMeta[RowNumber](rowNumber, conf, p, r) { + }), + expr[Rank]( + "Window function that returns the rank value within the aggregation window", + ExprChecks.windowOnly(TypeSig.INT, TypeSig.INT, + repeatingParamCheck = + Some(RepeatingParamCheck("ordering", + TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL, + TypeSig.all))), + (rank, conf, p, r) => new ExprMeta[Rank](rank, conf, p, r) { + }), + expr[DenseRank]( + "Window function that returns the dense rank value within the aggregation window", + ExprChecks.windowOnly(TypeSig.INT, TypeSig.INT, + repeatingParamCheck = + Some(RepeatingParamCheck("ordering", + TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL, + TypeSig.all))), + (denseRank, conf, p, r) => new ExprMeta[DenseRank](denseRank, conf, p, r) { + }), + expr[Lead]( + "Window function that returns N entries ahead of this one", + ExprChecks.windowOnly( + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all, + Seq( + ParamCheck("input", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + + TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all), + ParamCheck("offset", TypeSig.INT, TypeSig.INT), + ParamCheck("default", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all) + ) + ), + (lead, conf, p, r) => new OffsetWindowFunctionMeta[Lead](lead, conf, p, r) { + }), + expr[Lag]( + "Window function that returns N entries behind this one", + ExprChecks.windowOnly( + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all, + Seq( + ParamCheck("input", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + + TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all), + ParamCheck("offset", TypeSig.INT, TypeSig.INT), + ParamCheck("default", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all) + ) + ), + (lag, conf, p, r) => new OffsetWindowFunctionMeta[Lag](lag, conf, p, r) { + }), + expr[PreciseTimestampConversion]( + "Expression used internally to convert the TimestampType to Long and back without losing " + + "precision, i.e. in microseconds. Used in time windowing", + ExprChecks.unaryProject( + TypeSig.TIMESTAMP + TypeSig.LONG, + TypeSig.TIMESTAMP + TypeSig.LONG, + TypeSig.TIMESTAMP + TypeSig.LONG, + TypeSig.TIMESTAMP + TypeSig.LONG), + (a, conf, p, r) => new UnaryExprMeta[PreciseTimestampConversion](a, conf, p, r) { + }), + expr[UnaryMinus]( + "Negate a numeric value", + ExprChecks.unaryProjectAndAstInputMatchesOutput( + TypeSig.implicitCastsAstTypes, + TypeSig.gpuNumeric, + TypeSig.numericAndInterval), + (a, conf, p, r) => new UnaryAstExprMeta[UnaryMinus](a, conf, p, r) { + // val ansiEnabled = SQLConf.get.ansiEnabled + val ansiEnabled = false + + override def tagSelfForAst(): Unit = { + // Spark 2.x - ansi in not in 2.x + /* + if (ansiEnabled && GpuAnsi.needBasicOpOverflowCheck(a.dataType)) { + willNotWorkInAst("AST unary minus does not support ANSI mode.") + } + + */ + } + }), + expr[UnaryPositive]( + "A numeric value with a + in front of it", + ExprChecks.unaryProjectAndAstInputMatchesOutput( + TypeSig.astTypes, + TypeSig.gpuNumeric, + TypeSig.numericAndInterval), + (a, conf, p, r) => new UnaryAstExprMeta[UnaryPositive](a, conf, p, r) { + }), + expr[Year]( + "Returns the year from a date or timestamp", + ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, TypeSig.DATE, TypeSig.DATE), + (a, conf, p, r) => new UnaryExprMeta[Year](a, conf, p, r) { + }), + expr[Month]( + "Returns the month from a date or timestamp", + ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, TypeSig.DATE, TypeSig.DATE), + (a, conf, p, r) => new UnaryExprMeta[Month](a, conf, p, r) { + }), + expr[Quarter]( + "Returns the quarter of the year for date, in the range 1 to 4", + ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, TypeSig.DATE, TypeSig.DATE), + (a, conf, p, r) => new UnaryExprMeta[Quarter](a, conf, p, r) { + }), + expr[DayOfMonth]( + "Returns the day of the month from a date or timestamp", + ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, TypeSig.DATE, TypeSig.DATE), + (a, conf, p, r) => new UnaryExprMeta[DayOfMonth](a, conf, p, r) { + }), + expr[DayOfYear]( + "Returns the day of the year from a date or timestamp", + ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, TypeSig.DATE, TypeSig.DATE), + (a, conf, p, r) => new UnaryExprMeta[DayOfYear](a, conf, p, r) { + }), + expr[Acos]( + "Inverse cosine", + ExprChecks.mathUnaryWithAst, + (a, conf, p, r) => new UnaryAstExprMeta[Acos](a, conf, p, r) { + }), + expr[Asin]( + "Inverse sine", + ExprChecks.mathUnaryWithAst, + (a, conf, p, r) => new UnaryAstExprMeta[Asin](a, conf, p, r) { + }), + expr[Sqrt]( + "Square root", + ExprChecks.mathUnaryWithAst, + (a, conf, p, r) => new UnaryAstExprMeta[Sqrt](a, conf, p, r) { + }), + expr[Cbrt]( + "Cube root", + ExprChecks.mathUnaryWithAst, + (a, conf, p, r) => new UnaryAstExprMeta[Cbrt](a, conf, p, r) { + }), + expr[Floor]( + "Floor of a number", + ExprChecks.unaryProjectInputMatchesOutput( + TypeSig.DOUBLE + TypeSig.LONG + TypeSig.DECIMAL_128, + TypeSig.DOUBLE + TypeSig.LONG + TypeSig.DECIMAL_128), + (a, conf, p, r) => new UnaryExprMeta[Floor](a, conf, p, r) { + override def tagExprForGpu(): Unit = { + a.dataType match { + case dt: DecimalType => + val precision = GpuFloorCeil.unboundedOutputPrecision(dt) + if (precision > GpuOverrides.DECIMAL128_MAX_PRECISION) { + willNotWorkOnGpu(s"output precision $precision would require overflow " + + s"checks, which are not supported yet") + } + case _ => // NOOP + } + } + + }), + expr[Ceil]( + "Ceiling of a number", + ExprChecks.unaryProjectInputMatchesOutput( + TypeSig.DOUBLE + TypeSig.LONG + TypeSig.DECIMAL_128, + TypeSig.DOUBLE + TypeSig.LONG + TypeSig.DECIMAL_128), + (a, conf, p, r) => new UnaryExprMeta[Ceil](a, conf, p, r) { + override def tagExprForGpu(): Unit = { + a.dataType match { + case dt: DecimalType => + val precision = GpuFloorCeil.unboundedOutputPrecision(dt) + if (precision > GpuOverrides.DECIMAL128_MAX_PRECISION) { + willNotWorkOnGpu(s"output precision $precision would require overflow " + + s"checks, which are not supported yet") + } + case _ => // NOOP + } + } + + }), + expr[Not]( + "Boolean not operator", + ExprChecks.unaryProjectAndAstInputMatchesOutput( + TypeSig.astTypes, TypeSig.BOOLEAN, TypeSig.BOOLEAN), + (a, conf, p, r) => new UnaryAstExprMeta[Not](a, conf, p, r) { + }), + expr[IsNull]( + "Checks if a value is null", + ExprChecks.unaryProject(TypeSig.BOOLEAN, TypeSig.BOOLEAN, + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.MAP + TypeSig.ARRAY + + TypeSig.STRUCT + TypeSig.DECIMAL_128).nested(), + TypeSig.all), + (a, conf, p, r) => new UnaryExprMeta[IsNull](a, conf, p, r) { + }), + expr[IsNotNull]( + "Checks if a value is not null", + ExprChecks.unaryProject(TypeSig.BOOLEAN, TypeSig.BOOLEAN, + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.MAP + TypeSig.ARRAY + + TypeSig.STRUCT + TypeSig.DECIMAL_128).nested(), + TypeSig.all), + (a, conf, p, r) => new UnaryExprMeta[IsNotNull](a, conf, p, r) { + }), + expr[IsNaN]( + "Checks if a value is NaN", + ExprChecks.unaryProject(TypeSig.BOOLEAN, TypeSig.BOOLEAN, + TypeSig.DOUBLE + TypeSig.FLOAT, TypeSig.DOUBLE + TypeSig.FLOAT), + (a, conf, p, r) => new UnaryExprMeta[IsNaN](a, conf, p, r) { + }), + expr[Rint]( + "Rounds up a double value to the nearest double equal to an integer", + ExprChecks.mathUnaryWithAst, + (a, conf, p, r) => new UnaryAstExprMeta[Rint](a, conf, p, r) { + }), + expr[BitwiseNot]( + "Returns the bitwise NOT of the operands", + ExprChecks.unaryProjectAndAstInputMatchesOutput( + TypeSig.implicitCastsAstTypes, TypeSig.integral, TypeSig.integral), + (a, conf, p, r) => new UnaryAstExprMeta[BitwiseNot](a, conf, p, r) { + }), + expr[AtLeastNNonNulls]( + "Checks if number of non null/Nan values is greater than a given value", + ExprChecks.projectOnly(TypeSig.BOOLEAN, TypeSig.BOOLEAN, + repeatingParamCheck = Some(RepeatingParamCheck("input", + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP + + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all))), + (a, conf, p, r) => new ExprMeta[AtLeastNNonNulls](a, conf, p, r) { + }), + expr[DateAdd]( + "Returns the date that is num_days after start_date", + ExprChecks.binaryProject(TypeSig.DATE, TypeSig.DATE, + ("startDate", TypeSig.DATE, TypeSig.DATE), + ("days", + TypeSig.INT + TypeSig.SHORT + TypeSig.BYTE, + TypeSig.INT + TypeSig.SHORT + TypeSig.BYTE)), + (a, conf, p, r) => new BinaryExprMeta[DateAdd](a, conf, p, r) { + }), + expr[DateSub]( + "Returns the date that is num_days before start_date", + ExprChecks.binaryProject(TypeSig.DATE, TypeSig.DATE, + ("startDate", TypeSig.DATE, TypeSig.DATE), + ("days", + TypeSig.INT + TypeSig.SHORT + TypeSig.BYTE, + TypeSig.INT + TypeSig.SHORT + TypeSig.BYTE)), + (a, conf, p, r) => new BinaryExprMeta[DateSub](a, conf, p, r) { + }), + expr[NaNvl]( + "Evaluates to `left` iff left is not NaN, `right` otherwise", + ExprChecks.binaryProject(TypeSig.fp, TypeSig.fp, + ("lhs", TypeSig.fp, TypeSig.fp), + ("rhs", TypeSig.fp, TypeSig.fp)), + (a, conf, p, r) => new BinaryExprMeta[NaNvl](a, conf, p, r) { + }), + expr[ShiftLeft]( + "Bitwise shift left (<<)", + ExprChecks.binaryProject(TypeSig.INT + TypeSig.LONG, TypeSig.INT + TypeSig.LONG, + ("value", TypeSig.INT + TypeSig.LONG, TypeSig.INT + TypeSig.LONG), + ("amount", TypeSig.INT, TypeSig.INT)), + (a, conf, p, r) => new BinaryExprMeta[ShiftLeft](a, conf, p, r) { + }), + expr[ShiftRight]( + "Bitwise shift right (>>)", + ExprChecks.binaryProject(TypeSig.INT + TypeSig.LONG, TypeSig.INT + TypeSig.LONG, + ("value", TypeSig.INT + TypeSig.LONG, TypeSig.INT + TypeSig.LONG), + ("amount", TypeSig.INT, TypeSig.INT)), + (a, conf, p, r) => new BinaryExprMeta[ShiftRight](a, conf, p, r) { + }), + expr[ShiftRightUnsigned]( + "Bitwise unsigned shift right (>>>)", + ExprChecks.binaryProject(TypeSig.INT + TypeSig.LONG, TypeSig.INT + TypeSig.LONG, + ("value", TypeSig.INT + TypeSig.LONG, TypeSig.INT + TypeSig.LONG), + ("amount", TypeSig.INT, TypeSig.INT)), + (a, conf, p, r) => new BinaryExprMeta[ShiftRightUnsigned](a, conf, p, r) { + }), + expr[BitwiseAnd]( + "Returns the bitwise AND of the operands", + ExprChecks.binaryProjectAndAst( + TypeSig.implicitCastsAstTypes, TypeSig.integral, TypeSig.integral, + ("lhs", TypeSig.integral, TypeSig.integral), + ("rhs", TypeSig.integral, TypeSig.integral)), + (a, conf, p, r) => new BinaryAstExprMeta[BitwiseAnd](a, conf, p, r) { + }), + expr[BitwiseOr]( + "Returns the bitwise OR of the operands", + ExprChecks.binaryProjectAndAst( + TypeSig.implicitCastsAstTypes, TypeSig.integral, TypeSig.integral, + ("lhs", TypeSig.integral, TypeSig.integral), + ("rhs", TypeSig.integral, TypeSig.integral)), + (a, conf, p, r) => new BinaryAstExprMeta[BitwiseOr](a, conf, p, r) { + }), + expr[BitwiseXor]( + "Returns the bitwise XOR of the operands", + ExprChecks.binaryProjectAndAst( + TypeSig.implicitCastsAstTypes, TypeSig.integral, TypeSig.integral, + ("lhs", TypeSig.integral, TypeSig.integral), + ("rhs", TypeSig.integral, TypeSig.integral)), + (a, conf, p, r) => new BinaryAstExprMeta[BitwiseXor](a, conf, p, r) { + }), + expr[Coalesce] ( + "Returns the first non-null argument if exists. Otherwise, null", + ExprChecks.projectOnly( + (_gpuCommonTypes + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all, + repeatingParamCheck = Some(RepeatingParamCheck("param", + (_gpuCommonTypes + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all))), + (a, conf, p, r) => new ExprMeta[Coalesce](a, conf, p, r) { + }), + expr[Least] ( + "Returns the least value of all parameters, skipping null values", + ExprChecks.projectOnly( + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, TypeSig.orderable, + repeatingParamCheck = Some(RepeatingParamCheck("param", + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + TypeSig.orderable))), + (a, conf, p, r) => new ExprMeta[Least](a, conf, p, r) { + }), + expr[Greatest] ( + "Returns the greatest value of all parameters, skipping null values", + ExprChecks.projectOnly( + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, TypeSig.orderable, + repeatingParamCheck = Some(RepeatingParamCheck("param", + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + TypeSig.orderable))), + (a, conf, p, r) => new ExprMeta[Greatest](a, conf, p, r) { + }), + expr[Atan]( + "Inverse tangent", + ExprChecks.mathUnaryWithAst, + (a, conf, p, r) => new UnaryAstExprMeta[Atan](a, conf, p, r) { + }), + expr[Cos]( + "Cosine", + ExprChecks.mathUnaryWithAst, + (a, conf, p, r) => new UnaryAstExprMeta[Cos](a, conf, p, r) { + }), + expr[Exp]( + "Euler's number e raised to a power", + ExprChecks.mathUnaryWithAst, + (a, conf, p, r) => new UnaryAstExprMeta[Exp](a, conf, p, r) { + }), + expr[Expm1]( + "Euler's number e raised to a power minus 1", + ExprChecks.mathUnaryWithAst, + (a, conf, p, r) => new UnaryAstExprMeta[Expm1](a, conf, p, r) { + }), + expr[InitCap]( + "Returns str with the first letter of each word in uppercase. " + + "All other letters are in lowercase", + ExprChecks.unaryProjectInputMatchesOutput(TypeSig.STRING, TypeSig.STRING), + (a, conf, p, r) => new UnaryExprMeta[InitCap](a, conf, p, r) { + }).incompat(CASE_MODIFICATION_INCOMPAT), + expr[Log]( + "Natural log", + ExprChecks.mathUnary, + (a, conf, p, r) => new UnaryExprMeta[Log](a, conf, p, r) { + }), + expr[Log1p]( + "Natural log 1 + expr", + ExprChecks.mathUnary, + (a, conf, p, r) => new UnaryExprMeta[Log1p](a, conf, p, r) { + }), + expr[Log2]( + "Log base 2", + ExprChecks.mathUnary, + (a, conf, p, r) => new UnaryExprMeta[Log2](a, conf, p, r) { + }), + expr[Log10]( + "Log base 10", + ExprChecks.mathUnary, + (a, conf, p, r) => new UnaryExprMeta[Log10](a, conf, p, r) { + }), + expr[Logarithm]( + "Log variable base", + ExprChecks.binaryProject(TypeSig.DOUBLE, TypeSig.DOUBLE, + ("value", TypeSig.DOUBLE, TypeSig.DOUBLE), + ("base", TypeSig.DOUBLE, TypeSig.DOUBLE)), + (a, conf, p, r) => new BinaryExprMeta[Logarithm](a, conf, p, r) { + }), + expr[Sin]( + "Sine", + ExprChecks.mathUnaryWithAst, + (a, conf, p, r) => new UnaryAstExprMeta[Sin](a, conf, p, r) { + }), + expr[Sinh]( + "Hyperbolic sine", + ExprChecks.mathUnaryWithAst, + (a, conf, p, r) => new UnaryAstExprMeta[Sinh](a, conf, p, r) { + }), + expr[Cosh]( + "Hyperbolic cosine", + ExprChecks.mathUnaryWithAst, + (a, conf, p, r) => new UnaryAstExprMeta[Cosh](a, conf, p, r) { + }), + expr[Cot]( + "Cotangent", + ExprChecks.mathUnaryWithAst, + (a, conf, p, r) => new UnaryAstExprMeta[Cot](a, conf, p, r) { + }), + expr[Tanh]( + "Hyperbolic tangent", + ExprChecks.mathUnaryWithAst, + (a, conf, p, r) => new UnaryAstExprMeta[Tanh](a, conf, p, r) { + }), + expr[Tan]( + "Tangent", + ExprChecks.mathUnaryWithAst, + (a, conf, p, r) => new UnaryAstExprMeta[Tan](a, conf, p, r) { + }), + expr[KnownNotNull]( + "Tag an expression as known to not be null", + ExprChecks.unaryProjectInputMatchesOutput( + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.BINARY + TypeSig.CALENDAR + + TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT).nested(), TypeSig.all), + (k, conf, p, r) => new UnaryExprMeta[KnownNotNull](k, conf, p, r) { + }), + expr[DateDiff]( + "Returns the number of days from startDate to endDate", + ExprChecks.binaryProject(TypeSig.INT, TypeSig.INT, + ("lhs", TypeSig.DATE, TypeSig.DATE), + ("rhs", TypeSig.DATE, TypeSig.DATE)), + (a, conf, p, r) => new BinaryExprMeta[DateDiff](a, conf, p, r) { + }), + expr[TimeAdd]( + "Adds interval to timestamp", + ExprChecks.binaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP, + ("start", TypeSig.TIMESTAMP, TypeSig.TIMESTAMP), + ("interval", TypeSig.lit(TypeEnum.CALENDAR) + .withPsNote(TypeEnum.CALENDAR, "month intervals are not supported"), + TypeSig.CALENDAR)), + (timeAdd, conf, p, r) => new BinaryExprMeta[TimeAdd](timeAdd, conf, p, r) { + override def tagExprForGpu(): Unit = { + GpuOverrides.extractLit(timeAdd.interval).foreach { lit => + val intvl = lit.value.asInstanceOf[CalendarInterval] + if (intvl.months != 0) { + willNotWorkOnGpu("interval months isn't supported") + } + } + checkTimeZoneId(timeAdd.timeZoneId) + } + + }), + expr[DateFormatClass]( + "Converts timestamp to a value of string in the format specified by the date format", + ExprChecks.binaryProject(TypeSig.STRING, TypeSig.STRING, + ("timestamp", TypeSig.TIMESTAMP, TypeSig.TIMESTAMP), + ("strfmt", TypeSig.lit(TypeEnum.STRING) + .withPsNote(TypeEnum.STRING, "A limited number of formats are supported"), + TypeSig.STRING)), + (a, conf, p, r) => new UnixTimeExprMeta[DateFormatClass](a, conf, p, r) { + override def shouldFallbackOnAnsiTimestamp: Boolean = false + + } + ), + expr[ToUnixTimestamp]( + "Returns the UNIX timestamp of the given time", + ExprChecks.binaryProject(TypeSig.LONG, TypeSig.LONG, + ("timeExp", + TypeSig.STRING + TypeSig.DATE + TypeSig.TIMESTAMP, + TypeSig.STRING + TypeSig.DATE + TypeSig.TIMESTAMP), + ("format", TypeSig.lit(TypeEnum.STRING) + .withPsNote(TypeEnum.STRING, "A limited number of formats are supported"), + TypeSig.STRING)), + (a, conf, p, r) => new UnixTimeExprMeta[ToUnixTimestamp](a, conf, p, r) { + override def shouldFallbackOnAnsiTimestamp: Boolean = false + // ShimLoader.getSparkShims.shouldFallbackOnAnsiTimestamp + }), + expr[UnixTimestamp]( + "Returns the UNIX timestamp of current or specified time", + ExprChecks.binaryProject(TypeSig.LONG, TypeSig.LONG, + ("timeExp", + TypeSig.STRING + TypeSig.DATE + TypeSig.TIMESTAMP, + TypeSig.STRING + TypeSig.DATE + TypeSig.TIMESTAMP), + ("format", TypeSig.lit(TypeEnum.STRING) + .withPsNote(TypeEnum.STRING, "A limited number of formats are supported"), + TypeSig.STRING)), + (a, conf, p, r) => new UnixTimeExprMeta[UnixTimestamp](a, conf, p, r) { + override def shouldFallbackOnAnsiTimestamp: Boolean = false + // ShimLoader.getSparkShims.shouldFallbackOnAnsiTimestamp + + }), + expr[Hour]( + "Returns the hour component of the string/timestamp", + ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, + TypeSig.TIMESTAMP, TypeSig.TIMESTAMP), + (hour, conf, p, r) => new UnaryExprMeta[Hour](hour, conf, p, r) { + override def tagExprForGpu(): Unit = { + checkTimeZoneId(hour.timeZoneId) + } + + }), + expr[Minute]( + "Returns the minute component of the string/timestamp", + ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, + TypeSig.TIMESTAMP, TypeSig.TIMESTAMP), + (minute, conf, p, r) => new UnaryExprMeta[Minute](minute, conf, p, r) { + override def tagExprForGpu(): Unit = { + checkTimeZoneId(minute.timeZoneId) + } + + }), + expr[Second]( + "Returns the second component of the string/timestamp", + ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, + TypeSig.TIMESTAMP, TypeSig.TIMESTAMP), + (second, conf, p, r) => new UnaryExprMeta[Second](second, conf, p, r) { + override def tagExprForGpu(): Unit = { + checkTimeZoneId(second.timeZoneId) + } + }), + expr[WeekDay]( + "Returns the day of the week (0 = Monday...6=Sunday)", + ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, + TypeSig.DATE, TypeSig.DATE), + (a, conf, p, r) => new UnaryExprMeta[WeekDay](a, conf, p, r) { + }), + expr[DayOfWeek]( + "Returns the day of the week (1 = Sunday...7=Saturday)", + ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, + TypeSig.DATE, TypeSig.DATE), + (a, conf, p, r) => new UnaryExprMeta[DayOfWeek](a, conf, p, r) { + }), + expr[LastDay]( + "Returns the last day of the month which the date belongs to", + ExprChecks.unaryProjectInputMatchesOutput(TypeSig.DATE, TypeSig.DATE), + (a, conf, p, r) => new UnaryExprMeta[LastDay](a, conf, p, r) { + }), + expr[FromUnixTime]( + "Get the string from a unix timestamp", + ExprChecks.binaryProject(TypeSig.STRING, TypeSig.STRING, + ("sec", TypeSig.LONG, TypeSig.LONG), + ("format", TypeSig.lit(TypeEnum.STRING) + .withPsNote(TypeEnum.STRING, "Only a limited number of formats are supported"), + TypeSig.STRING)), + (a, conf, p, r) => new UnixTimeExprMeta[FromUnixTime](a, conf, p, r) { + override def shouldFallbackOnAnsiTimestamp: Boolean = false + + }), + expr[Pmod]( + "Pmod", + ExprChecks.binaryProject(TypeSig.integral + TypeSig.fp, TypeSig.cpuNumeric, + ("lhs", TypeSig.integral + TypeSig.fp, TypeSig.cpuNumeric), + ("rhs", TypeSig.integral + TypeSig.fp, TypeSig.cpuNumeric)), + (a, conf, p, r) => new BinaryExprMeta[Pmod](a, conf, p, r) { + }), + expr[Add]( + "Addition", + ExprChecks.binaryProjectAndAst( + TypeSig.implicitCastsAstTypes, + TypeSig.gpuNumeric, TypeSig.numericAndInterval, + ("lhs", TypeSig.gpuNumeric, + TypeSig.numericAndInterval), + ("rhs", TypeSig.gpuNumeric, + TypeSig.numericAndInterval)), + (a, conf, p, r) => new BinaryAstExprMeta[Add](a, conf, p, r) { + private val ansiEnabled = false + + override def tagSelfForAst(): Unit = { + } + + }), + expr[Subtract]( + "Subtraction", + ExprChecks.binaryProjectAndAst( + TypeSig.implicitCastsAstTypes, + TypeSig.gpuNumeric, TypeSig.numericAndInterval, + ("lhs", TypeSig.gpuNumeric, + TypeSig.numericAndInterval), + ("rhs", TypeSig.gpuNumeric, + TypeSig.numericAndInterval)), + (a, conf, p, r) => new BinaryAstExprMeta[Subtract](a, conf, p, r) { + private val ansiEnabled = false + + override def tagSelfForAst(): Unit = { + } + + }), + expr[Multiply]( + "Multiplication", + ExprChecks.binaryProjectAndAst( + TypeSig.implicitCastsAstTypes, + TypeSig.gpuNumeric + TypeSig.psNote(TypeEnum.DECIMAL, + "Because of Spark's inner workings the full range of decimal precision " + + "(even for 128-bit values) is not supported."), + TypeSig.cpuNumeric, + ("lhs", TypeSig.gpuNumeric, TypeSig.cpuNumeric), + ("rhs", TypeSig.gpuNumeric, TypeSig.cpuNumeric)), + (a, conf, p, r) => new BinaryAstExprMeta[Multiply](a, conf, p, r) { + override def tagExprForGpu(): Unit = { + } + + }), + expr[And]( + "Logical AND", + ExprChecks.binaryProjectAndAst(TypeSig.BOOLEAN, TypeSig.BOOLEAN, TypeSig.BOOLEAN, + ("lhs", TypeSig.BOOLEAN, TypeSig.BOOLEAN), + ("rhs", TypeSig.BOOLEAN, TypeSig.BOOLEAN)), + (a, conf, p, r) => new BinaryExprMeta[And](a, conf, p, r) { + }), + expr[Or]( + "Logical OR", + ExprChecks.binaryProjectAndAst(TypeSig.BOOLEAN, TypeSig.BOOLEAN, TypeSig.BOOLEAN, + ("lhs", TypeSig.BOOLEAN, TypeSig.BOOLEAN), + ("rhs", TypeSig.BOOLEAN, TypeSig.BOOLEAN)), + (a, conf, p, r) => new BinaryExprMeta[Or](a, conf, p, r) { + }), + expr[EqualNullSafe]( + "Check if the values are equal including nulls <=>", + ExprChecks.binaryProject( + TypeSig.BOOLEAN, TypeSig.BOOLEAN, + ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + TypeSig.comparable), + ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + TypeSig.comparable)), + (a, conf, p, r) => new BinaryExprMeta[EqualNullSafe](a, conf, p, r) { + }), + expr[EqualTo]( + "Check if the values are equal", + ExprChecks.binaryProjectAndAst( + TypeSig.comparisonAstTypes, + TypeSig.BOOLEAN, TypeSig.BOOLEAN, + ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + TypeSig.comparable), + ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + TypeSig.comparable)), + (a, conf, p, r) => new BinaryAstExprMeta[EqualTo](a, conf, p, r) { + }), + expr[GreaterThan]( + "> operator", + ExprChecks.binaryProjectAndAst( + TypeSig.comparisonAstTypes, + TypeSig.BOOLEAN, TypeSig.BOOLEAN, + ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + TypeSig.orderable), + ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + TypeSig.orderable)), + (a, conf, p, r) => new BinaryAstExprMeta[GreaterThan](a, conf, p, r) { + }), + expr[GreaterThanOrEqual]( + ">= operator", + ExprChecks.binaryProjectAndAst( + TypeSig.comparisonAstTypes, + TypeSig.BOOLEAN, TypeSig.BOOLEAN, + ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + TypeSig.orderable), + ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + TypeSig.orderable)), + (a, conf, p, r) => new BinaryAstExprMeta[GreaterThanOrEqual](a, conf, p, r) { + }), + expr[In]( + "IN operator", + ExprChecks.projectOnly(TypeSig.BOOLEAN, TypeSig.BOOLEAN, + Seq(ParamCheck("value", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + TypeSig.comparable)), + Some(RepeatingParamCheck("list", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128).withAllLit(), + TypeSig.comparable))), + (in, conf, p, r) => new ExprMeta[In](in, conf, p, r) { + override def tagExprForGpu(): Unit = { + val unaliased = in.list.map(extractLit) + val hasNullLiteral = unaliased.exists { + case Some(l) => l.value == null + case _ => false + } + if (hasNullLiteral) { + willNotWorkOnGpu("nulls are not supported") + } + } + }), + expr[InSet]( + "INSET operator", + ExprChecks.unaryProject(TypeSig.BOOLEAN, TypeSig.BOOLEAN, + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, TypeSig.comparable), + (in, conf, p, r) => new ExprMeta[InSet](in, conf, p, r) { + override def tagExprForGpu(): Unit = { + if (in.hset.contains(null)) { + willNotWorkOnGpu("nulls are not supported") + } + } + }), + expr[LessThan]( + "< operator", + ExprChecks.binaryProjectAndAst( + TypeSig.comparisonAstTypes, + TypeSig.BOOLEAN, TypeSig.BOOLEAN, + ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + TypeSig.orderable), + ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + TypeSig.orderable)), + (a, conf, p, r) => new BinaryAstExprMeta[LessThan](a, conf, p, r) { + }), + expr[LessThanOrEqual]( + "<= operator", + ExprChecks.binaryProjectAndAst( + TypeSig.comparisonAstTypes, + TypeSig.BOOLEAN, TypeSig.BOOLEAN, + ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + TypeSig.orderable), + ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + TypeSig.orderable)), + (a, conf, p, r) => new BinaryAstExprMeta[LessThanOrEqual](a, conf, p, r) { + }), + expr[CaseWhen]( + "CASE WHEN expression", + CaseWhenCheck, + (a, conf, p, r) => new ExprMeta[CaseWhen](a, conf, p, r) { + }), + expr[If]( + "IF expression", + ExprChecks.projectOnly( + (_gpuCommonTypes + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT + + TypeSig.MAP).nested(), + TypeSig.all, + Seq(ParamCheck("predicate", TypeSig.BOOLEAN, TypeSig.BOOLEAN), + ParamCheck("trueValue", + (_gpuCommonTypes + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT + + TypeSig.MAP).nested(), + TypeSig.all), + ParamCheck("falseValue", + (_gpuCommonTypes + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT + + TypeSig.MAP).nested(), + TypeSig.all))), + (a, conf, p, r) => new ExprMeta[If](a, conf, p, r) { + }), + expr[Pow]( + "lhs ^ rhs", + ExprChecks.binaryProjectAndAst( + TypeSig.implicitCastsAstTypes, TypeSig.DOUBLE, TypeSig.DOUBLE, + ("lhs", TypeSig.DOUBLE, TypeSig.DOUBLE), + ("rhs", TypeSig.DOUBLE, TypeSig.DOUBLE)), + (a, conf, p, r) => new BinaryAstExprMeta[Pow](a, conf, p, r) { + }), + expr[Divide]( + "Division", + ExprChecks.binaryProject( + TypeSig.DOUBLE + TypeSig.DECIMAL_128 + + TypeSig.psNote(TypeEnum.DECIMAL, + "Because of Spark's inner workings the full range of decimal precision " + + "(even for 128-bit values) is not supported."), + TypeSig.DOUBLE + TypeSig.DECIMAL_128, + ("lhs", TypeSig.DOUBLE + TypeSig.DECIMAL_128, + TypeSig.DOUBLE + TypeSig.DECIMAL_128), + ("rhs", TypeSig.DOUBLE + TypeSig.DECIMAL_128, + TypeSig.DOUBLE + TypeSig.DECIMAL_128)), + (a, conf, p, r) => new BinaryExprMeta[Divide](a, conf, p, r) { + }), + expr[Remainder]( + "Remainder or modulo", + ExprChecks.binaryProject( + TypeSig.integral + TypeSig.fp, TypeSig.cpuNumeric, + ("lhs", TypeSig.integral + TypeSig.fp, TypeSig.cpuNumeric), + ("rhs", TypeSig.integral + TypeSig.fp, TypeSig.cpuNumeric)), + (a, conf, p, r) => new BinaryExprMeta[Remainder](a, conf, p, r) { + }), + expr[AggregateExpression]( + "Aggregate expression", + ExprChecks.fullAgg( + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP).nested(), + TypeSig.all, + Seq(ParamCheck( + "aggFunc", + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP).nested(), + TypeSig.all)), + Some(RepeatingParamCheck("filter", TypeSig.BOOLEAN, TypeSig.BOOLEAN))), + (a, conf, p, r) => new ExprMeta[AggregateExpression](a, conf, p, r) { + // No filter parameter in 2.x + private val childrenExprMeta: Seq[BaseExprMeta[Expression]] = + a.children.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + override val childExprs: Seq[BaseExprMeta[_]] = + childrenExprMeta + }), + expr[SortOrder]( + "Sort order", + ExprChecks.projectOnly( + (pluginSupportedOrderableSig + TypeSig.DECIMAL_128 + TypeSig.STRUCT).nested(), + TypeSig.orderable, + Seq(ParamCheck( + "input", + (pluginSupportedOrderableSig + TypeSig.DECIMAL_128 + TypeSig.STRUCT).nested(), + TypeSig.orderable))), + (sortOrder, conf, p, r) => new BaseExprMeta[SortOrder](sortOrder, conf, p, r) { + override def tagExprForGpu(): Unit = { + if (isStructType(sortOrder.dataType)) { + val nullOrdering = sortOrder.nullOrdering + val directionDefaultNullOrdering = sortOrder.direction.defaultNullOrdering + val direction = sortOrder.direction.sql + if (nullOrdering != directionDefaultNullOrdering) { + willNotWorkOnGpu(s"only default null ordering $directionDefaultNullOrdering " + + s"for direction $direction is supported for nested types; actual: ${nullOrdering}") + } + } + } + + }), + expr[PivotFirst]( + "PivotFirst operator", + ExprChecks.reductionAndGroupByAgg( + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_64 + + TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_64), + TypeSig.all, + Seq(ParamCheck( + "pivotColumn", + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_64) + .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) + .withPsNote(TypeEnum.FLOAT, nanAggPsNote), + TypeSig.all), + ParamCheck("valueColumn", + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_64, + TypeSig.all))), + (pivot, conf, p, r) => new ImperativeAggExprMeta[PivotFirst](pivot, conf, p, r) { + override def tagAggForGpu(): Unit = { + checkAndTagFloatNanAgg("Pivot", pivot.pivotColumn.dataType, conf, this) + // If pivotColumnValues doesn't have distinct values, fall back to CPU + if (pivot.pivotColumnValues.distinct.lengthCompare(pivot.pivotColumnValues.length) != 0) { + willNotWorkOnGpu("PivotFirst does not work on the GPU when there are duplicate" + + " pivot values provided") + } + } + }), + expr[Count]( + "Count aggregate operator", + ExprChecks.fullAgg( + TypeSig.LONG, TypeSig.LONG, + repeatingParamCheck = Some(RepeatingParamCheck( + "input", _gpuCommonTypes + TypeSig.DECIMAL_128 + + TypeSig.STRUCT.nested(_gpuCommonTypes + TypeSig.DECIMAL_128), + TypeSig.all))), + (count, conf, p, r) => new AggExprMeta[Count](count, conf, p, r) { + override def tagAggForGpu(): Unit = { + if (count.children.size > 1) { + willNotWorkOnGpu("count of multiple columns not supported") + } + } + }), + expr[Max]( + "Max aggregate operator", + ExprChecksImpl( + ExprChecks.reductionAndGroupByAgg( + // Max supports single level struct, e.g.: max(struct(string, string)) + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT) + .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL), + TypeSig.orderable, + Seq(ParamCheck("input", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT) + .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) + .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) + .withPsNote(TypeEnum.FLOAT, nanAggPsNote), + TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts + ++ + ExprChecks.windowOnly( + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL), + TypeSig.orderable, + Seq(ParamCheck("input", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) + .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) + .withPsNote(TypeEnum.FLOAT, nanAggPsNote), + TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts), + (max, conf, p, r) => new AggExprMeta[Max](max, conf, p, r) { + override def tagAggForGpu(): Unit = { + val dataType = max.child.dataType + checkAndTagFloatNanAgg("Max", dataType, conf, this) + } + }), + expr[Min]( + "Min aggregate operator", + ExprChecksImpl( + ExprChecks.reductionAndGroupByAgg( + // Min supports single level struct, e.g.: max(struct(string, string)) + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT) + .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL), + TypeSig.orderable, + Seq(ParamCheck("input", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT) + .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) + .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) + .withPsNote(TypeEnum.FLOAT, nanAggPsNote), + TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts + ++ + ExprChecks.windowOnly( + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL), + TypeSig.orderable, + Seq(ParamCheck("input", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) + .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) + .withPsNote(TypeEnum.FLOAT, nanAggPsNote), + TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts), + (a, conf, p, r) => new AggExprMeta[Min](a, conf, p, r) { + override def tagAggForGpu(): Unit = { + val dataType = a.child.dataType + checkAndTagFloatNanAgg("Min", dataType, conf, this) + } + }), + expr[Sum]( + "Sum aggregate operator", + ExprChecks.fullAgg( + TypeSig.LONG + TypeSig.DOUBLE + TypeSig.DECIMAL_128, + TypeSig.LONG + TypeSig.DOUBLE + TypeSig.DECIMAL_128, + Seq(ParamCheck("input", TypeSig.gpuNumeric, TypeSig.cpuNumeric))), + (a, conf, p, r) => new AggExprMeta[Sum](a, conf, p, r) { + override def tagAggForGpu(): Unit = { + val inputDataType = a.child.dataType + checkAndTagFloatAgg(inputDataType, conf, this) + } + + }), + expr[First]( + "first aggregate operator", { + ExprChecks.aggNotWindow( + (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), + TypeSig.all, + Seq(ParamCheck("input", + (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), + TypeSig.all)) + ) + }, + (a, conf, p, r) => new AggExprMeta[First](a, conf, p, r) { + }), + expr[Last]( + "last aggregate operator", { + ExprChecks.aggNotWindow( + (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), + TypeSig.all, + Seq(ParamCheck("input", + (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(), + TypeSig.all)) + ) + }, + (a, conf, p, r) => new AggExprMeta[Last](a, conf, p, r) { + }), + expr[BRound]( + "Round an expression to d decimal places using HALF_EVEN rounding mode", + ExprChecks.binaryProject( + TypeSig.gpuNumeric, TypeSig.cpuNumeric, + ("value", TypeSig.gpuNumeric + + TypeSig.psNote(TypeEnum.FLOAT, "result may round slightly differently") + + TypeSig.psNote(TypeEnum.DOUBLE, "result may round slightly differently"), + TypeSig.cpuNumeric), + ("scale", TypeSig.lit(TypeEnum.INT), TypeSig.lit(TypeEnum.INT))), + (a, conf, p, r) => new BinaryExprMeta[BRound](a, conf, p, r) { + override def tagExprForGpu(): Unit = { + a.child.dataType match { + case FloatType | DoubleType if !conf.isIncompatEnabled => + willNotWorkOnGpu("rounding floating point numbers may be slightly off " + + s"compared to Spark's result, to enable set ${RapidsConf.INCOMPATIBLE_OPS}") + case _ => // NOOP + } + } + }), + expr[Round]( + "Round an expression to d decimal places using HALF_UP rounding mode", + ExprChecks.binaryProject( + TypeSig.gpuNumeric, TypeSig.cpuNumeric, + ("value", TypeSig.gpuNumeric + + TypeSig.psNote(TypeEnum.FLOAT, "result may round slightly differently") + + TypeSig.psNote(TypeEnum.DOUBLE, "result may round slightly differently"), + TypeSig.cpuNumeric), + ("scale", TypeSig.lit(TypeEnum.INT), TypeSig.lit(TypeEnum.INT))), + (a, conf, p, r) => new BinaryExprMeta[Round](a, conf, p, r) { + override def tagExprForGpu(): Unit = { + a.child.dataType match { + case FloatType | DoubleType if !conf.isIncompatEnabled => + willNotWorkOnGpu("rounding floating point numbers may be slightly off " + + s"compared to Spark's result, to enable set ${RapidsConf.INCOMPATIBLE_OPS}") + case _ => // NOOP + } + } + }), + expr[PythonUDF]( + "UDF run in an external python process. Does not actually run on the GPU, but " + + "the transfer of data to/from it can be accelerated", + ExprChecks.fullAggAndProject( + // Different types of Pandas UDF support different sets of output type. Please refer to + // https://github.com/apache/spark/blob/master/python/pyspark/sql/udf.py#L98 + // for more details. + // It is impossible to specify the exact type signature for each Pandas UDF type in a single + // expression 'PythonUDF'. + // So use the 'unionOfPandasUdfOut' to cover all types for Spark. The type signature of + // plugin is also an union of all the types of Pandas UDF. + (TypeSig.commonCudfTypes + TypeSig.ARRAY).nested() + TypeSig.STRUCT, + TypeSig.unionOfPandasUdfOut, + repeatingParamCheck = Some(RepeatingParamCheck( + "param", + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all))), + (a, conf, p, r) => new ExprMeta[PythonUDF](a, conf, p, r) { + override def replaceMessage: String = "not block GPU acceleration" + override def noReplacementPossibleMessage(reasons: String): String = + s"blocks running on GPU because $reasons" + }), + expr[Rand]( + "Generate a random column with i.i.d. uniformly distributed values in [0, 1)", + ExprChecks.projectOnly(TypeSig.DOUBLE, TypeSig.DOUBLE, + Seq(ParamCheck("seed", + (TypeSig.INT + TypeSig.LONG).withAllLit(), + (TypeSig.INT + TypeSig.LONG).withAllLit()))), + (a, conf, p, r) => new UnaryExprMeta[Rand](a, conf, p, r) { + }), + expr[SparkPartitionID] ( + "Returns the current partition id", + ExprChecks.projectOnly(TypeSig.INT, TypeSig.INT), + (a, conf, p, r) => new ExprMeta[SparkPartitionID](a, conf, p, r) { + }), + expr[MonotonicallyIncreasingID] ( + "Returns monotonically increasing 64-bit integers", + ExprChecks.projectOnly(TypeSig.LONG, TypeSig.LONG), + (a, conf, p, r) => new ExprMeta[MonotonicallyIncreasingID](a, conf, p, r) { + }), + expr[InputFileName] ( + "Returns the name of the file being read, or empty string if not available", + ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING), + (a, conf, p, r) => new ExprMeta[InputFileName](a, conf, p, r) { + }), + expr[InputFileBlockStart] ( + "Returns the start offset of the block being read, or -1 if not available", + ExprChecks.projectOnly(TypeSig.LONG, TypeSig.LONG), + (a, conf, p, r) => new ExprMeta[InputFileBlockStart](a, conf, p, r) { + }), + expr[InputFileBlockLength] ( + "Returns the length of the block being read, or -1 if not available", + ExprChecks.projectOnly(TypeSig.LONG, TypeSig.LONG), + (a, conf, p, r) => new ExprMeta[InputFileBlockLength](a, conf, p, r) { + }), + expr[Md5] ( + "MD5 hash operator", + ExprChecks.unaryProject(TypeSig.STRING, TypeSig.STRING, + TypeSig.BINARY, TypeSig.BINARY), + (a, conf, p, r) => new UnaryExprMeta[Md5](a, conf, p, r) { + }), + expr[Upper]( + "String uppercase operator", + ExprChecks.unaryProjectInputMatchesOutput(TypeSig.STRING, TypeSig.STRING), + (a, conf, p, r) => new UnaryExprMeta[Upper](a, conf, p, r) { + }) + .incompat(CASE_MODIFICATION_INCOMPAT), + expr[Lower]( + "String lowercase operator", + ExprChecks.unaryProjectInputMatchesOutput(TypeSig.STRING, TypeSig.STRING), + (a, conf, p, r) => new UnaryExprMeta[Lower](a, conf, p, r) { + }) + .incompat(CASE_MODIFICATION_INCOMPAT), + expr[StringLPad]( + "Pad a string on the left", + ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, + Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), + ParamCheck("len", TypeSig.lit(TypeEnum.INT), TypeSig.INT), + ParamCheck("pad", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), + (in, conf, p, r) => new TernaryExprMeta[StringLPad](in, conf, p, r) { + override def tagExprForGpu(): Unit = { + extractLit(in.pad).foreach { padLit => + if (padLit.value != null && + padLit.value.asInstanceOf[UTF8String].toString.length != 1) { + willNotWorkOnGpu("only a single character is supported for pad") + } + } + } + }), + expr[StringRPad]( + "Pad a string on the right", + ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, + Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), + ParamCheck("len", TypeSig.lit(TypeEnum.INT), TypeSig.INT), + ParamCheck("pad", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), + (in, conf, p, r) => new TernaryExprMeta[StringRPad](in, conf, p, r) { + override def tagExprForGpu(): Unit = { + extractLit(in.pad).foreach { padLit => + if (padLit.value != null && + padLit.value.asInstanceOf[UTF8String].toString.length != 1) { + willNotWorkOnGpu("only a single character is supported for pad") + } + } + } + }), + expr[StringSplit]( + "Splits `str` around occurrences that match `regex`", + ExprChecks.projectOnly(TypeSig.ARRAY.nested(TypeSig.STRING), + TypeSig.ARRAY.nested(TypeSig.STRING), + Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), + ParamCheck("regexp", TypeSig.lit(TypeEnum.STRING) + .withPsNote(TypeEnum.STRING, "very limited subset of regex supported"), + TypeSig.STRING), + ParamCheck("limit", TypeSig.lit(TypeEnum.INT), TypeSig.INT))), + (in, conf, p, r) => new GpuStringSplitMeta(in, conf, p, r)), + expr[GetStructField]( + "Gets the named field of the struct", + ExprChecks.unaryProject( + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP + TypeSig.NULL + + TypeSig.DECIMAL_128).nested(), + TypeSig.all, + TypeSig.STRUCT.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + + TypeSig.STRUCT + TypeSig.MAP + TypeSig.NULL + TypeSig.DECIMAL_128), + TypeSig.STRUCT.nested(TypeSig.all)), + (expr, conf, p, r) => new UnaryExprMeta[GetStructField](expr, conf, p, r) { + }), + expr[GetArrayItem]( + "Gets the field at `ordinal` in the Array", + ExprChecks.binaryProject( + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + + TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), + TypeSig.all, + ("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + + TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP), + TypeSig.ARRAY.nested(TypeSig.all)), + ("ordinal", TypeSig.lit(TypeEnum.INT), TypeSig.INT)), + (in, conf, p, r) => new GpuGetArrayItemMeta(in, conf, p, r)), + expr[GetMapValue]( + "Gets Value from a Map based on a key", + ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all, + ("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)), + ("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)), + (in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r)), + expr[ElementAt]( + "Returns element of array at given(1-based) index in value if column is array. " + + "Returns value for the given key in value if column is map", + ExprChecks.binaryProject( + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + + TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all, + ("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + + TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) + + TypeSig.MAP.nested(TypeSig.STRING) + .withPsNote(TypeEnum.MAP ,"If it's map, only string is supported."), + TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)), + ("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING)) + .withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " + + "not as maps keys") + .withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " + + "not array indexes"), + TypeSig.all)), + (in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) { + override def tagExprForGpu(): Unit = { + // To distinguish the supported nested type between Array and Map + val checks = in.left.dataType match { + case _: MapType => + // Match exactly with the checks for GetMapValue + ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all, + ("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)), + ("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)) + case _: ArrayType => + // Match exactly with the checks for GetArrayItem + ExprChecks.binaryProject( + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + + TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), + TypeSig.all, + ("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + + TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP), + TypeSig.ARRAY.nested(TypeSig.all)), + ("ordinal", TypeSig.lit(TypeEnum.INT), TypeSig.INT)) + case _ => throw new IllegalStateException("Only Array or Map is supported as input.") + } + checks.tag(this) + } + }), + expr[MapKeys]( + "Returns an unordered array containing the keys of the map", + ExprChecks.unaryProject( + TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.ARRAY.nested(TypeSig.all - TypeSig.MAP), // Maps cannot have other maps as keys + TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), + TypeSig.MAP.nested(TypeSig.all)), + (in, conf, p, r) => new UnaryExprMeta[MapKeys](in, conf, p, r) { + }), + expr[MapValues]( + "Returns an unordered array containing the values of the map", + ExprChecks.unaryProject( + TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), + TypeSig.ARRAY.nested(TypeSig.all), + TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), + TypeSig.MAP.nested(TypeSig.all)), + (in, conf, p, r) => new UnaryExprMeta[MapValues](in, conf, p, r) { + }), + expr[ArrayMin]( + "Returns the minimum value in the array", + ExprChecks.unaryProject( + TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL, + TypeSig.orderable, + TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) + .withPsNote(TypeEnum.DOUBLE, GpuOverrides.nanAggPsNote) + .withPsNote(TypeEnum.FLOAT, GpuOverrides.nanAggPsNote), + TypeSig.ARRAY.nested(TypeSig.orderable)), + (in, conf, p, r) => new UnaryExprMeta[ArrayMin](in, conf, p, r) { + override def tagExprForGpu(): Unit = { + GpuOverrides.checkAndTagFloatNanAgg("Min", in.dataType, conf, this) + } + }), + expr[ArrayMax]( + "Returns the maximum value in the array", + ExprChecks.unaryProject( + TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL, + TypeSig.orderable, + TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) + .withPsNote(TypeEnum.DOUBLE, GpuOverrides.nanAggPsNote) + .withPsNote(TypeEnum.FLOAT, GpuOverrides.nanAggPsNote), + TypeSig.ARRAY.nested(TypeSig.orderable)), + (in, conf, p, r) => new UnaryExprMeta[ArrayMax](in, conf, p, r) { + override def tagExprForGpu(): Unit = { + GpuOverrides.checkAndTagFloatNanAgg("Max", in.dataType, conf, this) + } + + }), + expr[CreateNamedStruct]( + "Creates a struct with the given field names and values", + CreateNamedStructCheck, + (in, conf, p, r) => new ExprMeta[CreateNamedStruct](in, conf, p, r) { + }), + expr[ArrayContains]( + "Returns a boolean if the array contains the passed in key", + ExprChecks.binaryProject( + TypeSig.BOOLEAN, + TypeSig.BOOLEAN, + ("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.NULL), + TypeSig.ARRAY.nested(TypeSig.all)), + ("key", TypeSig.commonCudfTypes + .withPsNote(TypeEnum.DOUBLE, "NaN literals are not supported. Columnar input" + + s" must not contain NaNs and ${RapidsConf.HAS_NANS} must be false.") + .withPsNote(TypeEnum.FLOAT, "NaN literals are not supported. Columnar input" + + s" must not contain NaNs and ${RapidsConf.HAS_NANS} must be false."), + TypeSig.all)), + (in, conf, p, r) => new BinaryExprMeta[ArrayContains](in, conf, p, r) { + override def tagExprForGpu(): Unit = { + // do not support literal arrays as LHS + if (extractLit(in.left).isDefined) { + willNotWorkOnGpu("Literal arrays are not supported for array_contains") + } + + val rhsVal = extractLit(in.right) + val mightHaveNans = (in.right.dataType, rhsVal) match { + case (FloatType, Some(f: Literal)) => f.value.asInstanceOf[Float].isNaN + case (DoubleType, Some(d: Literal)) => d.value.asInstanceOf[Double].isNaN + case (FloatType | DoubleType, None) => conf.hasNans // RHS is a column + case _ => false + } + if (mightHaveNans) { + willNotWorkOnGpu("Comparisons with NaN values are not supported and" + + "will compute incorrect results. If it is known that there are no NaNs, set " + + s" ${RapidsConf.HAS_NANS} to false.") + } + } + }), + expr[SortArray]( + "Returns a sorted array with the input array and the ascending / descending order", + ExprChecks.binaryProject( + TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128), + TypeSig.ARRAY.nested(TypeSig.all), + ("array", TypeSig.ARRAY.nested( + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128), + TypeSig.ARRAY.nested(TypeSig.all)), + ("ascendingOrder", TypeSig.lit(TypeEnum.BOOLEAN), TypeSig.lit(TypeEnum.BOOLEAN))), + (sortExpression, conf, p, r) => new BinaryExprMeta[SortArray](sortExpression, conf, p, r) { + }), + expr[CreateArray]( + "Returns an array with the given elements", + ExprChecks.projectOnly( + TypeSig.ARRAY.nested(TypeSig.gpuNumeric + + TypeSig.NULL + TypeSig.STRING + TypeSig.BOOLEAN + TypeSig.DATE + TypeSig.TIMESTAMP + + TypeSig.ARRAY + TypeSig.STRUCT), + TypeSig.ARRAY.nested(TypeSig.all), + repeatingParamCheck = Some(RepeatingParamCheck("arg", + TypeSig.gpuNumeric + TypeSig.NULL + TypeSig.STRING + + TypeSig.BOOLEAN + TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.STRUCT + + TypeSig.ARRAY.nested(TypeSig.gpuNumeric + TypeSig.NULL + TypeSig.STRING + + TypeSig.BOOLEAN + TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.STRUCT + + TypeSig.ARRAY), + TypeSig.all))), + (in, conf, p, r) => new ExprMeta[CreateArray](in, conf, p, r) { + + override def tagExprForGpu(): Unit = { + wrapped.dataType match { + case ArrayType(ArrayType(ArrayType(_, _), _), _) => + willNotWorkOnGpu("Only support to create array or array of array, Found: " + + s"${wrapped.dataType}") + case _ => + } + } + + }), + expr[LambdaFunction]( + "Holds a higher order SQL function", + ExprChecks.projectOnly( + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.ARRAY + + TypeSig.STRUCT + TypeSig.MAP).nested(), + TypeSig.all, + Seq(ParamCheck("function", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.ARRAY + + TypeSig.STRUCT + TypeSig.MAP).nested(), + TypeSig.all)), + Some(RepeatingParamCheck("arguments", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.ARRAY + + TypeSig.STRUCT + TypeSig.MAP).nested(), + TypeSig.all))), + (in, conf, p, r) => new ExprMeta[LambdaFunction](in, conf, p, r) { + }), + expr[NamedLambdaVariable]( + "A parameter to a higher order SQL function", + ExprChecks.projectOnly( + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.ARRAY + + TypeSig.STRUCT + TypeSig.MAP).nested(), + TypeSig.all), + (in, conf, p, r) => new ExprMeta[NamedLambdaVariable](in, conf, p, r) { + }), + expr[ArrayTransform]( + "Transform elements in an array using the transform function. This is similar to a `map` " + + "in functional programming", + ExprChecks.projectOnly(TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), + TypeSig.ARRAY.nested(TypeSig.all), + Seq( + ParamCheck("argument", + TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), + TypeSig.ARRAY.nested(TypeSig.all)), + ParamCheck("function", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP).nested(), + TypeSig.all))), + (in, conf, p, r) => new ExprMeta[ArrayTransform](in, conf, p, r) { + }), + expr[StringLocate]( + "Substring search operator", + ExprChecks.projectOnly(TypeSig.INT, TypeSig.INT, + Seq(ParamCheck("substr", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), + ParamCheck("str", TypeSig.STRING, TypeSig.STRING), + ParamCheck("start", TypeSig.lit(TypeEnum.INT), TypeSig.INT))), + (in, conf, p, r) => new TernaryExprMeta[StringLocate](in, conf, p, r) { + }), + expr[Substring]( + "Substring operator", + ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING + TypeSig.BINARY, + Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING + TypeSig.BINARY), + ParamCheck("pos", TypeSig.lit(TypeEnum.INT), TypeSig.INT), + ParamCheck("len", TypeSig.lit(TypeEnum.INT), TypeSig.INT))), + (in, conf, p, r) => new TernaryExprMeta[Substring](in, conf, p, r) { + }), + expr[SubstringIndex]( + "substring_index operator", + ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, + Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), + ParamCheck("delim", TypeSig.lit(TypeEnum.STRING) + .withPsNote(TypeEnum.STRING, "only a single character is allowed"), TypeSig.STRING), + ParamCheck("count", TypeSig.lit(TypeEnum.INT), TypeSig.INT))), + (in, conf, p, r) => new SubstringIndexMeta(in, conf, p, r)), + expr[StringRepeat]( + "StringRepeat operator that repeats the given strings with numbers of times " + + "given by repeatTimes", + ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, + Seq(ParamCheck("input", TypeSig.STRING, TypeSig.STRING), + ParamCheck("repeatTimes", TypeSig.INT, TypeSig.INT))), + (in, conf, p, r) => new BinaryExprMeta[StringRepeat](in, conf, p, r) { + }), + expr[StringReplace]( + "StringReplace operator", + ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, + Seq(ParamCheck("src", TypeSig.STRING, TypeSig.STRING), + ParamCheck("search", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), + ParamCheck("replace", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), + (in, conf, p, r) => new TernaryExprMeta[StringReplace](in, conf, p, r) { + }), + expr[StringTrim]( + "StringTrim operator", + ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, + Seq(ParamCheck("src", TypeSig.STRING, TypeSig.STRING)), + // Should really be an OptionalParam + Some(RepeatingParamCheck("trimStr", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), + (in, conf, p, r) => new String2TrimExpressionMeta[StringTrim](in, conf, p, r) { + }), + expr[StringTrimLeft]( + "StringTrimLeft operator", + ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, + Seq(ParamCheck("src", TypeSig.STRING, TypeSig.STRING)), + // Should really be an OptionalParam + Some(RepeatingParamCheck("trimStr", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), + (in, conf, p, r) => + new String2TrimExpressionMeta[StringTrimLeft](in, conf, p, r) { + }), + expr[StringTrimRight]( + "StringTrimRight operator", + ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, + Seq(ParamCheck("src", TypeSig.STRING, TypeSig.STRING)), + // Should really be an OptionalParam + Some(RepeatingParamCheck("trimStr", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), + (in, conf, p, r) => + new String2TrimExpressionMeta[StringTrimRight](in, conf, p, r) { + }), + expr[StartsWith]( + "Starts with", + ExprChecks.binaryProject(TypeSig.BOOLEAN, TypeSig.BOOLEAN, + ("src", TypeSig.STRING, TypeSig.STRING), + ("search", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING)), + (a, conf, p, r) => new BinaryExprMeta[StartsWith](a, conf, p, r) { + }), + expr[EndsWith]( + "Ends with", + ExprChecks.binaryProject(TypeSig.BOOLEAN, TypeSig.BOOLEAN, + ("src", TypeSig.STRING, TypeSig.STRING), + ("search", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING)), + (a, conf, p, r) => new BinaryExprMeta[EndsWith](a, conf, p, r) { + }), + expr[Concat]( + "List/String concatenate", + ExprChecks.projectOnly((TypeSig.STRING + TypeSig.ARRAY).nested( + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128), + (TypeSig.STRING + TypeSig.BINARY + TypeSig.ARRAY).nested(TypeSig.all), + repeatingParamCheck = Some(RepeatingParamCheck("input", + (TypeSig.STRING + TypeSig.ARRAY).nested( + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128), + (TypeSig.STRING + TypeSig.BINARY + TypeSig.ARRAY).nested(TypeSig.all)))), + (a, conf, p, r) => new ComplexTypeMergingExprMeta[Concat](a, conf, p, r) { + }), + expr[ConcatWs]( + "Concatenates multiple input strings or array of strings into a single " + + "string using a given separator", + ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, + repeatingParamCheck = Some(RepeatingParamCheck("input", + (TypeSig.STRING + TypeSig.ARRAY).nested(TypeSig.STRING), + (TypeSig.STRING + TypeSig.ARRAY).nested(TypeSig.STRING)))), + (a, conf, p, r) => new ExprMeta[ConcatWs](a, conf, p, r) { + override def tagExprForGpu(): Unit = { + if (a.children.size <= 1) { + // If only a separator specified and its a column, Spark returns an empty + // string for all entries unless they are null, then it returns null. + // This seems like edge case so instead of handling on GPU just fallback. + willNotWorkOnGpu("Only specifying separator column not supported on GPU") + } + } + }), + expr[Murmur3Hash] ( + "Murmur3 hash operator", + ExprChecks.projectOnly(TypeSig.INT, TypeSig.INT, + repeatingParamCheck = Some(RepeatingParamCheck("input", + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_64 + TypeSig.STRUCT).nested(), + TypeSig.all))), + (a, conf, p, r) => new ExprMeta[Murmur3Hash](a, conf, p, r) { + override val childExprs: Seq[BaseExprMeta[_]] = a.children + .map(GpuOverrides.wrapExpr(_, conf, Some(this))) + }), + expr[Contains]( + "Contains", + ExprChecks.binaryProject(TypeSig.BOOLEAN, TypeSig.BOOLEAN, + ("src", TypeSig.STRING, TypeSig.STRING), + ("search", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING)), + (a, conf, p, r) => new BinaryExprMeta[Contains](a, conf, p, r) { + }), + expr[Like]( + "Like", + ExprChecks.binaryProject(TypeSig.BOOLEAN, TypeSig.BOOLEAN, + ("src", TypeSig.STRING, TypeSig.STRING), + ("search", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING)), + (a, conf, p, r) => new BinaryExprMeta[Like](a, conf, p, r) { + }), + expr[RLike]( + "RLike", + ExprChecks.binaryProject(TypeSig.BOOLEAN, TypeSig.BOOLEAN, + ("str", TypeSig.STRING, TypeSig.STRING), + ("regexp", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING)), + (a, conf, p, r) => new GpuRLikeMeta(a, conf, p, r)).disabledByDefault( + "the implementation is not 100% compatible. " + + "See the compatibility guide for more information."), + expr[RegExpExtract]( + "RegExpExtract", + ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, + Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), + ParamCheck("regexp", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), + ParamCheck("idx", TypeSig.lit(TypeEnum.INT), + TypeSig.lit(TypeEnum.INT)))), + (a, conf, p, r) => new GpuRegExpExtractMeta(a, conf, p, r)) + .disabledByDefault( + "the implementation is not 100% compatible. " + + "See the compatibility guide for more information."), + expr[Length]( + "String character length or binary byte length", + ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, + TypeSig.STRING, TypeSig.STRING + TypeSig.BINARY), + (a, conf, p, r) => new UnaryExprMeta[Length](a, conf, p, r) { + }), + expr[Size]( + "The size of an array or a map", + ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, + (TypeSig.ARRAY + TypeSig.MAP).nested(TypeSig.commonCudfTypes + TypeSig.NULL + + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), + (TypeSig.ARRAY + TypeSig.MAP).nested(TypeSig.all)), + (a, conf, p, r) => new UnaryExprMeta[Size](a, conf, p, r) { + }), + expr[UnscaledValue]( + "Convert a Decimal to an unscaled long value for some aggregation optimizations", + ExprChecks.unaryProject(TypeSig.LONG, TypeSig.LONG, + TypeSig.DECIMAL_64, TypeSig.DECIMAL_128), + (a, conf, p, r) => new UnaryExprMeta[UnscaledValue](a, conf, p, r) { + override val isFoldableNonLitAllowed: Boolean = true + }), + expr[MakeDecimal]( + "Create a Decimal from an unscaled long value for some aggregation optimizations", + ExprChecks.unaryProject(TypeSig.DECIMAL_64, TypeSig.DECIMAL_128, + TypeSig.LONG, TypeSig.LONG), + (a, conf, p, r) => new UnaryExprMeta[MakeDecimal](a, conf, p, r) { + }), + expr[Explode]( + "Given an input array produces a sequence of rows for each value in the array", + ExprChecks.unaryProject( + // Here is a walk-around representation, since multi-level nested type is not supported yet. + // related issue: https://github.com/NVIDIA/spark-rapids/issues/1901 + TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), + TypeSig.ARRAY.nested(TypeSig.all), + (TypeSig.ARRAY + TypeSig.MAP).nested(TypeSig.commonCudfTypes + TypeSig.NULL + + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), + (TypeSig.ARRAY + TypeSig.MAP).nested(TypeSig.all)), + (a, conf, p, r) => new GeneratorExprMeta[Explode](a, conf, p, r) { + override val supportOuter: Boolean = true + }), + expr[PosExplode]( + "Given an input array produces a sequence of rows for each value in the array", + ExprChecks.unaryProject( + // Here is a walk-around representation, since multi-level nested type is not supported yet. + // related issue: https://github.com/NVIDIA/spark-rapids/issues/1901 + TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), + TypeSig.ARRAY.nested(TypeSig.all), + (TypeSig.ARRAY + TypeSig.MAP).nested(TypeSig.commonCudfTypes + TypeSig.NULL + + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), + (TypeSig.ARRAY + TypeSig.MAP).nested(TypeSig.all)), + (a, conf, p, r) => new GeneratorExprMeta[PosExplode](a, conf, p, r) { + override val supportOuter: Boolean = true + }), + expr[ReplicateRows]( + "Given an input row replicates the row N times", + ExprChecks.projectOnly( + // The plan is optimized to run HashAggregate on the rows to be replicated. + // HashAggregateExec doesn't support grouping by 128-bit decimal value yet. + // Issue to track decimal 128 support: https://github.com/NVIDIA/spark-rapids/issues/4410 + TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.ARRAY + TypeSig.STRUCT), + TypeSig.ARRAY.nested(TypeSig.all), + repeatingParamCheck = Some(RepeatingParamCheck("input", + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all))), + (a, conf, p, r) => new ReplicateRowsExprMeta[ReplicateRows](a, conf, p, r) { + }), + expr[StddevPop]( + "Aggregation computing population standard deviation", + ExprChecks.groupByOnly( + TypeSig.DOUBLE, TypeSig.DOUBLE, + Seq(ParamCheck("input", TypeSig.DOUBLE, TypeSig.DOUBLE))), + (a, conf, p, r) => new AggExprMeta[StddevPop](a, conf, p, r) { + }), + expr[StddevSamp]( + "Aggregation computing sample standard deviation", + ExprChecks.aggNotReduction( + TypeSig.DOUBLE, TypeSig.DOUBLE, + Seq(ParamCheck("input", TypeSig.DOUBLE, + TypeSig.DOUBLE))), + (a, conf, p, r) => new AggExprMeta[StddevSamp](a, conf, p, r) { + }), + expr[VariancePop]( + "Aggregation computing population variance", + ExprChecks.groupByOnly( + TypeSig.DOUBLE, TypeSig.DOUBLE, + Seq(ParamCheck("input", TypeSig.DOUBLE, TypeSig.DOUBLE))), + (a, conf, p, r) => new AggExprMeta[VariancePop](a, conf, p, r) { + }), + expr[VarianceSamp]( + "Aggregation computing sample variance", + ExprChecks.groupByOnly( + TypeSig.DOUBLE, TypeSig.DOUBLE, + Seq(ParamCheck("input", TypeSig.DOUBLE, TypeSig.DOUBLE))), + (a, conf, p, r) => new AggExprMeta[VarianceSamp](a, conf, p, r) { + }), + expr[ApproximatePercentile]( + "Approximate percentile", + ExprChecks.groupByOnly( + // note that output can be single number or array depending on whether percentiles param + // is a single number or an array + TypeSig.gpuNumeric + + TypeSig.ARRAY.nested(TypeSig.gpuNumeric), + TypeSig.cpuNumeric + TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.ARRAY.nested( + TypeSig.cpuNumeric + TypeSig.DATE + TypeSig.TIMESTAMP), + Seq( + ParamCheck("input", + TypeSig.gpuNumeric, + TypeSig.cpuNumeric + TypeSig.DATE + TypeSig.TIMESTAMP), + ParamCheck("percentage", + TypeSig.DOUBLE + TypeSig.ARRAY.nested(TypeSig.DOUBLE), + TypeSig.DOUBLE + TypeSig.ARRAY.nested(TypeSig.DOUBLE)), + ParamCheck("accuracy", TypeSig.INT, TypeSig.INT))), + (c, conf, p, r) => new TypedImperativeAggExprMeta[ApproximatePercentile](c, conf, p, r) { + + override def tagAggForGpu(): Unit = { + // check if the percentile expression can be supported on GPU + childExprs(1).wrapped match { + case lit: Literal => lit.value match { + case null => + willNotWorkOnGpu( + "approx_percentile on GPU only supports non-null literal percentiles") + case a: ArrayData if a.numElements == 0 => + willNotWorkOnGpu( + "approx_percentile on GPU does not support empty percentiles arrays") + case a: ArrayData if (0 until a.numElements).exists(a.isNullAt) => + willNotWorkOnGpu( + "approx_percentile on GPU does not support percentiles arrays containing nulls") + case _ => + // this is fine + } + case _ => + willNotWorkOnGpu("approx_percentile on GPU only supports literal percentiles") + } + } + override def aggBufferAttribute: AttributeReference = { + // Spark's ApproxPercentile has an aggregation buffer named "buf" with type "BinaryType" + // so we need to replace that here with the GPU aggregation buffer reference, which is + // a t-digest type + val aggBuffer = c.aggBufferAttributes.head + aggBuffer.copy(dataType = CudfTDigest.dataType)(aggBuffer.exprId, aggBuffer.qualifier) + } + }).incompat("the GPU implementation of approx_percentile is not bit-for-bit " + + s"compatible with Apache Spark. To enable it, set ${RapidsConf.INCOMPATIBLE_OPS}"), + expr[GetJsonObject]( + "Extracts a json object from path", + ExprChecks.projectOnly( + TypeSig.STRING, TypeSig.STRING, Seq(ParamCheck("json", TypeSig.STRING, TypeSig.STRING), + ParamCheck("path", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), + (a, conf, p, r) => new BinaryExprMeta[GetJsonObject](a, conf, p, r) { + } + ), + expr[ScalarSubquery]( + "Subquery that will return only one row and one column", + ExprChecks.projectOnly( + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + Nil, None), + (a, conf, p, r) => new ExprMeta[ScalarSubquery](a, conf, p, r) { + }), + expr[CreateMap]( + desc = "Create a map", + CreateMapCheck, + (a, conf, p, r) => new ExprMeta[CreateMap](a, conf, p, r) { + }), + expr[Sequence]( + desc = "Sequence", + ExprChecks.projectOnly( + TypeSig.ARRAY.nested(TypeSig.integral), TypeSig.ARRAY.nested(TypeSig.integral + + TypeSig.TIMESTAMP + TypeSig.DATE), + Seq(ParamCheck("start", TypeSig.integral, TypeSig.integral + TypeSig.TIMESTAMP + + TypeSig.DATE), + ParamCheck("stop", TypeSig.integral, TypeSig.integral + TypeSig.TIMESTAMP + + TypeSig.DATE)), + Some(RepeatingParamCheck("step", TypeSig.integral, TypeSig.integral + TypeSig.CALENDAR))), + (a, conf, p, r) => new GpuSequenceMeta(a, conf, p, r) + ) + ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap + + // Shim expressions should be last to allow overrides with shim-specific versions + val expressions: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = + commonExpressions ++ GpuHiveOverrides.exprs ++ ShimGpuOverrides.shimExpressions + + def wrapPart[INPUT <: Partitioning]( + part: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]]): PartMeta[INPUT] = + parts.get(part.getClass) + .map(r => r.wrap(part, conf, parent, r).asInstanceOf[PartMeta[INPUT]]) + .getOrElse(new RuleNotFoundPartMeta(part, conf, parent)) + + val parts : Map[Class[_ <: Partitioning], PartRule[_ <: Partitioning]] = Seq( + part[HashPartitioning]( + "Hash based partitioning", + // This needs to match what murmur3 supports. + PartChecks(RepeatingParamCheck("hash_key", + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_64 + TypeSig.STRUCT).nested(), + TypeSig.all)), + (hp, conf, p, r) => new PartMeta[HashPartitioning](hp, conf, p, r) { + override val childExprs: Seq[BaseExprMeta[_]] = + hp.expressions.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + + }), + part[RangePartitioning]( + "Range partitioning", + PartChecks(RepeatingParamCheck("order_key", + (pluginSupportedOrderableSig + TypeSig.DECIMAL_128 + TypeSig.STRUCT).nested(), + TypeSig.orderable)), + (rp, conf, p, r) => new PartMeta[RangePartitioning](rp, conf, p, r) { + override val childExprs: Seq[BaseExprMeta[_]] = + rp.ordering.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + + }), + part[RoundRobinPartitioning]( + "Round robin partitioning", + PartChecks(), + (rrp, conf, p, r) => new PartMeta[RoundRobinPartitioning](rrp, conf, p, r) { + }), + part[SinglePartition.type]( + "Single partitioning", + PartChecks(), + (sp, conf, p, r) => new PartMeta[SinglePartition.type](sp, conf, p, r) { + }) + ).map(r => (r.getClassFor.asSubclass(classOf[Partitioning]), r)).toMap + + def wrapDataWriteCmds[INPUT <: DataWritingCommand]( + writeCmd: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]]): DataWritingCommandMeta[INPUT] = + dataWriteCmds.get(writeCmd.getClass) + .map(r => r.wrap(writeCmd, conf, parent, r).asInstanceOf[DataWritingCommandMeta[INPUT]]) + .getOrElse(new RuleNotFoundDataWritingCommandMeta(writeCmd, conf, parent)) + + val dataWriteCmds: Map[Class[_ <: DataWritingCommand], + DataWritingCommandRule[_ <: DataWritingCommand]] = Seq( + dataWriteCmd[InsertIntoHadoopFsRelationCommand]( + "Write to Hadoop filesystem", + (a, conf, p, r) => new InsertIntoHadoopFsRelationCommandMeta(a, conf, p, r)), + dataWriteCmd[CreateDataSourceTableAsSelectCommand]( + "Create table with select command", + (a, conf, p, r) => new CreateDataSourceTableAsSelectCommandMeta(a, conf, p, r)) + ).map(r => (r.getClassFor.asSubclass(classOf[DataWritingCommand]), r)).toMap + + def wrapPlan[INPUT <: SparkPlan]( + plan: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]]): SparkPlanMeta[INPUT] = + execs.get(plan.getClass) + .map(r => r.wrap(plan, conf, parent, r).asInstanceOf[SparkPlanMeta[INPUT]]) + .getOrElse(new RuleNotFoundSparkPlanMeta(plan, conf, parent)) + + val commonExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = Seq( + exec[GenerateExec] ( + "The backend for operations that generate more output rows than input rows like explode", + ExecChecks( + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.ARRAY + + TypeSig.STRUCT + TypeSig.MAP).nested(), + TypeSig.all), + (gen, conf, p, r) => new GpuGenerateExecSparkPlanMeta(gen, conf, p, r)), + exec[ProjectExec]( + "The backend for most select, withColumn and dropColumn statements", + ExecChecks( + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + + TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), + TypeSig.all), + (proj, conf, p, r) => new GpuProjectExecMeta(proj, conf, p, r)), + exec[RangeExec]( + "The backend for range operator", + ExecChecks(TypeSig.LONG, TypeSig.LONG), + (range, conf, p, r) => new SparkPlanMeta[RangeExec](range, conf, p, r) { + }), + exec[CoalesceExec]( + "The backend for the dataframe coalesce method", + ExecChecks((_gpuCommonTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT + TypeSig.ARRAY + + TypeSig.MAP).nested(), + TypeSig.all), + (coalesce, conf, parent, r) => new SparkPlanMeta[CoalesceExec](coalesce, conf, parent, r) { + }), + exec[DataWritingCommandExec]( + "Writing data", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_128.withPsNote( + TypeEnum.DECIMAL, "128bit decimal only supported for Orc and Parquet") + + TypeSig.STRUCT.withPsNote(TypeEnum.STRUCT, "Only supported for Parquet") + + TypeSig.MAP.withPsNote(TypeEnum.MAP, "Only supported for Parquet") + + TypeSig.ARRAY.withPsNote(TypeEnum.ARRAY, "Only supported for Parquet")).nested(), + TypeSig.all), + (p, conf, parent, r) => new SparkPlanMeta[DataWritingCommandExec](p, conf, parent, r) { + override val childDataWriteCmds: scala.Seq[DataWritingCommandMeta[_]] = + Seq(GpuOverrides.wrapDataWriteCmds(p.cmd, conf, Some(this))) + + }), + exec[TakeOrderedAndProjectExec]( + "Take the first limit elements as defined by the sortOrder, and do projection if needed", + // The SortOrder TypeSig will govern what types can actually be used as sorting key data type. + // The types below are allowed as inputs and outputs. + ExecChecks((pluginSupportedOrderableSig + TypeSig.DECIMAL_128 + + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP).nested(), TypeSig.all), + (takeExec, conf, p, r) => + new SparkPlanMeta[TakeOrderedAndProjectExec](takeExec, conf, p, r) { + val sortOrder: Seq[BaseExprMeta[SortOrder]] = + takeExec.sortOrder.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val projectList: Seq[BaseExprMeta[NamedExpression]] = + takeExec.projectList.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + override val childExprs: Seq[BaseExprMeta[_]] = sortOrder ++ projectList + + }), + exec[LocalLimitExec]( + "Per-partition limiting of results", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP).nested(), + TypeSig.all), + (localLimitExec, conf, p, r) => + new SparkPlanMeta[LocalLimitExec](localLimitExec, conf, p, r) { + }), + exec[GlobalLimitExec]( + "Limiting of results across partitions", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP).nested(), + TypeSig.all), + (globalLimitExec, conf, p, r) => + new SparkPlanMeta[GlobalLimitExec](globalLimitExec, conf, p, r) { + }), + exec[CollectLimitExec]( + "Reduce to single partition and apply limit", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP).nested(), + TypeSig.all), + (collectLimitExec, conf, p, r) => + new SparkPlanMeta[CollectLimitExec](collectLimitExec, conf, p, r) { + override val childParts: scala.Seq[PartMeta[_]] = + Seq(GpuOverrides.wrapPart(collectLimitExec.outputPartitioning, conf, Some(this)))}) + .disabledByDefault("Collect Limit replacement can be slower on the GPU, if huge number " + + "of rows in a batch it could help by limiting the number of rows transferred from " + + "GPU to CPU"), + exec[FilterExec]( + "The backend for most filter statements", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + + TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), TypeSig.all), + (filter, conf, p, r) => new SparkPlanMeta[FilterExec](filter, conf, p, r) { + }), + exec[ShuffleExchangeExec]( + "The backend for most data being exchanged between processes", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP).nested() + .withPsNote(TypeEnum.STRUCT, "Round-robin partitioning is not supported for nested " + + s"structs if ${SQLConf.SORT_BEFORE_REPARTITION.key} is true") + .withPsNote(TypeEnum.ARRAY, "Round-robin partitioning is not supported if " + + s"${SQLConf.SORT_BEFORE_REPARTITION.key} is true") + .withPsNote(TypeEnum.MAP, "Round-robin partitioning is not supported if " + + s"${SQLConf.SORT_BEFORE_REPARTITION.key} is true"), + TypeSig.all), + (shuffle, conf, p, r) => new GpuShuffleMeta(shuffle, conf, p, r)), + exec[UnionExec]( + "The backend for the union operator", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.MAP + TypeSig.ARRAY + TypeSig.STRUCT).nested() + .withPsNote(TypeEnum.STRUCT, + "unionByName will not optionally impute nulls for missing struct fields " + + "when the column is a struct and there are non-overlapping fields"), TypeSig.all), + (union, conf, p, r) => new SparkPlanMeta[UnionExec](union, conf, p, r) { + }), + exec[BroadcastExchangeExec]( + "The backend for broadcast exchange of data", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP).nested(TypeSig.commonCudfTypes + + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.STRUCT), + TypeSig.all), + (exchange, conf, p, r) => new GpuBroadcastMeta(exchange, conf, p, r)), + exec[BroadcastHashJoinExec]( + "Implementation of join using broadcast data", + JoinTypeChecks.equiJoinExecChecks, + (join, conf, p, r) => new GpuBroadcastHashJoinMeta(join, conf, p, r)), + exec[BroadcastNestedLoopJoinExec]( + "Implementation of join using brute force. Full outer joins and joins where the " + + "broadcast side matches the join side (e.g.: LeftOuter with left broadcast) are not " + + "supported", + JoinTypeChecks.nonEquiJoinChecks, + (join, conf, p, r) => new GpuBroadcastNestedLoopJoinMeta(join, conf, p, r)), + exec[CartesianProductExec]( + "Implementation of join using brute force", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT) + .nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT), + TypeSig.all), + (join, conf, p, r) => new SparkPlanMeta[CartesianProductExec](join, conf, p, r) { + val condition: Option[BaseExprMeta[_]] = + join.condition.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + + override val childExprs: Seq[BaseExprMeta[_]] = condition.toSeq + + }), + exec[HashAggregateExec]( + "The backend for hash based aggregations", + ExecChecks( + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.MAP + TypeSig.ARRAY + TypeSig.STRUCT) + .nested() + .withPsNote(TypeEnum.ARRAY, "not allowed for grouping expressions") + .withPsNote(TypeEnum.MAP, "not allowed for grouping expressions") + .withPsNote(TypeEnum.STRUCT, + "not allowed for grouping expressions if containing Array or Map as child"), + TypeSig.all), + (agg, conf, p, r) => new GpuHashAggregateMeta(agg, conf, p, r)), + exec[ShuffledHashJoinExec]( + "Implementation of join using hashed shuffled data", + JoinTypeChecks.equiJoinExecChecks, + (join, conf, p, r) => new GpuShuffledHashJoinMeta(join, conf, p, r)), + exec[SortAggregateExec]( + "The backend for sort based aggregations", + // SPARK 2.x we can't check for the TypedImperativeAggregate properly so + // map/arrya/struct left off + ExecChecks( + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.MAP + TypeSig.BINARY) + .nested() + .withPsNote(TypeEnum.BINARY, "only allowed when aggregate buffers can be " + + "converted between CPU and GPU") + .withPsNote(TypeEnum.MAP, "not allowed for grouping expressions"), + TypeSig.all), + (agg, conf, p, r) => new GpuSortAggregateExecMeta(agg, conf, p, r)), + // SPARK 2.x we can't check for the TypedImperativeAggregate properly so don't say we do the + // ObjectHashAggregate + exec[SortExec]( + "The backend for the sort operator", + // The SortOrder TypeSig will govern what types can actually be used as sorting key data type. + // The types below are allowed as inputs and outputs. + ExecChecks((pluginSupportedOrderableSig + TypeSig.DECIMAL_128 + TypeSig.ARRAY + + TypeSig.STRUCT +TypeSig.MAP + TypeSig.BINARY).nested(), TypeSig.all), + (sort, conf, p, r) => new GpuSortMeta(sort, conf, p, r)), + exec[SortMergeJoinExec]( + "Sort merge join, replacing with shuffled hash join", + JoinTypeChecks.equiJoinExecChecks, + (join, conf, p, r) => new GpuSortMergeJoinMeta(join, conf, p, r)), + exec[ExpandExec]( + "The backend for the expand operator", + ExecChecks( + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_64 + + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP).nested(), + TypeSig.all), + (expand, conf, p, r) => new GpuExpandExecMeta(expand, conf, p, r)), + exec[WindowExec]( + "Window-operator backend", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP).nested(), + TypeSig.all, + Map("partitionSpec" -> + InputCheck(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_64, TypeSig.all))), + (windowOp, conf, p, r) => + new GpuWindowExecMeta(windowOp, conf, p, r) + ), + exec[SampleExec]( + "The backend for the sample operator", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + + TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), TypeSig.all), + (sample, conf, p, r) => new GpuSampleExecMeta(sample, conf, p, r) {} + ), + // ShimLoader.getSparkShims.aqeShuffleReaderExec, + // ShimLoader.getSparkShims.neverReplaceShowCurrentNamespaceCommand, + neverReplaceExec[ExecutedCommandExec]("Table metadata operation") + ).collect { case r if r != null => (r.getClassFor.asSubclass(classOf[SparkPlan]), r) }.toMap + + lazy val execs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = + commonExecs ++ ShimGpuOverrides.shimExecs + + def getTimeParserPolicy: TimeParserPolicy = { + // val key = SQLConf.LEGACY_TIME_PARSER_POLICY.key + val key = "2xgone" + val policy = SQLConf.get.getConfString(key, "EXCEPTION") + policy match { + case "LEGACY" => LegacyTimeParserPolicy + case "EXCEPTION" => ExceptionTimeParserPolicy + case "CORRECTED" => CorrectedTimeParserPolicy + } + } + + def wrapAndTagPlan(plan: SparkPlan, conf: RapidsConf): SparkPlanMeta[SparkPlan] = { + val wrap = GpuOverrides.wrapPlan(plan, conf, None) + wrap.tagForGpu() + wrap + } + + private def getOptimizations(wrap: SparkPlanMeta[SparkPlan], + conf: RapidsConf): Seq[Optimization] = { + Seq.empty + } + + private final class SortDataFromReplacementRule extends DataFromReplacementRule { + override val operationName: String = "Exec" + override def confKey = "spark.rapids.sql.exec.SortExec" + + override def getChecks: Option[TypeChecks[_]] = None + } + + // Only run the explain and don't actually convert or run on GPU. + def explainPotentialGpuPlan(df: DataFrame, explain: String = "ALL"): String = { + val plan = df.queryExecution.executedPlan + val conf = new RapidsConf(plan.conf) + val updatedPlan = prepareExplainOnly(plan) + // Here we look for subqueries to pull out and do the explain separately on them. + val subQueryExprs = getSubQueriesFromPlan(plan) + val preparedSubPlans = subQueryExprs.map(_.plan).map(prepareExplainOnly(_)) + val subPlanExplains = preparedSubPlans.map(explainSinglePlan(_, conf, explain)) + val topPlanExplain = explainSinglePlan(updatedPlan, conf, explain) + (subPlanExplains :+ topPlanExplain).mkString("\n") + } + + private def explainSinglePlan(updatedPlan: SparkPlan, conf: RapidsConf, + explain: String): String = { + val wrap = wrapAndTagPlan(updatedPlan, conf) + val reasonsToNotReplaceEntirePlan = wrap.getReasonsNotToReplaceEntirePlan + if (conf.allowDisableEntirePlan && reasonsToNotReplaceEntirePlan.nonEmpty) { + "Can't replace any part of this plan due to: " + + s"${reasonsToNotReplaceEntirePlan.mkString(",")}" + } else { + wrap.runAfterTagRules() + wrap.tagForExplain() + val shouldExplainAll = explain.equalsIgnoreCase("ALL") + wrap.explain(shouldExplainAll) + } + } + + private def getSubqueryExpressions(e: Expression): Seq[ExecSubqueryExpression] = { + val childExprs = e.children.flatMap(getSubqueryExpressions(_)) + val res = e match { + case sq: ExecSubqueryExpression => Seq(sq) + case _ => Seq.empty + } + childExprs ++ res + } + + private def getSubQueriesFromPlan(plan: SparkPlan): Seq[ExecSubqueryExpression] = { + val childPlans = plan.children.flatMap(getSubQueriesFromPlan) + val pSubs = plan.expressions.flatMap(getSubqueryExpressions) + childPlans ++ pSubs + } + + private def prepareExplainOnly(plan: SparkPlan): SparkPlan = { + // Strip out things that would have been added after our GPU plugin would have + // processed the plan. + // AQE we look at the input plan so pretty much just like if AQE wasn't enabled. + val planAfter = plan.transformUp { + case ia: InputAdapter => prepareExplainOnly(ia.child) + case ws: WholeStageCodegenExec => prepareExplainOnly(ws.child) + // case c2r: ColumnarToRowExec => prepareExplainOnly(c2r.child) + case re: ReusedExchangeExec => prepareExplainOnly(re.child) + // case aqe: AdaptiveSparkPlanExec => + // prepareExplainOnly(ShimLoader.getSparkShims.getAdaptiveInputPlan(aqe)) + case sub: SubqueryExec => prepareExplainOnly(sub.child) + } + planAfter + } +} + +// Spark 2.x +object GpuUserDefinedFunction { + // UDFs can support all types except UDT which does not have a clear columnar representation. + val udfTypeSig: TypeSig = (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + + TypeSig.BINARY + TypeSig.CALENDAR + TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT).nested() + +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala new file mode 100644 index 00000000000..3126a241cfe --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.TrampolineUtil +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType +import org.apache.spark.sql.types._ + +object GpuParquetFileFormat { + def tagGpuSupport( + meta: RapidsMeta[_, _], + spark: SparkSession, + options: Map[String, String], + schema: StructType): Unit = { + + val sqlConf = spark.sessionState.conf + val parquetOptions = new ParquetOptions(options, sqlConf) + + if (!meta.conf.isParquetEnabled) { + meta.willNotWorkOnGpu("Parquet input and output has been disabled. To enable set" + + s"${RapidsConf.ENABLE_PARQUET} to true") + } + + if (!meta.conf.isParquetWriteEnabled) { + meta.willNotWorkOnGpu("Parquet output has been disabled. To enable set" + + s"${RapidsConf.ENABLE_PARQUET_WRITE} to true") + } + + FileFormatChecks.tag(meta, schema, ParquetFormatType, WriteFileOp) + + parseCompressionType(parquetOptions.compressionCodecClassName) + .getOrElse(meta.willNotWorkOnGpu( + s"compression codec ${parquetOptions.compressionCodecClassName} is not supported")) + + if (sqlConf.writeLegacyParquetFormat) { + meta.willNotWorkOnGpu("Spark legacy format is not supported") + } + + if (!meta.conf.isParquetInt96WriteEnabled && sqlConf.parquetOutputTimestampType == + ParquetOutputTimestampType.INT96) { + meta.willNotWorkOnGpu(s"Writing INT96 is disabled, if you want to enable it turn it on by " + + s"setting the ${RapidsConf.ENABLE_PARQUET_INT96_WRITE} to true. NOTE: check " + + "out the compatibility.md to know about the limitations associated with INT96 writer") + } + + val schemaHasTimestamps = schema.exists { field => + TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType]) + } + if (schemaHasTimestamps) { + if(!isOutputTimestampTypeSupported(sqlConf.parquetOutputTimestampType)) { + meta.willNotWorkOnGpu(s"Output timestamp type " + + s"${sqlConf.parquetOutputTimestampType} is not supported") + } + } + + val schemaHasDates = schema.exists { field => + TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[DateType]) + } + + // Spark 2.x doesn't have the rebase mode because the changes of calendar type weren't made + // so just skip the checks, since this is just explain only it would depend on how + // they set when they get to 3.x. The default in 3.x is EXCEPTION which would be good + // for us. + /* + ShimLoader.getSparkShims.int96ParquetRebaseWrite(sqlConf) match { + case "EXCEPTION" => + case "CORRECTED" => + case "LEGACY" => + if (schemaHasTimestamps) { + meta.willNotWorkOnGpu("LEGACY rebase mode for int96 timestamps is not supported") + } + case other => + meta.willNotWorkOnGpu(s"$other is not a supported rebase mode for int96") + } + + ShimLoader.getSparkShims.parquetRebaseWrite(sqlConf) match { + case "EXCEPTION" => //Good + case "CORRECTED" => //Good + case "LEGACY" => + if (schemaHasDates || schemaHasTimestamps) { + meta.willNotWorkOnGpu("LEGACY rebase mode for dates and timestamps is not supported") + } + case other => + meta.willNotWorkOnGpu(s"$other is not a supported rebase mode") + } + */ + } + + // SPARK 2.X - just return String rather then CompressionType + def parseCompressionType(compressionType: String): Option[String] = { + compressionType match { + case "NONE" | "UNCOMPRESSED" => Some("NONE") + case "SNAPPY" => Some("SNAPPY") + case _ => None + } + } + + def isOutputTimestampTypeSupported( + outputTimestampType: ParquetOutputTimestampType.Value): Boolean = { + outputTimestampType match { + case ParquetOutputTimestampType.TIMESTAMP_MICROS | + ParquetOutputTimestampType.TIMESTAMP_MILLIS | + ParquetOutputTimestampType.INT96 => true + case _ => false + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScanBase.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScanBase.scala new file mode 100644 index 00000000000..e0b2fa20dc9 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScanBase.scala @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.{FileSourceScanExec, TrampolineUtil} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +object GpuReadParquetFileFormat { + def tagSupport(meta: SparkPlanMeta[FileSourceScanExec]): Unit = { + val fsse = meta.wrapped + GpuParquetScanBase.tagSupport( + fsse.sqlContext.sparkSession, + fsse.requiredSchema, + meta + ) + } +} + +object GpuParquetScanBase { + + def tagSupport( + sparkSession: SparkSession, + readSchema: StructType, + meta: RapidsMeta[_, _]): Unit = { + val sqlConf = sparkSession.conf + + if (!meta.conf.isParquetEnabled) { + meta.willNotWorkOnGpu("Parquet input and output has been disabled. To enable set" + + s"${RapidsConf.ENABLE_PARQUET} to true") + } + + if (!meta.conf.isParquetReadEnabled) { + meta.willNotWorkOnGpu("Parquet input has been disabled. To enable set" + + s"${RapidsConf.ENABLE_PARQUET_READ} to true") + } + + FileFormatChecks.tag(meta, readSchema, ParquetFormatType, ReadFileOp) + + val schemaHasStrings = readSchema.exists { field => + TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[StringType]) + } + + if (sqlConf.get(SQLConf.PARQUET_BINARY_AS_STRING.key, + SQLConf.PARQUET_BINARY_AS_STRING.defaultValueString).toBoolean && schemaHasStrings) { + meta.willNotWorkOnGpu(s"GpuParquetScan does not support" + + s" ${SQLConf.PARQUET_BINARY_AS_STRING.key}") + } + + val schemaHasTimestamps = readSchema.exists { field => + TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType]) + } + def isTsOrDate(dt: DataType) : Boolean = dt match { + case TimestampType | DateType => true + case _ => false + } + val schemaMightNeedNestedRebase = readSchema.exists { field => + if (DataTypeUtils.isNestedType(field.dataType)) { + TrampolineUtil.dataTypeExistsRecursively(field.dataType, isTsOrDate) + } else { + false + } + } + + // Currently timestamp conversion is not supported. + // If support needs to be added then we need to follow the logic in Spark's + // ParquetPartitionReaderFactory and VectorizedColumnReader which essentially + // does the following: + // - check if Parquet file was created by "parquet-mr" + // - if not then look at SQLConf.SESSION_LOCAL_TIMEZONE and assume timestamps + // were written in that timezone and convert them to UTC timestamps. + // Essentially this should boil down to a vector subtract of the scalar delta + // between the configured timezone's delta from UTC on the timestamp data. + if (schemaHasTimestamps && sparkSession.sessionState.conf.isParquetINT96TimestampConversion) { + meta.willNotWorkOnGpu("GpuParquetScan does not support int96 timestamp conversion") + } + + // Spark 2.x doesn't have the rebase mode because the changes of calendar type weren't made + // so just skip the checks, since this is just explain only it would depend on how + // they set when they get to 3.x. The default in 3.x is EXCEPTION which would be good + // for us. + + // Spark 2.x doesn't support the rebase mode + /* + sqlConf.get(ShimLoader.getSparkShims.int96ParquetRebaseReadKey) match { + case "EXCEPTION" => if (schemaMightNeedNestedRebase) { + meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + + s"${ShimLoader.getSparkShims.int96ParquetRebaseReadKey} is EXCEPTION") + } + case "CORRECTED" => // Good + case "LEGACY" => // really is EXCEPTION for us... + if (schemaMightNeedNestedRebase) { + meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + + s"${ShimLoader.getSparkShims.int96ParquetRebaseReadKey} is LEGACY") + } + case other => + meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode") + } + + sqlConf.get(ShimLoader.getSparkShims.parquetRebaseReadKey) match { + case "EXCEPTION" => if (schemaMightNeedNestedRebase) { + meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + + s"${ShimLoader.getSparkShims.parquetRebaseReadKey} is EXCEPTION") + } + case "CORRECTED" => // Good + case "LEGACY" => // really is EXCEPTION for us... + if (schemaMightNeedNestedRebase) { + meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + + s"${ShimLoader.getSparkShims.parquetRebaseReadKey} is LEGACY") + } + case other => + meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode") + } + */ + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadCSVFileFormat.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadCSVFileFormat.scala new file mode 100644 index 00000000000..22723a9a22d --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadCSVFileFormat.scala @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import com.nvidia.spark.rapids.shims.v2.GpuCSVScan + +import org.apache.spark.sql.execution.FileSourceScanExec + +object GpuReadCSVFileFormat { + def tagSupport(meta: SparkPlanMeta[FileSourceScanExec]): Unit = { + val fsse = meta.wrapped + GpuCSVScan.tagSupport( + fsse.sqlContext.sparkSession, + fsse.relation.dataSchema, + fsse.output.toStructType, + fsse.relation.options, + meta + ) + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadOrcFileFormat.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadOrcFileFormat.scala new file mode 100644 index 00000000000..ab2a409be1a --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadOrcFileFormat.scala @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.execution.FileSourceScanExec + +object GpuReadOrcFileFormat { + def tagSupport(meta: SparkPlanMeta[FileSourceScanExec]): Unit = { + val fsse = meta.wrapped + if (fsse.relation.options.getOrElse("mergeSchema", "false").toBoolean) { + meta.willNotWorkOnGpu("mergeSchema and schema evolution is not supported yet") + } + GpuOrcScanBase.tagSupport( + fsse.sqlContext.sparkSession, + fsse.requiredSchema, + meta + ) + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReplicateRowsMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReplicateRowsMeta.scala new file mode 100644 index 00000000000..f188439d6f5 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReplicateRowsMeta.scala @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.catalyst.expressions.ReplicateRows + +/** + * Base class for metadata around GeneratorExprMeta. + */ +abstract class ReplicateRowsExprMeta[INPUT <: ReplicateRows]( + gen: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends GeneratorExprMeta[INPUT](gen, conf, parent, rule) { +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExecMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExecMeta.scala new file mode 100644 index 00000000000..c180908a5b9 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExecMeta.scala @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.execution.SortExec + +sealed trait SortExecType extends Serializable + +object OutOfCoreSort extends SortExecType +object FullSortSingleBatch extends SortExecType +object SortEachBatch extends SortExecType + +class GpuSortMeta( + sort: SortExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends SparkPlanMeta[SortExec](sort, conf, parent, rule) { + + // Uses output attributes of child plan because SortExec will not change the attributes, + // and we need to propagate possible type conversions on the output attributes of + // GpuSortAggregateExec. + override protected val useOutputAttributesOfChild: Boolean = true + + // For transparent plan like ShuffleExchange, the accessibility of runtime data transition is + // depended on the next non-transparent plan. So, we need to trace back. + override val availableRuntimeDataTransition: Boolean = + childPlans.head.availableRuntimeDataTransition +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowMeta.scala new file mode 100644 index 00000000000..9aa5ebf5ab4 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowMeta.scala @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder, UnboundedPreceding} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.window.WindowExec +import org.apache.spark.sql.types._ + +/** + * Base class for GPU Execs that implement window functions. This abstracts the method + * by which the window function's input expressions, partition specs, order-by specs, etc. + * are extracted from the specific WindowExecType. + * + * @tparam WindowExecType The Exec class that implements window functions + * (E.g. o.a.s.sql.execution.window.WindowExec.) + */ +abstract class GpuBaseWindowExecMeta[WindowExecType <: SparkPlan] (windowExec: WindowExecType, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends SparkPlanMeta[WindowExecType](windowExec, conf, parent, rule) with Logging { + + /** + * Extracts window-expression from WindowExecType. + * The implementation varies, depending on the WindowExecType class. + */ + def getInputWindowExpressions: Seq[NamedExpression] + + /** + * Extracts partition-spec from WindowExecType. + * The implementation varies, depending on the WindowExecType class. + */ + def getPartitionSpecs: Seq[Expression] + + /** + * Extracts order-by spec from WindowExecType. + * The implementation varies, depending on the WindowExecType class. + */ + def getOrderSpecs: Seq[SortOrder] + + /** + * Indicates the output column semantics for the WindowExecType, + * i.e. whether to only return the window-expression result columns (as in some Spark + * distributions) or also include the input columns (as in Apache Spark). + */ + def getResultColumnsOnly: Boolean + + val windowExpressions: Seq[BaseExprMeta[NamedExpression]] = + getInputWindowExpressions.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val partitionSpec: Seq[BaseExprMeta[Expression]] = + getPartitionSpecs.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val orderSpec: Seq[BaseExprMeta[SortOrder]] = + getOrderSpecs.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + + lazy val inputFields: Seq[BaseExprMeta[Attribute]] = + windowExec.children.head.output.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + + + override def namedChildExprs: Map[String, Seq[BaseExprMeta[_]]] = Map( + "partitionSpec" -> partitionSpec + ) + + override def tagPlanForGpu(): Unit = { + // Implementation depends on receiving a `NamedExpression` wrapped WindowExpression. + windowExpressions.map(meta => meta.wrapped) + .filter(expr => !expr.isInstanceOf[NamedExpression]) + .foreach(_ => willNotWorkOnGpu("Unexpected query plan with Windowing functions; " + + "cannot convert for GPU execution. " + + "(Detail: WindowExpression not wrapped in `NamedExpression`.)")) + } + +} + +/** + * Specialization of GpuBaseWindowExecMeta for org.apache.spark.sql.window.WindowExec. + * This class implements methods to extract the window-expressions, partition columns, + * order-by columns, etc. from WindowExec. + */ +class GpuWindowExecMeta(windowExec: WindowExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends GpuBaseWindowExecMeta[WindowExec](windowExec, conf, parent, rule) { + + /** + * Fetches WindowExpressions in input `windowExec`, via reflection. + * As a byproduct, determines whether to return the original input columns, + * as part of the output. + * + * (Spark versions that use `projectList` expect result columns + * *not* to include the input columns. + * Apache Spark expects the input columns, before the aggregation output columns.) + * + * @return WindowExpressions within windowExec, + * and a boolean, indicating the result column semantics + * (i.e. whether result columns should be returned *without* including the + * input columns). + */ + def getWindowExpression: (Seq[NamedExpression], Boolean) = { + var resultColumnsOnly : Boolean = false + val expr = try { + val resultMethod = windowExec.getClass.getMethod("windowExpression") + resultMethod.invoke(windowExec).asInstanceOf[Seq[NamedExpression]] + } catch { + case _: NoSuchMethodException => + resultColumnsOnly = true + val winExpr = windowExec.getClass.getMethod("projectList") + winExpr.invoke(windowExec).asInstanceOf[Seq[NamedExpression]] + } + (expr, resultColumnsOnly) + } + + private lazy val (inputWindowExpressions, resultColumnsOnly) = getWindowExpression + + override def getInputWindowExpressions: Seq[NamedExpression] = inputWindowExpressions + override def getPartitionSpecs: Seq[Expression] = windowExec.partitionSpec + override def getOrderSpecs: Seq[SortOrder] = windowExec.orderSpec + override def getResultColumnsOnly: Boolean = resultColumnsOnly +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/InputFileBlockRule.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/InputFileBlockRule.scala new file mode 100644 index 00000000000..be18c41d7ba --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/InputFileBlockRule.scala @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids + +import scala.collection.mutable.{ArrayBuffer, LinkedHashMap} + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.{FileSourceScanExec, LeafExecNode, SparkPlan} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec + +/** + * InputFileBlockRule is to prevent the SparkPlans + * [SparkPlan (with first input_file_xxx expression), FileScan) to run on GPU + * + * See https://github.com/NVIDIA/spark-rapids/issues/3333 + */ +object InputFileBlockRule { + + /** + * Check the Expression is or has Input File expressions. + * @param exec expression to check + * @return true or false + */ + def checkHasInputFileExpressions(exec: Expression): Boolean = exec match { + case _: InputFileName => true + case _: InputFileBlockStart => true + case _: InputFileBlockLength => true + case e => e.children.exists(checkHasInputFileExpressions) + } + + private def checkHasInputFileExpressions(plan: SparkPlan): Boolean = { + plan.expressions.exists(checkHasInputFileExpressions) + } + + // Apply the rule on SparkPlanMeta + def apply(plan: SparkPlanMeta[SparkPlan]) = { + /** + * key: the SparkPlanMeta where has the first input_file_xxx expression + * value: an array of the SparkPlanMeta chain [SparkPlan (with first input_file_xxx), FileScan) + */ + val resultOps = LinkedHashMap[SparkPlanMeta[SparkPlan], ArrayBuffer[SparkPlanMeta[SparkPlan]]]() + recursivelyResolve(plan, None, resultOps) + + // If we've found some chains, we should prevent the transition. + resultOps.foreach { item => + item._2.foreach(p => p.inputFilePreventsRunningOnGpu()) + } + } + + /** + * Recursively apply the rule on the plan + * @param plan the plan to be resolved. + * @param key the SparkPlanMeta with the first input_file_xxx + * @param resultOps the found SparkPlan chain + */ + private def recursivelyResolve( + plan: SparkPlanMeta[SparkPlan], + key: Option[SparkPlanMeta[SparkPlan]], + resultOps: LinkedHashMap[SparkPlanMeta[SparkPlan], + ArrayBuffer[SparkPlanMeta[SparkPlan]]]): Unit = { + + plan.wrapped match { + case _: ShuffleExchangeExec => // Exchange will invalid the input_file_xxx + key.map(p => resultOps.remove(p)) // Remove the chain from Map + plan.childPlans.foreach(p => recursivelyResolve(p, None, resultOps)) + /* + case _: FileSourceScanExec | _: BatchScanExec => + if (plan.canThisBeReplaced) { // FileScan can be replaced + key.map(p => resultOps.remove(p)) // Remove the chain from Map + } + */ + case _: LeafExecNode => // We've reached the LeafNode but without any FileScan + key.map(p => resultOps.remove(p)) // Remove the chain from Map + case _ => + val newKey = if (key.isDefined) { + // The node is in the middle of chain [SparkPlan with input_file_xxx, FileScan) + resultOps.getOrElseUpdate(key.get, new ArrayBuffer[SparkPlanMeta[SparkPlan]]) += plan + key + } else { // There is no parent Node who has input_file_xxx + if (checkHasInputFileExpressions(plan.wrapped)) { + // Current node has input_file_xxx. Mark it as the first Node with input_file_xxx + resultOps.getOrElseUpdate(plan, new ArrayBuffer[SparkPlanMeta[SparkPlan]]) += plan + Some(plan) + } else { + None + } + } + + plan.childPlans.foreach(p => recursivelyResolve(p, newKey, resultOps)) + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala new file mode 100644 index 00000000000..c3e2633cd33 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -0,0 +1,1811 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids + +import java.io.{File, FileOutputStream} +import java.util + +import scala.collection.JavaConverters._ +import scala.collection.mutable.{HashMap, ListBuffer} + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.{ByteUnit, JavaUtils} +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.internal.SQLConf + +object ConfHelper { + def toBoolean(s: String, key: String): Boolean = { + try { + s.trim.toBoolean + } catch { + case _: IllegalArgumentException => + throw new IllegalArgumentException(s"$key should be boolean, but was $s") + } + } + + def toInteger(s: String, key: String): Integer = { + try { + s.trim.toInt + } catch { + case _: IllegalArgumentException => + throw new IllegalArgumentException(s"$key should be integer, but was $s") + } + } + + def toLong(s: String, key: String): Long = { + try { + s.trim.toLong + } catch { + case _: IllegalArgumentException => + throw new IllegalArgumentException(s"$key should be long, but was $s") + } + } + + def toDouble(s: String, key: String): Double = { + try { + s.trim.toDouble + } catch { + case _: IllegalArgumentException => + throw new IllegalArgumentException(s"$key should be integer, but was $s") + } + } + + def stringToSeq(str: String): Seq[String] = { + str.split(",").map(_.trim()).filter(_.nonEmpty) + } + + def stringToSeq[T](str: String, converter: String => T): Seq[T] = { + stringToSeq(str).map(converter) + } + + def seqToString[T](v: Seq[T], stringConverter: T => String): String = { + v.map(stringConverter).mkString(",") + } + + def byteFromString(str: String, unit: ByteUnit): Long = { + val (input, multiplier) = + if (str.nonEmpty && str.head == '-') { + (str.substring(1), -1) + } else { + (str, 1) + } + multiplier * JavaUtils.byteStringAs(input, unit) + } + + def makeConfAnchor(key: String, text: String = null): String = { + val t = if (text != null) text else key + // The anchor cannot be too long, so for now + val a = key.replaceFirst("spark.rapids.", "") + "" + t + } + + def getSqlFunctionsForClass[T](exprClass: Class[T]): Option[Seq[String]] = { + sqlFunctionsByClass.get(exprClass.getCanonicalName) + } + + lazy val sqlFunctionsByClass: Map[String, Seq[String]] = { + val functionsByClass = new HashMap[String, Seq[String]] + FunctionRegistry.expressions.foreach { case (sqlFn, (expressionInfo, _)) => + val className = expressionInfo.getClassName + val fnSeq = functionsByClass.getOrElse(className, Seq[String]()) + val fnCleaned = if (sqlFn != "|") { + sqlFn + } else { + "\\|" + } + functionsByClass.update(className, fnSeq :+ s"`$fnCleaned`") + } + functionsByClass.toMap + } +} + +abstract class ConfEntry[T](val key: String, val converter: String => T, + val doc: String, val isInternal: Boolean) { + + def get(conf: Map[String, String]): T + def get(conf: SQLConf): T + def help(asTable: Boolean = false): Unit + + override def toString: String = key +} + +class ConfEntryWithDefault[T](key: String, converter: String => T, doc: String, + isInternal: Boolean, val defaultValue: T) + extends ConfEntry[T](key, converter, doc, isInternal) { + + override def get(conf: Map[String, String]): T = { + conf.get(key).map(converter).getOrElse(defaultValue) + } + + override def get(conf: SQLConf): T = { + val tmp = conf.getConfString(key, null) + if (tmp == null) { + defaultValue + } else { + converter(tmp) + } + } + + override def help(asTable: Boolean = false): Unit = { + if (!isInternal) { + if (asTable) { + import ConfHelper.makeConfAnchor + println(s"${makeConfAnchor(key)}|$doc|$defaultValue") + } else { + println(s"$key:") + println(s"\t$doc") + println(s"\tdefault $defaultValue") + println() + } + } + } +} + +class OptionalConfEntry[T](key: String, val rawConverter: String => T, doc: String, + isInternal: Boolean) + extends ConfEntry[Option[T]](key, s => Some(rawConverter(s)), doc, isInternal) { + + override def get(conf: Map[String, String]): Option[T] = { + conf.get(key).map(rawConverter) + } + + override def get(conf: SQLConf): Option[T] = { + val tmp = conf.getConfString(key, null) + if (tmp == null) { + None + } else { + Some(rawConverter(tmp)) + } + } + + override def help(asTable: Boolean = false): Unit = { + if (!isInternal) { + if (asTable) { + import ConfHelper.makeConfAnchor + println(s"${makeConfAnchor(key)}|$doc|None") + } else { + println(s"$key:") + println(s"\t$doc") + println("\tNone") + println() + } + } + } +} + +class TypedConfBuilder[T]( + val parent: ConfBuilder, + val converter: String => T, + val stringConverter: T => String) { + + def this(parent: ConfBuilder, converter: String => T) = { + this(parent, converter, Option(_).map(_.toString).orNull) + } + + /** Apply a transformation to the user-provided values of the config entry. */ + def transform(fn: T => T): TypedConfBuilder[T] = { + new TypedConfBuilder(parent, s => fn(converter(s)), stringConverter) + } + + /** Checks if the user-provided value for the config matches the validator. */ + def checkValue(validator: T => Boolean, errorMsg: String): TypedConfBuilder[T] = { + transform { v => + if (!validator(v)) { + throw new IllegalArgumentException(errorMsg) + } + v + } + } + + /** Check that user-provided values for the config match a pre-defined set. */ + def checkValues(validValues: Set[T]): TypedConfBuilder[T] = { + transform { v => + if (!validValues.contains(v)) { + throw new IllegalArgumentException( + s"The value of ${parent.key} should be one of ${validValues.mkString(", ")}, but was $v") + } + v + } + } + + def createWithDefault(value: T): ConfEntry[T] = { + val ret = new ConfEntryWithDefault[T](parent.key, converter, + parent.doc, parent.isInternal, value) + parent.register(ret) + ret + } + + /** Turns the config entry into a sequence of values of the underlying type. */ + def toSequence: TypedConfBuilder[Seq[T]] = { + new TypedConfBuilder(parent, ConfHelper.stringToSeq(_, converter), + ConfHelper.seqToString(_, stringConverter)) + } + + def createOptional: OptionalConfEntry[T] = { + val ret = new OptionalConfEntry[T](parent.key, converter, + parent.doc, parent.isInternal) + parent.register(ret) + ret + } +} + +class ConfBuilder(val key: String, val register: ConfEntry[_] => Unit) { + + import ConfHelper._ + + var doc: String = null + var isInternal: Boolean = false + + def doc(data: String): ConfBuilder = { + this.doc = data + this + } + + def internal(): ConfBuilder = { + this.isInternal = true + this + } + + def booleanConf: TypedConfBuilder[Boolean] = { + new TypedConfBuilder[Boolean](this, toBoolean(_, key)) + } + + def bytesConf(unit: ByteUnit): TypedConfBuilder[Long] = { + new TypedConfBuilder[Long](this, byteFromString(_, unit)) + } + + def integerConf: TypedConfBuilder[Integer] = { + new TypedConfBuilder[Integer](this, toInteger(_, key)) + } + + def longConf: TypedConfBuilder[Long] = { + new TypedConfBuilder[Long](this, toLong(_, key)) + } + + def doubleConf: TypedConfBuilder[Double] = { + new TypedConfBuilder(this, toDouble(_, key)) + } + + def stringConf: TypedConfBuilder[String] = { + new TypedConfBuilder[String](this, identity[String]) + } +} + +object RapidsConf { + private val registeredConfs = new ListBuffer[ConfEntry[_]]() + + private def register(entry: ConfEntry[_]): Unit = { + registeredConfs += entry + } + + def conf(key: String): ConfBuilder = { + new ConfBuilder(key, register) + } + + // Resource Configuration + + val PINNED_POOL_SIZE = conf("spark.rapids.memory.pinnedPool.size") + .doc("The size of the pinned memory pool in bytes unless otherwise specified. " + + "Use 0 to disable the pool.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(0) + + val PAGEABLE_POOL_SIZE = conf("spark.rapids.memory.host.pageablePool.size") + .doc("The size of the pageable memory pool in bytes unless otherwise specified. " + + "Use 0 to disable the pool.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(ByteUnit.GiB.toBytes(1).toLong) + + val RMM_DEBUG = conf("spark.rapids.memory.gpu.debug") + .doc("Provides a log of GPU memory allocations and frees. If set to " + + "STDOUT or STDERR the logging will go there. Setting it to NONE disables logging. " + + "All other values are reserved for possible future expansion and in the mean time will " + + "disable logging.") + .stringConf + .createWithDefault("NONE") + + val GPU_OOM_DUMP_DIR = conf("spark.rapids.memory.gpu.oomDumpDir") + .doc("The path to a local directory where a heap dump will be created if the GPU " + + "encounters an unrecoverable out-of-memory (OOM) error. The filename will be of the " + + "form: \"gpu-oom-.hprof\" where is the process ID.") + .stringConf + .createOptional + + private val RMM_ALLOC_MAX_FRACTION_KEY = "spark.rapids.memory.gpu.maxAllocFraction" + private val RMM_ALLOC_MIN_FRACTION_KEY = "spark.rapids.memory.gpu.minAllocFraction" + private val RMM_ALLOC_RESERVE_KEY = "spark.rapids.memory.gpu.reserve" + + val RMM_ALLOC_FRACTION = conf("spark.rapids.memory.gpu.allocFraction") + .doc("The fraction of available (free) GPU memory that should be allocated for pooled " + + "memory. This must be less than or equal to the maximum limit configured via " + + s"$RMM_ALLOC_MAX_FRACTION_KEY, and greater than or equal to the minimum limit configured " + + s"via $RMM_ALLOC_MIN_FRACTION_KEY.") + .doubleConf + .checkValue(v => v >= 0 && v <= 1, "The fraction value must be in [0, 1].") + .createWithDefault(1) + + val RMM_ALLOC_MAX_FRACTION = conf(RMM_ALLOC_MAX_FRACTION_KEY) + .doc("The fraction of total GPU memory that limits the maximum size of the RMM pool. " + + s"The value must be greater than or equal to the setting for $RMM_ALLOC_FRACTION. " + + "Note that this limit will be reduced by the reserve memory configured in " + + s"$RMM_ALLOC_RESERVE_KEY.") + .doubleConf + .checkValue(v => v >= 0 && v <= 1, "The fraction value must be in [0, 1].") + .createWithDefault(1) + + val RMM_ALLOC_MIN_FRACTION = conf(RMM_ALLOC_MIN_FRACTION_KEY) + .doc("The fraction of total GPU memory that limits the minimum size of the RMM pool. " + + s"The value must be less than or equal to the setting for $RMM_ALLOC_FRACTION.") + .doubleConf + .checkValue(v => v >= 0 && v <= 1, "The fraction value must be in [0, 1].") + .createWithDefault(0.25) + + val RMM_ALLOC_RESERVE = conf(RMM_ALLOC_RESERVE_KEY) + .doc("The amount of GPU memory that should remain unallocated by RMM and left for " + + "system use such as memory needed for kernels and kernel launches.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(ByteUnit.MiB.toBytes(640).toLong) + + val HOST_SPILL_STORAGE_SIZE = conf("spark.rapids.memory.host.spillStorageSize") + .doc("Amount of off-heap host memory to use for buffering spilled GPU data before spilling " + + "to local disk. Use -1 to set the amount to the combined size of pinned and pageable " + + "memory pools.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(-1) + + val UNSPILL = conf("spark.rapids.memory.gpu.unspill.enabled") + .doc("When a spilled GPU buffer is needed again, should it be unspilled, or only copied " + + "back into GPU memory temporarily. Unspilling may be useful for GPU buffers that are " + + "needed frequently, for example, broadcast variables; however, it may also increase GPU " + + "memory usage") + .booleanConf + .createWithDefault(false) + + val GDS_SPILL = conf("spark.rapids.memory.gpu.direct.storage.spill.enabled") + .doc("Should GPUDirect Storage (GDS) be used to spill GPU memory buffers directly to disk. " + + "GDS must be enabled and the directory `spark.local.dir` must support GDS. This is an " + + "experimental feature. For more information on GDS, see " + + "https://docs.nvidia.com/gpudirect-storage/.") + .booleanConf + .createWithDefault(false) + + val GDS_SPILL_BATCH_WRITE_BUFFER_SIZE = + conf("spark.rapids.memory.gpu.direct.storage.spill.batchWriteBuffer.size") + .doc("The size of the GPU memory buffer used to batch small buffers when spilling to GDS. " + + "Note that this buffer is mapped to the PCI Base Address Register (BAR) space, which may " + + "be very limited on some GPUs (e.g. the NVIDIA T4 only has 256 MiB), and it is also used " + + "by UCX bounce buffers.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(ByteUnit.MiB.toBytes(8).toLong) + + val POOLED_MEM = conf("spark.rapids.memory.gpu.pooling.enabled") + .doc("Should RMM act as a pooling allocator for GPU memory, or should it just pass " + + "through to CUDA memory allocation directly. DEPRECATED: please use " + + "spark.rapids.memory.gpu.pool instead.") + .booleanConf + .createWithDefault(true) + + val RMM_POOL = conf("spark.rapids.memory.gpu.pool") + .doc("Select the RMM pooling allocator to use. Valid values are \"DEFAULT\", \"ARENA\", " + + "\"ASYNC\", and \"NONE\". With \"DEFAULT\", the RMM pool allocator is used; with " + + "\"ARENA\", the RMM arena allocator is used; with \"ASYNC\", the new CUDA stream-ordered " + + "memory allocator in CUDA 11.2+ is used. If set to \"NONE\", pooling is disabled and RMM " + + "just passes through to CUDA memory allocation directly. Note: \"ARENA\" is the " + + "recommended pool allocator if CUDF is built with Per-Thread Default Stream (PTDS).") + .stringConf + .createWithDefault("ARENA") + + val CONCURRENT_GPU_TASKS = conf("spark.rapids.sql.concurrentGpuTasks") + .doc("Set the number of tasks that can execute concurrently per GPU. " + + "Tasks may temporarily block when the number of concurrent tasks in the executor " + + "exceeds this amount. Allowing too many concurrent tasks on the same GPU may lead to " + + "GPU out of memory errors.") + .integerConf + .createWithDefault(1) + + val SHUFFLE_SPILL_THREADS = conf("spark.rapids.sql.shuffle.spillThreads") + .doc("Number of threads used to spill shuffle data to disk in the background.") + .integerConf + .createWithDefault(6) + + val GPU_BATCH_SIZE_BYTES = conf("spark.rapids.sql.batchSizeBytes") + .doc("Set the target number of bytes for a GPU batch. Splits sizes for input data " + + "is covered by separate configs. The maximum setting is 2 GB to avoid exceeding the " + + "cudf row count limit of a column.") + .bytesConf(ByteUnit.BYTE) + .checkValue(v => v >= 0 && v <= Integer.MAX_VALUE, + s"Batch size must be positive and not exceed ${Integer.MAX_VALUE} bytes.") + .createWithDefault(Integer.MAX_VALUE) + + val MAX_READER_BATCH_SIZE_ROWS = conf("spark.rapids.sql.reader.batchSizeRows") + .doc("Soft limit on the maximum number of rows the reader will read per batch. " + + "The orc and parquet readers will read row groups until this limit is met or exceeded. " + + "The limit is respected by the csv reader.") + .integerConf + .createWithDefault(Integer.MAX_VALUE) + + val MAX_READER_BATCH_SIZE_BYTES = conf("spark.rapids.sql.reader.batchSizeBytes") + .doc("Soft limit on the maximum number of bytes the reader reads per batch. " + + "The readers will read chunks of data until this limit is met or exceeded. " + + "Note that the reader may estimate the number of bytes that will be used on the GPU " + + "in some cases based on the schema and number of rows in each batch.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(Integer.MAX_VALUE) + + val DRIVER_TIMEZONE = conf("spark.rapids.driver.user.timezone") + .doc("This config is used to inform the executor plugin about the driver's timezone " + + "and is not intended to be set by the user.") + .internal() + .stringConf + .createOptional + + // Internal Features + + val UVM_ENABLED = conf("spark.rapids.memory.uvm.enabled") + .doc("UVM or universal memory can allow main host memory to act essentially as swap " + + "for device(GPU) memory. This allows the GPU to process more data than fits in memory, but " + + "can result in slower processing. This is an experimental feature.") + .internal() + .booleanConf + .createWithDefault(false) + + val EXPORT_COLUMNAR_RDD = conf("spark.rapids.sql.exportColumnarRdd") + .doc("Spark has no simply way to export columnar RDD data. This turns on special " + + "processing/tagging that allows the RDD to be picked back apart into a Columnar RDD.") + .internal() + .booleanConf + .createWithDefault(false) + + val STABLE_SORT = conf("spark.rapids.sql.stableSort.enabled") + .doc("Enable or disable stable sorting. Apache Spark's sorting is typically a stable " + + "sort, but sort stability cannot be guaranteed in distributed work loads because the " + + "order in which upstream data arrives to a task is not guaranteed. Sort stability then " + + "only matters when reading and sorting data from a file using a single task/partition. " + + "Because of limitations in the plugin when you enable stable sorting all of the data " + + "for a single task will be combined into a single batch before sorting. This currently " + + "disables spilling from GPU memory if the data size is too large.") + .booleanConf + .createWithDefault(false) + + // METRICS + + val METRICS_LEVEL = conf("spark.rapids.sql.metrics.level") + .doc("GPU plans can produce a lot more metrics than CPU plans do. In very large " + + "queries this can sometimes result in going over the max result size limit for the " + + "driver. Supported values include " + + "DEBUG which will enable all metrics supported and typically only needs to be enabled " + + "when debugging the plugin. " + + "MODERATE which should output enough metrics to understand how long each part of the " + + "query is taking and how much data is going to each part of the query. " + + "ESSENTIAL which disables most metrics except those Apache Spark CPU plans will also " + + "report or their equivalents.") + .stringConf + .transform(_.toUpperCase(java.util.Locale.ROOT)) + .checkValues(Set("DEBUG", "MODERATE", "ESSENTIAL")) + .createWithDefault("MODERATE") + + // ENABLE/DISABLE PROCESSING + + val IMPROVED_TIMESTAMP_OPS = + conf("spark.rapids.sql.improvedTimeOps.enabled") + .doc("When set to true, some operators will avoid overflowing by converting epoch days " + + " directly to seconds without first converting to microseconds") + .booleanConf + .createWithDefault(false) + + val SQL_ENABLED = conf("spark.rapids.sql.enabled") + .doc("Enable (true) or disable (false) sql operations on the GPU") + .booleanConf + .createWithDefault(true) + + val SQL_MODE = conf("spark.rapids.sql.mode") + .doc("Set the mode for the Rapids Accelerator. The supported modes are explainOnly and " + + "executeOnGPU. This config can not be changed at runtime, you must restart the " + + "application for it to take affect. The default mode is executeOnGPU, which means " + + "the RAPIDS Accelerator plugin convert the Spark operations and execute them on the " + + "GPU when possible. The explainOnly mode allows running queries on the CPU and the " + + "RAPIDS Accelerator will evaluate the queries as if it was going to run on the GPU. " + + "The explanations of what would have run on the GPU and why are output in log " + + "messages. When using explainOnly mode, the default explain output is ALL, this can " + + "be changed by setting spark.rapids.sql.explain. See that config for more details.") + .stringConf + .transform(_.toLowerCase(java.util.Locale.ROOT)) + .checkValues(Set("explainonly", "executeongpu")) + .createWithDefault("executeongpu") + + val UDF_COMPILER_ENABLED = conf("spark.rapids.sql.udfCompiler.enabled") + .doc("When set to true, Scala UDFs will be considered for compilation as Catalyst expressions") + .booleanConf + .createWithDefault(false) + + val INCOMPATIBLE_OPS = conf("spark.rapids.sql.incompatibleOps.enabled") + .doc("For operations that work, but are not 100% compatible with the Spark equivalent " + + "set if they should be enabled by default or disabled by default.") + .booleanConf + .createWithDefault(false) + + val INCOMPATIBLE_DATE_FORMATS = conf("spark.rapids.sql.incompatibleDateFormats.enabled") + .doc("When parsing strings as dates and timestamps in functions like unix_timestamp, some " + + "formats are fully supported on the GPU and some are unsupported and will fall back to " + + "the CPU. Some formats behave differently on the GPU than the CPU. Spark on the CPU " + + "interprets date formats with unsupported trailing characters as nulls, while Spark on " + + "the GPU will parse the date with invalid trailing characters. More detail can be found " + + "at [parsing strings as dates or timestamps]" + + "(compatibility.md#parsing-strings-as-dates-or-timestamps).") + .booleanConf + .createWithDefault(false) + + val IMPROVED_FLOAT_OPS = conf("spark.rapids.sql.improvedFloatOps.enabled") + .doc("For some floating point operations spark uses one way to compute the value " + + "and the underlying cudf implementation can use an improved algorithm. " + + "In some cases this can result in cudf producing an answer when spark overflows. " + + "Because this is not as compatible with spark, we have it disabled by default.") + .booleanConf + .createWithDefault(false) + + val HAS_NANS = conf("spark.rapids.sql.hasNans") + .doc("Config to indicate if your data has NaN's. Cudf doesn't " + + "currently support NaN's properly so you can get corrupt data if you have NaN's in your " + + "data and it runs on the GPU.") + .booleanConf + .createWithDefault(true) + + val NEED_DECIMAL_OVERFLOW_GUARANTEES = conf("spark.rapids.sql.decimalOverflowGuarantees") + .doc("FOR TESTING ONLY. DO NOT USE IN PRODUCTION. Please see the decimal section of " + + "the compatibility documents for more information on this config.") + .booleanConf + .createWithDefault(true) + + val ENABLE_FLOAT_AGG = conf("spark.rapids.sql.variableFloatAgg.enabled") + .doc("Spark assumes that all operations produce the exact same result each time. " + + "This is not true for some floating point aggregations, which can produce slightly " + + "different results on the GPU as the aggregation is done in parallel. This can enable " + + "those operations if you know the query is only computing it once.") + .booleanConf + .createWithDefault(false) + + val ENABLE_REPLACE_SORTMERGEJOIN = conf("spark.rapids.sql.replaceSortMergeJoin.enabled") + .doc("Allow replacing sortMergeJoin with HashJoin") + .booleanConf + .createWithDefault(true) + + val ENABLE_HASH_OPTIMIZE_SORT = conf("spark.rapids.sql.hashOptimizeSort.enabled") + .doc("Whether sorts should be inserted after some hashed operations to improve " + + "output ordering. This can improve output file sizes when saving to columnar formats.") + .booleanConf + .createWithDefault(false) + + val ENABLE_CAST_FLOAT_TO_DECIMAL = conf("spark.rapids.sql.castFloatToDecimal.enabled") + .doc("Casting from floating point types to decimal on the GPU returns results that have " + + "tiny difference compared to results returned from CPU.") + .booleanConf + .createWithDefault(false) + + val ENABLE_CAST_FLOAT_TO_STRING = conf("spark.rapids.sql.castFloatToString.enabled") + .doc("Casting from floating point types to string on the GPU returns results that have " + + "a different precision than the default results of Spark.") + .booleanConf + .createWithDefault(false) + + val ENABLE_CAST_FLOAT_TO_INTEGRAL_TYPES = + conf("spark.rapids.sql.castFloatToIntegralTypes.enabled") + .doc("Casting from floating point types to integral types on the GPU supports a " + + "slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST " + + "documentation for more details.") + .booleanConf + .createWithDefault(false) + + val ENABLE_CAST_DECIMAL_TO_FLOAT = conf("spark.rapids.sql.castDecimalToFloat.enabled") + .doc("Casting from decimal to floating point types on the GPU returns results that have " + + "tiny difference compared to results returned from CPU.") + .booleanConf + .createWithDefault(false) + + val ENABLE_CAST_STRING_TO_FLOAT = conf("spark.rapids.sql.castStringToFloat.enabled") + .doc("When set to true, enables casting from strings to float types (float, double) " + + "on the GPU. Currently hex values aren't supported on the GPU. Also note that casting from " + + "string to float types on the GPU returns incorrect results when the string represents any " + + "number \"1.7976931348623158E308\" <= x < \"1.7976931348623159E308\" " + + "and \"-1.7976931348623158E308\" >= x > \"-1.7976931348623159E308\" in both these cases " + + "the GPU returns Double.MaxValue while CPU returns \"+Infinity\" and \"-Infinity\" " + + "respectively") + .booleanConf + .createWithDefault(false) + + val ENABLE_CAST_STRING_TO_TIMESTAMP = conf("spark.rapids.sql.castStringToTimestamp.enabled") + .doc("When set to true, casting from string to timestamp is supported on the GPU. The GPU " + + "only supports a subset of formats when casting strings to timestamps. Refer to the CAST " + + "documentation for more details.") + .booleanConf + .createWithDefault(false) + + val HAS_EXTENDED_YEAR_VALUES = conf("spark.rapids.sql.hasExtendedYearValues") + .doc("Spark 3.2.0+ extended parsing of years in dates and " + + "timestamps to support the full range of possible values. Prior " + + "to this it was limited to a positive 4 digit year. The Accelerator does not " + + "support the extended range yet. This config indicates if your data includes " + + "this extended range or not, or if you don't care about getting the correct " + + "values on values with the extended range.") + .booleanConf + .createWithDefault(true) + + val ENABLE_CAST_DECIMAL_TO_STRING = conf("spark.rapids.sql.castDecimalToString.enabled") + .doc("When set to true, casting from decimal to string is supported on the GPU. The GPU " + + "does NOT produce exact same string as spark produces, but producing strings which are " + + "semantically equal. For instance, given input BigDecimal(123, -2), the GPU produces " + + "\"12300\", which spark produces \"1.23E+4\".") + .booleanConf + .createWithDefault(false) + + val ENABLE_INNER_JOIN = conf("spark.rapids.sql.join.inner.enabled") + .doc("When set to true inner joins are enabled on the GPU") + .booleanConf + .createWithDefault(true) + + val ENABLE_CROSS_JOIN = conf("spark.rapids.sql.join.cross.enabled") + .doc("When set to true cross joins are enabled on the GPU") + .booleanConf + .createWithDefault(true) + + val ENABLE_LEFT_OUTER_JOIN = conf("spark.rapids.sql.join.leftOuter.enabled") + .doc("When set to true left outer joins are enabled on the GPU") + .booleanConf + .createWithDefault(true) + + val ENABLE_RIGHT_OUTER_JOIN = conf("spark.rapids.sql.join.rightOuter.enabled") + .doc("When set to true right outer joins are enabled on the GPU") + .booleanConf + .createWithDefault(true) + + val ENABLE_FULL_OUTER_JOIN = conf("spark.rapids.sql.join.fullOuter.enabled") + .doc("When set to true full outer joins are enabled on the GPU") + .booleanConf + .createWithDefault(true) + + val ENABLE_LEFT_SEMI_JOIN = conf("spark.rapids.sql.join.leftSemi.enabled") + .doc("When set to true left semi joins are enabled on the GPU") + .booleanConf + .createWithDefault(true) + + val ENABLE_LEFT_ANTI_JOIN = conf("spark.rapids.sql.join.leftAnti.enabled") + .doc("When set to true left anti joins are enabled on the GPU") + .booleanConf + .createWithDefault(true) + + val ENABLE_PROJECT_AST = conf("spark.rapids.sql.projectAstEnabled") + .doc("Enable project operations to use cudf AST expressions when possible.") + .internal() + .booleanConf + .createWithDefault(false) + + // FILE FORMATS + val ENABLE_PARQUET = conf("spark.rapids.sql.format.parquet.enabled") + .doc("When set to false disables all parquet input and output acceleration") + .booleanConf + .createWithDefault(true) + + val ENABLE_PARQUET_INT96_WRITE = conf("spark.rapids.sql.format.parquet.writer.int96.enabled") + .doc("When set to false, disables accelerated parquet write if the " + + "spark.sql.parquet.outputTimestampType is set to INT96") + .booleanConf + .createWithDefault(true) + + // This is an experimental feature now. And eventually, should be enabled or disabled depending + // on something that we don't know yet but would try to figure out. + val ENABLE_CPU_BASED_UDF = conf("spark.rapids.sql.rowBasedUDF.enabled") + .doc("When set to true, optimizes a row-based UDF in a GPU operation by transferring " + + "only the data it needs between GPU and CPU inside a query operation, instead of falling " + + "this operation back to CPU. This is an experimental feature, and this config might be " + + "removed in the future.") + .booleanConf + .createWithDefault(false) + + object ParquetReaderType extends Enumeration { + val AUTO, COALESCING, MULTITHREADED, PERFILE = Value + } + + val PARQUET_READER_TYPE = conf("spark.rapids.sql.format.parquet.reader.type") + .doc("Sets the parquet reader type. We support different types that are optimized for " + + "different environments. The original Spark style reader can be selected by setting this " + + "to PERFILE which individually reads and copies files to the GPU. Loading many small files " + + "individually has high overhead, and using either COALESCING or MULTITHREADED is " + + "recommended instead. The COALESCING reader is good when using a local file system where " + + "the executors are on the same nodes or close to the nodes the data is being read on. " + + "This reader coalesces all the files assigned to a task into a single host buffer before " + + "sending it down to the GPU. It copies blocks from a single file into a host buffer in " + + "separate threads in parallel, see " + + "spark.rapids.sql.format.parquet.multiThreadedRead.numThreads. " + + "MULTITHREADED is good for cloud environments where you are reading from a blobstore " + + "that is totally separate and likely has a higher I/O read cost. Many times the cloud " + + "environments also get better throughput when you have multiple readers in parallel. " + + "This reader uses multiple threads to read each file in parallel and each file is sent " + + "to the GPU separately. This allows the CPU to keep reading while GPU is also doing work. " + + "See spark.rapids.sql.format.parquet.multiThreadedRead.numThreads and " + + "spark.rapids.sql.format.parquet.multiThreadedRead.maxNumFilesParallel to control " + + "the number of threads and amount of memory used. " + + "By default this is set to AUTO so we select the reader we think is best. This will " + + "either be the COALESCING or the MULTITHREADED based on whether we think the file is " + + "in the cloud. See spark.rapids.cloudSchemes.") + .stringConf + .transform(_.toUpperCase(java.util.Locale.ROOT)) + .checkValues(ParquetReaderType.values.map(_.toString)) + .createWithDefault(ParquetReaderType.AUTO.toString) + + /** List of schemes that are always considered cloud storage schemes */ + private lazy val DEFAULT_CLOUD_SCHEMES = + Seq("abfs", "abfss", "dbfs", "gs", "s3", "s3a", "s3n", "wasbs") + + val CLOUD_SCHEMES = conf("spark.rapids.cloudSchemes") + .doc("Comma separated list of additional URI schemes that are to be considered cloud based " + + s"filesystems. Schemes already included: ${DEFAULT_CLOUD_SCHEMES.mkString(", ")}. Cloud " + + "based stores generally would be total separate from the executors and likely have a " + + "higher I/O read cost. Many times the cloud filesystems also get better throughput when " + + "you have multiple readers in parallel. This is used with " + + "spark.rapids.sql.format.parquet.reader.type") + .stringConf + .toSequence + .createOptional + + val PARQUET_MULTITHREAD_READ_NUM_THREADS = + conf("spark.rapids.sql.format.parquet.multiThreadedRead.numThreads") + .doc("The maximum number of threads, on the executor, to use for reading small " + + "parquet files in parallel. This can not be changed at runtime after the executor has " + + "started. Used with COALESCING and MULTITHREADED reader, see " + + "spark.rapids.sql.format.parquet.reader.type.") + .integerConf + .createWithDefault(20) + + val PARQUET_MULTITHREAD_READ_MAX_NUM_FILES_PARALLEL = + conf("spark.rapids.sql.format.parquet.multiThreadedRead.maxNumFilesParallel") + .doc("A limit on the maximum number of files per task processed in parallel on the CPU " + + "side before the file is sent to the GPU. This affects the amount of host memory used " + + "when reading the files in parallel. Used with MULTITHREADED reader, see " + + "spark.rapids.sql.format.parquet.reader.type") + .integerConf + .checkValue(v => v > 0, "The maximum number of files must be greater than 0.") + .createWithDefault(Integer.MAX_VALUE) + + val ENABLE_PARQUET_READ = conf("spark.rapids.sql.format.parquet.read.enabled") + .doc("When set to false disables parquet input acceleration") + .booleanConf + .createWithDefault(true) + + val ENABLE_PARQUET_WRITE = conf("spark.rapids.sql.format.parquet.write.enabled") + .doc("When set to false disables parquet output acceleration") + .booleanConf + .createWithDefault(true) + + val ENABLE_ORC = conf("spark.rapids.sql.format.orc.enabled") + .doc("When set to false disables all orc input and output acceleration") + .booleanConf + .createWithDefault(true) + + val ENABLE_ORC_READ = conf("spark.rapids.sql.format.orc.read.enabled") + .doc("When set to false disables orc input acceleration") + .booleanConf + .createWithDefault(true) + + val ENABLE_ORC_WRITE = conf("spark.rapids.sql.format.orc.write.enabled") + .doc("When set to true enables orc output acceleration. We default it to false is because " + + "there is an ORC bug that ORC Java library fails to read ORC file without statistics in " + + "RowIndex. For more details, please refer to https://issues.apache.org/jira/browse/ORC-1075") + .booleanConf + .createWithDefault(false) + + // This will be deleted when COALESCING is implemented for ORC + object OrcReaderType extends Enumeration { + val AUTO, COALESCING, MULTITHREADED, PERFILE = Value + } + + val ORC_READER_TYPE = conf("spark.rapids.sql.format.orc.reader.type") + .doc("Sets the orc reader type. We support different types that are optimized for " + + "different environments. The original Spark style reader can be selected by setting this " + + "to PERFILE which individually reads and copies files to the GPU. Loading many small files " + + "individually has high overhead, and using either COALESCING or MULTITHREADED is " + + "recommended instead. The COALESCING reader is good when using a local file system where " + + "the executors are on the same nodes or close to the nodes the data is being read on. " + + "This reader coalesces all the files assigned to a task into a single host buffer before " + + "sending it down to the GPU. It copies blocks from a single file into a host buffer in " + + "separate threads in parallel, see " + + "spark.rapids.sql.format.orc.multiThreadedRead.numThreads. " + + "MULTITHREADED is good for cloud environments where you are reading from a blobstore " + + "that is totally separate and likely has a higher I/O read cost. Many times the cloud " + + "environments also get better throughput when you have multiple readers in parallel. " + + "This reader uses multiple threads to read each file in parallel and each file is sent " + + "to the GPU separately. This allows the CPU to keep reading while GPU is also doing work. " + + "See spark.rapids.sql.format.orc.multiThreadedRead.numThreads and " + + "spark.rapids.sql.format.orc.multiThreadedRead.maxNumFilesParallel to control " + + "the number of threads and amount of memory used. " + + "By default this is set to AUTO so we select the reader we think is best. This will " + + "either be the COALESCING or the MULTITHREADED based on whether we think the file is " + + "in the cloud. See spark.rapids.cloudSchemes.") + .stringConf + .transform(_.toUpperCase(java.util.Locale.ROOT)) + .checkValues(OrcReaderType.values.map(_.toString)) + .createWithDefault(OrcReaderType.AUTO.toString) + + val ORC_MULTITHREAD_READ_NUM_THREADS = + conf("spark.rapids.sql.format.orc.multiThreadedRead.numThreads") + .doc("The maximum number of threads, on the executor, to use for reading small " + + "orc files in parallel. This can not be changed at runtime after the executor has " + + "started. Used with MULTITHREADED reader, see " + + "spark.rapids.sql.format.orc.reader.type.") + .integerConf + .createWithDefault(20) + + val ORC_MULTITHREAD_READ_MAX_NUM_FILES_PARALLEL = + conf("spark.rapids.sql.format.orc.multiThreadedRead.maxNumFilesParallel") + .doc("A limit on the maximum number of files per task processed in parallel on the CPU " + + "side before the file is sent to the GPU. This affects the amount of host memory used " + + "when reading the files in parallel. Used with MULTITHREADED reader, see " + + "spark.rapids.sql.format.orc.reader.type") + .integerConf + .checkValue(v => v > 0, "The maximum number of files must be greater than 0.") + .createWithDefault(Integer.MAX_VALUE) + + val ENABLE_CSV = conf("spark.rapids.sql.format.csv.enabled") + .doc("When set to false disables all csv input and output acceleration. " + + "(only input is currently supported anyways)") + .booleanConf + .createWithDefault(true) + + val ENABLE_CSV_READ = conf("spark.rapids.sql.format.csv.read.enabled") + .doc("When set to false disables csv input acceleration") + .booleanConf + .createWithDefault(true) + + // TODO should we change this config? + val ENABLE_CSV_TIMESTAMPS = conf("spark.rapids.sql.csvTimestamps.enabled") + .doc("When set to true, enables the CSV parser to read timestamps. The default output " + + "format for Spark includes a timezone at the end. Anything except the UTC timezone is " + + "not supported. Timestamps after 2038 and before 1902 are also not supported.") + .booleanConf + .createWithDefault(false) + + val ENABLE_READ_CSV_DATES = conf("spark.rapids.sql.csv.read.date.enabled") + .doc("Parsing invalid CSV dates produces different results from Spark") + .booleanConf + .createWithDefault(false) + + val ENABLE_READ_CSV_BOOLS = conf("spark.rapids.sql.csv.read.bool.enabled") + .doc("Parsing an invalid CSV boolean value produces true instead of null") + .booleanConf + .createWithDefault(false) + + val ENABLE_READ_CSV_BYTES = conf("spark.rapids.sql.csv.read.byte.enabled") + .doc("Parsing CSV bytes is much more lenient and will return 0 for some " + + "malformed values instead of null") + .booleanConf + .createWithDefault(false) + + val ENABLE_READ_CSV_SHORTS = conf("spark.rapids.sql.csv.read.short.enabled") + .doc("Parsing CSV shorts is much more lenient and will return 0 for some " + + "malformed values instead of null") + .booleanConf + .createWithDefault(false) + + val ENABLE_READ_CSV_INTEGERS = conf("spark.rapids.sql.csv.read.integer.enabled") + .doc("Parsing CSV integers is much more lenient and will return 0 for some " + + "malformed values instead of null") + .booleanConf + .createWithDefault(false) + + val ENABLE_READ_CSV_LONGS = conf("spark.rapids.sql.csv.read.long.enabled") + .doc("Parsing CSV longs is much more lenient and will return 0 for some " + + "malformed values instead of null") + .booleanConf + .createWithDefault(false) + + val ENABLE_READ_CSV_FLOATS = conf("spark.rapids.sql.csv.read.float.enabled") + .doc("Parsing CSV floats has some issues at the min and max values for floating" + + "point numbers and can be more lenient on parsing inf and -inf values") + .booleanConf + .createWithDefault(false) + + val ENABLE_READ_CSV_DOUBLES = conf("spark.rapids.sql.csv.read.double.enabled") + .doc("Parsing CSV double has some issues at the min and max values for floating" + + "point numbers and can be more lenient on parsing inf and -inf values") + .booleanConf + .createWithDefault(false) + + val ENABLE_RANGE_WINDOW_BYTES = conf("spark.rapids.sql.window.range.byte.enabled") + .doc("When the order-by column of a range based window is byte type and " + + "the range boundary calculated for a value has overflow, CPU and GPU will get " + + "the different results. When set to false disables the range window acceleration for the " + + "byte type order-by column") + .booleanConf + .createWithDefault(false) + + val ENABLE_RANGE_WINDOW_SHORT = conf("spark.rapids.sql.window.range.short.enabled") + .doc("When the order-by column of a range based window is short type and " + + "the range boundary calculated for a value has overflow, CPU and GPU will get " + + "the different results. When set to false disables the range window acceleration for the " + + "short type order-by column") + .booleanConf + .createWithDefault(false) + + val ENABLE_RANGE_WINDOW_INT = conf("spark.rapids.sql.window.range.int.enabled") + .doc("When the order-by column of a range based window is int type and " + + "the range boundary calculated for a value has overflow, CPU and GPU will get " + + "the different results. When set to false disables the range window acceleration for the " + + "int type order-by column") + .booleanConf + .createWithDefault(true) + + val ENABLE_RANGE_WINDOW_LONG = conf("spark.rapids.sql.window.range.long.enabled") + .doc("When the order-by column of a range based window is long type and " + + "the range boundary calculated for a value has overflow, CPU and GPU will get " + + "the different results. When set to false disables the range window acceleration for the " + + "long type order-by column") + .booleanConf + .createWithDefault(true) + + // INTERNAL TEST AND DEBUG CONFIGS + + val TEST_CONF = conf("spark.rapids.sql.test.enabled") + .doc("Intended to be used by unit tests, if enabled all operations must run on the " + + "GPU or an error happens.") + .internal() + .booleanConf + .createWithDefault(false) + + val TEST_ALLOWED_NONGPU = conf("spark.rapids.sql.test.allowedNonGpu") + .doc("Comma separate string of exec or expression class names that are allowed " + + "to not be GPU accelerated for testing.") + .internal() + .stringConf + .toSequence + .createWithDefault(Nil) + + val TEST_VALIDATE_EXECS_ONGPU = conf("spark.rapids.sql.test.validateExecsInGpuPlan") + .doc("Comma separate string of exec class names to validate they " + + "are GPU accelerated. Used for testing.") + .internal() + .stringConf + .toSequence + .createWithDefault(Nil) + + val PARQUET_DEBUG_DUMP_PREFIX = conf("spark.rapids.sql.parquet.debug.dumpPrefix") + .doc("A path prefix where Parquet split file data is dumped for debugging.") + .internal() + .stringConf + .createWithDefault(null) + + val ORC_DEBUG_DUMP_PREFIX = conf("spark.rapids.sql.orc.debug.dumpPrefix") + .doc("A path prefix where ORC split file data is dumped for debugging.") + .internal() + .stringConf + .createWithDefault(null) + + val HASH_AGG_REPLACE_MODE = conf("spark.rapids.sql.hashAgg.replaceMode") + .doc("Only when hash aggregate exec has these modes (\"all\" by default): " + + "\"all\" (try to replace all aggregates, default), " + + "\"complete\" (exclusively replace complete aggregates), " + + "\"partial\" (exclusively replace partial aggregates), " + + "\"final\" (exclusively replace final aggregates)." + + " These modes can be connected with &(AND) or |(OR) to form sophisticated patterns.") + .internal() + .stringConf + .createWithDefault("all") + + val PARTIAL_MERGE_DISTINCT_ENABLED = conf("spark.rapids.sql.partialMerge.distinct.enabled") + .doc("Enables aggregates that are in PartialMerge mode to run on the GPU if true") + .internal() + .booleanConf + .createWithDefault(true) + + val SHUFFLE_MANAGER_ENABLED = conf("spark.rapids.shuffle.enabled") + .doc("Enable or disable the RAPIDS Shuffle Manager at runtime. " + + "The [RAPIDS Shuffle Manager](additional-functionality/rapids-shuffle.md) must " + + "already be configured. When set to `false`, the built-in Spark shuffle will be used. ") + .booleanConf + .createWithDefault(true) + + val SHUFFLE_TRANSPORT_ENABLE = conf("spark.rapids.shuffle.transport.enabled") + .doc("Enable the RAPIDS Shuffle Transport for accelerated shuffle. By default, this " + + "requires UCX to be installed in the system. Consider setting to false if running with " + + "a single executor and UCX is not available, for short-circuit cached shuffle " + + "(i.e. for testing purposes)") + .internal() + .booleanConf + .createWithDefault(true) + + val SHUFFLE_TRANSPORT_EARLY_START = conf("spark.rapids.shuffle.transport.earlyStart") + .doc("Enable early connection establishment for RAPIDS Shuffle") + .booleanConf + .createWithDefault(true) + + val SHUFFLE_TRANSPORT_EARLY_START_HEARTBEAT_INTERVAL = + conf("spark.rapids.shuffle.transport.earlyStart.heartbeatInterval") + .doc("Shuffle early start heartbeat interval (milliseconds). " + + "Executors will send a heartbeat RPC message to the driver at this interval") + .integerConf + .createWithDefault(5000) + + val SHUFFLE_TRANSPORT_EARLY_START_HEARTBEAT_TIMEOUT = + conf("spark.rapids.shuffle.transport.earlyStart.heartbeatTimeout") + .doc(s"Shuffle early start heartbeat timeout (milliseconds). " + + s"Executors that don't heartbeat within this timeout will be considered stale. " + + s"This timeout must be higher than the value for " + + s"${SHUFFLE_TRANSPORT_EARLY_START_HEARTBEAT_INTERVAL.key}") + .integerConf + .createWithDefault(10000) + + val SHUFFLE_TRANSPORT_CLASS_NAME = conf("spark.rapids.shuffle.transport.class") + .doc("The class of the specific RapidsShuffleTransport to use during the shuffle.") + .internal() + .stringConf + .createWithDefault("com.nvidia.spark.rapids.shuffle.ucx.UCXShuffleTransport") + + val SHUFFLE_TRANSPORT_MAX_RECEIVE_INFLIGHT_BYTES = + conf("spark.rapids.shuffle.transport.maxReceiveInflightBytes") + .doc("Maximum aggregate amount of bytes that be fetched at any given time from peers " + + "during shuffle") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(1024 * 1024 * 1024) + + val SHUFFLE_UCX_ACTIVE_MESSAGES_FORCE_RNDV = + conf("spark.rapids.shuffle.ucx.activeMessages.forceRndv") + .doc("Set to true to force 'rndv' mode for all UCX Active Messages. " + + "This should only be required with UCX 1.10.x. UCX 1.11.x deployments should " + + "set to false.") + .booleanConf + .createWithDefault(false) + + val SHUFFLE_UCX_USE_WAKEUP = conf("spark.rapids.shuffle.ucx.useWakeup") + .doc("When set to true, use UCX's event-based progress (epoll) in order to wake up " + + "the progress thread when needed, instead of a hot loop.") + .booleanConf + .createWithDefault(true) + + val SHUFFLE_UCX_LISTENER_START_PORT = conf("spark.rapids.shuffle.ucx.listenerStartPort") + .doc("Starting port to try to bind the UCX listener.") + .internal() + .integerConf + .createWithDefault(0) + + val SHUFFLE_UCX_MGMT_SERVER_HOST = conf("spark.rapids.shuffle.ucx.managementServerHost") + .doc("The host to be used to start the management server") + .stringConf + .createWithDefault(null) + + val SHUFFLE_UCX_MGMT_CONNECTION_TIMEOUT = + conf("spark.rapids.shuffle.ucx.managementConnectionTimeout") + .doc("The timeout for client connections to a remote peer") + .internal() + .integerConf + .createWithDefault(0) + + val SHUFFLE_UCX_BOUNCE_BUFFERS_SIZE = conf("spark.rapids.shuffle.ucx.bounceBuffers.size") + .doc("The size of bounce buffer to use in bytes. Note that this size will be the same " + + "for device and host memory") + .internal() + .bytesConf(ByteUnit.BYTE) + .createWithDefault(4 * 1024 * 1024) + + val SHUFFLE_UCX_BOUNCE_BUFFERS_DEVICE_COUNT = + conf("spark.rapids.shuffle.ucx.bounceBuffers.device.count") + .doc("The number of bounce buffers to pre-allocate from device memory") + .internal() + .integerConf + .createWithDefault(32) + + val SHUFFLE_UCX_BOUNCE_BUFFERS_HOST_COUNT = + conf("spark.rapids.shuffle.ucx.bounceBuffers.host.count") + .doc("The number of bounce buffers to pre-allocate from host memory") + .internal() + .integerConf + .createWithDefault(32) + + val SHUFFLE_MAX_CLIENT_THREADS = conf("spark.rapids.shuffle.maxClientThreads") + .doc("The maximum number of threads that the shuffle client should be allowed to start") + .internal() + .integerConf + .createWithDefault(50) + + val SHUFFLE_MAX_CLIENT_TASKS = conf("spark.rapids.shuffle.maxClientTasks") + .doc("The maximum number of tasks shuffle clients will queue before adding threads " + + s"(up to spark.rapids.shuffle.maxClientThreads), or slowing down the transport") + .internal() + .integerConf + .createWithDefault(100) + + val SHUFFLE_CLIENT_THREAD_KEEPALIVE = conf("spark.rapids.shuffle.clientThreadKeepAlive") + .doc("The number of seconds that the ThreadPoolExecutor will allow an idle client " + + "shuffle thread to stay alive, before reclaiming.") + .internal() + .integerConf + .createWithDefault(30) + + val SHUFFLE_MAX_SERVER_TASKS = conf("spark.rapids.shuffle.maxServerTasks") + .doc("The maximum number of tasks the shuffle server will queue up for its thread") + .internal() + .integerConf + .createWithDefault(1000) + + val SHUFFLE_MAX_METADATA_SIZE = conf("spark.rapids.shuffle.maxMetadataSize") + .doc("The maximum size of a metadata message that the shuffle plugin will keep in its " + + "direct message pool. ") + .internal() + .bytesConf(ByteUnit.BYTE) + .createWithDefault(500 * 1024) + + val SHUFFLE_COMPRESSION_CODEC = conf("spark.rapids.shuffle.compression.codec") + .doc("The GPU codec used to compress shuffle data when using RAPIDS shuffle. " + + "Supported codecs: lz4, copy, none") + .internal() + .stringConf + .createWithDefault("none") + + val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression.lz4.chunkSize") + .doc("A configurable chunk size to use when compressing with LZ4.") + .internal() + .bytesConf(ByteUnit.BYTE) + .createWithDefault(64 * 1024) + + // ALLUXIO CONFIGS + + val ALLUXIO_PATHS_REPLACE = conf("spark.rapids.alluxio.pathsToReplace") + .doc("List of paths to be replaced with corresponding alluxio scheme. Eg, when configure" + + "is set to \"s3:/foo->alluxio://0.1.2.3:19998/foo,gcs:/bar->alluxio://0.1.2.3:19998/bar\", " + + "which means: " + + " s3:/foo/a.csv will be replaced to alluxio://0.1.2.3:19998/foo/a.csv and " + + " gcs:/bar/b.csv will be replaced to alluxio://0.1.2.3:19998/bar/b.csv") + .stringConf + .toSequence + .createOptional + + // USER FACING DEBUG CONFIGS + + val SHUFFLE_COMPRESSION_MAX_BATCH_MEMORY = + conf("spark.rapids.shuffle.compression.maxBatchMemory") + .internal() + .bytesConf(ByteUnit.BYTE) + .createWithDefault(1024 * 1024 * 1024) + + val EXPLAIN = conf("spark.rapids.sql.explain") + .doc("Explain why some parts of a query were not placed on a GPU or not. Possible " + + "values are ALL: print everything, NONE: print nothing, NOT_ON_GPU: print only parts of " + + "a query that did not go on the GPU") + .stringConf + .createWithDefault("NONE") + + val SHIMS_PROVIDER_OVERRIDE = conf("spark.rapids.shims-provider-override") + .internal() + .doc("Overrides the automatic Spark shim detection logic and forces a specific shims " + + "provider class to be used. Set to the fully qualified shims provider class to use. " + + "If you are using a custom Spark version such as Spark 3.0.1.0 then this can be used to " + + "specify the shims provider that matches the base Spark version of Spark 3.0.1, i.e.: " + + "com.nvidia.spark.rapids.shims.spark301.SparkShimServiceProvider. If you modified Spark " + + "then there is no guarantee the RAPIDS Accelerator will function properly." + + "When tested in a combined jar with other Shims, it's expected that the provided " + + "implementation follows the same convention as existing Spark shims. If its class" + + " name has the form com.nvidia.spark.rapids.shims..YourSparkShimServiceProvider. " + + "The last package name component, i.e., shimId, can be used in the combined jar as the root" + + " directory /shimId for any incompatible classes. When tested in isolation, no special " + + "jar root is required" + ) + .stringConf + .createOptional + + val CUDF_VERSION_OVERRIDE = conf("spark.rapids.cudfVersionOverride") + .internal() + .doc("Overrides the cudf version compatibility check between cudf jar and RAPIDS Accelerator " + + "jar. If you are sure that the cudf jar which is mentioned in the classpath is compatible " + + "with the RAPIDS Accelerator version, then set this to true.") + .booleanConf + .createWithDefault(false) + + val ALLOW_DISABLE_ENTIRE_PLAN = conf("spark.rapids.allowDisableEntirePlan") + .internal() + .doc("The plugin has the ability to detect possibe incompatibility with some specific " + + "queries and cluster configurations. In those cases the plugin will disable GPU support " + + "for the entire query. Set this to false if you want to override that behavior, but use " + + "with caution.") + .booleanConf + .createWithDefault(true) + + val OPTIMIZER_ENABLED = conf("spark.rapids.sql.optimizer.enabled") + .internal() + .doc("Enable cost-based optimizer that will attempt to avoid " + + "transitions to GPU for operations that will not result in improved performance " + + "over CPU") + .booleanConf + .createWithDefault(false) + + val OPTIMIZER_EXPLAIN = conf("spark.rapids.sql.optimizer.explain") + .internal() + .doc("Explain why some parts of a query were not placed on a GPU due to " + + "optimization rules. Possible values are ALL: print everything, NONE: print nothing") + .stringConf + .createWithDefault("NONE") + + val OPTIMIZER_DEFAULT_ROW_COUNT = conf("spark.rapids.sql.optimizer.defaultRowCount") + .internal() + .doc("The cost-based optimizer uses estimated row counts to calculate costs and sometimes " + + "there is no row count available so we need a default assumption to use in this case") + .longConf + .createWithDefault(1000000) + + val OPTIMIZER_CLASS_NAME = conf("spark.rapids.sql.optimizer.className") + .internal() + .doc("Optimizer implementation class name. The class must implement the " + + "com.nvidia.spark.rapids.Optimizer trait") + .stringConf + .createWithDefault("com.nvidia.spark.rapids.CostBasedOptimizer") + + val OPTIMIZER_DEFAULT_CPU_OPERATOR_COST = conf("spark.rapids.sql.optimizer.cpu.exec.default") + .internal() + .doc("Default per-row CPU cost of executing an operator, in seconds") + .doubleConf + .createWithDefault(0.0002) + + val OPTIMIZER_DEFAULT_CPU_EXPRESSION_COST = conf("spark.rapids.sql.optimizer.cpu.expr.default") + .internal() + .doc("Default per-row CPU cost of evaluating an expression, in seconds") + .doubleConf + .createWithDefault(0.0) + + val OPTIMIZER_DEFAULT_GPU_OPERATOR_COST = conf("spark.rapids.sql.optimizer.gpu.exec.default") + .internal() + .doc("Default per-row GPU cost of executing an operator, in seconds") + .doubleConf + .createWithDefault(0.0001) + + val OPTIMIZER_DEFAULT_GPU_EXPRESSION_COST = conf("spark.rapids.sql.optimizer.gpu.expr.default") + .internal() + .doc("Default per-row GPU cost of evaluating an expression, in seconds") + .doubleConf + .createWithDefault(0.0) + + val OPTIMIZER_CPU_READ_SPEED = conf( + "spark.rapids.sql.optimizer.cpuReadSpeed") + .internal() + .doc("Speed of reading data from CPU memory in GB/s") + .doubleConf + .createWithDefault(30.0) + + val OPTIMIZER_CPU_WRITE_SPEED = conf( + "spark.rapids.sql.optimizer.cpuWriteSpeed") + .internal() + .doc("Speed of writing data to CPU memory in GB/s") + .doubleConf + .createWithDefault(30.0) + + val OPTIMIZER_GPU_READ_SPEED = conf( + "spark.rapids.sql.optimizer.gpuReadSpeed") + .internal() + .doc("Speed of reading data from GPU memory in GB/s") + .doubleConf + .createWithDefault(320.0) + + val OPTIMIZER_GPU_WRITE_SPEED = conf( + "spark.rapids.sql.optimizer.gpuWriteSpeed") + .internal() + .doc("Speed of writing data to GPU memory in GB/s") + .doubleConf + .createWithDefault(320.0) + + val USE_ARROW_OPT = conf("spark.rapids.arrowCopyOptimizationEnabled") + .doc("Option to turn off using the optimized Arrow copy code when reading from " + + "ArrowColumnVector in HostColumnarToGpu. Left as internal as user shouldn't " + + "have to turn it off, but its convenient for testing.") + .internal() + .booleanConf + .createWithDefault(true) + + val FORCE_SHIMCALLER_CLASSLOADER = conf("spark.rapids.force.caller.classloader") + .doc("Option to statically add shim's parallel world classloader URLs to " + + "the classloader of the ShimLoader class, typically Bootstrap classloader. This option" + + " uses reflection with setAccessible true on a classloader that is not created by Spark.") + .internal() + .booleanConf + .createWithDefault(value = true) + + val SPARK_GPU_RESOURCE_NAME = conf("spark.rapids.gpu.resourceName") + .doc("The name of the Spark resource that represents a GPU that you want the plugin to use " + + "if using custom resources with Spark.") + .stringConf + .createWithDefault("gpu") + + val SUPPRESS_PLANNING_FAILURE = conf("spark.rapids.sql.suppressPlanningFailure") + .doc("Option to fallback an individual query to CPU if an unexpected condition prevents the " + + "query plan from being converted to a GPU-enabled one. Note this is different from " + + "a normal CPU fallback for a yet-to-be-supported Spark SQL feature. If this happens " + + "the error should be reported and investigated as a GitHub issue.") + .booleanConf + .createWithDefault(value = false) + + val ENABLE_FAST_SAMPLE = conf("spark.rapids.sql.fast.sample") + .doc("Option to turn on fast sample. If enable it is inconsistent with CPU sample " + + "because of GPU sample algorithm is inconsistent with CPU.") + .booleanConf + .createWithDefault(value = false) + + private def printSectionHeader(category: String): Unit = + println(s"\n### $category") + + private def printToggleHeader(category: String): Unit = { + printSectionHeader(category) + println("Name | Description | Default Value | Notes") + println("-----|-------------|---------------|------------------") + } + + private def printToggleHeaderWithSqlFunction(category: String): Unit = { + printSectionHeader(category) + println("Name | SQL Function(s) | Description | Default Value | Notes") + println("-----|-----------------|-------------|---------------|------") + } + + def help(asTable: Boolean = false): Unit = { + if (asTable) { + println("---") + println("layout: page") + println("title: Configuration") + println("nav_order: 4") + println("---") + println(s"") + // scalastyle:off line.size.limit + println("""# RAPIDS Accelerator for Apache Spark Configuration + |The following is the list of options that `rapids-plugin-4-spark` supports. + | + |On startup use: `--conf [conf key]=[conf value]`. For example: + | + |``` + |$SPARK_HOME/bin/spark --jars 'rapids-4-spark_2.12-22.02.0-SNAPSHOT.jar,cudf-22.02.0-SNAPSHOT-cuda11.jar' \ + |--conf spark.plugins=com.nvidia.spark.SQLPlugin \ + |--conf spark.rapids.sql.incompatibleOps.enabled=true + |``` + | + |At runtime use: `spark.conf.set("[conf key]", [conf value])`. For example: + | + |``` + |scala> spark.conf.set("spark.rapids.sql.incompatibleOps.enabled", true) + |``` + | + | All configs can be set on startup, but some configs, especially for shuffle, will not + | work if they are set at runtime. + |""".stripMargin) + // scalastyle:on line.size.limit + + println("\n## General Configuration\n") + println("Name | Description | Default Value") + println("-----|-------------|--------------") + } else { + println("Rapids Configs:") + } + registeredConfs.sortBy(_.key).foreach(_.help(asTable)) + if (asTable) { + println("") + // scalastyle:off line.size.limit + println("""## Supported GPU Operators and Fine Tuning + |_The RAPIDS Accelerator for Apache Spark_ can be configured to enable or disable specific + |GPU accelerated expressions. Enabled expressions are candidates for GPU execution. If the + |expression is configured as disabled, the accelerator plugin will not attempt replacement, + |and it will run on the CPU. + | + |Please leverage the [`spark.rapids.sql.explain`](#sql.explain) setting to get + |feedback from the plugin as to why parts of a query may not be executing on the GPU. + | + |**NOTE:** Setting + |[`spark.rapids.sql.incompatibleOps.enabled=true`](#sql.incompatibleOps.enabled) + |will enable all the settings in the table below which are not enabled by default due to + |incompatibilities.""".stripMargin) + // scalastyle:on line.size.limit + + printToggleHeaderWithSqlFunction("Expressions\n") + } + GpuOverrides.expressions.values.toSeq.sortBy(_.tag.toString).foreach { rule => + val sqlFunctions = + ConfHelper.getSqlFunctionsForClass(rule.tag.runtimeClass).map(_.mkString(", ")) + + // this is only for formatting, this is done to ensure the table has a column for a + // row where there isn't a SQL function + rule.confHelp(asTable, Some(sqlFunctions.getOrElse(" "))) + } + if (asTable) { + printToggleHeader("Execution\n") + } + GpuOverrides.execs.values.toSeq.sortBy(_.tag.toString).foreach(_.confHelp(asTable)) + if (asTable) { + printToggleHeader("Partitioning\n") + } + GpuOverrides.parts.values.toSeq.sortBy(_.tag.toString).foreach(_.confHelp(asTable)) + } + def main(args: Array[String]): Unit = { + // Include the configs in PythonConfEntries + // com.nvidia.spark.rapids.python.PythonConfEntries.init() + val out = new FileOutputStream(new File(args(0))) + Console.withOut(out) { + Console.withErr(out) { + RapidsConf.help(true) + } + } + } +} + +class RapidsConf(conf: Map[String, String]) extends Logging { + + import ConfHelper._ + import RapidsConf._ + + def this(sqlConf: SQLConf) = { + this(sqlConf.getAllConfs) + } + + def this(sparkConf: SparkConf) = { + this(Map(sparkConf.getAll: _*)) + } + + def get[T](entry: ConfEntry[T]): T = { + entry.get(conf) + } + + lazy val rapidsConfMap: util.Map[String, String] = conf.filterKeys( + _.startsWith("spark.rapids.")).asJava + + lazy val metricsLevel: String = get(METRICS_LEVEL) + + lazy val isSqlEnabled: Boolean = get(SQL_ENABLED) + + lazy val isSqlExecuteOnGPU: Boolean = get(SQL_MODE).equals("executeongpu") + + lazy val isSqlExplainOnlyEnabled: Boolean = get(SQL_MODE).equals("explainonly") + + lazy val isUdfCompilerEnabled: Boolean = get(UDF_COMPILER_ENABLED) + + lazy val exportColumnarRdd: Boolean = get(EXPORT_COLUMNAR_RDD) + + lazy val stableSort: Boolean = get(STABLE_SORT) + + lazy val isIncompatEnabled: Boolean = get(INCOMPATIBLE_OPS) + + lazy val incompatDateFormats: Boolean = get(INCOMPATIBLE_DATE_FORMATS) + + lazy val includeImprovedFloat: Boolean = get(IMPROVED_FLOAT_OPS) + + lazy val pinnedPoolSize: Long = get(PINNED_POOL_SIZE) + + lazy val pageablePoolSize: Long = get(PAGEABLE_POOL_SIZE) + + lazy val concurrentGpuTasks: Int = get(CONCURRENT_GPU_TASKS) + + lazy val isTestEnabled: Boolean = get(TEST_CONF) + + lazy val testingAllowedNonGpu: Seq[String] = get(TEST_ALLOWED_NONGPU) + + lazy val validateExecsInGpuPlan: Seq[String] = get(TEST_VALIDATE_EXECS_ONGPU) + + lazy val rmmDebugLocation: String = get(RMM_DEBUG) + + lazy val gpuOomDumpDir: Option[String] = get(GPU_OOM_DUMP_DIR) + + lazy val isUvmEnabled: Boolean = get(UVM_ENABLED) + + lazy val isPooledMemEnabled: Boolean = get(POOLED_MEM) + + lazy val rmmPool: String = get(RMM_POOL) + + lazy val rmmAllocFraction: Double = get(RMM_ALLOC_FRACTION) + + lazy val rmmAllocMaxFraction: Double = get(RMM_ALLOC_MAX_FRACTION) + + lazy val rmmAllocMinFraction: Double = get(RMM_ALLOC_MIN_FRACTION) + + lazy val rmmAllocReserve: Long = get(RMM_ALLOC_RESERVE) + + lazy val hostSpillStorageSize: Long = get(HOST_SPILL_STORAGE_SIZE) + + lazy val isUnspillEnabled: Boolean = get(UNSPILL) + + lazy val isGdsSpillEnabled: Boolean = get(GDS_SPILL) + + lazy val gdsSpillBatchWriteBufferSize: Long = get(GDS_SPILL_BATCH_WRITE_BUFFER_SIZE) + + lazy val hasNans: Boolean = get(HAS_NANS) + + lazy val needDecimalGuarantees: Boolean = get(NEED_DECIMAL_OVERFLOW_GUARANTEES) + + lazy val gpuTargetBatchSizeBytes: Long = get(GPU_BATCH_SIZE_BYTES) + + lazy val isFloatAggEnabled: Boolean = get(ENABLE_FLOAT_AGG) + + lazy val explain: String = get(EXPLAIN) + + lazy val shouldExplain: Boolean = !explain.equalsIgnoreCase("NONE") + + lazy val shouldExplainAll: Boolean = explain.equalsIgnoreCase("ALL") + + lazy val isImprovedTimestampOpsEnabled: Boolean = get(IMPROVED_TIMESTAMP_OPS) + + lazy val maxReadBatchSizeRows: Int = get(MAX_READER_BATCH_SIZE_ROWS) + + lazy val maxReadBatchSizeBytes: Long = get(MAX_READER_BATCH_SIZE_BYTES) + + lazy val parquetDebugDumpPrefix: String = get(PARQUET_DEBUG_DUMP_PREFIX) + + lazy val orcDebugDumpPrefix: String = get(ORC_DEBUG_DUMP_PREFIX) + + lazy val hashAggReplaceMode: String = get(HASH_AGG_REPLACE_MODE) + + lazy val partialMergeDistinctEnabled: Boolean = get(PARTIAL_MERGE_DISTINCT_ENABLED) + + lazy val enableReplaceSortMergeJoin: Boolean = get(ENABLE_REPLACE_SORTMERGEJOIN) + + lazy val enableHashOptimizeSort: Boolean = get(ENABLE_HASH_OPTIMIZE_SORT) + + lazy val areInnerJoinsEnabled: Boolean = get(ENABLE_INNER_JOIN) + + lazy val areCrossJoinsEnabled: Boolean = get(ENABLE_CROSS_JOIN) + + lazy val areLeftOuterJoinsEnabled: Boolean = get(ENABLE_LEFT_OUTER_JOIN) + + lazy val areRightOuterJoinsEnabled: Boolean = get(ENABLE_RIGHT_OUTER_JOIN) + + lazy val areFullOuterJoinsEnabled: Boolean = get(ENABLE_FULL_OUTER_JOIN) + + lazy val areLeftSemiJoinsEnabled: Boolean = get(ENABLE_LEFT_SEMI_JOIN) + + lazy val areLeftAntiJoinsEnabled: Boolean = get(ENABLE_LEFT_ANTI_JOIN) + + lazy val isCastDecimalToFloatEnabled: Boolean = get(ENABLE_CAST_DECIMAL_TO_FLOAT) + + lazy val isCastFloatToDecimalEnabled: Boolean = get(ENABLE_CAST_FLOAT_TO_DECIMAL) + + lazy val isCastFloatToStringEnabled: Boolean = get(ENABLE_CAST_FLOAT_TO_STRING) + + lazy val isCastStringToTimestampEnabled: Boolean = get(ENABLE_CAST_STRING_TO_TIMESTAMP) + + lazy val hasExtendedYearValues: Boolean = get(HAS_EXTENDED_YEAR_VALUES) + + lazy val isCastStringToFloatEnabled: Boolean = get(ENABLE_CAST_STRING_TO_FLOAT) + + lazy val isCastFloatToIntegralTypesEnabled: Boolean = get(ENABLE_CAST_FLOAT_TO_INTEGRAL_TYPES) + + lazy val isCsvTimestampReadEnabled: Boolean = get(ENABLE_CSV_TIMESTAMPS) + + lazy val isCsvDateReadEnabled: Boolean = get(ENABLE_READ_CSV_DATES) + + lazy val isCsvBoolReadEnabled: Boolean = get(ENABLE_READ_CSV_BOOLS) + + lazy val isCsvByteReadEnabled: Boolean = get(ENABLE_READ_CSV_BYTES) + + lazy val isCsvShortReadEnabled: Boolean = get(ENABLE_READ_CSV_SHORTS) + + lazy val isCsvIntReadEnabled: Boolean = get(ENABLE_READ_CSV_INTEGERS) + + lazy val isCsvLongReadEnabled: Boolean = get(ENABLE_READ_CSV_LONGS) + + lazy val isCsvFloatReadEnabled: Boolean = get(ENABLE_READ_CSV_FLOATS) + + lazy val isCsvDoubleReadEnabled: Boolean = get(ENABLE_READ_CSV_DOUBLES) + + lazy val isCastDecimalToStringEnabled: Boolean = get(ENABLE_CAST_DECIMAL_TO_STRING) + + lazy val isProjectAstEnabled: Boolean = get(ENABLE_PROJECT_AST) + + lazy val isParquetEnabled: Boolean = get(ENABLE_PARQUET) + + lazy val isParquetInt96WriteEnabled: Boolean = get(ENABLE_PARQUET_INT96_WRITE) + + lazy val isParquetPerFileReadEnabled: Boolean = + ParquetReaderType.withName(get(PARQUET_READER_TYPE)) == ParquetReaderType.PERFILE + + lazy val isParquetAutoReaderEnabled: Boolean = + ParquetReaderType.withName(get(PARQUET_READER_TYPE)) == ParquetReaderType.AUTO + + lazy val isParquetCoalesceFileReadEnabled: Boolean = isParquetAutoReaderEnabled || + ParquetReaderType.withName(get(PARQUET_READER_TYPE)) == ParquetReaderType.COALESCING + + lazy val isParquetMultiThreadReadEnabled: Boolean = isParquetAutoReaderEnabled || + ParquetReaderType.withName(get(PARQUET_READER_TYPE)) == ParquetReaderType.MULTITHREADED + + lazy val parquetMultiThreadReadNumThreads: Int = get(PARQUET_MULTITHREAD_READ_NUM_THREADS) + + lazy val maxNumParquetFilesParallel: Int = get(PARQUET_MULTITHREAD_READ_MAX_NUM_FILES_PARALLEL) + + lazy val isParquetReadEnabled: Boolean = get(ENABLE_PARQUET_READ) + + lazy val isParquetWriteEnabled: Boolean = get(ENABLE_PARQUET_WRITE) + + lazy val isOrcEnabled: Boolean = get(ENABLE_ORC) + + lazy val isOrcReadEnabled: Boolean = get(ENABLE_ORC_READ) + + lazy val isOrcWriteEnabled: Boolean = get(ENABLE_ORC_WRITE) + + lazy val isOrcPerFileReadEnabled: Boolean = + OrcReaderType.withName(get(ORC_READER_TYPE)) == OrcReaderType.PERFILE + + lazy val isOrcAutoReaderEnabled: Boolean = + OrcReaderType.withName(get(ORC_READER_TYPE)) == OrcReaderType.AUTO + + lazy val isOrcCoalesceFileReadEnabled: Boolean = isOrcAutoReaderEnabled || + OrcReaderType.withName(get(ORC_READER_TYPE)) == OrcReaderType.COALESCING + + lazy val isOrcMultiThreadReadEnabled: Boolean = isOrcAutoReaderEnabled || + OrcReaderType.withName(get(ORC_READER_TYPE)) == OrcReaderType.MULTITHREADED + + lazy val orcMultiThreadReadNumThreads: Int = get(ORC_MULTITHREAD_READ_NUM_THREADS) + + lazy val maxNumOrcFilesParallel: Int = get(ORC_MULTITHREAD_READ_MAX_NUM_FILES_PARALLEL) + + lazy val isCsvEnabled: Boolean = get(ENABLE_CSV) + + lazy val isCsvReadEnabled: Boolean = get(ENABLE_CSV_READ) + + lazy val shuffleManagerEnabled: Boolean = get(SHUFFLE_MANAGER_ENABLED) + + lazy val shuffleTransportEnabled: Boolean = get(SHUFFLE_TRANSPORT_ENABLE) + + lazy val shuffleTransportClassName: String = get(SHUFFLE_TRANSPORT_CLASS_NAME) + + lazy val shuffleTransportEarlyStartHeartbeatInterval: Int = get( + SHUFFLE_TRANSPORT_EARLY_START_HEARTBEAT_INTERVAL) + + lazy val shuffleTransportEarlyStartHeartbeatTimeout: Int = get( + SHUFFLE_TRANSPORT_EARLY_START_HEARTBEAT_TIMEOUT) + + lazy val shuffleTransportEarlyStart: Boolean = get(SHUFFLE_TRANSPORT_EARLY_START) + + lazy val shuffleTransportMaxReceiveInflightBytes: Long = get( + SHUFFLE_TRANSPORT_MAX_RECEIVE_INFLIGHT_BYTES) + + lazy val shuffleUcxActiveMessagesForceRndv: Boolean = get(SHUFFLE_UCX_ACTIVE_MESSAGES_FORCE_RNDV) + + lazy val shuffleUcxUseWakeup: Boolean = get(SHUFFLE_UCX_USE_WAKEUP) + + lazy val shuffleUcxListenerStartPort: Int = get(SHUFFLE_UCX_LISTENER_START_PORT) + + lazy val shuffleUcxMgmtHost: String = get(SHUFFLE_UCX_MGMT_SERVER_HOST) + + lazy val shuffleUcxMgmtConnTimeout: Int = get(SHUFFLE_UCX_MGMT_CONNECTION_TIMEOUT) + + lazy val shuffleUcxBounceBuffersSize: Long = get(SHUFFLE_UCX_BOUNCE_BUFFERS_SIZE) + + lazy val shuffleUcxDeviceBounceBuffersCount: Int = get(SHUFFLE_UCX_BOUNCE_BUFFERS_DEVICE_COUNT) + + lazy val shuffleUcxHostBounceBuffersCount: Int = get(SHUFFLE_UCX_BOUNCE_BUFFERS_HOST_COUNT) + + lazy val shuffleMaxClientThreads: Int = get(SHUFFLE_MAX_CLIENT_THREADS) + + lazy val shuffleMaxClientTasks: Int = get(SHUFFLE_MAX_CLIENT_TASKS) + + lazy val shuffleClientThreadKeepAliveTime: Int = get(SHUFFLE_CLIENT_THREAD_KEEPALIVE) + + lazy val shuffleMaxServerTasks: Int = get(SHUFFLE_MAX_SERVER_TASKS) + + lazy val shuffleMaxMetadataSize: Long = get(SHUFFLE_MAX_METADATA_SIZE) + + lazy val shuffleCompressionCodec: String = get(SHUFFLE_COMPRESSION_CODEC) + + lazy val shuffleCompressionLz4ChunkSize: Long = get(SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE) + + lazy val shuffleCompressionMaxBatchMemory: Long = get(SHUFFLE_COMPRESSION_MAX_BATCH_MEMORY) + + lazy val shimsProviderOverride: Option[String] = get(SHIMS_PROVIDER_OVERRIDE) + + lazy val cudfVersionOverride: Boolean = get(CUDF_VERSION_OVERRIDE) + + lazy val allowDisableEntirePlan: Boolean = get(ALLOW_DISABLE_ENTIRE_PLAN) + + lazy val useArrowCopyOptimization: Boolean = get(USE_ARROW_OPT) + + lazy val getCloudSchemes: Seq[String] = + DEFAULT_CLOUD_SCHEMES ++ get(CLOUD_SCHEMES).getOrElse(Seq.empty) + + lazy val optimizerEnabled: Boolean = get(OPTIMIZER_ENABLED) + + lazy val optimizerExplain: String = get(OPTIMIZER_EXPLAIN) + + lazy val optimizerShouldExplainAll: Boolean = optimizerExplain.equalsIgnoreCase("ALL") + + lazy val optimizerClassName: String = get(OPTIMIZER_CLASS_NAME) + + lazy val defaultRowCount: Long = get(OPTIMIZER_DEFAULT_ROW_COUNT) + + lazy val defaultCpuOperatorCost: Double = get(OPTIMIZER_DEFAULT_CPU_OPERATOR_COST) + + lazy val defaultCpuExpressionCost: Double = get(OPTIMIZER_DEFAULT_CPU_EXPRESSION_COST) + + lazy val defaultGpuOperatorCost: Double = get(OPTIMIZER_DEFAULT_GPU_OPERATOR_COST) + + lazy val defaultGpuExpressionCost: Double = get(OPTIMIZER_DEFAULT_GPU_EXPRESSION_COST) + + lazy val cpuReadMemorySpeed: Double = get(OPTIMIZER_CPU_READ_SPEED) + + lazy val cpuWriteMemorySpeed: Double = get(OPTIMIZER_CPU_WRITE_SPEED) + + lazy val gpuReadMemorySpeed: Double = get(OPTIMIZER_GPU_READ_SPEED) + + lazy val gpuWriteMemorySpeed: Double = get(OPTIMIZER_GPU_WRITE_SPEED) + + lazy val getAlluxioPathsToReplace: Option[Seq[String]] = get(ALLUXIO_PATHS_REPLACE) + + lazy val driverTimeZone: Option[String] = get(DRIVER_TIMEZONE) + + lazy val isRangeWindowByteEnabled: Boolean = get(ENABLE_RANGE_WINDOW_BYTES) + + lazy val isRangeWindowShortEnabled: Boolean = get(ENABLE_RANGE_WINDOW_SHORT) + + lazy val isRangeWindowIntEnabled: Boolean = get(ENABLE_RANGE_WINDOW_INT) + + lazy val isRangeWindowLongEnabled: Boolean = get(ENABLE_RANGE_WINDOW_LONG) + + lazy val getSparkGpuResourceName: String = get(SPARK_GPU_RESOURCE_NAME) + + lazy val isCpuBasedUDFEnabled: Boolean = get(ENABLE_CPU_BASED_UDF) + + lazy val isFastSampleEnabled: Boolean = get(ENABLE_FAST_SAMPLE) + + private val optimizerDefaults = Map( + // this is not accurate because CPU projections do have a cost due to appending values + // to each row that is produced, but this needs to be a really small number because + // GpuProject cost is zero (in our cost model) and we don't want to encourage moving to + // the GPU just to do a trivial projection, so we pretend the overhead of a + // CPU projection (beyond evaluating the expressions) is also zero + "spark.rapids.sql.optimizer.cpu.exec.ProjectExec" -> "0", + // The cost of a GPU projection is mostly the cost of evaluating the expressions + // to produce the projected columns + "spark.rapids.sql.optimizer.gpu.exec.ProjectExec" -> "0", + // union does not further process data produced by its children + "spark.rapids.sql.optimizer.cpu.exec.UnionExec" -> "0", + "spark.rapids.sql.optimizer.gpu.exec.UnionExec" -> "0" + ) + + def isOperatorEnabled(key: String, incompat: Boolean, isDisabledByDefault: Boolean): Boolean = { + val default = !(isDisabledByDefault || incompat) || (incompat && isIncompatEnabled) + conf.get(key).map(toBoolean(_, key)).getOrElse(default) + } + + /** + * Get the GPU cost of an expression, for use in the cost-based optimizer. + */ + def getGpuExpressionCost(operatorName: String): Option[Double] = { + val key = s"spark.rapids.sql.optimizer.gpu.expr.$operatorName" + getOptionalCost(key) + } + + /** + * Get the GPU cost of an operator, for use in the cost-based optimizer. + */ + def getGpuOperatorCost(operatorName: String): Option[Double] = { + val key = s"spark.rapids.sql.optimizer.gpu.exec.$operatorName" + getOptionalCost(key) + } + + /** + * Get the CPU cost of an expression, for use in the cost-based optimizer. + */ + def getCpuExpressionCost(operatorName: String): Option[Double] = { + val key = s"spark.rapids.sql.optimizer.cpu.expr.$operatorName" + getOptionalCost(key) + } + + /** + * Get the CPU cost of an operator, for use in the cost-based optimizer. + */ + def getCpuOperatorCost(operatorName: String): Option[Double] = { + val key = s"spark.rapids.sql.optimizer.cpu.exec.$operatorName" + getOptionalCost(key) + } + + private def getOptionalCost(key: String) = { + // user-provided value takes precedence, then look in defaults map + conf.get(key).orElse(optimizerDefaults.get(key)).map(toDouble(_, key)) + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala new file mode 100644 index 00000000000..6014a21396f --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -0,0 +1,1148 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import java.time.ZoneId + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, ComplexTypeMergingExpression, Expression, String2TrimExpression, TernaryExpression, UnaryExpression, WindowExpression, WindowFunction} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ImperativeAggregate, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.aggregate._ +import org.apache.spark.sql.execution.command.DataWritingCommand +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.window.WindowExec +import org.apache.spark.sql.types.DataType + +trait DataFromReplacementRule { + val operationName: String + def incompatDoc: Option[String] = None + def disabledMsg: Option[String] = None + + def confKey: String + + def getChecks: Option[TypeChecks[_]] +} + +/** + * A version of DataFromReplacementRule that is used when no replacement rule can be found. + */ +final class NoRuleDataFromReplacementRule extends DataFromReplacementRule { + override val operationName: String = "" + + override def confKey = "NOT_FOUND" + + override def getChecks: Option[TypeChecks[_]] = None +} + +object RapidsMeta { +} + +/** + * Holds metadata about a stage in the physical plan that is separate from the plan itself. + * This is helpful in deciding when to replace part of the plan with a GPU enabled version. + * + * @param wrapped what we are wrapping + * @param conf the config + * @param parent the parent of this node, if there is one. + * @param rule holds information related to the config for this object, typically this is the rule + * used to wrap the stage. + * @tparam INPUT the exact type of the class we are wrapping. + * @tparam BASE the generic base class for this type of stage, i.e. SparkPlan, Expression, etc. + */ +abstract class RapidsMeta[INPUT <: BASE, BASE]( + val wrapped: INPUT, + val conf: RapidsConf, + val parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) { + + /** + * The wrapped plans that should be examined + */ + val childPlans: Seq[SparkPlanMeta[_]] + + /** + * The wrapped expressions that should be examined + */ + val childExprs: Seq[BaseExprMeta[_]] + + /** + * The wrapped partitioning that should be examined + */ + val childParts: Seq[PartMeta[_]] + + /** + * The wrapped data writing commands that should be examined + */ + val childDataWriteCmds: Seq[DataWritingCommandMeta[_]] + + /** + * Keep this on the CPU, but possibly convert its children under it to run on the GPU if enabled. + * By default this just returns what is wrapped by this. For some types of operators/stages, + * like SparkPlan, each part of the query can be converted independent of other parts. As such in + * a subclass this should be overridden to do the correct thing. + */ + def convertToCpu(): BASE = wrapped + + protected var cannotBeReplacedReasons: Option[mutable.Set[String]] = None + private var mustBeReplacedReasons: Option[mutable.Set[String]] = None + private var cannotReplaceAnyOfPlanReasons: Option[mutable.Set[String]] = None + private var shouldBeRemovedReasons: Option[mutable.Set[String]] = None + private var typeConversionReasons: Option[mutable.Set[String]] = None + protected var cannotRunOnGpuBecauseOfSparkPlan: Boolean = false + protected var cannotRunOnGpuBecauseOfCost: Boolean = false + + + /** + * Recursively force a section of the plan back onto CPU, stopping once a plan + * is reached that is already on CPU. + */ + final def recursiveCostPreventsRunningOnGpu(): Unit = { + if (canThisBeReplaced && !mustThisBeReplaced) { + costPreventsRunningOnGpu() + childDataWriteCmds.foreach(_.recursiveCostPreventsRunningOnGpu()) + } + } + + final def costPreventsRunningOnGpu(): Unit = { + cannotRunOnGpuBecauseOfCost = true + willNotWorkOnGpu("Removed by cost-based optimizer") + childExprs.foreach(_.recursiveCostPreventsRunningOnGpu()) + childParts.foreach(_.recursiveCostPreventsRunningOnGpu()) + } + + final def recursiveSparkPlanPreventsRunningOnGpu(): Unit = { + cannotRunOnGpuBecauseOfSparkPlan = true + childExprs.foreach(_.recursiveSparkPlanPreventsRunningOnGpu()) + childParts.foreach(_.recursiveSparkPlanPreventsRunningOnGpu()) + childDataWriteCmds.foreach(_.recursiveSparkPlanPreventsRunningOnGpu()) + } + + final def recursiveSparkPlanRemoved(): Unit = { + shouldBeRemoved("parent plan is removed") + childExprs.foreach(_.recursiveSparkPlanRemoved()) + childParts.foreach(_.recursiveSparkPlanRemoved()) + childDataWriteCmds.foreach(_.recursiveSparkPlanRemoved()) + } + + final def inputFilePreventsRunningOnGpu(): Unit = { + if (canThisBeReplaced) { + willNotWorkOnGpu("Removed by InputFileBlockRule preventing plans " + + "[SparkPlan(with input_file_xxx), FileScan) running on GPU") + } + } + + /** + * Call this to indicate that this should not be replaced with a GPU enabled version + * @param because why it should not be replaced. + */ + final def willNotWorkOnGpu(because: String): Unit = { + cannotBeReplacedReasons.get.add(because) + // annotate the real spark plan with the reason as well so that the information is available + // during query stage planning when AQE is on + } + + final def mustBeReplaced(because: String): Unit = { + mustBeReplacedReasons.get.add(because) + } + + /** + * Call this if there is a condition found that the entire plan is not allowed + * to run on the GPU. + */ + final def entirePlanWillNotWork(because: String): Unit = { + cannotReplaceAnyOfPlanReasons.get.add(because) + } + + final def shouldBeRemoved(because: String): Unit = + shouldBeRemovedReasons.get.add(because) + + /** + * Call this method to record information about type conversions via DataTypeMeta. + */ + final def addConvertedDataType(expression: Expression, typeMeta: DataTypeMeta): Unit = { + typeConversionReasons.get.add( + s"$expression: ${typeMeta.reasonForConversion}") + } + + /** + * Returns true if this node should be removed. + */ + final def shouldThisBeRemoved: Boolean = shouldBeRemovedReasons.exists(_.nonEmpty) + + /** + * Returns true iff this could be replaced. + */ + final def canThisBeReplaced: Boolean = cannotBeReplacedReasons.exists(_.isEmpty) + + /** + * Returns true iff this must be replaced because its children have already been + * replaced and this needs to also be replaced for compatibility. + */ + final def mustThisBeReplaced: Boolean = mustBeReplacedReasons.exists(_.nonEmpty) + + /** + * Returns the list of reasons the entire plan can't be replaced. An empty + * set means the entire plan is ok to be replaced, do the normal checking + * per exec and children. + */ + final def entirePlanExcludedReasons: Seq[String] = { + cannotReplaceAnyOfPlanReasons.getOrElse(mutable.Set.empty).toSeq + } + + /** + * Returns true iff all of the expressions and their children could be replaced. + */ + def canExprTreeBeReplaced: Boolean = childExprs.forall(_.canExprTreeBeReplaced) + + /** + * Returns true iff all of the partitioning can be replaced. + */ + def canPartsBeReplaced: Boolean = childParts.forall(_.canThisBeReplaced) + + /** + * Returns true iff all of the data writing commands can be replaced. + */ + def canDataWriteCmdsBeReplaced: Boolean = childDataWriteCmds.forall(_.canThisBeReplaced) + + def confKey: String = rule.confKey + final val operationName: String = rule.operationName + final val incompatDoc: Option[String] = rule.incompatDoc + def isIncompat: Boolean = incompatDoc.isDefined + final val disabledMsg: Option[String] = rule.disabledMsg + def isDisabledByDefault: Boolean = disabledMsg.isDefined + + def initReasons(): Unit = { + cannotBeReplacedReasons = Some(mutable.Set[String]()) + mustBeReplacedReasons = Some(mutable.Set[String]()) + shouldBeRemovedReasons = Some(mutable.Set[String]()) + cannotReplaceAnyOfPlanReasons = Some(mutable.Set[String]()) + typeConversionReasons = Some(mutable.Set[String]()) + } + + /** + * Tag all of the children to see if they are GPU compatible first. + * Do basic common verification for the operators, and then call + * [[tagSelfForGpu]] + */ + final def tagForGpu(): Unit = { + childParts.foreach(_.tagForGpu()) + childExprs.foreach(_.tagForGpu()) + childDataWriteCmds.foreach(_.tagForGpu()) + childPlans.foreach(_.tagForGpu()) + + initReasons() + + if (!conf.isOperatorEnabled(confKey, isIncompat, isDisabledByDefault)) { + if (isIncompat && !conf.isIncompatEnabled) { + willNotWorkOnGpu(s"the GPU version of ${wrapped.getClass.getSimpleName}" + + s" is not 100% compatible with the Spark version. ${incompatDoc.get}. To enable this" + + s" $operationName despite the incompatibilities please set the config" + + s" $confKey to true. You could also set ${RapidsConf.INCOMPATIBLE_OPS} to true" + + s" to enable all incompatible ops") + } else if (isDisabledByDefault) { + willNotWorkOnGpu(s"the $operationName ${wrapped.getClass.getSimpleName} has" + + s" been disabled, and is disabled by default because ${disabledMsg.get}. Set $confKey" + + s" to true if you wish to enable it") + } else { + willNotWorkOnGpu(s"the $operationName ${wrapped.getClass.getSimpleName} has" + + s" been disabled. Set $confKey to true if you wish to enable it") + } + } + + tagSelfForGpu() + } + + /** + * Do any extra checks and tag yourself if you are compatible or not. Be aware that this may + * already have been marked as incompatible for a number of reasons. + * + * All of your children should have already been tagged so if there are situations where you + * may need to disqualify your children for various reasons you may do it here too. + */ + def tagSelfForGpu(): Unit + + protected def indent(append: StringBuilder, depth: Int): Unit = + append.append(" " * depth) + + def replaceMessage: String = "run on GPU" + def noReplacementPossibleMessage(reasons: String): String = s"cannot run on GPU because $reasons" + def suppressWillWorkOnGpuInfo: Boolean = false + + private def willWorkOnGpuInfo: String = cannotBeReplacedReasons match { + case None => "NOT EVALUATED FOR GPU YET" + case Some(v) if v.isEmpty && + (cannotRunOnGpuBecauseOfSparkPlan || shouldThisBeRemoved) => "could " + replaceMessage + case Some(v) if v.isEmpty => "will " + replaceMessage + case Some(v) => + noReplacementPossibleMessage(v.mkString("; ")) + } + + private def willBeRemovedInfo: String = shouldBeRemovedReasons match { + case None => "" + case Some(v) if v.isEmpty => "" + case Some(v) => + val reasons = v.mkString("; ") + s" but is going to be removed because $reasons" + } + + private def typeConversionInfo: String = typeConversionReasons match { + case None => "" + case Some(v) if v.isEmpty => "" + case Some(v) => + "The data type of following expressions will be converted in GPU runtime: " + + v.mkString("; ") + } + + /** + * When converting this to a string should we include the string representation of what this + * wraps too? This is off by default. + */ + protected val printWrapped = false + + final private def getIndicatorChar: String = { + if (shouldThisBeRemoved) { + "#" + } else if (cannotRunOnGpuBecauseOfCost) { + "$" + } else if (canThisBeReplaced) { + if (cannotRunOnGpuBecauseOfSparkPlan) { + "@" + } else if (cannotRunOnGpuBecauseOfCost) { + "$" + } else { + "*" + } + } else { + "!" + } + } + + protected def checkTimeZoneId(timeZoneId: Option[String]): Unit = { + timeZoneId.foreach { zoneId => + if (!TypeChecks.areTimestampsSupported(ZoneId.systemDefault())) { + willNotWorkOnGpu(s"Only UTC zone id is supported. Actual zone id: $zoneId") + } + } + } + + /** + * Create a string representation of this in append. + * @param strBuilder where to place the string representation. + * @param depth how far down the tree this is. + * @param all should all the data be printed or just what does not work on the GPU? + */ + protected def print(strBuilder: StringBuilder, depth: Int, all: Boolean): Unit = { + if ((all || !canThisBeReplaced || cannotRunOnGpuBecauseOfSparkPlan) && + !suppressWillWorkOnGpuInfo) { + indent(strBuilder, depth) + strBuilder.append(getIndicatorChar) + + strBuilder.append(operationName) + .append(" <") + .append(wrapped.getClass.getSimpleName) + .append("> ") + + if (printWrapped) { + strBuilder.append(wrapped) + .append(" ") + } + + strBuilder.append(willWorkOnGpuInfo). + append(willBeRemovedInfo) + + typeConversionInfo match { + case info if info.isEmpty => + case info => strBuilder.append(". ").append(info) + } + + strBuilder.append("\n") + } + printChildren(strBuilder, depth, all) + } + + private final def printChildren(append: StringBuilder, depth: Int, all: Boolean): Unit = { + childParts.foreach(_.print(append, depth + 1, all)) + childExprs.foreach(_.print(append, depth + 1, all)) + childDataWriteCmds.foreach(_.print(append, depth + 1, all)) + childPlans.foreach(_.print(append, depth + 1, all)) + } + + def explain(all: Boolean): String = { + val appender = new StringBuilder() + print(appender, 0, all) + appender.toString() + } + + override def toString: String = { + explain(true) + } +} + +/** + * Base class for metadata around `Partitioning`. + */ +abstract class PartMeta[INPUT <: Partitioning](part: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends RapidsMeta[INPUT, Partitioning](part, conf, parent, rule) { + // 2.x - replaced GpuPartitioning with Partitioning, should be fine + // since BASE only used for convert + + override val childPlans: Seq[SparkPlanMeta[_]] = Seq.empty + override val childExprs: Seq[BaseExprMeta[_]] = Seq.empty + override val childParts: Seq[PartMeta[_]] = Seq.empty + override val childDataWriteCmds: Seq[DataWritingCommandMeta[_]] = Seq.empty + + override final def tagSelfForGpu(): Unit = { + rule.getChecks.foreach(_.tag(this)) + if (!canExprTreeBeReplaced) { + willNotWorkOnGpu("not all expressions can be replaced") + } + tagPartForGpu() + } + + def tagPartForGpu(): Unit = {} +} + +/** + * Metadata for Partitioning with no rule found + */ +final class RuleNotFoundPartMeta[INPUT <: Partitioning]( + part: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]]) + extends PartMeta[INPUT](part, conf, parent, new NoRuleDataFromReplacementRule) { + + override def tagPartForGpu(): Unit = { + willNotWorkOnGpu(s"GPU does not currently support the operator ${part.getClass}") + } + +} + +/** + * Base class for metadata around `DataWritingCommand`. + */ +abstract class DataWritingCommandMeta[INPUT <: DataWritingCommand]( + cmd: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends RapidsMeta[INPUT, DataWritingCommand](cmd, conf, parent, rule) { + + override val childPlans: Seq[SparkPlanMeta[_]] = Seq.empty + override val childExprs: Seq[BaseExprMeta[_]] = Seq.empty + override val childParts: Seq[PartMeta[_]] = Seq.empty + override val childDataWriteCmds: Seq[DataWritingCommandMeta[_]] = Seq.empty + + override def tagSelfForGpu(): Unit = {} +} + +/** + * Metadata for `DataWritingCommand` with no rule found + */ +final class RuleNotFoundDataWritingCommandMeta[INPUT <: DataWritingCommand]( + cmd: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]]) + extends DataWritingCommandMeta[INPUT](cmd, conf, parent, new NoRuleDataFromReplacementRule) { + + override def tagSelfForGpu(): Unit = { + willNotWorkOnGpu(s"GPU does not currently support the operator ${cmd.getClass}") + } +} + +/** + * Base class for metadata around `SparkPlan`. + */ +abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends RapidsMeta[INPUT, SparkPlan](plan, conf, parent, rule) { + + def tagForExplain(): Unit = { + if (!canThisBeReplaced) { + childExprs.foreach(_.recursiveSparkPlanPreventsRunningOnGpu()) + childParts.foreach(_.recursiveSparkPlanPreventsRunningOnGpu()) + childDataWriteCmds.foreach(_.recursiveSparkPlanPreventsRunningOnGpu()) + } + if (shouldThisBeRemoved) { + childExprs.foreach(_.recursiveSparkPlanRemoved()) + childParts.foreach(_.recursiveSparkPlanRemoved()) + childDataWriteCmds.foreach(_.recursiveSparkPlanRemoved()) + } + childPlans.foreach(_.tagForExplain()) + } + + def requireAstForGpuOn(exprMeta: BaseExprMeta[_]): Unit = { + // willNotWorkOnGpu does not deduplicate reasons. Most of the time that is fine + // but here we want to avoid adding the reason twice, because this method can be + // called multiple times, and also the reason can automatically be added in if + // a child expression would not work in the non-AST case either. + // So only add it if canExprTreeBeReplaced changed after requiring that the + // given expression is AST-able. + val previousExprReplaceVal = canExprTreeBeReplaced + exprMeta.requireAstForGpu() + val newExprReplaceVal = canExprTreeBeReplaced + if (previousExprReplaceVal != newExprReplaceVal && + !newExprReplaceVal) { + willNotWorkOnGpu("not all expressions can be replaced") + } + } + + override val childPlans: Seq[SparkPlanMeta[SparkPlan]] = + plan.children.map(GpuOverrides.wrapPlan(_, conf, Some(this))) + override val childExprs: Seq[BaseExprMeta[_]] = + plan.expressions.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + override val childParts: Seq[PartMeta[_]] = Seq.empty + override val childDataWriteCmds: Seq[DataWritingCommandMeta[_]] = Seq.empty + + def namedChildExprs: Map[String, Seq[BaseExprMeta[_]]] = Map.empty + + var cpuCost: Double = 0 + var gpuCost: Double = 0 + var estimatedOutputRows: Option[BigInt] = None + + override def convertToCpu(): SparkPlan = { + wrapped.withNewChildren(childPlans.map(_.convertIfNeeded())) + } + + def getReasonsNotToReplaceEntirePlan: Seq[String] = { + val childReasons = childPlans.flatMap(_.getReasonsNotToReplaceEntirePlan) + entirePlanExcludedReasons ++ childReasons + } + + // For adaptive execution we have to ensure we mark everything properly + // the first time through and that has to match what happens when AQE + // splits things up and does the subquery analysis at the shuffle boundaries. + // If the AQE subquery analysis changes the plan from what is originally + // marked we can end up with mismatches like happened in: + // https://github.com/NVIDIA/spark-rapids/issues/1423 + // AQE splits subqueries at shuffle boundaries which means that it only + // sees the children at that point. So in our fix up exchange we only + // look at the children and mark is at will not work on GPU if the + // child can't be replaced. + private def fixUpExchangeOverhead(): Unit = { + childPlans.foreach(_.fixUpExchangeOverhead()) + if (wrapped.isInstanceOf[ShuffleExchangeExec] && + !childPlans.exists(_.canThisBeReplaced) && + (plan.conf.adaptiveExecutionEnabled || + !parent.exists(_.canThisBeReplaced))) { + + willNotWorkOnGpu("Columnar exchange without columnar children is inefficient") + } + } + + /** + * Run rules that happen for the entire tree after it has been tagged initially. + */ + def runAfterTagRules(): Unit = { + // In the first pass tagSelfForGpu will deal with each operator individually. + // Children will be tagged first and then their parents will be tagged. This gives + // flexibility when tagging yourself to look at your children and disable yourself if your + // children are not all on the GPU. In some cases we need to be able to disable our + // children too, or in this case run a rule that will disable operations when looking at + // more of the tree. These exceptions should be documented here. We need to take special care + // that we take into account all side-effects of these changes, because we are **not** + // re-triggering the rules associated with parents, grandparents, etc. If things get too + // complicated we may need to update this to have something with triggers, but then we would + // have to be very careful to avoid loops in the rules. + + // RULES: + // 1) If file scan plan runs on the CPU, and the following plans run on GPU, then + // GpuRowToColumnar will be inserted. GpuRowToColumnar will invalid input_file_xxx operations, + // So input_file_xxx in the following GPU operators will get empty value. + // InputFileBlockRule is to prevent the SparkPlans + // [SparkPlan (with first input_file_xxx expression), FileScan) to run on GPU + InputFileBlockRule.apply(this.asInstanceOf[SparkPlanMeta[SparkPlan]]) + + // 2) For ShuffledHashJoin and SortMergeJoin we need to verify that all of the exchanges + // feeding them are either all on the GPU or all on the CPU, because the hashing is not + // consistent between the two implementations. This is okay because it is only impacting + // shuffled exchanges. So broadcast exchanges are not impacted which could have an impact on + // BroadcastHashJoin, and shuffled exchanges are not used to disable anything downstream. + fixUpExchangeOverhead() + } + + override final def tagSelfForGpu(): Unit = { + rule.getChecks.foreach(_.tag(this)) + + if (!canExprTreeBeReplaced) { + willNotWorkOnGpu("not all expressions can be replaced") + } + + /*if (!canScansBeReplaced) { + willNotWorkOnGpu("not all scans can be replaced") + } */ + + if (!canPartsBeReplaced) { + willNotWorkOnGpu("not all partitioning can be replaced") + } + + if (!canDataWriteCmdsBeReplaced) { + willNotWorkOnGpu("not all data writing commands can be replaced") + } + + checkExistingTags() + + tagPlanForGpu() + } + + /** + * When AQE is enabled and we are planning a new query stage, we need to look at meta-data + * previously stored on the spark plan to determine whether this operator can run on GPU + */ + def checkExistingTags(): Unit = { + } + + /** + * Called to verify that this plan will work on the GPU. Generic checks will have already been + * done. In general this method should only tag this operator as bad. If it needs to tag + * one of its children please take special care to update the comment inside + * `tagSelfForGpu` so we don't end up with something that could be cyclical. + */ + def tagPlanForGpu(): Unit = {} + + /** + * If this is enabled to be converted to a GPU version convert it and return the result, else + * do what is needed to possibly convert the rest of the plan. + */ + final def convertIfNeeded(): SparkPlan = { + if (shouldThisBeRemoved) { + if (childPlans.isEmpty) { + throw new IllegalStateException("can't remove when plan has no children") + } else if (childPlans.size > 1) { + throw new IllegalStateException("can't remove when plan has more than 1 child") + } + childPlans.head.convertIfNeeded() + } else { + convertToCpu + } + } + + /** + * Gets output attributes of current SparkPlanMeta, which is supposed to be called during + * type checking for the current plan. + * + * By default, it simply returns the output of wrapped plan. For specific plans, they can + * override outputTypeMetas to apply custom conversions on the output of wrapped plan. For plans + * which just pass through the schema of childPlan, they can set useOutputAttributesOfChild to + * true, in order to propagate the custom conversions of childPlan if they exist. + */ + def outputAttributes: Seq[Attribute] = outputTypeMetas match { + case Some(typeMetas) => + require(typeMetas.length == wrapped.output.length, + "The length of outputTypeMetas doesn't match to the length of plan's output") + wrapped.output.zip(typeMetas).map { + case (ar, meta) if meta.typeConverted => + addConvertedDataType(ar, meta) + AttributeReference(ar.name, meta.dataType.get, ar.nullable, ar.metadata)( + ar.exprId, ar.qualifier) + case (ar, _) => + ar + } + case None if useOutputAttributesOfChild => + require(wrapped.children.length == 1, + "useOutputAttributesOfChild ONLY works on UnaryPlan") + // We pass through the outputAttributes of the child plan only if it will be really applied + // in the runtime. We can pass through either if child plan can be replaced by GPU overrides; + // or if child plan is available for runtime type conversion. The later condition indicates + // the CPU to GPU data transition will be introduced as the pre-processing of the adjacent + // GpuRowToColumnarExec, though the child plan can't produce output attributes for GPU. + // Otherwise, we should fetch the outputAttributes from the wrapped plan. + // + // We can safely call childPlan.canThisBeReplaced here, because outputAttributes is called + // via tagSelfForGpu. At this point, tagging of the child plan has already taken place. + if (childPlans.head.canThisBeReplaced || childPlans.head.availableRuntimeDataTransition) { + childPlans.head.outputAttributes + } else { + wrapped.output + } + case None => + wrapped.output + } + + /** + * Overrides this method to implement custom conversions for specific plans. + */ + protected lazy val outputTypeMetas: Option[Seq[DataTypeMeta]] = None + + /** + * Whether to pass through the outputAttributes of childPlan's meta, only for UnaryPlan + */ + protected val useOutputAttributesOfChild: Boolean = false + + /** + * Whether there exists runtime data transition for the wrapped plan, if true, the overriding + * of output attributes will always work even when the wrapped plan can't be replaced by GPU + * overrides. + */ + val availableRuntimeDataTransition: Boolean = false +} + +/** + * Metadata for `SparkPlan` with no rule found + */ +final class RuleNotFoundSparkPlanMeta[INPUT <: SparkPlan]( + plan: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]]) + extends SparkPlanMeta[INPUT](plan, conf, parent, new NoRuleDataFromReplacementRule) { + + override def tagPlanForGpu(): Unit = + willNotWorkOnGpu(s"GPU does not currently support the operator ${plan.getClass}") +} + +/** + * Metadata for `SparkPlan` that should not be replaced or have any kind of warning for + */ +final class DoNotReplaceOrWarnSparkPlanMeta[INPUT <: SparkPlan]( + plan: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]]) + extends SparkPlanMeta[INPUT](plan, conf, parent, new NoRuleDataFromReplacementRule) { + + /** We don't want to spam the user with messages about these operators */ + override def suppressWillWorkOnGpuInfo: Boolean = true + + override def tagPlanForGpu(): Unit = + willNotWorkOnGpu(s"there is no need to replace ${plan.getClass}") +} + +sealed abstract class ExpressionContext +object ProjectExprContext extends ExpressionContext { + override def toString: String = "project" +} +/** + * This is a special context. All other contexts are determined by the Spark query in a generic way. + * AST support in many cases is an optimization and so it is tagged and checked after it is + * determined that this operation will run on the GPU. In other cases it is required. In those cases + * AST support is determined and used when tagging the metas to see if they will work on the GPU or + * not. This part is not done automatically. + */ +object AstExprContext extends ExpressionContext { + override def toString: String = "AST" + + val notSupportedMsg = "this expression does not support AST" +} +object GroupByAggExprContext extends ExpressionContext { + override def toString: String = "aggregation" +} +object ReductionAggExprContext extends ExpressionContext { + override def toString: String = "reduction" +} +object WindowAggExprContext extends ExpressionContext { + override def toString: String = "window" +} + +object ExpressionContext { + private[this] def findParentPlanMeta(meta: BaseExprMeta[_]): Option[SparkPlanMeta[_]] = + meta.parent match { + case Some(p: BaseExprMeta[_]) => findParentPlanMeta(p) + case Some(p: SparkPlanMeta[_]) => Some(p) + case _ => None + } + + def getAggregateFunctionContext(meta: BaseExprMeta[_]): ExpressionContext = { + val parent = findParentPlanMeta(meta) + assert(parent.isDefined, "It is expected that an aggregate function is a child of a SparkPlan") + parent.get.wrapped match { + case agg: SparkPlan if agg.isInstanceOf[WindowExec] => + WindowAggExprContext + case agg: HashAggregateExec => + // Spark 2.x doesn't have the BaseAggregateExec class + if (agg.groupingExpressions.isEmpty) { + ReductionAggExprContext + } else { + GroupByAggExprContext + } + case agg: SortAggregateExec => + if (agg.groupingExpressions.isEmpty) { + ReductionAggExprContext + } else { + GroupByAggExprContext + } + case _ => throw new IllegalStateException( + s"Found an aggregation function in an unexpected context $parent") + } + } + + def getRegularOperatorContext(meta: RapidsMeta[_, _]): ExpressionContext = meta.wrapped match { + case _: Expression if meta.parent.isDefined => getRegularOperatorContext(meta.parent.get) + case _ => ProjectExprContext + } +} + +/** + * The metadata around `DataType`, which records the original data type, the desired data type for + * GPU overrides, and the reason of potential conversion. The metadata is to ensure TypeChecks + * tagging the actual data types for GPU runtime, since data types of GPU overrides may slightly + * differ from original CPU counterparts. + */ +class DataTypeMeta( + val wrapped: Option[DataType], + desired: Option[DataType] = None, + reason: Option[String] = None) { + + lazy val dataType: Option[DataType] = desired match { + case Some(dt) => Some(dt) + case None => wrapped + } + + // typeConverted will only be true if there exists DataType in wrapped expression + lazy val typeConverted: Boolean = dataType.nonEmpty && dataType != wrapped + + /** + * Returns the reason for conversion if exists + */ + def reasonForConversion: String = { + val reasonMsg = (if (typeConverted) reason else None) + .map(r => s", because $r").getOrElse("") + s"Converted ${wrapped.getOrElse("N/A")} to " + + s"${dataType.getOrElse("N/A")}" + reasonMsg + } +} + +object DataTypeMeta { + /** + * create DataTypeMeta from Expression + */ + def apply(expr: Expression, overrideType: Option[DataType]): DataTypeMeta = { + val wrapped = try { + Some(expr.dataType) + } catch { + case _: java.lang.UnsupportedOperationException => None + } + new DataTypeMeta(wrapped, overrideType) + } +} + +/** + * Base class for metadata around `Expression`. + */ +abstract class BaseExprMeta[INPUT <: Expression]( + expr: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends RapidsMeta[INPUT, Expression](expr, conf, parent, rule) { + + private val cannotBeAstReasons: mutable.Set[String] = mutable.Set.empty + + override val childPlans: Seq[SparkPlanMeta[_]] = Seq.empty + override val childExprs: Seq[BaseExprMeta[_]] = + expr.children.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + override val childParts: Seq[PartMeta[_]] = Seq.empty + override val childDataWriteCmds: Seq[DataWritingCommandMeta[_]] = Seq.empty + + override val printWrapped: Boolean = true + + def dataType: DataType = expr.dataType + + val ignoreUnsetDataTypes = false + + override def canExprTreeBeReplaced: Boolean = + canThisBeReplaced && super.canExprTreeBeReplaced + + /** + * Gets the DataTypeMeta of current BaseExprMeta, which is supposed to be called in the + * tag methods of expression-level type checks. + * + * By default, it simply returns the data type of wrapped expression. But for specific + * expressions, they can easily override data type for type checking through calling the + * method `overrideDataType`. + */ + def typeMeta: DataTypeMeta = DataTypeMeta(wrapped.asInstanceOf[Expression], overrideType) + + /** + * Overrides the data type of the wrapped expression during type checking. + * + * NOTICE: This method will NOT modify the wrapped expression itself. Therefore, the actual + * transition on data type is still necessary when converting this expression to GPU. + */ + def overrideDataType(dt: DataType): Unit = overrideType = Some(dt) + + private var overrideType: Option[DataType] = None + + lazy val context: ExpressionContext = expr match { + case _: WindowExpression => WindowAggExprContext + case _: WindowFunction => WindowAggExprContext + case _: AggregateFunction => ExpressionContext.getAggregateFunctionContext(this) + case _: AggregateExpression => ExpressionContext.getAggregateFunctionContext(this) + case _ => ExpressionContext.getRegularOperatorContext(this) + } + + val isFoldableNonLitAllowed: Boolean = false + + final override def tagSelfForGpu(): Unit = { + if (wrapped.foldable && !GpuOverrides.isLit(wrapped) && !isFoldableNonLitAllowed) { + willNotWorkOnGpu(s"Cannot run on GPU. Is ConstantFolding excluded? Expression " + + s"$wrapped is foldable and operates on non literals") + } + rule.getChecks.foreach(_.tag(this)) + tagExprForGpu() + } + + /** + * Called to verify that this expression will work on the GPU. For most expressions without + * extra checks all of the checks should have already been done. + */ + def tagExprForGpu(): Unit = {} + + final def willNotWorkInAst(because: String): Unit = cannotBeAstReasons.add(because) + + final def canThisBeAst: Boolean = { + tagForAst() + childExprs.forall(_.canThisBeAst) && cannotBeAstReasons.isEmpty + } + + final def requireAstForGpu(): Unit = { + tagForAst() + cannotBeAstReasons.foreach { reason => + willNotWorkOnGpu(s"AST is required and $reason") + } + childExprs.foreach(_.requireAstForGpu()) + } + + private var taggedForAst = false + private final def tagForAst(): Unit = { + if (!taggedForAst) { + if (wrapped.foldable && !GpuOverrides.isLit(wrapped)) { + willNotWorkInAst(s"Cannot convert to AST. Is ConstantFolding excluded? Expression " + + s"$wrapped is foldable and operates on non literals") + } + + rule.getChecks.foreach { + case exprCheck: ExprChecks => exprCheck.tagAst(this) + case other => throw new IllegalArgumentException(s"Unexpected check found $other") + } + + tagSelfForAst() + taggedForAst = true + } + } + + /** Called to verify that this expression will work as a GPU AST expression. */ + protected def tagSelfForAst(): Unit = { + // NOOP + } + + protected def willWorkInAstInfo: String = { + if (cannotBeAstReasons.isEmpty) { + "will run in AST" + } else { + s"cannot be converted to GPU AST because ${cannotBeAstReasons.mkString(";")}" + } + } + + /** + * Create a string explanation for whether this expression tree can be converted to an AST + * @param strBuilder where to place the string representation. + * @param depth how far down the tree this is. + * @param all should all the data be printed or just what does not work in the AST? + */ + protected def printAst(strBuilder: StringBuilder, depth: Int, all: Boolean): Unit = { + if (all || !canThisBeAst) { + indent(strBuilder, depth) + strBuilder.append(operationName) + .append(" <") + .append(wrapped.getClass.getSimpleName) + .append("> ") + + if (printWrapped) { + strBuilder.append(wrapped) + .append(" ") + } + + strBuilder.append(willWorkInAstInfo).append("\n") + } + childExprs.foreach(_.printAst(strBuilder, depth + 1, all)) + } + + def explainAst(all: Boolean): String = { + tagForAst() + val appender = new StringBuilder() + printAst(appender, 0, all) + appender.toString() + } +} + +abstract class ExprMeta[INPUT <: Expression]( + expr: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends BaseExprMeta[INPUT](expr, conf, parent, rule) { +} + +/** + * Base class for metadata around `UnaryExpression`. + */ +abstract class UnaryExprMeta[INPUT <: UnaryExpression]( + expr: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends ExprMeta[INPUT](expr, conf, parent, rule) { + + /** + * `ConstantFolding` executes early in the logical plan process, which + * simplifies many things before we get to the physical plan. If you enable + * AQE, some optimizations can cause new expressions to show up that would have been + * folded in by the logical plan optimizer (like `cast(null as bigint)` which just + * becomes Literal(null, Long) after `ConstantFolding`), so enabling this here + * allows us to handle these when they are generated by an AQE rule. + */ + override val isFoldableNonLitAllowed: Boolean = true +} + +/** Base metadata class for unary expressions that support conversion to AST as well */ +abstract class UnaryAstExprMeta[INPUT <: UnaryExpression]( + expr: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends UnaryExprMeta[INPUT](expr, conf, parent, rule) { +} + +/** + * Base class for metadata around `AggregateFunction`. + */ +abstract class AggExprMeta[INPUT <: AggregateFunction]( + val expr: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends ExprMeta[INPUT](expr, conf, parent, rule) { + + override final def tagExprForGpu(): Unit = { + tagAggForGpu() + } + + // not all aggs overwrite this + def tagAggForGpu(): Unit = {} +} + +/** + * Base class for metadata around `ImperativeAggregate`. + */ +abstract class ImperativeAggExprMeta[INPUT <: ImperativeAggregate]( + expr: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends AggExprMeta[INPUT](expr, conf, parent, rule) { +} + +/** + * Base class for metadata around `TypedImperativeAggregate`. + */ +abstract class TypedImperativeAggExprMeta[INPUT <: TypedImperativeAggregate[_]]( + expr: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends ImperativeAggExprMeta[INPUT](expr, conf, parent, rule) { + + /** + * Returns aggregation buffer with the actual data type under GPU runtime. This method is + * called to override the data types of typed imperative aggregation buffers during GPU + * overriding. + */ + def aggBufferAttribute: AttributeReference + + /** + * Whether buffers of current Aggregate is able to be converted from CPU to GPU format and + * reversely in runtime. If true, it assumes both createCpuToGpuBufferConverter and + * createGpuToCpuBufferConverter are implemented. + */ + val supportBufferConversion: Boolean = false +} + +/** + * Base class for metadata around `BinaryExpression`. + */ +abstract class BinaryExprMeta[INPUT <: BinaryExpression]( + expr: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends ExprMeta[INPUT](expr, conf, parent, rule) { +} + +/** Base metadata class for binary expressions that support conversion to AST */ +abstract class BinaryAstExprMeta[INPUT <: BinaryExpression]( + expr: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends BinaryExprMeta[INPUT](expr, conf, parent, rule) { + + override def tagSelfForAst(): Unit = { + if (wrapped.left.dataType != wrapped.right.dataType) { + willNotWorkInAst("AST binary expression operand types must match, found " + + s"${wrapped.left.dataType},${wrapped.right.dataType}") + } + } +} + +/** + * Base class for metadata around `TernaryExpression`. + */ +abstract class TernaryExprMeta[INPUT <: TernaryExpression]( + expr: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends ExprMeta[INPUT](expr, conf, parent, rule) { +} + +abstract class String2TrimExpressionMeta[INPUT <: String2TrimExpression]( + expr: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends ExprMeta[INPUT](expr, conf, parent, rule) { +} + +/** + * Base class for metadata around `ComplexTypeMergingExpression`. + */ +abstract class ComplexTypeMergingExprMeta[INPUT <: ComplexTypeMergingExpression]( + expr: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends ExprMeta[INPUT](expr, conf, parent, rule) { +} + +/** + * Metadata for `Expression` with no rule found + */ +final class RuleNotFoundExprMeta[INPUT <: Expression]( + expr: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]]) + extends ExprMeta[INPUT](expr, conf, parent, new NoRuleDataFromReplacementRule) { + + override def tagExprForGpu(): Unit = + willNotWorkOnGpu(s"GPU does not currently support the operator ${expr.getClass}") +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala new file mode 100644 index 00000000000..5f4ee4d54d0 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -0,0 +1,794 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids + +import java.sql.SQLException + +import scala.collection.mutable.ListBuffer + +/** + * Regular expression parser based on a Pratt Parser design. + * + * The goal of this parser is to build a minimal AST that allows us + * to validate that we can support the expression on the GPU. The goal + * is not to parse with the level of detail that would be required if + * we were building an evaluation engine. For example, operator precedence is + * largely ignored but could be added if we need it later. + * + * The Java and cuDF regular expression documentation has been used as a reference: + * + * Java regex: https://docs.oracle.com/javase/7/docs/api/java/util/regex/Pattern.html + * cuDF regex: https://docs.rapids.ai/api/libcudf/stable/md_regex.html + * + * The following blog posts provide some background on Pratt Parsers and parsing regex. + * + * - https://journal.stuffwithstuff.com/2011/03/19/pratt-parsers-expression-parsing-made-easy/ + * - https://matt.might.net/articles/parsing-regex-with-recursive-descent/ + */ +class RegexParser(pattern: String) { + + /** index of current position within the string being parsed */ + private var pos = 0 + + def parse(): RegexAST = { + val ast = parseUntil(() => eof()) + if (!eof()) { + throw new RegexUnsupportedException("failed to parse full regex") + } + ast + } + + private def parseUntil(until: () => Boolean): RegexAST = { + val term = parseTerm(() => until() || peek().contains('|')) + if (!eof() && peek().contains('|')) { + consumeExpected('|') + RegexChoice(term, parseUntil(until)) + } else { + term + } + } + + private def parseTerm(until: () => Boolean): RegexAST = { + val sequence = RegexSequence(new ListBuffer()) + while (!eof() && !until()) { + parseFactor(until) match { + case RegexSequence(parts) => + sequence.parts ++= parts + case other => + sequence.parts += other + } + } + sequence + } + + private def isValidQuantifierAhead(): Boolean = { + if (peek().contains('{')) { + val bookmark = pos + consumeExpected('{') + val q = parseQuantifierOrLiteralBrace() + pos = bookmark + q match { + case _: QuantifierFixedLength | _: QuantifierVariableLength => true + case _ => false + } + } else { + false + } + } + + private def parseFactor(until: () => Boolean): RegexAST = { + var base = parseBase() + while (!eof() && !until() + && (peek().exists(ch => ch == '*' || ch == '+' || ch == '?') + || isValidQuantifierAhead())) { + + val quantifier = if (peek().contains('{')) { + consumeExpected('{') + parseQuantifierOrLiteralBrace().asInstanceOf[RegexQuantifier] + } else { + SimpleQuantifier(consume()) + } + base = RegexRepetition(base, quantifier) + } + base + } + + private def parseBase(): RegexAST = { + consume() match { + case '(' => + parseGroup() + case '[' => + parseCharacterClass() + case '\\' => + parseEscapedCharacter() + case '\u0000' => + throw new RegexUnsupportedException( + "cuDF does not support null characters in regular expressions", Some(pos)) + case '*' | '+' | '?' => + throw new RegexUnsupportedException( + "base expression cannot start with quantifier", Some(pos)) + case other => + RegexChar(other) + } + } + + private def parseGroup(): RegexAST = { + val captureGroup = if (pos + 1 < pattern.length + && pattern.charAt(pos) == '?' + && pattern.charAt(pos+1) == ':') { + pos += 2 + false + } else { + true + } + val term = parseUntil(() => peek().contains(')')) + consumeExpected(')') + RegexGroup(captureGroup, term) + } + + private def parseCharacterClass(): RegexCharacterClass = { + val start = pos + val characterClass = RegexCharacterClass(negated = false, characters = ListBuffer()) + // loop until the end of the character class or EOF + var characterClassComplete = false + while (!eof() && !characterClassComplete) { + val ch = consume() + ch match { + case '[' => + // treat as a literal character and add to the character class + characterClass.append(ch) + case ']' if (!characterClass.negated && pos > start + 1) || + (characterClass.negated && pos > start + 2) => + // "[]" is not a valid character class + // "[]a]" is a valid character class containing the characters "]" and "a" + // "[^]a]" is a valid negated character class containing the characters "]" and "a" + characterClassComplete = true + case '^' if pos == start + 1 => + // Negates the character class, causing it to match a single character not listed in + // the character class. Only valid immediately after the opening '[' + characterClass.negated = true + case '\n' | '\r' | '\t' | '\b' | '\f' | '\u0007' => + // treat as a literal character and add to the character class + characterClass.append(ch) + case '\\' => + peek() match { + case None => + throw new RegexUnsupportedException( + s"Unclosed character class", Some(pos)) + case Some(ch) => + // typically an escaped metacharacter ('\\', '^', '-', ']', '+') + // within the character class, but could be any escaped character + characterClass.appendEscaped(consumeExpected(ch)) + } + case '\u0000' => + throw new RegexUnsupportedException( + "cuDF does not support null characters in regular expressions", Some(pos)) + case _ => + // check for range + val start = ch + peek() match { + case Some('-') => + consumeExpected('-') + peek() match { + case Some(']') => + // '-' at end of class e.g. "[abc-]" + characterClass.append(ch) + characterClass.append('-') + case Some(end) => + skip() + characterClass.appendRange(start, end) + case _ => + throw new RegexUnsupportedException( + "unexpected EOF while parsing character range", + Some(pos)) + } + case _ => + // treat as supported literal character + characterClass.append(ch) + } + } + } + if (!characterClassComplete) { + throw new RegexUnsupportedException( + s"Unclosed character class", Some(pos)) + } + characterClass + } + + + /** + * Parse a quantifier in one of the following formats: + * + * {n} + * {n,} + * {n,m} (only valid if m >= n) + */ + private def parseQuantifierOrLiteralBrace(): RegexAST = { + + // assumes that '{' has already been consumed + val start = pos + + def treatAsLiteralBrace() = { + // this was not a quantifier, just a literal '{' + pos = start + 1 + RegexChar('{') + } + + consumeInt match { + case Some(minLength) => + peek() match { + case Some(',') => + consumeExpected(',') + val max = consumeInt() + if (peek().contains('}')) { + consumeExpected('}') + max match { + case None => + QuantifierVariableLength(minLength, None) + case Some(m) => + if (m >= minLength) { + QuantifierVariableLength(minLength, max) + } else { + treatAsLiteralBrace() + } + } + } else { + treatAsLiteralBrace() + } + case Some('}') => + consumeExpected('}') + QuantifierFixedLength(minLength) + case _ => + treatAsLiteralBrace() + } + case None => + treatAsLiteralBrace() + } + } + + private def parseEscapedCharacter(): RegexAST = { + peek() match { + case None => + throw new RegexUnsupportedException("escape at end of string", Some(pos)) + case Some(ch) => + ch match { + case 'A' | 'Z' | 'z' => + // string anchors + consumeExpected(ch) + RegexEscaped(ch) + case 's' | 'S' | 'd' | 'D' | 'w' | 'W' => + // meta sequences + consumeExpected(ch) + RegexEscaped(ch) + case 'B' | 'b' => + // word boundaries + consumeExpected(ch) + RegexEscaped(ch) + case '[' | '\\' | '^' | '$' | '.' | '⎮' | '?' | '*' | '+' | '(' | ')' | '{' | '}' => + // escaped metacharacter + consumeExpected(ch) + RegexEscaped(ch) + case 'x' => + consumeExpected(ch) + parseHexDigit + case _ if Character.isDigit(ch) => + parseOctalDigit + case other => + throw new RegexUnsupportedException( + s"invalid or unsupported escape character '$other'", Some(pos - 1)) + } + } + } + + private def isHexDigit(ch: Char): Boolean = ch.isDigit || + (ch >= 'a' && ch <= 'f') || + (ch >= 'A' && ch <= 'F') + + private def parseHexDigit: RegexHexDigit = { + // \xhh The character with hexadecimal value 0xhh + // \x{h...h} The character with hexadecimal value 0xh...h + // (Character.MIN_CODE_POINT <= 0xh...h <= Character.MAX_CODE_POINT) + + val start = pos + while (!eof() && isHexDigit(pattern.charAt(pos))) { + pos += 1 + } + val hexDigit = pattern.substring(start, pos) + + if (hexDigit.length < 2) { + throw new RegexUnsupportedException(s"Invalid hex digit: $hexDigit") + } + + val value = Integer.parseInt(hexDigit, 16) + if (value < Character.MIN_CODE_POINT || value > Character.MAX_CODE_POINT) { + throw new RegexUnsupportedException(s"Invalid hex digit: $hexDigit") + } + + RegexHexDigit(hexDigit) + } + + private def isOctalDigit(ch: Char): Boolean = ch >= '0' && ch <= '7' + + private def parseOctalDigit: RegexOctalChar = { + // \0n The character with octal value 0n (0 <= n <= 7) + // \0nn The character with octal value 0nn (0 <= n <= 7) + // \0mnn The character with octal value 0mnn (0 <= m <= 3, 0 <= n <= 7) + + def parseOctalDigits(n: Integer): RegexOctalChar = { + val octal = pattern.substring(pos, pos + n) + pos += n + RegexOctalChar(octal) + } + + if (!eof() && isOctalDigit(pattern.charAt(pos))) { + if (pos + 1 < pattern.length && isOctalDigit(pattern.charAt(pos + 1))) { + if (pos + 2 < pattern.length && isOctalDigit(pattern.charAt(pos + 2)) + && pattern.charAt(pos) <= '3') { + parseOctalDigits(3) + } else { + parseOctalDigits(2) + } + } else { + parseOctalDigits(1) + } + } else { + throw new RegexUnsupportedException( + "Invalid octal digit", Some(pos)) + } + } + + /** Determine if we are at the end of the input */ + private def eof(): Boolean = pos == pattern.length + + /** Advance the index by one */ + private def skip(): Unit = { + if (eof()) { + throw new RegexUnsupportedException("Unexpected EOF", Some(pos)) + } + pos += 1 + } + + /** Get the next character and advance the index by one */ + private def consume(): Char = { + if (eof()) { + throw new RegexUnsupportedException("Unexpected EOF", Some(pos)) + } else { + pos += 1 + pattern.charAt(pos - 1) + } + } + + /** Consume the next character if it is the one we expect */ + private def consumeExpected(expected: Char): Char = { + val consumed = consume() + if (consumed != expected) { + throw new RegexUnsupportedException( + s"Expected '$expected' but found '$consumed'", Some(pos-1)) + } + consumed + } + + /** Peek at the next character without consuming it */ + private def peek(): Option[Char] = { + if (eof()) { + None + } else { + Some(pattern.charAt(pos)) + } + } + + private def consumeInt(): Option[Int] = { + val start = pos + while (!eof() && peek().exists(_.isDigit)) { + skip() + } + if (start == pos) { + None + } else { + Some(pattern.substring(start, pos).toInt) + } + } + +} + +/** + * Transpile Java/Spark regular expression to a format that cuDF supports, or throw an exception + * if this is not possible. + * + * @param replace True if performing a replacement (regexp_replace), false + * if matching only (rlike) + */ +class CudfRegexTranspiler(replace: Boolean) { + + // cuDF throws a "nothing to repeat" exception for many of the edge cases that are + // rejected by the transpiler + private val nothingToRepeat = "nothing to repeat" + + /** + * Parse Java regular expression and translate into cuDF regular expression. + * + * @param pattern Regular expression that is valid in Java's engine + * @return Regular expression in cuDF format + */ + def transpile(pattern: String): String = { + // parse the source regular expression + val regex = new RegexParser(pattern).parse() + // validate that the regex is supported by cuDF + val cudfRegex = rewrite(regex) + // write out to regex string, performing minor transformations + // such as adding additional escaping + cudfRegex.toRegexString + } + + private def rewrite(regex: RegexAST): RegexAST = { + regex match { + + case RegexChar(ch) => ch match { + case '.' => + // workaround for https://github.com/rapidsai/cudf/issues/9619 + RegexCharacterClass(negated = true, ListBuffer(RegexChar('\r'), RegexChar('\n'))) + case '$' => + // see https://github.com/NVIDIA/spark-rapids/issues/4533 + throw new RegexUnsupportedException("line anchor $ is not supported") + case _ => + regex + } + + case RegexOctalChar(_) => + // see https://github.com/NVIDIA/spark-rapids/issues/4288 + throw new RegexUnsupportedException( + s"cuDF does not support octal digits consistently with Spark") + + case RegexHexDigit(_) => + // see https://github.com/NVIDIA/spark-rapids/issues/4486 + throw new RegexUnsupportedException( + s"cuDF does not support hex digits consistently with Spark") + + case RegexEscaped(ch) => ch match { + case 'D' => + // see https://github.com/NVIDIA/spark-rapids/issues/4475 + throw new RegexUnsupportedException("non-digit class \\D is not supported") + case 'W' => + // see https://github.com/NVIDIA/spark-rapids/issues/4475 + throw new RegexUnsupportedException("non-word class \\W is not supported") + case 'b' | 'B' => + // see https://github.com/NVIDIA/spark-rapids/issues/4517 + throw new RegexUnsupportedException("word boundaries are not supported") + case 's' | 'S' => + // see https://github.com/NVIDIA/spark-rapids/issues/4528 + throw new RegexUnsupportedException("whitespace classes are not supported") + case 'z' => + if (replace) { + // see https://github.com/NVIDIA/spark-rapids/issues/4425 + throw new RegexUnsupportedException( + "string anchor \\z is not supported in replace mode") + } + // cuDF does not support "\z" but supports "$", which is equivalent + RegexChar('$') + case 'Z' => + // see https://github.com/NVIDIA/spark-rapids/issues/4532 + throw new RegexUnsupportedException("string anchor \\Z is not supported") + case _ => + regex + } + + case RegexCharacterRange(_, _) => + regex + + case RegexCharacterClass(negated, characters) => + characters.foreach { + case RegexChar(ch) if ch == '[' || ch == ']' => + // examples: + // - "[a[]" should match the literal characters "a" and "[" + // - "[a-b[c-d]]" is supported by Java but not cuDF + throw new RegexUnsupportedException("nested character classes are not supported") + case _ => + } + val components: Seq[RegexCharacterClassComponent] = characters + .map(x => rewrite(x).asInstanceOf[RegexCharacterClassComponent]) + + if (negated) { + // There are differences between cuDF and Java handling of newlines + // for negative character matches. The expression `[^a]` will match + // `\r` and `\n` in Java but not in cuDF, so we replace `[^a]` with + // `(?:[\r\n]|[^a])`. We also have to take into account whether any + // newline characters are included in the character range. + // + // Examples: + // + // `[^a]` => `(?:[\r\n]|[^a])` + // `[^a\r]` => `(?:[\n]|[^a])` + // `[^a\n]` => `(?:[\r]|[^a])` + // `[^a\r\n]` => `[^a]` + // `[^\r\n]` => `[^\r\n]` + + val linefeedCharsInPattern = components.flatMap { + case RegexChar(ch) if ch == '\n' || ch == '\r' => Seq(ch) + case RegexEscaped(ch) if ch == 'n' => Seq('\n') + case RegexEscaped(ch) if ch == 'r' => Seq('\r') + case _ => Seq.empty + } + + val onlyLinefeedChars = components.length == linefeedCharsInPattern.length + + val negatedNewlines = Seq('\r', '\n').diff(linefeedCharsInPattern.distinct) + + if (onlyLinefeedChars && linefeedCharsInPattern.length == 2) { + // special case for `[^\r\n]` and `[^\\r\\n]` + RegexCharacterClass(negated = true, ListBuffer(components: _*)) + } else if (negatedNewlines.isEmpty) { + RegexCharacterClass(negated = true, ListBuffer(components: _*)) + } else { + RegexGroup(capture = false, + RegexChoice( + RegexCharacterClass(negated = false, + characters = ListBuffer(negatedNewlines.map(RegexChar): _*)), + RegexCharacterClass(negated = true, ListBuffer(components: _*)))) + } + } else { + RegexCharacterClass(negated, ListBuffer(components: _*)) + } + + case RegexSequence(parts) => + if (parts.isEmpty) { + // examples: "", "()", "a|", "|b" + throw new RegexUnsupportedException("empty sequence not supported") + } + if (isRegexChar(parts.head, '|') || isRegexChar(parts.last, '|')) { + // examples: "a|", "|b" + throw new RegexUnsupportedException(nothingToRepeat) + } + if (isRegexChar(parts.head, '{')) { + // example: "{" + // cuDF would treat this as a quantifier even though in this + // context (being at the start of a sequence) it is not quantifying anything + // note that we could choose to escape this in the transpiler rather than + // falling back to CPU + throw new RegexUnsupportedException(nothingToRepeat) + } + if (parts.forall(isBeginOrEndLineAnchor)) { + throw new RegexUnsupportedException( + "sequences that only contain '^' or '$' are not supported") + } + RegexSequence(parts.map(rewrite)) + + case RegexRepetition(base, quantifier) => (base, quantifier) match { + case (_, SimpleQuantifier(ch)) if replace && "?*".contains(ch) => + // example: pattern " ?", input "] b[", replace with "X": + // java: X]XXbX[X + // cuDF: XXXX] b[ + // see https://github.com/NVIDIA/spark-rapids/issues/4468 + throw new RegexUnsupportedException( + "regexp_replace on GPU does not support repetition with ? or *") + + case (RegexEscaped(ch), _) if ch != 'd' && ch != 'w' => + // example: "\B?" + throw new RegexUnsupportedException(nothingToRepeat) + + case (RegexChar(a), _) if "$^".contains(a) => + // example: "$*" + throw new RegexUnsupportedException(nothingToRepeat) + + case (RegexRepetition(_, _), _) => + // example: "a*+" + throw new RegexUnsupportedException(nothingToRepeat) + + case _ => + RegexRepetition(rewrite(base), quantifier) + + } + + case RegexChoice(l, r) => + val ll = rewrite(l) + val rr = rewrite(r) + + // cuDF does not support repetition on one side of a choice, such as "a*|a" + def isRepetition(e: RegexAST): Boolean = { + e match { + case RegexRepetition(_, _) => true + case RegexGroup(_, term) => isRepetition(term) + case RegexSequence(parts) if parts.nonEmpty => isRepetition(parts.last) + case _ => false + } + } + if (isRepetition(ll) || isRepetition(rr)) { + throw new RegexUnsupportedException(nothingToRepeat) + } + + // cuDF does not support terms ending with line anchors on one side + // of a choice, such as "^|$" + def endsWithLineAnchor(e: RegexAST): Boolean = { + e match { + case RegexSequence(parts) if parts.nonEmpty => + isBeginOrEndLineAnchor(parts.last) + case _ => false + } + } + if (endsWithLineAnchor(ll) || endsWithLineAnchor(rr)) { + throw new RegexUnsupportedException(nothingToRepeat) + } + + RegexChoice(ll, rr) + + case RegexGroup(capture, term) => + RegexGroup(capture, rewrite(term)) + + case other => + throw new RegexUnsupportedException(s"Unhandled expression in transpiler: $other") + } + } + + private def isBeginOrEndLineAnchor(regex: RegexAST): Boolean = regex match { + case RegexSequence(parts) => parts.nonEmpty && parts.forall(isBeginOrEndLineAnchor) + case RegexGroup(_, term) => isBeginOrEndLineAnchor(term) + case RegexChoice(l, r) => isBeginOrEndLineAnchor(l) && isBeginOrEndLineAnchor(r) + case RegexRepetition(term, _) => isBeginOrEndLineAnchor(term) + case RegexChar(ch) => ch == '^' || ch == '$' + case RegexEscaped('z') => true // \z gets translated to $ + case _ => false + } + + private def isRegexChar(expr: RegexAST, value: Char): Boolean = expr match { + case RegexChar(ch) => ch == value + case _ => false + } +} + +sealed trait RegexAST { + def children(): Seq[RegexAST] + def toRegexString: String +} + +sealed case class RegexSequence(parts: ListBuffer[RegexAST]) extends RegexAST { + override def children(): Seq[RegexAST] = parts + override def toRegexString: String = parts.map(_.toRegexString).mkString +} + +sealed case class RegexGroup(capture: Boolean, term: RegexAST) extends RegexAST { + override def children(): Seq[RegexAST] = Seq(term) + override def toRegexString: String = if (capture) { + s"(${term.toRegexString})" + } else { + s"(?:${term.toRegexString})" + } +} + +sealed case class RegexChoice(a: RegexAST, b: RegexAST) extends RegexAST { + override def children(): Seq[RegexAST] = Seq(a, b) + override def toRegexString: String = s"${a.toRegexString}|${b.toRegexString}" +} + +sealed case class RegexRepetition(a: RegexAST, quantifier: RegexQuantifier) extends RegexAST { + override def children(): Seq[RegexAST] = Seq(a) + override def toRegexString: String = s"${a.toRegexString}${quantifier.toRegexString}" +} + +sealed trait RegexQuantifier extends RegexAST + +sealed case class SimpleQuantifier(ch: Char) extends RegexQuantifier { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = ch.toString +} + +sealed case class QuantifierFixedLength(length: Int) + extends RegexQuantifier { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = { + s"{$length}" + } +} + +sealed case class QuantifierVariableLength(minLength: Int, maxLength: Option[Int]) + extends RegexQuantifier{ + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = { + maxLength match { + case Some(max) => + s"{$minLength,$max}" + case _ => + s"{$minLength,}" + } + } +} + +sealed trait RegexCharacterClassComponent extends RegexAST + +sealed case class RegexHexDigit(a: String) extends RegexCharacterClassComponent { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"\\x$a" +} + +sealed case class RegexOctalChar(a: String) extends RegexCharacterClassComponent { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"\\$a" +} + +sealed case class RegexChar(a: Char) extends RegexCharacterClassComponent { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"$a" +} + +sealed case class RegexEscaped(a: Char) extends RegexCharacterClassComponent{ + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"\\$a" +} + +sealed case class RegexCharacterRange(start: Char, end: Char) + extends RegexCharacterClassComponent{ + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"$start-$end" +} + +sealed case class RegexCharacterClass( + var negated: Boolean, + var characters: ListBuffer[RegexCharacterClassComponent]) + extends RegexAST { + + override def children(): Seq[RegexAST] = characters + def append(ch: Char): Unit = { + characters += RegexChar(ch) + } + + def appendEscaped(ch: Char): Unit = { + characters += RegexEscaped(ch) + } + + def appendRange(start: Char, end: Char): Unit = { + characters += RegexCharacterRange(start, end) + } + + override def toRegexString: String = { + val builder = new StringBuilder("[") + if (negated) { + builder.append("^") + } + for (a <- characters) { + a match { + case RegexChar(ch) if requiresEscaping(ch) => + // cuDF has stricter escaping requirements for certain characters + // within a character class compared to Java or Python regex + builder.append(s"\\$ch") + case other => + builder.append(other.toRegexString) + } + } + builder.append("]") + builder.toString() + } + + private def requiresEscaping(ch: Char): Boolean = { + // there are likely other cases that we will need to add here but this + // covers everything we have seen so far during fuzzing + ch match { + case '-' => + // cuDF requires '-' to be escaped when used as a character within a character + // to disambiguate from the character range syntax 'a-b' + true + case _ => + false + } + } +} + +class RegexUnsupportedException(message: String, index: Option[Int] = None) + extends SQLException { + override def getMessage: String = { + index match { + case Some(i) => s"$message near index $i" + case _ => message + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimGpuOverrides.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimGpuOverrides.scala new file mode 100644 index 00000000000..9b8981e143c --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimGpuOverrides.scala @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import com.nvidia.spark.rapids.shims.v2._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.aggregate._ +import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand, DataWritingCommandExec, ExecutedCommandExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.execution.python.{AggregateInPandasExec, ArrowEvalPythonExec, FlatMapGroupsInPandasExec, WindowInPandasExec} +import org.apache.spark.sql.execution.window.WindowExec +import org.apache.spark.sql.rapids._ +import org.apache.spark.sql.rapids.execution._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +// Overrides that are in the shim in normal sql-plugin, moved here for easier diffing +object ShimGpuOverrides extends Logging { + + val shimExpressions: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = + Seq( + GpuOverrides.expr[Cast]( + "Convert a column of one type of data into another type", + new CastChecks(), + (cast, conf, p, r) => new CastExprMeta[Cast](cast, false, conf, p, r, + doFloatToIntCheck = false, stringToAnsiDate = false)), + GpuOverrides.expr[Average]( + "Average aggregate operator", + ExprChecks.fullAgg( + TypeSig.DOUBLE + TypeSig.DECIMAL_128, + TypeSig.DOUBLE + TypeSig.DECIMAL_128, + Seq(ParamCheck("input", + TypeSig.integral + TypeSig.fp + TypeSig.DECIMAL_128, + TypeSig.cpuNumeric))), + (a, conf, p, r) => new AggExprMeta[Average](a, conf, p, r) { + override def tagAggForGpu(): Unit = { + // For Decimal Average the SUM adds a precision of 10 to avoid overflowing + // then it divides by the count with an output scale that is 4 more than the input + // scale. With how our divide works to match Spark, this means that we will need a + // precision of 5 more. So 38 - 10 - 5 = 23 + val dataType = a.child.dataType + dataType match { + case dt: DecimalType => + if (dt.precision > 23) { + if (conf.needDecimalGuarantees) { + willNotWorkOnGpu("GpuAverage cannot guarantee proper overflow checks for " + + s"a precision large than 23. The current precision is ${dt.precision}") + } else { + logWarning("Decimal overflow guarantees disabled for " + + s"Average(${a.child.dataType}) produces $dt with an " + + s"intermediate precision of ${dt.precision + 15}") + } + } + case _ => // NOOP + } + GpuOverrides.checkAndTagFloatAgg(dataType, conf, this) + } + + }), + GpuOverrides.expr[Abs]( + "Absolute value", + ExprChecks.unaryProjectAndAstInputMatchesOutput( + TypeSig.implicitCastsAstTypes, TypeSig.gpuNumeric, + TypeSig.cpuNumeric), + (a, conf, p, r) => new UnaryAstExprMeta[Abs](a, conf, p, r) { + }), + GpuOverrides.expr[RegExpReplace]( + "RegExpReplace support for string literal input patterns", + ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, + Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), + ParamCheck("regex", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), + ParamCheck("rep", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), + (a, conf, p, r) => new GpuRegExpReplaceMeta(a, conf, p, r)).disabledByDefault( + "the implementation is not 100% compatible. " + + "See the compatibility guide for more information."), + GpuOverrides.expr[TimeSub]( + "Subtracts interval from timestamp", + ExprChecks.binaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP, + ("start", TypeSig.TIMESTAMP, TypeSig.TIMESTAMP), + ("interval", TypeSig.lit(TypeEnum.CALENDAR) + .withPsNote(TypeEnum.CALENDAR, "months not supported"), TypeSig.CALENDAR)), + (timeSub, conf, p, r) => new BinaryExprMeta[TimeSub](timeSub, conf, p, r) { + override def tagExprForGpu(): Unit = { + timeSub.interval match { + case Literal(intvl: CalendarInterval, DataTypes.CalendarIntervalType) => + if (intvl.months != 0) { + willNotWorkOnGpu("interval months isn't supported") + } + case _ => + } + checkTimeZoneId(timeSub.timeZoneId) + } + }), + GpuOverrides.expr[ScalaUDF]( + "User Defined Function, the UDF can choose to implement a RAPIDS accelerated interface " + + "to get better performance.", + ExprChecks.projectOnly( + GpuUserDefinedFunction.udfTypeSig, + TypeSig.all, + repeatingParamCheck = + Some(RepeatingParamCheck("param", GpuUserDefinedFunction.udfTypeSig, TypeSig.all))), + (expr, conf, p, r) => new ScalaUDFMetaBase(expr, conf, p, r) { + }) + ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap + + val shimExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = Seq( + GpuOverrides.exec[FileSourceScanExec]( + "Reading data from files, often from Hive tables", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + + TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), TypeSig.all), + (fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) { + + // partition filters and data filters are not run on the GPU + override val childExprs: Seq[ExprMeta[_]] = Seq.empty + + override def tagPlanForGpu(): Unit = GpuFileSourceScanExec.tagSupport(this) + }), + GpuOverrides.exec[ArrowEvalPythonExec]( + "The backend of the Scalar Pandas UDFs. Accelerates the data transfer between the" + + " Java process and the Python process. It also supports scheduling GPU resources" + + " for the Python process when enabled", + ExecChecks( + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT).nested(), + TypeSig.all), + (e, conf, p, r) => + new SparkPlanMeta[ArrowEvalPythonExec](e, conf, p, r) { + val udfs: Seq[BaseExprMeta[PythonUDF]] = + e.udfs.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val resultAttrs: Seq[BaseExprMeta[Attribute]] = + e.output.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + override val childExprs: Seq[BaseExprMeta[_]] = udfs ++ resultAttrs + override def replaceMessage: String = "partially run on GPU" + override def noReplacementPossibleMessage(reasons: String): String = + s"cannot run even partially on the GPU because $reasons" + }), + GpuOverrides.exec[FlatMapGroupsInPandasExec]( + "The backend for Flat Map Groups Pandas UDF, Accelerates the data transfer between the" + + " Java process and the Python process. It also supports scheduling GPU resources" + + " for the Python process when enabled.", + ExecChecks(TypeSig.commonCudfTypes, TypeSig.all), + (flatPy, conf, p, r) => new GpuFlatMapGroupsInPandasExecMeta(flatPy, conf, p, r)), + GpuOverrides.exec[WindowInPandasExec]( + "The backend for Window Aggregation Pandas UDF, Accelerates the data transfer between" + + " the Java process and the Python process. It also supports scheduling GPU resources" + + " for the Python process when enabled. For now it only supports row based window frame.", + ExecChecks( + (TypeSig.commonCudfTypes + TypeSig.ARRAY).nested(TypeSig.commonCudfTypes), + TypeSig.all), + (winPy, conf, p, r) => new GpuWindowInPandasExecMetaBase(winPy, conf, p, r) { + override val windowExpressions: Seq[BaseExprMeta[NamedExpression]] = + winPy.windowExpression.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + }).disabledByDefault("it only supports row based frame for now"), + GpuOverrides.exec[AggregateInPandasExec]( + "The backend for an Aggregation Pandas UDF, this accelerates the data transfer between" + + " the Java process and the Python process. It also supports scheduling GPU resources" + + " for the Python process when enabled.", + ExecChecks(TypeSig.commonCudfTypes, TypeSig.all), + (aggPy, conf, p, r) => new GpuAggregateInPandasExecMeta(aggPy, conf, p, r)) + ).collect { case r if r != null => (r.getClassFor.asSubclass(classOf[SparkPlan]), r) }.toMap +} + diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala new file mode 100644 index 00000000000..f36614d9fe6 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -0,0 +1,2173 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import java.io.{File, FileOutputStream} +import java.time.ZoneId + +import com.nvidia.spark.rapids.shims.v2.TypeSigUtil + +import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnaryExpression, WindowSpecDefinition} +import org.apache.spark.sql.types._ + +/** Trait of TypeSigUtil for different spark versions */ +trait TypeSigUtilBase { + + /** + * Check if this type of Spark-specific is supported by the plugin or not. + * @param check the Supported Types + * @param dataType the data type to be checked + * @return true if it is allowed else false. + */ + def isSupported(check: TypeEnum.ValueSet, dataType: DataType): Boolean + + /** + * Get all supported types for the spark-specific + * @return the all supported typ + */ + def getAllSupportedTypes: TypeEnum.ValueSet + + /** + * Return the reason why this type is not supported.\ + * @param check the Supported Types + * @param dataType the data type to be checked + * @param notSupportedReason the reason for not supporting + * @return the reason + */ + def reasonNotSupported( + check: TypeEnum.ValueSet, + dataType: DataType, + notSupportedReason: Seq[String]): Seq[String] + + /** + * Map DataType to TypeEnum + * @param dataType the data type to be mapped + * @return the TypeEnum + */ + def mapDataTypeToTypeEnum(dataType: DataType): TypeEnum.Value + + /** Get numeric and interval TypeSig */ + def getNumericAndInterval: TypeSig +} + +/** + * The level of support that the plugin has for a given type. Used for documentation generation. + */ +sealed abstract class SupportLevel { + def htmlTag: String + def text: String +} + +/** + * N/A neither spark nor the plugin supports this. + */ +object NotApplicable extends SupportLevel { + override def htmlTag: String = " " + override def text: String = "NA" +} + +/** + * Spark supports this but the plugin does not. + */ +object NotSupported extends SupportLevel { + override def htmlTag: String = s"$text" + override def text: String = "NS" +} + +/** + * Both Spark and the plugin support this. + */ +class Supported() extends SupportLevel { + override def htmlTag: String = s"$text" + override def text: String = "S" +} + +/** + * The plugin partially supports this type. + * @param missingChildTypes child types that are not supported + * @param needsLitWarning true if we need to warn that we only support a literal value when Spark + * does not. + * @param note any other notes we want to include about not complete support. + */ +class PartiallySupported( + val missingChildTypes: TypeEnum.ValueSet = TypeEnum.ValueSet(), + val needsLitWarning: Boolean = false, + val note: Option[String] = None) extends SupportLevel { + override def htmlTag: String = { + val typeStr = if (missingChildTypes.isEmpty) { + None + } else { + Some("unsupported child types " + missingChildTypes.mkString(", ")) + } + val litOnly = if (needsLitWarning) { + Some("Literal value only") + } else { + None + } + val extraInfo = (note.toSeq ++ litOnly.toSeq ++ typeStr.toSeq).mkString(";
") + val allText = s"$text
$extraInfo" + s"$allText" + } + + // don't include the extra info in the supported text field for now + // as the qualification tool doesn't use it + override def text: String = "PS" +} + +/** + * The Supported Types. The TypeSig API should be preferred for this, except in a few cases when + * TypeSig asks for a TypeEnum. + */ +object TypeEnum extends Enumeration { + type TypeEnum = Value + + val BOOLEAN: Value = Value + val BYTE: Value = Value + val SHORT: Value = Value + val INT: Value = Value + val LONG: Value = Value + val FLOAT: Value = Value + val DOUBLE: Value = Value + val DATE: Value = Value + val TIMESTAMP: Value = Value + val STRING: Value = Value + val DECIMAL: Value = Value + val NULL: Value = Value + val BINARY: Value = Value + val CALENDAR: Value = Value + val ARRAY: Value = Value + val MAP: Value = Value + val STRUCT: Value = Value + val UDT: Value = Value + val DAYTIME: Value = Value + val YEARMONTH: Value = Value +} + +/** + * A type signature. This is a bit limited in what it supports right now, but can express + * a set of base types and a separate set of types that can be nested under the base types + * (child types). It can also express if a particular base type has to be a literal or not. + */ +final class TypeSig private( + private val initialTypes: TypeEnum.ValueSet, + private val maxAllowedDecimalPrecision: Int = GpuOverrides.DECIMAL64_MAX_PRECISION, + private val childTypes: TypeEnum.ValueSet = TypeEnum.ValueSet(), + private val litOnlyTypes: TypeEnum.ValueSet = TypeEnum.ValueSet(), + private val notes: Map[TypeEnum.Value, String] = Map.empty) { + + /** + * Add a literal restriction to the signature + * @param dataType the type that has to be literal. Will be added if it does not already exist. + * @return the new signature. + */ + def withLit(dataType: TypeEnum.Value): TypeSig = { + val it = initialTypes + dataType + val lt = litOnlyTypes + dataType + new TypeSig(it, maxAllowedDecimalPrecision, childTypes, lt, notes) + } + + /** + * All currently supported types can only be literal values. + * @return the new signature. + */ + def withAllLit(): TypeSig = { + // don't need to combine initialTypes with litOnlyTypes because litOnly should be a subset + new TypeSig(initialTypes, maxAllowedDecimalPrecision, childTypes, initialTypes, notes) + } + + /** + * Combine two type signatures together. Base types and child types will be the union of + * both as will limitations on literal values. + * @param other what to combine with. + * @return the new signature + */ + def + (other: TypeSig): TypeSig = { + val it = initialTypes ++ other.initialTypes + val nt = childTypes ++ other.childTypes + val lt = litOnlyTypes ++ other.litOnlyTypes + val dp = Math.max(maxAllowedDecimalPrecision, other.maxAllowedDecimalPrecision) + // TODO nested types is not always going to do what you want, so we might want to warn + val nts = notes ++ other.notes + new TypeSig(it, dp, nt, lt, nts) + } + + /** + * Remove a type signature. The reverse of + + * @param other what to remove + * @return the new signature + */ + def - (other: TypeSig): TypeSig = { + val it = initialTypes -- other.initialTypes + val nt = childTypes -- other.childTypes + val lt = litOnlyTypes -- other.litOnlyTypes + val nts = notes -- other.notes.keySet + new TypeSig(it, maxAllowedDecimalPrecision, nt, lt, nts) + } + + def intersect(other: TypeSig): TypeSig = { + val it = initialTypes & other.initialTypes + val nt = childTypes & other.childTypes + val lt = litOnlyTypes & other.initialTypes + val nts = notes.filterKeys(other.initialTypes) + new TypeSig(it, maxAllowedDecimalPrecision, nt, lt, nts) + } + + /** + * Add child types to this type signature. Note that these do not stack so if childTypes has + * child types too they are ignored. + * @param childTypes the basic types to add. + * @return the new type signature + */ + def nested(childTypes: TypeSig): TypeSig = { + val mp = Math.max(maxAllowedDecimalPrecision, childTypes.maxAllowedDecimalPrecision) + new TypeSig(initialTypes, mp, this.childTypes ++ childTypes.initialTypes, litOnlyTypes, notes) + } + + /** + * Update this type signature to be nested with the initial types too. + * @return the update type signature + */ + def nested(): TypeSig = + new TypeSig(initialTypes, maxAllowedDecimalPrecision, initialTypes ++ childTypes, + litOnlyTypes, notes) + + /** + * Add a note about a given type that marks it as partially supported. + * @param dataType the type this note is for. + * @param note the note itself + * @return the updated TypeSignature. + */ + def withPsNote(dataType: TypeEnum.Value, note: String): TypeSig = + new TypeSig(initialTypes + dataType, maxAllowedDecimalPrecision, childTypes, litOnlyTypes, + notes.+((dataType, note))) + + private def isSupportedType(dataType: TypeEnum.Value): Boolean = + initialTypes.contains(dataType) + + /** + * Given an expression tag the associated meta for it to be supported or not. + * + * @param meta the meta that gets marked for support or not. + * @param exprMeta the meta of expression to check against. + * @param name the name of the expression (typically a parameter name) + */ + def tagExprParam( + meta: RapidsMeta[_, _], + exprMeta: BaseExprMeta[_], + name: String, + willNotWork: String => Unit): Unit = { + val typeMeta = exprMeta.typeMeta + // This is for a parameter so skip it if there is no data type for the expression + typeMeta.dataType.foreach { dt => + val expr = exprMeta.wrapped.asInstanceOf[Expression] + + if (!isSupportedByPlugin(dt)) { + willNotWork(s"$name expression ${expr.getClass.getSimpleName} $expr " + + reasonNotSupported(dt).mkString("(", ", ", ")")) + } else if (isLitOnly(dt) && !GpuOverrides.isLit(expr)) { + willNotWork(s"$name only supports $dt if it is a literal value") + } + if (typeMeta.typeConverted) { + meta.addConvertedDataType(expr, typeMeta) + } + } + } + + /** + * Check if this type is supported by the plugin or not. + * @param dataType the data type to be checked + * @return true if it is allowed else false. + */ + def isSupportedByPlugin(dataType: DataType): Boolean = + isSupported(initialTypes, dataType) + + private [this] def isLitOnly(dataType: DataType): Boolean = dataType match { + case BooleanType => litOnlyTypes.contains(TypeEnum.BOOLEAN) + case ByteType => litOnlyTypes.contains(TypeEnum.BYTE) + case ShortType => litOnlyTypes.contains(TypeEnum.SHORT) + case IntegerType => litOnlyTypes.contains(TypeEnum.INT) + case LongType => litOnlyTypes.contains(TypeEnum.LONG) + case FloatType => litOnlyTypes.contains(TypeEnum.FLOAT) + case DoubleType => litOnlyTypes.contains(TypeEnum.DOUBLE) + case DateType => litOnlyTypes.contains(TypeEnum.DATE) + case TimestampType => litOnlyTypes.contains(TypeEnum.TIMESTAMP) + case StringType => litOnlyTypes.contains(TypeEnum.STRING) + case _: DecimalType => litOnlyTypes.contains(TypeEnum.DECIMAL) + case NullType => litOnlyTypes.contains(TypeEnum.NULL) + case BinaryType => litOnlyTypes.contains(TypeEnum.BINARY) + case CalendarIntervalType => litOnlyTypes.contains(TypeEnum.CALENDAR) + case _: ArrayType => litOnlyTypes.contains(TypeEnum.ARRAY) + case _: MapType => litOnlyTypes.contains(TypeEnum.MAP) + case _: StructType => litOnlyTypes.contains(TypeEnum.STRUCT) + case _ => TypeSigUtil.isSupported(litOnlyTypes, dataType) + } + + def isSupportedBySpark(dataType: DataType): Boolean = + isSupported(initialTypes, dataType) + + private[this] def isSupported( + check: TypeEnum.ValueSet, + dataType: DataType): Boolean = + dataType match { + case BooleanType => check.contains(TypeEnum.BOOLEAN) + case ByteType => check.contains(TypeEnum.BYTE) + case ShortType => check.contains(TypeEnum.SHORT) + case IntegerType => check.contains(TypeEnum.INT) + case LongType => check.contains(TypeEnum.LONG) + case FloatType => check.contains(TypeEnum.FLOAT) + case DoubleType => check.contains(TypeEnum.DOUBLE) + case DateType => check.contains(TypeEnum.DATE) + case TimestampType if check.contains(TypeEnum.TIMESTAMP) => + TypeChecks.areTimestampsSupported(ZoneId.systemDefault()) + case StringType => check.contains(TypeEnum.STRING) + case dt: DecimalType => + check.contains(TypeEnum.DECIMAL) && + dt.precision <= maxAllowedDecimalPrecision + case NullType => check.contains(TypeEnum.NULL) + case BinaryType => check.contains(TypeEnum.BINARY) + case CalendarIntervalType => check.contains(TypeEnum.CALENDAR) + case ArrayType(elementType, _) if check.contains(TypeEnum.ARRAY) => + isSupported(childTypes, elementType) + case MapType(keyType, valueType, _) if check.contains(TypeEnum.MAP) => + isSupported(childTypes, keyType) && + isSupported(childTypes, valueType) + case StructType(fields) if check.contains(TypeEnum.STRUCT) => + fields.map(_.dataType).forall { t => + isSupported(childTypes, t) + } + case _ => TypeSigUtil.isSupported(check, dataType) + } + + def reasonNotSupported(dataType: DataType): Seq[String] = + reasonNotSupported(initialTypes, dataType, isChild = false) + + private[this] def withChild(isChild: Boolean, msg: String): String = if (isChild) { + "child " + msg + } else { + msg + } + + private[this] def basicNotSupportedMessage(dataType: DataType, + te: TypeEnum.Value, check: TypeEnum.ValueSet, isChild: Boolean): Seq[String] = { + if (check.contains(te)) { + Seq.empty + } else { + Seq(withChild(isChild, s"$dataType is not supported")) + } + } + + private[this] def reasonNotSupported( + check: TypeEnum.ValueSet, + dataType: DataType, + isChild: Boolean): Seq[String] = + dataType match { + case BooleanType => + basicNotSupportedMessage(dataType, TypeEnum.BOOLEAN, check, isChild) + case ByteType => + basicNotSupportedMessage(dataType, TypeEnum.BYTE, check, isChild) + case ShortType => + basicNotSupportedMessage(dataType, TypeEnum.SHORT, check, isChild) + case IntegerType => + basicNotSupportedMessage(dataType, TypeEnum.INT, check, isChild) + case LongType => + basicNotSupportedMessage(dataType, TypeEnum.LONG, check, isChild) + case FloatType => + basicNotSupportedMessage(dataType, TypeEnum.FLOAT, check, isChild) + case DoubleType => + basicNotSupportedMessage(dataType, TypeEnum.DOUBLE, check, isChild) + case DateType => + basicNotSupportedMessage(dataType, TypeEnum.DATE, check, isChild) + case TimestampType => + if (check.contains(TypeEnum.TIMESTAMP) && + (!TypeChecks.areTimestampsSupported(ZoneId.systemDefault()))) { + Seq(withChild(isChild, s"$dataType is not supported when the JVM system " + + s"timezone is set to ${ZoneId.systemDefault()}. Set the timezone to UTC to enable " + + s"$dataType support")) + } else { + basicNotSupportedMessage(dataType, TypeEnum.TIMESTAMP, check, isChild) + } + case StringType => + basicNotSupportedMessage(dataType, TypeEnum.STRING, check, isChild) + case dt: DecimalType => + if (check.contains(TypeEnum.DECIMAL)) { + var reasons = Seq[String]() + if (dt.precision > maxAllowedDecimalPrecision) { + reasons ++= Seq(withChild(isChild, s"$dataType precision is larger " + + s"than we support $maxAllowedDecimalPrecision")) + } + reasons + } else { + basicNotSupportedMessage(dataType, TypeEnum.DECIMAL, check, isChild) + } + case NullType => + basicNotSupportedMessage(dataType, TypeEnum.NULL, check, isChild) + case BinaryType => + basicNotSupportedMessage(dataType, TypeEnum.BINARY, check, isChild) + case CalendarIntervalType => + basicNotSupportedMessage(dataType, TypeEnum.CALENDAR, check, isChild) + case ArrayType(elementType, _) => + if (check.contains(TypeEnum.ARRAY)) { + reasonNotSupported(childTypes, elementType, isChild = true) + } else { + basicNotSupportedMessage(dataType, TypeEnum.ARRAY, check, isChild) + } + case MapType(keyType, valueType, _) => + if (check.contains(TypeEnum.MAP)) { + reasonNotSupported(childTypes, keyType, isChild = true) ++ + reasonNotSupported(childTypes, valueType, isChild = true) + } else { + basicNotSupportedMessage(dataType, TypeEnum.MAP, check, isChild) + } + case StructType(fields) => + if (check.contains(TypeEnum.STRUCT)) { + fields.flatMap { sf => + reasonNotSupported(childTypes, sf.dataType, isChild = true) + } + } else { + basicNotSupportedMessage(dataType, TypeEnum.STRUCT, check, isChild) + } + case _ => TypeSigUtil.reasonNotSupported(check, dataType, + Seq(withChild(isChild, s"$dataType is not supported"))) + } + + def areAllSupportedByPlugin(types: Seq[DataType]): Boolean = + types.forall(isSupportedByPlugin) + + /** + * Get the level of support for a given type compared to what Spark supports. + * Used for documentation. + */ + def getSupportLevel(dataType: TypeEnum.Value, allowed: TypeSig): SupportLevel = { + if (!allowed.isSupportedType(dataType)) { + NotApplicable + } else if (!isSupportedType(dataType)) { + NotSupported + } else { + var note = notes.get(dataType) + val needsLitWarning = litOnlyTypes.contains(dataType) && + !allowed.litOnlyTypes.contains(dataType) + val lowerPrecision = + dataType == TypeEnum.DECIMAL && maxAllowedDecimalPrecision < DecimalType.MAX_PRECISION + if (lowerPrecision) { + val msg = s"max DECIMAL precision of $maxAllowedDecimalPrecision" + note = if (note.isEmpty) { + Some(msg) + } else { + Some(note.get + ";
" + msg) + } + } + + if (dataType == TypeEnum.TIMESTAMP) { + val msg = s"UTC is only supported TZ for TIMESTAMP" + note = if (note.isEmpty) { + Some(msg) + } else { + Some(note.get + ";
" + msg) + } + } + + dataType match { + case TypeEnum.ARRAY | TypeEnum.MAP | TypeEnum.STRUCT => + val subTypeLowerPrecision = childTypes.contains(TypeEnum.DECIMAL) && + maxAllowedDecimalPrecision < DecimalType.MAX_PRECISION + if (subTypeLowerPrecision) { + val msg = s"max child DECIMAL precision of $maxAllowedDecimalPrecision" + note = if (note.isEmpty) { + Some(msg) + } else { + Some(note.get + ";
" + msg) + } + } + + if (childTypes.contains(TypeEnum.TIMESTAMP)) { + val msg = s"UTC is only supported TZ for child TIMESTAMP" + note = if (note.isEmpty) { + Some(msg) + } else { + Some(note.get + ";
" + msg) + } + } + + val subTypesMissing = allowed.childTypes -- childTypes + if (subTypesMissing.isEmpty && note.isEmpty && !needsLitWarning) { + new Supported() + } else { + new PartiallySupported(missingChildTypes = subTypesMissing, + needsLitWarning = needsLitWarning, + note = note) + } + case _ if note.isDefined || needsLitWarning => + new PartiallySupported(needsLitWarning = needsLitWarning, note = note) + case _ => + new Supported() + } + } + } +} + +object TypeSig { + /** + * Create a TypeSig that only supports a literal of the given type. + */ + def lit(dataType: TypeEnum.Value): TypeSig = + TypeSig.none.withLit(dataType) + + /** + * Create a TypeSig that has partial support for the given type. + */ + def psNote(dataType: TypeEnum.Value, note: String): TypeSig = + TypeSig.none.withPsNote(dataType, note) + + def decimal(maxPrecision: Int): TypeSig = + new TypeSig(TypeEnum.ValueSet(TypeEnum.DECIMAL), maxPrecision) + + /** + * All types nested and not nested + */ + val all: TypeSig = { + val allSupportedTypes = TypeSigUtil.getAllSupportedTypes() + new TypeSig(allSupportedTypes, DecimalType.MAX_PRECISION, allSupportedTypes) + } + + /** + * No types supported at all + */ + val none: TypeSig = new TypeSig(TypeEnum.ValueSet()) + + val BOOLEAN: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.BOOLEAN)) + val BYTE: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.BYTE)) + val SHORT: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.SHORT)) + val INT: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.INT)) + val LONG: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.LONG)) + val FLOAT: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.FLOAT)) + val DOUBLE: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.DOUBLE)) + val DATE: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.DATE)) + val TIMESTAMP: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.TIMESTAMP)) + val STRING: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.STRING)) + val DECIMAL_64: TypeSig = decimal(GpuOverrides.DECIMAL64_MAX_PRECISION) + + /** + * Full support for 128 bit DECIMAL. In the future we expect to have other types with + * slightly less than full DECIMAL support. This are things like math operations where + * we cannot replicate the overflow behavior of Spark. These will be added when needed. + */ + val DECIMAL_128: TypeSig = decimal(GpuOverrides.DECIMAL128_MAX_PRECISION) + + val NULL: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.NULL)) + val BINARY: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.BINARY)) + val CALENDAR: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.CALENDAR)) + /** + * ARRAY type support, but not very useful on its own because no child types under + * it are supported + */ + val ARRAY: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.ARRAY)) + /** + * MAP type support, but not very useful on its own because no child types under + * it are supported + */ + val MAP: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.MAP)) + /** + * STRUCT type support, but only matches empty structs unless you add child types to it. + */ + val STRUCT: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.STRUCT)) + /** + * User Defined Type (We don't support these in the plugin yet) + */ + val UDT: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.UDT)) + + /** + * DayTimeIntervalType of Spark 3.2.0+ support + */ + val DAYTIME: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.DAYTIME)) + + /** + * YearMonthIntervalType of Spark 3.2.0+ support + */ + val YEARMONTH: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.YEARMONTH)) + + /** + * A signature for types that are generally supported by the plugin/CUDF. Please make sure to + * check what Spark actually supports instead of blindly using this in a signature. + */ + val commonCudfTypes: TypeSig = BOOLEAN + BYTE + SHORT + INT + LONG + FLOAT + DOUBLE + DATE + + TIMESTAMP + STRING + + /** + * All floating point types + */ + val fp: TypeSig = FLOAT + DOUBLE + + /** + * All integer types + */ + val integral: TypeSig = BYTE + SHORT + INT + LONG + + /** + * All numeric types fp + integral + DECIMAL_64 + */ + val gpuNumeric: TypeSig = integral + fp + DECIMAL_128 + + /** + * All numeric types fp + integral + DECIMAL_128 + */ + val cpuNumeric: TypeSig = integral + fp + DECIMAL_128 + + /** + * All values that correspond to Spark's AtomicType but supported by GPU + */ + val gpuAtomics: TypeSig = gpuNumeric + BINARY + BOOLEAN + DATE + STRING + TIMESTAMP + + /** + * All values that correspond to Spark's AtomicType + */ + val cpuAtomics: TypeSig = cpuNumeric + BINARY + BOOLEAN + DATE + STRING + TIMESTAMP + + /** + * numeric + CALENDAR but only for GPU + */ + val gpuNumericAndInterval: TypeSig = gpuNumeric + CALENDAR + + /** + * numeric + CALENDAR + */ + val numericAndInterval: TypeSig = TypeSigUtil.getNumericAndInterval() + + /** + * All types that CUDF supports sorting/ordering on. + */ + val gpuOrderable: TypeSig = (BOOLEAN + BYTE + SHORT + INT + LONG + FLOAT + DOUBLE + DATE + + TIMESTAMP + STRING + DECIMAL_64 + NULL + STRUCT).nested() + + /** + * All types that Spark supports sorting/ordering on (really everything but MAP) + */ + val orderable: TypeSig = (BOOLEAN + BYTE + SHORT + INT + LONG + FLOAT + DOUBLE + DATE + + TIMESTAMP + STRING + DECIMAL_128 + NULL + BINARY + CALENDAR + ARRAY + STRUCT + + UDT).nested() + + /** + * All types that Spark supports for comparison operators (really everything but MAP according + * to https://spark.apache.org/docs/latest/api/sql/index.html#_12), e.g. "<=>", "=", "==". + */ + val comparable: TypeSig = (BOOLEAN + BYTE + SHORT + INT + LONG + FLOAT + DOUBLE + DATE + + TIMESTAMP + STRING + DECIMAL_128 + NULL + BINARY + CALENDAR + ARRAY + STRUCT + + UDT).nested() + + /** + * Different types of Pandas UDF support different sets of output type. Please refer to + * https://github.com/apache/spark/blob/master/python/pyspark/sql/udf.py#L98 + * for more details. + * + * It is impossible to specify the exact type signature for each Pandas UDF type in a single + * expression 'PythonUDF'. + * + * So here comes the union of all the sets of supported type, to cover all the cases. + */ + val unionOfPandasUdfOut: TypeSig = + (commonCudfTypes + BINARY + DECIMAL_64 + NULL + ARRAY + MAP).nested() + STRUCT + + /** All types that can appear in AST expressions */ + val astTypes: TypeSig = BOOLEAN + integral + fp + TIMESTAMP + DATE + + /** All AST types that work for comparisons */ + val comparisonAstTypes: TypeSig = astTypes - fp + + /** All types that can appear in an implicit cast AST expression */ + val implicitCastsAstTypes: TypeSig = astTypes - BYTE - SHORT + + def getDataType(expr: Expression): Option[DataType] = { + try { + Some(expr.dataType) + } catch { + case _: java.lang.UnsupportedOperationException =>None + } + } +} + +abstract class TypeChecks[RET] { + def tag(meta: RapidsMeta[_, _]): Unit + + def support(dataType: TypeEnum.Value): RET + + val shown: Boolean = true + + private def stringifyTypeAttributeMap(groupedByType: Map[DataType, Set[String]]): String = { + groupedByType.map { case (dataType, nameSet) => + dataType + " " + nameSet.mkString("[", ", ", "]") + }.mkString(", ") + } + + protected def tagUnsupportedTypes( + meta: RapidsMeta[_, _], + sig: TypeSig, + fields: Seq[StructField], + msgFormat: String + ): Unit = { + val unsupportedTypes: Map[DataType, Set[String]] = fields + .filterNot(attr => sig.isSupportedByPlugin(attr.dataType)) + .groupBy(_.dataType) + .mapValues(_.map(_.name).toSet) + + if (unsupportedTypes.nonEmpty) { + meta.willNotWorkOnGpu(msgFormat.format(stringifyTypeAttributeMap(unsupportedTypes))) + } + } +} + +object TypeChecks { + /** + * Check if the time zone passed is supported by plugin. + */ + def areTimestampsSupported(timezoneId: ZoneId): Boolean = { + timezoneId.normalized() == GpuOverrides.UTC_TIMEZONE_ID + } +} + +/** + * Checks a set of named inputs to an SparkPlan node against a TypeSig + */ +case class InputCheck(cudf: TypeSig, spark: TypeSig, notes: List[String] = List.empty) + +/** + * Checks a single parameter by position against a TypeSig + */ +case class ParamCheck(name: String, cudf: TypeSig, spark: TypeSig) + +/** + * Checks the type signature for a parameter that repeats (Can only be used at the end of a list + * of position parameters) + */ +case class RepeatingParamCheck(name: String, cudf: TypeSig, spark: TypeSig) + +/** + * Checks an expression that have input parameters and a single output. This is intended to be + * given for a specific ExpressionContext. If your expression does not meet this pattern you may + * need to create a custom ExprChecks instance. + */ +case class ContextChecks( + outputCheck: TypeSig, + sparkOutputSig: TypeSig, + paramCheck: Seq[ParamCheck] = Seq.empty, + repeatingParamCheck: Option[RepeatingParamCheck] = None) + extends TypeChecks[Map[String, SupportLevel]] { + + def tagAst(exprMeta: BaseExprMeta[_]): Unit = { + tagBase(exprMeta, exprMeta.willNotWorkInAst) + } + + override def tag(rapidsMeta: RapidsMeta[_, _]): Unit = { + tagBase(rapidsMeta, rapidsMeta.willNotWorkOnGpu) + } + + private[this] def tagBase(rapidsMeta: RapidsMeta[_, _], willNotWork: String => Unit): Unit = { + val meta = rapidsMeta.asInstanceOf[BaseExprMeta[_]] + val expr = meta.wrapped.asInstanceOf[Expression] + meta.typeMeta.dataType match { + case Some(dt: DataType) => + if (!outputCheck.isSupportedByPlugin(dt)) { + willNotWork(s"expression ${expr.getClass.getSimpleName} $expr " + + s"produces an unsupported type $dt") + } + if (meta.typeMeta.typeConverted) { + meta.addConvertedDataType(expr, meta.typeMeta) + } + case None => + if (!meta.ignoreUnsetDataTypes) { + willNotWork(s"expression ${expr.getClass.getSimpleName} $expr " + + s" does not have a corresponding dataType.") + } + } + + val children = meta.childExprs + val fixedChecks = paramCheck.toArray + assert (fixedChecks.length <= children.length, + s"${expr.getClass.getSimpleName} expected at least ${fixedChecks.length} but " + + s"found ${children.length}") + fixedChecks.zipWithIndex.foreach { case (check, i) => + check.cudf.tagExprParam(meta, children(i), check.name, willNotWork) + } + if (repeatingParamCheck.isEmpty) { + assert(fixedChecks.length == children.length, + s"${expr.getClass.getSimpleName} expected ${fixedChecks.length} but " + + s"found ${children.length}") + } else { + val check = repeatingParamCheck.get + (fixedChecks.length until children.length).foreach { i => + check.cudf.tagExprParam(meta, children(i), check.name, willNotWork) + } + } + } + + override def support(dataType: TypeEnum.Value): Map[String, SupportLevel] = { + val fixed = paramCheck.map(check => + (check.name, check.cudf.getSupportLevel(dataType, check.spark))) + val variable = repeatingParamCheck.map(check => + (check.name, check.cudf.getSupportLevel(dataType, check.spark))) + val output = ("result", outputCheck.getSupportLevel(dataType, sparkOutputSig)) + + (fixed ++ variable ++ Seq(output)).toMap + } +} + +/** + * Checks for either a read or a write of a given file format. + */ +class FileFormatChecks private ( + sig: TypeSig, + sparkSig: TypeSig) + extends TypeChecks[SupportLevel] { + + def tag(meta: RapidsMeta[_, _], + schema: StructType, + fileType: FileFormatType, + op: FileFormatOp): Unit = { + tagUnsupportedTypes(meta, sig, schema.fields, + s"unsupported data types %s in $op for $fileType") + } + + override def support(dataType: TypeEnum.Value): SupportLevel = + sig.getSupportLevel(dataType, sparkSig) + + override def tag(meta: RapidsMeta[_, _]): Unit = + throw new IllegalStateException("Internal Error not supported") + + def getFileFormat: TypeSig = sig +} + +object FileFormatChecks { + /** + * File format checks with separate read and write signatures for cudf. + */ + def apply( + cudfRead: TypeSig, + cudfWrite: TypeSig, + sparkSig: TypeSig): Map[FileFormatOp, FileFormatChecks] = Map( + (ReadFileOp, new FileFormatChecks(cudfRead, sparkSig)), + (WriteFileOp, new FileFormatChecks(cudfWrite, sparkSig)) + ) + + /** + * File format checks where read and write have the same signature for cudf. + */ + def apply( + cudfReadWrite: TypeSig, + sparkSig: TypeSig): Map[FileFormatOp, FileFormatChecks] = + apply(cudfReadWrite, cudfReadWrite, sparkSig) + + def tag(meta: RapidsMeta[_, _], + schema: StructType, + fileType: FileFormatType, + op: FileFormatOp): Unit = { + GpuOverrides.fileFormats(fileType)(op).tag(meta, schema, fileType, op) + } +} + +/** + * Checks the input and output types supported by a SparkPlan node. We don't currently separate + * input checks from output checks. We can add this in if something needs it. + * + * The namedChecks map can be used to provide checks for specific groups of expressions. + */ +class ExecChecks private( + check: TypeSig, + sparkSig: TypeSig, + val namedChecks: Map[String, InputCheck], + override val shown: Boolean = true) + extends TypeChecks[Map[String, SupportLevel]] { + + override def tag(rapidsMeta: RapidsMeta[_, _]): Unit = { + val meta = rapidsMeta.asInstanceOf[SparkPlanMeta[_]] + + // expression.toString to capture ids in not-on-GPU tags + def toStructField(a: Attribute) = StructField(name = a.toString(), dataType = a.dataType) + + tagUnsupportedTypes(meta, check, meta.outputAttributes.map(toStructField), + "unsupported data types in output: %s") + tagUnsupportedTypes(meta, check, meta.childPlans.flatMap(_.outputAttributes.map(toStructField)), + "unsupported data types in input: %s") + + val namedChildExprs = meta.namedChildExprs + + val missing = namedChildExprs.keys.filterNot(namedChecks.contains) + if (missing.nonEmpty) { + throw new IllegalStateException(s"${meta.getClass.getSimpleName} " + + s"is missing ExecChecks for ${missing.mkString(",")}") + } + + namedChecks.foreach { + case (fieldName, pc) => + val fieldMeta = namedChildExprs(fieldName) + .flatMap(_.typeMeta.dataType) + .zipWithIndex + .map(t => StructField(s"c${t._2}", t._1)) + tagUnsupportedTypes(meta, pc.cudf, fieldMeta, + s"unsupported data types in '$fieldName': %s") + } + } + + override def support(dataType: TypeEnum.Value): Map[String, SupportLevel] = { + val groups = namedChecks.map { case (name, pc) => + (name, pc.cudf.getSupportLevel(dataType, pc.spark)) + } + groups ++ Map("Input/Output" -> check.getSupportLevel(dataType, sparkSig)) + } + + def supportNotes: Map[String, List[String]] = { + namedChecks.map { case (name, pc) => + (name, pc.notes) + }.filter { + case (_, notes) => notes.nonEmpty + } + } +} + +/** + * gives users an API to create ExecChecks. + */ +object ExecChecks { + def apply(check: TypeSig, sparkSig: TypeSig) : ExecChecks = { + new ExecChecks(check, sparkSig, Map.empty) + } + + def apply(check: TypeSig, + sparkSig: TypeSig, + namedChecks: Map[String, InputCheck]): ExecChecks = { + new ExecChecks(check, sparkSig, namedChecks) + } + + def hiddenHack(): ExecChecks = { + new ExecChecks(TypeSig.all, TypeSig.all, Map.empty, shown = false) + } +} + +/** + * Base class all Partition checks must follow + */ +abstract class PartChecks extends TypeChecks[Map[String, SupportLevel]] + +case class PartChecksImpl( + paramCheck: Seq[ParamCheck] = Seq.empty, + repeatingParamCheck: Option[RepeatingParamCheck] = None) + extends PartChecks { + + override def tag(meta: RapidsMeta[_, _]): Unit = { + val part = meta.wrapped + val children = meta.childExprs + + val fixedChecks = paramCheck.toArray + assert (fixedChecks.length <= children.length, + s"${part.getClass.getSimpleName} expected at least ${fixedChecks.length} but " + + s"found ${children.length}") + fixedChecks.zipWithIndex.foreach { case (check, i) => + check.cudf.tagExprParam(meta, children(i), check.name, meta.willNotWorkOnGpu) + } + if (repeatingParamCheck.isEmpty) { + assert(fixedChecks.length == children.length, + s"${part.getClass.getSimpleName} expected ${fixedChecks.length} but " + + s"found ${children.length}") + } else { + val check = repeatingParamCheck.get + (fixedChecks.length until children.length).foreach { i => + check.cudf.tagExprParam(meta, children(i), check.name, meta.willNotWorkOnGpu) + } + } + } + + override def support(dataType: TypeEnum.Value): Map[String, SupportLevel] = { + val fixed = paramCheck.map(check => + (check.name, check.cudf.getSupportLevel(dataType, check.spark))) + val variable = repeatingParamCheck.map(check => + (check.name, check.cudf.getSupportLevel(dataType, check.spark))) + + (fixed ++ variable).toMap + } +} + +object PartChecks { + def apply(repeatingParamCheck: RepeatingParamCheck): PartChecks = + PartChecksImpl(Seq.empty, Some(repeatingParamCheck)) + + def apply(): PartChecks = PartChecksImpl() +} + +/** + * Base class all Expression checks must follow. + */ +abstract class ExprChecks extends TypeChecks[Map[ExpressionContext, Map[String, SupportLevel]]] { + /** + * Tag this for AST or not. + */ + def tagAst(meta: BaseExprMeta[_]): Unit +} + +case class ExprChecksImpl(contexts: Map[ExpressionContext, ContextChecks]) + extends ExprChecks { + override def tag(meta: RapidsMeta[_, _]): Unit = { + val exprMeta = meta.asInstanceOf[BaseExprMeta[_]] + val context = exprMeta.context + val checks = contexts.get(context) + if (checks.isEmpty) { + meta.willNotWorkOnGpu(s"this is not supported in the $context context") + } else { + checks.get.tag(meta) + } + } + + override def support( + dataType: TypeEnum.Value): Map[ExpressionContext, Map[String, SupportLevel]] = { + contexts.map { + case (expContext: ExpressionContext, check: ContextChecks) => + (expContext, check.support(dataType)) + } + } + + override def tagAst(exprMeta: BaseExprMeta[_]): Unit = { + val checks = contexts.get(AstExprContext) + if (checks.isEmpty) { + exprMeta.willNotWorkInAst(AstExprContext.notSupportedMsg) + } else { + checks.get.tagAst(exprMeta) + } + } +} + +/** + * This is specific to CaseWhen, because it does not follow the typical parameter convention. + */ +object CaseWhenCheck extends ExprChecks { + val check: TypeSig = (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP).nested() + + val sparkSig: TypeSig = TypeSig.all + + override def tagAst(meta: BaseExprMeta[_]): Unit = { + meta.willNotWorkInAst(AstExprContext.notSupportedMsg) + // when this supports AST tagBase(exprMeta, meta.willNotWorkInAst) + } + + override def tag(meta: RapidsMeta[_, _]): Unit = { + val exprMeta = meta.asInstanceOf[BaseExprMeta[_]] + val context = exprMeta.context + if (context != ProjectExprContext) { + meta.willNotWorkOnGpu(s"this is not supported in the $context context") + } else { + tagBase(exprMeta, meta.willNotWorkOnGpu) + } + } + + private[this] def tagBase(exprMeta: BaseExprMeta[_], willNotWork: String => Unit): Unit = { + // children of CaseWhen: branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue (Optional) + // + // The length of children will be odd if elseValue is not None, which means we can detect + // both branch pair and possible elseValue via a size 2 grouped iterator. + exprMeta.childExprs.grouped(2).foreach { + case Seq(pred, value) => + TypeSig.BOOLEAN.tagExprParam(exprMeta, pred, "predicate", willNotWork) + check.tagExprParam(exprMeta, value, "value", willNotWork) + case Seq(elseValue) => + check.tagExprParam(exprMeta, elseValue, "else", willNotWork) + } + } + + override def support(dataType: TypeEnum.Value): + Map[ExpressionContext, Map[String, SupportLevel]] = { + val projectSupport = check.getSupportLevel(dataType, sparkSig) + val projectPredSupport = TypeSig.BOOLEAN.getSupportLevel(dataType, TypeSig.BOOLEAN) + Map((ProjectExprContext, + Map( + ("predicate", projectPredSupport), + ("value", projectSupport), + ("result", projectSupport)))) + } +} + +/** + * This is specific to WidowSpec, because it does not follow the typical parameter convention. + */ +object WindowSpecCheck extends ExprChecks { + val check: TypeSig = + TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + + TypeSig.STRUCT.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) + val sparkSig: TypeSig = TypeSig.all + + override def tagAst(meta: BaseExprMeta[_]): Unit = { + meta.willNotWorkInAst(AstExprContext.notSupportedMsg) + // when this supports AST tagBase(exprMeta, meta.willNotWorkInAst) + } + + override def tag(meta: RapidsMeta[_, _]): Unit = { + val exprMeta = meta.asInstanceOf[BaseExprMeta[_]] + val context = exprMeta.context + if (context != ProjectExprContext) { + meta.willNotWorkOnGpu(s"this is not supported in the $context context") + } else { + tagBase(exprMeta, meta.willNotWorkOnGpu) + } + } + + private [this] def tagBase(exprMeta: BaseExprMeta[_], willNotWork: String => Unit): Unit = { + val win = exprMeta.wrapped.asInstanceOf[WindowSpecDefinition] + // children of WindowSpecDefinition: partitionSpec ++ orderSpec :+ frameSpecification + win.partitionSpec.indices.foreach(i => + check.tagExprParam(exprMeta, exprMeta.childExprs(i), "partition", willNotWork)) + val partSize = win.partitionSpec.length + win.orderSpec.indices.foreach(i => + check.tagExprParam(exprMeta, exprMeta.childExprs(i + partSize), "order", + willNotWork)) + } + + override def support(dataType: TypeEnum.Value): + Map[ExpressionContext, Map[String, SupportLevel]] = { + val projectSupport = check.getSupportLevel(dataType, sparkSig) + Map((ProjectExprContext, + Map( + ("partition", projectSupport), + ("value", projectSupport), + ("result", projectSupport)))) + } +} + +object CreateMapCheck extends ExprChecks { + + // Spark supports all types except for Map for key (Map is not supported + // even in child types) + private val keySig: TypeSig = (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.ARRAY + TypeSig.STRUCT).nested() + + private val valueSig: TypeSig = (TypeSig.commonCudfTypes + TypeSig.NULL + + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT).nested() + + override def tagAst(meta: BaseExprMeta[_]): Unit = { + meta.willNotWorkInAst("CreateMap is not supported by AST") + } + + override def tag(meta: RapidsMeta[_, _]): Unit = { + val exprMeta = meta.asInstanceOf[BaseExprMeta[_]] + val context = exprMeta.context + if (context != ProjectExprContext) { + meta.willNotWorkOnGpu(s"this is not supported in the $context context") + } + } + + override def support( + dataType: TypeEnum.Value): Map[ExpressionContext, Map[String, SupportLevel]] = { + Map((ProjectExprContext, + Map( + ("key", keySig.getSupportLevel(dataType, keySig)), + ("value", valueSig.getSupportLevel(dataType, valueSig))))) + } +} + + +/** + * A check for CreateNamedStruct. The parameter values alternate between one type and another. + * If this pattern shows up again we can make this more generic at that point. + */ +object CreateNamedStructCheck extends ExprChecks { + val nameSig: TypeSig = TypeSig.lit(TypeEnum.STRING) + val sparkNameSig: TypeSig = TypeSig.lit(TypeEnum.STRING) + val valueSig: TypeSig = (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT).nested() + val sparkValueSig: TypeSig = TypeSig.all + val resultSig: TypeSig = TypeSig.STRUCT.nested(valueSig) + val sparkResultSig: TypeSig = TypeSig.STRUCT.nested(sparkValueSig) + + override def tagAst(meta: BaseExprMeta[_]): Unit = { + meta.willNotWorkInAst(AstExprContext.notSupportedMsg) + // when this supports AST tagBase(exprMeta, meta.willNotWorkInAst) + } + + override def tag(meta: RapidsMeta[_, _]): Unit = { + val exprMeta = meta.asInstanceOf[BaseExprMeta[_]] + val context = exprMeta.context + if (context != ProjectExprContext) { + meta.willNotWorkOnGpu(s"this is not supported in the $context context") + } else { + tagBase(exprMeta, meta.willNotWorkOnGpu) + } + } + + private[this] def tagBase(exprMeta: BaseExprMeta[_], willNotWork: String => Unit): Unit = { + exprMeta.childExprs.grouped(2).foreach { + case Seq(nameMeta, valueMeta) => + nameSig.tagExprParam(exprMeta, nameMeta, "name", willNotWork) + valueSig.tagExprParam(exprMeta, valueMeta, "value", willNotWork) + } + exprMeta.typeMeta.dataType.foreach { dt => + if (!resultSig.isSupportedByPlugin(dt)) { + willNotWork(s"unsupported data type in output: $dt") + } + } + } + + override def support(dataType: TypeEnum.Value): + Map[ExpressionContext, Map[String, SupportLevel]] = { + val nameProjectSupport = nameSig.getSupportLevel(dataType, sparkNameSig) + val valueProjectSupport = valueSig.getSupportLevel(dataType, sparkValueSig) + val resultProjectSupport = resultSig.getSupportLevel(dataType, sparkResultSig) + Map((ProjectExprContext, + Map( + ("name", nameProjectSupport), + ("value", valueProjectSupport), + ("result", resultProjectSupport)))) + } +} + +class CastChecks extends ExprChecks { + // Don't show this with other operators show it in a different location + override val shown: Boolean = false + + // When updating these please check child classes too + import TypeSig._ + val nullChecks: TypeSig = integral + fp + BOOLEAN + TIMESTAMP + DATE + STRING + + NULL + DECIMAL_128 + val sparkNullSig: TypeSig = all + + val booleanChecks: TypeSig = integral + fp + BOOLEAN + TIMESTAMP + STRING + DECIMAL_128 + val sparkBooleanSig: TypeSig = cpuNumeric + BOOLEAN + TIMESTAMP + STRING + + val integralChecks: TypeSig = gpuNumeric + BOOLEAN + TIMESTAMP + STRING + + BINARY + val sparkIntegralSig: TypeSig = cpuNumeric + BOOLEAN + TIMESTAMP + STRING + BINARY + + val fpToStringPsNote: String = s"Conversion may produce different results and requires " + + s"${RapidsConf.ENABLE_CAST_FLOAT_TO_STRING} to be true." + val fpChecks: TypeSig = (gpuNumeric + BOOLEAN + TIMESTAMP + STRING) + .withPsNote(TypeEnum.STRING, fpToStringPsNote) + val sparkFpSig: TypeSig = cpuNumeric + BOOLEAN + TIMESTAMP + STRING + + val dateChecks: TypeSig = integral + fp + BOOLEAN + TIMESTAMP + DATE + STRING + val sparkDateSig: TypeSig = cpuNumeric + BOOLEAN + TIMESTAMP + DATE + STRING + + val timestampChecks: TypeSig = integral + fp + BOOLEAN + TIMESTAMP + DATE + STRING + val sparkTimestampSig: TypeSig = cpuNumeric + BOOLEAN + TIMESTAMP + DATE + STRING + + val stringChecks: TypeSig = gpuNumeric + BOOLEAN + TIMESTAMP + DATE + STRING + + BINARY + val sparkStringSig: TypeSig = cpuNumeric + BOOLEAN + TIMESTAMP + DATE + CALENDAR + STRING + BINARY + + val binaryChecks: TypeSig = none + val sparkBinarySig: TypeSig = STRING + BINARY + + val decimalChecks: TypeSig = gpuNumeric + STRING + val sparkDecimalSig: TypeSig = cpuNumeric + BOOLEAN + TIMESTAMP + STRING + + val calendarChecks: TypeSig = none + val sparkCalendarSig: TypeSig = CALENDAR + STRING + + val arrayChecks: TypeSig = psNote(TypeEnum.STRING, "the array's child type must also support " + + "being cast to string") + ARRAY.nested(commonCudfTypes + DECIMAL_128 + NULL + + ARRAY + BINARY + STRUCT + MAP) + + psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to the " + + "desired child type(s)") + + val sparkArraySig: TypeSig = STRING + ARRAY.nested(all) + + val mapChecks: TypeSig = MAP.nested(commonCudfTypes + DECIMAL_128 + NULL + ARRAY + BINARY + + STRUCT + MAP) + + psNote(TypeEnum.MAP, "the map's key and value must also support being cast to the " + + "desired child types") + val sparkMapSig: TypeSig = STRING + MAP.nested(all) + + val structChecks: TypeSig = psNote(TypeEnum.STRING, "the struct's children must also support " + + "being cast to string") + + STRUCT.nested(commonCudfTypes + DECIMAL_128 + NULL + ARRAY + BINARY + STRUCT + MAP) + + psNote(TypeEnum.STRUCT, "the struct's children must also support being cast to the " + + "desired child type(s)") + val sparkStructSig: TypeSig = STRING + STRUCT.nested(all) + + val udtChecks: TypeSig = none + val sparkUdtSig: TypeSig = STRING + UDT + + val daytimeChecks: TypeSig = none + val sparkDaytimeChecks: TypeSig = DAYTIME + STRING + + val yearmonthChecks: TypeSig = none + val sparkYearmonthChecks: TypeSig = YEARMONTH + STRING + + private[this] def getChecksAndSigs(from: DataType): (TypeSig, TypeSig) = from match { + case NullType => (nullChecks, sparkNullSig) + case BooleanType => (booleanChecks, sparkBooleanSig) + case ByteType | ShortType | IntegerType | LongType => (integralChecks, sparkIntegralSig) + case FloatType | DoubleType => (fpChecks, sparkFpSig) + case DateType => (dateChecks, sparkDateSig) + case TimestampType => (timestampChecks, sparkTimestampSig) + case StringType => (stringChecks, sparkStringSig) + case BinaryType => (binaryChecks, sparkBinarySig) + case _: DecimalType => (decimalChecks, sparkDecimalSig) + case CalendarIntervalType => (calendarChecks, sparkCalendarSig) + case _: ArrayType => (arrayChecks, sparkArraySig) + case _: MapType => (mapChecks, sparkMapSig) + case _: StructType => (structChecks, sparkStructSig) + case _ => getChecksAndSigs(TypeSigUtil.mapDataTypeToTypeEnum(from)) + } + + private[this] def getChecksAndSigs(from: TypeEnum.Value): (TypeSig, TypeSig) = from match { + case TypeEnum.NULL => (nullChecks, sparkNullSig) + case TypeEnum.BOOLEAN => (booleanChecks, sparkBooleanSig) + case TypeEnum.BYTE | TypeEnum.SHORT | TypeEnum.INT | TypeEnum.LONG => + (integralChecks, sparkIntegralSig) + case TypeEnum.FLOAT | TypeEnum.DOUBLE => (fpChecks, sparkFpSig) + case TypeEnum.DATE => (dateChecks, sparkDateSig) + case TypeEnum.TIMESTAMP => (timestampChecks, sparkTimestampSig) + case TypeEnum.STRING => (stringChecks, sparkStringSig) + case TypeEnum.BINARY => (binaryChecks, sparkBinarySig) + case TypeEnum.DECIMAL => (decimalChecks, sparkDecimalSig) + case TypeEnum.CALENDAR => (calendarChecks, sparkCalendarSig) + case TypeEnum.ARRAY => (arrayChecks, sparkArraySig) + case TypeEnum.MAP => (mapChecks, sparkMapSig) + case TypeEnum.STRUCT => (structChecks, sparkStructSig) + case TypeEnum.UDT => (udtChecks, sparkUdtSig) + case TypeEnum.DAYTIME => (daytimeChecks, sparkDaytimeChecks) + case TypeEnum.YEARMONTH => (yearmonthChecks, sparkYearmonthChecks) + } + + override def tagAst(meta: BaseExprMeta[_]): Unit = { + meta.willNotWorkInAst(AstExprContext.notSupportedMsg) + // when this supports AST tagBase(meta, meta.willNotWorkInAst) + } + + override def tag(meta: RapidsMeta[_, _]): Unit = { + val exprMeta = meta.asInstanceOf[BaseExprMeta[_]] + val context = exprMeta.context + if (context != ProjectExprContext) { + meta.willNotWorkOnGpu(s"this is not supported in the $context context") + } else { + tagBase(meta, meta.willNotWorkOnGpu) + } + } + + private[this] def tagBase(meta: RapidsMeta[_, _], willNotWork: String => Unit): Unit = { + val cast = meta.wrapped.asInstanceOf[UnaryExpression] + val from = cast.child.dataType + val to = cast.dataType + if (!gpuCanCast(from, to)) { + willNotWork(s"${meta.wrapped.getClass.getSimpleName} from $from to $to is not supported") + } + } + + override def support( + dataType: TypeEnum.Value): Map[ExpressionContext, Map[String, SupportLevel]] = { + throw new IllegalStateException("support is different for cast") + } + + def support(from: TypeEnum.Value, to: TypeEnum.Value): SupportLevel = { + val (checks, sparkSig) = getChecksAndSigs(from) + checks.getSupportLevel(to, sparkSig) + } + + def sparkCanCast(from: DataType, to: DataType): Boolean = { + val (_, sparkSig) = getChecksAndSigs(from) + sparkSig.isSupportedBySpark(to) + } + + def gpuCanCast(from: DataType, to: DataType): Boolean = { + val (checks, _) = getChecksAndSigs(from) + checks.isSupportedByPlugin(to) + } +} + +object ExprChecks { + /** + * A check for an expression that only supports project. + */ + def projectOnly( + outputCheck: TypeSig, + sparkOutputSig: TypeSig, + paramCheck: Seq[ParamCheck] = Seq.empty, + repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks = + ExprChecksImpl(Map( + (ProjectExprContext, + ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) + + /** + * A check for an expression that supports project and as much of AST as it can. + */ + def projectAndAst( + allowedAstTypes: TypeSig, + outputCheck: TypeSig, + sparkOutputSig: TypeSig, + paramCheck: Seq[ParamCheck] = Seq.empty, + repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks = { + val astOutputCheck = outputCheck.intersect(allowedAstTypes) + val astParamCheck = paramCheck.map { pc => + ParamCheck(pc.name, pc.cudf.intersect(allowedAstTypes), pc.spark) + } + val astRepeatingParamCheck = repeatingParamCheck.map { rpc => + RepeatingParamCheck(rpc.name, rpc.cudf.intersect(allowedAstTypes), rpc.spark) + } + ExprChecksImpl(Map( + ProjectExprContext -> + ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck), + AstExprContext -> + ContextChecks(astOutputCheck, sparkOutputSig, astParamCheck, astRepeatingParamCheck) + )) + } + + /** + * A check for a unary expression that only support project. + */ + def unaryProject( + outputCheck: TypeSig, + sparkOutputSig: TypeSig, + inputCheck: TypeSig, + sparkInputSig: TypeSig): ExprChecks = + projectOnly(outputCheck, sparkOutputSig, + Seq(ParamCheck("input", inputCheck, sparkInputSig))) + + /** + * A check for a unary expression that supports project and as much AST as it can. + */ + def unaryProjectAndAst( + allowedAstTypes: TypeSig, + outputCheck: TypeSig, + sparkOutputSig: TypeSig, + inputCheck: TypeSig, + sparkInputSig: TypeSig): ExprChecks = + projectAndAst(allowedAstTypes, outputCheck, sparkOutputSig, + Seq(ParamCheck("input", inputCheck, sparkInputSig))) + + /** + * Unary expression checks for project where the input matches the output. + */ + def unaryProjectInputMatchesOutput(check: TypeSig, sparkSig: TypeSig): ExprChecks = + unaryProject(check, sparkSig, check, sparkSig) + + /** + * Unary expression checks for project where the input matches the output and it also + * supports as much of AST as it can. + */ + def unaryProjectAndAstInputMatchesOutput( + allowedAstTypes: TypeSig, + check: TypeSig, + sparkSig: TypeSig): ExprChecks = + unaryProjectAndAst(allowedAstTypes, check, sparkSig, check, sparkSig) + + /** + * Math unary checks where input and output are both DoubleType. + */ + val mathUnary: ExprChecks = unaryProjectInputMatchesOutput(TypeSig.DOUBLE, TypeSig.DOUBLE) + + /** + * Math unary checks where input and output are both DoubleType and AST is supported. + */ + val mathUnaryWithAst: ExprChecks = + unaryProjectAndAstInputMatchesOutput( + TypeSig.implicitCastsAstTypes, TypeSig.DOUBLE, TypeSig.DOUBLE) + + /** + * Helper function for a binary expression where the plugin only supports project. + */ + def binaryProject( + outputCheck: TypeSig, + sparkOutputSig: TypeSig, + param1: (String, TypeSig, TypeSig), + param2: (String, TypeSig, TypeSig)): ExprChecks = + projectOnly(outputCheck, sparkOutputSig, + Seq(ParamCheck(param1._1, param1._2, param1._3), + ParamCheck(param2._1, param2._2, param2._3))) + + /** + * Helper function for a binary expression where the plugin supports project and AST. + */ + def binaryProjectAndAst( + allowedAstTypes: TypeSig, + outputCheck: TypeSig, + sparkOutputSig: TypeSig, + param1: (String, TypeSig, TypeSig), + param2: (String, TypeSig, TypeSig)): ExprChecks = + projectAndAst(allowedAstTypes, outputCheck, sparkOutputSig, + Seq(ParamCheck(param1._1, param1._2, param1._3), + ParamCheck(param2._1, param2._2, param2._3))) + + /** + * Aggregate operation where only group by agg and reduction is supported in the plugin and in + * Spark. + */ + def reductionAndGroupByAgg( + outputCheck: TypeSig, + sparkOutputSig: TypeSig, + paramCheck: Seq[ParamCheck] = Seq.empty, + repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks = + ExprChecksImpl(Map( + (GroupByAggExprContext, + ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), + (ReductionAggExprContext, + ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) + + /** + * Aggregate operation where window, reduction, and group by agg are all supported the same. + */ + def fullAgg( + outputCheck: TypeSig, + sparkOutputSig: TypeSig, + paramCheck: Seq[ParamCheck] = Seq.empty, + repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks = + ExprChecksImpl(Map( + (GroupByAggExprContext, + ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), + (ReductionAggExprContext, + ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), + (WindowAggExprContext, + ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) + + /** + * For a generic expression that can work as both an aggregation and in the project context. + * This is really just for PythonUDF. + */ + def fullAggAndProject( + outputCheck: TypeSig, + sparkOutputSig: TypeSig, + paramCheck: Seq[ParamCheck] = Seq.empty, + repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks = + ExprChecksImpl(Map( + (GroupByAggExprContext, + ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), + (ReductionAggExprContext, + ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), + (WindowAggExprContext, + ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), + (ProjectExprContext, + ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) + + /** + * An aggregation check where group by and reduction are supported by the plugin, but Spark + * also supports window operations on these. + */ + def aggNotWindow( + outputCheck: TypeSig, + sparkOutputSig: TypeSig, + paramCheck: Seq[ParamCheck] = Seq.empty, + repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks = { + val windowParamCheck = paramCheck.map { pc => + ParamCheck(pc.name, TypeSig.none, pc.spark) + } + val windowRepeat = repeatingParamCheck.map { pc => + RepeatingParamCheck(pc.name, TypeSig.none, pc.spark) + } + ExprChecksImpl(Map( + (GroupByAggExprContext, + ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), + (ReductionAggExprContext, + ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), + (WindowAggExprContext, + ContextChecks(TypeSig.none, sparkOutputSig, windowParamCheck, windowRepeat)))) + } + + /** + * Window only operations. Spark does not support these operations as anything but a window + * operation. + */ + def windowOnly( + outputCheck: TypeSig, + sparkOutputSig: TypeSig, + paramCheck: Seq[ParamCheck] = Seq.empty, + repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks = + ExprChecksImpl(Map( + (WindowAggExprContext, + ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) + + + /** + * An aggregation check where group by is supported by the plugin, but Spark also supports + * reduction and window operations on these. + */ + def groupByOnly( + outputCheck: TypeSig, + sparkOutputSig: TypeSig, + paramCheck: Seq[ParamCheck] = Seq.empty, + repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks = { + val noneParamCheck = paramCheck.map { pc => + ParamCheck(pc.name, TypeSig.none, pc.spark) + } + val noneRepeatCheck = repeatingParamCheck.map { pc => + RepeatingParamCheck(pc.name, TypeSig.none, pc.spark) + } + ExprChecksImpl(Map( + (ReductionAggExprContext, + ContextChecks(TypeSig.none, sparkOutputSig, noneParamCheck, noneRepeatCheck)), + (GroupByAggExprContext, + ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), + (WindowAggExprContext, + ContextChecks(TypeSig.none, sparkOutputSig, noneParamCheck, noneRepeatCheck)))) + } + + /** + * An aggregation check where group by and window operations are supported by the plugin, but + * Spark also supports reduction on these. + */ + def aggNotReduction( + outputCheck: TypeSig, + sparkOutputSig: TypeSig, + paramCheck: Seq[ParamCheck] = Seq.empty, + repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks = { + val noneParamCheck = paramCheck.map { pc => + ParamCheck(pc.name, TypeSig.none, pc.spark) + } + val noneRepeatCheck = repeatingParamCheck.map { pc => + RepeatingParamCheck(pc.name, TypeSig.none, pc.spark) + } + ExprChecksImpl(Map( + (ReductionAggExprContext, + ContextChecks(TypeSig.none, sparkOutputSig, noneParamCheck, noneRepeatCheck)), + (GroupByAggExprContext, + ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), + (WindowAggExprContext, + ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) + } +} + +/** + * Used for generating the support docs. + */ +object SupportedOpsDocs { + private lazy val allSupportedTypes = + TypeSigUtil.getAllSupportedTypes() + + private def execChecksHeaderLine(): Unit = { + println("") + println("Executor") + println("Description") + println("Notes") + println("Param(s)") + allSupportedTypes.foreach { t => + println(s"$t") + } + println("") + } + + private def exprChecksHeaderLine(): Unit = { + println("") + println("Expression") + println("SQL Functions(s)") + println("Description") + println("Notes") + println("Context") + println("Param/Output") + allSupportedTypes.foreach { t => + println(s"$t") + } + println("") + } + + private def partChecksHeaderLine(): Unit = { + println("") + println("Partition") + println("Description") + println("Notes") + println("Param") + allSupportedTypes.foreach { t => + println(s"$t") + } + println("") + } + + private def ioChecksHeaderLine(): Unit = { + println("") + println("Format") + println("Direction") + allSupportedTypes.foreach { t => + println(s"$t") + } + println("") + } + + def getSparkVersion: String = { + // hack for databricks, try to find something more reliable? + if (SPARK_BUILD_USER.equals("Databricks")) { + SPARK_VERSION + "-databricks" + } else { + SPARK_VERSION + } + } + + def help(): Unit = { + val headerEveryNLines = 15 + // scalastyle:off line.size.limit + println("---") + println("layout: page") + println("title: Supported Operators") + println("nav_order: 6") + println("---") + println("") + println("Apache Spark supports processing various types of data. Not all expressions") + println("support all data types. The RAPIDS Accelerator for Apache Spark has further") + println("restrictions on what types are supported for processing. This tries") + println("to document what operations are supported and what data types each operation supports.") + println("Because Apache Spark is under active development too and this document was generated") + println(s"against version ${getSparkVersion} of Spark. Most of this should still") + println("apply to other versions of Spark, but there may be slight changes.") + println() + println("# General limitations") + println("## `Decimal`") + println("The `Decimal` type in Spark supports a precision") + println("up to 38 digits (128-bits). The RAPIDS Accelerator in most cases stores values up to") + println("64-bits and will support 128-bit in the future. As such the accelerator currently only") + println(s"supports a precision up to ${GpuOverrides.DECIMAL64_MAX_PRECISION} digits. Note that") + println("decimals are disabled by default in the plugin, because it is supported by a relatively") + println("small number of operations presently. This can result in a lot of data movement to and") + println("from the GPU, slowing down processing in some cases.") + println("Result `Decimal` precision and scale follow the same rule as CPU mode in Apache Spark:") + println() + println("```") + println(" * In particular, if we have expressions e1 and e2 with precision/scale p1/s1 and p2/s2") + println(" * respectively, then the following operations have the following precision / scale:") + println(" *") + println(" * Operation Result Precision Result Scale") + println(" * ------------------------------------------------------------------------") + println(" * e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)") + println(" * e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)") + println(" * e1 * e2 p1 + p2 + 1 s1 + s2") + println(" * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1)") + println(" * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2)") + println(" * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2)") + println("```") + println() + println("However, Spark inserts `PromotePrecision` to CAST both sides to the same type.") + println("GPU mode may fall back to CPU even if the result Decimal precision is within 18 digits.") + println("For example, `Decimal(8,2)` x `Decimal(6,3)` resulting in `Decimal (15,5)` runs on CPU,") + println("because due to `PromotePrecision`, GPU mode assumes the result is `Decimal(19,6)`.") + println("There are even extreme cases where Spark can temporarily return a Decimal value") + println("larger than what can be stored in 128-bits and then uses the `CheckOverflow`") + println("operator to round it to a desired precision and scale. This means that even when") + println("the accelerator supports 128-bit decimal, we might not be able to support all") + println("operations that Spark can support.") + println() + println("## `Timestamp`") + println("Timestamps in Spark will all be converted to the local time zone before processing") + println("and are often converted to UTC before being stored, like in Parquet or ORC.") + println("The RAPIDS Accelerator only supports UTC as the time zone for timestamps.") + println() + println("## `CalendarInterval`") + println("In Spark `CalendarInterval`s store three values, months, days, and microseconds.") + println("Support for this type is still very limited in the accelerator. In some cases") + println("only a a subset of the type is supported, like window ranges only support days currently.") + println() + println("## Configuration") + println("There are lots of different configuration values that can impact if an operation") + println("is supported or not. Some of these are a part of the RAPIDS Accelerator and cover") + println("the level of compatibility with Apache Spark. Those are covered [here](configs.md).") + println("Others are a part of Apache Spark itself and those are a bit harder to document.") + println("The work of updating this to cover that support is still ongoing.") + println() + println("In general though if you ever have any question about why an operation is not running") + println("on the GPU you may set `spark.rapids.sql.explain` to ALL and it will try to give all of") + println("the reasons why this particular operator or expression is on the CPU or GPU.") + println() + println("# Key") + println("## Types") + println() + println("|Type Name|Type Description|") + println("|---------|----------------|") + println("|BOOLEAN|Holds true or false values.|") + println("|BYTE|Signed 8-bit integer value.|") + println("|SHORT|Signed 16-bit integer value.|") + println("|INT|Signed 32-bit integer value.|") + println("|LONG|Signed 64-bit integer value.|") + println("|FLOAT|32-bit floating point value.|") + println("|DOUBLE|64-bit floating point value.|") + println("|DATE|A date with no time component. Stored as 32-bit integer with days since Jan 1, 1970.|") + println("|TIMESTAMP|A date and time. Stored as 64-bit integer with microseconds since Jan 1, 1970 in the current time zone.|") + println("|STRING|A text string. Stored as UTF-8 encoded bytes.|") + println("|DECIMAL|A fixed point decimal value with configurable precision and scale.|") + println("|NULL|Only stores null values and is typically only used when no other type can be determined from the SQL.|") + println("|BINARY|An array of non-nullable bytes.|") + println("|CALENDAR|Represents a period of time. Stored as months, days and microseconds.|") + println("|ARRAY|A sequence of elements.|") + println("|MAP|A set of key value pairs, the keys cannot be null.|") + println("|STRUCT|A series of named fields.|") + println("|UDT|User defined types and java Objects. These are not standard SQL types.|") + println() + println("## Support") + println() + println("|Value|Description|") + println("|---------|----------------|") + println("|S| (Supported) Both Apache Spark and the RAPIDS Accelerator support this type fully.|") + println("| | (Not Applicable) Neither Spark not the RAPIDS Accelerator support this type in this situation.|") + println("|_PS_| (Partial Support) Apache Spark supports this type, but the RAPIDS Accelerator only partially supports it. An explanation for what is missing will be included with this.|") + println("|**NS**| (Not Supported) Apache Spark supports this type but the RAPIDS Accelerator does not.") + println() + println("# SparkPlan or Executor Nodes") + println("Apache Spark uses a Directed Acyclic Graph(DAG) of processing to build a query.") + println("The nodes in this graph are instances of `SparkPlan` and represent various high") + println("level operations like doing a filter or project. The operations that the RAPIDS") + println("Accelerator supports are described below.") + println("") + execChecksHeaderLine() + var totalCount = 0 + var nextOutputAt = headerEveryNLines + GpuOverrides.execs.values.toSeq.sortBy(_.tag.toString).foreach { rule => + val checks = rule.getChecks + if (rule.isVisible && checks.forall(_.shown)) { + if (totalCount >= nextOutputAt) { + execChecksHeaderLine() + nextOutputAt = totalCount + headerEveryNLines + } + println("") + val execChecks = checks.get.asInstanceOf[ExecChecks] + val allData = allSupportedTypes.toList.map { t => + (t, execChecks.support(t)) + }.toMap + + val notes = execChecks.supportNotes + // Now we should get the same keys for each type, so we are only going to look at the first + // type for now + val totalSpan = allData.values.head.size + val inputs = allData.values.head.keys + + println(s"""""") + println(s"""""") + println(s"""""") + var count = 0 + inputs.foreach { input => + val named = notes.get(input) + .map(l => input + "
(" + l.mkString(";
") + ")") + .getOrElse(input) + println(s"") + allSupportedTypes.foreach { t => + println(allData(t)(input).htmlTag) + } + println("") + count += 1 + if (count < totalSpan) { + println("") + } + } + + totalCount += totalSpan + } + } + println("
${rule.tag.runtimeClass.getSimpleName}${rule.description}${rule.notes().getOrElse("None")}$named
") + println() + println("# Expression and SQL Functions") + println("Inside each node in the DAG there can be one or more trees of expressions") + println("that describe various types of processing that happens in that part of the plan.") + println("These can be things like adding two numbers together or checking for null.") + println("These expressions can have multiple input parameters and one output value.") + println("These expressions also can happen in different contexts. Because of how the") + println("accelerator works different contexts have different levels of support.") + println() + println("The most common expression context is `project`. In this context values from a single") + println("input row go through the expression and the result will also be use to produce") + println("something in the same row. Be aware that even in the case of aggregation and window") + println("operations most of the processing is still done in the project context either before") + println("or after the other processing happens.") + println() + println("Aggregation operations like count or sum can take place in either the `aggregation`,") + println("`reduction`, or `window` context. `aggregation` is when the operation was done while") + println("grouping the data by one or more keys. `reduction` is when there is no group by and") + println("there is a single result for an entire column. `window` is for window operations.") + println() + println("The final expression context is `AST` or Abstract Syntax Tree.") + println("Before explaining AST we first need to explain in detail how project context operations") + println("work. Generally for a project context operation the plan Spark developed is read") + println("on the CPU and an appropriate set of GPU kernels are selected to do those") + println("operations. For example `a >= b + 1`. Would result in calling a GPU kernel to add") + println("`1` to `b`, followed by another kernel that is called to compare `a` to that result.") + println("The interpretation is happening on the CPU, and the GPU is used to do the processing.") + println("For AST the interpretation for some reason cannot happen on the CPU and instead must") + println("be done in the GPU kernel itself. An example of this is conditional joins. If you") + println("want to join on `A.a >= B.b + 1` where `A` and `B` are separate tables or data") + println("frames, the `+` and `>=` operations cannot run as separate independent kernels") + println("because it is done on a combination of rows in both `A` and `B`. Instead part of the") + println("plan that Spark developed is turned into an abstract syntax tree and sent to the GPU") + println("where it can be interpreted. The number and types of operations supported in this") + println("are limited.") + println("") + exprChecksHeaderLine() + totalCount = 0 + nextOutputAt = headerEveryNLines + GpuOverrides.expressions.values.toSeq.sortBy(_.tag.toString).foreach { rule => + val checks = rule.getChecks + if (rule.isVisible && checks.isDefined && checks.forall(_.shown)) { + if (totalCount >= nextOutputAt) { + exprChecksHeaderLine() + nextOutputAt = totalCount + headerEveryNLines + } + val sqlFunctions = + ConfHelper.getSqlFunctionsForClass(rule.tag.runtimeClass).map(_.mkString(", ")) + val exprChecks = checks.get.asInstanceOf[ExprChecks] + // Params can change between contexts, but should not + val allData = allSupportedTypes.toList.map { t => + (t, exprChecks.support(t)) + }.toMap + // Now we should get the same keys for each type, so we are only going to look at the first + // type for now + val totalSpan = allData.values.head.map { + case (_, m: Map[String, SupportLevel]) => m.size + }.sum + val representative = allData.values.head + println("") + println("") + println("") + println("") + println("") + var count = 0 + representative.foreach { + case (context, data) => + val contextSpan = data.size + println("") + data.keys.foreach { param => + println(s"") + allSupportedTypes.foreach { t => + println(allData(t)(context)(param).htmlTag) + } + println("") + count += 1 + if (count < totalSpan) { + println("") + } + } + } + totalCount += totalSpan + } + } + println("
" + + s"${rule.tag.runtimeClass.getSimpleName}" + s"${sqlFunctions.getOrElse(" ")}" + s"${rule.description}" + s"${rule.notes().getOrElse("None")}" + s"$context$param
") + println() + println("## Casting") + println("The above table does not show what is and is not supported for cast.") + println("This table shows the matrix of supported casts.") + println("Nested types like MAP, Struct, and Array can only be cast if the child types") + println("can be cast.") + println() + println("Some of the casts to/from string on the GPU are not 100% the same and are disabled") + println("by default. Please see the configs for more details on these specific cases.") + println() + println("Please note that even though casting from one type to another is supported") + println("by Spark it does not mean they all produce usable results. For example casting") + println("from a date to a boolean always produces a null. This is for Hive compatibility") + println("and the accelerator produces the same result.") + println() + GpuOverrides.expressions.values.toSeq.sortBy(_.tag.toString).foreach { rule => + rule.getChecks match { + case Some(cc: CastChecks) => + println(s"### `${rule.tag.runtimeClass.getSimpleName}`") + println() + println("") + val numTypes = allSupportedTypes.size + println("") + println("") + allSupportedTypes.foreach { t => + println(s"") + } + println("") + + println("") + var count = 0 + allSupportedTypes.foreach { from => + println(s"") + allSupportedTypes.foreach { to => + println(cc.support(from, to).htmlTag) + } + println("") + count += 1 + if (count < numTypes) { + println("") + } + } + println("
TO
$t
FROM$from
") + println() + case _ => // Nothing + } + } + println() + println("# Partitioning") + println("When transferring data between different tasks the data is partitioned in") + println("specific ways depending on requirements in the plan. Be aware that the types") + println("included below are only for rows that impact where the data is partitioned.") + println("So for example if we are doing a join on the column `a` the data would be") + println("hash partitioned on `a`, but all of the other columns in the same data frame") + println("as `a` don't show up in the table. They are controlled by the rules for") + println("`ShuffleExchangeExec` which uses the `Partitioning`.") + println("") + partChecksHeaderLine() + totalCount = 0 + nextOutputAt = headerEveryNLines + GpuOverrides.parts.values.toSeq.sortBy(_.tag.toString).foreach { rule => + val checks = rule.getChecks + if (rule.isVisible && checks.isDefined && checks.forall(_.shown)) { + if (totalCount >= nextOutputAt) { + partChecksHeaderLine() + nextOutputAt = totalCount + headerEveryNLines + } + val partChecks = checks.get.asInstanceOf[PartChecks] + val allData = allSupportedTypes.toList.map { t => + (t, partChecks.support(t)) + }.toMap + // Now we should get the same keys for each type, so we are only going to look at the first + // type for now + val totalSpan = allData.values.head.size + if (totalSpan > 0) { + val representative = allData.values.head + println("") + println("") + println("") + println("") + var count = 0 + representative.keys.foreach { param => + println(s"") + allSupportedTypes.foreach { t => + println(allData(t)(param).htmlTag) + } + println("") + count += 1 + if (count < totalSpan) { + println("") + } + } + totalCount += totalSpan + } else { + // No arguments... + println("") + println(s"") + println(s"") + println(s"") + println(NotApplicable.htmlTag) // param + allSupportedTypes.foreach { _ => + println(NotApplicable.htmlTag) + } + println("") + totalCount += 1 + } + } + } + println("
" + + s"${rule.tag.runtimeClass.getSimpleName}" + s"${rule.description}" + s"${rule.notes().getOrElse("None")}$param
${rule.tag.runtimeClass.getSimpleName}${rule.description}${rule.notes().getOrElse("None")}
") + println() + println("## Input/Output") + println("For Input and Output it is not cleanly exposed what types are supported and which are not.") + println("This table tries to clarify that. Be aware that some types may be disabled in some") + println("cases for either reads or writes because of processing limitations, like rebasing") + println("dates or timestamps, or for a lack of type coercion support.") + println("") + ioChecksHeaderLine() + totalCount = 0 + nextOutputAt = headerEveryNLines + GpuOverrides.fileFormats.toSeq.sortBy(_._1.toString).foreach { + case (format, ioMap) => + if (totalCount >= nextOutputAt) { + ioChecksHeaderLine() + nextOutputAt = totalCount + headerEveryNLines + } + val read = ioMap(ReadFileOp) + val write = ioMap(WriteFileOp) + println("") + println("") + println("") + allSupportedTypes.foreach { t => + println(read.support(t).htmlTag) + } + println("") + println("") + println("") + allSupportedTypes.foreach { t => + println(write.support(t).htmlTag) + } + println("") + totalCount += 2 + } + println("
" + s"$formatRead
Write
") + // scalastyle:on line.size.limit + } + + def main(args: Array[String]): Unit = { + val out = new FileOutputStream(new File(args(0))) + Console.withOut(out) { + Console.withErr(out) { + SupportedOpsDocs.help() + } + } + } +} + +object SupportedOpsForTools { + + private lazy val allSupportedTypes = + TypeSigUtil.getAllSupportedTypes() + + private def outputSupportIO() { + // Look at what we have for defaults for some configs because if the configs are off + // it likely means something isn't completely compatible. + val conf = new RapidsConf(Map.empty[String, String]) + val types = allSupportedTypes.toSeq + val header = Seq("Format", "Direction") ++ types + val writeOps: Array[String] = Array.fill(types.size)("NA") + println(header.mkString(",")) + GpuOverrides.fileFormats.toSeq.sortBy(_._1.toString).foreach { + case (format, ioMap) => + val formatLowerCase = format.toString.toLowerCase + val formatEnabled = formatLowerCase match { + case "csv" => conf.isCsvEnabled && conf.isCsvReadEnabled + case "parquet" => conf.isParquetEnabled && conf.isParquetReadEnabled + case "orc" => conf.isOrcEnabled && conf.isOrcReadEnabled + case _ => + throw new IllegalArgumentException("Format is unknown we need to add it here!") + } + val read = ioMap(ReadFileOp) + // we have lots of configs for various operations, just try to get the main ones + val readOps = types.map { t => + val typeEnabled = if (format.toString.toLowerCase.equals("csv")) { + t.toString match { + case "BOOLEAN" => conf.isCsvBoolReadEnabled + case "BYTE" => conf.isCsvByteReadEnabled + case "SHORT" => conf.isCsvShortReadEnabled + case "INT" => conf.isCsvIntReadEnabled + case "LONG" => conf.isCsvLongReadEnabled + case "FLOAT" => conf.isCsvFloatReadEnabled + case "DOUBLE" => conf.isCsvDoubleReadEnabled + case "TIMESTAMP" => conf.isCsvTimestampReadEnabled + case "DATE" => conf.isCsvDateReadEnabled + case _ => true + } + } else { + t.toString match { + case _ => true + } + } + if (!formatEnabled || !typeEnabled) { + // indicate configured off by default + "CO" + } else { + read.support(t).text + } + } + // print read formats and types + println(s"${(Seq(format, "read") ++ readOps).mkString(",")}") + + val writeFileFormat = ioMap(WriteFileOp).getFileFormat + // print supported write formats and NA for types. Cannot determine types from event logs. + if (writeFileFormat != TypeSig.none) { + println(s"${(Seq(format, "write") ++ writeOps).mkString(",")}") + } + } + } + + def help(): Unit = { + outputSupportIO() + } + + def main(args: Array[String]): Unit = { + val out = new FileOutputStream(new File(args(0))) + Console.withOut(out) { + Console.withErr(out) { + SupportedOpsForTools.help() + } + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregateMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregateMeta.scala new file mode 100644 index 00000000000..12dd00d0915 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregateMeta.scala @@ -0,0 +1,275 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import scala.annotation.tailrec +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExprId, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.execution.{SortExec, SparkPlan, TrampolineUtil} +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.types.{ArrayType, DataType, DecimalType, MapType} + +// Spark 2.x - had to copy the GpuBaseAggregateMeta into each Hash and Sort Meta because no +// BaseAggregateExec class in Spark 2.x + +class GpuHashAggregateMeta( + val agg: HashAggregateExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends SparkPlanMeta[HashAggregateExec](agg, conf, parent, rule) { + + val groupingExpressions: Seq[BaseExprMeta[_]] = + agg.groupingExpressions.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val aggregateExpressions: Seq[BaseExprMeta[_]] = + agg.aggregateExpressions.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val aggregateAttributes: Seq[BaseExprMeta[_]] = + agg.aggregateAttributes.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val resultExpressions: Seq[BaseExprMeta[_]] = + agg.resultExpressions.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + + override val childExprs: Seq[BaseExprMeta[_]] = + groupingExpressions ++ aggregateExpressions ++ aggregateAttributes ++ resultExpressions + + override def tagPlanForGpu(): Unit = { + // We don't support Arrays and Maps as GroupBy keys yet, even they are nested in Structs. So, + // we need to run recursive type check on the structs. + val arrayOrMapGroupings = agg.groupingExpressions.exists(e => + TrampolineUtil.dataTypeExistsRecursively(e.dataType, + dt => dt.isInstanceOf[ArrayType] || dt.isInstanceOf[MapType])) + if (arrayOrMapGroupings) { + willNotWorkOnGpu("ArrayTypes or MapTypes in grouping expressions are not supported") + } + + val dec128Grouping = agg.groupingExpressions.exists(e => + TrampolineUtil.dataTypeExistsRecursively(e.dataType, + dt => dt.isInstanceOf[DecimalType] && + dt.asInstanceOf[DecimalType].precision > GpuOverrides.DECIMAL64_MAX_PRECISION)) + if (dec128Grouping) { + willNotWorkOnGpu("grouping by a 128-bit decimal value is not currently supported") + } + + tagForReplaceMode() + } + + /** + * Tagging checks tied to configs that control the aggregation modes that are replaced. + * + * The rule of replacement is determined by `spark.rapids.sql.hashAgg.replaceMode`, which + * is a string configuration consisting of AggregateMode names in lower cases connected by + * &(AND) and |(OR). The default value of this config is `all`, which indicates replacing all + * aggregates if possible. + * + * The `|` serves as the outer connector, which represents patterns of both sides are able to be + * replaced. For instance, `final|partialMerge` indicates that aggregate plans purely in either + * Final mode or PartialMerge mode can be replaced. But aggregate plans also contain + * AggExpressions of other mode will NOT be replaced, such as: stage 3 of single distinct + * aggregate who contains both Partial and PartialMerge. + * + * On the contrary, the `&` serves as the inner connector, which intersects modes of both sides + * to form a mode pattern. The replacement only takes place for aggregate plans who have the + * exact same mode pattern as what defined the rule. For instance, `partial&partialMerge` means + * that aggregate plans can be only replaced if they contain AggExpressions of Partial and + * contain AggExpressions of PartialMerge and don't contain AggExpressions of other modes. + * + * In practice, we need to combine `|` and `&` to form some sophisticated patterns. For instance, + * `final&complete|final|partialMerge` represents aggregate plans in three different patterns are + * GPU-replaceable: plans contain both Final and Complete modes; plans only contain Final mode; + * plans only contain PartialMerge mode. + */ + private def tagForReplaceMode(): Unit = { + val aggPattern = agg.aggregateExpressions.map(_.mode).toSet + val strPatternToReplace = conf.hashAggReplaceMode.toLowerCase + + if (aggPattern.nonEmpty && strPatternToReplace != "all") { + val aggPatternsCanReplace = strPatternToReplace.split("\\|").map { subPattern => + subPattern.split("&").map { + case "partial" => Partial + case "partialmerge" => PartialMerge + case "final" => Final + case "complete" => Complete + case s => throw new IllegalArgumentException(s"Invalid Aggregate Mode $s") + }.toSet + } + if (!aggPatternsCanReplace.contains(aggPattern)) { + val message = aggPattern.map(_.toString).mkString(",") + willNotWorkOnGpu(s"Replacing mode pattern `$message` hash aggregates disabled") + } else if (aggPattern == Set(Partial)) { + // In partial mode, if there are non-distinct functions and multiple distinct functions, + // non-distinct functions are computed using the First operator. The final result would be + // incorrect for non-distinct functions for partition size > 1. Reason for this is - if + // the first batch computed and sent to CPU doesn't contain all the rows required to + // compute non-distinct function(s), then Spark would consider that value as final result + // (due to First). Fall back to CPU in this case. + if (AggregateUtils.shouldFallbackMultiDistinct(agg.aggregateExpressions)) { + willNotWorkOnGpu("Aggregates of non-distinct functions with multiple distinct " + + "functions are non-deterministic for non-distinct functions as it is " + + "computed using First.") + } + } + } + + if (!conf.partialMergeDistinctEnabled && aggPattern.contains(PartialMerge)) { + willNotWorkOnGpu("Replacing Partial Merge aggregates disabled. " + + s"Set ${conf.partialMergeDistinctEnabled} to true if desired") + } + } +} + +class GpuSortAggregateExecMeta( + val agg: SortAggregateExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends SparkPlanMeta[SortAggregateExec](agg, conf, parent, rule) { + + val groupingExpressions: Seq[BaseExprMeta[_]] = + agg.groupingExpressions.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val aggregateExpressions: Seq[BaseExprMeta[_]] = + agg.aggregateExpressions.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val aggregateAttributes: Seq[BaseExprMeta[_]] = + agg.aggregateAttributes.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val resultExpressions: Seq[BaseExprMeta[_]] = + agg.resultExpressions.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + + override val childExprs: Seq[BaseExprMeta[_]] = + groupingExpressions ++ aggregateExpressions ++ aggregateAttributes ++ resultExpressions + + override def tagPlanForGpu(): Unit = { + // We don't support Arrays and Maps as GroupBy keys yet, even they are nested in Structs. So, + // we need to run recursive type check on the structs. + val arrayOrMapGroupings = agg.groupingExpressions.exists(e => + TrampolineUtil.dataTypeExistsRecursively(e.dataType, + dt => dt.isInstanceOf[ArrayType] || dt.isInstanceOf[MapType])) + if (arrayOrMapGroupings) { + willNotWorkOnGpu("ArrayTypes or MapTypes in grouping expressions are not supported") + } + + val dec128Grouping = agg.groupingExpressions.exists(e => + TrampolineUtil.dataTypeExistsRecursively(e.dataType, + dt => dt.isInstanceOf[DecimalType] && + dt.asInstanceOf[DecimalType].precision > GpuOverrides.DECIMAL64_MAX_PRECISION)) + if (dec128Grouping) { + willNotWorkOnGpu("grouping by a 128-bit decimal value is not currently supported") + } + + tagForReplaceMode() + + // Make sure this is the last check - if this is SortAggregate, the children can be sorts and we + // want to validate they can run on GPU and remove them before replacing this with a + // HashAggregate. We don't want to do this if there is a first or last aggregate, + // because dropping the sort will make them no longer deterministic. + // In the future we might be able to pull the sort functionality into the aggregate so + // we can sort a single batch at a time and sort the combined result as well which would help + // with data skew. + val hasFirstOrLast = agg.aggregateExpressions.exists { agg => + agg.aggregateFunction match { + case _: First | _: Last => true + case _ => false + } + } + if (canThisBeReplaced && !hasFirstOrLast) { + childPlans.foreach { plan => + if (plan.wrapped.isInstanceOf[SortExec]) { + if (!plan.canThisBeReplaced) { + willNotWorkOnGpu("one of the preceding SortExec's cannot be replaced") + } else { + plan.shouldBeRemoved("replacing sort aggregate with hash aggregate") + } + } + } + } + } + + /** + * Tagging checks tied to configs that control the aggregation modes that are replaced. + * + * The rule of replacement is determined by `spark.rapids.sql.hashAgg.replaceMode`, which + * is a string configuration consisting of AggregateMode names in lower cases connected by + * &(AND) and |(OR). The default value of this config is `all`, which indicates replacing all + * aggregates if possible. + * + * The `|` serves as the outer connector, which represents patterns of both sides are able to be + * replaced. For instance, `final|partialMerge` indicates that aggregate plans purely in either + * Final mode or PartialMerge mode can be replaced. But aggregate plans also contain + * AggExpressions of other mode will NOT be replaced, such as: stage 3 of single distinct + * aggregate who contains both Partial and PartialMerge. + * + * On the contrary, the `&` serves as the inner connector, which intersects modes of both sides + * to form a mode pattern. The replacement only takes place for aggregate plans who have the + * exact same mode pattern as what defined the rule. For instance, `partial&partialMerge` means + * that aggregate plans can be only replaced if they contain AggExpressions of Partial and + * contain AggExpressions of PartialMerge and don't contain AggExpressions of other modes. + * + * In practice, we need to combine `|` and `&` to form some sophisticated patterns. For instance, + * `final&complete|final|partialMerge` represents aggregate plans in three different patterns are + * GPU-replaceable: plans contain both Final and Complete modes; plans only contain Final mode; + * plans only contain PartialMerge mode. + */ + private def tagForReplaceMode(): Unit = { + val aggPattern = agg.aggregateExpressions.map(_.mode).toSet + val strPatternToReplace = conf.hashAggReplaceMode.toLowerCase + + if (aggPattern.nonEmpty && strPatternToReplace != "all") { + val aggPatternsCanReplace = strPatternToReplace.split("\\|").map { subPattern => + subPattern.split("&").map { + case "partial" => Partial + case "partialmerge" => PartialMerge + case "final" => Final + case "complete" => Complete + case s => throw new IllegalArgumentException(s"Invalid Aggregate Mode $s") + }.toSet + } + if (!aggPatternsCanReplace.contains(aggPattern)) { + val message = aggPattern.map(_.toString).mkString(",") + willNotWorkOnGpu(s"Replacing mode pattern `$message` hash aggregates disabled") + } else if (aggPattern == Set(Partial)) { + // In partial mode, if there are non-distinct functions and multiple distinct functions, + // non-distinct functions are computed using the First operator. The final result would be + // incorrect for non-distinct functions for partition size > 1. Reason for this is - if + // the first batch computed and sent to CPU doesn't contain all the rows required to + // compute non-distinct function(s), then Spark would consider that value as final result + // (due to First). Fall back to CPU in this case. + if (AggregateUtils.shouldFallbackMultiDistinct(agg.aggregateExpressions)) { + willNotWorkOnGpu("Aggregates of non-distinct functions with multiple distinct " + + "functions are non-deterministic for non-distinct functions as it is " + + "computed using First.") + } + } + } + + if (!conf.partialMergeDistinctEnabled && aggPattern.contains(PartialMerge)) { + willNotWorkOnGpu("Replacing Partial Merge aggregates disabled. " + + s"Set ${conf.partialMergeDistinctEnabled} to true if desired") + } + } +} + +// SPARK 2.x we can't check for the TypedImperativeAggregate properly so don't say we do the +// ObjectHashAggregate +/* +class GpuObjectHashAggregateExecMeta( + override val agg: ObjectHashAggregateExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends GpuTypedImperativeSupportedAggregateExecMeta(agg, + agg.requiredChildDistributionExpressions, conf, parent, rule) + +*/ diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperatorsMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperatorsMeta.scala new file mode 100644 index 00000000000..a3b5ea0afa6 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperatorsMeta.scala @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.{ProjectExec, SampleExec} + +class GpuProjectExecMeta( + proj: ProjectExec, + conf: RapidsConf, + p: Option[RapidsMeta[_, _]], + r: DataFromReplacementRule) extends SparkPlanMeta[ProjectExec](proj, conf, p, r) + with Logging { +} + +class GpuSampleExecMeta( + sample: SampleExec, + conf: RapidsConf, + p: Option[RapidsMeta[_, _]], + r: DataFromReplacementRule) extends SparkPlanMeta[SampleExec](sample, conf, p, r) + with Logging { +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/literalsMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/literalsMeta.scala new file mode 100644 index 00000000000..7f82b28e4ab --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/literalsMeta.scala @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.types._ + +class LiteralExprMeta( + lit: Literal, + conf: RapidsConf, + p: Option[RapidsMeta[_, _]], + r: DataFromReplacementRule) extends ExprMeta[Literal](lit, conf, p, r) { + + def withNewLiteral(newLiteral: Literal): LiteralExprMeta = + new LiteralExprMeta(newLiteral, conf, p, r) + + // There are so many of these that we don't need to print them out, unless it + // will not work on the GPU + override def print(append: StringBuilder, depth: Int, all: Boolean): Unit = { + if (!this.canThisBeReplaced || cannotRunOnGpuBecauseOfSparkPlan) { + super.print(append, depth, all) + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuBroadcastHashJoinExecMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuBroadcastHashJoinExecMeta.scala new file mode 100644 index 00000000000..92aa71728e4 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuBroadcastHashJoinExecMeta.scala @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.v2 + +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.shims.v2._ + +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.rapids.execution.{GpuHashJoin, JoinTypeChecks} + +class GpuBroadcastHashJoinMeta( + join: BroadcastHashJoinExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends GpuBroadcastJoinMeta[BroadcastHashJoinExec](join, conf, parent, rule) { + + val leftKeys: Seq[BaseExprMeta[_]] = + join.leftKeys.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val rightKeys: Seq[BaseExprMeta[_]] = + join.rightKeys.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val condition: Option[BaseExprMeta[_]] = + join.condition.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val buildSide: GpuBuildSide = GpuJoinUtils.getGpuBuildSide(join.buildSide) + + override val namedChildExprs: Map[String, Seq[BaseExprMeta[_]]] = + JoinTypeChecks.equiJoinMeta(leftKeys, rightKeys, condition) + + override val childExprs: Seq[BaseExprMeta[_]] = leftKeys ++ rightKeys ++ condition + + override def tagPlanForGpu(): Unit = { + GpuHashJoin.tagJoin(this, join.joinType, buildSide, join.leftKeys, join.rightKeys, + join.condition) + val Seq(leftChild, rightChild) = childPlans + val buildSideMeta = buildSide match { + case GpuBuildLeft => leftChild + case GpuBuildRight => rightChild + } + + if (!canBuildSideBeReplaced(buildSideMeta)) { + willNotWorkOnGpu("the broadcast for this join must be on the GPU too") + } + + if (!canThisBeReplaced) { + buildSideMeta.willNotWorkOnGpu("the BroadcastHashJoin this feeds is not on the GPU") + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuCSVScan.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuCSVScan.scala new file mode 100644 index 00000000000..5285974d0a0 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuCSVScan.scala @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.v2 + +import java.nio.charset.StandardCharsets + +import com.nvidia.spark.rapids._ + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.util.PermissiveMode +import org.apache.spark.sql.execution.datasources.csv.CSVOptions +import org.apache.spark.sql.types._ + +object GpuCSVScan { + private val supportedDateFormats = Set( + "yyyy-MM-dd", + "yyyy/MM/dd", + "yyyy-MM", + "yyyy/MM", + "MM-yyyy", + "MM/yyyy", + "MM-dd-yyyy", + "MM/dd/yyyy" + // TODO "dd-MM-yyyy" and "dd/MM/yyyy" can also be supported, but only if we set + // dayfirst to true in the parser config. This is not plumbed into the java cudf yet + // and would need to coordinate with the timestamp format too, because both cannot + // coexist + ) + + private val supportedTsPortionFormats = Set( + "HH:mm:ss.SSSXXX", + "HH:mm:ss[.SSS][XXX]", + "HH:mm", + "HH:mm:ss", + "HH:mm[:ss]", + "HH:mm:ss.SSS", + "HH:mm:ss[.SSS]" + ) + + def dateFormatInRead(csvOpts: CSVOptions): Option[String] = { + // spark 2.x uses FastDateFormat, use getPattern + Option(csvOpts.dateFormat.getPattern) + } + + def timestampFormatInRead(csvOpts: CSVOptions): Option[String] = { + // spark 2.x uses FastDateFormat, use getPattern + Option(csvOpts.timestampFormat.getPattern) + } + + def tagSupport( + sparkSession: SparkSession, + dataSchema: StructType, + readSchema: StructType, + options: Map[String, String], + meta: RapidsMeta[_, _]): Unit = { + val parsedOptions: CSVOptions = new CSVOptions( + options, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) + + if (!meta.conf.isCsvEnabled) { + meta.willNotWorkOnGpu("CSV input and output has been disabled. To enable set" + + s"${RapidsConf.ENABLE_CSV} to true") + } + + if (!meta.conf.isCsvReadEnabled) { + meta.willNotWorkOnGpu("CSV input has been disabled. To enable set" + + s"${RapidsConf.ENABLE_CSV_READ} to true") + } + + if (!parsedOptions.enforceSchema) { + meta.willNotWorkOnGpu("GpuCSVScan always enforces schemas") + } + + if (dataSchema == null || dataSchema.isEmpty) { + meta.willNotWorkOnGpu("GpuCSVScan requires a specified data schema") + } + + // 2.x only supports delimiter as char + /* + if (parsedOptions.delimiter.length > 1) { + meta.willNotWorkOnGpu("GpuCSVScan does not support multi-character delimiters") + } + */ + + // delimiter is char in 2.x + if (parsedOptions.delimiter > 127) { + meta.willNotWorkOnGpu("GpuCSVScan does not support non-ASCII delimiters") + } + + if (parsedOptions.quote > 127) { + meta.willNotWorkOnGpu("GpuCSVScan does not support non-ASCII quote chars") + } + + if (parsedOptions.comment > 127) { + meta.willNotWorkOnGpu("GpuCSVScan does not support non-ASCII comment chars") + } + + if (parsedOptions.escape != '\\') { + meta.willNotWorkOnGpu("GpuCSVScan does not support modified escape chars") + } + + if (parsedOptions.charToEscapeQuoteEscaping.isDefined) { + meta.willNotWorkOnGpu("GPU CSV Parsing does not support charToEscapeQuoteEscaping") + } + + if (StandardCharsets.UTF_8.name() != parsedOptions.charset && + StandardCharsets.US_ASCII.name() != parsedOptions.charset) { + meta.willNotWorkOnGpu("GpuCSVScan only supports UTF8 encoded data") + } + + // TODO parsedOptions.ignoreLeadingWhiteSpaceInRead cudf always does this, but not for strings + // TODO parsedOptions.ignoreTrailingWhiteSpaceInRead cudf always does this, but not for strings + // TODO parsedOptions.multiLine cudf always does this, but it is not the default and it is not + // consistent + + // 2.x doesn't have linSeparator config + // CSV text with '\n', '\r' and '\r\n' as line separators. + // Since I have no way to check in 2.x we will just assume it works for explain until + // they move to 3.x + /* + if (parsedOptions.lineSeparator.getOrElse("\n") != "\n") { + meta.willNotWorkOnGpu("GpuCSVScan only supports \"\\n\" as a line separator") + } + */ + + if (parsedOptions.parseMode != PermissiveMode) { + meta.willNotWorkOnGpu("GpuCSVScan only supports Permissive CSV parsing") + } + + // TODO parsedOptions.nanValue This is here by default so we should support it, but cudf + // make it null https://github.com/NVIDIA/spark-rapids/issues/125 + parsedOptions.positiveInf.toLowerCase() match { + case "inf" | "+inf" | "infinity" | "+infinity" => + case _ => + meta.willNotWorkOnGpu(s"the positive infinity value '${parsedOptions.positiveInf}'" + + s" is not supported'") + } + parsedOptions.negativeInf.toLowerCase() match { + case "-inf" | "-infinity" => + case _ => + meta.willNotWorkOnGpu(s"the positive infinity value '${parsedOptions.positiveInf}'" + + s" is not supported'") + } + // parsedOptions.maxCharsPerColumn does not impact the final output it is a performance + // improvement if you know the maximum size + + // parsedOptions.maxColumns was originally a performance optimization but is not used any more + + if (readSchema.map(_.dataType).contains(DateType)) { + if (!meta.conf.isCsvDateReadEnabled) { + meta.willNotWorkOnGpu("CSV reading is not 100% compatible when reading dates. " + + s"To enable it please set ${RapidsConf.ENABLE_READ_CSV_DATES} to true.") + } + dateFormatInRead(parsedOptions).foreach { dateFormat => + if (!supportedDateFormats.contains(dateFormat)) { + meta.willNotWorkOnGpu(s"the date format '${dateFormat}' is not supported'") + } + } + } + + if (!meta.conf.isCsvBoolReadEnabled && readSchema.map(_.dataType).contains(BooleanType)) { + meta.willNotWorkOnGpu("CSV reading is not 100% compatible when reading boolean. " + + s"To enable it please set ${RapidsConf.ENABLE_READ_CSV_BOOLS} to true.") + } + + if (!meta.conf.isCsvByteReadEnabled && readSchema.map(_.dataType).contains(ByteType)) { + meta.willNotWorkOnGpu("CSV reading is not 100% compatible when reading bytes. " + + s"To enable it please set ${RapidsConf.ENABLE_READ_CSV_BYTES} to true.") + } + + if (!meta.conf.isCsvShortReadEnabled && readSchema.map(_.dataType).contains(ShortType)) { + meta.willNotWorkOnGpu("CSV reading is not 100% compatible when reading shorts. " + + s"To enable it please set ${RapidsConf.ENABLE_READ_CSV_SHORTS} to true.") + } + + if (!meta.conf.isCsvIntReadEnabled && readSchema.map(_.dataType).contains(IntegerType)) { + meta.willNotWorkOnGpu("CSV reading is not 100% compatible when reading integers. " + + s"To enable it please set ${RapidsConf.ENABLE_READ_CSV_INTEGERS} to true.") + } + + if (!meta.conf.isCsvLongReadEnabled && readSchema.map(_.dataType).contains(LongType)) { + meta.willNotWorkOnGpu("CSV reading is not 100% compatible when reading longs. " + + s"To enable it please set ${RapidsConf.ENABLE_READ_CSV_LONGS} to true.") + } + + if (!meta.conf.isCsvFloatReadEnabled && readSchema.map(_.dataType).contains(FloatType)) { + meta.willNotWorkOnGpu("CSV reading is not 100% compatible when reading floats. " + + s"To enable it please set ${RapidsConf.ENABLE_READ_CSV_FLOATS} to true.") + } + + if (!meta.conf.isCsvDoubleReadEnabled && readSchema.map(_.dataType).contains(DoubleType)) { + meta.willNotWorkOnGpu("CSV reading is not 100% compatible when reading doubles. " + + s"To enable it please set ${RapidsConf.ENABLE_READ_CSV_DOUBLES} to true.") + } + + if (readSchema.map(_.dataType).contains(TimestampType)) { + if (!meta.conf.isCsvTimestampReadEnabled) { + meta.willNotWorkOnGpu("GpuCSVScan does not support parsing timestamp types. To " + + s"enable it please set ${RapidsConf.ENABLE_CSV_TIMESTAMPS} to true.") + } + + // Spark 2.x doesn't have zoneId, so use timeZone and then to id + if (!TypeChecks.areTimestampsSupported(parsedOptions.timeZone.toZoneId)) { + meta.willNotWorkOnGpu("Only UTC zone id is supported") + } + timestampFormatInRead(parsedOptions).foreach { tsFormat => + val parts = tsFormat.split("'T'", 2) + if (parts.isEmpty) { + meta.willNotWorkOnGpu(s"the timestamp format '$tsFormat' is not supported") + } + if (parts.headOption.exists(h => !supportedDateFormats.contains(h))) { + meta.willNotWorkOnGpu(s"the timestamp format '$tsFormat' is not supported") + } + if (parts.length > 1 && !supportedTsPortionFormats.contains(parts(1))) { + meta.willNotWorkOnGpu(s"the timestamp format '$tsFormat' is not supported") + } + } + } + // TODO parsedOptions.emptyValueInRead + + FileFormatChecks.tag(meta, readSchema, CsvFormatType, ReadFileOp) + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuJoinUtils.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuJoinUtils.scala new file mode 100644 index 00000000000..032b3aa39b8 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuJoinUtils.scala @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.shims.v2 + +import com.nvidia.spark.rapids.shims.v2._ + +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} + +/** + * Spark BuildSide, BuildRight, BuildLeft moved packages in Spark 3.1 + * so create GPU versions of these that can be agnostic to Spark version. + */ +sealed abstract class GpuBuildSide + +case object GpuBuildRight extends GpuBuildSide + +case object GpuBuildLeft extends GpuBuildSide + +object GpuJoinUtils { + def getGpuBuildSide(buildSide: BuildSide): GpuBuildSide = { + buildSide match { + case BuildRight => GpuBuildRight + case BuildLeft => GpuBuildLeft + case _ => throw new Exception(s"unknown build side type $buildSide") + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala new file mode 100644 index 00000000000..71943973d20 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.shims.v2 + +import com.nvidia.spark.rapids._ + +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} +import org.apache.spark.sql.types.DataTypes +import org.apache.spark.unsafe.types.UTF8String + +class GpuRegExpReplaceMeta( + expr: RegExpReplace, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends TernaryExprMeta[RegExpReplace](expr, conf, parent, rule) { + + private var pattern: Option[String] = None + + override def tagExprForGpu(): Unit = { + expr.regexp match { + case Literal(s: UTF8String, DataTypes.StringType) if s != null => + if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) { + // use GpuStringReplace + } else { + try { + pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString)) + } catch { + case e: RegexUnsupportedException => + willNotWorkOnGpu(e.getMessage) + } + } + + case _ => + willNotWorkOnGpu(s"only non-null literal strings are supported on GPU") + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuShuffledHashJoinExecMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuShuffledHashJoinExecMeta.scala new file mode 100644 index 00000000000..8b3382b3db8 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuShuffledHashJoinExecMeta.scala @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.v2 + +import com.nvidia.spark.rapids._ + +import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec +import org.apache.spark.sql.rapids.execution.{GpuHashJoin, JoinTypeChecks} + +class GpuShuffledHashJoinMeta( + join: ShuffledHashJoinExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends SparkPlanMeta[ShuffledHashJoinExec](join, conf, parent, rule) { + val leftKeys: Seq[BaseExprMeta[_]] = + join.leftKeys.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val rightKeys: Seq[BaseExprMeta[_]] = + join.rightKeys.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val condition: Option[BaseExprMeta[_]] = + join.condition.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val buildSide: GpuBuildSide = GpuJoinUtils.getGpuBuildSide(join.buildSide) + + override val childExprs: Seq[BaseExprMeta[_]] = leftKeys ++ rightKeys ++ condition + + override val namedChildExprs: Map[String, Seq[BaseExprMeta[_]]] = + JoinTypeChecks.equiJoinMeta(leftKeys, rightKeys, condition) + + override def tagPlanForGpu(): Unit = { + GpuHashJoin.tagJoin(this, join.joinType, buildSide, join.leftKeys, join.rightKeys, + join.condition) + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuSortMergeJoinMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuSortMergeJoinMeta.scala new file mode 100644 index 00000000000..fa78118b484 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuSortMergeJoinMeta.scala @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.v2 + +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.shims.v2._ + +import org.apache.spark.sql.execution.SortExec +import org.apache.spark.sql.execution.joins.SortMergeJoinExec +import org.apache.spark.sql.rapids.execution.{GpuHashJoin, JoinTypeChecks} + +class GpuSortMergeJoinMeta( + join: SortMergeJoinExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends SparkPlanMeta[SortMergeJoinExec](join, conf, parent, rule) { + + val leftKeys: Seq[BaseExprMeta[_]] = + join.leftKeys.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val rightKeys: Seq[BaseExprMeta[_]] = + join.rightKeys.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val condition: Option[BaseExprMeta[_]] = join.condition.map( + GpuOverrides.wrapExpr(_, conf, Some(this))) + val buildSide: GpuBuildSide = if (GpuHashJoin.canBuildRight(join.joinType)) { + GpuBuildRight + } else if (GpuHashJoin.canBuildLeft(join.joinType)) { + GpuBuildLeft + } else { + throw new IllegalStateException(s"Cannot build either side for ${join.joinType} join") + } + + override val childExprs: Seq[BaseExprMeta[_]] = leftKeys ++ rightKeys ++ condition + + override val namedChildExprs: Map[String, Seq[BaseExprMeta[_]]] = + JoinTypeChecks.equiJoinMeta(leftKeys, rightKeys, condition) + + override def tagPlanForGpu(): Unit = { + // Use conditions from Hash Join + GpuHashJoin.tagJoin(this, join.joinType, buildSide, join.leftKeys, join.rightKeys, + join.condition) + + if (!conf.enableReplaceSortMergeJoin) { + willNotWorkOnGpu(s"Not replacing sort merge join with hash join, " + + s"see ${RapidsConf.ENABLE_REPLACE_SORTMERGEJOIN.key}") + } + + // make sure this is the last check - if this is SortMergeJoin, the children can be Sorts and we + // want to validate they can run on GPU and remove them before replacing this with a + // ShuffleHashJoin + if (canThisBeReplaced) { + childPlans.foreach { plan => + if (plan.wrapped.isInstanceOf[SortExec]) { + if (!plan.canThisBeReplaced) { + willNotWorkOnGpu(s"can't replace sortMergeJoin because one of the SortExec's before " + + s"can't be replaced.") + } else { + plan.shouldBeRemoved("replacing sortMergeJoin with shuffleHashJoin") + } + } + } + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/OffsetWindowFunctionMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/OffsetWindowFunctionMeta.scala new file mode 100644 index 00000000000..a245ebebfdc --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/OffsetWindowFunctionMeta.scala @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.v2 + +import com.nvidia.spark.rapids.{BaseExprMeta, DataFromReplacementRule, ExprMeta, GpuOverrides, RapidsConf, RapidsMeta} + +import org.apache.spark.sql.catalyst.expressions.{Lag, Lead, Literal, OffsetWindowFunction} +import org.apache.spark.sql.types.IntegerType + +abstract class OffsetWindowFunctionMeta[INPUT <: OffsetWindowFunction] ( + expr: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends ExprMeta[INPUT](expr, conf, parent, rule) { + lazy val input: BaseExprMeta[_] = GpuOverrides.wrapExpr(expr.input, conf, Some(this)) + lazy val offset: BaseExprMeta[_] = { + expr match { + case _: Lead => // Supported. + case _: Lag => // Supported. + case other => + throw new IllegalStateException( + s"Only LEAD/LAG offset window functions are supported. Found: $other") + } + + val literalOffset = GpuOverrides.extractLit(expr.offset) match { + case Some(Literal(offset: Int, IntegerType)) => + Literal(offset, IntegerType) + case _ => + throw new IllegalStateException( + s"Only integer literal offsets are supported for LEAD/LAG. Found: ${expr.offset}") + } + + GpuOverrides.wrapExpr(literalOffset, conf, Some(this)) + } + lazy val default: BaseExprMeta[_] = GpuOverrides.wrapExpr(expr.default, conf, Some(this)) + + override val childExprs: Seq[BaseExprMeta[_]] = Seq(input, offset, default) + + override def tagExprForGpu(): Unit = { + expr match { + case _: Lead => // Supported. + case _: Lag => // Supported. + case other => + willNotWorkOnGpu( s"Only LEAD/LAG offset window functions are supported. Found: $other") + } + + if (GpuOverrides.extractLit(expr.offset).isEmpty) { // Not a literal offset. + willNotWorkOnGpu( + s"Only integer literal offsets are supported for LEAD/LAG. Found: ${expr.offset}") + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala new file mode 100644 index 00000000000..7378c93aed6 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.v2 + +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, TernaryExpression, UnaryExpression} +import org.apache.spark.sql.catalyst.plans.logical.Command +import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan, UnaryExecNode} + +trait ShimExpression extends Expression + +trait ShimUnaryExpression extends UnaryExpression + +trait ShimBinaryExpression extends BinaryExpression + +trait ShimTernaryExpression extends TernaryExpression { + def first: Expression + def second: Expression + def third: Expression + final def children: Seq[Expression] = IndexedSeq(first, second, third) +} + +trait ShimSparkPlan extends SparkPlan + +trait ShimUnaryExecNode extends UnaryExecNode + +trait ShimBinaryExecNode extends BinaryExecNode + +trait ShimUnaryCommand extends Command diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/TypeSigUtil.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/TypeSigUtil.scala new file mode 100644 index 00000000000..39b7af5f6d9 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/TypeSigUtil.scala @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.v2 + +import com.nvidia.spark.rapids.{TypeEnum, TypeSig} + +import org.apache.spark.sql.types.DataType + +/** TypeSig Support for [3.0.1, 3.2.0) */ +object TypeSigUtil extends com.nvidia.spark.rapids.TypeSigUtilBase { + + /** + * Check if this type of Spark-specific is supported by the plugin or not. + * + * @param check the Supported Types + * @param dataType the data type to be checked + * @return true if it is allowed else false. + */ + override def isSupported( + check: TypeEnum.ValueSet, + dataType: DataType): Boolean = false + + /** + * Get all supported types for the spark-specific + * + * @return the all supported typ + */ + override def getAllSupportedTypes(): TypeEnum.ValueSet = + TypeEnum.values - TypeEnum.DAYTIME - TypeEnum.YEARMONTH + + /** + * Return the reason why this type is not supported.\ + * + * @param check the Supported Types + * @param dataType the data type to be checked + * @param notSupportedReason the reason for not supporting + * @return the reason + */ + override def reasonNotSupported( + check: TypeEnum.ValueSet, + dataType: DataType, + notSupportedReason: Seq[String]): Seq[String] = notSupportedReason + + /** + * Map DataType to TypeEnum + * + * @param dataType the data type to be mapped + * @return the TypeEnum + */ + override def mapDataTypeToTypeEnum(dataType: DataType): TypeEnum.Value = TypeEnum.UDT + + /** Get numeric and interval TypeSig */ + override def getNumericAndInterval(): TypeSig = + TypeSig.cpuNumeric + TypeSig.CALENDAR +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuPandasMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuPandasMeta.scala new file mode 100644 index 00000000000..22f60110a35 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuPandasMeta.scala @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.v2 + +import com.nvidia.spark.rapids._ + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.python.{AggregateInPandasExec, ArrowEvalPythonExec, FlatMapGroupsInPandasExec, WindowInPandasExec} +import org.apache.spark.sql.types._ + +abstract class GpuWindowInPandasExecMetaBase( + winPandas: WindowInPandasExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends SparkPlanMeta[WindowInPandasExec](winPandas, conf, parent, rule) { + + override def replaceMessage: String = "partially run on GPU" + override def noReplacementPossibleMessage(reasons: String): String = + s"cannot run even partially on the GPU because $reasons" + + val windowExpressions: Seq[BaseExprMeta[NamedExpression]] + + val partitionSpec: Seq[BaseExprMeta[Expression]] = + winPandas.partitionSpec.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + + val orderSpec: Seq[BaseExprMeta[SortOrder]] = + winPandas.orderSpec.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + + // Same check with that in GpuWindowExecMeta + override def tagPlanForGpu(): Unit = { + // Implementation depends on receiving a `NamedExpression` wrapped WindowExpression. + windowExpressions.map(meta => meta.wrapped) + .filter(expr => !expr.isInstanceOf[NamedExpression]) + .foreach(_ => willNotWorkOnGpu(because = "Unexpected query plan with Windowing" + + " Pandas UDF; cannot convert for GPU execution. " + + "(Detail: WindowExpression not wrapped in `NamedExpression`.)")) + + // Early check for the frame type, only supporting RowFrame for now, which is different from + // the node GpuWindowExec. + windowExpressions + .flatMap(meta => meta.wrapped.collect { case e: SpecifiedWindowFrame => e }) + .filter(swf => swf.frameType.equals(RangeFrame)) + .foreach(rf => willNotWorkOnGpu(because = s"Only support RowFrame for now," + + s" but found ${rf.frameType}")) + } +} + +class GpuAggregateInPandasExecMeta( + aggPandas: AggregateInPandasExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends SparkPlanMeta[AggregateInPandasExec](aggPandas, conf, parent, rule) { + + override def replaceMessage: String = "partially run on GPU" + override def noReplacementPossibleMessage(reasons: String): String = + s"cannot run even partially on the GPU because $reasons" + + private val groupingNamedExprs: Seq[BaseExprMeta[NamedExpression]] = + aggPandas.groupingExpressions.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + + private val udfs: Seq[BaseExprMeta[PythonUDF]] = + aggPandas.udfExpressions.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + + private val resultNamedExprs: Seq[BaseExprMeta[NamedExpression]] = + aggPandas.resultExpressions.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + + override val childExprs: Seq[BaseExprMeta[_]] = groupingNamedExprs ++ udfs ++ resultNamedExprs +} + +class GpuFlatMapGroupsInPandasExecMeta( + flatPandas: FlatMapGroupsInPandasExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends SparkPlanMeta[FlatMapGroupsInPandasExec](flatPandas, conf, parent, rule) { + + override def replaceMessage: String = "partially run on GPU" + override def noReplacementPossibleMessage(reasons: String): String = + s"cannot run even partially on the GPU because $reasons" + + private val groupingAttrs: Seq[BaseExprMeta[Attribute]] = + flatPandas.groupingAttributes.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + + private val udf: BaseExprMeta[PythonUDF] = GpuOverrides.wrapExpr( + flatPandas.func.asInstanceOf[PythonUDF], conf, Some(this)) + + private val resultAttrs: Seq[BaseExprMeta[Attribute]] = + flatPandas.output.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + + override val childExprs: Seq[BaseExprMeta[_]] = groupingAttrs ++ resultAttrs :+ udf +} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala new file mode 100644 index 00000000000..4111942c863 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala @@ -0,0 +1,333 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.v2 + +import java.util.concurrent.TimeUnit + +import com.nvidia.spark.rapids.{BaseExprMeta, DataFromReplacementRule, ExprMeta, GpuOverrides, RapidsConf, RapidsMeta} + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +abstract class GpuWindowExpressionMetaBase( + windowExpression: WindowExpression, + conf: RapidsConf, + parent: Option[RapidsMeta[_,_]], + rule: DataFromReplacementRule) + extends ExprMeta[WindowExpression](windowExpression, conf, parent, rule) { + + private def getAndCheckRowBoundaryValue(boundary: Expression) : Int = boundary match { + case literal: Literal => + literal.dataType match { + case IntegerType => + literal.value.asInstanceOf[Int] + case t => + willNotWorkOnGpu(s"unsupported window boundary type $t") + -1 + } + case UnboundedPreceding => Int.MinValue + case UnboundedFollowing => Int.MaxValue + case CurrentRow => 0 + case _ => + willNotWorkOnGpu("unsupported window boundary type") + -1 + } + + /** Tag if RangeFrame expression is supported */ + def tagOtherTypesForRangeFrame(bounds: Expression): Unit = { + willNotWorkOnGpu(s"the type of boundary is not supported in a window range" + + s" function, found $bounds") + } + + override def tagExprForGpu(): Unit = { + + // Must have two children: + // 1. An AggregateExpression as the window function: SUM, MIN, MAX, COUNT + // 2. A WindowSpecDefinition, defining the window-bounds, partitioning, and ordering. + val windowFunction = wrapped.windowFunction + + wrapped.windowSpec.frameSpecification match { + case spec: SpecifiedWindowFrame => + spec.frameType match { + case RowFrame => + // Will also verify that the types are what we expect. + val lower = getAndCheckRowBoundaryValue(spec.lower) + val upper = getAndCheckRowBoundaryValue(spec.upper) + windowFunction match { + case _: Lead | _: Lag => // ignored we are good + case _ => + // need to be sure that the lower/upper are acceptable + if (lower > 0) { + willNotWorkOnGpu(s"lower-bounds ahead of current row is not supported. " + + s"Found $lower") + } + if (upper < 0) { + willNotWorkOnGpu(s"upper-bounds behind the current row is not supported. " + + s"Found $upper") + } + } + case RangeFrame => + // Spark by default does a RangeFrame if no RowFrame is given + // even for columns that are not time type columns. We can switch this to row + // based iff the ranges we are looking at both unbounded. + if (spec.isUnbounded) { + // this is okay because we will translate it to be a row query + } else { + // check whether order by column is supported or not + val orderSpec = wrapped.windowSpec.orderSpec + if (orderSpec.length > 1) { + // We only support a single order by column + willNotWorkOnGpu("only a single date/time or integral (Boolean exclusive)" + + "based column in window range functions is supported") + } + val orderByTypeSupported = orderSpec.forall { so => + so.dataType match { + case ByteType | ShortType | IntegerType | LongType | + DateType | TimestampType => true + case _ => false + } + } + if (!orderByTypeSupported) { + willNotWorkOnGpu(s"the type of orderBy column is not supported in a window" + + s" range function, found ${orderSpec.head.dataType}") + } + + def checkRangeBoundaryConfig(dt: DataType): Unit = { + dt match { + case ByteType => if (!conf.isRangeWindowByteEnabled) willNotWorkOnGpu( + s"Range window frame is not 100% compatible when the order by type is " + + s"byte and the range value calculated has overflow. " + + s"To enable it please set ${RapidsConf.ENABLE_RANGE_WINDOW_BYTES} to true.") + case ShortType => if (!conf.isRangeWindowShortEnabled) willNotWorkOnGpu( + s"Range window frame is not 100% compatible when the order by type is " + + s"short and the range value calculated has overflow. " + + s"To enable it please set ${RapidsConf.ENABLE_RANGE_WINDOW_SHORT} to true.") + case IntegerType => if (!conf.isRangeWindowIntEnabled) willNotWorkOnGpu( + s"Range window frame is not 100% compatible when the order by type is " + + s"int and the range value calculated has overflow. " + + s"To enable it please set ${RapidsConf.ENABLE_RANGE_WINDOW_INT} to true.") + case LongType => if (!conf.isRangeWindowLongEnabled) willNotWorkOnGpu( + s"Range window frame is not 100% compatible when the order by type is " + + s"long and the range value calculated has overflow. " + + s"To enable it please set ${RapidsConf.ENABLE_RANGE_WINDOW_LONG} to true.") + case _ => // never reach here + } + } + + // check whether the boundaries are supported or not. + Seq(spec.lower, spec.upper).foreach { + case l @ Literal(_, ByteType | ShortType | IntegerType | LongType) => + checkRangeBoundaryConfig(l.dataType) + case Literal(ci: CalendarInterval, CalendarIntervalType) => + // interval is only working for TimeStampType + if (ci.months != 0) { + willNotWorkOnGpu("interval months isn't supported") + } + case UnboundedFollowing | UnboundedPreceding | CurrentRow => + case anythings => tagOtherTypesForRangeFrame(anythings) + } + } + } + case other => + willNotWorkOnGpu(s"only SpecifiedWindowFrame is a supported window-frame specification. " + + s"Found ${other.prettyName}") + } + } +} + +abstract class GpuSpecifiedWindowFrameMetaBase( + windowFrame: SpecifiedWindowFrame, + conf: RapidsConf, + parent: Option[RapidsMeta[_,_]], + rule: DataFromReplacementRule) + extends ExprMeta[SpecifiedWindowFrame](windowFrame, conf, parent, rule) { + + // SpecifiedWindowFrame has no associated dataType. + override val ignoreUnsetDataTypes: Boolean = true + + /** + * Tag RangeFrame for other types and get the value + */ + def getAndTagOtherTypesForRangeFrame(bounds : Expression, isLower : Boolean): Long = { + willNotWorkOnGpu(s"Bounds for Range-based window frames must be specified in Integral" + + s" type (Boolean exclusive) or CalendarInterval. Found ${bounds.dataType}") + if (isLower) -1 else 1 // not check again + } + + override def tagExprForGpu(): Unit = { + if (windowFrame.frameType.equals(RangeFrame)) { + // Expect either SpecialFrame (UNBOUNDED PRECEDING/FOLLOWING, or CURRENT ROW), + // or CalendarIntervalType in days. + + // Check that: + // 1. if `bounds` is specified as a Literal, it is specified in DAYS. + // 2. if `bounds` is a lower-bound, it can't be ahead of the current row. + // 3. if `bounds` is an upper-bound, it can't be behind the current row. + def checkIfInvalid(bounds : Expression, isLower : Boolean) : Option[String] = { + + if (!bounds.isInstanceOf[Literal]) { + // Bounds are likely SpecialFrameBoundaries (CURRENT_ROW, UNBOUNDED PRECEDING/FOLLOWING). + return None + } + + val value: Long = bounds match { + case Literal(value, ByteType) => value.asInstanceOf[Byte].toLong + case Literal(value, ShortType) => value.asInstanceOf[Short].toLong + case Literal(value, IntegerType) => value.asInstanceOf[Int].toLong + case Literal(value, LongType) => value.asInstanceOf[Long] + case Literal(ci: CalendarInterval, CalendarIntervalType) => + if (ci.months != 0) { + willNotWorkOnGpu("interval months isn't supported") + } + // return the total microseconds + try { + // Spark 2.x different - no days, just months and microseconds + // could remove this catch but leaving for now + /* + Math.addExact( + Math.multiplyExact(ci.days.toLong, TimeUnit.DAYS.toMicros(1)), + ci.microseconds) + */ + ci.microseconds + } catch { + case _: ArithmeticException => + willNotWorkOnGpu("windows over timestamps are converted to microseconds " + + s"and $ci is too large to fit") + if (isLower) -1 else 1 // not check again + } + case _ => getAndTagOtherTypesForRangeFrame(bounds, isLower) + } + + if (isLower && value > 0) { + Some(s"Lower-bounds ahead of current row is not supported. Found: $value") + } else if (!isLower && value < 0) { + Some(s"Upper-bounds behind current row is not supported. Found: $value") + } else { + None + } + } + + val invalidUpper = checkIfInvalid(windowFrame.upper, isLower = false) + if (invalidUpper.nonEmpty) { + willNotWorkOnGpu(invalidUpper.get) + } + + val invalidLower = checkIfInvalid(windowFrame.lower, isLower = true) + if (invalidLower.nonEmpty) { + willNotWorkOnGpu(invalidLower.get) + } + } + + if (windowFrame.frameType.equals(RowFrame)) { + + windowFrame.lower match { + case literal : Literal => + if (!literal.value.isInstanceOf[Int]) { + willNotWorkOnGpu(s"Literal Lower-bound of ROWS window-frame must be of INT type. " + + s"Found ${literal.dataType}") + } + // We don't support a lower bound > 0 except for lead/lag where it is required + // That check is done in GpuWindowExpressionMeta where it knows what type of operation + // is being done + case UnboundedPreceding => + case CurrentRow => + case _ => + willNotWorkOnGpu(s"Lower-bound of ROWS window-frame must be an INT literal," + + s"UNBOUNDED PRECEDING, or CURRENT ROW. " + + s"Found unexpected bound: ${windowFrame.lower.prettyName}") + } + + windowFrame.upper match { + case literal : Literal => + if (!literal.value.isInstanceOf[Int]) { + willNotWorkOnGpu(s"Literal Upper-bound of ROWS window-frame must be of INT type. " + + s"Found ${literal.dataType}") + } + // We don't support a upper bound < 0 except for lead/lag where it is required + // That check is done in GpuWindowExpressionMeta where it knows what type of operation + // is being done + case UnboundedFollowing => + case CurrentRow => + case _ => willNotWorkOnGpu(s"Upper-bound of ROWS window-frame must be an INT literal," + + s"UNBOUNDED FOLLOWING, or CURRENT ROW. " + + s"Found unexpected bound: ${windowFrame.upper.prettyName}") + } + } + } +} + +class GpuSpecifiedWindowFrameMeta( + windowFrame: SpecifiedWindowFrame, + conf: RapidsConf, + parent: Option[RapidsMeta[_,_]], + rule: DataFromReplacementRule) + extends GpuSpecifiedWindowFrameMetaBase(windowFrame, conf, parent, rule) {} + +class GpuWindowExpressionMeta( + windowExpression: WindowExpression, + conf: RapidsConf, + parent: Option[RapidsMeta[_,_]], + rule: DataFromReplacementRule) + extends GpuWindowExpressionMetaBase(windowExpression, conf, parent, rule) {} + +object GpuWindowUtil { + + /** + * Check if the type of RangeFrame is valid in GpuWindowSpecDefinition + * @param orderSpecType the first order by data type + * @param ft the first frame boundary data type + * @return true to valid, false to invalid + */ + def isValidRangeFrameType(orderSpecType: DataType, ft: DataType): Boolean = { + (orderSpecType, ft) match { + case (DateType, IntegerType) => true + case (TimestampType, CalendarIntervalType) => true + case (a, b) => a == b + } + } + + def getRangeBoundaryValue(boundary: Expression): ParsedBoundary = boundary match { + case anything => throw new UnsupportedOperationException("Unsupported window frame" + + s" expression $anything") + } +} + +case class ParsedBoundary(isUnbounded: Boolean, valueAsLong: Long) + +class GpuWindowSpecDefinitionMeta( + windowSpec: WindowSpecDefinition, + conf: RapidsConf, + parent: Option[RapidsMeta[_,_]], + rule: DataFromReplacementRule) + extends ExprMeta[WindowSpecDefinition](windowSpec, conf, parent, rule) { + + val partitionSpec: Seq[BaseExprMeta[Expression]] = + windowSpec.partitionSpec.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val orderSpec: Seq[BaseExprMeta[SortOrder]] = + windowSpec.orderSpec.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val windowFrame: BaseExprMeta[WindowFrame] = + GpuOverrides.wrapExpr(windowSpec.frameSpecification, conf, Some(this)) + + override val ignoreUnsetDataTypes: Boolean = true + + override def tagExprForGpu(): Unit = { + if (!windowSpec.frameSpecification.isInstanceOf[SpecifiedWindowFrame]) { + willNotWorkOnGpu(s"WindowFunctions without a SpecifiedWindowFrame are unsupported.") + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/org/apache/spark/TrampolineUtil.scala b/spark2-sql-plugin/src/main/scala/org/apache/spark/TrampolineUtil.scala new file mode 100644 index 00000000000..d092ead6225 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/org/apache/spark/TrampolineUtil.scala @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, IdentityBroadcastMode} +import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode +import org.apache.spark.sql.types.DataType +import org.apache.spark.util.Utils + +object TrampolineUtil { + + // package has to be different to access 2.x HashedRelationBroadcastMode + def isSupportedRelation(mode: BroadcastMode): Boolean = mode match { + case _ : HashedRelationBroadcastMode => true + case IdentityBroadcastMode => true + case _ => false + } + + /** + * Return true if the provided predicate function returns true for any + * type node within the datatype tree. + */ + def dataTypeExistsRecursively(dt: DataType, f: DataType => Boolean): Boolean = { + dt.existsRecursively(f) + } + + /** Get the simple name of a class with fixup for any Scala internal errors */ + def getSimpleName(cls: Class[_]): String = { + Utils.getSimpleName(cls) + } +} diff --git a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveOverrides.scala b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveOverrides.scala new file mode 100644 index 00000000000..2b623584560 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveOverrides.scala @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.rapids + +import com.nvidia.spark.RapidsUDF +import com.nvidia.spark.rapids.{ExprChecks, ExprMeta, ExprRule, GpuOverrides, RapidsConf, RepeatingParamCheck, TypeSig} +import com.nvidia.spark.rapids.GpuUserDefinedFunction.udfTypeSig + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.hive.{HiveGenericUDF, HiveSimpleUDF} + +object GpuHiveOverrides { + def isSparkHiveAvailable: Boolean = { + try { + getClass().getClassLoader.loadClass("org.apache.spark.sql.hive.HiveSessionStateBuilder") + getClass().getClassLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") + true + } catch { + case _: ClassNotFoundException | _: NoClassDefFoundError => false + } + } + + /** + * Builds the rules that are specific to spark-hive Catalyst nodes. This will return an empty + * mapping if spark-hive is unavailable. + */ + def exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = { + if (!isSparkHiveAvailable) { + return Map.empty + } + + Seq( + GpuOverrides.expr[HiveSimpleUDF]( + "Hive UDF, the UDF can choose to implement a RAPIDS accelerated interface to" + + " get better performance", + ExprChecks.projectOnly( + udfTypeSig, + TypeSig.all, + repeatingParamCheck = Some(RepeatingParamCheck("param", udfTypeSig, TypeSig.all))), + (a, conf, p, r) => new ExprMeta[HiveSimpleUDF](a, conf, p, r) { + private val opRapidsFunc = a.function match { + case rapidsUDF: RapidsUDF => Some(rapidsUDF) + case _ => None + } + + override def tagExprForGpu(): Unit = { + if (opRapidsFunc.isEmpty && !conf.isCpuBasedUDFEnabled) { + willNotWorkOnGpu(s"Hive SimpleUDF ${a.name} implemented by " + + s"${a.funcWrapper.functionClassName} does not provide a GPU implementation " + + s"and CPU-based UDFs are not enabled by `${RapidsConf.ENABLE_CPU_BASED_UDF.key}`") + } + } + }), + GpuOverrides.expr[HiveGenericUDF]( + "Hive Generic UDF, the UDF can choose to implement a RAPIDS accelerated interface to" + + " get better performance", + ExprChecks.projectOnly( + udfTypeSig, + TypeSig.all, + repeatingParamCheck = Some(RepeatingParamCheck("param", udfTypeSig, TypeSig.all))), + (a, conf, p, r) => new ExprMeta[HiveGenericUDF](a, conf, p, r) { + private val opRapidsFunc = a.function match { + case rapidsUDF: RapidsUDF => Some(rapidsUDF) + case _ => None + } + + override def tagExprForGpu(): Unit = { + if (opRapidsFunc.isEmpty && !conf.isCpuBasedUDFEnabled) { + willNotWorkOnGpu(s"Hive GenericUDF ${a.name} implemented by " + + s"${a.funcWrapper.functionClassName} does not provide a GPU implementation " + + s"and CPU-based UDFs are not enabled by `${RapidsConf.ENABLE_CPU_BASED_UDF.key}`") + } + } + }) + ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap + } +} diff --git a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSource.scala b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSource.scala new file mode 100644 index 00000000000..bed39dfed88 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSource.scala @@ -0,0 +1,180 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import java.util.{Locale, ServiceConfigurationError, ServiceLoader} + +import scala.collection.JavaConverters._ +import scala.util.{Failure, Success, Try} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider +import org.apache.spark.sql.execution.datasources.json.JsonFileFormat +import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources._ +import org.apache.spark.util.Utils + +object GpuDataSource extends Logging { + + /** A map to maintain backward compatibility in case we move data sources around. */ + private val backwardCompatibilityMap: Map[String, String] = { + val jdbc = classOf[JdbcRelationProvider].getCanonicalName + val json = classOf[JsonFileFormat].getCanonicalName + val parquet = classOf[ParquetFileFormat].getCanonicalName + val csv = classOf[CSVFileFormat].getCanonicalName + val libsvm = "org.apache.spark.ml.source.libsvm.LibSVMFileFormat" + val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" + val nativeOrc = classOf[OrcFileFormat].getCanonicalName + val socket = classOf[TextSocketSourceProvider].getCanonicalName + val rate = classOf[RateStreamProvider].getCanonicalName + + Map( + "org.apache.spark.sql.jdbc" -> jdbc, + "org.apache.spark.sql.jdbc.DefaultSource" -> jdbc, + "org.apache.spark.sql.execution.datasources.jdbc.DefaultSource" -> jdbc, + "org.apache.spark.sql.execution.datasources.jdbc" -> jdbc, + "org.apache.spark.sql.json" -> json, + "org.apache.spark.sql.json.DefaultSource" -> json, + "org.apache.spark.sql.execution.datasources.json" -> json, + "org.apache.spark.sql.execution.datasources.json.DefaultSource" -> json, + "org.apache.spark.sql.parquet" -> parquet, + "org.apache.spark.sql.parquet.DefaultSource" -> parquet, + "org.apache.spark.sql.execution.datasources.parquet" -> parquet, + "org.apache.spark.sql.execution.datasources.parquet.DefaultSource" -> parquet, + "org.apache.spark.sql.hive.orc.DefaultSource" -> orc, + "org.apache.spark.sql.hive.orc" -> orc, + "org.apache.spark.sql.execution.datasources.orc.DefaultSource" -> nativeOrc, + "org.apache.spark.sql.execution.datasources.orc" -> nativeOrc, + "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, + "org.apache.spark.ml.source.libsvm" -> libsvm, + "com.databricks.spark.csv" -> csv, + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket, + "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate + ) + } + + /** + * Class that were removed in Spark 2.0. Used to detect incompatibility libraries for Spark 2.0. + */ + private val spark2RemovedClasses = Set( + "org.apache.spark.sql.DataFrame", + "org.apache.spark.sql.sources.HadoopFsRelationProvider", + "org.apache.spark.Logging") + + + /** Given a provider name, look up the data source class definition. */ + def lookupDataSource(provider: String, conf: SQLConf): Class[_] = { + val provider1 = backwardCompatibilityMap.getOrElse(provider, provider) match { + case name if name.equalsIgnoreCase("orc") && + conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "native" => + classOf[OrcFileFormat].getCanonicalName + case name if name.equalsIgnoreCase("orc") && + conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "hive" => + "org.apache.spark.sql.hive.orc.OrcFileFormat" + case "com.databricks.spark.avro" if conf.replaceDatabricksSparkAvroEnabled => + "org.apache.spark.sql.avro.AvroFileFormat" + case name => name + } + val provider2 = s"$provider1.DefaultSource" + val loader = Utils.getContextOrSparkClassLoader + val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) + + try { + serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider1)).toList match { + // the provider format did not match any given registered aliases + case Nil => + try { + Try(loader.loadClass(provider1)).orElse(Try(loader.loadClass(provider2))) match { + case Success(dataSource) => + // Found the data source using fully qualified path + dataSource + case Failure(error) => + if (provider1.startsWith("org.apache.spark.sql.hive.orc")) { + throw new AnalysisException( + "Hive built-in ORC data source must be used with Hive support enabled. " + + "Please use the native ORC data source by setting 'spark.sql.orc.impl' to " + + "'native'") + } else if (provider1.toLowerCase(Locale.ROOT) == "avro" || + provider1 == "com.databricks.spark.avro" || + provider1 == "org.apache.spark.sql.avro") { + throw new AnalysisException( + s"Failed to find data source: $provider1. Avro is built-in but external data " + + "source module since Spark 2.4. Please deploy the application as per " + + "the deployment section of \"Apache Avro Data Source Guide\".") + } else if (provider1.toLowerCase(Locale.ROOT) == "kafka") { + throw new AnalysisException( + s"Failed to find data source: $provider1. Please deploy the application as " + + "per the deployment section of " + + "\"Structured Streaming + Kafka Integration Guide\".") + } else { + throw new ClassNotFoundException( + s"Failed to find data source: $provider1. Please find packages at " + + "http://spark.apache.org/third-party-projects.html", + error) + } + } + } catch { + case e: NoClassDefFoundError => // This one won't be caught by Scala NonFatal + // NoClassDefFoundError's class name uses "/" rather than "." for packages + val className = e.getMessage.replaceAll("/", ".") + if (spark2RemovedClasses.contains(className)) { + throw new ClassNotFoundException(s"$className was removed in Spark 2.0. " + + "Please check if your library is compatible with Spark 2.0", e) + } else { + throw e + } + } + case head :: Nil => + // there is exactly one registered alias + head.getClass + case sources => + // There are multiple registered aliases for the input. If there is single datasource + // that has "org.apache.spark" package in the prefix, we use it considering it is an + // internal datasource within Spark. + val sourceNames = sources.map(_.getClass.getName) + val internalSources = sources.filter(_.getClass.getName.startsWith("org.apache.spark")) + if (internalSources.size == 1) { + logWarning(s"Multiple sources found for $provider1 (${sourceNames.mkString(", ")}), " + + s"defaulting to the internal datasource (${internalSources.head.getClass.getName}).") + internalSources.head.getClass + } else { + throw new AnalysisException(s"Multiple sources found for $provider1 " + + s"(${sourceNames.mkString(", ")}), please specify the fully qualified class name.") + } + } + } catch { + case e: ServiceConfigurationError if e.getCause.isInstanceOf[NoClassDefFoundError] => + // NoClassDefFoundError's class name uses "/" rather than "." for packages + val className = e.getCause.getMessage.replaceAll("/", ".") + if (spark2RemovedClasses.contains(className)) { + throw new ClassNotFoundException(s"Detected an incompatible DataSourceRegister. " + + "Please remove the incompatible library from classpath or upgrade it. " + + s"Error: ${e.getMessage}", e) + } else { + throw e + } + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExec.scala b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExec.scala new file mode 100644 index 00000000000..a5c110ad23a --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExec.scala @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import com.nvidia.spark.rapids._ + +import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat + +object GpuFileSourceScanExec { + def tagSupport(meta: SparkPlanMeta[FileSourceScanExec]): Unit = { + meta.wrapped.relation.fileFormat match { + case _: CSVFileFormat => GpuReadCSVFileFormat.tagSupport(meta) + case f if GpuOrcFileFormat.isSparkOrcFormat(f) => GpuReadOrcFileFormat.tagSupport(meta) + case _: ParquetFileFormat => GpuReadParquetFileFormat.tagSupport(meta) + case f => + meta.willNotWorkOnGpu(s"unsupported file format: ${f.getClass.getCanonicalName}") + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala new file mode 100644 index 00000000000..abea9852ec0 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import com.nvidia.spark.rapids._ +import org.apache.orc.OrcConf +import org.apache.orc.OrcConf._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.orc.{OrcFileFormat, OrcOptions, OrcUtils} +import org.apache.spark.sql.types._ + +object GpuOrcFileFormat extends Logging { + // The classname used when Spark is configured to use the Hive implementation for ORC. + // Spark is not always compiled with Hive support so we cannot import from Spark jars directly. + private val HIVE_IMPL_CLASS = "org.apache.spark.sql.hive.orc.OrcFileFormat" + + def isSparkOrcFormat(format: FileFormat): Boolean = format match { + case _: OrcFileFormat => true + case f if f.getClass.getCanonicalName.equals(HIVE_IMPL_CLASS) => true + case _ => false + } + + def tagGpuSupport(meta: RapidsMeta[_, _], + spark: SparkSession, + options: Map[String, String], + schema: StructType): Unit = { + + if (!meta.conf.isOrcEnabled) { + meta.willNotWorkOnGpu("ORC input and output has been disabled. To enable set" + + s"${RapidsConf.ENABLE_ORC} to true") + } + + if (!meta.conf.isOrcWriteEnabled) { + meta.willNotWorkOnGpu("ORC output has been disabled. To enable set " + + s"${RapidsConf.ENABLE_ORC_WRITE} to true.\n" + + "Please note that, the ORC file written by spark-rapids will not include statistics " + + "in RowIndex, which will result in Spark 3.1.1+ failed to read ORC file when the filter " + + "is pushed down. This is an ORC issue, " + + "please refer to https://issues.apache.org/jira/browse/ORC-1075") + } + + FileFormatChecks.tag(meta, schema, OrcFormatType, WriteFileOp) + + val sqlConf = spark.sessionState.conf + + val parameters = CaseInsensitiveMap(options) + + case class ConfDataForTagging(orcConf: OrcConf, defaultValue: Any, message: String) + + def tagIfOrcOrHiveConfNotSupported(params: ConfDataForTagging): Unit = { + val conf = params.orcConf + val defaultValue = params.defaultValue + val message = params.message + val confValue = parameters.get(conf.getAttribute) + .orElse(parameters.get(conf.getHiveConfName)) + if (confValue.isDefined && confValue.get != defaultValue) { + logInfo(message) + } + } + + val orcOptions = new OrcOptions(options, sqlConf) + orcOptions.compressionCodec match { + case "NONE" | "SNAPPY" => + case c => meta.willNotWorkOnGpu(s"compression codec $c is not supported") + } + + // hard coding the default value as it could change in future + val supportedConf = Map( + STRIPE_SIZE.ordinal() -> + ConfDataForTagging(STRIPE_SIZE, 67108864L, "only 64MB stripe size is supported"), + BUFFER_SIZE.ordinal() -> + ConfDataForTagging(BUFFER_SIZE, 262144, "only 256KB block size is supported"), + ROW_INDEX_STRIDE.ordinal() -> + ConfDataForTagging(ROW_INDEX_STRIDE, 10000, "only 10,000 row index stride is supported"), + BLOCK_PADDING.ordinal() -> + ConfDataForTagging(BLOCK_PADDING, true, "Block padding isn't supported")) + + OrcConf.values().foreach(conf => { + if (supportedConf.contains(conf.ordinal())) { + tagIfOrcOrHiveConfNotSupported(supportedConf(conf.ordinal())) + } else { + if ((conf.getHiveConfName != null && parameters.contains(conf.getHiveConfName)) + || parameters.contains(conf.getAttribute)) { + // these configurations are implementation specific and don't apply to cudf + // The user has set them so we can't run on GPU + logInfo(s"${conf.name()} is unsupported configuration") + } + } + }) + + } +} + diff --git a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuScalaUDFMeta.scala b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuScalaUDFMeta.scala new file mode 100644 index 00000000000..693345469fa --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuScalaUDFMeta.scala @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import java.lang.invoke.SerializedLambda + +import com.nvidia.spark.RapidsUDF +import com.nvidia.spark.rapids.{DataFromReplacementRule, ExprMeta, RapidsConf, RapidsMeta} + +import org.apache.spark.sql.catalyst.expressions.ScalaUDF +import org.apache.spark.sql.execution.TrampolineUtil + +abstract class ScalaUDFMetaBase( + expr: ScalaUDF, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) extends ExprMeta(expr, conf, parent, rule) { + + lazy val opRapidsFunc = GpuScalaUDF.getRapidsUDFInstance(expr.function) + + override def tagExprForGpu(): Unit = { + if (opRapidsFunc.isEmpty && !conf.isCpuBasedUDFEnabled) { + val udfName = expr.udfName.getOrElse("UDF") + val udfClass = expr.function.getClass + willNotWorkOnGpu(s"neither $udfName implemented by $udfClass provides " + + s"a GPU implementation, nor the conf `${RapidsConf.ENABLE_CPU_BASED_UDF.key}` " + + s"is enabled") + } + } +} + +object GpuScalaUDF { + /** + * Determine if the UDF function implements the [[com.nvidia.spark.RapidsUDF]] interface, + * returning the instance if it does. The lambda wrapper that Spark applies to Java UDFs will be + * inspected if necessary to locate the user's UDF instance. + */ + def getRapidsUDFInstance(function: AnyRef): Option[RapidsUDF] = { + function match { + case f: RapidsUDF => Some(f) + case f => + try { + // This may be a lambda that Spark's UDFRegistration wrapped around a Java UDF instance. + val clazz = f.getClass + if (TrampolineUtil.getSimpleName(clazz).toLowerCase().contains("lambda")) { + // Try to find a `writeReplace` method, further indicating it is likely a lambda + // instance, and invoke it to serialize the lambda. Once serialized, captured arguments + // can be examine to locate the Java UDF instance. + // Note this relies on implementation details of Spark's UDFRegistration class. + val writeReplace = clazz.getDeclaredMethod("writeReplace") + writeReplace.setAccessible(true) + val serializedLambda = writeReplace.invoke(f).asInstanceOf[SerializedLambda] + if (serializedLambda.getCapturedArgCount == 1) { + serializedLambda.getCapturedArg(0) match { + case c: RapidsUDF => Some(c) + case _ => None + } + } else { + None + } + } else { + None + } + } catch { + case _: ClassCastException | _: NoSuchMethodException | _: SecurityException => None + } + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala new file mode 100644 index 00000000000..023386ba5b4 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala @@ -0,0 +1,177 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import java.math.BigInteger + +import com.nvidia.spark.rapids._ + +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.expressions.{ComplexTypeMergingExpression, ExpectsInputTypes, Expression, NullIntolerant} +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch + +object GpuDecimalMultiply { + // For Spark the final desired output is + // new_scale = lhs.scale + rhs.scale + // new_precision = lhs.precision + rhs.precision + 1 + // But Spark will round the final result, so we need at least one more + // decimal place on the scale to be able to do the rounding too. + + // In CUDF the output scale is the same lhs.scale + rhs.scale, but because we need one more + // we will need to increase the scale for either the lhs or the rhs so it works. We will pick + // the one with the smallest precision to do it, because it minimises the chance of requiring a + // larger data type to do the multiply. + + /** + * Get the scales that are needed for the lhs and rhs to produce the desired result. + */ + def lhsRhsNeededScales( + lhs: DecimalType, + rhs: DecimalType, + outputType: DecimalType): (Int, Int) = { + val cudfIntermediateScale = lhs.scale + rhs.scale + val requiredIntermediateScale = outputType.scale + 1 + if (requiredIntermediateScale > cudfIntermediateScale) { + // In practice this should only ever be 1, but just to be cautious... + val neededScaleDiff = requiredIntermediateScale - cudfIntermediateScale + // So we need to add some to the LHS and some to the RHS. + var addToLhs = 0 + var addToRhs = 0 + // We start by trying + // to bring them both to the same precision. + val precisionDiff = lhs.precision - rhs.precision + if (precisionDiff > 0) { + addToRhs = math.min(precisionDiff, neededScaleDiff) + } else { + addToLhs = math.min(math.abs(precisionDiff), neededScaleDiff) + } + val stillNeeded = neededScaleDiff - (addToLhs + addToRhs) + if (stillNeeded > 0) { + // We need to split it between the two + val l = stillNeeded/2 + val r = stillNeeded - l + addToLhs += l + addToRhs += r + } + (lhs.scale + addToLhs, rhs.scale + addToRhs) + } else { + (lhs.scale, rhs.scale) + } + } + + def nonRoundedIntermediatePrecision( + l: DecimalType, + r: DecimalType, + outputType: DecimalType): Int = { + // CUDF ignores the precision, except for the underlying device type, so in general we + // need to find the largest precision needed between the LHS, RHS, and intermediate output + // In practice this should probably always be outputType.precision + 1, but just to be + // cautions we calculate it all out. + val (lhsScale, rhsScale) = lhsRhsNeededScales(l, r, outputType) + val lhsPrecision = l.precision - l.scale + lhsScale + val rhsPrecision = r.precision - r.scale + rhsScale + // we add 1 to the output precision so we can round the final result to match Spark + math.max(math.max(lhsPrecision, rhsPrecision), outputType.precision + 1) + } + + def intermediatePrecision(lhs: DecimalType, rhs: DecimalType, outputType: DecimalType): Int = + math.min( + nonRoundedIntermediatePrecision(lhs, rhs, outputType), + GpuOverrides.DECIMAL128_MAX_PRECISION) + + def intermediateLhsRhsTypes( + lhs: DecimalType, + rhs: DecimalType, + outputType: DecimalType): (DecimalType, DecimalType) = { + val precision = intermediatePrecision(lhs, rhs, outputType) + val (lhsScale, rhsScale) = lhsRhsNeededScales(lhs, rhs, outputType) + (DecimalType(precision, lhsScale), DecimalType(precision, rhsScale)) + } + + def intermediateResultType( + lhs: DecimalType, + rhs: DecimalType, + outputType: DecimalType): DecimalType = { + val precision = intermediatePrecision(lhs, rhs, outputType) + DecimalType(precision, + math.min(outputType.scale + 1, GpuOverrides.DECIMAL128_MAX_PRECISION)) + } + + private[this] lazy val max128Int = new BigInteger(Array(2.toByte)).pow(127) + .subtract(BigInteger.ONE) + private[this] lazy val min128Int = new BigInteger(Array(2.toByte)).pow(127) + .negate() + +} + +object GpuDecimalDivide { + // For Spark the final desired output is + // new_scale = max(6, lhs.scale + rhs.precision + 1) + // new_precision = lhs.precision - lhs.scale + rhs.scale + new_scale + // But Spark will round the final result, so we need at least one more + // decimal place on the scale to be able to do the rounding too. + + def lhsNeededScale(rhs: DecimalType, outputType: DecimalType): Int = + outputType.scale + rhs.scale + 1 + + def lhsNeededPrecision(lhs: DecimalType, rhs: DecimalType, outputType: DecimalType): Int = { + val neededLhsScale = lhsNeededScale(rhs, outputType) + (lhs.precision - lhs.scale) + neededLhsScale + } + + def nonRoundedIntermediateArgPrecision( + lhs: DecimalType, + rhs: DecimalType, + outputType: DecimalType): Int = { + val neededLhsPrecision = lhsNeededPrecision(lhs, rhs, outputType) + math.max(neededLhsPrecision, rhs.precision) + } + + def intermediateArgPrecision(lhs: DecimalType, rhs: DecimalType, outputType: DecimalType): Int = + math.min( + nonRoundedIntermediateArgPrecision(lhs, rhs, outputType), + GpuOverrides.DECIMAL128_MAX_PRECISION) + + def intermediateLhsType( + lhs: DecimalType, + rhs: DecimalType, + outputType: DecimalType): DecimalType = { + val precision = intermediateArgPrecision(lhs, rhs, outputType) + val scale = math.min(lhsNeededScale(rhs, outputType), precision) + DecimalType(precision, scale) + } + + def intermediateRhsType( + lhs: DecimalType, + rhs: DecimalType, + outputType: DecimalType): DecimalType = { + val precision = intermediateArgPrecision(lhs, rhs, outputType) + DecimalType(precision, rhs.scale) + } + + def intermediateResultType(outputType: DecimalType): DecimalType = { + // If the user says that this will not overflow we will still + // try to do rounding for a correct answer, unless we cannot + // because it is already a scale of 38 + DecimalType( + math.min(outputType.precision + 1, GpuOverrides.DECIMAL128_MAX_PRECISION), + math.min(outputType.scale + 1, GpuOverrides.DECIMAL128_MAX_PRECISION)) + } +} diff --git a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala new file mode 100644 index 00000000000..d753532ed8f --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import com.nvidia.spark.rapids.{DataFromReplacementRule, ExprMeta, RapidsConf, RapidsMeta} + +import org.apache.spark.sql.catalyst.expressions.Sequence + +class GpuSequenceMeta( + expr: Sequence, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends ExprMeta[Sequence](expr, conf, parent, rule) { + + override def tagExprForGpu(): Unit = { + // We have to fall back to the CPU if the timeZoneId is not UTC when + // we are processing date/timestamp. + // Date/Timestamp are not enabled right now so this is probably fine. + } +} diff --git a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala new file mode 100644 index 00000000000..757efd4a37f --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import com.nvidia.spark.rapids.{BinaryExprMeta, DataFromReplacementRule, DataTypeUtils, GpuOverrides, RapidsConf, RapidsMeta} + +import org.apache.spark.sql.catalyst.expressions.{GetArrayItem, GetMapValue} + +class GpuGetArrayItemMeta( + expr: GetArrayItem, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends BinaryExprMeta[GetArrayItem](expr, conf, parent, rule) { + import GpuOverrides._ + + override def tagExprForGpu(): Unit = { + extractLit(expr.ordinal).foreach { litOrd => + // Once literal array/struct types are supported this can go away + val ord = litOrd.value + if ((ord == null || ord.asInstanceOf[Int] < 0) && DataTypeUtils.isNestedType(expr.dataType)) { + willNotWorkOnGpu("negative and null indexes are not supported for nested types") + } + } + } +} + +class GpuGetMapValueMeta( + expr: GetMapValue, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends BinaryExprMeta[GetMapValue](expr, conf, parent, rule) { +} diff --git a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala new file mode 100644 index 00000000000..a259126d633 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +sealed trait TimeParserPolicy extends Serializable +object LegacyTimeParserPolicy extends TimeParserPolicy +object ExceptionTimeParserPolicy extends TimeParserPolicy +object CorrectedTimeParserPolicy extends TimeParserPolicy diff --git a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsMeta.scala b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsMeta.scala new file mode 100644 index 00000000000..f0dcb5addb3 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsMeta.scala @@ -0,0 +1,150 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.DateUtils.TimestampFormatConversionException +import com.nvidia.spark.rapids.GpuOverrides.extractStringLit + +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, TimeZoneAwareExpression} +import org.apache.spark.sql.types._ + +case class ParseFormatMeta(separator: Char, isTimestamp: Boolean, validRegex: String) + +case class RegexReplace(search: String, replace: String) + +object GpuToTimestamp { + // We are compatible with Spark for these formats when the timeParserPolicy is CORRECTED + // or EXCEPTION. It is possible that other formats may be supported but these are the only + // ones that we have tests for. + val CORRECTED_COMPATIBLE_FORMATS = Map( + "yyyy-MM-dd" -> ParseFormatMeta('-', isTimestamp = false, + raw"\A\d{4}-\d{2}-\d{2}\Z"), + "yyyy/MM/dd" -> ParseFormatMeta('/', isTimestamp = false, + raw"\A\d{4}/\d{1,2}/\d{1,2}\Z"), + "yyyy-MM" -> ParseFormatMeta('-', isTimestamp = false, + raw"\A\d{4}-\d{2}\Z"), + "yyyy/MM" -> ParseFormatMeta('/', isTimestamp = false, + raw"\A\d{4}/\d{2}\Z"), + "dd/MM/yyyy" -> ParseFormatMeta('/', isTimestamp = false, + raw"\A\d{2}/\d{2}/\d{4}\Z"), + "yyyy-MM-dd HH:mm:ss" -> ParseFormatMeta('-', isTimestamp = true, + raw"\A\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}\Z"), + "MM-dd" -> ParseFormatMeta('-', isTimestamp = false, + raw"\A\d{2}-\d{2}\Z"), + "MM/dd" -> ParseFormatMeta('/', isTimestamp = false, + raw"\A\d{2}/\d{2}\Z"), + "dd-MM" -> ParseFormatMeta('-', isTimestamp = false, + raw"\A\d{2}-\d{2}\Z"), + "dd/MM" -> ParseFormatMeta('/', isTimestamp = false, + raw"\A\d{2}/\d{2}\Z") + ) + // We are compatible with Spark for these formats when the timeParserPolicy is LEGACY. It + // is possible that other formats may be supported but these are the only ones that we have + // tests for. + val LEGACY_COMPATIBLE_FORMATS = Map( + "yyyy-MM-dd" -> ParseFormatMeta('-', isTimestamp = false, + raw"\A\d{4}-\d{1,2}-\d{1,2}(\D|\s|\Z)"), + "yyyy/MM/dd" -> ParseFormatMeta('/', isTimestamp = false, + raw"\A\d{4}/\d{1,2}/\d{1,2}(\D|\s|\Z)"), + "dd-MM-yyyy" -> ParseFormatMeta('-', isTimestamp = false, + raw"\A\d{1,2}-\d{1,2}-\d{4}(\D|\s|\Z)"), + "dd/MM/yyyy" -> ParseFormatMeta('/', isTimestamp = false, + raw"\A\d{1,2}/\d{1,2}/\d{4}(\D|\s|\Z)"), + "yyyy-MM-dd HH:mm:ss" -> ParseFormatMeta('-', isTimestamp = true, + raw"\A\d{4}-\d{1,2}-\d{1,2}[ T]\d{1,2}:\d{1,2}:\d{1,2}(\D|\s|\Z)"), + "yyyy/MM/dd HH:mm:ss" -> ParseFormatMeta('/', isTimestamp = true, + raw"\A\d{4}/\d{1,2}/\d{1,2}[ T]\d{1,2}:\d{1,2}:\d{1,2}(\D|\s|\Z)") + ) +} + +abstract class UnixTimeExprMeta[A <: BinaryExpression with TimeZoneAwareExpression] + (expr: A, conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends BinaryExprMeta[A](expr, conf, parent, rule) { + + def shouldFallbackOnAnsiTimestamp: Boolean + + var sparkFormat: String = _ + var strfFormat: String = _ + override def tagExprForGpu(): Unit = { + checkTimeZoneId(expr.timeZoneId) + + if (shouldFallbackOnAnsiTimestamp) { + willNotWorkOnGpu("ANSI mode is not supported") + } + + // Date and Timestamp work too + if (expr.right.dataType == StringType) { + extractStringLit(expr.right) match { + case Some(rightLit) => + sparkFormat = rightLit + if (GpuOverrides.getTimeParserPolicy == LegacyTimeParserPolicy) { + try { + // try and convert the format to cuDF format - this will throw an exception if + // the format contains unsupported characters or words + strfFormat = DateUtils.toStrf(sparkFormat, + expr.left.dataType == DataTypes.StringType) + // format parsed ok but we have no 100% compatible formats in LEGACY mode + if (GpuToTimestamp.LEGACY_COMPATIBLE_FORMATS.contains(sparkFormat)) { + // LEGACY support has a number of issues that mean we cannot guarantee + // compatibility with CPU + // - we can only support 4 digit years but Spark supports a wider range + // - we use a proleptic Gregorian calender but Spark uses a hybrid Julian+Gregorian + // calender in LEGACY mode + // Spark 2.x - ansi not available + /* + if (SQLConf.get.ansiEnabled) { + willNotWorkOnGpu("LEGACY format in ANSI mode is not supported on the GPU") + } else */ + if (!conf.incompatDateFormats) { + willNotWorkOnGpu(s"LEGACY format '$sparkFormat' on the GPU is not guaranteed " + + s"to produce the same results as Spark on CPU. Set " + + s"${RapidsConf.INCOMPATIBLE_DATE_FORMATS.key}=true to force onto GPU.") + } + } else { + willNotWorkOnGpu(s"LEGACY format '$sparkFormat' is not supported on the GPU.") + } + } catch { + case e: TimestampFormatConversionException => + willNotWorkOnGpu(s"Failed to convert ${e.reason} ${e.getMessage}") + } + } else { + try { + // try and convert the format to cuDF format - this will throw an exception if + // the format contains unsupported characters or words + strfFormat = DateUtils.toStrf(sparkFormat, + expr.left.dataType == DataTypes.StringType) + // format parsed ok, so it is either compatible (tested/certified) or incompatible + if (!GpuToTimestamp.CORRECTED_COMPATIBLE_FORMATS.contains(sparkFormat) && + !conf.incompatDateFormats) { + willNotWorkOnGpu(s"CORRECTED format '$sparkFormat' on the GPU is not guaranteed " + + s"to produce the same results as Spark on CPU. Set " + + s"${RapidsConf.INCOMPATIBLE_DATE_FORMATS.key}=true to force onto GPU.") + } + } catch { + case e: TimestampFormatConversionException => + willNotWorkOnGpu(s"Failed to convert ${e.reason} ${e.getMessage}") + } + } + case None => + willNotWorkOnGpu("format has to be a string literal") + } + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExecMeta.scala b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExecMeta.scala new file mode 100644 index 00000000000..11f7aad25f2 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExecMeta.scala @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.execution + +import com.nvidia.spark.rapids._ + +import org.apache.spark.sql.execution.TrampolineUtil +import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec} + +class GpuBroadcastMeta( + exchange: BroadcastExchangeExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) extends + SparkPlanMeta[BroadcastExchangeExec](exchange, conf, parent, rule) { + + override def tagPlanForGpu(): Unit = { + if (!TrampolineUtil.isSupportedRelation(exchange.mode)) { + willNotWorkOnGpu( + "Broadcast exchange is only supported for HashedJoin or BroadcastNestedLoopJoin") + } + def isSupported(rm: RapidsMeta[_, _]): Boolean = rm.wrapped match { + case _: BroadcastHashJoinExec => true + case _: BroadcastNestedLoopJoinExec => true + case _ => false + } + if (parent.isDefined) { + if (!parent.exists(isSupported)) { + willNotWorkOnGpu("BroadcastExchange only works on the GPU if being used " + + "with a GPU version of BroadcastHashJoinExec or BroadcastNestedLoopJoinExec") + } + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinMeta.scala b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinMeta.scala new file mode 100644 index 00000000000..8793829485c --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinMeta.scala @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.execution + +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.shims.v2._ + +import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, BuildLeft, BuildRight, BuildSide} + +class GpuBroadcastNestedLoopJoinMeta( + join: BroadcastNestedLoopJoinExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends GpuBroadcastJoinMeta[BroadcastNestedLoopJoinExec](join, conf, parent, rule) { + + val conditionMeta: Option[BaseExprMeta[_]] = + join.condition.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + + val gpuBuildSide: GpuBuildSide = GpuJoinUtils.getGpuBuildSide(join.buildSide) + + override def namedChildExprs: Map[String, Seq[BaseExprMeta[_]]] = + JoinTypeChecks.nonEquiJoinMeta(conditionMeta) + + override val childExprs: Seq[BaseExprMeta[_]] = conditionMeta.toSeq + + override def tagPlanForGpu(): Unit = { + JoinTypeChecks.tagForGpu(join.joinType, this) + join.joinType match { + case _: InnerLike => + case LeftOuter | RightOuter | LeftSemi | LeftAnti => + conditionMeta.foreach(requireAstForGpuOn) + case _ => willNotWorkOnGpu(s"${join.joinType} currently is not supported") + } + join.joinType match { + case LeftOuter | LeftSemi | LeftAnti if gpuBuildSide == GpuBuildLeft => + willNotWorkOnGpu(s"build left not supported for ${join.joinType}") + case RightOuter if gpuBuildSide == GpuBuildRight => + willNotWorkOnGpu(s"build right not supported for ${join.joinType}") + case _ => + } + + val Seq(leftPlan, rightPlan) = childPlans + val buildSide = gpuBuildSide match { + case GpuBuildLeft => leftPlan + case GpuBuildRight => rightPlan + } + + if (!canBuildSideBeReplaced(buildSide)) { + willNotWorkOnGpu("the broadcast for this join must be on the GPU too") + } + + if (!canThisBeReplaced) { + buildSide.willNotWorkOnGpu( + "the BroadcastNestedLoopJoin this feeds is not on the GPU") + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala new file mode 100644 index 00000000000..c01dc173c64 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.rapids.execution + +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.shims.v2._ + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.{Cross, ExistenceJoin, FullOuter, Inner, InnerLike, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types._ + +object JoinTypeChecks { + def tagForGpu(joinType: JoinType, meta: RapidsMeta[_, _]): Unit = { + val conf = meta.conf + joinType match { + case Inner if !conf.areInnerJoinsEnabled => + meta.willNotWorkOnGpu("inner joins have been disabled. To enable set " + + s"${RapidsConf.ENABLE_INNER_JOIN.key} to true") + case Cross if !conf.areCrossJoinsEnabled => + meta.willNotWorkOnGpu("cross joins have been disabled. To enable set " + + s"${RapidsConf.ENABLE_CROSS_JOIN.key} to true") + case LeftOuter if !conf.areLeftOuterJoinsEnabled => + meta.willNotWorkOnGpu("left outer joins have been disabled. To enable set " + + s"${RapidsConf.ENABLE_LEFT_OUTER_JOIN.key} to true") + case RightOuter if !conf.areRightOuterJoinsEnabled => + meta.willNotWorkOnGpu("right outer joins have been disabled. To enable set " + + s"${RapidsConf.ENABLE_RIGHT_OUTER_JOIN.key} to true") + case FullOuter if !conf.areFullOuterJoinsEnabled => + meta.willNotWorkOnGpu("full outer joins have been disabled. To enable set " + + s"${RapidsConf.ENABLE_FULL_OUTER_JOIN.key} to true") + case LeftSemi if !conf.areLeftSemiJoinsEnabled => + meta.willNotWorkOnGpu("left semi joins have been disabled. To enable set " + + s"${RapidsConf.ENABLE_LEFT_SEMI_JOIN.key} to true") + case LeftAnti if !conf.areLeftAntiJoinsEnabled => + meta.willNotWorkOnGpu("left anti joins have been disabled. To enable set " + + s"${RapidsConf.ENABLE_LEFT_ANTI_JOIN.key} to true") + case _ => // not disabled + } + } + + val LEFT_KEYS = "leftKeys" + val RIGHT_KEYS = "rightKeys" + val CONDITION = "condition" + + private[this] val cudfSupportedKeyTypes = + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_64 + TypeSig.STRUCT).nested() + private[this] val sparkSupportedJoinKeyTypes = TypeSig.all - TypeSig.MAP.nested() + + private[this] val joinRideAlongTypes = + (cudfSupportedKeyTypes + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.MAP).nested() + + val equiJoinExecChecks: ExecChecks = ExecChecks( + joinRideAlongTypes, + TypeSig.all, + Map( + LEFT_KEYS -> InputCheck(cudfSupportedKeyTypes, sparkSupportedJoinKeyTypes), + RIGHT_KEYS -> InputCheck(cudfSupportedKeyTypes, sparkSupportedJoinKeyTypes), + CONDITION -> InputCheck(TypeSig.BOOLEAN, TypeSig.BOOLEAN))) + + def equiJoinMeta(leftKeys: Seq[BaseExprMeta[_]], + rightKeys: Seq[BaseExprMeta[_]], + condition: Option[BaseExprMeta[_]]): Map[String, Seq[BaseExprMeta[_]]] = { + Map( + LEFT_KEYS -> leftKeys, + RIGHT_KEYS -> rightKeys, + CONDITION -> condition.toSeq) + } + + val nonEquiJoinChecks: ExecChecks = ExecChecks( + joinRideAlongTypes, + TypeSig.all, + Map(CONDITION -> InputCheck(TypeSig.BOOLEAN, TypeSig.BOOLEAN, + notes = List("A non-inner join only is supported if the condition expression can be " + + "converted to a GPU AST expression")))) + + def nonEquiJoinMeta(condition: Option[BaseExprMeta[_]]): Map[String, Seq[BaseExprMeta[_]]] = + Map(CONDITION -> condition.toSeq) +} + +object GpuHashJoin { + + def tagJoin( + meta: RapidsMeta[_, _], + joinType: JoinType, + buildSide: GpuBuildSide, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + condition: Option[Expression]): Unit = { + val keyDataTypes = (leftKeys ++ rightKeys).map(_.dataType) + + def unSupportNonEqualCondition(): Unit = if (condition.isDefined) { + meta.willNotWorkOnGpu(s"$joinType joins currently do not support conditions") + } + def unSupportStructKeys(): Unit = if (keyDataTypes.exists(_.isInstanceOf[StructType])) { + meta.willNotWorkOnGpu(s"$joinType joins currently do not support with struct keys") + } + JoinTypeChecks.tagForGpu(joinType, meta) + joinType match { + case _: InnerLike => + case RightOuter | LeftOuter | LeftSemi | LeftAnti => + unSupportNonEqualCondition() + case FullOuter => + unSupportNonEqualCondition() + // FullOuter join cannot support with struct keys as two issues below + // * https://github.com/NVIDIA/spark-rapids/issues/2126 + // * https://github.com/rapidsai/cudf/issues/7947 + unSupportStructKeys() + case _ => + meta.willNotWorkOnGpu(s"$joinType currently is not supported") + } + buildSide match { + case GpuBuildLeft if !canBuildLeft(joinType) => + meta.willNotWorkOnGpu(s"$joinType does not support left-side build") + case GpuBuildRight if !canBuildRight(joinType) => + meta.willNotWorkOnGpu(s"$joinType does not support right-side build") + case _ => + } + } + + /** Determine if this type of join supports using the right side of the join as the build side. */ + def canBuildRight(joinType: JoinType): Boolean = joinType match { + case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true + case _ => false + } + + /** Determine if this type of join supports using the left side of the join as the build side. */ + def canBuildLeft(joinType: JoinType): Boolean = joinType match { + case _: InnerLike | RightOuter | FullOuter => true + case _ => false + } +} diff --git a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleMeta.scala b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleMeta.scala new file mode 100644 index 00000000000..e1abb8bd916 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleMeta.scala @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.execution + +import scala.collection.AbstractIterator +import scala.concurrent.Future + +import com.nvidia.spark.rapids._ + +import org.apache.spark.{MapOutputStatistics, ShuffleDependency} +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute} +import org.apache.spark.sql.catalyst.plans.physical.RoundRobinPartitioning +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} +import org.apache.spark.sql.execution.metric._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.MutablePair + +class GpuShuffleMeta( + shuffle: ShuffleExchangeExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends SparkPlanMeta[ShuffleExchangeExec](shuffle, conf, parent, rule) { + // Some kinds of Partitioning are a type of expression, but Partitioning itself is not + // so don't let them leak through as expressions + override val childExprs: scala.Seq[ExprMeta[_]] = Seq.empty + override val childParts: scala.Seq[PartMeta[_]] = + Seq(GpuOverrides.wrapPart(shuffle.outputPartitioning, conf, Some(this))) + + // Propagate possible type conversions on the output attributes of map-side plans to + // reduce-side counterparts. We can pass through the outputs of child because Shuffle will + // not change the data schema. And we need to pass through because Shuffle itself and + // reduce-side plans may failed to pass the type check for tagging CPU data types rather + // than their GPU counterparts. + // + // Taking AggregateExec with TypedImperativeAggregate function as example: + // Assume I have a query: SELECT a, COLLECT_LIST(b) FROM table GROUP BY a, which physical plan + // looks like: + // ObjectHashAggregate(keys=[a#10], functions=[collect_list(b#11, 0, 0)], + // output=[a#10, collect_list(b)#17]) + // +- Exchange hashpartitioning(a#10, 200), true, [id=#13] + // +- ObjectHashAggregate(keys=[a#10], functions=[partial_collect_list(b#11, 0, 0)], + // output=[a#10, buf#21]) + // +- LocalTableScan [a#10, b#11] + // + // We will override the data type of buf#21 in GpuNoHashAggregateMeta. Otherwise, the partial + // Aggregate will fall back to CPU because buf#21 produce a GPU-unsupported type: BinaryType. + // Just like the partial Aggregate, the ShuffleExchange will also fall back to CPU unless we + // apply the same type overriding as its child plan: the partial Aggregate. + override protected val useOutputAttributesOfChild: Boolean = true + + // For transparent plan like ShuffleExchange, the accessibility of runtime data transition is + // depended on the next non-transparent plan. So, we need to trace back. + override val availableRuntimeDataTransition: Boolean = + childPlans.head.availableRuntimeDataTransition + + override def tagPlanForGpu(): Unit = { + + shuffle.outputPartitioning match { + case _: RoundRobinPartitioning + if shuffle.sqlContext.sparkSession.sessionState.conf + .sortBeforeRepartition => + val orderableTypes = GpuOverrides.pluginSupportedOrderableSig + TypeSig.DECIMAL_128 + shuffle.output.map(_.dataType) + .filterNot(orderableTypes.isSupportedByPlugin) + .foreach { dataType => + willNotWorkOnGpu(s"round-robin partitioning cannot sort $dataType to run " + + s"this on the GPU set ${SQLConf.SORT_BEFORE_REPARTITION.key} to false") + } + case _ => + } + + } +} diff --git a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala new file mode 100644 index 00000000000..39810113132 --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import org.apache.spark.sql.types._ + +object GpuFloorCeil { + def unboundedOutputPrecision(dt: DecimalType): Int = { + if (dt.scale == 0) { + dt.precision + } else { + dt.precision - dt.scale + 1 + } + } +} diff --git a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringMeta.scala b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringMeta.scala new file mode 100644 index 00000000000..71ab0d831ef --- /dev/null +++ b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringMeta.scala @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import scala.collection.mutable.ArrayBuffer + +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.shims.v2.ShimExpression + +import org.apache.spark.sql.catalyst.expressions.{Literal, RegExpExtract, RLike, StringSplit, SubstringIndex} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class GpuRLikeMeta( + expr: RLike, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) extends BinaryExprMeta[RLike](expr, conf, parent, rule) { + + private var pattern: Option[String] = None + + override def tagExprForGpu(): Unit = { + expr.right match { + case Literal(str: UTF8String, DataTypes.StringType) if str != null => + try { + // verify that we support this regex and can transpile it to cuDF format + pattern = Some(new CudfRegexTranspiler(replace = false).transpile(str.toString)) + } catch { + case e: RegexUnsupportedException => + willNotWorkOnGpu(e.getMessage) + } + case _ => + willNotWorkOnGpu(s"only non-null literal strings are supported on GPU") + } + } +} + +class GpuRegExpExtractMeta( + expr: RegExpExtract, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends TernaryExprMeta[RegExpExtract](expr, conf, parent, rule) { + + private var pattern: Option[String] = None + private var numGroups = 0 + + override def tagExprForGpu(): Unit = { + + def countGroups(regexp: RegexAST): Int = { + regexp match { + case RegexGroup(_, term) => 1 + countGroups(term) + case other => other.children().map(countGroups).sum + } + } + + expr.regexp match { + case Literal(str: UTF8String, DataTypes.StringType) if str != null => + try { + val javaRegexpPattern = str.toString + // verify that we support this regex and can transpile it to cuDF format + val cudfRegexPattern = new CudfRegexTranspiler(replace = false) + .transpile(javaRegexpPattern) + pattern = Some(cudfRegexPattern) + numGroups = countGroups(new RegexParser(javaRegexpPattern).parse()) + } catch { + case e: RegexUnsupportedException => + willNotWorkOnGpu(e.getMessage) + } + case _ => + willNotWorkOnGpu(s"only non-null literal strings are supported on GPU") + } + + expr.idx match { + case Literal(value, DataTypes.IntegerType) => + val idx = value.asInstanceOf[Int] + if (idx < 0) { + willNotWorkOnGpu("the specified group index cannot be less than zero") + } + if (idx > numGroups) { + willNotWorkOnGpu( + s"regex group count is $numGroups, but the specified group index is $idx") + } + case _ => + willNotWorkOnGpu("GPU only supports literal index") + } + } +} + +class SubstringIndexMeta( + expr: SubstringIndex, + override val conf: RapidsConf, + override val parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends TernaryExprMeta[SubstringIndex](expr, conf, parent, rule) { + private var regexp: String = _ + + override def tagExprForGpu(): Unit = { + val delim = GpuOverrides.extractStringLit(expr.delimExpr).getOrElse("") + if (delim == null || delim.length != 1) { + willNotWorkOnGpu("only a single character deliminator is supported") + } + + val count = GpuOverrides.extractLit(expr.countExpr) + if (canThisBeReplaced) { + val c = count.get.value.asInstanceOf[Integer] + this.regexp = GpuSubstringIndex.makeExtractRe(delim, c) + } + } +} + +object CudfRegexp { + val escapeForCudfCharSet = Seq('^', '-', ']') + + def notCharSet(c: Char): String = c match { + case '\n' => "(?:.|\r)" + case '\r' => "(?:.|\n)" + case chr if escapeForCudfCharSet.contains(chr) => "(?:[^\\" + chr + "]|\r|\n)" + case chr => "(?:[^" + chr + "]|\r|\n)" + } + + val escapeForCudf = Seq('[', '^', '$', '.', '|', '?', '*','+', '(', ')', '\\', '{', '}') + + def cudfQuote(c: Character): String = c match { + case chr if escapeForCudf.contains(chr) => "\\" + chr + case chr => Character.toString(chr) + } +} + +object GpuSubstringIndex { + def makeExtractRe(delim: String, count: Integer): String = { + if (delim.length != 1) { + throw new IllegalStateException("NOT SUPPORTED") + } + val quotedDelim = CudfRegexp.cudfQuote(delim.charAt(0)) + val notDelim = CudfRegexp.notCharSet(delim.charAt(0)) + // substring_index has a deliminator and a count. If the count is positive then + // you get back a substring from 0 until the Nth deliminator is found + // If the count is negative it goes in reverse + if (count == 0) { + // Count is zero so return a null regexp as a special case + null + } else if (count == 1) { + // If the count is 1 we want to match everything from the beginning of the string until we + // find the first occurrence of the deliminator or the end of the string + "\\A(" + notDelim + "*)" + } else if (count > 0) { + // If the count is > 1 we first match 0 up to count - 1 occurrences of the patten + // `not the deliminator 0 or more times followed by the deliminator` + // After that we go back to matching everything until we find the deliminator or the end of + // the string + "\\A((?:" + notDelim + "*" + quotedDelim + "){0," + (count - 1) + "}" + notDelim + "*)" + } else if (count == -1) { + // A -1 looks like 1 but we start looking at the end of the string + "(" + notDelim + "*)\\Z" + } else { //count < 0 + // All others look like a positive count, but again we are matching starting at the end of + // the string instead of the beginning + "((?:" + notDelim + "*" + quotedDelim + "){0," + ((-count) - 1) + "}" + notDelim + "*)\\Z" + } + } +} + +class GpuStringSplitMeta( + expr: StringSplit, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _]], + rule: DataFromReplacementRule) + extends BinaryExprMeta[StringSplit](expr, conf, parent, rule) { + import GpuOverrides._ + + override def tagExprForGpu(): Unit = { + // 2.x uses expr.pattern not expr.regex + val regexp = extractLit(expr.pattern) + if (regexp.isEmpty) { + willNotWorkOnGpu("only literal regexp values are supported") + } else { + val str = regexp.get.value.asInstanceOf[UTF8String] + if (str != null) { + if (!canRegexpBeTreatedLikeARegularString(str)) { + willNotWorkOnGpu("regular expressions are not supported yet") + } + if (str.numChars() == 0) { + willNotWorkOnGpu("An empty regex is not supported yet") + } + } else { + willNotWorkOnGpu("null regex is not supported yet") + } + } + // 2.x has no limit parameter + /* + if (!isLit(expr.limit)) { + willNotWorkOnGpu("only literal limit is supported") + } + */ + } +}