From c0391c662c43335a6ca83deeab760870504bb1f0 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 3 Jun 2022 10:12:56 -0600 Subject: [PATCH] Fall back to CPU for unsupported regular expression edge cases with end of line/string anchors and newlines (#5610) --- docs/compatibility.md | 3 +- .../com/nvidia/spark/rapids/RegexParser.scala | 126 +++++++++++++++--- .../spark/rapids/RegularExpressionSuite.scala | 12 +- .../RegularExpressionTranspilerSuite.scala | 110 ++++++++------- 4 files changed, 173 insertions(+), 78 deletions(-) diff --git a/docs/compatibility.md b/docs/compatibility.md index ca59894767b3..a940e1d53d1d 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -592,7 +592,8 @@ The following regular expression patterns are not yet supported on the GPU and w - Line anchor `$` is not supported by `regexp_replace`, and in some rare contexts. - String anchor `\Z` is not supported by `regexp_replace`, and in some rare contexts. - String anchor `\z` is not supported by `regexp_replace` -- Line anchor `$` and string anchors `\z` and `\Z` are not supported in patterns containing `\W` or `\D` +- Patterns containing an end of line or string anchor immediately next to a newline or repetition that produces zero + or more results - Line and string anchors are not supported by `string_split` and `str_to_map` - Word and non-word boundaries, `\b` and `\B` - Whitespace and non-whitespace characters, `\s` and `\S` diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 085a39ae23b2..e1f928746f36 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -19,6 +19,8 @@ import java.sql.SQLException import scala.collection.mutable.ListBuffer +import com.nvidia.spark.rapids.RegexParser.toReadableString + /** * Regular expression parser based on a Pratt Parser design. * @@ -555,6 +557,21 @@ object RegexParser { true } } + + def toReadableString(x: String): String = { + x.map { + case '\r' => "\\r" + case '\n' => "\\n" + case '\t' => "\\t" + case '\f' => "\\f" + case '\u000b' => "\\u000b" + case '\u0085' => "\\u0085" + case '\u2028' => "\\u2028" + case '\u2029' => "\\u2029" + case other => other + }.mkString + } + } sealed trait RegexMode @@ -745,24 +762,83 @@ class CudfRegexTranspiler(mode: RegexMode) { private def transpile(regex: RegexAST, replacement: Option[RegexReplacement], previous: Option[RegexAST]): RegexAST = { - // look for patterns that we know are problematic before we attempt to rewrite the expression - val negatedWordOrDigit = contains(regex, { - case RegexEscaped('W') | RegexEscaped('D') => true - case _ => false - }) - val endOfLineAnchor = contains(regex, { - case RegexChar('$') | RegexEscaped('Z') | RegexEscaped('z') => true - case _ => false - }) + def containsBeginAnchor(regex: RegexAST): Boolean = { + contains(regex, { + case RegexChar('^') | RegexEscaped('A') => true + case _ => false + }) + } - // this check is quite broad and could potentially be refined to look for \W or \D - // immediately next to a line anchor - if (negatedWordOrDigit && endOfLineAnchor) { - throw new RegexUnsupportedException( - "Combination of \\W or \\D with line anchor $ " + - "or string anchors \\z or \\Z is not supported") + def containsEndAnchor(regex: RegexAST): Boolean = { + contains(regex, { + case RegexChar('$') | RegexEscaped('z') | RegexEscaped('Z') => true + case _ => false + }) + } + + def containsNewline(regex: RegexAST): Boolean = { + contains(regex, { + case RegexChar('\r') | RegexEscaped('r') => true + case RegexChar('\n') | RegexEscaped('n') => true + case RegexChar('\u0085') | RegexChar('\u2028') | RegexChar('\u2029') => true + case RegexEscaped('s') | RegexEscaped('v') | RegexEscaped('R') => true + case RegexEscaped('W') | RegexEscaped('D') | RegexEscaped('S') | RegexEscaped('V') => + // these would get transpiled to negated character classes + // that include newlines + true + case RegexCharacterClass(true, _) => true + case _ => false + }) + } + + def containsEmpty(regex: RegexAST): Boolean = { + contains(regex, { + case RegexRepetition(_, term) => term match { + case SimpleQuantifier('*') | SimpleQuantifier('?') => true + case QuantifierFixedLength(0) => true + case QuantifierVariableLength(0, _) => true + case _ => false + } + case _ => false + }) } + // check a pair of regex ast nodes for unsupported combinations + // of end string/line anchors and newlines or optional items + def checkEndAnchorContext(r1: RegexAST, r2: RegexAST): Unit = { + if ((containsEndAnchor(r1) && + (containsNewline(r2) || containsEmpty(r2) || containsBeginAnchor(r2))) || + (containsEndAnchor(r2) && + (containsNewline(r1) || containsEmpty(r1) || containsBeginAnchor(r1)))) { + throw new RegexUnsupportedException( + s"End of line/string anchor is not supported in this context: " + + s"${toReadableString(r1.toRegexString)}" + + s"${toReadableString(r2.toRegexString)}") + } + } + + def checkUnsupported(regex: RegexAST): Unit = { + regex match { + case RegexSequence(parts) => + for (i <- 1 until parts.length) { + checkEndAnchorContext(parts(i - 1), parts(i)) + } + case RegexChoice(l, r) => + checkUnsupported(l) + checkUnsupported(r) + case RegexGroup(_, term) => checkUnsupported(term) + case RegexRepetition(ast, _) => checkUnsupported(ast) + case RegexCharacterClass(_, components) => + for (i <- 1 until components.length) { + checkEndAnchorContext(components(i - 1), components(i)) + } + case _ => + // ignore + } + } + + checkUnsupported(regex) + rewrite(regex, replacement, previous) } @@ -1231,13 +1307,19 @@ class CudfRegexTranspiler(mode: RegexMode) { } } - private def contains(regex: RegexAST, f: RegexAST => Boolean): Boolean = regex match { - case RegexSequence(parts) => parts.exists(x => contains(x, f)) - case RegexGroup(_, term) => contains(term, f) - case RegexChoice(l, r) => contains(l, f) || contains(r, f) - case RegexRepetition(term, _) => contains(term, f) - case RegexCharacterClass(_, chars) => chars.exists(ch => contains(ch, f)) - case leaf => f(leaf) + private def contains(regex: RegexAST, f: RegexAST => Boolean): Boolean = { + if (f(regex)) { + true + } else { + regex match { + case RegexSequence(parts) => parts.exists(x => contains(x, f)) + case RegexGroup(_, term) => contains(term, f) + case RegexChoice(l, r) => contains(l, f) || contains(r, f) + case RegexRepetition(term, _) => contains(term, f) + case RegexCharacterClass(_, chars) => chars.exists(ch => contains(ch, f)) + case leaf => f(leaf) + } + } } private def isBeginOrEndLineAnchor(regex: RegexAST): Boolean = regex match { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala index b90f6812ea05..a4c45487eef9 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala @@ -91,13 +91,17 @@ class RegularExpressionSuite extends SparkQueryCompareTestSuite { frame => frame.selectExpr("regexp_replace(strings,'\\(foo\\)','D')") } - testSparkResultsAreEqual("String regexp_extract regex 1", - extractStrings, conf = conf) { + // https://github.com/NVIDIA/spark-rapids/issues/5659 + testGpuFallback("String regexp_extract regex 1", + "ProjectExec", extractStrings, conf = conf, + execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec")) { frame => frame.selectExpr("regexp_extract(strings, '^([a-z]*)([0-9]*)([a-z]*)$', 1)") } - testSparkResultsAreEqual("String regexp_extract regex 2", - extractStrings, conf = conf) { + // https://github.com/NVIDIA/spark-rapids/issues/5659 + testGpuFallback("String regexp_extract regex 2", + "ProjectExec", extractStrings, conf = conf, + execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec")) { frame => frame.selectExpr("regexp_extract(strings, '^([a-z]*)([0-9]*)([a-z]*)$', 2)") } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index d20a91eadf6b..6435fa0c6ef2 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.{HashSet, ListBuffer} import scala.util.{Random, Try} import ai.rapids.cudf.{ColumnVector, CudfException} +import com.nvidia.spark.rapids.RegexParser.toReadableString import org.scalatest.FunSuite import org.apache.spark.sql.rapids.GpuRegExpUtils @@ -78,8 +79,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { val patterns = Seq("\\W\\Z\\D", "\\W$", "$\\D") patterns.foreach(pattern => assertUnsupported(pattern, RegexFindMode, - "Combination of \\W or \\D with line anchor $ " + - "or string anchors \\z or \\Z is not supported") + "End of line/string anchor is not supported in this context") ) } @@ -119,8 +119,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { val patterns = Seq("\r$", "\n$", "\r\n$", "\u0085$", "\u2028$", "\u2029$") patterns.foreach(pattern => assertUnsupported(pattern, RegexReplaceMode, - "Regex sequences with a line terminator character followed by " + - "'$' are not supported in replace mode") + "End of line/string anchor is not supported in this context") ) } @@ -266,25 +265,36 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } test("line anchor sequence $\\n fall back to CPU") { - assertUnsupported("a$\n", RegexFindMode, "regex sequence $\\n is not supported") + assertUnsupported("a$\n", RegexFindMode, + "End of line/string anchor is not supported in this context") } test("line anchor $ - find") { - val patterns = Seq("$\r", "a$", "\r$", "\f$", "$\f", "\u0085$", "\u2028$", "\u2029$", "\n$", - "\r\n$", "[\r\n]?$", "\\00*[D$3]$", "a$b") + val patterns = Seq("a$", "a$b", "\f$", "$\f") val inputs = Seq("a", "a\n", "a\r", "a\r\n", "a\u0085\n", "a\f", "\f", "\r", "\u0085", "\u2028", "\u2029", "\n", "\r\n", "\r\n\r", "\r\n\u0085", "\u0085\r", "\u2028\n", "\u2029\n", "\n\r", "\n\u0085", "\n\u2028", "\n\u2029", "2+|+??wD\n", "a\r\nb") assertCpuGpuMatchesRegexpFind(patterns, inputs) + val unsupportedPatterns = Seq("[\r\n]?$", "$\r", "\r$", + "\u0085$", "\u2028$", "\u2029$", "\n$", "\r\n$", "\\00*[D$3]$") + for (pattern <- unsupportedPatterns) { + assertUnsupported(pattern, RegexFindMode, + "End of line/string anchor is not supported in this context") + } } test("string anchor \\Z - find") { - val patterns = Seq("\\Z\r", "a\\Z", "\r\\Z", "\f\\Z", "\\Z\f", "\u0085\\Z", "\u2028\\Z", - "\u2029\\Z", "\n\\Z", "\r\n\\Z", "[\r\n]?\\Z", "\\00*[D$3]\\Z", "a\\Zb", "a\\Z+") + val patterns = Seq("a\\Z", "a\\Zb", "a\\Z+", "\f\\Z", "\\Z\f") val inputs = Seq("a", "a\n", "a\r", "a\r\n", "a\u0085\n", "a\f", "\f", "\r", "\u0085", "\u2028", "\u2029", "\n", "\r\n", "\r\n\r", "\r\n\u0085", "\u0085\r", "\u2028\n", "\u2029\n", "\n\r", "\n\u0085", "\n\u2028", "\n\u2029", "2+|+??wD\n", "a\r\nb") assertCpuGpuMatchesRegexpFind(patterns, inputs) + val unsupportedPatterns = Seq("[\r\n]?\\Z", "\\Z\r", "\r\\Z", + "\u0085\\Z", "\u2028\\Z", "\u2029\\Z", "\n\\Z", "\r\n\\Z", "\\00*[D$3]\\Z") + for (pattern <- unsupportedPatterns) { + assertUnsupported(pattern, RegexFindMode, + "End of line/string anchor is not supported in this context") + } } test("whitespace boundaries - replace") { @@ -336,7 +346,10 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { assert(transpiled === expected) } - test("transpile complex regex 2") { + // this was a test for transpiling but we did not ever try to run the + // resulting regex to see if it produced valid results + // see https://github.com/NVIDIA/spark-rapids/issues/5656 + ignore("transpile complex regex 2") { val TIMESTAMP_TRUNCATE_REGEX = "^([0-9]{4}-[0-9]{2}-[0-9]{2} " + "[0-9]{2}:[0-9]{2}:[0-9]{2})" + "(.[1-9]*(?:0)?[1-9]+)?(.0*[1-9]+)?(?:.0*)?\\z" @@ -362,23 +375,13 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("transpile $") { doTranspileTest("a$", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$") - doTranspileTest("$$\r", "\r(?:[\n\u0085\u2028\u2029])?$") - doTranspileTest("]$\r", "]\r(?:[\n\u0085\u2028\u2029])?$") - doTranspileTest("^$[^*A-ZA-Z]", "^(?:[\n\r\u0085\u2028\u2029])$") - doTranspileTest("^$([^*A-ZA-Z])", "^([\n\r\u0085\u2028\u2029])$") } test("transpile \\Z") { doTranspileTest("a\\Z", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$") - doTranspileTest("\\Z\\Z\r", "\r(?:[\n\u0085\u2028\u2029])?$") - doTranspileTest("]\\Z\r", "]\r(?:[\n\u0085\u2028\u2029])?$") - doTranspileTest("^\\Z[^*A-ZA-Z]", "^(?:[\n\r\u0085\u2028\u2029])$") - doTranspileTest("^\\Z([^*A-ZA-Z])", "^([\n\r\u0085\u2028\u2029])$") doTranspileTest("a\\Z+", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$") doTranspileTest("a\\Z{1}", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$") doTranspileTest("a\\Z{1,}", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$") - doTranspileTest("a\\Z\\V", - "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$[^\u000B\u0085\u2028\u2029\n\f\r]") } test("compare CPU and GPU: character range including unescaped + and -") { @@ -399,6 +402,8 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { assertCpuGpuMatchesRegexpFind(patterns, inputs) } + // Patterns are generated from these characters + // '&' is absent due to https://github.com/NVIDIA/spark-rapids/issues/5655 private val REGEXP_LIMITED_CHARS_COMMON = "|()[]{},.^$*+?abc123x\\ \t\r\n\f\u000bBsdwSDWzZ" private val REGEXP_LIMITED_CHARS_FIND = REGEXP_LIMITED_CHARS_COMMON @@ -480,11 +485,16 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } test("cuDF does not support some uses of line anchors in regexp_replace") { - Seq("^$", "^", "$", "(^)($)", "(((^^^)))$", "^*", "$*", "^+", "$+", "^|$", "^^|$$").foreach( + Seq("^", "$", "^*", "$*", "^+", "$+", "^|$", "^^|$$").foreach( pattern => assertUnsupported(pattern, RegexReplaceMode, "sequences that only contain '^' or '$' are not supported") ) + Seq("^$", "(^)($)", "(((^^^)))$").foreach( + pattern => + assertUnsupported(pattern, RegexReplaceMode, + "End of line/string anchor is not supported in this context") + ) } test("compare CPU and GPU: regexp replace negated character class") { @@ -551,11 +561,18 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } test("AST fuzz test - regexp_find") { - doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_FIND), RegexFindMode) + doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_FIND), REGEXP_LIMITED_CHARS_FIND, + RegexFindMode) } test("AST fuzz test - regexp_replace") { - doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_REPLACE), RegexReplaceMode) + doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_REPLACE), REGEXP_LIMITED_CHARS_REPLACE, + RegexReplaceMode) + } + + test("AST fuzz test - regexp_find - anchor focused") { + doAstFuzzTest(validDataChars = Some("\r\nabc"), + validPatternChars = "^$\\AZz\r\n()[]-", mode = RegexFindMode) } test("string split - optimized") { @@ -596,7 +613,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("string split fuzz") { val (data, patterns) = generateDataAndPatterns(Some(REGEXP_LIMITED_CHARS_REPLACE), - RegexSplitMode) + REGEXP_LIMITED_CHARS_REPLACE, RegexSplitMode) for (limit <- Seq(-2, -1, 2, 5)) { doStringSplitTest(patterns, data, limit) } @@ -657,8 +674,9 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } } - private def doAstFuzzTest(validChars: Option[String], mode: RegexMode) { - val (data, patterns) = generateDataAndPatterns(validChars, mode) + private def doAstFuzzTest(validDataChars: Option[String], validPatternChars: String, + mode: RegexMode) { + val (data, patterns) = generateDataAndPatterns(validDataChars, validPatternChars, mode) if (mode == RegexReplaceMode) { assertCpuGpuMatchesRegexpReplace(patterns.toSeq, data) } else { @@ -666,17 +684,19 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } } - private def generateDataAndPatterns(validChars: Option[String], mode: RegexMode) - : (Seq[String], Set[String]) = { - val r = new EnhancedRandom(new Random(seed = 0L), - FuzzerOptions(validChars, maxStringLen = 12)) + private def generateDataAndPatterns( + validDataChars: Option[String], + validPatternChars: String, + mode: RegexMode): (Seq[String], Set[String]) = { - val fuzzer = new FuzzRegExp(REGEXP_LIMITED_CHARS_FIND) + val dataGen = new EnhancedRandom(new Random(seed = 0L), + FuzzerOptions(validDataChars, maxStringLen = 12)) val data = Range(0, 1000) - .map(_ => r.nextString()) + .map(_ => dataGen.nextString()) // generate patterns that are valid on both CPU and GPU + val fuzzer = new FuzzRegExp(validPatternChars) val patterns = HashSet[String]() while (patterns.size < 5000) { val pattern = fuzzer.generate(0).toRegexString @@ -690,7 +710,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } private def assertCpuGpuMatchesRegexpFind(javaPatterns: Seq[String], input: Seq[String]) = { - for (javaPattern <- javaPatterns) { + for ((javaPattern, patternIndex) <- javaPatterns.zipWithIndex) { val cpu = cpuContains(javaPattern, input) val (cudfPattern, _) = new CudfRegexTranspiler(RegexFindMode).transpile(javaPattern, None) @@ -702,7 +722,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } for (i <- input.indices) { if (cpu(i) != gpu(i)) { - fail(s"javaPattern=${toReadableString(javaPattern)}, " + + fail(s"javaPattern[$patternIndex]=${toReadableString(javaPattern)}, " + s"cudfPattern=${toReadableString(cudfPattern)}, " + s"input='${toReadableString(input(i))}', " + s"cpu=${cpu(i)}, gpu=${gpu(i)}") @@ -714,7 +734,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { private def assertCpuGpuMatchesRegexpReplace( javaPatterns: Seq[String], input: Seq[String]) = { - for (javaPattern <- javaPatterns) { + for ((javaPattern, patternIndex) <- javaPatterns.zipWithIndex) { val cpu = cpuReplace(javaPattern, input) val (cudfPattern, replaceString) = (new CudfRegexTranspiler(RegexReplaceMode)).transpile(javaPattern, @@ -729,7 +749,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } for (i <- input.indices) { if (cpu(i) != gpu(i)) { - fail(s"javaPattern=${toReadableString(javaPattern)}, " + + fail(s"javaPattern[$patternIndex]=${toReadableString(javaPattern)}, " + s"cudfPattern=${toReadableString(cudfPattern)}, " + s"input='${toReadableString(input(i))}', " + s"cpu=${toReadableString(cpu(i))}, " + @@ -777,20 +797,6 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { result } - private def toReadableString(x: String): String = { - x.map { - case '\r' => "\\r" - case '\n' => "\\n" - case '\t' => "\\t" - case '\f' => "\\f" - case '\u000b' => "\\u000b" - case '\u0085' => "\\u0085" - case '\u2028' => "\\u2028" - case '\u2029' => "\\u2029" - case other => other - }.mkString - } - private def cpuContains(pattern: String, input: Seq[String]): Array[Boolean] = { val p = Pattern.compile(pattern) input.map(s => p.matcher(s).find(0)).toArray @@ -840,7 +846,9 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { val e = intercept[RegexUnsupportedException] { transpile(pattern, mode) } - assert(e.getMessage.startsWith(message), pattern) + if (!e.getMessage.startsWith(message)) { + fail(s"Pattern '$pattern': Error was [${e.getMessage}] but expected [$message]'") + } } private def parse(pattern: String): RegexAST = new RegexParser(pattern).parse()