diff --git a/docs/configs.md b/docs/configs.md
index 208a10c09eb..bdc3cb692c2 100644
--- a/docs/configs.md
+++ b/docs/configs.md
@@ -128,6 +128,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
spark.rapids.sql.expression.Expm1|`expm1`|Euler's number e raised to a power minus 1|true|None|
spark.rapids.sql.expression.Floor|`floor`|Floor of a number|true|None|
spark.rapids.sql.expression.FromUnixTime|`from_unixtime`|Get the string from a unix timestamp|true|None|
+spark.rapids.sql.expression.GetArrayItem| |Gets the field at `ordinal` in the Array|true|None|
spark.rapids.sql.expression.GreaterThan|`>`|> operator|true|None|
spark.rapids.sql.expression.GreaterThanOrEqual|`>=`|>= operator|true|None|
spark.rapids.sql.expression.Hour|`hour`|Returns the hour component of the string/timestamp|true|None|
@@ -186,6 +187,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
spark.rapids.sql.expression.StringLocate|`position`, `locate`|Substring search operator|true|None|
spark.rapids.sql.expression.StringRPad|`rpad`|Pad a string on the right|true|None|
spark.rapids.sql.expression.StringReplace|`replace`|StringReplace operator|true|None|
+spark.rapids.sql.expression.StringSplit|`split`|Splits `str` around occurrences that match `regex`|true|None|
spark.rapids.sql.expression.StringTrim|`trim`|StringTrim operator|true|None|
spark.rapids.sql.expression.StringTrimLeft|`ltrim`|StringTrimLeft operator|true|None|
spark.rapids.sql.expression.StringTrimRight|`rtrim`|StringTrimRight operator|true|None|
diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py
index 7cc03295dc8..372f4272350 100644
--- a/integration_tests/src/main/python/string_test.py
+++ b/integration_tests/src/main/python/string_test.py
@@ -23,6 +23,20 @@
def mk_str_gen(pattern):
return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')
+# Because of limitations in array support we need to combine these two together to make
+# this work. This should be split up into separate tests once support is better.
+def test_split_with_array_index():
+ data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}')
+ delim = '_'
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark : unary_op_df(spark, data_gen).selectExpr(
+ 'split(a, "AB")[0]',
+ 'split(a, "_")[1]',
+ 'split(a, "_")[null]',
+ 'split(a, "_")[3]',
+ 'split(a, "_")[0]',
+ 'split(a, "_")[-1]'))
+
@pytest.mark.parametrize('data_gen,delim', [(mk_str_gen('([ABC]{0,3}_?){0,7}'), '_'),
(mk_str_gen('([MNP_]{0,3}\\.?){0,5}'), '.'),
(mk_str_gen('([123]{0,3}\\^?){0,5}'), '^')], ids=idfn)
diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java
index c69c13527f3..8cc1aca2319 100644
--- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java
+++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java
@@ -16,6 +16,7 @@
package com.nvidia.spark.rapids;
+import ai.rapids.cudf.ColumnViewAccess;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.HostColumnVector;
import ai.rapids.cudf.Scalar;
@@ -192,12 +193,22 @@ protected static final DataType getSparkType(DType type) {
case TIMESTAMP_DAYS:
return DataTypes.DateType;
case TIMESTAMP_MICROSECONDS:
- return DataTypes.TimestampType; // TODO need to verify that the TimeUnits are correct
+ return DataTypes.TimestampType;
case STRING:
return DataTypes.StringType;
default:
throw new IllegalArgumentException(type + " is not supported by spark yet.");
+ }
+ }
+ protected static final DataType getSparkTypeFrom(ColumnViewAccess access) {
+ DType type = access.getDataType();
+ if (type == DType.LIST) {
+ try (ColumnViewAccess child = access.getChildColumnViewAccess(0)) {
+ return new ArrayType(getSparkTypeFrom(child), true);
+ }
+ } else {
+ return getSparkType(type);
}
}
@@ -300,7 +311,7 @@ public static final ColumnarBatch from(Table table, int startColIndex, int until
* but not both.
*/
public static final GpuColumnVector from(ai.rapids.cudf.ColumnVector cudfCv) {
- return new GpuColumnVector(getSparkType(cudfCv.getType()), cudfCv);
+ return new GpuColumnVector(getSparkTypeFrom(cudfCv), cudfCv);
}
public static final GpuColumnVector from(Scalar scalar, int count) {
diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVectorFromBuffer.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVectorFromBuffer.java
index abb95ca8ce6..ab8335eea72 100644
--- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVectorFromBuffer.java
+++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVectorFromBuffer.java
@@ -49,7 +49,7 @@ public static ColumnarBatch from(ContiguousTable contigTable) {
try {
for (int i = 0; i < numColumns; ++i) {
ColumnVector v = table.getColumn(i);
- DataType type = getSparkType(v.getType());
+ DataType type = getSparkTypeFrom(v);
columns[i] = new GpuColumnVectorFromBuffer(type, v.incRefCount(), buffer);
}
return new ColumnarBatch(columns, (int) rows);
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index 1e3403a9005..d57f3e95d21 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -336,6 +336,11 @@ object GpuOverrides {
"\\S", "\\v", "\\V", "\\w", "\\w", "\\p", "$", "\\b", "\\B", "\\A", "\\G", "\\Z", "\\z", "\\R",
"?", "|", "(", ")", "{", "}", "\\k", "\\Q", "\\E", ":", "!", "<=", ">")
+ def canRegexpBeTreatedLikeARegularString(strLit: UTF8String): Boolean = {
+ val s = strLit.toString
+ !regexList.exists(pattern => s.contains(pattern))
+ }
+
@scala.annotation.tailrec
def extractLit(exp: Expression): Option[Literal] = exp match {
case l: Literal => Some(l)
@@ -1328,6 +1333,12 @@ object GpuOverrides {
pad: Expression): GpuExpression =
GpuStringRPad(str, width, pad)
}),
+ expr[StringSplit](
+ "Splits `str` around occurrences that match `regex`",
+ (in, conf, p, r) => new GpuStringSplitMeta(in, conf, p, r)),
+ expr[GetArrayItem](
+ "Gets the field at `ordinal` in the Array",
+ (in, conf, p, r) => new GpuGetArrayItemMeta(in, conf, p, r)),
expr[StringLocate](
"Substring search operator",
(in, conf, p, r) => new TernaryExprMeta[StringLocate](in, conf, p, r) {
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala
new file mode 100644
index 00000000000..bf6473fd8ba
--- /dev/null
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala
@@ -0,0 +1,89 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids
+
+import ai.rapids.cudf.{ColumnVector, Scalar}
+import com.nvidia.spark.rapids.{BinaryExprMeta, ConfKeysAndIncompat, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta}
+
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExtractValue, GetArrayItem}
+import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, DataType, IntegralType}
+
+class GpuGetArrayItemMeta(
+ expr: GetArrayItem,
+ conf: RapidsConf,
+ parent: Option[RapidsMeta[_, _, _]],
+ rule: ConfKeysAndIncompat)
+ extends BinaryExprMeta[GetArrayItem](expr, conf, parent, rule) {
+ import GpuOverrides._
+
+ override def tagExprForGpu(): Unit = {
+ if (!isLit(expr.ordinal)) {
+ willNotWorkOnGpu("only literal ordinals are supported")
+ }
+ }
+ override def convertToGpu(
+ arr: Expression,
+ ordinal: Expression): GpuExpression =
+ GpuGetArrayItem(arr, ordinal)
+
+ def isSupported(t: DataType) = t match {
+ // For now we will only do one level of array type support
+ case a : ArrayType => isSupportedType(a.elementType)
+ case _ => isSupportedType(t)
+ }
+
+ override def areAllSupportedTypes(types: DataType*): Boolean = types.forall(isSupported)
+}
+
+/**
+ * Returns the field at `ordinal` in the Array `child`.
+ *
+ * We need to do type checking here as `ordinal` expression maybe unresolved.
+ */
+case class GpuGetArrayItem(child: Expression, ordinal: Expression)
+ extends GpuBinaryExpression with ExpectsInputTypes with ExtractValue {
+
+ // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)
+
+ override def toString: String = s"$child[$ordinal]"
+ override def sql: String = s"${child.sql}[${ordinal.sql}]"
+
+ override def left: Expression = child
+ override def right: Expression = ordinal
+ // Eventually we need something more full featured like
+ // GetArrayItemUtil.computeNullabilityFromArray
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType
+
+ override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): GpuColumnVector =
+ throw new IllegalStateException("This is not supported yet")
+
+ override def doColumnar(lhs: Scalar, rhs: GpuColumnVector): GpuColumnVector =
+ throw new IllegalStateException("This is not supported yet")
+
+ override def doColumnar(lhs: GpuColumnVector, ordinal: Scalar): GpuColumnVector = {
+ // Need to handle negative indexes...
+ if (ordinal.isValid && ordinal.getInt >= 0) {
+ GpuColumnVector.from(lhs.getBase.extractListElement(ordinal.getInt))
+ } else {
+ withResource(Scalar.fromNull(GpuColumnVector.getRapidsType(dataType))) { nullScalar =>
+ GpuColumnVector.from(ColumnVector.fromScalar(nullScalar, lhs.getRowCount.toInt))
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala
index ba7e5863226..a485b48303f 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala
@@ -22,7 +22,7 @@ import ai.rapids.cudf.{ColumnVector, DType, PadSide, Scalar, Table}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
-import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ImplicitCastInputTypes, NullIntolerant, Predicate, SubstringIndex}
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ImplicitCastInputTypes, NullIntolerant, Predicate, StringSplit, SubstringIndex}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.UTF8String
@@ -768,4 +768,96 @@ case class GpuStringRPad(str: Expression, len: Expression, pad: Expression)
def this(str: Expression, len: Expression) = {
this(str, len, GpuLiteral(" ", StringType))
}
+}
+
+class GpuStringSplitMeta(
+ expr: StringSplit,
+ conf: RapidsConf,
+ parent: Option[RapidsMeta[_, _, _]],
+ rule: ConfKeysAndIncompat)
+ extends TernaryExprMeta[StringSplit](expr, conf, parent, rule) {
+ import GpuOverrides._
+
+ override def tagExprForGpu(): Unit = {
+ val regexp = extractLit(expr.regex)
+ if (regexp.isEmpty) {
+ willNotWorkOnGpu("only literal regexp values are supported")
+ } else {
+ val str = regexp.get.value.asInstanceOf[UTF8String]
+ if (str != null) {
+ if (!canRegexpBeTreatedLikeARegularString(str)) {
+ willNotWorkOnGpu("regular expressions are not supported yet")
+ }
+ if (str.numChars() == 0) {
+ willNotWorkOnGpu("An empty regex is not supported yet")
+ }
+ } else {
+ willNotWorkOnGpu("null regex is not supported yet")
+ }
+ }
+ if (!isLit(expr.limit)) {
+ willNotWorkOnGpu("only literal limit is supported")
+ }
+ }
+ override def convertToGpu(
+ str: Expression,
+ regexp: Expression,
+ limit: Expression): GpuExpression =
+ GpuStringSplit(str, regexp, limit)
+
+ // For now we support all of the possible input and output types for this operator
+ override def areAllSupportedTypes(types: DataType*): Boolean = true
+}
+
+case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression)
+ extends GpuTernaryExpression with ImplicitCastInputTypes {
+
+ override def dataType: DataType = ArrayType(StringType)
+ override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
+ override def children: Seq[Expression] = str :: regex :: limit :: Nil
+
+ def this(exp: Expression, regex: Expression) = this(exp, regex, GpuLiteral(-1, IntegerType))
+
+ override def prettyName: String = "split"
+
+ override def doColumnar(str: GpuColumnVector, regex: Scalar, limit: Scalar): GpuColumnVector = {
+ val intLimit = limit.getInt
+ GpuColumnVector.from(str.getBase.stringSplitRecord(regex, intLimit))
+ }
+
+ override def doColumnar(
+ str: GpuColumnVector,
+ regex: GpuColumnVector,
+ limit: GpuColumnVector): GpuColumnVector =
+ throw new IllegalStateException("This is not supported yet")
+
+ override def doColumnar(
+ str: Scalar,
+ regex: GpuColumnVector,
+ limit: GpuColumnVector): GpuColumnVector =
+ throw new IllegalStateException("This is not supported yet")
+
+ override def doColumnar(
+ str: Scalar,
+ regex: Scalar,
+ limit: GpuColumnVector): GpuColumnVector =
+ throw new IllegalStateException("This is not supported yet")
+
+ override def doColumnar(
+ str: Scalar,
+ regex: GpuColumnVector,
+ limit: Scalar): GpuColumnVector =
+ throw new IllegalStateException("This is not supported yet")
+
+ override def doColumnar(
+ str: GpuColumnVector,
+ regex: Scalar,
+ limit: GpuColumnVector): GpuColumnVector =
+ throw new IllegalStateException("This is not supported yet")
+
+ override def doColumnar(
+ str: GpuColumnVector,
+ regex: GpuColumnVector,
+ limit: Scalar): GpuColumnVector =
+ throw new IllegalStateException("This is not supported yet")
}
\ No newline at end of file