From 75bf95fe9764f9fcac77a2651b841614c4a34f99 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 18 Jan 2022 12:37:16 -0700 Subject: [PATCH 1/5] Fall back to CPU if regexp_replace replacement expression contains back-references Signed-off-by: Andy Grove --- integration_tests/src/main/python/string_test.py | 12 +++++++++++- .../spark/rapids/shims/v2/GpuRegExpReplaceExec.scala | 12 +++++++++++- .../spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala | 12 +++++++++++- .../spark/rapids/shims/v2/GpuRegExpReplaceExec.scala | 12 +++++++++++- .../spark/rapids/shims/v2/GpuRegExpReplaceExec.scala | 12 +++++++++++- 5 files changed, 55 insertions(+), 5 deletions(-) diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index 65f9042b9c2..0bfcf1d255f 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -348,6 +348,16 @@ def test_re_replace(): 'REGEXP_REPLACE(a, "TEST", NULL)'), conf={'spark.rapids.sql.expression.RegExpReplace': 'true'}) +@allow_non_gpu('ProjectExec', 'RegExpReplace') +def test_re_replace_backrefs(): + gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}') + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'REGEXP_REPLACE(a, "(TEST)", "[$0]")', + 'REGEXP_REPLACE(a, "(TEST)", "[$1]")'), + 'RegExpReplace', + conf={'spark.rapids.sql.expression.RegExpReplace': 'true'}) + def test_re_replace_null(): gen = mk_str_gen('[\u0000 ]{0,2}TE[\u0000 ]{0,2}ST[\u0000 ]{0,2}')\ .with_special_case("\u0000")\ diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index a482b09232b..6de6ac64e27 100644 --- a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ */ package com.nvidia.spark.rapids.shims.v2 +import java.util.regex.Pattern + import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta, RegexUnsupportedException, TernaryExprMeta} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} @@ -48,6 +50,14 @@ class GpuRegExpReplaceMeta( case _ => willNotWorkOnGpu(s"only non-null literal strings are supported on GPU") } + + expr.rep match { + case Literal(s: UTF8String, DataTypes.StringType) if s != null => + val backrefPattern = Pattern.compile("\\$[0-9]") + if (backrefPattern.matcher(s.toString).find()) { + willNotWorkOnGpu("regexp_replace with back-references is not supported") + } + } } override def convertToGpu( diff --git a/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala b/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala index ad608663567..78b3488e2c8 100644 --- a/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala +++ b/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ */ package com.nvidia.spark.rapids.shims.v2 +import java.util.regex.Pattern + import com.nvidia.spark.rapids._ import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} @@ -48,6 +50,14 @@ class GpuRegExpReplaceMeta( case _ => willNotWorkOnGpu(s"only non-null literal strings are supported on GPU") } + + expr.rep match { + case Literal(s: UTF8String, DataTypes.StringType) if s != null => + val backrefPattern = Pattern.compile("\\$[0-9]") + if (backrefPattern.matcher(s.toString).find()) { + willNotWorkOnGpu("regexp_replace with back-references is not supported") + } + } } override def convertToGpu( diff --git a/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index 82df7a91fd0..01ab1fa2e7c 100644 --- a/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ */ package com.nvidia.spark.rapids.shims.v2 +import java.util.regex.Pattern + import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} @@ -49,6 +51,14 @@ class GpuRegExpReplaceMeta( willNotWorkOnGpu(s"only non-null literal strings are supported on GPU") } + expr.rep match { + case Literal(s: UTF8String, DataTypes.StringType) if s != null => + val backrefPattern = Pattern.compile("\\$[0-9]") + if (backrefPattern.matcher(s.toString).find()) { + willNotWorkOnGpu("regexp_replace with back-references is not supported") + } + } + GpuOverrides.extractLit(expr.pos).foreach { lit => if (lit.value.asInstanceOf[Int] != 1) { willNotWorkOnGpu("only a search starting position of 1 is supported") diff --git a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index 82df7a91fd0..01ab1fa2e7c 100644 --- a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ */ package com.nvidia.spark.rapids.shims.v2 +import java.util.regex.Pattern + import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} @@ -49,6 +51,14 @@ class GpuRegExpReplaceMeta( willNotWorkOnGpu(s"only non-null literal strings are supported on GPU") } + expr.rep match { + case Literal(s: UTF8String, DataTypes.StringType) if s != null => + val backrefPattern = Pattern.compile("\\$[0-9]") + if (backrefPattern.matcher(s.toString).find()) { + willNotWorkOnGpu("regexp_replace with back-references is not supported") + } + } + GpuOverrides.extractLit(expr.pos).foreach { lit => if (lit.value.asInstanceOf[Int] != 1) { willNotWorkOnGpu("only a search starting position of 1 is supported") From 257b5358b64b4927c3f78fd373bbff5294ca1212 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 18 Jan 2022 12:38:07 -0700 Subject: [PATCH 2/5] Update compatibility guide --- docs/compatibility.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/compatibility.md b/docs/compatibility.md index f37aa5ebf80..8f1e5ce537d 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -480,6 +480,7 @@ Here are some examples of regular expression patterns that are not supported on - Empty groups: `()` - Regular expressions containing null characters (unless the pattern is a simple literal string) - Hex and octal digits +- `regexp_replace` does not support back-references Work is ongoing to increase the range of regular expressions that can run on the GPU. From 5d5727b635711571305d7bd6a72fa897882223f5 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 19 Jan 2022 14:02:27 -0700 Subject: [PATCH 3/5] Add support for escape characters in replacement strings Signed-off-by: Andy Grove --- .../src/main/python/string_test.py | 15 ++++++ .../shims/v2/GpuRegExpReplaceExec.scala | 18 ++++--- .../shims/v2/GpuRegExpReplaceMeta.scala | 18 ++++--- .../shims/v2/GpuRegExpReplaceExec.scala | 18 ++++--- .../shims/v2/GpuRegExpReplaceExec.scala | 18 ++++--- .../spark/sql/rapids/stringFunctions.scala | 54 +++++++++++++++++-- 6 files changed, 108 insertions(+), 33 deletions(-) diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index 98a7a5a9c12..bbe16de9e41 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -358,6 +358,21 @@ def test_re_replace_backrefs(): 'RegExpReplace', conf={'spark.rapids.sql.expression.RegExpReplace': 'true'}) +def test_re_replace_backrefs_escaped(): + gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'REGEXP_REPLACE(a, "(TEST)", "[\\\\$0]")', + 'REGEXP_REPLACE(a, "(TEST)", "[\\\\$1]")'), + conf={'spark.rapids.sql.expression.RegExpReplace': 'true'}) + +def test_re_replace_escaped(): + gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'REGEXP_REPLACE(a, "[A-Z]+", "\\\\A\\A")'), + conf={'spark.rapids.sql.expression.RegExpReplace': 'true'}) + def test_re_replace_null(): gen = mk_str_gen('[\u0000 ]{0,2}TE[\u0000 ]{0,2}ST[\u0000 ]{0,2}')\ .with_special_case("\u0000")\ diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index 6de6ac64e27..39e09b99eca 100644 --- a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -15,12 +15,10 @@ */ package com.nvidia.spark.rapids.shims.v2 -import java.util.regex.Pattern - import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta, RegexUnsupportedException, TernaryExprMeta} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} -import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuStringReplace} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace} import org.apache.spark.sql.types.DataTypes import org.apache.spark.unsafe.types.UTF8String @@ -32,6 +30,7 @@ class GpuRegExpReplaceMeta( extends TernaryExprMeta[RegExpReplace](expr, conf, parent, rule) { private var pattern: Option[String] = None + private var replacement: Option[String] = None override def tagExprForGpu(): Unit = { expr.regexp match { @@ -53,10 +52,11 @@ class GpuRegExpReplaceMeta( expr.rep match { case Literal(s: UTF8String, DataTypes.StringType) if s != null => - val backrefPattern = Pattern.compile("\\$[0-9]") - if (backrefPattern.matcher(s.toString).find()) { + if (GpuRegExpUtils.containsBackrefs(s.toString)) { willNotWorkOnGpu("regexp_replace with back-references is not supported") } + replacement = Some(GpuRegExpUtils.unescapeReplaceString(s.toString)) + case _ => } } @@ -67,8 +67,12 @@ class GpuRegExpReplaceMeta( if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) { GpuStringReplace(lhs, regexp, rep) } else { - GpuRegExpReplace(lhs, regexp, rep, pattern.getOrElse( - throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern"))) + (pattern, replacement) match { + case (Some(cudfPattern), Some(cudfReplacement)) => + GpuRegExpReplace(lhs, regexp, rep, cudfPattern, cudfReplacement) + case _ => + throw new IllegalStateException("Expression has not been tagged correctly") + } } } } diff --git a/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala b/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala index 78b3488e2c8..2460e3d2e5f 100644 --- a/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala +++ b/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala @@ -15,12 +15,10 @@ */ package com.nvidia.spark.rapids.shims.v2 -import java.util.regex.Pattern - import com.nvidia.spark.rapids._ import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} -import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuStringReplace} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace} import org.apache.spark.sql.types.DataTypes import org.apache.spark.unsafe.types.UTF8String @@ -32,6 +30,7 @@ class GpuRegExpReplaceMeta( extends TernaryExprMeta[RegExpReplace](expr, conf, parent, rule) { private var pattern: Option[String] = None + private var replacement: Option[String] = None override def tagExprForGpu(): Unit = { expr.regexp match { @@ -53,10 +52,11 @@ class GpuRegExpReplaceMeta( expr.rep match { case Literal(s: UTF8String, DataTypes.StringType) if s != null => - val backrefPattern = Pattern.compile("\\$[0-9]") - if (backrefPattern.matcher(s.toString).find()) { + if (GpuRegExpUtils.containsBackrefs(s.toString)) { willNotWorkOnGpu("regexp_replace with back-references is not supported") } + replacement = Some(GpuRegExpUtils.unescapeReplaceString(s.toString)) + case _ => } } @@ -67,8 +67,12 @@ class GpuRegExpReplaceMeta( if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) { GpuStringReplace(lhs, regexp, rep) } else { - GpuRegExpReplace(lhs, regexp, rep, pattern.getOrElse( - throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern"))) + (pattern, replacement) match { + case (Some(cudfPattern), Some(cudfReplacement)) => + GpuRegExpReplace(lhs, regexp, rep, cudfPattern, cudfReplacement) + case _ => + throw new IllegalStateException("Expression has not been tagged correctly") + } } } } diff --git a/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index 01ab1fa2e7c..b199a4d3a91 100644 --- a/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -15,12 +15,10 @@ */ package com.nvidia.spark.rapids.shims.v2 -import java.util.regex.Pattern - import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} -import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuStringReplace} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace} import org.apache.spark.sql.types.DataTypes import org.apache.spark.unsafe.types.UTF8String @@ -32,6 +30,7 @@ class GpuRegExpReplaceMeta( extends QuaternaryExprMeta[RegExpReplace](expr, conf, parent, rule) { private var pattern: Option[String] = None + private var replacement: Option[String] = None override def tagExprForGpu(): Unit = { expr.regexp match { @@ -53,10 +52,11 @@ class GpuRegExpReplaceMeta( expr.rep match { case Literal(s: UTF8String, DataTypes.StringType) if s != null => - val backrefPattern = Pattern.compile("\\$[0-9]") - if (backrefPattern.matcher(s.toString).find()) { + if (GpuRegExpUtils.containsBackrefs(s.toString)) { willNotWorkOnGpu("regexp_replace with back-references is not supported") } + replacement = Some(GpuRegExpUtils.unescapeReplaceString(s.toString)) + case _ => } GpuOverrides.extractLit(expr.pos).foreach { lit => @@ -78,8 +78,12 @@ class GpuRegExpReplaceMeta( if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) { GpuStringReplace(subject, regexp, rep) } else { - GpuRegExpReplace(subject, regexp, rep, pattern.getOrElse( - throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern"))) + (pattern, replacement) match { + case (Some(cudfPattern), Some(cudfReplacement)) => + GpuRegExpReplace(lhs, regexp, rep, cudfPattern, cudfReplacement) + case _ => + throw new IllegalStateException("Expression has not been tagged correctly") + } } } } \ No newline at end of file diff --git a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index 01ab1fa2e7c..b199a4d3a91 100644 --- a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -15,12 +15,10 @@ */ package com.nvidia.spark.rapids.shims.v2 -import java.util.regex.Pattern - import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} -import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuStringReplace} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace} import org.apache.spark.sql.types.DataTypes import org.apache.spark.unsafe.types.UTF8String @@ -32,6 +30,7 @@ class GpuRegExpReplaceMeta( extends QuaternaryExprMeta[RegExpReplace](expr, conf, parent, rule) { private var pattern: Option[String] = None + private var replacement: Option[String] = None override def tagExprForGpu(): Unit = { expr.regexp match { @@ -53,10 +52,11 @@ class GpuRegExpReplaceMeta( expr.rep match { case Literal(s: UTF8String, DataTypes.StringType) if s != null => - val backrefPattern = Pattern.compile("\\$[0-9]") - if (backrefPattern.matcher(s.toString).find()) { + if (GpuRegExpUtils.containsBackrefs(s.toString)) { willNotWorkOnGpu("regexp_replace with back-references is not supported") } + replacement = Some(GpuRegExpUtils.unescapeReplaceString(s.toString)) + case _ => } GpuOverrides.extractLit(expr.pos).foreach { lit => @@ -78,8 +78,12 @@ class GpuRegExpReplaceMeta( if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) { GpuStringReplace(subject, regexp, rep) } else { - GpuRegExpReplace(subject, regexp, rep, pattern.getOrElse( - throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern"))) + (pattern, replacement) match { + case (Some(cudfPattern), Some(cudfReplacement)) => + GpuRegExpReplace(lhs, regexp, rep, cudfPattern, cudfReplacement) + case _ => + throw new IllegalStateException("Expression has not been tagged correctly") + } } } } \ No newline at end of file diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 692be261f6a..b0b0e5ae8fd 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -744,6 +744,48 @@ case class GpuLike(left: Expression, right: Expression, escapeChar: Char) } } +object GpuRegExpUtils { + + /** + * Determine if a string contains back-references such as `$1` but ignoring + * if preceded by escape character. + */ + def containsBackrefs(s: String): Boolean = { + var i = 0 + while (i < s.length) { + if (s.charAt(i) == '\\') { + i += 2 + } else { + if (s.charAt(i) == '$' && i+1 < s.length) { + if (s.charAt(i+1).isDigit) { + return true + } + } + i += 1 + } + } + false + } + + /** + * We need to remove escape characters in the regexp_replace + * replacement string before passing to cuDF. + */ + def unescapeReplaceString(s: String): String = { + val b = new StringBuilder + var i = 0 + while (i < s.length) { + if (s.charAt(i) == '\\' && i+1 < s.length) { + i += 1 + } + b.append(s.charAt(i)) + i += 1 + } + b.toString + } + +} + class GpuRLikeMeta( expr: RLike, conf: RapidsConf, @@ -854,7 +896,8 @@ case class GpuRegExpReplace( srcExpr: Expression, searchExpr: Expression, replaceExpr: Expression, - cudfRegexPattern: String) + cudfRegexPattern: String, + cudfReplacementString: String) extends GpuRegExpTernaryBase with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) @@ -863,15 +906,16 @@ case class GpuRegExpReplace( override def second: Expression = searchExpr override def third: Expression = replaceExpr - def this(srcExpr: Expression, searchExpr: Expression, cudfRegexPattern: String) = { - this(srcExpr, searchExpr, GpuLiteral("", StringType), cudfRegexPattern) + def this(srcExpr: Expression, searchExpr: Expression, cudfRegexPattern: String, + cudfReplacementString: String) = { + this(srcExpr, searchExpr, GpuLiteral("", StringType), cudfRegexPattern, cudfReplacementString) } override def doColumnar( strExpr: GpuColumnVector, searchExpr: GpuScalar, replaceExpr: GpuScalar): ColumnVector = { - strExpr.getBase.replaceRegex(cudfRegexPattern, replaceExpr.getBase) + strExpr.getBase.replaceRegex(cudfRegexPattern, Scalar.fromString(cudfReplacementString)) } } From cc5d1fa54746ced6c78d2b93ab7d8e221227fe94 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 19 Jan 2022 14:52:48 -0700 Subject: [PATCH 4/5] Use escaped characters in all regexp_replace tests --- .../spark/rapids/RegularExpressionTranspilerSuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 9a2c3668cdd..5809fc2f73c 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -23,6 +23,7 @@ import scala.util.{Random, Try} import ai.rapids.cudf.{ColumnVector, CudfException} import org.scalatest.FunSuite +import org.apache.spark.sql.rapids.GpuRegExpUtils import org.apache.spark.sql.types.DataTypes class RegularExpressionTranspilerSuite extends FunSuite with Arm { @@ -418,13 +419,14 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { result } - private val REPLACE_STRING = "_REPLACE_" + private val REPLACE_STRING = "\\_\\RE\\\\P\\L\\A\\C\\E\\_" /** cuDF replaceRe helper */ private def gpuReplace(cudfPattern: String, input: Seq[String]): Array[String] = { val result = new Array[String](input.length) + val replace = GpuRegExpUtils.unescapeReplaceString(REPLACE_STRING) withResource(ColumnVector.fromStrings(input: _*)) { cv => - withResource(GpuScalar.from(REPLACE_STRING, DataTypes.StringType)) { replace => + withResource(GpuScalar.from(replace, DataTypes.StringType)) { replace => withResource(cv.replaceRegex(cudfPattern, replace)) { c => withResource(c.copyToHost()) { hv => result.indices.foreach(i => result(i) = new String(hv.getUTF8(i))) From 8e6dae7130fdb3f600a654bbd032255274c2329e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 19 Jan 2022 16:49:51 -0700 Subject: [PATCH 5/5] Fix resource leak and add more escaped characters to test --- integration_tests/src/main/python/string_test.py | 2 +- .../scala/org/apache/spark/sql/rapids/stringFunctions.scala | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index bbe16de9e41..bf740a64d2e 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -370,7 +370,7 @@ def test_re_replace_escaped(): gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}') assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, gen).selectExpr( - 'REGEXP_REPLACE(a, "[A-Z]+", "\\\\A\\A")'), + 'REGEXP_REPLACE(a, "[A-Z]+", "\\\\A\\A\\\\t\\\\r\\\\n\\t\\r\\n")'), conf={'spark.rapids.sql.expression.RegExpReplace': 'true'}) def test_re_replace_null(): diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index b0b0e5ae8fd..5bcd9826028 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -915,7 +915,9 @@ case class GpuRegExpReplace( strExpr: GpuColumnVector, searchExpr: GpuScalar, replaceExpr: GpuScalar): ColumnVector = { - strExpr.getBase.replaceRegex(cudfRegexPattern, Scalar.fromString(cudfReplacementString)) + withResource(Scalar.fromString(cudfReplacementString)) { rep => + strExpr.getBase.replaceRegex(cudfRegexPattern, rep) + } } }