Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add regular expression support to string_split #4714

Merged
merged 16 commits into from
Feb 14, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ Here are some examples of regular expression patterns that are not supported on
- Line anchor `$`
- String anchor `\Z`
- String anchor `\z` is not supported by `regexp_replace`
- Line and string anchors are not supported by `string_split`
- Non-digit character class `\D`
- Non-word character class `\W`
- Word and non-word boundaries, `\b` and `\B`
Expand Down
93 changes: 90 additions & 3 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_sql_fallback_collect, assert_gpu_fallback_collect, assert_gpu_and_cpu_error
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, \
assert_gpu_sql_fallback_collect, assert_gpu_fallback_collect, assert_gpu_and_cpu_error, \
assert_cpu_and_gpu_are_equal_collect_with_capture
from conftest import is_databricks_runtime
from data_gen import *
from marks import *
Expand All @@ -25,15 +27,100 @@
def mk_str_gen(pattern):
return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')

def test_split():
def test_split_no_limit():
data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}')
delim = '_'
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "AB")',
'split(a, "C")',
'split(a, "_")'))

def test_split_negative_limit():
data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "AB", -1)',
'split(a, "C", -2)',
'split(a, "_", -999)'))

# https://github.com/NVIDIA/spark-rapids/issues/4720
@allow_non_gpu('ProjectExec', 'StringSplit')
def test_split_zero_limit_fallback():
data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}')
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "AB", 0)'),
exist_classes= "ProjectExec",
non_exist_classes= "GpuProjectExec")

# https://github.com/NVIDIA/spark-rapids/issues/4720
@allow_non_gpu('ProjectExec', 'StringSplit')
def test_split_one_limit_fallback():
data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}')
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "AB", 1)'),
exist_classes= "ProjectExec",
non_exist_classes= "GpuProjectExec")

def test_split_positive_limit():
data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "AB", 2)',
'split(a, "C", 3)',
'split(a, "_", 999)'))

def test_split_re_negative_limit():
data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \
.with_special_case('boo:and:foo')
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, ":", -1)',
'split(a, "o", -2)'))
ttnghia marked this conversation as resolved.
Show resolved Hide resolved

# https://github.com/NVIDIA/spark-rapids/issues/4720
@allow_non_gpu('ProjectExec', 'StringSplit')
def test_split_re_zero_limit_fallback():
data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \
.with_special_case('boo:and:foo')
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, ":", 0)',
'split(a, "o", 0)'),
exist_classes= "ProjectExec",
non_exist_classes= "GpuProjectExec")

# https://github.com/NVIDIA/spark-rapids/issues/4720
@allow_non_gpu('ProjectExec', 'StringSplit')
def test_split_re_one_limit_fallback():
data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \
.with_special_case('boo:and:foo')
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, ":", 1)',
'split(a, "o", 1)'),
exist_classes= "ProjectExec",
non_exist_classes= "GpuProjectExec")

def test_split_re_positive_limit():
data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \
.with_special_case('boo:and:foo')
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, ":", 2)',
'split(a, ":", 5)',
'split(a, "o", 2)',
'split(a, "o", 5)'))

def test_split_re_no_limit():
data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \
.with_special_case('boo:and:foo')
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, ":")',
'split(a, "o")'))

@pytest.mark.parametrize('data_gen,delim', [(mk_str_gen('([ABC]{0,3}_?){0,7}'), '_'),
(mk_str_gen('([MNP_]{0,3}\\.?){0,5}'), '.'),
(mk_str_gen('([123]{0,3}\\^?){0,5}'), '^')], ids=idfn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.shims.v2
import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta, RegexUnsupportedException, TernaryExprMeta}

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace, RegexReplaceMode}
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -39,7 +39,7 @@ class GpuRegExpReplaceMeta(
// use GpuStringReplace
} else {
try {
pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString))
pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString))
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class GpuRegExpReplaceMeta(
// use GpuStringReplace
} else {
try {
pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString))
pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString))
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.shims.v2
import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException}

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace, RegexReplaceMode}
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -39,7 +39,7 @@ class GpuRegExpReplaceMeta(
// use GpuStringReplace
} else {
try {
pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString))
pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString))
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.shims.v2
import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException}

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace, RegexReplaceMode}
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -39,7 +39,7 @@ class GpuRegExpReplaceMeta(
// use GpuStringReplace
} else {
try {
pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString))
pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString))
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,14 +428,19 @@ object RegexParser {
}
}

sealed trait RegexMode
object RegexFindMode extends RegexMode
object RegexReplaceMode extends RegexMode
object RegexSplitMode extends RegexMode

/**
* Transpile Java/Spark regular expression to a format that cuDF supports, or throw an exception
* if this is not possible.
*
* @param replace True if performing a replacement (regexp_replace), false
* if matching only (rlike)
*/
class CudfRegexTranspiler(replace: Boolean) {
class CudfRegexTranspiler(mode: RegexMode) {

// cuDF throws a "nothing to repeat" exception for many of the edge cases that are
// rejected by the transpiler
Expand Down Expand Up @@ -467,6 +472,8 @@ class CudfRegexTranspiler(replace: Boolean) {
case '$' =>
// see https://github.com/NVIDIA/spark-rapids/issues/4533
throw new RegexUnsupportedException("line anchor $ is not supported")
case '^' if mode == RegexSplitMode =>
throw new RegexUnsupportedException("line anchor ^ is not supported in split mode")
case _ =>
regex
}
Expand Down Expand Up @@ -494,8 +501,14 @@ class CudfRegexTranspiler(replace: Boolean) {
case 's' | 'S' =>
// see https://github.com/NVIDIA/spark-rapids/issues/4528
throw new RegexUnsupportedException("whitespace classes are not supported")
case 'A' if mode == RegexSplitMode =>
throw new RegexUnsupportedException("string anchor \\A is not supported in split mode")
case 'Z' if mode == RegexSplitMode =>
throw new RegexUnsupportedException("string anchor \\Z is not supported in split mode")
case 'z' if mode == RegexSplitMode =>
throw new RegexUnsupportedException("string anchor \\z is not supported in split mode")
case 'z' =>
if (replace) {
if (mode == RegexReplaceMode) {
// see https://github.com/NVIDIA/spark-rapids/issues/4425
throw new RegexUnsupportedException(
"string anchor \\z is not supported in replace mode")
Expand Down Expand Up @@ -590,7 +603,7 @@ class CudfRegexTranspiler(replace: Boolean) {
RegexSequence(parts.map(rewrite))

case RegexRepetition(base, quantifier) => (base, quantifier) match {
case (_, SimpleQuantifier(ch)) if replace && "?*".contains(ch) =>
case (_, SimpleQuantifier(ch)) if mode == RegexReplaceMode && "?*".contains(ch) =>
// example: pattern " ?", input "] b[", replace with "X":
// java: X]XXbX[X
// cuDF: XXXX] b[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ class GpuRLikeMeta(
case Literal(str: UTF8String, DataTypes.StringType) if str != null =>
try {
// verify that we support this regex and can transpile it to cuDF format
pattern = Some(new CudfRegexTranspiler(replace = false).transpile(str.toString))
pattern = Some(new CudfRegexTranspiler(RegexFindMode).transpile(str.toString))
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
Expand Down Expand Up @@ -946,7 +946,7 @@ class GpuRegExpExtractMeta(
try {
val javaRegexpPattern = str.toString
// verify that we support this regex and can transpile it to cuDF format
val cudfRegexPattern = new CudfRegexTranspiler(replace = false)
val cudfRegexPattern = new CudfRegexTranspiler(RegexFindMode)
.transpile(javaRegexpPattern)
pattern = Some(cudfRegexPattern)
numGroups = countGroups(new RegexParser(javaRegexpPattern).parse())
Expand Down Expand Up @@ -1289,51 +1289,67 @@ class GpuStringSplitMeta(
extends TernaryExprMeta[StringSplit](expr, conf, parent, rule) {
import GpuOverrides._

private var pattern: String = _
private var isRegExp = false

override def tagExprForGpu(): Unit = {
val regexp = extractLit(expr.regex)
if (regexp.isEmpty) {
willNotWorkOnGpu("only literal regexp values are supported")
} else {
val str = regexp.get.value.asInstanceOf[UTF8String]
if (str != null) {
if (RegexParser.isRegExpString(str.toString)) {
willNotWorkOnGpu("regular expressions are not supported yet")
}
if (str.numChars() == 0) {
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
willNotWorkOnGpu("An empty regex is not supported yet")
}
isRegExp = RegexParser.isRegExpString(str.toString)
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
if (isRegExp) {
try {
pattern = new CudfRegexTranspiler(RegexSplitMode).transpile(str.toString)
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
}
} else {
pattern = str.toString
}
} else {
willNotWorkOnGpu("null regex is not supported yet")
}
}
if (!isLit(expr.limit)) {
willNotWorkOnGpu("only literal limit is supported")
extractLit(expr.limit) match {
case Some(Literal(n: Int, _)) =>
if (n == 0 || n == 1) {
// https://github.com/NVIDIA/spark-rapids/issues/4720
willNotWorkOnGpu("limit of 0 or 1 is not supported")
}
case _ =>
willNotWorkOnGpu("only literal limit is supported")
}
}
override def convertToGpu(
str: Expression,
regexp: Expression,
limit: Expression): GpuExpression =
GpuStringSplit(str, regexp, limit)
GpuStringSplit(str, regexp, limit, isRegExp, pattern)
}

case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression)
extends GpuTernaryExpression with ImplicitCastInputTypes {
case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression,
isRegExp: Boolean, pattern: String)
Comment on lines +1374 to +1375
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get rid of the regex expression completely? It is now useless since we use pattern instead.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this can be achieved if you override columnarEval instead of doColumnar similar to https://github.com/NVIDIA/spark-rapids/pull/4636/files#diff-a12810882b81a4eb395c03a80951f96ec080db793ffed6755739eeb2122840ccR1432

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be switching this from a GpuTernaryExpression to a GpuBinaryExpression. I personally don't see it as a big deal either way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to stick with GpuTernaryExpression to match Spark

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we don't have to, right? I don't see any benefit from keeping it a TernaryExpr instead of just UnaryExpr/GpuExpr. I tried to implement GpuStringToMap to inherit GpuExpression and the evaluation function is super short: https://github.com/NVIDIA/spark-rapids/pull/4636/files#diff-a12810882b81a4eb395c03a80951f96ec080db793ffed6755739eeb2122840ccR1507-R1518

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically we have evaluated the literal delimiter pattern before calling to the Gpu override, thus we only pass in ONE input string expression.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't I still need to pass in all of the expressions though so that I can implement children() correctly?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the original delimiter expression still needs to be passed in to initialize children, but it is not used anywhere in the evaluation later on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The delimiter expression already isn't used in the evaluation. It is only referenced in override def second: Expression = regex which is just used to construct children in final def children: Seq[Expression] = IndexedSeq(first, second, third).

I'm not against making the change and am curious to see what the benefits are but I would rather do this as a follow-on issue and review how similar regexp expressions are implemented since they all follow this same pattern.

extends GpuTernaryExpression with ImplicitCastInputTypes {

override def dataType: DataType = ArrayType(StringType)
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
override def first: Expression = str
override def second: Expression = regex
override def third: Expression = limit

def this(exp: Expression, regex: Expression) = this(exp, regex, GpuLiteral(-1, IntegerType))

override def prettyName: String = "split"

override def doColumnar(str: GpuColumnVector, regex: GpuScalar,
limit: GpuScalar): ColumnVector = {
val intLimit = limit.getValue.asInstanceOf[Int]
str.getBase.stringSplitRecord(regex.getBase, intLimit)
str.getBase.stringSplitRecord(pattern, intLimit, isRegExp)
}

override def doColumnar(numRows: Int, val0: GpuScalar, val1: GpuScalar,
Expand Down
Loading