diff --git a/cpp/benchmarks/common/generate_input.cu b/cpp/benchmarks/common/generate_input.cu index cf597e644aa..bb7529bb37a 100644 --- a/cpp/benchmarks/common/generate_input.cu +++ b/cpp/benchmarks/common/generate_input.cu @@ -53,6 +53,8 @@ #include #include +#include + #include #include #include @@ -247,12 +249,12 @@ struct random_value_fn()>> { sec.end(), ns.begin(), result.begin(), - [] __device__(int64_t sec_value, int64_t nanoseconds_value) { + cuda::proclaim_return_type([] __device__(int64_t sec_value, int64_t nanoseconds_value) { auto const timestamp_ns = cudf::duration_s{sec_value} + cudf::duration_ns{nanoseconds_value}; // Return value in the type's precision return T(cuda::std::chrono::duration_cast(timestamp_ns)); - }); + })); return result; } }; @@ -367,12 +369,13 @@ rmm::device_uvector sample_indices_with_run_length(cudf::size_t // This is gather. auto avg_repeated_sample_indices_iterator = thrust::make_transform_iterator( thrust::make_counting_iterator(0), - [rb = run_lens.begin(), - re = run_lens.end(), - samples_indices = samples_indices.begin()] __device__(cudf::size_type i) { - auto sample_idx = thrust::upper_bound(thrust::seq, rb, re, i) - rb; - return samples_indices[sample_idx]; - }); + cuda::proclaim_return_type( + [rb = run_lens.begin(), + re = run_lens.end(), + samples_indices = samples_indices.begin()] __device__(cudf::size_type i) { + auto sample_idx = thrust::upper_bound(thrust::seq, rb, re, i) - rb; + return samples_indices[sample_idx]; + })); rmm::device_uvector repeated_sample_indices(num_rows, cudf::get_default_stream()); thrust::copy(thrust::device, @@ -513,7 +516,7 @@ std::unique_ptr create_random_utf8_string_column(data_profile cons lengths.end(), null_mask.begin(), lengths.begin(), - [] __device__(auto) { return 0; }, + cuda::proclaim_return_type([] __device__(auto) { return 0; }), thrust::logical_not{}); auto valid_lengths = thrust::make_transform_iterator( thrust::make_zip_iterator(thrust::make_tuple(lengths.begin(), null_mask.begin())), diff --git a/cpp/include/cudf/column/column_view.hpp b/cpp/include/cudf/column/column_view.hpp index d80c720a255..134e835911f 100644 --- a/cpp/include/cudf/column/column_view.hpp +++ b/cpp/include/cudf/column/column_view.hpp @@ -478,7 +478,10 @@ class mutable_column_view : public detail::column_view_base { public: mutable_column_view() = default; - ~mutable_column_view() = default; + ~mutable_column_view(){ + // Needed so that the first instance of the implicit destructor for any TU isn't 'constructed' + // from a host+device function marking the implicit version also as host+device + }; mutable_column_view(mutable_column_view const&) = default; ///< Copy constructor mutable_column_view(mutable_column_view&&) = default; ///< Move constructor diff --git a/cpp/include/cudf/detail/null_mask.cuh b/cpp/include/cudf/detail/null_mask.cuh index 78cd3d7bcb7..ae05d4c6954 100644 --- a/cpp/include/cudf/detail/null_mask.cuh +++ b/cpp/include/cudf/detail/null_mask.cuh @@ -37,6 +37,8 @@ #include #include +#include + #include #include #include @@ -330,20 +332,21 @@ rmm::device_uvector segmented_count_bits(bitmask_type const* bitmask, // set bits from the length of the segment. auto segments_begin = thrust::make_zip_iterator(first_bit_indices_begin, last_bit_indices_begin); - auto segment_length_iterator = - thrust::transform_iterator(segments_begin, [] __device__(auto const& segment) { + auto segment_length_iterator = thrust::transform_iterator( + segments_begin, cuda::proclaim_return_type([] __device__(auto const& segment) { auto const begin = thrust::get<0>(segment); auto const end = thrust::get<1>(segment); return end - begin; - }); + })); thrust::transform(rmm::exec_policy(stream), segment_length_iterator, segment_length_iterator + num_ranges, d_bit_counts.data(), d_bit_counts.data(), - [] __device__(auto segment_size, auto segment_bit_count) { - return segment_size - segment_bit_count; - }); + cuda::proclaim_return_type( + [] __device__(auto segment_size, auto segment_bit_count) { + return segment_size - segment_bit_count; + })); } CUDF_CHECK_CUDA(stream.value()); @@ -541,12 +544,12 @@ std::pair segmented_null_mask_reduction( { auto const segments_begin = thrust::make_zip_iterator(first_bit_indices_begin, last_bit_indices_begin); - auto const segment_length_iterator = - thrust::make_transform_iterator(segments_begin, [] __device__(auto const& segment) { + auto const segment_length_iterator = thrust::make_transform_iterator( + segments_begin, cuda::proclaim_return_type([] __device__(auto const& segment) { auto const begin = thrust::get<0>(segment); auto const end = thrust::get<1>(segment); return end - begin; - }); + })); auto const num_segments = static_cast(std::distance(first_bit_indices_begin, first_bit_indices_end)); diff --git a/cpp/include/cudf/detail/sizes_to_offsets_iterator.cuh b/cpp/include/cudf/detail/sizes_to_offsets_iterator.cuh index 155b1ce5691..358dcca02b9 100644 --- a/cpp/include/cudf/detail/sizes_to_offsets_iterator.cuh +++ b/cpp/include/cudf/detail/sizes_to_offsets_iterator.cuh @@ -27,6 +27,8 @@ #include #include +#include + #include namespace cudf { @@ -311,9 +313,10 @@ std::pair, size_type> make_offsets_child_column( // using exclusive-scan technically requires count+1 input values even though // the final input value is never used. // The input iterator is wrapped here to allow the last value to be safely read. - auto map_fn = [begin, count] __device__(size_type idx) -> size_type { - return idx < count ? static_cast(begin[idx]) : size_type{0}; - }; + auto map_fn = + cuda::proclaim_return_type([begin, count] __device__(size_type idx) -> size_type { + return idx < count ? static_cast(begin[idx]) : size_type{0}; + }); auto input_itr = cudf::detail::make_counting_transform_iterator(0, map_fn); // Use the sizes-to-offsets iterator to compute the total number of elements auto const total_elements = sizes_to_offsets(input_itr, input_itr + count + 1, d_offsets, stream); diff --git a/cpp/include/cudf/detail/utilities/cast_functor.cuh b/cpp/include/cudf/detail/utilities/cast_functor.cuh new file mode 100644 index 00000000000..d5209942c8a --- /dev/null +++ b/cpp/include/cudf/detail/utilities/cast_functor.cuh @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +/** + * @brief A casting functor wrapping another functor. + * @file + */ + +#include + +#include + +#include +#include + +namespace cudf { +namespace detail { + +/** + * @brief Functor that casts another functor's result to a specified type. + * + * CUB 2.0.0 reductions require that the binary operator returns the same type + * as the initial value type, so we wrap binary operators with this when used + * by CUB. + */ +template +struct cast_functor_fn { + F f; + + template + CUDF_HOST_DEVICE inline ResultType operator()(Ts&&... args) + { + return static_cast(f(std::forward(args)...)); + } +}; + +/** + * @brief Function creating a casting functor. + */ +template +inline cast_functor_fn> cast_functor(F&& f) +{ + return cast_functor_fn>{std::forward(f)}; +} + +} // namespace detail + +} // namespace cudf diff --git a/cpp/include/cudf/detail/utilities/element_argminmax.cuh b/cpp/include/cudf/detail/utilities/element_argminmax.cuh index 45b56278dba..8bd34ef0237 100644 --- a/cpp/include/cudf/detail/utilities/element_argminmax.cuh +++ b/cpp/include/cudf/detail/utilities/element_argminmax.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ struct element_argminmax_fn { bool const has_nulls; bool const arg_min; - __device__ inline auto operator()(size_type const& lhs_idx, size_type const& rhs_idx) const + __device__ inline size_type operator()(size_type const& lhs_idx, size_type const& rhs_idx) const { // The extra bounds checking is due to issue github.com/rapidsai/cudf/9156 and // github.com/NVIDIA/thrust/issues/1525 diff --git a/cpp/include/cudf/lists/detail/gather.cuh b/cpp/include/cudf/lists/detail/gather.cuh index 18fe707fd69..4484a9995c3 100644 --- a/cpp/include/cudf/lists/detail/gather.cuh +++ b/cpp/include/cudf/lists/detail/gather.cuh @@ -30,6 +30,8 @@ #include #include +#include + namespace cudf { namespace lists { namespace detail { @@ -83,12 +85,12 @@ gather_data make_gather_data(cudf::lists_column_view const& source_column, auto sizes_itr = cudf::detail::make_counting_transform_iterator( 0, - [source_column_nullmask, - source_column_offset = source_column.offset(), - gather_map, - output_count, - src_offsets, - src_size] __device__(int32_t index) -> int32_t { + cuda::proclaim_return_type([source_column_nullmask, + source_column_offset = source_column.offset(), + gather_map, + output_count, + src_offsets, + src_size] __device__(int32_t index) -> int32_t { int32_t offset_index = index < output_count ? gather_map[index] : 0; // if this is an invalid index, this will be a NULL list @@ -102,7 +104,7 @@ gather_data make_gather_data(cudf::lists_column_view const& source_column, // the length of this list return src_offsets[offset_index + 1] - src_offsets[offset_index]; - }); + })); auto [dst_offsets_c, map_size] = cudf::detail::make_offsets_child_column(sizes_itr, sizes_itr + output_count, stream, mr); diff --git a/cpp/include/cudf/lists/detail/scatter.cuh b/cpp/include/cudf/lists/detail/scatter.cuh index ff148c59a23..ea2f2bbf544 100644 --- a/cpp/include/cudf/lists/detail/scatter.cuh +++ b/cpp/include/cudf/lists/detail/scatter.cuh @@ -39,6 +39,8 @@ #include #include +#include + #include namespace cudf { @@ -62,9 +64,10 @@ rmm::device_uvector list_vector_from_column( index_begin, index_end, vector.begin(), - [label, lists_column] __device__(size_type row_index) { - return unbound_list_view{label, lists_column, row_index}; - }); + cuda::proclaim_return_type( + [label, lists_column] __device__(size_type row_index) { + return unbound_list_view{label, lists_column, row_index}; + })); return vector; } @@ -115,7 +118,8 @@ std::unique_ptr scatter_impl(rmm::device_uvector cons lists_column_view(target); // Checks that target is a list column. auto list_size_begin = thrust::make_transform_iterator( - target_vector.begin(), [] __device__(unbound_list_view l) { return l.size(); }); + target_vector.begin(), + cuda::proclaim_return_type([] __device__(unbound_list_view l) { return l.size(); })); auto offsets_column = std::get<0>(cudf::detail::make_offsets_child_column( list_size_begin, list_size_begin + target.size(), stream, mr)); diff --git a/cpp/include/cudf/reduction/detail/reduction.cuh b/cpp/include/cudf/reduction/detail/reduction.cuh index 1620635e0e3..48b65a3fc54 100644 --- a/cpp/include/cudf/reduction/detail/reduction.cuh +++ b/cpp/include/cudf/reduction/detail/reduction.cuh @@ -19,6 +19,7 @@ #include "reduction_operators.cuh" #include +#include #include #include @@ -64,7 +65,7 @@ std::unique_ptr reduce(InputIterator d_in, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - auto const binary_op = op.get_binary_op(); + auto const binary_op = cudf::detail::cast_functor(op.get_binary_op()); auto const initial_value = init.value_or(op.template get_identity()); auto dev_result = rmm::device_scalar{initial_value, stream, mr}; @@ -124,7 +125,7 @@ std::unique_ptr reduce(InputIterator d_in, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - auto const binary_op = op.get_binary_op(); + auto const binary_op = cudf::detail::cast_functor(op.get_binary_op()); auto const initial_value = init.value_or(op.template get_identity()); auto dev_result = rmm::device_scalar{initial_value, stream}; @@ -190,7 +191,7 @@ std::unique_ptr reduce(InputIterator d_in, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - auto const binary_op = op.get_binary_op(); + auto const binary_op = cudf::detail::cast_functor(op.get_binary_op()); auto const initial_value = op.template get_identity(); rmm::device_scalar intermediate_result{initial_value, stream}; diff --git a/cpp/include/cudf/reduction/detail/segmented_reduction.cuh b/cpp/include/cudf/reduction/detail/segmented_reduction.cuh index 5c2eaf8cdcb..e86506681eb 100644 --- a/cpp/include/cudf/reduction/detail/segmented_reduction.cuh +++ b/cpp/include/cudf/reduction/detail/segmented_reduction.cuh @@ -18,6 +18,8 @@ #include "reduction_operators.cuh" +#include + #include #include #include @@ -45,7 +47,7 @@ namespace detail { * @param d_offset_begin Begin iterator to segment indices * @param d_offset_end End iterator to segment indices * @param d_out Output data iterator - * @param binary_op The reduction operator + * @param op The reduction operator * @param initial_value Initial value of the reduction * @param stream CUDA stream used for device memory operations and kernel launches * @@ -61,12 +63,12 @@ void segmented_reduce(InputIterator d_in, OffsetIterator d_offset_begin, OffsetIterator d_offset_end, OutputIterator d_out, - BinaryOp binary_op, + BinaryOp op, OutputType initial_value, rmm::cuda_stream_view stream) { auto const num_segments = static_cast(std::distance(d_offset_begin, d_offset_end)) - 1; - + auto const binary_op = cudf::detail::cast_functor(op); // Allocate temporary storage size_t temp_storage_bytes = 0; cub::DeviceSegmentedReduce::Reduce(nullptr, @@ -148,8 +150,8 @@ void segmented_reduce(InputIterator d_in, using OutputType = typename thrust::iterator_value::type; using IntermediateType = typename thrust::iterator_value::type; auto num_segments = static_cast(std::distance(d_offset_begin, d_offset_end)) - 1; - auto const binary_op = op.get_binary_op(); auto const initial_value = op.template get_identity(); + auto const binary_op = cudf::detail::cast_functor(op.get_binary_op()); rmm::device_uvector intermediate_result{static_cast(num_segments), stream}; diff --git a/cpp/include/cudf/strings/detail/copy_if_else.cuh b/cpp/include/cudf/strings/detail/copy_if_else.cuh index b553b491f09..6f0b199ff12 100644 --- a/cpp/include/cudf/strings/detail/copy_if_else.cuh +++ b/cpp/include/cudf/strings/detail/copy_if_else.cuh @@ -29,6 +29,8 @@ #include #include +#include + namespace cudf { namespace strings { namespace detail { @@ -78,10 +80,11 @@ std::unique_ptr copy_if_else(StringIterLeft lhs_begin, auto null_mask = (null_count > 0) ? std::move(valid_mask.first) : rmm::device_buffer{}; // build offsets column - auto offsets_transformer = [lhs_begin, rhs_begin, filter_fn] __device__(size_type idx) { - auto const result = filter_fn(idx) ? lhs_begin[idx] : rhs_begin[idx]; - return result.has_value() ? result->size_bytes() : 0; - }; + auto offsets_transformer = cuda::proclaim_return_type( + [lhs_begin, rhs_begin, filter_fn] __device__(size_type idx) { + auto const result = filter_fn(idx) ? lhs_begin[idx] : rhs_begin[idx]; + return result.has_value() ? result->size_bytes() : 0; + }); auto offsets_transformer_itr = thrust::make_transform_iterator( thrust::make_counting_iterator(0), offsets_transformer); diff --git a/cpp/include/cudf/strings/detail/gather.cuh b/cpp/include/cudf/strings/detail/gather.cuh index 7cd2338cb67..1523a81d63f 100644 --- a/cpp/include/cudf/strings/detail/gather.cuh +++ b/cpp/include/cudf/strings/detail/gather.cuh @@ -35,6 +35,8 @@ #include #include +#include + namespace cudf { namespace strings { namespace detail { @@ -301,11 +303,13 @@ std::unique_ptr gather(strings_column_view const& strings, auto const d_in_offsets = !strings.is_empty() ? strings.offsets_begin() : nullptr; auto offsets_itr = thrust::make_transform_iterator( - begin, [d_strings = *d_strings, d_in_offsets] __device__(size_type idx) { - if (NullifyOutOfBounds && (idx < 0 || idx >= d_strings.size())) { return 0; } - if (not d_strings.is_valid(idx)) { return 0; } - return d_in_offsets[idx + 1] - d_in_offsets[idx]; - }); + begin, + cuda::proclaim_return_type( + [d_strings = *d_strings, d_in_offsets] __device__(size_type idx) { + if (NullifyOutOfBounds && (idx < 0 || idx >= d_strings.size())) { return 0; } + if (not d_strings.is_valid(idx)) { return 0; } + return d_in_offsets[idx + 1] - d_in_offsets[idx]; + })); auto [out_offsets_column, total_bytes] = cudf::detail::make_offsets_child_column(offsets_itr, offsets_itr + output_count, stream, mr); diff --git a/cpp/include/cudf/strings/detail/merge.cuh b/cpp/include/cudf/strings/detail/merge.cuh index 5f50faa158e..aef1fe93792 100644 --- a/cpp/include/cudf/strings/detail/merge.cuh +++ b/cpp/include/cudf/strings/detail/merge.cuh @@ -32,6 +32,8 @@ #include #include +#include + namespace cudf { namespace strings { namespace detail { @@ -73,13 +75,14 @@ std::unique_ptr merge(strings_column_view const& lhs, null_mask = cudf::detail::create_null_mask(strings_count, mask_state::ALL_VALID, stream, mr); // build offsets column - auto offsets_transformer = [d_lhs, d_rhs] __device__(auto index_pair) { - auto const [side, index] = index_pair; - if (side == side::LEFT ? d_lhs.is_null(index) : d_rhs.is_null(index)) return 0; - auto d_str = - side == side::LEFT ? d_lhs.element(index) : d_rhs.element(index); - return d_str.size_bytes(); - }; + auto offsets_transformer = + cuda::proclaim_return_type([d_lhs, d_rhs] __device__(auto index_pair) { + auto const [side, index] = index_pair; + if (side == side::LEFT ? d_lhs.is_null(index) : d_rhs.is_null(index)) return 0; + auto d_str = + side == side::LEFT ? d_lhs.element(index) : d_rhs.element(index); + return d_str.size_bytes(); + }); auto offsets_transformer_itr = thrust::make_transform_iterator(begin, offsets_transformer); auto [offsets_column, bytes] = cudf::detail::make_offsets_child_column( offsets_transformer_itr, offsets_transformer_itr + strings_count, stream, mr); diff --git a/cpp/include/cudf/strings/detail/scatter.cuh b/cpp/include/cudf/strings/detail/scatter.cuh index 55dd5bda260..56eeec01715 100644 --- a/cpp/include/cudf/strings/detail/scatter.cuh +++ b/cpp/include/cudf/strings/detail/scatter.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,8 @@ #include #include +#include + namespace cudf { namespace strings { namespace detail { @@ -73,7 +75,9 @@ std::unique_ptr scatter(SourceIterator begin, // this ensures empty strings are not mapped to nulls in the make_strings_column function auto const size = thrust::distance(begin, end); auto itr = thrust::make_transform_iterator( - begin, [] __device__(string_view const sv) { return sv.empty() ? string_view{} : sv; }); + begin, cuda::proclaim_return_type([] __device__(string_view const sv) { + return sv.empty() ? string_view{} : sv; + })); // do the scatter thrust::scatter( diff --git a/cpp/include/cudf/strings/detail/strings_column_factories.cuh b/cpp/include/cudf/strings/detail/strings_column_factories.cuh index 7e608cd10f0..15b1c2bfec4 100644 --- a/cpp/include/cudf/strings/detail/strings_column_factories.cuh +++ b/cpp/include/cudf/strings/detail/strings_column_factories.cuh @@ -37,6 +37,8 @@ #include #include +#include + namespace cudf { namespace strings { namespace detail { @@ -79,9 +81,10 @@ std::unique_ptr make_strings_column(IndexPairIterator begin, if (strings_count == 0) return make_empty_column(type_id::STRING); // build offsets column from the strings sizes - auto offsets_transformer = [] __device__(string_index_pair item) -> size_type { - return (item.first != nullptr ? static_cast(item.second) : size_type{0}); - }; + auto offsets_transformer = + cuda::proclaim_return_type([] __device__(string_index_pair item) -> size_type { + return (item.first != nullptr ? static_cast(item.second) : size_type{0}); + }); auto offsets_transformer_itr = thrust::make_transform_iterator(begin, offsets_transformer); auto [offsets_column, bytes] = cudf::detail::make_offsets_child_column( offsets_transformer_itr, offsets_transformer_itr + strings_count, stream, mr); @@ -103,9 +106,10 @@ std::unique_ptr make_strings_column(IndexPairIterator begin, auto const d_data = offsets_view.template data(); auto const d_offsets = device_span{d_data, static_cast(offsets_view.size())}; - auto const str_begin = thrust::make_transform_iterator(begin, [] __device__(auto ip) { - return string_view{ip.first, ip.second}; - }); + auto const str_begin = thrust::make_transform_iterator( + begin, cuda::proclaim_return_type([] __device__(auto ip) { + return string_view{ip.first, ip.second}; + })); return gather_chars(str_begin, thrust::make_counting_iterator(0), @@ -180,7 +184,8 @@ std::unique_ptr make_strings_column(CharIterator chars_begin, offsets_begin, offsets_end, offsets_view.data(), - [] __device__(auto offset) { return static_cast(offset); }); + cuda::proclaim_return_type( + [] __device__(auto offset) { return static_cast(offset); })); // build chars column auto chars_column = strings::detail::create_chars_child_column(bytes, stream, mr); diff --git a/cpp/include/cudf/utilities/error.hpp b/cpp/include/cudf/utilities/error.hpp index afb9275e152..bf8b87e2563 100644 --- a/cpp/include/cudf/utilities/error.hpp +++ b/cpp/include/cudf/utilities/error.hpp @@ -76,6 +76,12 @@ struct logic_error : public std::logic_error, public stacktrace_recorder { // TODO Add an error code member? This would be useful for translating an // exception to an error code in a pure-C API + + ~logic_error() + { + // Needed so that the first instance of the implicit destructor for any TU isn't 'constructed' + // from a host+device function marking the implicit version also as host+device + } }; /** * @brief Exception thrown when a CUDA error is encountered. diff --git a/cpp/src/binaryop/compiled/binary_ops.cu b/cpp/src/binaryop/compiled/binary_ops.cu index 85ab5c6d6cb..464c15dac9d 100644 --- a/cpp/src/binaryop/compiled/binary_ops.cu +++ b/cpp/src/binaryop/compiled/binary_ops.cu @@ -33,6 +33,8 @@ #include #include +#include + namespace cudf { namespace binops { namespace compiled { @@ -231,7 +233,7 @@ struct null_considering_binop { cudf::string_view const invalid_str{nullptr, 0}; // Create a compare function lambda - auto minmax_func = + auto minmax_func = cuda::proclaim_return_type( [op, invalid_str] __device__( bool lhs_valid, bool rhs_valid, cudf::string_view lhs_value, cudf::string_view rhs_value) { if (!lhs_valid && !rhs_valid) @@ -244,7 +246,7 @@ struct null_considering_binop { return lhs_value; else return rhs_value; - }; + }); // Populate output column populate_out_col( diff --git a/cpp/src/copying/contiguous_split.cu b/cpp/src/copying/contiguous_split.cu index 6a32ee41e32..dd4af236ecf 100644 --- a/cpp/src/copying/contiguous_split.cu +++ b/cpp/src/copying/contiguous_split.cu @@ -45,6 +45,8 @@ #include #include +#include + #include #include @@ -1193,11 +1195,11 @@ std::unique_ptr compute_splits( thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_bufs), d_dst_buf_info, - [d_src_buf_info, - offset_stack_partition_size, - d_offset_stack, - d_indices, - num_src_bufs] __device__(std::size_t t) { + cuda::proclaim_return_type([d_src_buf_info, + offset_stack_partition_size, + d_offset_stack, + d_indices, + num_src_bufs] __device__(std::size_t t) { int const split_index = t / num_src_bufs; int const src_buf_index = t % num_src_bufs; auto const& src_info = d_src_buf_info[src_buf_index]; @@ -1264,7 +1266,7 @@ std::unique_ptr compute_splits( src_info.is_validity ? 1 : 0, src_buf_index, split_index}; - }); + })); // compute total size of each partition // key is the split index @@ -1405,9 +1407,11 @@ std::unique_ptr chunk_iteration_state::create( rmm::device_uvector d_batch_offsets(num_bufs + 1, stream, temp_mr); auto const buf_count_iter = cudf::detail::make_counting_transform_iterator( - 0, [num_bufs, num_batches = num_batches_func{batches.begin()}] __device__(size_type i) { - return i == num_bufs ? 0 : num_batches(i); - }); + 0, + cuda::proclaim_return_type( + [num_bufs, num_batches = num_batches_func{batches.begin()}] __device__(size_type i) { + return i == num_bufs ? 0 : num_batches(i); + })); thrust::exclusive_scan(rmm::exec_policy(stream, temp_mr), buf_count_iter, @@ -1631,25 +1635,26 @@ std::unique_ptr compute_batches(int num_bufs, d_dst_buf_info, d_dst_buf_info + num_bufs, batches.begin(), - [desired_batch_size = desired_batch_size] __device__( - dst_buf_info const& buf) -> thrust::pair { - // Total bytes for this incoming partition - std::size_t const bytes = - static_cast(buf.num_elements) * static_cast(buf.element_size); - - // This clause handles nested data types (e.g. list or string) that store no data in the row - // columns, only in their children. - if (bytes == 0) { return {1, 0}; } - - // The number of batches we want to subdivide this buffer into - std::size_t const num_batches = std::max( - std::size_t{1}, util::round_up_unsafe(bytes, desired_batch_size) / desired_batch_size); - - // NOTE: leaving batch size as a separate parameter for future tuning - // possibilities, even though in the current implementation it will be a - // constant. - return {num_batches, desired_batch_size}; - }); + cuda::proclaim_return_type>( + [desired_batch_size = desired_batch_size] __device__( + dst_buf_info const& buf) -> thrust::pair { + // Total bytes for this incoming partition + std::size_t const bytes = + static_cast(buf.num_elements) * static_cast(buf.element_size); + + // This clause handles nested data types (e.g. list or string) that store no data in the row + // columns, only in their children. + if (bytes == 0) { return {1, 0}; } + + // The number of batches we want to subdivide this buffer into + std::size_t const num_batches = std::max( + std::size_t{1}, util::round_up_unsafe(bytes, desired_batch_size) / desired_batch_size); + + // NOTE: leaving batch size as a separate parameter for future tuning + // possibilities, even though in the current implementation it will be a + // constant. + return {num_batches, desired_batch_size}; + })); return chunk_iteration_state::create(batches, num_bufs, @@ -1789,7 +1794,8 @@ struct contiguous_split_state { auto values = thrust::make_transform_iterator( chunk_iter_state->d_batched_dst_buf_info.begin(), - [] __device__(dst_buf_info const& info) { return info.valid_count; }); + cuda::proclaim_return_type( + [] __device__(dst_buf_info const& info) { return info.valid_count; })); thrust::reduce_by_key(rmm::exec_policy(stream, temp_mr), keys, diff --git a/cpp/src/copying/gather.cu b/cpp/src/copying/gather.cu index 267c71591d5..2083d3ed618 100644 --- a/cpp/src/copying/gather.cu +++ b/cpp/src/copying/gather.cu @@ -28,6 +28,8 @@ #include +#include + namespace cudf { namespace detail { @@ -46,7 +48,8 @@ std::unique_ptr gather(table_view const& source_table, if (neg_indices == negative_index_policy::ALLOWED) { cudf::size_type n_rows = source_table.num_rows(); - auto idx_converter = [n_rows] __device__(size_type in) { return in < 0 ? in + n_rows : in; }; + auto idx_converter = cuda::proclaim_return_type( + [n_rows] __device__(size_type in) { return in < 0 ? in + n_rows : in; }); return gather(source_table, thrust::make_transform_iterator(map_begin, idx_converter), thrust::make_transform_iterator(map_end, idx_converter), diff --git a/cpp/src/copying/reverse.cu b/cpp/src/copying/reverse.cu index fbbbc56e712..884c93e268c 100644 --- a/cpp/src/copying/reverse.cu +++ b/cpp/src/copying/reverse.cu @@ -32,6 +32,8 @@ #include #include +#include + namespace cudf { namespace detail { std::unique_ptr
reverse(table_view const& source_table, @@ -39,8 +41,10 @@ std::unique_ptr
reverse(table_view const& source_table, rmm::mr::device_memory_resource* mr) { size_type num_rows = source_table.num_rows(); - auto elements = - make_counting_transform_iterator(0, [num_rows] __device__(auto i) { return num_rows - i - 1; }); + auto elements = make_counting_transform_iterator( + 0, cuda::proclaim_return_type([num_rows] __device__(auto i) { + return num_rows - i - 1; + })); auto elements_end = elements + source_table.num_rows(); return gather(source_table, elements, elements_end, out_of_bounds_policy::DONT_CHECK, stream, mr); diff --git a/cpp/src/copying/sample.cu b/cpp/src/copying/sample.cu index f3d8d624171..e7f5522d3b3 100644 --- a/cpp/src/copying/sample.cu +++ b/cpp/src/copying/sample.cu @@ -31,6 +31,8 @@ #include #include +#include + namespace cudf { namespace detail { @@ -51,12 +53,12 @@ std::unique_ptr
sample(table_view const& input, if (n == 0) return cudf::empty_like(input); if (replacement == sample_with_replacement::TRUE) { - auto RandomGen = [seed, num_rows] __device__(auto i) { + auto RandomGen = cuda::proclaim_return_type([seed, num_rows] __device__(auto i) { thrust::default_random_engine rng(seed); thrust::uniform_int_distribution dist{0, num_rows - 1}; rng.discard(i); return dist(rng); - }; + }); auto begin = cudf::detail::make_counting_transform_iterator(0, RandomGen); diff --git a/cpp/src/copying/scatter.cu b/cpp/src/copying/scatter.cu index 879ddb5048e..8f326184012 100644 --- a/cpp/src/copying/scatter.cu +++ b/cpp/src/copying/scatter.cu @@ -43,6 +43,8 @@ #include #include +#include + namespace cudf { namespace detail { namespace { @@ -356,9 +358,11 @@ std::unique_ptr
scatter(std::vector> // > (2^31)/2, but the end result after the final (% n_rows) will fit. so we'll do the computation // using a signed 64 bit value. auto scatter_iter = thrust::make_transform_iterator( - map_begin, [n_rows = static_cast(n_rows)] __device__(size_type in) -> size_type { - return ((static_cast(in) % n_rows) + n_rows) % n_rows; - }); + map_begin, + cuda::proclaim_return_type( + [n_rows = static_cast(n_rows)] __device__(size_type in) -> size_type { + return static_cast(((static_cast(in) % n_rows) + n_rows) % n_rows); + })); // Dispatch over data type per column auto result = std::vector>(target.num_columns()); diff --git a/cpp/src/dictionary/detail/concatenate.cu b/cpp/src/dictionary/detail/concatenate.cu index 121b5bce499..024acaa872d 100644 --- a/cpp/src/dictionary/detail/concatenate.cu +++ b/cpp/src/dictionary/detail/concatenate.cu @@ -43,6 +43,8 @@ #include #include +#include + #include #include @@ -151,10 +153,11 @@ struct dispatch_compute_indices { keys_view->begin(), thrust::make_transform_iterator( thrust::make_counting_iterator(0), - [d_offsets, d_map_to_keys, d_all_indices, indices_itr] __device__(size_type idx) { - if (d_all_indices.is_null(idx)) return 0; - return indices_itr[idx] + d_offsets[d_map_to_keys[idx]].first; - })); + cuda::proclaim_return_type( + [d_offsets, d_map_to_keys, d_all_indices, indices_itr] __device__(size_type idx) { + if (d_all_indices.is_null(idx)) return 0; + return indices_itr[idx] + d_offsets[d_map_to_keys[idx]].first; + }))); auto new_keys_view = column_device_view::create(new_keys, stream); @@ -256,10 +259,10 @@ std::unique_ptr concatenate(host_span columns, // build a vector of values to map the old indices to the concatenated keys auto children_offsets = child_offsets_fn.create_children_offsets(stream); rmm::device_uvector map_to_keys(indices_size, stream); - auto indices_itr = - cudf::detail::make_counting_transform_iterator(1, [] __device__(size_type idx) { + auto indices_itr = cudf::detail::make_counting_transform_iterator( + 1, cuda::proclaim_return_type([] __device__(size_type idx) { return offsets_pair{0, idx}; - }); + })); // the indices offsets (pair.second) are for building the map thrust::lower_bound( rmm::exec_policy(stream), diff --git a/cpp/src/filling/repeat.cu b/cpp/src/filling/repeat.cu index 677d9a09515..b3ed9743953 100644 --- a/cpp/src/filling/repeat.cu +++ b/cpp/src/filling/repeat.cu @@ -43,6 +43,8 @@ #include #include +#include + #include #include @@ -146,7 +148,8 @@ std::unique_ptr
repeat(table_view const& input_table, auto output_size = input_table.num_rows() * count; auto map_begin = cudf::detail::make_counting_transform_iterator( - 0, [count] __device__(auto i) { return i / count; }); + 0, + cuda::proclaim_return_type([count] __device__(size_type i) { return i / count; })); auto map_end = map_begin + output_size; return gather(input_table, map_begin, map_end, out_of_bounds_policy::DONT_CHECK, stream, mr); diff --git a/cpp/src/groupby/hash/groupby.cu b/cpp/src/groupby/hash/groupby.cu index 195c8924c9a..32693487c32 100644 --- a/cpp/src/groupby/hash/groupby.cu +++ b/cpp/src/groupby/hash/groupby.cu @@ -54,12 +54,13 @@ #include #include +#include +#include + #include #include #include -#include - namespace cudf { namespace groupby { namespace detail { @@ -524,7 +525,8 @@ rmm::device_uvector extract_populated_keys(map_type c { rmm::device_uvector populated_keys(num_keys, stream); - auto const get_key = [] __device__(auto const& element) { return element.first; }; // first = key + auto const get_key = cuda::proclaim_return_type::key_type>( + [] __device__(auto const& element) { return element.first; }); // first = key auto const key_used = [unused = map.get_unused_key()] __device__(auto key) { return key != unused; }; diff --git a/cpp/src/groupby/sort/group_count.cu b/cpp/src/groupby/sort/group_count.cu index e7274034f55..e35b0c2b2fe 100644 --- a/cpp/src/groupby/sort/group_count.cu +++ b/cpp/src/groupby/sort/group_count.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,8 @@ #include #include +#include + namespace cudf { namespace groupby { namespace detail { @@ -54,7 +56,9 @@ std::unique_ptr group_count_valid(column_view const& values, // so we need to transform it to cast it to an integer type auto bitmask_iterator = thrust::make_transform_iterator(cudf::detail::make_validity_iterator(*values_view), - [] __device__(auto b) { return static_cast(b); }); + cuda::proclaim_return_type([] __device__(auto b) { + return static_cast(b); + })); thrust::reduce_by_key(rmm::exec_policy(stream), group_labels.begin(), diff --git a/cpp/src/groupby/sort/group_nth_element.cu b/cpp/src/groupby/sort/group_nth_element.cu index 58d76a8ab43..037fa9a735c 100644 --- a/cpp/src/groupby/sort/group_nth_element.cu +++ b/cpp/src/groupby/sort/group_nth_element.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,6 +37,8 @@ #include #include +#include + namespace cudf { namespace groupby { namespace detail { @@ -81,7 +83,9 @@ std::unique_ptr group_nth_element(column_view const& values, auto values_view = column_device_view::create(values, stream); auto bitmask_iterator = thrust::make_transform_iterator(cudf::detail::make_validity_iterator(*values_view), - [] __device__(auto b) { return static_cast(b); }); + cuda::proclaim_return_type([] __device__(auto b) { + return static_cast(b); + })); rmm::device_uvector intra_group_index(values.size(), stream); // intra group index for valids only. thrust::exclusive_scan_by_key(rmm::exec_policy(stream), diff --git a/cpp/src/groupby/sort/sort_helper.cu b/cpp/src/groupby/sort/sort_helper.cu index 4c87c091b34..1a696c8bc28 100644 --- a/cpp/src/groupby/sort/sort_helper.cu +++ b/cpp/src/groupby/sort/sort_helper.cu @@ -42,6 +42,8 @@ #include #include +#include + #include #include #include @@ -291,8 +293,10 @@ std::unique_ptr
sort_groupby_helper::unique_keys(rmm::cuda_stream_view st { auto idx_data = key_sort_order(stream).data(); - auto gather_map_it = thrust::make_transform_iterator( - group_offsets(stream).begin(), [idx_data] __device__(size_type i) { return idx_data[i]; }); + auto gather_map_it = + thrust::make_transform_iterator(group_offsets(stream).begin(), + cuda::proclaim_return_type( + [idx_data] __device__(size_type i) { return idx_data[i]; })); return cudf::detail::gather(_keys, gather_map_it, diff --git a/cpp/src/io/json/json_column.cu b/cpp/src/io/json/json_column.cu index 5ea29fcfd2d..056cce18a52 100644 --- a/cpp/src/io/json/json_column.cu +++ b/cpp/src/io/json/json_column.cu @@ -47,6 +47,7 @@ #include #include +#include #include #include @@ -648,11 +649,12 @@ void make_device_json_column(device_span input, auto& parent_col_ids = sorted_col_ids; // reuse sorted_col_ids auto parent_col_id = thrust::make_transform_iterator( thrust::make_counting_iterator(0), - [col_ids = col_ids.begin(), - parent_node_ids = tree.parent_node_ids.begin()] __device__(size_type node_id) { - return parent_node_ids[node_id] == parent_node_sentinel ? parent_node_sentinel - : col_ids[parent_node_ids[node_id]]; - }); + cuda::proclaim_return_type( + [col_ids = col_ids.begin(), + parent_node_ids = tree.parent_node_ids.begin()] __device__(size_type node_id) { + return parent_node_ids[node_id] == parent_node_sentinel ? parent_node_sentinel + : col_ids[parent_node_ids[node_id]]; + })); auto const list_children_end = thrust::copy_if( rmm::exec_policy(stream), thrust::make_zip_iterator(thrust::make_counting_iterator(0), parent_col_id), diff --git a/cpp/src/io/json/json_tree.cu b/cpp/src/io/json/json_tree.cu index da5b0eedfbd..9a70b987fa5 100644 --- a/cpp/src/io/json/json_tree.cu +++ b/cpp/src/io/json/json_tree.cu @@ -55,6 +55,8 @@ #include #include +#include + #include namespace cudf::io::json { @@ -274,9 +276,11 @@ tree_meta_t get_tree_representation(device_span tokens, { rmm::device_uvector token_levels(num_tokens, stream); auto const push_pop_it = thrust::make_transform_iterator( - tokens.begin(), [does_push, does_pop] __device__(PdaTokenT const token) -> size_type { - return does_push(token) - does_pop(token); - }); + tokens.begin(), + cuda::proclaim_return_type( + [does_push, does_pop] __device__(PdaTokenT const token) -> size_type { + return does_push(token) - does_pop(token); + })); thrust::exclusive_scan( rmm::exec_policy(stream), push_pop_it, push_pop_it + num_tokens, token_levels.begin()); @@ -407,13 +411,15 @@ rmm::device_uvector hash_node_type_with_field_name(device_span{}(field_name); - }; + auto const d_hasher = cuda::proclaim_return_type< + typename cudf::hashing::detail::default_hash::result_type>( + [d_input = d_input.data(), + node_range_begin = d_tree.node_range_begin.data(), + node_range_end = d_tree.node_range_end.data()] __device__(auto node_id) { + auto const field_name = cudf::string_view( + d_input + node_range_begin[node_id], node_range_end[node_id] - node_range_begin[node_id]); + return cudf::hashing::detail::default_hash{}(field_name); + }); auto const d_equal = [d_input = d_input.data(), node_range_begin = d_tree.node_range_begin.data(), node_range_end = d_tree.node_range_end.data()] __device__(auto node_id1, diff --git a/cpp/src/io/json/write_json.cu b/cpp/src/io/json/write_json.cu index f1a43baa9b0..b2017ee513f 100644 --- a/cpp/src/io/json/write_json.cu +++ b/cpp/src/io/json/write_json.cu @@ -56,6 +56,8 @@ #include #include +#include + #include #include #include @@ -300,9 +302,9 @@ std::unique_ptr struct_to_strings(table_view const& strings_columns, // if previous column was null, then we skip the value separator rmm::device_uvector d_str_separator(total_rows, stream); auto row_num = cudf::detail::make_counting_transform_iterator( - 0, [tbl = *tbl_device_view] __device__(auto idx) -> size_type { - return idx / tbl.num_columns(); - }); + 0, + cuda::proclaim_return_type([tbl = *tbl_device_view] __device__(auto idx) + -> size_type { return idx / tbl.num_columns(); })); auto validity_iterator = cudf::detail::make_counting_transform_iterator(0, validity_fn{*tbl_device_view}); thrust::exclusive_scan_by_key(rmm::exec_policy(stream), @@ -337,7 +339,9 @@ std::unique_ptr struct_to_strings(table_view const& strings_columns, auto old_offsets = strings_column_view(joined_col->view()).offsets(); rmm::device_uvector row_string_offsets(strings_columns.num_rows() + 1, stream, mr); auto const d_strview_offsets = cudf::detail::make_counting_transform_iterator( - 0, [num_strviews_per_row] __device__(size_type const i) { return i * num_strviews_per_row; }); + 0, cuda::proclaim_return_type([num_strviews_per_row] __device__(size_type const i) { + return i * num_strviews_per_row; + })); thrust::gather(rmm::exec_policy(stream), d_strview_offsets, d_strview_offsets + row_string_offsets.size(), @@ -395,11 +399,13 @@ std::unique_ptr join_list_of_strings(lists_column_view const& lists_stri rmm::device_uvector d_strview_offsets(num_offsets, stream); auto num_strings_per_list = cudf::detail::make_counting_transform_iterator( - 0, [offsets = offsets.begin(), num_offsets] __device__(size_type idx) { - if (idx + 1 >= num_offsets) return 0; - auto const length = offsets[idx + 1] - offsets[idx]; - return length == 0 ? 2 : (2 + length + length - 1); - }); + 0, + cuda::proclaim_return_type( + [offsets = offsets.begin(), num_offsets] __device__(size_type idx) { + if (idx + 1 >= num_offsets) return 0; + auto const length = offsets[idx + 1] - offsets[idx]; + return length == 0 ? 2 : (2 + length + length - 1); + })); thrust::exclusive_scan(rmm::exec_policy(stream), num_strings_per_list, num_strings_per_list + num_offsets, diff --git a/cpp/src/io/parquet/reader_impl_preprocess.cu b/cpp/src/io/parquet/reader_impl_preprocess.cu index 7a1d2faa93c..c91d5959d20 100644 --- a/cpp/src/io/parquet/reader_impl_preprocess.cu +++ b/cpp/src/io/parquet/reader_impl_preprocess.cu @@ -41,6 +41,8 @@ #include #include +#include + #include namespace cudf::io::parquet::detail { @@ -311,11 +313,11 @@ int decode_page_headers(cudf::detail::hostdevice_vector& chunks // compute max bytes needed for level data auto level_bit_size = cudf::detail::make_counting_transform_iterator( - 0, [chunks = chunks.d_begin()] __device__(int i) { + 0, cuda::proclaim_return_type([chunks = chunks.d_begin()] __device__(int i) { auto c = chunks[i]; return static_cast( max(c.level_bits[level_type::REPETITION], c.level_bits[level_type::DEFINITION])); - }); + })); // max level data bit size. int const max_level_bits = thrust::reduce(rmm::exec_policy(stream), level_bit_size, diff --git a/cpp/src/io/text/multibyte_split.cu b/cpp/src/io/text/multibyte_split.cu index deaab5995af..443ca0f5fe7 100644 --- a/cpp/src/io/text/multibyte_split.cu +++ b/cpp/src/io/text/multibyte_split.cu @@ -46,6 +46,8 @@ #include #include +#include + #include #include #include @@ -527,26 +529,28 @@ std::unique_ptr multibyte_split(cudf::io::text::data_chunk_source global_offsets.begin(), global_offsets.end(), offsets.begin() + insert_begin, - [baseline = *first_row_offset] __device__(byte_offset global_offset) { - return static_cast(global_offset - baseline); - }); + cuda::proclaim_return_type( + [baseline = *first_row_offset] __device__(byte_offset global_offset) { + return static_cast(global_offset - baseline); + })); auto string_count = offsets.size() - 1; if (strip_delimiters) { auto it = cudf::detail::make_counting_transform_iterator( 0, - [ofs = offsets.data(), - chars = chars.data(), - delim_size = static_cast(delimiter.size()), - last_row = static_cast(string_count) - 1, - insert_end] __device__(size_type row) { - auto const begin = ofs[row]; - auto const len = ofs[row + 1] - begin; - if (row == last_row && insert_end) { - return thrust::make_pair(chars + begin, len); - } else { - return thrust::make_pair(chars + begin, std::max(0, len - delim_size)); - }; - }); + cuda::proclaim_return_type>( + [ofs = offsets.data(), + chars = chars.data(), + delim_size = static_cast(delimiter.size()), + last_row = static_cast(string_count) - 1, + insert_end] __device__(size_type row) { + auto const begin = ofs[row]; + auto const len = ofs[row + 1] - begin; + if (row == last_row && insert_end) { + return thrust::make_pair(chars + begin, len); + } else { + return thrust::make_pair(chars + begin, std::max(0, len - delim_size)); + }; + })); return cudf::strings::detail::make_strings_column(it, it + string_count, stream, mr); } else { return cudf::make_strings_column( diff --git a/cpp/src/lists/combine/concatenate_list_elements.cu b/cpp/src/lists/combine/concatenate_list_elements.cu index 99dbd55678b..26fb81a600f 100644 --- a/cpp/src/lists/combine/concatenate_list_elements.cu +++ b/cpp/src/lists/combine/concatenate_list_elements.cu @@ -40,6 +40,8 @@ #include #include +#include + namespace cudf { namespace lists { namespace detail { @@ -133,11 +135,12 @@ generate_list_offsets_and_validities(column_view const& input, // Compute output list sizes and validities. auto sizes_itr = cudf::detail::make_counting_transform_iterator( 0, - [lists_of_lists_dv = *lists_of_lists_dv_ptr, - lists_dv = *lists_dv_ptr, - d_row_offsets, - d_list_offsets, - d_validities = validities.begin()] __device__(auto const idx) { + cuda::proclaim_return_type([lists_of_lists_dv = *lists_of_lists_dv_ptr, + lists_dv = *lists_dv_ptr, + d_row_offsets, + d_list_offsets, + d_validities = + validities.begin()] __device__(auto const idx) { if (d_row_offsets[idx] == d_row_offsets[idx + 1]) { // This is a null/empty row. d_validities[idx] = static_cast(lists_of_lists_dv.is_valid(idx)); return size_type{0}; @@ -154,7 +157,7 @@ generate_list_offsets_and_validities(column_view const& input, // Compute size of the output list as sum of sizes of all lists in the current input row. return d_list_offsets[d_row_offsets[idx + 1]] - d_list_offsets[d_row_offsets[idx]]; - }); + })); // Compute offsets from sizes. auto out_offsets = std::get<0>( cudf::detail::make_offsets_child_column(sizes_itr, sizes_itr + num_rows, stream, mr)); diff --git a/cpp/src/lists/combine/concatenate_rows.cu b/cpp/src/lists/combine/concatenate_rows.cu index 49be7b5ff17..e143fae5742 100644 --- a/cpp/src/lists/combine/concatenate_rows.cu +++ b/cpp/src/lists/combine/concatenate_rows.cu @@ -33,6 +33,8 @@ #include #include +#include + namespace cudf { namespace lists { namespace detail { @@ -80,15 +82,17 @@ generate_regrouped_offsets_and_null_mask(table_device_view const& input, auto offsets = cudf::make_fixed_width_column( data_type{type_to_id()}, input.num_rows() + 1, mask_state::UNALLOCATED, stream, mr); - auto keys = thrust::make_transform_iterator(thrust::make_counting_iterator(size_t{0}), - [num_columns = input.num_columns()] __device__( - size_t i) -> size_type { return i / num_columns; }); + auto keys = thrust::make_transform_iterator( + thrust::make_counting_iterator(size_t{0}), + cuda::proclaim_return_type([num_columns = input.num_columns()] __device__( + size_t i) -> size_type { return i / num_columns; })); // generate sizes for the regrouped rows auto values = thrust::make_transform_iterator( thrust::make_counting_iterator(size_t{0}), - [input, row_null_counts = row_null_counts.data(), null_policy] __device__( - size_t i) -> size_type { + cuda::proclaim_return_type([input, + row_null_counts = row_null_counts.data(), + null_policy] __device__(size_t i) -> size_type { auto const col_index = i % input.num_columns(); auto const row_index = i / input.num_columns(); @@ -105,7 +109,7 @@ generate_regrouped_offsets_and_null_mask(table_device_view const& input, input.column(col_index).child(lists_column_view::offsets_column_index).data() + input.column(col_index).offset(); return offsets[row_index + 1] - offsets[row_index]; - }); + })); thrust::reduce_by_key(rmm::exec_policy(stream), keys, @@ -157,17 +161,19 @@ rmm::device_uvector generate_null_counts(table_device_view const& inp { rmm::device_uvector null_counts(input.num_rows(), stream); - auto keys = thrust::make_transform_iterator(thrust::make_counting_iterator(size_t{0}), - [num_columns = input.num_columns()] __device__( - size_t i) -> size_type { return i / num_columns; }); + auto keys = thrust::make_transform_iterator( + thrust::make_counting_iterator(size_t{0}), + cuda::proclaim_return_type([num_columns = input.num_columns()] __device__( + size_t i) -> size_type { return i / num_columns; })); auto null_values = thrust::make_transform_iterator( - thrust::make_counting_iterator(size_t{0}), [input] __device__(size_t i) -> size_type { + thrust::make_counting_iterator(size_t{0}), + cuda::proclaim_return_type([input] __device__(size_t i) -> size_type { auto const col_index = i % input.num_columns(); auto const row_index = i / input.num_columns(); auto const& col = input.column(col_index); return col.null_mask() ? (bit_is_set(col.null_mask(), row_index + col.offset()) ? 0 : 1) : 0; - }); + })); thrust::reduce_by_key(rmm::exec_policy(stream), keys, @@ -237,12 +243,13 @@ std::unique_ptr concatenate_rows(table_view const& input, return cudf::detail::valid_if( iter, iter + (input.num_rows() * input.num_columns()), - [num_rows = input.num_rows(), - num_columns = input.num_columns(), - row_null_counts = row_null_counts.data()] __device__(size_t i) -> size_type { - auto const row_index = i % num_rows; - return row_null_counts[row_index] != num_columns; - }, + cuda::proclaim_return_type( + [num_rows = input.num_rows(), + num_columns = input.num_columns(), + row_null_counts = row_null_counts.data()] __device__(size_t i) -> size_type { + auto const row_index = i % num_rows; + return row_null_counts[row_index] != num_columns; + }), stream, rmm::mr::get_current_device_resource()); } @@ -250,11 +257,12 @@ std::unique_ptr concatenate_rows(table_view const& input, return cudf::detail::valid_if( iter, iter + (input.num_rows() * input.num_columns()), - [num_rows = input.num_rows(), - row_null_counts = row_null_counts.data()] __device__(size_t i) -> size_type { - auto const row_index = i % num_rows; - return row_null_counts[row_index] == 0; - }, + cuda::proclaim_return_type( + [num_rows = input.num_rows(), + row_null_counts = row_null_counts.data()] __device__(size_t i) -> size_type { + auto const row_index = i % num_rows; + return row_null_counts[row_index] == 0; + }), stream, rmm::mr::get_current_device_resource()); }(); @@ -267,13 +275,14 @@ std::unique_ptr concatenate_rows(table_view const& input, // this we can simply swap in a new set of offsets that re-groups them. bmo auto iter = thrust::make_transform_iterator( thrust::make_counting_iterator(size_t{0}), - [num_columns = input.num_columns(), - num_rows = input.num_rows()] __device__(size_t i) -> size_type { - auto const src_col_index = i % num_columns; - auto const src_row_index = i / num_columns; - auto const concat_row_index = (src_col_index * num_rows) + src_row_index; - return concat_row_index; - }); + cuda::proclaim_return_type( + [num_columns = input.num_columns(), + num_rows = input.num_rows()] __device__(size_t i) -> size_type { + auto const src_col_index = i % num_columns; + auto const src_row_index = i / num_columns; + auto const concat_row_index = (src_col_index * num_rows) + src_row_index; + return concat_row_index; + })); auto gathered = cudf::detail::gather(table_view({*concat}), iter, iter + (input.num_columns() * input.num_rows()), diff --git a/cpp/src/lists/contains.cu b/cpp/src/lists/contains.cu index cd2bc493bc7..1a88844928e 100644 --- a/cpp/src/lists/contains.cu +++ b/cpp/src/lists/contains.cu @@ -41,6 +41,8 @@ #include #include +#include + #include namespace cudf::lists { @@ -204,9 +206,10 @@ std::unique_ptr dispatch_index_of(lists_column_view const& lists, auto const lists_cdv_ptr = column_device_view::create(lists.parent(), stream); auto const input_it = cudf::detail::make_counting_transform_iterator( size_type{0}, - [lists = cudf::detail::lists_column_device_view{*lists_cdv_ptr}] __device__(auto const idx) { - return list_device_view{lists, idx}; - }); + cuda::proclaim_return_type( + [lists = cudf::detail::lists_column_device_view{*lists_cdv_ptr}] __device__(auto const idx) { + return list_device_view{lists, idx}; + })); auto out_positions = make_numeric_column( data_type{type_to_id()}, num_rows, cudf::mask_state::UNALLOCATED, stream, mr); @@ -254,10 +257,10 @@ std::unique_ptr to_contains(std::unique_ptr&& key_positions, positions_begin, positions_begin + key_positions->size(), result->mutable_view().template begin(), - [] __device__(auto const i) { + cuda::proclaim_return_type([] __device__(auto const i) { // position == NOT_FOUND_SENTINEL: the list does not contain the search key. return i != NOT_FOUND_SENTINEL; - }); + })); auto const null_count = key_positions->null_count(); [[maybe_unused]] auto [data, null_mask, children] = key_positions->release(); @@ -346,18 +349,19 @@ std::unique_ptr contains_nulls(lists_column_view const& lists, auto const out_begin = output->mutable_view().template begin(); auto const lists_cdv_ptr = column_device_view::create(lists_cv, stream); - thrust::tabulate(rmm::exec_policy(stream), - out_begin, - out_begin + lists.size(), - [lists = cudf::detail::lists_column_device_view{*lists_cdv_ptr}] __device__( - auto const list_idx) { - auto const list = list_device_view{lists, list_idx}; - return list.is_null() || - thrust::any_of(thrust::seq, - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(list.size()), - [&list](auto const idx) { return list.is_null(idx); }); - }); + thrust::tabulate( + rmm::exec_policy(stream), + out_begin, + out_begin + lists.size(), + cuda::proclaim_return_type([lists = cudf::detail::lists_column_device_view{ + *lists_cdv_ptr}] __device__(auto const list_idx) { + auto const list = list_device_view{lists, list_idx}; + return list.is_null() || + thrust::any_of(thrust::seq, + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(list.size()), + [&list](auto const idx) { return list.is_null(idx); }); + })); return output; } diff --git a/cpp/src/lists/copying/scatter_helper.cu b/cpp/src/lists/copying/scatter_helper.cu index ca5358798c0..a341028d805 100644 --- a/cpp/src/lists/copying/scatter_helper.cu +++ b/cpp/src/lists/copying/scatter_helper.cu @@ -30,6 +30,8 @@ #include #include +#include + namespace cudf { namespace lists { namespace detail { @@ -189,11 +191,11 @@ struct list_child_constructor { thrust::make_counting_iterator(0), thrust::make_counting_iterator(child_column->size()), child_column->mutable_view().begin(), - [offset_begin = list_offsets.begin(), - offset_size = list_offsets.size(), - d_list_vector = list_vector.begin(), - source_lists, - target_lists] __device__(auto index) { + cuda::proclaim_return_type([offset_begin = list_offsets.begin(), + offset_size = list_offsets.size(), + d_list_vector = list_vector.begin(), + source_lists, + target_lists] __device__(auto index) { auto const list_index_iter = thrust::upper_bound(thrust::seq, offset_begin, offset_begin + offset_size, index); auto const list_index = @@ -201,7 +203,7 @@ struct list_child_constructor { auto const intra_index = static_cast(index - offset_begin[list_index]); auto actual_list_row = d_list_vector[list_index].bind_to_column(source_lists, target_lists); return actual_list_row.template element(intra_index); - }); + })); child_column->set_null_count(child_null_mask.second); @@ -241,12 +243,12 @@ struct list_child_constructor { thrust::make_counting_iterator(0), thrust::make_counting_iterator(string_views.size()), string_views.begin(), - [offset_begin = list_offsets.begin(), - offset_size = list_offsets.size(), - d_list_vector = list_vector.begin(), - source_lists, - target_lists, - null_string_view] __device__(auto index) { + cuda::proclaim_return_type([offset_begin = list_offsets.begin(), + offset_size = list_offsets.size(), + d_list_vector = list_vector.begin(), + source_lists, + target_lists, + null_string_view] __device__(auto index) { auto const list_index_iter = thrust::upper_bound(thrust::seq, offset_begin, offset_begin + offset_size, index); auto const list_index = @@ -264,7 +266,7 @@ struct list_child_constructor { // ensure a string from an all-empty column is not mapped to the null placeholder auto const empty_string_view = string_view{}; return d_str.empty() ? empty_string_view : d_str; - }); + })); // string_views should now have been populated with source and target references. auto sv_span = cudf::device_span(string_views); @@ -308,11 +310,11 @@ struct list_child_constructor { thrust::make_counting_iterator(0), thrust::make_counting_iterator(child_list_views.size()), child_list_views.begin(), - [offset_begin = list_offsets.begin(), - offset_size = list_offsets.size(), - d_list_vector = list_vector.begin(), - source_lists, - target_lists] __device__(auto index) { + cuda::proclaim_return_type([offset_begin = list_offsets.begin(), + offset_size = list_offsets.size(), + d_list_vector = list_vector.begin(), + source_lists, + target_lists] __device__(auto index) { auto const list_index_iter = thrust::upper_bound(thrust::seq, offset_begin, offset_begin + offset_size, index); auto const list_index = @@ -331,12 +333,13 @@ struct list_child_constructor { auto size = child_lists_offsets_ptr[child_row_index + 1] - child_lists_offsets_ptr[child_row_index]; return unbound_list_view{label, child_row_index, size}; - }); + })); // child_list_views should now have been populated, with source and target references. auto begin = thrust::make_transform_iterator( - child_list_views.begin(), [] __device__(auto const& row) { return row.size(); }); + child_list_views.begin(), + cuda::proclaim_return_type([] __device__(auto const& row) { return row.size(); })); auto child_offsets = std::get<0>( cudf::detail::make_offsets_child_column(begin, begin + child_list_views.size(), stream, mr)); diff --git a/cpp/src/lists/copying/segmented_gather.cu b/cpp/src/lists/copying/segmented_gather.cu index 855ceadf33f..5439a95966b 100644 --- a/cpp/src/lists/copying/segmented_gather.cu +++ b/cpp/src/lists/copying/segmented_gather.cu @@ -24,6 +24,8 @@ #include #include +#include + #include namespace cudf { @@ -55,29 +57,31 @@ std::unique_ptr segmented_gather(lists_column_view const& value_column, }; // Calculate Flattened gather indices (value_offset[row]+sub_index - auto transformer = [values_lists_view = *value_device_view, - value_offsets, - map_begin, - gather_index_begin, - gather_index_end, - bounds_policy, - out_of_bounds] __device__(size_type index) -> size_type { - // Get each row's offset. (Each row is a list). - auto offset_idx = - thrust::upper_bound( - thrust::seq, gather_index_begin, gather_index_end, gather_index_begin[-1] + index) - - gather_index_begin; - // Get each sub_index in list in each row of gather_map. - auto sub_index = map_begin[index]; - auto list_is_null = values_lists_view.is_null(offset_idx); - auto list_size = list_is_null ? 0 : (value_offsets[offset_idx + 1] - value_offsets[offset_idx]); - auto wrapped_sub_index = sub_index < 0 ? sub_index + list_size : sub_index; - auto constexpr null_idx = cuda::std::numeric_limits::max(); - // Add sub_index to value_column offsets, to get gather indices of child of value_column - return (bounds_policy == out_of_bounds_policy::NULLIFY && out_of_bounds(sub_index, list_size)) - ? null_idx - : value_offsets[offset_idx] + wrapped_sub_index - value_offsets[0]; - }; + auto transformer = + cuda::proclaim_return_type([values_lists_view = *value_device_view, + value_offsets, + map_begin, + gather_index_begin, + gather_index_end, + bounds_policy, + out_of_bounds] __device__(size_type index) -> size_type { + // Get each row's offset. (Each row is a list). + auto offset_idx = + thrust::upper_bound( + thrust::seq, gather_index_begin, gather_index_end, gather_index_begin[-1] + index) - + gather_index_begin; + // Get each sub_index in list in each row of gather_map. + auto sub_index = map_begin[index]; + auto list_is_null = values_lists_view.is_null(offset_idx); + auto list_size = + list_is_null ? 0 : (value_offsets[offset_idx + 1] - value_offsets[offset_idx]); + auto wrapped_sub_index = sub_index < 0 ? sub_index + list_size : sub_index; + auto constexpr null_idx = cuda::std::numeric_limits::max(); + // Add sub_index to value_column offsets, to get gather indices of child of value_column + return (bounds_policy == out_of_bounds_policy::NULLIFY && out_of_bounds(sub_index, list_size)) + ? null_idx + : value_offsets[offset_idx] + wrapped_sub_index - value_offsets[0]; + }); auto child_gather_index_begin = cudf::detail::make_counting_transform_iterator(0, transformer); // Call gather on child of value_column diff --git a/cpp/src/lists/dremel.cu b/cpp/src/lists/dremel.cu index 2b1978bec80..ea539bb8247 100644 --- a/cpp/src/lists/dremel.cu +++ b/cpp/src/lists/dremel.cu @@ -34,6 +34,8 @@ #include #include +#include + namespace cudf::detail { namespace { /** @@ -338,8 +340,10 @@ dremel_data get_encoding(column_view h_col, // Scan to get distance by which each offset value is shifted due to the insertion of empties auto scan_it = cudf::detail::make_counting_transform_iterator( column_offsets[level], - [off = lcv.offsets().data(), size = lcv.offsets().size()] __device__( - auto i) -> int { return (i + 1 < size) && (off[i] == off[i + 1]); }); + cuda::proclaim_return_type([off = lcv.offsets().data(), + size = lcv.offsets().size()] __device__(auto i) -> int { + return (i + 1 < size) && (off[i] == off[i + 1]); + })); rmm::device_uvector scan_out(offset_size_at_level, stream); thrust::exclusive_scan( rmm::exec_policy(stream), scan_it, scan_it + offset_size_at_level, scan_out.begin()); @@ -375,10 +379,11 @@ dremel_data get_encoding(column_view h_col, auto [empties, empties_idx, empties_size] = get_empties(nesting_levels[level], column_offsets[level], column_ends[level]); - auto offset_transformer = [new_child_offsets = new_offsets.data(), - child_start = column_offsets[level + 1]] __device__(auto x) { - return new_child_offsets[x - child_start]; // (x - child's offset) - }; + auto offset_transformer = cuda::proclaim_return_type( + [new_child_offsets = new_offsets.data(), + child_start = column_offsets[level + 1]] __device__(auto x) { + return new_child_offsets[x - child_start]; // (x - child's offset) + }); // We will be reading from old rep_levels and writing again to rep_levels. Swap the current // rep values into temp_rep_vals so it can become the input and rep_levels can again be output. @@ -423,8 +428,10 @@ dremel_data get_encoding(column_view h_col, // level value fof an empty list auto scan_it = cudf::detail::make_counting_transform_iterator( column_offsets[level], - [off = lcv.offsets().data(), size = lcv.offsets().size()] __device__( - auto i) -> int { return (i + 1 < size) && (off[i] == off[i + 1]); }); + cuda::proclaim_return_type([off = lcv.offsets().data(), + size = lcv.offsets().size()] __device__(auto i) -> int { + return (i + 1 < size) && (off[i] == off[i + 1]); + })); rmm::device_uvector scan_out(offset_size_at_level, stream); thrust::exclusive_scan( rmm::exec_policy(stream), scan_it, scan_it + offset_size_at_level, scan_out.begin()); diff --git a/cpp/src/lists/explode.cu b/cpp/src/lists/explode.cu index 4db3254f201..cdb7857b74a 100644 --- a/cpp/src/lists/explode.cu +++ b/cpp/src/lists/explode.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,6 +38,8 @@ #include #include +#include + #include #include @@ -122,7 +124,9 @@ std::unique_ptr
explode(table_view const& input_table, auto offsets = explode_col.offsets_begin(); // offsets + 1 here to skip the 0th offset, which removes a - 1 operation later. auto offsets_minus_one = thrust::make_transform_iterator( - thrust::next(offsets), [offsets] __device__(auto i) { return (i - offsets[0]) - 1; }); + thrust::next(offsets), cuda::proclaim_return_type([offsets] __device__(auto i) { + return (i - offsets[0]) - 1; + })); auto counting_iter = thrust::make_counting_iterator(0); // This looks like an off-by-one bug, but what is going on here is that we need to reduce each @@ -158,7 +162,9 @@ std::unique_ptr
explode_position(table_view const& input_table, auto offsets = explode_col.offsets_begin(); // offsets + 1 here to skip the 0th offset, which removes a - 1 operation later. auto offsets_minus_one = thrust::make_transform_iterator( - offsets + 1, [offsets] __device__(auto i) { return (i - offsets[0]) - 1; }); + offsets + 1, cuda::proclaim_return_type([offsets] __device__(auto i) { + return (i - offsets[0]) - 1; + })); auto counting_iter = thrust::make_counting_iterator(0); rmm::device_uvector pos(sliced_child.size(), stream, mr); @@ -171,16 +177,17 @@ std::unique_ptr
explode_position(table_view const& input_table, counting_iter, counting_iter + gather_map.size(), gather_map.begin(), - [position_array = pos.data(), - offsets_minus_one, - offsets, - offset_size = explode_col.size()] __device__(auto idx) -> size_type { + cuda::proclaim_return_type([position_array = pos.data(), + offsets_minus_one, + offsets, + offset_size = + explode_col.size()] __device__(auto idx) -> size_type { auto lb_idx = thrust::distance( offsets_minus_one, thrust::lower_bound(thrust::seq, offsets_minus_one, offsets_minus_one + offset_size, idx)); position_array[idx] = idx - (offsets[lb_idx] - offsets[0]); return lb_idx; - }); + })); return build_table(input_table, explode_column_idx, @@ -208,9 +215,10 @@ std::unique_ptr
explode_outer(table_view const& input_table, auto null_or_empty = thrust::make_transform_iterator( thrust::make_counting_iterator(0), - [offsets, offsets_size = explode_col.size() - 1] __device__(int idx) { - return (idx > offsets_size || (offsets[idx + 1] != offsets[idx])) ? 0 : 1; - }); + cuda::proclaim_return_type( + [offsets, offsets_size = explode_col.size() - 1] __device__(int idx) { + return (idx > offsets_size || (offsets[idx + 1] != offsets[idx])) ? 0 : 1; + })); thrust::inclusive_scan(rmm::exec_policy(stream), null_or_empty, null_or_empty + explode_col.size(), @@ -233,7 +241,9 @@ std::unique_ptr
explode_outer(table_view const& input_table, // offsets + 1 here to skip the 0th offset, which removes a - 1 operation later. auto offsets_minus_one = thrust::make_transform_iterator( - thrust::next(offsets), [offsets] __device__(auto i) { return (i - offsets[0]) - 1; }); + thrust::next(offsets), cuda::proclaim_return_type([offsets] __device__(auto i) { + return (i - offsets[0]) - 1; + })); auto fill_gather_maps = [offsets_minus_one, gather_map_p = gather_map.begin(), diff --git a/cpp/src/lists/interleave_columns.cu b/cpp/src/lists/interleave_columns.cu index e80d63939ea..90041dbd46e 100644 --- a/cpp/src/lists/interleave_columns.cu +++ b/cpp/src/lists/interleave_columns.cu @@ -38,6 +38,8 @@ #include #include +#include + namespace cudf { namespace lists { namespace detail { @@ -71,10 +73,10 @@ generate_list_offsets_and_validities(table_view const& input, thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_output_lists), d_offsets, - [num_cols, - table_dv = *table_dv_ptr, - d_validities = validities.begin(), - has_null_mask] __device__(size_type const idx) { + cuda::proclaim_return_type([num_cols, + table_dv = *table_dv_ptr, + d_validities = validities.begin(), + has_null_mask] __device__(size_type const idx) { auto const col_id = idx % num_cols; auto const list_id = idx / num_cols; auto const& lists_col = table_dv.column(col_id); @@ -83,7 +85,7 @@ generate_list_offsets_and_validities(table_view const& input, lists_col.child(lists_column_view::offsets_column_index).template data() + lists_col.offset(); return list_offsets[list_id + 1] - list_offsets[list_id]; - }); + })); // Compute offsets from sizes. thrust::exclusive_scan( @@ -110,11 +112,11 @@ std::unique_ptr concatenate_and_gather_lists(host_span([num_cols, num_input_rows] __device__(auto const idx) { auto const source_col_idx = idx % num_cols; auto const source_row_idx = idx / num_cols; return source_col_idx * num_input_rows + source_row_idx; - }); + })); // The gather API should be able to handle any data type for the input columns. auto result = cudf::detail::gather(table_view{{concatenated_col->view()}}, diff --git a/cpp/src/merge/merge.cu b/cpp/src/merge/merge.cu index ee29c207cf1..0d30230de28 100644 --- a/cpp/src/merge/merge.cu +++ b/cpp/src/merge/merge.cu @@ -50,6 +50,8 @@ #include #include +#include + #include #include @@ -401,10 +403,11 @@ struct column_merger { row_order_.begin(), row_order_.end(), merged_view.begin(), - [d_lcol, d_rcol] __device__(index_type const& index_pair) { - auto const [side, index] = index_pair; - return side == side::LEFT ? d_lcol[index] : d_rcol[index]; - }); + cuda::proclaim_return_type( + [d_lcol, d_rcol] __device__(index_type const& index_pair) { + auto const [side, index] = index_pair; + return side == side::LEFT ? d_lcol[index] : d_rcol[index]; + })); // CAVEAT: conditional call below is erroneous without // set_null_mask() call (see TODO above): @@ -477,10 +480,12 @@ std::unique_ptr column_merger::operator()( auto concatenated_list = cudf::lists::detail::concatenate(columns, stream, mr); auto const iter_gather = cudf::detail::make_counting_transform_iterator( - 0, [row_order = row_order_.data(), lsize = lcol.size()] __device__(auto const idx) { - auto const [side, index] = row_order[idx]; - return side == side::LEFT ? index : lsize + index; - }); + 0, + cuda::proclaim_return_type( + [row_order = row_order_.data(), lsize = lcol.size()] __device__(auto const idx) { + auto const [side, index] = row_order[idx]; + return side == side::LEFT ? index : lsize + index; + })); auto result = cudf::detail::gather(table_view{{concatenated_list->view()}}, iter_gather, diff --git a/cpp/src/partitioning/round_robin.cu b/cpp/src/partitioning/round_robin.cu index 32c72f61741..c615f08ff12 100644 --- a/cpp/src/partitioning/round_robin.cu +++ b/cpp/src/partitioning/round_robin.cu @@ -40,6 +40,8 @@ #include #include +#include + #include #include // for std::ceil() #include @@ -89,9 +91,10 @@ std::pair, std::vector> degenerate // iterator for partition index rotated right by start_partition positions: auto rotated_iter_begin = thrust::make_transform_iterator( thrust::make_counting_iterator(0), - [num_partitions, start_partition] __device__(auto index) { - return (index + num_partitions - start_partition) % num_partitions; - }); + cuda::proclaim_return_type( + [num_partitions, start_partition] __device__(auto index) { + return (index + num_partitions - start_partition) % num_partitions; + })); if (num_partitions == nrows) { rmm::device_uvector partition_offsets(num_partitions, stream); @@ -131,7 +134,9 @@ std::pair, std::vector> degenerate // this composes rotated_iter transform (above) iterator with // calculating number of edges of transposed bi-graph: auto nedges_iter_begin = thrust::make_transform_iterator( - rotated_iter_begin, [nrows] __device__(auto index) { return (index < nrows ? 1 : 0); }); + rotated_iter_begin, + cuda::proclaim_return_type( + [nrows] __device__(auto index) { return (index < nrows ? 1 : 0); })); // offsets (part 2: compute partition offsets): rmm::device_uvector partition_offsets(num_partitions, stream); @@ -205,12 +210,12 @@ std::pair, std::vector> round_robin_part auto iter_begin = thrust::make_transform_iterator( thrust::make_counting_iterator(0), - [nrows, - num_partitions, - max_partition_size, - num_partitions_max_size, - total_max_partitions_size, - delta] __device__(auto index0) { + cuda::proclaim_return_type([nrows, + num_partitions, + max_partition_size, + num_partitions_max_size, + total_max_partitions_size, + delta] __device__(auto index0) { // rotate original index right by delta positions; // this is the effect of applying start_partition: // @@ -230,7 +235,7 @@ std::pair, std::vector> round_robin_part : num_partitions_max_size + (rotated_index - total_max_partitions_size) / (max_partition_size - 1)); return num_partitions * index_within_partition + partition_index; - }); + })); auto uniq_tbl = cudf::detail::gather( input, iter_begin, iter_begin + nrows, cudf::out_of_bounds_policy::DONT_CHECK, stream, mr); diff --git a/cpp/src/quantiles/quantile.cu b/cpp/src/quantiles/quantile.cu index 4a9c2e3a902..946ebd479c5 100644 --- a/cpp/src/quantiles/quantile.cu +++ b/cpp/src/quantiles/quantile.cu @@ -39,6 +39,8 @@ #include #include +#include + #include #include @@ -95,9 +97,11 @@ struct quantile_functor { q_device.begin(), q_device.end(), d_output->template begin(), - [sorted_data, interp = interp, size = size] __device__(double q) { - return select_quantile_data(sorted_data, size, q, interp); - }); + cuda::proclaim_return_type( + [sorted_data, interp = interp, size = size] __device__(double q) { + return select_quantile_data( + sorted_data, size, q, interp); + })); } else { auto sorted_data = thrust::make_permutation_iterator( dictionary::detail::make_dictionary_iterator(*d_input), ordered_indices); @@ -105,15 +109,18 @@ struct quantile_functor { q_device.begin(), q_device.end(), d_output->template begin(), - [sorted_data, interp = interp, size = size] __device__(double q) { - return select_quantile_data(sorted_data, size, q, interp); - }); + cuda::proclaim_return_type( + [sorted_data, interp = interp, size = size] __device__(double q) { + return select_quantile_data( + sorted_data, size, q, interp); + })); } if (input.nullable()) { auto sorted_validity = thrust::make_transform_iterator( ordered_indices, - [input = *d_input] __device__(size_type idx) { return input.is_valid_nocheck(idx); }); + cuda::proclaim_return_type( + [input = *d_input] __device__(size_type idx) { return input.is_valid_nocheck(idx); })); auto [mask, null_count] = valid_if( q_device.begin(), diff --git a/cpp/src/quantiles/quantiles.cu b/cpp/src/quantiles/quantiles.cu index c6760e77403..f55e9c4cb6a 100644 --- a/cpp/src/quantiles/quantiles.cu +++ b/cpp/src/quantiles/quantiles.cu @@ -31,6 +31,8 @@ #include #include +#include + #include #include @@ -44,10 +46,11 @@ std::unique_ptr
quantiles(table_view const& input, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - auto quantile_idx_lookup = [sortmap, interp, size = input.num_rows()] __device__(double q) { - auto selector = [sortmap] __device__(auto idx) { return sortmap[idx]; }; - return detail::select_quantile(selector, size, q, interp); - }; + auto quantile_idx_lookup = cuda::proclaim_return_type( + [sortmap, interp, size = input.num_rows()] __device__(double q) { + auto selector = [sortmap] __device__(auto idx) { return sortmap[idx]; }; + return detail::select_quantile(selector, size, q, interp); + }); auto const q_device = cudf::detail::make_device_uvector_async(q, stream, rmm::mr::get_current_device_resource()); diff --git a/cpp/src/quantiles/tdigest/tdigest.cu b/cpp/src/quantiles/tdigest/tdigest.cu index 79a25f79f60..4764ac4d87a 100644 --- a/cpp/src/quantiles/tdigest/tdigest.cu +++ b/cpp/src/quantiles/tdigest/tdigest.cu @@ -40,6 +40,8 @@ #include #include +#include + using namespace cudf::tdigest; namespace cudf { @@ -199,12 +201,13 @@ std::unique_ptr compute_approx_percentiles(tdigest_column_view const& in rmm::mr::get_current_device_resource()); auto keys = cudf::detail::make_counting_transform_iterator( 0, - [offsets_begin = offsets.begin(), - offsets_end = offsets.end()] __device__(size_type i) { - return thrust::distance( - offsets_begin, - thrust::prev(thrust::upper_bound(thrust::seq, offsets_begin, offsets_end, i))); - }); + cuda::proclaim_return_type( + [offsets_begin = offsets.begin(), + offsets_end = offsets.end()] __device__(size_type i) { + return thrust::distance( + offsets_begin, + thrust::prev(thrust::upper_bound(thrust::seq, offsets_begin, offsets_end, i))); + })); thrust::inclusive_scan_by_key(rmm::exec_policy(stream), keys, keys + weight.size(), @@ -381,7 +384,8 @@ std::unique_ptr percentile_approx(tdigest_column_view const& input, auto [bitmask, null_count] = [stream, mr, &tdv]() { auto tdigest_is_empty = thrust::make_transform_iterator( detail::size_begin(tdv), - [] __device__(size_type tdigest_size) -> size_type { return tdigest_size == 0; }); + cuda::proclaim_return_type( + [] __device__(size_type tdigest_size) -> size_type { return tdigest_size == 0; })); auto const null_count = thrust::reduce(rmm::exec_policy(stream), tdigest_is_empty, tdigest_is_empty + tdv.size(), 0); if (null_count == 0) { diff --git a/cpp/src/quantiles/tdigest/tdigest_aggregation.cu b/cpp/src/quantiles/tdigest/tdigest_aggregation.cu index 44a13c450ab..3ccef38715b 100644 --- a/cpp/src/quantiles/tdigest/tdigest_aggregation.cu +++ b/cpp/src/quantiles/tdigest/tdigest_aggregation.cu @@ -52,6 +52,8 @@ #include #include +#include + namespace cudf { namespace tdigest { namespace detail { @@ -536,11 +538,13 @@ generate_group_cluster_info(int delta, // generate group cluster offsets (where the clusters for a given group start and end) auto group_cluster_offsets = cudf::make_numeric_column( - data_type{type_id::INT32}, num_groups + 1, mask_state::UNALLOCATED, stream, mr); + data_type{type_to_id()}, num_groups + 1, mask_state::UNALLOCATED, stream, mr); auto cluster_size = cudf::detail::make_counting_transform_iterator( - 0, [group_num_clusters = group_num_clusters.begin(), num_groups] __device__(size_type index) { - return index == num_groups ? 0 : group_num_clusters[index]; - }); + 0, + cuda::proclaim_return_type( + [group_num_clusters = group_num_clusters.begin(), num_groups] __device__(size_type index) { + return index == num_groups ? 0 : group_num_clusters[index]; + })); thrust::exclusive_scan(rmm::exec_policy(stream), cluster_size, cluster_size + num_groups + 1, @@ -584,8 +588,10 @@ std::unique_ptr build_output_column(size_type num_rows, return weights[i] == 0; }; // whether or not this particular tdigest is a stub - auto is_stub_digest = [offsets = offsets->view().begin(), is_stub_weight] __device__( - size_type i) { return is_stub_weight(offsets[i]) ? 1 : 0; }; + auto is_stub_digest = cuda::proclaim_return_type( + [offsets = offsets->view().begin(), is_stub_weight] __device__(size_type i) { + return is_stub_weight(offsets[i]) ? 1 : 0; + }); size_type const num_stubs = [&]() { if (!has_nulls) { return 0; } @@ -623,17 +629,19 @@ std::unique_ptr build_output_column(size_type num_rows, // adjust offsets. rmm::device_uvector sizes(num_rows, stream); - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(0) + num_rows, - sizes.begin(), - [offsets = offsets->view().begin()] __device__(size_type i) { - return offsets[i + 1] - offsets[i]; - }); + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(0) + num_rows, + sizes.begin(), + cuda::proclaim_return_type([offsets = offsets->view().begin()] __device__( + size_type i) { return offsets[i + 1] - offsets[i]; })); auto iter = cudf::detail::make_counting_transform_iterator( - 0, [sizes = sizes.begin(), is_stub_digest, num_rows] __device__(size_type i) { - return i == num_rows || is_stub_digest(i) ? 0 : sizes[i]; - }); + 0, + cuda::proclaim_return_type( + [sizes = sizes.begin(), is_stub_digest, num_rows] __device__(size_type i) { + return i == num_rows || is_stub_digest(i) ? 0 : sizes[i]; + })); thrust::exclusive_scan(rmm::exec_policy(stream), iter, iter + num_rows + 1, @@ -651,6 +659,39 @@ std::unique_ptr build_output_column(size_type num_rows, mr); } +template +struct compute_tdigests_keys_fn { + int const delta; + double const* group_cluster_wl; + size_type const* group_cluster_offsets; + CumulativeWeight group_cumulative_weight; + + __device__ size_type operator()(size_type value_index) + { + // get group index, relative value index within the group and cumulative weight. + [[maybe_unused]] auto [group_index, relative_value_index, cumulative_weight] = + group_cumulative_weight(value_index); + + auto const num_clusters = + group_cluster_offsets[group_index + 1] - group_cluster_offsets[group_index]; + if (num_clusters == 0) { return group_cluster_offsets[group_index]; } + + // compute start of cluster weight limits for this group + double const* weight_limits = group_cluster_wl + group_cluster_offsets[group_index]; + + // local cluster index + size_type const group_cluster_index = + min(num_clusters - 1, + static_cast( + thrust::lower_bound( + thrust::seq, weight_limits, weight_limits + num_clusters, cumulative_weight) - + weight_limits)); + + // add the cluster offset to generate a globally unique key + return group_cluster_index + group_cluster_offsets[group_index]; + } +}; + /** * @brief Compute a column of tdigests. * @@ -715,32 +756,10 @@ std::unique_ptr compute_tdigests(int delta, // between the groups, so we add our group start offset. auto keys = thrust::make_transform_iterator( thrust::make_counting_iterator(0), - [delta, - group_cluster_wl = group_cluster_wl.data(), - group_cluster_offsets = group_cluster_offsets->view().begin(), - group_cumulative_weight] __device__(size_type value_index) -> size_type { - // get group index, relative value index within the group and cumulative weight. - [[maybe_unused]] auto [group_index, relative_value_index, cumulative_weight] = - group_cumulative_weight(value_index); - - auto const num_clusters = - group_cluster_offsets[group_index + 1] - group_cluster_offsets[group_index]; - if (num_clusters == 0) { return group_cluster_offsets[group_index]; } - - // compute start of cluster weight limits for this group - double const* weight_limits = group_cluster_wl + group_cluster_offsets[group_index]; - - // local cluster index - size_type const group_cluster_index = - min(num_clusters - 1, - static_cast( - thrust::lower_bound( - thrust::seq, weight_limits, weight_limits + num_clusters, cumulative_weight) - - weight_limits)); - - // add the cluster offset to generate a globally unique key - return group_cluster_index + group_cluster_offsets[group_index]; - }); + compute_tdigests_keys_fn{delta, + group_cluster_wl.data(), + group_cluster_offsets->view().begin(), + group_cumulative_weight}); // mean and weight data auto centroid_means = cudf::make_numeric_column( @@ -1205,6 +1224,11 @@ std::unique_ptr reduce_tdigest(column_view const& col, col.type(), typed_reduce_tdigest{}, sorted->get_column(0), delta, stream, mr); } +struct group_offsets_fn { + size_type const size; + CUDF_HOST_DEVICE size_type operator()(size_type i) const { return i == 0 ? 0 : size; } +}; + std::unique_ptr reduce_merge_tdigest(column_view const& input, int max_centroids, rmm::cuda_stream_view stream, @@ -1214,11 +1238,10 @@ std::unique_ptr reduce_merge_tdigest(column_view const& input, if (input.size() == 0) { return cudf::tdigest::detail::make_empty_tdigest_scalar(stream, mr); } - auto h_group_offsets = cudf::detail::make_counting_transform_iterator( - 0, [size = input.size()](size_type i) { return i == 0 ? 0 : size; }); - auto group_offsets = cudf::detail::make_counting_transform_iterator( - 0, [size = input.size()] __device__(size_type i) { return i == 0 ? 0 : size; }); - auto group_labels = thrust::make_constant_iterator(0); + auto group_offsets_ = group_offsets_fn{input.size()}; + auto h_group_offsets = cudf::detail::make_counting_transform_iterator(0, group_offsets_); + auto group_offsets = cudf::detail::make_counting_transform_iterator(0, group_offsets_); + auto group_labels = thrust::make_constant_iterator(0); return to_tdigest_scalar(merge_tdigests(tdv, h_group_offsets, group_offsets, diff --git a/cpp/src/reductions/histogram.cu b/cpp/src/reductions/histogram.cu index fa84bbeb25d..218e2e57420 100644 --- a/cpp/src/reductions/histogram.cu +++ b/cpp/src/reductions/histogram.cu @@ -26,6 +26,7 @@ #include #include +#include #include @@ -178,7 +179,9 @@ compute_row_frequencies(table_view const& input, auto const row_comp = cudf::experimental::row::equality::self_comparator(preprocessed_input); auto const pair_iter = cudf::detail::make_counting_transform_iterator( - size_type{0}, [] __device__(size_type const i) { return cuco::make_pair(i, i); }); + size_type{0}, + cuda::proclaim_return_type>( + [] __device__(size_type const i) { return cuco::make_pair(i, i); })); // Always compare NaNs as equal. using nan_equal_comparator = diff --git a/cpp/src/reductions/nth_element.cu b/cpp/src/reductions/nth_element.cu index ef58ec3f42e..82035fa78ce 100644 --- a/cpp/src/reductions/nth_element.cu +++ b/cpp/src/reductions/nth_element.cu @@ -28,6 +28,8 @@ #include #include +#include + namespace cudf::reduction::detail { std::unique_ptr nth_element(column_view const& col, @@ -45,7 +47,9 @@ std::unique_ptr nth_element(column_view const& col, auto dcol = column_device_view::create(col, stream); auto bitmask_iterator = thrust::make_transform_iterator(cudf::detail::make_validity_iterator(*dcol), - [] __device__(auto b) { return static_cast(b); }); + cuda::proclaim_return_type([] __device__(auto b) { + return static_cast(b); + })); rmm::device_uvector null_skipped_index(col.size(), stream); // null skipped index for valids only. thrust::inclusive_scan(rmm::exec_policy(stream), diff --git a/cpp/src/reductions/scan/scan_exclusive.cu b/cpp/src/reductions/scan/scan_exclusive.cu index 3fb1fc64f61..4a96c4efeed 100644 --- a/cpp/src/reductions/scan/scan_exclusive.cu +++ b/cpp/src/reductions/scan/scan_exclusive.cu @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -28,6 +29,8 @@ #include +#include + namespace cudf { namespace detail { namespace { @@ -64,8 +67,11 @@ struct scan_dispatcher { auto identity = Op::template identity(); auto begin = make_null_replacement_iterator(*d_input, identity, input.has_nulls()); + + // CUB 2.0.0 requires that the binary operator returns the same type as the identity. + auto const binary_op = cudf::detail::cast_functor(Op{}); thrust::exclusive_scan( - rmm::exec_policy(stream), begin, begin + input.size(), output.data(), identity, Op{}); + rmm::exec_policy(stream), begin, begin + input.size(), output.data(), identity, binary_op); CUDF_CHECK_CUDA(stream.value()); return output_column; diff --git a/cpp/src/reductions/scan/scan_inclusive.cu b/cpp/src/reductions/scan/scan_inclusive.cu index 91aa1cac487..42664de8ed7 100644 --- a/cpp/src/reductions/scan/scan_inclusive.cu +++ b/cpp/src/reductions/scan/scan_inclusive.cu @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -82,8 +83,11 @@ struct scan_functor { auto d_input = column_device_view::create(input_view, stream); auto const begin = make_null_replacement_iterator(*d_input, Op::template identity(), input_view.has_nulls()); + + // CUB 2.0.0 requires that the binary operator returns the same type as the identity. + auto const binary_op = cudf::detail::cast_functor(Op{}); thrust::inclusive_scan( - rmm::exec_policy(stream), begin, begin + input_view.size(), result.data(), Op{}); + rmm::exec_policy(stream), begin, begin + input_view.size(), result.data(), binary_op); CUDF_CHECK_CUDA(stream.value()); return output_column; diff --git a/cpp/src/reductions/segmented/simple.cuh b/cpp/src/reductions/segmented/simple.cuh index 05a871ed4fb..9bd7260e64f 100644 --- a/cpp/src/reductions/segmented/simple.cuh +++ b/cpp/src/reductions/segmented/simple.cuh @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -39,6 +40,8 @@ #include #include +#include + #include #include @@ -76,7 +79,7 @@ std::unique_ptr simple_segmented_reduction( auto simple_op = Op{}; auto const num_segments = offsets.size() - 1; - auto const binary_op = simple_op.get_binary_op(); + auto const binary_op = cudf::detail::cast_functor(simple_op.get_binary_op()); // Cast initial value ResultType initial_value = [&] { @@ -241,7 +244,8 @@ std::unique_ptr fixed_point_segmented_reduction( d_col->begin(), d_col->end(), d_col->begin(), - [new_scale] __device__(auto fp) { return fp.rescaled(new_scale); }); + cuda::proclaim_return_type( + [new_scale] __device__(auto fp) { return fp.rescaled(new_scale); })); return new_scale; } diff --git a/cpp/src/reductions/simple.cuh b/cpp/src/reductions/simple.cuh index 9bb01c72d8d..00bf93edaf2 100644 --- a/cpp/src/reductions/simple.cuh +++ b/cpp/src/reductions/simple.cuh @@ -19,6 +19,7 @@ #include "nested_type_minmax_util.cuh" #include +#include #include #include #include @@ -314,11 +315,12 @@ struct same_element_type_dispatcher { // We will do reduction to find the ARGMIN/ARGMAX index, then return the element at that index. auto const binop_generator = cudf::reduction::detail::comparison_binop_generator::create(input, stream); + auto const binary_op = cudf::detail::cast_functor(binop_generator.binop()); auto const minmax_idx = thrust::reduce(rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(input.size()), size_type{0}, - binop_generator.binop()); + binary_op); return cudf::detail::get_element(input, minmax_idx, stream, mr); } diff --git a/cpp/src/replace/clamp.cu b/cpp/src/replace/clamp.cu index 0c934533d54..6852b19af44 100644 --- a/cpp/src/replace/clamp.cu +++ b/cpp/src/replace/clamp.cu @@ -44,6 +44,8 @@ #include #include +#include + namespace cudf { namespace detail { namespace { @@ -132,19 +134,20 @@ std::enable_if_t(), std::unique_ptr> clamp auto scalar_zip_itr = thrust::make_zip_iterator(thrust::make_tuple(lo_itr, lo_replace_itr, hi_itr, hi_replace_itr)); - auto trans = [] __device__(auto element_optional, auto scalar_tuple) { - if (element_optional.has_value()) { - auto lo_optional = thrust::get<0>(scalar_tuple); - auto hi_optional = thrust::get<2>(scalar_tuple); - if (lo_optional.has_value() and (*element_optional < *lo_optional)) { - return *(thrust::get<1>(scalar_tuple)); - } else if (hi_optional.has_value() and (*element_optional > *hi_optional)) { - return *(thrust::get<3>(scalar_tuple)); + auto trans = + cuda::proclaim_return_type([] __device__(auto element_optional, auto scalar_tuple) { + if (element_optional.has_value()) { + auto lo_optional = thrust::get<0>(scalar_tuple); + auto hi_optional = thrust::get<2>(scalar_tuple); + if (lo_optional.has_value() and (*element_optional < *lo_optional)) { + return *(thrust::get<1>(scalar_tuple)); + } else if (hi_optional.has_value() and (*element_optional > *hi_optional)) { + return *(thrust::get<3>(scalar_tuple)); + } } - } - return *element_optional; - }; + return *element_optional; + }); auto input_pair_iterator = make_optional_iterator(*input_device_view, nullate::DYNAMIC{input.has_nulls()}); diff --git a/cpp/src/reshape/interleave_columns.cu b/cpp/src/reshape/interleave_columns.cu index d803d786517..deb0acb4742 100644 --- a/cpp/src/reshape/interleave_columns.cu +++ b/cpp/src/reshape/interleave_columns.cu @@ -35,6 +35,8 @@ #include #include +#include + namespace cudf { namespace detail { namespace { @@ -175,14 +177,15 @@ struct interleave_columns_impl(source_col_idx).size_bytes() - : 0; - }; + auto offsets_transformer = + cuda::proclaim_return_type([num_columns, d_table] __device__(size_type idx) { + // First compute the column and the row this item belongs to + auto source_row_idx = idx % num_columns; + auto source_col_idx = idx / num_columns; + return d_table.column(source_row_idx).is_valid(source_col_idx) + ? d_table.column(source_row_idx).element(source_col_idx).size_bytes() + : 0; + }); auto offsets_transformer_itr = thrust::make_transform_iterator( thrust::make_counting_iterator(0), offsets_transformer); auto [offsets_column, bytes] = cudf::detail::make_offsets_child_column( @@ -234,10 +237,10 @@ struct interleave_columns_impl()>> { auto index_begin = thrust::make_counting_iterator(0); auto index_end = thrust::make_counting_iterator(output_size); - auto func_value = [input = *device_input, - divisor = input.num_columns()] __device__(size_type idx) { - return input.column(idx % divisor).element(idx / divisor); - }; + auto func_value = cuda::proclaim_return_type( + [input = *device_input, divisor = input.num_columns()] __device__(size_type idx) { + return input.column(idx % divisor).element(idx / divisor); + }); if (not create_mask) { thrust::transform( diff --git a/cpp/src/rolling/detail/lead_lag_nested.cuh b/cpp/src/rolling/detail/lead_lag_nested.cuh index d2fe9fabd1b..734f7d1f565 100644 --- a/cpp/src/rolling/detail/lead_lag_nested.cuh +++ b/cpp/src/rolling/detail/lead_lag_nested.cuh @@ -34,6 +34,8 @@ #include #include +#include + #include namespace cudf::detail { @@ -141,25 +143,27 @@ std::unique_ptr compute_lead_lag_for_nested(aggregation::Kind op, thrust::make_counting_iterator(size_type{0}), thrust::make_counting_iterator(size_type{input.size()}), gather_map.begin(), - [following, input_size, null_index, row_offset] __device__(size_type i) { - // Note: grouped_*rolling_window() trims preceding/following to - // the beginning/end of the group. `rolling_window()` does not. - // Must trim _following[i] so as not to go past the column end. - auto _following = min(following[i], input_size - i - 1); - return (row_offset > _following) ? null_index : (i + row_offset); - }); + cuda::proclaim_return_type( + [following, input_size, null_index, row_offset] __device__(size_type i) { + // Note: grouped_*rolling_window() trims preceding/following to + // the beginning/end of the group. `rolling_window()` does not. + // Must trim _following[i] so as not to go past the column end. + auto _following = min(following[i], input_size - i - 1); + return (row_offset > _following) ? null_index : (i + row_offset); + })); } else { thrust::transform(rmm::exec_policy(stream), thrust::make_counting_iterator(size_type{0}), thrust::make_counting_iterator(size_type{input.size()}), gather_map.begin(), - [preceding, input_size, null_index, row_offset] __device__(size_type i) { - // Note: grouped_*rolling_window() trims preceding/following to - // the beginning/end of the group. `rolling_window()` does not. - // Must trim _preceding[i] so as not to go past the column start. - auto _preceding = min(preceding[i], i + 1); - return (row_offset > (_preceding - 1)) ? null_index : (i - row_offset); - }); + cuda::proclaim_return_type( + [preceding, input_size, null_index, row_offset] __device__(size_type i) { + // Note: grouped_*rolling_window() trims preceding/following to + // the beginning/end of the group. `rolling_window()` does not. + // Must trim _preceding[i] so as not to go past the column start. + auto _preceding = min(preceding[i], i + 1); + return (row_offset > (_preceding - 1)) ? null_index : (i - row_offset); + })); } auto output_with_nulls = cudf::detail::gather(table_view{std::vector{input}}, diff --git a/cpp/src/rolling/detail/rolling_collect_list.cuh b/cpp/src/rolling/detail/rolling_collect_list.cuh index 39d15ed716f..22e55561eca 100644 --- a/cpp/src/rolling/detail/rolling_collect_list.cuh +++ b/cpp/src/rolling/detail/rolling_collect_list.cuh @@ -30,6 +30,8 @@ #include #include +#include + namespace cudf { namespace detail { @@ -71,9 +73,10 @@ std::unique_ptr create_collect_offsets(size_type input_size, preceding_begin + input_size, following_begin, mutable_sizes.begin(), - [min_periods] __device__(auto const preceding, auto const following) { - return (preceding + following) < min_periods ? 0 : (preceding + following); - }); + cuda::proclaim_return_type( + [min_periods] __device__(auto const preceding, auto const following) { + return (preceding + following) < min_periods ? 0 : (preceding + following); + })); // Convert `sizes` to an offsets column, via inclusive_scan(): auto offsets_column = std::get<0>(cudf::detail::make_offsets_child_column( @@ -115,17 +118,18 @@ std::unique_ptr create_collect_gather_map(column_view const& child_offse thrust::make_counting_iterator(0), thrust::make_counting_iterator(per_row_mapping.size()), gather_map->mutable_view().template begin(), - [d_offsets = - child_offsets.template begin(), // E.g. [0, 2, 5, 8, 11, 13] - d_groups = - per_row_mapping.template begin(), // E.g. [0,0, 1,1,1, 2,2,2, 3,3,3, 4,4] - d_prev = preceding_iter] __device__(auto i) { - auto group = d_groups[i]; - auto group_start_offset = d_offsets[group]; - auto relative_index = i - group_start_offset; - - return (group - d_prev[group] + 1) + relative_index; - }); + cuda::proclaim_return_type( + [d_offsets = + child_offsets.template begin(), // E.g. [0, 2, 5, 8, 11, 13] + d_groups = + per_row_mapping.template begin(), // E.g. [0,0, 1,1,1, 2,2,2, 3,3,3, 4,4] + d_prev = preceding_iter] __device__(auto i) { + auto group = d_groups[i]; + auto group_start_offset = d_offsets[group]; + auto relative_index = i - group_start_offset; + + return (group - d_prev[group] + 1) + relative_index; + })); return gather_map; } @@ -168,14 +172,16 @@ std::unique_ptr rolling_collect_list(column_view const& input, // column boundaries. // `grouped_rolling_window()` and `time_range_based_grouped_rolling_window() do. auto preceding_begin = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), [preceding_begin_raw] __device__(auto i) { + thrust::make_counting_iterator(0), + cuda::proclaim_return_type([preceding_begin_raw] __device__(auto i) { return thrust::min(preceding_begin_raw[i], i + 1); - }); - auto following_begin = - thrust::make_transform_iterator(thrust::make_counting_iterator(0), - [following_begin_raw, size = input.size()] __device__(auto i) { - return thrust::min(following_begin_raw[i], size - i - 1); - }); + })); + auto following_begin = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), + cuda::proclaim_return_type( + [following_begin_raw, size = input.size()] __device__(auto i) { + return thrust::min(following_begin_raw[i], size - i - 1); + })); // Materialize collect list's offsets. auto offsets = diff --git a/cpp/src/rolling/detail/rolling_fixed_window.cu b/cpp/src/rolling/detail/rolling_fixed_window.cu index e951db955e5..07ecf2730a0 100644 --- a/cpp/src/rolling/detail/rolling_fixed_window.cu +++ b/cpp/src/rolling/detail/rolling_fixed_window.cu @@ -23,6 +23,8 @@ #include +#include + namespace cudf::detail { // Applies a fixed-size rolling window function to the values in a column. @@ -63,14 +65,13 @@ std::unique_ptr rolling_window(column_view const& input, // E.g. If preceding_window == 2, then for a column of 5 elements, preceding_window will be: // [1, 2, 2, 2, 1] - auto const preceding_calc = [preceding_window] __device__(size_type i) { - return thrust::min(i + 1, preceding_window); - }; + auto const preceding_calc = cuda::proclaim_return_type( + [preceding_window] __device__(size_type i) { return thrust::min(i + 1, preceding_window); }); - auto const following_calc = [col_size = input.size(), - following_window] __device__(size_type i) { - return thrust::min(col_size - i - 1, following_window); - }; + auto const following_calc = cuda::proclaim_return_type( + [col_size = input.size(), following_window] __device__(size_type i) { + return thrust::min(col_size - i - 1, following_window); + }); auto const preceding_column = expand_to_column(preceding_calc, input.size(), stream); auto const following_column = expand_to_column(following_calc, input.size(), stream); diff --git a/cpp/src/rolling/detail/rolling_variable_window.cu b/cpp/src/rolling/detail/rolling_variable_window.cu index fcddabe54a4..85c5e5cb67e 100644 --- a/cpp/src/rolling/detail/rolling_variable_window.cu +++ b/cpp/src/rolling/detail/rolling_variable_window.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,8 @@ #include #include +#include + namespace cudf::detail { // Applies a variable-size rolling window function to the values in a column. @@ -64,13 +66,16 @@ std::unique_ptr rolling_window(column_view const& input, // E.g. If preceding_window == [2, 2, 2, 2, 2] for a column of 5 elements, the new // preceding_window will be: [1, 2, 2, 2, 1] auto const preceding_window_begin = cudf::detail::make_counting_transform_iterator( - 0, [preceding = preceding_window.begin()] __device__(size_type i) { - return thrust::min(i + 1, preceding[i]); - }); + 0, + cuda::proclaim_return_type( + [preceding = preceding_window.begin()] __device__(size_type i) { + return thrust::min(i + 1, preceding[i]); + })); auto const following_window_begin = cudf::detail::make_counting_transform_iterator( 0, - [col_size = input.size(), following = following_window.begin()] __device__( - size_type i) { return thrust::min(col_size - i - 1, following[i]); }); + cuda::proclaim_return_type( + [col_size = input.size(), following = following_window.begin()] __device__( + size_type i) { return thrust::min(col_size - i - 1, following[i]); })); return cudf::detail::rolling_window(input, empty_like(defaults_col)->view(), preceding_window_begin, diff --git a/cpp/src/rolling/grouped_rolling.cu b/cpp/src/rolling/grouped_rolling.cu index 7ac784bef43..aa009e47c2a 100644 --- a/cpp/src/rolling/grouped_rolling.cu +++ b/cpp/src/rolling/grouped_rolling.cu @@ -36,6 +36,8 @@ #include #include +#include + namespace cudf { std::unique_ptr grouped_rolling_window(table_view const& group_keys, column_view const& input, @@ -443,75 +445,75 @@ std::unique_ptr range_window_ASC(column_view const& input, auto [h_nulls_begin_idx, h_nulls_end_idx] = get_null_bounds_for_orderby_column(orderby_column); auto const p_orderby_device_view = cudf::column_device_view::create(orderby_column, stream); - auto const preceding_calculator = + auto const preceding_calculator = cuda::proclaim_return_type( [nulls_begin_idx = h_nulls_begin_idx, nulls_end_idx = h_nulls_end_idx, orderby_device_view = *p_orderby_device_view, preceding_window, preceding_window_is_unbounded] __device__(size_type idx) -> size_type { - if (preceding_window_is_unbounded) { - return idx + 1; // Technically `idx - 0 + 1`, - // where 0 == Group start, - // and 1 accounts for the current row - } - if (idx >= nulls_begin_idx && idx < nulls_end_idx) { - // Current row is in the null group. - // Must consider beginning of null-group as window start. - return idx - nulls_begin_idx + 1; - } - - auto const d_orderby = begin(orderby_device_view); - // orderby[idx] not null. Binary search the group, excluding null group. - // If nulls_begin_idx == 0, either - // 1. NULLS FIRST ordering: Binary search starts where nulls_end_idx. - // 2. NO NULLS: Binary search starts at 0 (also nulls_end_idx). - // Otherwise, NULLS LAST ordering. Start at 0. - auto const group_start = nulls_begin_idx == 0 ? nulls_end_idx : 0; - auto const lowest_in_window = compute_lowest_in_window(d_orderby, idx, preceding_window); - - return ((d_orderby + idx) - thrust::lower_bound(thrust::seq, - d_orderby + group_start, - d_orderby + idx, - lowest_in_window, - cudf::detail::nan_aware_less{})) + - 1; // Add 1, for `preceding` to account for current row. - }; + if (preceding_window_is_unbounded) { + return idx + 1; // Technically `idx - 0 + 1`, + // where 0 == Group start, + // and 1 accounts for the current row + } + if (idx >= nulls_begin_idx && idx < nulls_end_idx) { + // Current row is in the null group. + // Must consider beginning of null-group as window start. + return idx - nulls_begin_idx + 1; + } + + auto const d_orderby = begin(orderby_device_view); + // orderby[idx] not null. Binary search the group, excluding null group. + // If nulls_begin_idx == 0, either + // 1. NULLS FIRST ordering: Binary search starts where nulls_end_idx. + // 2. NO NULLS: Binary search starts at 0 (also nulls_end_idx). + // Otherwise, NULLS LAST ordering. Start at 0. + auto const group_start = nulls_begin_idx == 0 ? nulls_end_idx : 0; + auto const lowest_in_window = compute_lowest_in_window(d_orderby, idx, preceding_window); + + return ((d_orderby + idx) - thrust::lower_bound(thrust::seq, + d_orderby + group_start, + d_orderby + idx, + lowest_in_window, + cudf::detail::nan_aware_less{})) + + 1; // Add 1, for `preceding` to account for current row. + }); auto const preceding_column = cudf::detail::expand_to_column(preceding_calculator, input.size(), stream); - auto const following_calculator = + auto const following_calculator = cuda::proclaim_return_type( [nulls_begin_idx = h_nulls_begin_idx, nulls_end_idx = h_nulls_end_idx, num_rows = input.size(), orderby_device_view = *p_orderby_device_view, following_window, following_window_is_unbounded] __device__(size_type idx) -> size_type { - if (following_window_is_unbounded) { return num_rows - idx - 1; } - if (idx >= nulls_begin_idx && idx < nulls_end_idx) { - // Current row is in the null group. - // Window ends at the end of the null group. - return nulls_end_idx - idx - 1; - } - - auto const d_orderby = begin(orderby_device_view); - // orderby[idx] not null. Binary search the group, excluding null group. - // If nulls_begin_idx == 0, either - // 1. NULLS FIRST ordering: Binary search ends at num_rows. - // 2. NO NULLS: Binary search also ends at num_rows. - // Otherwise, NULLS LAST ordering. End at nulls_begin_idx. - - auto const group_end = nulls_begin_idx == 0 ? num_rows : nulls_begin_idx; - auto const highest_in_window = compute_highest_in_window(d_orderby, idx, following_window); - - return (thrust::upper_bound(thrust::seq, - d_orderby + idx, - d_orderby + group_end, - highest_in_window, - cudf::detail::nan_aware_less{}) - - (d_orderby + idx)) - - 1; - }; + if (following_window_is_unbounded) { return num_rows - idx - 1; } + if (idx >= nulls_begin_idx && idx < nulls_end_idx) { + // Current row is in the null group. + // Window ends at the end of the null group. + return nulls_end_idx - idx - 1; + } + + auto const d_orderby = begin(orderby_device_view); + // orderby[idx] not null. Binary search the group, excluding null group. + // If nulls_begin_idx == 0, either + // 1. NULLS FIRST ordering: Binary search ends at num_rows. + // 2. NO NULLS: Binary search also ends at num_rows. + // Otherwise, NULLS LAST ordering. End at nulls_begin_idx. + + auto const group_end = nulls_begin_idx == 0 ? num_rows : nulls_begin_idx; + auto const highest_in_window = compute_highest_in_window(d_orderby, idx, following_window); + + return (thrust::upper_bound(thrust::seq, + d_orderby + idx, + d_orderby + group_end, + highest_in_window, + cudf::detail::nan_aware_less{}) - + (d_orderby + idx)) - + 1; + }); auto const following_column = cudf::detail::expand_to_column(following_calculator, input.size(), stream); @@ -619,7 +621,7 @@ std::unique_ptr range_window_ASC(column_view const& input, get_null_bounds_for_orderby_column(orderby_column, group_offsets, stream); auto const p_orderby_device_view = cudf::column_device_view::create(orderby_column, stream); - auto const preceding_calculator = + auto const preceding_calculator = cuda::proclaim_return_type( [d_group_offsets = group_offsets.data(), d_group_labels = group_labels.data(), orderby_device_view = *p_orderby_device_view, @@ -627,42 +629,42 @@ std::unique_ptr range_window_ASC(column_view const& input, d_nulls_end = null_end.data(), preceding_window, preceding_window_is_unbounded] __device__(size_type idx) -> size_type { - auto const group_label = d_group_labels[idx]; - auto const group_start = d_group_offsets[group_label]; - auto const nulls_begin = d_nulls_begin[group_label]; - auto const nulls_end = d_nulls_end[group_label]; - - if (preceding_window_is_unbounded) { return idx - group_start + 1; } - - // If idx lies in the null-range, the window is the null range. - if (idx >= nulls_begin && idx < nulls_end) { - // Current row is in the null group. - // The window starts at the start of the null group. - return idx - nulls_begin + 1; - } - - auto const d_orderby = begin(orderby_device_view); - - // orderby[idx] not null. Search must exclude the null group. - // If nulls_begin == group_start, either of the following is true: - // 1. NULLS FIRST ordering: Search must begin at nulls_end. - // 2. NO NULLS: Search must begin at group_start (which also equals nulls_end.) - // Otherwise, NULLS LAST ordering. Search must start at nulls group_start. - auto const search_start = nulls_begin == group_start ? nulls_end : group_start; - auto const lowest_in_window = compute_lowest_in_window(d_orderby, idx, preceding_window); - - return ((d_orderby + idx) - thrust::lower_bound(thrust::seq, - d_orderby + search_start, - d_orderby + idx, - lowest_in_window, - cudf::detail::nan_aware_less{})) + - 1; // Add 1, for `preceding` to account for current row. - }; + auto const group_label = d_group_labels[idx]; + auto const group_start = d_group_offsets[group_label]; + auto const nulls_begin = d_nulls_begin[group_label]; + auto const nulls_end = d_nulls_end[group_label]; + + if (preceding_window_is_unbounded) { return idx - group_start + 1; } + + // If idx lies in the null-range, the window is the null range. + if (idx >= nulls_begin && idx < nulls_end) { + // Current row is in the null group. + // The window starts at the start of the null group. + return idx - nulls_begin + 1; + } + + auto const d_orderby = begin(orderby_device_view); + + // orderby[idx] not null. Search must exclude the null group. + // If nulls_begin == group_start, either of the following is true: + // 1. NULLS FIRST ordering: Search must begin at nulls_end. + // 2. NO NULLS: Search must begin at group_start (which also equals nulls_end.) + // Otherwise, NULLS LAST ordering. Search must start at nulls group_start. + auto const search_start = nulls_begin == group_start ? nulls_end : group_start; + auto const lowest_in_window = compute_lowest_in_window(d_orderby, idx, preceding_window); + + return ((d_orderby + idx) - thrust::lower_bound(thrust::seq, + d_orderby + search_start, + d_orderby + idx, + lowest_in_window, + cudf::detail::nan_aware_less{})) + + 1; // Add 1, for `preceding` to account for current row. + }); auto const preceding_column = cudf::detail::expand_to_column(preceding_calculator, input.size(), stream); - auto const following_calculator = + auto const following_calculator = cuda::proclaim_return_type( [d_group_offsets = group_offsets.data(), d_group_labels = group_labels.data(), orderby_device_view = *p_orderby_device_view, @@ -670,41 +672,41 @@ std::unique_ptr range_window_ASC(column_view const& input, d_nulls_end = null_end.data(), following_window, following_window_is_unbounded] __device__(size_type idx) -> size_type { - auto const group_label = d_group_labels[idx]; - auto const group_start = d_group_offsets[group_label]; - auto const group_end = - d_group_offsets[group_label + 1]; // Cannot fall off the end, since offsets - // is capped with `input.size()`. - auto const nulls_begin = d_nulls_begin[group_label]; - auto const nulls_end = d_nulls_end[group_label]; - - if (following_window_is_unbounded) { return (group_end - idx) - 1; } - - // If idx lies in the null-range, the window is the null range. - if (idx >= nulls_begin && idx < nulls_end) { - // Current row is in the null group. - // The window ends at the end of the null group. - return nulls_end - idx - 1; - } - - auto const d_orderby = begin(orderby_device_view); - - // orderby[idx] not null. Search must exclude the null group. - // If nulls_begin == group_start, either of the following is true: - // 1. NULLS FIRST ordering: Search ends at group_end. - // 2. NO NULLS: Search ends at group_end. - // Otherwise, NULLS LAST ordering. Search ends at nulls_begin. - auto const search_end = nulls_begin == group_start ? group_end : nulls_begin; - auto const highest_in_window = compute_highest_in_window(d_orderby, idx, following_window); - - return (thrust::upper_bound(thrust::seq, - d_orderby + idx, - d_orderby + search_end, - highest_in_window, - cudf::detail::nan_aware_less{}) - - (d_orderby + idx)) - - 1; - }; + auto const group_label = d_group_labels[idx]; + auto const group_start = d_group_offsets[group_label]; + auto const group_end = + d_group_offsets[group_label + 1]; // Cannot fall off the end, since offsets + // is capped with `input.size()`. + auto const nulls_begin = d_nulls_begin[group_label]; + auto const nulls_end = d_nulls_end[group_label]; + + if (following_window_is_unbounded) { return (group_end - idx) - 1; } + + // If idx lies in the null-range, the window is the null range. + if (idx >= nulls_begin && idx < nulls_end) { + // Current row is in the null group. + // The window ends at the end of the null group. + return nulls_end - idx - 1; + } + + auto const d_orderby = begin(orderby_device_view); + + // orderby[idx] not null. Search must exclude the null group. + // If nulls_begin == group_start, either of the following is true: + // 1. NULLS FIRST ordering: Search ends at group_end. + // 2. NO NULLS: Search ends at group_end. + // Otherwise, NULLS LAST ordering. Search ends at nulls_begin. + auto const search_end = nulls_begin == group_start ? group_end : nulls_begin; + auto const highest_in_window = compute_highest_in_window(d_orderby, idx, following_window); + + return (thrust::upper_bound(thrust::seq, + d_orderby + idx, + d_orderby + search_end, + highest_in_window, + cudf::detail::nan_aware_less{}) - + (d_orderby + idx)) - + 1; + }); auto const following_column = cudf::detail::expand_to_column(following_calculator, input.size(), stream); @@ -732,75 +734,75 @@ std::unique_ptr range_window_DESC(column_view const& input, auto [h_nulls_begin_idx, h_nulls_end_idx] = get_null_bounds_for_orderby_column(orderby_column); auto const p_orderby_device_view = cudf::column_device_view::create(orderby_column, stream); - auto const preceding_calculator = + auto const preceding_calculator = cuda::proclaim_return_type( [nulls_begin_idx = h_nulls_begin_idx, nulls_end_idx = h_nulls_end_idx, orderby_device_view = *p_orderby_device_view, preceding_window, preceding_window_is_unbounded] __device__(size_type idx) -> size_type { - if (preceding_window_is_unbounded) { - return idx + 1; // Technically `idx - 0 + 1`, - // where 0 == Group start, - // and 1 accounts for the current row - } - if (idx >= nulls_begin_idx && idx < nulls_end_idx) { - // Current row is in the null group. - // Must consider beginning of null-group as window start. - return idx - nulls_begin_idx + 1; - } - - auto const d_orderby = begin(orderby_device_view); - // orderby[idx] not null. Binary search the group, excluding null group. - // If nulls_begin_idx == 0, either - // 1. NULLS FIRST ordering: Binary search starts where nulls_end_idx. - // 2. NO NULLS: Binary search starts at 0 (also nulls_end_idx). - // Otherwise, NULLS LAST ordering. Start at 0. - auto const group_start = nulls_begin_idx == 0 ? nulls_end_idx : 0; - auto const highest_in_window = compute_highest_in_window(d_orderby, idx, preceding_window); - - return ((d_orderby + idx) - thrust::lower_bound(thrust::seq, - d_orderby + group_start, - d_orderby + idx, - highest_in_window, - cudf::detail::nan_aware_greater{})) + - 1; // Add 1, for `preceding` to account for current row. - }; + if (preceding_window_is_unbounded) { + return idx + 1; // Technically `idx - 0 + 1`, + // where 0 == Group start, + // and 1 accounts for the current row + } + if (idx >= nulls_begin_idx && idx < nulls_end_idx) { + // Current row is in the null group. + // Must consider beginning of null-group as window start. + return idx - nulls_begin_idx + 1; + } + + auto const d_orderby = begin(orderby_device_view); + // orderby[idx] not null. Binary search the group, excluding null group. + // If nulls_begin_idx == 0, either + // 1. NULLS FIRST ordering: Binary search starts where nulls_end_idx. + // 2. NO NULLS: Binary search starts at 0 (also nulls_end_idx). + // Otherwise, NULLS LAST ordering. Start at 0. + auto const group_start = nulls_begin_idx == 0 ? nulls_end_idx : 0; + auto const highest_in_window = compute_highest_in_window(d_orderby, idx, preceding_window); + + return ((d_orderby + idx) - thrust::lower_bound(thrust::seq, + d_orderby + group_start, + d_orderby + idx, + highest_in_window, + cudf::detail::nan_aware_greater{})) + + 1; // Add 1, for `preceding` to account for current row. + }); auto const preceding_column = cudf::detail::expand_to_column(preceding_calculator, input.size(), stream); - auto const following_calculator = + auto const following_calculator = cuda::proclaim_return_type( [nulls_begin_idx = h_nulls_begin_idx, nulls_end_idx = h_nulls_end_idx, num_rows = input.size(), orderby_device_view = *p_orderby_device_view, following_window, following_window_is_unbounded] __device__(size_type idx) -> size_type { - if (following_window_is_unbounded) { return (num_rows - idx) - 1; } - if (idx >= nulls_begin_idx && idx < nulls_end_idx) { - // Current row is in the null group. - // Window ends at the end of the null group. - return nulls_end_idx - idx - 1; - } - - auto const d_orderby = begin(orderby_device_view); - // orderby[idx] not null. Search must exclude null group. - // If nulls_begin_idx = 0, either - // 1. NULLS FIRST ordering: Search ends at num_rows. - // 2. NO NULLS: Search also ends at num_rows. - // Otherwise, NULLS LAST ordering: End at nulls_begin_idx. - - auto const group_end = nulls_begin_idx == 0 ? num_rows : nulls_begin_idx; - auto const lowest_in_window = compute_lowest_in_window(d_orderby, idx, following_window); - - return (thrust::upper_bound(thrust::seq, - d_orderby + idx, - d_orderby + group_end, - lowest_in_window, - cudf::detail::nan_aware_greater{}) - - (d_orderby + idx)) - - 1; - }; + if (following_window_is_unbounded) { return (num_rows - idx) - 1; } + if (idx >= nulls_begin_idx && idx < nulls_end_idx) { + // Current row is in the null group. + // Window ends at the end of the null group. + return nulls_end_idx - idx - 1; + } + + auto const d_orderby = begin(orderby_device_view); + // orderby[idx] not null. Search must exclude null group. + // If nulls_begin_idx = 0, either + // 1. NULLS FIRST ordering: Search ends at num_rows. + // 2. NO NULLS: Search also ends at num_rows. + // Otherwise, NULLS LAST ordering: End at nulls_begin_idx. + + auto const group_end = nulls_begin_idx == 0 ? num_rows : nulls_begin_idx; + auto const lowest_in_window = compute_lowest_in_window(d_orderby, idx, following_window); + + return (thrust::upper_bound(thrust::seq, + d_orderby + idx, + d_orderby + group_end, + lowest_in_window, + cudf::detail::nan_aware_greater{}) - + (d_orderby + idx)) - + 1; + }); auto const following_column = cudf::detail::expand_to_column(following_calculator, input.size(), stream); @@ -828,7 +830,7 @@ std::unique_ptr range_window_DESC(column_view const& input, get_null_bounds_for_orderby_column(orderby_column, group_offsets, stream); auto const p_orderby_device_view = cudf::column_device_view::create(orderby_column, stream); - auto const preceding_calculator = + auto const preceding_calculator = cuda::proclaim_return_type( [d_group_offsets = group_offsets.data(), d_group_labels = group_labels.data(), orderby_device_view = *p_orderby_device_view, @@ -836,41 +838,41 @@ std::unique_ptr range_window_DESC(column_view const& input, d_nulls_end = null_end.data(), preceding_window, preceding_window_is_unbounded] __device__(size_type idx) -> size_type { - auto const group_label = d_group_labels[idx]; - auto const group_start = d_group_offsets[group_label]; - auto const nulls_begin = d_nulls_begin[group_label]; - auto const nulls_end = d_nulls_end[group_label]; - - if (preceding_window_is_unbounded) { return (idx - group_start) + 1; } - - // If idx lies in the null-range, the window is the null range. - if (idx >= nulls_begin && idx < nulls_end) { - // Current row is in the null group. - // The window starts at the start of the null group. - return idx - nulls_begin + 1; - } - - auto const d_orderby = begin(orderby_device_view); - // orderby[idx] not null. Search must exclude the null group. - // If nulls_begin == group_start, either of the following is true: - // 1. NULLS FIRST ordering: Search must begin at nulls_end. - // 2. NO NULLS: Search must begin at group_start (which also equals nulls_end.) - // Otherwise, NULLS LAST ordering. Search must start at nulls group_start. - auto const search_start = nulls_begin == group_start ? nulls_end : group_start; - auto const highest_in_window = compute_highest_in_window(d_orderby, idx, preceding_window); - - return ((d_orderby + idx) - thrust::lower_bound(thrust::seq, - d_orderby + search_start, - d_orderby + idx, - highest_in_window, - cudf::detail::nan_aware_greater{})) + - 1; // Add 1, for `preceding` to account for current row. - }; + auto const group_label = d_group_labels[idx]; + auto const group_start = d_group_offsets[group_label]; + auto const nulls_begin = d_nulls_begin[group_label]; + auto const nulls_end = d_nulls_end[group_label]; + + if (preceding_window_is_unbounded) { return (idx - group_start) + 1; } + + // If idx lies in the null-range, the window is the null range. + if (idx >= nulls_begin && idx < nulls_end) { + // Current row is in the null group. + // The window starts at the start of the null group. + return idx - nulls_begin + 1; + } + + auto const d_orderby = begin(orderby_device_view); + // orderby[idx] not null. Search must exclude the null group. + // If nulls_begin == group_start, either of the following is true: + // 1. NULLS FIRST ordering: Search must begin at nulls_end. + // 2. NO NULLS: Search must begin at group_start (which also equals nulls_end.) + // Otherwise, NULLS LAST ordering. Search must start at nulls group_start. + auto const search_start = nulls_begin == group_start ? nulls_end : group_start; + auto const highest_in_window = compute_highest_in_window(d_orderby, idx, preceding_window); + + return ((d_orderby + idx) - thrust::lower_bound(thrust::seq, + d_orderby + search_start, + d_orderby + idx, + highest_in_window, + cudf::detail::nan_aware_greater{})) + + 1; // Add 1, for `preceding` to account for current row. + }); auto const preceding_column = cudf::detail::expand_to_column(preceding_calculator, input.size(), stream); - auto const following_calculator = + auto const following_calculator = cuda::proclaim_return_type( [d_group_offsets = group_offsets.data(), d_group_labels = group_labels.data(), orderby_device_view = *p_orderby_device_view, @@ -878,38 +880,38 @@ std::unique_ptr range_window_DESC(column_view const& input, d_nulls_end = null_end.data(), following_window, following_window_is_unbounded] __device__(size_type idx) -> size_type { - auto const group_label = d_group_labels[idx]; - auto const group_start = d_group_offsets[group_label]; - auto const group_end = d_group_offsets[group_label + 1]; - auto const nulls_begin = d_nulls_begin[group_label]; - auto const nulls_end = d_nulls_end[group_label]; - - if (following_window_is_unbounded) { return (group_end - idx) - 1; } - - // If idx lies in the null-range, the window is the null range. - if (idx >= nulls_begin && idx < nulls_end) { - // Current row is in the null group. - // The window ends at the end of the null group. - return nulls_end - idx - 1; - } - - auto const d_orderby = begin(orderby_device_view); - // orderby[idx] not null. Search must exclude the null group. - // If nulls_begin == group_start, either of the following is true: - // 1. NULLS FIRST ordering: Search ends at group_end. - // 2. NO NULLS: Search ends at group_end. - // Otherwise, NULLS LAST ordering. Search ends at nulls_begin. - auto const search_end = nulls_begin == group_start ? group_end : nulls_begin; - auto const lowest_in_window = compute_lowest_in_window(d_orderby, idx, following_window); - - return (thrust::upper_bound(thrust::seq, - d_orderby + idx, - d_orderby + search_end, - lowest_in_window, - cudf::detail::nan_aware_greater{}) - - (d_orderby + idx)) - - 1; - }; + auto const group_label = d_group_labels[idx]; + auto const group_start = d_group_offsets[group_label]; + auto const group_end = d_group_offsets[group_label + 1]; + auto const nulls_begin = d_nulls_begin[group_label]; + auto const nulls_end = d_nulls_end[group_label]; + + if (following_window_is_unbounded) { return (group_end - idx) - 1; } + + // If idx lies in the null-range, the window is the null range. + if (idx >= nulls_begin && idx < nulls_end) { + // Current row is in the null group. + // The window ends at the end of the null group. + return nulls_end - idx - 1; + } + + auto const d_orderby = begin(orderby_device_view); + // orderby[idx] not null. Search must exclude the null group. + // If nulls_begin == group_start, either of the following is true: + // 1. NULLS FIRST ordering: Search ends at group_end. + // 2. NO NULLS: Search ends at group_end. + // Otherwise, NULLS LAST ordering. Search ends at nulls_begin. + auto const search_end = nulls_begin == group_start ? group_end : nulls_begin; + auto const lowest_in_window = compute_lowest_in_window(d_orderby, idx, following_window); + + return (thrust::upper_bound(thrust::seq, + d_orderby + idx, + d_orderby + search_end, + lowest_in_window, + cudf::detail::nan_aware_greater{}) - + (d_orderby + idx)) - + 1; + }); auto const following_column = cudf::detail::expand_to_column(following_calculator, input.size(), stream); diff --git a/cpp/src/search/contains_table.cu b/cpp/src/search/contains_table.cu index 43624ba691d..09122b37d6f 100644 --- a/cpp/src/search/contains_table.cu +++ b/cpp/src/search/contains_table.cu @@ -1,286 +1,292 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include -#include - -#include -#include - -#include - -#include - -#include - -namespace cudf::detail { - -namespace { - -using cudf::experimental::row::lhs_index_type; -using cudf::experimental::row::rhs_index_type; - -/** - * @brief An hasher adapter wrapping both haystack hasher and needles hasher - */ -template -struct hasher_adapter { - hasher_adapter(HaystackHasher const& haystack_hasher, NeedleHasher const& needle_hasher) - : _haystack_hasher{haystack_hasher}, _needle_hasher{needle_hasher} - { - } - - __device__ constexpr auto operator()(lhs_index_type idx) const noexcept - { - return _haystack_hasher(static_cast(idx)); - } - - __device__ constexpr auto operator()(rhs_index_type idx) const noexcept - { - return _needle_hasher(static_cast(idx)); - } - - private: - HaystackHasher const _haystack_hasher; - NeedleHasher const _needle_hasher; -}; - -/** - * @brief An comparator adapter wrapping both self comparator and two table comparator - */ -template -struct comparator_adapter { - comparator_adapter(SelfEqual const& self_equal, TwoTableEqual const& two_table_equal) - : _self_equal{self_equal}, _two_table_equal{two_table_equal} - { - } - - __device__ constexpr auto operator()(lhs_index_type lhs_index, - lhs_index_type rhs_index) const noexcept - { - auto const lhs = static_cast(lhs_index); - auto const rhs = static_cast(rhs_index); - - return _self_equal(lhs, rhs); - } - - __device__ constexpr auto operator()(lhs_index_type lhs_index, - rhs_index_type rhs_index) const noexcept - { - return _two_table_equal(lhs_index, rhs_index); - } - - private: - SelfEqual const _self_equal; - TwoTableEqual const _two_table_equal; -}; - -/** - * @brief Build a row bitmask for the input table. - * - * The output bitmask will have invalid bits corresponding to the input rows having nulls (at - * any nested level) and vice versa. - * - * @param input The input table - * @param stream CUDA stream used for device memory operations and kernel launches - * @return A pair of pointer to the output bitmask and the buffer containing the bitmask - */ -std::pair build_row_bitmask(table_view const& input, - rmm::cuda_stream_view stream) -{ - auto const nullable_columns = get_nullable_columns(input); - CUDF_EXPECTS(nullable_columns.size() > 0, - "The input table has nulls thus it should have nullable columns."); - - // If there are more than one nullable column, we compute `bitmask_and` of their null masks. - // Otherwise, we have only one nullable column and can use its null mask directly. - if (nullable_columns.size() > 1) { - auto row_bitmask = - cudf::detail::bitmask_and( - table_view{nullable_columns}, stream, rmm::mr::get_current_device_resource()) - .first; - auto const row_bitmask_ptr = static_cast(row_bitmask.data()); - return std::pair(std::move(row_bitmask), row_bitmask_ptr); - } - - return std::pair(rmm::device_buffer{0, stream}, nullable_columns.front().null_mask()); -} - -/** - * @brief Invokes the given `func` with desired comparators based on the specified `compare_nans` - * parameter - * - * @tparam HasNested Flag indicating whether there are nested columns in haystack or needles - * @tparam Hasher Type of device hash function - * @tparam Func Type of the helper function doing `contains` check - * - * @param compare_nulls Control whether nulls should be compared as equal or not - * @param compare_nans Control whether floating-point NaNs values should be compared as equal or not - * @param haystack_has_nulls Flag indicating whether haystack has nulls or not - * @param has_any_nulls Flag indicating whether there are nested nulls is either haystack or needles - * @param self_equal Self table comparator - * @param two_table_equal Two table comparator - * @param d_hasher Device hash functor - * @param func The input functor to invoke - */ -template -void dispatch_nan_comparator( - null_equality compare_nulls, - nan_equality compare_nans, - bool haystack_has_nulls, - bool has_any_nulls, - cudf::experimental::row::equality::self_comparator self_equal, - cudf::experimental::row::equality::two_table_comparator two_table_equal, - Hasher const& d_hasher, - Func&& func) -{ - // Distinguish probing scheme CG sizes between nested and flat types for better performance - auto const probing_scheme = [&]() { - if constexpr (HasNested) { - return cuco::experimental::linear_probing<4, Hasher>{d_hasher}; - } else { - return cuco::experimental::linear_probing<1, Hasher>{d_hasher}; - } - }(); - - if (compare_nans == nan_equality::ALL_EQUAL) { - using nan_equal_comparator = - cudf::experimental::row::equality::nan_equal_physical_equality_comparator; - auto const d_self_equal = self_equal.equal_to( - nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, nan_equal_comparator{}); - auto const d_two_table_equal = two_table_equal.equal_to( - nullate::DYNAMIC{has_any_nulls}, compare_nulls, nan_equal_comparator{}); - func(d_self_equal, d_two_table_equal, probing_scheme); - } else { - using nan_unequal_comparator = cudf::experimental::row::equality::physical_equality_comparator; - auto const d_self_equal = self_equal.equal_to( - nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, nan_unequal_comparator{}); - auto const d_two_table_equal = two_table_equal.equal_to( - nullate::DYNAMIC{has_any_nulls}, compare_nulls, nan_unequal_comparator{}); - func(d_self_equal, d_two_table_equal, probing_scheme); - } -} - -} // namespace - -rmm::device_uvector contains(table_view const& haystack, - table_view const& needles, - null_equality compare_nulls, - nan_equality compare_nans, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) -{ - CUDF_EXPECTS(cudf::have_same_types(haystack, needles), "Column types mismatch"); - - auto const haystack_has_nulls = has_nested_nulls(haystack); - auto const needles_has_nulls = has_nested_nulls(needles); - auto const has_any_nulls = haystack_has_nulls || needles_has_nulls; - - auto const preprocessed_needles = - cudf::experimental::row::equality::preprocessed_table::create(needles, stream); - auto const preprocessed_haystack = - cudf::experimental::row::equality::preprocessed_table::create(haystack, stream); - - auto const haystack_hasher = cudf::experimental::row::hash::row_hasher(preprocessed_haystack); - auto const d_haystack_hasher = haystack_hasher.device_hasher(nullate::DYNAMIC{has_any_nulls}); - auto const needle_hasher = cudf::experimental::row::hash::row_hasher(preprocessed_needles); - auto const d_needle_hasher = needle_hasher.device_hasher(nullate::DYNAMIC{has_any_nulls}); - auto const d_hasher = hasher_adapter{d_haystack_hasher, d_needle_hasher}; - - auto const self_equal = cudf::experimental::row::equality::self_comparator(preprocessed_haystack); - auto const two_table_equal = cudf::experimental::row::equality::two_table_comparator( - preprocessed_haystack, preprocessed_needles); - - // The output vector. - auto contained = rmm::device_uvector(needles.num_rows(), stream, mr); - - auto const haystack_iter = cudf::detail::make_counting_transform_iterator( - size_type{0}, [] __device__(auto idx) { return lhs_index_type{idx}; }); - auto const needles_iter = cudf::detail::make_counting_transform_iterator( - size_type{0}, [] __device__(auto idx) { return rhs_index_type{idx}; }); - - auto const helper_func = - [&](auto const& d_self_equal, auto const& d_two_table_equal, auto const& probing_scheme) { - auto const d_equal = comparator_adapter{d_self_equal, d_two_table_equal}; - - auto set = cuco::experimental::static_set{ - cuco::experimental::extent{compute_hash_table_size(haystack.num_rows())}, - cuco::empty_key{lhs_index_type{-1}}, - d_equal, - probing_scheme, - detail::hash_table_allocator_type{default_allocator{}, stream}, - stream.value()}; - - if (haystack_has_nulls && compare_nulls == null_equality::UNEQUAL) { - auto const bitmask_buffer_and_ptr = build_row_bitmask(haystack, stream); - auto const row_bitmask_ptr = bitmask_buffer_and_ptr.second; - - // If the haystack table has nulls but they are compared unequal, don't insert them. - // Otherwise, it was known to cause performance issue: - // - https://github.com/rapidsai/cudf/pull/6943 - // - https://github.com/rapidsai/cudf/pull/8277 - set.insert_if_async(haystack_iter, - haystack_iter + haystack.num_rows(), - thrust::counting_iterator(0), // stencil - row_is_valid{row_bitmask_ptr}, - stream.value()); - } else { - set.insert_async(haystack_iter, haystack_iter + haystack.num_rows(), stream.value()); - } - - if (needles_has_nulls && compare_nulls == null_equality::UNEQUAL) { - auto const bitmask_buffer_and_ptr = build_row_bitmask(needles, stream); - auto const row_bitmask_ptr = bitmask_buffer_and_ptr.second; - set.contains_if_async(needles_iter, - needles_iter + needles.num_rows(), - thrust::counting_iterator(0), // stencil - row_is_valid{row_bitmask_ptr}, - contained.begin(), - stream.value()); - } else { - set.contains_async( - needles_iter, needles_iter + needles.num_rows(), contained.begin(), stream.value()); - } - }; - - if (cudf::detail::has_nested_columns(haystack)) { - dispatch_nan_comparator(compare_nulls, - compare_nans, - haystack_has_nulls, - has_any_nulls, - self_equal, - two_table_equal, - d_hasher, - helper_func); - } else { - dispatch_nan_comparator(compare_nulls, - compare_nans, - haystack_has_nulls, - has_any_nulls, - self_equal, - two_table_equal, - d_hasher, - helper_func); - } - - return contained; -} - -} // namespace cudf::detail +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include + +#include +#include + +#include + +#include + +#include + +#include + +namespace cudf::detail { + +namespace { + +using cudf::experimental::row::lhs_index_type; +using cudf::experimental::row::rhs_index_type; + +/** + * @brief An hasher adapter wrapping both haystack hasher and needles hasher + */ +template +struct hasher_adapter { + hasher_adapter(HaystackHasher const& haystack_hasher, NeedleHasher const& needle_hasher) + : _haystack_hasher{haystack_hasher}, _needle_hasher{needle_hasher} + { + } + + __device__ constexpr auto operator()(lhs_index_type idx) const noexcept + { + return _haystack_hasher(static_cast(idx)); + } + + __device__ constexpr auto operator()(rhs_index_type idx) const noexcept + { + return _needle_hasher(static_cast(idx)); + } + + private: + HaystackHasher const _haystack_hasher; + NeedleHasher const _needle_hasher; +}; + +/** + * @brief An comparator adapter wrapping both self comparator and two table comparator + */ +template +struct comparator_adapter { + comparator_adapter(SelfEqual const& self_equal, TwoTableEqual const& two_table_equal) + : _self_equal{self_equal}, _two_table_equal{two_table_equal} + { + } + + __device__ constexpr auto operator()(lhs_index_type lhs_index, + lhs_index_type rhs_index) const noexcept + { + auto const lhs = static_cast(lhs_index); + auto const rhs = static_cast(rhs_index); + + return _self_equal(lhs, rhs); + } + + __device__ constexpr auto operator()(lhs_index_type lhs_index, + rhs_index_type rhs_index) const noexcept + { + return _two_table_equal(lhs_index, rhs_index); + } + + private: + SelfEqual const _self_equal; + TwoTableEqual const _two_table_equal; +}; + +/** + * @brief Build a row bitmask for the input table. + * + * The output bitmask will have invalid bits corresponding to the input rows having nulls (at + * any nested level) and vice versa. + * + * @param input The input table + * @param stream CUDA stream used for device memory operations and kernel launches + * @return A pair of pointer to the output bitmask and the buffer containing the bitmask + */ +std::pair build_row_bitmask(table_view const& input, + rmm::cuda_stream_view stream) +{ + auto const nullable_columns = get_nullable_columns(input); + CUDF_EXPECTS(nullable_columns.size() > 0, + "The input table has nulls thus it should have nullable columns."); + + // If there are more than one nullable column, we compute `bitmask_and` of their null masks. + // Otherwise, we have only one nullable column and can use its null mask directly. + if (nullable_columns.size() > 1) { + auto row_bitmask = + cudf::detail::bitmask_and( + table_view{nullable_columns}, stream, rmm::mr::get_current_device_resource()) + .first; + auto const row_bitmask_ptr = static_cast(row_bitmask.data()); + return std::pair(std::move(row_bitmask), row_bitmask_ptr); + } + + return std::pair(rmm::device_buffer{0, stream}, nullable_columns.front().null_mask()); +} + +/** + * @brief Invokes the given `func` with desired comparators based on the specified `compare_nans` + * parameter + * + * @tparam HasNested Flag indicating whether there are nested columns in haystack or needles + * @tparam Hasher Type of device hash function + * @tparam Func Type of the helper function doing `contains` check + * + * @param compare_nulls Control whether nulls should be compared as equal or not + * @param compare_nans Control whether floating-point NaNs values should be compared as equal or not + * @param haystack_has_nulls Flag indicating whether haystack has nulls or not + * @param has_any_nulls Flag indicating whether there are nested nulls is either haystack or needles + * @param self_equal Self table comparator + * @param two_table_equal Two table comparator + * @param d_hasher Device hash functor + * @param func The input functor to invoke + */ +template +void dispatch_nan_comparator( + null_equality compare_nulls, + nan_equality compare_nans, + bool haystack_has_nulls, + bool has_any_nulls, + cudf::experimental::row::equality::self_comparator self_equal, + cudf::experimental::row::equality::two_table_comparator two_table_equal, + Hasher const& d_hasher, + Func&& func) +{ + // Distinguish probing scheme CG sizes between nested and flat types for better performance + auto const probing_scheme = [&]() { + if constexpr (HasNested) { + return cuco::experimental::linear_probing<4, Hasher>{d_hasher}; + } else { + return cuco::experimental::linear_probing<1, Hasher>{d_hasher}; + } + }(); + + if (compare_nans == nan_equality::ALL_EQUAL) { + using nan_equal_comparator = + cudf::experimental::row::equality::nan_equal_physical_equality_comparator; + auto const d_self_equal = self_equal.equal_to( + nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, nan_equal_comparator{}); + auto const d_two_table_equal = two_table_equal.equal_to( + nullate::DYNAMIC{has_any_nulls}, compare_nulls, nan_equal_comparator{}); + func(d_self_equal, d_two_table_equal, probing_scheme); + } else { + using nan_unequal_comparator = cudf::experimental::row::equality::physical_equality_comparator; + auto const d_self_equal = self_equal.equal_to( + nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, nan_unequal_comparator{}); + auto const d_two_table_equal = two_table_equal.equal_to( + nullate::DYNAMIC{has_any_nulls}, compare_nulls, nan_unequal_comparator{}); + func(d_self_equal, d_two_table_equal, probing_scheme); + } +} + +} // namespace + +rmm::device_uvector contains(table_view const& haystack, + table_view const& needles, + null_equality compare_nulls, + nan_equality compare_nans, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + CUDF_EXPECTS(cudf::have_same_types(haystack, needles), "Column types mismatch"); + + auto const haystack_has_nulls = has_nested_nulls(haystack); + auto const needles_has_nulls = has_nested_nulls(needles); + auto const has_any_nulls = haystack_has_nulls || needles_has_nulls; + + auto const preprocessed_needles = + cudf::experimental::row::equality::preprocessed_table::create(needles, stream); + auto const preprocessed_haystack = + cudf::experimental::row::equality::preprocessed_table::create(haystack, stream); + + auto const haystack_hasher = cudf::experimental::row::hash::row_hasher(preprocessed_haystack); + auto const d_haystack_hasher = haystack_hasher.device_hasher(nullate::DYNAMIC{has_any_nulls}); + auto const needle_hasher = cudf::experimental::row::hash::row_hasher(preprocessed_needles); + auto const d_needle_hasher = needle_hasher.device_hasher(nullate::DYNAMIC{has_any_nulls}); + auto const d_hasher = hasher_adapter{d_haystack_hasher, d_needle_hasher}; + + auto const self_equal = cudf::experimental::row::equality::self_comparator(preprocessed_haystack); + auto const two_table_equal = cudf::experimental::row::equality::two_table_comparator( + preprocessed_haystack, preprocessed_needles); + + // The output vector. + auto contained = rmm::device_uvector(needles.num_rows(), stream, mr); + + auto const haystack_iter = cudf::detail::make_counting_transform_iterator( + size_type{0}, cuda::proclaim_return_type([] __device__(auto idx) { + return lhs_index_type{idx}; + })); + auto const needles_iter = cudf::detail::make_counting_transform_iterator( + size_type{0}, cuda::proclaim_return_type([] __device__(auto idx) { + return rhs_index_type{idx}; + })); + + auto const helper_func = + [&](auto const& d_self_equal, auto const& d_two_table_equal, auto const& probing_scheme) { + auto const d_equal = comparator_adapter{d_self_equal, d_two_table_equal}; + + auto set = cuco::experimental::static_set{ + cuco::experimental::extent{compute_hash_table_size(haystack.num_rows())}, + cuco::empty_key{lhs_index_type{-1}}, + d_equal, + probing_scheme, + detail::hash_table_allocator_type{default_allocator{}, stream}, + stream.value()}; + + if (haystack_has_nulls && compare_nulls == null_equality::UNEQUAL) { + auto const bitmask_buffer_and_ptr = build_row_bitmask(haystack, stream); + auto const row_bitmask_ptr = bitmask_buffer_and_ptr.second; + + // If the haystack table has nulls but they are compared unequal, don't insert them. + // Otherwise, it was known to cause performance issue: + // - https://github.com/rapidsai/cudf/pull/6943 + // - https://github.com/rapidsai/cudf/pull/8277 + set.insert_if_async(haystack_iter, + haystack_iter + haystack.num_rows(), + thrust::counting_iterator(0), // stencil + row_is_valid{row_bitmask_ptr}, + stream.value()); + } else { + set.insert_async(haystack_iter, haystack_iter + haystack.num_rows(), stream.value()); + } + + if (needles_has_nulls && compare_nulls == null_equality::UNEQUAL) { + auto const bitmask_buffer_and_ptr = build_row_bitmask(needles, stream); + auto const row_bitmask_ptr = bitmask_buffer_and_ptr.second; + set.contains_if_async(needles_iter, + needles_iter + needles.num_rows(), + thrust::counting_iterator(0), // stencil + row_is_valid{row_bitmask_ptr}, + contained.begin(), + stream.value()); + } else { + set.contains_async( + needles_iter, needles_iter + needles.num_rows(), contained.begin(), stream.value()); + } + }; + + if (cudf::detail::has_nested_columns(haystack)) { + dispatch_nan_comparator(compare_nulls, + compare_nans, + haystack_has_nulls, + has_any_nulls, + self_equal, + two_table_equal, + d_hasher, + helper_func); + } else { + dispatch_nan_comparator(compare_nulls, + compare_nans, + haystack_has_nulls, + has_any_nulls, + self_equal, + two_table_equal, + d_hasher, + helper_func); + } + + return contained; +} + +} // namespace cudf::detail diff --git a/cpp/src/sort/rank.cu b/cpp/src/sort/rank.cu index 3ead8cfcbaa..9cf07f065d2 100644 --- a/cpp/src/sort/rank.cu +++ b/cpp/src/sort/rank.cu @@ -44,6 +44,9 @@ #include #include +#include +#include + namespace cudf { namespace detail { namespace { @@ -145,11 +148,14 @@ void tie_break_ranks_transform(cudf::device_span dense_rank_sor tie_sorted.begin(), thrust::equal_to{}, tie_breaker); + using TransformerReturnType = + cuda::std::decay_t>; auto sorted_tied_rank = thrust::make_transform_iterator( dense_rank_sorted.begin(), - [tied_rank = tie_sorted.begin(), transformer] __device__(auto dense_pos) { - return transformer(tied_rank[dense_pos - 1]); - }); + cuda::proclaim_return_type( + [tied_rank = tie_sorted.begin(), transformer] __device__(auto dense_pos) { + return transformer(tied_rank[dense_pos - 1]); + })); thrust::scatter(rmm::exec_policy(stream), sorted_tied_rank, sorted_tied_rank + input_size, @@ -245,14 +251,14 @@ void rank_average(cudf::device_span group_keys, cudf::detail::make_counting_transform_iterator(1, index_counter{}), sorted_order_view, rank_mutable_view.begin(), - [] __device__(auto rank_count1, auto rank_count2) { + cuda::proclaim_return_type([] __device__(auto rank_count1, auto rank_count2) { return MinCount{std::min(rank_count1.first, rank_count2.first), rank_count1.second + rank_count2.second}; - }, - [] __device__(MinCount minrank_count) { // min+(count-1)/2 + }), + cuda::proclaim_return_type([] __device__(MinCount minrank_count) { // min+(count-1)/2 return static_cast(thrust::get<0>(minrank_count)) + (static_cast(thrust::get<1>(minrank_count)) - 1) / 2.0; - }, + }), stream); } @@ -348,13 +354,14 @@ std::unique_ptr rank(column_view const& input, (null_handling == null_policy::EXCLUDE) ? input.size() - input.null_count() : input.size(); auto drs = dense_rank_sorted.data(); bool const is_dense = (method == rank_method::DENSE); - thrust::transform(rmm::exec_policy(stream), - rank_iter, - rank_iter + input.size(), - rank_iter, - [is_dense, drs, count] __device__(double r) -> double { - return is_dense ? r / drs[count - 1] : r / count; - }); + thrust::transform( + rmm::exec_policy(stream), + rank_iter, + rank_iter + input.size(), + rank_iter, + cuda::proclaim_return_type([is_dense, drs, count] __device__(double r) -> double { + return is_dense ? r / drs[count - 1] : r / count; + })); } return rank_column; } diff --git a/cpp/src/stream_compaction/distinct.cu b/cpp/src/stream_compaction/distinct.cu index cc1e3423d42..7adce5d3cbc 100644 --- a/cpp/src/stream_compaction/distinct.cu +++ b/cpp/src/stream_compaction/distinct.cu @@ -32,6 +32,8 @@ #include #include +#include + #include #include @@ -66,7 +68,9 @@ rmm::device_uvector get_distinct_indices(table_view const& input, auto const row_comp = cudf::experimental::row::equality::self_comparator(preprocessed_input); auto const pair_iter = cudf::detail::make_counting_transform_iterator( - size_type{0}, [] __device__(size_type const i) { return cuco::make_pair(i, i); }); + size_type{0}, + cuda::proclaim_return_type>( + [] __device__(size_type const i) { return cuco::make_pair(i, i); })); auto const insert_keys = [&](auto const value_comp) { if (has_nested_columns) { diff --git a/cpp/src/strings/attributes.cu b/cpp/src/strings/attributes.cu index 8dc150998ee..de51356845c 100644 --- a/cpp/src/strings/attributes.cu +++ b/cpp/src/strings/attributes.cu @@ -42,6 +42,8 @@ #include +#include + namespace cudf { namespace strings { namespace detail { @@ -93,11 +95,11 @@ std::unique_ptr counts_fn(strings_column_view const& strings, thrust::make_counting_iterator(0), thrust::make_counting_iterator(strings.size()), d_lengths, - [d_strings, ufn] __device__(size_type idx) { + cuda::proclaim_return_type([d_strings, ufn] __device__(size_type idx) { return d_strings.is_null(idx) ? 0 : static_cast(ufn(d_strings.element(idx))); - }); + })); results->set_null_count(strings.null_count()); // reset null count return results; } @@ -169,7 +171,8 @@ std::unique_ptr count_characters(strings_column_view const& input, { if ((input.size() == input.null_count()) || ((input.chars_size() / (input.size() - input.null_count())) < AVG_CHAR_BYTES_THRESHOLD)) { - auto ufn = [] __device__(string_view const& d_str) { return d_str.length(); }; + auto ufn = cuda::proclaim_return_type( + [] __device__(string_view const& d_str) { return d_str.length(); }); return counts_fn(input, ufn, stream, mr); } @@ -180,7 +183,8 @@ std::unique_ptr count_bytes(strings_column_view const& input, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - auto ufn = [] __device__(string_view const& d_str) { return d_str.size_bytes(); }; + auto ufn = cuda::proclaim_return_type( + [] __device__(string_view const& d_str) { return d_str.size_bytes(); }); return counts_fn(input, ufn, stream, mr); } @@ -228,11 +232,11 @@ std::unique_ptr code_points(strings_column_view const& input, thrust::make_counting_iterator(0), thrust::make_counting_iterator(input.size()), offsets.begin() + 1, - [d_column] __device__(size_type idx) { + cuda::proclaim_return_type([d_column] __device__(size_type idx) { size_type length = 0; if (!d_column.is_null(idx)) length = d_column.element(idx).length(); return length; - }, + }), thrust::plus()); offsets.set_element_to_zero_async(0, stream); diff --git a/cpp/src/strings/copying/copying.cu b/cpp/src/strings/copying/copying.cu index e6796c2209b..2295a80ff5b 100644 --- a/cpp/src/strings/copying/copying.cu +++ b/cpp/src/strings/copying/copying.cu @@ -26,6 +26,8 @@ #include +#include + namespace cudf { namespace strings { namespace detail { @@ -58,7 +60,8 @@ std::unique_ptr copy_slice(strings_column_view const& strings, d_offsets.begin(), d_offsets.end(), d_offsets.begin(), - [chars_offset] __device__(auto offset) { return offset - chars_offset; }); + cuda::proclaim_return_type( + [chars_offset] __device__(auto offset) { return offset - chars_offset; })); } // slice the chars child column diff --git a/cpp/src/strings/extract/extract.cu b/cpp/src/strings/extract/extract.cu index 8edcd167e5c..9af1e54fe66 100644 --- a/cpp/src/strings/extract/extract.cu +++ b/cpp/src/strings/extract/extract.cu @@ -36,6 +36,8 @@ #include #include +#include + namespace cudf { namespace strings { namespace detail { @@ -110,12 +112,12 @@ std::unique_ptr
extract(strings_column_view const& input, std::vector> results(groups); auto make_strings_lambda = [&](size_type column_index) { // this iterator transposes the extract results into column order - auto indices_itr = - thrust::make_permutation_iterator(indices.begin(), - cudf::detail::make_counting_transform_iterator( - 0, [column_index, groups] __device__(size_type idx) { - return (idx * groups) + column_index; - })); + auto indices_itr = thrust::make_permutation_iterator( + indices.begin(), + cudf::detail::make_counting_transform_iterator( + 0, cuda::proclaim_return_type([column_index, groups] __device__(size_type idx) { + return (idx * groups) + column_index; + }))); return make_strings_column(indices_itr, indices_itr + input.size(), stream, mr); }; diff --git a/cpp/src/strings/filling/fill.cu b/cpp/src/strings/filling/fill.cu index 0c4dd538c74..997749586f0 100644 --- a/cpp/src/strings/filling/fill.cu +++ b/cpp/src/strings/filling/fill.cu @@ -31,6 +31,8 @@ #include #include +#include + namespace cudf { namespace strings { namespace detail { @@ -72,11 +74,12 @@ std::unique_ptr fill(strings_column_view const& strings, rmm::device_buffer& null_mask = valid_mask.first; // build offsets column - auto offsets_transformer = [d_strings, begin, end, d_value] __device__(size_type idx) { - if (((begin <= idx) && (idx < end)) ? !d_value.is_valid() : d_strings.is_null(idx)) return 0; - return ((begin <= idx) && (idx < end)) ? d_value.size() - : d_strings.element(idx).size_bytes(); - }; + auto offsets_transformer = cuda::proclaim_return_type( + [d_strings, begin, end, d_value] __device__(size_type idx) { + if (((begin <= idx) && (idx < end)) ? !d_value.is_valid() : d_strings.is_null(idx)) return 0; + return ((begin <= idx) && (idx < end)) ? d_value.size() + : d_strings.element(idx).size_bytes(); + }); auto offsets_transformer_itr = thrust::make_transform_iterator( thrust::make_counting_iterator(0), offsets_transformer); auto [offsets_column, bytes] = cudf::detail::make_offsets_child_column( diff --git a/cpp/src/strings/replace/multi.cu b/cpp/src/strings/replace/multi.cu index f80ace57c69..28736c2ca15 100644 --- a/cpp/src/strings/replace/multi.cu +++ b/cpp/src/strings/replace/multi.cu @@ -45,6 +45,8 @@ #include #include +#include + namespace cudf { namespace strings { namespace detail { @@ -304,7 +306,9 @@ std::unique_ptr replace_character_parallel(strings_column_view const& in auto string_indices = rmm::device_uvector(target_count, stream); auto const pos_itr = cudf::detail::make_counting_transform_iterator( - 0, [d_positions] __device__(auto idx) -> size_type { return d_positions[idx].first; }); + 0, cuda::proclaim_return_type([d_positions] __device__(auto idx) -> size_type { + return d_positions[idx].first; + })); auto pos_count = std::distance(d_positions, copy_end); thrust::upper_bound(rmm::exec_policy(stream), @@ -346,9 +350,10 @@ std::unique_ptr replace_character_parallel(strings_column_view const& in thrust::make_counting_iterator(0), thrust::make_counting_iterator(strings_count), counts.begin(), - [fn, d_positions, d_targets_offsets] __device__(size_type idx) -> size_type { - return fn.count_strings(idx, d_positions, d_targets_offsets); - }); + cuda::proclaim_return_type( + [fn, d_positions, d_targets_offsets] __device__(size_type idx) -> size_type { + return fn.count_strings(idx, d_positions, d_targets_offsets); + })); // create offsets from the counts auto offsets = diff --git a/cpp/src/strings/replace/replace.cu b/cpp/src/strings/replace/replace.cu index a6a14f27dec..aa955d3086e 100644 --- a/cpp/src/strings/replace/replace.cu +++ b/cpp/src/strings/replace/replace.cu @@ -45,6 +45,8 @@ #include #include +#include + namespace cudf { namespace strings { namespace detail { @@ -353,11 +355,12 @@ size_type filter_maxrepl_target_positions(size_type* d_target_positions, size_type max_repl_per_row, rmm::cuda_stream_view stream) { - auto pos_to_row_fn = [d_offsets_span] __device__(size_type target_pos) -> size_type { - auto upper_bound = - thrust::upper_bound(thrust::seq, d_offsets_span.begin(), d_offsets_span.end(), target_pos); - return thrust::distance(d_offsets_span.begin(), upper_bound); - }; + auto pos_to_row_fn = cuda::proclaim_return_type( + [d_offsets_span] __device__(size_type target_pos) -> size_type { + auto upper_bound = + thrust::upper_bound(thrust::seq, d_offsets_span.begin(), d_offsets_span.end(), target_pos); + return thrust::distance(d_offsets_span.begin(), upper_bound); + }); // compute the match count per row for each target position rmm::device_uvector match_counts(target_count, stream); @@ -467,15 +470,15 @@ std::unique_ptr replace_char_parallel(strings_column_view const& strings auto offsets_view = offsets_column->mutable_view(); auto delta_per_target = d_repl.size_bytes() - target_size; device_span d_target_positions_span(d_target_positions, target_count); - auto offsets_update_fn = + auto offsets_update_fn = cuda::proclaim_return_type( [d_target_positions_span, delta_per_target, chars_start] __device__(int32_t offset) -> int32_t { - // determine the number of target positions occurring before this offset - size_type const* next_target_pos_ptr = thrust::lower_bound( - thrust::seq, d_target_positions_span.begin(), d_target_positions_span.end(), offset); - size_type num_prev_targets = - thrust::distance(d_target_positions_span.data(), next_target_pos_ptr); - return offset - chars_start + delta_per_target * num_prev_targets; - }; + // determine the number of target positions occurring before this offset + size_type const* next_target_pos_ptr = thrust::lower_bound( + thrust::seq, d_target_positions_span.begin(), d_target_positions_span.end(), offset); + size_type num_prev_targets = + thrust::distance(d_target_positions_span.data(), next_target_pos_ptr); + return offset - chars_start + delta_per_target * num_prev_targets; + }); thrust::transform(rmm::exec_policy(stream), d_offsets_span.begin(), d_offsets_span.end(), @@ -720,10 +723,11 @@ std::unique_ptr replace_nulls(strings_column_view const& strings, // build offsets column auto offsets_transformer_itr = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), [d_strings, d_repl] __device__(size_type idx) { + thrust::make_counting_iterator(0), + cuda::proclaim_return_type([d_strings, d_repl] __device__(size_type idx) { return d_strings.is_null(idx) ? d_repl.size_bytes() : d_strings.element(idx).size_bytes(); - }); + })); auto [offsets_column, bytes] = cudf::detail::make_offsets_child_column( offsets_transformer_itr, offsets_transformer_itr + strings_count, stream, mr); auto d_offsets = offsets_column->view().data(); diff --git a/cpp/src/strings/split/split.cu b/cpp/src/strings/split/split.cu index bad7eef4523..c87c36ba3b9 100644 --- a/cpp/src/strings/split/split.cu +++ b/cpp/src/strings/split/split.cu @@ -40,6 +40,8 @@ #include #include +#include + namespace cudf { namespace strings { namespace detail { @@ -129,18 +131,22 @@ std::unique_ptr
split_fn(strings_column_view const& input, rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(input.size()), - [d_offsets] __device__(auto idx) -> size_type { return d_offsets[idx + 1] - d_offsets[idx]; }, + cuda::proclaim_return_type([d_offsets] __device__(auto idx) -> size_type { + return d_offsets[idx + 1] - d_offsets[idx]; + }), 0, thrust::maximum{}); // build strings columns for each token position for (size_type col = 0; col < columns_count; ++col) { auto itr = cudf::detail::make_counting_transform_iterator( - 0, [d_tokens, d_offsets, col] __device__(size_type idx) { - auto const offset = d_offsets[idx]; - auto const token_count = d_offsets[idx + 1] - offset; - return (col < token_count) ? d_tokens[offset + col] : string_index_pair{nullptr, 0}; - }); + 0, + cuda::proclaim_return_type( + [d_tokens, d_offsets, col] __device__(size_type idx) { + auto const offset = d_offsets[idx]; + auto const token_count = d_offsets[idx + 1] - offset; + return (col < token_count) ? d_tokens[offset + col] : string_index_pair{nullptr, 0}; + })); results.emplace_back(make_strings_column(itr, itr + input.size(), stream, mr)); } @@ -334,7 +340,9 @@ std::unique_ptr
whitespace_split_fn(size_type strings_count, thrust::make_counting_iterator(0), thrust::make_counting_iterator(strings_count), d_token_counts, - [tokenizer] __device__(size_type idx) { return tokenizer.count_tokens(idx); }); + cuda::proclaim_return_type([tokenizer] __device__(size_type idx) { + return tokenizer.count_tokens(idx); + })); // column count is the maximum number of tokens for any string size_type const columns_count = thrust::reduce( diff --git a/cpp/src/strings/split/split_record.cu b/cpp/src/strings/split/split_record.cu index 7a0cfb9ef41..64061aba4fd 100644 --- a/cpp/src/strings/split/split_record.cu +++ b/cpp/src/strings/split/split_record.cu @@ -35,6 +35,8 @@ #include #include +#include + namespace cudf { namespace strings { namespace detail { @@ -142,7 +144,9 @@ std::unique_ptr whitespace_split_record_fn(strings_column_view const& in { // create offsets column by counting the number of tokens per string auto sizes_itr = cudf::detail::make_counting_transform_iterator( - 0, [reader] __device__(auto idx) { return reader.count_tokens(idx); }); + 0, cuda::proclaim_return_type([reader] __device__(auto idx) { + return reader.count_tokens(idx); + })); auto [offsets, total_tokens] = cudf::detail::make_offsets_child_column(sizes_itr, sizes_itr + input.size(), stream, mr); auto d_offsets = offsets->view().template data(); diff --git a/cpp/src/text/bpe/load_merge_pairs.cu b/cpp/src/text/bpe/load_merge_pairs.cu index 6d223a7ddb7..c07d929e98a 100644 --- a/cpp/src/text/bpe/load_merge_pairs.cu +++ b/cpp/src/text/bpe/load_merge_pairs.cu @@ -33,6 +33,8 @@ #include #include +#include + namespace nvtext { namespace detail { namespace { @@ -50,7 +52,9 @@ std::unique_ptr initialize_merge_pairs_map( stream.value()); auto iter = cudf::detail::make_counting_transform_iterator( - 0, [] __device__(cudf::size_type idx) { return cuco::make_pair(idx, idx); }); + 0, + cuda::proclaim_return_type>( + [] __device__(cudf::size_type idx) { return cuco::make_pair(idx, idx); })); merge_pairs_map->insert_async(iter, iter + (input.size() / 2), stream.value()); @@ -70,7 +74,9 @@ std::unique_ptr initialize_mp_table_map( stream.value()); auto iter = cudf::detail::make_counting_transform_iterator( - 0, [] __device__(cudf::size_type idx) { return cuco::make_pair(idx, idx); }); + 0, + cuda::proclaim_return_type>( + [] __device__(cudf::size_type idx) { return cuco::make_pair(idx, idx); })); mp_table_map->insert_async(iter, iter + input.size(), stream.value()); diff --git a/cpp/src/text/generate_ngrams.cu b/cpp/src/text/generate_ngrams.cu index 5f2f4d021a4..31e2405ce88 100644 --- a/cpp/src/text/generate_ngrams.cu +++ b/cpp/src/text/generate_ngrams.cu @@ -39,6 +39,8 @@ #include #include +#include + namespace nvtext { namespace detail { namespace { @@ -219,11 +221,12 @@ std::unique_ptr generate_character_ngrams(cudf::strings_column_vie thrust::make_counting_iterator(0), thrust::make_counting_iterator(strings_count + 1), ngram_offsets.begin(), - [d_strings, strings_count, ngrams] __device__(auto idx) { - if (d_strings.is_null(idx) || (idx == strings_count)) return 0; - auto const length = d_strings.element(idx).length(); - return std::max(0, static_cast(length + 1 - ngrams)); - }, + cuda::proclaim_return_type( + [d_strings, strings_count, ngrams] __device__(auto idx) { + if (d_strings.is_null(idx) || (idx == strings_count)) return 0; + auto const length = d_strings.element(idx).length(); + return std::max(0, static_cast(length + 1 - ngrams)); + }), cudf::size_type{0}, thrust::plus()); @@ -287,11 +290,13 @@ std::unique_ptr hash_character_ngrams(cudf::strings_column_view co // build offsets column by computing the number of ngrams per string auto sizes_itr = cudf::detail::make_counting_transform_iterator( - 0, [d_strings = *d_strings, ngrams] __device__(auto idx) { - if (d_strings.is_null(idx)) { return 0; } - auto const length = d_strings.element(idx).length(); - return std::max(0, static_cast(length + 1 - ngrams)); - }); + 0, + cuda::proclaim_return_type( + [d_strings = *d_strings, ngrams] __device__(auto idx) { + if (d_strings.is_null(idx)) { return 0; } + auto const length = d_strings.element(idx).length(); + return std::max(0, static_cast(length + 1 - ngrams)); + })); auto [offsets, total_ngrams] = cudf::detail::make_offsets_child_column(sizes_itr, sizes_itr + input.size(), stream, mr); auto d_offsets = offsets->view().data(); diff --git a/cpp/src/text/ngrams_tokenize.cu b/cpp/src/text/ngrams_tokenize.cu index 73d85513e95..bc5cd04eac6 100644 --- a/cpp/src/text/ngrams_tokenize.cu +++ b/cpp/src/text/ngrams_tokenize.cu @@ -39,6 +39,8 @@ #include #include +#include + #include namespace nvtext { @@ -193,10 +195,11 @@ std::unique_ptr ngrams_tokenize(cudf::strings_column_view const& s thrust::make_counting_iterator(0), thrust::make_counting_iterator(strings_count), d_ngram_offsets + 1, - [d_token_offsets, ngrams] __device__(cudf::size_type idx) { - auto token_count = d_token_offsets[idx + 1] - d_token_offsets[idx]; - return (token_count >= ngrams) ? token_count - ngrams + 1 : 0; - }, + cuda::proclaim_return_type( + [d_token_offsets, ngrams] __device__(cudf::size_type idx) { + auto token_count = d_token_offsets[idx + 1] - d_token_offsets[idx]; + return (token_count >= ngrams) ? token_count - ngrams + 1 : 0; + }), thrust::plus{}); ngram_offsets.set_element_to_zero_async(0, stream); auto const total_ngrams = ngram_offsets.back_element(stream); diff --git a/cpp/tests/utilities/column_utilities.cu b/cpp/tests/utilities/column_utilities.cu index f54ea28d9b2..3486d2102eb 100644 --- a/cpp/tests/utilities/column_utilities.cu +++ b/cpp/tests/utilities/column_utilities.cu @@ -49,6 +49,8 @@ #include #include +#include + #include #include @@ -117,17 +119,17 @@ std::unique_ptr generate_child_row_indices(lists_column_view const& c, // compute total # of child row indices we will be emitting. auto row_size_iter = cudf::detail::make_counting_transform_iterator( 0, - [row_indices = row_indices.begin(), - validity = c.null_mask(), - offsets = c.offsets().begin(), - offset = c.offset()] __device__(int index) { + cuda::proclaim_return_type([row_indices = row_indices.begin(), + validity = c.null_mask(), + offsets = c.offsets().begin(), + offset = c.offset()] __device__(int index) { // both null mask and offsets data are not pre-sliced. so we need to add the column offset to // every incoming index. auto const true_index = row_indices[index] + offset; return !validity || cudf::bit_is_set(validity, true_index) ? (offsets[true_index + 1] - offsets[true_index]) : 0; - }); + })); auto const output_size = thrust::reduce(rmm::exec_policy(cudf::test::get_default_stream()), row_size_iter, row_size_iter + row_indices.size()); @@ -155,7 +157,7 @@ std::unique_ptr generate_child_row_indices(lists_column_view const& c, thrust::generate(rmm::exec_policy(cudf::test::get_default_stream()), result->mutable_view().begin(), result->mutable_view().end(), - [] __device__() { return 1; }); + cuda::proclaim_return_type([] __device__() { return 1; })); // scatter the output row positions into result buffer // @@ -163,14 +165,15 @@ std::unique_ptr generate_child_row_indices(lists_column_view const& c, // auto output_row_iter = cudf::detail::make_counting_transform_iterator( 0, - [row_indices = row_indices.begin(), - offsets = c.offsets().begin(), - offset = c.offset(), - first_offset = cudf::detail::get_value( - c.offsets(), c.offset(), cudf::test::get_default_stream())] __device__(int index) { - auto const true_index = row_indices[index] + offset; - return offsets[true_index] - first_offset; - }); + cuda::proclaim_return_type( + [row_indices = row_indices.begin(), + offsets = c.offsets().begin(), + offset = c.offset(), + first_offset = cudf::detail::get_value( + c.offsets(), c.offset(), cudf::test::get_default_stream())] __device__(int index) { + auto const true_index = row_indices[index] + offset; + return offsets[true_index] - first_offset; + })); thrust::scatter_if(rmm::exec_policy(cudf::test::get_default_stream()), output_row_iter, output_row_iter + row_indices.size(), @@ -188,7 +191,7 @@ std::unique_ptr generate_child_row_indices(lists_column_view const& c, thrust::generate(rmm::exec_policy(cudf::test::get_default_stream()), keys->mutable_view().begin(), keys->mutable_view().end(), - [] __device__() { return 0; }); + cuda::proclaim_return_type([] __device__() { return 0; })); thrust::scatter_if(rmm::exec_policy(cudf::test::get_default_stream()), row_size_iter, row_size_iter + row_indices.size(), @@ -244,14 +247,14 @@ struct column_property_comparator { { auto validity_iter = cudf::detail::make_counting_transform_iterator( 0, - [row_indices = row_indices.begin(), - validity = c.null_mask(), - offset = c.offset()] __device__(int index) { + cuda::proclaim_return_type([row_indices = row_indices.begin(), + validity = c.null_mask(), + offset = c.offset()] __device__(int index) { // both null mask and offsets data are not pre-sliced. so we need to add the column offset // to every incoming index. auto const true_index = row_indices[index] + offset; return !validity || cudf::bit_is_set(validity, true_index) ? 0 : 1; - }); + })); return thrust::reduce(rmm::exec_policy(cudf::test::get_default_stream()), validity_iter, validity_iter + row_indices.size()); @@ -623,24 +626,28 @@ struct column_comparator_impl { lhs_l.offsets(), lhs_l.offset(), cudf::test::get_default_stream()); auto lhs_offsets = thrust::make_transform_iterator( lhs_l.offsets().begin() + lhs_l.offset(), - [lhs_shift] __device__(size_type offset) { return offset - lhs_shift; }); + cuda::proclaim_return_type( + [lhs_shift] __device__(size_type offset) { return offset - lhs_shift; })); auto lhs_valids = thrust::make_transform_iterator( thrust::make_counting_iterator(0), - [mask = lhs_l.null_mask(), offset = lhs_l.offset()] __device__(size_type index) { - return mask == nullptr ? true : cudf::bit_is_set(mask, index + offset); - }); + cuda::proclaim_return_type( + [mask = lhs_l.null_mask(), offset = lhs_l.offset()] __device__(size_type index) { + return mask == nullptr ? true : cudf::bit_is_set(mask, index + offset); + })); // right side size_type rhs_shift = cudf::detail::get_value( rhs_l.offsets(), rhs_l.offset(), cudf::test::get_default_stream()); auto rhs_offsets = thrust::make_transform_iterator( rhs_l.offsets().begin() + rhs_l.offset(), - [rhs_shift] __device__(size_type offset) { return offset - rhs_shift; }); + cuda::proclaim_return_type( + [rhs_shift] __device__(size_type offset) { return offset - rhs_shift; })); auto rhs_valids = thrust::make_transform_iterator( thrust::make_counting_iterator(0), - [mask = rhs_l.null_mask(), offset = rhs_l.offset()] __device__(size_type index) { - return mask == nullptr ? true : cudf::bit_is_set(mask, index + offset); - }); + cuda::proclaim_return_type( + [mask = rhs_l.null_mask(), offset = rhs_l.offset()] __device__(size_type index) { + return mask == nullptr ? true : cudf::bit_is_set(mask, index + offset); + })); // when checking for equivalency, we can't compare offset values directly, we can only // compare lengths of the rows, and only if valid. as a concrete example, you could have two diff --git a/java/src/main/native/src/aggregation128_utils.cu b/java/src/main/native/src/aggregation128_utils.cu index 101a2ed2e2c..d722aaa84fe 100644 --- a/java/src/main/native/src/aggregation128_utils.cu +++ b/java/src/main/native/src/aggregation128_utils.cu @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -86,8 +87,9 @@ std::unique_ptr extract_chunk32(cudf::column_view const &in_col, c auto const in_begin = in_col.begin(); // Build an iterator for every fourth 32-bit value, i.e.: one "chunk" of a __int128_t value - thrust::transform_iterator transform_iter{thrust::counting_iterator{0}, - [] __device__(auto i) { return i * 4; }}; + thrust::transform_iterator transform_iter{ + thrust::counting_iterator{0}, + cuda::proclaim_return_type([] __device__(auto i) { return i * 4; })}; thrust::permutation_iterator stride_iter{in_begin + chunk_idx, transform_iter}; thrust::copy(rmm::exec_policy(stream), stride_iter, stride_iter + num_rows, diff --git a/java/src/main/native/src/row_conversion.cu b/java/src/main/native/src/row_conversion.cu index e5d3930843f..fd7e7bc0b31 100644 --- a/java/src/main/native/src/row_conversion.cu +++ b/java/src/main/native/src/row_conversion.cu @@ -15,6 +15,7 @@ */ #include +#include #include #include #include @@ -244,11 +245,12 @@ build_string_row_offsets(table_view const &tbl, size_type fixed_width_and_validi }); // transform the row sizes to include fixed width size and alignment - thrust::transform(rmm::exec_policy(stream), d_row_sizes.begin(), d_row_sizes.end(), - d_row_sizes.begin(), [fixed_width_and_validity_size] __device__(auto row_size) { - return util::round_up_unsafe(fixed_width_and_validity_size + row_size, - JCUDF_ROW_ALIGNMENT); - }); + thrust::transform( + rmm::exec_policy(stream), d_row_sizes.begin(), d_row_sizes.end(), d_row_sizes.begin(), + cuda::proclaim_return_type([fixed_width_and_validity_size] __device__( + auto row_size) { + return util::round_up_unsafe(fixed_width_and_validity_size + row_size, JCUDF_ROW_ALIGNMENT); + })); return {std::move(d_row_sizes), std::move(d_offsets_iterators)}; } @@ -1496,9 +1498,10 @@ batch_data build_batches(size_type num_rows, RowSize row_sizes, bool all_fixed_w while (last_row_end < num_rows) { auto offset_row_sizes = thrust::make_transform_iterator( cumulative_row_sizes.begin(), - [last_row_end, cumulative_row_sizes = cumulative_row_sizes.data()] __device__(auto i) { - return i - cumulative_row_sizes[last_row_end]; - }); + cuda::proclaim_return_type( + [last_row_end, cumulative_row_sizes = cumulative_row_sizes.data()] __device__(auto i) { + return i - cumulative_row_sizes[last_row_end]; + })); auto search_start = offset_row_sizes + last_row_end; auto search_end = offset_row_sizes + num_rows; @@ -1559,14 +1562,15 @@ int compute_tile_counts(device_span const &batch_row_boundaries size_type const num_batches = batch_row_boundaries.size() - 1; device_uvector num_tiles(num_batches, stream); auto iter = thrust::make_counting_iterator(0); - thrust::transform(rmm::exec_policy(stream), iter, iter + num_batches, num_tiles.begin(), - [desired_tile_height, - batch_row_boundaries = - batch_row_boundaries.data()] __device__(auto batch_index) -> size_type { - return util::div_rounding_up_unsafe(batch_row_boundaries[batch_index + 1] - - batch_row_boundaries[batch_index], - desired_tile_height); - }); + thrust::transform( + rmm::exec_policy(stream), iter, iter + num_batches, num_tiles.begin(), + cuda::proclaim_return_type( + [desired_tile_height, batch_row_boundaries = batch_row_boundaries.data()] __device__( + auto batch_index) -> size_type { + return util::div_rounding_up_unsafe(batch_row_boundaries[batch_index + 1] - + batch_row_boundaries[batch_index], + desired_tile_height); + })); return thrust::reduce(rmm::exec_policy(stream), num_tiles.begin(), num_tiles.end()); } @@ -1590,53 +1594,56 @@ build_tiles(device_span tiles, size_type const num_batches = batch_row_boundaries.size() - 1; device_uvector num_tiles(num_batches, stream); auto iter = thrust::make_counting_iterator(0); - thrust::transform(rmm::exec_policy(stream), iter, iter + num_batches, num_tiles.begin(), - [desired_tile_height, - batch_row_boundaries = - batch_row_boundaries.data()] __device__(auto batch_index) -> size_type { - return util::div_rounding_up_unsafe(batch_row_boundaries[batch_index + 1] - - batch_row_boundaries[batch_index], - desired_tile_height); - }); + thrust::transform( + rmm::exec_policy(stream), iter, iter + num_batches, num_tiles.begin(), + cuda::proclaim_return_type( + [desired_tile_height, batch_row_boundaries = batch_row_boundaries.data()] __device__( + auto batch_index) -> size_type { + return util::div_rounding_up_unsafe(batch_row_boundaries[batch_index + 1] - + batch_row_boundaries[batch_index], + desired_tile_height); + })); size_type const total_tiles = thrust::reduce(rmm::exec_policy(stream), num_tiles.begin(), num_tiles.end()); device_uvector tile_starts(num_batches + 1, stream); auto tile_iter = cudf::detail::make_counting_transform_iterator( - 0, [num_tiles = num_tiles.data(), num_batches] __device__(auto i) { - return (i < num_batches) ? num_tiles[i] : 0; - }); + 0, cuda::proclaim_return_type( + [num_tiles = num_tiles.data(), num_batches] __device__(auto i) { + return (i < num_batches) ? num_tiles[i] : 0; + })); thrust::exclusive_scan(rmm::exec_policy(stream), tile_iter, tile_iter + num_batches + 1, tile_starts.begin()); // in tiles thrust::transform( rmm::exec_policy(stream), iter, iter + total_tiles, tiles.begin(), - [=, tile_starts = tile_starts.data(), - batch_row_boundaries = batch_row_boundaries.data()] __device__(size_type tile_index) { - // what batch this tile falls in - auto const batch_index_iter = - thrust::upper_bound(thrust::seq, tile_starts, tile_starts + num_batches, tile_index); - auto const batch_index = std::distance(tile_starts, batch_index_iter) - 1; - // local index within the tile - int const local_tile_index = tile_index - tile_starts[batch_index]; - // the start row for this batch. - int const batch_row_start = batch_row_boundaries[batch_index]; - // the start row for this tile - int const tile_row_start = batch_row_start + (local_tile_index * desired_tile_height); - // the end row for this tile - int const max_row = - std::min(total_number_of_rows - 1, - batch_index + 1 > num_batches ? - std::numeric_limits::max() : - static_cast(batch_row_boundaries[batch_index + 1]) - 1); - int const tile_row_end = - std::min(batch_row_start + ((local_tile_index + 1) * desired_tile_height) - 1, max_row); - - // stuff the tile - return tile_info{column_start, tile_row_start, column_end, tile_row_end, - static_cast(batch_index)}; - }); + cuda::proclaim_return_type( + [=, tile_starts = tile_starts.data(), + batch_row_boundaries = batch_row_boundaries.data()] __device__(size_type tile_index) { + // what batch this tile falls in + auto const batch_index_iter = thrust::upper_bound( + thrust::seq, tile_starts, tile_starts + num_batches, tile_index); + auto const batch_index = std::distance(tile_starts, batch_index_iter) - 1; + // local index within the tile + int const local_tile_index = tile_index - tile_starts[batch_index]; + // the start row for this batch. + int const batch_row_start = batch_row_boundaries[batch_index]; + // the start row for this tile + int const tile_row_start = batch_row_start + (local_tile_index * desired_tile_height); + // the end row for this tile + int const max_row = + std::min(total_number_of_rows - 1, + batch_index + 1 > num_batches ? + std::numeric_limits::max() : + static_cast(batch_row_boundaries[batch_index + 1]) - 1); + int const tile_row_end = std::min( + batch_row_start + ((local_tile_index + 1) * desired_tile_height) - 1, max_row); + + // stuff the tile + return tile_info{column_start, tile_row_start, column_end, tile_row_end, + static_cast(batch_index)}; + })); return total_tiles; } @@ -2151,7 +2158,8 @@ std::unique_ptr
convert_from_rows(lists_column_view const &input, thrust::transform(rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_batches), gpu_batch_row_boundaries.begin(), - [num_rows] __device__(auto i) { return i == 0 ? 0 : num_rows; }); + cuda::proclaim_return_type( + [num_rows] __device__(auto i) { return i == 0 ? 0 : num_rows; })); int info_count = 0; detail::determine_tiles( @@ -2222,9 +2230,10 @@ std::unique_ptr
convert_from_rows(lists_column_view const &input, std::vector string_data_col_ptrs; for (auto &col_string_lengths : string_lengths) { device_uvector output_string_offsets(num_rows + 1, stream, mr); - auto tmp = [num_rows, col_string_lengths] __device__(auto const &i) { - return i < num_rows ? col_string_lengths[i] : 0; - }; + auto tmp = cuda::proclaim_return_type( + [num_rows, col_string_lengths] __device__(auto const &i) { + return i < num_rows ? col_string_lengths[i] : 0; + }); auto bounded_iter = cudf::detail::make_counting_transform_iterator(0, tmp); thrust::exclusive_scan(rmm::exec_policy(stream), bounded_iter, bounded_iter + num_rows + 1, output_string_offsets.begin());