From baeb07d5cd8d1578d083c81a1a89cd0381cb224c Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Fri, 15 Apr 2022 22:46:24 -0700 Subject: [PATCH] Add strong index type. --- .../cudf/detail/utilities/strong_index.hpp | 117 ++++++++++++++++++ cpp/include/cudf/table/row_operators.cuh | 26 ++-- cpp/src/groupby/sort/group_nunique.cu | 15 ++- cpp/src/transform/one_hot_encode.cu | 7 +- 4 files changed, 148 insertions(+), 17 deletions(-) create mode 100644 cpp/include/cudf/detail/utilities/strong_index.hpp diff --git a/cpp/include/cudf/detail/utilities/strong_index.hpp b/cpp/include/cudf/detail/utilities/strong_index.hpp new file mode 100644 index 00000000000..6655c6a1e5f --- /dev/null +++ b/cpp/include/cudf/detail/utilities/strong_index.hpp @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace cudf { + +namespace detail { + +enum class index_type_name { + LEFT, + RIGHT, +}; + +template +struct strong_index { + public: + constexpr explicit strong_index(size_type index) : _index(index) {} + + constexpr explicit operator size_type() const { return _index; } + + constexpr size_type value() const { return _index; } + + constexpr strong_index operator+(size_type const& v) const { return strong_index(_index + v); } + constexpr strong_index operator-(size_type const& v) const { return strong_index(_index - v); } + constexpr strong_index operator*(size_type const& v) const { return strong_index(_index * v); } + + constexpr bool operator==(size_type v) const { return _index == v; } + constexpr bool operator!=(size_type v) const { return _index != v; } + constexpr bool operator<=(size_type v) const { return _index <= v; } + constexpr bool operator>=(size_type v) const { return _index >= v; } + constexpr bool operator<(size_type v) const { return _index < v; } + constexpr bool operator>(size_type v) const { return _index > v; } + + constexpr strong_index& operator=(strong_index const& s) + { + _index = s._index; + return *this; + } + constexpr strong_index& operator=(size_type const& i) + { + _index = i; + return *this; + } + constexpr strong_index& operator+=(size_type const& v) + { + _index += v; + return *this; + } + constexpr strong_index& operator-=(size_type const& v) + { + _index -= v; + return *this; + } + constexpr strong_index& operator*=(size_type const& v) + { + _index *= v; + return *this; + } + + constexpr strong_index& operator++() + { + ++_index; + return *this; + } + constexpr strong_index operator++(int) + { + strong_index tmp(*this); + ++_index; + return tmp; + } + constexpr strong_index& operator--() + { + --_index; + return *this; + } + constexpr strong_index operator--(int) + { + strong_index tmp(*this); + --_index; + return tmp; + } + + friend std::ostream& operator<<(std::ostream& os, strong_index s) + { + return os << s._index; + } + friend std::istream& operator>>(std::istream& is, strong_index& s) + { + return is >> s._index; + } + + private: + size_type _index; +}; + +} // namespace detail + +using lhs_index_type = detail::strong_index; +using rhs_index_type = detail::strong_index; + +} // namespace cudf diff --git a/cpp/include/cudf/table/row_operators.cuh b/cpp/include/cudf/table/row_operators.cuh index 4eca03a800c..5bad736c580 100644 --- a/cpp/include/cudf/table/row_operators.cuh +++ b/cpp/include/cudf/table/row_operators.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -198,12 +199,12 @@ class element_equality_comparator { */ template ()>* = nullptr> - __device__ bool operator()(size_type lhs_element_index, - size_type rhs_element_index) const noexcept + __device__ bool operator()(cudf::lhs_index_type lhs_element_index, + cudf::rhs_index_type rhs_element_index) const noexcept { if (nulls) { - bool const lhs_is_null{lhs.is_null(lhs_element_index)}; - bool const rhs_is_null{rhs.is_null(rhs_element_index)}; + bool const lhs_is_null{lhs.is_null(lhs_element_index.value())}; + bool const rhs_is_null{rhs.is_null(rhs_element_index.value())}; if (lhs_is_null and rhs_is_null) { return nulls_are_equal == null_equality::EQUAL; } else if (lhs_is_null != rhs_is_null) { @@ -211,17 +212,24 @@ class element_equality_comparator { } } - return equality_compare(lhs.element(lhs_element_index), - rhs.element(rhs_element_index)); + return equality_compare(lhs.element(lhs_element_index.value()), + rhs.element(rhs_element_index.value())); } template ()>* = nullptr> - __device__ bool operator()(size_type lhs_element_index, size_type rhs_element_index) + __device__ bool operator()(cudf::lhs_index_type lhs_element_index, + cudf::rhs_index_type rhs_element_index) const noexcept { CUDF_UNREACHABLE("Attempted to compare elements of uncomparable types."); } + __device__ bool operator()(cudf::rhs_index_type rhs_element_index, + cudf::lhs_index_type lhs_element_index) const noexcept + { + return operator()(lhs_element_index, rhs_element_index); + } + private: column_device_view lhs; column_device_view rhs; @@ -246,8 +254,8 @@ class row_equality_comparator { auto equal_elements = [=](column_device_view l, column_device_view r) { return cudf::type_dispatcher(l.type(), element_equality_comparator{nulls, l, r, nulls_are_equal}, - lhs_row_index, - rhs_row_index); + cudf::lhs_index_type(lhs_row_index), + cudf::rhs_index_type(rhs_row_index)); }; return thrust::equal(thrust::seq, lhs.begin(), lhs.end(), rhs.begin(), equal_elements); diff --git a/cpp/src/groupby/sort/group_nunique.cu b/cpp/src/groupby/sort/group_nunique.cu index 478060cbd16..a29e3cf238c 100644 --- a/cpp/src/groupby/sort/group_nunique.cu +++ b/cpp/src/groupby/sort/group_nunique.cu @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -62,9 +63,11 @@ struct nunique_functor { group_labels = group_labels.data()] __device__(auto i) -> size_type { bool is_input_countable = (null_handling == null_policy::INCLUDE || v.is_valid_nocheck(i)); - bool is_unique = is_input_countable && - (group_offsets[group_labels[i]] == i || // first element or - (not equal.operator()(i, i - 1))); // new unique value in sorted + bool is_unique = + is_input_countable && + (group_offsets[group_labels[i]] == i || // first element or + (not equal.operator()(cudf::lhs_index_type(i), + cudf::rhs_index_type(i - 1)))); // new unique value in sorted return static_cast(is_unique); }); @@ -82,8 +85,10 @@ struct nunique_functor { equal, group_offsets = group_offsets.data(), group_labels = group_labels.data()] __device__(auto i) -> size_type { - bool is_unique = group_offsets[group_labels[i]] == i || // first element or - (not equal.operator()(i, i - 1)); // new unique value in sorted + bool is_unique = + group_offsets[group_labels[i]] == i || // first element or + (not equal.operator()(cudf::lhs_index_type(i), + cudf::rhs_index_type(i - 1))); // new unique value in sorted return static_cast(is_unique); }); thrust::reduce_by_key(rmm::exec_policy(stream), diff --git a/cpp/src/transform/one_hot_encode.cu b/cpp/src/transform/one_hot_encode.cu index 16aee349bb5..10567d21d09 100644 --- a/cpp/src/transform/one_hot_encode.cu +++ b/cpp/src/transform/one_hot_encode.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -47,8 +48,8 @@ struct one_hot_encode_functor { bool __device__ operator()(size_type i) { - size_type const element_index = i % _input_size; - size_type const category_index = i / _input_size; + cudf::lhs_index_type const element_index(i % _input_size); + cudf::rhs_index_type const category_index(i / _input_size); return _equality_comparator.template operator()(element_index, category_index); }