Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

regexp_replace with back-references should fall back to CPU #4556

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ Here are some examples of regular expression patterns that are not supported on
- Empty groups: `()`
- Regular expressions containing null characters (unless the pattern is a simple literal string)
- Hex and octal digits
- `regexp_replace` does not support back-references

Work is ongoing to increase the range of regular expressions that can run on the GPU.

Expand Down
25 changes: 25 additions & 0 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,31 @@ def test_re_replace():
'REGEXP_REPLACE(a, "TEST", NULL)'),
conf={'spark.rapids.sql.expression.RegExpReplace': 'true'})

@allow_non_gpu('ProjectExec', 'RegExpReplace')
def test_re_replace_backrefs():
gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}')
assert_gpu_fallback_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'REGEXP_REPLACE(a, "(TEST)", "[$0]")',
'REGEXP_REPLACE(a, "(TEST)", "[$1]")'),
'RegExpReplace',
conf={'spark.rapids.sql.expression.RegExpReplace': 'true'})

def test_re_replace_backrefs_escaped():
gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'REGEXP_REPLACE(a, "(TEST)", "[\\\\$0]")',
'REGEXP_REPLACE(a, "(TEST)", "[\\\\$1]")'),
conf={'spark.rapids.sql.expression.RegExpReplace': 'true'})

def test_re_replace_escaped():
gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'REGEXP_REPLACE(a, "[A-Z]+", "\\\\A\\A")'),
conf={'spark.rapids.sql.expression.RegExpReplace': 'true'})

def test_re_replace_null():
gen = mk_str_gen('[\u0000 ]{0,2}TE[\u0000 ]{0,2}ST[\u0000 ]{0,2}')\
.with_special_case("\u0000")\
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.shims.v2
import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta, RegexUnsupportedException, TernaryExprMeta}

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuStringReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace}
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -30,6 +30,7 @@ class GpuRegExpReplaceMeta(
extends TernaryExprMeta[RegExpReplace](expr, conf, parent, rule) {

private var pattern: Option[String] = None
private var replacement: Option[String] = None

override def tagExprForGpu(): Unit = {
expr.regexp match {
Expand All @@ -48,6 +49,15 @@ class GpuRegExpReplaceMeta(
case _ =>
willNotWorkOnGpu(s"only non-null literal strings are supported on GPU")
}

expr.rep match {
case Literal(s: UTF8String, DataTypes.StringType) if s != null =>
if (GpuRegExpUtils.containsBackrefs(s.toString)) {
willNotWorkOnGpu("regexp_replace with back-references is not supported")
}
replacement = Some(GpuRegExpUtils.unescapeReplaceString(s.toString))
case _ =>
}
}

override def convertToGpu(
Expand All @@ -57,8 +67,12 @@ class GpuRegExpReplaceMeta(
if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) {
GpuStringReplace(lhs, regexp, rep)
} else {
GpuRegExpReplace(lhs, regexp, rep, pattern.getOrElse(
throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern")))
(pattern, replacement) match {
case (Some(cudfPattern), Some(cudfReplacement)) =>
GpuRegExpReplace(lhs, regexp, rep, cudfPattern, cudfReplacement)
case _ =>
throw new IllegalStateException("Expression has not been tagged correctly")
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.shims.v2
import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuStringReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace}
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -30,6 +30,7 @@ class GpuRegExpReplaceMeta(
extends TernaryExprMeta[RegExpReplace](expr, conf, parent, rule) {

private var pattern: Option[String] = None
private var replacement: Option[String] = None

override def tagExprForGpu(): Unit = {
expr.regexp match {
Expand All @@ -48,6 +49,15 @@ class GpuRegExpReplaceMeta(
case _ =>
willNotWorkOnGpu(s"only non-null literal strings are supported on GPU")
}

expr.rep match {
case Literal(s: UTF8String, DataTypes.StringType) if s != null =>
if (GpuRegExpUtils.containsBackrefs(s.toString)) {
willNotWorkOnGpu("regexp_replace with back-references is not supported")
}
replacement = Some(GpuRegExpUtils.unescapeReplaceString(s.toString))
case _ =>
}
}

override def convertToGpu(
Expand All @@ -57,8 +67,12 @@ class GpuRegExpReplaceMeta(
if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) {
GpuStringReplace(lhs, regexp, rep)
} else {
GpuRegExpReplace(lhs, regexp, rep, pattern.getOrElse(
throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern")))
(pattern, replacement) match {
case (Some(cudfPattern), Some(cudfReplacement)) =>
GpuRegExpReplace(lhs, regexp, rep, cudfPattern, cudfReplacement)
case _ =>
throw new IllegalStateException("Expression has not been tagged correctly")
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.shims.v2
import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException}

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuStringReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace}
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -30,6 +30,7 @@ class GpuRegExpReplaceMeta(
extends QuaternaryExprMeta[RegExpReplace](expr, conf, parent, rule) {

private var pattern: Option[String] = None
private var replacement: Option[String] = None

override def tagExprForGpu(): Unit = {
expr.regexp match {
Expand All @@ -49,6 +50,15 @@ class GpuRegExpReplaceMeta(
willNotWorkOnGpu(s"only non-null literal strings are supported on GPU")
}

expr.rep match {
case Literal(s: UTF8String, DataTypes.StringType) if s != null =>
if (GpuRegExpUtils.containsBackrefs(s.toString)) {
willNotWorkOnGpu("regexp_replace with back-references is not supported")
}
replacement = Some(GpuRegExpUtils.unescapeReplaceString(s.toString))
case _ =>
}

GpuOverrides.extractLit(expr.pos).foreach { lit =>
if (lit.value.asInstanceOf[Int] != 1) {
willNotWorkOnGpu("only a search starting position of 1 is supported")
Expand All @@ -68,8 +78,12 @@ class GpuRegExpReplaceMeta(
if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) {
GpuStringReplace(subject, regexp, rep)
} else {
GpuRegExpReplace(subject, regexp, rep, pattern.getOrElse(
throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern")))
(pattern, replacement) match {
case (Some(cudfPattern), Some(cudfReplacement)) =>
GpuRegExpReplace(lhs, regexp, rep, cudfPattern, cudfReplacement)
case _ =>
throw new IllegalStateException("Expression has not been tagged correctly")
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.shims.v2
import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException}

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuStringReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace}
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -30,6 +30,7 @@ class GpuRegExpReplaceMeta(
extends QuaternaryExprMeta[RegExpReplace](expr, conf, parent, rule) {

private var pattern: Option[String] = None
private var replacement: Option[String] = None

override def tagExprForGpu(): Unit = {
expr.regexp match {
Expand All @@ -49,6 +50,15 @@ class GpuRegExpReplaceMeta(
willNotWorkOnGpu(s"only non-null literal strings are supported on GPU")
}

expr.rep match {
case Literal(s: UTF8String, DataTypes.StringType) if s != null =>
if (GpuRegExpUtils.containsBackrefs(s.toString)) {
willNotWorkOnGpu("regexp_replace with back-references is not supported")
}
replacement = Some(GpuRegExpUtils.unescapeReplaceString(s.toString))
case _ =>
}

GpuOverrides.extractLit(expr.pos).foreach { lit =>
if (lit.value.asInstanceOf[Int] != 1) {
willNotWorkOnGpu("only a search starting position of 1 is supported")
Expand All @@ -68,8 +78,12 @@ class GpuRegExpReplaceMeta(
if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) {
GpuStringReplace(subject, regexp, rep)
} else {
GpuRegExpReplace(subject, regexp, rep, pattern.getOrElse(
throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern")))
(pattern, replacement) match {
case (Some(cudfPattern), Some(cudfReplacement)) =>
GpuRegExpReplace(lhs, regexp, rep, cudfPattern, cudfReplacement)
case _ =>
throw new IllegalStateException("Expression has not been tagged correctly")
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -744,6 +744,48 @@ case class GpuLike(left: Expression, right: Expression, escapeChar: Char)
}
}

object GpuRegExpUtils {

/**
* Determine if a string contains back-references such as `$1` but ignoring
* if preceded by escape character.
*/
def containsBackrefs(s: String): Boolean = {
var i = 0
while (i < s.length) {
if (s.charAt(i) == '\\') {
i += 2
} else {
if (s.charAt(i) == '$' && i+1 < s.length) {
if (s.charAt(i+1).isDigit) {
return true
}
}
i += 1
}
}
false
}

/**
* We need to remove escape characters in the regexp_replace
* replacement string before passing to cuDF.
*/
def unescapeReplaceString(s: String): String = {
val b = new StringBuilder
var i = 0
while (i < s.length) {
if (s.charAt(i) == '\\' && i+1 < s.length) {
i += 1
}
b.append(s.charAt(i))
jlowe marked this conversation as resolved.
Show resolved Hide resolved
i += 1
}
b.toString
}

}

class GpuRLikeMeta(
expr: RLike,
conf: RapidsConf,
Expand Down Expand Up @@ -854,7 +896,8 @@ case class GpuRegExpReplace(
srcExpr: Expression,
searchExpr: Expression,
replaceExpr: Expression,
cudfRegexPattern: String)
cudfRegexPattern: String,
cudfReplacementString: String)
extends GpuRegExpTernaryBase with ImplicitCastInputTypes {

override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType)
Expand All @@ -863,15 +906,16 @@ case class GpuRegExpReplace(
override def second: Expression = searchExpr
override def third: Expression = replaceExpr

def this(srcExpr: Expression, searchExpr: Expression, cudfRegexPattern: String) = {
this(srcExpr, searchExpr, GpuLiteral("", StringType), cudfRegexPattern)
def this(srcExpr: Expression, searchExpr: Expression, cudfRegexPattern: String,
cudfReplacementString: String) = {
this(srcExpr, searchExpr, GpuLiteral("", StringType), cudfRegexPattern, cudfReplacementString)
}

override def doColumnar(
strExpr: GpuColumnVector,
searchExpr: GpuScalar,
replaceExpr: GpuScalar): ColumnVector = {
strExpr.getBase.replaceRegex(cudfRegexPattern, replaceExpr.getBase)
strExpr.getBase.replaceRegex(cudfRegexPattern, Scalar.fromString(cudfReplacementString))
jlowe marked this conversation as resolved.
Show resolved Hide resolved
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.util.{Random, Try}
import ai.rapids.cudf.{ColumnVector, CudfException}
import org.scalatest.FunSuite

import org.apache.spark.sql.rapids.GpuRegExpUtils
import org.apache.spark.sql.types.DataTypes

class RegularExpressionTranspilerSuite extends FunSuite with Arm {
Expand Down Expand Up @@ -418,13 +419,14 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
result
}

private val REPLACE_STRING = "_REPLACE_"
private val REPLACE_STRING = "\\_\\RE\\\\P\\L\\A\\C\\E\\_"

/** cuDF replaceRe helper */
private def gpuReplace(cudfPattern: String, input: Seq[String]): Array[String] = {
val result = new Array[String](input.length)
val replace = GpuRegExpUtils.unescapeReplaceString(REPLACE_STRING)
withResource(ColumnVector.fromStrings(input: _*)) { cv =>
withResource(GpuScalar.from(REPLACE_STRING, DataTypes.StringType)) { replace =>
withResource(GpuScalar.from(replace, DataTypes.StringType)) { replace =>
withResource(cv.replaceRegex(cudfPattern, replace)) { c =>
withResource(c.copyToHost()) { hv =>
result.indices.foreach(i => result(i) = new String(hv.getUTF8(i)))
Expand Down