diff --git a/docs/compatibility.md b/docs/compatibility.md index 7b4261f7df6..5b925aa36c3 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -446,6 +446,7 @@ The following Apache Spark regular expression functions and expressions are supp - `RLIKE` - `regexp` +- `regexp_extract` - `regexp_like` - `regexp_replace` @@ -457,6 +458,7 @@ These operations can be enabled on the GPU with the following configuration sett - `spark.rapids.sql.expression.RLike=true` (for `RLIKE`, `regexp`, and `regexp_like`) - `spark.rapids.sql.expression.RegExpReplace=true` for `regexp_replace` +- `spark.rapids.sql.expression.RegExpExtract=true` for `regexp_extract` Even when these expressions are enabled, there are instances where regular expression operations will fall back to CPU when the RAPIDS Accelerator determines that a pattern is either unsupported or would produce incorrect results on the GPU. @@ -475,8 +477,6 @@ Here are some examples of regular expression patterns that are not supported on In addition to these cases that can be detected, there are also known issues that can cause incorrect results: -- `$` does not match the end of a string if the string ends with a line-terminator - ([cuDF issue #9620](https://github.com/rapidsai/cudf/issues/9620)) - Character classes for negative matches have different behavior between CPU and GPU for multiline strings. The pattern `[^a]` will match line-terminators on CPU but not on GPU. diff --git a/docs/configs.md b/docs/configs.md index f0d6db7c4e2..796d8a612c9 100644 --- a/docs/configs.md +++ b/docs/configs.md @@ -262,6 +262,7 @@ Name | SQL Function(s) | Description | Default Value | Notes spark.rapids.sql.expression.RLike|`rlike`|RLike|false|This is disabled by default because the implementation is not 100% compatible. See the compatibility guide for more information.| spark.rapids.sql.expression.Rand|`random`, `rand`|Generate a random column with i.i.d. uniformly distributed values in [0, 1)|true|None| spark.rapids.sql.expression.Rank|`rank`|Window function that returns the rank value within the aggregation window|true|None| +spark.rapids.sql.expression.RegExpExtract|`regexp_extract`|RegExpExtract|false|This is disabled by default because the implementation is not 100% compatible. See the compatibility guide for more information.| spark.rapids.sql.expression.RegExpReplace|`regexp_replace`|RegExpReplace support for string literal input patterns|false|This is disabled by default because the implementation is not 100% compatible. See the compatibility guide for more information.| spark.rapids.sql.expression.Remainder|`%`, `mod`|Remainder or modulo|true|None| spark.rapids.sql.expression.Rint|`rint`|Rounds up a double value to the nearest double equal to an integer|true|None| diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 0597fe41d6f..9a46a116734 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -9823,6 +9823,95 @@ are limited. +RegExpExtract +`regexp_extract` +RegExpExtract +This is disabled by default because the implementation is not 100% compatible. See the compatibility guide for more information. +project +str + + + + + + + + + +S + + + + + + + + + + +regexp + + + + + + + + + +PS
Literal value only
+ + + + + + + + + + +idx + + + +S + + + + + + + + + + + + + + + + +result + + + + + + + + + +S + + + + + + + + + + RegExpReplace `regexp_replace` RegExpReplace support for string literal input patterns @@ -9980,6 +10069,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Rint `rint` Rounds up a double value to the nearest double equal to an integer @@ -10070,32 +10185,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Round `round` Round an expression to d decimal places using HALF_UP rounding mode @@ -10352,6 +10441,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + ShiftRight `shiftright` Bitwise shift right (>>) @@ -10488,32 +10603,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Signum `sign`, `signum` Returns -1.0, 0.0 or 1.0 as expr is negative, 0 or positive @@ -10741,6 +10830,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Size `size`, `cardinality` The size of an array or a map @@ -10856,32 +10971,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - SortOrder Sort order @@ -11113,6 +11202,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StartsWith Starts with @@ -11270,32 +11385,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - StringLocate `position`, `locate` Substring search operator @@ -11474,6 +11563,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StringRepeat `repeat` StringRepeat operator that repeats the given strings with numbers of times given by repeatTimes @@ -11607,54 +11722,28 @@ are limited. - - - -result - - - - - - - - - -S - - - - - - - - - - -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT + + + +result + + + + + + + + + +S + + + + + + + + StringSplit @@ -11882,6 +11971,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StringTrimRight `rtrim` StringTrimRight operator @@ -12039,32 +12154,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - SubstringIndex `substring_index` substring_index operator @@ -12286,6 +12375,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Tan `tan` Tangent @@ -12466,32 +12581,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - TimeAdd Adds interval to timestamp @@ -12675,6 +12764,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + ToRadians `radians` Converts degrees to radians @@ -12858,32 +12973,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - TransformValues `transform_values` Transform values in a map using a transform function @@ -13042,6 +13131,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + UnaryPositive `positive` A numeric value with a + in front of it @@ -13252,32 +13367,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - UnscaledValue Convert a Decimal to an unscaled long value for some aggregation optimizations @@ -13419,6 +13508,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + WindowExpression Calculates a return value for every input row of a table based on a group (or "window") of rows diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index ffbad689653..7a7843b6357 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -14,7 +14,7 @@ import pytest -from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_sql_fallback_collect, assert_gpu_fallback_collect +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_sql_fallback_collect, assert_gpu_fallback_collect, assert_gpu_and_cpu_error from conftest import is_databricks_runtime from data_gen import * from marks import * @@ -538,6 +538,69 @@ def test_regexp_replace_character_set_negated(): 'regexp_replace(a, "[^\n]", "1")'), conf={'spark.rapids.sql.expression.RegExpReplace': 'true'}) +def test_regexp_extract(): + gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)$", 1)', + 'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)$", 2)', + 'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)$", 3)'), + conf={'spark.rapids.sql.expression.RegExpExtract': 'true'}) + +def test_regexp_extract_no_match(): + gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_extract(a, "^([0-9]+)([a-z]+)([0-9]+)$", 0)', + 'regexp_extract(a, "^([0-9]+)([a-z]+)([0-9]+)$", 1)', + 'regexp_extract(a, "^([0-9]+)([a-z]+)([0-9]+)$", 2)', + 'regexp_extract(a, "^([0-9]+)([a-z]+)([0-9]+)$", 3)'), + conf={'spark.rapids.sql.expression.RegExpExtract': 'true'}) + +# if we determine that the index is out of range we fall back to CPU and let +# Spark take care of the error handling +@allow_non_gpu('ProjectExec', 'RegExpExtract') +def test_regexp_extract_idx_negative(): + gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}') + assert_gpu_and_cpu_error( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)$", -1)').collect(), + error_message = "The specified group index cannot be less than zero", + conf={'spark.rapids.sql.expression.RegExpExtract': 'true'}) + +# if we determine that the index is out of range we fall back to CPU and let +# Spark take care of the error handling +@allow_non_gpu('ProjectExec', 'RegExpExtract') +def test_regexp_extract_idx_out_of_bounds(): + gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}') + assert_gpu_and_cpu_error( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)$", 4)').collect(), + error_message = "Regex group count is 3, but the specified group index is 4", + conf={'spark.rapids.sql.expression.RegExpExtract': 'true'}) + +def test_regexp_extract_multiline(): + gen = mk_str_gen('[abcd]{2}[\r\n]{0,2}[0-9]{2}[\r\n]{0,2}[abcd]{2}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_extract(a, "^([a-d]*)([\r\n]*)", 2)'), + conf={'spark.rapids.sql.expression.RegExpExtract': 'true'}) + +def test_regexp_extract_multiline_negated_character_class(): + gen = mk_str_gen('[abcd]{2}[\r\n]{0,2}[0-9]{2}[\r\n]{0,2}[abcd]{2}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_extract(a, "^([a-d]*)([^a-z]*)([a-d]*)$", 2)'), + conf={'spark.rapids.sql.expression.RegExpExtract': 'true'}) + +def test_regexp_extract_idx_0(): + gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_extract(a, "^([a-d]*)([0-9]*)([a-d]*)$", 0)', + 'regexp_extract(a, "^([a-d]*)[0-9]*([a-d]*)$", 0)'), + conf={'spark.rapids.sql.expression.RegExpExtract': 'true'}) + def test_rlike(): gen = mk_str_gen('[abcd]{1,3}') assert_gpu_and_cpu_are_equal_collect( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 58174ba1ccf..e40697b6b46 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3030,6 +3030,17 @@ object GpuOverrides extends Logging { (a, conf, p, r) => new GpuRLikeMeta(a, conf, p, r)).disabledByDefault( "the implementation is not 100% compatible. " + "See the compatibility guide for more information."), + expr[RegExpExtract]( + "RegExpExtract", + ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, + Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING), + ParamCheck("regexp", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), + ParamCheck("idx", TypeSig.lit(TypeEnum.INT), + TypeSig.lit(TypeEnum.INT)))), + (a, conf, p, r) => new GpuRegExpExtractMeta(a, conf, p, r)) + .disabledByDefault( + "the implementation is not 100% compatible. " + + "See the compatibility guide for more information."), expr[Length]( "String character length or binary byte length", ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 6a14f74b561..692be261f6a 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -23,7 +23,7 @@ import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.shims.v2.ShimExpression -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ImplicitCastInputTypes, Literal, NullIntolerant, Predicate, RLike, StringSplit, SubstringIndex} +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ImplicitCastInputTypes, Literal, NullIntolerant, Predicate, RegExpExtract, RLike, StringSplit, SubstringIndex} import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.unsafe.types.UTF8String @@ -801,24 +801,9 @@ case class GpuRLike(left: Expression, right: Expression, pattern: String) override def dataType: DataType = BooleanType } -case class GpuRegExpReplace( - srcExpr: Expression, - searchExpr: Expression, - replaceExpr: Expression, - cudfRegexPattern: String) - extends GpuTernaryExpression with ImplicitCastInputTypes { +abstract class GpuRegExpTernaryBase extends GpuTernaryExpression { - override def dataType: DataType = srcExpr.dataType - - override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) - - override def first: Expression = srcExpr - override def second: Expression = searchExpr - override def third: Expression = replaceExpr - - def this(srcExpr: Expression, searchExpr: Expression, cudfRegexPattern: String) = { - this(srcExpr, searchExpr, GpuLiteral("", StringType), cudfRegexPattern) - } + override def dataType: DataType = StringType override def doColumnar( strExpr: GpuColumnVector, @@ -852,23 +837,171 @@ case class GpuRegExpReplace( override def doColumnar( strExpr: GpuColumnVector, - searchExpr: GpuScalar, - replaceExpr: GpuScalar): ColumnVector = { - strExpr.getBase.replaceRegex(cudfRegexPattern, replaceExpr.getBase) - } + searchExpr: GpuColumnVector, + replaceExpr: GpuScalar): ColumnVector = + throw new UnsupportedOperationException(s"Cannot columnar evaluate expression: $this") override def doColumnar(numRows: Int, val0: GpuScalar, val1: GpuScalar, val2: GpuScalar): ColumnVector = { - withResource(GpuColumnVector.from(val0, numRows, srcExpr.dataType)) { val0Col => + withResource(GpuColumnVector.from(val0, numRows, first.dataType)) { val0Col => doColumnar(val0Col, val1, val2) } } +} + +case class GpuRegExpReplace( + srcExpr: Expression, + searchExpr: Expression, + replaceExpr: Expression, + cudfRegexPattern: String) + extends GpuRegExpTernaryBase with ImplicitCastInputTypes { + + override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) + + override def first: Expression = srcExpr + override def second: Expression = searchExpr + override def third: Expression = replaceExpr + + def this(srcExpr: Expression, searchExpr: Expression, cudfRegexPattern: String) = { + this(srcExpr, searchExpr, GpuLiteral("", StringType), cudfRegexPattern) + } + override def doColumnar( strExpr: GpuColumnVector, - searchExpr: GpuColumnVector, - replaceExpr: GpuScalar): ColumnVector = - throw new UnsupportedOperationException(s"Cannot columnar evaluate expression: $this") + searchExpr: GpuScalar, + replaceExpr: GpuScalar): ColumnVector = { + strExpr.getBase.replaceRegex(cudfRegexPattern, replaceExpr.getBase) + } + +} + +class GpuRegExpExtractMeta( + expr: RegExpExtract, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends TernaryExprMeta[RegExpExtract](expr, conf, parent, rule) { + + private var pattern: Option[String] = None + private var numGroups = 0 + + override def tagExprForGpu(): Unit = { + + def countGroups(regexp: RegexAST): Int = { + regexp match { + case RegexGroup(_, term) => 1 + countGroups(term) + case other => other.children().map(countGroups).sum + } + } + + expr.regexp match { + case Literal(str: UTF8String, DataTypes.StringType) if str != null => + try { + val javaRegexpPattern = str.toString + // verify that we support this regex and can transpile it to cuDF format + val cudfRegexPattern = new CudfRegexTranspiler(replace = false) + .transpile(javaRegexpPattern) + pattern = Some(cudfRegexPattern) + numGroups = countGroups(new RegexParser(javaRegexpPattern).parse()) + } catch { + case e: RegexUnsupportedException => + willNotWorkOnGpu(e.getMessage) + } + case _ => + willNotWorkOnGpu(s"only non-null literal strings are supported on GPU") + } + + expr.idx match { + case Literal(value, DataTypes.IntegerType) => + val idx = value.asInstanceOf[Int] + if (idx < 0) { + willNotWorkOnGpu("the specified group index cannot be less than zero") + } + if (idx > numGroups) { + willNotWorkOnGpu( + s"regex group count is $numGroups, but the specified group index is $idx") + } + case _ => + willNotWorkOnGpu("GPU only supports literal index") + } + } + + override def convertToGpu( + str: Expression, + regexp: Expression, + idx: Expression): GpuExpression = { + val cudfPattern = pattern.getOrElse( + throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern")) + GpuRegExpExtract(str, regexp, idx, cudfPattern) + } +} + +case class GpuRegExpExtract( + subject: Expression, + regexp: Expression, + idx: Expression, + cudfRegexPattern: String) + extends GpuRegExpTernaryBase with ImplicitCastInputTypes with NullIntolerant { + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) + override def first: Expression = subject + override def second: Expression = regexp + override def third: Expression = idx + + override def prettyName: String = "regexp_extract" + + override def doColumnar( + str: GpuColumnVector, + regexp: GpuScalar, + idx: GpuScalar): ColumnVector = { + + val groupIndex = idx.getValue.asInstanceOf[Int] + + // There are some differences in behavior between cuDF and Java so we have + // to handle those cases here. + // + // Given the pattern `^([a-z]*)([0-9]*)([a-z]*)$` the following table + // shows the value that would be extracted for group index 2 given a range + // of inputs. The behavior is mostly consistent except for the case where + // the input is non-null and does not match the pattern. + // + // | Input | Java | cuDF | + // |--------|-------|-------| + // | '' | '' | '' | + // | NULL | NULL | NULL | + // | 'a1a' | '1' | '1' | + // | '1a1' | '' | NULL | + + if (groupIndex == 0) { + withResource(GpuScalar.from("", DataTypes.StringType)) { emptyString => + withResource(GpuScalar.from(null, DataTypes.StringType)) { nullString => + withResource(str.getBase.matchesRe(cudfRegexPattern)) { matches => + withResource(str.getBase.isNull) { isNull => + withResource(matches.ifElse(str.getBase, emptyString)) { + isNull.ifElse(nullString, _) + } + } + } + } + } + } else { + withResource(GpuScalar.from("", DataTypes.StringType)) { emptyString => + withResource(GpuScalar.from(null, DataTypes.StringType)) { nullString => + withResource(str.getBase.extractRe(cudfRegexPattern)) { extract => + withResource(str.getBase.matchesRe(cudfRegexPattern)) { matches => + withResource(str.getBase.isNull) { isNull => + withResource(matches.ifElse(extract.getColumn(groupIndex - 1), emptyString)) { + isNull.ifElse(nullString, _) + } + } + } + } + } + } + } + } + } class SubstringIndexMeta( diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/StringFallbackSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala similarity index 76% rename from tests/src/test/scala/com/nvidia/spark/rapids/StringFallbackSuite.scala rename to tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala index 6ba3deed5c6..db3a1791317 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/StringFallbackSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala @@ -16,10 +16,13 @@ package com.nvidia.spark.rapids import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, SparkSession} -class StringFallbackSuite extends SparkQueryCompareTestSuite { +class RegularExpressionSuite extends SparkQueryCompareTestSuite { - private val conf = new SparkConf().set("spark.rapids.sql.expression.RegExpReplace", "true") + private val conf = new SparkConf() + .set("spark.rapids.sql.expression.RegExpReplace", "true") + .set("spark.rapids.sql.expression.RegExpExtract", "true") testGpuFallback( "String regexp_replace replace str columnar fall back", @@ -90,4 +93,31 @@ class StringFallbackSuite extends SparkQueryCompareTestSuite { nullableStringsFromCsv, conf = conf) { frame => frame.selectExpr("regexp_replace(strings,'\\(foo\\)','D')") } + + testSparkResultsAreEqual("String regexp_extract regex 1", + extractStrings, conf = conf) { + frame => frame.selectExpr("regexp_extract(strings, '^([a-z]*)([0-9]*)([a-z]*)$', 1)") + } + + testSparkResultsAreEqual("String regexp_extract regex 2", + extractStrings, conf = conf) { + frame => frame.selectExpr("regexp_extract(strings, '^([a-z]*)([0-9]*)([a-z]*)$', 2)") + } + + testSparkResultsAreEqual("String regexp_extract literal input", + extractStrings, conf = conf) { + frame => frame.selectExpr("regexp_extract('abc123def', '^([a-z]*)([0-9]*)([a-z]*)$', 2)") + } + + private def extractStrings(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[(String)]( + (""), + (null), + ("abc123def"), + ("abc\r\n12\r3\ndef"), + ("123abc456") + ).toDF("strings") + } + }