Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for POSIX predefined character classes #5692

Merged
20 changes: 20 additions & 0 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,26 @@ def test_regexp_replace_word():
),
conf=_regexp_conf)

def test_predefined_character_classes():
gen = mk_str_gen('[a-zA-Z]{0,2}[\r\n!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~]{0,2}[0-9]{0,2}')
anthony-chang marked this conversation as resolved.
Show resolved Hide resolved
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'regexp_replace(a, "\\\\p{Lower}", "x")',
'regexp_replace(a, "\\\\p{Upper}", "x")',
'regexp_replace(a, "\\\\p{ASCII}", "x")',
'regexp_replace(a, "\\\\p{Alpha}", "x")',
'regexp_replace(a, "\\\\p{Digit}", "x")',
'regexp_replace(a, "\\\\p{Alnum}", "x")',
'regexp_replace(a, "\\\\p{Punct}", "x")',
'regexp_replace(a, "\\\\p{Graph}", "x")',
'regexp_replace(a, "\\\\p{Print}", "x")',
'regexp_replace(a, "\\\\p{Blank}", "x")',
'regexp_replace(a, "\\\\p{Cntrl}", "x")',
'regexp_replace(a, "\\\\p{XDigit}", "x")',
'regexp_replace(a, "\\\\p{Space}", "x")',
),
conf=_regexp_conf)

def test_rlike():
gen = mk_str_gen('[abcd]{1,3}')
assert_gpu_and_cpu_are_equal_collect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,13 +404,69 @@ class RegexParser(pattern: String) {
parseHexDigit
case '0' =>
parseOctalDigit
case 'p' =>
consumeExpected(ch)
parsePredefinedClass
case other =>
throw new RegexUnsupportedException(
s"invalid or unsupported escape character '$other'", Some(pos - 1))
}
}
}

private def parsePredefinedClass: RegexCharacterClass = {
consumeExpected('{')
val start = pos
while(!eof() && pattern.charAt(pos).isLetter) {
pos += 1
}
val className = pattern.substring(start, pos)
def getCharacters(className: String): ListBuffer[RegexCharacterClassComponent] = {
// Character lists from here:
// https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html
className match {
case "Lower" =>
ListBuffer(RegexCharacterRange(RegexChar('a'), RegexChar('z')))
case "Upper" =>
ListBuffer(RegexCharacterRange(RegexChar('A'), RegexChar('Z')))
case "ASCII" =>
// should be \u0000-\u007f but we do not support the null terminator \u0000
anthony-chang marked this conversation as resolved.
Show resolved Hide resolved
ListBuffer(RegexCharacterRange(RegexChar('\u0001'), RegexChar('\u007f')))
case "Alpha" =>
ListBuffer(getCharacters("Lower"), getCharacters("Upper")).flatten
case "Digit" =>
ListBuffer(RegexCharacterRange(RegexChar('0'), RegexChar('9')))
case "Alnum" =>
ListBuffer(getCharacters("Alpha"), getCharacters("Digit")).flatten
case "Punct" =>
val res:ListBuffer[RegexCharacterClassComponent] =
ListBuffer("!\"#$%&'()*+,-./:;<=>?@\\^_`{|}~".map(RegexChar): _*)
res ++= ListBuffer(RegexEscaped('['), RegexEscaped(']'))
case "Graph" =>
ListBuffer(getCharacters("Alnum"), getCharacters("Punct")).flatten
case "Print" =>
val res = getCharacters("Graph")
res += RegexChar('\u0020')
case "Blank" =>
ListBuffer(RegexChar(' '), RegexEscaped('t'))
case "Cntrl" =>
ListBuffer(RegexCharacterRange(RegexChar('\u0001'), RegexChar('\u001f')),
anthony-chang marked this conversation as resolved.
Show resolved Hide resolved
RegexChar('\u007f'))
case "XDigit" =>
ListBuffer(RegexCharacterRange(RegexChar('0'), RegexChar('9')),
RegexCharacterRange(RegexChar('a'), RegexChar('f')),
RegexCharacterRange(RegexChar('A'), RegexChar('F')))
case "Space" =>
ListBuffer(" \t\n\u000B\f\r".map(RegexChar): _*)
case _ =>
throw new RegexUnsupportedException(
s"predefined character class ${className} is not supported", Some(pos))
}
}
consumeExpected('}')
RegexCharacterClass(negated = false, characters = getCharacters(className))
}

private def isHexDigit(ch: Char): Boolean = ch.isDigit ||
(ch >= 'a' && ch <= 'f') ||
(ch >= 'A' && ch <= 'F')
Expand Down Expand Up @@ -1053,7 +1109,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
}
val components: Seq[RegexCharacterClassComponent] = characters
.map {
case r @ RegexChar(ch) if "^$".contains(ch) => r
case r @ RegexChar(ch) if "^$.".contains(ch) => r
case ch => rewrite(ch, replacement, None) match {
case valid: RegexCharacterClassComponent => valid
case _ =>
Expand Down Expand Up @@ -1147,6 +1203,10 @@ class CudfRegexTranspiler(mode: RegexMode) {
RegexRepetition(lineTerminatorMatcher(Set(ch), true, false),
SimpleQuantifier('?')), RegexChar('$')))))
popBackrefIfNecessary(false)
case RegexEscaped('z') =>
NVnavkumar marked this conversation as resolved.
Show resolved Hide resolved
// \Z\z or $\z transpiles to $
r(j) = RegexChar('$')
popBackrefIfNecessary(false)
case RegexEscaped('b') | RegexEscaped('B') =>
throw new RegexUnsupportedException(
"regex sequences with \\b or \\B not supported around $")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {

test("transpile \\z") {
doTranspileTest("abc\\z", "abc$")
doTranspileTest("abc\\Z\\z", "abc$")
doTranspileTest("abc$\\z", "abc$")
}

test("transpile $") {
Expand All @@ -391,6 +393,14 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
doTranspileTest("a\\Z{1,}", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$")
}

test("transpile predefined character classes") {
doTranspileTest("\\p{Lower}", "[a-z]")
doTranspileTest("\\p{Alpha}", "[a-zA-Z]")
doTranspileTest("\\p{Alnum}", "[a-zA-Z0-9]")
doTranspileTest("\\p{Punct}", "[!\"#$%&'()*+,\\-./:;<=>?@\\^_`{|}~\\[\\]]")
doTranspileTest("\\p{Print}", "[a-zA-Z0-9!\"#$%&'()*+,\\-./:;<=>?@\\^_`{|}~\\[\\]\u0020]")
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it intentional that only a subset are tested here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this one is mostly a sanity check since the \p{Print} transpilation recursively uses the definition of the 4 \p{} classes above it.
I tested the full set in the integration tests.


test("compare CPU and GPU: character range including unescaped + and -") {
val patterns = Seq("a[-]+", "a[a-b-]+", "a[-a-b]", "a[-+]", "a[+-]")
val inputs = Seq("a+", "a-", "a", "a-+", "a[a-b-]")
Expand All @@ -409,9 +419,8 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
assertCpuGpuMatchesRegexpFind(patterns, inputs)
}

// Patterns are generated from these characters
// '&' is absent due to https://github.com/NVIDIA/spark-rapids/issues/5655
private val REGEXP_LIMITED_CHARS_COMMON = "|()[]{},.^$*+?abc123x\\ \t\r\n\f\u000bBsdwSDWzZ"
private val REGEXP_LIMITED_CHARS_COMMON = "|()[]{},-./;:!^$#%&*+?<=>@\"'~_`" +
"abc123x\\ \t\r\n\f\u000bBsdwSDWzZ"

private val REGEXP_LIMITED_CHARS_FIND = REGEXP_LIMITED_CHARS_COMMON

Expand Down Expand Up @@ -750,6 +759,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
gpuReplace(cudfPattern, replaceString.get, input)
} catch {
case e: CudfException =>
println(e)
anthony-chang marked this conversation as resolved.
Show resolved Hide resolved
fail(s"cuDF failed to compile pattern: ${toReadableString(cudfPattern)}, " +
s"original: ${toReadableString(javaPattern)}, " +
s"replacement: ${toReadableString(replaceString.get)}", e)
Expand Down