diff --git a/docs/configs.md b/docs/configs.md
index a203e0df9ca..c11b427f4fd 100644
--- a/docs/configs.md
+++ b/docs/configs.md
@@ -129,6 +129,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
spark.rapids.sql.expression.Cos|`cos`|Cosine|true|None|
spark.rapids.sql.expression.Cosh|`cosh`|Hyperbolic cosine|true|None|
spark.rapids.sql.expression.Cot|`cot`|Cotangent|true|None|
+spark.rapids.sql.expression.CreateNamedStruct|`named_struct`, `struct`|Creates a struct with the given field names and values.|true|None|
spark.rapids.sql.expression.CurrentRow$| |Special boundary for a window frame, indicating stopping at the current row|true|None|
spark.rapids.sql.expression.DateAdd|`date_add`|Returns the date that is num_days after start_date|true|None|
spark.rapids.sql.expression.DateDiff|`datediff`|Returns the number of days from startDate to endDate|true|None|
diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 0f87c1e3151..809f232edd5 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -3283,6 +3283,138 @@ Accelerator support is described below.
|
+CreateNamedStruct |
+`named_struct`, `struct` |
+Creates a struct with the given field names and values. |
+None |
+project |
+name |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+S |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+
+
+value |
+S |
+S |
+S |
+S |
+S |
+S |
+S |
+S |
+S* |
+S |
+S* |
+S |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+
+
+result |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) |
+ |
+
+
+lambda |
+name |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+NS |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+
+
+value |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+
+
+result |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+NS |
+ |
+
+
CurrentRow$ |
|
Special boundary for a window frame, indicating stopping at the current row |
diff --git a/integration_tests/src/main/python/struct_test.py b/integration_tests/src/main/python/struct_test.py
index 604572c0f6e..004550049cd 100644
--- a/integration_tests/src/main/python/struct_test.py
+++ b/integration_tests/src/main/python/struct_test.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020, NVIDIA CORPORATION.
+# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -31,3 +31,11 @@ def test_struct_get_item(data_gen):
'a.first',
'a.second',
'a.third'))
+
+@pytest.mark.parametrize('data_gen', all_basic_gens + [decimal_gen_default, decimal_gen_scale_precision], ids=idfn)
+def test_make_struct(data_gen):
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark : binary_op_df(spark, data_gen).selectExpr(
+ 'struct(a, b)',
+ 'named_struct("foo", b, "bar", 5, "end", a)'))
+
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index 70258460a15..1e07af4f232 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2020, NVIDIA CORPORATION.
+ * Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -1891,6 +1891,13 @@ object GpuOverrides {
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r)),
+ expr[CreateNamedStruct](
+ "Creates a struct with the given field names and values.",
+ CreateNamedStructCheck,
+ (in, conf, p, r) => new ExprMeta[CreateNamedStruct](in, conf, p, r) {
+ override def convertToGpu(): GpuExpression =
+ GpuCreateNamedStruct(childExprs.map(_.convertToGpu()))
+ }),
expr[StringLocate](
"Substring search operator",
ExprChecks.projectNotLambda(TypeSig.INT, TypeSig.INT,
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
index 0cef1ddeb21..264415dbb72 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020, NVIDIA CORPORATION.
+ * Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -669,6 +669,61 @@ object WindowSpecCheck extends ExprChecks {
}
}
+/**
+ * A check for CreateNamedStruct. The parameter values alternate between one type and another.
+ * If this pattern shows up again we can make this more generic at that point.
+ */
+object CreateNamedStructCheck extends ExprChecks {
+ val nameSig: TypeSig = TypeSig.lit(TypeEnum.STRING)
+ val sparkNameSig: TypeSig = TypeSig.lit(TypeEnum.STRING)
+ val valueSig: TypeSig = TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL
+ val sparkValueSig: TypeSig = TypeSig.all
+ val resultSig: TypeSig = TypeSig.STRUCT.nested(valueSig)
+ val sparkResultSig: TypeSig = TypeSig.STRUCT.nested(sparkValueSig)
+
+ override def tag(meta: RapidsMeta[_, _, _]): Unit = {
+ val exprMeta = meta.asInstanceOf[BaseExprMeta[_]]
+ val context = exprMeta.context
+ if (context != ProjectExprContext) {
+ meta.willNotWorkOnGpu(s"this is not supported in the $context context")
+ } else {
+ val origExpr = exprMeta.wrapped.asInstanceOf[Expression]
+ val (nameExprs, valExprs) = origExpr.children.grouped(2).map {
+ case Seq(name, value) => (name, value)
+ }.toList.unzip
+ nameExprs.foreach { expr =>
+ nameSig.tagExprParam(meta, expr, "name")
+ }
+ valExprs.foreach { expr =>
+ valueSig.tagExprParam(meta, expr, "value")
+ }
+ if (!resultSig.isSupportedByPlugin(origExpr.dataType, meta.conf.decimalTypeEnabled)) {
+ meta.willNotWorkOnGpu(s"unsupported data types in output: ${origExpr.dataType}")
+ }
+ }
+ }
+
+ override def support(dataType: TypeEnum.Value):
+ Map[ExpressionContext, Map[String, SupportLevel]] = {
+ val nameProjectSupport = nameSig.getSupportLevel(dataType, sparkNameSig)
+ val nameLambdaSupport = TypeSig.none.getSupportLevel(dataType, sparkNameSig)
+ val valueProjectSupport = valueSig.getSupportLevel(dataType, sparkValueSig)
+ val valueLambdaSupport = TypeSig.none.getSupportLevel(dataType, sparkValueSig)
+ val resultProjectSupport = resultSig.getSupportLevel(dataType, sparkResultSig)
+ val resultLambdaSupport = TypeSig.none.getSupportLevel(dataType, sparkResultSig)
+ Map((ProjectExprContext,
+ Map(
+ ("name", nameProjectSupport),
+ ("value", valueProjectSupport),
+ ("result", resultProjectSupport))),
+ (LambdaExprContext,
+ Map(
+ ("name", nameLambdaSupport),
+ ("value", valueLambdaSupport),
+ ("result", resultLambdaSupport))))
+ }
+}
+
class CastChecks extends ExprChecks {
// Don't show this with other operators show it in a different location
override val shown: Boolean = false
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeCreator.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeCreator.scala
new file mode 100644
index 00000000000..f134ae4577f
--- /dev/null
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeCreator.scala
@@ -0,0 +1,100 @@
+/*
+ * Copyright (c) 2021, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids
+
+import ai.rapids.cudf.ColumnVector
+import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar}
+import com.nvidia.spark.rapids.RapidsPluginImplicits.{AutoCloseableArray, ReallyAGpuExpression}
+
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FUNC_ALIAS
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, NamedExpression}
+import org.apache.spark.sql.types.{Metadata, StringType, StructField, StructType}
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+case class GpuCreateNamedStruct(children: Seq[Expression]) extends GpuExpression {
+ lazy val (nameExprs, valExprs) = children.grouped(2).map {
+ case Seq(name, value) => (name, value)
+ }.toList.unzip
+
+ private lazy val names = nameExprs.map {
+ case g: GpuExpression => g.columnarEval(null)
+ case e => e.eval(EmptyRow)
+ }
+
+ override def nullable: Boolean = false
+
+ override def foldable: Boolean = valExprs.forall(_.foldable)
+
+ override lazy val dataType: StructType = {
+ val fields = names.zip(valExprs).map {
+ case (name, expr) =>
+ val metadata = expr match {
+ case ne: NamedExpression => ne.metadata
+ case _ => Metadata.empty
+ }
+ StructField(name.toString, expr.dataType, expr.nullable, metadata)
+ }
+ StructType(fields)
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (children.size % 2 != 0) {
+ TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.")
+ } else {
+ val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType)
+ if (invalidNames.nonEmpty) {
+ TypeCheckResult.TypeCheckFailure(
+ s"Only foldable ${StringType.catalogString} expressions are allowed to appear at odd" +
+ s" position, got: ${invalidNames.mkString(",")}")
+ } else if (!names.contains(null)) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeCheckResult.TypeCheckFailure("Field name should not be null")
+ }
+ }
+ }
+
+ // There is an alias set at `CreateStruct.create`. If there is an alias,
+ // this is the struct function explicitly called by a user and we should
+ // respect it in the SQL string as `struct(...)`.
+ override def prettyName: String = getTagValue(FUNC_ALIAS).getOrElse("named_struct")
+
+ override def sql: String = getTagValue(FUNC_ALIAS).map { alias =>
+ val childrenSQL = children.indices.filter(_ % 2 == 1).map(children(_).sql).mkString(", ")
+ s"$alias($childrenSQL)"
+ }.getOrElse(super.sql)
+
+ override def columnarEval(batch: ColumnarBatch): Any = {
+ // The names are only used for the type. Here we really just care about the data
+ withResource(new Array[ColumnVector](valExprs.size)) { columns =>
+ val numRows = batch.numRows()
+ valExprs.indices.foreach { index =>
+ valExprs(index).columnarEval(batch) match {
+ case cv: GpuColumnVector =>
+ columns(index) = cv.getBase
+ case other =>
+ val dt = dataType.fields(index).dataType
+ withResource(GpuScalar.from(other, dt)) { scalar =>
+ columns(index) = ColumnVector.fromScalar(scalar, numRows)
+ }
+ }
+ }
+ GpuColumnVector.from(ColumnVector.makeStruct(numRows, columns: _*), dataType)
+ }
+ }
+}
\ No newline at end of file