From 34c17618f7cc2efbeff1788806bcbe77a16eed56 Mon Sep 17 00:00:00 2001
From: Remzi Yang <59198230+HaoYang670@users.noreply.github.com>
Date: Sat, 25 Jun 2022 10:23:35 +0800
Subject: [PATCH] Add `GpuMapConcat` support for nested-type values (#5686)
* support nested values
Signed-off-by: remzi <13716567376yh@gmail.com>
* rewrite GpuMapConcat
Signed-off-by: remzi <13716567376yh@gmail.com>
* update
simplify
shrink the test size
Signed-off-by: remzi <13716567376yh@gmail.com>
* add limit
Signed-off-by: remzi <13716567376yh@gmail.com>
* remove unused import
Signed-off-by: remzi <13716567376yh@gmail.com>
* fallback to CPU when key type is nested
Signed-off-by: remzi <13716567376yh@gmail.com>
* update docs
Signed-off-by: remzi <13716567376yh@gmail.com>
* add tests
Signed-off-by: remzi <13716567376yh@gmail.com>
* update comments
Signed-off-by: remzi <13716567376yh@gmail.com>
---
docs/supported_ops.md | 4 +-
.../src/main/python/collection_ops_test.py | 4 +-
integration_tests/src/main/python/data_gen.py | 5 ++-
.../nvidia/spark/rapids/GpuOverrides.scala | 15 ++++++-
.../sql/rapids/collectionOperations.scala | 29 ++-----------
.../spark/rapids/CollectionOpSuite.scala | 43 +++++++++++++++++++
.../rapids/SparkQueryCompareTestSuite.scala | 10 +++++
7 files changed, 77 insertions(+), 33 deletions(-)
create mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/CollectionOpSuite.scala
diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 66946d040ab..9364b79b787 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -8628,7 +8628,7 @@ are limited.
|
|
|
-PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
|
|
@@ -8649,7 +8649,7 @@ are limited.
|
|
|
-PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
|
|
diff --git a/integration_tests/src/main/python/collection_ops_test.py b/integration_tests/src/main/python/collection_ops_test.py
index 077066418d5..7e31c512f49 100644
--- a/integration_tests/src/main/python/collection_ops_test.py
+++ b/integration_tests/src/main/python/collection_ops_test.py
@@ -89,7 +89,7 @@ def test_concat_string():
f.concat(f.lit(''), f.col('b')),
f.concat(f.col('a'), f.lit(''))))
-@pytest.mark.parametrize('data_gen', all_basic_map_gens + decimal_64_map_gens + decimal_128_map_gens, ids=idfn)
+@pytest.mark.parametrize('data_gen', map_gens_sample + decimal_64_map_gens + decimal_128_map_gens, ids=idfn)
def test_map_concat(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: three_col_df(spark, data_gen, data_gen, data_gen
@@ -100,7 +100,7 @@ def test_map_concat(data_gen):
{"spark.sql.mapKeyDedupPolicy": "LAST_WIN"}
)
-@pytest.mark.parametrize('data_gen', all_basic_map_gens + decimal_64_map_gens + decimal_128_map_gens, ids=idfn)
+@pytest.mark.parametrize('data_gen', map_gens_sample + decimal_64_map_gens + decimal_128_map_gens, ids=idfn)
def test_map_concat_with_lit(data_gen):
lit_col1 = f.lit(gen_scalar(data_gen)).cast(data_gen.data_type)
lit_col2 = f.lit(gen_scalar(data_gen)).cast(data_gen.data_type)
diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py
index 947f5e090a2..e9057b77d6b 100644
--- a/integration_tests/src/main/python/data_gen.py
+++ b/integration_tests/src/main/python/data_gen.py
@@ -969,7 +969,10 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
# Some map gens, but not all because of nesting
map_gens_sample = all_basic_map_gens + [MapGen(StringGen(pattern='key_[0-9]', nullable=False), ArrayGen(string_gen), max_length=10),
MapGen(RepeatSeqGen(IntegerGen(nullable=False), 10), long_gen, max_length=10),
- MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen)]
+ MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen),
+ MapGen(IntegerGen(False), ArrayGen(int_gen, max_length=3), max_length=3),
+ MapGen(ShortGen(False), StructGen([['child0', byte_gen], ['child1', double_gen]]), max_length=3),
+ MapGen(ByteGen(False), MapGen(FloatGen(False), date_gen, max_length=3), max_length=3)]
nested_gens_sample = array_gens_sample + struct_gens_sample_with_decimal128 + map_gens_sample + decimal_128_map_gens
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index ae82390baeb..68d22da0b3a 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -3100,14 +3100,25 @@ object GpuOverrides extends Logging {
}),
expr[MapConcat](
"Returns the union of all the given maps",
+ // Currently, GpuMapConcat supports nested values but not nested keys.
+ // We will add the nested key support after
+ // cuDF can fully support nested types in lists::drop_list_duplicates.
+ // Issue link: https://github.com/rapidsai/cudf/issues/11093
ExprChecks.projectOnly(TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
- TypeSig.NULL),
+ TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
TypeSig.MAP.nested(TypeSig.all),
repeatingParamCheck = Some(RepeatingParamCheck("input",
TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
- TypeSig.NULL),
+ TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
TypeSig.MAP.nested(TypeSig.all)))),
(a, conf, p, r) => new ComplexTypeMergingExprMeta[MapConcat](a, conf, p, r) {
+ override def tagExprForGpu(): Unit = {
+ a.dataType.keyType match {
+ case MapType(_,_,_) | ArrayType(_,_) | StructType(_) => willNotWorkOnGpu(
+ s"GpuMapConcat does not currently support the key type ${a.dataType.keyType}.")
+ case _ =>
+ }
+ }
override def convertToGpu(child: Seq[Expression]): GpuExpression = GpuMapConcat(child)
}),
expr[ConcatWs](
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala
index 7521103835e..05324669b29 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala
@@ -18,8 +18,6 @@ package org.apache.spark.sql.rapids
import java.util.Optional
-import scala.collection.mutable.ArrayBuffer
-
import ai.rapids.cudf
import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar, SegmentedReductionAggregation, Table}
import com.nvidia.spark.rapids._
@@ -96,30 +94,9 @@ case class GpuMapConcat(children: Seq[Expression]) extends GpuComplexTypeMerging
// For single column concat, we pass the result of child node to avoid extra cuDF call.
case (_, 1) => children.head.columnarEval(batch)
case (dt, _) => {
- val cols = children.safeMap(columnarEvalToColumn(_, batch))
- // concatenate keys and values
- val (key_list, value_list) = withResource(cols) { cols =>
- withResource(ArrayBuffer[ColumnView]()) { keys =>
- withResource(ArrayBuffer[ColumnView]()) { values =>
- cols.foreach{ col =>
- keys.append(GpuMapUtils.getKeysAsListView(col.getBase))
- values.append(GpuMapUtils.getValuesAsListView(col.getBase))
- }
- closeOnExcept(ColumnVector.listConcatenateByRow(keys: _*)) {key_list =>
- (key_list, ColumnVector.listConcatenateByRow(values: _*))
- }
- }
- }
- }
- // build map column from concatenated keys and values
- withResource(Seq(key_list, value_list)) { case Seq(keys, values) =>
- withResource(Seq(keys.getChildColumnView(0), values.getChildColumnView(0))) {
- case Seq(k_child, v_chlid) =>
- withResource(ColumnView.makeStructView(k_child, v_chlid)) {structs =>
- withResource(keys.replaceListChild(structs)) { struct_list =>
- GpuCreateMap.createMapFromKeysValuesAsStructs(dt, struct_list)
- }
- }
+ withResource(children.safeMap(columnarEvalToColumn(_, batch).getBase())) {cols =>
+ withResource(cudf.ColumnVector.listConcatenateByRow(cols: _*)) {structs =>
+ GpuCreateMap.createMapFromKeysValuesAsStructs(dataType, structs)
}
}
}
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/CollectionOpSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/CollectionOpSuite.scala
new file mode 100644
index 00000000000..6d1bd7fc279
--- /dev/null
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/CollectionOpSuite.scala
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids
+
+import org.apache.spark.sql.functions.map_concat
+
+class CollectionOpSuite extends SparkQueryCompareTestSuite {
+ testGpuFallback(
+ "MapConcat with Array keys fall back",
+ "ProjectExec",
+ ArrayKeyMapDF,
+ execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec")) {
+ frame => {
+ import frame.sparkSession.implicits._
+ frame.select(map_concat($"col1", $"col2"))
+ }
+ }
+
+ testGpuFallback(
+ "MapConcat with Struct keys fall back",
+ "ProjectExec",
+ StructKeyMapDF,
+ execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec")) {
+ frame => {
+ import frame.sparkSession.implicits._
+ frame.select(map_concat($"col1", $"col2"))
+ }
+ }
+}
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala
index 8a3a907af7d..6619fe0ed73 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala
@@ -1780,6 +1780,16 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm {
).toDF("strings")
}
+ def ArrayKeyMapDF(session: SparkSession): DataFrame = {
+ import session.sqlContext.implicits._
+ Seq((Map(List(1, 2) -> 2), (Map(List(2, 3) -> 3)))).toDF("col1", "col2")
+ }
+
+ def StructKeyMapDF(session: SparkSession): DataFrame = {
+ import session.sqlContext.implicits._
+ Seq((Map((1, 2) -> 2), (Map((2, 3) -> 3)))).toDF("col1", "col2")
+ }
+
def nullableStringsFromCsv = {
fromCsvDf("strings.csv", StructType(Array(
StructField("strings", StringType, true),