diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index 0e3c5392475..743ced11ceb 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -95,23 +95,12 @@ def test_rtrim(): 'TRIM(TRAILING NULL FROM a)', 'TRIM(TRAILING "" FROM a)')) -# Once https://github.com/NVIDIA/spark-rapids/issues/112 is fixed this should be -# deleted and the corresponding lines in the other tests should be uncommented -@pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/112') -def test_start_n_ends_with_xfail(): - gen = mk_str_gen('[Ab\ud720]{3}A.{0,3}Z[Ab\ud720]{3}') - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen).select( - f.col('a').startswith(''), - f.col('a').endswith(''))) - def test_startswith(): gen = mk_str_gen('[Ab\ud720]{3}A.{0,3}Z[Ab\ud720]{3}') assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, gen).select( f.col('a').startswith('A'), - #https://github.com/NVIDIA/spark-rapids/issues/112 - #f.col('a').startswith(''), + f.col('a').startswith(''), f.col('a').startswith(None), f.col('a').startswith('A\ud720'))) @@ -121,8 +110,7 @@ def test_endswith(): assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, gen).select( f.col('a').endswith('A'), - #https://github.com/NVIDIA/spark-rapids/issues/112 - #f.col('a').endswith(''), + f.col('a').endswith(''), f.col('a').endswith(None), f.col('a').endswith('A\ud720'))) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 4db5d92a851..4805733cee4 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -171,18 +171,8 @@ case class GpuStartsWith(left: Expression, right: Expression) override def toString: String = s"gpustartswith($left, $right)" - def doColumnar(lhs: GpuColumnVector, rhs: Scalar): GpuColumnVector = { - if (rhs.getJavaString.isEmpty) { - val boolScalar = Scalar.fromBool(true) - try { - GpuColumnVector.from(ColumnVector.fromScalar(boolScalar, lhs.getRowCount.toInt)) - } finally { - boolScalar.close() - } - } else { - GpuColumnVector.from(lhs.getBase.startsWith(rhs)) - } - } + def doColumnar(lhs: GpuColumnVector, rhs: Scalar): GpuColumnVector = + GpuColumnVector.from(lhs.getBase.startsWith(rhs)) override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): GpuColumnVector = throw new IllegalStateException( @@ -206,18 +196,8 @@ case class GpuEndsWith(left: Expression, right: Expression) override def toString: String = s"gpuendswith($left, $right)" - def doColumnar(lhs: GpuColumnVector, rhs: Scalar): GpuColumnVector = { - if (rhs.getJavaString.isEmpty) { - val boolScalar = Scalar.fromBool(true) - try { - GpuColumnVector.from(ColumnVector.fromScalar(boolScalar, lhs.getRowCount.toInt)) - } finally { - boolScalar.close() - } - } else { - GpuColumnVector.from(lhs.getBase.endsWith(rhs)) - } - } + def doColumnar(lhs: GpuColumnVector, rhs: Scalar): GpuColumnVector = + GpuColumnVector.from(lhs.getBase.endsWith(rhs)) override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): GpuColumnVector = throw new IllegalStateException( @@ -340,24 +320,8 @@ case class GpuContains(left: Expression, right: Expression) extends GpuBinaryExp override def toString: String = s"gpucontains($left, $right)" - def doColumnar(lhs: GpuColumnVector, rhs: Scalar): GpuColumnVector = { - val ret = if (rhs.getJavaString.isEmpty) { - withResource(Scalar.fromBool(true)) { trueScalar => - if (left.nullable) { - withResource(Scalar.fromBool(null)) { nullBool => - withResource(lhs.getBase.isNull) { isNull => - isNull.ifElse(nullBool, trueScalar) - } - } - } else { - ColumnVector.fromScalar(trueScalar, lhs.getRowCount.toInt) - } - } - } else { - lhs.getBase.stringContains(rhs) - } - GpuColumnVector.from(ret) - } + def doColumnar(lhs: GpuColumnVector, rhs: Scalar): GpuColumnVector = + GpuColumnVector.from(lhs.getBase.stringContains(rhs)) override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): GpuColumnVector = throw new IllegalStateException("Really should not be here, " +