Skip to content

Commit

Permalink
detect unsupported regexp edge cases
Browse files Browse the repository at this point in the history
Signed-off-by: Andy Grove <andygrove@nvidia.com>
  • Loading branch information
andygrove committed May 24, 2022
1 parent eb0f23e commit 79883d4
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 20 deletions.
114 changes: 113 additions & 1 deletion sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ class RegexParser(pattern: String) {
throw new RegexUnsupportedException("escape at end of string", Some(pos))
case Some(ch) =>
ch match {
case 'r' | 'n' | 'f' =>
// newlines
consumeExpected(ch)
RegexEscaped(ch)
case 'A' | 'Z' | 'z' =>
// string anchors
consumeExpected(ch)
Expand Down Expand Up @@ -559,7 +563,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
val replacement = repl.map(s => new RegexParser(s).parseReplacement(countCaptureGroups(regex)))

// validate that the regex is supported by cuDF
val cudfRegex = rewrite(regex, replacement, None)
val cudfRegex = transpile(regex, replacement, None)
// write out to regex string, performing minor transformations
// such as adding additional escaping
(cudfRegex.toRegexString, replacement.map(_.toRegexString))
Expand Down Expand Up @@ -696,6 +700,99 @@ class CudfRegexTranspiler(mode: RegexMode) {
}
}

private def transpile(regex: RegexAST, replacement: Option[RegexReplacement],
previous: Option[RegexAST]): RegexAST = {

def isMaybeEndAnchor(regex: RegexAST): Boolean = {
contains(regex, {
case RegexChar('$') | RegexEscaped('z') | RegexEscaped('Z') => true
case _ => false
})
}

def isMaybeNewline(regex: RegexAST): Boolean = {
contains(regex, {
case RegexChar('\r') | RegexEscaped('r') => true
case RegexChar('\n') | RegexEscaped('n') => true
case RegexChar('\f') | RegexEscaped('f') => true
case RegexEscaped('s') => true
case RegexEscaped('W') | RegexEscaped('D') =>
// these would get transpiled to negated character classes
// that include newlines
true
case RegexCharacterClass(true, _) => true
//TODO others?
case _ => false
})
}

def isMaybeEmpty(regex: RegexAST): Boolean = {
contains(regex, {
case RegexRepetition(_, term) => term match {
case SimpleQuantifier('*') | SimpleQuantifier('?') => true
case QuantifierFixedLength(0) => true
case QuantifierVariableLength(0, _) => true
case _ => false
}
case _ => false
})
}

def checkUnsupported(r1: RegexAST, r2: RegexAST): Unit = {
// println(s"checkUnsupported $r1 and $r2")
if (isMaybeEndAnchor(r1)) {
// println("r1 is anchor")
val a = isMaybeEmpty(r2)
val b = isMaybeNewline(r2)
val c = isMaybeEndAnchor(r2)
// println(s"r2 isMaybeEmpty: $a, isMaybeNewline, isMaybeEndAnchor: $c")
if (a || b || c) {
//TODO we should throw a more specific error
throw new RegexUnsupportedException(
"End of line/string anchor is not supported in this context")
}
}
if (isMaybeEndAnchor(r2)) {
// println("r2 is anchor")
val a = isMaybeEmpty(r1)
val b = isMaybeNewline(r1)
val c = isMaybeEndAnchor(r1)
// println(s"r1 isMaybeEmpty: $a, isMaybeNewline: $b, isMaybeEndAnchor: $c")
if (a || b || c) {
//TODO we should throw a more specific error
throw new RegexUnsupportedException(
"End of line/string anchor is not supported in this context")
}
}
// println("everything is fine")
}

def checkEndAnchorNearNewline(regex: RegexAST): Unit = {
regex match {
case RegexSequence(parts) =>
parts.indices.foreach { i =>
if (i > 0) {
checkUnsupported(parts(i - 1), parts(i))
}
if (i + 1 > parts.length) {
checkUnsupported(parts(i), parts(i+1))
}
}
case RegexChoice(l, r) =>
checkEndAnchorNearNewline(l)
checkEndAnchorNearNewline(r)
case RegexGroup(_, term) => checkEndAnchorNearNewline(term)
case RegexRepetition(ast, _) => checkEndAnchorNearNewline(ast)
case _ =>
// ignore
}
}

checkEndAnchorNearNewline(regex)

rewrite(regex, replacement, previous)
}

private def rewrite(regex: RegexAST, replacement: Option[RegexReplacement],
previous: Option[RegexAST]): RegexAST = {
regex match {
Expand Down Expand Up @@ -1149,6 +1246,21 @@ class CudfRegexTranspiler(mode: RegexMode) {
}
}

private def contains(regex: RegexAST, f: RegexAST => Boolean): Boolean = {
if (f(regex)) {
true
} else {
regex match {
case RegexSequence(parts) => parts.exists(x => contains(x, f))
case RegexGroup(_, term) => contains(term, f)
case RegexChoice(l, r) => contains(l, f) || contains(r, f)
case RegexRepetition(term, _) => contains(term, f)
case RegexCharacterClass(_, chars) => chars.exists(ch => contains(ch, f))
case leaf => f(leaf)
}
}
}

private def isBeginOrEndLineAnchor(regex: RegexAST): Boolean = regex match {
case RegexSequence(parts) => parts.nonEmpty && parts.forall(isBeginOrEndLineAnchor)
case RegexGroup(_, term) => isBeginOrEndLineAnchor(term)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,14 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
}

test("newline before $ in replace mode") {
val patterns = Seq("\r$", "\n$", "\r\n$", "\u0085$", "\u2028$", "\u2029$")
//TODO would be good if we could have more consistent error messages
val patterns = Seq("\r$", "\n$", "\r\n$")
patterns.foreach(pattern =>
assertUnsupported(pattern, RegexReplaceMode,
"End of line/string anchor is not supported in this context")
)
val patterns2 = Seq("\u0085$", "\u2028$", "\u2029$")
patterns2.foreach(pattern =>
assertUnsupported(pattern, RegexReplaceMode,
"Regex sequences with a line terminator character followed by " +
"'$' are not supported in replace mode")
Expand Down Expand Up @@ -238,21 +244,27 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
}

test("line anchor sequence $\\n fall back to CPU") {
assertUnsupported("a$\n", RegexFindMode, "regex sequence $\\n is not supported")
assertUnsupported("a$\n", RegexFindMode,
"End of line/string anchor is not supported in this context")
}

test("line anchor $ - find") {
val patterns = Seq("$\r", "a$", "\r$", "\f$", "$\f", "\u0085$", "\u2028$", "\u2029$", "\n$",
"\r\n$", "[\r\n]?$", "\\00*[D$3]$", "a$b")
// are these special cases that we can allow? what is the rule?
// maybe these are ok because the $ is at the end of the pattern?
// "\r$", "\f$", "\u0085$", "\u2028$", "\u2029$", "\n$", "\r\n$", "[\r\n]?$"
// not sure about these ...
// "$\r", "$\f", "\\00*[D$3]$"
val patterns = Seq("a$", "a$b")
val inputs = Seq("a", "a\n", "a\r", "a\r\n", "a\u0085\n", "a\f", "\f", "\r", "\u0085", "\u2028",
"\u2029", "\n", "\r\n", "\r\n\r", "\r\n\u0085", "\u0085\r", "\u2028\n", "\u2029\n", "\n\r",
"\n\u0085", "\n\u2028", "\n\u2029", "2+|+??wD\n", "a\r\nb")
assertCpuGpuMatchesRegexpFind(patterns, inputs)
}

test("string anchor \\Z - find") {
val patterns = Seq("\\Z\r", "a\\Z", "\r\\Z", "\f\\Z", "\\Z\f", "\u0085\\Z", "\u2028\\Z",
"\u2029\\Z", "\n\\Z", "\r\n\\Z", "[\r\n]?\\Z", "\\00*[D$3]\\Z", "a\\Zb", "a\\Z+")
// "\\Z\r", "\r\\Z", "\f\\Z", "\\Z\f", "\u0085\\Z", "\u2028\\Z",
// "\u2029\\Z", "\n\\Z", "\r\n\\Z", "[\r\n]?\\Z", "\\00*[D$3]\\Z",
val patterns = Seq("a\\Z", "a\\Zb", "a\\Z+")
val inputs = Seq("a", "a\n", "a\r", "a\r\n", "a\u0085\n", "a\f", "\f", "\r", "\u0085", "\u2028",
"\u2029", "\n", "\r\n", "\r\n\r", "\r\n\u0085", "\u0085\r", "\u2028\n", "\u2029\n", "\n\r",
"\n\u0085", "\n\u2028", "\n\u2029", "2+|+??wD\n", "a\r\nb")
Expand Down Expand Up @@ -308,7 +320,9 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
assert(transpiled === expected)
}

test("transpile complex regex 2") {
// this was a test for transpiling but we did not ever try to run the
// resulting regex to see if it produced valid results
ignore("transpile complex regex 2") {
val TIMESTAMP_TRUNCATE_REGEX = "^([0-9]{4}-[0-9]{2}-[0-9]{2} " +
"[0-9]{2}:[0-9]{2}:[0-9]{2})" +
"(.[1-9]*(?:0)?[1-9]+)?(.0*[1-9]+)?(?:.0*)?\\z"
Expand All @@ -334,18 +348,18 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {

test("transpile $") {
doTranspileTest("a$", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$")
doTranspileTest("$$\r", "\r(?:[\n\u0085\u2028\u2029])?$")
doTranspileTest("]$\r", "]\r(?:[\n\u0085\u2028\u2029])?$")
doTranspileTest("^$[^*A-ZA-Z]", "^(?:[\n\r\u0085\u2028\u2029])$")
doTranspileTest("^$([^*A-ZA-Z])", "^([\n\r\u0085\u2028\u2029])$")
// doTranspileTest("$$\r", "\r(?:[\n\u0085\u2028\u2029])?$")
// doTranspileTest("]$\r", "]\r(?:[\n\u0085\u2028\u2029])?$")
// doTranspileTest("^$[^*A-ZA-Z]", "^(?:[\n\r\u0085\u2028\u2029])$")
// doTranspileTest("^$([^*A-ZA-Z])", "^([\n\r\u0085\u2028\u2029])$")
}

test("transpile \\Z") {
doTranspileTest("a\\Z", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$")
doTranspileTest("\\Z\\Z\r", "\r(?:[\n\u0085\u2028\u2029])?$")
doTranspileTest("]\\Z\r", "]\r(?:[\n\u0085\u2028\u2029])?$")
doTranspileTest("^\\Z[^*A-ZA-Z]", "^(?:[\n\r\u0085\u2028\u2029])$")
doTranspileTest("^\\Z([^*A-ZA-Z])", "^([\n\r\u0085\u2028\u2029])$")
// doTranspileTest("\\Z\\Z\r", "\r(?:[\n\u0085\u2028\u2029])?$")
// doTranspileTest("]\\Z\r", "]\r(?:[\n\u0085\u2028\u2029])?$")
// doTranspileTest("^\\Z[^*A-ZA-Z]", "^(?:[\n\r\u0085\u2028\u2029])$")
// doTranspileTest("^\\Z([^*A-ZA-Z])", "^([\n\r\u0085\u2028\u2029])$")
doTranspileTest("a\\Z+", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$")
doTranspileTest("a\\Z{1}", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$")
doTranspileTest("a\\Z{1,}", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$")
Expand All @@ -371,7 +385,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
assertCpuGpuMatchesRegexpFind(patterns, inputs)
}

private val REGEXP_LIMITED_CHARS_COMMON = "|()[]{},.^$*+?abc123x\\ \t\r\n\f\u000bBsdwSDWzZ"
private val REGEXP_LIMITED_CHARS_COMMON = "|()[]{},.^$*+?abcf123x\\ \t\r\n\f\u000bBsdwSDWzZ_"

private val REGEXP_LIMITED_CHARS_FIND = REGEXP_LIMITED_CHARS_COMMON

Expand All @@ -383,6 +397,18 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
assertCpuGpuMatchesRegexpFind(patterns, inputs)
}

test("fall back to CPU for newline next to line or string anchor") {
// these patterns were discovered during fuzz testing and resulted in different
// results between CPU and GPU
val patterns = Seq(raw"\w[\r,B]\Z", raw"\s\Z\Z", "^$\\s", "$x*\\r", "$\\r")
for (mode <- Seq(RegexFindMode, RegexReplaceMode)) {
patterns.foreach(pattern => {
assertUnsupported(pattern, mode,
"End of line/string anchor is not supported in this context")
})
}
}

test("fall back to CPU for \\D") {
// see https://github.com/NVIDIA/spark-rapids/issues/4475
for (mode <- Seq(RegexFindMode, RegexReplaceMode)) {
Expand Down Expand Up @@ -461,16 +487,22 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {

test("compare CPU and GPU: regexp replace line anchor supported use cases") {
val inputs = Seq("a", "b", "c", "cat", "", "^", "$", "^a", "t$")
val patterns = Seq("^a", "^a", "(^a|^t)", "^[ac]", "^^^a", "[\\^^]", "a$", "a$$", "\\$$")
val patterns = Seq("^a", "^a", "(^a|^t)", "^[ac]", "^^^a", "[\\^^]", "a$", "\\$$")
// "a$$"
assertCpuGpuMatchesRegexpReplace(patterns, inputs)
}

test("cuDF does not support some uses of line anchors in regexp_replace") {
Seq("^$", "^", "$", "(^)($)", "(((^^^)))$", "^*", "$*", "^+", "$+", "^|$", "^^|$$").foreach(
Seq("^$", "^", "$", "(^)($)", "(((^^^)))$", "^*", "$*", "^+", "$+", "^|$").foreach(
pattern =>
assertUnsupported(pattern, RegexReplaceMode,
"sequences that only contain '^' or '$' are not supported")
)
Seq("^^|$$").foreach(
pattern =>
assertUnsupported(pattern, RegexReplaceMode,
"End of line/string anchor is not supported in this context")
)
}

test("compare CPU and GPU: regexp replace negated character class") {
Expand Down Expand Up @@ -826,7 +858,9 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
val e = intercept[RegexUnsupportedException] {
transpile(pattern, mode)
}
assert(e.getMessage.startsWith(message), pattern)
if (!e.getMessage.startsWith(message)) {
fail(s"Pattern '$pattern': Error was [${e.getMessage}] but expected [$message]'")
}
}

private def parse(pattern: String): RegexAST = new RegexParser(pattern).parse()
Expand Down

0 comments on commit 79883d4

Please sign in to comment.