Skip to content

Commit

Permalink
Remove special cases for contains, startsWith, and endWith (NVIDIA#256)
Browse files Browse the repository at this point in the history
  • Loading branch information
revans2 authored Jun 24, 2020
1 parent 5361ef3 commit 941485c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 56 deletions.
16 changes: 2 additions & 14 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,23 +95,12 @@ def test_rtrim():
'TRIM(TRAILING NULL FROM a)',
'TRIM(TRAILING "" FROM a)'))

# Once https://github.com/NVIDIA/spark-rapids/issues/112 is fixed this should be
# deleted and the corresponding lines in the other tests should be uncommented
@pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/112')
def test_start_n_ends_with_xfail():
gen = mk_str_gen('[Ab\ud720]{3}A.{0,3}Z[Ab\ud720]{3}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).select(
f.col('a').startswith(''),
f.col('a').endswith('')))

def test_startswith():
gen = mk_str_gen('[Ab\ud720]{3}A.{0,3}Z[Ab\ud720]{3}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).select(
f.col('a').startswith('A'),
#https://github.com/NVIDIA/spark-rapids/issues/112
#f.col('a').startswith(''),
f.col('a').startswith(''),
f.col('a').startswith(None),
f.col('a').startswith('A\ud720')))

Expand All @@ -121,8 +110,7 @@ def test_endswith():
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).select(
f.col('a').endswith('A'),
#https://github.com/NVIDIA/spark-rapids/issues/112
#f.col('a').endswith(''),
f.col('a').endswith(''),
f.col('a').endswith(None),
f.col('a').endswith('A\ud720')))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,18 +171,8 @@ case class GpuStartsWith(left: Expression, right: Expression)

override def toString: String = s"gpustartswith($left, $right)"

def doColumnar(lhs: GpuColumnVector, rhs: Scalar): GpuColumnVector = {
if (rhs.getJavaString.isEmpty) {
val boolScalar = Scalar.fromBool(true)
try {
GpuColumnVector.from(ColumnVector.fromScalar(boolScalar, lhs.getRowCount.toInt))
} finally {
boolScalar.close()
}
} else {
GpuColumnVector.from(lhs.getBase.startsWith(rhs))
}
}
def doColumnar(lhs: GpuColumnVector, rhs: Scalar): GpuColumnVector =
GpuColumnVector.from(lhs.getBase.startsWith(rhs))

override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): GpuColumnVector =
throw new IllegalStateException(
Expand All @@ -206,18 +196,8 @@ case class GpuEndsWith(left: Expression, right: Expression)

override def toString: String = s"gpuendswith($left, $right)"

def doColumnar(lhs: GpuColumnVector, rhs: Scalar): GpuColumnVector = {
if (rhs.getJavaString.isEmpty) {
val boolScalar = Scalar.fromBool(true)
try {
GpuColumnVector.from(ColumnVector.fromScalar(boolScalar, lhs.getRowCount.toInt))
} finally {
boolScalar.close()
}
} else {
GpuColumnVector.from(lhs.getBase.endsWith(rhs))
}
}
def doColumnar(lhs: GpuColumnVector, rhs: Scalar): GpuColumnVector =
GpuColumnVector.from(lhs.getBase.endsWith(rhs))

override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): GpuColumnVector =
throw new IllegalStateException(
Expand Down Expand Up @@ -340,24 +320,8 @@ case class GpuContains(left: Expression, right: Expression) extends GpuBinaryExp

override def toString: String = s"gpucontains($left, $right)"

def doColumnar(lhs: GpuColumnVector, rhs: Scalar): GpuColumnVector = {
val ret = if (rhs.getJavaString.isEmpty) {
withResource(Scalar.fromBool(true)) { trueScalar =>
if (left.nullable) {
withResource(Scalar.fromBool(null)) { nullBool =>
withResource(lhs.getBase.isNull) { isNull =>
isNull.ifElse(nullBool, trueScalar)
}
}
} else {
ColumnVector.fromScalar(trueScalar, lhs.getRowCount.toInt)
}
}
} else {
lhs.getBase.stringContains(rhs)
}
GpuColumnVector.from(ret)
}
def doColumnar(lhs: GpuColumnVector, rhs: Scalar): GpuColumnVector =
GpuColumnVector.from(lhs.getBase.stringContains(rhs))

override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): GpuColumnVector =
throw new IllegalStateException("Really should not be here, " +
Expand Down

0 comments on commit 941485c

Please sign in to comment.