diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index c60e8db3524..86ba30d9c83 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -444,7 +444,17 @@ class CudfRegexTranspiler(replace: Boolean) { // this is a bit extreme and it would be good to replace with finer-grained // rules throw new RegexUnsupportedException("regexp_replace on GPU does not support ^ or $") - + case '$' => + RegexSequence(ListBuffer( + RegexRepetition( + RegexCharacterClass(negated = false, + characters = ListBuffer(RegexChar('\r'))), + SimpleQuantifier('?')), + RegexRepetition( + RegexCharacterClass(negated = false, + characters = ListBuffer(RegexChar('\n'))), + SimpleQuantifier('?')), + RegexChar('$'))) case _ => regex } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 9bd859f6562..240ba7d6eab 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -15,7 +15,7 @@ */ package com.nvidia.spark.rapids -import java.util.regex.Pattern +import java.util.regex.{Matcher, Pattern} import scala.collection.mutable.ListBuffer import scala.util.{Random, Try} @@ -132,13 +132,16 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { assertUnsupported(pattern, "nothing to repeat")) } - ignore("known issue - multiline difference between CPU and GPU") { - // see https://github.com/rapidsai/cudf/issues/9620 + test("end of line anchor with strings ending in valid newline") { val pattern = "2$" - // this matches "2" but not "2\n" on the GPU assertCpuGpuMatchesRegexpFind(Seq(pattern), Seq("2", "2\n", "2\r", "2\r\n")) } + test("end of line anchor with strings ending in invalid newline") { + val pattern = "2$" + assertCpuGpuMatchesRegexpFind(Seq(pattern), Seq("2\n\r")) + } + test("dot matches CR on GPU but not on CPU") { // see https://github.com/rapidsai/cudf/issues/9619 val pattern = "1." @@ -187,8 +190,10 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { ")" + "$" // end of line - // input and output should be identical - doTranspileTest(VALID_FLOAT_REGEX, VALID_FLOAT_REGEX) + // input and output should be identical except for '$' being replaced with '[\r]?[\n]?$' + doTranspileTest(VALID_FLOAT_REGEX, + VALID_FLOAT_REGEX.replaceAll("\\$", + Matcher.quoteReplacement("[\r]?[\n]?$"))) } test("transpile complex regex 2") { @@ -197,9 +202,11 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { "(.[1-9]*(?:0)?[1-9]+)?(.0*[1-9]+)?(?:.0*)?$" // input and output should be identical except for `.` being replaced with `[^\r\n]` + // and '$' being replaced with '[\r]?[\n]?$' doTranspileTest(TIMESTAMP_TRUNCATE_REGEX, - TIMESTAMP_TRUNCATE_REGEX.replaceAll("\\.", "[^\r\n]")) - + TIMESTAMP_TRUNCATE_REGEX + .replaceAll("\\.", "[^\r\n]") + .replaceAll("\\$", Matcher.quoteReplacement("[\r]?[\n]?$"))) } test("compare CPU and GPU: character range including unescaped + and -") { @@ -257,7 +264,6 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { // LF has been excluded due to known issues val chars = (0x00 to 0x7F) .map(_.toChar) - .filterNot(_ == '\n') doFuzzTest(Some(chars.mkString), replace = true) } @@ -272,8 +278,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { options = FuzzerOptions(validChars, maxStringLen = 12)) val data = Range(0, 1000) - // remove trailing newlines as workaround for https://github.com/rapidsai/cudf/issues/9620 - .map(_ => removeTrailingNewlines(r.nextString())) + .map(_ => r.nextString()) // generate patterns that are valid on both CPU and GPU val patterns = ListBuffer[String]() @@ -291,14 +296,6 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } } - private def removeTrailingNewlines(input: String): String = { - var s = input - while (s.endsWith("\r") || s.endsWith("\n")) { - s = s.substring(0, s.length - 1) - } - s - } - private def assertCpuGpuMatchesRegexpFind(javaPatterns: Seq[String], input: Seq[String]) = { for (javaPattern <- javaPatterns) { val cpu = cpuContains(javaPattern, input)