Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement groupby collect_set #7420

Merged
merged 24 commits into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e15dbf0
Rename aggregation::Kind::COLLECT to aggregation::Kind::COLLECT_LIST.…
ttnghia Feb 22, 2021
73a0fa3
Rename functions `*_collect_*` into `*_collect_list_*` functions
ttnghia Feb 22, 2021
8814fcd
Merge branch 'branch-0.19' into collect_set
ttnghia Mar 12, 2021
1ec959f
Add collect_set type and factory function
ttnghia Mar 12, 2021
401290a
Update copyright year
ttnghia Mar 12, 2021
8ceda44
Initially implement tests for collect_set
ttnghia Mar 16, 2021
35ac84c
Rewrite tests for groupby collect_set
ttnghia Mar 16, 2021
931e01c
Implement sort-based groupby collect_set
ttnghia Mar 16, 2021
9db7c58
Add detail API for drop_list_duplicates that accepts stream parameter
ttnghia Mar 16, 2021
f522627
Expose the detail::drop_list_duplicates function to use in other places
ttnghia Mar 16, 2021
2af1db7
Use the detail::drop_list_duplicate function in groupby collect_set
ttnghia Mar 16, 2021
4099d17
Fix default parameters for the make_collect_set_aggregation() function
ttnghia Mar 16, 2021
765b2ef
Remove rolling collect_set test
ttnghia Mar 16, 2021
10abd8e
Implement tests for groupby collect_set
ttnghia Mar 17, 2021
864d878
Add groupby/collect_set_test.cpp to CMakeList.txt
ttnghia Mar 17, 2021
6887816
Update copyright year in header
ttnghia Mar 17, 2021
96c0e27
Remove CUDF_TEST_PROGRAM_MAIN() from test file
ttnghia Mar 17, 2021
483e752
Fix style check
ttnghia Mar 17, 2021
c0dae84
Update tests for groupby collect_set
ttnghia Mar 17, 2021
43731ba
Fix java binding and python binding for collect_list, and also add a …
ttnghia Mar 18, 2021
d91d53c
Reverse name for Kind::COLLECT enum in Python binding
ttnghia Mar 18, 2021
b511cb5
Fix python binding for collect API
ttnghia Mar 18, 2021
782dc9d
Merge branch 'branch-0.19' into collect_set
ttnghia Mar 23, 2021
918eeb7
Rename tests unit
ttnghia Mar 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conda/recipes/libcudf/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ test:
- test -f $PREFIX/include/cudf/join.hpp
- test -f $PREFIX/include/cudf/lists/detail/concatenate.hpp
- test -f $PREFIX/include/cudf/lists/detail/copying.hpp
- test -f $PREFIX/include/cudf/lists/detail/drop_list_duplicates.hpp
- test -f $PREFIX/include/cudf/lists/detail/sorting.hpp
- test -f $PREFIX/include/cudf/lists/count_elements.hpp
- test -f $PREFIX/include/cudf/lists/drop_list_duplicates.hpp
Expand Down
28 changes: 23 additions & 5 deletions cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -74,7 +74,8 @@ class aggregation {
NUNIQUE, ///< count number of unique elements
NTH_ELEMENT, ///< get the nth element
ROW_NUMBER, ///< get row-number of current index (relative to rolling window)
COLLECT, ///< collect values into a list
COLLECT_LIST, ///< collect values into a list
COLLECT_SET, ///< collect values into a list without duplicate entries
LEAD, ///< window function, accesses row at specified offset following current row
LAG, ///< window function, accesses row at specified offset preceding current row
PTX, ///< PTX UDF based reduction
Expand Down Expand Up @@ -205,18 +206,35 @@ std::unique_ptr<aggregation> make_nth_element_aggregation(
std::unique_ptr<aggregation> make_row_number_aggregation();

/**
* @brief Factory to create a COLLECT aggregation
* @brief Factory to create a COLLECT_LIST aggregation
*
* `COLLECT` returns a list column of all included elements in the group/series.
* `COLLECT_LIST` returns a list column of all included elements in the group/series.
*
* If `null_handling` is set to `EXCLUDE`, null elements are dropped from each
* of the list rows.
*
* @param null_handling Indicates whether to include/exclude nulls in list elements.
*/
std::unique_ptr<aggregation> make_collect_aggregation(
std::unique_ptr<aggregation> make_collect_list_aggregation(
null_policy null_handling = null_policy::INCLUDE);

/**
* @brief Factory to create a COLLECT_SET aggregation
*
* `COLLECT_SET` returns a lists column of all included elements in the group/series. Within each
* list, the duplicated entries are dropped out such that each entry appears only once.
*
* If `null_handling` is set to `EXCLUDE`, null elements are dropped from each
* of the list rows.
*
* @param null_handling Indicates whether to include/exclude nulls during collection
* @param nulls_equal Flag to specify whether null entries within each list should be considered
* equal
*/
std::unique_ptr<aggregation> make_collect_set_aggregation(
null_policy null_handling = null_policy::INCLUDE,
null_equality null_equal = null_equality::EQUAL);

/// Factory to create a LAG aggregation
std::unique_ptr<aggregation> make_lag_aggregation(size_type offset);

Expand Down
48 changes: 41 additions & 7 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -320,11 +320,11 @@ struct udf_aggregation final : derived_aggregation<udf_aggregation> {
};

/**
* @brief Derived aggregation class for specifying COLLECT aggregation
* @brief Derived aggregation class for specifying COLLECT_LIST aggregation
*/
struct collect_list_aggregation final : derived_aggregation<nunique_aggregation> {
explicit collect_list_aggregation(null_policy null_handling = null_policy::INCLUDE)
: derived_aggregation{COLLECT}, _null_handling{null_handling}
: derived_aggregation{COLLECT_LIST}, _null_handling{null_handling}
{
}
null_policy _null_handling; ///< include or exclude nulls
Expand All @@ -340,6 +340,32 @@ struct collect_list_aggregation final : derived_aggregation<nunique_aggregation>
size_t hash_impl() const { return std::hash<int>{}(static_cast<int>(_null_handling)); }
};

/**
* @brief Derived aggregation class for specifying COLLECT_SET aggregation
*/
struct collect_set_aggregation final : derived_aggregation<collect_set_aggregation> {
explicit collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE,
null_equality null_equal = null_equality::EQUAL)
: derived_aggregation{COLLECT_SET}, _null_handling{null_handling}, _null_equal(null_equal)
{
}
null_policy _null_handling; ///< include or exclude nulls
null_equality _null_equal; ///< whether to consider nulls as equal values

protected:
friend class derived_aggregation<collect_set_aggregation>;

bool operator==(collect_set_aggregation const& other) const
{
return _null_handling == other._null_handling && _null_equal == other._null_equal;
}

size_t hash_impl() const
{
return std::hash<int>{}(static_cast<int>(_null_handling) ^ static_cast<int>(_null_equal));
}
};

/**
* @brief Sentinel value used for `ARGMAX` aggregation.
*
Expand Down Expand Up @@ -514,9 +540,15 @@ struct target_type_impl<Source, aggregation::ROW_NUMBER> {
using type = cudf::size_type;
};

// Always use list for COLLECT
// Always use list for COLLECT_LIST
template <typename Source>
struct target_type_impl<Source, aggregation::COLLECT_LIST> {
using type = cudf::list_view;
};

// Always use list for COLLECT_SET
template <typename Source>
struct target_type_impl<Source, aggregation::COLLECT> {
struct target_type_impl<Source, aggregation::COLLECT_SET> {
using type = cudf::list_view;
};

Expand Down Expand Up @@ -617,8 +649,10 @@ CUDA_HOST_DEVICE_CALLABLE decltype(auto) aggregation_dispatcher(aggregation::Kin
return f.template operator()<aggregation::NTH_ELEMENT>(std::forward<Ts>(args)...);
case aggregation::ROW_NUMBER:
return f.template operator()<aggregation::ROW_NUMBER>(std::forward<Ts>(args)...);
case aggregation::COLLECT:
return f.template operator()<aggregation::COLLECT>(std::forward<Ts>(args)...);
case aggregation::COLLECT_LIST:
return f.template operator()<aggregation::COLLECT_LIST>(std::forward<Ts>(args)...);
case aggregation::COLLECT_SET:
return f.template operator()<aggregation::COLLECT_SET>(std::forward<Ts>(args)...);
case aggregation::LEAD:
return f.template operator()<aggregation::LEAD>(std::forward<Ts>(args)...);
case aggregation::LAG:
Expand Down
38 changes: 38 additions & 0 deletions cpp/include/cudf/lists/detail/drop_list_duplicates.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cudf/lists/lists_column_view.hpp>

#include <rmm/cuda_stream_view.hpp>

namespace cudf {
namespace lists {
namespace detail {

/**
* @copydoc cudf::lists::drop_list_duplicates
*
* @param stream CUDA stream used for device memory operations and kernel launches.
*/
std::unique_ptr<column> drop_list_duplicates(
lists_column_view const& lists_column,
null_equality nulls_equal,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());
} // namespace detail
} // namespace lists
} // namespace cudf
12 changes: 9 additions & 3 deletions cpp/src/aggregation/aggregation.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -125,11 +125,17 @@ std::unique_ptr<aggregation> make_row_number_aggregation()
{
return std::make_unique<aggregation>(aggregation::ROW_NUMBER);
}
/// Factory to create a COLLECT aggregation
std::unique_ptr<aggregation> make_collect_aggregation(null_policy null_handling)
/// Factory to create a COLLECT_LIST aggregation
std::unique_ptr<aggregation> make_collect_list_aggregation(null_policy null_handling)
{
return std::make_unique<detail::collect_list_aggregation>(null_handling);
}
/// Factory to create a COLLECT_SET aggregation
std::unique_ptr<aggregation> make_collect_set_aggregation(null_policy null_handling,
null_equality null_equal)
{
return std::make_unique<detail::collect_set_aggregation>(null_handling, null_equal);
}
/// Factory to create a LAG aggregation
std::unique_ptr<aggregation> make_lag_aggregation(size_type offset)
{
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/groupby/groupby.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
29 changes: 24 additions & 5 deletions cpp/src/groupby/sort/groupby.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possible merge conflict with PR #7387

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Then either this, or that PR need to be merged first.

*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,16 +19,15 @@

#include <cudf/aggregation.hpp>
#include <cudf/column/column.hpp>
#include <cudf/column/column_factories.hpp>
#include <cudf/column/column_view.hpp>
#include <cudf/detail/aggregation/aggregation.hpp>
#include <cudf/detail/aggregation/result_cache.hpp>
#include <cudf/detail/binaryop.hpp>
#include <cudf/detail/gather.hpp>
#include <cudf/detail/groupby.hpp>
#include <cudf/detail/groupby/sort_helper.hpp>
#include <cudf/detail/unary.hpp>
#include <cudf/groupby.hpp>
#include <cudf/lists/detail/drop_list_duplicates.hpp>
#include <cudf/table/table.hpp>
#include <cudf/table/table_view.hpp>
#include <cudf/types.hpp>
Expand Down Expand Up @@ -65,6 +64,7 @@ struct store_result_functor {
template <aggregation::Kind k>
void operator()(aggregation const& agg)
{
CUDF_FAIL("Unsupported aggregation.");
}

private:
Expand Down Expand Up @@ -401,12 +401,12 @@ void store_result_functor::operator()<aggregation::NTH_ELEMENT>(aggregation cons
}

template <>
void store_result_functor::operator()<aggregation::COLLECT>(aggregation const& agg)
void store_result_functor::operator()<aggregation::COLLECT_LIST>(aggregation const& agg)
{
auto null_handling =
static_cast<cudf::detail::collect_list_aggregation const&>(agg)._null_handling;
CUDF_EXPECTS(null_handling == null_policy::INCLUDE,
"null exclusion is not supported on groupby COLLECT aggregation.");
"null exclusion is not supported on groupby COLLECT_LIST aggregation.");

if (cache.has_result(col_idx, agg)) return;

Expand All @@ -416,6 +416,25 @@ void store_result_functor::operator()<aggregation::COLLECT>(aggregation const& a
cache.add_result(col_idx, agg, std::move(result));
};

template <>
void store_result_functor::operator()<aggregation::COLLECT_SET>(aggregation const& agg)
{
auto const null_handling =
static_cast<cudf::detail::collect_set_aggregation const&>(agg)._null_handling;
CUDF_EXPECTS(null_handling == null_policy::INCLUDE,
"null exclusion is not supported on groupby COLLECT_SET aggregation.");

if (cache.has_result(col_idx, agg)) { return; }

auto const collect_result = detail::group_collect(
get_grouped_values(), helper.group_offsets(), helper.num_groups(), stream, mr);
auto const nulls_equal =
static_cast<cudf::detail::collect_set_aggregation const&>(agg)._null_equal;
cache.add_result(col_idx,
agg,
lists::detail::drop_list_duplicates(
lists_column_view(collect_result->view()), nulls_equal, stream, mr));
};
} // namespace detail

// Sort-based groupby
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/lists/drop_list_duplicates.cu
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ void generate_offsets(size_type num_entries,
return offsets[i - prefix_sum_empty_lists[i]];
});
}
} // anonymous namespace

/**
* @copydoc cudf::lists::drop_list_duplicates
*
Expand Down Expand Up @@ -276,7 +278,6 @@ std::unique_ptr<column> drop_list_duplicates(lists_column_view const& lists_colu
cudf::detail::copy_bitmask(lists_column.parent(), stream, mr));
}

} // anonymous namespace
} // namespace detail

/**
Expand Down
19 changes: 10 additions & 9 deletions cpp/src/rolling/rolling_detail.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ template <typename InputType,
std::enable_if_t<!std::is_same<InputType, cudf::string_view>::value and
!(op == aggregation::COUNT_VALID || op == aggregation::COUNT_ALL ||
op == aggregation::ROW_NUMBER || op == aggregation::LEAD ||
op == aggregation::LAG || op == aggregation::COLLECT)>* = nullptr>
op == aggregation::LAG || op == aggregation::COLLECT_LIST)>* = nullptr>
bool __device__ process_rolling_window(column_device_view input,
column_device_view ignored_default_outputs,
mutable_column_device_view output,
Expand Down Expand Up @@ -814,7 +814,7 @@ struct rolling_window_launcher {
typename PrecedingWindowIterator,
typename FollowingWindowIterator>
std::enable_if_t<!(op == aggregation::MEAN || op == aggregation::LEAD || op == aggregation::LAG ||
op == aggregation::COLLECT),
op == aggregation::COLLECT_LIST),
std::unique_ptr<column>>
operator()(column_view const& input,
column_view const& default_outputs,
Expand Down Expand Up @@ -897,11 +897,11 @@ struct rolling_window_launcher {
}

/**
* @brief Creates the offsets child of the result of the `COLLECT` window aggregation
* @brief Creates the offsets child of the result of the `COLLECT_LIST` window aggregation
*
* Given the input column, the preceding/following window bounds, and `min_periods`,
* the sizes of each list row may be computed. These values can then be used to
* calculate the offsets for the result of `COLLECT`.
* calculate the offsets for the result of `COLLECT_LIST`.
*
* Note: If `min_periods` exceeds the number of observations for a window, the size
* is set to `0` (since the result is `null`).
Expand Down Expand Up @@ -945,7 +945,7 @@ struct rolling_window_launcher {
}

/**
* @brief Generate mapping of each row in the COLLECT result's child column
* @brief Generate mapping of each row in the COLLECT_LIST result's child column
* to the index of the row it belongs to.
*
* If
Expand Down Expand Up @@ -1030,7 +1030,7 @@ struct rolling_window_launcher {

/**
* @brief Create gather map to generate the child column of the result of
* the `COLLECT` window aggregation.
* the `COLLECT_LIST` window aggregation.
*/
template <typename PrecedingIter>
std::unique_ptr<column> create_collect_gather_map(column_view const& child_offsets,
Expand Down Expand Up @@ -1064,7 +1064,7 @@ struct rolling_window_launcher {
}

/**
* @brief Count null entries in result of COLLECT.
* @brief Count null entries in result of COLLECT_LIST.
*/
size_type count_child_nulls(column_view const& input,
std::unique_ptr<column> const& gather_map,
Expand Down Expand Up @@ -1139,7 +1139,7 @@ struct rolling_window_launcher {
}

template <aggregation::Kind op, typename PrecedingIter, typename FollowingIter>
std::enable_if_t<(op == aggregation::COLLECT), std::unique_ptr<column>> operator()(
std::enable_if_t<(op == aggregation::COLLECT_LIST), std::unique_ptr<column>> operator()(
column_view const& input,
column_view const& default_outputs,
PrecedingIter preceding_begin_raw,
Expand All @@ -1150,7 +1150,7 @@ struct rolling_window_launcher {
rmm::mr::device_memory_resource* mr)
{
CUDF_EXPECTS(default_outputs.is_empty(),
"COLLECT window function does not support default values.");
"COLLECT_LIST window function does not support default values.");

if (input.is_empty()) return empty_like(input);

Expand Down Expand Up @@ -1370,6 +1370,7 @@ std::unique_ptr<column> rolling_window(column_view const& input,
auto input_col = cudf::is_dictionary(input.type())
? dictionary_column_view(input).get_indices_annotated()
: input;

auto output = cudf::type_dispatcher(input_col.type(),
dispatch_rolling{},
input_col,
Expand Down
Loading