From 34c17618f7cc2efbeff1788806bcbe77a16eed56 Mon Sep 17 00:00:00 2001 From: Remzi Yang <59198230+HaoYang670@users.noreply.github.com> Date: Sat, 25 Jun 2022 10:23:35 +0800 Subject: [PATCH] Add `GpuMapConcat` support for nested-type values (#5686) * support nested values Signed-off-by: remzi <13716567376yh@gmail.com> * rewrite GpuMapConcat Signed-off-by: remzi <13716567376yh@gmail.com> * update simplify shrink the test size Signed-off-by: remzi <13716567376yh@gmail.com> * add limit Signed-off-by: remzi <13716567376yh@gmail.com> * remove unused import Signed-off-by: remzi <13716567376yh@gmail.com> * fallback to CPU when key type is nested Signed-off-by: remzi <13716567376yh@gmail.com> * update docs Signed-off-by: remzi <13716567376yh@gmail.com> * add tests Signed-off-by: remzi <13716567376yh@gmail.com> * update comments Signed-off-by: remzi <13716567376yh@gmail.com> --- docs/supported_ops.md | 4 +- .../src/main/python/collection_ops_test.py | 4 +- integration_tests/src/main/python/data_gen.py | 5 ++- .../nvidia/spark/rapids/GpuOverrides.scala | 15 ++++++- .../sql/rapids/collectionOperations.scala | 29 ++----------- .../spark/rapids/CollectionOpSuite.scala | 43 +++++++++++++++++++ .../rapids/SparkQueryCompareTestSuite.scala | 10 +++++ 7 files changed, 77 insertions(+), 33 deletions(-) create mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/CollectionOpSuite.scala diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 66946d040ab..9364b79b787 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -8628,7 +8628,7 @@ are limited. -PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
@@ -8649,7 +8649,7 @@ are limited. -PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
diff --git a/integration_tests/src/main/python/collection_ops_test.py b/integration_tests/src/main/python/collection_ops_test.py index 077066418d5..7e31c512f49 100644 --- a/integration_tests/src/main/python/collection_ops_test.py +++ b/integration_tests/src/main/python/collection_ops_test.py @@ -89,7 +89,7 @@ def test_concat_string(): f.concat(f.lit(''), f.col('b')), f.concat(f.col('a'), f.lit('')))) -@pytest.mark.parametrize('data_gen', all_basic_map_gens + decimal_64_map_gens + decimal_128_map_gens, ids=idfn) +@pytest.mark.parametrize('data_gen', map_gens_sample + decimal_64_map_gens + decimal_128_map_gens, ids=idfn) def test_map_concat(data_gen): assert_gpu_and_cpu_are_equal_collect( lambda spark: three_col_df(spark, data_gen, data_gen, data_gen @@ -100,7 +100,7 @@ def test_map_concat(data_gen): {"spark.sql.mapKeyDedupPolicy": "LAST_WIN"} ) -@pytest.mark.parametrize('data_gen', all_basic_map_gens + decimal_64_map_gens + decimal_128_map_gens, ids=idfn) +@pytest.mark.parametrize('data_gen', map_gens_sample + decimal_64_map_gens + decimal_128_map_gens, ids=idfn) def test_map_concat_with_lit(data_gen): lit_col1 = f.lit(gen_scalar(data_gen)).cast(data_gen.data_type) lit_col2 = f.lit(gen_scalar(data_gen)).cast(data_gen.data_type) diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py index 947f5e090a2..e9057b77d6b 100644 --- a/integration_tests/src/main/python/data_gen.py +++ b/integration_tests/src/main/python/data_gen.py @@ -969,7 +969,10 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False): # Some map gens, but not all because of nesting map_gens_sample = all_basic_map_gens + [MapGen(StringGen(pattern='key_[0-9]', nullable=False), ArrayGen(string_gen), max_length=10), MapGen(RepeatSeqGen(IntegerGen(nullable=False), 10), long_gen, max_length=10), - MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen)] + MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen), + MapGen(IntegerGen(False), ArrayGen(int_gen, max_length=3), max_length=3), + MapGen(ShortGen(False), StructGen([['child0', byte_gen], ['child1', double_gen]]), max_length=3), + MapGen(ByteGen(False), MapGen(FloatGen(False), date_gen, max_length=3), max_length=3)] nested_gens_sample = array_gens_sample + struct_gens_sample_with_decimal128 + map_gens_sample + decimal_128_map_gens diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index ae82390baeb..68d22da0b3a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3100,14 +3100,25 @@ object GpuOverrides extends Logging { }), expr[MapConcat]( "Returns the union of all the given maps", + // Currently, GpuMapConcat supports nested values but not nested keys. + // We will add the nested key support after + // cuDF can fully support nested types in lists::drop_list_duplicates. + // Issue link: https://github.com/rapidsai/cudf/issues/11093 ExprChecks.projectOnly(TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + - TypeSig.NULL), + TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.MAP.nested(TypeSig.all), repeatingParamCheck = Some(RepeatingParamCheck("input", TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + - TypeSig.NULL), + TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.MAP.nested(TypeSig.all)))), (a, conf, p, r) => new ComplexTypeMergingExprMeta[MapConcat](a, conf, p, r) { + override def tagExprForGpu(): Unit = { + a.dataType.keyType match { + case MapType(_,_,_) | ArrayType(_,_) | StructType(_) => willNotWorkOnGpu( + s"GpuMapConcat does not currently support the key type ${a.dataType.keyType}.") + case _ => + } + } override def convertToGpu(child: Seq[Expression]): GpuExpression = GpuMapConcat(child) }), expr[ConcatWs]( diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala index 7521103835e..05324669b29 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala @@ -18,8 +18,6 @@ package org.apache.spark.sql.rapids import java.util.Optional -import scala.collection.mutable.ArrayBuffer - import ai.rapids.cudf import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar, SegmentedReductionAggregation, Table} import com.nvidia.spark.rapids._ @@ -96,30 +94,9 @@ case class GpuMapConcat(children: Seq[Expression]) extends GpuComplexTypeMerging // For single column concat, we pass the result of child node to avoid extra cuDF call. case (_, 1) => children.head.columnarEval(batch) case (dt, _) => { - val cols = children.safeMap(columnarEvalToColumn(_, batch)) - // concatenate keys and values - val (key_list, value_list) = withResource(cols) { cols => - withResource(ArrayBuffer[ColumnView]()) { keys => - withResource(ArrayBuffer[ColumnView]()) { values => - cols.foreach{ col => - keys.append(GpuMapUtils.getKeysAsListView(col.getBase)) - values.append(GpuMapUtils.getValuesAsListView(col.getBase)) - } - closeOnExcept(ColumnVector.listConcatenateByRow(keys: _*)) {key_list => - (key_list, ColumnVector.listConcatenateByRow(values: _*)) - } - } - } - } - // build map column from concatenated keys and values - withResource(Seq(key_list, value_list)) { case Seq(keys, values) => - withResource(Seq(keys.getChildColumnView(0), values.getChildColumnView(0))) { - case Seq(k_child, v_chlid) => - withResource(ColumnView.makeStructView(k_child, v_chlid)) {structs => - withResource(keys.replaceListChild(structs)) { struct_list => - GpuCreateMap.createMapFromKeysValuesAsStructs(dt, struct_list) - } - } + withResource(children.safeMap(columnarEvalToColumn(_, batch).getBase())) {cols => + withResource(cudf.ColumnVector.listConcatenateByRow(cols: _*)) {structs => + GpuCreateMap.createMapFromKeysValuesAsStructs(dataType, structs) } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/CollectionOpSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/CollectionOpSuite.scala new file mode 100644 index 00000000000..6d1bd7fc279 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/CollectionOpSuite.scala @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.functions.map_concat + +class CollectionOpSuite extends SparkQueryCompareTestSuite { + testGpuFallback( + "MapConcat with Array keys fall back", + "ProjectExec", + ArrayKeyMapDF, + execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec")) { + frame => { + import frame.sparkSession.implicits._ + frame.select(map_concat($"col1", $"col2")) + } + } + + testGpuFallback( + "MapConcat with Struct keys fall back", + "ProjectExec", + StructKeyMapDF, + execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec")) { + frame => { + import frame.sparkSession.implicits._ + frame.select(map_concat($"col1", $"col2")) + } + } +} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala index 8a3a907af7d..6619fe0ed73 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala @@ -1780,6 +1780,16 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm { ).toDF("strings") } + def ArrayKeyMapDF(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq((Map(List(1, 2) -> 2), (Map(List(2, 3) -> 3)))).toDF("col1", "col2") + } + + def StructKeyMapDF(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq((Map((1, 2) -> 2), (Map((2, 3) -> 3)))).toDF("col1", "col2") + } + def nullableStringsFromCsv = { fromCsvDf("strings.csv", StructType(Array( StructField("strings", StringType, true),