From 7c8071b5c9817facb30d9c2b26ebeef8f4f65aa8 Mon Sep 17 00:00:00 2001 From: Anthony Chang <54450499+anthony-chang@users.noreply.github.com> Date: Wed, 18 May 2022 15:38:06 -0400 Subject: [PATCH] Move expanding of character classes to rewrite stage (#5527) Signed-off-by: Anthony Chang --- .../com/nvidia/spark/rapids/RegexParser.scala | 160 ++++++++++-------- .../RegularExpressionTranspilerSuite.scala | 2 + 2 files changed, 89 insertions(+), 73 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index b1d0d33c7e8..ee57e121512 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -342,40 +342,7 @@ class RegexParser(pattern: String) { // string anchors consumeExpected(ch) RegexEscaped(ch) - case 'h' | 'H' => - // horizontal whitespace - // see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html - // under "Predefined character classes" - val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer( - RegexChar(' '), RegexChar('\u00A0'), RegexChar('\u1680'), RegexChar('\u180e'), - RegexChar('\u202f'), RegexChar('\u205f'), RegexChar('\u3000') - ) - chars += RegexEscaped('t') - chars += RegexCharacterRange('\u2000', '\u200a') - consumeExpected(ch) - RegexCharacterClass(negated = ch.isUpper, characters = chars) - case 'v' | 'V' => - // vertical whitespace - // see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html - // under "Predefined character classes" - val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer( - RegexChar('\u000B'), RegexChar('\u0085'), RegexChar('\u2028'), RegexChar('\u2029') - ) - chars ++= Seq('n', 'f', 'r').map(RegexEscaped) - consumeExpected(ch) - RegexCharacterClass(negated = ch.isUpper, characters = chars) - case 'R' => - // linebreak sequence - // see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html - // under "Linebreak matcher" - val l = RegexSequence(ListBuffer(RegexChar('\u000D'), RegexChar('\u000A'))) - val r = RegexCharacterClass(false, ListBuffer[RegexCharacterClassComponent]( - RegexChar('\u000A'), RegexChar('\u000B'), RegexChar('\u000C'), RegexChar('\u000D'), - RegexChar('\u0085'), RegexChar('\u2028'), RegexChar('\u2029') - )) - consumeExpected(ch) - RegexGroup(true, RegexChoice(l, r)) - case 's' | 'S' | 'd' | 'D' | 'w' | 'W' => + case 's' | 'S' | 'd' | 'D' | 'w' | 'W' | 'v' | 'V' | 'h' | 'H' | 'R' => // meta sequences consumeExpected(ch) RegexEscaped(ch) @@ -640,7 +607,7 @@ class CudfRegexTranspiler(mode: RegexMode) { private def isSupportedRepetitionBase(e: RegexAST): Boolean = { e match { case RegexEscaped(ch) => ch match { - case 'd' | 'w' | 'h' | 'H' | 'v' | 'V' => true + case 'd' | 'w' | 's' | 'S' | 'h' | 'H' | 'v' | 'V' => true case _ => false } @@ -689,6 +656,46 @@ class CudfRegexTranspiler(mode: RegexMode) { } } + private def negateCharacterClass(components: Seq[RegexCharacterClassComponent]): RegexAST = { + // There are differences between cuDF and Java handling of newlines + // for negative character matches. The expression `[^a]` will match + // `\r` and `\n` in Java but not in cuDF, so we replace `[^a]` with + // `(?:[\r\n]|[^a])`. We also have to take into account whether any + // newline characters are included in the character range. + // + // Examples: + // + // `[^a]` => `(?:[\r\n]|[^a])` + // `[^a\r]` => `(?:[\n]|[^a])` + // `[^a\n]` => `(?:[\r]|[^a])` + // `[^a\r\n]` => `[^a]` + // `[^\r\n]` => `[^\r\n]` + + val linefeedCharsInPattern = components.flatMap { + case RegexChar(ch) if ch == '\n' || ch == '\r' => Seq(ch) + case RegexEscaped(ch) if ch == 'n' => Seq('\n') + case RegexEscaped(ch) if ch == 'r' => Seq('\r') + case _ => Seq.empty + } + + val onlyLinefeedChars = components.length == linefeedCharsInPattern.length + + val negatedNewlines = Seq('\r', '\n').diff(linefeedCharsInPattern.distinct) + + if (onlyLinefeedChars && linefeedCharsInPattern.length == 2) { + // special case for `[^\r\n]` and `[^\\r\\n]` + RegexCharacterClass(negated = true, ListBuffer(components: _*)) + } else if (negatedNewlines.isEmpty) { + RegexCharacterClass(negated = true, ListBuffer(components: _*)) + } else { + RegexGroup(capture = false, + RegexChoice( + RegexCharacterClass(negated = false, + characters = ListBuffer(negatedNewlines.map(RegexChar): _*)), + RegexCharacterClass(negated = true, ListBuffer(components: _*)))) + } + } + private def rewrite(regex: RegexAST, replacement: Option[RegexReplacement], previous: Option[RegexAST]): RegexAST = { regex match { @@ -814,10 +821,53 @@ class CudfRegexTranspiler(mode: RegexMode) { rewrite(RegexChar('$'), replacement, previous) } case 's' | 'S' => + // whitespace characters val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer( RegexChar(' '), RegexChar('\u000b')) chars ++= Seq('n', 't', 'r', 'f').map(RegexEscaped) - RegexCharacterClass(negated = ch.isUpper, characters = chars) + if (ch.isUpper) { + negateCharacterClass(chars) + } else { + RegexCharacterClass(negated = false, characters = chars) + } + case 'h' | 'H' => + // horizontal whitespace + // see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html + // under "Predefined character classes" + val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer( + RegexChar(' '), RegexChar('\u00A0'), RegexChar('\u1680'), RegexChar('\u180e'), + RegexChar('\u202f'), RegexChar('\u205f'), RegexChar('\u3000') + ) + chars += RegexEscaped('t') + chars += RegexCharacterRange('\u2000', '\u200a') + if (ch.isUpper) { + negateCharacterClass(chars) + } else { + RegexCharacterClass(negated = false, characters = chars) + } + case 'v' | 'V' => + // vertical whitespace + // see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html + // under "Predefined character classes" + val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer( + RegexChar('\u000B'), RegexChar('\u0085'), RegexChar('\u2028'), RegexChar('\u2029') + ) + chars ++= Seq('n', 'f', 'r').map(RegexEscaped) + if (ch.isUpper) { + negateCharacterClass(chars) + } else { + RegexCharacterClass(negated = false, characters = chars) + } + case 'R' => + // linebreak sequence + // see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html + // under "Linebreak matcher" + val l = RegexSequence(ListBuffer(RegexChar('\u000D'), RegexChar('\u000A'))) + val r = RegexCharacterClass(false, ListBuffer[RegexCharacterClassComponent]( + RegexChar('\u000A'), RegexChar('\u000B'), RegexChar('\u000C'), RegexChar('\u000D'), + RegexChar('\u0085'), RegexChar('\u2028'), RegexChar('\u2029') + )) + RegexGroup(true, RegexChoice(l, r)) case _ => regex } @@ -859,43 +909,7 @@ class CudfRegexTranspiler(mode: RegexMode) { } if (negated) { - // There are differences between cuDF and Java handling of newlines - // for negative character matches. The expression `[^a]` will match - // `\r` and `\n` in Java but not in cuDF, so we replace `[^a]` with - // `(?:[\r\n]|[^a])`. We also have to take into account whether any - // newline characters are included in the character range. - // - // Examples: - // - // `[^a]` => `(?:[\r\n]|[^a])` - // `[^a\r]` => `(?:[\n]|[^a])` - // `[^a\n]` => `(?:[\r]|[^a])` - // `[^a\r\n]` => `[^a]` - // `[^\r\n]` => `[^\r\n]` - - val linefeedCharsInPattern = components.flatMap { - case RegexChar(ch) if ch == '\n' || ch == '\r' => Seq(ch) - case RegexEscaped(ch) if ch == 'n' => Seq('\n') - case RegexEscaped(ch) if ch == 'r' => Seq('\r') - case _ => Seq.empty - } - - val onlyLinefeedChars = components.length == linefeedCharsInPattern.length - - val negatedNewlines = Seq('\r', '\n').diff(linefeedCharsInPattern.distinct) - - if (onlyLinefeedChars && linefeedCharsInPattern.length == 2) { - // special case for `[^\r\n]` and `[^\\r\\n]` - RegexCharacterClass(negated = true, ListBuffer(components: _*)) - } else if (negatedNewlines.isEmpty) { - RegexCharacterClass(negated = true, ListBuffer(components: _*)) - } else { - RegexGroup(capture = false, - RegexChoice( - RegexCharacterClass(negated = false, - characters = ListBuffer(negatedNewlines.map(RegexChar): _*)), - RegexCharacterClass(negated = true, ListBuffer(components: _*)))) - } + negateCharacterClass(components) } else { RegexCharacterClass(negated, ListBuffer(components: _*)) } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 3d8fdc34625..12d26123e08 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -349,6 +349,8 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { doTranspileTest("a\\Z+", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$") doTranspileTest("a\\Z{1}", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$") doTranspileTest("a\\Z{1,}", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$") + doTranspileTest("a\\Z\\V", + "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$[^\u000B\u0085\u2028\u2029\n\f\r]") } test("compare CPU and GPU: character range including unescaped + and -") {