Skip to content

Commit

Permalink
[Feat] add repeat, sparsity, eval_n_elements APIs to bitset
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Sep 16, 2024
1 parent ac53a0f commit 59a3898
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 10 deletions.
98 changes: 98 additions & 0 deletions cpp/include/raft/core/bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <raft/util/device_atomics.cuh>
#include <raft/util/popc.cuh>

#include <rmm/device_scalar.hpp>

#include <thrust/for_each.h>

namespace raft::core {
Expand Down Expand Up @@ -60,6 +62,102 @@ _RAFT_DEVICE void bitset_view<bitset_t, index_t>::set(const index_t sample_index
}
}

template <typename bitset_t, typename index_t>
struct bitset_copy_functor {
const bitset_t* bitset_ptr;
bitset_t* output_device_ptr;
index_t valid_bits;
index_t bits_per_element;
index_t total_bits;

bitset_copy_functor(const bitset_t* _bitset_ptr,
bitset_t* _output_device_ptr,
index_t _valid_bits,
index_t _bits_per_element,
index_t _total_bits)
: bitset_ptr(_bitset_ptr),
output_device_ptr(_output_device_ptr),
valid_bits(_valid_bits),
bits_per_element(_bits_per_element),
total_bits(_total_bits)
{
}

__device__ void operator()(index_t i)
{
if (i < total_bits) {
index_t src_bit_index = i % valid_bits;
index_t dst_bit_index = i;

index_t src_element_index = src_bit_index / bits_per_element;
index_t src_bit_offset = src_bit_index % bits_per_element;

index_t dst_element_index = dst_bit_index / bits_per_element;
index_t dst_bit_offset = dst_bit_index % bits_per_element;

bitset_t src_element = bitset_ptr[src_element_index];
bitset_t src_bit = (src_element >> src_bit_offset) & 1;

if (src_bit) {
atomicOr(output_device_ptr + dst_element_index, bitset_t(1) << dst_bit_offset);
} else {
atomicAnd(output_device_ptr + dst_element_index, ~(bitset_t(1) << dst_bit_offset));
}
}
}
};

template <typename bitset_t, typename index_t>
void bitset_view<bitset_t, index_t>::repeat(const raft::resources& res,
index_t times,
bitset_t* output_device_ptr) const
{
auto thrust_policy = raft::resource::get_thrust_policy(res);
constexpr index_t bits_per_element = sizeof(bitset_t) * 8;

if (bitset_len_ % bits_per_element == 0) {
index_t num_elements_to_copy = bitset_len_ / bits_per_element;

for (index_t i = 0; i < times; ++i) {
raft::copy(output_device_ptr + i * num_elements_to_copy,
bitset_ptr_,
num_elements_to_copy,
raft::resource::get_cuda_stream(res));
}
} else {
index_t valid_bits = bitset_len_;
index_t total_bits = valid_bits * times;
index_t output_row_elements = (total_bits + bits_per_element - 1) / bits_per_element;
thrust::for_each_n(thrust_policy,
thrust::counting_iterator<index_t>(0),
total_bits,
bitset_copy_functor<bitset_t, index_t>(
bitset_ptr_, output_device_ptr, valid_bits, bits_per_element, total_bits));
}
}

template <typename bitset_t, typename index_t>
double bitset_view<bitset_t, index_t>::sparsity(const raft::resources& res) const
{
index_t nnz_h = 0;
index_t size_h = this->size();
auto stream = raft::resource::get_cuda_stream(res);

if (0 == size_h) { return static_cast<double>(1.0); }

rmm::device_scalar<index_t> nnz(0, stream);

auto vector_view = raft::make_device_vector_view<const bitset_t, index_t>(data(), n_elements());
auto nnz_view = raft::make_device_scalar_view<index_t>(nnz.data());
auto size_view = raft::make_host_scalar_view<index_t>(&size_h);

raft::popc(res, vector_view, size_view, nnz_view);
raft::copy(&nnz_h, nnz.data(), 1, stream);

raft::resource::sync_stream(res, stream);
return static_cast<double>((1.0 * (size_h - nnz_h)) / (1.0 * size_h));
}

template <typename bitset_t, typename index_t>
bitset<bitset_t, index_t>::bitset(const raft::resources& res,
raft::device_vector_view<const index_t, index_t> mask_index,
Expand Down
53 changes: 53 additions & 0 deletions cpp/include/raft/core/bitset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include <raft/core/resources.hpp>
#include <raft/util/integer_utils.hpp>

#include <cmath>

namespace raft::core {
/**
* @defgroup bitset Bitset
Expand Down Expand Up @@ -104,6 +106,57 @@ struct bitset_view {
return raft::make_device_vector_view<const bitset_t, index_t>(bitset_ptr_, n_elements());
}

/**
* @brief Repeats the bitset data and copies it to the output device pointer.
*
* This function takes the original bitset data stored in the device memory
* and repeats it a specified number of times into a new location in the device memory.
* The bits are copied bit-by-bit to ensure that even if the number of bits (bitset_len_)
* is not a multiple of the bitset element size (e.g., 32 for uint32_t), the bits are
* tightly packed without any gaps between rows.
*
* @param res RAFT resources for managing CUDA streams and execution policies.
* @param times Number of times the bitset data should be repeated in the output.
* @param output_device_ptr Device pointer where the repeated bitset data will be stored.
*
* The caller must ensure that the output device pointer has enough memory allocated
* to hold `times * bitset_len` bits, where `bitset_len` is the number of bits in the original
* bitset. This function uses Thrust parallel algorithms to efficiently perform the operation on
* the GPU.
*/
void repeat(const raft::resources& res, index_t times, bitset_t* output_device_ptr) const;

/**
* @brief Calculate the sparsity (fraction of 0s) of the bitset.
*
* This function computes the sparsity of the bitset, defined as the ratio of unset bits (0s)
* to the total number of bits in the set. If the total number of bits is zero, the function
* returns 1.0, indicating the set is fully sparse.
*
* @param res RAFT resources for managing CUDA streams and execution policies.
* @return double The sparsity of the bitset, i.e., the fraction of unset bits.
*
* This API will synchronize on the stream of `res`.
*/
double sparsity(const raft::resources& res) const;

/**
* @brief Calculates the number of `bitset_t` elements required to store a bitset.
*
* This function computes the number of `bitset_t` elements needed to store a bitset, ensuring
* that all bits are accounted for. If the bitset length is not a multiple of the `bitset_t` size
* (in bits), the calculation rounds up to include the remaining bits in an additional `bitset_t`
* element.
*
* @param bitset_len The total length of the bitset in bits.
* @return size_t The number of `bitset_t` elements required to store the bitset.
*/
static inline size_t eval_n_elements(size_t bitset_len)
{
const size_t bits_per_element = sizeof(bitset_t) * 8;
return (bitset_len + bits_per_element - 1) / bits_per_element;
}

private:
bitset_t* bitset_ptr_;
index_t bitset_len_;
Expand Down
102 changes: 92 additions & 10 deletions cpp/test/core/bitset.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,12 +32,13 @@ struct test_spec_bitset {
uint64_t bitset_len;
uint64_t mask_len;
uint64_t query_len;
uint64_t repeat_times;
};

auto operator<<(std::ostream& os, const test_spec_bitset& ss) -> std::ostream&
{
os << "bitset{bitset_len: " << ss.bitset_len << ", mask_len: " << ss.mask_len
<< ", query_len: " << ss.query_len << "}";
<< ", query_len: " << ss.query_len << ", repeat_times: " << ss.repeat_times << "}";
return os;
}

Expand Down Expand Up @@ -80,20 +81,68 @@ void flip_cpu_bitset(std::vector<bitset_t>& bitset)
}
}

template <typename bitset_t>
void repeat_cpu_bitset(std::vector<bitset_t>& input,
size_t input_bits,
size_t repeat,
std::vector<bitset_t>& output)
{
const size_t output_bits = input_bits * repeat;
const size_t output_units = (output_bits + sizeof(bitset_t) * 8 - 1) / (sizeof(bitset_t) * 8);

std::memset(output.data(), 0, output_units * sizeof(bitset_t));

size_t output_bit_index = 0;

for (size_t r = 0; r < repeat; ++r) {
for (size_t i = 0; i < input_bits; ++i) {
size_t input_unit_index = i / (sizeof(bitset_t) * 8);
size_t input_bit_offset = i % (sizeof(bitset_t) * 8);
bool bit = (input[input_unit_index] >> input_bit_offset) & 1;

size_t output_unit_index = output_bit_index / (sizeof(bitset_t) * 8);
size_t output_bit_offset = output_bit_index % (sizeof(bitset_t) * 8);

output[output_unit_index] |= (static_cast<bitset_t>(bit) << output_bit_offset);

++output_bit_index;
}
}
}

template <typename bitset_t>
double sparsity_cpu_bitset(std::vector<bitset_t>& data, size_t total_bits)
{
size_t one_count = 0;
for (size_t i = 0; i < total_bits; ++i) {
size_t unit_index = i / (sizeof(bitset_t) * 8);
size_t bit_offset = i % (sizeof(bitset_t) * 8);
bool bit = (data[unit_index] >> bit_offset) & 1;
if (bit == 1) { ++one_count; }
}
return static_cast<double>((total_bits - one_count) / (1.0 * total_bits));
}

template <typename bitset_t, typename index_t>
class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
protected:
index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8;
const test_spec_bitset spec;
std::vector<bitset_t> bitset_result;
std::vector<bitset_t> bitset_ref;
std::vector<bitset_t> bitset_repeat_ref;
std::vector<bitset_t> bitset_repeat_result;
raft::resources res;

public:
explicit BitsetTest()
: spec(testing::TestWithParam<test_spec_bitset>::GetParam()),
bitset_result(raft::ceildiv(spec.bitset_len, uint64_t(bitset_element_size))),
bitset_ref(raft::ceildiv(spec.bitset_len, uint64_t(bitset_element_size)))
bitset_ref(raft::ceildiv(spec.bitset_len, uint64_t(bitset_element_size))),
bitset_repeat_ref(
raft::ceildiv(spec.bitset_len * spec.repeat_times, uint64_t(bitset_element_size))),
bitset_repeat_result(
raft::ceildiv(spec.bitset_len * spec.repeat_times, uint64_t(bitset_element_size)))
{
}

Expand Down Expand Up @@ -145,6 +194,37 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
resource::sync_stream(res, stream);
ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare<bitset_t>()));

// test sparsity, repeat and eval_n_elements
if constexpr (std::is_same_v<bitset_t, uint32_t> || std::is_same_v<bitset_t, uint64_t>) {
auto my_bitset_view = my_bitset.view();
auto sparsity_result = my_bitset_view.sparsity(res);
auto sparsity_ref = sparsity_cpu_bitset(bitset_ref, size_t(spec.bitset_len));
ASSERT_EQ(sparsity_result, sparsity_ref);

auto eval_n_elements =
bitset_view<bitset_t, index_t>::eval_n_elements(spec.bitset_len * spec.repeat_times);
ASSERT_EQ(bitset_repeat_ref.size(), eval_n_elements);

auto repeat_device = raft::make_device_vector<bitset_t, index_t>(res, eval_n_elements);
RAFT_CUDA_TRY(cudaMemsetAsync(
repeat_device.data_handle(), 0, eval_n_elements * sizeof(bitset_t), stream));
repeat_cpu_bitset(
bitset_ref, size_t(spec.bitset_len), size_t(spec.repeat_times), bitset_repeat_ref);

my_bitset_view.repeat(res, index_t(spec.repeat_times), repeat_device.data_handle());

ASSERT_EQ(bitset_repeat_ref.size(), repeat_device.size());
update_host(
bitset_repeat_result.data(), repeat_device.data_handle(), repeat_device.size(), stream);
ASSERT_EQ(bitset_repeat_ref.size(), bitset_repeat_result.size());
ASSERT_TRUE(hostVecMatch(bitset_repeat_ref, bitset_repeat_result, raft::Compare<bitset_t>()));

// recheck the sparsity after repeat
sparsity_result =
sparsity_cpu_bitset(bitset_repeat_result, size_t(spec.bitset_len * spec.repeat_times));
ASSERT_EQ(sparsity_result, sparsity_ref);
}

// Flip the bitset and re-test
auto bitset_count = my_bitset.count(res);
my_bitset.flip(res);
Expand All @@ -167,13 +247,15 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
}
};

auto inputs_bitset = ::testing::Values(test_spec_bitset{32, 5, 10},
test_spec_bitset{100, 30, 10},
test_spec_bitset{1024, 55, 100},
test_spec_bitset{10000, 1000, 1000},
test_spec_bitset{1 << 15, 1 << 3, 1 << 12},
test_spec_bitset{1 << 15, 1 << 24, 1 << 13},
test_spec_bitset{1 << 25, 1 << 23, 1 << 14});
auto inputs_bitset = ::testing::Values(test_spec_bitset{32, 5, 10, 101},
test_spec_bitset{100, 30, 10, 13},
test_spec_bitset{1024, 55, 100, 1},
test_spec_bitset{10000, 1000, 1000, 100},
test_spec_bitset{1 << 15, 1 << 3, 1 << 12, 5},
test_spec_bitset{1 << 15, 1 << 24, 1 << 13, 3},
test_spec_bitset{1 << 25, 1 << 23, 1 << 14, 3},
test_spec_bitset{1 << 25, 1 << 23, 1 << 14, 201},
test_spec_bitset{10000000, 1 << 23, 1 << 14, 401});

using Uint16_32 = BitsetTest<uint16_t, uint32_t>;
TEST_P(Uint16_32, Run) { run(); }
Expand Down

0 comments on commit 59a3898

Please sign in to comment.