Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for regexp_extract_all on GPU #5968

Merged
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
2513a74
[WIP] initial work
anthony-chang Jun 2, 2022
1bc597e
[WIP] initial work
anthony-chang Jun 2, 2022
b90c914
Merge branch 'regexp-extract-all' of github.com:anthony-chang/spark-r…
anthony-chang Jun 23, 2022
4a0f2cd
Merge branch 'branch-22.08' of github.com:nvidia/spark-rapids into re…
anthony-chang Jun 23, 2022
c17a594
[WIP] use findall for group idx=0
anthony-chang Jun 23, 2022
fb21eee
Merge branch 'branch-22.08' of github.com:nvidia/spark-rapids into re…
anthony-chang Jun 30, 2022
48145b5
Fix typing issues, add python tests
anthony-chang Jul 4, 2022
e782683
Remove scala tests, add python tests for errors/fallback
anthony-chang Jul 4, 2022
b9a716c
Fix scalastyle
anthony-chang Jul 4, 2022
8cd0f8d
[WIP] initial implementation for regexp extract all
anthony-chang Jul 6, 2022
8385831
Cleanup, add comments, handle idx=0 and > 0 cases
anthony-chang Jul 7, 2022
916d1e3
Merge branch 'branch-22.08' of github.com:nvidia/spark-rapids into re…
anthony-chang Jul 7, 2022
8b9d356
Add python tests and handle null inputs properly
anthony-chang Jul 7, 2022
e02b59a
Update docs
anthony-chang Jul 7, 2022
2af89b7
Cleanup impoorts
anthony-chang Jul 7, 2022
786d0a9
Rename variables
anthony-chang Jul 8, 2022
c53a9d1
Merge branch 'branch-22.08' of github.com:nvidia/spark-rapids into re…
anthony-chang Jul 8, 2022
429e180
Wrap scalar in withResource call
anthony-chang Jul 8, 2022
1834490
Address feedback about resource exceptions
anthony-chang Jul 8, 2022
eb5b7b3
Merge branch 'branch-22.08' of github.com:NVIDIA/spark-rapids into re…
anthony-chang Jul 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,7 @@ The following Apache Spark regular expression functions and expressions are supp
- `RLIKE`
- `regexp`
- `regexp_extract`
- `regexp_extract_all`
- `regexp_like`
- `regexp_replace`
- `string_split`
Expand Down
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.Rand"></a>spark.rapids.sql.expression.Rand|`random`, `rand`|Generate a random column with i.i.d. uniformly distributed values in [0, 1)|true|None|
<a name="sql.expression.Rank"></a>spark.rapids.sql.expression.Rank|`rank`|Window function that returns the rank value within the aggregation window|true|None|
<a name="sql.expression.RegExpExtract"></a>spark.rapids.sql.expression.RegExpExtract|`regexp_extract`|Extract a specific group identified by a regular expression|true|None|
<a name="sql.expression.RegExpExtractAll"></a>spark.rapids.sql.expression.RegExpExtractAll|`regexp_extract_all`|Extract all strings matching a regular expression corresponding to the regex group index|true|None|
<a name="sql.expression.RegExpReplace"></a>spark.rapids.sql.expression.RegExpReplace|`regexp_replace`|String replace using a regular expression pattern|true|None|
<a name="sql.expression.Remainder"></a>spark.rapids.sql.expression.Remainder|`%`, `mod`|Remainder or modulo|true|None|
<a name="sql.expression.ReplicateRows"></a>spark.rapids.sql.expression.ReplicateRows| |Given an input row replicates the row N times|true|None|
Expand Down
89 changes: 89 additions & 0 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -10611,6 +10611,95 @@ are limited.
<td> </td>
</tr>
<tr>
<td rowSpan="4">RegExpExtractAll</td>
<td rowSpan="4">`regexp_extract_all`</td>
<td rowSpan="4">Extract all strings matching a regular expression corresponding to the regex group index</td>
<td rowSpan="4">None</td>
<td rowSpan="4">project</td>
<td>str</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>regexp</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>Literal value only</em></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>idx</td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>Literal value only</em></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="5">RegExpReplace</td>
<td rowSpan="5">`regexp_replace`</td>
<td rowSpan="5">String replace using a regular expression pattern</td>
Expand Down
42 changes: 42 additions & 0 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,3 +1130,45 @@ def test_rlike_fallback_possessive_quantifier():
'a rlike "a*+"'),
'RLike',
conf=_regexp_conf)

def test_regexp_extract_all_idx_zero():
gen = mk_str_gen('[abcd]{0,3}[0-9]{0,3}-[0-9]{0,3}[abcd]{1,3}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'regexp_extract_all(a, "([a-d]+).*([0-9])", 0)',
'regexp_extract_all(a, "(a)(b)", 0)',
'regexp_extract_all(a, "([a-z0-9]([abcd]))", 0)',
'regexp_extract_all(a, "(\\\\d+)-(\\\\d+)", 0)',
),
conf=_regexp_conf)

def test_regexp_extract_all_idx_positive():
gen = mk_str_gen('[abcd]{0,3}[0-9]{0,3}-[0-9]{0,3}[abcd]{1,3}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'regexp_extract_all(a, "([a-d]+).*([0-9])", 1)',
'regexp_extract_all(a, "(a)(b)", 2)',
'regexp_extract_all(a, "([a-z0-9]((([abcd](\\\\d?)))))", 3)',
'regexp_extract_all(a, "(\\\\d+)-(\\\\d+)", 2)',
),
conf=_regexp_conf)

@allow_non_gpu('ProjectExec', 'RegExpExtractAll')
def test_regexp_extract_all_idx_negative():
gen = mk_str_gen('[abcd]{0,3}')
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, gen).selectExpr(
'regexp_extract_all(a, "(a)", -1)'
).collect(),
error_message="The specified group index cannot be less than zero",
conf=_regexp_conf)

@allow_non_gpu('ProjectExec', 'RegExpExtractAll')
def test_regexp_extract_all_idx_out_of_bounds():
gen = mk_str_gen('[abcd]{0,3}')
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, gen).selectExpr(
'regexp_extract_all(a, "([a-d]+).*([0-9])", 3)'
).collect(),
error_message="Regex group count is 2, but the specified group index is 3",
conf=_regexp_conf)
Original file line number Diff line number Diff line change
Expand Up @@ -3194,6 +3194,14 @@ object GpuOverrides extends Logging {
ParamCheck("idx", TypeSig.lit(TypeEnum.INT),
TypeSig.lit(TypeEnum.INT)))),
(a, conf, p, r) => new GpuRegExpExtractMeta(a, conf, p, r)),
expr[RegExpExtractAll](
"Extract all strings matching a regular expression corresponding to the regex group index",
ExprChecks.projectOnly(TypeSig.ARRAY.nested(TypeSig.STRING),
TypeSig.ARRAY.nested(TypeSig.STRING),
Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING),
ParamCheck("regexp", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING),
ParamCheck("idx", TypeSig.lit(TypeEnum.INT), TypeSig.INT))),
(a, conf, p, r) => new GpuRegExpExtractAllMeta(a, conf, p, r)),
expr[Length](
"String character length or binary byte length",
ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.{RegExpShim, ShimExpression}

import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ImplicitCastInputTypes, InputFileName, Literal, NullIntolerant, Predicate, RegExpExtract, RLike, StringSplit, StringToMap, SubstringIndex, TernaryExpression}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ImplicitCastInputTypes, InputFileName, Literal, NullIntolerant, Predicate, RegExpExtract, RegExpExtractAll, RLike, StringSplit, StringToMap, SubstringIndex, TernaryExpression}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -865,6 +865,20 @@ object GpuRegExpUtils {
isASTEmptyRepetition(parseAST(pattern))
}

/**
* Returns the number of groups in regexp
* (includes both capturing and non-capturing groups)
*/
def countGroups(pattern: String): Int = {
def countGroups(regexp: RegexAST): Int = {
regexp match {
case RegexGroup(_, term) => 1 + countGroups(term)
case other => other.children().map(countGroups).sum
}
}
countGroups(parseAST(pattern))
}

}

class GpuRLikeMeta(
Expand Down Expand Up @@ -1063,21 +1077,14 @@ class GpuRegExpExtractMeta(
case _ =>
}

def countGroups(regexp: RegexAST): Int = {
regexp match {
case RegexGroup(_, term) => 1 + countGroups(term)
case other => other.children().map(countGroups).sum
}
}

expr.regexp match {
case Literal(str: UTF8String, DataTypes.StringType) if str != null =>
try {
val javaRegexpPattern = str.toString
// verify that we support this regex and can transpile it to cuDF format
pattern = Some(new CudfRegexTranspiler(RegexFindMode)
.transpile(javaRegexpPattern, None)._1)
numGroups = countGroups(new RegexParser(javaRegexpPattern).parse())
numGroups = GpuRegExpUtils.countGroups(javaRegexpPattern)
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
Expand Down Expand Up @@ -1174,7 +1181,136 @@ case class GpuRegExpExtract(
}
}
}
}

class GpuRegExpExtractAllMeta(
expr: RegExpExtractAll,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends TernaryExprMeta[RegExpExtractAll](expr, conf, parent, rule) {

private var pattern: Option[String] = None
private var numGroups = 0

override def tagExprForGpu(): Unit = {
GpuRegExpUtils.tagForRegExpEnabled(this)

expr.regexp match {
case Literal(str: UTF8String, DataTypes.StringType) if str != null =>
try {
val javaRegexpPattern = str.toString
// verify that we support this regex and can transpile it to cuDF format
pattern = Some(new CudfRegexTranspiler(RegexFindMode)
.transpile(javaRegexpPattern, None)._1)
numGroups = GpuRegExpUtils.countGroups(javaRegexpPattern)
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
}
case _ =>
willNotWorkOnGpu(s"only non-null literal strings are supported on GPU")
}

expr.idx match {
case Literal(value, DataTypes.IntegerType) =>
val idx = value.asInstanceOf[Int]
if (idx < 0) {
willNotWorkOnGpu("the specified group index cannot be less than zero")
}
if (idx > numGroups) {
willNotWorkOnGpu(
s"regex group count is $numGroups, but the specified group index is $idx")
}
case _ =>
willNotWorkOnGpu("GPU only supports literal index")
}
}

override def convertToGpu(
str: Expression,
regexp: Expression,
idx: Expression): GpuExpression = {
val cudfPattern = pattern.getOrElse(
throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern"))
GpuRegExpExtractAll(str, regexp, idx, numGroups, cudfPattern)
}
}

case class GpuRegExpExtractAll(
str: Expression,
regexp: Expression,
idx: Expression,
numGroups: Int,
cudfRegexPattern: String)
extends GpuRegExpTernaryBase with ImplicitCastInputTypes with NullIntolerant {

override def dataType: DataType = ArrayType(StringType, containsNull = true)
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType)
override def first: Expression = str
override def second: Expression = regexp
override def third: Expression = idx

override def prettyName: String = "regexp_extract_all"

override def doColumnar(
str: GpuColumnVector,
regexp: GpuScalar,
idx: GpuScalar): ColumnVector = {

idx.getValue.asInstanceOf[Int] match {
case 0 =>
str.getBase.extractAllRecord(cudfRegexPattern, 0)
case intIdx =>
// Extract matches corresponding to idx. cuDF's extract_all_record does not support
// group idx, so we must manually extract the relevant matches. Example:
// Given the pattern (\d+)-(\d+) and idx=1
//
// | Input | Java | cuDF |
// |-----------------|-----------------|--------------------------------|
// | '1-2, 3-4, 5-6' | ['1', '3', '5'] | ['1', '2', '3', '4', '5', '6'] |
//
// Since idx=1 and the pattern has 2 capture groups, we take the 1st element and every
// 2nd element afterwards from the cuDF list

val rowCount = str.getRowCount

val extractedWithNulls = withResource(
str.getBase.extractAllRecord(cudfRegexPattern, intIdx)) { allExtracted =>
withResource(allExtracted.countElements) { listSizes =>
withResource(listSizes.max) { maxSize =>
val maxSizeInt = maxSize.getInt
val stringCols: Seq[ColumnVector] = Range(intIdx - 1, maxSizeInt, numGroups).map {
anthony-chang marked this conversation as resolved.
Show resolved Hide resolved
i =>
withResource(ColumnVector.fromScalar(Scalar.fromInt(i), rowCount.toInt)) {
index => allExtracted.extractListElement(index)
}
anthony-chang marked this conversation as resolved.
Show resolved Hide resolved
}
ColumnVector.makeList(stringCols: _*)
}
}
}
// Filter out null values in the lists
val extractedStrings = withResource(extractedWithNulls.getListOffsetsView) { offsetsCol =>
withResource(extractedWithNulls.getChildColumnView(0)) { stringCol =>
withResource(stringCol.isNotNull) { isNotNull =>
withResource(isNotNull.makeListFromOffsets(rowCount, offsetsCol)) { booleanMask =>
extractedWithNulls.applyBooleanMask(booleanMask)
}
}
}
}
// If input is null, output should also be null
withResource(GpuScalar.from(null, DataTypes.createArrayType(DataTypes.StringType))) {
nullStringList =>
withResource(str.getBase.isNull) { isInputNull =>
withResource(extractedStrings) {
anthony-chang marked this conversation as resolved.
Show resolved Hide resolved
isInputNull.ifElse(nullStringList, _)
}
}
}
}
}
}

class SubstringIndexMeta(
Expand Down
1 change: 1 addition & 0 deletions tools/src/main/resources/operatorsScore.csv
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ RaiseError,4
Rand,4
Rank,4
RegExpExtract,4
RegExpExtractAll,4
RegExpReplace,4
Remainder,4
ReplicateRows,4
Expand Down
4 changes: 4 additions & 0 deletions tools/src/main/resources/supportedExprs.csv
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,10 @@ RegExpExtract,S,`regexp_extract`,None,project,str,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,N
RegExpExtract,S,`regexp_extract`,None,project,regexp,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA
RegExpExtract,S,`regexp_extract`,None,project,idx,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
RegExpExtract,S,`regexp_extract`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA
RegExpExtractAll,S,`regexp_extract_all`,None,project,str,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA
RegExpExtractAll,S,`regexp_extract_all`,None,project,regexp,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA
RegExpExtractAll,S,`regexp_extract_all`,None,project,idx,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
RegExpExtractAll,S,`regexp_extract_all`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA
RegExpReplace,S,`regexp_replace`,None,project,regex,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA
RegExpReplace,S,`regexp_replace`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA
RegExpReplace,S,`regexp_replace`,None,project,pos,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
Expand Down