Skip to content

Commit

Permalink
Add GpuMapConcat support for nested-type values (#5686)
Browse files Browse the repository at this point in the history
* support nested values

Signed-off-by: remzi <13716567376yh@gmail.com>

* rewrite GpuMapConcat

Signed-off-by: remzi <13716567376yh@gmail.com>

* update
simplify
shrink the test size

Signed-off-by: remzi <13716567376yh@gmail.com>

* add limit

Signed-off-by: remzi <13716567376yh@gmail.com>

* remove unused import

Signed-off-by: remzi <13716567376yh@gmail.com>

* fallback to CPU when key type is nested

Signed-off-by: remzi <13716567376yh@gmail.com>

* update docs

Signed-off-by: remzi <13716567376yh@gmail.com>

* add tests

Signed-off-by: remzi <13716567376yh@gmail.com>

* update comments

Signed-off-by: remzi <13716567376yh@gmail.com>
  • Loading branch information
HaoYang670 authored Jun 25, 2022
1 parent ed8fc87 commit 34c1761
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 33 deletions.
4 changes: 2 additions & 2 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -8628,7 +8628,7 @@ are limited.
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
</tr>
Expand All @@ -8649,7 +8649,7 @@ are limited.
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
</tr>
Expand Down
4 changes: 2 additions & 2 deletions integration_tests/src/main/python/collection_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_concat_string():
f.concat(f.lit(''), f.col('b')),
f.concat(f.col('a'), f.lit(''))))

@pytest.mark.parametrize('data_gen', all_basic_map_gens + decimal_64_map_gens + decimal_128_map_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', map_gens_sample + decimal_64_map_gens + decimal_128_map_gens, ids=idfn)
def test_map_concat(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: three_col_df(spark, data_gen, data_gen, data_gen
Expand All @@ -100,7 +100,7 @@ def test_map_concat(data_gen):
{"spark.sql.mapKeyDedupPolicy": "LAST_WIN"}
)

@pytest.mark.parametrize('data_gen', all_basic_map_gens + decimal_64_map_gens + decimal_128_map_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', map_gens_sample + decimal_64_map_gens + decimal_128_map_gens, ids=idfn)
def test_map_concat_with_lit(data_gen):
lit_col1 = f.lit(gen_scalar(data_gen)).cast(data_gen.data_type)
lit_col2 = f.lit(gen_scalar(data_gen)).cast(data_gen.data_type)
Expand Down
5 changes: 4 additions & 1 deletion integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,10 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
# Some map gens, but not all because of nesting
map_gens_sample = all_basic_map_gens + [MapGen(StringGen(pattern='key_[0-9]', nullable=False), ArrayGen(string_gen), max_length=10),
MapGen(RepeatSeqGen(IntegerGen(nullable=False), 10), long_gen, max_length=10),
MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen)]
MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen),
MapGen(IntegerGen(False), ArrayGen(int_gen, max_length=3), max_length=3),
MapGen(ShortGen(False), StructGen([['child0', byte_gen], ['child1', double_gen]]), max_length=3),
MapGen(ByteGen(False), MapGen(FloatGen(False), date_gen, max_length=3), max_length=3)]

nested_gens_sample = array_gens_sample + struct_gens_sample_with_decimal128 + map_gens_sample + decimal_128_map_gens

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3100,14 +3100,25 @@ object GpuOverrides extends Logging {
}),
expr[MapConcat](
"Returns the union of all the given maps",
// Currently, GpuMapConcat supports nested values but not nested keys.
// We will add the nested key support after
// cuDF can fully support nested types in lists::drop_list_duplicates.
// Issue link: https://github.com/rapidsai/cudf/issues/11093
ExprChecks.projectOnly(TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
TypeSig.NULL),
TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
TypeSig.MAP.nested(TypeSig.all),
repeatingParamCheck = Some(RepeatingParamCheck("input",
TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
TypeSig.NULL),
TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
TypeSig.MAP.nested(TypeSig.all)))),
(a, conf, p, r) => new ComplexTypeMergingExprMeta[MapConcat](a, conf, p, r) {
override def tagExprForGpu(): Unit = {
a.dataType.keyType match {
case MapType(_,_,_) | ArrayType(_,_) | StructType(_) => willNotWorkOnGpu(
s"GpuMapConcat does not currently support the key type ${a.dataType.keyType}.")
case _ =>
}
}
override def convertToGpu(child: Seq[Expression]): GpuExpression = GpuMapConcat(child)
}),
expr[ConcatWs](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ package org.apache.spark.sql.rapids

import java.util.Optional

import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf
import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar, SegmentedReductionAggregation, Table}
import com.nvidia.spark.rapids._
Expand Down Expand Up @@ -96,30 +94,9 @@ case class GpuMapConcat(children: Seq[Expression]) extends GpuComplexTypeMerging
// For single column concat, we pass the result of child node to avoid extra cuDF call.
case (_, 1) => children.head.columnarEval(batch)
case (dt, _) => {
val cols = children.safeMap(columnarEvalToColumn(_, batch))
// concatenate keys and values
val (key_list, value_list) = withResource(cols) { cols =>
withResource(ArrayBuffer[ColumnView]()) { keys =>
withResource(ArrayBuffer[ColumnView]()) { values =>
cols.foreach{ col =>
keys.append(GpuMapUtils.getKeysAsListView(col.getBase))
values.append(GpuMapUtils.getValuesAsListView(col.getBase))
}
closeOnExcept(ColumnVector.listConcatenateByRow(keys: _*)) {key_list =>
(key_list, ColumnVector.listConcatenateByRow(values: _*))
}
}
}
}
// build map column from concatenated keys and values
withResource(Seq(key_list, value_list)) { case Seq(keys, values) =>
withResource(Seq(keys.getChildColumnView(0), values.getChildColumnView(0))) {
case Seq(k_child, v_chlid) =>
withResource(ColumnView.makeStructView(k_child, v_chlid)) {structs =>
withResource(keys.replaceListChild(structs)) { struct_list =>
GpuCreateMap.createMapFromKeysValuesAsStructs(dt, struct_list)
}
}
withResource(children.safeMap(columnarEvalToColumn(_, batch).getBase())) {cols =>
withResource(cudf.ColumnVector.listConcatenateByRow(cols: _*)) {structs =>
GpuCreateMap.createMapFromKeysValuesAsStructs(dataType, structs)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import org.apache.spark.sql.functions.map_concat

class CollectionOpSuite extends SparkQueryCompareTestSuite {
testGpuFallback(
"MapConcat with Array keys fall back",
"ProjectExec",
ArrayKeyMapDF,
execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec")) {
frame => {
import frame.sparkSession.implicits._
frame.select(map_concat($"col1", $"col2"))
}
}

testGpuFallback(
"MapConcat with Struct keys fall back",
"ProjectExec",
StructKeyMapDF,
execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec")) {
frame => {
import frame.sparkSession.implicits._
frame.select(map_concat($"col1", $"col2"))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1780,6 +1780,16 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm {
).toDF("strings")
}

def ArrayKeyMapDF(session: SparkSession): DataFrame = {
import session.sqlContext.implicits._
Seq((Map(List(1, 2) -> 2), (Map(List(2, 3) -> 3)))).toDF("col1", "col2")
}

def StructKeyMapDF(session: SparkSession): DataFrame = {
import session.sqlContext.implicits._
Seq((Map((1, 2) -> 2), (Map((2, 3) -> 3)))).toDF("col1", "col2")
}

def nullableStringsFromCsv = {
fromCsvDf("strings.csv", StructType(Array(
StructField("strings", StringType, true),
Expand Down

0 comments on commit 34c1761

Please sign in to comment.