diff --git a/docs/configs.md b/docs/configs.md
index 4720050e5f8..d667b03f0d3 100644
--- a/docs/configs.md
+++ b/docs/configs.md
@@ -183,7 +183,9 @@ Name | SQL Function(s) | Description | Default Value | Notes
spark.rapids.sql.expression.SpecifiedWindowFrame| |specification of the width of the group (or "frame") of input rows around which a window function is evaluated|true|None|
spark.rapids.sql.expression.Sqrt|`sqrt`|square root|true|None|
spark.rapids.sql.expression.StartsWith| |Starts With|true|None|
+spark.rapids.sql.expression.StringLPad|`lpad`|Pad a string on the left|true|None|
spark.rapids.sql.expression.StringLocate|`position`, `locate`|Substring search operator|true|None|
+spark.rapids.sql.expression.StringRPad|`rpad`|Pad a string on the right|true|None|
spark.rapids.sql.expression.StringReplace|`replace`|StringReplace operator|true|None|
spark.rapids.sql.expression.StringTrim|`trim`|StringTrim operator|true|None|
spark.rapids.sql.expression.StringTrimLeft|`ltrim`|StringTrimLeft operator|true|None|
diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py
index cdd637ca122..b8230dde452 100644
--- a/integration_tests/src/main/python/data_gen.py
+++ b/integration_tests/src/main/python/data_gen.py
@@ -160,7 +160,7 @@ def with_special_pattern(self, pattern, flags=0, charset=sre_yield.CHARSET, weig
length = int(len(strs))
except OverflowError:
length = _MAX_CHOICES
- return self.with_special_case(lambda rand : strs[rand.randint(0, length)], weight=weight)
+ return self.with_special_case(lambda rand : strs[rand.randrange(0, length)], weight=weight)
def start(self, rand):
strs = self.base_strs
@@ -168,7 +168,7 @@ def start(self, rand):
length = int(len(strs))
except OverflowError:
length = _MAX_CHOICES
- self._start(rand, lambda : strs[rand.randint(0, length)])
+ self._start(rand, lambda : strs[rand.randrange(0, length)])
_BYTE_MIN = -(1 << 7)
_BYTE_MAX = (1 << 7) - 1
diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py
index 743ced11ceb..7cc03295dc8 100644
--- a/integration_tests/src/main/python/string_test.py
+++ b/integration_tests/src/main/python/string_test.py
@@ -35,6 +35,28 @@ def test_substring_index(data_gen,delim):
f.substring_index(f.col('a'), delim, -1),
f.substring_index(f.col('a'), delim, -4)))
+# ONLY LITERAL WIDTH AND PAD ARE SUPPORTED
+def test_lpad():
+ gen = mk_str_gen('.{0,5}')
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark: unary_op_df(spark, gen).selectExpr(
+ 'LPAD(a, 2, " ")',
+ 'LPAD(a, NULL, " ")',
+ 'LPAD(a, 5, NULL)',
+ 'LPAD(a, 5, "G")',
+ 'LPAD(a, -1, "G")'))
+
+# ONLY LITERAL WIDTH AND PAD ARE SUPPORTED
+def test_rpad():
+ gen = mk_str_gen('.{0,5}')
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark: unary_op_df(spark, gen).selectExpr(
+ 'RPAD(a, 2, " ")',
+ 'RPAD(a, NULL, " ")',
+ 'RPAD(a, 5, NULL)',
+ 'RPAD(a, 5, "G")',
+ 'RPAD(a, -1, "G")'))
+
# ONLY LITERAL SEARCH PARAMS ARE SUPPORTED
def test_position():
gen = mk_str_gen('.{0,3}Z_Z.{0,3}A.{0,3}')
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index 8de6e9ffeac..81fd95f9286 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -1282,6 +1282,50 @@ object GpuOverrides {
override def convertToGpu(child: Expression): GpuExpression = GpuLower(child)
})
.incompat(CASE_MODIFICATION_INCOMPAT),
+ expr[StringLPad](
+ "Pad a string on the left",
+ (in, conf, p, r) => new TernaryExprMeta[StringLPad](in, conf, p, r) {
+ override def tagExprForGpu(): Unit = {
+ if (!isLit(in.len)) {
+ willNotWorkOnGpu("only literal length is supported")
+ }
+
+ val padLit = extractLit(in.pad)
+ if (padLit.isEmpty) {
+ willNotWorkOnGpu("only literal pad is supported")
+ } else if (padLit.get.value != null &&
+ padLit.get.value.asInstanceOf[UTF8String].toString.length != 1) {
+ willNotWorkOnGpu("only a single character is supported for pad")
+ }
+ }
+ override def convertToGpu(
+ str: Expression,
+ width: Expression,
+ pad: Expression): GpuExpression =
+ GpuStringLPad(str, width, pad)
+ }),
+ expr[StringRPad](
+ "Pad a string on the right",
+ (in, conf, p, r) => new TernaryExprMeta[StringRPad](in, conf, p, r) {
+ override def tagExprForGpu(): Unit = {
+ if (!isLit(in.len)) {
+ willNotWorkOnGpu("only literal length is supported")
+ }
+
+ val padLit = extractLit(in.pad)
+ if (padLit.isEmpty) {
+ willNotWorkOnGpu("only literal pad is supported")
+ } else if (padLit.get.value != null &&
+ padLit.get.value.asInstanceOf[UTF8String].toString.length != 1) {
+ willNotWorkOnGpu("only a single character is supported for pad")
+ }
+ }
+ override def convertToGpu(
+ str: Expression,
+ width: Expression,
+ pad: Expression): GpuExpression =
+ GpuStringRPad(str, width, pad)
+ }),
expr[StringLocate](
"Substring search operator",
(in, conf, p, r) => new TernaryExprMeta[StringLocate](in, conf, p, r) {
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala
index 4805733cee4..1fbfe8cb571 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids
import scala.collection.mutable.ArrayBuffer
-import ai.rapids.cudf.{ColumnVector, Scalar, Table}
+import ai.rapids.cudf.{ColumnVector, DType, PadSide, Scalar, Table}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
@@ -696,3 +696,76 @@ case class GpuSubstringIndex(strExpr: Expression,
"Internal Error: this version of substring index is not supported")
}
+trait BasePad extends GpuTernaryExpression with ImplicitCastInputTypes with NullIntolerant {
+ val str: Expression
+ val len: Expression
+ val pad: Expression
+ val direction: PadSide
+
+ override def children: Seq[Expression] = str :: len :: pad :: Nil
+ override def dataType: DataType = StringType
+ override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType)
+
+ override def doColumnar(str: GpuColumnVector, len: Scalar, pad: Scalar): GpuColumnVector = {
+ if (len.isValid && pad.isValid) {
+ val l = math.max(0, len.getInt)
+ withResource(str.getBase.pad(l, direction, pad.getJavaString)) { padded =>
+ GpuColumnVector.from(padded.substring(0, l))
+ }
+ } else {
+ withResource(Scalar.fromNull(DType.STRING)) { ns =>
+ GpuColumnVector.from(ColumnVector.fromScalar(ns, str.getRowCount.toInt))
+ }
+ }
+ }
+
+ override def doColumnar(
+ str: GpuColumnVector,
+ len: GpuColumnVector,
+ pad: GpuColumnVector): GpuColumnVector =
+ throw new IllegalStateException("This is not supported yet")
+
+ override def doColumnar(
+ str: Scalar,
+ len: GpuColumnVector,
+ pad: GpuColumnVector): GpuColumnVector =
+ throw new IllegalStateException("This is not supported yet")
+
+ override def doColumnar(str: Scalar, len: Scalar, pad: GpuColumnVector): GpuColumnVector =
+ throw new IllegalStateException("This is not supported yet")
+
+ override def doColumnar(str: Scalar, len: GpuColumnVector, pad: Scalar): GpuColumnVector =
+ throw new IllegalStateException("This is not supported yet")
+
+ override def doColumnar(
+ str: GpuColumnVector,
+ len: Scalar,
+ pad: GpuColumnVector): GpuColumnVector =
+ throw new IllegalStateException("This is not supported yet")
+
+ override def doColumnar(
+ str: GpuColumnVector,
+ len: GpuColumnVector,
+ pad: Scalar): GpuColumnVector =
+ throw new IllegalStateException("This is not supported yet")
+}
+
+case class GpuStringLPad(str: Expression, len: Expression, pad: Expression)
+ extends BasePad {
+ val direction = PadSide.LEFT
+ override def prettyName: String = "lpad"
+
+ def this(str: Expression, len: Expression) = {
+ this(str, len, GpuLiteral(" ", StringType))
+ }
+}
+
+case class GpuStringRPad(str: Expression, len: Expression, pad: Expression)
+ extends BasePad {
+ val direction = PadSide.RIGHT
+ override def prettyName: String = "rpad"
+
+ def this(str: Expression, len: Expression) = {
+ this(str, len, GpuLiteral(" ", StringType))
+ }
+}
\ No newline at end of file