Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug with character class immediately following a string anchor #5527

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 87 additions & 73 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -342,40 +342,7 @@ class RegexParser(pattern: String) {
// string anchors
consumeExpected(ch)
RegexEscaped(ch)
case 'h' | 'H' =>
// horizontal whitespace
// see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html
// under "Predefined character classes"
val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer(
RegexChar(' '), RegexChar('\u00A0'), RegexChar('\u1680'), RegexChar('\u180e'),
RegexChar('\u202f'), RegexChar('\u205f'), RegexChar('\u3000')
)
chars += RegexEscaped('t')
chars += RegexCharacterRange('\u2000', '\u200a')
consumeExpected(ch)
RegexCharacterClass(negated = ch.isUpper, characters = chars)
case 'v' | 'V' =>
// vertical whitespace
// see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html
// under "Predefined character classes"
val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer(
RegexChar('\u000B'), RegexChar('\u0085'), RegexChar('\u2028'), RegexChar('\u2029')
)
chars ++= Seq('n', 'f', 'r').map(RegexEscaped)
consumeExpected(ch)
RegexCharacterClass(negated = ch.isUpper, characters = chars)
case 'R' =>
// linebreak sequence
// see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html
// under "Linebreak matcher"
val l = RegexSequence(ListBuffer(RegexChar('\u000D'), RegexChar('\u000A')))
val r = RegexCharacterClass(false, ListBuffer[RegexCharacterClassComponent](
RegexChar('\u000A'), RegexChar('\u000B'), RegexChar('\u000C'), RegexChar('\u000D'),
RegexChar('\u0085'), RegexChar('\u2028'), RegexChar('\u2029')
))
consumeExpected(ch)
RegexGroup(true, RegexChoice(l, r))
case 's' | 'S' | 'd' | 'D' | 'w' | 'W' =>
case 's' | 'S' | 'd' | 'D' | 'w' | 'W' | 'v' | 'V' | 'h' | 'H' | 'R' =>
// meta sequences
consumeExpected(ch)
RegexEscaped(ch)
Expand Down Expand Up @@ -647,7 +614,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
private def isSupportedRepetitionBase(e: RegexAST): Boolean = {
e match {
case RegexEscaped(ch) => ch match {
case 'd' | 'w' | 'h' | 'H' | 'v' | 'V' => true
case 'd' | 'w' | 's' | 'S' | 'h' | 'H' | 'v' | 'V' => true
case _ => false
}

Expand Down Expand Up @@ -696,6 +663,46 @@ class CudfRegexTranspiler(mode: RegexMode) {
}
}

private def negateCharacterClass(components: Seq[RegexCharacterClassComponent]): RegexAST = {
// There are differences between cuDF and Java handling of newlines
// for negative character matches. The expression `[^a]` will match
// `\r` and `\n` in Java but not in cuDF, so we replace `[^a]` with
// `(?:[\r\n]|[^a])`. We also have to take into account whether any
// newline characters are included in the character range.
//
// Examples:
//
// `[^a]` => `(?:[\r\n]|[^a])`
// `[^a\r]` => `(?:[\n]|[^a])`
// `[^a\n]` => `(?:[\r]|[^a])`
// `[^a\r\n]` => `[^a]`
// `[^\r\n]` => `[^\r\n]`

val linefeedCharsInPattern = components.flatMap {
case RegexChar(ch) if ch == '\n' || ch == '\r' => Seq(ch)
case RegexEscaped(ch) if ch == 'n' => Seq('\n')
case RegexEscaped(ch) if ch == 'r' => Seq('\r')
case _ => Seq.empty
}

val onlyLinefeedChars = components.length == linefeedCharsInPattern.length

val negatedNewlines = Seq('\r', '\n').diff(linefeedCharsInPattern.distinct)

if (onlyLinefeedChars && linefeedCharsInPattern.length == 2) {
// special case for `[^\r\n]` and `[^\\r\\n]`
RegexCharacterClass(negated = true, ListBuffer(components: _*))
} else if (negatedNewlines.isEmpty) {
RegexCharacterClass(negated = true, ListBuffer(components: _*))
} else {
RegexGroup(capture = false,
RegexChoice(
RegexCharacterClass(negated = false,
characters = ListBuffer(negatedNewlines.map(RegexChar): _*)),
RegexCharacterClass(negated = true, ListBuffer(components: _*))))
}
}

private def rewrite(regex: RegexAST, replacement: Option[RegexReplacement],
previous: Option[RegexAST]): RegexAST = {
regex match {
Expand Down Expand Up @@ -821,10 +828,53 @@ class CudfRegexTranspiler(mode: RegexMode) {
rewrite(RegexChar('$'), replacement, previous)
}
case 's' | 'S' =>
// whitespace characters
val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer(
RegexChar(' '), RegexChar('\u000b'))
chars ++= Seq('n', 't', 'r', 'f').map(RegexEscaped)
RegexCharacterClass(negated = ch.isUpper, characters = chars)
if (ch.isUpper) {
negateCharacterClass(chars)
} else {
RegexCharacterClass(negated = false, characters = chars)
}
case 'h' | 'H' =>
// horizontal whitespace
// see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html
// under "Predefined character classes"
val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer(
RegexChar(' '), RegexChar('\u00A0'), RegexChar('\u1680'), RegexChar('\u180e'),
RegexChar('\u202f'), RegexChar('\u205f'), RegexChar('\u3000')
)
chars += RegexEscaped('t')
chars += RegexCharacterRange('\u2000', '\u200a')
if (ch.isUpper) {
negateCharacterClass(chars)
} else {
RegexCharacterClass(negated = false, characters = chars)
}
case 'v' | 'V' =>
// vertical whitespace
// see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html
// under "Predefined character classes"
val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer(
RegexChar('\u000B'), RegexChar('\u0085'), RegexChar('\u2028'), RegexChar('\u2029')
)
chars ++= Seq('n', 'f', 'r').map(RegexEscaped)
if (ch.isUpper) {
negateCharacterClass(chars)
} else {
RegexCharacterClass(negated = false, characters = chars)
}
case 'R' =>
// linebreak sequence
// see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html
// under "Linebreak matcher"
val l = RegexSequence(ListBuffer(RegexChar('\u000D'), RegexChar('\u000A')))
val r = RegexCharacterClass(false, ListBuffer[RegexCharacterClassComponent](
RegexChar('\u000A'), RegexChar('\u000B'), RegexChar('\u000C'), RegexChar('\u000D'),
RegexChar('\u0085'), RegexChar('\u2028'), RegexChar('\u2029')
))
RegexGroup(true, RegexChoice(l, r))
case _ =>
regex
}
Expand Down Expand Up @@ -859,43 +909,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
})

if (negated) {
// There are differences between cuDF and Java handling of newlines
// for negative character matches. The expression `[^a]` will match
// `\r` and `\n` in Java but not in cuDF, so we replace `[^a]` with
// `(?:[\r\n]|[^a])`. We also have to take into account whether any
// newline characters are included in the character range.
//
// Examples:
//
// `[^a]` => `(?:[\r\n]|[^a])`
// `[^a\r]` => `(?:[\n]|[^a])`
// `[^a\n]` => `(?:[\r]|[^a])`
// `[^a\r\n]` => `[^a]`
// `[^\r\n]` => `[^\r\n]`

val linefeedCharsInPattern = components.flatMap {
case RegexChar(ch) if ch == '\n' || ch == '\r' => Seq(ch)
case RegexEscaped(ch) if ch == 'n' => Seq('\n')
case RegexEscaped(ch) if ch == 'r' => Seq('\r')
case _ => Seq.empty
}

val onlyLinefeedChars = components.length == linefeedCharsInPattern.length

val negatedNewlines = Seq('\r', '\n').diff(linefeedCharsInPattern.distinct)

if (onlyLinefeedChars && linefeedCharsInPattern.length == 2) {
// special case for `[^\r\n]` and `[^\\r\\n]`
RegexCharacterClass(negated = true, ListBuffer(components: _*))
} else if (negatedNewlines.isEmpty) {
RegexCharacterClass(negated = true, ListBuffer(components: _*))
} else {
RegexGroup(capture = false,
RegexChoice(
RegexCharacterClass(negated = false,
characters = ListBuffer(negatedNewlines.map(RegexChar): _*)),
RegexCharacterClass(negated = true, ListBuffer(components: _*))))
}
negateCharacterClass(components)
} else {
RegexCharacterClass(negated, ListBuffer(components: _*))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
doTranspileTest("a\\Z+", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$")
doTranspileTest("a\\Z{1}", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$")
doTranspileTest("a\\Z{1,}", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$")
doTranspileTest("a\\Z\\V",
"a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$[^\u000B\u0085\u2028\u2029\n\f\r]")
}

test("compare CPU and GPU: character range including unescaped + and -") {
Expand Down