Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GpuMapConcat support for nested-type values #5686

Merged
merged 10 commits into from
Jun 25, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -8628,7 +8628,7 @@ are limited.
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
</tr>
Expand All @@ -8649,7 +8649,7 @@ are limited.
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
</tr>
Expand Down
4 changes: 2 additions & 2 deletions integration_tests/src/main/python/collection_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_concat_string():
f.concat(f.lit(''), f.col('b')),
f.concat(f.col('a'), f.lit(''))))

@pytest.mark.parametrize('data_gen', all_basic_map_gens + decimal_64_map_gens + decimal_128_map_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', map_gens_sample + decimal_64_map_gens + decimal_128_map_gens, ids=idfn)
def test_map_concat(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: three_col_df(spark, data_gen, data_gen, data_gen
Expand All @@ -100,7 +100,7 @@ def test_map_concat(data_gen):
{"spark.sql.mapKeyDedupPolicy": "LAST_WIN"}
)

@pytest.mark.parametrize('data_gen', all_basic_map_gens + decimal_64_map_gens + decimal_128_map_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', map_gens_sample + decimal_64_map_gens + decimal_128_map_gens, ids=idfn)
def test_map_concat_with_lit(data_gen):
lit_col1 = f.lit(gen_scalar(data_gen)).cast(data_gen.data_type)
lit_col2 = f.lit(gen_scalar(data_gen)).cast(data_gen.data_type)
Expand Down
5 changes: 4 additions & 1 deletion integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,10 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
# Some map gens, but not all because of nesting
map_gens_sample = all_basic_map_gens + [MapGen(StringGen(pattern='key_[0-9]', nullable=False), ArrayGen(string_gen), max_length=10),
MapGen(RepeatSeqGen(IntegerGen(nullable=False), 10), long_gen, max_length=10),
MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen)]
MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen),
MapGen(IntegerGen(False), ArrayGen(int_gen)),
jlowe marked this conversation as resolved.
Show resolved Hide resolved
MapGen(BooleanGen(False), StructGen([['child0', byte_gen], ['child1', double_gen]]), max_length=10),
jlowe marked this conversation as resolved.
Show resolved Hide resolved
MapGen(ByteGen(False), MapGen(FloatGen(False), date_gen, max_length=10), max_length=10)]

nested_gens_sample = array_gens_sample + struct_gens_sample_with_decimal128 + map_gens_sample + decimal_128_map_gens

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3104,11 +3104,11 @@ object GpuOverrides extends Logging {
expr[MapConcat](
"Returns the union of all the given maps",
ExprChecks.projectOnly(TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
TypeSig.NULL),
TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
TypeSig.MAP.nested(TypeSig.all),
repeatingParamCheck = Some(RepeatingParamCheck("input",
TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
TypeSig.NULL),
TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem here is that the checks don't distinguish between keys and values in a map. This check is adding nested support for both keys and values, but the code can only support nested types for values, not keys.

This comment was marked as resolved.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!

Copy link
Collaborator Author

@HaoYang670 HaoYang670 Jun 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how could we test the nested key type cases, as MapGen does not support ArrayGen as the keys. (List in Python is unhashable)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a psNote to the param checks so it is clearly documented that keys don't support arrays? As for the tests, when we do support arrays for map keys we can write the tests in scala or we can work around it by creating a list of key/value structs and then converting it to a map, do the computation, and then convert the resulting map back to a list of key/value structs after we are done. I think a scala test is preferable but either would work.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated! And I have added 2 cases to test the falling back.

TypeSig.MAP.nested(TypeSig.all)))),
(a, conf, p, r) => new ComplexTypeMergingExprMeta[MapConcat](a, conf, p, r) {
override def convertToGpu(child: Seq[Expression]): GpuExpression = GpuMapConcat(child)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,30 +103,9 @@ case class GpuMapConcat(children: Seq[Expression]) extends GpuComplexTypeMerging
// For single column concat, we pass the result of child node to avoid extra cuDF call.
case (_, 1) => children.head.columnarEval(batch)
case (dt, _) => {
val cols = children.safeMap(columnarEvalToColumn(_, batch))
// concatenate keys and values
val (key_list, value_list) = withResource(cols) { cols =>
withResource(ArrayBuffer[ColumnView]()) { keys =>
withResource(ArrayBuffer[ColumnView]()) { values =>
cols.foreach{ col =>
keys.append(GpuMapUtils.getKeysAsListView(col.getBase))
values.append(GpuMapUtils.getValuesAsListView(col.getBase))
}
closeOnExcept(ColumnVector.listConcatenateByRow(keys: _*)) {key_list =>
(key_list, ColumnVector.listConcatenateByRow(values: _*))
}
}
}
}
// build map column from concatenated keys and values
withResource(Seq(key_list, value_list)) { case Seq(keys, values) =>
withResource(Seq(keys.getChildColumnView(0), values.getChildColumnView(0))) {
case Seq(k_child, v_chlid) =>
withResource(ColumnView.makeStructView(k_child, v_chlid)) {structs =>
withResource(keys.replaceListChild(structs)) { struct_list =>
GpuCreateMap.createMapFromKeysValuesAsStructs(dt, struct_list)
}
}
withResource(children.safeMap(columnarEvalToColumn(_, batch).getBase())) {cols =>
withResource(cudf.ColumnVector.listConcatenateByRow(cols: _*)) {structs =>
GpuCreateMap.createMapFromKeysValuesAsStructs(dataType, structs)
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Expand Down