From 699b88e0497923011deda09ce69a65bccd96dcea Mon Sep 17 00:00:00 2001 From: Navin Kumar Date: Mon, 7 Feb 2022 15:22:32 -0800 Subject: [PATCH 1/8] in this case, throw the appropriate exception since this is not supported by cuDF and should fallback to CPU --- .../com/nvidia/spark/rapids/RegexParser.scala | 16 ++++++++++++++++ .../rapids/RegularExpressionParserSuite.scala | 7 +++++++ .../RegularExpressionTranspilerSuite.scala | 7 +++++++ 3 files changed, 30 insertions(+) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 002c8b3f04b..3a5295dea65 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -610,6 +610,22 @@ class CudfRegexTranspiler(replace: Boolean) { // example: "a*+" throw new RegexUnsupportedException(nothingToRepeat) + case (RegexGroup(capture, term), SimpleQuantifier(ch)) if "+*".contains(ch) => + // example: "(3?)+" + def isSimpleRepetition(e: RegexAST):Boolean = { + e match { + case RegexRepetition(term, quantifier) => + term.isInstanceOf[RegexCharacterClassComponent] + case RegexSequence(parts) if parts.length == 1 => + isSimpleRepetition(parts.last) + case _ => false + } + } + val tr = rewrite(term) + if (isSimpleRepetition(tr)) { + throw new RegexUnsupportedException(nothingToRepeat) + } + RegexRepetition(RegexGroup(capture, tr), quantifier) case _ => RegexRepetition(rewrite(base), quantifier) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala index 389fe7800af..6593fcf8a61 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala @@ -140,6 +140,13 @@ class RegularExpressionParserSuite extends FunSuite { RegexSequence(ListBuffer(RegexOctalChar("47"), RegexChar('7')))) } + test("repetition with group containing simple repetition") { + assert(parse("(3?)+") === + RegexSequence(ListBuffer(RegexRepetition(RegexGroup(capture = true, + RegexSequence(ListBuffer(RegexRepetition(RegexChar('3'), + SimpleQuantifier('?'))))),SimpleQuantifier('+'))))) + } + test("group containing choice with repetition") { assert(parse("(\t+|a)") == RegexSequence(ListBuffer( RegexGroup(capture = true, RegexChoice(RegexSequence(ListBuffer( diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index cef033d53f5..b33e42dc1de 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -119,6 +119,13 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { ) } + test("cuDF does not support single repetition both inside and outside of capture groups") { + // see https://github.com/NVIDIA/spark-rapids/issues/4487 + val patterns = Seq("(3?)+", "(3?)*", "(3*)+") + patterns.foreach(pattern => + assertUnsupported(pattern, replace = false, "nothing to repeat")) + } + test("cuDF does not support OR at BOL / EOL") { val patterns = Seq("$|a", "^|a") patterns.foreach(pattern => { From a27dddfd854c9f22d2598127b6e3191008019a6b Mon Sep 17 00:00:00 2001 From: Navin Kumar Date: Mon, 7 Feb 2022 15:49:12 -0800 Subject: [PATCH 2/8] fix indent Signed-off-by: Navin Kumar --- .../src/main/scala/com/nvidia/spark/rapids/RegexParser.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 3a5295dea65..b83d27d21b6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -623,7 +623,8 @@ class CudfRegexTranspiler(replace: Boolean) { } val tr = rewrite(term) if (isSimpleRepetition(tr)) { - throw new RegexUnsupportedException(nothingToRepeat) + // perhaps we could rewrite it here + throw new RegexUnsupportedException(nothingToRepeat) } RegexRepetition(RegexGroup(capture, tr), quantifier) case _ => From 42ba3c87e8747474b8acb3636d8daa0639888cd5 Mon Sep 17 00:00:00 2001 From: Navin Kumar Date: Tue, 8 Feb 2022 17:56:43 -0800 Subject: [PATCH 3/8] Updated tests to time out to handle patterns that hang in cuDF Signed-off-by: Navin Kumar --- .../RegularExpressionTranspilerSuite.scala | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index b33e42dc1de..a51efc851dd 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -15,6 +15,11 @@ */ package com.nvidia.spark.rapids +import scala.concurrent._ +import scala.concurrent.duration._ +import ExecutionContext.Implicits.global +import scala.language.postfixOps + import java.util.regex.Pattern import scala.collection.mutable.{HashSet, ListBuffer} @@ -121,7 +126,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("cuDF does not support single repetition both inside and outside of capture groups") { // see https://github.com/NVIDIA/spark-rapids/issues/4487 - val patterns = Seq("(3?)+", "(3?)*", "(3*)+") + val patterns = Seq("(3?)+", "(3?)*", "(3*)+", "((3?))+") patterns.foreach(pattern => assertUnsupported(pattern, replace = false, "nothing to repeat")) } @@ -411,10 +416,18 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { val cpu = cpuContains(javaPattern, input) val cudfPattern = new CudfRegexTranspiler(replace = false).transpile(javaPattern) val gpu = try { - gpuContains(cudfPattern, input) + val result = Future { + gpuContains(cudfPattern, input) + } + Await.result(result, 1 second) } catch { case e: CudfException => fail(s"cuDF failed to compile pattern: ${toReadableString(cudfPattern)}", e) + case e: java.util.concurrent.TimeoutException => + fail(s"cuDF timed out regexp_find: " + + s"javaPattern=${toReadableString(javaPattern)}, " + + s"cudfPattern=${toReadableString(cudfPattern)}" + ) } for (i <- input.indices) { if (cpu(i) != gpu(i)) { @@ -434,10 +447,18 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { val cpu = cpuReplace(javaPattern, input) val cudfPattern = new CudfRegexTranspiler(replace = true).transpile(javaPattern) val gpu = try { - gpuReplace(cudfPattern, input) + val result = Future { + gpuReplace(cudfPattern, input) + } + Await.result(result, 10 second) } catch { case e: CudfException => fail(s"cuDF failed to compile pattern: ${toReadableString(cudfPattern)}", e) + case e: java.util.concurrent.TimeoutException => + fail(s"cuDF timed out regexp_replace: " + + s"javaPattern=${toReadableString(javaPattern)}, " + + s"cudfPattern=${toReadableString(cudfPattern)}" + ) } for (i <- input.indices) { if (cpu(i) != gpu(i)) { From 2cfa4cec7ac8c42cdc140c16b983465af32bf067 Mon Sep 17 00:00:00 2001 From: Navin Kumar Date: Tue, 15 Feb 2022 13:22:27 -0800 Subject: [PATCH 4/8] nail down more repetition based patterns that need to fallback to CPU in order to avoid hanging on the GPU Signed-off-by: Navin Kumar --- .../com/nvidia/spark/rapids/RegexParser.scala | 78 ++++++++++--------- .../rapids/RegularExpressionParserSuite.scala | 8 ++ .../RegularExpressionTranspilerSuite.scala | 37 +++------ 3 files changed, 62 insertions(+), 61 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 8cce0592b1e..cb94804f806 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -467,6 +467,39 @@ class CudfRegexTranspiler(mode: RegexMode) { cudfRegex.toRegexString } + private def isSomeRepetition(f: ListBuffer[RegexAST] => Boolean)(e: RegexAST): Boolean = { + e match { + case RegexRepetition(_, _) => true + case RegexGroup(_, term) => isRepetition(term) + case RegexSequence(parts) if f(parts) => isRepetition(parts.last) + case _ => false + } + } + + private def isRepetition = isSomeRepetition(_.nonEmpty)(_) + private def isNestedRepetition = isSomeRepetition(_.length == 1)(_) + + private def isSupportedRepetitionBase(e: RegexAST): Boolean = { + e match { + case RegexEscaped(ch) if ch != 'd' && ch != 'w' => // example: "\B?" + false + + case RegexChar(a) if "$^".contains(a) => + // example: "$*" + false + + case RegexRepetition(_, _) => + // example: "a*+" + false + + case RegexSequence(parts) => + parts.forall(x => isSupportedRepetitionBase(x)) + + case _ => true + } + } + + private def rewrite(regex: RegexAST): RegexAST = { regex match { @@ -628,37 +661,18 @@ class CudfRegexTranspiler(mode: RegexMode) { throw new RegexUnsupportedException( "regexp_replace on GPU does not support repetition with ? or *") - case (RegexEscaped(ch), _) if ch != 'd' && ch != 'w' => - // example: "\B?" - throw new RegexUnsupportedException(nothingToRepeat) - - case (RegexChar(a), _) if "$^".contains(a) => - // example: "$*" + case (RegexGroup(_, term), SimpleQuantifier(ch)) + if "+*".contains(ch) && (!isSupportedRepetitionBase(term) || + isNestedRepetition(term)) => throw new RegexUnsupportedException(nothingToRepeat) - - case (RegexRepetition(_, _), _) => - // example: "a*+" + case (RegexGroup(_, term), QuantifierVariableLength(_,None)) + if !isSupportedRepetitionBase(term) || isNestedRepetition(term) => + // specifically this variable length repetition: \A{2,} throw new RegexUnsupportedException(nothingToRepeat) - - case (RegexGroup(capture, term), SimpleQuantifier(ch)) if "+*".contains(ch) => - // example: "(3?)+" - def isSimpleRepetition(e: RegexAST):Boolean = { - e match { - case RegexRepetition(term, quantifier) => - term.isInstanceOf[RegexCharacterClassComponent] - case RegexSequence(parts) if parts.length == 1 => - isSimpleRepetition(parts.last) - case _ => false - } - } - val tr = rewrite(term) - if (isSimpleRepetition(tr)) { - // perhaps we could rewrite it here - throw new RegexUnsupportedException(nothingToRepeat) - } - RegexRepetition(RegexGroup(capture, tr), quantifier) - case _ => + case _ if isSupportedRepetitionBase(base) => RegexRepetition(rewrite(base), quantifier) + case _ => + throw new RegexUnsupportedException(nothingToRepeat) } @@ -667,14 +681,6 @@ class CudfRegexTranspiler(mode: RegexMode) { val rr = rewrite(r) // cuDF does not support repetition on one side of a choice, such as "a*|a" - def isRepetition(e: RegexAST): Boolean = { - e match { - case RegexRepetition(_, _) => true - case RegexGroup(_, term) => isRepetition(term) - case RegexSequence(parts) if parts.nonEmpty => isRepetition(parts.last) - case _ => false - } - } if (isRepetition(ll) || isRepetition(rr)) { throw new RegexUnsupportedException(nothingToRepeat) } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala index 82c21631ee1..06c642ba561 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala @@ -147,6 +147,14 @@ class RegularExpressionParserSuite extends FunSuite { SimpleQuantifier('?'))))),SimpleQuantifier('+'))))) } + test("repetition with group containing escape character") { + assert(parse(raw"(\A)+") === + RegexSequence(ListBuffer(RegexRepetition(RegexGroup(capture = true, + RegexSequence(ListBuffer(RegexEscaped('A')))), + SimpleQuantifier('+')))) + ) + } + test("group containing choice with repetition") { assert(parse("(\t+|a)") == RegexSequence(ListBuffer( RegexGroup(capture = true, RegexChoice(RegexSequence(ListBuffer( diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 5c5e153acfb..09b7d25426f 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -15,10 +15,6 @@ */ package com.nvidia.spark.rapids -import scala.concurrent._ -import scala.concurrent.duration._ -import ExecutionContext.Implicits.global -import scala.language.postfixOps import java.util.regex.Pattern @@ -128,9 +124,16 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { // see https://github.com/NVIDIA/spark-rapids/issues/4487 val patterns = Seq("(3?)+", "(3?)*", "(3*)+", "((3?))+") patterns.foreach(pattern => - assertUnsupported(pattern, replace = false, "nothing to repeat")) + assertUnsupported(pattern, RegexReplaceMode, "nothing to repeat")) } + test("cuDF doesn't support \\A (escaped string anchors) in some repetitions") { + val patterns = Seq(raw"(\A)+", raw"(\A){2,}") + patterns.foreach(pattern => + assertUnsupported(pattern, RegexFindMode, "nothing to repeat") + ) + } + test("cuDF does not support OR at BOL / EOL") { val patterns = Seq("$|a", "^|a") patterns.foreach(pattern => { @@ -478,18 +481,10 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { val cpu = cpuContains(javaPattern, input) val cudfPattern = new CudfRegexTranspiler(RegexFindMode).transpile(javaPattern) val gpu = try { - val result = Future { - gpuContains(cudfPattern, input) - } - Await.result(result, 1 second) + gpuContains(cudfPattern, input) } catch { case e: CudfException => fail(s"cuDF failed to compile pattern: ${toReadableString(cudfPattern)}", e) - case e: java.util.concurrent.TimeoutException => - fail(s"cuDF timed out regexp_find: " + - s"javaPattern=${toReadableString(javaPattern)}, " + - s"cudfPattern=${toReadableString(cudfPattern)}" - ) } for (i <- input.indices) { if (cpu(i) != gpu(i)) { @@ -509,18 +504,10 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { val cpu = cpuReplace(javaPattern, input) val cudfPattern = new CudfRegexTranspiler(RegexReplaceMode).transpile(javaPattern) val gpu = try { - val result = Future { - gpuReplace(cudfPattern, input) - } - Await.result(result, 10 second) + gpuReplace(cudfPattern, input) } catch { case e: CudfException => fail(s"cuDF failed to compile pattern: ${toReadableString(cudfPattern)}", e) - case e: java.util.concurrent.TimeoutException => - fail(s"cuDF timed out regexp_replace: " + - s"javaPattern=${toReadableString(javaPattern)}, " + - s"cudfPattern=${toReadableString(cudfPattern)}" - ) } for (i <- input.indices) { if (cpu(i) != gpu(i)) { @@ -656,12 +643,12 @@ class FuzzRegExp(suggestedChars: String, skipKnownIssues: Boolean = true) { () => predefinedCharacterClass, () => group(depth), () => boundaryMatch, - () => sequence(depth)) + () => sequence(depth), + () => repetition(depth)) // https://github.com/NVIDIA/spark-rapids/issues/4487 val generators = if (skipKnownIssues) { baseGenerators } else { baseGenerators ++ Seq( - () => repetition(depth), // https://github.com/NVIDIA/spark-rapids/issues/4487 () => choice(depth)) // https://github.com/NVIDIA/spark-rapids/issues/4603 } generators(rr.nextInt(generators.length))() From 25861d4c306757a80a63e6db60e0683e1be02cb0 Mon Sep 17 00:00:00 2001 From: Navin Kumar Date: Wed, 16 Feb 2022 07:44:56 -0800 Subject: [PATCH 5/8] Update for more edge cases with \A Signed-off-by: Navin Kumar --- .../com/nvidia/spark/rapids/RegexParser.scala | 15 ++++++++++----- .../RegularExpressionTranspilerSuite.scala | 19 ++++++++++--------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index cb94804f806..1ea2a3fd6c5 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -495,6 +495,9 @@ class CudfRegexTranspiler(mode: RegexMode) { case RegexSequence(parts) => parts.forall(x => isSupportedRepetitionBase(x)) + case RegexGroup(_, term) => + isSupportedRepetitionBase(term) + case _ => true } } @@ -662,13 +665,14 @@ class CudfRegexTranspiler(mode: RegexMode) { "regexp_replace on GPU does not support repetition with ? or *") case (RegexGroup(_, term), SimpleQuantifier(ch)) - if "+*".contains(ch) && (!isSupportedRepetitionBase(term) || - isNestedRepetition(term)) => + if "+*".contains(ch) && !isSupportedRepetitionBase(term) => throw new RegexUnsupportedException(nothingToRepeat) case (RegexGroup(_, term), QuantifierVariableLength(_,None)) - if !isSupportedRepetitionBase(term) || isNestedRepetition(term) => + if !isSupportedRepetitionBase(term) => // specifically this variable length repetition: \A{2,} throw new RegexUnsupportedException(nothingToRepeat) + case (RegexGroup(_, _), SimpleQuantifier(ch)) if ch == '?' => + RegexRepetition(rewrite(base), quantifier) case _ if isSupportedRepetitionBase(base) => RegexRepetition(rewrite(base), quantifier) case _ => @@ -690,8 +694,9 @@ class CudfRegexTranspiler(mode: RegexMode) { def endsWithLineAnchor(e: RegexAST): Boolean = { e match { case RegexSequence(parts) if parts.nonEmpty => - isBeginOrEndLineAnchor(parts.last) - case _ => false + endsWithLineAnchor(parts.last) + case RegexEscaped('A') => true + case _ => isBeginOrEndLineAnchor(e) } } if (endsWithLineAnchor(ll) || endsWithLineAnchor(rr)) { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 09b7d25426f..9ba740ceede 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -43,7 +43,8 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { "a^|b", "w$|b", "\n[^\r\n]x*|^3x", - "]*\\wWW$|zb" + "]*\\wWW$|zb", + "(\\A|\\05)?" ) // data is not relevant because we are checking for compilation errors val inputs = Seq("a") @@ -124,16 +125,9 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { // see https://github.com/NVIDIA/spark-rapids/issues/4487 val patterns = Seq("(3?)+", "(3?)*", "(3*)+", "((3?))+") patterns.foreach(pattern => - assertUnsupported(pattern, RegexReplaceMode, "nothing to repeat")) + assertUnsupported(pattern, RegexFindMode, "nothing to repeat")) } - test("cuDF doesn't support \\A (escaped string anchors) in some repetitions") { - val patterns = Seq(raw"(\A)+", raw"(\A){2,}") - patterns.foreach(pattern => - assertUnsupported(pattern, RegexFindMode, "nothing to repeat") - ) - } - test("cuDF does not support OR at BOL / EOL") { val patterns = Seq("$|a", "^|a") patterns.foreach(pattern => { @@ -186,6 +180,13 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { "\ntest", "test\n", "\ntest\n")) } + test("string anchor \\A will fall back to CPU in some repetitions") { + val patterns = Seq(raw"(\A)+", raw"(\A)*", raw"(\A){2,}") + patterns.foreach(pattern => + assertUnsupported(pattern, RegexFindMode, "nothing to repeat") + ) + } + test("string anchor \\Z fall back to CPU") { for (mode <- Seq(RegexFindMode, RegexReplaceMode)) { assertUnsupported("\\Z", mode, "string anchor \\Z is not supported") From 614aba7bb6622ad11626d6245405fbe385ba2ad6 Mon Sep 17 00:00:00 2001 From: Navin Kumar Date: Wed, 16 Feb 2022 19:12:29 -0800 Subject: [PATCH 6/8] handle {0} and {0,0} case Signed-off-by: Navin Kumar --- .../com/nvidia/spark/rapids/RegexParser.scala | 5 +++++ .../RegularExpressionTranspilerSuite.scala | 20 +++++++++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 1ea2a3fd6c5..1725a73cbad 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -664,6 +664,11 @@ class CudfRegexTranspiler(mode: RegexMode) { throw new RegexUnsupportedException( "regexp_replace on GPU does not support repetition with ? or *") + case (_, QuantifierFixedLength(0)) | (_, QuantifierVariableLength(0,Some(0))) + if mode != RegexFindMode => + throw new RegexUnsupportedException( + "regex_replace and regex_split on GPU do not support repetition with {0} or {0,0}") + case (RegexGroup(_, term), SimpleQuantifier(ch)) if "+*".contains(ch) && !isSupportedRepetitionBase(term) => throw new RegexUnsupportedException(nothingToRepeat) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 9ba740ceede..82c9b07e8a5 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -310,6 +310,22 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { assertCpuGpuMatchesRegexpReplace(patterns, inputs) } + test("regexp_replace - fall back to CPU for {0} or {0,0}") { + val patterns = Seq("a{0}", raw"\02{0}", "a{0,0}", raw"\02{0,0}") + patterns.foreach(pattern => + assertUnsupported(pattern, RegexReplaceMode, + "regex_replace and regex_split on GPU do not support repetition with {0} or {0,0}") + ) + } + + test("regexp_split - fall back to CPU for {0} or {0,0}") { + val patterns = Seq("a{0}", raw"\02{0}", "a{0,0}", raw"\02{0,0}") + patterns.foreach(pattern => + assertUnsupported(pattern, RegexSplitMode, + "regex_replace and regex_split on GPU do not support repetition with {0} or {0,0}") + ) + } + test("compare CPU and GPU: regexp find fuzz test with limited chars") { // testing with this limited set of characters finds issues much // faster than using the full ASCII set @@ -644,12 +660,12 @@ class FuzzRegExp(suggestedChars: String, skipKnownIssues: Boolean = true) { () => predefinedCharacterClass, () => group(depth), () => boundaryMatch, - () => sequence(depth), - () => repetition(depth)) // https://github.com/NVIDIA/spark-rapids/issues/4487 + () => sequence(depth)) val generators = if (skipKnownIssues) { baseGenerators } else { baseGenerators ++ Seq( + () => repetition(depth), // https://github.com/NVIDIA/spark-rapids/issues/4487 () => choice(depth)) // https://github.com/NVIDIA/spark-rapids/issues/4603 } generators(rr.nextInt(generators.length))() From 90150cc0a28893c7f8897f41c10817b9f9c4c211 Mon Sep 17 00:00:00 2001 From: Navin Kumar Date: Thu, 17 Feb 2022 07:09:49 -0800 Subject: [PATCH 7/8] Add one more edge case to fallback to CPU --- .../com/nvidia/spark/rapids/RegexParser.scala | 5 +++++ .../RegularExpressionTranspilerSuite.scala | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 1725a73cbad..4e83c4e17ab 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -664,6 +664,11 @@ class CudfRegexTranspiler(mode: RegexMode) { throw new RegexUnsupportedException( "regexp_replace on GPU does not support repetition with ? or *") + case (_, QuantifierVariableLength(0,None)) if mode == RegexReplaceMode => + // see https://github.com/NVIDIA/spark-rapids/issues/4468 + throw new RegexUnsupportedException( + "regexp_replace on GPU does not support repetition with {0,}") + case (_, QuantifierFixedLength(0)) | (_, QuantifierVariableLength(0,Some(0))) if mode != RegexFindMode => throw new RegexUnsupportedException( diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 82c9b07e8a5..7617fce7f81 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -310,6 +310,24 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { assertCpuGpuMatchesRegexpReplace(patterns, inputs) } + test("regexp_replace - character class repetition - ? and * - fall back to CPU") { + val patterns = Seq(raw"[1a-zA-Z]?", raw"[1a-zA-Z]*") + patterns.foreach(pattern => + assertUnsupported(pattern, RegexReplaceMode, + "regexp_replace on GPU does not support repetition with ? or *" + ) + ) + } + + test("regexp_replace - character class repetition - {0,} - fall back to CPU") { + val patterns = Seq(raw"[1a-zA-Z]{0,}") + patterns.foreach(pattern => + assertUnsupported(pattern, RegexReplaceMode, + "regexp_replace on GPU does not support repetition with {0,}" + ) + ) + } + test("regexp_replace - fall back to CPU for {0} or {0,0}") { val patterns = Seq("a{0}", raw"\02{0}", "a{0,0}", raw"\02{0,0}") patterns.foreach(pattern => From 587557c24cc29984b7aa035077829c48d3c083fe Mon Sep 17 00:00:00 2001 From: Navin Kumar Date: Thu, 17 Feb 2022 11:44:02 -0800 Subject: [PATCH 8/8] Some style fixes and clean up Signed-off-by: Navin Kumar --- .../scala/com/nvidia/spark/rapids/RegexParser.scala | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 4e83c4e17ab..cac002435a0 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -467,18 +467,15 @@ class CudfRegexTranspiler(mode: RegexMode) { cudfRegex.toRegexString } - private def isSomeRepetition(f: ListBuffer[RegexAST] => Boolean)(e: RegexAST): Boolean = { + private def isRepetition(e: RegexAST): Boolean = { e match { case RegexRepetition(_, _) => true case RegexGroup(_, term) => isRepetition(term) - case RegexSequence(parts) if f(parts) => isRepetition(parts.last) + case RegexSequence(parts) if parts.nonEmpty => isRepetition(parts.last) case _ => false } } - private def isRepetition = isSomeRepetition(_.nonEmpty)(_) - private def isNestedRepetition = isSomeRepetition(_.length == 1)(_) - private def isSupportedRepetitionBase(e: RegexAST): Boolean = { e match { case RegexEscaped(ch) if ch != 'd' && ch != 'w' => // example: "\B?" @@ -493,7 +490,7 @@ class CudfRegexTranspiler(mode: RegexMode) { false case RegexSequence(parts) => - parts.forall(x => isSupportedRepetitionBase(x)) + parts.forall(isSupportedRepetitionBase) case RegexGroup(_, term) => isSupportedRepetitionBase(term) @@ -679,7 +676,7 @@ class CudfRegexTranspiler(mode: RegexMode) { throw new RegexUnsupportedException(nothingToRepeat) case (RegexGroup(_, term), QuantifierVariableLength(_,None)) if !isSupportedRepetitionBase(term) => - // specifically this variable length repetition: \A{2,} + // specifically this variable length repetition: \A{2,} throw new RegexUnsupportedException(nothingToRepeat) case (RegexGroup(_, _), SimpleQuantifier(ch)) if ch == '?' => RegexRepetition(rewrite(base), quantifier)