From 3a67f5ec803dc3a39365b682704bde8db36444b1 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Fri, 10 Jun 2022 11:49:37 -0700 Subject: [PATCH] Use `static_multimap` Signed-off-by: Nghia Truong --- cpp/src/search/contains_nested.cu | 66 +++++++++++++++---------------- 1 file changed, 31 insertions(+), 35 deletions(-) diff --git a/cpp/src/search/contains_nested.cu b/cpp/src/search/contains_nested.cu index 7b2903250fe..5e4d41e66ea 100644 --- a/cpp/src/search/contains_nested.cu +++ b/cpp/src/search/contains_nested.cu @@ -29,6 +29,8 @@ #include #include +#include + namespace cudf::detail { bool contains_nested_element(column_view const& haystack, @@ -92,61 +94,55 @@ std::unique_ptr multi_contains_nested_elements(column_view const& haysta using cudf::experimental::row::lhs_index_type; using cudf::experimental::row::rhs_index_type; - using static_map = cuco::static_map; + using static_map = cuco::static_multimap; auto haystack_map = static_map{compute_hash_table_size(haystack.size()), - cuco::sentinel::empty_key{lhs_index_type{detail::COMPACTION_EMPTY_KEY_SENTINEL}}, + cuco::sentinel::empty_key{hash_value_type{detail::COMPACTION_EMPTY_KEY_SENTINEL}}, cuco::sentinel::empty_value{lhs_index_type{detail::COMPACTION_EMPTY_KEY_SENTINEL}}, - detail::hash_table_allocator_type{default_allocator{}, stream}, - stream.value()}; - - // Insert all indices of the elements in the haystack column into the hash map. - // As such, we will use `thrust::equal_to` as key comparator to not ignore any key. - // - // An alternative way to this is to use `self_comparator` for key comparisons, which would only - // insert unique rows of the haystack column. This would save some memory (or significant amount - // of memory if the haystack column contains many duplicate rows), however, will result in much - // more processing time due to expensive row comparisons. - { - auto const haystack_it = cudf::detail::make_counting_transform_iterator( - size_type{0}, [] __device__(size_type const i) { - return cuco::make_pair(lhs_index_type{i}, lhs_index_type{i}); - }); + stream.value(), + detail::hash_table_allocator_type{default_allocator{}, stream}}; + { auto const hasher = cudf::experimental::row::hash::row_hasher(haystack_tv, stream); auto const d_hasher = detail::experimental::compaction_hash( hasher.device_hasher(nullate::DYNAMIC{haystack_has_nulls})); - haystack_map.insert( - haystack_it, haystack_it + haystack.size(), d_hasher, thrust::equal_to{}, stream.value()); + auto const haystack_it = cudf::detail::make_counting_transform_iterator( + size_type{0}, [d_hasher] __device__(size_type const i) { + return cuco::make_pair(d_hasher(i), lhs_index_type{i}); + }); + + haystack_map.insert(haystack_it, haystack_it + haystack.size(), stream.value()); } - // Check for existence of needles in haystack. - // During this, we will use `negative_index_comparator_adapter` to convert the existing indices - // in the hash map (i.e., indices of haystack elements) into `lhs_index_type`, and indices of the - // searching needles into `rhs_index_type` for row comparisons. { - // A reverse iterator constructed from `0` value will begin from `-1`. - // Thus, needle indices will iterate in reverse order in the range `[-1, -1-needles.size())`. - // They will be converted back to the range `[0, needles.size())` then into `rhs_index_type` - // automatically by `negative_index_hasher_adapter` and `negative_index_comparator_adapter`. - auto const needles_it = cudf::detail::make_counting_transform_iterator( - size_type{0}, [] __device__(size_type const i) { return rhs_index_type{i}; }); - auto const hasher = cudf::experimental::row::hash::row_hasher(needles_tv, stream); auto const d_hasher = detail::experimental::compaction_hash( hasher.device_hasher(nullate::DYNAMIC{needles_has_nulls})); + auto const needles_it = cudf::detail::make_counting_transform_iterator( + size_type{0}, [d_hasher] __device__(size_type const i) { + return cuco::make_pair(d_hasher(i), rhs_index_type{i}); + }); + auto const comparator = cudf::experimental::row::equality::two_table_comparator(haystack_tv, needles_tv, stream); auto const d_eqcomp = comparator.equal_to(nullate::DYNAMIC{haystack_has_nulls || needles_has_nulls}); - haystack_map.contains( - needles_it, needles_it + needles.size(), out_begin, d_hasher, d_eqcomp, stream.value()); + haystack_map.pair_contains( + needles_it, + needles_it + needles.size(), + out_begin, + [d_eqcomp] __device__(auto const& lhs_hash_and_index, auto const& rhs_hash_and_index) { + auto const& [lhs_hash, lhs_index] = lhs_hash_and_index; + auto const& [rhs_hash, rhs_index] = rhs_hash_and_index; + return lhs_hash == rhs_hash ? d_eqcomp(lhs_index, rhs_index) : false; + }, + stream.value()); } return result;