Skip to content

Commit

Permalink
Support contains() on lists of primitives (#7039)
Browse files Browse the repository at this point in the history
Closes #6944.

This commit adds a method (`contains()`) to check whether each row of a `LIST` column contains the scalar value specified as an argument. The operation returns a `BOOL8` column (with as many rows as the input `LIST`), each row indicating `true` if the value is found, `false` if not.

Output `column[i]` is set to null if even one of the following holds true (in line with the semantics of `array_contains()` in SQL):
  1. The search key `skey` is null
  2. The list row `lists[i]` is null
  3. The list row `lists[i]` contains even *one* null, *and* `lists[i]` does not contain the search key.

This implementation currently supports the operation on lists of numerics or strings.

Authors:
  - MithunR (@mythrocks)

Approvers:
  - AJ Schmidt (@ajschmidt8)
  - Mark Harris (@harrism)
  - David (@davidwendt)
  - Karthikeyan (@karthikeyann)

URL: #7039
  • Loading branch information
mythrocks authored Jan 25, 2021
1 parent eb1336f commit b1e9e20
Show file tree
Hide file tree
Showing 9 changed files with 988 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- PR #6907 Add `replace_null` API with `replace_policy` parameter, `fixed_width` column support
- PR #6885 Share `factorize` implementation with Index and cudf module
- PR #6775 Implement cudf.DateOffset for months
- PR #7039 Support contains() on lists of primitives

## Improvements

Expand Down
1 change: 1 addition & 0 deletions conda/recipes/libcudf/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ test:
- test -f $PREFIX/include/cudf/lists/detail/copying.hpp
- test -f $PREFIX/include/cudf/lists/count_elements.hpp
- test -f $PREFIX/include/cudf/lists/extract.hpp
- test -f $PREFIX/include/cudf/lists/contains.hpp
- test -f $PREFIX/include/cudf/lists/gather.hpp
- test -f $PREFIX/include/cudf/lists/lists_column_view.hpp
- test -f $PREFIX/include/cudf/merge.hpp
Expand Down
15 changes: 15 additions & 0 deletions cpp/include/cudf/detail/iterator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,21 @@ auto inline make_validity_iterator(column_device_view const& column)
validity_accessor{column});
}

/**
* @brief Constructs a constant device iterator over a scalar's validity.
*
* Dereferencing the returned iterator returns a `bool`.
*
* For `p = *(iter + i)`, `p` is the validity of the scalar.
*
* @param scalar_value The scalar to iterate
* @return auto Iterator that returns scalar validity
*/
auto inline make_validity_iterator(scalar const& scalar_value)
{
return thrust::make_constant_iterator(scalar_value.is_valid());
}

/**
* @brief value accessor for scalar with valid data.
* The unary functor returns data of Element type of the scalar.
Expand Down
79 changes: 79 additions & 0 deletions cpp/include/cudf/lists/contains.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cudf/column/column.hpp>
#include <cudf/lists/lists_column_view.hpp>

namespace cudf {
namespace lists {
/**
* @addtogroup lists_contains
* @{
* @file
*/

/**
* @brief Create a column of bool values indicating whether the specified scalar
* is an element of each row of a list column.
*
* The output column has as many elements as the input `lists` column.
* Output `column[i]` is set to true if the lists row `lists[i]` contains the value
* specified in `search_key`. Otherwise, it is set to false.
*
* Output `column[i]` is set to null if one or more of the following are true:
* 1. The search key `search_key` is null
* 2. The list row `lists[i]` is null
* 3. The list row `lists[i]` does not contain the search key, and contains at least
* one null.
*
* @param lists Lists column whose `n` rows are to be searched
* @param search_key The scalar key to be looked up in each list row
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return std::unique_ptr<column> BOOL8 column of `n` rows with the result of the lookup
*/
std::unique_ptr<column> contains(
cudf::lists_column_view const& lists,
cudf::scalar const& search_key,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Create a column of bool values indicating whether the list rows of the first
* column contain the corresponding values in the second column
*
* The output column has as many elements as the input `lists` column.
* Output `column[i]` is set to true if the lists row `lists[i]` contains the value
* in `search_keys[i]`. Otherwise, it is set to false.
*
* Output `column[i]` is set to null if one or more of the following are true:
* 1. The row `search_keys[i]` is null
* 2. The list row `lists[i]` is null
* 3. The list row `lists[i]` does not contain the `search_keys[i]`, and contains at least
* one null.
*
* @param lists Lists column whose `n` rows are to be searched
* @param search_keys Column of elements to be looked up in each list row
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return std::unique_ptr<column> BOOL8 column of `n` rows with the result of the lookup
*/
std::unique_ptr<column> contains(
cudf::lists_column_view const& lists,
cudf::column_view const& search_keys,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/** @} */ // end of group
} // namespace lists
} // namespace cudf
70 changes: 70 additions & 0 deletions cpp/include/cudf/lists/list_device_view.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,82 @@ class list_device_view {
*/
CUDA_DEVICE_CALLABLE lists_column_device_view const& get_column() const { return lists_column; }

template <typename T>
struct pair_accessor;

template <typename T>
using const_pair_iterator =
thrust::transform_iterator<pair_accessor<T>, thrust::counting_iterator<cudf::size_type>>;

/**
* @brief Fetcher for a pair iterator to the first element in the list_device_view.
*
* Dereferencing the returned iterator yields a `thrust::pair<T, bool>`.
*
* If the element at index `i` is valid, then for `p = iter[i]`,
* 1. `p.first` is the value of the element at `i`
* 2. `p.second == true`
*
* If the element at index `i` is null,
* 1. `p.first` is undefined
* 2. `p.second == false`
*/
template <typename T>
CUDA_DEVICE_CALLABLE const_pair_iterator<T> pair_begin() const
{
return const_pair_iterator<T>{thrust::counting_iterator<size_type>(0), pair_accessor<T>{*this}};
}

/**
* @brief Fetcher for a pair iterator to one position past the last element in the
* list_device_view.
*/
template <typename T>
CUDA_DEVICE_CALLABLE const_pair_iterator<T> pair_end() const
{
return const_pair_iterator<T>{thrust::counting_iterator<size_type>(size()),
pair_accessor<T>{*this}};
}

private:
lists_column_device_view const& lists_column;
size_type _row_index{}; // Row index in the Lists column vector.
size_type _size{}; // Number of elements in *this* list row.

size_type begin_offset; // Offset in list_column_device_view where this list begins.

/**
* @brief pair accessor for elements in a `list_device_view`
*
* This unary functor returns a pair of:
* 1. data element at a specified index
* 2. boolean validity flag for that element
*
* @tparam T The element-type of the list row
*/
template <typename T>
struct pair_accessor {
list_device_view const& list;

/**
* @brief constructor
*
* @param _list The `list_device_view` whose rows are being accessed.
*/
explicit CUDA_HOST_DEVICE_CALLABLE pair_accessor(list_device_view const& _list) : list{_list} {}

/**
* @brief Accessor for the {data, validity} pair at the specified index
*
* @param i Index into the list_device_view
* @return A pair of data element and its validity flag.
*/
CUDA_DEVICE_CALLABLE
thrust::pair<T, bool> operator()(cudf::size_type i) const
{
return {list.element<T>(i), !list.is_null(i)};
}
};
};

} // namespace cudf
2 changes: 2 additions & 0 deletions cpp/include/doxygen_groups.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@
* @defgroup lists_apis Lists
* @{
* @defgroup lists_extract Extracting
* @defgroup lists_contains Searching
* @defgroup lists_gather Gathering
* @defgroup lists_elements Counting
* @}
* @defgroup nvtext_apis NVText
Expand Down
Loading

0 comments on commit b1e9e20

Please sign in to comment.