Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support contains() on lists of primitives #7039

Merged
merged 34 commits into from
Jan 25, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
dd01ba1
Support contains() on lists of primitives
mythrocks Dec 17, 2020
4558261
Support contains() on lists of primitives
mythrocks Dec 18, 2020
a87ffb0
Support contains() on lists of primitives
mythrocks Dec 18, 2020
21aae8a
Support contains() on lists of primitives
mythrocks Dec 18, 2020
4e5b819
Support contains() on lists of primitives:
mythrocks Dec 19, 2020
752379a
Support contains() on lists of primitives:
mythrocks Dec 19, 2020
551c014
Support contains() on lists of primitives:
mythrocks Dec 19, 2020
36ffde8
Support contains() on lists of primitives:
mythrocks Dec 19, 2020
2e7d725
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks Dec 19, 2020
80510cf
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks Jan 4, 2021
f9bb045
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks Jan 5, 2021
f880b46
Support contains() on lists of primitives:
mythrocks Jan 5, 2021
ce92cc5
Support contains() on lists of primitives
mythrocks Jan 5, 2021
9c92bf8
Support contains() on lists of primitives
mythrocks Jan 6, 2021
542e39a
Support contains() on lists of primitives:
mythrocks Jan 6, 2021
0a60a12
Support contains() on lists of primitives
mythrocks Jan 7, 2021
63d1a4a
Chrono support in lists::contains()
mythrocks Jan 7, 2021
ed9269f
Added validity iterator for scalars.
mythrocks Jan 7, 2021
2d469b7
Collapsed construct_null_mask() single function
mythrocks Jan 7, 2021
a398bb6
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks Jan 7, 2021
6ebed08
Collapsed contains() implementation to one impl function
mythrocks Jan 8, 2021
e01e860
Renamed search_key_has_all_nulls
mythrocks Jan 13, 2021
21d06e9
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks Jan 14, 2021
792bd86
Fixed behaviour for lists containing nulls:
mythrocks Jan 15, 2021
bb800a7
Cleaned up SFINAE is_supported().
mythrocks Jan 17, 2021
1210f83
Remove names for unused function/template parameters.
mythrocks Jan 19, 2021
c77d1a1
Code formatting.
mythrocks Jan 19, 2021
21bd39d
Fixed documentation for make_validity_iterator(scalar)
mythrocks Jan 19, 2021
793863a
Added Doxygen directives.
mythrocks Jan 21, 2021
b0ad781
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks Jan 21, 2021
4225a62
Move namespace directives to file level.
mythrocks Jan 25, 2021
f7d4bad
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks Jan 25, 2021
d6ddee0
Fix doxygen_groups.h lists_contains group
mythrocks Jan 25, 2021
b3b5b6c
Move tests to namespace.
mythrocks Jan 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- PR #6907 Add `replace_null` API with `replace_policy` parameter, `fixed_width` column support

- PR #6775 Implement cudf.DateOffset for months
- PR #7039 Support contains() on lists of primitives

## Improvements

Expand Down
1 change: 1 addition & 0 deletions conda/recipes/libcudf/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ test:
- test -f $PREFIX/include/cudf/lists/detail/concatenate.hpp
- test -f $PREFIX/include/cudf/lists/detail/copying.hpp
- test -f $PREFIX/include/cudf/lists/extract.hpp
- test -f $PREFIX/include/cudf/lists/contains.hpp
- test -f $PREFIX/include/cudf/lists/lists_column_view.hpp
- test -f $PREFIX/include/cudf/merge.hpp
- test -f $PREFIX/include/cudf/null_mask.hpp
Expand Down
71 changes: 71 additions & 0 deletions cpp/include/cudf/lists/contains.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cudf/column/column.hpp>
#include <cudf/lists/lists_column_view.hpp>

namespace cudf {
namespace lists {

mythrocks marked this conversation as resolved.
Show resolved Hide resolved
/**
* @brief Create a column of bool values indicating whether the specified scalar
* is an element of each row of a list column.
*
* The output column has as many elements as the input `lists` column.
* Output `column[i]` is set to true if the lists row `lists[i]` contains the value
* specified in skey. Otherwise, it is set to false.
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
*
* Output `column[i]` is set to null if even one of the following holds true:
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
* 1. The search key `skey` is null
* 2. The list row `lists[i]` is null
* 3. The list row `lists[i]` contains even *one* null
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
*
* @param lists Lists column whose `n` rows are to be searched
* @param skey The scalar key to be looked up in each list row
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return std::unique_ptr<column> BOOL8 column of `n` rows with the result of the lookup
*/
std::unique_ptr<column> contains(
cudf::lists_column_view const& lists,
cudf::scalar const& skey,
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Create a column of bool values indicating whether the list rows of the first
* column contain the corresponding values in the second column
*
* The output column has as many elements as the input `lists` column.
* Output `column[i]` is set to true if the lists row `lists[i]` contains the value
* in `skey[i]`. Otherwise, it is set to false.
*
* Output `column[i]` is set to null if even one of the following holds true:
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
* 1. The row `skey[i]` is null
* 2. The list row `lists[i]` is null
* 3. The list row `lists[i]` contains even *one* null
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
*
* @param lists Lists column whose `n` rows are to be searched
* @param skeys Column of elements to be looked up in each list row
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return std::unique_ptr<column> BOOL8 column of `n` rows with the result of the lookup
*/
std::unique_ptr<column> contains(
cudf::lists_column_view const& lists,
cudf::column_view const& skeys,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

} // namespace lists
} // namespace cudf
266 changes: 266 additions & 0 deletions cpp/src/lists/contains.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <thrust/logical.h>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/valid_if.cuh>
#include <cudf/lists/contains.hpp>
#include <cudf/lists/list_device_view.cuh>
#include <cudf/lists/lists_column_device_view.cuh>
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/scalar/scalar.hpp>
#include <cudf/scalar/scalar_device_view.cuh>
#include <cudf/utilities/type_dispatcher.hpp>
#include <rmm/exec_policy.hpp>
#include <type_traits>

namespace cudf {
namespace lists {

namespace {

auto CUDA_HOST_DEVICE_CALLABLE counting_iter(size_type n)
{
return thrust::make_counting_iterator(n);
}
mythrocks marked this conversation as resolved.
Show resolved Hide resolved

std::pair<rmm::device_buffer, size_type> construct_null_mask(
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
cudf::detail::lists_column_device_view const& d_lists,
cudf::scalar const& skey,
bool input_column_has_nulls,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
using namespace cudf;
using namespace cudf::detail;

if (skey.is_valid(stream) && !input_column_has_nulls) {
return std::make_pair(rmm::device_buffer{0, stream, mr}, size_type{0});
}

if (!skey.is_valid(stream)) {
return std::make_pair(cudf::create_null_mask(d_lists.size(), mask_state::ALL_NULL, mr),
d_lists.size());
}

return cudf::detail::valid_if(
counting_iter(0), counting_iter(d_lists.size()), [d_lists] __device__(auto const& row_index) {
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
auto list = cudf::list_device_view(d_lists, row_index);
if (list.is_null()) { return false; }
return thrust::none_of(thrust::seq,
counting_iter(0),
counting_iter(list.size()),
[&list] __device__(auto const& i) { return list.is_null(i); });
});
}

std::pair<rmm::device_buffer, size_type> construct_null_mask(
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
cudf::detail::lists_column_device_view const& d_lists,
cudf::column_device_view const& d_skeys,
bool input_column_has_nulls,
bool skeys_column_has_nulls,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
using namespace cudf;
using namespace cudf::detail;

if (!skeys_column_has_nulls && !input_column_has_nulls) {
return std::make_pair(rmm::device_buffer{0, stream, mr}, size_type{0});
}

return cudf::detail::valid_if(counting_iter(0),
counting_iter(d_lists.size()),
[d_lists, d_skeys] __device__(auto const& row_index) {
if (d_skeys.is_null(row_index)) { return false; }

auto list = cudf::list_device_view(d_lists, row_index);

if (list.is_null()) { return false; }

return thrust::none_of(
thrust::seq,
counting_iter(0),
counting_iter(list.size()),
[&list] __device__(auto const& i) { return list.is_null(i); });
});
}

struct lookup_functor {
template <typename T, typename... Args>
std::enable_if_t<!cudf::is_numeric<T>() && !std::is_same<T, cudf::string_view>::value, void>
operator()(Args&&...) const
{
CUDF_FAIL("lists::contains() is only supported on numeric types and strings.");
}

template <typename T>
std::enable_if_t<cudf::is_numeric<T>() || std::is_same<T, cudf::string_view>::value, void>
operator()(cudf::detail::lists_column_device_view const& d_lists,
cudf::scalar const& skey,
cudf::mutable_column_device_view output_bools,
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr) const
{
assert(skey.is_valid() && "skey should have been checked for nulls by this point.");

auto h_scalar = static_cast<cudf::scalar_type_t<T> const&>(skey);
auto d_scalar = cudf::get_scalar_device_view(h_scalar);

thrust::transform(rmm::exec_policy(stream),
counting_iter(0),
counting_iter(d_lists.size()),
output_bools.data<bool>(),
[d_lists, d_scalar] __device__(auto row_index) {
auto list = cudf::list_device_view(d_lists, row_index);
if (list.is_null()) { return false; }
for (size_type i{0}; i < list.size(); ++i) {
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
if (list.is_null(i)) { return false; }
auto list_element = list.template element<T>(i);
if (list_element == d_scalar.value()) { return true; }
}
return false;
});
}

template <typename T>
std::enable_if_t<cudf::is_numeric<T>() || std::is_same<T, cudf::string_view>::value, void>
operator()(cudf::detail::lists_column_device_view const& d_lists,
cudf::column_device_view const& d_skeys,
cudf::mutable_column_device_view output_bools,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr) const
{
thrust::transform(rmm::exec_policy(stream),
counting_iter(0),
counting_iter(d_lists.size()),
output_bools.data<bool>(),
[d_lists, d_skeys] __device__(auto row_index) {
if (d_skeys.is_null(row_index)) { return false; }
auto list = cudf::list_device_view(d_lists, row_index);
if (list.is_null()) { return false; }
auto skey = d_skeys.template element<T>(row_index);
for (size_type i{0}; i < list.size(); ++i) {
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
if (list.is_null(i)) { return false; }
auto list_element = list.template element<T>(i);
if (list_element == skey) { return true; }
}
return false;
});
}
};

} // namespace

namespace detail {

std::unique_ptr<column> contains(cudf::lists_column_view const& lists,
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
cudf::scalar const& skey,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
using namespace cudf;
using namespace cudf::detail;

CUDF_EXPECTS(!cudf::is_nested(lists.child().type()),
"Nested types not supported in lists::contains()");
CUDF_EXPECTS(lists.child().type().id() == skey.type().id(),
"Type of search key does not match list column element type.");
CUDF_EXPECTS(skey.type().id() != type_id::EMPTY, "Type cannot be empty.");

auto const device_view = column_device_view::create(lists.parent(), stream);
auto const d_lists = lists_column_device_view(*device_view);

rmm::device_buffer null_mask;
size_type num_nulls;

std::tie(null_mask, num_nulls) =
construct_null_mask(d_lists, skey, lists.has_nulls() || lists.child().has_nulls(), stream, mr);

auto ret_bools = make_fixed_width_column(
data_type{type_id::BOOL8}, lists.size(), std::move(null_mask), num_nulls, stream, mr);

if (skey.is_valid()) {
auto ret_bools_mutable_device_view =
mutable_column_device_view::create(ret_bools->mutable_view(), stream);

cudf::type_dispatcher(
skey.type(), lookup_functor{}, d_lists, skey, *ret_bools_mutable_device_view, stream, mr);
}

return ret_bools;
}

std::unique_ptr<column> contains(cudf::lists_column_view const& lists,
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
cudf::column_view const& skeys,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
using namespace cudf;
using namespace cudf::detail;

CUDF_EXPECTS(!cudf::is_nested(lists.child().type()),
"Nested types not supported in lists::contains()");
CUDF_EXPECTS(lists.child().type().id() == skeys.type().id(),
"Type of search key does not match list column element type.");
CUDF_EXPECTS(skeys.size() == lists.size(), "Number of search keys must match list column size.");
CUDF_EXPECTS(skeys.type().id() != type_id::EMPTY, "Type cannot be empty.");

auto const device_view = column_device_view::create(lists.parent(), stream);
auto const d_lists = lists_column_device_view(*device_view);
auto const d_skeys = column_device_view::create(skeys, stream);

rmm::device_buffer null_mask;
size_type num_nulls;

std::tie(null_mask, num_nulls) =
construct_null_mask(d_lists,
*d_skeys,
lists.has_nulls() || lists.child().has_nulls(),
skeys.has_nulls(),
stream,
mr);

auto ret_bools = make_fixed_width_column(
data_type{type_id::BOOL8}, lists.size(), std::move(null_mask), num_nulls, stream, mr);

auto ret_bools_mutable_device_view =
mutable_column_device_view::create(ret_bools->mutable_view(), stream);

cudf::type_dispatcher(
skeys.type(), lookup_functor{}, d_lists, *d_skeys, *ret_bools_mutable_device_view, stream, mr);

return ret_bools;
}

} // namespace detail

std::unique_ptr<column> contains(cudf::lists_column_view const& lists,
cudf::scalar const& skey,
rmm::mr::device_memory_resource* mr)
{
return detail::contains(lists, skey, rmm::cuda_stream_default, mr);
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
}

std::unique_ptr<column> contains(cudf::lists_column_view const& lists,
cudf::column_view const& skeys,
rmm::mr::device_memory_resource* mr)
{
return detail::contains(lists, skeys, rmm::cuda_stream_default, mr);
}

} // namespace lists
} // namespace cudf
3 changes: 2 additions & 1 deletion cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,8 @@ ConfigureTest(AST_TEST "${AST_TEST_SRC}")
# - lists tests ----------------------------------------------------------------------------------

set(LISTS_TEST_SRC
"${CMAKE_CURRENT_SOURCE_DIR}/lists/extract_tests.cpp")
"${CMAKE_CURRENT_SOURCE_DIR}/lists/extract_tests.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/lists/contains_tests.cpp")

ConfigureTest(LISTS_TEST "${LISTS_TEST_SRC}")

Expand Down
Loading