Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter out infinities in radix-based select-k #1742

Closed
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
8786766
Filter out infinities in radix-based select-k
achirkin Aug 16, 2023
8964ba0
Write neighbours with inf dist at zero pass and allow zero current_le…
achirkin Aug 18, 2023
35296d4
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
achirkin Aug 22, 2023
3eb905f
Add noinline to filter_and_histogram_for_one_block to reduce register…
achirkin Aug 22, 2023
44014cf
Add tests for inputs containing infinities
achirkin Aug 22, 2023
3b0b68d
Fix the last filter not ignoring the bound values
achirkin Aug 23, 2023
e7ece2c
Add benchmarks for the edge case of having many infinities
achirkin Aug 23, 2023
d956724
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
achirkin Aug 23, 2023
28f5004
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
achirkin Aug 24, 2023
09ab49b
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
achirkin Aug 24, 2023
39a8a75
Make the result testing more strict, allow setting input indices, and…
achirkin Aug 28, 2023
32e8bf6
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
achirkin Aug 28, 2023
6a1e670
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
achirkin Aug 29, 2023
01f2342
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
achirkin Aug 30, 2023
276ce93
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
achirkin Aug 30, 2023
e70b392
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
achirkin Aug 30, 2023
07a629c
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
achirkin Aug 30, 2023
e7189bb
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
achirkin Aug 30, 2023
08a692c
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
achirkin Aug 31, 2023
2ca00fd
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
achirkin Aug 31, 2023
645d8ec
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
achirkin Sep 1, 2023
0591fe8
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
achirkin Sep 5, 2023
8cf4da0
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
achirkin Sep 7, 2023
dcf6d68
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
achirkin Sep 8, 2023
b587374
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
cjnolet Sep 8, 2023
b170f02
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
achirkin Sep 9, 2023
87551e7
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
achirkin Sep 13, 2023
67e8d5a
Merge branch 'branch-23.10' into enh-select-k-radix-handle-infinities
achirkin Sep 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 67 additions & 27 deletions cpp/include/raft/matrix/detail/select_radix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ struct alignas(128) Counter {
// value are written from back to front. We need to keep count of them separately because the
// number of elements that <= the k-th value might exceed k.
alignas(128) IdxT out_back_cnt;

// Number of infinities found in the zero pass.
alignas(128) IdxT bound_cnt;
};

/**
Expand All @@ -221,7 +224,8 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf,
IdxT* histogram,
bool select_min,
int pass,
bool early_stop)
bool early_stop,
IdxT k)
{
constexpr int num_buckets = calc_num_buckets<BitsPerPass>();
__shared__ IdxT histogram_smem[num_buckets];
Expand All @@ -232,12 +236,24 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf,

const int start_bit = calc_start_bit<T, BitsPerPass>(pass);
const unsigned mask = calc_mask<T, BitsPerPass>(pass);
// The last possible value (k-th cannot be further).
const auto bound = select_min ? upper_bound<T>() : lower_bound<T>();

if (pass == 0) {
IdxT* p_bound_cnt = &counter->bound_cnt;
// Passed to vectorized_process, this function executes in all blocks in parallel,
// i.e. the work is split along the input (both, in batches and chunks of a single row).
// Later, the histograms are merged using atomicAdd.
auto f = [select_min, start_bit, mask](T value, IdxT) {
auto f = [in_idx_buf, out, out_idx, select_min, start_bit, mask, bound, p_bound_cnt, k](
T value, IdxT i) {
if (value == bound) {
if (i < k) {
IdxT pos = k - 1 - atomicAdd(p_bound_cnt, IdxT{1});
out[pos] = value;
out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i;
}
return;
}
int bucket = calc_bucket<T, BitsPerPass>(value, start_bit, mask, select_min);
atomicAdd(histogram_smem + bucket, static_cast<IdxT>(1));
};
Expand Down Expand Up @@ -265,7 +281,9 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf,
kth_value_bits,
p_filter_cnt,
p_out_cnt,
bound,
early_stop](T value, IdxT i) {
if (value == bound) { return; }
const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit)
<< previous_start_bit;
if (previous_bits == kth_value_bits) {
Expand Down Expand Up @@ -370,7 +388,7 @@ _RAFT_DEVICE void choose_bucket(Counter<T, IdxT>* counter,
IdxT cur = histogram[i];

// one and only one thread will satisfy this condition, so counter is written by only one thread
if (prev < k && cur >= k) {
if (prev < k && (cur >= k || i + 1 == num_buckets)) {
counter->k = k - prev; // how many values still are there to find
counter->len = cur - prev; // number of values in next pass
typename cub::Traits<T>::UnsignedBits bucket = i;
Expand All @@ -395,13 +413,15 @@ _RAFT_DEVICE void last_filter(const T* in_buf,
{
const auto kth_value_bits = counter->kth_value_bits;
const int start_bit = calc_start_bit<T, BitsPerPass>(pass);
const auto bound = select_min ? upper_bound<T>() : lower_bound<T>();

// changed in choose_bucket(); need to reload
const IdxT needed_num_of_kth = counter->k;
IdxT* p_out_cnt = &counter->out_cnt;
IdxT* p_out_back_cnt = &counter->out_back_cnt;
for (IdxT i = threadIdx.x; i < current_len; i += blockDim.x) {
const T value = in_buf[i];
const T value = in_buf[i];
if (value == bound) { continue; }
const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit;
if (bits < kth_value_bits) {
IdxT pos = atomicAdd(p_out_cnt, static_cast<IdxT>(1));
Expand All @@ -413,7 +433,7 @@ _RAFT_DEVICE void last_filter(const T* in_buf,
} else if (bits == kth_value_bits) {
IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast<IdxT>(1));
if (back_pos < needed_num_of_kth) {
IdxT pos = k - 1 - back_pos;
IdxT pos = k - needed_num_of_kth + back_pos;
out[pos] = value;
out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i;
}
Expand Down Expand Up @@ -457,6 +477,7 @@ __global__ void last_filter_kernel(const T* in,
const IdxT needed_num_of_kth = counter->k;
IdxT* p_out_cnt = &counter->out_cnt;
IdxT* p_out_back_cnt = &counter->out_back_cnt;
const auto bound = select_min ? upper_bound<T>() : lower_bound<T>();

auto f = [k,
select_min,
Expand All @@ -465,8 +486,10 @@ __global__ void last_filter_kernel(const T* in,
p_out_cnt,
p_out_back_cnt,
in_idx_buf,
bound,
out,
out_idx](T value, IdxT i) {
if (value == bound) { return; }
const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit;
if (bits < kth_value_bits) {
IdxT pos = atomicAdd(p_out_cnt, static_cast<IdxT>(1));
Expand All @@ -475,7 +498,7 @@ __global__ void last_filter_kernel(const T* in,
} else if (bits == kth_value_bits) {
IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast<IdxT>(1));
if (back_pos < needed_num_of_kth) {
IdxT pos = k - 1 - back_pos;
IdxT pos = k - needed_num_of_kth + back_pos;
out[pos] = value;
out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i;
}
Expand Down Expand Up @@ -560,12 +583,11 @@ __global__ void radix_kernel(const T* in,
current_len = counter->len;
previous_len = counter->previous_len;
}
if (current_len == 0) { return; }

// When k=len, early_stop will be true at pass 0. It means filter_and_histogram() should handle
// correctly the case that pass=0 and early_stop=true. However, this special case of k=len is
// handled in other way in select_k() so such case is not possible here.
const bool early_stop = (current_len == current_k);
const bool early_stop = (current_len <= current_k);
const IdxT buf_len = calc_buf_len<T>(len);

// "previous_len > buf_len" means previous pass skips writing buffer
Expand Down Expand Up @@ -602,7 +624,8 @@ __global__ void radix_kernel(const T* in,
histogram,
select_min,
pass,
early_stop);
early_stop,
k);
__threadfence();

bool isLastBlock = false;
Expand Down Expand Up @@ -723,7 +746,7 @@ _RAFT_HOST_DEVICE void set_buf_pointers(const T* in,
{
if (pass == 0) {
in_buf = in;
in_idx_buf = nullptr;
in_idx_buf = in_idx;
out_buf = nullptr;
out_idx_buf = nullptr;
} else if (pass == 1) {
Expand Down Expand Up @@ -860,16 +883,17 @@ void radix_topk(const T* in,
// The following a few functions are for the one-block version, which uses single thread block for
// each row of a batch.
template <typename T, typename IdxT, int BitsPerPass>
_RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf,
const IdxT* in_idx_buf,
T* out_buf,
IdxT* out_idx_buf,
T* out,
IdxT* out_idx,
Counter<T, IdxT>* counter,
IdxT* histogram,
bool select_min,
int pass)
_RAFT_DEVICE __noinline__ void filter_and_histogram_for_one_block(const T* in_buf,
const IdxT* in_idx_buf,
T* out_buf,
IdxT* out_idx_buf,
T* out,
IdxT* out_idx,
Counter<T, IdxT>* counter,
IdxT* histogram,
bool select_min,
int pass,
IdxT k)
{
constexpr int num_buckets = calc_num_buckets<BitsPerPass>();
for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) {
Expand All @@ -882,12 +906,25 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf,
const int start_bit = calc_start_bit<T, BitsPerPass>(pass);
const unsigned mask = calc_mask<T, BitsPerPass>(pass);
const IdxT previous_len = counter->previous_len;
// The last possible value (k-th cannot be further).
const auto bound = select_min ? upper_bound<T>() : lower_bound<T>();

if (pass == 0) {
auto f = [histogram, select_min, start_bit, mask](T value, IdxT) {
int bucket = calc_bucket<T, BitsPerPass>(value, start_bit, mask, select_min);
atomicAdd(histogram + bucket, static_cast<IdxT>(1));
};
IdxT* p_bound_cnt = &counter->bound_cnt;
auto f =
[histogram, in_idx_buf, out, out_idx, select_min, start_bit, mask, bound, p_bound_cnt, k](
T value, IdxT i) {
if (value == bound) {
if (i < k) {
IdxT pos = k - 1 - atomicAdd(p_bound_cnt, IdxT{1});
out[pos] = value;
out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i;
}
return;
}
int bucket = calc_bucket<T, BitsPerPass>(value, start_bit, mask, select_min);
atomicAdd(histogram + bucket, static_cast<IdxT>(1));
};
vectorized_process(threadIdx.x, blockDim.x, in_buf, previous_len, f);
} else {
// not use vectorized_process here because it increases #registers a lot
Expand All @@ -896,7 +933,8 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf,
const int previous_start_bit = calc_start_bit<T, BitsPerPass>(pass - 1);

for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) {
const T value = in_buf[i];
const T value = in_buf[i];
if (value == bound) { continue; }
const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit)
<< previous_start_bit;
if (previous_bits == kth_value_bits) {
Expand Down Expand Up @@ -943,6 +981,7 @@ __global__ void radix_topk_one_block_kernel(const T* in,
counter.kth_value_bits = 0;
counter.out_cnt = 0;
counter.out_back_cnt = 0;
counter.bound_cnt = 0;
}
__syncthreads();

Expand Down Expand Up @@ -977,7 +1016,8 @@ __global__ void radix_topk_one_block_kernel(const T* in,
&counter,
histogram,
select_min,
pass);
pass,
k);
__syncthreads();

scan<IdxT, BitsPerPass, BlockSize>(histogram);
Expand All @@ -987,7 +1027,7 @@ __global__ void radix_topk_one_block_kernel(const T* in,
if (threadIdx.x == 0) { counter.previous_len = current_len; }
__syncthreads();

if (counter.len == counter.k || pass == num_passes - 1) {
if (counter.len <= counter.k || pass == num_passes - 1) {
last_filter<T, IdxT, BitsPerPass>(pass == 0 ? in : out_buf,
pass == 0 ? in_idx : out_idx_buf,
out,
Expand Down