From 21de93590196578a9857dd1a393d3c705ea1ceaf Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Thu, 19 Nov 2020 10:39:04 -0600 Subject: [PATCH 1/2] Add in support for GetStructField Signed-off-by: Robert (Bobby) Evans --- docs/configs.md | 1 + .../src/main/python/struct_test.py | 32 +++++++++++ .../nvidia/spark/rapids/GpuOverrides.scala | 14 +++++ .../sql/rapids/complexTypeExtractors.scala | 53 +++++++++++++++++-- 4 files changed, 96 insertions(+), 4 deletions(-) create mode 100644 integration_tests/src/main/python/struct_test.py diff --git a/docs/configs.md b/docs/configs.md index 37636958d67..5d2256035f3 100644 --- a/docs/configs.md +++ b/docs/configs.md @@ -143,6 +143,7 @@ Name | SQL Function(s) | Description | Default Value | Notes spark.rapids.sql.expression.FromUnixTime|`from_unixtime`|Get the string from a unix timestamp|true|None| spark.rapids.sql.expression.GetArrayItem| |Gets the field at `ordinal` in the Array|true|None| spark.rapids.sql.expression.GetMapValue| |Gets Value from a Map based on a key|true|None| +spark.rapids.sql.expression.GetStructField| |Gets the named field of the struct|true|None| spark.rapids.sql.expression.GreaterThan|`>`|> operator|true|None| spark.rapids.sql.expression.GreaterThanOrEqual|`>=`|>= operator|true|None| spark.rapids.sql.expression.Greatest|`greatest`|Returns the greatest value of all parameters, skipping null values|true|None| diff --git a/integration_tests/src/main/python/struct_test.py b/integration_tests/src/main/python/struct_test.py new file mode 100644 index 00000000000..e53f3214df1 --- /dev/null +++ b/integration_tests/src/main/python/struct_test.py @@ -0,0 +1,32 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from asserts import assert_gpu_and_cpu_are_equal_collect +from data_gen import * +from marks import incompat +from pyspark.sql.types import * +import pyspark.sql.functions as f + +@pytest.mark.parametrize('data_gen', [StructGen([["first", boolean_gen], ["second", byte_gen], ["third", float_gen]]), + StructGen([["first", short_gen], ["second", int_gen], ["third", long_gen]]), + StructGen([["first", double_gen], ["second", date_gen], ["third", timestamp_gen]]), + StructGen([["first", string_gen], ["second", ArrayGen(byte_gen)], ["third", simple_string_to_string_map_gen]])], ids=idfn) +def test_struct_get_item(data_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'a.first', + 'a.second', + 'a.third')) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 6b492c4896c..0e72575530a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1609,6 +1609,20 @@ object GpuOverrides { expr[StringSplit]( "Splits `str` around occurrences that match `regex`", (in, conf, p, r) => new GpuStringSplitMeta(in, conf, p, r)), + + expr[GetStructField]( + "Gets the named field of the struct", + (expr, conf, p, r) => new UnaryExprMeta[GetStructField](expr, conf, p, r) { + override def convertToGpu(arr: Expression): GpuExpression = + GpuGetStructField(arr, expr.ordinal, expr.name) + + override def isSupportedType(t: DataType): Boolean = + GpuOverrides.isSupportedType(t, + allowArray = true, + allowStruct = true, + allowMaps = true, + allowNesting = true) + }), expr[GetArrayItem]( "Gets the field at `ordinal` in the Array", (in, conf, p, r) => new GpuGetArrayItemMeta(in, conf, p, r)), diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index 927c77269e6..b5ca69eac76 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -17,12 +17,57 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf.{ColumnVector, Scalar} -import com.nvidia.spark.rapids.{BinaryExprMeta, ConfKeysAndIncompat, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta} +import com.nvidia.spark.rapids.{BinaryExprMeta, ConfKeysAndIncompat, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, RapidsConf, RapidsMeta} +import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExtractValue, GetArrayItem, GetMapValue, ImplicitCastInputTypes, NullIntolerant} -import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, DataType, IntegralType, MapType, StringType, StructType} +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExtractValue, GetArrayItem, GetMapValue, ImplicitCastInputTypes, NullIntolerant, UnaryExpression} +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, TypeUtils} +import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, DataType, IntegralType, MapType, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch + +case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[String] = None) + extends UnaryExpression with GpuExpression with ExtractValue { + + lazy val childSchema: StructType = child.dataType.asInstanceOf[StructType] + + override def dataType: DataType = childSchema(ordinal).dataType + override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable + + override def toString: String = { + val fieldName = if (resolved) childSchema(ordinal).name else s"_$ordinal" + s"$child.${name.getOrElse(fieldName)}" + } + + override def sql: String = + child.sql + s".${quoteIdentifier(name.getOrElse(childSchema(ordinal).name))}" + + override def columnarEval(batch: ColumnarBatch): Any = { + val input = child.columnarEval(batch) + val dt = dataType + try { + input match { + case cv: GpuColumnVector => + withResource(cv.getBase.getChildColumnView(ordinal)) { view => + GpuColumnVector.from(view.copyToColumnVector(), dt) + } + case null => null + case ir: InternalRow => + // Literal struct values are not currently supported, but just in case... + val tmp = ir.get(ordinal, dt) + withResource(GpuScalar.from(tmp, dt)) { scalar => + GpuColumnVector.from(scalar, batch.numRows(), dt) + } + } + } finally { + input match { + case ac: AutoCloseable => ac.close() + case _ => // NOOP + } + } + } +} class GpuGetArrayItemMeta( expr: GetArrayItem, From c04fa8cee7bf0eac691422ff5ca3c04d556296db Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Thu, 19 Nov 2020 15:02:59 -0600 Subject: [PATCH 2/2] Addressed review comments --- .../src/main/scala/com/nvidia/spark/rapids/Arm.scala | 12 ++++++++++++ .../spark/sql/rapids/complexTypeExtractors.scala | 12 +++--------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala index 77e4b492336..65a73d8a3ba 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala @@ -49,6 +49,18 @@ trait Arm { } } + /** Executes the provided code block and then closes the value if it is AutoCloseable */ + def withResourceIfAllowed[T, V](r: T)(block: T => V): V = { + try { + block(r) + } finally { + r match { + case c: AutoCloseable => c.close() + case _ => //NOOP + } + } + } + /** Executes the provided code block, closing the resource only if an exception occurs */ def closeOnExcept[T <: AutoCloseable, V](r: T)(block: T => V): V = { try { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index b5ca69eac76..e8236d27eff 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, Dat import org.apache.spark.sql.vectorized.ColumnarBatch case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[String] = None) - extends UnaryExpression with GpuExpression with ExtractValue { + extends UnaryExpression with GpuExpression with ExtractValue with NullIntolerant { lazy val childSchema: StructType = child.dataType.asInstanceOf[StructType] @@ -44,9 +44,8 @@ case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[Strin child.sql + s".${quoteIdentifier(name.getOrElse(childSchema(ordinal).name))}" override def columnarEval(batch: ColumnarBatch): Any = { - val input = child.columnarEval(batch) - val dt = dataType - try { + withResourceIfAllowed(child.columnarEval(batch)) { input => + val dt = dataType input match { case cv: GpuColumnVector => withResource(cv.getBase.getChildColumnView(ordinal)) { view => @@ -60,11 +59,6 @@ case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[Strin GpuColumnVector.from(scalar, batch.numRows(), dt) } } - } finally { - input match { - case ac: AutoCloseable => ac.close() - case _ => // NOOP - } } } }