Skip to content

Commit

Permalink
Support map_filter operator (#5436)
Browse files Browse the repository at this point in the history
* Support `map_filter` operator

Signed-off-by: Chong Gao <res_life@163.com>
  • Loading branch information
res-life authored May 14, 2022
1 parent 76552b9 commit b865ecb
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.Lower"></a>spark.rapids.sql.expression.Lower|`lower`, `lcase`|String lowercase operator|false|This is not 100% compatible with the Spark version because the Unicode version used by cuDF and the JVM may differ, resulting in some corner-case characters not changing case correctly.|
<a name="sql.expression.MakeDecimal"></a>spark.rapids.sql.expression.MakeDecimal| |Create a Decimal from an unscaled long value for some aggregation optimizations|true|None|
<a name="sql.expression.MapEntries"></a>spark.rapids.sql.expression.MapEntries|`map_entries`|Returns an unordered array of all entries in the given map|true|None|
<a name="sql.expression.MapFilter"></a>spark.rapids.sql.expression.MapFilter|`map_filter`|Filters entries in a map using the function|true|None|
<a name="sql.expression.MapKeys"></a>spark.rapids.sql.expression.MapKeys|`map_keys`|Returns an unordered array containing the keys of the map|true|None|
<a name="sql.expression.MapValues"></a>spark.rapids.sql.expression.MapValues|`map_values`|Returns an unordered array containing the values of the map|true|None|
<a name="sql.expression.Md5"></a>spark.rapids.sql.expression.Md5|`md5`|MD5 hash operator|true|None|
Expand Down
82 changes: 75 additions & 7 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -8654,12 +8654,12 @@ are limited.
<td> </td>
</tr>
<tr>
<td rowSpan="2">MapKeys</td>
<td rowSpan="2">`map_keys`</td>
<td rowSpan="2">Returns an unordered array containing the keys of the map</td>
<td rowSpan="2">None</td>
<td rowSpan="2">project</td>
<td>input</td>
<td rowSpan="3">MapFilter</td>
<td rowSpan="3">`map_filter`</td>
<td rowSpan="3">Filters entries in a map using the function</td>
<td rowSpan="3">None</td>
<td rowSpan="3">project</td>
<td>argument</td>
<td> </td>
<td> </td>
<td> </td>
Expand All @@ -8680,6 +8680,27 @@ are limited.
<td> </td>
</tr>
<tr>
<td>function</td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
Expand All @@ -8695,8 +8716,8 @@ are limited.
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
</tr>
Expand Down Expand Up @@ -8727,6 +8748,53 @@ are limited.
<th>UDT</th>
</tr>
<tr>
<td rowSpan="2">MapKeys</td>
<td rowSpan="2">`map_keys`</td>
<td rowSpan="2">Returns an unordered array containing the keys of the map</td>
<td rowSpan="2">None</td>
<td rowSpan="2">project</td>
<td>input</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="2">MapValues</td>
<td rowSpan="2">`map_values`</td>
<td rowSpan="2">Returns an unordered array containing the values of the map</td>
Expand Down
9 changes: 9 additions & 0 deletions integration_tests/src/main/python/map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,12 @@ def test_transform_keys_last_win_fallback(data_gen):
def test_sql_map_scalars(query):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.sql('SELECT {}'.format(query)))

@pytest.mark.parametrize('data_gen', map_gens_sample, ids=idfn)
def test_map_filter(data_gen):
columns = ['map_filter(a, (key, value) -> isnotnull(value) )',
'map_filter(a, (key, value) -> isnull(value) )',
'map_filter(a, (key, value) -> isnull(key) or isnotnull(value) )',
'map_filter(a, (key, value) -> isnotnull(key) and isnull(value) )']
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(columns))
Original file line number Diff line number Diff line change
Expand Up @@ -2957,6 +2957,22 @@ object GpuOverrides extends Logging {
GpuTransformValues(childExprs.head.convertToGpu(), childExprs(1).convertToGpu())
}
}),
expr[MapFilter](
"Filters entries in a map using the function",
ExprChecks.projectOnly(TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
TypeSig.MAP.nested(TypeSig.all),
Seq(
ParamCheck("argument",
TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
TypeSig.MAP.nested(TypeSig.all)),
ParamCheck("function", TypeSig.BOOLEAN, TypeSig.BOOLEAN))),
(in, conf, p, r) => new ExprMeta[MapFilter](in, conf, p, r) {
override def convertToGpu(): GpuExpression = {
GpuMapFilter(childExprs.head.convertToGpu(), childExprs(1).convertToGpu())
}
}),
expr[StringLocate](
"Substring search operator",
ExprChecks.projectOnly(TypeSig.INT, TypeSig.INT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,4 +530,46 @@ case class GpuTransformValues(
}
}
}
}
}

case class GpuMapFilter(argument: Expression,
function: Expression,
isBound: Boolean = false,
boundIntermediate: Seq[GpuExpression] = Seq.empty)
extends GpuMapSimpleHigherOrderFunction {

override def dataType: DataType = argument.dataType

override def prettyName: String = "map_filter"

override def bind(input: AttributeSeq): GpuExpression = {
val (boundFunc, boundArg, boundIntermediate) = bindLambdaFunc(input)

GpuMapFilter(boundArg, boundFunc, isBound = true, boundIntermediate)
}

override def columnarEval(batch: ColumnarBatch): Any = {
withResource(GpuExpressionsUtils.columnarEvalToColumn(argument, batch)) { mapArg =>
// `mapArg` is list of struct(key, value)
val plainBoolCol = withResource(makeElementProjectBatch(batch, mapArg.getBase)) { cb =>
GpuExpressionsUtils.columnarEvalToColumn(function, cb)
}

withResource(plainBoolCol) { plainBoolCol =>
assert(plainBoolCol.dataType() == BooleanType, "map_filter should have a predicate filter")
withResource(mapArg.getBase.getListOffsetsView) { argOffsetsCv =>
// convert the one dimension plain bool column to list of bool column
withResource(plainBoolCol.getBase.makeListFromOffsets(mapArg.getRowCount, argOffsetsCv)) {
listOfBoolCv =>
// extract entries for each map in the `mapArg` column
// according to the `listOfBoolCv` column
// `mapArg` is a map column containing no duplicate keys and null keys,
// so no need to `assertNoNullKeys` and `assertNoDuplicateKeys` after the extraction
val retCv = mapArg.getBase.applyBooleanMask(listOfBoolCv)
GpuColumnVector.from(retCv, dataType)
}
}
}
}
}
}
1 change: 1 addition & 0 deletions tools/src/main/resources/operatorsScore.csv
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ Logarithm,3
Lower,3
MakeDecimal,3
MapEntries,3
MapFilter,3
MapKeys,3
MapValues,3
Max,3
Expand Down
3 changes: 3 additions & 0 deletions tools/src/main/resources/supportedExprs.csv
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,9 @@ MakeDecimal,S, ,None,project,input,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N
MakeDecimal,S, ,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA
MapEntries,S,`map_entries`,None,project,input,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA
MapEntries,S,`map_entries`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA
MapFilter,S,`map_filter`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA
MapFilter,S,`map_filter`,None,project,function,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
MapFilter,S,`map_filter`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA
MapKeys,S,`map_keys`,None,project,input,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA
MapKeys,S,`map_keys`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA
MapValues,S,`map_values`,None,project,input,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA
Expand Down

0 comments on commit b865ecb

Please sign in to comment.