diff --git a/docs/configs.md b/docs/configs.md
index 6b9380089bd..ee7ebd380e0 100644
--- a/docs/configs.md
+++ b/docs/configs.md
@@ -276,6 +276,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
spark.rapids.sql.expression.PythonUDF| |UDF run in an external python process. Does not actually run on the GPU, but the transfer of data to/from it can be accelerated|true|None|
spark.rapids.sql.expression.Quarter|`quarter`|Returns the quarter of the year for date, in the range 1 to 4|true|None|
spark.rapids.sql.expression.RLike|`rlike`|Regular expression version of Like|true|None|
+spark.rapids.sql.expression.RaiseError|`raise_error`|Throw an exception|true|None|
spark.rapids.sql.expression.Rand|`random`, `rand`|Generate a random column with i.i.d. uniformly distributed values in [0, 1)|true|None|
spark.rapids.sql.expression.Rank|`rank`|Window function that returns the rank value within the aggregation window|true|None|
spark.rapids.sql.expression.RegExpExtract|`regexp_extract`|Extract a specific group identified by a regular expression|true|None|
diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 3568388cc76..78f183e9439 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -10334,6 +10334,53 @@ are limited.
UDT |
+RaiseError |
+`raise_error` |
+Throw an exception |
+None |
+project |
+input |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+S |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+
+
+result |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+S |
+ |
+ |
+ |
+ |
+ |
+ |
+
+
Rand |
`random`, `rand` |
Generate a random column with i.i.d. uniformly distributed values in [0, 1) |
@@ -10627,6 +10674,32 @@ are limited.
|
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
Remainder |
`%`, `mod` |
Remainder or modulo |
@@ -10695,32 +10768,6 @@ are limited.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
ReplicateRows |
|
Given an input row replicates the row N times |
@@ -10999,6 +11046,32 @@ are limited.
NS |
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
Second |
`second` |
Returns the second component of the string/timestamp |
@@ -11135,32 +11208,6 @@ are limited.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
ShiftLeft |
`shiftleft` |
Bitwise shift left (<<) |
@@ -11365,6 +11412,32 @@ are limited.
|
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
Signum |
`sign`, `signum` |
Returns -1.0, 0.0 or 1.0 as expr is negative, 0 or positive |
@@ -11502,32 +11575,6 @@ are limited.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
Sinh |
`sinh` |
Hyperbolic sine |
@@ -11733,6 +11780,32 @@ are limited.
|
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
SortOrder |
|
Sort order |
@@ -11874,32 +11947,6 @@ are limited.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
Sqrt |
`sqrt` |
Square root |
@@ -12147,10 +12194,36 @@ are limited.
|
-StringLocate |
-`position`, `locate` |
-Substring search operator |
-None |
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
+StringLocate |
+`position`, `locate` |
+Substring search operator |
+None |
project |
substr |
|
@@ -12236,32 +12309,6 @@ are limited.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
StringRPad |
`rpad` |
Pad a string on the right |
@@ -12508,6 +12555,32 @@ are limited.
|
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
StringSplit |
`split` |
Splits `str` around occurrences that match `regex` |
@@ -12597,32 +12670,6 @@ are limited.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
StringToMap |
`str_to_map` |
Creates a map after splitting the input string into pairs of key-value strings |
@@ -12916,6 +12963,32 @@ are limited.
|
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
Substring |
`substr`, `substring` |
Substring operator |
@@ -13005,32 +13078,6 @@ are limited.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
SubstringIndex |
`substring_index` |
substring_index operator |
@@ -13342,6 +13389,32 @@ are limited.
|
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
Tanh |
`tanh` |
Hyperbolic tangent |
@@ -13432,32 +13505,6 @@ are limited.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
TimeAdd |
|
Adds interval to timestamp |
@@ -13756,6 +13803,32 @@ are limited.
|
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
TransformValues |
`transform_values` |
Transform values in a map using a transform function |
@@ -13824,32 +13897,6 @@ are limited.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
UnaryMinus |
`negative` |
Negate a numeric value |
@@ -14150,6 +14197,32 @@ are limited.
|
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
UnscaledValue |
|
Convert a Decimal to an unscaled long value for some aggregation optimizations |
@@ -14197,32 +14270,6 @@ are limited.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
Upper |
`upper`, `ucase` |
String uppercase operator |
diff --git a/integration_tests/src/main/python/misc_expr_test.py b/integration_tests/src/main/python/misc_expr_test.py
index 41d01ba0065..8802a0eef35 100644
--- a/integration_tests/src/main/python/misc_expr_test.py
+++ b/integration_tests/src/main/python/misc_expr_test.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2021, NVIDIA CORPORATION.
+# Copyright (c) 2020-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
import pytest
-from asserts import assert_gpu_and_cpu_are_equal_collect
+from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error
from data_gen import *
from marks import incompat, approximate_float
from pyspark.sql.types import *
@@ -31,3 +31,24 @@ def test_part_id():
lambda spark : unary_op_df(spark, short_gen, num_slices=8).select(
f.col('a'),
f.spark_partition_id()))
+
+def test_raise_error():
+ data_gen = ShortGen(nullable=False, min_val=0, max_val=20, special_cases=[])
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark: unary_op_df(spark, data_gen, num_slices=2).select(
+ f.when(f.col('a') > 30, f.raise_error("unexpected"))))
+
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark: spark.range(0).select(f.raise_error(f.col("id"))))
+
+ assert_gpu_and_cpu_error(
+ lambda spark: unary_op_df(spark, null_gen, length=2, num_slices=1).select(
+ f.raise_error(f.col('a'))).collect(),
+ conf={},
+ error_message="java.lang.RuntimeException")
+
+ assert_gpu_and_cpu_error(
+ lambda spark: unary_op_df(spark, short_gen, length=2, num_slices=1).select(
+ f.raise_error(f.lit("unexpected"))).collect(),
+ conf={},
+ error_message="java.lang.RuntimeException: unexpected")
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index 77574e8dd95..67da0b3220b 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -3477,7 +3477,15 @@ object GpuOverrides extends Logging {
TypeSig.ARRAY.nested(TypeSig.commonCudfTypesWithNested),
TypeSig.ARRAY.nested(TypeSig.all)),
(e, conf, p, r) => new GpuGetArrayStructFieldsMeta(e, conf, p, r)
- )
+ ),
+ expr[RaiseError](
+ "Throw an exception",
+ ExprChecks.unaryProject(
+ TypeSig.NULL, TypeSig.NULL,
+ TypeSig.STRING, TypeSig.STRING),
+ (a, conf, p, r) => new UnaryExprMeta[RaiseError](a, conf, p, r) {
+ override def convertToGpu(child: Expression): GpuExpression = GpuRaiseError(child)
+ })
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
// Shim expressions should be last to allow overrides with shim-specific versions
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/misc.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/misc.scala
new file mode 100644
index 00000000000..159711afa3e
--- /dev/null
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/misc.scala
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids
+
+import ai.rapids.cudf.{ColumnVector}
+import com.nvidia.spark.rapids.{Arm, GpuColumnVector, GpuUnaryExpression}
+
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression}
+import org.apache.spark.sql.types.{AbstractDataType, DataType, NullType, StringType}
+
+case class GpuRaiseError(child: Expression) extends GpuUnaryExpression with ExpectsInputTypes
+ with Arm {
+
+ override def dataType: DataType = NullType
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
+ override def toString: String = s"raise_error($child)"
+
+ /** Could evaluating this expression cause side-effects, such as throwing an exception? */
+ override def hasSideEffects: Boolean = true
+
+ override protected def doColumnar(input: GpuColumnVector): ColumnVector = {
+ if (input.getRowCount <= 0) {
+ // For the case: when(condition, raise_error(col("a"))
+ return GpuColumnVector.columnVectorFromNull(0, NullType)
+ }
+
+ // Take the first one as the error message
+ withResource(input.getBase.getScalarElement(0)) { scalarMsg =>
+ if (!scalarMsg.isValid()) {
+ throw new RuntimeException()
+ } else {
+ throw new RuntimeException(scalarMsg.getJavaString())
+ }
+ }
+ }
+
+}
diff --git a/tools/src/main/resources/operatorsScore.csv b/tools/src/main/resources/operatorsScore.csv
index c38f3e2c3f8..178069c403c 100644
--- a/tools/src/main/resources/operatorsScore.csv
+++ b/tools/src/main/resources/operatorsScore.csv
@@ -172,6 +172,7 @@ PromotePrecision,3
PythonUDF,3
Quarter,3
RLike,3
+RaiseError,3
Rand,3
Rank,3
RegExpExtract,3
diff --git a/tools/src/main/resources/supportedExprs.csv b/tools/src/main/resources/supportedExprs.csv
index 145cc9a61d6..d98e137e87a 100644
--- a/tools/src/main/resources/supportedExprs.csv
+++ b/tools/src/main/resources/supportedExprs.csv
@@ -369,6 +369,8 @@ Quarter,S,`quarter`,None,project,result,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
RLike,S,`rlike`,None,project,str,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA
RLike,S,`rlike`,None,project,regexp,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA
RLike,S,`rlike`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
+RaiseError,S,`raise_error`,None,project,input,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA
+RaiseError,S,`raise_error`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA
Rand,S,`random`; `rand`,None,project,seed,NA,NA,NA,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
Rand,S,`random`; `rand`,None,project,result,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
Rank,S,`rank`,None,window,ordering,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,NS,NS