Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for regexp_replace with back-references #5087

Merged
merged 8 commits into from
Apr 2, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,16 +498,26 @@ def test_re_replace():
'REGEXP_REPLACE(a, "TEST", NULL)'),
conf=_regexp_conf)

@allow_non_gpu('ProjectExec', 'RegExpReplace')
def test_re_replace_backrefs():
gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}')
assert_gpu_fallback_collect(
gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}TEST')
andygrove marked this conversation as resolved.
Show resolved Hide resolved
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
andygrove marked this conversation as resolved.
Show resolved Hide resolved
'REGEXP_REPLACE(a, "(TEST)", "[$0]")',
'REGEXP_REPLACE(a, "(TEST)", "[$1]")'),
'RegExpReplace',
'REGEXP_REPLACE(a, "(TEST)", "[\\1]")',
'REGEXP_REPLACE(a, "(T)[a-z]+(T)", "[$2][$1][$0]")',
'REGEXP_REPLACE(a, "([0-9]+)(T)[a-z]+(T)", "[$3][$2][$1]")',
'REGEXP_REPLACE(a, "(TESTT)", "\\0 \\1")' # no match
),
conf=_regexp_conf)

# For GPU runs, cuDF will check the range and throw exception if index is out of range
def test_re_replace_backrefs_idx_out_of_bounds():
gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}')
assert_gpu_and_cpu_error(lambda spark: unary_op_df(spark, gen).selectExpr(
'REGEXP_REPLACE(a, "(T)(E)(S)(T)", "[$5]")').collect(),
conf=_regexp_conf,
error_message='')

def test_re_replace_backrefs_escaped():
gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}')
assert_gpu_and_cpu_are_equal_collect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
package com.nvidia.spark.rapids

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpReplaceWithBackref, GpuRegExpUtils, GpuStringReplace}
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -29,13 +29,15 @@ class GpuRegExpReplaceMeta(

private var pattern: Option[String] = None
private var replacement: Option[String] = None
private var canUseGpuStringReplace = false
private var containsBackref: Boolean = false

override def tagExprForGpu(): Unit = {
GpuRegExpUtils.tagForRegExpEnabled(this)
expr.regexp match {
case Literal(s: UTF8String, DataTypes.StringType) if s != null =>
if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) {
// use GpuStringReplace
canUseGpuStringReplace = true
} else {
try {
pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString))
Expand All @@ -51,10 +53,11 @@ class GpuRegExpReplaceMeta(

expr.rep match {
case Literal(s: UTF8String, DataTypes.StringType) if s != null =>
if (GpuRegExpUtils.containsBackrefs(s.toString)) {
willNotWorkOnGpu("regexp_replace with back-references is not supported")
GpuRegExpUtils.backrefConversion(s.toString) match {
case (hasBackref, convertedRep) =>
containsBackref = hasBackref
replacement = Some(GpuRegExpUtils.unescapeReplaceString(convertedRep))
}
replacement = Some(GpuRegExpUtils.unescapeReplaceString(s.toString))
case _ =>
}

Expand All @@ -73,13 +76,16 @@ class GpuRegExpReplaceMeta(
// ignore the pos expression which must be a literal 1 after tagging check
require(childExprs.length == 4,
s"Unexpected child count for RegExpReplace: ${childExprs.length}")
val Seq(subject, regexp, rep) = childExprs.take(3).map(_.convertToGpu())
if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) {
GpuStringReplace(subject, regexp, rep)
if (canUseGpuStringReplace) {
GpuStringReplace(lhs, regexp, rep)
} else {
(pattern, replacement) match {
case (Some(cudfPattern), Some(cudfReplacement)) =>
GpuRegExpReplace(lhs, regexp, rep, cudfPattern, cudfReplacement)
if (containsBackref) {
GpuRegExpReplaceWithBackref(lhs, cudfPattern, cudfReplacement)
} else {
GpuRegExpReplace(lhs, regexp, rep, cudfPattern, cudfReplacement)
}
case _ =>
throw new IllegalStateException("Expression has not been tagged correctly")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -773,24 +773,44 @@ case class GpuLike(left: Expression, right: Expression, escapeChar: Char)
object GpuRegExpUtils {

/**
* Determine if a string contains back-references such as `$1` but ignoring
* if preceded by escape character.
* Convert symbols of back-references if input string contains any.
* In spark's regex rule, there are two patterns of back-references:
* \group_index and $group_index
* This method transforms above two patterns into cuDF pattern ${group_index}, except they are
* preceded by escape character.
*
* @param rep replacement string
* @return A pair consists of a boolean indicating whether containing any backref and the
* converted replacement.
*/
def containsBackrefs(s: String): Boolean = {
def backrefConversion(rep: String): (Boolean, String) = {
val b = new StringBuilder
var i = 0
while (i < s.length) {
if (s.charAt(i) == '\\') {
while (i < rep.length) {
// match $group_index or \group_index
if (Seq('$', '\\').contains(rep.charAt(i))
&& i + 1 < rep.length && rep.charAt(i + 1).isDigit) {

b.append("${")
var j = i + 1
while (j + 1 < rep.length && rep.charAt(j).isDigit) {
b.append(rep.charAt(j))
j += 1
}
andygrove marked this conversation as resolved.
Show resolved Hide resolved
b.append("}")
i = j
} else if (rep.charAt(i) == '\\' && i + 1 < rep.length) {
// skip potential \$group_index or \\group_index
b.append('\\').append(rep.charAt(i + 1))
i += 2
} else {
if (s.charAt(i) == '$' && i+1 < s.length) {
if (s.charAt(i+1).isDigit) {
return true
}
}
b.append(rep.charAt(i))
i += 1
}
}
false

val converted = b.toString
!rep.equals(converted) -> converted
}

/**
Expand Down Expand Up @@ -956,6 +976,22 @@ case class GpuRegExpReplace(

}

case class GpuRegExpReplaceWithBackref(
override val child: Expression,
cudfRegexPattern: String,
cudfReplacementString: String)
extends GpuUnaryExpression with ImplicitCastInputTypes {

override def inputTypes: Seq[DataType] = Seq(StringType)

override def dataType: DataType = StringType

override protected def doColumnar(input: GpuColumnVector): ColumnVector = {
input.getBase.stringReplaceWithBackrefs(cudfRegexPattern, cudfReplacementString)
}

}

class GpuRegExpExtractMeta(
expr: RegExpExtract,
conf: RapidsConf,
Expand Down