From 318807ed708916c18970c1b0b269ea7d979c399c Mon Sep 17 00:00:00 2001 From: Navin Kumar <97137715+NVnavkumar@users.noreply.github.com> Date: Mon, 18 Sep 2023 13:46:40 -0700 Subject: [PATCH] Handle escaping the dangling right ] and right } in the regexp transpiler (#9239) * Handle escaping the dangling right ] and right } automatically in the transpiler to ensure compatibility with cudf Signed-off-by: Navin Kumar * Fix a syntax error in pytest that snuck in Signed-off-by: Navin Kumar * fix scalatest failures Signed-off-by: Navin Kumar --------- Signed-off-by: Navin Kumar --- .../src/main/python/regexp_test.py | 24 +++++++++++++++++++ .../com/nvidia/spark/rapids/RegexParser.scala | 6 ++++- .../rapids/RegularExpressionParserSuite.scala | 4 ++-- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/integration_tests/src/main/python/regexp_test.py b/integration_tests/src/main/python/regexp_test.py index 3d3ce6ce4af..63a96d0fd37 100644 --- a/integration_tests/src/main/python/regexp_test.py +++ b/integration_tests/src/main/python/regexp_test.py @@ -122,6 +122,25 @@ def test_split_re_no_limit(): 'split(a, "^[o]")'), conf=_regexp_conf) +def test_split_with_dangling_brackets(): + data_gen = mk_str_gen('([bf]o{0,2}[.?+\\^$|{}]{1,2}){1,7}') \ + .with_special_case('boo.and.foo') \ + .with_special_case('boo?and?foo') \ + .with_special_case('boo+and+foo') \ + .with_special_case('boo^and^foo') \ + .with_special_case('boo$and$foo') \ + .with_special_case('boo|and|foo') \ + .with_special_case('boo{and}foo') \ + .with_special_case('boo$|and$|foo') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "[a-z]]")', + 'split(a, "[boo]]]")', + 'split(a, "[foo]}")', + 'split(a, "[foo]}}")'), + conf=_regexp_conf) + + def test_split_optimized_no_re(): data_gen = mk_str_gen('([bf]o{0,2}[.?+\\^$|{}]{1,2}){1,7}') \ .with_special_case('boo.and.foo') \ @@ -134,6 +153,11 @@ def test_split_optimized_no_re(): .with_special_case('boo$|and$|foo') assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "]")', + 'split(a, "]]")', + 'split(a, "}")', + 'split(a, "}}")', + 'split(a, ",")', 'split(a, "\\\\.")', 'split(a, "\\\\?")', 'split(a, "\\\\+")', diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 985c4efda40..acaa90b0f12 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -147,6 +147,10 @@ class RegexParser(pattern: String) { parseGroup() case '[' => parseCharacterClass() + case ']' => + RegexEscaped(']') + case '}' => + RegexEscaped('}') case '\\' => parseEscapedCharacter() case '\u0000' => @@ -1857,7 +1861,7 @@ sealed case class RegexChar(ch: Char) extends RegexCharacterClassComponent { override def toRegexString: String = ch.toString } -sealed case class RegexEscaped(a: Char) extends RegexCharacterClassComponent{ +sealed case class RegexEscaped(a: Char) extends RegexCharacterClassComponent { def this(a: Char, position: Int) { this(a) this.position = Some(position) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala index 3c4f91d0816..00465308e6f 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala @@ -52,7 +52,7 @@ class RegularExpressionParserSuite extends AnyFunSuite { test("not a quantifier") { assert(parse("{1}") === RegexSequence(ListBuffer( - RegexChar('{'), RegexChar('1'),RegexChar('}')))) + RegexChar('{'), RegexChar('1'),RegexEscaped('}')))) } test("nested repetition") { @@ -109,7 +109,7 @@ class RegularExpressionParserSuite extends AnyFunSuite { assert(parse("[a]]") === RegexSequence(ListBuffer( RegexCharacterClass(negated = false, - ListBuffer(RegexChar('a'))), RegexChar(']')))) + ListBuffer(RegexChar('a'))), RegexEscaped(']')))) } test("escaped brackets") {