diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 09ca02032bf..cac002435a0 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -467,6 +467,39 @@ class CudfRegexTranspiler(mode: RegexMode) { cudfRegex.toRegexString } + private def isRepetition(e: RegexAST): Boolean = { + e match { + case RegexRepetition(_, _) => true + case RegexGroup(_, term) => isRepetition(term) + case RegexSequence(parts) if parts.nonEmpty => isRepetition(parts.last) + case _ => false + } + } + + private def isSupportedRepetitionBase(e: RegexAST): Boolean = { + e match { + case RegexEscaped(ch) if ch != 'd' && ch != 'w' => // example: "\B?" + false + + case RegexChar(a) if "$^".contains(a) => + // example: "$*" + false + + case RegexRepetition(_, _) => + // example: "a*+" + false + + case RegexSequence(parts) => + parts.forall(isSupportedRepetitionBase) + + case RegexGroup(_, term) => + isSupportedRepetitionBase(term) + + case _ => true + } + } + + private def rewrite(regex: RegexAST): RegexAST = { regex match { @@ -628,20 +661,29 @@ class CudfRegexTranspiler(mode: RegexMode) { throw new RegexUnsupportedException( "regexp_replace on GPU does not support repetition with ? or *") - case (RegexEscaped(ch), _) if ch != 'd' && ch != 'w' => - // example: "\B?" - throw new RegexUnsupportedException(nothingToRepeat) + case (_, QuantifierVariableLength(0,None)) if mode == RegexReplaceMode => + // see https://github.com/NVIDIA/spark-rapids/issues/4468 + throw new RegexUnsupportedException( + "regexp_replace on GPU does not support repetition with {0,}") - case (RegexChar(a), _) if "$^".contains(a) => - // example: "$*" - throw new RegexUnsupportedException(nothingToRepeat) + case (_, QuantifierFixedLength(0)) | (_, QuantifierVariableLength(0,Some(0))) + if mode != RegexFindMode => + throw new RegexUnsupportedException( + "regex_replace and regex_split on GPU do not support repetition with {0} or {0,0}") - case (RegexRepetition(_, _), _) => - // example: "a*+" + case (RegexGroup(_, term), SimpleQuantifier(ch)) + if "+*".contains(ch) && !isSupportedRepetitionBase(term) => throw new RegexUnsupportedException(nothingToRepeat) - - case _ => + case (RegexGroup(_, term), QuantifierVariableLength(_,None)) + if !isSupportedRepetitionBase(term) => + // specifically this variable length repetition: \A{2,} + throw new RegexUnsupportedException(nothingToRepeat) + case (RegexGroup(_, _), SimpleQuantifier(ch)) if ch == '?' => + RegexRepetition(rewrite(base), quantifier) + case _ if isSupportedRepetitionBase(base) => RegexRepetition(rewrite(base), quantifier) + case _ => + throw new RegexUnsupportedException(nothingToRepeat) } @@ -650,14 +692,6 @@ class CudfRegexTranspiler(mode: RegexMode) { val rr = rewrite(r) // cuDF does not support repetition on one side of a choice, such as "a*|a" - def isRepetition(e: RegexAST): Boolean = { - e match { - case RegexRepetition(_, _) => true - case RegexGroup(_, term) => isRepetition(term) - case RegexSequence(parts) if parts.nonEmpty => isRepetition(parts.last) - case _ => false - } - } if (isRepetition(ll) || isRepetition(rr)) { throw new RegexUnsupportedException(nothingToRepeat) } @@ -667,8 +701,9 @@ class CudfRegexTranspiler(mode: RegexMode) { def endsWithLineAnchor(e: RegexAST): Boolean = { e match { case RegexSequence(parts) if parts.nonEmpty => - isBeginOrEndLineAnchor(parts.last) - case _ => false + endsWithLineAnchor(parts.last) + case RegexEscaped('A') => true + case _ => isBeginOrEndLineAnchor(e) } } if (endsWithLineAnchor(ll) || endsWithLineAnchor(rr)) { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala index 49e0e0f759f..06c642ba561 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala @@ -140,6 +140,21 @@ class RegularExpressionParserSuite extends FunSuite { RegexSequence(ListBuffer(RegexOctalChar("047"), RegexChar('7')))) } + test("repetition with group containing simple repetition") { + assert(parse("(3?)+") === + RegexSequence(ListBuffer(RegexRepetition(RegexGroup(capture = true, + RegexSequence(ListBuffer(RegexRepetition(RegexChar('3'), + SimpleQuantifier('?'))))),SimpleQuantifier('+'))))) + } + + test("repetition with group containing escape character") { + assert(parse(raw"(\A)+") === + RegexSequence(ListBuffer(RegexRepetition(RegexGroup(capture = true, + RegexSequence(ListBuffer(RegexEscaped('A')))), + SimpleQuantifier('+')))) + ) + } + test("group containing choice with repetition") { assert(parse("(\t+|a)") == RegexSequence(ListBuffer( RegexGroup(capture = true, RegexChoice(RegexSequence(ListBuffer( diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index f050eeddfdd..7617fce7f81 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -15,6 +15,7 @@ */ package com.nvidia.spark.rapids + import java.util.regex.Pattern import scala.collection.mutable.{HashSet, ListBuffer} @@ -42,7 +43,8 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { "a^|b", "w$|b", "\n[^\r\n]x*|^3x", - "]*\\wWW$|zb" + "]*\\wWW$|zb", + "(\\A|\\05)?" ) // data is not relevant because we are checking for compilation errors val inputs = Seq("a") @@ -119,6 +121,13 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { ) } + test("cuDF does not support single repetition both inside and outside of capture groups") { + // see https://github.com/NVIDIA/spark-rapids/issues/4487 + val patterns = Seq("(3?)+", "(3?)*", "(3*)+", "((3?))+") + patterns.foreach(pattern => + assertUnsupported(pattern, RegexFindMode, "nothing to repeat")) + } + test("cuDF does not support OR at BOL / EOL") { val patterns = Seq("$|a", "^|a") patterns.foreach(pattern => { @@ -171,6 +180,13 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { "\ntest", "test\n", "\ntest\n")) } + test("string anchor \\A will fall back to CPU in some repetitions") { + val patterns = Seq(raw"(\A)+", raw"(\A)*", raw"(\A){2,}") + patterns.foreach(pattern => + assertUnsupported(pattern, RegexFindMode, "nothing to repeat") + ) + } + test("string anchor \\Z fall back to CPU") { for (mode <- Seq(RegexFindMode, RegexReplaceMode)) { assertUnsupported("\\Z", mode, "string anchor \\Z is not supported") @@ -294,6 +310,40 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { assertCpuGpuMatchesRegexpReplace(patterns, inputs) } + test("regexp_replace - character class repetition - ? and * - fall back to CPU") { + val patterns = Seq(raw"[1a-zA-Z]?", raw"[1a-zA-Z]*") + patterns.foreach(pattern => + assertUnsupported(pattern, RegexReplaceMode, + "regexp_replace on GPU does not support repetition with ? or *" + ) + ) + } + + test("regexp_replace - character class repetition - {0,} - fall back to CPU") { + val patterns = Seq(raw"[1a-zA-Z]{0,}") + patterns.foreach(pattern => + assertUnsupported(pattern, RegexReplaceMode, + "regexp_replace on GPU does not support repetition with {0,}" + ) + ) + } + + test("regexp_replace - fall back to CPU for {0} or {0,0}") { + val patterns = Seq("a{0}", raw"\02{0}", "a{0,0}", raw"\02{0,0}") + patterns.foreach(pattern => + assertUnsupported(pattern, RegexReplaceMode, + "regex_replace and regex_split on GPU do not support repetition with {0} or {0,0}") + ) + } + + test("regexp_split - fall back to CPU for {0} or {0,0}") { + val patterns = Seq("a{0}", raw"\02{0}", "a{0,0}", raw"\02{0,0}") + patterns.foreach(pattern => + assertUnsupported(pattern, RegexSplitMode, + "regex_replace and regex_split on GPU do not support repetition with {0} or {0,0}") + ) + } + test("compare CPU and GPU: regexp find fuzz test with limited chars") { // testing with this limited set of characters finds issues much // faster than using the full ASCII set