Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move nullable index iterator to indexalator factory #7399

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion cpp/include/cudf/detail/indexalator.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
#pragma once

#include <cudf/column/column_view.hpp>
#include <cudf/detail/iterator.cuh>
#include <cudf/scalar/scalar.hpp>
#include <cudf/utilities/traits.hpp>

Expand Down Expand Up @@ -480,6 +481,58 @@ struct indexalator_factory {
{
return type_dispatcher(indices.type(), output_indexalator_fn{}, indices);
}

/**
* @brief An index accessor that returns a validity flag along with the index value.
*
* This is suitable as a `pair_iterator` for calling functions like `copy_if_else`.
*/
struct nullable_index_accessor {
input_indexalator iter;
bitmask_type const* null_mask{};
size_type const offset{};
bool const has_nulls{};

/**
* @brief Create an accessor from a column_view.
*/
nullable_index_accessor(column_view const& col, bool has_nulls = false)
: null_mask{col.null_mask()}, offset{col.offset()}, has_nulls{has_nulls}
{
if (has_nulls) { CUDF_EXPECTS(col.nullable(), "Unexpected non-nullable column."); }
iter = make_input_iterator(col);
}

/**
* @brief Create an accessor from a scalar.
*/
nullable_index_accessor(scalar const& input) : has_nulls{!input.is_valid()}
{
iter = indexalator_factory::make_input_iterator(input);
}

__device__ thrust::pair<size_type, bool> operator()(size_type i) const
{
return {iter[i], (has_nulls ? bit_is_set(null_mask, i + offset) : true)};
}
};

/**
* @brief Create an index iterator with a nullable index accessor.
*/
static auto make_input_pair_iterator(column_view const& col)
{
return make_counting_transform_iterator(0, nullable_index_accessor{col, col.has_nulls()});
}

/**
* @brief Create an index iterator with a nullable index accessor for a scalar.
*/
static auto make_input_pair_iterator(scalar const& input)
{
return thrust::make_transform_iterator(thrust::make_constant_iterator<size_type>(0),
nullable_index_accessor{input});
}
};

} // namespace detail
Expand Down
71 changes: 10 additions & 61 deletions cpp/src/dictionary/replace.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,66 +27,11 @@

#include <rmm/cuda_stream_view.hpp>

#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/counting_iterator.h>

namespace cudf {
namespace dictionary {
namespace detail {
namespace {

/**
* @brief An index accessor that returns a validity flag along with the index value.
*
* This is used to make a `pair_iterator` for calling `copy_if_else`.
*/
template <bool has_nulls = false>
struct nullable_index_accessor {
cudf::detail::input_indexalator iter;
bitmask_type const* null_mask{};
size_type const offset{};

/**
* @brief Create an accessor from a column_view.
*/
nullable_index_accessor(column_view const& col) : null_mask{col.null_mask()}, offset{col.offset()}
{
if (has_nulls) { CUDF_EXPECTS(col.nullable(), "Unexpected non-nullable column."); }
iter = cudf::detail::indexalator_factory::make_input_iterator(col);
}

/**
* @brief Create an accessor from a scalar.
*/
nullable_index_accessor(scalar const& input)
{
iter = cudf::detail::indexalator_factory::make_input_iterator(input);
}

__device__ thrust::pair<size_type, bool> operator()(size_type i) const
{
return {iter[i], (has_nulls ? bit_is_set(null_mask, i + offset) : true)};
}
};

/**
* @brief Create an index iterator with a nullable index accessor.
*/
template <bool has_nulls>
auto make_nullable_index_iterator(column_view const& col)
{
return cudf::detail::make_counting_transform_iterator(0, nullable_index_accessor<has_nulls>{col});
}

/**
* @brief Create an index iterator with a nullable index accessor for a scalar.
*/
auto make_scalar_iterator(scalar const& input)
{
return thrust::make_transform_iterator(thrust::make_constant_iterator<size_type>(0),
nullable_index_accessor<false>{input});
}

/**
* @brief This utility uses `copy_if_else` to replace null entries using the input bitmask as a
* predicate.
Expand Down Expand Up @@ -116,7 +61,7 @@ std::unique_ptr<column> replace_indices(column_view const& input,
using Element = typename thrust::
tuple_element<0, typename thrust::iterator_traits<ReplacementIter>::value_type>::type;

auto input_pair_iterator = make_nullable_index_iterator<true>(input);
auto input_pair_iterator = cudf::detail::indexalator_factory::make_input_pair_iterator(input);

return cudf::detail::copy_if_else(true,
input_pair_iterator,
Expand Down Expand Up @@ -151,11 +96,12 @@ std::unique_ptr<column> replace_nulls(dictionary_column_view const& input,
auto const input_indices =
dictionary_column_view(matched.front()->view()).get_indices_annotated();
auto const repl_indices = dictionary_column_view(matched.back()->view()).get_indices_annotated();

auto new_indices =
repl_indices.has_nulls()
? replace_indices(input_indices, make_nullable_index_iterator<true>(repl_indices), stream, mr)
: replace_indices(
input_indices, make_nullable_index_iterator<false>(repl_indices), stream, mr);
replace_indices(input_indices,
cudf::detail::indexalator_factory::make_input_pair_iterator(repl_indices),
stream,
mr);

return make_dictionary_column(
std::move(matched.front()->release().children.back()), std::move(new_indices), stream, mr);
Expand Down Expand Up @@ -185,7 +131,10 @@ std::unique_ptr<column> replace_nulls(dictionary_column_view const& input,
// now build the new indices by doing replace-null on the updated indices
auto const input_indices = input_view.get_indices_annotated();
auto new_indices =
replace_indices(input_indices, make_scalar_iterator(*scalar_index), stream, mr);
replace_indices(input_indices,
cudf::detail::indexalator_factory::make_input_pair_iterator(*scalar_index),
stream,
mr);
new_indices->set_null_mask(rmm::device_buffer{0, stream, mr}, 0);

return make_dictionary_column(
Expand Down