Skip to content

Commit

Permalink
Fall back to CPU for unsupported regular expression edge cases with e…
Browse files Browse the repository at this point in the history
…nd of line/string anchors and newlines (NVIDIA#5610)
  • Loading branch information
andygrove authored and HaoYang670 committed Jun 6, 2022
1 parent 1c5824d commit c0391c6
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 78 deletions.
3 changes: 2 additions & 1 deletion docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,8 @@ The following regular expression patterns are not yet supported on the GPU and w
- Line anchor `$` is not supported by `regexp_replace`, and in some rare contexts.
- String anchor `\Z` is not supported by `regexp_replace`, and in some rare contexts.
- String anchor `\z` is not supported by `regexp_replace`
- Line anchor `$` and string anchors `\z` and `\Z` are not supported in patterns containing `\W` or `\D`
- Patterns containing an end of line or string anchor immediately next to a newline or repetition that produces zero
or more results
- Line and string anchors are not supported by `string_split` and `str_to_map`
- Word and non-word boundaries, `\b` and `\B`
- Whitespace and non-whitespace characters, `\s` and `\S`
Expand Down
126 changes: 104 additions & 22 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import java.sql.SQLException

import scala.collection.mutable.ListBuffer

import com.nvidia.spark.rapids.RegexParser.toReadableString

/**
* Regular expression parser based on a Pratt Parser design.
*
Expand Down Expand Up @@ -555,6 +557,21 @@ object RegexParser {
true
}
}

def toReadableString(x: String): String = {
x.map {
case '\r' => "\\r"
case '\n' => "\\n"
case '\t' => "\\t"
case '\f' => "\\f"
case '\u000b' => "\\u000b"
case '\u0085' => "\\u0085"
case '\u2028' => "\\u2028"
case '\u2029' => "\\u2029"
case other => other
}.mkString
}

}

sealed trait RegexMode
Expand Down Expand Up @@ -745,24 +762,83 @@ class CudfRegexTranspiler(mode: RegexMode) {
private def transpile(regex: RegexAST, replacement: Option[RegexReplacement],
previous: Option[RegexAST]): RegexAST = {

// look for patterns that we know are problematic before we attempt to rewrite the expression
val negatedWordOrDigit = contains(regex, {
case RegexEscaped('W') | RegexEscaped('D') => true
case _ => false
})
val endOfLineAnchor = contains(regex, {
case RegexChar('$') | RegexEscaped('Z') | RegexEscaped('z') => true
case _ => false
})
def containsBeginAnchor(regex: RegexAST): Boolean = {
contains(regex, {
case RegexChar('^') | RegexEscaped('A') => true
case _ => false
})
}

// this check is quite broad and could potentially be refined to look for \W or \D
// immediately next to a line anchor
if (negatedWordOrDigit && endOfLineAnchor) {
throw new RegexUnsupportedException(
"Combination of \\W or \\D with line anchor $ " +
"or string anchors \\z or \\Z is not supported")
def containsEndAnchor(regex: RegexAST): Boolean = {
contains(regex, {
case RegexChar('$') | RegexEscaped('z') | RegexEscaped('Z') => true
case _ => false
})
}

def containsNewline(regex: RegexAST): Boolean = {
contains(regex, {
case RegexChar('\r') | RegexEscaped('r') => true
case RegexChar('\n') | RegexEscaped('n') => true
case RegexChar('\u0085') | RegexChar('\u2028') | RegexChar('\u2029') => true
case RegexEscaped('s') | RegexEscaped('v') | RegexEscaped('R') => true
case RegexEscaped('W') | RegexEscaped('D') | RegexEscaped('S') | RegexEscaped('V') =>
// these would get transpiled to negated character classes
// that include newlines
true
case RegexCharacterClass(true, _) => true
case _ => false
})
}

def containsEmpty(regex: RegexAST): Boolean = {
contains(regex, {
case RegexRepetition(_, term) => term match {
case SimpleQuantifier('*') | SimpleQuantifier('?') => true
case QuantifierFixedLength(0) => true
case QuantifierVariableLength(0, _) => true
case _ => false
}
case _ => false
})
}

// check a pair of regex ast nodes for unsupported combinations
// of end string/line anchors and newlines or optional items
def checkEndAnchorContext(r1: RegexAST, r2: RegexAST): Unit = {
if ((containsEndAnchor(r1) &&
(containsNewline(r2) || containsEmpty(r2) || containsBeginAnchor(r2))) ||
(containsEndAnchor(r2) &&
(containsNewline(r1) || containsEmpty(r1) || containsBeginAnchor(r1)))) {
throw new RegexUnsupportedException(
s"End of line/string anchor is not supported in this context: " +
s"${toReadableString(r1.toRegexString)}" +
s"${toReadableString(r2.toRegexString)}")
}
}

def checkUnsupported(regex: RegexAST): Unit = {
regex match {
case RegexSequence(parts) =>
for (i <- 1 until parts.length) {
checkEndAnchorContext(parts(i - 1), parts(i))
}
case RegexChoice(l, r) =>
checkUnsupported(l)
checkUnsupported(r)
case RegexGroup(_, term) => checkUnsupported(term)
case RegexRepetition(ast, _) => checkUnsupported(ast)
case RegexCharacterClass(_, components) =>
for (i <- 1 until components.length) {
checkEndAnchorContext(components(i - 1), components(i))
}
case _ =>
// ignore
}
}

checkUnsupported(regex)

rewrite(regex, replacement, previous)
}

Expand Down Expand Up @@ -1231,13 +1307,19 @@ class CudfRegexTranspiler(mode: RegexMode) {
}
}

private def contains(regex: RegexAST, f: RegexAST => Boolean): Boolean = regex match {
case RegexSequence(parts) => parts.exists(x => contains(x, f))
case RegexGroup(_, term) => contains(term, f)
case RegexChoice(l, r) => contains(l, f) || contains(r, f)
case RegexRepetition(term, _) => contains(term, f)
case RegexCharacterClass(_, chars) => chars.exists(ch => contains(ch, f))
case leaf => f(leaf)
private def contains(regex: RegexAST, f: RegexAST => Boolean): Boolean = {
if (f(regex)) {
true
} else {
regex match {
case RegexSequence(parts) => parts.exists(x => contains(x, f))
case RegexGroup(_, term) => contains(term, f)
case RegexChoice(l, r) => contains(l, f) || contains(r, f)
case RegexRepetition(term, _) => contains(term, f)
case RegexCharacterClass(_, chars) => chars.exists(ch => contains(ch, f))
case leaf => f(leaf)
}
}
}

private def isBeginOrEndLineAnchor(regex: RegexAST): Boolean = regex match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,17 @@ class RegularExpressionSuite extends SparkQueryCompareTestSuite {
frame => frame.selectExpr("regexp_replace(strings,'\\(foo\\)','D')")
}

testSparkResultsAreEqual("String regexp_extract regex 1",
extractStrings, conf = conf) {
// https://github.com/NVIDIA/spark-rapids/issues/5659
testGpuFallback("String regexp_extract regex 1",
"ProjectExec", extractStrings, conf = conf,
execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec")) {
frame => frame.selectExpr("regexp_extract(strings, '^([a-z]*)([0-9]*)([a-z]*)$', 1)")
}

testSparkResultsAreEqual("String regexp_extract regex 2",
extractStrings, conf = conf) {
// https://github.com/NVIDIA/spark-rapids/issues/5659
testGpuFallback("String regexp_extract regex 2",
"ProjectExec", extractStrings, conf = conf,
execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec")) {
frame => frame.selectExpr("regexp_extract(strings, '^([a-z]*)([0-9]*)([a-z]*)$', 2)")
}

Expand Down
Loading

0 comments on commit c0391c6

Please sign in to comment.