diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index 54f46b86b97..363821a835c 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -35,7 +35,8 @@ def test_split_no_limit(): lambda spark : unary_op_df(spark, data_gen).selectExpr( 'split(a, "AB")', 'split(a, "C")', - 'split(a, "_")')) + 'split(a, "_")'), + conf=_regexp_conf) def test_split_negative_limit(): data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}') @@ -43,7 +44,8 @@ def test_split_negative_limit(): lambda spark : unary_op_df(spark, data_gen).selectExpr( 'split(a, "AB", -1)', 'split(a, "C", -2)', - 'split(a, "_", -999)')) + 'split(a, "_", -999)'), + conf=_regexp_conf) # https://github.com/NVIDIA/spark-rapids/issues/4720 @allow_non_gpu('ProjectExec', 'StringSplit') @@ -52,6 +54,7 @@ def test_split_zero_limit_fallback(): assert_cpu_and_gpu_are_equal_collect_with_capture( lambda spark : unary_op_df(spark, data_gen).selectExpr( 'split(a, "AB", 0)'), + conf=_regexp_conf, exist_classes= "ProjectExec", non_exist_classes= "GpuProjectExec") @@ -62,6 +65,7 @@ def test_split_one_limit_fallback(): assert_cpu_and_gpu_are_equal_collect_with_capture( lambda spark : unary_op_df(spark, data_gen).selectExpr( 'split(a, "AB", 1)'), + conf=_regexp_conf, exist_classes= "ProjectExec", non_exist_classes= "GpuProjectExec") @@ -84,7 +88,8 @@ def test_split_re_negative_limit(): 'split(a, "[^o]", -1)', 'split(a, "[o]{1,2}", -1)', 'split(a, "[bf]", -1)', - 'split(a, "[o]", -2)')) + 'split(a, "[o]", -2)'), + conf=_regexp_conf) # https://github.com/NVIDIA/spark-rapids/issues/4720 @allow_non_gpu('ProjectExec', 'StringSplit') @@ -123,7 +128,8 @@ def test_split_re_positive_limit(): 'split(a, "[^o]", 55)', 'split(a, "[o]{1,2}", 999)', 'split(a, "[bf]", 2)', - 'split(a, "[o]", 5)')) + 'split(a, "[o]", 5)'), + conf=_regexp_conf) def test_split_re_no_limit(): data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ @@ -136,7 +142,8 @@ def test_split_re_no_limit(): 'split(a, "[^o]")', 'split(a, "[o]{1,2}")', 'split(a, "[bf]")', - 'split(a, "[o]")')) + 'split(a, "[o]")'), + conf=_regexp_conf) def test_split_optimized_no_re(): data_gen = mk_str_gen('([bf]o{0,2}[.?+\\^$|{}]{1,2}){1,7}') \ @@ -158,9 +165,8 @@ def test_split_optimized_no_re(): 'split(a, "\\\\|")', 'split(a, "\\\\{")', 'split(a, "\\\\}")', - 'split(a, "\\\\$\\\\|")', - ) - ) + 'split(a, "\\\\$\\\\|")'), + conf=_regexp_conf) def test_split_optimized_no_re_combined(): data_gen = mk_str_gen('([bf]o{0,2}[AZ.?+\\^$|{}]{1,2}){1,7}') \ @@ -180,10 +186,53 @@ def test_split_optimized_no_re_combined(): 'split(a, "A\\\\$Z")', 'split(a, "A\\\\|Z")', 'split(a, "\\\\{Z")', - 'split(a, "\\\\}Z")', - ) + 'split(a, "\\\\}Z")'), + conf=_regexp_conf) + +def test_split_regexp_disabled_no_fallback(): + conf = { 'spark.rapids.sql.regexp.enabled': 'false' } + data_gen = mk_str_gen('([bf]o{0,2}[.?+\\^$|&_]{1,2}){1,7}') \ + .with_special_case('boo.and.foo') \ + .with_special_case('boo?and?foo') \ + .with_special_case('boo+and+foo') \ + .with_special_case('boo^and^foo') \ + .with_special_case('boo$and$foo') \ + .with_special_case('boo|and|foo') \ + .with_special_case('boo&and&foo') \ + .with_special_case('boo_and_foo') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "\\\\.")', + 'split(a, "\\\\?")', + 'split(a, "\\\\+")', + 'split(a, "\\\\^")', + 'split(a, "\\\\$")', + 'split(a, "\\\\|")', + 'split(a, "&")', + 'split(a, "_")', + ), conf ) +@allow_non_gpu('ProjectExec', 'StringSplit') +def test_split_regexp_disabled_fallback(): + conf = { 'spark.rapids.sql.regexp.enabled': 'false' } + data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ + .with_special_case('boo:and:foo') + assert_gpu_sql_fallback_collect( + lambda spark : unary_op_df(spark, data_gen), + 'StringSplit', + 'string_split_table', + 'select ' + + 'split(a, "[:]", 2), ' + + 'split(a, "[o:]", 5), ' + + 'split(a, "[^:]", 2), ' + + 'split(a, "[^o]", 55), ' + + 'split(a, "[o]{1,2}", 999), ' + + 'split(a, "[bf]", 2), ' + + 'split(a, "[o]", 5) from string_split_table', + conf) + + @pytest.mark.parametrize('data_gen,delim', [(mk_str_gen('([ABC]{0,3}_?){0,7}'), '_'), (mk_str_gen('([MNP_]{0,3}\\.?){0,5}'), '.'), (mk_str_gen('([123]{0,3}\\^?){0,5}'), '^')], ids=idfn) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 2f4114497cf..8b5ce42c899 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1414,10 +1414,22 @@ class GpuStringSplitMeta( extends StringSplitRegExpMeta[StringSplit](expr, conf, parent, rule) { import GpuOverrides._ - private var delimInfo: Option[(String, Boolean)] = None + private var pattern = "" + private var isRegExp = false override def tagExprForGpu(): Unit = { - delimInfo = checkRegExp(expr.regex) + checkRegExp(expr.regex) match { + case Some((p, isRe)) => + pattern = p + isRegExp = isRe + case _ => throwUncheckedDelimiterException() + } + + // if this is a valid regular expression, then we should check the configuration + // whether to run this on the GPU + if (isRegExp) { + GpuRegExpUtils.tagForRegExpEnabled(this) + } extractLit(expr.limit) match { case Some(Literal(n: Int, _)) => @@ -1434,8 +1446,7 @@ class GpuStringSplitMeta( str: Expression, regexp: Expression, limit: Expression): GpuExpression = { - val delim: (String, Boolean) = delimInfo.getOrElse(throwUncheckedDelimiterException()) - GpuStringSplit(str, regexp, limit, delim._1, delim._2) + GpuStringSplit(str, regexp, limit, pattern, isRegExp) } }