Skip to content

Commit

Permalink
Add support for POSIX predefined character classes (#5692)
Browse files Browse the repository at this point in the history
* Add POSIX character classes

Signed-off-by: Anthony Chang <antchang@nvidia.com>

* Update fuzz tests with all punctuation, fix transpilation bug with \z

Signed-off-by: Anthony Chang <antchang@nvidia.com>

* Address feedback

Signed-off-by: Anthony Chang <antchang@nvidia.com>

* Remove _ from fuzz tests

Signed-off-by: Anthony Chang <antchang@nvidia.com>

* Add documentation for null character in \p{Cntrl}

Signed-off-by: Anthony Chang <antchang@nvidia.com>
  • Loading branch information
anthony-chang authored Jun 7, 2022
1 parent 7eed898 commit 99e4c30
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 4 deletions.
3 changes: 3 additions & 0 deletions docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,9 @@ These are the known edge cases where running on the GPU will produce different r
- Regular expressions that contain an end of line anchor '$' or end of string anchor '\Z' or '\z' immediately
next to a newline or a repetition that produces zero or more results
([#5610](https://github.com/NVIDIA/spark-rapids/pull/5610))`
- The character class `\p{ASCII}` matches only `[\x01-\x7F]` as opposed to Java's definition which matches `[\x00-\x7F]`,
since null characters are not currently supported. Similarily, `\p{Cntrl}` matches only `[\x01-\x1F\x7F]` as
opposed to Java's `[\x00-\x1F\x7F]`

The following regular expression patterns are not yet supported on the GPU and will fall back to the CPU.

Expand Down
20 changes: 20 additions & 0 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,26 @@ def test_regexp_replace_word():
),
conf=_regexp_conf)

def test_predefined_character_classes():
gen = mk_str_gen('[a-zA-Z]{0,2}[\r\n!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~]{0,2}[0-9]{0,2}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'regexp_replace(a, "\\\\p{Lower}", "x")',
'regexp_replace(a, "\\\\p{Upper}", "x")',
'regexp_replace(a, "\\\\p{ASCII}", "x")',
'regexp_replace(a, "\\\\p{Alpha}", "x")',
'regexp_replace(a, "\\\\p{Digit}", "x")',
'regexp_replace(a, "\\\\p{Alnum}", "x")',
'regexp_replace(a, "\\\\p{Punct}", "x")',
'regexp_replace(a, "\\\\p{Graph}", "x")',
'regexp_replace(a, "\\\\p{Print}", "x")',
'regexp_replace(a, "\\\\p{Blank}", "x")',
'regexp_replace(a, "\\\\p{Cntrl}", "x")',
'regexp_replace(a, "\\\\p{XDigit}", "x")',
'regexp_replace(a, "\\\\p{Space}", "x")',
),
conf=_regexp_conf)

def test_rlike():
gen = mk_str_gen('[abcd]{1,3}')
assert_gpu_and_cpu_are_equal_collect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,13 +404,70 @@ class RegexParser(pattern: String) {
parseHexDigit
case '0' =>
parseOctalDigit
case 'p' =>
consumeExpected(ch)
parsePredefinedClass
case other =>
throw new RegexUnsupportedException(
s"invalid or unsupported escape character '$other'", Some(pos - 1))
}
}
}

private def parsePredefinedClass: RegexCharacterClass = {
consumeExpected('{')
val start = pos
while(!eof() && pattern.charAt(pos).isLetter) {
pos += 1
}
val className = pattern.substring(start, pos)
def getCharacters(className: String): ListBuffer[RegexCharacterClassComponent] = {
// Character lists from here:
// https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html
className match {
case "Lower" =>
ListBuffer(RegexCharacterRange(RegexChar('a'), RegexChar('z')))
case "Upper" =>
ListBuffer(RegexCharacterRange(RegexChar('A'), RegexChar('Z')))
case "ASCII" =>
// should be \u0000-\u007f but we do not support the null terminator \u0000
ListBuffer(RegexCharacterRange(RegexChar('\u0001'), RegexChar('\u007f')))
case "Alpha" =>
ListBuffer(getCharacters("Lower"), getCharacters("Upper")).flatten
case "Digit" =>
ListBuffer(RegexCharacterRange(RegexChar('0'), RegexChar('9')))
case "Alnum" =>
ListBuffer(getCharacters("Alpha"), getCharacters("Digit")).flatten
case "Punct" =>
val res:ListBuffer[RegexCharacterClassComponent] =
ListBuffer("!\"#$%&'()*+,-./:;<=>?@\\^_`{|}~".map(RegexChar): _*)
res ++= ListBuffer(RegexEscaped('['), RegexEscaped(']'))
case "Graph" =>
ListBuffer(getCharacters("Alnum"), getCharacters("Punct")).flatten
case "Print" =>
val res = getCharacters("Graph")
res += RegexChar('\u0020')
case "Blank" =>
ListBuffer(RegexChar(' '), RegexEscaped('t'))
case "Cntrl" =>
// should be \u0001-\u001f but we do not support the null terminator \u0000
ListBuffer(RegexCharacterRange(RegexChar('\u0001'), RegexChar('\u001f')),
RegexChar('\u007f'))
case "XDigit" =>
ListBuffer(RegexCharacterRange(RegexChar('0'), RegexChar('9')),
RegexCharacterRange(RegexChar('a'), RegexChar('f')),
RegexCharacterRange(RegexChar('A'), RegexChar('F')))
case "Space" =>
ListBuffer(" \t\n\u000B\f\r".map(RegexChar): _*)
case _ =>
throw new RegexUnsupportedException(
s"predefined character class ${className} is not supported", Some(pos))
}
}
consumeExpected('}')
RegexCharacterClass(negated = false, characters = getCharacters(className))
}

private def isHexDigit(ch: Char): Boolean = ch.isDigit ||
(ch >= 'a' && ch <= 'f') ||
(ch >= 'A' && ch <= 'F')
Expand Down Expand Up @@ -1053,7 +1110,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
}
val components: Seq[RegexCharacterClassComponent] = characters
.map {
case r @ RegexChar(ch) if "^$".contains(ch) => r
case r @ RegexChar(ch) if "^$.".contains(ch) => r
case ch => rewrite(ch, replacement, None) match {
case valid: RegexCharacterClassComponent => valid
case _ =>
Expand Down Expand Up @@ -1147,6 +1204,10 @@ class CudfRegexTranspiler(mode: RegexMode) {
RegexRepetition(lineTerminatorMatcher(Set(ch), true, false),
SimpleQuantifier('?')), RegexChar('$')))))
popBackrefIfNecessary(false)
case RegexEscaped('z') =>
// \Z\z or $\z transpiles to $
r(j) = RegexChar('$')
popBackrefIfNecessary(false)
case RegexEscaped('b') | RegexEscaped('B') =>
throw new RegexUnsupportedException(
"regex sequences with \\b or \\B not supported around $")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,8 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {

test("transpile \\z") {
doTranspileTest("abc\\z", "abc$")
doTranspileTest("abc\\Z\\z", "abc$")
doTranspileTest("abc$\\z", "abc$")
}

test("transpile $") {
Expand All @@ -405,6 +407,14 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
doTranspileTest("a\\Z{1,}", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$")
}

test("transpile predefined character classes") {
doTranspileTest("\\p{Lower}", "[a-z]")
doTranspileTest("\\p{Alpha}", "[a-zA-Z]")
doTranspileTest("\\p{Alnum}", "[a-zA-Z0-9]")
doTranspileTest("\\p{Punct}", "[!\"#$%&'()*+,\\-./:;<=>?@\\^_`{|}~\\[\\]]")
doTranspileTest("\\p{Print}", "[a-zA-Z0-9!\"#$%&'()*+,\\-./:;<=>?@\\^_`{|}~\\[\\]\u0020]")
}

test("compare CPU and GPU: character range including unescaped + and -") {
val patterns = Seq("a[-]+", "a[a-b-]+", "a[-a-b]", "a[-+]", "a[+-]")
val inputs = Seq("a+", "a-", "a", "a-+", "a[a-b-]")
Expand All @@ -423,9 +433,8 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
assertCpuGpuMatchesRegexpFind(patterns, inputs)
}

// Patterns are generated from these characters
// '&' is absent due to https://github.com/NVIDIA/spark-rapids/issues/5655
private val REGEXP_LIMITED_CHARS_COMMON = "|()[]{},.^$*+?abc123x\\ \t\r\n\f\u000bBsdwSDWzZ"
private val REGEXP_LIMITED_CHARS_COMMON = "|()[]{},-./;:!^$#%&*+?<=>@\"'~`" +
"abc123x\\ \t\r\n\f\u000bBsdwSDWzZ"

private val REGEXP_LIMITED_CHARS_FIND = REGEXP_LIMITED_CHARS_COMMON

Expand Down

0 comments on commit 99e4c30

Please sign in to comment.