Skip to content

Commit

Permalink
Add strong index type.
Browse files Browse the repository at this point in the history
  • Loading branch information
bdice committed Apr 20, 2022
1 parent c8c7271 commit baeb07d
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 17 deletions.
117 changes: 117 additions & 0 deletions cpp/include/cudf/detail/utilities/strong_index.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudf/types.hpp>

namespace cudf {

namespace detail {

enum class index_type_name {
LEFT,
RIGHT,
};

template <index_type_name IndexName>
struct strong_index {
public:
constexpr explicit strong_index(size_type index) : _index(index) {}

constexpr explicit operator size_type() const { return _index; }

constexpr size_type value() const { return _index; }

constexpr strong_index operator+(size_type const& v) const { return strong_index(_index + v); }
constexpr strong_index operator-(size_type const& v) const { return strong_index(_index - v); }
constexpr strong_index operator*(size_type const& v) const { return strong_index(_index * v); }

constexpr bool operator==(size_type v) const { return _index == v; }
constexpr bool operator!=(size_type v) const { return _index != v; }
constexpr bool operator<=(size_type v) const { return _index <= v; }
constexpr bool operator>=(size_type v) const { return _index >= v; }
constexpr bool operator<(size_type v) const { return _index < v; }
constexpr bool operator>(size_type v) const { return _index > v; }

constexpr strong_index& operator=(strong_index const& s)
{
_index = s._index;
return *this;
}
constexpr strong_index& operator=(size_type const& i)
{
_index = i;
return *this;
}
constexpr strong_index& operator+=(size_type const& v)
{
_index += v;
return *this;
}
constexpr strong_index& operator-=(size_type const& v)
{
_index -= v;
return *this;
}
constexpr strong_index& operator*=(size_type const& v)
{
_index *= v;
return *this;
}

constexpr strong_index& operator++()
{
++_index;
return *this;
}
constexpr strong_index operator++(int)
{
strong_index tmp(*this);
++_index;
return tmp;
}
constexpr strong_index& operator--()
{
--_index;
return *this;
}
constexpr strong_index operator--(int)
{
strong_index tmp(*this);
--_index;
return tmp;
}

friend std::ostream& operator<<(std::ostream& os, strong_index<IndexName> s)
{
return os << s._index;
}
friend std::istream& operator>>(std::istream& is, strong_index<IndexName>& s)
{
return is >> s._index;
}

private:
size_type _index;
};

} // namespace detail

using lhs_index_type = detail::strong_index<detail::index_type_name::LEFT>;
using rhs_index_type = detail::strong_index<detail::index_type_name::RIGHT>;

} // namespace cudf
26 changes: 17 additions & 9 deletions cpp/include/cudf/table/row_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cudf/detail/hashing.hpp>
#include <cudf/detail/utilities/assert.cuh>
#include <cudf/detail/utilities/hash_functions.cuh>
#include <cudf/detail/utilities/strong_index.hpp>
#include <cudf/sorting.hpp>
#include <cudf/table/table_device_view.cuh>
#include <cudf/utilities/traits.hpp>
Expand Down Expand Up @@ -198,30 +199,37 @@ class element_equality_comparator {
*/
template <typename Element,
std::enable_if_t<cudf::is_equality_comparable<Element, Element>()>* = nullptr>
__device__ bool operator()(size_type lhs_element_index,
size_type rhs_element_index) const noexcept
__device__ bool operator()(cudf::lhs_index_type lhs_element_index,
cudf::rhs_index_type rhs_element_index) const noexcept
{
if (nulls) {
bool const lhs_is_null{lhs.is_null(lhs_element_index)};
bool const rhs_is_null{rhs.is_null(rhs_element_index)};
bool const lhs_is_null{lhs.is_null(lhs_element_index.value())};
bool const rhs_is_null{rhs.is_null(rhs_element_index.value())};
if (lhs_is_null and rhs_is_null) {
return nulls_are_equal == null_equality::EQUAL;
} else if (lhs_is_null != rhs_is_null) {
return false;
}
}

return equality_compare(lhs.element<Element>(lhs_element_index),
rhs.element<Element>(rhs_element_index));
return equality_compare(lhs.element<Element>(lhs_element_index.value()),
rhs.element<Element>(rhs_element_index.value()));
}

template <typename Element,
std::enable_if_t<not cudf::is_equality_comparable<Element, Element>()>* = nullptr>
__device__ bool operator()(size_type lhs_element_index, size_type rhs_element_index)
__device__ bool operator()(cudf::lhs_index_type lhs_element_index,
cudf::rhs_index_type rhs_element_index) const noexcept
{
CUDF_UNREACHABLE("Attempted to compare elements of uncomparable types.");
}

__device__ bool operator()(cudf::rhs_index_type rhs_element_index,
cudf::lhs_index_type lhs_element_index) const noexcept
{
return operator()(lhs_element_index, rhs_element_index);
}

private:
column_device_view lhs;
column_device_view rhs;
Expand All @@ -246,8 +254,8 @@ class row_equality_comparator {
auto equal_elements = [=](column_device_view l, column_device_view r) {
return cudf::type_dispatcher(l.type(),
element_equality_comparator{nulls, l, r, nulls_are_equal},
lhs_row_index,
rhs_row_index);
cudf::lhs_index_type(lhs_row_index),
cudf::rhs_index_type(rhs_row_index));
};

return thrust::equal(thrust::seq, lhs.begin(), lhs.end(), rhs.begin(), equal_elements);
Expand Down
15 changes: 10 additions & 5 deletions cpp/src/groupby/sort/group_nunique.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <cudf/aggregation.hpp>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/iterator.cuh>
#include <cudf/detail/utilities/strong_index.hpp>
#include <cudf/table/row_operators.cuh>
#include <cudf/types.hpp>
#include <cudf/utilities/span.hpp>
Expand Down Expand Up @@ -62,9 +63,11 @@ struct nunique_functor {
group_labels = group_labels.data()] __device__(auto i) -> size_type {
bool is_input_countable =
(null_handling == null_policy::INCLUDE || v.is_valid_nocheck(i));
bool is_unique = is_input_countable &&
(group_offsets[group_labels[i]] == i || // first element or
(not equal.operator()<T>(i, i - 1))); // new unique value in sorted
bool is_unique =
is_input_countable &&
(group_offsets[group_labels[i]] == i || // first element or
(not equal.operator()<T>(cudf::lhs_index_type(i),
cudf::rhs_index_type(i - 1)))); // new unique value in sorted
return static_cast<size_type>(is_unique);
});

Expand All @@ -82,8 +85,10 @@ struct nunique_functor {
equal,
group_offsets = group_offsets.data(),
group_labels = group_labels.data()] __device__(auto i) -> size_type {
bool is_unique = group_offsets[group_labels[i]] == i || // first element or
(not equal.operator()<T>(i, i - 1)); // new unique value in sorted
bool is_unique =
group_offsets[group_labels[i]] == i || // first element or
(not equal.operator()<T>(cudf::lhs_index_type(i),
cudf::rhs_index_type(i - 1))); // new unique value in sorted
return static_cast<size_type>(is_unique);
});
thrust::reduce_by_key(rmm::exec_policy(stream),
Expand Down
7 changes: 4 additions & 3 deletions cpp/src/transform/one_hot_encode.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@
#include <cudf/detail/copy.hpp>
#include <cudf/detail/iterator.cuh>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/detail/utilities/strong_index.hpp>
#include <cudf/table/row_operators.cuh>
#include <cudf/types.hpp>
#include <cudf/utilities/error.hpp>
Expand Down Expand Up @@ -47,8 +48,8 @@ struct one_hot_encode_functor {

bool __device__ operator()(size_type i)
{
size_type const element_index = i % _input_size;
size_type const category_index = i / _input_size;
cudf::lhs_index_type const element_index(i % _input_size);
cudf::rhs_index_type const category_index(i / _input_size);
return _equality_comparator.template operator()<InputType>(element_index, category_index);
}

Expand Down

0 comments on commit baeb07d

Please sign in to comment.