diff --git a/docs/configs.md b/docs/configs.md index 6b9380089bd..ee7ebd380e0 100644 --- a/docs/configs.md +++ b/docs/configs.md @@ -276,6 +276,7 @@ Name | SQL Function(s) | Description | Default Value | Notes spark.rapids.sql.expression.PythonUDF| |UDF run in an external python process. Does not actually run on the GPU, but the transfer of data to/from it can be accelerated|true|None| spark.rapids.sql.expression.Quarter|`quarter`|Returns the quarter of the year for date, in the range 1 to 4|true|None| spark.rapids.sql.expression.RLike|`rlike`|Regular expression version of Like|true|None| +spark.rapids.sql.expression.RaiseError|`raise_error`|Throw an exception|true|None| spark.rapids.sql.expression.Rand|`random`, `rand`|Generate a random column with i.i.d. uniformly distributed values in [0, 1)|true|None| spark.rapids.sql.expression.Rank|`rank`|Window function that returns the rank value within the aggregation window|true|None| spark.rapids.sql.expression.RegExpExtract|`regexp_extract`|Extract a specific group identified by a regular expression|true|None| diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 3568388cc76..78f183e9439 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -10334,6 +10334,53 @@ are limited. UDT +RaiseError +`raise_error` +Throw an exception +None +project +input + + + + + + + + + +S + + + + + + + + + + +result + + + + + + + + + + + +S + + + + + + + + Rand `random`, `rand` Generate a random column with i.i.d. uniformly distributed values in [0, 1) @@ -10627,6 +10674,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Remainder `%`, `mod` Remainder or modulo @@ -10695,32 +10768,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - ReplicateRows Given an input row replicates the row N times @@ -10999,6 +11046,32 @@ are limited. NS +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Second `second` Returns the second component of the string/timestamp @@ -11135,32 +11208,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - ShiftLeft `shiftleft` Bitwise shift left (<<) @@ -11365,6 +11412,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Signum `sign`, `signum` Returns -1.0, 0.0 or 1.0 as expr is negative, 0 or positive @@ -11502,32 +11575,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Sinh `sinh` Hyperbolic sine @@ -11733,6 +11780,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + SortOrder Sort order @@ -11874,32 +11947,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Sqrt `sqrt` Square root @@ -12147,10 +12194,36 @@ are limited. -StringLocate -`position`, `locate` -Substring search operator -None +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + +StringLocate +`position`, `locate` +Substring search operator +None project substr @@ -12236,32 +12309,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - StringRPad `rpad` Pad a string on the right @@ -12508,6 +12555,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StringSplit `split` Splits `str` around occurrences that match `regex` @@ -12597,32 +12670,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - StringToMap `str_to_map` Creates a map after splitting the input string into pairs of key-value strings @@ -12916,6 +12963,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Substring `substr`, `substring` Substring operator @@ -13005,32 +13078,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - SubstringIndex `substring_index` substring_index operator @@ -13342,6 +13389,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Tanh `tanh` Hyperbolic tangent @@ -13432,32 +13505,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - TimeAdd Adds interval to timestamp @@ -13756,6 +13803,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + TransformValues `transform_values` Transform values in a map using a transform function @@ -13824,32 +13897,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - UnaryMinus `negative` Negate a numeric value @@ -14150,6 +14197,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + UnscaledValue Convert a Decimal to an unscaled long value for some aggregation optimizations @@ -14197,32 +14270,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Upper `upper`, `ucase` String uppercase operator diff --git a/integration_tests/src/main/python/misc_expr_test.py b/integration_tests/src/main/python/misc_expr_test.py index 41d01ba0065..8802a0eef35 100644 --- a/integration_tests/src/main/python/misc_expr_test.py +++ b/integration_tests/src/main/python/misc_expr_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ import pytest -from asserts import assert_gpu_and_cpu_are_equal_collect +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error from data_gen import * from marks import incompat, approximate_float from pyspark.sql.types import * @@ -31,3 +31,24 @@ def test_part_id(): lambda spark : unary_op_df(spark, short_gen, num_slices=8).select( f.col('a'), f.spark_partition_id())) + +def test_raise_error(): + data_gen = ShortGen(nullable=False, min_val=0, max_val=20, special_cases=[]) + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, data_gen, num_slices=2).select( + f.when(f.col('a') > 30, f.raise_error("unexpected")))) + + assert_gpu_and_cpu_are_equal_collect( + lambda spark: spark.range(0).select(f.raise_error(f.col("id")))) + + assert_gpu_and_cpu_error( + lambda spark: unary_op_df(spark, null_gen, length=2, num_slices=1).select( + f.raise_error(f.col('a'))).collect(), + conf={}, + error_message="java.lang.RuntimeException") + + assert_gpu_and_cpu_error( + lambda spark: unary_op_df(spark, short_gen, length=2, num_slices=1).select( + f.raise_error(f.lit("unexpected"))).collect(), + conf={}, + error_message="java.lang.RuntimeException: unexpected") diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 77574e8dd95..67da0b3220b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3477,7 +3477,15 @@ object GpuOverrides extends Logging { TypeSig.ARRAY.nested(TypeSig.commonCudfTypesWithNested), TypeSig.ARRAY.nested(TypeSig.all)), (e, conf, p, r) => new GpuGetArrayStructFieldsMeta(e, conf, p, r) - ) + ), + expr[RaiseError]( + "Throw an exception", + ExprChecks.unaryProject( + TypeSig.NULL, TypeSig.NULL, + TypeSig.STRING, TypeSig.STRING), + (a, conf, p, r) => new UnaryExprMeta[RaiseError](a, conf, p, r) { + override def convertToGpu(child: Expression): GpuExpression = GpuRaiseError(child) + }) ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap // Shim expressions should be last to allow overrides with shim-specific versions diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/misc.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/misc.scala new file mode 100644 index 00000000000..159711afa3e --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/misc.scala @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import ai.rapids.cudf.{ColumnVector} +import com.nvidia.spark.rapids.{Arm, GpuColumnVector, GpuUnaryExpression} + +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression} +import org.apache.spark.sql.types.{AbstractDataType, DataType, NullType, StringType} + +case class GpuRaiseError(child: Expression) extends GpuUnaryExpression with ExpectsInputTypes + with Arm { + + override def dataType: DataType = NullType + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + override def toString: String = s"raise_error($child)" + + /** Could evaluating this expression cause side-effects, such as throwing an exception? */ + override def hasSideEffects: Boolean = true + + override protected def doColumnar(input: GpuColumnVector): ColumnVector = { + if (input.getRowCount <= 0) { + // For the case: when(condition, raise_error(col("a")) + return GpuColumnVector.columnVectorFromNull(0, NullType) + } + + // Take the first one as the error message + withResource(input.getBase.getScalarElement(0)) { scalarMsg => + if (!scalarMsg.isValid()) { + throw new RuntimeException() + } else { + throw new RuntimeException(scalarMsg.getJavaString()) + } + } + } + +} diff --git a/tools/src/main/resources/operatorsScore.csv b/tools/src/main/resources/operatorsScore.csv index c38f3e2c3f8..178069c403c 100644 --- a/tools/src/main/resources/operatorsScore.csv +++ b/tools/src/main/resources/operatorsScore.csv @@ -172,6 +172,7 @@ PromotePrecision,3 PythonUDF,3 Quarter,3 RLike,3 +RaiseError,3 Rand,3 Rank,3 RegExpExtract,3 diff --git a/tools/src/main/resources/supportedExprs.csv b/tools/src/main/resources/supportedExprs.csv index 145cc9a61d6..d98e137e87a 100644 --- a/tools/src/main/resources/supportedExprs.csv +++ b/tools/src/main/resources/supportedExprs.csv @@ -369,6 +369,8 @@ Quarter,S,`quarter`,None,project,result,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA RLike,S,`rlike`,None,project,str,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA RLike,S,`rlike`,None,project,regexp,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA RLike,S,`rlike`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +RaiseError,S,`raise_error`,None,project,input,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA +RaiseError,S,`raise_error`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA Rand,S,`random`; `rand`,None,project,seed,NA,NA,NA,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA Rand,S,`random`; `rand`,None,project,result,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA Rank,S,`rank`,None,window,ordering,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,NS,NS