From e0a5ae4c7ebf7c0b37f12f954d4daf38f89c02e2 Mon Sep 17 00:00:00 2001 From: davidwendt Date: Wed, 17 Feb 2021 09:36:13 -0500 Subject: [PATCH] Move nullable index iterator to indexalator factory --- cpp/include/cudf/detail/indexalator.cuh | 55 ++++++++++++++++++- cpp/src/dictionary/replace.cu | 71 ++++--------------------- 2 files changed, 64 insertions(+), 62 deletions(-) diff --git a/cpp/include/cudf/detail/indexalator.cuh b/cpp/include/cudf/detail/indexalator.cuh index b346336959c..8568bd68bfd 100644 --- a/cpp/include/cudf/detail/indexalator.cuh +++ b/cpp/include/cudf/detail/indexalator.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include @@ -480,6 +481,58 @@ struct indexalator_factory { { return type_dispatcher(indices.type(), output_indexalator_fn{}, indices); } + + /** + * @brief An index accessor that returns a validity flag along with the index value. + * + * This is suitable as a `pair_iterator` for calling functions like `copy_if_else`. + */ + struct nullable_index_accessor { + input_indexalator iter; + bitmask_type const* null_mask{}; + size_type const offset{}; + bool const has_nulls{}; + + /** + * @brief Create an accessor from a column_view. + */ + nullable_index_accessor(column_view const& col, bool has_nulls = false) + : null_mask{col.null_mask()}, offset{col.offset()}, has_nulls{has_nulls} + { + if (has_nulls) { CUDF_EXPECTS(col.nullable(), "Unexpected non-nullable column."); } + iter = make_input_iterator(col); + } + + /** + * @brief Create an accessor from a scalar. + */ + nullable_index_accessor(scalar const& input) : has_nulls{!input.is_valid()} + { + iter = indexalator_factory::make_input_iterator(input); + } + + __device__ thrust::pair operator()(size_type i) const + { + return {iter[i], (has_nulls ? bit_is_set(null_mask, i + offset) : true)}; + } + }; + + /** + * @brief Create an index iterator with a nullable index accessor. + */ + static auto make_input_pair_iterator(column_view const& col) + { + return make_counting_transform_iterator(0, nullable_index_accessor{col, col.has_nulls()}); + } + + /** + * @brief Create an index iterator with a nullable index accessor for a scalar. + */ + static auto make_input_pair_iterator(scalar const& input) + { + return thrust::make_transform_iterator(thrust::make_constant_iterator(0), + nullable_index_accessor{input}); + } }; } // namespace detail diff --git a/cpp/src/dictionary/replace.cu b/cpp/src/dictionary/replace.cu index a062c04169e..f8f1d01b4a5 100644 --- a/cpp/src/dictionary/replace.cu +++ b/cpp/src/dictionary/replace.cu @@ -27,66 +27,11 @@ #include -#include -#include - namespace cudf { namespace dictionary { namespace detail { namespace { -/** - * @brief An index accessor that returns a validity flag along with the index value. - * - * This is used to make a `pair_iterator` for calling `copy_if_else`. - */ -template -struct nullable_index_accessor { - cudf::detail::input_indexalator iter; - bitmask_type const* null_mask{}; - size_type const offset{}; - - /** - * @brief Create an accessor from a column_view. - */ - nullable_index_accessor(column_view const& col) : null_mask{col.null_mask()}, offset{col.offset()} - { - if (has_nulls) { CUDF_EXPECTS(col.nullable(), "Unexpected non-nullable column."); } - iter = cudf::detail::indexalator_factory::make_input_iterator(col); - } - - /** - * @brief Create an accessor from a scalar. - */ - nullable_index_accessor(scalar const& input) - { - iter = cudf::detail::indexalator_factory::make_input_iterator(input); - } - - __device__ thrust::pair operator()(size_type i) const - { - return {iter[i], (has_nulls ? bit_is_set(null_mask, i + offset) : true)}; - } -}; - -/** - * @brief Create an index iterator with a nullable index accessor. - */ -template -auto make_nullable_index_iterator(column_view const& col) -{ - return cudf::detail::make_counting_transform_iterator(0, nullable_index_accessor{col}); -} - -/** - * @brief Create an index iterator with a nullable index accessor for a scalar. - */ -auto make_scalar_iterator(scalar const& input) -{ - return thrust::make_transform_iterator(thrust::make_constant_iterator(0), - nullable_index_accessor{input}); -} - /** * @brief This utility uses `copy_if_else` to replace null entries using the input bitmask as a * predicate. @@ -116,7 +61,7 @@ std::unique_ptr replace_indices(column_view const& input, using Element = typename thrust:: tuple_element<0, typename thrust::iterator_traits::value_type>::type; - auto input_pair_iterator = make_nullable_index_iterator(input); + auto input_pair_iterator = cudf::detail::indexalator_factory::make_input_pair_iterator(input); return cudf::detail::copy_if_else(true, input_pair_iterator, @@ -151,11 +96,12 @@ std::unique_ptr replace_nulls(dictionary_column_view const& input, auto const input_indices = dictionary_column_view(matched.front()->view()).get_indices_annotated(); auto const repl_indices = dictionary_column_view(matched.back()->view()).get_indices_annotated(); + auto new_indices = - repl_indices.has_nulls() - ? replace_indices(input_indices, make_nullable_index_iterator(repl_indices), stream, mr) - : replace_indices( - input_indices, make_nullable_index_iterator(repl_indices), stream, mr); + replace_indices(input_indices, + cudf::detail::indexalator_factory::make_input_pair_iterator(repl_indices), + stream, + mr); return make_dictionary_column( std::move(matched.front()->release().children.back()), std::move(new_indices), stream, mr); @@ -185,7 +131,10 @@ std::unique_ptr replace_nulls(dictionary_column_view const& input, // now build the new indices by doing replace-null on the updated indices auto const input_indices = input_view.get_indices_annotated(); auto new_indices = - replace_indices(input_indices, make_scalar_iterator(*scalar_index), stream, mr); + replace_indices(input_indices, + cudf::detail::indexalator_factory::make_input_pair_iterator(*scalar_index), + stream, + mr); new_indices->set_null_mask(rmm::device_buffer{0, stream, mr}, 0); return make_dictionary_column(