From 87867666d5dcd51d49f1262780fd3d9cf77a8e59 Mon Sep 17 00:00:00 2001 From: achirkin Date: Wed, 16 Aug 2023 17:36:34 +0200 Subject: [PATCH 1/7] Filter out infinities in radix-based select-k --- cpp/include/raft/matrix/detail/select_radix.cuh | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index edde924892..2e0082c070 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -238,6 +238,7 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, // i.e. the work is split along the input (both, in batches and chunks of a single row). // Later, the histograms are merged using atomicAdd. auto f = [select_min, start_bit, mask](T value, IdxT) { + if (select_min ? (value >= upper_bound()) : (value <= lower_bound())) { return; } int bucket = calc_bucket(value, start_bit, mask, select_min); atomicAdd(histogram_smem + bucket, static_cast(1)); }; @@ -266,6 +267,7 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, p_filter_cnt, p_out_cnt, early_stop](T value, IdxT i) { + if (select_min ? (value >= upper_bound()) : (value <= lower_bound())) { return; } const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) << previous_start_bit; if (previous_bits == kth_value_bits) { @@ -885,6 +887,7 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, if (pass == 0) { auto f = [histogram, select_min, start_bit, mask](T value, IdxT) { + if (select_min ? (value >= upper_bound()) : (value <= lower_bound())) { return; } int bucket = calc_bucket(value, start_bit, mask, select_min); atomicAdd(histogram + bucket, static_cast(1)); }; @@ -896,7 +899,8 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, const int previous_start_bit = calc_start_bit(pass - 1); for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { - const T value = in_buf[i]; + const T value = in_buf[i]; + if (select_min ? (value >= upper_bound()) : (value <= lower_bound())) { continue; } const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) << previous_start_bit; if (previous_bits == kth_value_bits) { From 8964ba02885a9750f5de4c360ef09337ab13bcb5 Mon Sep 17 00:00:00 2001 From: achirkin Date: Fri, 18 Aug 2023 20:18:13 +0200 Subject: [PATCH 2/7] Write neighbours with inf dist at zero pass and allow zero current_len at first pass --- .../raft/matrix/detail/select_radix.cuh | 70 ++++++++++++++----- 1 file changed, 51 insertions(+), 19 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 2e0082c070..1c2958466e 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -203,6 +203,9 @@ struct alignas(128) Counter { // value are written from back to front. We need to keep count of them separately because the // number of elements that <= the k-th value might exceed k. alignas(128) IdxT out_back_cnt; + + // Number of infinities found in the zero pass. + alignas(128) IdxT bound_cnt; }; /** @@ -221,7 +224,8 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, IdxT* histogram, bool select_min, int pass, - bool early_stop) + bool early_stop, + IdxT k) { constexpr int num_buckets = calc_num_buckets(); __shared__ IdxT histogram_smem[num_buckets]; @@ -232,13 +236,24 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, const int start_bit = calc_start_bit(pass); const unsigned mask = calc_mask(pass); + // The last possible value (k-th cannot be further). + const auto bound = select_min ? upper_bound() : lower_bound(); if (pass == 0) { + IdxT* p_bound_cnt = &counter->bound_cnt; // Passed to vectorized_process, this function executes in all blocks in parallel, // i.e. the work is split along the input (both, in batches and chunks of a single row). // Later, the histograms are merged using atomicAdd. - auto f = [select_min, start_bit, mask](T value, IdxT) { - if (select_min ? (value >= upper_bound()) : (value <= lower_bound())) { return; } + auto f = [in_idx_buf, out, out_idx, select_min, start_bit, mask, bound, p_bound_cnt, k]( + T value, IdxT i) { + if (value == bound) { + if (i < k) { + IdxT pos = k - 1 - atomicAdd(p_bound_cnt, IdxT{1}); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + return; + } int bucket = calc_bucket(value, start_bit, mask, select_min); atomicAdd(histogram_smem + bucket, static_cast(1)); }; @@ -266,8 +281,9 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, kth_value_bits, p_filter_cnt, p_out_cnt, + bound, early_stop](T value, IdxT i) { - if (select_min ? (value >= upper_bound()) : (value <= lower_bound())) { return; } + if (value == bound) { return; } const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) << previous_start_bit; if (previous_bits == kth_value_bits) { @@ -372,7 +388,7 @@ _RAFT_DEVICE void choose_bucket(Counter* counter, IdxT cur = histogram[i]; // one and only one thread will satisfy this condition, so counter is written by only one thread - if (prev < k && cur >= k) { + if (prev < k && (cur >= k || i + 1 == num_buckets)) { counter->k = k - prev; // how many values still are there to find counter->len = cur - prev; // number of values in next pass typename cub::Traits::UnsignedBits bucket = i; @@ -415,7 +431,7 @@ _RAFT_DEVICE void last_filter(const T* in_buf, } else if (bits == kth_value_bits) { IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); if (back_pos < needed_num_of_kth) { - IdxT pos = k - 1 - back_pos; + IdxT pos = k - needed_num_of_kth + back_pos; out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } @@ -477,7 +493,7 @@ __global__ void last_filter_kernel(const T* in, } else if (bits == kth_value_bits) { IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); if (back_pos < needed_num_of_kth) { - IdxT pos = k - 1 - back_pos; + IdxT pos = k - needed_num_of_kth + back_pos; out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } @@ -562,12 +578,11 @@ __global__ void radix_kernel(const T* in, current_len = counter->len; previous_len = counter->previous_len; } - if (current_len == 0) { return; } // When k=len, early_stop will be true at pass 0. It means filter_and_histogram() should handle // correctly the case that pass=0 and early_stop=true. However, this special case of k=len is // handled in other way in select_k() so such case is not possible here. - const bool early_stop = (current_len == current_k); + const bool early_stop = (current_len <= current_k); const IdxT buf_len = calc_buf_len(len); // "previous_len > buf_len" means previous pass skips writing buffer @@ -604,7 +619,9 @@ __global__ void radix_kernel(const T* in, histogram, select_min, pass, - early_stop); + early_stop, + k); + if (current_len == 0) { return; } __threadfence(); bool isLastBlock = false; @@ -871,7 +888,8 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, Counter* counter, IdxT* histogram, bool select_min, - int pass) + int pass, + IdxT k) { constexpr int num_buckets = calc_num_buckets(); for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { @@ -884,13 +902,25 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, const int start_bit = calc_start_bit(pass); const unsigned mask = calc_mask(pass); const IdxT previous_len = counter->previous_len; + // The last possible value (k-th cannot be further). + const auto bound = select_min ? upper_bound() : lower_bound(); if (pass == 0) { - auto f = [histogram, select_min, start_bit, mask](T value, IdxT) { - if (select_min ? (value >= upper_bound()) : (value <= lower_bound())) { return; } - int bucket = calc_bucket(value, start_bit, mask, select_min); - atomicAdd(histogram + bucket, static_cast(1)); - }; + IdxT* p_bound_cnt = &counter->bound_cnt; + auto f = + [histogram, in_idx_buf, out, out_idx, select_min, start_bit, mask, bound, p_bound_cnt, k]( + T value, IdxT i) { + if (value == bound) { + if (i < k) { + IdxT pos = k - 1 - atomicAdd(p_bound_cnt, IdxT{1}); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + return; + } + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + }; vectorized_process(threadIdx.x, blockDim.x, in_buf, previous_len, f); } else { // not use vectorized_process here because it increases #registers a lot @@ -900,7 +930,7 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { const T value = in_buf[i]; - if (select_min ? (value >= upper_bound()) : (value <= lower_bound())) { continue; } + if (value == bound) { continue; } const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) << previous_start_bit; if (previous_bits == kth_value_bits) { @@ -947,6 +977,7 @@ __global__ void radix_topk_one_block_kernel(const T* in, counter.kth_value_bits = 0; counter.out_cnt = 0; counter.out_back_cnt = 0; + counter.bound_cnt = 0; } __syncthreads(); @@ -981,7 +1012,8 @@ __global__ void radix_topk_one_block_kernel(const T* in, &counter, histogram, select_min, - pass); + pass, + k); __syncthreads(); scan(histogram); @@ -991,7 +1023,7 @@ __global__ void radix_topk_one_block_kernel(const T* in, if (threadIdx.x == 0) { counter.previous_len = current_len; } __syncthreads(); - if (counter.len == counter.k || pass == num_passes - 1) { + if (counter.len <= counter.k || pass == num_passes - 1) { last_filter(pass == 0 ? in : out_buf, pass == 0 ? in_idx : out_idx_buf, out, From 3eb905fb93b36c75aa24919c147d601274443d3b Mon Sep 17 00:00:00 2001 From: achirkin Date: Tue, 22 Aug 2023 21:25:40 +0200 Subject: [PATCH 3/7] Add noinline to filter_and_histogram_for_one_block to reduce register usage for double+uint64_t types --- .../raft/matrix/detail/select_radix.cuh | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 1c2958466e..e454747b74 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -879,17 +879,17 @@ void radix_topk(const T* in, // The following a few functions are for the one-block version, which uses single thread block for // each row of a batch. template -_RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - Counter* counter, - IdxT* histogram, - bool select_min, - int pass, - IdxT k) +_RAFT_DEVICE __noinline__ void filter_and_histogram_for_one_block(const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + Counter* counter, + IdxT* histogram, + bool select_min, + int pass, + IdxT k) { constexpr int num_buckets = calc_num_buckets(); for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { From 44014cf104bc4c421c3a76e36db643068cc01610 Mon Sep 17 00:00:00 2001 From: achirkin Date: Tue, 22 Aug 2023 21:29:04 +0200 Subject: [PATCH 4/7] Add tests for inputs containing infinities --- .../raft_internal/matrix/select_k.cuh | 7 ++-- cpp/test/matrix/select_k.cu | 34 +++++++++++++++++++ cpp/test/matrix/select_k.cuh | 25 ++++++++++++-- 3 files changed, 62 insertions(+), 4 deletions(-) diff --git a/cpp/internal/raft_internal/matrix/select_k.cuh b/cpp/internal/raft_internal/matrix/select_k.cuh index b72e67580a..1d15c5fc03 100644 --- a/cpp/internal/raft_internal/matrix/select_k.cuh +++ b/cpp/internal/raft_internal/matrix/select_k.cuh @@ -33,6 +33,7 @@ struct params { bool use_index_input = true; bool use_same_leading_bits = false; bool use_memory_pool = true; + double frac_infinities = 0.0; }; inline auto operator<<(std::ostream& os, const params& ss) -> std::ostream& @@ -41,8 +42,10 @@ inline auto operator<<(std::ostream& os, const params& ss) -> std::ostream& os << ", len: " << ss.len; os << ", k: " << ss.k; os << (ss.select_min ? ", asc" : ", dsc"); - os << (ss.use_index_input ? "" : ", no-input-index"); - os << (ss.use_same_leading_bits ? ", same-leading-bits}" : "}"); + if (!ss.use_index_input) { os << ", no-input-index"; } + if (ss.use_same_leading_bits) { os << ", same-leading-bits"; } + if (ss.frac_infinities > 0) { os << ", infs: " << ss.frac_infinities; } + os << "}"; return os; } diff --git a/cpp/test/matrix/select_k.cu b/cpp/test/matrix/select_k.cu index 63f020b420..ce4e3e867e 100644 --- a/cpp/test/matrix/select_k.cu +++ b/cpp/test/matrix/select_k.cu @@ -70,6 +70,28 @@ auto inputs_random_largek = testing::Values(select::params{100, 100000, 1000, tr select::params{100, 100000, 2048, false}, select::params{100, 100000, 1237, true}); +auto inputs_random_many_infs = + testing::Values(select::params{10, 100000, 1, true, false, false, true, 0.9}, + select::params{10, 100000, 16, true, false, false, true, 0.9}, + select::params{10, 100000, 64, true, false, false, true, 0.9}, + select::params{10, 100000, 128, true, false, false, true, 0.9}, + select::params{10, 100000, 256, true, false, false, true, 0.9}, + select::params{1000, 10000, 1, true, false, false, true, 0.9}, + select::params{1000, 10000, 16, true, false, false, true, 0.9}, + select::params{1000, 10000, 64, true, false, false, true, 0.9}, + select::params{1000, 10000, 128, true, false, false, true, 0.9}, + select::params{1000, 10000, 256, true, false, false, true, 0.9}, + select::params{10, 100000, 1, true, false, false, true, 0.999}, + select::params{10, 100000, 16, true, false, false, true, 0.999}, + select::params{10, 100000, 64, true, false, false, true, 0.999}, + select::params{10, 100000, 128, true, false, false, true, 0.999}, + select::params{10, 100000, 256, true, false, false, true, 0.999}, + select::params{1000, 10000, 1, true, false, false, true, 0.999}, + select::params{1000, 10000, 16, true, false, false, true, 0.999}, + select::params{1000, 10000, 64, true, false, false, true, 0.999}, + select::params{1000, 10000, 128, true, false, false, true, 0.999}, + select::params{1000, 10000, 256, true, false, false, true, 0.999}); + using ReferencedRandomFloatInt = SelectK::params_random>; TEST_P(ReferencedRandomFloatInt, Run) { run(); } // NOLINT @@ -111,4 +133,16 @@ INSTANTIATE_TEST_CASE_P( // NOLINT select::Algo::kRadix8bits, select::Algo::kRadix11bits, select::Algo::kRadix11bitsExtraPass))); + +using ReferencedRandomFloatIntkWarpsortAsGT = + SelectK::params_random>; +TEST_P(ReferencedRandomFloatIntkWarpsortAsGT, Run) { run(); } // NOLINT +INSTANTIATE_TEST_CASE_P( // NOLINT + SelectK, + ReferencedRandomFloatIntkWarpsortAsGT, + testing::Combine(inputs_random_many_infs, + testing::Values(select::Algo::kRadix8bits, + select::Algo::kRadix11bits, + select::Algo::kRadix11bitsExtraPass))); + } // namespace raft::matrix diff --git a/cpp/test/matrix/select_k.cuh b/cpp/test/matrix/select_k.cuh index e0e0cad225..04c68cf021 100644 --- a/cpp/test/matrix/select_k.cuh +++ b/cpp/test/matrix/select_k.cuh @@ -227,7 +227,8 @@ struct SelectK // NOLINT // to non-deterministic nature of some implementations. auto& in_ids = ref.get_in_ids(); auto& in_dists = ref.get_in_dists(); - auto compare_ids = [&in_ids, &in_dists](const IdxT& i, const IdxT& j) { + const auto bound = spec.select_min ? raft::upper_bound() : raft::lower_bound(); + auto compare_ids = [&in_ids, &in_dists, bound](const IdxT& i, const IdxT& j) { if (i == j) return true; auto ix_i = static_cast(std::find(in_ids.begin(), in_ids.end(), i) - in_ids.begin()); auto ix_j = static_cast(std::find(in_ids.begin(), in_ids.end(), j) - in_ids.begin()); @@ -235,7 +236,8 @@ struct SelectK // NOLINT return false; auto dist_i = in_dists[ix_i]; auto dist_j = in_dists[ix_j]; - if (dist_i == dist_j) return true; + // Some algorithms return invalid/zero indices for bound values. + if (dist_i == dist_j || dist_j == bound || dist_i == bound) return true; std::cout << "ERROR: ref[" << ix_i << "] = " << dist_i << " != " << "res[" << ix_j << "] = " << dist_j << std::endl; return false; @@ -335,6 +337,12 @@ INSTANTIATE_TEST_CASE_P( // NOLINT select::Algo::kWarpFiltered, select::Algo::kWarpDistributed))); +template +struct replace_with_mask { + KeyT replacement; + constexpr auto inline operator()(KeyT x, uint8_t mask) -> KeyT { return mask ? replacement : x; } +}; + template struct with_ref { template @@ -354,6 +362,19 @@ struct with_ref { rmm::device_uvector dists_d(spec.len * spec.batch_size, s); raft::random::RngState r(42); normal(handle, r, dists_d.data(), dists_d.size(), KeyT(10.0), KeyT(100.0)); + + if (spec.frac_infinities > 0.0) { + rmm::device_uvector mask_buf(dists_d.size(), s); + auto mask = make_device_vector_view(mask_buf.data(), mask_buf.size()); + raft::random::bernoulli(handle, r, mask, spec.frac_infinities); + KeyT bound = spec.select_min ? raft::upper_bound() : raft::lower_bound(); + auto mask_in = + make_device_vector_view(mask_buf.data(), mask_buf.size()); + auto dists_in = make_device_vector_view(dists_d.data(), dists_d.size()); + auto dists_out = make_device_vector_view(dists_d.data(), dists_d.size()); + raft::linalg::map(handle, dists_out, replace_with_mask{bound}, dists_in, mask_in); + } + update_host(dists.data(), dists_d.data(), dists_d.size(), s); s.synchronize(); } From 3b0b68d29d8e39189d86433483723c4cd23338fc Mon Sep 17 00:00:00 2001 From: achirkin Date: Wed, 23 Aug 2023 08:45:52 +0200 Subject: [PATCH 5/7] Fix the last filter not ignoring the bound values --- cpp/include/raft/matrix/detail/select_radix.cuh | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index e454747b74..5efe6d7821 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -413,13 +413,15 @@ _RAFT_DEVICE void last_filter(const T* in_buf, { const auto kth_value_bits = counter->kth_value_bits; const int start_bit = calc_start_bit(pass); + const auto bound = select_min ? upper_bound() : lower_bound(); // changed in choose_bucket(); need to reload const IdxT needed_num_of_kth = counter->k; IdxT* p_out_cnt = &counter->out_cnt; IdxT* p_out_back_cnt = &counter->out_back_cnt; for (IdxT i = threadIdx.x; i < current_len; i += blockDim.x) { - const T value = in_buf[i]; + const T value = in_buf[i]; + if (value == bound) { continue; } const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; if (bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); @@ -475,6 +477,7 @@ __global__ void last_filter_kernel(const T* in, const IdxT needed_num_of_kth = counter->k; IdxT* p_out_cnt = &counter->out_cnt; IdxT* p_out_back_cnt = &counter->out_back_cnt; + const auto bound = select_min ? upper_bound() : lower_bound(); auto f = [k, select_min, @@ -483,8 +486,10 @@ __global__ void last_filter_kernel(const T* in, p_out_cnt, p_out_back_cnt, in_idx_buf, + bound, out, out_idx](T value, IdxT i) { + if (value == bound) { return; } const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; if (bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); @@ -621,7 +626,6 @@ __global__ void radix_kernel(const T* in, pass, early_stop, k); - if (current_len == 0) { return; } __threadfence(); bool isLastBlock = false; From e7ece2cec6fb9fa69275b20eb73e6fd32a9ef354 Mon Sep 17 00:00:00 2001 From: achirkin Date: Wed, 23 Aug 2023 09:37:12 +0200 Subject: [PATCH 6/7] Add benchmarks for the edge case of having many infinities --- cpp/bench/prims/matrix/select_k.cu | 64 +++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/cpp/bench/prims/matrix/select_k.cu b/cpp/bench/prims/matrix/select_k.cu index 1bff66cac4..992fda8a38 100644 --- a/cpp/bench/prims/matrix/select_k.cu +++ b/cpp/bench/prims/matrix/select_k.cu @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -38,6 +39,19 @@ namespace raft::matrix { using namespace raft::bench; // NOLINT +template +struct replace_with_mask { + KeyT replacement; + int64_t line_length; + int64_t spared_inputs; + constexpr auto inline operator()(int64_t offset, KeyT x, uint8_t mask) -> KeyT + { + auto i = offset % line_length; + // don't replace all the inputs, spare a few elements at the beginning of the input + return (mask && i >= spared_inputs) ? replacement : x; + } +}; + template struct selection : public fixture { explicit selection(const select::params& p) @@ -67,6 +81,21 @@ struct selection : public fixture { } } raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), min_value, max_value); + if (p.frac_infinities > 0.0) { + rmm::device_uvector mask_buf(p.batch_size * p.len, stream); + auto mask = make_device_vector_view(mask_buf.data(), mask_buf.size()); + raft::random::bernoulli(handle, state, mask, p.frac_infinities); + KeyT bound = p.select_min ? raft::upper_bound() : raft::lower_bound(); + auto mask_in = + make_device_vector_view(mask_buf.data(), mask_buf.size()); + auto dists_in = make_device_vector_view(in_dists_.data(), in_dists_.size()); + auto dists_out = make_device_vector_view(in_dists_.data(), in_dists_.size()); + raft::linalg::map_offset(handle, + dists_out, + replace_with_mask{bound, int64_t(p.len), int64_t(p.k / 2)}, + dists_in, + mask_in); + } } void run_benchmark(::benchmark::State& state) override // NOLINT @@ -75,8 +104,12 @@ struct selection : public fixture { std::ostringstream label_stream; label_stream << params_.batch_size << "#" << params_.len << "#" << params_.k; if (params_.use_same_leading_bits) { label_stream << "#same-leading-bits"; } + if (params_.frac_infinities > 0) { label_stream << "#infs-" << params_.frac_infinities; } state.SetLabel(label_stream.str()); - loop_on_state(state, [this]() { + common::nvtx::range case_scope("%s - %s", state.name().c_str(), label_stream.str().c_str()); + int iter = 0; + loop_on_state(state, [&iter, this]() { + common::nvtx::range lap_scope("lap-", iter++); select::select_k_impl(handle, Algo, in_dists_.data(), @@ -149,6 +182,35 @@ const std::vector kInputs{ {10, 1000000, 64, true, false, true}, {10, 1000000, 128, true, false, true}, {10, 1000000, 256, true, false, true}, + + {10, 1000000, 1, true, false, false, true, 0.1}, + {10, 1000000, 16, true, false, false, true, 0.1}, + {10, 1000000, 64, true, false, false, true, 0.1}, + {10, 1000000, 128, true, false, false, true, 0.1}, + {10, 1000000, 256, true, false, false, true, 0.1}, + + {10, 1000000, 1, true, false, false, true, 0.9}, + {10, 1000000, 16, true, false, false, true, 0.9}, + {10, 1000000, 64, true, false, false, true, 0.9}, + {10, 1000000, 128, true, false, false, true, 0.9}, + {10, 1000000, 256, true, false, false, true, 0.9}, + {1000, 10000, 1, true, false, false, true, 0.9}, + {1000, 10000, 16, true, false, false, true, 0.9}, + {1000, 10000, 64, true, false, false, true, 0.9}, + {1000, 10000, 128, true, false, false, true, 0.9}, + {1000, 10000, 256, true, false, false, true, 0.9}, + + {10, 1000000, 1, true, false, false, true, 1.0}, + {10, 1000000, 16, true, false, false, true, 1.0}, + {10, 1000000, 64, true, false, false, true, 1.0}, + {10, 1000000, 128, true, false, false, true, 1.0}, + {10, 1000000, 256, true, false, false, true, 1.0}, + {1000, 10000, 1, true, false, false, true, 1.0}, + {1000, 10000, 16, true, false, false, true, 1.0}, + {1000, 10000, 64, true, false, false, true, 1.0}, + {1000, 10000, 128, true, false, false, true, 1.0}, + {1000, 10000, 256, true, false, false, true, 1.0}, + {1000, 10000, 256, true, false, false, true, 0.999}, }; #define SELECTION_REGISTER(KeyT, IdxT, A) \ From 39a8a75c060604581484331363bba287f3b70b76 Mon Sep 17 00:00:00 2001 From: achirkin Date: Mon, 28 Aug 2023 13:51:59 +0200 Subject: [PATCH 7/7] Make the result testing more strict, allow setting input indices, and fix the non-set in_idx_buf in the zero-th pass of the one-block kernel --- .../raft/matrix/detail/select_radix.cuh | 2 +- cpp/test/matrix/select_k.cuh | 108 ++++++++++++++---- 2 files changed, 85 insertions(+), 25 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 5efe6d7821..1f6357a563 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -746,7 +746,7 @@ _RAFT_HOST_DEVICE void set_buf_pointers(const T* in, { if (pass == 0) { in_buf = in; - in_idx_buf = nullptr; + in_idx_buf = in_idx; out_buf = nullptr; out_idx_buf = nullptr; } else if (pass == 1) { diff --git a/cpp/test/matrix/select_k.cuh b/cpp/test/matrix/select_k.cuh index 04c68cf021..eaabbb3357 100644 --- a/cpp/test/matrix/select_k.cuh +++ b/cpp/test/matrix/select_k.cuh @@ -49,14 +49,16 @@ auto gen_simple_ids(uint32_t batch_size, uint32_t len) -> std::vector template struct io_simple { public: - bool not_supported = false; + bool not_supported = false; + std::optional algo = std::nullopt; io_simple(const select::params& spec, const std::vector& in_dists, + const std::optional>& in_ids, const std::vector& out_dists, const std::vector& out_ids) : in_dists_(in_dists), - in_ids_(gen_simple_ids(spec.batch_size, spec.len)), + in_ids_(in_ids.value_or(gen_simple_ids(spec.batch_size, spec.len))), out_dists_(out_dists), out_ids_(out_ids) { @@ -78,12 +80,14 @@ template struct io_computed { public: bool not_supported = false; + select::Algo algo; io_computed(const select::params& spec, const select::Algo& algo, const std::vector& in_dists, const std::optional>& in_ids = std::nullopt) - : in_dists_(in_dists), + : algo(algo), + in_dists_(in_dists), in_ids_(in_ids.value_or(gen_simple_ids(spec.batch_size, spec.len))), out_dists_(spec.batch_size * spec.k), out_ids_(spec.batch_size * spec.k) @@ -223,34 +227,62 @@ struct SelectK // NOLINT if (ref.not_supported || res.not_supported) { GTEST_SKIP(); } ASSERT_TRUE(hostVecMatch(ref.get_out_dists(), res.get_out_dists(), Compare())); - // If the dists (keys) are the same, different corresponding ids may end up in the selection due - // to non-deterministic nature of some implementations. - auto& in_ids = ref.get_in_ids(); - auto& in_dists = ref.get_in_dists(); - const auto bound = spec.select_min ? raft::upper_bound() : raft::lower_bound(); - auto compare_ids = [&in_ids, &in_dists, bound](const IdxT& i, const IdxT& j) { + // If the dists (keys) are the same, different corresponding ids may end up in the selection + // due to non-deterministic nature of some implementations. + auto compare_ids = [this](const IdxT& i, const IdxT& j) { if (i == j) return true; + auto& in_ids = ref.get_in_ids(); + auto& in_dists = ref.get_in_dists(); auto ix_i = static_cast(std::find(in_ids.begin(), in_ids.end(), i) - in_ids.begin()); auto ix_j = static_cast(std::find(in_ids.begin(), in_ids.end(), j) - in_ids.begin()); - if (static_cast(ix_i) >= in_ids.size() || static_cast(ix_j) >= in_ids.size()) - return false; + auto forgive_i = forgive_algo(ref.algo, i); + auto forgive_j = forgive_algo(res.algo, j); + // Some algorithms return invalid indices in special cases. + // This can be considered as TODO for us to fix. + if (static_cast(ix_i) >= in_ids.size()) return forgive_i; + if (static_cast(ix_j) >= in_ids.size()) return forgive_j; auto dist_i = in_dists[ix_i]; auto dist_j = in_dists[ix_j]; - // Some algorithms return invalid/zero indices for bound values. - if (dist_i == dist_j || dist_j == bound || dist_i == bound) return true; + if (dist_i == dist_j) return true; + const auto bound = spec.select_min ? raft::upper_bound() : raft::lower_bound(); + if (forgive_i && dist_i == bound) return true; + if (forgive_j && dist_j == bound) return true; + // Otherwise really fail std::cout << "ERROR: ref[" << ix_i << "] = " << dist_i << " != " << "res[" << ix_j << "] = " << dist_j << std::endl; return false; }; ASSERT_TRUE(hostVecMatch(ref.get_out_ids(), res.get_out_ids(), compare_ids)); } + + auto forgive_algo(const std::optional& algo, IdxT ix) const -> bool + { + if (!algo.has_value()) { return false; } + switch (algo.value()) { + // not sure which algo this is. + case select::Algo::kPublicApi: return true; + // warp-sort-based algos currently return zero index for inf distances. + case select::Algo::kWarpAuto: + case select::Algo::kWarpImmediate: + case select::Algo::kWarpFiltered: + case select::Algo::kWarpDistributed: + case select::Algo::kWarpDistributedShm: return ix == 0; + // FAISS version returns a special invalid value: + case select::Algo::kFaissBlockSelect: return ix == std::numeric_limits::max(); + // Do not forgive by default + default: return false; + } + } }; template struct params_simple { - using io_t = io_simple; - using input_t = - std::tuple, std::vector, std::vector>; + using io_t = io_simple; + using input_t = std::tuple, + std::optional>, + std::vector, + std::vector>; using params_t = std::tuple; static auto read(params_t ps) -> Params @@ -261,15 +293,17 @@ struct params_simple { std::get<0>(ins), algo, io_simple( - std::get<0>(ins), std::get<1>(ins), std::get<2>(ins), std::get<3>(ins))); + std::get<0>(ins), std::get<1>(ins), std::get<2>(ins), std::get<3>(ins), std::get<4>(ins))); } }; +auto inf_f = std::numeric_limits::max(); auto inputs_simple_f = testing::Values( params_simple::input_t( {5, 5, 5, true, true}, {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, + std::nullopt, {1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0}, {4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 3, 0, 1, 4, 2, 4, 2, 1, 3, 0, 0, 2, 1, 4, 3}), @@ -277,12 +311,14 @@ auto inputs_simple_f = testing::Values( {5, 5, 3, true, true}, {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, + std::nullopt, {1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, {4, 3, 2, 0, 1, 2, 3, 0, 1, 4, 2, 1, 0, 2, 1}), params_simple::input_t( {5, 5, 5, true, false}, {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, + std::nullopt, {1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0}, {4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 3, 0, 1, 4, 2, 4, 2, 1, 3, 0, 0, 2, 1, 4, 3}), @@ -290,20 +326,31 @@ auto inputs_simple_f = testing::Values( {5, 5, 3, true, false}, {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, + std::nullopt, {1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, {4, 3, 2, 0, 1, 2, 3, 0, 1, 4, 2, 1, 0, 2, 1}), params_simple::input_t( {5, 7, 3, true, true}, {5.0, 4.0, 3.0, 2.0, 1.3, 7.5, 19.0, 9.0, 2.0, 3.0, 3.0, 5.0, 6.0, 4.0, 2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0, 5.0, 7.0, 2.5, 4.0, 7.0, 8.0, 8.0, 1.0, 3.0, 2.0, 5.0, 4.0, 1.1, 1.2}, + std::nullopt, {1.3, 2.0, 3.0, 2.0, 3.0, 3.0, 1.0, 1.0, 1.0, 2.5, 4.0, 5.0, 1.0, 1.1, 1.2}, {4, 3, 2, 1, 2, 3, 3, 5, 6, 2, 3, 0, 0, 5, 6}), - params_simple::input_t( - {1, 7, 3, true, true}, {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, {1.0, 1.0, 1.0}, {3, 5, 6}), - params_simple::input_t( - {1, 7, 3, false, false}, {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, {5.0, 4.0, 3.0}, {2, 4, 1}), - params_simple::input_t( - {1, 7, 3, false, true}, {2.0, 3.0, 5.0, 9.0, 4.0, 9.0, 9.0}, {9.0, 9.0, 9.0}, {3, 5, 6}), + params_simple::input_t({1, 7, 3, true, true}, + {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, + std::nullopt, + {1.0, 1.0, 1.0}, + {3, 5, 6}), + params_simple::input_t({1, 7, 3, false, false}, + {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, + std::nullopt, + {5.0, 4.0, 3.0}, + {2, 4, 1}), + params_simple::input_t({1, 7, 3, false, true}, + {2.0, 3.0, 5.0, 9.0, 4.0, 9.0, 9.0}, + std::nullopt, + {9.0, 9.0, 9.0}, + {3, 5, 6}), params_simple::input_t( {1, 130, 5, false, true}, {19, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, @@ -311,6 +358,7 @@ auto inputs_simple_f = testing::Values( 0, 1, 0, 1, 0, 1, 0, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 4, 4, 2, 3, 2, 3, 2, 3, 2, 3, 2, 20}, + std::nullopt, {20, 19, 18, 17, 16}, {129, 0, 117, 116, 115}), params_simple::input_t( @@ -320,8 +368,20 @@ auto inputs_simple_f = testing::Values( 0, 1, 0, 1, 0, 1, 0, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 4, 4, 2, 3, 2, 3, 2, 3, 2, 3, 2, 20}, + std::nullopt, {20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6}, - {129, 0, 117, 116, 115, 114, 113, 112, 111, 110, 109, 108, 107, 106, 105})); + {129, 0, 117, 116, 115, 114, 113, 112, 111, 110, 109, 108, 107, 106, 105}), + params_simple::input_t( + select::params{1, 32, 31, true, true}, + {0, 1, 2, 3, inf_f, inf_f, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, + std::optional{std::vector{31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, + 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, + 9, 8, 7, 6, 75, 74, 3, 2, 1, 0}}, + {0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, inf_f}, + {31, 30, 29, 28, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, + 13, 12, 11, 10, 9, 8, 7, 6, 75, 74, 3, 2, 1, 0, 27})); using SimpleFloatInt = SelectK; TEST_P(SimpleFloatInt, Run) { run(); } // NOLINT