Skip to content

Commit

Permalink
Move expanding of character classes to rewrite stage (#5527)
Browse files Browse the repository at this point in the history
Signed-off-by: Anthony Chang <antchang@nvidia.com>
  • Loading branch information
anthony-chang authored May 18, 2022
1 parent c0fe6e4 commit 7c8071b
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 73 deletions.
160 changes: 87 additions & 73 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -342,40 +342,7 @@ class RegexParser(pattern: String) {
// string anchors
consumeExpected(ch)
RegexEscaped(ch)
case 'h' | 'H' =>
// horizontal whitespace
// see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html
// under "Predefined character classes"
val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer(
RegexChar(' '), RegexChar('\u00A0'), RegexChar('\u1680'), RegexChar('\u180e'),
RegexChar('\u202f'), RegexChar('\u205f'), RegexChar('\u3000')
)
chars += RegexEscaped('t')
chars += RegexCharacterRange('\u2000', '\u200a')
consumeExpected(ch)
RegexCharacterClass(negated = ch.isUpper, characters = chars)
case 'v' | 'V' =>
// vertical whitespace
// see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html
// under "Predefined character classes"
val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer(
RegexChar('\u000B'), RegexChar('\u0085'), RegexChar('\u2028'), RegexChar('\u2029')
)
chars ++= Seq('n', 'f', 'r').map(RegexEscaped)
consumeExpected(ch)
RegexCharacterClass(negated = ch.isUpper, characters = chars)
case 'R' =>
// linebreak sequence
// see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html
// under "Linebreak matcher"
val l = RegexSequence(ListBuffer(RegexChar('\u000D'), RegexChar('\u000A')))
val r = RegexCharacterClass(false, ListBuffer[RegexCharacterClassComponent](
RegexChar('\u000A'), RegexChar('\u000B'), RegexChar('\u000C'), RegexChar('\u000D'),
RegexChar('\u0085'), RegexChar('\u2028'), RegexChar('\u2029')
))
consumeExpected(ch)
RegexGroup(true, RegexChoice(l, r))
case 's' | 'S' | 'd' | 'D' | 'w' | 'W' =>
case 's' | 'S' | 'd' | 'D' | 'w' | 'W' | 'v' | 'V' | 'h' | 'H' | 'R' =>
// meta sequences
consumeExpected(ch)
RegexEscaped(ch)
Expand Down Expand Up @@ -640,7 +607,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
private def isSupportedRepetitionBase(e: RegexAST): Boolean = {
e match {
case RegexEscaped(ch) => ch match {
case 'd' | 'w' | 'h' | 'H' | 'v' | 'V' => true
case 'd' | 'w' | 's' | 'S' | 'h' | 'H' | 'v' | 'V' => true
case _ => false
}

Expand Down Expand Up @@ -689,6 +656,46 @@ class CudfRegexTranspiler(mode: RegexMode) {
}
}

private def negateCharacterClass(components: Seq[RegexCharacterClassComponent]): RegexAST = {
// There are differences between cuDF and Java handling of newlines
// for negative character matches. The expression `[^a]` will match
// `\r` and `\n` in Java but not in cuDF, so we replace `[^a]` with
// `(?:[\r\n]|[^a])`. We also have to take into account whether any
// newline characters are included in the character range.
//
// Examples:
//
// `[^a]` => `(?:[\r\n]|[^a])`
// `[^a\r]` => `(?:[\n]|[^a])`
// `[^a\n]` => `(?:[\r]|[^a])`
// `[^a\r\n]` => `[^a]`
// `[^\r\n]` => `[^\r\n]`

val linefeedCharsInPattern = components.flatMap {
case RegexChar(ch) if ch == '\n' || ch == '\r' => Seq(ch)
case RegexEscaped(ch) if ch == 'n' => Seq('\n')
case RegexEscaped(ch) if ch == 'r' => Seq('\r')
case _ => Seq.empty
}

val onlyLinefeedChars = components.length == linefeedCharsInPattern.length

val negatedNewlines = Seq('\r', '\n').diff(linefeedCharsInPattern.distinct)

if (onlyLinefeedChars && linefeedCharsInPattern.length == 2) {
// special case for `[^\r\n]` and `[^\\r\\n]`
RegexCharacterClass(negated = true, ListBuffer(components: _*))
} else if (negatedNewlines.isEmpty) {
RegexCharacterClass(negated = true, ListBuffer(components: _*))
} else {
RegexGroup(capture = false,
RegexChoice(
RegexCharacterClass(negated = false,
characters = ListBuffer(negatedNewlines.map(RegexChar): _*)),
RegexCharacterClass(negated = true, ListBuffer(components: _*))))
}
}

private def rewrite(regex: RegexAST, replacement: Option[RegexReplacement],
previous: Option[RegexAST]): RegexAST = {
regex match {
Expand Down Expand Up @@ -814,10 +821,53 @@ class CudfRegexTranspiler(mode: RegexMode) {
rewrite(RegexChar('$'), replacement, previous)
}
case 's' | 'S' =>
// whitespace characters
val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer(
RegexChar(' '), RegexChar('\u000b'))
chars ++= Seq('n', 't', 'r', 'f').map(RegexEscaped)
RegexCharacterClass(negated = ch.isUpper, characters = chars)
if (ch.isUpper) {
negateCharacterClass(chars)
} else {
RegexCharacterClass(negated = false, characters = chars)
}
case 'h' | 'H' =>
// horizontal whitespace
// see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html
// under "Predefined character classes"
val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer(
RegexChar(' '), RegexChar('\u00A0'), RegexChar('\u1680'), RegexChar('\u180e'),
RegexChar('\u202f'), RegexChar('\u205f'), RegexChar('\u3000')
)
chars += RegexEscaped('t')
chars += RegexCharacterRange('\u2000', '\u200a')
if (ch.isUpper) {
negateCharacterClass(chars)
} else {
RegexCharacterClass(negated = false, characters = chars)
}
case 'v' | 'V' =>
// vertical whitespace
// see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html
// under "Predefined character classes"
val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer(
RegexChar('\u000B'), RegexChar('\u0085'), RegexChar('\u2028'), RegexChar('\u2029')
)
chars ++= Seq('n', 'f', 'r').map(RegexEscaped)
if (ch.isUpper) {
negateCharacterClass(chars)
} else {
RegexCharacterClass(negated = false, characters = chars)
}
case 'R' =>
// linebreak sequence
// see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html
// under "Linebreak matcher"
val l = RegexSequence(ListBuffer(RegexChar('\u000D'), RegexChar('\u000A')))
val r = RegexCharacterClass(false, ListBuffer[RegexCharacterClassComponent](
RegexChar('\u000A'), RegexChar('\u000B'), RegexChar('\u000C'), RegexChar('\u000D'),
RegexChar('\u0085'), RegexChar('\u2028'), RegexChar('\u2029')
))
RegexGroup(true, RegexChoice(l, r))
case _ =>
regex
}
Expand Down Expand Up @@ -859,43 +909,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
}

if (negated) {
// There are differences between cuDF and Java handling of newlines
// for negative character matches. The expression `[^a]` will match
// `\r` and `\n` in Java but not in cuDF, so we replace `[^a]` with
// `(?:[\r\n]|[^a])`. We also have to take into account whether any
// newline characters are included in the character range.
//
// Examples:
//
// `[^a]` => `(?:[\r\n]|[^a])`
// `[^a\r]` => `(?:[\n]|[^a])`
// `[^a\n]` => `(?:[\r]|[^a])`
// `[^a\r\n]` => `[^a]`
// `[^\r\n]` => `[^\r\n]`

val linefeedCharsInPattern = components.flatMap {
case RegexChar(ch) if ch == '\n' || ch == '\r' => Seq(ch)
case RegexEscaped(ch) if ch == 'n' => Seq('\n')
case RegexEscaped(ch) if ch == 'r' => Seq('\r')
case _ => Seq.empty
}

val onlyLinefeedChars = components.length == linefeedCharsInPattern.length

val negatedNewlines = Seq('\r', '\n').diff(linefeedCharsInPattern.distinct)

if (onlyLinefeedChars && linefeedCharsInPattern.length == 2) {
// special case for `[^\r\n]` and `[^\\r\\n]`
RegexCharacterClass(negated = true, ListBuffer(components: _*))
} else if (negatedNewlines.isEmpty) {
RegexCharacterClass(negated = true, ListBuffer(components: _*))
} else {
RegexGroup(capture = false,
RegexChoice(
RegexCharacterClass(negated = false,
characters = ListBuffer(negatedNewlines.map(RegexChar): _*)),
RegexCharacterClass(negated = true, ListBuffer(components: _*))))
}
negateCharacterClass(components)
} else {
RegexCharacterClass(negated, ListBuffer(components: _*))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,8 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
doTranspileTest("a\\Z+", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$")
doTranspileTest("a\\Z{1}", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$")
doTranspileTest("a\\Z{1,}", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$")
doTranspileTest("a\\Z\\V",
"a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$[^\u000B\u0085\u2028\u2029\n\f\r]")
}

test("compare CPU and GPU: character range including unescaped + and -") {
Expand Down

0 comments on commit 7c8071b

Please sign in to comment.