Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support running GetArrayStructFields on GPU[databricks] #4875

Merged
merged 5 commits into from
Mar 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.Floor"></a>spark.rapids.sql.expression.Floor|`floor`|Floor of a number|true|None|
<a name="sql.expression.FromUnixTime"></a>spark.rapids.sql.expression.FromUnixTime|`from_unixtime`|Get the string from a unix timestamp|true|None|
<a name="sql.expression.GetArrayItem"></a>spark.rapids.sql.expression.GetArrayItem| |Gets the field at `ordinal` in the Array|true|None|
<a name="sql.expression.GetArrayStructFields"></a>spark.rapids.sql.expression.GetArrayStructFields| |Extracts the `ordinal`-th fields of all array elements for the data with the type of array of struct|true|None|
<a name="sql.expression.GetJsonObject"></a>spark.rapids.sql.expression.GetJsonObject|`get_json_object`|Extracts a json object from path|true|None|
<a name="sql.expression.GetMapValue"></a>spark.rapids.sql.expression.GetMapValue| |Gets Value from a Map based on a key|true|None|
<a name="sql.expression.GetStructField"></a>spark.rapids.sql.expression.GetStructField| |Gets the named field of the struct|true|None|
Expand Down
47 changes: 47 additions & 0 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -5745,6 +5745,53 @@ are limited.
<td><b>NS</b></td>
</tr>
<tr>
<td rowSpan="2">GetArrayStructFields</td>
<td rowSpan="2"> </td>
<td rowSpan="2">Extracts the `ordinal`-th fields of all array elements for the data with the type of array of struct</td>
<td rowSpan="2">None</td>
<td rowSpan="2">project</td>
<td>input</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="3">GetJsonObject</td>
<td rowSpan="3">`get_json_object`</td>
<td rowSpan="3">Extracts a json object from path</td>
Expand Down
9 changes: 9 additions & 0 deletions integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,12 @@ def test_array_max(data_gen):
def test_sql_array_scalars(query):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.sql('SELECT {}'.format(query)))


@pytest.mark.parametrize('data_gen', all_basic_gens + nested_gens_sample, ids=idfn)
def test_get_array_struct_fields(data_gen):
array_struct_gen = ArrayGen(
StructGen([['child0', data_gen], ['child1', int_gen]]),
max_length=6)
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, array_struct_gen).selectExpr('a.child0'))
2 changes: 2 additions & 0 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,8 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
MapGen(RepeatSeqGen(IntegerGen(nullable=False), 10), long_gen, max_length=10),
MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen)]

nested_gens_sample = array_gens_sample + struct_gens_sample_with_decimal128 + map_gens_sample + decimal_128_map_gens

ansi_enabled_conf = {'spark.sql.ansi.enabled': 'true'}
no_nans_conf = {'spark.rapids.sql.hasNans': 'false'}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3365,7 +3365,20 @@ object GpuOverrides extends Logging {
TypeSig.STRING, TypeSig.STRING + TypeSig.BINARY),
(a, conf, p, r) => new UnaryExprMeta[OctetLength](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression = GpuOctetLength(child)
})
}),
expr[GetArrayStructFields](
"Extracts the `ordinal`-th fields of all array elements for the data with the type of" +
" array of struct",
ExprChecks.unaryProject(
TypeSig.ARRAY.nested(TypeSig.commonCudfTypesWithNested),
TypeSig.ARRAY.nested(TypeSig.all),
// we should allow all supported types for the children types signature of the nested
// struct, even only a struct child is allowed for the array here. Since TypeSig supports
// only one level signature for nested type.
TypeSig.ARRAY.nested(TypeSig.commonCudfTypesWithNested),
TypeSig.ARRAY.nested(TypeSig.all)),
(e, conf, p, r) => new GpuGetArrayStructFieldsMeta(e, conf, p, r)
)
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap

// Shim expressions should be last to allow overrides with shim-specific versions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,12 @@ object TypeSig {
TIMESTAMP + STRING + DECIMAL_128 + NULL + BINARY + CALENDAR + ARRAY + STRUCT +
UDT).nested()

/**
* commonCudfTypes plus decimal, null and nested types.
*/
val commonCudfTypesWithNested: TypeSig = (commonCudfTypes + DECIMAL_128 + NULL +
ARRAY + STRUCT + MAP).nested()

/**
* Different types of Pandas UDF support different sets of output type. Please refer to
* https://github.com/apache/spark/blob/master/python/pyspark/sql/udf.py#L98
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
package org.apache.spark.sql.rapids

import ai.rapids.cudf.ColumnVector
import com.nvidia.spark.rapids.{BinaryExprMeta, DataFromReplacementRule, DataTypeUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, RapidsConf, RapidsMeta}
import com.nvidia.spark.rapids.{BinaryExprMeta, DataFromReplacementRule, DataTypeUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuListUtils, GpuOverrides, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta, UnaryExprMeta}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.v2.{RapidsErrorUtils, ShimUnaryExpression}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExtractValue, GetArrayItem, GetMapValue, ImplicitCastInputTypes, NullIntolerant}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExtractValue, GetArrayItem, GetArrayStructFields, GetMapValue, ImplicitCastInputTypes, NullIntolerant}
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, TypeUtils}
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, IntegralType, MapType, StructType}
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, IntegralType, MapType, StructField, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -267,3 +267,45 @@ case class GpuArrayContains(left: Expression, right: Expression)

override def prettyName: String = "array_contains"
}

class GpuGetArrayStructFieldsMeta(
expr: GetArrayStructFields,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends UnaryExprMeta[GetArrayStructFields](expr, conf, parent, rule) {

def convertToGpu(child: Expression): GpuExpression =
GpuGetArrayStructFields(child, expr.field, expr.ordinal, expr.numFields, expr.containsNull)
}

/**
* For a child whose data type is an array of structs, extracts the `ordinal`-th fields of all array
* elements, and returns them as a new array.
*
* No need to do type checking since it is handled by 'ExtractValue'.
*/
case class GpuGetArrayStructFields(
child: Expression,
field: StructField,
ordinal: Int,
numFields: Int,
containsNull: Boolean) extends GpuUnaryExpression with ExtractValue with NullIntolerant {

override def dataType: DataType = ArrayType(field.dataType, containsNull)
override def toString: String = s"$child.${field.name}"
override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}"

override protected def doColumnar(input: GpuColumnVector): ColumnVector = {
val base = input.getBase
val fieldView = withResource(base.getChildColumnView(0)) { structView =>
structView.getChildColumnView(ordinal)
}
val listView = withResource(fieldView) { _ =>
GpuListUtils.replaceListDataColumnAsView(base, fieldView)
}
withResource(listView) { _ =>
listView.copyToColumnVector()
}
}
}