Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for regexp_extract_all with idx 0 #5947

Closed
2 changes: 2 additions & 0 deletions docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,7 @@ The following Apache Spark regular expression functions and expressions are supp
- `RLIKE`
- `regexp`
- `regexp_extract`
- `regexp_extract_all`
- `regexp_like`
- `regexp_replace`
- `string_split`
Expand Down Expand Up @@ -606,6 +607,7 @@ The following regular expression patterns are not yet supported on the GPU and w
- Empty groups: `()`
- Regular expressions containing null characters (unless the pattern is a simple literal string)
- `regexp_replace` does not support back-references
- `regexp_extract_all` only supports group index 0

Work is ongoing to increase the range of regular expressions that can run on the GPU.

Expand Down
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.Rand"></a>spark.rapids.sql.expression.Rand|`random`, `rand`|Generate a random column with i.i.d. uniformly distributed values in [0, 1)|true|None|
<a name="sql.expression.Rank"></a>spark.rapids.sql.expression.Rank|`rank`|Window function that returns the rank value within the aggregation window|true|None|
<a name="sql.expression.RegExpExtract"></a>spark.rapids.sql.expression.RegExpExtract|`regexp_extract`|Extract a specific group identified by a regular expression|true|None|
<a name="sql.expression.RegExpExtractAll"></a>spark.rapids.sql.expression.RegExpExtractAll|`regexp_extract_all`|Extract all strings matching a regular expression corresponding to the regex group index|true|None|
<a name="sql.expression.RegExpReplace"></a>spark.rapids.sql.expression.RegExpReplace|`regexp_replace`|String replace using a regular expression pattern|true|None|
<a name="sql.expression.Remainder"></a>spark.rapids.sql.expression.Remainder|`%`, `mod`|Remainder or modulo|true|None|
<a name="sql.expression.ReplicateRows"></a>spark.rapids.sql.expression.ReplicateRows| |Given an input row replicates the row N times|true|None|
Expand Down
89 changes: 89 additions & 0 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -10611,6 +10611,95 @@ are limited.
<td> </td>
</tr>
<tr>
<td rowSpan="4">RegExpExtractAll</td>
<td rowSpan="4">`regexp_extract_all`</td>
<td rowSpan="4">Extract all strings matching a regular expression corresponding to the regex group index</td>
<td rowSpan="4">None</td>
<td rowSpan="4">project</td>
<td>str</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>regexp</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>Literal value only</em></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>idx</td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>Literal value only</em></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="5">RegExpReplace</td>
<td rowSpan="5">`regexp_replace`</td>
<td rowSpan="5">String replace using a regular expression pattern</td>
Expand Down
43 changes: 43 additions & 0 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,3 +1130,46 @@ def test_rlike_fallback_possessive_quantifier():
'a rlike "a*+"'),
'RLike',
conf=_regexp_conf)

def test_regexp_extract_all_idx_zero():
gen = mk_str_gen('[abcd]{0,3}[0-9]{0,3}-[0-9]{0,3}[abcd]{1,3}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'regexp_extract_all(a, "([a-d]+).*([0-9])", 0)',
'regexp_extract_all(a, "(a)(b)", 0)',
'regexp_extract_all(a, "([a-z0-9]([abcd]))", 0)',
'regexp_extract_all(a, "(\\\\d+)-(\\\\d+)", 0)',
),
conf=_regexp_conf)

# https://github.com/NVIDIA/spark-rapids/issues/4283
@allow_non_gpu('ProjectExec', 'RegExpExtractAll')
def test_regexp_extract_all_idx_positive_fallback():
gen = mk_str_gen('[abcd]{0,3}[0-9]{0,3}-[0-9]{0,3}[abcd]{1,3}')
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark : unary_op_df(spark, gen).selectExpr(
'regexp_extract_all(a, "([a-d]+).*([0-9])", 1)',
),
conf=_regexp_conf,
exist_classes= "ProjectExec",
non_exist_classes= "GpuProjectExec")

@allow_non_gpu('ProjectExec', 'RegExpExtractAll')
def test_regexp_extract_all_idx_negative():
gen = mk_str_gen('[abcd]{0,3}')
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, gen).selectExpr(
'regexp_extract_all(a, "(a)", -1)'
).collect(),
error_message="The specified group index cannot be less than zero",
conf=_regexp_conf)

@allow_non_gpu('ProjectExec', 'RegExpExtractAll')
def test_regexp_extract_all_idx_out_of_bounds():
gen = mk_str_gen('[abcd]{0,3}')
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, gen).selectExpr(
'regexp_extract_all(a, "([a-d]+).*([0-9])", 3)'
).collect(),
error_message="Regex group count is 2, but the specified group index is 3",
conf=_regexp_conf)
Original file line number Diff line number Diff line change
Expand Up @@ -3194,6 +3194,14 @@ object GpuOverrides extends Logging {
ParamCheck("idx", TypeSig.lit(TypeEnum.INT),
TypeSig.lit(TypeEnum.INT)))),
(a, conf, p, r) => new GpuRegExpExtractMeta(a, conf, p, r)),
expr[RegExpExtractAll](
"Extract all strings matching a regular expression corresponding to the regex group index",
ExprChecks.projectOnly(TypeSig.ARRAY.nested(TypeSig.STRING),
TypeSig.ARRAY.nested(TypeSig.STRING),
Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING),
ParamCheck("regexp", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING),
ParamCheck("idx", TypeSig.lit(TypeEnum.INT), TypeSig.INT))),
(a, conf, p, r) => new GpuRegExpExtractAllMeta(a, conf, p, r)),
expr[Length](
"String character length or binary byte length",
ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.{RegExpShim, ShimExpression}

import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ImplicitCastInputTypes, InputFileName, Literal, NullIntolerant, Predicate, RegExpExtract, RLike, StringSplit, StringToMap, SubstringIndex, TernaryExpression}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ImplicitCastInputTypes, InputFileName, Literal, NullIntolerant, Predicate, RegExpExtract, RegExpExtractAll, RLike, StringSplit, StringToMap, SubstringIndex, TernaryExpression}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -865,6 +865,20 @@ object GpuRegExpUtils {
isASTEmptyRepetition(parseAST(pattern))
}

/**
* Returns the number of groups in regexp
* (includes both capturing and non-capturing groups)
*/
def countGroups(pattern: String): Int = {
def countGroups(regexp: RegexAST): Int = {
regexp match {
case RegexGroup(_, term) => 1 + countGroups(term)
case other => other.children().map(countGroups).sum
}
}
countGroups(parseAST(pattern))
}

}

class GpuRLikeMeta(
Expand Down Expand Up @@ -1063,21 +1077,14 @@ class GpuRegExpExtractMeta(
case _ =>
}

def countGroups(regexp: RegexAST): Int = {
regexp match {
case RegexGroup(_, term) => 1 + countGroups(term)
case other => other.children().map(countGroups).sum
}
}

expr.regexp match {
case Literal(str: UTF8String, DataTypes.StringType) if str != null =>
try {
val javaRegexpPattern = str.toString
// verify that we support this regex and can transpile it to cuDF format
pattern = Some(new CudfRegexTranspiler(RegexFindMode)
.transpile(javaRegexpPattern, None)._1)
numGroups = countGroups(new RegexParser(javaRegexpPattern).parse())
numGroups = GpuRegExpUtils.countGroups(javaRegexpPattern)
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
Expand Down Expand Up @@ -1174,7 +1181,88 @@ case class GpuRegExpExtract(
}
}
}
}

class GpuRegExpExtractAllMeta(
expr: RegExpExtractAll,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends TernaryExprMeta[RegExpExtractAll](expr, conf, parent, rule) {

private var pattern: Option[String] = None
private var numGroups = 0

override def tagExprForGpu(): Unit = {
GpuRegExpUtils.tagForRegExpEnabled(this)

expr.regexp match {
case Literal(str: UTF8String, DataTypes.StringType) if str != null =>
try {
val javaRegexpPattern = str.toString
// verify that we support this regex and can transpile it to cuDF format
pattern = Some(new CudfRegexTranspiler(RegexFindMode)
.transpile(javaRegexpPattern, None)._1)
numGroups = GpuRegExpUtils.countGroups(javaRegexpPattern)
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
}
case _ =>
willNotWorkOnGpu(s"only non-null literal strings are supported on GPU")
}

expr.idx match {
case Literal(value, DataTypes.IntegerType) =>
val idx = value.asInstanceOf[Int]
if (idx < 0) {
willNotWorkOnGpu("the specified group index cannot be less than zero")
}
if(idx > 0) {
willNotWorkOnGpu("regexp_extract_all only supports group index of 0")
}
if (idx > numGroups) {
willNotWorkOnGpu(
s"regex group count is $numGroups, but the specified group index is $idx")
}
case _ =>
willNotWorkOnGpu("GPU only supports literal index")
}
}

override def convertToGpu(
str: Expression,
regexp: Expression,
idx: Expression): GpuExpression = {
val cudfPattern = pattern.getOrElse(
throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern"))
GpuRegExpExtractAll(str, regexp, idx, cudfPattern)
}
}

case class GpuRegExpExtractAll(
str: Expression,
regexp: Expression,
idx: Expression,
cudfRegexPattern: String)
extends GpuRegExpTernaryBase with ImplicitCastInputTypes with NullIntolerant {

override def dataType: DataType = ArrayType(StringType, containsNull = true)
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType)
override def first: Expression = str
override def second: Expression = regexp
override def third: Expression = idx

override def prettyName: String = "regexp_extract_all"

override def doColumnar(
str: GpuColumnVector,
regexp: GpuScalar,
idx: GpuScalar): ColumnVector = {

val intIdx = idx.getValue.asInstanceOf[Int]
str.getBase.extractAllRecord(cudfRegexPattern, intIdx)
}
}

class SubstringIndexMeta(
Expand Down
1 change: 1 addition & 0 deletions tools/src/main/resources/operatorsScore.csv
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ RaiseError,4
Rand,4
Rank,4
RegExpExtract,4
RegExpExtractAll,4
RegExpReplace,4
Remainder,4
ReplicateRows,4
Expand Down
4 changes: 4 additions & 0 deletions tools/src/main/resources/supportedExprs.csv
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,10 @@ RegExpExtract,S,`regexp_extract`,None,project,str,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,N
RegExpExtract,S,`regexp_extract`,None,project,regexp,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA
RegExpExtract,S,`regexp_extract`,None,project,idx,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
RegExpExtract,S,`regexp_extract`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA
RegExpExtractAll,S,`regexp_extract_all`,None,project,str,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA
RegExpExtractAll,S,`regexp_extract_all`,None,project,regexp,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA
RegExpExtractAll,S,`regexp_extract_all`,None,project,idx,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
RegExpExtractAll,S,`regexp_extract_all`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA
RegExpReplace,S,`regexp_replace`,None,project,regex,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA
RegExpReplace,S,`regexp_replace`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA
RegExpReplace,S,`regexp_replace`,None,project,pos,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
Expand Down