Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor collect_set to use cudf::distinct and cudf::lists::distinct #11228

Merged
merged 53 commits into from
Jul 15, 2022
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
1d7e8e0
Add new implementation and test files
ttnghia Jun 24, 2022
51b80db
Fix compile error
ttnghia Jun 24, 2022
08a76ad
Rename function
ttnghia Jun 27, 2022
16101f7
Implement `cudf::detail::stable_distinct` and `lists::distinct`
ttnghia Jun 27, 2022
5ec13d6
Rewrite doxygen
ttnghia Jun 27, 2022
6c5b738
Rename variable
ttnghia Jun 27, 2022
5b70eee
Rewrite comment
ttnghia Jun 27, 2022
238248d
Rename files
ttnghia Jun 27, 2022
ba6bf6b
Implement float tests
ttnghia Jun 27, 2022
3845c95
Implement string tests
ttnghia Jun 27, 2022
507c82d
Implement tests for `ListDistinctTypedTest`
ttnghia Jun 28, 2022
2cb8347
Complete the remaining tests
ttnghia Jun 28, 2022
7efdea0
Merge branch 'branch-22.08' into add_lists_distinct
ttnghia Jun 28, 2022
4388637
Rewrite doxygen
ttnghia Jun 28, 2022
4dd5e74
Misc
ttnghia Jun 28, 2022
3b0760c
Misc
ttnghia Jun 28, 2022
9730b70
Rewrite test
ttnghia Jun 28, 2022
9bd9b6f
Fix doxygen
ttnghia Jun 28, 2022
790a482
Fix header
ttnghia Jun 28, 2022
1c58baa
Rewrite doxygen
ttnghia Jun 28, 2022
d493c4f
Rewrite doxygen and fix headers
ttnghia Jun 28, 2022
d090d2a
Fix iterator type
ttnghia Jun 30, 2022
ee51822
Rewrite doxygen
ttnghia Jun 30, 2022
ccdd6f0
Add empty lines
ttnghia Jun 30, 2022
034ee2a
Merge branch 'branch-22.08' into add_lists_distinct
ttnghia Jun 30, 2022
b1231a2
Update default stream
ttnghia Jun 30, 2022
af91b80
Merge branch 'branch-22.08' into add_lists_distinct
ttnghia Jul 5, 2022
86c9ba8
Merge branch 'branch-22.08' into add_lists_distinct
ttnghia Jul 8, 2022
99d70b1
Handle empty input
ttnghia Jul 8, 2022
cf965f6
Merge branch 'add_lists_distinct' into refactor_collect_set
ttnghia Jul 8, 2022
dbc483c
Replace `lists::drop_list_duplicates` by `lists::distinct`
ttnghia Jul 8, 2022
450e509
Fix merge set tests
ttnghia Jul 8, 2022
5ab24f3
Fix collect set tests
ttnghia Jul 8, 2022
e4fc6c4
Merge branch 'branch-22.08' into refactor_collect_set
ttnghia Jul 11, 2022
2ed173c
Fix `collect_ops_tests`
ttnghia Jul 11, 2022
b0ee7c6
Optimize `collect_ops`
ttnghia Jul 11, 2022
c5c6972
Merge branch 'branch-22.08' into refactor_collect_set
ttnghia Jul 11, 2022
d434dcb
Fix tests
ttnghia Jul 11, 2022
d8e2753
Fix copyright year
ttnghia Jul 12, 2022
bf61f5a
Merge branch 'fix_rolling_tests' into refactor_collect_set
ttnghia Jul 12, 2022
7444e2c
Merge branch 'branch-22.08' into refactor_collect_set
ttnghia Jul 12, 2022
697802d
Fix C++ collect_set tests
ttnghia Jul 12, 2022
27d7d83
Fix TableTest for collect set
ttnghia Jul 12, 2022
ea48965
Fix ReductionTest for collect set
ttnghia Jul 12, 2022
37d60cf
Misc
ttnghia Jul 13, 2022
1f10a16
Optimize `collect_set` in reduction
ttnghia Jul 13, 2022
eee0bbb
Rewrite `collect_set_tests`
ttnghia Jul 13, 2022
23c44e4
Merge branch 'branch-22.08' into refactor_collect_set
ttnghia Jul 13, 2022
b55ba30
Misc
ttnghia Jul 13, 2022
73287ec
Add extra blank line
ttnghia Jul 13, 2022
47d3298
Use `sort_by_key`
ttnghia Jul 13, 2022
a23356c
Add/remove comments
ttnghia Jul 13, 2022
b1f1890
Merge branch 'branch-22.08' into refactor_collect_set
ttnghia Jul 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 31 additions & 31 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include <cudf/detail/unary.hpp>
#include <cudf/dictionary/dictionary_column_view.hpp>
#include <cudf/groupby.hpp>
#include <cudf/lists/detail/drop_list_duplicates.hpp>
#include <cudf/lists/detail/stream_compaction.hpp>
#include <cudf/table/table.hpp>
#include <cudf/table/table_view.hpp>
#include <cudf/types.hpp>
Expand Down Expand Up @@ -99,7 +99,7 @@ void aggregate_result_functor::operator()<aggregation::SUM>(aggregation const& a
agg,
detail::group_sum(
get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::PRODUCT>(aggregation const& agg)
Expand All @@ -111,7 +111,7 @@ void aggregate_result_functor::operator()<aggregation::PRODUCT>(aggregation cons
agg,
detail::group_product(
get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::ARGMAX>(aggregation const& agg)
Expand All @@ -126,7 +126,7 @@ void aggregate_result_functor::operator()<aggregation::ARGMAX>(aggregation const
helper.key_sort_order(stream),
stream,
mr));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::ARGMIN>(aggregation const& agg)
Expand All @@ -141,7 +141,7 @@ void aggregate_result_functor::operator()<aggregation::ARGMIN>(aggregation const
helper.key_sort_order(stream),
stream,
mr));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::MIN>(aggregation const& agg)
Expand Down Expand Up @@ -181,7 +181,7 @@ void aggregate_result_functor::operator()<aggregation::MIN>(aggregation const& a
}();

cache.add_result(values, agg, std::move(result));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::MAX>(aggregation const& agg)
Expand Down Expand Up @@ -221,7 +221,7 @@ void aggregate_result_functor::operator()<aggregation::MAX>(aggregation const& a
}();

cache.add_result(values, agg, std::move(result));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::MEAN>(aggregation const& agg)
Expand All @@ -248,7 +248,7 @@ void aggregate_result_functor::operator()<aggregation::MEAN>(aggregation const&
stream,
mr);
cache.add_result(values, agg, std::move(result));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::M2>(aggregation const& agg)
Expand All @@ -263,7 +263,7 @@ void aggregate_result_functor::operator()<aggregation::M2>(aggregation const& ag
values,
agg,
detail::group_m2(get_grouped_values(), mean_result, helper.group_labels(stream), stream, mr));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::VARIANCE>(aggregation const& agg)
Expand All @@ -286,7 +286,7 @@ void aggregate_result_functor::operator()<aggregation::VARIANCE>(aggregation con
stream,
mr);
cache.add_result(values, agg, std::move(result));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::STD>(aggregation const& agg)
Expand All @@ -300,7 +300,7 @@ void aggregate_result_functor::operator()<aggregation::STD>(aggregation const& a

auto result = cudf::detail::unary_operation(var_result, unary_operator::SQRT, stream, mr);
cache.add_result(values, agg, std::move(result));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::QUANTILE>(aggregation const& agg)
Expand All @@ -321,7 +321,7 @@ void aggregate_result_functor::operator()<aggregation::QUANTILE>(aggregation con
stream,
mr);
cache.add_result(values, agg, std::move(result));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::MEDIAN>(aggregation const& agg)
Expand All @@ -341,7 +341,7 @@ void aggregate_result_functor::operator()<aggregation::MEDIAN>(aggregation const
stream,
mr);
cache.add_result(values, agg, std::move(result));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::NUNIQUE>(aggregation const& agg)
Expand All @@ -358,7 +358,7 @@ void aggregate_result_functor::operator()<aggregation::NUNIQUE>(aggregation cons
stream,
mr);
cache.add_result(values, agg, std::move(result));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::NTH_ELEMENT>(aggregation const& agg)
Expand Down Expand Up @@ -404,7 +404,7 @@ void aggregate_result_functor::operator()<aggregation::COLLECT_LIST>(aggregation
stream,
mr);
cache.add_result(values, agg, std::move(result));
};
}

template <>
void aggregate_result_functor::operator()<aggregation::COLLECT_SET>(aggregation const& agg)
Expand All @@ -426,9 +426,9 @@ void aggregate_result_functor::operator()<aggregation::COLLECT_SET>(aggregation
cache.add_result(
values,
agg,
lists::detail::drop_list_duplicates(
lists_column_view(collect_result->view()), nulls_equal, nans_equal, stream, mr));
};
lists::detail::distinct(
lists_column_view{collect_result->view()}, nulls_equal, nans_equal, stream, mr));
}

/**
* @brief Perform merging for the lists that correspond to the same key value.
Expand All @@ -455,7 +455,7 @@ void aggregate_result_functor::operator()<aggregation::MERGE_LISTS>(aggregation
agg,
detail::group_merge_lists(
get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr));
};
}

/**
* @brief Perform merging for the lists corresponding to the same key value, then dropping duplicate
Expand All @@ -473,13 +473,13 @@ void aggregate_result_functor::operator()<aggregation::MERGE_LISTS>(aggregation
* column for this aggregation.
*
* Firstly, this aggregation performs `MERGE_LISTS` to concatenate the input lists (corresponding to
* the same key) into intermediate lists, then it calls `lists::drop_list_duplicates` on them to
* the same key) into intermediate lists, then it calls `lists::distinct` on them to
* remove duplicate list entries. As such, the input (partial results) to this aggregation should be
* generated by (distributed) `COLLECT_LIST` aggregations, not `COLLECT_SET`, to avoid unnecessarily
* removing duplicate entries for the partial results.
*
* Since duplicate list entries will be removed, the parameters `null_equality` and `nan_equality`
* are needed for calling to `lists::drop_list_duplicates`.
* are needed for calling to `lists::distinct`.
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
*/
template <>
void aggregate_result_functor::operator()<aggregation::MERGE_SETS>(aggregation const& agg)
Expand All @@ -494,12 +494,12 @@ void aggregate_result_functor::operator()<aggregation::MERGE_SETS>(aggregation c
auto const& merge_sets_agg = dynamic_cast<cudf::detail::merge_sets_aggregation const&>(agg);
cache.add_result(values,
agg,
lists::detail::drop_list_duplicates(lists_column_view(merged_result->view()),
merge_sets_agg._nulls_equal,
merge_sets_agg._nans_equal,
stream,
mr));
};
lists::detail::distinct(lists_column_view{merged_result->view()},
merge_sets_agg._nulls_equal,
merge_sets_agg._nans_equal,
stream,
mr));
}

/**
* @brief Perform merging for the M2 values that correspond to the same key value.
Expand Down Expand Up @@ -528,7 +528,7 @@ void aggregate_result_functor::operator()<aggregation::MERGE_M2>(aggregation con
agg,
detail::group_merge_m2(
get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr));
};
}

/**
* @brief Creates column views with only valid elements in both input column views
Expand Down Expand Up @@ -600,7 +600,7 @@ void aggregate_result_functor::operator()<aggregation::COVARIANCE>(aggregation c
cov_agg._ddof,
stream,
mr));
};
}

/**
* @brief Perform correlation between two child columns of non-nullable struct column.
Expand Down Expand Up @@ -710,7 +710,7 @@ void aggregate_result_functor::operator()<aggregation::TDIGEST>(aggregation cons
max_centroids,
stream,
mr));
};
}

/**
* @brief Generate a merged tdigest column from a grouped set of input tdigest columns.
Expand Down Expand Up @@ -752,7 +752,7 @@ void aggregate_result_functor::operator()<aggregation::MERGE_TDIGEST>(aggregatio
max_centroids,
stream,
mr));
};
}

} // namespace detail

Expand Down
41 changes: 20 additions & 21 deletions cpp/src/reductions/collect_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,14 @@
#include <cudf/detail/copy_if.cuh>
#include <cudf/detail/iterator.cuh>
#include <cudf/detail/reduction_functions.hpp>
#include <cudf/lists/drop_list_duplicates.hpp>
#include <cudf/lists/lists_column_factories.hpp>
#include <cudf/detail/stream_compaction.hpp>
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/scalar/scalar.hpp>
#include <cudf/scalar/scalar_factories.hpp>

namespace cudf {
namespace reduction {

std::unique_ptr<scalar> drop_duplicates(list_scalar const& scalar,
null_equality nulls_equal,
nan_equality nans_equal,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto list_wrapper = lists::detail::make_lists_column_from_scalar(scalar, 1, stream, mr);
auto lcw = lists_column_view(list_wrapper->view());
auto no_dup_wrapper = lists::drop_list_duplicates(lcw, nulls_equal, nans_equal, mr);
auto no_dup = lists_column_view(no_dup_wrapper->view()).get_sliced_child(stream);
return make_list_scalar(no_dup, stream, mr);
}

std::unique_ptr<scalar> collect_list(column_view const& col,
null_policy null_handling,
rmm::cuda_stream_view stream,
Expand Down Expand Up @@ -72,9 +58,16 @@ std::unique_ptr<scalar> collect_set(column_view const& col,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto scalar = collect_list(col, null_handling, stream, mr);
auto ls = dynamic_cast<list_scalar*>(scalar.get());
return drop_duplicates(*ls, nulls_equal, nans_equal, stream, mr);
auto scalar = collect_list(col, null_handling, stream, mr);
auto ls = dynamic_cast<list_scalar*>(scalar.get());
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
auto distinct_table = detail::distinct(table_view{{ls->view()}},
std::vector<size_type>{0},
duplicate_keep_option::KEEP_ANY,
nulls_equal,
nans_equal,
stream,
mr);
return std::make_unique<list_scalar>(std::move(distinct_table->get_column(0)), true, stream, mr);
}

std::unique_ptr<scalar> merge_sets(lists_column_view const& col,
Expand All @@ -83,9 +76,15 @@ std::unique_ptr<scalar> merge_sets(lists_column_view const& col,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto flatten_col = col.get_sliced_child(stream);
auto scalar = std::make_unique<list_scalar>(flatten_col, true, stream, mr);
return drop_duplicates(*scalar, nulls_equal, nans_equal, stream, mr);
auto flatten_col = col.get_sliced_child(stream);
auto distinct_table = detail::distinct(table_view{{flatten_col}},
std::vector<size_type>{0},
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
duplicate_keep_option::KEEP_ANY,
nulls_equal,
nans_equal,
stream,
mr);
return std::make_unique<list_scalar>(std::move(distinct_table->get_column(0)), true, stream, mr);
}

} // namespace reduction
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/rolling/rolling_detail.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
#include <cudf/detail/utilities/device_operators.cuh>
#include <cudf/dictionary/dictionary_column_view.hpp>
#include <cudf/dictionary/dictionary_factories.hpp>
#include <cudf/lists/detail/drop_list_duplicates.hpp>
#include <cudf/lists/detail/stream_compaction.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/bit.hpp>
#include <cudf/utilities/error.hpp>
Expand Down Expand Up @@ -929,8 +929,8 @@ class rolling_aggregation_postprocessor final : public cudf::detail::aggregation
stream,
rmm::mr::get_current_device_resource());

result = lists::detail::drop_list_duplicates(
lists_column_view(collected_list->view()), agg._nulls_equal, agg._nans_equal, stream, mr);
result = lists::detail::distinct(
lists_column_view{collected_list->view()}, agg._nulls_equal, agg._nans_equal, stream, mr);
}

// perform the element-wise square root operation on result of VARIANCE
Expand Down
Loading