From ef245e866e07079d1de8ad040a5296e36c9850f3 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 17 Aug 2023 16:48:15 -0700 Subject: [PATCH 01/28] successful compilation --- cpp/include/raft/neighbors/cagra.cuh | 19 +- cpp/include/raft/neighbors/cagra_types.hpp | 6 + .../raft/neighbors/detail/nn_descent.cuh | 1309 +++++++++++++++++ cpp/include/raft/neighbors/nn_descent.cuh | 67 + .../raft/neighbors/nn_descent_types.hpp | 124 ++ 5 files changed, 1521 insertions(+), 4 deletions(-) create mode 100644 cpp/include/raft/neighbors/detail/nn_descent.cuh create mode 100644 cpp/include/raft/neighbors/nn_descent.cuh create mode 100644 cpp/include/raft/neighbors/nn_descent_types.hpp diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 6bb7beca55..785fb78868 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #include namespace raft::neighbors::cagra { @@ -256,13 +257,23 @@ index build(raft::resources const& res, graph_degree = intermediate_degree; } - auto knn_graph = raft::make_host_matrix(dataset.extent(0), intermediate_degree); + auto cagra_graph = raft::make_host_matrix(dataset.extent(0), graph_degree); - build_knn_graph(res, dataset, knn_graph.view()); + if (params.build_algo == graph_build_algo::IVF_PQ) { + auto knn_graph = raft::make_host_matrix(dataset.extent(0), intermediate_degree); - auto cagra_graph = raft::make_host_matrix(dataset.extent(0), graph_degree); + build_knn_graph(res, dataset, knn_graph.view()); - optimize(res, knn_graph.view(), cagra_graph.view()); + optimize(res, knn_graph.view(), cagra_graph.view()); + } + else { + nn_descent::index_params nn_descent_params; + nn_descent_params.intermediate_graph_degree = intermediate_degree; + nn_descent_params.graph_degree = graph_degree; + auto nn_descent_index = nn_descent::build(res, nn_descent_params, dataset); + + optimize(res, nn_descent_index.graph(), cagra_graph.view()); + } // Construct an index from dataset and optimized knn graph. return index(res, params.metric, dataset, raft::make_const_mdspan(cagra_graph.view())); diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 01d6a92235..fa293e9d5c 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -40,9 +40,15 @@ namespace raft::neighbors::cagra { * @{ */ +enum class graph_build_algo { + IVF_PQ, + NN_DESCENT +}; + struct index_params : ann::index_params { size_t intermediate_graph_degree = 128; // Degree of input graph for pruning. size_t graph_degree = 64; // Degree of output graph. + graph_build_algo build_algo = graph_build_algo::IVF_PQ; // ANN algorithm to build knn graph }; enum class search_algo { diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh new file mode 100644 index 0000000000..3315b2e09a --- /dev/null +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -0,0 +1,1309 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +#include "../nn_descent_types.hpp" + +#include +#include +#include + +namespace raft::neighbors::nn_descent::detail { +using DistData_t = float; +constexpr int DEGREE_ON_DEVICE{32}; +constexpr int SEGMENT_SIZE{32}; +constexpr int counter_interval{100}; +template +struct InternalID_t; + +// InternalID_t uses 1 bit for marking (new or old). +template <> +class InternalID_t { + private: + using Index_t = int; + Index_t id_{std::numeric_limits::max()}; + + public: + __host__ __device__ bool is_new() const { return id_ >= 0; } + __host__ __device__ Index_t& id_with_flag() { return id_; } + __host__ __device__ Index_t id() const { + if (is_new()) return id_; + return -id_ - 1; + } + __host__ __device__ void mark_old() { + if (id_ >= 0) id_ = -id_ - 1; + } + __host__ __device__ bool operator==(const InternalID_t& other) const { + return id() == other.id(); + } +}; + +template +struct ResultItem; + +template <> +class ResultItem { + private: + using Index_t = int; + Index_t id_; + DistData_t dist_; + + public: + __host__ __device__ ResultItem() + : id_(std::numeric_limits::max()), dist_(std::numeric_limits::max()){}; + __host__ __device__ ResultItem(const Index_t id_with_flag, const DistData_t dist) + : id_(id_with_flag), dist_(dist){}; + __host__ __device__ bool is_new() const { return id_ >= 0; } + __host__ __device__ Index_t& id_with_flag() { return id_; } + __host__ __device__ Index_t id() const { + if (is_new()) return id_; + return -id_ - 1; + } + __host__ __device__ DistData_t& dist() { return dist_; } + + __host__ __device__ void mark_old() { + if (id_ >= 0) id_ = -id_ - 1; + } + + __host__ __device__ bool operator<(const ResultItem& other) const { + if (dist_ == other.dist_) return id() < other.id(); + return dist_ < other.dist_; + } + __host__ __device__ bool operator==(const ResultItem& other) const { + return id() == other.id(); + } + __host__ __device__ bool operator>=(const ResultItem& other) const { + return !(*this < other); + } + __host__ __device__ bool operator<=(const ResultItem& other) const { + return (*this == other) || (*this < other); + } + __host__ __device__ bool operator>(const ResultItem& other) const { + return !(*this <= other); + } + __host__ __device__ bool operator!=(const ResultItem& other) const { + return !(*this == other); + } +}; + +constexpr __host__ __device__ size_t div_up(const size_t a, const size_t b) { + return a / b + (a % b != 0); +} + +constexpr int to_multiple_of_32(int number) { return div_up(number, 32) * 32; } + +constexpr int WARP_SIZE = 32; +constexpr unsigned int FULL_MASK = 0xffffffff; + +template +int get_batch_size(const int it_now, const T nrow, const int batch_size) { + int it_total = div_up(nrow, batch_size); + return (it_now == it_total - 1) ? nrow - it_now * batch_size : batch_size; +} + +// for avoiding bank conflict +template +constexpr __host__ __device__ __forceinline__ int skew_dim(int ndim) { + // all "4"s are for alignment + if constexpr (std::is_same::value) { + ndim = div_up(ndim, 4) * 4; + return ndim + (ndim % 32 == 0) * 4; + } +} + +template +__device__ __forceinline__ ResultItem xor_swap(ResultItem x, int mask, int dir) { + ResultItem y; + y.dist() = __shfl_xor_sync(FULL_MASK, x.dist(), mask, WARP_SIZE); + y.id_with_flag() = __shfl_xor_sync(FULL_MASK, x.id_with_flag(), mask, WARP_SIZE); + return x < y == dir ? y : x; +} + +__device__ __forceinline__ int xor_swap(int x, int mask, int dir) { + int y = __shfl_xor_sync(FULL_MASK, x, mask, WARP_SIZE); + return x < y == dir ? y : x; +} + +__device__ __forceinline__ uint bfe(uint lane_id, uint pos) { + uint res; + asm("bfe.u32 %0,%1,%2,%3;" : "=r"(res) : "r"(lane_id), "r"(pos), "r"(1)); + return res; +} + +// https://en.wikipedia.org/wiki/Xorshift#xorshift* +__host__ __device__ __forceinline__ uint64_t xorshift64(uint64_t x) { + x ^= x >> 12; + x ^= x << 25; + x ^= x >> 27; + return x * 0x2545F4914F6CDD1DULL; +} + +template +__device__ __forceinline__ void warp_bitonic_sort(T* element_ptr, const int lane_id) { + static_assert(WARP_SIZE == 32); + auto& element = *element_ptr; + element = xor_swap(element, 0x01, bfe(lane_id, 1) ^ bfe(lane_id, 0)); + element = xor_swap(element, 0x02, bfe(lane_id, 2) ^ bfe(lane_id, 1)); + element = xor_swap(element, 0x01, bfe(lane_id, 2) ^ bfe(lane_id, 0)); + element = xor_swap(element, 0x04, bfe(lane_id, 3) ^ bfe(lane_id, 2)); + element = xor_swap(element, 0x02, bfe(lane_id, 3) ^ bfe(lane_id, 1)); + element = xor_swap(element, 0x01, bfe(lane_id, 3) ^ bfe(lane_id, 0)); + element = xor_swap(element, 0x08, bfe(lane_id, 4) ^ bfe(lane_id, 3)); + element = xor_swap(element, 0x04, bfe(lane_id, 4) ^ bfe(lane_id, 2)); + element = xor_swap(element, 0x02, bfe(lane_id, 4) ^ bfe(lane_id, 1)); + element = xor_swap(element, 0x01, bfe(lane_id, 4) ^ bfe(lane_id, 0)); + element = xor_swap(element, 0x10, bfe(lane_id, 4)); + element = xor_swap(element, 0x08, bfe(lane_id, 3)); + element = xor_swap(element, 0x04, bfe(lane_id, 2)); + element = xor_swap(element, 0x02, bfe(lane_id, 1)); + element = xor_swap(element, 0x01, bfe(lane_id, 0)); + return; +} + +enum class Metric_t { + METRIC_INNER_PRODUCT = 0, + METRIC_L2 = 1, +}; + +struct BuildConfig { + size_t max_dataset_size; + size_t dataset_dim; + size_t node_degree{64}; + size_t internal_node_degree{0}; + // If internal_node_degree == 0, the value of node_degree will be assigned to it + size_t max_iterations{50}; + float termination_threshold{0.0001}; + Metric_t metric_type{Metric_t::METRIC_INNER_PRODUCT}; +}; + +template +class BloomFilter { + public: + BloomFilter(size_t nrow, size_t num_sets_per_list, size_t num_hashs) + : nrow_(nrow), + num_sets_per_list_(num_sets_per_list), + num_hashs_(num_hashs), + bitsets_(nrow * num_bits_per_set_ * num_sets_per_list) {} + + void add(size_t list_id, Index_t key) { + if (is_cleared) { + is_cleared = false; + } + uint32_t hash = hash_0(key); + size_t global_set_idx = list_id * num_bits_per_set_ * num_sets_per_list_ + + key % num_sets_per_list_ * num_bits_per_set_; + bitsets_[global_set_idx + hash % num_bits_per_set_] = 1; + for (size_t i = 1; i < num_hashs_; i++) { + hash = hash + hash_1(key); + bitsets_[global_set_idx + hash % num_bits_per_set_] = 1; + } + } + + bool check(size_t list_id, Index_t key) { + bool is_present = true; + uint32_t hash = hash_0(key); + size_t global_set_idx = list_id * num_bits_per_set_ * num_sets_per_list_ + + key % num_sets_per_list_ * num_bits_per_set_; + is_present &= bitsets_[global_set_idx + hash % num_bits_per_set_]; + + if (!is_present) return false; + for (size_t i = 1; i < num_hashs_; i++) { + hash = hash + hash_1(key); + is_present &= bitsets_[global_set_idx + hash % num_bits_per_set_]; + if (!is_present) return false; + } + return true; + } + + void clear() { + if (is_cleared) return; +#pragma omp parallel for + for (size_t i = 0; i < nrow_ * num_bits_per_set_ * num_sets_per_list_; i++) { + bitsets_[i] = 0; + } + is_cleared = true; + } + + private: + uint32_t hash_0(uint32_t value) { + value *= 1103515245; + value += 12345; + value ^= value << 13; + value ^= value >> 17; + value ^= value << 5; + return value; + } + + uint32_t hash_1(uint32_t value) { + value *= 1664525; + value += 1013904223; + value ^= value << 13; + value ^= value >> 17; + value ^= value << 5; + return value; + } + + static constexpr int num_bits_per_set_ = 512; + bool is_cleared{true}; + std::vector bitsets_; + size_t nrow_; + size_t num_sets_per_list_; + size_t num_hashs_; +}; + +template +struct GnndGraph { + static constexpr int segment_size = 32; + InternalID_t* h_graph; + DistData_t* h_dists; + + size_t nrow; + size_t node_degree; + int num_samples; + int num_segments; + + Index_t* h_graph_new; + int2* h_list_sizes_new; + + Index_t* h_graph_old; + int2* h_list_sizes_old; + BloomFilter bloom_filter; + + GnndGraph(const GnndGraph&) = delete; + GnndGraph& operator=(const GnndGraph&) = delete; + GnndGraph(const size_t nrow, const size_t node_degree, const size_t internal_node_degree, + const size_t num_samples); + void init_random_graph(); + // Use Bloom filter to sample "new" neighbors for local joining + void sample_graph_new(InternalID_t* new_neighbors, const size_t width); + void sample_graph(bool sample_new); + void update_graph(const InternalID_t* new_neighbors, const DistData_t* new_dists, + const size_t width, std::atomic& update_counter); + void sort_lists(); + void clear(); + void dealloc(); + ~GnndGraph(); +}; + +template +class GNND { + public: + GNND(const BuildConfig& build_config); + GNND(const GNND&) = delete; + GNND& operator=(const GNND&) = delete; + + // Use delete[] to deallocate the returned graph + void build(Data_t* data, const Index_t nrow, Index_t* output_graph, cudaStream_t stream = 0); + void dealloc(); + ~GNND(); + using ID_t = InternalID_t; + + private: + void add_reverse_edges(Index_t* graph_ptr, Index_t* h_rev_graph_ptr, Index_t* d_rev_graph_ptr, + int2* list_sizes, cudaStream_t stream = 0); + void local_join(cudaStream_t stream = 0); + void alloc_workspace(); + + BuildConfig build_config_; + GnndGraph graph_; + std::atomic update_counter_; + + __half* d_data_; + DistData_t* l2_norms_; + + ID_t* graph_buffer_; + DistData_t* dists_buffer_; + ID_t* graph_host_buffer_; + DistData_t* dists_host_buffer_; + + int* d_locks_; + + Index_t* h_rev_graph_new_; + // int2.x is the number of forward edges, int2.y is the number of reverse edges + int2* d_list_sizes_new_; + + Index_t* h_graph_old_; + Index_t* h_rev_graph_old_; + int2* d_list_sizes_old_; + + Index_t nrow_; + const int ndim_; +}; + +constexpr int TILE_ROW_WIDTH = 64; +constexpr int TILE_COL_WIDTH = 128; + +constexpr int NUM_SAMPLES = 32; +// For now, the max. number of samples is 32, so the sample cache size is fixed +// to 64 (32 * 2). +constexpr int MAX_NUM_BI_SAMPLES = 64; +constexpr int SKEWED_MAX_NUM_BI_SAMPLES = skew_dim(MAX_NUM_BI_SAMPLES); +constexpr int BLOCK_SIZE = 512; +constexpr int WMMA_M = 16; +constexpr int WMMA_N = 16; +constexpr int WMMA_K = 16; + +template +__device__ __forceinline__ void load_vec(Data_t *vec_buffer, const Data_t *d_vec, + const int load_dims, const int padding_dims, + const int lane_id) { + if constexpr (std::is_same_v or std::is_same_v or std::is_same_v) { + constexpr int num_load_elems_per_warp = WARP_SIZE; + for (int step = 0; step < div_up(load_dims, num_load_elems_per_warp); step++) { + int idx = step * num_load_elems_per_warp + lane_id; + if (idx < load_dims) { + vec_buffer[idx] = d_vec[idx]; + } else if (idx < padding_dims) { + vec_buffer[idx] = 0.0f; + } + } + } + if constexpr (std::is_same::value) { + if ((size_t)vec_buffer % sizeof(float2) == 0 && load_dims % 4 == 0 && + padding_dims % 4 == 0) { + constexpr int num_load_elems_per_warp = WARP_SIZE * 4; +#pragma unroll + for (int step = 0; step < div_up(load_dims, num_load_elems_per_warp); step++) { + int idx_in_vec = step * num_load_elems_per_warp + lane_id * 4; + if (idx_in_vec + 4 <= load_dims) { + *(float2 *)(vec_buffer + idx_in_vec) = *(float2 *)(d_vec + idx_in_vec); + } else if (idx_in_vec + 4 <= padding_dims) { + *(float2 *)(vec_buffer + idx_in_vec) = float2({0.0f, 0.0f}); + } + } + } else { + constexpr int num_load_elems_per_warp = WARP_SIZE; + for (int step = 0; step < div_up(load_dims, num_load_elems_per_warp); step++) { + int idx = step * num_load_elems_per_warp + lane_id; + if (idx < load_dims) { + vec_buffer[idx] = d_vec[idx]; + } else if (idx < padding_dims) { + vec_buffer[idx] = 0.0f; + } + } + } + } +} + +template +__global__ void preprocess_data_kernel(const Data_t *input_data, __half *output_data, Index_t nrow, + int dim, DistData_t *l2_norms) { + extern __shared__ char buffer[]; + __shared__ float l2_norm; + Data_t *s_vec = (Data_t *)buffer; + size_t list_id = blockIdx.x; + + load_vec(s_vec, input_data + list_id * dim, dim, dim, threadIdx.x % WARP_SIZE); + if (threadIdx.x == 0) { + l2_norm = 0; + } + __syncthreads(); + int lane_id = threadIdx.x % WARP_SIZE; + for (int step = 0; step < div_up(dim, WARP_SIZE); step++) { + int idx = step * WARP_SIZE + lane_id; + float part_dist = 0; + if (idx < dim) { + part_dist = s_vec[idx]; + part_dist = part_dist * part_dist; + } + __syncwarp(); + for (int offset = WARP_SIZE >> 1; offset >= 1; offset >>= 1) { + part_dist += __shfl_down_sync(FULL_MASK, part_dist, offset); + } + if (lane_id == 0) { + l2_norm += part_dist; + } + __syncwarp(); + } + + for (int step = 0; step < div_up(dim, WARP_SIZE); step++) { + int idx = step * WARP_SIZE + threadIdx.x; + if (idx < dim) { + if (l2_norms == nullptr) { + output_data[list_id * dim + idx] = (float)input_data[list_id * dim + idx] / sqrt(l2_norm); + } else { + output_data[list_id * dim + idx] = input_data[list_id * dim + idx]; + if (idx == 0) { + l2_norms[list_id] = l2_norm; + } + } + } + } +} + +template +__global__ void add_rev_edges_kernel(const Index_t *graph, Index_t *rev_graph, int num_samples, + int2 *list_sizes) { + size_t list_id = blockIdx.x; + int2 list_size = list_sizes[list_id]; + + for (int idx = threadIdx.x; idx < list_size.x; idx += blockDim.x) { + // each node has same number (num_samples) of forward and reverse edges + size_t rev_list_id = graph[list_id * num_samples + idx]; + // there are already num_samples forward edges + int idx_in_rev_list = atomicAdd(&list_sizes[rev_list_id].y, 1); + if (idx_in_rev_list >= num_samples) { + atomicExch(&list_sizes[rev_list_id].y, num_samples); + } else { + rev_graph[rev_list_id * num_samples + idx_in_rev_list] = list_id; + } + } +} + +template > +__device__ void insert_to_global_graph(ResultItem elem, size_t list_id, ID_t *graph, + DistData_t *dists, int node_degree, int *locks, + bool new_new = true) { + int tx = threadIdx.x; + int lane_id = tx % WARP_SIZE; + size_t global_idx_base = list_id * node_degree; + if (elem.id() == list_id) return; + + const int num_segments = div_up(node_degree, WARP_SIZE); + + int loop_flag = 0; + do { + int segment_id = elem.id() % num_segments; + if (lane_id == 0) { + loop_flag = atomicCAS(&locks[list_id * num_segments + segment_id], 0, 1) == 0; + } + + loop_flag = __shfl_sync(FULL_MASK, loop_flag, 0); + + if (loop_flag == 1) { + ResultItem knn_list_frag; + int local_idx = segment_id * WARP_SIZE + lane_id; + size_t global_idx = global_idx_base + local_idx; + if (local_idx < node_degree) { + knn_list_frag.id_with_flag() = graph[global_idx].id_with_flag(); + knn_list_frag.dist() = dists[global_idx]; + } + + int pos_to_insert = -1; + ResultItem prev_elem; + + prev_elem.id_with_flag() = __shfl_up_sync(FULL_MASK, knn_list_frag.id_with_flag(), 1); + prev_elem.dist() = __shfl_up_sync(FULL_MASK, knn_list_frag.dist(), 1); + + if (lane_id == 0) { + prev_elem = ResultItem{std::numeric_limits::min(), + std::numeric_limits::lowest()}; + } + if (elem > prev_elem && elem < knn_list_frag) { + pos_to_insert = segment_id * WARP_SIZE + lane_id; + } else if (elem == prev_elem || elem == knn_list_frag) { + pos_to_insert = -2; + } + uint mask = __ballot_sync(FULL_MASK, pos_to_insert >= 0); + if (mask) { + uint set_lane_id = __fns(mask, 0, 1); + pos_to_insert = __shfl_sync(FULL_MASK, pos_to_insert, set_lane_id); + } + + if (pos_to_insert >= 0) { + int local_idx = segment_id * WARP_SIZE + lane_id; + if (local_idx > pos_to_insert) { + local_idx++; + } else if (local_idx == pos_to_insert) { + graph[global_idx_base + local_idx].id_with_flag() = elem.id_with_flag(); + dists[global_idx_base + local_idx] = elem.dist(); + local_idx++; + } + size_t global_pos = global_idx_base + local_idx; + if (local_idx < (segment_id + 1) * WARP_SIZE && local_idx < node_degree) { + graph[global_pos].id_with_flag() = knn_list_frag.id_with_flag(); + dists[global_pos] = knn_list_frag.dist(); + } + } + __threadfence(); + if (loop_flag && lane_id == 0) { + atomicExch(&locks[list_id * num_segments + segment_id], 0); + } + } + } while (!loop_flag); +} + +template +__device__ ResultItem get_min_item(const Index_t id, const int idx_in_list, + const Index_t *neighbs, const DistData_t *distances, + const bool find_in_row = true) { + int lane_id = threadIdx.x % WARP_SIZE; + + static_assert(MAX_NUM_BI_SAMPLES == 64); + int idx[MAX_NUM_BI_SAMPLES / WARP_SIZE]; + float dist[MAX_NUM_BI_SAMPLES / WARP_SIZE] = {std::numeric_limits::max(), + std::numeric_limits::max()}; + idx[0] = lane_id; + idx[1] = WARP_SIZE + lane_id; + + if (neighbs[idx[0]] != id) { + dist[0] = find_in_row ? distances[idx_in_list * SKEWED_MAX_NUM_BI_SAMPLES + lane_id] + : distances[idx_in_list + lane_id * SKEWED_MAX_NUM_BI_SAMPLES]; + } + + if (neighbs[idx[1]] != id) { + dist[1] = find_in_row + ? distances[idx_in_list * SKEWED_MAX_NUM_BI_SAMPLES + WARP_SIZE + lane_id] + : distances[idx_in_list + (WARP_SIZE + lane_id) * SKEWED_MAX_NUM_BI_SAMPLES]; + } + + if (dist[1] < dist[0]) { + dist[0] = dist[1]; + idx[0] = idx[1]; + } + __syncwarp(); + for (int offset = WARP_SIZE >> 1; offset >= 1; offset >>= 1) { + float other_idx = __shfl_down_sync(FULL_MASK, idx[0], offset); + float other_dist = __shfl_down_sync(FULL_MASK, dist[0], offset); + if (other_dist < dist[0]) { + dist[0] = other_dist; + idx[0] = other_idx; + } + } + + ResultItem result; + result.dist() = __shfl_sync(FULL_MASK, dist[0], 0); + result.id_with_flag() = neighbs[__shfl_sync(FULL_MASK, idx[0], 0)]; + return result; +} + +template +__device__ __forceinline__ void remove_duplicates(T *list_a, int list_a_size, T *list_b, + int list_b_size, int &unique_counter, + int execute_warp_id) { + static_assert(WARP_SIZE == 32); + if (!(threadIdx.x >= execute_warp_id * WARP_SIZE && + threadIdx.x < execute_warp_id * WARP_SIZE + WARP_SIZE)) { + return; + } + int lane_id = threadIdx.x % WARP_SIZE; + T elem = std::numeric_limits::max(); + if (lane_id < list_a_size) { + elem = list_a[lane_id]; + } + warp_bitonic_sort(&elem, lane_id); + + if (elem != std::numeric_limits::max()) { + list_a[lane_id] = elem; + } + + T elem_b = std::numeric_limits::max(); + + if (lane_id < list_b_size) { + elem_b = list_b[lane_id]; + } + __syncwarp(); + + int idx_l = 0; + int idx_r = list_a_size; + bool existed = false; + while (idx_l < idx_r) { + int idx = (idx_l + idx_r) / 2; + int elem = list_a[idx]; + if (elem == elem_b) { + existed = true; + break; + } + if (elem_b > elem) { + idx_l = idx + 1; + } else { + idx_r = idx; + } + } + if (!existed && elem_b != std::numeric_limits::max()) { + int idx = atomicAdd(&unique_counter, 1); + list_a[list_a_size + idx] = elem_b; + } +} + +template > +__global__ void __launch_bounds__(BLOCK_SIZE, 4) + local_join_kernel(const Index_t *graph_new, const Index_t *rev_graph_new, const int2 *sizes_new, + const Index_t *graph_old, const Index_t *rev_graph_old, const int2 *sizes_old, + const int width, const __half *data, const int data_dim, ID_t *graph, + DistData_t *dists, int graph_width, int *locks, DistData_t *l2_norms) { + using namespace nvcuda; + __shared__ int s_list[MAX_NUM_BI_SAMPLES * 2]; + + constexpr int APAD = 8; + constexpr int BPAD = 8; + __shared__ __half s_nv[MAX_NUM_BI_SAMPLES][TILE_COL_WIDTH + APAD]; // New vectors + __shared__ __half s_ov[MAX_NUM_BI_SAMPLES][TILE_COL_WIDTH + BPAD]; // Old vectors + static_assert(sizeof(float) * MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES <= + sizeof(__half) * MAX_NUM_BI_SAMPLES * (TILE_COL_WIDTH + BPAD)); + // s_distances: MAX_NUM_BI_SAMPLES x SKEWED_MAX_NUM_BI_SAMPLES, reuse the space of s_ov + float *s_distances = (float *)&s_ov[0][0]; + int *s_unique_counter = (int *)&s_ov[0][0]; + + if (threadIdx.x == 0) { + s_unique_counter[0] = 0; + s_unique_counter[1] = 0; + } + + Index_t *new_neighbors = s_list; + Index_t *old_neighbors = s_list + MAX_NUM_BI_SAMPLES; + + size_t list_id = blockIdx.x; + int2 list_new_size2 = sizes_new[list_id]; + int list_new_size = list_new_size2.x + list_new_size2.y; + int2 list_old_size2 = sizes_old[list_id]; + int list_old_size = list_old_size2.x + list_old_size2.y; + + if (!list_new_size) return; + int tx = threadIdx.x; + + if (tx < list_new_size2.x) { + new_neighbors[tx] = graph_new[list_id * width + tx]; + } else if (tx >= list_new_size2.x && tx < list_new_size) { + new_neighbors[tx] = rev_graph_new[list_id * width + tx - list_new_size2.x]; + } + + if (tx < list_old_size2.x) { + old_neighbors[tx] = graph_old[list_id * width + tx]; + } else if (tx >= list_old_size2.x && tx < list_old_size) { + old_neighbors[tx] = rev_graph_old[list_id * width + tx - list_old_size2.x]; + } + + __syncthreads(); + + remove_duplicates(new_neighbors, list_new_size2.x, new_neighbors + list_new_size2.x, + list_new_size2.y, s_unique_counter[0], 0); + + remove_duplicates(old_neighbors, list_old_size2.x, old_neighbors + list_old_size2.x, + list_old_size2.y, s_unique_counter[1], 1); + __syncthreads(); + list_new_size = list_new_size2.x + s_unique_counter[0]; + list_old_size = list_old_size2.x + s_unique_counter[1]; + + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + constexpr int num_warps = BLOCK_SIZE / WARP_SIZE; + + int warp_id_y = warp_id / 4; + int warp_id_x = warp_id % 4; + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0); + for (int step = 0; step < div_up(data_dim, TILE_COL_WIDTH); step++) { + int num_load_elems = (step == div_up(data_dim, TILE_COL_WIDTH) - 1) + ? data_dim - step * TILE_COL_WIDTH + : TILE_COL_WIDTH; +#pragma unroll + for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { + int idx = i * num_warps + warp_id; + if (idx < list_new_size) { + size_t neighbor_id = new_neighbors[idx]; + size_t idx_in_data = neighbor_id * data_dim; + load_vec(s_nv[idx], data + idx_in_data + step * TILE_COL_WIDTH, num_load_elems, + TILE_COL_WIDTH, lane_id); + } + } + __syncthreads(); + + for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) { + wmma::load_matrix_sync(a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, + TILE_COL_WIDTH + APAD); + wmma::load_matrix_sync(b_frag, s_nv[warp_id_x * WMMA_N] + i * WMMA_K, + TILE_COL_WIDTH + BPAD); + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + __syncthreads(); + } + } + + wmma::store_matrix_sync( + s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N, c_frag, + SKEWED_MAX_NUM_BI_SAMPLES, wmma::mem_row_major); + __syncthreads(); + + for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { + if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_new_size && i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) { + if (l2_norms == nullptr) { + s_distances[i] = -s_distances[i]; + } else { + s_distances[i] = l2_norms[new_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] + + l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] - + 2.0 * s_distances[i]; + } + } else { + s_distances[i] = std::numeric_limits::max(); + } + } + __syncthreads(); + + for (int step = 0; step < div_up(list_new_size, num_warps); step++) { + int idx_in_list = step * num_warps + tx / WARP_SIZE; + if (idx_in_list >= list_new_size) continue; + auto min_elem = get_min_item(s_list[idx_in_list], idx_in_list, new_neighbors, s_distances); + if (min_elem.id() < gridDim.x) { + insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks, + true); + } + } + + if (!list_old_size) return; + + __syncthreads(); + + wmma::fill_fragment(c_frag, 0.0); + for (int step = 0; step < div_up(data_dim, TILE_COL_WIDTH); step++) { + int num_load_elems = (step == div_up(data_dim, TILE_COL_WIDTH) - 1) + ? data_dim - step * TILE_COL_WIDTH + : TILE_COL_WIDTH; + if (TILE_COL_WIDTH < data_dim) { +#pragma unroll + for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { + int idx = i * num_warps + warp_id; + if (idx < list_new_size) { + size_t neighbor_id = new_neighbors[idx]; + size_t idx_in_data = neighbor_id * data_dim; + load_vec(s_nv[idx], data + idx_in_data + step * TILE_COL_WIDTH, num_load_elems, + TILE_COL_WIDTH, lane_id); + } + } + } +#pragma unroll + for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { + int idx = i * num_warps + warp_id; + if (idx < list_old_size) { + size_t neighbor_id = old_neighbors[idx]; + size_t idx_in_data = neighbor_id * data_dim; + load_vec(s_ov[idx], data + idx_in_data + step * TILE_COL_WIDTH, num_load_elems, + TILE_COL_WIDTH, lane_id); + } + } + __syncthreads(); + + for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) { + wmma::load_matrix_sync(a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, + TILE_COL_WIDTH + APAD); + wmma::load_matrix_sync(b_frag, s_ov[warp_id_x * WMMA_N] + i * WMMA_K, + TILE_COL_WIDTH + BPAD); + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + __syncthreads(); + } + } + + wmma::store_matrix_sync( + s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N, c_frag, + SKEWED_MAX_NUM_BI_SAMPLES, wmma::mem_row_major); + __syncthreads(); + + for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { + if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_old_size && i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) { + if (l2_norms == nullptr) { + s_distances[i] = -s_distances[i]; + } else { + s_distances[i] = l2_norms[old_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] + + l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] - + 2.0 * s_distances[i]; + } + } else { + s_distances[i] = std::numeric_limits::max(); + } + } + __syncthreads(); + + for (int step = 0; step < div_up(MAX_NUM_BI_SAMPLES * 2, num_warps); step++) { + int idx_in_list = step * num_warps + tx / WARP_SIZE; + if (idx_in_list >= list_new_size && idx_in_list < MAX_NUM_BI_SAMPLES) continue; + if (idx_in_list >= MAX_NUM_BI_SAMPLES + list_old_size && + idx_in_list < MAX_NUM_BI_SAMPLES * 2) + continue; + ResultItem min_elem{std::numeric_limits::max(), + std::numeric_limits::max()}; + if (idx_in_list < MAX_NUM_BI_SAMPLES) { + auto temp_min_item = + get_min_item(s_list[idx_in_list], idx_in_list, old_neighbors, s_distances); + if (temp_min_item.dist() < min_elem.dist()) { + min_elem = temp_min_item; + } + } else { + auto temp_min_item = get_min_item(s_list[idx_in_list], idx_in_list - MAX_NUM_BI_SAMPLES, + new_neighbors, s_distances, false); + if (temp_min_item.dist() < min_elem.dist()) { + min_elem = temp_min_item; + } + } + + if (min_elem.id() < gridDim.x) { + insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks, + false); + } + } +} + +namespace { +template +int insert_to_ordered_list(InternalID_t* list, DistData_t* dist_list, const int width, + const InternalID_t neighb_id, const DistData_t dist) { + if (dist > dist_list[width - 1]) { + return width; + } + + int idx_insert = width; + for (int i = 0; i < width; i++) { + if (list[i].id() == neighb_id.id()) { + return width; + } + if (dist_list[i] > dist) { + idx_insert = i; + break; + } + } + if (idx_insert == width) return idx_insert; + + memmove(list + idx_insert + 1, list + idx_insert, sizeof(*list) * (width - idx_insert - 1)); + memmove(dist_list + idx_insert + 1, dist_list + idx_insert, + sizeof(*dist_list) * (width - idx_insert - 1)); + + list[idx_insert] = neighb_id; + dist_list[idx_insert] = dist; + return idx_insert; +}; + +} // namespace + +template +GnndGraph::GnndGraph(const size_t nrow, const size_t node_degree, + const size_t internal_node_degree, const size_t num_samples) + : nrow(nrow), + node_degree(node_degree), + num_samples(num_samples), + bloom_filter(nrow, internal_node_degree / segment_size, 3) { + // node_degree must be a multiple of segment_size; + assert(node_degree % segment_size == 0); + assert(internal_node_degree % segment_size == 0); + + num_segments = node_degree / segment_size; + // To save the CPU memory, graph should be allocated by external function + h_graph = nullptr; + h_dists = new DistData_t[nrow * node_degree]; + + RAFT_CUDA_TRY(cudaMallocHost(&h_graph_new, sizeof(*h_graph_new) * nrow * num_samples)); + RAFT_CUDA_TRY(cudaMallocHost(&h_list_sizes_new, sizeof(*h_list_sizes_new) * nrow)); + + RAFT_CUDA_TRY(cudaMallocHost(&h_graph_old, sizeof(*h_graph_old) * nrow * num_samples)); + RAFT_CUDA_TRY(cudaMallocHost(&h_list_sizes_old, sizeof(*h_list_sizes_old) * nrow)); +} + +// This is the only operation on the CPU that cannot be overlapped. +// So it should be as fast as possible. +template +void GnndGraph::sample_graph_new(InternalID_t* new_neighbors, + const size_t width) { +#pragma omp parallel for + for (size_t i = 0; i < nrow; i++) { + auto list_new = h_graph_new + i * num_samples; + h_list_sizes_new[i].x = 0; + h_list_sizes_new[i].y = 0; + + for (size_t j = 0; j < width; j++) { + auto new_neighb_id = new_neighbors[i * width + j].id(); + if ((size_t)new_neighb_id >= nrow) break; + if (bloom_filter.check(i, new_neighb_id)) { + continue; + } + bloom_filter.add(i, new_neighb_id); + new_neighbors[i * width + j].mark_old(); + list_new[h_list_sizes_new[i].x++] = new_neighb_id; + if (h_list_sizes_new[i].x == num_samples) break; + } + } +} + +template +void GnndGraph::init_random_graph() { + // random sequence (range: 0~nrow) + std::vector rand_seq(nrow); + std::iota(rand_seq.begin(), rand_seq.end(), 0); + std::random_shuffle(rand_seq.begin(), rand_seq.end()); + +#pragma omp parallel for + for (size_t i = 0; i < nrow; i++) { + for (size_t j = 0; j < NUM_SAMPLES; j++) { + size_t idx = i * NUM_SAMPLES + j; + Index_t id = rand_seq[idx % nrow]; + if ((size_t)id == i) { + id = rand_seq[(idx + NUM_SAMPLES) % nrow]; + } + h_graph[i * node_degree + j].id_with_flag() = id; + } + for (size_t j = NUM_SAMPLES; j < node_degree; j++) { + h_graph[i * node_degree + j].id_with_flag() = std::numeric_limits::max(); + } + for (size_t j = 0; j < node_degree; j++) { + h_dists[i * node_degree + j] = std::numeric_limits::max(); + } + } +} + +template +void GnndGraph::sample_graph(bool sample_new) { +#pragma omp parallel for + for (size_t i = 0; i < nrow; i++) { + h_list_sizes_old[i].x = 0; + h_list_sizes_old[i].y = 0; + h_list_sizes_new[i].x = 0; + h_list_sizes_new[i].y = 0; + + auto list = h_graph + i * node_degree; + auto list_old = h_graph_old + i * num_samples; + auto list_new = h_graph_new + i * num_samples; + for (int j = 0; j < segment_size; j++) { + for (int k = 0; k < num_segments; k++) { + auto neighbor = list[k * segment_size + j]; + if ((size_t)neighbor.id() >= nrow) continue; + if (!neighbor.is_new()) { + if (h_list_sizes_old[i].x < num_samples) { + list_old[h_list_sizes_old[i].x++] = neighbor.id(); + } + } else if (sample_new) { + if (h_list_sizes_new[i].x < num_samples) { + list[k * segment_size + j].mark_old(); + list_new[h_list_sizes_new[i].x++] = neighbor.id(); + } + } + if (h_list_sizes_old[i].x == num_samples && h_list_sizes_new[i].x == num_samples) { + break; + } + } + if (h_list_sizes_old[i].x == num_samples && h_list_sizes_new[i].x == num_samples) { + break; + } + } + } +} + +template +void GnndGraph::update_graph(const InternalID_t* new_neighbors, + const DistData_t* new_dists, const size_t width, + std::atomic& update_counter) { +#pragma omp parallel for + for (size_t i = 0; i < nrow; i++) { + for (size_t j = 0; j < width; j++) { + auto new_neighb_id = new_neighbors[i * width + j]; + auto new_dist = new_dists[i * width + j]; + if (new_dist == std::numeric_limits::max()) break; + if ((size_t)new_neighb_id.id() == i) continue; + int idx_seg = new_neighb_id.id() % num_segments; + auto list = h_graph + i * node_degree + idx_seg * segment_size; + auto dist_list = h_dists + i * node_degree + idx_seg * segment_size; + int insert_pos = + insert_to_ordered_list(list, dist_list, segment_size, new_neighb_id, new_dist); + if (i % counter_interval == 0 && insert_pos != segment_size) { + update_counter++; + } + } + } +} + +template +void GnndGraph::sort_lists() { +#pragma omp parallel for + for (size_t i = 0; i < nrow; i++) { + std::vector> new_list; + for (size_t j = 0; j < node_degree; j++) { + new_list.emplace_back(h_dists[i * node_degree + j], h_graph[i * node_degree + j].id()); + } + std::sort(new_list.begin(), new_list.end()); + for (size_t j = 0; j < node_degree; j++) { + h_graph[i * node_degree + j].id_with_flag() = new_list[j].second; + h_dists[i * node_degree + j] = new_list[j].first; + } + } +} + +template +void GnndGraph::clear() { + bloom_filter.clear(); +} + +template +void GnndGraph::dealloc() { + delete[] h_dists; + RAFT_CUDA_TRY(cudaFreeHost(h_graph_new)); + RAFT_CUDA_TRY(cudaFreeHost(h_list_sizes_new)); + RAFT_CUDA_TRY(cudaFreeHost(h_graph_old)); + RAFT_CUDA_TRY(cudaFreeHost(h_list_sizes_old)); + assert(h_graph == nullptr); +} + +template +GnndGraph::~GnndGraph() { + +} + +template +GNND::GNND(const BuildConfig& build_config) + : build_config_(build_config), + graph_(build_config.max_dataset_size, + to_multiple_of_32(build_config.node_degree * + (build_config.node_degree <= 32 ? 1.0 : 1.3)), + to_multiple_of_32(build_config.internal_node_degree ? build_config.internal_node_degree + : build_config.node_degree), + NUM_SAMPLES), + nrow_(build_config.max_dataset_size), + ndim_(build_config.dataset_dim) { + static_assert(NUM_SAMPLES <= 32); + alloc_workspace(); +}; + +template +void GNND::alloc_workspace() { + RAFT_CUDA_TRY(cudaMalloc(&d_data_, sizeof(__half) * nrow_ * ndim_)); + RAFT_CUDA_TRY(cudaMallocHost(&graph_host_buffer_, + sizeof(*graph_host_buffer_) * nrow_ * DEGREE_ON_DEVICE)); + RAFT_CUDA_TRY(cudaMallocHost(&dists_host_buffer_, + sizeof(*dists_host_buffer_) * nrow_ * DEGREE_ON_DEVICE)); + RAFT_CUDA_TRY( + cudaMalloc(&dists_buffer_, sizeof(*dists_buffer_) * nrow_ * DEGREE_ON_DEVICE)); + thrust::fill(thrust::device, dists_buffer_, dists_buffer_ + (size_t)nrow_ * DEGREE_ON_DEVICE, + std::numeric_limits::max()); + RAFT_CUDA_TRY( + cudaMalloc(&graph_buffer_, sizeof(*graph_buffer_) * nrow_ * DEGREE_ON_DEVICE)); + thrust::fill(thrust::device, reinterpret_cast(graph_buffer_), + reinterpret_cast(graph_buffer_) + (size_t)nrow_ * DEGREE_ON_DEVICE, + std::numeric_limits::max()); + RAFT_CUDA_TRY(cudaMalloc(&d_locks_, sizeof(*d_locks_) * nrow_)); + RAFT_CUDA_TRY( + cudaMallocHost(&h_rev_graph_new_, sizeof(*h_rev_graph_new_) * nrow_ * NUM_SAMPLES)); + RAFT_CUDA_TRY(cudaMallocHost(&h_graph_old_, sizeof(*h_graph_old_) * nrow_ * NUM_SAMPLES)); + RAFT_CUDA_TRY( + cudaMallocHost(&h_rev_graph_old_, sizeof(*h_rev_graph_old_) * nrow_ * NUM_SAMPLES)); + RAFT_CUDA_TRY(cudaMalloc(&d_list_sizes_new_, sizeof(*d_list_sizes_new_) * nrow_)); + RAFT_CUDA_TRY(cudaMalloc(&d_list_sizes_old_, sizeof(*d_list_sizes_old_) * nrow_)); + + if (build_config_.metric_type == Metric_t::METRIC_L2) { + RAFT_CUDA_TRY(cudaMalloc(&l2_norms_, sizeof(*l2_norms_) * nrow_)); + } else { + l2_norms_ = nullptr; + } +} + +template +void GNND::add_reverse_edges(Index_t* graph_ptr, Index_t* h_rev_graph_ptr, + Index_t* d_rev_graph_ptr, int2* list_sizes, + cudaStream_t stream) { + add_rev_edges_kernel<<>>(graph_ptr, d_rev_graph_ptr, NUM_SAMPLES, + list_sizes); + RAFT_CUDA_TRY(cudaMemcpyAsync(h_rev_graph_ptr, d_rev_graph_ptr, + sizeof(*h_rev_graph_ptr) * nrow_ * NUM_SAMPLES, + cudaMemcpyDeviceToHost, stream)); +} + +template +void GNND::local_join(cudaStream_t stream) { + thrust::fill(thrust::device.on(stream), dists_buffer_, + dists_buffer_ + (size_t)nrow_ * DEGREE_ON_DEVICE, + std::numeric_limits::max()); + local_join_kernel<<>>( + graph_.h_graph_new, h_rev_graph_new_, d_list_sizes_new_, h_graph_old_, h_rev_graph_old_, + d_list_sizes_old_, NUM_SAMPLES, d_data_, ndim_, graph_buffer_, dists_buffer_, + DEGREE_ON_DEVICE, d_locks_, l2_norms_); +} + +template +void GNND::build(Data_t* data, const Index_t nrow, Index_t* output_graph, cudaStream_t stream) { + nrow_ = nrow; + graph_.h_graph = (InternalID_t*)output_graph; + + cudaPointerAttributes data_ptr_attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data)); + if (data_ptr_attr.type == cudaMemoryTypeUnregistered) { + RAFT_CUDA_TRY(cudaHostRegister(const_cast*>(data), sizeof(Data_t) * nrow * build_config_.dataset_dim, + cudaHostRegisterDefault)); + } + + preprocess_data_kernel<<< + nrow_, WARP_SIZE, sizeof(Data_t) * div_up(build_config_.dataset_dim, WARP_SIZE) * WARP_SIZE, + stream>>>(data, d_data_, nrow_, build_config_.dataset_dim, l2_norms_); + + thrust::fill(thrust::device.on(stream), (Index_t*)graph_buffer_, + (Index_t*)graph_buffer_ + (size_t)nrow_ * DEGREE_ON_DEVICE, + std::numeric_limits::max()); + + graph_.clear(); + graph_.init_random_graph(); + graph_.sample_graph(true); + + auto update_and_sample = [&](bool update_graph) { + if (update_graph) { + update_counter_ = 0; + graph_.update_graph(graph_host_buffer_, dists_host_buffer_, DEGREE_ON_DEVICE, + update_counter_); + if (update_counter_ < build_config_.termination_threshold * nrow_ * + build_config_.dataset_dim / counter_interval) { + update_counter_ = -1; + } + } + graph_.sample_graph(false); + }; + + for (size_t it = 0; it < build_config_.max_iterations; it++) { + RAFT_CUDA_TRY(cudaMemcpyAsync(d_list_sizes_new_, graph_.h_list_sizes_new, + sizeof(*d_list_sizes_new_) * nrow_, + cudaMemcpyHostToDevice, stream)); + RAFT_CUDA_TRY(cudaMemcpyAsync(h_graph_old_, graph_.h_graph_old, + sizeof(*h_graph_old_) * nrow_ * NUM_SAMPLES, + cudaMemcpyHostToHost, stream)); + RAFT_CUDA_TRY(cudaMemcpyAsync(d_list_sizes_old_, graph_.h_list_sizes_old, + sizeof(*d_list_sizes_old_) * nrow_, + cudaMemcpyHostToDevice, stream)); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + + std::thread update_and_sample_thread(update_and_sample, it); + + std::cout << "# GNND iteraton: " << it + 1 << "/" << build_config_.max_iterations << "\r"; + std::fflush(stdout); + + // Reuse dists_buffer_ to save GPU memory. graph_buffer_ cannot be reused, because it + // contains some information for local_join. + static_assert(DEGREE_ON_DEVICE * sizeof(*dists_buffer_) >= + NUM_SAMPLES * sizeof(*graph_buffer_)); + add_reverse_edges(graph_.h_graph_new, h_rev_graph_new_, (Index_t*)dists_buffer_, + d_list_sizes_new_, stream); + add_reverse_edges(h_graph_old_, h_rev_graph_old_, (Index_t*)dists_buffer_, + d_list_sizes_old_, stream); + + local_join(stream); + + update_and_sample_thread.join(); + + if (update_counter_ == -1) { + break; + } + + RAFT_CUDA_TRY(cudaMemcpyAsync(graph_host_buffer_, graph_buffer_, + sizeof(*graph_buffer_) * nrow_ * DEGREE_ON_DEVICE, + cudaMemcpyDeviceToHost, stream)); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + RAFT_CUDA_TRY(cudaMemcpyAsync(dists_host_buffer_, dists_buffer_, + sizeof(*dists_buffer_) * nrow_ * DEGREE_ON_DEVICE, + cudaMemcpyDeviceToHost, stream)); + graph_.sample_graph_new(graph_host_buffer_, DEGREE_ON_DEVICE); + } + + graph_.update_graph(graph_host_buffer_, dists_host_buffer_, DEGREE_ON_DEVICE, update_counter_); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + + graph_.sort_lists(); + + // Reuse graph_.h_dists as the buffer for shrink the lists in graph + static_assert(sizeof(decltype(*graph_.h_dists)) >= sizeof(Index_t)); + Index_t* graph_shrink_buffer = (Index_t*)graph_.h_dists; + +#pragma omp parallel for + for (size_t i = 0; i < (size_t)nrow_; i++) { + for (size_t j = 0; j < build_config_.node_degree; j++) { + size_t idx = i * graph_.node_degree + j; + Index_t id = graph_.h_graph[idx].id(); + if (id < nrow_) { + graph_shrink_buffer[i * build_config_.node_degree + j] = id; + } else { + graph_shrink_buffer[i * build_config_.node_degree + j] = xorshift64(idx) % nrow_; + } + } + } + graph_.h_graph = nullptr; + +#pragma omp parallel for + for (size_t i = 0; i < (size_t)nrow_; i++) { + for (size_t j = 0; j < build_config_.node_degree; j++) { + output_graph[i * build_config_.node_degree + j] = + graph_shrink_buffer[i * build_config_.node_degree + j]; + } + } + + graph_.dealloc(); +} + +template +void GNND::dealloc() { + RAFT_CUDA_TRY(cudaFree(d_data_)); + RAFT_CUDA_TRY(cudaFreeHost(graph_host_buffer_)); + RAFT_CUDA_TRY(cudaFreeHost(dists_host_buffer_)); + RAFT_CUDA_TRY(cudaFree(dists_buffer_)); + RAFT_CUDA_TRY(cudaFree(graph_buffer_)); + RAFT_CUDA_TRY(cudaFree(d_locks_)); + RAFT_CUDA_TRY(cudaFreeHost(h_rev_graph_new_)); + RAFT_CUDA_TRY(cudaFreeHost(h_graph_old_)); + RAFT_CUDA_TRY(cudaFreeHost(h_rev_graph_old_)); + RAFT_CUDA_TRY(cudaFree(d_list_sizes_new_)); + RAFT_CUDA_TRY(cudaFree(d_list_sizes_old_)); + RAFT_CUDA_TRY(cudaFree(l2_norms_)); +} + +template +GNND::~GNND() { +} + +template , memory_type::host>> +index build(raft::resources const& res, + const index_params& params, + mdspan, row_major, Accessor> dataset) { + RAFT_EXPECTS(dataset.size() >= std::numeric_limits::max() - 1, + "The dataset_size for GNND should be less than %d", + std::numeric_limits::max() - 1); + size_t intermediate_degree = params.intermediate_graph_degree; + size_t graph_degree = params.graph_degree; + if (intermediate_degree >= static_cast(dataset.extent(0))) { + RAFT_LOG_WARN( + "Intermediate graph degree cannot be larger than dataset size, reducing it to %lu", + dataset.extent(0)); + intermediate_degree = dataset.extent(0) - 1; + } + if (intermediate_degree < graph_degree) { + RAFT_LOG_WARN( + "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " + "graph_degree.", + graph_degree, + intermediate_degree); + graph_degree = intermediate_degree; + } + + index idx{res, dataset.extent(0), static_cast(graph_degree)}; + + BuildConfig build_config{.max_dataset_size = static_cast(dataset.extent(0)), + .dataset_dim = static_cast(dataset.extent(1)), + .node_degree = graph_degree, + .internal_node_degree = intermediate_degree, + .max_iterations = params.max_iterations, + .termination_threshold = params.termination_threshold, + .metric_type = Metric_t::METRIC_L2}; + + GNND nnd(build_config); + nnd.build(dataset.data_handle(), dataset.extent(0), idx.int_graph().data_handle(), resource::get_cuda_stream(res)); + nnd.dealloc(); + + return idx; +} + +} diff --git a/cpp/include/raft/neighbors/nn_descent.cuh b/cpp/include/raft/neighbors/nn_descent.cuh new file mode 100644 index 0000000000..25fe9675bc --- /dev/null +++ b/cpp/include/raft/neighbors/nn_descent.cuh @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "detail/nn_descent.cuh" + +#include +#include + +namespace raft::neighbors::nn_descent { + +/** + * @defgroup nn-descent CUDA ANN Graph-based gradient descent nearest neighbor + * @{ + */ + +/** + * @brief Build nn-descent Index with dataset in device memory + * + * @tparam T + * @tparam IdxT + * @param res raft::resources + * @param params nn_descent::index_params + * @param dataset raft::device_matrix_view + * @return index + */ +template +index build(raft::resources const& res, + index_params const& params, + raft::device_matrix_view dataset) { + return detail::build(res, params, dataset); +} + +/** + * @brief Build nn-descent Index with dataset in host memory + * + * @tparam T + * @tparam IdxT + * @param res raft::resources + * @param params nn_descent::index_params + * @param dataset raft::host_matrix_view + * @return index + */ +template +index build(raft::resources const& res, + index_params const& params, + raft::host_matrix_view dataset) { + return detail::build(res, params, dataset); +} + +/** @} */ // end group cagra + +} \ No newline at end of file diff --git a/cpp/include/raft/neighbors/nn_descent_types.hpp b/cpp/include/raft/neighbors/nn_descent_types.hpp new file mode 100644 index 0000000000..a793857773 --- /dev/null +++ b/cpp/include/raft/neighbors/nn_descent_types.hpp @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "ann_types.hpp" +#include + +#include +#include +#include +#include +#include + +namespace raft::neighbors::nn_descent { +/** + * @ingroup nn_descent + * @{ + */ + +struct index_params : ann::index_params { + size_t intermediate_graph_degree = 128; // Degree of input graph for pruning. + size_t graph_degree = 64; // Degree of output graph. + size_t max_iterations = 50; // Number of nn-descent iterations. + float termination_threshold = 0.0001; // Termination threshold of nn-descent. +}; + +/** + * @brief nn-descent Index + * + * @tparam IdxT dtype to be used for constructing knn-graph + */ +template +struct index : ann::index { +public: + /** + * @brief Construct a new index object + * + * This constructor creates an nn-descent index which is a knn-graph in host memory. + * The type of the knn-graph is a dense raft::host_matrix and dimensions are + * (n_rows, n_cols). + * + * @param res raft::resources + * @param n_rows number of rows in knn-graph + * @param n_cols number of cols in knn-graph + */ + index(raft::resources const& res, int64_t n_rows, int64_t n_cols) : + ann::index(), + res_{res}, + metric_{raft::distance::DistanceType::L2Expanded}, + int_graph_{raft::make_host_matrix(n_rows, n_cols)}, + graph_{raft::make_host_matrix(0, 0)} { } + + /** Distance metric used for clustering. */ + [[nodiscard]] constexpr inline auto metric() const noexcept -> raft::distance::DistanceType + { + return metric_; + } + + // /** Total length of the index (number of vectors). */ + [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT + { + return graph_.view().extent(0); + } + + /** Graph degree */ + [[nodiscard]] constexpr inline auto graph_degree() const noexcept -> uint32_t + { + return graph_.view().extent(1); + } + + /** neighborhood graph [size, graph-degree] */ + [[nodiscard]] inline auto graph() noexcept + -> host_matrix_view + { + if constexpr (std::is_same_v or std::is_same_v) { + return raft::make_host_matrix_view( + reinterpret_cast(int_graph_.data_handle()), + int_graph_.extent(0), + int_graph_.extent(1)); + } + else { + graph_ = raft::make_host_matrix(int_graph_.extent(0), int_graph_.extent(1)); + std::copy(graph_.data_handle(), graph_.data_handle() + graph_.size(), int_graph_.data_handle()); + return graph_.view(); + } + } + + /** int type graph */ + [[nodiscard]] inline auto int_graph() noexcept + -> host_matrix_view { + return int_graph_.view(); + } + + // Don't allow copying the index for performance reasons (try avoiding copying data) + index(const index&) = delete; + index(index&&) = default; + auto operator=(const index&) -> index& = delete; + auto operator=(index&&) -> index& = default; + ~index() = default; + +private: + raft::resources const& res_; + raft::distance::DistanceType metric_; + raft::host_matrix int_graph_; // nn-descent only supports int IdxT graphs + raft::host_matrix graph_; // graph to return for non-int IdxT +}; + +/** @} */ + +} \ No newline at end of file From 0c1a6fe135f42a7b5190e35a55e01d1047d837ff Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 17 Aug 2023 18:34:47 -0700 Subject: [PATCH 02/28] nn-descent tests stuck indefinitely --- cpp/include/raft/neighbors/cagra_types.hpp | 1 + cpp/include/raft/neighbors/detail/nn_descent.cuh | 6 ++++-- cpp/include/raft/neighbors/nn_descent.cuh | 2 +- cpp/include/raft/neighbors/nn_descent_types.hpp | 2 +- cpp/test/neighbors/ann_cagra.cuh | 8 ++++++++ 5 files changed, 15 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index fa293e9d5c..935d15edb8 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -355,6 +355,7 @@ struct index : ann::index { namespace raft::neighbors::experimental::cagra { using raft::neighbors::cagra::hash_mode; using raft::neighbors::cagra::index; +using raft::neighbors::cagra::graph_build_algo; using raft::neighbors::cagra::index_params; using raft::neighbors::cagra::search_algo; using raft::neighbors::cagra::search_params; diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index 3315b2e09a..20455f1725 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -1129,6 +1129,7 @@ void GNND::local_join(cudaStream_t stream) { template void GNND::build(Data_t* data, const Index_t nrow, Index_t* output_graph, cudaStream_t stream) { + cudaStreamSynchronize(stream); nrow_ = nrow; graph_.h_graph = (InternalID_t*)output_graph; @@ -1269,8 +1270,8 @@ template build(raft::resources const& res, const index_params& params, mdspan, row_major, Accessor> dataset) { - RAFT_EXPECTS(dataset.size() >= std::numeric_limits::max() - 1, - "The dataset_size for GNND should be less than %d", + RAFT_EXPECTS(dataset.size() < std::numeric_limits::max() - 1, + "The dataset size for GNND should be less than %d", std::numeric_limits::max() - 1); size_t intermediate_degree = params.intermediate_graph_degree; size_t graph_degree = params.graph_degree; @@ -1300,6 +1301,7 @@ index build(raft::resources const& res, .metric_type = Metric_t::METRIC_L2}; GNND nnd(build_config); + std::cout << "graph dim: " << idx.int_graph().extent(0) << ", " << idx.int_graph().extent(1) << std::endl; nnd.build(dataset.data_handle(), dataset.extent(0), idx.int_graph().data_handle(), resource::get_cuda_stream(res)); nnd.dealloc(); diff --git a/cpp/include/raft/neighbors/nn_descent.cuh b/cpp/include/raft/neighbors/nn_descent.cuh index 25fe9675bc..408d590aca 100644 --- a/cpp/include/raft/neighbors/nn_descent.cuh +++ b/cpp/include/raft/neighbors/nn_descent.cuh @@ -64,4 +64,4 @@ index build(raft::resources const& res, /** @} */ // end group cagra -} \ No newline at end of file +} diff --git a/cpp/include/raft/neighbors/nn_descent_types.hpp b/cpp/include/raft/neighbors/nn_descent_types.hpp index a793857773..8a8e97497e 100644 --- a/cpp/include/raft/neighbors/nn_descent_types.hpp +++ b/cpp/include/raft/neighbors/nn_descent_types.hpp @@ -121,4 +121,4 @@ struct index : ann::index { /** @} */ -} \ No newline at end of file +} diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 89cb070afc..c10d7c1299 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -130,6 +130,7 @@ struct AnnCagraInputs { int n_rows; int dim; int k; + graph_build_algo build_algo; search_algo algo; int max_queries; int team_size; @@ -198,6 +199,7 @@ class AnnCagraTest : public ::testing::TestWithParam { cagra::index_params index_params; index_params.metric = ps.metric; // Note: currently ony the cagra::index_params metric is // not used for knn_graph building. + index_params.build_algo = ps.build_algo; cagra::search_params search_params; search_params.algo = ps.algo; search_params.max_queries = ps.max_queries; @@ -372,6 +374,7 @@ inline std::vector generate_inputs() {1000}, {1, 8, 17}, {1, 16}, // k + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, {0, 1, 10, 100}, // query size {0}, @@ -386,6 +389,7 @@ inline std::vector generate_inputs() {1000}, {1, 3, 5, 7, 8, 17, 64, 128, 137, 192, 256, 512, 619, 1024}, // dim {16}, // k + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, {search_algo::AUTO}, {10}, {0}, @@ -400,6 +404,7 @@ inline std::vector generate_inputs() {1000}, {64}, {16}, + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, {search_algo::AUTO}, {10}, {0, 4, 8, 16, 32}, // team_size @@ -415,6 +420,7 @@ inline std::vector generate_inputs() {1000}, {64}, {16}, + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, {search_algo::AUTO}, {10}, {0}, // team_size @@ -430,6 +436,7 @@ inline std::vector generate_inputs() {10000, 20000}, {32}, {10}, + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, {search_algo::AUTO}, {10}, {0}, // team_size @@ -445,6 +452,7 @@ inline std::vector generate_inputs() {10000, 20000}, {32}, {10}, + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, {search_algo::AUTO}, {10}, {0}, // team_size From 7ad4c8afefc64306a63ff72531ed7bdba49b7596 Mon Sep 17 00:00:00 2001 From: Ray Wang Date: Mon, 21 Aug 2023 14:24:21 +0000 Subject: [PATCH 03/28] Fix the bug of unexpected hang --- .../raft/neighbors/detail/nn_descent.cuh | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index 20455f1725..2272a61638 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -1062,8 +1062,7 @@ template GNND::GNND(const BuildConfig& build_config) : build_config_(build_config), graph_(build_config.max_dataset_size, - to_multiple_of_32(build_config.node_degree * - (build_config.node_degree <= 32 ? 1.0 : 1.3)), + to_multiple_of_32(build_config.node_degree), to_multiple_of_32(build_config.internal_node_degree ? build_config.internal_node_degree : build_config.node_degree), NUM_SAMPLES), @@ -1090,6 +1089,7 @@ void GNND::alloc_workspace() { reinterpret_cast(graph_buffer_) + (size_t)nrow_ * DEGREE_ON_DEVICE, std::numeric_limits::max()); RAFT_CUDA_TRY(cudaMalloc(&d_locks_, sizeof(*d_locks_) * nrow_)); + thrust::fill(thrust::device, d_locks_, d_locks_ + nrow_, 0); RAFT_CUDA_TRY( cudaMallocHost(&h_rev_graph_new_, sizeof(*h_rev_graph_new_) * nrow_ * NUM_SAMPLES)); RAFT_CUDA_TRY(cudaMallocHost(&h_graph_old_, sizeof(*h_graph_old_) * nrow_ * NUM_SAMPLES)); @@ -1239,12 +1239,11 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out graph_shrink_buffer[i * build_config_.node_degree + j]; } } - - graph_.dealloc(); } template void GNND::dealloc() { + graph_.dealloc(); RAFT_CUDA_TRY(cudaFree(d_data_)); RAFT_CUDA_TRY(cudaFreeHost(graph_host_buffer_)); RAFT_CUDA_TRY(cudaFreeHost(dists_host_buffer_)); @@ -1289,22 +1288,32 @@ index build(raft::resources const& res, intermediate_degree); graph_degree = intermediate_degree; } - - index idx{res, dataset.extent(0), static_cast(graph_degree)}; + // The elements in each knn-list are partitioned into different buckets, and we need more buckets + // to mitigate bucket collisions. `intermediate_degree` is OK to larger than extended_graph_degree. + size_t extended_graph_degree = to_multiple_of_32(graph_degree * (graph_degree <= 32 ? 1.0 : 1.3)); + index int_idx{res, dataset.extent(0), static_cast(extended_graph_degree)}; BuildConfig build_config{.max_dataset_size = static_cast(dataset.extent(0)), .dataset_dim = static_cast(dataset.extent(1)), - .node_degree = graph_degree, + .node_degree = extended_graph_degree, .internal_node_degree = intermediate_degree, .max_iterations = params.max_iterations, .termination_threshold = params.termination_threshold, .metric_type = Metric_t::METRIC_L2}; GNND nnd(build_config); - std::cout << "graph dim: " << idx.int_graph().extent(0) << ", " << idx.int_graph().extent(1) << std::endl; - nnd.build(dataset.data_handle(), dataset.extent(0), idx.int_graph().data_handle(), resource::get_cuda_stream(res)); + std::cout << "Intermediate graph dim: " << int_idx.int_graph().extent(0) << ", " << int_idx.int_graph().extent(1) << std::endl; + nnd.build(dataset.data_handle(), dataset.extent(0), int_idx.int_graph().data_handle(), resource::get_cuda_stream(res)); nnd.dealloc(); - + index idx{res, dataset.extent(0), static_cast(graph_degree)}; +#pragma omp parallel for + for (size_t i = 0; i < static_cast(dataset.extent(0)); i++) { + for (size_t j = 0; j < graph_degree; j++) { + auto graph = idx.int_graph().data_handle(); + auto int_graph = int_idx.int_graph().data_handle(); + graph[i * graph_degree + j] = int_graph[i * extended_graph_degree + j]; + } + } return idx; } From 4ccf3a752fdbb06276acb2d41142f5f3309603c7 Mon Sep 17 00:00:00 2001 From: Ray Wang Date: Tue, 22 Aug 2023 09:50:26 +0000 Subject: [PATCH 04/28] Fix bugs that cause unit tests to fail --- .../raft/neighbors/detail/nn_descent.cuh | 78 +++++++++++-------- 1 file changed, 47 insertions(+), 31 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index 2272a61638..3c6a595f5c 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -24,6 +24,7 @@ #include #include #include +#include #include "../nn_descent_types.hpp" @@ -372,7 +373,7 @@ __device__ __forceinline__ void load_vec(Data_t *vec_buffer, const Data_t *d_vec const int lane_id) { if constexpr (std::is_same_v or std::is_same_v or std::is_same_v) { constexpr int num_load_elems_per_warp = WARP_SIZE; - for (int step = 0; step < div_up(load_dims, num_load_elems_per_warp); step++) { + for (int step = 0; step < div_up(padding_dims, num_load_elems_per_warp); step++) { int idx = step * num_load_elems_per_warp + lane_id; if (idx < load_dims) { vec_buffer[idx] = d_vec[idx]; @@ -381,12 +382,12 @@ __device__ __forceinline__ void load_vec(Data_t *vec_buffer, const Data_t *d_vec } } } - if constexpr (std::is_same::value) { - if ((size_t)vec_buffer % sizeof(float2) == 0 && load_dims % 4 == 0 && - padding_dims % 4 == 0) { + if constexpr (std::is_same_v) { + if ((size_t)d_vec % sizeof(float2) == 0 && (size_t)vec_buffer % sizeof(float2) == 0 && + load_dims % 4 == 0 && padding_dims % 4 == 0) { constexpr int num_load_elems_per_warp = WARP_SIZE * 4; #pragma unroll - for (int step = 0; step < div_up(load_dims, num_load_elems_per_warp); step++) { + for (int step = 0; step < div_up(padding_dims, num_load_elems_per_warp); step++) { int idx_in_vec = step * num_load_elems_per_warp + lane_id * 4; if (idx_in_vec + 4 <= load_dims) { *(float2 *)(vec_buffer + idx_in_vec) = *(float2 *)(d_vec + idx_in_vec); @@ -396,7 +397,7 @@ __device__ __forceinline__ void load_vec(Data_t *vec_buffer, const Data_t *d_vec } } else { constexpr int num_load_elems_per_warp = WARP_SIZE; - for (int step = 0; step < div_up(load_dims, num_load_elems_per_warp); step++) { + for (int step = 0; step < div_up(padding_dims, num_load_elems_per_warp); step++) { int idx = step * num_load_elems_per_warp + lane_id; if (idx < load_dims) { vec_buffer[idx] = d_vec[idx]; @@ -408,15 +409,15 @@ __device__ __forceinline__ void load_vec(Data_t *vec_buffer, const Data_t *d_vec } } -template -__global__ void preprocess_data_kernel(const Data_t *input_data, __half *output_data, Index_t nrow, - int dim, DistData_t *l2_norms) { +template +__global__ void preprocess_data_kernel(const Data_t* input_data, __half* output_data, int dim, + DistData_t* l2_norms, size_t list_offset = 0) { extern __shared__ char buffer[]; __shared__ float l2_norm; Data_t *s_vec = (Data_t *)buffer; - size_t list_id = blockIdx.x; + size_t list_id = list_offset + blockIdx.x; - load_vec(s_vec, input_data + list_id * dim, dim, dim, threadIdx.x % WARP_SIZE); + load_vec(s_vec, input_data + blockIdx.x * dim, dim, dim, threadIdx.x % WARP_SIZE); if (threadIdx.x == 0) { l2_norm = 0; } @@ -443,9 +444,10 @@ __global__ void preprocess_data_kernel(const Data_t *input_data, __half *output_ int idx = step * WARP_SIZE + threadIdx.x; if (idx < dim) { if (l2_norms == nullptr) { - output_data[list_id * dim + idx] = (float)input_data[list_id * dim + idx] / sqrt(l2_norm); + output_data[list_id * dim + idx] = + (float)input_data[(size_t)blockIdx.x * dim + idx] / sqrt(l2_norm); } else { - output_data[list_id * dim + idx] = input_data[list_id * dim + idx]; + output_data[list_id * dim + idx] = input_data[(size_t)blockIdx.x * dim + idx]; if (idx == 0) { l2_norms[list_id] = l2_norm; } @@ -475,8 +477,7 @@ __global__ void add_rev_edges_kernel(const Index_t *graph, Index_t *rev_graph, i template > __device__ void insert_to_global_graph(ResultItem elem, size_t list_id, ID_t *graph, - DistData_t *dists, int node_degree, int *locks, - bool new_new = true) { + DistData_t *dists, int node_degree, int *locks) { int tx = threadIdx.x; int lane_id = tx % WARP_SIZE; size_t global_idx_base = list_id * node_degree; @@ -760,8 +761,7 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 4) if (idx_in_list >= list_new_size) continue; auto min_elem = get_min_item(s_list[idx_in_list], idx_in_list, new_neighbors, s_distances); if (min_elem.id() < gridDim.x) { - insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks, - true); + insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks); } } @@ -851,8 +851,7 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 4) } if (min_elem.id() < gridDim.x) { - insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks, - false); + insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks); } } } @@ -1113,7 +1112,7 @@ void GNND::add_reverse_edges(Index_t* graph_ptr, Index_t* h_rev list_sizes); RAFT_CUDA_TRY(cudaMemcpyAsync(h_rev_graph_ptr, d_rev_graph_ptr, sizeof(*h_rev_graph_ptr) * nrow_ * NUM_SAMPLES, - cudaMemcpyDeviceToHost, stream)); + cudaMemcpyDefault, stream)); } template @@ -1136,14 +1135,31 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out cudaPointerAttributes data_ptr_attr; RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data)); if (data_ptr_attr.type == cudaMemoryTypeUnregistered) { - RAFT_CUDA_TRY(cudaHostRegister(const_cast*>(data), sizeof(Data_t) * nrow * build_config_.dataset_dim, - cudaHostRegisterDefault)); + typename std::remove_const::type* input_data; + size_t batch_size = 100000; + RAFT_CUDA_TRY(cudaMallocAsync(&input_data, + sizeof(Data_t) * batch_size * build_config_.dataset_dim, stream)); + for (size_t step = 0; step < div_up(nrow_, batch_size); step++) { + size_t list_offset = step * batch_size; + size_t num_lists = + step != div_up(nrow_, batch_size) - 1 ? batch_size : nrow_ - list_offset; + RAFT_CUDA_TRY(cudaMemcpyAsync( + input_data, data + list_offset * build_config_.dataset_dim, + sizeof(Data_t) * num_lists * build_config_.dataset_dim, cudaMemcpyDefault, stream)); + preprocess_data_kernel<<>>(input_data, d_data_, build_config_.dataset_dim, + l2_norms_, list_offset); + } + RAFT_CUDA_TRY(cudaFreeAsync(input_data, stream)); + } else { + preprocess_data_kernel<<< + nrow_, WARP_SIZE, + sizeof(Data_t) * div_up(build_config_.dataset_dim, WARP_SIZE) * WARP_SIZE, stream>>>( + data, d_data_, build_config_.dataset_dim, l2_norms_); } - preprocess_data_kernel<<< - nrow_, WARP_SIZE, sizeof(Data_t) * div_up(build_config_.dataset_dim, WARP_SIZE) * WARP_SIZE, - stream>>>(data, d_data_, nrow_, build_config_.dataset_dim, l2_norms_); - thrust::fill(thrust::device.on(stream), (Index_t*)graph_buffer_, (Index_t*)graph_buffer_ + (size_t)nrow_ * DEGREE_ON_DEVICE, std::numeric_limits::max()); @@ -1168,13 +1184,13 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out for (size_t it = 0; it < build_config_.max_iterations; it++) { RAFT_CUDA_TRY(cudaMemcpyAsync(d_list_sizes_new_, graph_.h_list_sizes_new, sizeof(*d_list_sizes_new_) * nrow_, - cudaMemcpyHostToDevice, stream)); + cudaMemcpyDefault, stream)); RAFT_CUDA_TRY(cudaMemcpyAsync(h_graph_old_, graph_.h_graph_old, sizeof(*h_graph_old_) * nrow_ * NUM_SAMPLES, - cudaMemcpyHostToHost, stream)); + cudaMemcpyDefault, stream)); RAFT_CUDA_TRY(cudaMemcpyAsync(d_list_sizes_old_, graph_.h_list_sizes_old, sizeof(*d_list_sizes_old_) * nrow_, - cudaMemcpyHostToDevice, stream)); + cudaMemcpyDefault, stream)); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); std::thread update_and_sample_thread(update_and_sample, it); @@ -1201,11 +1217,11 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out RAFT_CUDA_TRY(cudaMemcpyAsync(graph_host_buffer_, graph_buffer_, sizeof(*graph_buffer_) * nrow_ * DEGREE_ON_DEVICE, - cudaMemcpyDeviceToHost, stream)); + cudaMemcpyDefault, stream)); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); RAFT_CUDA_TRY(cudaMemcpyAsync(dists_host_buffer_, dists_buffer_, sizeof(*dists_buffer_) * nrow_ * DEGREE_ON_DEVICE, - cudaMemcpyDeviceToHost, stream)); + cudaMemcpyDefault, stream)); graph_.sample_graph_new(graph_host_buffer_, DEGREE_ON_DEVICE); } From 40e1cf0695257cfd6918d4c727b239e8bb04eb9d Mon Sep 17 00:00:00 2001 From: Ray Wang Date: Tue, 22 Aug 2023 10:57:31 +0000 Subject: [PATCH 05/28] Fix duplicate nodes issue --- cpp/include/raft/neighbors/detail/nn_descent.cuh | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index 3c6a595f5c..ad629d414b 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -944,18 +944,13 @@ void GnndGraph::init_random_graph() { #pragma omp parallel for for (size_t i = 0; i < nrow; i++) { - for (size_t j = 0; j < NUM_SAMPLES; j++) { - size_t idx = i * NUM_SAMPLES + j; + for (size_t j = 0; j < node_degree; j++) { + size_t idx = i * node_degree + j; Index_t id = rand_seq[idx % nrow]; if ((size_t)id == i) { - id = rand_seq[(idx + NUM_SAMPLES) % nrow]; + id = rand_seq[(idx + node_degree) % nrow]; } h_graph[i * node_degree + j].id_with_flag() = id; - } - for (size_t j = NUM_SAMPLES; j < node_degree; j++) { - h_graph[i * node_degree + j].id_with_flag() = std::numeric_limits::max(); - } - for (size_t j = 0; j < node_degree; j++) { h_dists[i * node_degree + j] = std::numeric_limits::max(); } } From 33f5ebc4455f0f2eaf2e4526ae7fd68c15f9ffa4 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 22 Aug 2023 17:58:52 -0700 Subject: [PATCH 06/28] passing tests --- cpp/include/raft/neighbors/cagra.cuh | 28 +- .../neighbors/detail/cagra/cagra_build.cuh | 27 + .../raft/neighbors/detail/nn_descent.cuh | 2170 +++++++++-------- .../raft/neighbors/nn_descent_types.hpp | 57 +- cpp/test/neighbors/ann_cagra.cuh | 135 +- 5 files changed, 1273 insertions(+), 1144 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 785fb78868..a544f7044c 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -25,7 +25,6 @@ #include #include #include -#include #include namespace raft::neighbors::cagra { @@ -53,8 +52,8 @@ namespace raft::neighbors::cagra { * @code{.cpp} * using namespace raft::neighbors; * // use default index parameters - * cagra::index_params build_params; - * cagra::search_params search_params + * ivf_pq::index_params build_params; + * ivf_pq::search_params search_params * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); * // create knn graph * cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params); @@ -96,6 +95,18 @@ void build_knn_graph(raft::resources const& res, res, dataset_internal, knn_graph_internal, refine_rate, build_params, search_params); } +template , memory_type::device>> +nn_descent::index build_knn_graph( + raft::resources const& res, + mdspan, row_major, accessor> dataset, + std::optional build_params = std::nullopt) +{ + return detail::build_knn_graph(res, dataset, build_params); +} + /** * @brief Sort a KNN graph index. * Preprocessing step for `cagra::optimize`: If a KNN graph is not built using @@ -265,12 +276,11 @@ index build(raft::resources const& res, build_knn_graph(res, dataset, knn_graph.view()); optimize(res, knn_graph.view(), cagra_graph.view()); - } - else { - nn_descent::index_params nn_descent_params; - nn_descent_params.intermediate_graph_degree = intermediate_degree; - nn_descent_params.graph_degree = graph_degree; - auto nn_descent_index = nn_descent::build(res, nn_descent_params, dataset); + } else { + auto nn_descent_params = std::make_optional(); + nn_descent_params->intermediate_graph_degree = intermediate_degree; + nn_descent_params->graph_degree = intermediate_degree; + auto nn_descent_index = build_knn_graph(res, dataset, nn_descent_params); optimize(res, nn_descent_index.graph(), cagra_graph.view()); } diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index d19d7e7904..d67394806a 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -34,6 +34,7 @@ #include #include #include +#include #include namespace raft::neighbors::cagra::detail { @@ -238,4 +239,30 @@ void build_knn_graph(raft::resources const& res, if (!first) RAFT_LOG_DEBUG("# Finished building kNN graph"); } +template +nn_descent::index build_knn_graph( + raft::resources const& res, + mdspan, row_major, accessor> dataset, + std::optional build_params = std::nullopt) +{ + if (!build_params) { build_params = std::make_optional(); } + + auto nn_descent_idx = nn_descent::build(res, *build_params, dataset); + + using internal_IdxT = typename std::make_unsigned::type; + using g_accessor = typename decltype(nn_descent_idx.graph())::accessor_type; + using g_accessor_internal = + host_device_accessor, g_accessor::mem_type>; + + auto knn_graph_internal = + mdspan, row_major, g_accessor_internal>( + reinterpret_cast(nn_descent_idx.graph().data_handle()), + nn_descent_idx.graph().extent(0), + nn_descent_idx.graph().extent(1)); + + graph::sort_knn_graph(res, dataset, knn_graph_internal); + + return nn_descent_idx; +} + } // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index ad629d414b..74d27daac1 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -22,9 +22,9 @@ #include #include +#include #include #include -#include #include "../nn_descent_types.hpp" @@ -43,23 +43,26 @@ struct InternalID_t; // InternalID_t uses 1 bit for marking (new or old). template <> class InternalID_t { - private: - using Index_t = int; - Index_t id_{std::numeric_limits::max()}; - - public: - __host__ __device__ bool is_new() const { return id_ >= 0; } - __host__ __device__ Index_t& id_with_flag() { return id_; } - __host__ __device__ Index_t id() const { - if (is_new()) return id_; - return -id_ - 1; - } - __host__ __device__ void mark_old() { - if (id_ >= 0) id_ = -id_ - 1; - } - __host__ __device__ bool operator==(const InternalID_t& other) const { - return id() == other.id(); - } + private: + using Index_t = int; + Index_t id_{std::numeric_limits::max()}; + + public: + __host__ __device__ bool is_new() const { return id_ >= 0; } + __host__ __device__ Index_t& id_with_flag() { return id_; } + __host__ __device__ Index_t id() const + { + if (is_new()) return id_; + return -id_ - 1; + } + __host__ __device__ void mark_old() + { + if (id_ >= 0) id_ = -id_ - 1; + } + __host__ __device__ bool operator==(const InternalID_t& other) const + { + return id() == other.id(); + } }; template @@ -67,291 +70,319 @@ struct ResultItem; template <> class ResultItem { - private: - using Index_t = int; - Index_t id_; - DistData_t dist_; - - public: - __host__ __device__ ResultItem() - : id_(std::numeric_limits::max()), dist_(std::numeric_limits::max()){}; - __host__ __device__ ResultItem(const Index_t id_with_flag, const DistData_t dist) - : id_(id_with_flag), dist_(dist){}; - __host__ __device__ bool is_new() const { return id_ >= 0; } - __host__ __device__ Index_t& id_with_flag() { return id_; } - __host__ __device__ Index_t id() const { - if (is_new()) return id_; - return -id_ - 1; - } - __host__ __device__ DistData_t& dist() { return dist_; } + private: + using Index_t = int; + Index_t id_; + DistData_t dist_; + + public: + __host__ __device__ ResultItem() + : id_(std::numeric_limits::max()), dist_(std::numeric_limits::max()){}; + __host__ __device__ ResultItem(const Index_t id_with_flag, const DistData_t dist) + : id_(id_with_flag), dist_(dist){}; + __host__ __device__ bool is_new() const { return id_ >= 0; } + __host__ __device__ Index_t& id_with_flag() { return id_; } + __host__ __device__ Index_t id() const + { + if (is_new()) return id_; + return -id_ - 1; + } + __host__ __device__ DistData_t& dist() { return dist_; } - __host__ __device__ void mark_old() { - if (id_ >= 0) id_ = -id_ - 1; - } + __host__ __device__ void mark_old() + { + if (id_ >= 0) id_ = -id_ - 1; + } - __host__ __device__ bool operator<(const ResultItem& other) const { - if (dist_ == other.dist_) return id() < other.id(); - return dist_ < other.dist_; - } - __host__ __device__ bool operator==(const ResultItem& other) const { - return id() == other.id(); - } - __host__ __device__ bool operator>=(const ResultItem& other) const { - return !(*this < other); - } - __host__ __device__ bool operator<=(const ResultItem& other) const { - return (*this == other) || (*this < other); - } - __host__ __device__ bool operator>(const ResultItem& other) const { - return !(*this <= other); - } - __host__ __device__ bool operator!=(const ResultItem& other) const { - return !(*this == other); - } + __host__ __device__ bool operator<(const ResultItem& other) const + { + if (dist_ == other.dist_) return id() < other.id(); + return dist_ < other.dist_; + } + __host__ __device__ bool operator==(const ResultItem& other) const + { + return id() == other.id(); + } + __host__ __device__ bool operator>=(const ResultItem& other) const + { + return !(*this < other); + } + __host__ __device__ bool operator<=(const ResultItem& other) const + { + return (*this == other) || (*this < other); + } + __host__ __device__ bool operator>(const ResultItem& other) const + { + return !(*this <= other); + } + __host__ __device__ bool operator!=(const ResultItem& other) const + { + return !(*this == other); + } }; -constexpr __host__ __device__ size_t div_up(const size_t a, const size_t b) { - return a / b + (a % b != 0); +constexpr __host__ __device__ size_t div_up(const size_t a, const size_t b) +{ + return a / b + (a % b != 0); } constexpr int to_multiple_of_32(int number) { return div_up(number, 32) * 32; } -constexpr int WARP_SIZE = 32; +constexpr int WARP_SIZE = 32; constexpr unsigned int FULL_MASK = 0xffffffff; template -int get_batch_size(const int it_now, const T nrow, const int batch_size) { - int it_total = div_up(nrow, batch_size); - return (it_now == it_total - 1) ? nrow - it_now * batch_size : batch_size; +int get_batch_size(const int it_now, const T nrow, const int batch_size) +{ + int it_total = div_up(nrow, batch_size); + return (it_now == it_total - 1) ? nrow - it_now * batch_size : batch_size; } // for avoiding bank conflict template -constexpr __host__ __device__ __forceinline__ int skew_dim(int ndim) { - // all "4"s are for alignment - if constexpr (std::is_same::value) { - ndim = div_up(ndim, 4) * 4; - return ndim + (ndim % 32 == 0) * 4; - } +constexpr __host__ __device__ __forceinline__ int skew_dim(int ndim) +{ + // all "4"s are for alignment + if constexpr (std::is_same::value) { + ndim = div_up(ndim, 4) * 4; + return ndim + (ndim % 32 == 0) * 4; + } } template -__device__ __forceinline__ ResultItem xor_swap(ResultItem x, int mask, int dir) { - ResultItem y; - y.dist() = __shfl_xor_sync(FULL_MASK, x.dist(), mask, WARP_SIZE); - y.id_with_flag() = __shfl_xor_sync(FULL_MASK, x.id_with_flag(), mask, WARP_SIZE); - return x < y == dir ? y : x; +__device__ __forceinline__ ResultItem xor_swap(ResultItem x, int mask, int dir) +{ + ResultItem y; + y.dist() = __shfl_xor_sync(FULL_MASK, x.dist(), mask, WARP_SIZE); + y.id_with_flag() = __shfl_xor_sync(FULL_MASK, x.id_with_flag(), mask, WARP_SIZE); + return x < y == dir ? y : x; } -__device__ __forceinline__ int xor_swap(int x, int mask, int dir) { - int y = __shfl_xor_sync(FULL_MASK, x, mask, WARP_SIZE); - return x < y == dir ? y : x; +__device__ __forceinline__ int xor_swap(int x, int mask, int dir) +{ + int y = __shfl_xor_sync(FULL_MASK, x, mask, WARP_SIZE); + return x < y == dir ? y : x; } -__device__ __forceinline__ uint bfe(uint lane_id, uint pos) { - uint res; - asm("bfe.u32 %0,%1,%2,%3;" : "=r"(res) : "r"(lane_id), "r"(pos), "r"(1)); - return res; +__device__ __forceinline__ uint bfe(uint lane_id, uint pos) +{ + uint res; + asm("bfe.u32 %0,%1,%2,%3;" : "=r"(res) : "r"(lane_id), "r"(pos), "r"(1)); + return res; } // https://en.wikipedia.org/wiki/Xorshift#xorshift* -__host__ __device__ __forceinline__ uint64_t xorshift64(uint64_t x) { - x ^= x >> 12; - x ^= x << 25; - x ^= x >> 27; - return x * 0x2545F4914F6CDD1DULL; +__host__ __device__ __forceinline__ uint64_t xorshift64(uint64_t x) +{ + x ^= x >> 12; + x ^= x << 25; + x ^= x >> 27; + return x * 0x2545F4914F6CDD1DULL; } template -__device__ __forceinline__ void warp_bitonic_sort(T* element_ptr, const int lane_id) { - static_assert(WARP_SIZE == 32); - auto& element = *element_ptr; - element = xor_swap(element, 0x01, bfe(lane_id, 1) ^ bfe(lane_id, 0)); - element = xor_swap(element, 0x02, bfe(lane_id, 2) ^ bfe(lane_id, 1)); - element = xor_swap(element, 0x01, bfe(lane_id, 2) ^ bfe(lane_id, 0)); - element = xor_swap(element, 0x04, bfe(lane_id, 3) ^ bfe(lane_id, 2)); - element = xor_swap(element, 0x02, bfe(lane_id, 3) ^ bfe(lane_id, 1)); - element = xor_swap(element, 0x01, bfe(lane_id, 3) ^ bfe(lane_id, 0)); - element = xor_swap(element, 0x08, bfe(lane_id, 4) ^ bfe(lane_id, 3)); - element = xor_swap(element, 0x04, bfe(lane_id, 4) ^ bfe(lane_id, 2)); - element = xor_swap(element, 0x02, bfe(lane_id, 4) ^ bfe(lane_id, 1)); - element = xor_swap(element, 0x01, bfe(lane_id, 4) ^ bfe(lane_id, 0)); - element = xor_swap(element, 0x10, bfe(lane_id, 4)); - element = xor_swap(element, 0x08, bfe(lane_id, 3)); - element = xor_swap(element, 0x04, bfe(lane_id, 2)); - element = xor_swap(element, 0x02, bfe(lane_id, 1)); - element = xor_swap(element, 0x01, bfe(lane_id, 0)); - return; +__device__ __forceinline__ void warp_bitonic_sort(T* element_ptr, const int lane_id) +{ + static_assert(WARP_SIZE == 32); + auto& element = *element_ptr; + element = xor_swap(element, 0x01, bfe(lane_id, 1) ^ bfe(lane_id, 0)); + element = xor_swap(element, 0x02, bfe(lane_id, 2) ^ bfe(lane_id, 1)); + element = xor_swap(element, 0x01, bfe(lane_id, 2) ^ bfe(lane_id, 0)); + element = xor_swap(element, 0x04, bfe(lane_id, 3) ^ bfe(lane_id, 2)); + element = xor_swap(element, 0x02, bfe(lane_id, 3) ^ bfe(lane_id, 1)); + element = xor_swap(element, 0x01, bfe(lane_id, 3) ^ bfe(lane_id, 0)); + element = xor_swap(element, 0x08, bfe(lane_id, 4) ^ bfe(lane_id, 3)); + element = xor_swap(element, 0x04, bfe(lane_id, 4) ^ bfe(lane_id, 2)); + element = xor_swap(element, 0x02, bfe(lane_id, 4) ^ bfe(lane_id, 1)); + element = xor_swap(element, 0x01, bfe(lane_id, 4) ^ bfe(lane_id, 0)); + element = xor_swap(element, 0x10, bfe(lane_id, 4)); + element = xor_swap(element, 0x08, bfe(lane_id, 3)); + element = xor_swap(element, 0x04, bfe(lane_id, 2)); + element = xor_swap(element, 0x02, bfe(lane_id, 1)); + element = xor_swap(element, 0x01, bfe(lane_id, 0)); + return; } enum class Metric_t { - METRIC_INNER_PRODUCT = 0, - METRIC_L2 = 1, + METRIC_INNER_PRODUCT = 0, + METRIC_L2 = 1, }; struct BuildConfig { - size_t max_dataset_size; - size_t dataset_dim; - size_t node_degree{64}; - size_t internal_node_degree{0}; - // If internal_node_degree == 0, the value of node_degree will be assigned to it - size_t max_iterations{50}; - float termination_threshold{0.0001}; - Metric_t metric_type{Metric_t::METRIC_INNER_PRODUCT}; + size_t max_dataset_size; + size_t dataset_dim; + size_t node_degree{64}; + size_t internal_node_degree{0}; + // If internal_node_degree == 0, the value of node_degree will be assigned to it + size_t max_iterations{50}; + float termination_threshold{0.0001}; + Metric_t metric_type{Metric_t::METRIC_INNER_PRODUCT}; }; template class BloomFilter { - public: - BloomFilter(size_t nrow, size_t num_sets_per_list, size_t num_hashs) - : nrow_(nrow), - num_sets_per_list_(num_sets_per_list), - num_hashs_(num_hashs), - bitsets_(nrow * num_bits_per_set_ * num_sets_per_list) {} - - void add(size_t list_id, Index_t key) { - if (is_cleared) { - is_cleared = false; - } - uint32_t hash = hash_0(key); - size_t global_set_idx = list_id * num_bits_per_set_ * num_sets_per_list_ + - key % num_sets_per_list_ * num_bits_per_set_; - bitsets_[global_set_idx + hash % num_bits_per_set_] = 1; - for (size_t i = 1; i < num_hashs_; i++) { - hash = hash + hash_1(key); - bitsets_[global_set_idx + hash % num_bits_per_set_] = 1; - } + public: + BloomFilter(size_t nrow, size_t num_sets_per_list, size_t num_hashs) + : nrow_(nrow), + num_sets_per_list_(num_sets_per_list), + num_hashs_(num_hashs), + bitsets_(nrow * num_bits_per_set_ * num_sets_per_list) + { + } + + void add(size_t list_id, Index_t key) + { + if (is_cleared) { is_cleared = false; } + uint32_t hash = hash_0(key); + size_t global_set_idx = list_id * num_bits_per_set_ * num_sets_per_list_ + + key % num_sets_per_list_ * num_bits_per_set_; + bitsets_[global_set_idx + hash % num_bits_per_set_] = 1; + for (size_t i = 1; i < num_hashs_; i++) { + hash = hash + hash_1(key); + bitsets_[global_set_idx + hash % num_bits_per_set_] = 1; } + } - bool check(size_t list_id, Index_t key) { - bool is_present = true; - uint32_t hash = hash_0(key); - size_t global_set_idx = list_id * num_bits_per_set_ * num_sets_per_list_ + - key % num_sets_per_list_ * num_bits_per_set_; - is_present &= bitsets_[global_set_idx + hash % num_bits_per_set_]; - - if (!is_present) return false; - for (size_t i = 1; i < num_hashs_; i++) { - hash = hash + hash_1(key); - is_present &= bitsets_[global_set_idx + hash % num_bits_per_set_]; - if (!is_present) return false; - } - return true; + bool check(size_t list_id, Index_t key) + { + bool is_present = true; + uint32_t hash = hash_0(key); + size_t global_set_idx = list_id * num_bits_per_set_ * num_sets_per_list_ + + key % num_sets_per_list_ * num_bits_per_set_; + is_present &= bitsets_[global_set_idx + hash % num_bits_per_set_]; + + if (!is_present) return false; + for (size_t i = 1; i < num_hashs_; i++) { + hash = hash + hash_1(key); + is_present &= bitsets_[global_set_idx + hash % num_bits_per_set_]; + if (!is_present) return false; } + return true; + } - void clear() { - if (is_cleared) return; + void clear() + { + if (is_cleared) return; #pragma omp parallel for - for (size_t i = 0; i < nrow_ * num_bits_per_set_ * num_sets_per_list_; i++) { - bitsets_[i] = 0; - } - is_cleared = true; + for (size_t i = 0; i < nrow_ * num_bits_per_set_ * num_sets_per_list_; i++) { + bitsets_[i] = 0; } + is_cleared = true; + } - private: - uint32_t hash_0(uint32_t value) { - value *= 1103515245; - value += 12345; - value ^= value << 13; - value ^= value >> 17; - value ^= value << 5; - return value; - } + private: + uint32_t hash_0(uint32_t value) + { + value *= 1103515245; + value += 12345; + value ^= value << 13; + value ^= value >> 17; + value ^= value << 5; + return value; + } - uint32_t hash_1(uint32_t value) { - value *= 1664525; - value += 1013904223; - value ^= value << 13; - value ^= value >> 17; - value ^= value << 5; - return value; - } + uint32_t hash_1(uint32_t value) + { + value *= 1664525; + value += 1013904223; + value ^= value << 13; + value ^= value >> 17; + value ^= value << 5; + return value; + } - static constexpr int num_bits_per_set_ = 512; - bool is_cleared{true}; - std::vector bitsets_; - size_t nrow_; - size_t num_sets_per_list_; - size_t num_hashs_; + static constexpr int num_bits_per_set_ = 512; + bool is_cleared{true}; + std::vector bitsets_; + size_t nrow_; + size_t num_sets_per_list_; + size_t num_hashs_; }; template struct GnndGraph { - static constexpr int segment_size = 32; - InternalID_t* h_graph; - DistData_t* h_dists; - - size_t nrow; - size_t node_degree; - int num_samples; - int num_segments; - - Index_t* h_graph_new; - int2* h_list_sizes_new; - - Index_t* h_graph_old; - int2* h_list_sizes_old; - BloomFilter bloom_filter; - - GnndGraph(const GnndGraph&) = delete; - GnndGraph& operator=(const GnndGraph&) = delete; - GnndGraph(const size_t nrow, const size_t node_degree, const size_t internal_node_degree, - const size_t num_samples); - void init_random_graph(); - // Use Bloom filter to sample "new" neighbors for local joining - void sample_graph_new(InternalID_t* new_neighbors, const size_t width); - void sample_graph(bool sample_new); - void update_graph(const InternalID_t* new_neighbors, const DistData_t* new_dists, - const size_t width, std::atomic& update_counter); - void sort_lists(); - void clear(); - void dealloc(); - ~GnndGraph(); + static constexpr int segment_size = 32; + InternalID_t* h_graph; + DistData_t* h_dists; + + size_t nrow; + size_t node_degree; + int num_samples; + int num_segments; + + Index_t* h_graph_new; + int2* h_list_sizes_new; + + Index_t* h_graph_old; + int2* h_list_sizes_old; + BloomFilter bloom_filter; + + GnndGraph(const GnndGraph&) = delete; + GnndGraph& operator=(const GnndGraph&) = delete; + GnndGraph(const size_t nrow, + const size_t node_degree, + const size_t internal_node_degree, + const size_t num_samples); + void init_random_graph(); + // Use Bloom filter to sample "new" neighbors for local joining + void sample_graph_new(InternalID_t* new_neighbors, const size_t width); + void sample_graph(bool sample_new); + void update_graph(const InternalID_t* new_neighbors, + const DistData_t* new_dists, + const size_t width, + std::atomic& update_counter); + void sort_lists(); + void clear(); + void dealloc(); + ~GnndGraph(); }; template class GNND { - public: - GNND(const BuildConfig& build_config); - GNND(const GNND&) = delete; - GNND& operator=(const GNND&) = delete; - - // Use delete[] to deallocate the returned graph - void build(Data_t* data, const Index_t nrow, Index_t* output_graph, cudaStream_t stream = 0); - void dealloc(); - ~GNND(); - using ID_t = InternalID_t; - - private: - void add_reverse_edges(Index_t* graph_ptr, Index_t* h_rev_graph_ptr, Index_t* d_rev_graph_ptr, - int2* list_sizes, cudaStream_t stream = 0); - void local_join(cudaStream_t stream = 0); - void alloc_workspace(); - - BuildConfig build_config_; - GnndGraph graph_; - std::atomic update_counter_; - - __half* d_data_; - DistData_t* l2_norms_; - - ID_t* graph_buffer_; - DistData_t* dists_buffer_; - ID_t* graph_host_buffer_; - DistData_t* dists_host_buffer_; - - int* d_locks_; - - Index_t* h_rev_graph_new_; - // int2.x is the number of forward edges, int2.y is the number of reverse edges - int2* d_list_sizes_new_; - - Index_t* h_graph_old_; - Index_t* h_rev_graph_old_; - int2* d_list_sizes_old_; - - Index_t nrow_; - const int ndim_; + public: + GNND(const BuildConfig& build_config); + GNND(const GNND&) = delete; + GNND& operator=(const GNND&) = delete; + + // Use delete[] to deallocate the returned graph + void build(Data_t* data, const Index_t nrow, Index_t* output_graph, cudaStream_t stream = 0); + void dealloc(); + ~GNND(); + using ID_t = InternalID_t; + + private: + void add_reverse_edges(Index_t* graph_ptr, + Index_t* h_rev_graph_ptr, + Index_t* d_rev_graph_ptr, + int2* list_sizes, + cudaStream_t stream = 0); + void local_join(cudaStream_t stream = 0); + void alloc_workspace(); + + BuildConfig build_config_; + GnndGraph graph_; + std::atomic update_counter_; + + __half* d_data_; + DistData_t* l2_norms_; + + ID_t* graph_buffer_; + DistData_t* dists_buffer_; + ID_t* graph_host_buffer_; + DistData_t* dists_host_buffer_; + + int* d_locks_; + + Index_t* h_rev_graph_new_; + // int2.x is the number of forward edges, int2.y is the number of reverse edges + int2* d_list_sizes_new_; + + Index_t* h_graph_old_; + Index_t* h_rev_graph_old_; + int2* d_list_sizes_old_; + + Index_t nrow_; + const int ndim_; }; constexpr int TILE_ROW_WIDTH = 64; @@ -360,917 +391,983 @@ constexpr int TILE_COL_WIDTH = 128; constexpr int NUM_SAMPLES = 32; // For now, the max. number of samples is 32, so the sample cache size is fixed // to 64 (32 * 2). -constexpr int MAX_NUM_BI_SAMPLES = 64; +constexpr int MAX_NUM_BI_SAMPLES = 64; constexpr int SKEWED_MAX_NUM_BI_SAMPLES = skew_dim(MAX_NUM_BI_SAMPLES); -constexpr int BLOCK_SIZE = 512; -constexpr int WMMA_M = 16; -constexpr int WMMA_N = 16; -constexpr int WMMA_K = 16; +constexpr int BLOCK_SIZE = 512; +constexpr int WMMA_M = 16; +constexpr int WMMA_N = 16; +constexpr int WMMA_K = 16; template -__device__ __forceinline__ void load_vec(Data_t *vec_buffer, const Data_t *d_vec, - const int load_dims, const int padding_dims, - const int lane_id) { - if constexpr (std::is_same_v or std::is_same_v or std::is_same_v) { - constexpr int num_load_elems_per_warp = WARP_SIZE; - for (int step = 0; step < div_up(padding_dims, num_load_elems_per_warp); step++) { - int idx = step * num_load_elems_per_warp + lane_id; - if (idx < load_dims) { - vec_buffer[idx] = d_vec[idx]; - } else if (idx < padding_dims) { - vec_buffer[idx] = 0.0f; - } - } +__device__ __forceinline__ void load_vec(Data_t* vec_buffer, + const Data_t* d_vec, + const int load_dims, + const int padding_dims, + const int lane_id) +{ + if constexpr (std::is_same_v or std::is_same_v or + std::is_same_v) { + constexpr int num_load_elems_per_warp = WARP_SIZE; + for (int step = 0; step < div_up(padding_dims, num_load_elems_per_warp); step++) { + int idx = step * num_load_elems_per_warp + lane_id; + if (idx < load_dims) { + vec_buffer[idx] = d_vec[idx]; + } else if (idx < padding_dims) { + vec_buffer[idx] = 0.0f; + } } - if constexpr (std::is_same_v) { - if ((size_t)d_vec % sizeof(float2) == 0 && (size_t)vec_buffer % sizeof(float2) == 0 && - load_dims % 4 == 0 && padding_dims % 4 == 0) { - constexpr int num_load_elems_per_warp = WARP_SIZE * 4; + } + if constexpr (std::is_same_v) { + if ((size_t)d_vec % sizeof(float2) == 0 && (size_t)vec_buffer % sizeof(float2) == 0 && + load_dims % 4 == 0 && padding_dims % 4 == 0) { + constexpr int num_load_elems_per_warp = WARP_SIZE * 4; #pragma unroll - for (int step = 0; step < div_up(padding_dims, num_load_elems_per_warp); step++) { - int idx_in_vec = step * num_load_elems_per_warp + lane_id * 4; - if (idx_in_vec + 4 <= load_dims) { - *(float2 *)(vec_buffer + idx_in_vec) = *(float2 *)(d_vec + idx_in_vec); - } else if (idx_in_vec + 4 <= padding_dims) { - *(float2 *)(vec_buffer + idx_in_vec) = float2({0.0f, 0.0f}); - } - } - } else { - constexpr int num_load_elems_per_warp = WARP_SIZE; - for (int step = 0; step < div_up(padding_dims, num_load_elems_per_warp); step++) { - int idx = step * num_load_elems_per_warp + lane_id; - if (idx < load_dims) { - vec_buffer[idx] = d_vec[idx]; - } else if (idx < padding_dims) { - vec_buffer[idx] = 0.0f; - } - } + for (int step = 0; step < div_up(padding_dims, num_load_elems_per_warp); step++) { + int idx_in_vec = step * num_load_elems_per_warp + lane_id * 4; + if (idx_in_vec + 4 <= load_dims) { + *(float2*)(vec_buffer + idx_in_vec) = *(float2*)(d_vec + idx_in_vec); + } else if (idx_in_vec + 4 <= padding_dims) { + *(float2*)(vec_buffer + idx_in_vec) = float2({0.0f, 0.0f}); + } + } + } else { + constexpr int num_load_elems_per_warp = WARP_SIZE; + for (int step = 0; step < div_up(padding_dims, num_load_elems_per_warp); step++) { + int idx = step * num_load_elems_per_warp + lane_id; + if (idx < load_dims) { + vec_buffer[idx] = d_vec[idx]; + } else if (idx < padding_dims) { + vec_buffer[idx] = 0.0f; } + } } + } } template -__global__ void preprocess_data_kernel(const Data_t* input_data, __half* output_data, int dim, - DistData_t* l2_norms, size_t list_offset = 0) { - extern __shared__ char buffer[]; - __shared__ float l2_norm; - Data_t *s_vec = (Data_t *)buffer; - size_t list_id = list_offset + blockIdx.x; - - load_vec(s_vec, input_data + blockIdx.x * dim, dim, dim, threadIdx.x % WARP_SIZE); - if (threadIdx.x == 0) { - l2_norm = 0; +__global__ void preprocess_data_kernel(const Data_t* input_data, + __half* output_data, + int dim, + DistData_t* l2_norms, + size_t list_offset = 0) +{ + extern __shared__ char buffer[]; + __shared__ float l2_norm; + Data_t* s_vec = (Data_t*)buffer; + size_t list_id = list_offset + blockIdx.x; + + load_vec(s_vec, input_data + blockIdx.x * dim, dim, dim, threadIdx.x % WARP_SIZE); + if (threadIdx.x == 0) { l2_norm = 0; } + __syncthreads(); + int lane_id = threadIdx.x % WARP_SIZE; + for (int step = 0; step < div_up(dim, WARP_SIZE); step++) { + int idx = step * WARP_SIZE + lane_id; + float part_dist = 0; + if (idx < dim) { + part_dist = s_vec[idx]; + part_dist = part_dist * part_dist; } - __syncthreads(); - int lane_id = threadIdx.x % WARP_SIZE; - for (int step = 0; step < div_up(dim, WARP_SIZE); step++) { - int idx = step * WARP_SIZE + lane_id; - float part_dist = 0; - if (idx < dim) { - part_dist = s_vec[idx]; - part_dist = part_dist * part_dist; - } - __syncwarp(); - for (int offset = WARP_SIZE >> 1; offset >= 1; offset >>= 1) { - part_dist += __shfl_down_sync(FULL_MASK, part_dist, offset); - } - if (lane_id == 0) { - l2_norm += part_dist; - } - __syncwarp(); + __syncwarp(); + for (int offset = WARP_SIZE >> 1; offset >= 1; offset >>= 1) { + part_dist += __shfl_down_sync(FULL_MASK, part_dist, offset); } + if (lane_id == 0) { l2_norm += part_dist; } + __syncwarp(); + } - for (int step = 0; step < div_up(dim, WARP_SIZE); step++) { - int idx = step * WARP_SIZE + threadIdx.x; - if (idx < dim) { - if (l2_norms == nullptr) { - output_data[list_id * dim + idx] = - (float)input_data[(size_t)blockIdx.x * dim + idx] / sqrt(l2_norm); - } else { - output_data[list_id * dim + idx] = input_data[(size_t)blockIdx.x * dim + idx]; - if (idx == 0) { - l2_norms[list_id] = l2_norm; - } - } - } + for (int step = 0; step < div_up(dim, WARP_SIZE); step++) { + int idx = step * WARP_SIZE + threadIdx.x; + if (idx < dim) { + if (l2_norms == nullptr) { + output_data[list_id * dim + idx] = + (float)input_data[(size_t)blockIdx.x * dim + idx] / sqrt(l2_norm); + } else { + output_data[list_id * dim + idx] = input_data[(size_t)blockIdx.x * dim + idx]; + if (idx == 0) { l2_norms[list_id] = l2_norm; } + } } + } } template -__global__ void add_rev_edges_kernel(const Index_t *graph, Index_t *rev_graph, int num_samples, - int2 *list_sizes) { - size_t list_id = blockIdx.x; - int2 list_size = list_sizes[list_id]; - - for (int idx = threadIdx.x; idx < list_size.x; idx += blockDim.x) { - // each node has same number (num_samples) of forward and reverse edges - size_t rev_list_id = graph[list_id * num_samples + idx]; - // there are already num_samples forward edges - int idx_in_rev_list = atomicAdd(&list_sizes[rev_list_id].y, 1); - if (idx_in_rev_list >= num_samples) { - atomicExch(&list_sizes[rev_list_id].y, num_samples); - } else { - rev_graph[rev_list_id * num_samples + idx_in_rev_list] = list_id; - } +__global__ void add_rev_edges_kernel(const Index_t* graph, + Index_t* rev_graph, + int num_samples, + int2* list_sizes) +{ + size_t list_id = blockIdx.x; + int2 list_size = list_sizes[list_id]; + + for (int idx = threadIdx.x; idx < list_size.x; idx += blockDim.x) { + // each node has same number (num_samples) of forward and reverse edges + size_t rev_list_id = graph[list_id * num_samples + idx]; + // there are already num_samples forward edges + int idx_in_rev_list = atomicAdd(&list_sizes[rev_list_id].y, 1); + if (idx_in_rev_list >= num_samples) { + atomicExch(&list_sizes[rev_list_id].y, num_samples); + } else { + rev_graph[rev_list_id * num_samples + idx_in_rev_list] = list_id; } + } } template > -__device__ void insert_to_global_graph(ResultItem elem, size_t list_id, ID_t *graph, - DistData_t *dists, int node_degree, int *locks) { - int tx = threadIdx.x; - int lane_id = tx % WARP_SIZE; - size_t global_idx_base = list_id * node_degree; - if (elem.id() == list_id) return; - - const int num_segments = div_up(node_degree, WARP_SIZE); - - int loop_flag = 0; - do { - int segment_id = elem.id() % num_segments; - if (lane_id == 0) { - loop_flag = atomicCAS(&locks[list_id * num_segments + segment_id], 0, 1) == 0; - } +__device__ void insert_to_global_graph(ResultItem elem, + size_t list_id, + ID_t* graph, + DistData_t* dists, + int node_degree, + int* locks) +{ + int tx = threadIdx.x; + int lane_id = tx % WARP_SIZE; + size_t global_idx_base = list_id * node_degree; + if (elem.id() == list_id) return; + + const int num_segments = div_up(node_degree, WARP_SIZE); + + int loop_flag = 0; + do { + int segment_id = elem.id() % num_segments; + if (lane_id == 0) { + loop_flag = atomicCAS(&locks[list_id * num_segments + segment_id], 0, 1) == 0; + } - loop_flag = __shfl_sync(FULL_MASK, loop_flag, 0); - - if (loop_flag == 1) { - ResultItem knn_list_frag; - int local_idx = segment_id * WARP_SIZE + lane_id; - size_t global_idx = global_idx_base + local_idx; - if (local_idx < node_degree) { - knn_list_frag.id_with_flag() = graph[global_idx].id_with_flag(); - knn_list_frag.dist() = dists[global_idx]; - } - - int pos_to_insert = -1; - ResultItem prev_elem; - - prev_elem.id_with_flag() = __shfl_up_sync(FULL_MASK, knn_list_frag.id_with_flag(), 1); - prev_elem.dist() = __shfl_up_sync(FULL_MASK, knn_list_frag.dist(), 1); - - if (lane_id == 0) { - prev_elem = ResultItem{std::numeric_limits::min(), - std::numeric_limits::lowest()}; - } - if (elem > prev_elem && elem < knn_list_frag) { - pos_to_insert = segment_id * WARP_SIZE + lane_id; - } else if (elem == prev_elem || elem == knn_list_frag) { - pos_to_insert = -2; - } - uint mask = __ballot_sync(FULL_MASK, pos_to_insert >= 0); - if (mask) { - uint set_lane_id = __fns(mask, 0, 1); - pos_to_insert = __shfl_sync(FULL_MASK, pos_to_insert, set_lane_id); - } - - if (pos_to_insert >= 0) { - int local_idx = segment_id * WARP_SIZE + lane_id; - if (local_idx > pos_to_insert) { - local_idx++; - } else if (local_idx == pos_to_insert) { - graph[global_idx_base + local_idx].id_with_flag() = elem.id_with_flag(); - dists[global_idx_base + local_idx] = elem.dist(); - local_idx++; - } - size_t global_pos = global_idx_base + local_idx; - if (local_idx < (segment_id + 1) * WARP_SIZE && local_idx < node_degree) { - graph[global_pos].id_with_flag() = knn_list_frag.id_with_flag(); - dists[global_pos] = knn_list_frag.dist(); - } - } - __threadfence(); - if (loop_flag && lane_id == 0) { - atomicExch(&locks[list_id * num_segments + segment_id], 0); - } - } - } while (!loop_flag); -} + loop_flag = __shfl_sync(FULL_MASK, loop_flag, 0); -template -__device__ ResultItem get_min_item(const Index_t id, const int idx_in_list, - const Index_t *neighbs, const DistData_t *distances, - const bool find_in_row = true) { - int lane_id = threadIdx.x % WARP_SIZE; - - static_assert(MAX_NUM_BI_SAMPLES == 64); - int idx[MAX_NUM_BI_SAMPLES / WARP_SIZE]; - float dist[MAX_NUM_BI_SAMPLES / WARP_SIZE] = {std::numeric_limits::max(), - std::numeric_limits::max()}; - idx[0] = lane_id; - idx[1] = WARP_SIZE + lane_id; - - if (neighbs[idx[0]] != id) { - dist[0] = find_in_row ? distances[idx_in_list * SKEWED_MAX_NUM_BI_SAMPLES + lane_id] - : distances[idx_in_list + lane_id * SKEWED_MAX_NUM_BI_SAMPLES]; - } + if (loop_flag == 1) { + ResultItem knn_list_frag; + int local_idx = segment_id * WARP_SIZE + lane_id; + size_t global_idx = global_idx_base + local_idx; + if (local_idx < node_degree) { + knn_list_frag.id_with_flag() = graph[global_idx].id_with_flag(); + knn_list_frag.dist() = dists[global_idx]; + } - if (neighbs[idx[1]] != id) { - dist[1] = find_in_row - ? distances[idx_in_list * SKEWED_MAX_NUM_BI_SAMPLES + WARP_SIZE + lane_id] - : distances[idx_in_list + (WARP_SIZE + lane_id) * SKEWED_MAX_NUM_BI_SAMPLES]; - } + int pos_to_insert = -1; + ResultItem prev_elem; - if (dist[1] < dist[0]) { - dist[0] = dist[1]; - idx[0] = idx[1]; - } - __syncwarp(); - for (int offset = WARP_SIZE >> 1; offset >= 1; offset >>= 1) { - float other_idx = __shfl_down_sync(FULL_MASK, idx[0], offset); - float other_dist = __shfl_down_sync(FULL_MASK, dist[0], offset); - if (other_dist < dist[0]) { - dist[0] = other_dist; - idx[0] = other_idx; - } - } + prev_elem.id_with_flag() = __shfl_up_sync(FULL_MASK, knn_list_frag.id_with_flag(), 1); + prev_elem.dist() = __shfl_up_sync(FULL_MASK, knn_list_frag.dist(), 1); - ResultItem result; - result.dist() = __shfl_sync(FULL_MASK, dist[0], 0); - result.id_with_flag() = neighbs[__shfl_sync(FULL_MASK, idx[0], 0)]; - return result; -} + if (lane_id == 0) { + prev_elem = ResultItem{std::numeric_limits::min(), + std::numeric_limits::lowest()}; + } + if (elem > prev_elem && elem < knn_list_frag) { + pos_to_insert = segment_id * WARP_SIZE + lane_id; + } else if (elem == prev_elem || elem == knn_list_frag) { + pos_to_insert = -2; + } + uint mask = __ballot_sync(FULL_MASK, pos_to_insert >= 0); + if (mask) { + uint set_lane_id = __fns(mask, 0, 1); + pos_to_insert = __shfl_sync(FULL_MASK, pos_to_insert, set_lane_id); + } -template -__device__ __forceinline__ void remove_duplicates(T *list_a, int list_a_size, T *list_b, - int list_b_size, int &unique_counter, - int execute_warp_id) { - static_assert(WARP_SIZE == 32); - if (!(threadIdx.x >= execute_warp_id * WARP_SIZE && - threadIdx.x < execute_warp_id * WARP_SIZE + WARP_SIZE)) { - return; - } - int lane_id = threadIdx.x % WARP_SIZE; - T elem = std::numeric_limits::max(); - if (lane_id < list_a_size) { - elem = list_a[lane_id]; + if (pos_to_insert >= 0) { + int local_idx = segment_id * WARP_SIZE + lane_id; + if (local_idx > pos_to_insert) { + local_idx++; + } else if (local_idx == pos_to_insert) { + graph[global_idx_base + local_idx].id_with_flag() = elem.id_with_flag(); + dists[global_idx_base + local_idx] = elem.dist(); + local_idx++; + } + size_t global_pos = global_idx_base + local_idx; + if (local_idx < (segment_id + 1) * WARP_SIZE && local_idx < node_degree) { + graph[global_pos].id_with_flag() = knn_list_frag.id_with_flag(); + dists[global_pos] = knn_list_frag.dist(); + } + } + __threadfence(); + if (loop_flag && lane_id == 0) { atomicExch(&locks[list_id * num_segments + segment_id], 0); } } - warp_bitonic_sort(&elem, lane_id); + } while (!loop_flag); +} - if (elem != std::numeric_limits::max()) { - list_a[lane_id] = elem; - } +template +__device__ ResultItem get_min_item(const Index_t id, + const int idx_in_list, + const Index_t* neighbs, + const DistData_t* distances, + const bool find_in_row = true) +{ + int lane_id = threadIdx.x % WARP_SIZE; + + static_assert(MAX_NUM_BI_SAMPLES == 64); + int idx[MAX_NUM_BI_SAMPLES / WARP_SIZE]; + float dist[MAX_NUM_BI_SAMPLES / WARP_SIZE] = {std::numeric_limits::max(), + std::numeric_limits::max()}; + idx[0] = lane_id; + idx[1] = WARP_SIZE + lane_id; + + if (neighbs[idx[0]] != id) { + dist[0] = find_in_row ? distances[idx_in_list * SKEWED_MAX_NUM_BI_SAMPLES + lane_id] + : distances[idx_in_list + lane_id * SKEWED_MAX_NUM_BI_SAMPLES]; + } - T elem_b = std::numeric_limits::max(); + if (neighbs[idx[1]] != id) { + dist[1] = find_in_row + ? distances[idx_in_list * SKEWED_MAX_NUM_BI_SAMPLES + WARP_SIZE + lane_id] + : distances[idx_in_list + (WARP_SIZE + lane_id) * SKEWED_MAX_NUM_BI_SAMPLES]; + } - if (lane_id < list_b_size) { - elem_b = list_b[lane_id]; + if (dist[1] < dist[0]) { + dist[0] = dist[1]; + idx[0] = idx[1]; + } + __syncwarp(); + for (int offset = WARP_SIZE >> 1; offset >= 1; offset >>= 1) { + float other_idx = __shfl_down_sync(FULL_MASK, idx[0], offset); + float other_dist = __shfl_down_sync(FULL_MASK, dist[0], offset); + if (other_dist < dist[0]) { + dist[0] = other_dist; + idx[0] = other_idx; } - __syncwarp(); + } - int idx_l = 0; - int idx_r = list_a_size; - bool existed = false; - while (idx_l < idx_r) { - int idx = (idx_l + idx_r) / 2; - int elem = list_a[idx]; - if (elem == elem_b) { - existed = true; - break; - } - if (elem_b > elem) { - idx_l = idx + 1; - } else { - idx_r = idx; - } + ResultItem result; + result.dist() = __shfl_sync(FULL_MASK, dist[0], 0); + result.id_with_flag() = neighbs[__shfl_sync(FULL_MASK, idx[0], 0)]; + return result; +} + +template +__device__ __forceinline__ void remove_duplicates( + T* list_a, int list_a_size, T* list_b, int list_b_size, int& unique_counter, int execute_warp_id) +{ + static_assert(WARP_SIZE == 32); + if (!(threadIdx.x >= execute_warp_id * WARP_SIZE && + threadIdx.x < execute_warp_id * WARP_SIZE + WARP_SIZE)) { + return; + } + int lane_id = threadIdx.x % WARP_SIZE; + T elem = std::numeric_limits::max(); + if (lane_id < list_a_size) { elem = list_a[lane_id]; } + warp_bitonic_sort(&elem, lane_id); + + if (elem != std::numeric_limits::max()) { list_a[lane_id] = elem; } + + T elem_b = std::numeric_limits::max(); + + if (lane_id < list_b_size) { elem_b = list_b[lane_id]; } + __syncwarp(); + + int idx_l = 0; + int idx_r = list_a_size; + bool existed = false; + while (idx_l < idx_r) { + int idx = (idx_l + idx_r) / 2; + int elem = list_a[idx]; + if (elem == elem_b) { + existed = true; + break; } - if (!existed && elem_b != std::numeric_limits::max()) { - int idx = atomicAdd(&unique_counter, 1); - list_a[list_a_size + idx] = elem_b; + if (elem_b > elem) { + idx_l = idx + 1; + } else { + idx_r = idx; } + } + if (!existed && elem_b != std::numeric_limits::max()) { + int idx = atomicAdd(&unique_counter, 1); + list_a[list_a_size + idx] = elem_b; + } } template > -__global__ void __launch_bounds__(BLOCK_SIZE, 4) - local_join_kernel(const Index_t *graph_new, const Index_t *rev_graph_new, const int2 *sizes_new, - const Index_t *graph_old, const Index_t *rev_graph_old, const int2 *sizes_old, - const int width, const __half *data, const int data_dim, ID_t *graph, - DistData_t *dists, int graph_width, int *locks, DistData_t *l2_norms) { - using namespace nvcuda; - __shared__ int s_list[MAX_NUM_BI_SAMPLES * 2]; - - constexpr int APAD = 8; - constexpr int BPAD = 8; - __shared__ __half s_nv[MAX_NUM_BI_SAMPLES][TILE_COL_WIDTH + APAD]; // New vectors - __shared__ __half s_ov[MAX_NUM_BI_SAMPLES][TILE_COL_WIDTH + BPAD]; // Old vectors - static_assert(sizeof(float) * MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES <= - sizeof(__half) * MAX_NUM_BI_SAMPLES * (TILE_COL_WIDTH + BPAD)); - // s_distances: MAX_NUM_BI_SAMPLES x SKEWED_MAX_NUM_BI_SAMPLES, reuse the space of s_ov - float *s_distances = (float *)&s_ov[0][0]; - int *s_unique_counter = (int *)&s_ov[0][0]; - - if (threadIdx.x == 0) { - s_unique_counter[0] = 0; - s_unique_counter[1] = 0; - } - - Index_t *new_neighbors = s_list; - Index_t *old_neighbors = s_list + MAX_NUM_BI_SAMPLES; - - size_t list_id = blockIdx.x; - int2 list_new_size2 = sizes_new[list_id]; - int list_new_size = list_new_size2.x + list_new_size2.y; - int2 list_old_size2 = sizes_old[list_id]; - int list_old_size = list_old_size2.x + list_old_size2.y; +__global__ void __launch_bounds__(BLOCK_SIZE, 4) local_join_kernel(const Index_t* graph_new, + const Index_t* rev_graph_new, + const int2* sizes_new, + const Index_t* graph_old, + const Index_t* rev_graph_old, + const int2* sizes_old, + const int width, + const __half* data, + const int data_dim, + ID_t* graph, + DistData_t* dists, + int graph_width, + int* locks, + DistData_t* l2_norms) +{ + using namespace nvcuda; + __shared__ int s_list[MAX_NUM_BI_SAMPLES * 2]; + + constexpr int APAD = 8; + constexpr int BPAD = 8; + __shared__ __half s_nv[MAX_NUM_BI_SAMPLES][TILE_COL_WIDTH + APAD]; // New vectors + __shared__ __half s_ov[MAX_NUM_BI_SAMPLES][TILE_COL_WIDTH + BPAD]; // Old vectors + static_assert(sizeof(float) * MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES <= + sizeof(__half) * MAX_NUM_BI_SAMPLES * (TILE_COL_WIDTH + BPAD)); + // s_distances: MAX_NUM_BI_SAMPLES x SKEWED_MAX_NUM_BI_SAMPLES, reuse the space of s_ov + float* s_distances = (float*)&s_ov[0][0]; + int* s_unique_counter = (int*)&s_ov[0][0]; + + if (threadIdx.x == 0) { + s_unique_counter[0] = 0; + s_unique_counter[1] = 0; + } - if (!list_new_size) return; - int tx = threadIdx.x; + Index_t* new_neighbors = s_list; + Index_t* old_neighbors = s_list + MAX_NUM_BI_SAMPLES; - if (tx < list_new_size2.x) { - new_neighbors[tx] = graph_new[list_id * width + tx]; - } else if (tx >= list_new_size2.x && tx < list_new_size) { - new_neighbors[tx] = rev_graph_new[list_id * width + tx - list_new_size2.x]; - } + size_t list_id = blockIdx.x; + int2 list_new_size2 = sizes_new[list_id]; + int list_new_size = list_new_size2.x + list_new_size2.y; + int2 list_old_size2 = sizes_old[list_id]; + int list_old_size = list_old_size2.x + list_old_size2.y; - if (tx < list_old_size2.x) { - old_neighbors[tx] = graph_old[list_id * width + tx]; - } else if (tx >= list_old_size2.x && tx < list_old_size) { - old_neighbors[tx] = rev_graph_old[list_id * width + tx - list_old_size2.x]; - } + if (!list_new_size) return; + int tx = threadIdx.x; - __syncthreads(); + if (tx < list_new_size2.x) { + new_neighbors[tx] = graph_new[list_id * width + tx]; + } else if (tx >= list_new_size2.x && tx < list_new_size) { + new_neighbors[tx] = rev_graph_new[list_id * width + tx - list_new_size2.x]; + } - remove_duplicates(new_neighbors, list_new_size2.x, new_neighbors + list_new_size2.x, - list_new_size2.y, s_unique_counter[0], 0); + if (tx < list_old_size2.x) { + old_neighbors[tx] = graph_old[list_id * width + tx]; + } else if (tx >= list_old_size2.x && tx < list_old_size) { + old_neighbors[tx] = rev_graph_old[list_id * width + tx - list_old_size2.x]; + } - remove_duplicates(old_neighbors, list_old_size2.x, old_neighbors + list_old_size2.x, - list_old_size2.y, s_unique_counter[1], 1); - __syncthreads(); - list_new_size = list_new_size2.x + s_unique_counter[0]; - list_old_size = list_old_size2.x + s_unique_counter[1]; - - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; - constexpr int num_warps = BLOCK_SIZE / WARP_SIZE; - - int warp_id_y = warp_id / 4; - int warp_id_x = warp_id % 4; - - wmma::fragment a_frag; - wmma::fragment b_frag; - wmma::fragment c_frag; - wmma::fill_fragment(c_frag, 0.0); - for (int step = 0; step < div_up(data_dim, TILE_COL_WIDTH); step++) { - int num_load_elems = (step == div_up(data_dim, TILE_COL_WIDTH) - 1) - ? data_dim - step * TILE_COL_WIDTH - : TILE_COL_WIDTH; + __syncthreads(); + + remove_duplicates(new_neighbors, + list_new_size2.x, + new_neighbors + list_new_size2.x, + list_new_size2.y, + s_unique_counter[0], + 0); + + remove_duplicates(old_neighbors, + list_old_size2.x, + old_neighbors + list_old_size2.x, + list_old_size2.y, + s_unique_counter[1], + 1); + __syncthreads(); + list_new_size = list_new_size2.x + s_unique_counter[0]; + list_old_size = list_old_size2.x + s_unique_counter[1]; + + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + constexpr int num_warps = BLOCK_SIZE / WARP_SIZE; + + int warp_id_y = warp_id / 4; + int warp_id_x = warp_id % 4; + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0); + for (int step = 0; step < div_up(data_dim, TILE_COL_WIDTH); step++) { + int num_load_elems = (step == div_up(data_dim, TILE_COL_WIDTH) - 1) + ? data_dim - step * TILE_COL_WIDTH + : TILE_COL_WIDTH; #pragma unroll - for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { - int idx = i * num_warps + warp_id; - if (idx < list_new_size) { - size_t neighbor_id = new_neighbors[idx]; - size_t idx_in_data = neighbor_id * data_dim; - load_vec(s_nv[idx], data + idx_in_data + step * TILE_COL_WIDTH, num_load_elems, - TILE_COL_WIDTH, lane_id); - } - } - __syncthreads(); - - for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) { - wmma::load_matrix_sync(a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, - TILE_COL_WIDTH + APAD); - wmma::load_matrix_sync(b_frag, s_nv[warp_id_x * WMMA_N] + i * WMMA_K, - TILE_COL_WIDTH + BPAD); - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - __syncthreads(); - } + for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { + int idx = i * num_warps + warp_id; + if (idx < list_new_size) { + size_t neighbor_id = new_neighbors[idx]; + size_t idx_in_data = neighbor_id * data_dim; + load_vec(s_nv[idx], + data + idx_in_data + step * TILE_COL_WIDTH, + num_load_elems, + TILE_COL_WIDTH, + lane_id); + } } - - wmma::store_matrix_sync( - s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N, c_frag, - SKEWED_MAX_NUM_BI_SAMPLES, wmma::mem_row_major); __syncthreads(); - for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { - if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_new_size && i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) { - if (l2_norms == nullptr) { - s_distances[i] = -s_distances[i]; - } else { - s_distances[i] = l2_norms[new_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] + - l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] - - 2.0 * s_distances[i]; - } - } else { - s_distances[i] = std::numeric_limits::max(); - } + for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) { + wmma::load_matrix_sync(a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD); + wmma::load_matrix_sync(b_frag, s_nv[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD); + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + __syncthreads(); } - __syncthreads(); + } - for (int step = 0; step < div_up(list_new_size, num_warps); step++) { - int idx_in_list = step * num_warps + tx / WARP_SIZE; - if (idx_in_list >= list_new_size) continue; - auto min_elem = get_min_item(s_list[idx_in_list], idx_in_list, new_neighbors, s_distances); - if (min_elem.id() < gridDim.x) { - insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks); - } + wmma::store_matrix_sync( + s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N, + c_frag, + SKEWED_MAX_NUM_BI_SAMPLES, + wmma::mem_row_major); + __syncthreads(); + + for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { + if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_new_size && + i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) { + if (l2_norms == nullptr) { + s_distances[i] = -s_distances[i]; + } else { + s_distances[i] = l2_norms[new_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] + + l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] - + 2.0 * s_distances[i]; + } + } else { + s_distances[i] = std::numeric_limits::max(); + } + } + __syncthreads(); + + for (int step = 0; step < div_up(list_new_size, num_warps); step++) { + int idx_in_list = step * num_warps + tx / WARP_SIZE; + if (idx_in_list >= list_new_size) continue; + auto min_elem = get_min_item(s_list[idx_in_list], idx_in_list, new_neighbors, s_distances); + if (min_elem.id() < gridDim.x) { + insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks); } + } - if (!list_old_size) return; - - __syncthreads(); + if (!list_old_size) return; - wmma::fill_fragment(c_frag, 0.0); - for (int step = 0; step < div_up(data_dim, TILE_COL_WIDTH); step++) { - int num_load_elems = (step == div_up(data_dim, TILE_COL_WIDTH) - 1) - ? data_dim - step * TILE_COL_WIDTH - : TILE_COL_WIDTH; - if (TILE_COL_WIDTH < data_dim) { + __syncthreads(); + + wmma::fill_fragment(c_frag, 0.0); + for (int step = 0; step < div_up(data_dim, TILE_COL_WIDTH); step++) { + int num_load_elems = (step == div_up(data_dim, TILE_COL_WIDTH) - 1) + ? data_dim - step * TILE_COL_WIDTH + : TILE_COL_WIDTH; + if (TILE_COL_WIDTH < data_dim) { #pragma unroll - for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { - int idx = i * num_warps + warp_id; - if (idx < list_new_size) { - size_t neighbor_id = new_neighbors[idx]; - size_t idx_in_data = neighbor_id * data_dim; - load_vec(s_nv[idx], data + idx_in_data + step * TILE_COL_WIDTH, num_load_elems, - TILE_COL_WIDTH, lane_id); - } - } + for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { + int idx = i * num_warps + warp_id; + if (idx < list_new_size) { + size_t neighbor_id = new_neighbors[idx]; + size_t idx_in_data = neighbor_id * data_dim; + load_vec(s_nv[idx], + data + idx_in_data + step * TILE_COL_WIDTH, + num_load_elems, + TILE_COL_WIDTH, + lane_id); } + } + } #pragma unroll - for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { - int idx = i * num_warps + warp_id; - if (idx < list_old_size) { - size_t neighbor_id = old_neighbors[idx]; - size_t idx_in_data = neighbor_id * data_dim; - load_vec(s_ov[idx], data + idx_in_data + step * TILE_COL_WIDTH, num_load_elems, - TILE_COL_WIDTH, lane_id); - } - } - __syncthreads(); - - for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) { - wmma::load_matrix_sync(a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, - TILE_COL_WIDTH + APAD); - wmma::load_matrix_sync(b_frag, s_ov[warp_id_x * WMMA_N] + i * WMMA_K, - TILE_COL_WIDTH + BPAD); - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - __syncthreads(); - } + for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { + int idx = i * num_warps + warp_id; + if (idx < list_old_size) { + size_t neighbor_id = old_neighbors[idx]; + size_t idx_in_data = neighbor_id * data_dim; + load_vec(s_ov[idx], + data + idx_in_data + step * TILE_COL_WIDTH, + num_load_elems, + TILE_COL_WIDTH, + lane_id); + } } - - wmma::store_matrix_sync( - s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N, c_frag, - SKEWED_MAX_NUM_BI_SAMPLES, wmma::mem_row_major); __syncthreads(); - for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { - if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_old_size && i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) { - if (l2_norms == nullptr) { - s_distances[i] = -s_distances[i]; - } else { - s_distances[i] = l2_norms[old_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] + - l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] - - 2.0 * s_distances[i]; - } - } else { - s_distances[i] = std::numeric_limits::max(); - } + for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) { + wmma::load_matrix_sync(a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD); + wmma::load_matrix_sync(b_frag, s_ov[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD); + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + __syncthreads(); } - __syncthreads(); + } - for (int step = 0; step < div_up(MAX_NUM_BI_SAMPLES * 2, num_warps); step++) { - int idx_in_list = step * num_warps + tx / WARP_SIZE; - if (idx_in_list >= list_new_size && idx_in_list < MAX_NUM_BI_SAMPLES) continue; - if (idx_in_list >= MAX_NUM_BI_SAMPLES + list_old_size && - idx_in_list < MAX_NUM_BI_SAMPLES * 2) - continue; - ResultItem min_elem{std::numeric_limits::max(), - std::numeric_limits::max()}; - if (idx_in_list < MAX_NUM_BI_SAMPLES) { - auto temp_min_item = - get_min_item(s_list[idx_in_list], idx_in_list, old_neighbors, s_distances); - if (temp_min_item.dist() < min_elem.dist()) { - min_elem = temp_min_item; - } - } else { - auto temp_min_item = get_min_item(s_list[idx_in_list], idx_in_list - MAX_NUM_BI_SAMPLES, - new_neighbors, s_distances, false); - if (temp_min_item.dist() < min_elem.dist()) { - min_elem = temp_min_item; - } - } + wmma::store_matrix_sync( + s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N, + c_frag, + SKEWED_MAX_NUM_BI_SAMPLES, + wmma::mem_row_major); + __syncthreads(); + + for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { + if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_old_size && + i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) { + if (l2_norms == nullptr) { + s_distances[i] = -s_distances[i]; + } else { + s_distances[i] = l2_norms[old_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] + + l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] - + 2.0 * s_distances[i]; + } + } else { + s_distances[i] = std::numeric_limits::max(); + } + } + __syncthreads(); + + for (int step = 0; step < div_up(MAX_NUM_BI_SAMPLES * 2, num_warps); step++) { + int idx_in_list = step * num_warps + tx / WARP_SIZE; + if (idx_in_list >= list_new_size && idx_in_list < MAX_NUM_BI_SAMPLES) continue; + if (idx_in_list >= MAX_NUM_BI_SAMPLES + list_old_size && idx_in_list < MAX_NUM_BI_SAMPLES * 2) + continue; + ResultItem min_elem{std::numeric_limits::max(), + std::numeric_limits::max()}; + if (idx_in_list < MAX_NUM_BI_SAMPLES) { + auto temp_min_item = + get_min_item(s_list[idx_in_list], idx_in_list, old_neighbors, s_distances); + if (temp_min_item.dist() < min_elem.dist()) { min_elem = temp_min_item; } + } else { + auto temp_min_item = get_min_item( + s_list[idx_in_list], idx_in_list - MAX_NUM_BI_SAMPLES, new_neighbors, s_distances, false); + if (temp_min_item.dist() < min_elem.dist()) { min_elem = temp_min_item; } + } - if (min_elem.id() < gridDim.x) { - insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks); - } + if (min_elem.id() < gridDim.x) { + insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks); } + } } namespace { template -int insert_to_ordered_list(InternalID_t* list, DistData_t* dist_list, const int width, - const InternalID_t neighb_id, const DistData_t dist) { - if (dist > dist_list[width - 1]) { - return width; - } - - int idx_insert = width; - for (int i = 0; i < width; i++) { - if (list[i].id() == neighb_id.id()) { - return width; - } - if (dist_list[i] > dist) { - idx_insert = i; - break; - } +int insert_to_ordered_list(InternalID_t* list, + DistData_t* dist_list, + const int width, + const InternalID_t neighb_id, + const DistData_t dist) +{ + if (dist > dist_list[width - 1]) { return width; } + + int idx_insert = width; + for (int i = 0; i < width; i++) { + if (list[i].id() == neighb_id.id()) { return width; } + if (dist_list[i] > dist) { + idx_insert = i; + break; } - if (idx_insert == width) return idx_insert; + } + if (idx_insert == width) return idx_insert; - memmove(list + idx_insert + 1, list + idx_insert, sizeof(*list) * (width - idx_insert - 1)); - memmove(dist_list + idx_insert + 1, dist_list + idx_insert, - sizeof(*dist_list) * (width - idx_insert - 1)); + memmove(list + idx_insert + 1, list + idx_insert, sizeof(*list) * (width - idx_insert - 1)); + memmove(dist_list + idx_insert + 1, + dist_list + idx_insert, + sizeof(*dist_list) * (width - idx_insert - 1)); - list[idx_insert] = neighb_id; - dist_list[idx_insert] = dist; - return idx_insert; + list[idx_insert] = neighb_id; + dist_list[idx_insert] = dist; + return idx_insert; }; } // namespace template -GnndGraph::GnndGraph(const size_t nrow, const size_t node_degree, - const size_t internal_node_degree, const size_t num_samples) - : nrow(nrow), - node_degree(node_degree), - num_samples(num_samples), - bloom_filter(nrow, internal_node_degree / segment_size, 3) { - // node_degree must be a multiple of segment_size; - assert(node_degree % segment_size == 0); - assert(internal_node_degree % segment_size == 0); - - num_segments = node_degree / segment_size; - // To save the CPU memory, graph should be allocated by external function - h_graph = nullptr; - h_dists = new DistData_t[nrow * node_degree]; - - RAFT_CUDA_TRY(cudaMallocHost(&h_graph_new, sizeof(*h_graph_new) * nrow * num_samples)); - RAFT_CUDA_TRY(cudaMallocHost(&h_list_sizes_new, sizeof(*h_list_sizes_new) * nrow)); - - RAFT_CUDA_TRY(cudaMallocHost(&h_graph_old, sizeof(*h_graph_old) * nrow * num_samples)); - RAFT_CUDA_TRY(cudaMallocHost(&h_list_sizes_old, sizeof(*h_list_sizes_old) * nrow)); +GnndGraph::GnndGraph(const size_t nrow, + const size_t node_degree, + const size_t internal_node_degree, + const size_t num_samples) + : nrow(nrow), + node_degree(node_degree), + num_samples(num_samples), + bloom_filter(nrow, internal_node_degree / segment_size, 3) +{ + // node_degree must be a multiple of segment_size; + assert(node_degree % segment_size == 0); + assert(internal_node_degree % segment_size == 0); + + num_segments = node_degree / segment_size; + // To save the CPU memory, graph should be allocated by external function + h_graph = nullptr; + h_dists = new DistData_t[nrow * node_degree]; + + RAFT_CUDA_TRY(cudaMallocHost(&h_graph_new, sizeof(*h_graph_new) * nrow * num_samples)); + RAFT_CUDA_TRY(cudaMallocHost(&h_list_sizes_new, sizeof(*h_list_sizes_new) * nrow)); + + RAFT_CUDA_TRY(cudaMallocHost(&h_graph_old, sizeof(*h_graph_old) * nrow * num_samples)); + RAFT_CUDA_TRY(cudaMallocHost(&h_list_sizes_old, sizeof(*h_list_sizes_old) * nrow)); } // This is the only operation on the CPU that cannot be overlapped. // So it should be as fast as possible. template -void GnndGraph::sample_graph_new(InternalID_t* new_neighbors, - const size_t width) { +void GnndGraph::sample_graph_new(InternalID_t* new_neighbors, const size_t width) +{ #pragma omp parallel for - for (size_t i = 0; i < nrow; i++) { - auto list_new = h_graph_new + i * num_samples; - h_list_sizes_new[i].x = 0; - h_list_sizes_new[i].y = 0; - - for (size_t j = 0; j < width; j++) { - auto new_neighb_id = new_neighbors[i * width + j].id(); - if ((size_t)new_neighb_id >= nrow) break; - if (bloom_filter.check(i, new_neighb_id)) { - continue; - } - bloom_filter.add(i, new_neighb_id); - new_neighbors[i * width + j].mark_old(); - list_new[h_list_sizes_new[i].x++] = new_neighb_id; - if (h_list_sizes_new[i].x == num_samples) break; - } + for (size_t i = 0; i < nrow; i++) { + auto list_new = h_graph_new + i * num_samples; + h_list_sizes_new[i].x = 0; + h_list_sizes_new[i].y = 0; + + for (size_t j = 0; j < width; j++) { + auto new_neighb_id = new_neighbors[i * width + j].id(); + if ((size_t)new_neighb_id >= nrow) break; + if (bloom_filter.check(i, new_neighb_id)) { continue; } + bloom_filter.add(i, new_neighb_id); + new_neighbors[i * width + j].mark_old(); + list_new[h_list_sizes_new[i].x++] = new_neighb_id; + if (h_list_sizes_new[i].x == num_samples) break; } + } } template -void GnndGraph::init_random_graph() { - // random sequence (range: 0~nrow) - std::vector rand_seq(nrow); - std::iota(rand_seq.begin(), rand_seq.end(), 0); - std::random_shuffle(rand_seq.begin(), rand_seq.end()); +void GnndGraph::init_random_graph() +{ + // random sequence (range: 0~nrow) + std::vector rand_seq(nrow); + std::iota(rand_seq.begin(), rand_seq.end(), 0); + std::random_shuffle(rand_seq.begin(), rand_seq.end()); #pragma omp parallel for - for (size_t i = 0; i < nrow; i++) { - for (size_t j = 0; j < node_degree; j++) { - size_t idx = i * node_degree + j; - Index_t id = rand_seq[idx % nrow]; - if ((size_t)id == i) { - id = rand_seq[(idx + node_degree) % nrow]; - } - h_graph[i * node_degree + j].id_with_flag() = id; - h_dists[i * node_degree + j] = std::numeric_limits::max(); - } + for (size_t i = 0; i < nrow; i++) { + for (size_t j = 0; j < node_degree; j++) { + size_t idx = i * node_degree + j; + Index_t id = rand_seq[idx % nrow]; + if ((size_t)id == i) { id = rand_seq[(idx + node_degree) % nrow]; } + h_graph[i * node_degree + j].id_with_flag() = id; + h_dists[i * node_degree + j] = std::numeric_limits::max(); } + } } template -void GnndGraph::sample_graph(bool sample_new) { +void GnndGraph::sample_graph(bool sample_new) +{ #pragma omp parallel for - for (size_t i = 0; i < nrow; i++) { - h_list_sizes_old[i].x = 0; - h_list_sizes_old[i].y = 0; - h_list_sizes_new[i].x = 0; - h_list_sizes_new[i].y = 0; - - auto list = h_graph + i * node_degree; - auto list_old = h_graph_old + i * num_samples; - auto list_new = h_graph_new + i * num_samples; - for (int j = 0; j < segment_size; j++) { - for (int k = 0; k < num_segments; k++) { - auto neighbor = list[k * segment_size + j]; - if ((size_t)neighbor.id() >= nrow) continue; - if (!neighbor.is_new()) { - if (h_list_sizes_old[i].x < num_samples) { - list_old[h_list_sizes_old[i].x++] = neighbor.id(); - } - } else if (sample_new) { - if (h_list_sizes_new[i].x < num_samples) { - list[k * segment_size + j].mark_old(); - list_new[h_list_sizes_new[i].x++] = neighbor.id(); - } - } - if (h_list_sizes_old[i].x == num_samples && h_list_sizes_new[i].x == num_samples) { - break; - } - } - if (h_list_sizes_old[i].x == num_samples && h_list_sizes_new[i].x == num_samples) { - break; - } + for (size_t i = 0; i < nrow; i++) { + h_list_sizes_old[i].x = 0; + h_list_sizes_old[i].y = 0; + h_list_sizes_new[i].x = 0; + h_list_sizes_new[i].y = 0; + + auto list = h_graph + i * node_degree; + auto list_old = h_graph_old + i * num_samples; + auto list_new = h_graph_new + i * num_samples; + for (int j = 0; j < segment_size; j++) { + for (int k = 0; k < num_segments; k++) { + auto neighbor = list[k * segment_size + j]; + if ((size_t)neighbor.id() >= nrow) continue; + if (!neighbor.is_new()) { + if (h_list_sizes_old[i].x < num_samples) { + list_old[h_list_sizes_old[i].x++] = neighbor.id(); + } + } else if (sample_new) { + if (h_list_sizes_new[i].x < num_samples) { + list[k * segment_size + j].mark_old(); + list_new[h_list_sizes_new[i].x++] = neighbor.id(); + } } + if (h_list_sizes_old[i].x == num_samples && h_list_sizes_new[i].x == num_samples) { break; } + } + if (h_list_sizes_old[i].x == num_samples && h_list_sizes_new[i].x == num_samples) { break; } } + } } template void GnndGraph::update_graph(const InternalID_t* new_neighbors, - const DistData_t* new_dists, const size_t width, - std::atomic& update_counter) { + const DistData_t* new_dists, + const size_t width, + std::atomic& update_counter) +{ #pragma omp parallel for - for (size_t i = 0; i < nrow; i++) { - for (size_t j = 0; j < width; j++) { - auto new_neighb_id = new_neighbors[i * width + j]; - auto new_dist = new_dists[i * width + j]; - if (new_dist == std::numeric_limits::max()) break; - if ((size_t)new_neighb_id.id() == i) continue; - int idx_seg = new_neighb_id.id() % num_segments; - auto list = h_graph + i * node_degree + idx_seg * segment_size; - auto dist_list = h_dists + i * node_degree + idx_seg * segment_size; - int insert_pos = - insert_to_ordered_list(list, dist_list, segment_size, new_neighb_id, new_dist); - if (i % counter_interval == 0 && insert_pos != segment_size) { - update_counter++; - } - } + for (size_t i = 0; i < nrow; i++) { + for (size_t j = 0; j < width; j++) { + auto new_neighb_id = new_neighbors[i * width + j]; + auto new_dist = new_dists[i * width + j]; + if (new_dist == std::numeric_limits::max()) break; + if ((size_t)new_neighb_id.id() == i) continue; + int idx_seg = new_neighb_id.id() % num_segments; + auto list = h_graph + i * node_degree + idx_seg * segment_size; + auto dist_list = h_dists + i * node_degree + idx_seg * segment_size; + int insert_pos = + insert_to_ordered_list(list, dist_list, segment_size, new_neighb_id, new_dist); + if (i % counter_interval == 0 && insert_pos != segment_size) { update_counter++; } } + } } template -void GnndGraph::sort_lists() { +void GnndGraph::sort_lists() +{ #pragma omp parallel for - for (size_t i = 0; i < nrow; i++) { - std::vector> new_list; - for (size_t j = 0; j < node_degree; j++) { - new_list.emplace_back(h_dists[i * node_degree + j], h_graph[i * node_degree + j].id()); - } - std::sort(new_list.begin(), new_list.end()); - for (size_t j = 0; j < node_degree; j++) { - h_graph[i * node_degree + j].id_with_flag() = new_list[j].second; - h_dists[i * node_degree + j] = new_list[j].first; - } + for (size_t i = 0; i < nrow; i++) { + std::vector> new_list; + for (size_t j = 0; j < node_degree; j++) { + new_list.emplace_back(h_dists[i * node_degree + j], h_graph[i * node_degree + j].id()); + } + std::sort(new_list.begin(), new_list.end()); + for (size_t j = 0; j < node_degree; j++) { + h_graph[i * node_degree + j].id_with_flag() = new_list[j].second; + h_dists[i * node_degree + j] = new_list[j].first; } + } } template -void GnndGraph::clear() { - bloom_filter.clear(); +void GnndGraph::clear() +{ + bloom_filter.clear(); } template -void GnndGraph::dealloc() { - delete[] h_dists; - RAFT_CUDA_TRY(cudaFreeHost(h_graph_new)); - RAFT_CUDA_TRY(cudaFreeHost(h_list_sizes_new)); - RAFT_CUDA_TRY(cudaFreeHost(h_graph_old)); - RAFT_CUDA_TRY(cudaFreeHost(h_list_sizes_old)); - assert(h_graph == nullptr); +void GnndGraph::dealloc() +{ + delete[] h_dists; + RAFT_CUDA_TRY(cudaFreeHost(h_graph_new)); + RAFT_CUDA_TRY(cudaFreeHost(h_list_sizes_new)); + RAFT_CUDA_TRY(cudaFreeHost(h_graph_old)); + RAFT_CUDA_TRY(cudaFreeHost(h_list_sizes_old)); + assert(h_graph == nullptr); } template -GnndGraph::~GnndGraph() { - +GnndGraph::~GnndGraph() +{ } template GNND::GNND(const BuildConfig& build_config) - : build_config_(build_config), - graph_(build_config.max_dataset_size, - to_multiple_of_32(build_config.node_degree), - to_multiple_of_32(build_config.internal_node_degree ? build_config.internal_node_degree - : build_config.node_degree), - NUM_SAMPLES), - nrow_(build_config.max_dataset_size), - ndim_(build_config.dataset_dim) { - static_assert(NUM_SAMPLES <= 32); - alloc_workspace(); + : build_config_(build_config), + graph_(build_config.max_dataset_size, + to_multiple_of_32(build_config.node_degree), + to_multiple_of_32(build_config.internal_node_degree ? build_config.internal_node_degree + : build_config.node_degree), + NUM_SAMPLES), + nrow_(build_config.max_dataset_size), + ndim_(build_config.dataset_dim) +{ + static_assert(NUM_SAMPLES <= 32); + alloc_workspace(); }; template -void GNND::alloc_workspace() { - RAFT_CUDA_TRY(cudaMalloc(&d_data_, sizeof(__half) * nrow_ * ndim_)); - RAFT_CUDA_TRY(cudaMallocHost(&graph_host_buffer_, - sizeof(*graph_host_buffer_) * nrow_ * DEGREE_ON_DEVICE)); - RAFT_CUDA_TRY(cudaMallocHost(&dists_host_buffer_, - sizeof(*dists_host_buffer_) * nrow_ * DEGREE_ON_DEVICE)); - RAFT_CUDA_TRY( - cudaMalloc(&dists_buffer_, sizeof(*dists_buffer_) * nrow_ * DEGREE_ON_DEVICE)); - thrust::fill(thrust::device, dists_buffer_, dists_buffer_ + (size_t)nrow_ * DEGREE_ON_DEVICE, - std::numeric_limits::max()); - RAFT_CUDA_TRY( - cudaMalloc(&graph_buffer_, sizeof(*graph_buffer_) * nrow_ * DEGREE_ON_DEVICE)); - thrust::fill(thrust::device, reinterpret_cast(graph_buffer_), - reinterpret_cast(graph_buffer_) + (size_t)nrow_ * DEGREE_ON_DEVICE, - std::numeric_limits::max()); - RAFT_CUDA_TRY(cudaMalloc(&d_locks_, sizeof(*d_locks_) * nrow_)); - thrust::fill(thrust::device, d_locks_, d_locks_ + nrow_, 0); - RAFT_CUDA_TRY( - cudaMallocHost(&h_rev_graph_new_, sizeof(*h_rev_graph_new_) * nrow_ * NUM_SAMPLES)); - RAFT_CUDA_TRY(cudaMallocHost(&h_graph_old_, sizeof(*h_graph_old_) * nrow_ * NUM_SAMPLES)); - RAFT_CUDA_TRY( - cudaMallocHost(&h_rev_graph_old_, sizeof(*h_rev_graph_old_) * nrow_ * NUM_SAMPLES)); - RAFT_CUDA_TRY(cudaMalloc(&d_list_sizes_new_, sizeof(*d_list_sizes_new_) * nrow_)); - RAFT_CUDA_TRY(cudaMalloc(&d_list_sizes_old_, sizeof(*d_list_sizes_old_) * nrow_)); - - if (build_config_.metric_type == Metric_t::METRIC_L2) { - RAFT_CUDA_TRY(cudaMalloc(&l2_norms_, sizeof(*l2_norms_) * nrow_)); - } else { - l2_norms_ = nullptr; - } +void GNND::alloc_workspace() +{ + RAFT_CUDA_TRY(cudaMalloc(&d_data_, sizeof(__half) * nrow_ * ndim_)); + RAFT_CUDA_TRY( + cudaMallocHost(&graph_host_buffer_, sizeof(*graph_host_buffer_) * nrow_ * DEGREE_ON_DEVICE)); + RAFT_CUDA_TRY( + cudaMallocHost(&dists_host_buffer_, sizeof(*dists_host_buffer_) * nrow_ * DEGREE_ON_DEVICE)); + RAFT_CUDA_TRY(cudaMalloc(&dists_buffer_, sizeof(*dists_buffer_) * nrow_ * DEGREE_ON_DEVICE)); + thrust::fill(thrust::device, + dists_buffer_, + dists_buffer_ + (size_t)nrow_ * DEGREE_ON_DEVICE, + std::numeric_limits::max()); + RAFT_CUDA_TRY(cudaMalloc(&graph_buffer_, sizeof(*graph_buffer_) * nrow_ * DEGREE_ON_DEVICE)); + thrust::fill(thrust::device, + reinterpret_cast(graph_buffer_), + reinterpret_cast(graph_buffer_) + (size_t)nrow_ * DEGREE_ON_DEVICE, + std::numeric_limits::max()); + RAFT_CUDA_TRY(cudaMalloc(&d_locks_, sizeof(*d_locks_) * nrow_)); + thrust::fill(thrust::device, d_locks_, d_locks_ + nrow_, 0); + RAFT_CUDA_TRY(cudaMallocHost(&h_rev_graph_new_, sizeof(*h_rev_graph_new_) * nrow_ * NUM_SAMPLES)); + RAFT_CUDA_TRY(cudaMallocHost(&h_graph_old_, sizeof(*h_graph_old_) * nrow_ * NUM_SAMPLES)); + RAFT_CUDA_TRY(cudaMallocHost(&h_rev_graph_old_, sizeof(*h_rev_graph_old_) * nrow_ * NUM_SAMPLES)); + RAFT_CUDA_TRY(cudaMalloc(&d_list_sizes_new_, sizeof(*d_list_sizes_new_) * nrow_)); + RAFT_CUDA_TRY(cudaMalloc(&d_list_sizes_old_, sizeof(*d_list_sizes_old_) * nrow_)); + + if (build_config_.metric_type == Metric_t::METRIC_L2) { + RAFT_CUDA_TRY(cudaMalloc(&l2_norms_, sizeof(*l2_norms_) * nrow_)); + } else { + l2_norms_ = nullptr; + } } template -void GNND::add_reverse_edges(Index_t* graph_ptr, Index_t* h_rev_graph_ptr, - Index_t* d_rev_graph_ptr, int2* list_sizes, - cudaStream_t stream) { - add_rev_edges_kernel<<>>(graph_ptr, d_rev_graph_ptr, NUM_SAMPLES, - list_sizes); - RAFT_CUDA_TRY(cudaMemcpyAsync(h_rev_graph_ptr, d_rev_graph_ptr, - sizeof(*h_rev_graph_ptr) * nrow_ * NUM_SAMPLES, - cudaMemcpyDefault, stream)); +void GNND::add_reverse_edges(Index_t* graph_ptr, + Index_t* h_rev_graph_ptr, + Index_t* d_rev_graph_ptr, + int2* list_sizes, + cudaStream_t stream) +{ + add_rev_edges_kernel<<>>( + graph_ptr, d_rev_graph_ptr, NUM_SAMPLES, list_sizes); + RAFT_CUDA_TRY(cudaMemcpyAsync(h_rev_graph_ptr, + d_rev_graph_ptr, + sizeof(*h_rev_graph_ptr) * nrow_ * NUM_SAMPLES, + cudaMemcpyDefault, + stream)); } template -void GNND::local_join(cudaStream_t stream) { - thrust::fill(thrust::device.on(stream), dists_buffer_, - dists_buffer_ + (size_t)nrow_ * DEGREE_ON_DEVICE, - std::numeric_limits::max()); - local_join_kernel<<>>( - graph_.h_graph_new, h_rev_graph_new_, d_list_sizes_new_, h_graph_old_, h_rev_graph_old_, - d_list_sizes_old_, NUM_SAMPLES, d_data_, ndim_, graph_buffer_, dists_buffer_, - DEGREE_ON_DEVICE, d_locks_, l2_norms_); +void GNND::local_join(cudaStream_t stream) +{ + thrust::fill(thrust::device.on(stream), + dists_buffer_, + dists_buffer_ + (size_t)nrow_ * DEGREE_ON_DEVICE, + std::numeric_limits::max()); + local_join_kernel<<>>(graph_.h_graph_new, + h_rev_graph_new_, + d_list_sizes_new_, + h_graph_old_, + h_rev_graph_old_, + d_list_sizes_old_, + NUM_SAMPLES, + d_data_, + ndim_, + graph_buffer_, + dists_buffer_, + DEGREE_ON_DEVICE, + d_locks_, + l2_norms_); } template -void GNND::build(Data_t* data, const Index_t nrow, Index_t* output_graph, cudaStream_t stream) { - cudaStreamSynchronize(stream); - nrow_ = nrow; - graph_.h_graph = (InternalID_t*)output_graph; - - cudaPointerAttributes data_ptr_attr; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data)); - if (data_ptr_attr.type == cudaMemoryTypeUnregistered) { - typename std::remove_const::type* input_data; - size_t batch_size = 100000; - RAFT_CUDA_TRY(cudaMallocAsync(&input_data, - sizeof(Data_t) * batch_size * build_config_.dataset_dim, stream)); - for (size_t step = 0; step < div_up(nrow_, batch_size); step++) { - size_t list_offset = step * batch_size; - size_t num_lists = - step != div_up(nrow_, batch_size) - 1 ? batch_size : nrow_ - list_offset; - RAFT_CUDA_TRY(cudaMemcpyAsync( - input_data, data + list_offset * build_config_.dataset_dim, - sizeof(Data_t) * num_lists * build_config_.dataset_dim, cudaMemcpyDefault, stream)); - preprocess_data_kernel<<>>(input_data, d_data_, build_config_.dataset_dim, - l2_norms_, list_offset); - } - RAFT_CUDA_TRY(cudaFreeAsync(input_data, stream)); - } else { - preprocess_data_kernel<<< - nrow_, WARP_SIZE, - sizeof(Data_t) * div_up(build_config_.dataset_dim, WARP_SIZE) * WARP_SIZE, stream>>>( - data, d_data_, build_config_.dataset_dim, l2_norms_); +void GNND::build(Data_t* data, + const Index_t nrow, + Index_t* output_graph, + cudaStream_t stream) +{ + cudaStreamSynchronize(stream); + nrow_ = nrow; + graph_.h_graph = (InternalID_t*)output_graph; + + cudaPointerAttributes data_ptr_attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data)); + if (data_ptr_attr.type == cudaMemoryTypeUnregistered) { + typename std::remove_const::type* input_data; + size_t batch_size = 100000; + RAFT_CUDA_TRY(cudaMallocAsync( + &input_data, sizeof(Data_t) * batch_size * build_config_.dataset_dim, stream)); + for (size_t step = 0; step < div_up(nrow_, batch_size); step++) { + size_t list_offset = step * batch_size; + size_t num_lists = step != div_up(nrow_, batch_size) - 1 ? batch_size : nrow_ - list_offset; + RAFT_CUDA_TRY(cudaMemcpyAsync(input_data, + data + list_offset * build_config_.dataset_dim, + sizeof(Data_t) * num_lists * build_config_.dataset_dim, + cudaMemcpyDefault, + stream)); + preprocess_data_kernel<<>>( + input_data, d_data_, build_config_.dataset_dim, l2_norms_, list_offset); } + RAFT_CUDA_TRY(cudaFreeAsync(input_data, stream)); + } else { + preprocess_data_kernel<<>>(data, d_data_, build_config_.dataset_dim, l2_norms_); + } - thrust::fill(thrust::device.on(stream), (Index_t*)graph_buffer_, - (Index_t*)graph_buffer_ + (size_t)nrow_ * DEGREE_ON_DEVICE, - std::numeric_limits::max()); - - graph_.clear(); - graph_.init_random_graph(); - graph_.sample_graph(true); - - auto update_and_sample = [&](bool update_graph) { - if (update_graph) { - update_counter_ = 0; - graph_.update_graph(graph_host_buffer_, dists_host_buffer_, DEGREE_ON_DEVICE, - update_counter_); - if (update_counter_ < build_config_.termination_threshold * nrow_ * - build_config_.dataset_dim / counter_interval) { - update_counter_ = -1; - } - } - graph_.sample_graph(false); - }; - - for (size_t it = 0; it < build_config_.max_iterations; it++) { - RAFT_CUDA_TRY(cudaMemcpyAsync(d_list_sizes_new_, graph_.h_list_sizes_new, - sizeof(*d_list_sizes_new_) * nrow_, - cudaMemcpyDefault, stream)); - RAFT_CUDA_TRY(cudaMemcpyAsync(h_graph_old_, graph_.h_graph_old, - sizeof(*h_graph_old_) * nrow_ * NUM_SAMPLES, - cudaMemcpyDefault, stream)); - RAFT_CUDA_TRY(cudaMemcpyAsync(d_list_sizes_old_, graph_.h_list_sizes_old, - sizeof(*d_list_sizes_old_) * nrow_, - cudaMemcpyDefault, stream)); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - - std::thread update_and_sample_thread(update_and_sample, it); - - std::cout << "# GNND iteraton: " << it + 1 << "/" << build_config_.max_iterations << "\r"; - std::fflush(stdout); - - // Reuse dists_buffer_ to save GPU memory. graph_buffer_ cannot be reused, because it - // contains some information for local_join. - static_assert(DEGREE_ON_DEVICE * sizeof(*dists_buffer_) >= - NUM_SAMPLES * sizeof(*graph_buffer_)); - add_reverse_edges(graph_.h_graph_new, h_rev_graph_new_, (Index_t*)dists_buffer_, - d_list_sizes_new_, stream); - add_reverse_edges(h_graph_old_, h_rev_graph_old_, (Index_t*)dists_buffer_, - d_list_sizes_old_, stream); - - local_join(stream); - - update_and_sample_thread.join(); - - if (update_counter_ == -1) { - break; - } - - RAFT_CUDA_TRY(cudaMemcpyAsync(graph_host_buffer_, graph_buffer_, - sizeof(*graph_buffer_) * nrow_ * DEGREE_ON_DEVICE, - cudaMemcpyDefault, stream)); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - RAFT_CUDA_TRY(cudaMemcpyAsync(dists_host_buffer_, dists_buffer_, - sizeof(*dists_buffer_) * nrow_ * DEGREE_ON_DEVICE, - cudaMemcpyDefault, stream)); - graph_.sample_graph_new(graph_host_buffer_, DEGREE_ON_DEVICE); + thrust::fill(thrust::device.on(stream), + (Index_t*)graph_buffer_, + (Index_t*)graph_buffer_ + (size_t)nrow_ * DEGREE_ON_DEVICE, + std::numeric_limits::max()); + + graph_.clear(); + graph_.init_random_graph(); + graph_.sample_graph(true); + + auto update_and_sample = [&](bool update_graph) { + if (update_graph) { + update_counter_ = 0; + graph_.update_graph( + graph_host_buffer_, dists_host_buffer_, DEGREE_ON_DEVICE, update_counter_); + if (update_counter_ < build_config_.termination_threshold * nrow_ * + build_config_.dataset_dim / counter_interval) { + update_counter_ = -1; + } } + graph_.sample_graph(false); + }; + + for (size_t it = 0; it < build_config_.max_iterations; it++) { + RAFT_CUDA_TRY(cudaMemcpyAsync(d_list_sizes_new_, + graph_.h_list_sizes_new, + sizeof(*d_list_sizes_new_) * nrow_, + cudaMemcpyDefault, + stream)); + RAFT_CUDA_TRY(cudaMemcpyAsync(h_graph_old_, + graph_.h_graph_old, + sizeof(*h_graph_old_) * nrow_ * NUM_SAMPLES, + cudaMemcpyDefault, + stream)); + RAFT_CUDA_TRY(cudaMemcpyAsync(d_list_sizes_old_, + graph_.h_list_sizes_old, + sizeof(*d_list_sizes_old_) * nrow_, + cudaMemcpyDefault, + stream)); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + + std::thread update_and_sample_thread(update_and_sample, it); + + std::cout << "# GNND iteraton: " << it + 1 << "/" << build_config_.max_iterations << "\r"; + std::fflush(stdout); + + // Reuse dists_buffer_ to save GPU memory. graph_buffer_ cannot be reused, because it + // contains some information for local_join. + static_assert(DEGREE_ON_DEVICE * sizeof(*dists_buffer_) >= + NUM_SAMPLES * sizeof(*graph_buffer_)); + add_reverse_edges( + graph_.h_graph_new, h_rev_graph_new_, (Index_t*)dists_buffer_, d_list_sizes_new_, stream); + add_reverse_edges( + h_graph_old_, h_rev_graph_old_, (Index_t*)dists_buffer_, d_list_sizes_old_, stream); + + local_join(stream); + + update_and_sample_thread.join(); + + if (update_counter_ == -1) { break; } - graph_.update_graph(graph_host_buffer_, dists_host_buffer_, DEGREE_ON_DEVICE, update_counter_); + RAFT_CUDA_TRY(cudaMemcpyAsync(graph_host_buffer_, + graph_buffer_, + sizeof(*graph_buffer_) * nrow_ * DEGREE_ON_DEVICE, + cudaMemcpyDefault, + stream)); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + RAFT_CUDA_TRY(cudaMemcpyAsync(dists_host_buffer_, + dists_buffer_, + sizeof(*dists_buffer_) * nrow_ * DEGREE_ON_DEVICE, + cudaMemcpyDefault, + stream)); + graph_.sample_graph_new(graph_host_buffer_, DEGREE_ON_DEVICE); + } + + graph_.update_graph(graph_host_buffer_, dists_host_buffer_, DEGREE_ON_DEVICE, update_counter_); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - graph_.sort_lists(); + graph_.sort_lists(); - // Reuse graph_.h_dists as the buffer for shrink the lists in graph - static_assert(sizeof(decltype(*graph_.h_dists)) >= sizeof(Index_t)); - Index_t* graph_shrink_buffer = (Index_t*)graph_.h_dists; + // Reuse graph_.h_dists as the buffer for shrink the lists in graph + static_assert(sizeof(decltype(*graph_.h_dists)) >= sizeof(Index_t)); + Index_t* graph_shrink_buffer = (Index_t*)graph_.h_dists; #pragma omp parallel for - for (size_t i = 0; i < (size_t)nrow_; i++) { - for (size_t j = 0; j < build_config_.node_degree; j++) { - size_t idx = i * graph_.node_degree + j; - Index_t id = graph_.h_graph[idx].id(); - if (id < nrow_) { - graph_shrink_buffer[i * build_config_.node_degree + j] = id; - } else { - graph_shrink_buffer[i * build_config_.node_degree + j] = xorshift64(idx) % nrow_; - } - } + for (size_t i = 0; i < (size_t)nrow_; i++) { + for (size_t j = 0; j < build_config_.node_degree; j++) { + size_t idx = i * graph_.node_degree + j; + Index_t id = graph_.h_graph[idx].id(); + if (id < nrow_) { + graph_shrink_buffer[i * build_config_.node_degree + j] = id; + } else { + graph_shrink_buffer[i * build_config_.node_degree + j] = xorshift64(idx) % nrow_; + } } - graph_.h_graph = nullptr; + } + graph_.h_graph = nullptr; #pragma omp parallel for - for (size_t i = 0; i < (size_t)nrow_; i++) { - for (size_t j = 0; j < build_config_.node_degree; j++) { - output_graph[i * build_config_.node_degree + j] = - graph_shrink_buffer[i * build_config_.node_degree + j]; - } + for (size_t i = 0; i < (size_t)nrow_; i++) { + for (size_t j = 0; j < build_config_.node_degree; j++) { + output_graph[i * build_config_.node_degree + j] = + graph_shrink_buffer[i * build_config_.node_degree + j]; } + } } template -void GNND::dealloc() { - graph_.dealloc(); - RAFT_CUDA_TRY(cudaFree(d_data_)); - RAFT_CUDA_TRY(cudaFreeHost(graph_host_buffer_)); - RAFT_CUDA_TRY(cudaFreeHost(dists_host_buffer_)); - RAFT_CUDA_TRY(cudaFree(dists_buffer_)); - RAFT_CUDA_TRY(cudaFree(graph_buffer_)); - RAFT_CUDA_TRY(cudaFree(d_locks_)); - RAFT_CUDA_TRY(cudaFreeHost(h_rev_graph_new_)); - RAFT_CUDA_TRY(cudaFreeHost(h_graph_old_)); - RAFT_CUDA_TRY(cudaFreeHost(h_rev_graph_old_)); - RAFT_CUDA_TRY(cudaFree(d_list_sizes_new_)); - RAFT_CUDA_TRY(cudaFree(d_list_sizes_old_)); - RAFT_CUDA_TRY(cudaFree(l2_norms_)); +void GNND::dealloc() +{ + graph_.dealloc(); + RAFT_CUDA_TRY(cudaFree(d_data_)); + RAFT_CUDA_TRY(cudaFreeHost(graph_host_buffer_)); + RAFT_CUDA_TRY(cudaFreeHost(dists_host_buffer_)); + RAFT_CUDA_TRY(cudaFree(dists_buffer_)); + RAFT_CUDA_TRY(cudaFree(graph_buffer_)); + RAFT_CUDA_TRY(cudaFree(d_locks_)); + RAFT_CUDA_TRY(cudaFreeHost(h_rev_graph_new_)); + RAFT_CUDA_TRY(cudaFreeHost(h_graph_old_)); + RAFT_CUDA_TRY(cudaFreeHost(h_rev_graph_old_)); + RAFT_CUDA_TRY(cudaFree(d_list_sizes_new_)); + RAFT_CUDA_TRY(cudaFree(d_list_sizes_old_)); + RAFT_CUDA_TRY(cudaFree(l2_norms_)); } template -GNND::~GNND() { +GNND::~GNND() +{ } template , memory_type::host>> index build(raft::resources const& res, - const index_params& params, - mdspan, row_major, Accessor> dataset) { - RAFT_EXPECTS(dataset.size() < std::numeric_limits::max() - 1, - "The dataset size for GNND should be less than %d", - std::numeric_limits::max() - 1); + const index_params& params, + mdspan, row_major, Accessor> dataset) +{ + RAFT_EXPECTS(dataset.size() < std::numeric_limits::max() - 1, + "The dataset size for GNND should be less than %d", + std::numeric_limits::max() - 1); size_t intermediate_degree = params.intermediate_graph_degree; size_t graph_degree = params.graph_degree; if (intermediate_degree >= static_cast(dataset.extent(0))) { @@ -1300,32 +1398,40 @@ index build(raft::resources const& res, graph_degree = intermediate_degree; } // The elements in each knn-list are partitioned into different buckets, and we need more buckets - // to mitigate bucket collisions. `intermediate_degree` is OK to larger than extended_graph_degree. + // to mitigate bucket collisions. `intermediate_degree` is OK to larger than + // extended_graph_degree. size_t extended_graph_degree = to_multiple_of_32(graph_degree * (graph_degree <= 32 ? 1.0 : 1.3)); - index int_idx{res, dataset.extent(0), static_cast(extended_graph_degree)}; + size_t extended_intermediate_degree = + to_multiple_of_32(intermediate_degree * (graph_degree <= 32 ? 1.0 : 1.3)); + + auto int_graph = raft::make_host_matrix( + dataset.extent(0), static_cast(extended_graph_degree)); - BuildConfig build_config{.max_dataset_size = static_cast(dataset.extent(0)), - .dataset_dim = static_cast(dataset.extent(1)), - .node_degree = extended_graph_degree, - .internal_node_degree = intermediate_degree, - .max_iterations = params.max_iterations, - .termination_threshold = params.termination_threshold, - .metric_type = Metric_t::METRIC_L2}; + BuildConfig build_config{.max_dataset_size = static_cast(dataset.extent(0)), + .dataset_dim = static_cast(dataset.extent(1)), + .node_degree = extended_graph_degree, + .internal_node_degree = extended_intermediate_degree, + .max_iterations = params.max_iterations, + .termination_threshold = params.termination_threshold, + .metric_type = Metric_t::METRIC_L2}; GNND nnd(build_config); - std::cout << "Intermediate graph dim: " << int_idx.int_graph().extent(0) << ", " << int_idx.int_graph().extent(1) << std::endl; - nnd.build(dataset.data_handle(), dataset.extent(0), int_idx.int_graph().data_handle(), resource::get_cuda_stream(res)); + std::cout << "Intermediate graph dim: " << int_graph.extent(0) << ", " << int_graph.extent(1) + << std::endl; + nnd.build(dataset.data_handle(), + dataset.extent(0), + int_graph.data_handle(), + resource::get_cuda_stream(res)); nnd.dealloc(); index idx{res, dataset.extent(0), static_cast(graph_degree)}; #pragma omp parallel for for (size_t i = 0; i < static_cast(dataset.extent(0)); i++) { - for (size_t j = 0; j < graph_degree; j++) { - auto graph = idx.int_graph().data_handle(); - auto int_graph = int_idx.int_graph().data_handle(); - graph[i * graph_degree + j] = int_graph[i * extended_graph_degree + j]; - } + for (size_t j = 0; j < graph_degree; j++) { + auto graph = idx.graph().data_handle(); + graph[i * graph_degree + j] = int_graph.data_handle()[i * extended_graph_degree + j]; + } } return idx; } -} +} // namespace raft::neighbors::nn_descent::detail diff --git a/cpp/include/raft/neighbors/nn_descent_types.hpp b/cpp/include/raft/neighbors/nn_descent_types.hpp index 8a8e97497e..ae4a70ba50 100644 --- a/cpp/include/raft/neighbors/nn_descent_types.hpp +++ b/cpp/include/raft/neighbors/nn_descent_types.hpp @@ -32,37 +32,38 @@ namespace raft::neighbors::nn_descent { */ struct index_params : ann::index_params { - size_t intermediate_graph_degree = 128; // Degree of input graph for pruning. - size_t graph_degree = 64; // Degree of output graph. - size_t max_iterations = 50; // Number of nn-descent iterations. - float termination_threshold = 0.0001; // Termination threshold of nn-descent. + size_t intermediate_graph_degree = 128; // Degree of input graph for pruning. + size_t graph_degree = 64; // Degree of output graph. + size_t max_iterations = 50; // Number of nn-descent iterations. + float termination_threshold = 0.0001; // Termination threshold of nn-descent. }; /** * @brief nn-descent Index - * + * * @tparam IdxT dtype to be used for constructing knn-graph */ template struct index : ann::index { -public: + public: /** * @brief Construct a new index object - * + * * This constructor creates an nn-descent index which is a knn-graph in host memory. * The type of the knn-graph is a dense raft::host_matrix and dimensions are * (n_rows, n_cols). - * + * * @param res raft::resources * @param n_rows number of rows in knn-graph * @param n_cols number of cols in knn-graph */ - index(raft::resources const& res, int64_t n_rows, int64_t n_cols) : - ann::index(), - res_{res}, - metric_{raft::distance::DistanceType::L2Expanded}, - int_graph_{raft::make_host_matrix(n_rows, n_cols)}, - graph_{raft::make_host_matrix(0, 0)} { } + index(raft::resources const& res, int64_t n_rows, int64_t n_cols) + : ann::index(), + res_{res}, + metric_{raft::distance::DistanceType::L2Expanded}, + graph_{raft::make_host_matrix(n_rows, n_cols)} + { + } /** Distance metric used for clustering. */ [[nodiscard]] constexpr inline auto metric() const noexcept -> raft::distance::DistanceType @@ -83,28 +84,11 @@ struct index : ann::index { } /** neighborhood graph [size, graph-degree] */ - [[nodiscard]] inline auto graph() noexcept - -> host_matrix_view + [[nodiscard]] inline auto graph() noexcept -> host_matrix_view { - if constexpr (std::is_same_v or std::is_same_v) { - return raft::make_host_matrix_view( - reinterpret_cast(int_graph_.data_handle()), - int_graph_.extent(0), - int_graph_.extent(1)); - } - else { - graph_ = raft::make_host_matrix(int_graph_.extent(0), int_graph_.extent(1)); - std::copy(graph_.data_handle(), graph_.data_handle() + graph_.size(), int_graph_.data_handle()); - return graph_.view(); - } + return graph_.view(); } - /** int type graph */ - [[nodiscard]] inline auto int_graph() noexcept - -> host_matrix_view { - return int_graph_.view(); - } - // Don't allow copying the index for performance reasons (try avoiding copying data) index(const index&) = delete; index(index&&) = default; @@ -112,13 +96,12 @@ struct index : ann::index { auto operator=(index&&) -> index& = default; ~index() = default; -private: + private: raft::resources const& res_; raft::distance::DistanceType metric_; - raft::host_matrix int_graph_; // nn-descent only supports int IdxT graphs - raft::host_matrix graph_; // graph to return for non-int IdxT + raft::host_matrix graph_; // graph to return for non-int IdxT }; /** @} */ -} +} // namespace raft::neighbors::nn_descent diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index c10d7c1299..1f82d6771b 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -144,12 +144,13 @@ struct AnnCagraInputs { inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraInputs& p) { - std::vector algo = {"single-cta", "multi_cta", "multi_kernel", "auto"}; + std::vector algo = {"single-cta", "multi_cta", "multi_kernel", "auto"}; + std::vector build_algo = {"IVF_PQ", "NN_DESCENT"}; os << "{n_queries=" << p.n_queries << ", dataset shape=" << p.n_rows << "x" << p.dim << ", k=" << p.k << ", " << algo.at((int)p.algo) << ", max_queries=" << p.max_queries << ", itopk_size=" << p.itopk_size << ", search_width=" << p.search_width - << ", metric=" << static_cast(p.metric) << (p.host_dataset ? ", host" : ", device") << '}' - << std::endl; + << ", metric=" << static_cast(p.metric) << (p.host_dataset ? ", host" : ", device") + << ", build_algo=" << build_algo.at((int)p.build_algo) << '}' << std::endl; return os; } @@ -322,11 +323,29 @@ class AnnCagraSortTest : public ::testing::TestWithParam { auto knn_graph = raft::make_host_matrix(ps.n_rows, index_params.intermediate_graph_degree); - if (ps.host_dataset) { - cagra::build_knn_graph(handle_, database_host_view, knn_graph.view()); + if (ps.build_algo == graph_build_algo::IVF_PQ) { + if (ps.host_dataset) { + cagra::build_knn_graph(handle_, database_host_view, knn_graph.view()); + } else { + cagra::build_knn_graph(handle_, database_view, knn_graph.view()); + } } else { - cagra::build_knn_graph(handle_, database_view, knn_graph.view()); - }; + auto nn_descent_idx_params = std::make_optional(); + nn_descent_idx_params->graph_degree = index_params.intermediate_graph_degree; + nn_descent_idx_params->intermediate_graph_degree = index_params.intermediate_graph_degree; + + if (ps.host_dataset) { + auto nn_descent_idx = + cagra::build_knn_graph(handle_, database_host_view, nn_descent_idx_params); + std::memcpy( + knn_graph.data_handle(), nn_descent_idx.graph().data_handle(), knn_graph.size()); + } else { + auto nn_descent_idx = + cagra::build_knn_graph(handle_, database_host_view, nn_descent_idx_params); + std::memcpy( + knn_graph.data_handle(), nn_descent_idx.graph().data_handle(), knn_graph.size()); + } + } handle_.sync_stream(); ASSERT_TRUE(CheckOrder(knn_graph.view(), database_host.view())); @@ -373,7 +392,7 @@ inline std::vector generate_inputs() {100}, {1000}, {1, 8, 17}, - {1, 16}, // k + {1, 16}, // k {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, {0, 1, 10, 100}, // query size @@ -399,68 +418,52 @@ inline std::vector generate_inputs() {false}, {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - inputs2 = - raft::util::itertools::product({100}, - {1000}, - {64}, - {16}, - {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, - {search_algo::AUTO}, - {10}, - {0, 4, 8, 16, 32}, // team_size - {64}, - {1}, - {raft::distance::DistanceType::L2Expanded}, - {false}, - {0.995}); - inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - - inputs2 = - raft::util::itertools::product({100}, - {1000}, - {64}, - {16}, - {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, - {search_algo::AUTO}, - {10}, - {0}, // team_size - {32, 64, 128, 256, 512, 768}, - {1}, - {raft::distance::DistanceType::L2Expanded}, - {false}, - {0.995}); + inputs2 = raft::util::itertools::product( + {100}, + {1000}, + {64}, + {16}, + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, + {search_algo::AUTO}, + {10}, + {0, 4, 8, 16, 32}, // team_size + {64}, + {1}, + {raft::distance::DistanceType::L2Expanded}, + {false}, + {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - inputs2 = - raft::util::itertools::product({100}, - {10000, 20000}, - {32}, - {10}, - {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, - {search_algo::AUTO}, - {10}, - {0}, // team_size - {64}, - {1}, - {raft::distance::DistanceType::L2Expanded}, - {false, true}, - {0.995}); + inputs2 = raft::util::itertools::product( + {100}, + {1000}, + {64}, + {16}, + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, + {search_algo::AUTO}, + {10}, + {0}, // team_size + {32, 64, 128, 256, 512, 768}, + {1}, + {raft::distance::DistanceType::L2Expanded}, + {false}, + {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - inputs2 = - raft::util::itertools::product({100}, - {10000, 20000}, - {32}, - {10}, - {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, - {search_algo::AUTO}, - {10}, - {0}, // team_size - {64}, - {1}, - {raft::distance::DistanceType::L2Expanded}, - {false, true}, - {0.995}); + inputs2 = raft::util::itertools::product( + {100}, + {10000, 20000}, + {32}, + {10}, + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, + {search_algo::AUTO}, + {10}, + {0}, // team_size + {64}, + {1}, + {raft::distance::DistanceType::L2Expanded}, + {false, true}, + {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); return inputs; From 6ac9186e12670e536f410b1f898d12c6e838ba20 Mon Sep 17 00:00:00 2001 From: Ray Wang Date: Wed, 30 Aug 2023 04:21:55 -0700 Subject: [PATCH 07/28] Fix duplicate nodes issue --- .../raft/neighbors/detail/nn_descent.cuh | 41 +++++++++++-------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index 6ef651d37f..b3c64a33ac 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -927,11 +927,12 @@ int insert_to_ordered_list(InternalID_t* list, if (dist > dist_list[width - 1]) { return width; } int idx_insert = width; + bool position_found = false; for (int i = 0; i < width; i++) { if (list[i].id() == neighb_id.id()) { return width; } - if (dist_list[i] > dist) { + if (!position_found && dist_list[i] > dist) { idx_insert = i; - break; + position_found = true; } } if (idx_insert == width) return idx_insert; @@ -1000,19 +1001,27 @@ void GnndGraph::sample_graph_new(InternalID_t* new_neighbors, template void GnndGraph::init_random_graph() { - // random sequence (range: 0~nrow) - std::vector rand_seq(nrow); - std::iota(rand_seq.begin(), rand_seq.end(), 0); - std::random_shuffle(rand_seq.begin(), rand_seq.end()); + for (size_t seg_idx = 0; seg_idx < static_cast(num_segments); seg_idx++) { + // random sequence (range: 0~nrow) + // segment_x stores neighbors which id % num_segments == x + std::vector rand_seq(nrow / num_segments); + std::iota(rand_seq.begin(), rand_seq.end(), 0); + std::random_shuffle(rand_seq.begin(), rand_seq.end()); #pragma omp parallel for - for (size_t i = 0; i < nrow; i++) { - for (size_t j = 0; j < node_degree; j++) { - size_t idx = i * node_degree + j; - Index_t id = rand_seq[idx % nrow]; - if ((size_t)id == i) { id = rand_seq[(idx + node_degree) % nrow]; } - h_graph[i * node_degree + j].id_with_flag() = id; - h_dists[i * node_degree + j] = std::numeric_limits::max(); + for (size_t i = 0; i < nrow; i++) { + size_t base_idx = i * node_degree + seg_idx * segment_size; + auto h_neighbor_list = h_graph + base_idx; + auto h_dist_list = h_dists + base_idx; + for (size_t j = 0; j < static_cast(segment_size); j++) { + size_t idx = base_idx + j; + Index_t id = rand_seq[idx % rand_seq.size()] * num_segments + seg_idx; + if ((size_t)id == i) { + id = rand_seq[(idx + segment_size) % rand_seq.size()] * num_segments + seg_idx; + } + h_neighbor_list[j].id_with_flag() = id; + h_dist_list[j] = std::numeric_limits::max(); + } } } } @@ -1064,9 +1073,9 @@ void GnndGraph::update_graph(const InternalID_t* new_neighbors auto new_dist = new_dists[i * width + j]; if (new_dist == std::numeric_limits::max()) break; if ((size_t)new_neighb_id.id() == i) continue; - int idx_seg = new_neighb_id.id() % num_segments; - auto list = h_graph + i * node_degree + idx_seg * segment_size; - auto dist_list = h_dists + i * node_degree + idx_seg * segment_size; + int seg_idx = new_neighb_id.id() % num_segments; + auto list = h_graph + i * node_degree + seg_idx * segment_size; + auto dist_list = h_dists + i * node_degree + seg_idx * segment_size; int insert_pos = insert_to_ordered_list(list, dist_list, segment_size, new_neighb_id, new_dist); if (i % counter_interval == 0 && insert_pos != segment_size) { update_counter++; } From 508050f53247632cd03cb535b6f4c32ea593d7fd Mon Sep 17 00:00:00 2001 From: Ray Wang Date: Wed, 30 Aug 2023 04:47:18 -0700 Subject: [PATCH 08/28] Fix IMA in sort_knn_graph --- cpp/include/raft/neighbors/detail/cagra/graph_core.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index 0558d7ea39..f535098df1 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -244,7 +244,7 @@ void sort_knn_graph(raft::resources const& res, const uint32_t input_graph_degree = knn_graph.extent(1); IdxT* const input_graph_ptr = knn_graph.data_handle(); - auto d_input_graph = raft::make_device_matrix(res, graph_size, input_graph_degree); + auto d_input_graph = raft::make_device_matrix(res, graph_size, input_graph_degree); // // Sorting kNN graph From 0496bd955e684610aef90e767122b198c5e15307 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 30 Aug 2023 06:16:28 -0700 Subject: [PATCH 09/28] temp benchmark --- cpp/include/raft/neighbors/cagra.cuh | 8 +++++--- cpp/include/raft/neighbors/cagra_types.hpp | 13 +++++-------- cpp/include/raft/neighbors/nn_descent_types.hpp | 2 +- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index a544f7044c..14b3c5d2e0 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -277,9 +277,11 @@ index build(raft::resources const& res, optimize(res, knn_graph.view(), cagra_graph.view()); } else { - auto nn_descent_params = std::make_optional(); - nn_descent_params->intermediate_graph_degree = intermediate_degree; - nn_descent_params->graph_degree = intermediate_degree; + auto nn_descent_params = std::make_optional(); + // nn_descent_params->intermediate_graph_degree = intermediate_degree; + // nn_descent_params->graph_degree = intermediate_degree; + nn_descent_params->intermediate_graph_degree = 64; + nn_descent_params->graph_degree = 96; auto nn_descent_index = build_knn_graph(res, dataset, nn_descent_params); optimize(res, nn_descent_index.graph(), cagra_graph.view()); diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 935d15edb8..2c8daa362b 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -40,15 +40,12 @@ namespace raft::neighbors::cagra { * @{ */ -enum class graph_build_algo { - IVF_PQ, - NN_DESCENT -}; +enum class graph_build_algo { IVF_PQ, NN_DESCENT }; struct index_params : ann::index_params { - size_t intermediate_graph_degree = 128; // Degree of input graph for pruning. - size_t graph_degree = 64; // Degree of output graph. - graph_build_algo build_algo = graph_build_algo::IVF_PQ; // ANN algorithm to build knn graph + size_t intermediate_graph_degree = 128; // Degree of input graph for pruning. + size_t graph_degree = 32; // Degree of output graph. + graph_build_algo build_algo = graph_build_algo::NN_DESCENT; // ANN algorithm to build knn graph }; enum class search_algo { @@ -353,9 +350,9 @@ struct index : ann::index { // TODO: Remove deprecated experimental namespace in 23.12 release namespace raft::neighbors::experimental::cagra { +using raft::neighbors::cagra::graph_build_algo; using raft::neighbors::cagra::hash_mode; using raft::neighbors::cagra::index; -using raft::neighbors::cagra::graph_build_algo; using raft::neighbors::cagra::index_params; using raft::neighbors::cagra::search_algo; using raft::neighbors::cagra::search_params; diff --git a/cpp/include/raft/neighbors/nn_descent_types.hpp b/cpp/include/raft/neighbors/nn_descent_types.hpp index ae4a70ba50..4df03fc643 100644 --- a/cpp/include/raft/neighbors/nn_descent_types.hpp +++ b/cpp/include/raft/neighbors/nn_descent_types.hpp @@ -34,7 +34,7 @@ namespace raft::neighbors::nn_descent { struct index_params : ann::index_params { size_t intermediate_graph_degree = 128; // Degree of input graph for pruning. size_t graph_degree = 64; // Degree of output graph. - size_t max_iterations = 50; // Number of nn-descent iterations. + size_t max_iterations = 15; // Number of nn-descent iterations. float termination_threshold = 0.0001; // Termination threshold of nn-descent. }; From 0e96d405c14b761d66a2a66bb474cbbc51485f31 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 31 Aug 2023 13:03:52 -0700 Subject: [PATCH 10/28] Revert "temp benchmark" This reverts commit 0496bd955e684610aef90e767122b198c5e15307. --- cpp/include/raft/neighbors/cagra.cuh | 8 +++----- cpp/include/raft/neighbors/cagra_types.hpp | 13 ++++++++----- cpp/include/raft/neighbors/nn_descent_types.hpp | 2 +- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 14b3c5d2e0..a544f7044c 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -277,11 +277,9 @@ index build(raft::resources const& res, optimize(res, knn_graph.view(), cagra_graph.view()); } else { - auto nn_descent_params = std::make_optional(); - // nn_descent_params->intermediate_graph_degree = intermediate_degree; - // nn_descent_params->graph_degree = intermediate_degree; - nn_descent_params->intermediate_graph_degree = 64; - nn_descent_params->graph_degree = 96; + auto nn_descent_params = std::make_optional(); + nn_descent_params->intermediate_graph_degree = intermediate_degree; + nn_descent_params->graph_degree = intermediate_degree; auto nn_descent_index = build_knn_graph(res, dataset, nn_descent_params); optimize(res, nn_descent_index.graph(), cagra_graph.view()); diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 2c8daa362b..935d15edb8 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -40,12 +40,15 @@ namespace raft::neighbors::cagra { * @{ */ -enum class graph_build_algo { IVF_PQ, NN_DESCENT }; +enum class graph_build_algo { + IVF_PQ, + NN_DESCENT +}; struct index_params : ann::index_params { - size_t intermediate_graph_degree = 128; // Degree of input graph for pruning. - size_t graph_degree = 32; // Degree of output graph. - graph_build_algo build_algo = graph_build_algo::NN_DESCENT; // ANN algorithm to build knn graph + size_t intermediate_graph_degree = 128; // Degree of input graph for pruning. + size_t graph_degree = 64; // Degree of output graph. + graph_build_algo build_algo = graph_build_algo::IVF_PQ; // ANN algorithm to build knn graph }; enum class search_algo { @@ -350,9 +353,9 @@ struct index : ann::index { // TODO: Remove deprecated experimental namespace in 23.12 release namespace raft::neighbors::experimental::cagra { -using raft::neighbors::cagra::graph_build_algo; using raft::neighbors::cagra::hash_mode; using raft::neighbors::cagra::index; +using raft::neighbors::cagra::graph_build_algo; using raft::neighbors::cagra::index_params; using raft::neighbors::cagra::search_algo; using raft::neighbors::cagra::search_params; diff --git a/cpp/include/raft/neighbors/nn_descent_types.hpp b/cpp/include/raft/neighbors/nn_descent_types.hpp index 4df03fc643..ae4a70ba50 100644 --- a/cpp/include/raft/neighbors/nn_descent_types.hpp +++ b/cpp/include/raft/neighbors/nn_descent_types.hpp @@ -34,7 +34,7 @@ namespace raft::neighbors::nn_descent { struct index_params : ann::index_params { size_t intermediate_graph_degree = 128; // Degree of input graph for pruning. size_t graph_degree = 64; // Degree of output graph. - size_t max_iterations = 15; // Number of nn-descent iterations. + size_t max_iterations = 50; // Number of nn-descent iterations. float termination_threshold = 0.0001; // Termination threshold of nn-descent. }; From 94682d8edf4f5f29fb1f44a14f7a521aa68f7d97 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 31 Aug 2023 13:05:46 -0700 Subject: [PATCH 11/28] remove explicit sort from nn-descent --- cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh | 2 -- 1 file changed, 2 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index d67394806a..a790a5da84 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -260,8 +260,6 @@ nn_descent::index build_knn_graph( nn_descent_idx.graph().extent(0), nn_descent_idx.graph().extent(1)); - graph::sort_knn_graph(res, dataset, knn_graph_internal); - return nn_descent_idx; } From 7bf3ad630b79c1face63330c78090137ea6dc59c Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 31 Aug 2023 17:15:28 -0700 Subject: [PATCH 12/28] Use RAFT types --- cpp/include/raft/neighbors/cagra.cuh | 2 +- .../neighbors/detail/cagra/cagra_build.cuh | 2 + .../raft/neighbors/detail/nn_descent.cuh | 307 +++++++++--------- 3 files changed, 151 insertions(+), 160 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index a544f7044c..c95facb118 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -278,8 +278,8 @@ index build(raft::resources const& res, optimize(res, knn_graph.view(), cagra_graph.view()); } else { auto nn_descent_params = std::make_optional(); - nn_descent_params->intermediate_graph_degree = intermediate_degree; nn_descent_params->graph_degree = intermediate_degree; + nn_descent_params->intermediate_graph_degree = 1.5 * intermediate_degree; auto nn_descent_index = build_knn_graph(res, dataset, nn_descent_params); optimize(res, nn_descent_index.graph(), cagra_graph.view()); diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index a790a5da84..d67394806a 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -260,6 +260,8 @@ nn_descent::index build_knn_graph( nn_descent_idx.graph().extent(0), nn_descent_idx.graph().extent(1)); + graph::sort_knn_graph(res, dataset, knn_graph_internal); + return nn_descent_idx; } diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index b3c64a33ac..b740284cc4 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -25,14 +25,24 @@ #include #include #include +#include +#include +#include #include "../nn_descent_types.hpp" +#include +#include #include #include #include namespace raft::neighbors::nn_descent::detail { + +using pinned_memory_resource = thrust::universal_host_pinned_memory_resource; +template +using pinned_memory_allocator = thrust::mr::stateless_resource_allocator; + using DistData_t = float; constexpr int DEGREE_ON_DEVICE{32}; constexpr int SEGMENT_SIZE{32}; @@ -303,18 +313,19 @@ template struct GnndGraph { static constexpr int segment_size = 32; InternalID_t* h_graph; - DistData_t* h_dists; size_t nrow; size_t node_degree; int num_samples; int num_segments; - Index_t* h_graph_new; - int2* h_list_sizes_new; + raft::host_matrix h_dists; + + thrust::host_vector> h_graph_new; + thrust::host_vector> h_list_sizes_new; - Index_t* h_graph_old; - int2* h_list_sizes_old; + thrust::host_vector> h_graph_old; + thrust::host_vector> h_list_sizes_old; BloomFilter bloom_filter; GnndGraph(const GnndGraph&) = delete; @@ -333,19 +344,17 @@ struct GnndGraph { std::atomic& update_counter); void sort_lists(); void clear(); - void dealloc(); ~GnndGraph(); }; template class GNND { public: - GNND(const BuildConfig& build_config); + GNND(raft::resources const& res, const BuildConfig& build_config); GNND(const GNND&) = delete; GNND& operator=(const GNND&) = delete; - // Use delete[] to deallocate the returned graph - void build(Data_t* data, const Index_t nrow, Index_t* output_graph, cudaStream_t stream = 0); + void build(Data_t* data, const Index_t nrow, Index_t* output_graph); void dealloc(); ~GNND(); using ID_t = InternalID_t; @@ -357,32 +366,34 @@ class GNND { int2* list_sizes, cudaStream_t stream = 0); void local_join(cudaStream_t stream = 0); - void alloc_workspace(); + + raft::resources const& res; BuildConfig build_config_; GnndGraph graph_; std::atomic update_counter_; - __half* d_data_; - DistData_t* l2_norms_; + Index_t nrow_; + const int ndim_; + + raft::device_matrix<__half, Index_t, raft::row_major> d_data_; + raft::device_vector l2_norms_; + + raft::device_matrix graph_buffer_; + raft::device_matrix dists_buffer_; - ID_t* graph_buffer_; - DistData_t* dists_buffer_; - ID_t* graph_host_buffer_; - DistData_t* dists_host_buffer_; + thrust::host_vector> graph_host_buffer_; + thrust::host_vector> dists_host_buffer_; - int* d_locks_; + raft::device_vector d_locks_; - Index_t* h_rev_graph_new_; + thrust::host_vector> h_rev_graph_new_; + thrust::host_vector> h_graph_old_; + thrust::host_vector> h_rev_graph_old_; // int2.x is the number of forward edges, int2.y is the number of reverse edges - int2* d_list_sizes_new_; - Index_t* h_graph_old_; - Index_t* h_rev_graph_old_; int2* d_list_sizes_old_; - - Index_t nrow_; - const int ndim_; + int2* d_list_sizes_new_; }; constexpr int TILE_ROW_WIDTH = 64; @@ -926,12 +937,12 @@ int insert_to_ordered_list(InternalID_t* list, { if (dist > dist_list[width - 1]) { return width; } - int idx_insert = width; + int idx_insert = width; bool position_found = false; for (int i = 0; i < width; i++) { if (list[i].id() == neighb_id.id()) { return width; } if (!position_found && dist_list[i] > dist) { - idx_insert = i; + idx_insert = i; position_found = true; } } @@ -957,7 +968,12 @@ GnndGraph::GnndGraph(const size_t nrow, : nrow(nrow), node_degree(node_degree), num_samples(num_samples), - bloom_filter(nrow, internal_node_degree / segment_size, 3) + bloom_filter(nrow, internal_node_degree / segment_size, 3), + h_dists{raft::make_host_matrix(nrow, node_degree)}, + h_graph_new{nrow * num_samples}, + h_list_sizes_new{nrow}, + h_graph_old{nrow * num_samples}, + h_list_sizes_old{nrow} { // node_degree must be a multiple of segment_size; assert(node_degree % segment_size == 0); @@ -966,13 +982,6 @@ GnndGraph::GnndGraph(const size_t nrow, num_segments = node_degree / segment_size; // To save the CPU memory, graph should be allocated by external function h_graph = nullptr; - h_dists = new DistData_t[nrow * node_degree]; - - RAFT_CUDA_TRY(cudaMallocHost(&h_graph_new, sizeof(*h_graph_new) * nrow * num_samples)); - RAFT_CUDA_TRY(cudaMallocHost(&h_list_sizes_new, sizeof(*h_list_sizes_new) * nrow)); - - RAFT_CUDA_TRY(cudaMallocHost(&h_graph_old, sizeof(*h_graph_old) * nrow * num_samples)); - RAFT_CUDA_TRY(cudaMallocHost(&h_list_sizes_old, sizeof(*h_list_sizes_old) * nrow)); } // This is the only operation on the CPU that cannot be overlapped. @@ -982,7 +991,7 @@ void GnndGraph::sample_graph_new(InternalID_t* new_neighbors, { #pragma omp parallel for for (size_t i = 0; i < nrow; i++) { - auto list_new = h_graph_new + i * num_samples; + auto list_new = h_graph_new.data() + i * num_samples; h_list_sizes_new[i].x = 0; h_list_sizes_new[i].y = 0; @@ -1010,9 +1019,9 @@ void GnndGraph::init_random_graph() #pragma omp parallel for for (size_t i = 0; i < nrow; i++) { - size_t base_idx = i * node_degree + seg_idx * segment_size; + size_t base_idx = i * node_degree + seg_idx * segment_size; auto h_neighbor_list = h_graph + base_idx; - auto h_dist_list = h_dists + base_idx; + auto h_dist_list = h_dists.data_handle() + base_idx; for (size_t j = 0; j < static_cast(segment_size); j++) { size_t idx = base_idx + j; Index_t id = rand_seq[idx % rand_seq.size()] * num_segments + seg_idx; @@ -1037,8 +1046,8 @@ void GnndGraph::sample_graph(bool sample_new) h_list_sizes_new[i].y = 0; auto list = h_graph + i * node_degree; - auto list_old = h_graph_old + i * num_samples; - auto list_new = h_graph_new + i * num_samples; + auto list_old = h_graph_old.data() + i * num_samples; + auto list_new = h_graph_new.data() + i * num_samples; for (int j = 0; j < segment_size; j++) { for (int k = 0; k < num_segments; k++) { auto neighbor = list[k * segment_size + j]; @@ -1075,7 +1084,7 @@ void GnndGraph::update_graph(const InternalID_t* new_neighbors if ((size_t)new_neighb_id.id() == i) continue; int seg_idx = new_neighb_id.id() % num_segments; auto list = h_graph + i * node_degree + seg_idx * segment_size; - auto dist_list = h_dists + i * node_degree + seg_idx * segment_size; + auto dist_list = h_dists.data_handle() + i * node_degree + seg_idx * segment_size; int insert_pos = insert_to_ordered_list(list, dist_list, segment_size, new_neighb_id, new_dist); if (i % counter_interval == 0 && insert_pos != segment_size) { update_counter++; } @@ -1090,12 +1099,13 @@ void GnndGraph::sort_lists() for (size_t i = 0; i < nrow; i++) { std::vector> new_list; for (size_t j = 0; j < node_degree; j++) { - new_list.emplace_back(h_dists[i * node_degree + j], h_graph[i * node_degree + j].id()); + new_list.emplace_back(h_dists.data_handle()[i * node_degree + j], + h_graph[i * node_degree + j].id()); } std::sort(new_list.begin(), new_list.end()); for (size_t j = 0; j < node_degree; j++) { h_graph[i * node_degree + j].id_with_flag() = new_list[j].second; - h_dists[i * node_degree + j] = new_list[j].first; + h_dists.data_handle()[i * node_degree + j] = new_list[j].first; } } } @@ -1106,69 +1116,52 @@ void GnndGraph::clear() bloom_filter.clear(); } -template -void GnndGraph::dealloc() -{ - delete[] h_dists; - RAFT_CUDA_TRY(cudaFreeHost(h_graph_new)); - RAFT_CUDA_TRY(cudaFreeHost(h_list_sizes_new)); - RAFT_CUDA_TRY(cudaFreeHost(h_graph_old)); - RAFT_CUDA_TRY(cudaFreeHost(h_list_sizes_old)); - assert(h_graph == nullptr); -} - template GnndGraph::~GnndGraph() { + assert(h_graph == nullptr); } template -GNND::GNND(const BuildConfig& build_config) - : build_config_(build_config), +GNND::GNND(raft::resources const& res, const BuildConfig& build_config) + : res(res), + build_config_(build_config), graph_(build_config.max_dataset_size, to_multiple_of_32(build_config.node_degree), to_multiple_of_32(build_config.internal_node_degree ? build_config.internal_node_degree : build_config.node_degree), NUM_SAMPLES), nrow_(build_config.max_dataset_size), - ndim_(build_config.dataset_dim) + ndim_(build_config.dataset_dim), + d_data_{raft::make_device_matrix<__half, Index_t, raft::row_major>( + res, nrow_, build_config.dataset_dim)}, + l2_norms_{raft::make_device_vector(res, nrow_)}, + graph_buffer_{ + raft::make_device_matrix(res, nrow_, DEGREE_ON_DEVICE)}, + dists_buffer_{ + raft::make_device_matrix(res, nrow_, DEGREE_ON_DEVICE)}, + graph_host_buffer_{static_cast(nrow_ * DEGREE_ON_DEVICE)}, + dists_host_buffer_{static_cast(nrow_ * DEGREE_ON_DEVICE)}, + d_locks_{raft::make_device_vector(res, nrow_)}, + h_rev_graph_new_{static_cast(nrow_ * NUM_SAMPLES)}, + h_graph_old_{static_cast(nrow_ * NUM_SAMPLES)}, + h_rev_graph_old_{static_cast(nrow_ * NUM_SAMPLES)} { static_assert(NUM_SAMPLES <= 32); - alloc_workspace(); -}; -template -void GNND::alloc_workspace() -{ - RAFT_CUDA_TRY(cudaMalloc(&d_data_, sizeof(__half) * nrow_ * ndim_)); - RAFT_CUDA_TRY( - cudaMallocHost(&graph_host_buffer_, sizeof(*graph_host_buffer_) * nrow_ * DEGREE_ON_DEVICE)); - RAFT_CUDA_TRY( - cudaMallocHost(&dists_host_buffer_, sizeof(*dists_host_buffer_) * nrow_ * DEGREE_ON_DEVICE)); - RAFT_CUDA_TRY(cudaMalloc(&dists_buffer_, sizeof(*dists_buffer_) * nrow_ * DEGREE_ON_DEVICE)); thrust::fill(thrust::device, - dists_buffer_, - dists_buffer_ + (size_t)nrow_ * DEGREE_ON_DEVICE, + dists_buffer_.data_handle(), + dists_buffer_.data_handle() + dists_buffer_.size(), std::numeric_limits::max()); - RAFT_CUDA_TRY(cudaMalloc(&graph_buffer_, sizeof(*graph_buffer_) * nrow_ * DEGREE_ON_DEVICE)); thrust::fill(thrust::device, - reinterpret_cast(graph_buffer_), - reinterpret_cast(graph_buffer_) + (size_t)nrow_ * DEGREE_ON_DEVICE, + reinterpret_cast(graph_buffer_.data_handle()), + reinterpret_cast(graph_buffer_.data_handle()) + graph_buffer_.size(), std::numeric_limits::max()); - RAFT_CUDA_TRY(cudaMalloc(&d_locks_, sizeof(*d_locks_) * nrow_)); - thrust::fill(thrust::device, d_locks_, d_locks_ + nrow_, 0); - RAFT_CUDA_TRY(cudaMallocHost(&h_rev_graph_new_, sizeof(*h_rev_graph_new_) * nrow_ * NUM_SAMPLES)); - RAFT_CUDA_TRY(cudaMallocHost(&h_graph_old_, sizeof(*h_graph_old_) * nrow_ * NUM_SAMPLES)); - RAFT_CUDA_TRY(cudaMallocHost(&h_rev_graph_old_, sizeof(*h_rev_graph_old_) * nrow_ * NUM_SAMPLES)); + thrust::fill(thrust::device, d_locks_.data_handle(), d_locks_.data_handle() + d_locks_.size(), 0); + RAFT_CUDA_TRY(cudaMalloc(&d_list_sizes_new_, sizeof(*d_list_sizes_new_) * nrow_)); RAFT_CUDA_TRY(cudaMalloc(&d_list_sizes_old_, sizeof(*d_list_sizes_old_) * nrow_)); - - if (build_config_.metric_type == Metric_t::METRIC_L2) { - RAFT_CUDA_TRY(cudaMalloc(&l2_norms_, sizeof(*l2_norms_) * nrow_)); - } else { - l2_norms_ = nullptr; - } -} +}; template void GNND::add_reverse_edges(Index_t* graph_ptr, @@ -1190,46 +1183,44 @@ template void GNND::local_join(cudaStream_t stream) { thrust::fill(thrust::device.on(stream), - dists_buffer_, - dists_buffer_ + (size_t)nrow_ * DEGREE_ON_DEVICE, + dists_buffer_.data_handle(), + dists_buffer_.data_handle() + dists_buffer_.size(), std::numeric_limits::max()); - local_join_kernel<<>>(graph_.h_graph_new, - h_rev_graph_new_, - d_list_sizes_new_, - h_graph_old_, - h_rev_graph_old_, - d_list_sizes_old_, - NUM_SAMPLES, - d_data_, - ndim_, - graph_buffer_, - dists_buffer_, - DEGREE_ON_DEVICE, - d_locks_, - l2_norms_); + local_join_kernel<<>>( + thrust::raw_pointer_cast(graph_.h_graph_new.data()), + thrust::raw_pointer_cast(h_rev_graph_new_.data()), + d_list_sizes_new_, + thrust::raw_pointer_cast(h_graph_old_.data()), + thrust::raw_pointer_cast(h_rev_graph_old_.data()), + d_list_sizes_old_, + NUM_SAMPLES, + d_data_.data_handle(), + ndim_, + graph_buffer_.data_handle(), + dists_buffer_.data_handle(), + DEGREE_ON_DEVICE, + d_locks_.data_handle(), + l2_norms_.data_handle()); } template -void GNND::build(Data_t* data, - const Index_t nrow, - Index_t* output_graph, - cudaStream_t stream) +void GNND::build(Data_t* data, const Index_t nrow, Index_t* output_graph) { - nrow_ = nrow; - graph_.h_graph = (InternalID_t*)output_graph; + cudaStream_t stream = raft::resource::get_cuda_stream(res); + nrow_ = nrow; + graph_.h_graph = (InternalID_t*)output_graph; cudaPointerAttributes data_ptr_attr; RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data)); if (data_ptr_attr.type == cudaMemoryTypeUnregistered) { - std::cout << "HERE AS EXPECTED" << std::endl; - typename std::remove_const::type* input_data; size_t batch_size = 100000; - RAFT_CUDA_TRY(cudaMallocAsync( - &input_data, sizeof(Data_t) * batch_size * build_config_.dataset_dim, stream)); + using input_t = typename std::remove_const::type; + auto input_data = raft::make_device_matrix( + res, batch_size, build_config_.dataset_dim); for (size_t step = 0; step < div_up(nrow_, batch_size); step++) { size_t list_offset = step * batch_size; size_t num_lists = step != div_up(nrow_, batch_size) - 1 ? batch_size : nrow_ - list_offset; - RAFT_CUDA_TRY(cudaMemcpyAsync(input_data, + RAFT_CUDA_TRY(cudaMemcpyAsync(input_data.data_handle(), data + list_offset * build_config_.dataset_dim, sizeof(Data_t) * num_lists * build_config_.dataset_dim, cudaMemcpyDefault, @@ -1238,21 +1229,24 @@ void GNND::build(Data_t* data, WARP_SIZE, sizeof(Data_t) * div_up(build_config_.dataset_dim, WARP_SIZE) * WARP_SIZE, - stream>>>( - input_data, d_data_, build_config_.dataset_dim, l2_norms_, list_offset); + stream>>>(input_data.data_handle(), + d_data_.data_handle(), + build_config_.dataset_dim, + l2_norms_.data_handle(), + list_offset); } - RAFT_CUDA_TRY(cudaFreeAsync(input_data, stream)); } else { preprocess_data_kernel<<>>(data, d_data_, build_config_.dataset_dim, l2_norms_); + stream>>>( + data, d_data_.data_handle(), build_config_.dataset_dim, l2_norms_.data_handle()); } thrust::fill(thrust::device.on(stream), - (Index_t*)graph_buffer_, - (Index_t*)graph_buffer_ + (size_t)nrow_ * DEGREE_ON_DEVICE, + (Index_t*)graph_buffer_.data_handle(), + (Index_t*)graph_buffer_.data_handle() + graph_buffer_.size(), std::numeric_limits::max()); graph_.clear(); @@ -1262,8 +1256,10 @@ void GNND::build(Data_t* data, auto update_and_sample = [&](bool update_graph) { if (update_graph) { update_counter_ = 0; - graph_.update_graph( - graph_host_buffer_, dists_host_buffer_, DEGREE_ON_DEVICE, update_counter_); + graph_.update_graph(thrust::raw_pointer_cast(graph_host_buffer_.data()), + thrust::raw_pointer_cast(dists_host_buffer_.data()), + DEGREE_ON_DEVICE, + update_counter_); if (update_counter_ < build_config_.termination_threshold * nrow_ * build_config_.dataset_dim / counter_interval) { update_counter_ = -1; @@ -1274,17 +1270,17 @@ void GNND::build(Data_t* data, for (size_t it = 0; it < build_config_.max_iterations; it++) { RAFT_CUDA_TRY(cudaMemcpyAsync(d_list_sizes_new_, - graph_.h_list_sizes_new, + thrust::raw_pointer_cast(graph_.h_list_sizes_new.data()), sizeof(*d_list_sizes_new_) * nrow_, cudaMemcpyDefault, stream)); - RAFT_CUDA_TRY(cudaMemcpyAsync(h_graph_old_, - graph_.h_graph_old, - sizeof(*h_graph_old_) * nrow_ * NUM_SAMPLES, + RAFT_CUDA_TRY(cudaMemcpyAsync(thrust::raw_pointer_cast(h_graph_old_.data()), + thrust::raw_pointer_cast(graph_.h_graph_old.data()), + sizeof(*h_graph_old_.data()) * nrow_ * NUM_SAMPLES, cudaMemcpyDefault, stream)); RAFT_CUDA_TRY(cudaMemcpyAsync(d_list_sizes_old_, - graph_.h_list_sizes_old, + thrust::raw_pointer_cast(graph_.h_list_sizes_old.data()), sizeof(*d_list_sizes_old_) * nrow_, cudaMemcpyDefault, stream)); @@ -1297,12 +1293,18 @@ void GNND::build(Data_t* data, // Reuse dists_buffer_ to save GPU memory. graph_buffer_ cannot be reused, because it // contains some information for local_join. - static_assert(DEGREE_ON_DEVICE * sizeof(*dists_buffer_) >= - NUM_SAMPLES * sizeof(*graph_buffer_)); - add_reverse_edges( - graph_.h_graph_new, h_rev_graph_new_, (Index_t*)dists_buffer_, d_list_sizes_new_, stream); - add_reverse_edges( - h_graph_old_, h_rev_graph_old_, (Index_t*)dists_buffer_, d_list_sizes_old_, stream); + static_assert(DEGREE_ON_DEVICE * sizeof(*(dists_buffer_.data_handle())) >= + NUM_SAMPLES * sizeof(*(graph_buffer_.data_handle()))); + add_reverse_edges(thrust::raw_pointer_cast(graph_.h_graph_new.data()), + thrust::raw_pointer_cast(h_rev_graph_new_.data()), + (Index_t*)dists_buffer_.data_handle(), + d_list_sizes_new_, + stream); + add_reverse_edges(thrust::raw_pointer_cast(h_graph_old_.data()), + thrust::raw_pointer_cast(h_rev_graph_old_.data()), + (Index_t*)dists_buffer_.data_handle(), + d_list_sizes_old_, + stream); local_join(stream); @@ -1310,28 +1312,31 @@ void GNND::build(Data_t* data, if (update_counter_ == -1) { break; } - RAFT_CUDA_TRY(cudaMemcpyAsync(graph_host_buffer_, - graph_buffer_, - sizeof(*graph_buffer_) * nrow_ * DEGREE_ON_DEVICE, + RAFT_CUDA_TRY(cudaMemcpyAsync(thrust::raw_pointer_cast(graph_host_buffer_.data()), + graph_buffer_.data_handle(), + (sizeof(*graph_buffer_.data_handle())) * nrow_ * DEGREE_ON_DEVICE, cudaMemcpyDefault, stream)); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - RAFT_CUDA_TRY(cudaMemcpyAsync(dists_host_buffer_, - dists_buffer_, - sizeof(*dists_buffer_) * nrow_ * DEGREE_ON_DEVICE, + RAFT_CUDA_TRY(cudaMemcpyAsync(thrust::raw_pointer_cast(dists_host_buffer_.data()), + dists_buffer_.data_handle(), + sizeof(*(dists_buffer_.data_handle())) * nrow_ * DEGREE_ON_DEVICE, cudaMemcpyDefault, stream)); - graph_.sample_graph_new(graph_host_buffer_, DEGREE_ON_DEVICE); + graph_.sample_graph_new(thrust::raw_pointer_cast(graph_host_buffer_.data()), DEGREE_ON_DEVICE); } - graph_.update_graph(graph_host_buffer_, dists_host_buffer_, DEGREE_ON_DEVICE, update_counter_); + graph_.update_graph(thrust::raw_pointer_cast(graph_host_buffer_.data()), + thrust::raw_pointer_cast(dists_host_buffer_.data()), + DEGREE_ON_DEVICE, + update_counter_); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); graph_.sort_lists(); // Reuse graph_.h_dists as the buffer for shrink the lists in graph - static_assert(sizeof(decltype(*graph_.h_dists)) >= sizeof(Index_t)); - Index_t* graph_shrink_buffer = (Index_t*)graph_.h_dists; + static_assert(sizeof(decltype(*(graph_.h_dists.data_handle()))) >= sizeof(Index_t)); + Index_t* graph_shrink_buffer = (Index_t*)graph_.h_dists.data_handle(); #pragma omp parallel for for (size_t i = 0; i < (size_t)nrow_; i++) { @@ -1359,19 +1364,8 @@ void GNND::build(Data_t* data, template void GNND::dealloc() { - graph_.dealloc(); - RAFT_CUDA_TRY(cudaFree(d_data_)); - RAFT_CUDA_TRY(cudaFreeHost(graph_host_buffer_)); - RAFT_CUDA_TRY(cudaFreeHost(dists_host_buffer_)); - RAFT_CUDA_TRY(cudaFree(dists_buffer_)); - RAFT_CUDA_TRY(cudaFree(graph_buffer_)); - RAFT_CUDA_TRY(cudaFree(d_locks_)); - RAFT_CUDA_TRY(cudaFreeHost(h_rev_graph_new_)); - RAFT_CUDA_TRY(cudaFreeHost(h_graph_old_)); - RAFT_CUDA_TRY(cudaFreeHost(h_rev_graph_old_)); RAFT_CUDA_TRY(cudaFree(d_list_sizes_new_)); RAFT_CUDA_TRY(cudaFree(d_list_sizes_old_)); - RAFT_CUDA_TRY(cudaFree(l2_norms_)); } template @@ -1411,7 +1405,7 @@ index build(raft::resources const& res, // extended_graph_degree. size_t extended_graph_degree = to_multiple_of_32(graph_degree * (graph_degree <= 32 ? 1.0 : 1.3)); size_t extended_intermediate_degree = - to_multiple_of_32(intermediate_degree * (graph_degree <= 32 ? 1.0 : 1.3)); + to_multiple_of_32(intermediate_degree * (intermediate_degree <= 32 ? 1.0 : 1.3)); auto int_graph = raft::make_host_matrix( dataset.extent(0), static_cast(extended_graph_degree)); @@ -1424,13 +1418,8 @@ index build(raft::resources const& res, .termination_threshold = params.termination_threshold, .metric_type = Metric_t::METRIC_L2}; - std::cout << "Intermediate graph dim: " << int_graph.extent(0) << ", " << int_graph.extent(1) - << std::endl; - GNND nnd(build_config); - nnd.build(dataset.data_handle(), - dataset.extent(0), - int_graph.data_handle(), - resource::get_cuda_stream(res)); + GNND nnd(res, build_config); + nnd.build(dataset.data_handle(), dataset.extent(0), int_graph.data_handle()); nnd.dealloc(); index idx{res, dataset.extent(0), static_cast(graph_degree)}; #pragma omp parallel for From 60d78051c22c61e796834195dcbb8d7fb00d6f96 Mon Sep 17 00:00:00 2001 From: divyegala Date: Fri, 1 Sep 2023 13:45:46 -0700 Subject: [PATCH 13/28] using RAFT types --- .../raft/neighbors/detail/nn_descent.cuh | 46 +++++++------------ 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index b740284cc4..fa59d8494e 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -23,6 +23,9 @@ #include #include #include + +#include + #include #include #include @@ -355,8 +358,7 @@ class GNND { GNND& operator=(const GNND&) = delete; void build(Data_t* data, const Index_t nrow, Index_t* output_graph); - void dealloc(); - ~GNND(); + ~GNND() = default; using ID_t = InternalID_t; private: @@ -392,8 +394,8 @@ class GNND { thrust::host_vector> h_rev_graph_old_; // int2.x is the number of forward edges, int2.y is the number of reverse edges - int2* d_list_sizes_old_; - int2* d_list_sizes_new_; + raft::device_vector d_list_sizes_new_; + raft::device_vector d_list_sizes_old_; }; constexpr int TILE_ROW_WIDTH = 64; @@ -1145,7 +1147,9 @@ GNND::GNND(raft::resources const& res, const BuildConfig& build d_locks_{raft::make_device_vector(res, nrow_)}, h_rev_graph_new_{static_cast(nrow_ * NUM_SAMPLES)}, h_graph_old_{static_cast(nrow_ * NUM_SAMPLES)}, - h_rev_graph_old_{static_cast(nrow_ * NUM_SAMPLES)} + h_rev_graph_old_{static_cast(nrow_ * NUM_SAMPLES)}, + d_list_sizes_new_{raft::make_device_vector(res, nrow_)}, + d_list_sizes_old_{raft::make_device_vector(res, nrow_)} { static_assert(NUM_SAMPLES <= 32); @@ -1158,9 +1162,6 @@ GNND::GNND(raft::resources const& res, const BuildConfig& build reinterpret_cast(graph_buffer_.data_handle()) + graph_buffer_.size(), std::numeric_limits::max()); thrust::fill(thrust::device, d_locks_.data_handle(), d_locks_.data_handle() + d_locks_.size(), 0); - - RAFT_CUDA_TRY(cudaMalloc(&d_list_sizes_new_, sizeof(*d_list_sizes_new_) * nrow_)); - RAFT_CUDA_TRY(cudaMalloc(&d_list_sizes_old_, sizeof(*d_list_sizes_old_) * nrow_)); }; template @@ -1189,10 +1190,10 @@ void GNND::local_join(cudaStream_t stream) local_join_kernel<<>>( thrust::raw_pointer_cast(graph_.h_graph_new.data()), thrust::raw_pointer_cast(h_rev_graph_new_.data()), - d_list_sizes_new_, + d_list_sizes_new_.data_handle(), thrust::raw_pointer_cast(h_graph_old_.data()), thrust::raw_pointer_cast(h_rev_graph_old_.data()), - d_list_sizes_old_, + d_list_sizes_old_.data_handle(), NUM_SAMPLES, d_data_.data_handle(), ndim_, @@ -1269,9 +1270,9 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out }; for (size_t it = 0; it < build_config_.max_iterations; it++) { - RAFT_CUDA_TRY(cudaMemcpyAsync(d_list_sizes_new_, + RAFT_CUDA_TRY(cudaMemcpyAsync(d_list_sizes_new_.data_handle(), thrust::raw_pointer_cast(graph_.h_list_sizes_new.data()), - sizeof(*d_list_sizes_new_) * nrow_, + sizeof(*(d_list_sizes_new_.data_handle())) * nrow_, cudaMemcpyDefault, stream)); RAFT_CUDA_TRY(cudaMemcpyAsync(thrust::raw_pointer_cast(h_graph_old_.data()), @@ -1279,9 +1280,9 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out sizeof(*h_graph_old_.data()) * nrow_ * NUM_SAMPLES, cudaMemcpyDefault, stream)); - RAFT_CUDA_TRY(cudaMemcpyAsync(d_list_sizes_old_, + RAFT_CUDA_TRY(cudaMemcpyAsync(d_list_sizes_old_.data_handle(), thrust::raw_pointer_cast(graph_.h_list_sizes_old.data()), - sizeof(*d_list_sizes_old_) * nrow_, + sizeof(*(d_list_sizes_old_.data_handle())) * nrow_, cudaMemcpyDefault, stream)); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); @@ -1298,12 +1299,12 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out add_reverse_edges(thrust::raw_pointer_cast(graph_.h_graph_new.data()), thrust::raw_pointer_cast(h_rev_graph_new_.data()), (Index_t*)dists_buffer_.data_handle(), - d_list_sizes_new_, + d_list_sizes_new_.data_handle(), stream); add_reverse_edges(thrust::raw_pointer_cast(h_graph_old_.data()), thrust::raw_pointer_cast(h_rev_graph_old_.data()), (Index_t*)dists_buffer_.data_handle(), - d_list_sizes_old_, + d_list_sizes_old_.data_handle(), stream); local_join(stream); @@ -1361,18 +1362,6 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out } } -template -void GNND::dealloc() -{ - RAFT_CUDA_TRY(cudaFree(d_list_sizes_new_)); - RAFT_CUDA_TRY(cudaFree(d_list_sizes_old_)); -} - -template -GNND::~GNND() -{ -} - template build(raft::resources const& res, GNND nnd(res, build_config); nnd.build(dataset.data_handle(), dataset.extent(0), int_graph.data_handle()); - nnd.dealloc(); index idx{res, dataset.extent(0), static_cast(graph_degree)}; #pragma omp parallel for for (size_t i = 0; i < static_cast(dataset.extent(0)); i++) { From 5eb5690476a61888f1adbe3278a91866839e6486 Mon Sep 17 00:00:00 2001 From: divyegala Date: Fri, 1 Sep 2023 14:11:14 -0700 Subject: [PATCH 14/28] remove explicit cuda copies and stream syncs --- .../raft/neighbors/detail/nn_descent.cuh | 69 ++++++++----------- 1 file changed, 30 insertions(+), 39 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index fa59d8494e..08a58195ff 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -39,6 +39,7 @@ #include #include #include +#include namespace raft::neighbors::nn_descent::detail { @@ -1173,11 +1174,8 @@ void GNND::add_reverse_edges(Index_t* graph_ptr, { add_rev_edges_kernel<<>>( graph_ptr, d_rev_graph_ptr, NUM_SAMPLES, list_sizes); - RAFT_CUDA_TRY(cudaMemcpyAsync(h_rev_graph_ptr, - d_rev_graph_ptr, - sizeof(*h_rev_graph_ptr) * nrow_ * NUM_SAMPLES, - cudaMemcpyDefault, - stream)); + raft::copy( + h_rev_graph_ptr, d_rev_graph_ptr, nrow_ * NUM_SAMPLES, raft::resource::get_cuda_stream(res)); } template @@ -1221,11 +1219,10 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out for (size_t step = 0; step < div_up(nrow_, batch_size); step++) { size_t list_offset = step * batch_size; size_t num_lists = step != div_up(nrow_, batch_size) - 1 ? batch_size : nrow_ - list_offset; - RAFT_CUDA_TRY(cudaMemcpyAsync(input_data.data_handle(), - data + list_offset * build_config_.dataset_dim, - sizeof(Data_t) * num_lists * build_config_.dataset_dim, - cudaMemcpyDefault, - stream)); + raft::copy(input_data.data_handle(), + data + list_offset * build_config_.dataset_dim, + num_lists * build_config_.dataset_dim, + raft::resource::get_cuda_stream(res)); preprocess_data_kernel<<::build(Data_t* data, const Index_t nrow, Index_t* out }; for (size_t it = 0; it < build_config_.max_iterations; it++) { - RAFT_CUDA_TRY(cudaMemcpyAsync(d_list_sizes_new_.data_handle(), - thrust::raw_pointer_cast(graph_.h_list_sizes_new.data()), - sizeof(*(d_list_sizes_new_.data_handle())) * nrow_, - cudaMemcpyDefault, - stream)); - RAFT_CUDA_TRY(cudaMemcpyAsync(thrust::raw_pointer_cast(h_graph_old_.data()), - thrust::raw_pointer_cast(graph_.h_graph_old.data()), - sizeof(*h_graph_old_.data()) * nrow_ * NUM_SAMPLES, - cudaMemcpyDefault, - stream)); - RAFT_CUDA_TRY(cudaMemcpyAsync(d_list_sizes_old_.data_handle(), - thrust::raw_pointer_cast(graph_.h_list_sizes_old.data()), - sizeof(*(d_list_sizes_old_.data_handle())) * nrow_, - cudaMemcpyDefault, - stream)); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + raft::copy(d_list_sizes_new_.data_handle(), + thrust::raw_pointer_cast(graph_.h_list_sizes_new.data()), + nrow_, + raft::resource::get_cuda_stream(res)); + raft::copy(thrust::raw_pointer_cast(h_graph_old_.data()), + thrust::raw_pointer_cast(graph_.h_graph_old.data()), + nrow_ * NUM_SAMPLES, + raft::resource::get_cuda_stream(res)); + raft::copy(d_list_sizes_old_.data_handle(), + thrust::raw_pointer_cast(graph_.h_list_sizes_old.data()), + nrow_, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); std::thread update_and_sample_thread(update_and_sample, it); @@ -1312,18 +1306,16 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out update_and_sample_thread.join(); if (update_counter_ == -1) { break; } + raft::copy(thrust::raw_pointer_cast(graph_host_buffer_.data()), + graph_buffer_.data_handle(), + nrow_ * DEGREE_ON_DEVICE, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + raft::copy(thrust::raw_pointer_cast(dists_host_buffer_.data()), + dists_buffer_.data_handle(), + nrow_ * DEGREE_ON_DEVICE, + raft::resource::get_cuda_stream(res)); - RAFT_CUDA_TRY(cudaMemcpyAsync(thrust::raw_pointer_cast(graph_host_buffer_.data()), - graph_buffer_.data_handle(), - (sizeof(*graph_buffer_.data_handle())) * nrow_ * DEGREE_ON_DEVICE, - cudaMemcpyDefault, - stream)); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - RAFT_CUDA_TRY(cudaMemcpyAsync(thrust::raw_pointer_cast(dists_host_buffer_.data()), - dists_buffer_.data_handle(), - sizeof(*(dists_buffer_.data_handle())) * nrow_ * DEGREE_ON_DEVICE, - cudaMemcpyDefault, - stream)); graph_.sample_graph_new(thrust::raw_pointer_cast(graph_host_buffer_.data()), DEGREE_ON_DEVICE); } @@ -1331,8 +1323,7 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out thrust::raw_pointer_cast(dists_host_buffer_.data()), DEGREE_ON_DEVICE, update_counter_); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - + raft::resource::sync_stream(res); graph_.sort_lists(); // Reuse graph_.h_dists as the buffer for shrink the lists in graph From 21ac440031b56ae1fff8c7752281843bde624ea3 Mon Sep 17 00:00:00 2001 From: divyegala Date: Fri, 1 Sep 2023 14:30:51 -0700 Subject: [PATCH 15/28] experimental namespace, docs update+code examples --- cpp/include/raft/neighbors/cagra.cuh | 48 ++++++++++++-- .../neighbors/detail/cagra/cagra_build.cuh | 10 +-- .../raft/neighbors/detail/nn_descent.cuh | 4 +- cpp/include/raft/neighbors/nn_descent.cuh | 64 ++++++++++++++----- .../raft/neighbors/nn_descent_types.hpp | 4 +- 5 files changed, 100 insertions(+), 30 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index c95facb118..5d94bb2ce0 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -35,10 +35,9 @@ namespace raft::neighbors::cagra { */ /** - * @brief Build a kNN graph. + * @brief Build a kNN graph using IVF-PQ. * * The kNN graph is the first building block for CAGRA index. - * This function uses the IVF-PQ method to build a kNN graph. * * The output is a dense matrix that stores the neighbor indices for each pont in the dataset. * Each point has the same number of neighbors. @@ -95,14 +94,51 @@ void build_knn_graph(raft::resources const& res, res, dataset_internal, knn_graph_internal, refine_rate, build_params, search_params); } +/** + * @brief Build a kNN graph using NN-descent. + * + * The kNN graph is the first building block for CAGRA index. + * + * The output is a dense matrix that stores the neighbor indices for each pont in the dataset. + * Each point has the same number of neighbors. + * + * See [cagra::build](#cagra::build) for an alternative method. + * + * The following distance metrics are supported: + * - L2Expanded + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * using namespace raft::neighbors::experimental; + * // use default index parameters + * nn_descent::index_params build_params; + * build_params.graph_degree = 128; + * // create knn graph + * auto nn_descent_index = cagra::build_knn_graph(res, dataset, build_params); + * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64); + * cagra::optimize(res, dataset, nn_descent_index.graph.view(), optimized_graph.view()); + * // Construct an index from dataset and optimized knn_graph + * auto index = cagra::index(res, build_params.metric(), dataset, + * optimized_graph.view()); + * @endcode + * + * @tparam DataT data element type + * @tparam IdxT type of the dataset vector indices + * @tparam accessor host or device accessor_type for the dataset + * @param res raft::resources + * @param dataset raft::host/device_matrix_view + * @param build_params raft::neighbors::nn_descent::index_params + * @return raft::neighbors::nn_descent::index + */ template , memory_type::device>> -nn_descent::index build_knn_graph( +experimental::nn_descent::index build_knn_graph( raft::resources const& res, mdspan, row_major, accessor> dataset, - std::optional build_params = std::nullopt) + std::optional build_params = std::nullopt) { return detail::build_knn_graph(res, dataset, build_params); } @@ -277,8 +313,8 @@ index build(raft::resources const& res, optimize(res, knn_graph.view(), cagra_graph.view()); } else { - auto nn_descent_params = std::make_optional(); - nn_descent_params->graph_degree = intermediate_degree; + auto nn_descent_params = std::make_optional(); + nn_descent_params->graph_degree = intermediate_degree; nn_descent_params->intermediate_graph_degree = 1.5 * intermediate_degree; auto nn_descent_index = build_knn_graph(res, dataset, nn_descent_params); diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index d67394806a..d572285178 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -240,14 +240,16 @@ void build_knn_graph(raft::resources const& res, } template -nn_descent::index build_knn_graph( +experimental::nn_descent::index build_knn_graph( raft::resources const& res, mdspan, row_major, accessor> dataset, - std::optional build_params = std::nullopt) + std::optional build_params = std::nullopt) { - if (!build_params) { build_params = std::make_optional(); } + if (!build_params) { + build_params = std::make_optional(); + } - auto nn_descent_idx = nn_descent::build(res, *build_params, dataset); + auto nn_descent_idx = experimental::nn_descent::build(res, *build_params, dataset); using internal_IdxT = typename std::make_unsigned::type; using g_accessor = typename decltype(nn_descent_idx.graph())::accessor_type; diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index 08a58195ff..997e1bdea4 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -41,7 +41,7 @@ #include #include -namespace raft::neighbors::nn_descent::detail { +namespace raft::neighbors::experimental::nn_descent::detail { using pinned_memory_resource = thrust::universal_host_pinned_memory_resource; template @@ -1411,4 +1411,4 @@ index build(raft::resources const& res, return idx; } -} // namespace raft::neighbors::nn_descent::detail +} // namespace raft::neighbors::experimental::nn_descent::detail diff --git a/cpp/include/raft/neighbors/nn_descent.cuh b/cpp/include/raft/neighbors/nn_descent.cuh index 408d590aca..abdaf4771d 100644 --- a/cpp/include/raft/neighbors/nn_descent.cuh +++ b/cpp/include/raft/neighbors/nn_descent.cuh @@ -21,7 +21,7 @@ #include #include -namespace raft::neighbors::nn_descent { +namespace raft::neighbors::experimental::nn_descent { /** * @defgroup nn-descent CUDA ANN Graph-based gradient descent nearest neighbor @@ -30,38 +30,70 @@ namespace raft::neighbors::nn_descent { /** * @brief Build nn-descent Index with dataset in device memory - * - * @tparam T - * @tparam IdxT + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors::experimental; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::device_matrix_view dataset + * auto index = cagra::build(res, index_params, dataset); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * @endcode + * + * @tparam T data-type + * @tparam IdxT index-type * @param res raft::resources * @param params nn_descent::index_params * @param dataset raft::device_matrix_view - * @return index + * @return index */ template index build(raft::resources const& res, index_params const& params, - raft::device_matrix_view dataset) { - return detail::build(res, params, dataset); + raft::device_matrix_view dataset) +{ + return detail::build(res, params, dataset); } /** * @brief Build nn-descent Index with dataset in host memory - * - * @tparam T - * @tparam IdxT + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors::experimental; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::host_matrix_view dataset + * auto index = cagra::build(res, index_params, dataset); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * @endcode + * + * @tparam T data-type + * @tparam IdxT index-type * @param res raft::resources * @param params nn_descent::index_params * @param dataset raft::host_matrix_view - * @return index + * @return index */ template index build(raft::resources const& res, - index_params const& params, - raft::host_matrix_view dataset) { - return detail::build(res, params, dataset); + index_params const& params, + raft::host_matrix_view dataset) +{ + return detail::build(res, params, dataset); } -/** @} */ // end group cagra +/** @} */ // end group nn-descent -} +} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/include/raft/neighbors/nn_descent_types.hpp b/cpp/include/raft/neighbors/nn_descent_types.hpp index ae4a70ba50..85f9336263 100644 --- a/cpp/include/raft/neighbors/nn_descent_types.hpp +++ b/cpp/include/raft/neighbors/nn_descent_types.hpp @@ -25,7 +25,7 @@ #include #include -namespace raft::neighbors::nn_descent { +namespace raft::neighbors::experimental::nn_descent { /** * @ingroup nn_descent * @{ @@ -104,4 +104,4 @@ struct index : ann::index { /** @} */ -} // namespace raft::neighbors::nn_descent +} // namespace raft::neighbors::experimental::nn_descent From 28135a82d7dfe9706edb5f9969b0ad382be86211 Mon Sep 17 00:00:00 2001 From: divyegala Date: Fri, 1 Sep 2023 15:05:17 -0700 Subject: [PATCH 16/28] add graph_build_algo to bench-ann --- cpp/bench/ann/src/raft/raft_benchmark.cu | 7 +++++++ docs/source/ann_benchmarks_param_tuning.md | 1 + 2 files changed, 8 insertions(+) diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index aa25d1532f..f576681720 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -139,6 +139,13 @@ void parse_build_param(const nlohmann::json& conf, if (conf.contains("intermediate_graph_degree")) { param.intermediate_graph_degree = conf.at("intermediate_graph_degree"); } + if (conf.contains("graph_build_algo")) { + if (conf.at("graph_build_algo") == "IVF_PQ") { + param.build_algo = raft::neighbors::cagra::graph_build_algo::IVF_PQ; + } else if (conf.at("graph_build_algo") == "NN_DESCENT") { + param.build_algo = raft::neighbors::cagra::graph_build_algo::NN_DESCENT; + } + } } template diff --git a/docs/source/ann_benchmarks_param_tuning.md b/docs/source/ann_benchmarks_param_tuning.md index 020c2d5ad9..b625c11990 100644 --- a/docs/source/ann_benchmarks_param_tuning.md +++ b/docs/source/ann_benchmarks_param_tuning.md @@ -42,6 +42,7 @@ CAGRA uses a graph-based index, which creates an intermediate, approximate kNN g |-----------|----------------|----------|---------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | `graph_degree` | `build_param` | N | Positive Integer >0 | 64 | Degree of the final kNN graph index. | | `intermediate_graph_degree` | `build_param` | N | Positive Integer >0 | 128 | Degree of the intermediate kNN graph. | +| `graph_build_algo` | `build_param` | N | string | "IVF_PQ" | Algorithm to use for search. Possible values: {"IVF_PQ", "NN_DESCENT"} | | `itopk` | `search_wdith` | N | Positive Integer >0 | 64 | Number of intermediate search results retained during the search. Higher values improve search accuracy at the cost of speed. | | `search_width` | `search_param` | N | Positive Integer >0 | 1 | Number of graph nodes to select as the starting point for the search in each iteration. | | `max_iterations` | `search_param` | N | Integer >=0 | 0 | Upper limit of search iterations. Auto select when 0. | From b0344c727e5734f6715c28b33fa9bcb725dfa854 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 5 Sep 2023 17:44:41 -0700 Subject: [PATCH 17/28] add tests --- cpp/test/CMakeLists.txt | 16 ++ cpp/test/neighbors/ann_nn_descent.cuh | 161 ++++++++++++++++++ .../ann_nn_descent/test_float_int64_t.cu | 28 +++ .../ann_nn_descent/test_float_uint32_t.cu | 28 +++ .../ann_nn_descent/test_int8_t_uint32_t.cu | 28 +++ .../ann_nn_descent/test_uint8_t_uint32_t.cu | 28 +++ cpp/test/neighbors/ann_utils.cuh | 41 +++++ 7 files changed, 330 insertions(+) create mode 100644 cpp/test/neighbors/ann_nn_descent.cuh create mode 100644 cpp/test/neighbors/ann_nn_descent/test_float_int64_t.cu create mode 100644 cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu create mode 100644 cpp/test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu create mode 100644 cpp/test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index db4c59c807..69d950d10d 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -379,6 +379,22 @@ if(BUILD_TESTS) 100 ) + ConfigureTest( + NAME + NEIGHBORS_ANN_NN_DESCENT_TEST + PATH + test/neighbors/ann_nn_descent/test_float_uint32_t.cu + test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu + test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu + test/neighbors/ann_nn_descent/test_float_int64_t.cu + LIB + EXPLICIT_INSTANTIATE_ONLY + GPUS + 1 + PERCENT + 100 + ) + ConfigureTest( NAME NEIGHBORS_SELECTION_TEST PATH test/neighbors/selection.cu LIB EXPLICIT_INSTANTIATE_ONLY GPUS 1 PERCENT 50 diff --git a/cpp/test/neighbors/ann_nn_descent.cuh b/cpp/test/neighbors/ann_nn_descent.cuh new file mode 100644 index 0000000000..0eeee65529 --- /dev/null +++ b/cpp/test/neighbors/ann_nn_descent.cuh @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../test_utils.cuh" +#include "ann_utils.cuh" + +#include + +#include +#include +#include + +#include + +#include +#include +#include +#include + +namespace raft::neighbors::experimental::nn_descent { + +struct AnnNNDescentInputs { + int n_rows; + int dim; + int graph_degree; + raft::distance::DistanceType metric; + bool host_dataset; + double min_recall; +}; + +inline ::std::ostream& operator<<(::std::ostream& os, const AnnNNDescentInputs& p) +{ + os << "dataset shape=" << p.n_rows << "x" << p.dim << ", graph_degree=" << p.graph_degree + << ", metric=" << static_cast(p.metric) << (p.host_dataset ? ", host" : ", device") + << std::endl; + return os; +} + +template +class AnnNNDescentTest : public ::testing::TestWithParam { + public: + AnnNNDescentTest() + : stream_(resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam::GetParam()), + database(0, stream_) + { + } + + protected: + void testNNDescent() + { + size_t queries_size = ps.n_rows * ps.graph_degree; + std::vector indices_NNDescent(queries_size); + std::vector indices_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + database.data(), + database.data(), + ps.n_rows, + ps.n_rows, + ps.dim, + ps.graph_degree, + ps.metric); + update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + { + { + nn_descent::index_params index_params; + index_params.metric = ps.metric; + index_params.graph_degree = ps.graph_degree; + index_params.intermediate_graph_degree = 2 * ps.graph_degree; + + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.n_rows, ps.dim); + + { + if (ps.host_dataset) { + auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); + auto database_host_view = raft::make_host_matrix_view( + (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); + auto index = nn_descent::build(handle_, index_params, database_host_view); + update_host( + indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); + } else { + auto index = nn_descent::build(handle_, index_params, database_view); + update_host( + indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); + }; + } + resource::sync_stream(handle_); + } + + // for (int i = 0; i < min(ps.n_queries, 10); i++) { + // // std::cout << "query " << i << std::end; + // print_vector("T", indices_naive.data() + i * ps.k, ps.k, std::cout); + // print_vector("C", indices_Cagra.data() + i * ps.k, ps.k, std::cout); + // print_vector("T", distances_naive.data() + i * ps.k, ps.k, std::cout); + // print_vector("C", distances_Cagra.data() + i * ps.k, ps.k, std::cout); + // } + double min_recall = ps.min_recall; + EXPECT_TRUE(eval_recall( + indices_naive, indices_NNDescent, ps.n_rows, ps.graph_degree, 0.001, min_recall)); + } + } + + void SetUp() override + { + database.resize(((size_t)ps.n_rows) * ps.dim, stream_); + raft::random::Rng r(1234ULL); + if constexpr (std::is_same{}) { + r.normal(database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0), stream_); + } else { + r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), stream_); + } + resource::sync_stream(handle_); + } + + void TearDown() override + { + resource::sync_stream(handle_); + database.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + AnnNNDescentInputs ps; + rmm::device_uvector database; +}; + +const std::vector inputs = raft::util::itertools::product( + {1000, 2000}, // n_rows + {3, 5, 7, 8, 17, 64, 128, 137, 192, 256, 512, 619, 1024}, // dim + {32, 64}, // graph_degree + {raft::distance::DistanceType::L2Expanded}, + {false, true}, + {0.92}); + +} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_nn_descent/test_float_int64_t.cu b/cpp/test/neighbors/ann_nn_descent/test_float_int64_t.cu new file mode 100644 index 0000000000..92a47a6a77 --- /dev/null +++ b/cpp/test/neighbors/ann_nn_descent/test_float_int64_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../ann_nn_descent.cuh" + +namespace raft::neighbors::experimental::nn_descent { + +typedef AnnNNDescentTest AnnNNDescentTestF_I64; +TEST_P(AnnNNDescentTestF_I64, AnnCagra) { this->testNNDescent(); } + +INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestF_I64, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu b/cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu new file mode 100644 index 0000000000..13bff6ac90 --- /dev/null +++ b/cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../ann_nn_descent.cuh" + +namespace raft::neighbors::experimental::nn_descent { + +typedef AnnNNDescentTest AnnNNDescentTestF_U32; +TEST_P(AnnNNDescentTestF_U32, AnnCagra) { this->testNNDescent(); } + +INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestF_U32, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu b/cpp/test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu new file mode 100644 index 0000000000..5895303e09 --- /dev/null +++ b/cpp/test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../ann_nn_descent.cuh" + +namespace raft::neighbors::experimental::nn_descent { + +typedef AnnNNDescentTest AnnNNDescentTestI8_U32; +TEST_P(AnnNNDescentTestI8_U32, AnnCagra) { this->testNNDescent(); } + +INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestI8_U32, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu b/cpp/test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu new file mode 100644 index 0000000000..a034e84074 --- /dev/null +++ b/cpp/test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../ann_nn_descent.cuh" + +namespace raft::neighbors::experimental::nn_descent { + +typedef AnnNNDescentTest AnnNNDescentTestUI8_U32; +TEST_P(AnnNNDescentTestUI8_U32, AnnCagra) { this->testNNDescent(); } + +INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestUI8_U32, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index 0e54e29c01..fabc21f508 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -123,6 +123,47 @@ struct idx_dist_pair { idx_dist_pair(IdxT x, DistT y, CompareDist op) : idx(x), dist(y), eq_compare(op) {} }; +template +auto eval_recall(const std::vector& expected_idx, + const std::vector& actual_idx, + size_t rows, + size_t cols, + double eps, + double min_recall) -> testing::AssertionResult +{ + size_t match_count = 0; + size_t total_count = static_cast(rows) * static_cast(cols); + for (size_t i = 0; i < rows; ++i) { + for (size_t k = 0; k < cols; ++k) { + size_t idx_k = i * cols + k; // row major assumption! + auto act_idx = actual_idx[idx_k]; + for (size_t j = 0; j < cols; ++j) { + size_t idx = i * cols + j; // row major assumption! + auto exp_idx = expected_idx[idx]; + if (act_idx == exp_idx) { + match_count++; + break; + } + } + } + } + double actual_recall = static_cast(match_count) / static_cast(total_count); + double error_margin = (actual_recall - min_recall) / std::max(1.0 - min_recall, eps); + RAFT_LOG_INFO("Recall = %f (%zu/%zu), the error is %2.1f%% %s the threshold (eps = %f).", + actual_recall, + match_count, + total_count, + std::abs(error_margin * 100.0), + error_margin < 0 ? "above" : "below", + eps); + if (actual_recall < min_recall - eps) { + return testing::AssertionFailure() + << "actual recall (" << actual_recall << ") is lower than the minimum expected recall (" + << min_recall << "); eps = " << eps << ". "; + } + return testing::AssertionSuccess(); +} + template auto eval_neighbours(const std::vector& expected_idx, const std::vector& actual_idx, From 3f3d9659be9b9e1790c2eda490442bf6b33dee90 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 5 Sep 2023 17:45:06 -0700 Subject: [PATCH 18/28] add arch guards for using wmma --- cpp/include/raft/neighbors/cagra.cuh | 7 +++++++ cpp/include/raft/neighbors/nn_descent.cuh | 2 ++ 2 files changed, 9 insertions(+) diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 5d94bb2ce0..34f7106c27 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -21,6 +21,7 @@ #include "detail/cagra/graph_core.cuh" #include +#include #include #include #include @@ -94,6 +95,7 @@ void build_knn_graph(raft::resources const& res, res, dataset_internal, knn_graph_internal, refine_rate, build_params, search_params); } +#if (__CUDA_ARCH__ >= 700) /** * @brief Build a kNN graph using NN-descent. * @@ -142,6 +144,7 @@ experimental::nn_descent::index build_knn_graph( { return detail::build_knn_graph(res, dataset, build_params); } +#endif /** * @brief Sort a KNN graph index. @@ -313,12 +316,16 @@ index build(raft::resources const& res, optimize(res, knn_graph.view(), cagra_graph.view()); } else { +#if (__CUDA_ARCH__ >= 700) auto nn_descent_params = std::make_optional(); nn_descent_params->graph_degree = intermediate_degree; nn_descent_params->intermediate_graph_degree = 1.5 * intermediate_degree; auto nn_descent_index = build_knn_graph(res, dataset, nn_descent_params); optimize(res, nn_descent_index.graph(), cagra_graph.view()); +#else + THROW("Cannot run CAGRA with graph_build_algo::NN_Descent for CUDA_ARCH<700"); +#endif } // Construct an index from dataset and optimized knn graph. diff --git a/cpp/include/raft/neighbors/nn_descent.cuh b/cpp/include/raft/neighbors/nn_descent.cuh index abdaf4771d..ab7675abe7 100644 --- a/cpp/include/raft/neighbors/nn_descent.cuh +++ b/cpp/include/raft/neighbors/nn_descent.cuh @@ -28,6 +28,7 @@ namespace raft::neighbors::experimental::nn_descent { * @{ */ +#if (__CUDA_ARCH__ >= 700) /** * @brief Build nn-descent Index with dataset in device memory * @@ -93,6 +94,7 @@ index build(raft::resources const& res, { return detail::build(res, params, dataset); } +#endif /** @} */ // end group nn-descent From 832d056f20d34389e65c0f3a0131d46c1a8236f2 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 5 Sep 2023 18:10:18 -0700 Subject: [PATCH 19/28] Revert "add arch guards for using wmma" This reverts commit 3f3d9659be9b9e1790c2eda490442bf6b33dee90. --- cpp/include/raft/neighbors/cagra.cuh | 7 ------- cpp/include/raft/neighbors/nn_descent.cuh | 2 -- 2 files changed, 9 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 34f7106c27..5d94bb2ce0 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -21,7 +21,6 @@ #include "detail/cagra/graph_core.cuh" #include -#include #include #include #include @@ -95,7 +94,6 @@ void build_knn_graph(raft::resources const& res, res, dataset_internal, knn_graph_internal, refine_rate, build_params, search_params); } -#if (__CUDA_ARCH__ >= 700) /** * @brief Build a kNN graph using NN-descent. * @@ -144,7 +142,6 @@ experimental::nn_descent::index build_knn_graph( { return detail::build_knn_graph(res, dataset, build_params); } -#endif /** * @brief Sort a KNN graph index. @@ -316,16 +313,12 @@ index build(raft::resources const& res, optimize(res, knn_graph.view(), cagra_graph.view()); } else { -#if (__CUDA_ARCH__ >= 700) auto nn_descent_params = std::make_optional(); nn_descent_params->graph_degree = intermediate_degree; nn_descent_params->intermediate_graph_degree = 1.5 * intermediate_degree; auto nn_descent_index = build_knn_graph(res, dataset, nn_descent_params); optimize(res, nn_descent_index.graph(), cagra_graph.view()); -#else - THROW("Cannot run CAGRA with graph_build_algo::NN_Descent for CUDA_ARCH<700"); -#endif } // Construct an index from dataset and optimized knn graph. diff --git a/cpp/include/raft/neighbors/nn_descent.cuh b/cpp/include/raft/neighbors/nn_descent.cuh index ab7675abe7..abdaf4771d 100644 --- a/cpp/include/raft/neighbors/nn_descent.cuh +++ b/cpp/include/raft/neighbors/nn_descent.cuh @@ -28,7 +28,6 @@ namespace raft::neighbors::experimental::nn_descent { * @{ */ -#if (__CUDA_ARCH__ >= 700) /** * @brief Build nn-descent Index with dataset in device memory * @@ -94,7 +93,6 @@ index build(raft::resources const& res, { return detail::build(res, params, dataset); } -#endif /** @} */ // end group nn-descent From f60db9dbdaff75547ffdcf095f35357992c7628c Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 5 Sep 2023 18:29:14 -0700 Subject: [PATCH 20/28] correctly add arch guards using raft::util::arch --- .../raft/neighbors/detail/nn_descent.cuh | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index 997e1bdea4..3419346e5f 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -35,9 +35,11 @@ #include "../nn_descent_types.hpp" #include +#include #include #include #include +#include // raft::util::arch::SM_* #include #include @@ -706,6 +708,7 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 4) local_join_kernel(const Index_t int* locks, DistData_t* l2_norms) { +#if (__CUDA_ARCH__ >= 700) using namespace nvcuda; __shared__ int s_list[MAX_NUM_BI_SAMPLES * 2]; @@ -928,6 +931,7 @@ __global__ void __launch_bounds__(BLOCK_SIZE, 4) local_join_kernel(const Index_t insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks); } } +#endif } namespace { @@ -1205,6 +1209,8 @@ void GNND::local_join(cudaStream_t stream) template void GNND::build(Data_t* data, const Index_t nrow, Index_t* output_graph) { + using input_t = typename std::remove_const::type; + cudaStream_t stream = raft::resource::get_cuda_stream(res); nrow_ = nrow; graph_.h_graph = (InternalID_t*)output_graph; @@ -1213,7 +1219,6 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data)); if (data_ptr_attr.type == cudaMemoryTypeUnregistered) { size_t batch_size = 100000; - using input_t = typename std::remove_const::type; auto input_data = raft::make_device_matrix( res, batch_size, build_config_.dataset_dim); for (size_t step = 0; step < div_up(nrow_, batch_size); step++) { @@ -1301,7 +1306,21 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out d_list_sizes_old_.data_handle(), stream); - local_join(stream); + // Tensor operations from `mma.h` are guarded with archicteture + // __CUDA_ARCH__ >= 700. Since RAFT supports compilation for ARCH 600, + // we need to ensure that `local_join_kernel` (which uses tensor) operations + // is not only not compiled, but also a runtime error is presented to the user + auto kernel = preprocess_data_kernel; + void* kernel_ptr = reinterpret_cast(kernel); + auto runtime_arch = raft::util::arch::kernel_virtual_arch(kernel_ptr); + auto wmma_range = + raft::util::arch::SM_range(raft::util::arch::SM_70(), raft::util::arch::SM_future()); + + if (wmma_range.contains(runtime_arch)) { + local_join(stream); + } else { + THROW("NN_DESCENT cannot be run for __CUDA_ARCH__ < 700"); + } update_and_sample_thread.join(); From 86f18bb5edd3936d931757736988d3d3dfd35cc6 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 6 Sep 2023 09:19:10 -0700 Subject: [PATCH 21/28] fix launch bounds for arches 750,860 --- .../raft/neighbors/detail/nn_descent.cuh | 36 +++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index 3419346e5f..2ac5535250 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -693,20 +693,28 @@ __device__ __forceinline__ void remove_duplicates( } template > -__global__ void __launch_bounds__(BLOCK_SIZE, 4) local_join_kernel(const Index_t* graph_new, - const Index_t* rev_graph_new, - const int2* sizes_new, - const Index_t* graph_old, - const Index_t* rev_graph_old, - const int2* sizes_old, - const int width, - const __half* data, - const int data_dim, - ID_t* graph, - DistData_t* dists, - int graph_width, - int* locks, - DistData_t* l2_norms) +__global__ void +#ifdef __CUDA_ARCH__ +#if (__CUDA_ARCH__) == 750 || (__CUDA_ARCH__) == 860 +__launch_bounds__(BLOCK_SIZE) +#else +__launch_bounds__(BLOCK_SIZE, 4) +#endif +#endif + local_join_kernel(const Index_t* graph_new, + const Index_t* rev_graph_new, + const int2* sizes_new, + const Index_t* graph_old, + const Index_t* rev_graph_old, + const int2* sizes_old, + const int width, + const __half* data, + const int data_dim, + ID_t* graph, + DistData_t* dists, + int graph_width, + int* locks, + DistData_t* l2_norms) { #if (__CUDA_ARCH__ >= 700) using namespace nvcuda; From 69b7ba7b5a8e910a4ddb3a001254093d57f9a5f3 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 6 Sep 2023 11:22:22 -0700 Subject: [PATCH 22/28] add comment explaining launch bounds changes for archs --- cpp/include/raft/neighbors/detail/nn_descent.cuh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index 2ac5535250..96bb7f944f 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -692,6 +692,12 @@ __device__ __forceinline__ void remove_duplicates( } } +// launch_bounds here denote BLOCK_SIZE = 512 and MIN_BLOCKS_PER_SM = 4 +// Per +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications, +// MAX_RESIDENT_THREAD_PER_SM = BLOCK_SIZE * BLOCKS_PER_SM = 2048 +// For architectures 750 and 860, the values for MAX_RESIDENT_THREAD_PER_SM +// is 1024 and 1536 respectively, which means the bounds don't work anymore template > __global__ void #ifdef __CUDA_ARCH__ From a44e4a42260f276fbfeb9bc5beeb55535f9c4a56 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 12 Sep 2023 19:26:04 -0700 Subject: [PATCH 23/28] first batch of review addressing --- build.sh | 2 +- cpp/include/raft/neighbors/cagra.cuh | 24 ++- cpp/include/raft/neighbors/cagra_types.hpp | 4 + .../raft/neighbors/detail/nn_descent.cuh | 193 +++++++++--------- cpp/include/raft/neighbors/nn_descent.cuh | 30 +-- .../raft/neighbors/nn_descent_types.hpp | 25 ++- cpp/test/neighbors/ann_utils.cuh | 2 + 7 files changed, 152 insertions(+), 128 deletions(-) diff --git a/build.sh b/build.sh index 071820ba93..32582738c3 100755 --- a/build.sh +++ b/build.sh @@ -78,7 +78,7 @@ INSTALL_TARGET=install BUILD_REPORT_METRICS="" BUILD_REPORT_INCL_CACHE_STATS=OFF -TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" +TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" CACHE_ARGS="" diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 5d94bb2ce0..a9d73eb84b 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -39,7 +39,7 @@ namespace raft::neighbors::cagra { * * The kNN graph is the first building block for CAGRA index. * - * The output is a dense matrix that stores the neighbor indices for each pont in the dataset. + * The output is a dense matrix that stores the neighbor indices for each point in the dataset. * Each point has the same number of neighbors. * * See [cagra::build](#cagra::build) for an alternative method. @@ -99,7 +99,7 @@ void build_knn_graph(raft::resources const& res, * * The kNN graph is the first building block for CAGRA index. * - * The output is a dense matrix that stores the neighbor indices for each pont in the dataset. + * The output is a dense matrix that stores the neighbor indices for each point in the dataset. * Each point has the same number of neighbors. * * See [cagra::build](#cagra::build) for an alternative method. @@ -116,7 +116,7 @@ void build_knn_graph(raft::resources const& res, * build_params.graph_degree = 128; * // create knn graph * auto nn_descent_index = cagra::build_knn_graph(res, dataset, build_params); - * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64); + * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64); * cagra::optimize(res, dataset, nn_descent_index.graph.view(), optimized_graph.view()); * // Construct an index from dataset and optimized knn_graph * auto index = cagra::index(res, build_params.metric(), dataset, @@ -126,10 +126,13 @@ void build_knn_graph(raft::resources const& res, * @tparam DataT data element type * @tparam IdxT type of the dataset vector indices * @tparam accessor host or device accessor_type for the dataset - * @param res raft::resources - * @param dataset raft::host/device_matrix_view - * @param build_params raft::neighbors::nn_descent::index_params - * @return raft::neighbors::nn_descent::index + * @param res raft::resources is an object mangaging resources + * @param dataset input raft::host/device_matrix_view that can be located in + * in host or device memory + * @param build_params an instance of experimental::nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @return experimental::nn_descent::index index containing all-neighbors knn graph in host + * memory */ template build(raft::resources const& res, optimize(res, knn_graph.view(), cagra_graph.view()); } else { - auto nn_descent_params = std::make_optional(); - nn_descent_params->graph_degree = intermediate_degree; - nn_descent_params->intermediate_graph_degree = 1.5 * intermediate_degree; + // Use nn-descent to build CAGRA knn graph + auto nn_descent_params = experimental::nn_descent::index_params(); + nn_descent_params.graph_degree = intermediate_degree; + nn_descent_params.intermediate_graph_degree = 1.5 * intermediate_degree; auto nn_descent_index = build_knn_graph(res, dataset, nn_descent_params); optimize(res, nn_descent_index.graph(), cagra_graph.view()); diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index c076c3b051..fc47433877 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -40,6 +40,10 @@ namespace raft::neighbors::cagra { * @{ */ +/** + * @brief ANN algorithm used by CAGRA to build knn graph + * + */ enum class graph_build_algo { IVF_PQ, NN_DESCENT }; struct index_params : ann::index_params { diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index 96bb7f944f..b1a51381e0 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -39,9 +39,12 @@ #include #include #include +#include #include // raft::util::arch::SM_* +#include #include #include +#include namespace raft::neighbors::experimental::nn_descent::detail { @@ -137,20 +140,12 @@ class ResultItem { } }; -constexpr __host__ __device__ size_t div_up(const size_t a, const size_t b) -{ - return a / b + (a % b != 0); -} - -constexpr int to_multiple_of_32(int number) { return div_up(number, 32) * 32; } - -constexpr int WARP_SIZE = 32; -constexpr unsigned int FULL_MASK = 0xffffffff; +using align32 = raft::Pow2<32>; template int get_batch_size(const int it_now, const T nrow, const int batch_size) { - int it_total = div_up(nrow, batch_size); + int it_total = ceildiv(nrow, batch_size); return (it_now == it_total - 1) ? nrow - it_now * batch_size : batch_size; } @@ -160,7 +155,7 @@ constexpr __host__ __device__ __forceinline__ int skew_dim(int ndim) { // all "4"s are for alignment if constexpr (std::is_same::value) { - ndim = div_up(ndim, 4) * 4; + ndim = ceildiv(ndim, 4) * 4; return ndim + (ndim % 32 == 0) * 4; } } @@ -169,14 +164,15 @@ template __device__ __forceinline__ ResultItem xor_swap(ResultItem x, int mask, int dir) { ResultItem y; - y.dist() = __shfl_xor_sync(FULL_MASK, x.dist(), mask, WARP_SIZE); - y.id_with_flag() = __shfl_xor_sync(FULL_MASK, x.id_with_flag(), mask, WARP_SIZE); + y.dist() = __shfl_xor_sync(raft::warp_full_mask(), x.dist(), mask, raft::warp_size()); + y.id_with_flag() = + __shfl_xor_sync(raft::warp_full_mask(), x.id_with_flag(), mask, raft::warp_size()); return x < y == dir ? y : x; } __device__ __forceinline__ int xor_swap(int x, int mask, int dir) { - int y = __shfl_xor_sync(FULL_MASK, x, mask, WARP_SIZE); + int y = __shfl_xor_sync(raft::warp_full_mask(), x, mask, raft::warp_size()); return x < y == dir ? y : x; } @@ -187,19 +183,10 @@ __device__ __forceinline__ uint bfe(uint lane_id, uint pos) return res; } -// https://en.wikipedia.org/wiki/Xorshift#xorshift* -__host__ __device__ __forceinline__ uint64_t xorshift64(uint64_t x) -{ - x ^= x >> 12; - x ^= x << 25; - x ^= x >> 27; - return x * 0x2545F4914F6CDD1DULL; -} - template __device__ __forceinline__ void warp_bitonic_sort(T* element_ptr, const int lane_id) { - static_assert(WARP_SIZE == 32); + static_assert(raft::warp_size() == 32); auto& element = *element_ptr; element = xor_swap(element, 0x01, bfe(lane_id, 1) ^ bfe(lane_id, 0)); element = xor_swap(element, 0x02, bfe(lane_id, 2) ^ bfe(lane_id, 1)); @@ -423,8 +410,8 @@ __device__ __forceinline__ void load_vec(Data_t* vec_buffer, { if constexpr (std::is_same_v or std::is_same_v or std::is_same_v) { - constexpr int num_load_elems_per_warp = WARP_SIZE; - for (int step = 0; step < div_up(padding_dims, num_load_elems_per_warp); step++) { + constexpr int num_load_elems_per_warp = raft::warp_size(); + for (int step = 0; step < ceildiv(padding_dims, num_load_elems_per_warp); step++) { int idx = step * num_load_elems_per_warp + lane_id; if (idx < load_dims) { vec_buffer[idx] = d_vec[idx]; @@ -436,9 +423,9 @@ __device__ __forceinline__ void load_vec(Data_t* vec_buffer, if constexpr (std::is_same_v) { if ((size_t)d_vec % sizeof(float2) == 0 && (size_t)vec_buffer % sizeof(float2) == 0 && load_dims % 4 == 0 && padding_dims % 4 == 0) { - constexpr int num_load_elems_per_warp = WARP_SIZE * 4; + constexpr int num_load_elems_per_warp = raft::warp_size() * 4; #pragma unroll - for (int step = 0; step < div_up(padding_dims, num_load_elems_per_warp); step++) { + for (int step = 0; step < ceildiv(padding_dims, num_load_elems_per_warp); step++) { int idx_in_vec = step * num_load_elems_per_warp + lane_id * 4; if (idx_in_vec + 4 <= load_dims) { *(float2*)(vec_buffer + idx_in_vec) = *(float2*)(d_vec + idx_in_vec); @@ -447,8 +434,8 @@ __device__ __forceinline__ void load_vec(Data_t* vec_buffer, } } } else { - constexpr int num_load_elems_per_warp = WARP_SIZE; - for (int step = 0; step < div_up(padding_dims, num_load_elems_per_warp); step++) { + constexpr int num_load_elems_per_warp = raft::warp_size(); + for (int step = 0; step < ceildiv(padding_dims, num_load_elems_per_warp); step++) { int idx = step * num_load_elems_per_warp + lane_id; if (idx < load_dims) { vec_buffer[idx] = d_vec[idx]; @@ -472,27 +459,27 @@ __global__ void preprocess_data_kernel(const Data_t* input_data, Data_t* s_vec = (Data_t*)buffer; size_t list_id = list_offset + blockIdx.x; - load_vec(s_vec, input_data + blockIdx.x * dim, dim, dim, threadIdx.x % WARP_SIZE); + load_vec(s_vec, input_data + blockIdx.x * dim, dim, dim, threadIdx.x % raft::warp_size()); if (threadIdx.x == 0) { l2_norm = 0; } __syncthreads(); - int lane_id = threadIdx.x % WARP_SIZE; - for (int step = 0; step < div_up(dim, WARP_SIZE); step++) { - int idx = step * WARP_SIZE + lane_id; + int lane_id = threadIdx.x % raft::warp_size(); + for (int step = 0; step < ceildiv(dim, raft::warp_size()); step++) { + int idx = step * raft::warp_size() + lane_id; float part_dist = 0; if (idx < dim) { part_dist = s_vec[idx]; part_dist = part_dist * part_dist; } __syncwarp(); - for (int offset = WARP_SIZE >> 1; offset >= 1; offset >>= 1) { - part_dist += __shfl_down_sync(FULL_MASK, part_dist, offset); + for (int offset = raft::warp_size() >> 1; offset >= 1; offset >>= 1) { + part_dist += __shfl_down_sync(raft::warp_full_mask(), part_dist, offset); } if (lane_id == 0) { l2_norm += part_dist; } __syncwarp(); } - for (int step = 0; step < div_up(dim, WARP_SIZE); step++) { - int idx = step * WARP_SIZE + threadIdx.x; + for (int step = 0; step < ceildiv(dim, raft::warp_size()); step++) { + int idx = step * raft::warp_size() + threadIdx.x; if (idx < dim) { if (l2_norms == nullptr) { output_data[list_id * dim + idx] = @@ -536,11 +523,11 @@ __device__ void insert_to_global_graph(ResultItem elem, int* locks) { int tx = threadIdx.x; - int lane_id = tx % WARP_SIZE; + int lane_id = tx % raft::warp_size(); size_t global_idx_base = list_id * node_degree; if (elem.id() == list_id) return; - const int num_segments = div_up(node_degree, WARP_SIZE); + const int num_segments = ceildiv(node_degree, raft::warp_size()); int loop_flag = 0; do { @@ -549,11 +536,11 @@ __device__ void insert_to_global_graph(ResultItem elem, loop_flag = atomicCAS(&locks[list_id * num_segments + segment_id], 0, 1) == 0; } - loop_flag = __shfl_sync(FULL_MASK, loop_flag, 0); + loop_flag = __shfl_sync(raft::warp_full_mask(), loop_flag, 0); if (loop_flag == 1) { ResultItem knn_list_frag; - int local_idx = segment_id * WARP_SIZE + lane_id; + int local_idx = segment_id * raft::warp_size() + lane_id; size_t global_idx = global_idx_base + local_idx; if (local_idx < node_degree) { knn_list_frag.id_with_flag() = graph[global_idx].id_with_flag(); @@ -563,26 +550,27 @@ __device__ void insert_to_global_graph(ResultItem elem, int pos_to_insert = -1; ResultItem prev_elem; - prev_elem.id_with_flag() = __shfl_up_sync(FULL_MASK, knn_list_frag.id_with_flag(), 1); - prev_elem.dist() = __shfl_up_sync(FULL_MASK, knn_list_frag.dist(), 1); + prev_elem.id_with_flag() = + __shfl_up_sync(raft::warp_full_mask(), knn_list_frag.id_with_flag(), 1); + prev_elem.dist() = __shfl_up_sync(raft::warp_full_mask(), knn_list_frag.dist(), 1); if (lane_id == 0) { prev_elem = ResultItem{std::numeric_limits::min(), std::numeric_limits::lowest()}; } if (elem > prev_elem && elem < knn_list_frag) { - pos_to_insert = segment_id * WARP_SIZE + lane_id; + pos_to_insert = segment_id * raft::warp_size() + lane_id; } else if (elem == prev_elem || elem == knn_list_frag) { pos_to_insert = -2; } - uint mask = __ballot_sync(FULL_MASK, pos_to_insert >= 0); + uint mask = __ballot_sync(raft::warp_full_mask(), pos_to_insert >= 0); if (mask) { uint set_lane_id = __fns(mask, 0, 1); - pos_to_insert = __shfl_sync(FULL_MASK, pos_to_insert, set_lane_id); + pos_to_insert = __shfl_sync(raft::warp_full_mask(), pos_to_insert, set_lane_id); } if (pos_to_insert >= 0) { - int local_idx = segment_id * WARP_SIZE + lane_id; + int local_idx = segment_id * raft::warp_size() + lane_id; if (local_idx > pos_to_insert) { local_idx++; } else if (local_idx == pos_to_insert) { @@ -591,7 +579,7 @@ __device__ void insert_to_global_graph(ResultItem elem, local_idx++; } size_t global_pos = global_idx_base + local_idx; - if (local_idx < (segment_id + 1) * WARP_SIZE && local_idx < node_degree) { + if (local_idx < (segment_id + 1) * raft::warp_size() && local_idx < node_degree) { graph[global_pos].id_with_flag() = knn_list_frag.id_with_flag(); dists[global_pos] = knn_list_frag.dist(); } @@ -609,14 +597,14 @@ __device__ ResultItem get_min_item(const Index_t id, const DistData_t* distances, const bool find_in_row = true) { - int lane_id = threadIdx.x % WARP_SIZE; + int lane_id = threadIdx.x % raft::warp_size(); static_assert(MAX_NUM_BI_SAMPLES == 64); - int idx[MAX_NUM_BI_SAMPLES / WARP_SIZE]; - float dist[MAX_NUM_BI_SAMPLES / WARP_SIZE] = {std::numeric_limits::max(), - std::numeric_limits::max()}; - idx[0] = lane_id; - idx[1] = WARP_SIZE + lane_id; + int idx[MAX_NUM_BI_SAMPLES / raft::warp_size()]; + float dist[MAX_NUM_BI_SAMPLES / raft::warp_size()] = {std::numeric_limits::max(), + std::numeric_limits::max()}; + idx[0] = lane_id; + idx[1] = raft::warp_size() + lane_id; if (neighbs[idx[0]] != id) { dist[0] = find_in_row ? distances[idx_in_list * SKEWED_MAX_NUM_BI_SAMPLES + lane_id] @@ -624,9 +612,10 @@ __device__ ResultItem get_min_item(const Index_t id, } if (neighbs[idx[1]] != id) { - dist[1] = find_in_row - ? distances[idx_in_list * SKEWED_MAX_NUM_BI_SAMPLES + WARP_SIZE + lane_id] - : distances[idx_in_list + (WARP_SIZE + lane_id) * SKEWED_MAX_NUM_BI_SAMPLES]; + dist[1] = + find_in_row + ? distances[idx_in_list * SKEWED_MAX_NUM_BI_SAMPLES + raft::warp_size() + lane_id] + : distances[idx_in_list + (raft::warp_size() + lane_id) * SKEWED_MAX_NUM_BI_SAMPLES]; } if (dist[1] < dist[0]) { @@ -634,9 +623,9 @@ __device__ ResultItem get_min_item(const Index_t id, idx[0] = idx[1]; } __syncwarp(); - for (int offset = WARP_SIZE >> 1; offset >= 1; offset >>= 1) { - float other_idx = __shfl_down_sync(FULL_MASK, idx[0], offset); - float other_dist = __shfl_down_sync(FULL_MASK, dist[0], offset); + for (int offset = raft::warp_size() >> 1; offset >= 1; offset >>= 1) { + float other_idx = __shfl_down_sync(raft::warp_full_mask(), idx[0], offset); + float other_dist = __shfl_down_sync(raft::warp_full_mask(), dist[0], offset); if (other_dist < dist[0]) { dist[0] = other_dist; idx[0] = other_idx; @@ -644,8 +633,8 @@ __device__ ResultItem get_min_item(const Index_t id, } ResultItem result; - result.dist() = __shfl_sync(FULL_MASK, dist[0], 0); - result.id_with_flag() = neighbs[__shfl_sync(FULL_MASK, idx[0], 0)]; + result.dist() = __shfl_sync(raft::warp_full_mask(), dist[0], 0); + result.id_with_flag() = neighbs[__shfl_sync(raft::warp_full_mask(), idx[0], 0)]; return result; } @@ -653,12 +642,12 @@ template __device__ __forceinline__ void remove_duplicates( T* list_a, int list_a_size, T* list_b, int list_b_size, int& unique_counter, int execute_warp_id) { - static_assert(WARP_SIZE == 32); - if (!(threadIdx.x >= execute_warp_id * WARP_SIZE && - threadIdx.x < execute_warp_id * WARP_SIZE + WARP_SIZE)) { + static_assert(raft::warp_size() == 32); + if (!(threadIdx.x >= execute_warp_id * raft::warp_size() && + threadIdx.x < execute_warp_id * raft::warp_size() + raft::warp_size())) { return; } - int lane_id = threadIdx.x % WARP_SIZE; + int lane_id = threadIdx.x % raft::warp_size(); T elem = std::numeric_limits::max(); if (lane_id < list_a_size) { elem = list_a[lane_id]; } warp_bitonic_sort(&elem, lane_id); @@ -784,9 +773,9 @@ __launch_bounds__(BLOCK_SIZE, 4) list_new_size = list_new_size2.x + s_unique_counter[0]; list_old_size = list_old_size2.x + s_unique_counter[1]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; - constexpr int num_warps = BLOCK_SIZE / WARP_SIZE; + int warp_id = threadIdx.x / raft::warp_size(); + int lane_id = threadIdx.x % raft::warp_size(); + constexpr int num_warps = BLOCK_SIZE / raft::warp_size(); int warp_id_y = warp_id / 4; int warp_id_x = warp_id % 4; @@ -795,8 +784,8 @@ __launch_bounds__(BLOCK_SIZE, 4) wmma::fragment b_frag; wmma::fragment c_frag; wmma::fill_fragment(c_frag, 0.0); - for (int step = 0; step < div_up(data_dim, TILE_COL_WIDTH); step++) { - int num_load_elems = (step == div_up(data_dim, TILE_COL_WIDTH) - 1) + for (int step = 0; step < ceildiv(data_dim, TILE_COL_WIDTH); step++) { + int num_load_elems = (step == ceildiv(data_dim, TILE_COL_WIDTH) - 1) ? data_dim - step * TILE_COL_WIDTH : TILE_COL_WIDTH; #pragma unroll @@ -845,8 +834,8 @@ __launch_bounds__(BLOCK_SIZE, 4) } __syncthreads(); - for (int step = 0; step < div_up(list_new_size, num_warps); step++) { - int idx_in_list = step * num_warps + tx / WARP_SIZE; + for (int step = 0; step < ceildiv(list_new_size, num_warps); step++) { + int idx_in_list = step * num_warps + tx / raft::warp_size(); if (idx_in_list >= list_new_size) continue; auto min_elem = get_min_item(s_list[idx_in_list], idx_in_list, new_neighbors, s_distances); if (min_elem.id() < gridDim.x) { @@ -859,8 +848,8 @@ __launch_bounds__(BLOCK_SIZE, 4) __syncthreads(); wmma::fill_fragment(c_frag, 0.0); - for (int step = 0; step < div_up(data_dim, TILE_COL_WIDTH); step++) { - int num_load_elems = (step == div_up(data_dim, TILE_COL_WIDTH) - 1) + for (int step = 0; step < ceildiv(data_dim, TILE_COL_WIDTH); step++) { + int num_load_elems = (step == ceildiv(data_dim, TILE_COL_WIDTH) - 1) ? data_dim - step * TILE_COL_WIDTH : TILE_COL_WIDTH; if (TILE_COL_WIDTH < data_dim) { @@ -924,8 +913,8 @@ __launch_bounds__(BLOCK_SIZE, 4) } __syncthreads(); - for (int step = 0; step < div_up(MAX_NUM_BI_SAMPLES * 2, num_warps); step++) { - int idx_in_list = step * num_warps + tx / WARP_SIZE; + for (int step = 0; step < ceildiv(MAX_NUM_BI_SAMPLES * 2, num_warps); step++) { + int idx_in_list = step * num_warps + tx / raft::warp_size(); if (idx_in_list >= list_new_size && idx_in_list < MAX_NUM_BI_SAMPLES) continue; if (idx_in_list >= MAX_NUM_BI_SAMPLES + list_old_size && idx_in_list < MAX_NUM_BI_SAMPLES * 2) continue; @@ -1148,9 +1137,9 @@ GNND::GNND(raft::resources const& res, const BuildConfig& build : res(res), build_config_(build_config), graph_(build_config.max_dataset_size, - to_multiple_of_32(build_config.node_degree), - to_multiple_of_32(build_config.internal_node_degree ? build_config.internal_node_degree - : build_config.node_degree), + align32::roundUp(build_config.node_degree), + align32::roundUp(build_config.internal_node_degree ? build_config.internal_node_degree + : build_config.node_degree), NUM_SAMPLES), nrow_(build_config.max_dataset_size), ndim_(build_config.dataset_dim), @@ -1190,7 +1179,7 @@ void GNND::add_reverse_edges(Index_t* graph_ptr, int2* list_sizes, cudaStream_t stream) { - add_rev_edges_kernel<<>>( + add_rev_edges_kernel<<>>( graph_ptr, d_rev_graph_ptr, NUM_SAMPLES, list_sizes); raft::copy( h_rev_graph_ptr, d_rev_graph_ptr, nrow_ * NUM_SAMPLES, raft::resource::get_cuda_stream(res)); @@ -1232,20 +1221,22 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out cudaPointerAttributes data_ptr_attr; RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data)); if (data_ptr_attr.type == cudaMemoryTypeUnregistered) { - size_t batch_size = 100000; - auto input_data = raft::make_device_matrix( + int batch_size = 100000; + auto input_data = raft::make_device_matrix( res, batch_size, build_config_.dataset_dim); - for (size_t step = 0; step < div_up(nrow_, batch_size); step++) { - size_t list_offset = step * batch_size; - size_t num_lists = step != div_up(nrow_, batch_size) - 1 ? batch_size : nrow_ - list_offset; + for (int step = 0; step < ceildiv(nrow_, batch_size); step++) { + int list_offset = step * batch_size; + int num_lists = step != ceildiv(nrow_, batch_size) - 1 ? batch_size : nrow_ - list_offset; raft::copy(input_data.data_handle(), data + list_offset * build_config_.dataset_dim, num_lists * build_config_.dataset_dim, raft::resource::get_cuda_stream(res)); preprocess_data_kernel<<(raft::warp_size())) * + raft::warp_size(), stream>>>(input_data.data_handle(), d_data_.data_handle(), build_config_.dataset_dim, @@ -1253,12 +1244,12 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out list_offset); } } else { - preprocess_data_kernel<<>>( - data, d_data_.data_handle(), build_config_.dataset_dim, l2_norms_.data_handle()); + preprocess_data_kernel<<< + nrow_, + raft::warp_size(), + sizeof(Data_t) * ceildiv(build_config_.dataset_dim, static_cast(raft::warp_size())) * + raft::warp_size(), + stream>>>(data, d_data_.data_handle(), build_config_.dataset_dim, l2_norms_.data_handle()); } thrust::fill(thrust::device.on(stream), @@ -1371,7 +1362,8 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out if (id < nrow_) { graph_shrink_buffer[i * build_config_.node_degree + j] = id; } else { - graph_shrink_buffer[i * build_config_.node_degree + j] = xorshift64(idx) % nrow_; + graph_shrink_buffer[i * build_config_.node_degree + j] = + raft::neighbors::cagra::detail::device::xorshift64(idx) % nrow_; } } } @@ -1416,9 +1408,10 @@ index build(raft::resources const& res, // The elements in each knn-list are partitioned into different buckets, and we need more buckets // to mitigate bucket collisions. `intermediate_degree` is OK to larger than // extended_graph_degree. - size_t extended_graph_degree = to_multiple_of_32(graph_degree * (graph_degree <= 32 ? 1.0 : 1.3)); - size_t extended_intermediate_degree = - to_multiple_of_32(intermediate_degree * (intermediate_degree <= 32 ? 1.0 : 1.3)); + size_t extended_graph_degree = + align32::roundUp(static_cast(graph_degree * (graph_degree <= 32 ? 1.0 : 1.3))); + size_t extended_intermediate_degree = align32::roundUp( + static_cast(intermediate_degree * (intermediate_degree <= 32 ? 1.0 : 1.3))); auto int_graph = raft::make_host_matrix( dataset.extent(0), static_cast(extended_graph_degree)); diff --git a/cpp/include/raft/neighbors/nn_descent.cuh b/cpp/include/raft/neighbors/nn_descent.cuh index abdaf4771d..b20d407223 100644 --- a/cpp/include/raft/neighbors/nn_descent.cuh +++ b/cpp/include/raft/neighbors/nn_descent.cuh @@ -24,7 +24,7 @@ namespace raft::neighbors::experimental::nn_descent { /** - * @defgroup nn-descent CUDA ANN Graph-based gradient descent nearest neighbor + * @defgroup nn-descent CUDA gradient descent nearest neighbor * @{ */ @@ -46,12 +46,14 @@ namespace raft::neighbors::experimental::nn_descent { * // dataset * @endcode * - * @tparam T data-type - * @tparam IdxT index-type - * @param res raft::resources - * @param params nn_descent::index_params - * @param dataset raft::device_matrix_view - * @return index + * @tparam T data-type of the input dataset + * @tparam IdxT data-type for the output index + * @param res raft::resources is an object mangaging resources + * @param params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param dataset raft::device_matrix_view input dataset expected to be located + * in device memory + * @return index index containing all-neighbors knn graph in host memory */ template index build(raft::resources const& res, @@ -79,12 +81,14 @@ index build(raft::resources const& res, * // dataset * @endcode * - * @tparam T data-type - * @tparam IdxT index-type - * @param res raft::resources - * @param params nn_descent::index_params - * @param dataset raft::host_matrix_view - * @return index + * @tparam T data-type of the input dataset + * @tparam IdxT data-type for the output index + * @param res raft::resources is an object mangaging resources + * @param params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param dataset raft::host_matrix_view input dataset expected to be located + * in host memory + * @return index index containing all-neighbors knn graph in host memory */ template index build(raft::resources const& res, diff --git a/cpp/include/raft/neighbors/nn_descent_types.hpp b/cpp/include/raft/neighbors/nn_descent_types.hpp index 85f9336263..ef169faf1b 100644 --- a/cpp/include/raft/neighbors/nn_descent_types.hpp +++ b/cpp/include/raft/neighbors/nn_descent_types.hpp @@ -31,15 +31,32 @@ namespace raft::neighbors::experimental::nn_descent { * @{ */ +/** + * @brief Parameters used to build an nn-descent index + * + * `graph_degree`: For an input dataset of dimensions (N, D), + * determines the final dimensions of the all-neighbors knn graph + * which turns out to be of dimensions (N, graph_degree) + * `intermediate_graph_degree`: Internally, nn-descent builds an + * all-neighbors knn graph of dimensions (N, intermediate_graph_degree) + * before selecting the final `graph_degree` neighbors. It's recommended + * that `intermediate_graph_degree` >= 1.5 * graph_degree + * `max_iterations`: The number of iterations that nn-descent will refine + * the graph for. More iterations produce a better quality graph at cost of performance + * `termination_threshold`: The delta at which nn-descent will terminate its iterations + * + */ struct index_params : ann::index_params { - size_t intermediate_graph_degree = 128; // Degree of input graph for pruning. size_t graph_degree = 64; // Degree of output graph. - size_t max_iterations = 50; // Number of nn-descent iterations. + size_t intermediate_graph_degree = 128; // Degree of input graph for pruning. + size_t max_iterations = 20; // Number of nn-descent iterations. float termination_threshold = 0.0001; // Termination threshold of nn-descent. }; /** - * @brief nn-descent Index + * @brief nn-descent Build an nn-descent index + * The index contains an all-neighbors graph of the input dataset + * stored in host memory of dimensions (n_rows, n_cols) * * @tparam IdxT dtype to be used for constructing knn-graph */ @@ -53,7 +70,7 @@ struct index : ann::index { * The type of the knn-graph is a dense raft::host_matrix and dimensions are * (n_rows, n_cols). * - * @param res raft::resources + * @param res raft::resources is an object mangaging resources * @param n_rows number of rows in knn-graph * @param n_cols number of cols in knn-graph */ diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index fabc21f508..be60ec5b6d 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -164,6 +164,8 @@ auto eval_recall(const std::vector& expected_idx, return testing::AssertionSuccess(); } +/** same as eval_recall, but in case indices do not match, + * then check distances as well, and accept match if actual dist is equal to expected_dist */ template auto eval_neighbours(const std::vector& expected_idx, const std::vector& actual_idx, From 4f0e425466e0092e86dfb070bde232551092e4b2 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 14 Sep 2023 13:12:54 -0700 Subject: [PATCH 24/28] use batch load iterator --- .../raft/neighbors/detail/nn_descent.cuh | 32 ++++++------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index b1a51381e0..da22d9ff2d 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -40,6 +40,7 @@ #include #include #include +#include #include // raft::util::arch::SM_* #include #include @@ -206,11 +207,6 @@ __device__ __forceinline__ void warp_bitonic_sort(T* element_ptr, const int lane return; } -enum class Metric_t { - METRIC_INNER_PRODUCT = 0, - METRIC_L2 = 1, -}; - struct BuildConfig { size_t max_dataset_size; size_t dataset_dim; @@ -219,7 +215,6 @@ struct BuildConfig { // If internal_node_degree == 0, the value of node_degree will be assigned to it size_t max_iterations{50}; float termination_threshold{0.0001}; - Metric_t metric_type{Metric_t::METRIC_INNER_PRODUCT}; }; template @@ -1212,8 +1207,6 @@ void GNND::local_join(cudaStream_t stream) template void GNND::build(Data_t* data, const Index_t nrow, Index_t* output_graph) { - using input_t = typename std::remove_const::type; - cudaStream_t stream = raft::resource::get_cuda_stream(res); nrow_ = nrow; graph_.h_graph = (InternalID_t*)output_graph; @@ -1221,27 +1214,21 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out cudaPointerAttributes data_ptr_attr; RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data)); if (data_ptr_attr.type == cudaMemoryTypeUnregistered) { - int batch_size = 100000; - auto input_data = raft::make_device_matrix( - res, batch_size, build_config_.dataset_dim); - for (int step = 0; step < ceildiv(nrow_, batch_size); step++) { - int list_offset = step * batch_size; - int num_lists = step != ceildiv(nrow_, batch_size) - 1 ? batch_size : nrow_ - list_offset; - raft::copy(input_data.data_handle(), - data + list_offset * build_config_.dataset_dim, - num_lists * build_config_.dataset_dim, - raft::resource::get_cuda_stream(res)); - preprocess_data_kernel<<(nrow_), build_config_.dataset_dim, batch_size, stream}; + for (auto const& batch : vec_batches) { + preprocess_data_kernel<<(raft::warp_size())) * raft::warp_size(), - stream>>>(input_data.data_handle(), + stream>>>(batch.data(), d_data_.data_handle(), build_config_.dataset_dim, l2_norms_.data_handle(), - list_offset); + batch.offset()); } } else { preprocess_data_kernel<<< @@ -1421,8 +1408,7 @@ index build(raft::resources const& res, .node_degree = extended_graph_degree, .internal_node_degree = extended_intermediate_degree, .max_iterations = params.max_iterations, - .termination_threshold = params.termination_threshold, - .metric_type = Metric_t::METRIC_L2}; + .termination_threshold = params.termination_threshold}; GNND nnd(res, build_config); nnd.build(dataset.data_handle(), dataset.extent(0), int_graph.data_handle()); From c55ae4e0e5b895d98fd7e1f52515af300a8c895a Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 14 Sep 2023 14:16:24 -0700 Subject: [PATCH 25/28] add nn-descent to python cagra --- .../raft/neighbors/detail/nn_descent.cuh | 2 ++ .../pylibraft/neighbors/cagra/cagra.pyx | 23 ++++++++++--------- .../pylibraft/neighbors/cagra/cpp/c_cagra.pxd | 5 ++++ python/pylibraft/pylibraft/test/test_cagra.py | 12 +++++++++- 4 files changed, 30 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index da22d9ff2d..adcd734cc5 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -1207,6 +1207,8 @@ void GNND::local_join(cudaStream_t stream) template void GNND::build(Data_t* data, const Index_t nrow, Index_t* output_graph) { + using input_t = typename std::remove_const::type; + cudaStream_t stream = raft::resource::get_cuda_stream(res); nrow_ = nrow; graph_.h_graph = (InternalID_t*)output_graph; diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx index e0c59a5ed3..c11d933b27 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx +++ b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx @@ -104,11 +104,13 @@ cdef class IndexParams: graph_degree : int, default = 64 - add_data_on_build : bool, default = True - After training the coarse and fine quantizers, we will populate - the index with the dataset if add_data_on_build == True, otherwise - the index is left empty, and the extend method can be used - to add new vectors to the index. + build_algo: string denoting the graph building algorithm to use, + default = "ivf_pq" + Valid values for algo: ["ivf_pq", "nn_descent"], where + - ivf_pq will use the IVF-PQ algorithm for building the knn graph + - nn_descent (experimental) will use the NN-Descent algorithm for + building the knn graph. It is expected to be generally + faster than ivf_pq. """ cdef c_cagra.index_params params @@ -116,12 +118,15 @@ cdef class IndexParams: metric="sqeuclidean", intermediate_graph_degree=128, graph_degree=64, - add_data_on_build=True): + build_algo="ivf_pq"): self.params.metric = _get_metric(metric) self.params.metric_arg = 0 self.params.intermediate_graph_degree = intermediate_graph_degree self.params.graph_degree = graph_degree - self.params.add_data_on_build = add_data_on_build + if build_algo == "ivf_pq": + self.params.build_algo = c_cagra.graph_build_algo.IVF_PQ + elif build_algo == "nn_descent": + self.params.build_algo = c_cagra.graph_build_algo.NN_DESCENT @property def metric(self): @@ -135,10 +140,6 @@ cdef class IndexParams: def graph_degree(self): return self.params.graph_degree - @property - def add_data_on_build(self): - return self.params.add_data_on_build - cdef class Index: cdef readonly bool trained diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd index 0c683bcd9b..7e22f274e9 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd +++ b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd @@ -51,9 +51,14 @@ from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( cdef extern from "raft/neighbors/cagra_types.hpp" \ namespace "raft::neighbors::cagra" nogil: + ctypedef enum graph_build_algo: + IVF_PQ "raft::neighbors::cagra::graph_build_algo::IVF_PQ", + NN_DESCENT "raft::neighbors::cagra::graph_build_algo::NN_DESCENT" + cpdef cppclass index_params(ann_index_params): size_t intermediate_graph_degree size_t graph_degree + graph_build_algo build_algo ctypedef enum search_algo: SINGLE_CTA "raft::neighbors::cagra::search_algo::SINGLE_CTA", diff --git a/python/pylibraft/pylibraft/test/test_cagra.py b/python/pylibraft/pylibraft/test/test_cagra.py index 74e9f53b91..f74fc5ae62 100644 --- a/python/pylibraft/pylibraft/test/test_cagra.py +++ b/python/pylibraft/pylibraft/test/test_cagra.py @@ -52,6 +52,7 @@ def run_cagra_build_search_test( metric="euclidean", intermediate_graph_degree=128, graph_degree=64, + build_algo="ivf_pq", array_type="device", compare=True, inplace=True, @@ -67,6 +68,7 @@ def run_cagra_build_search_test( metric=metric, intermediate_graph_degree=intermediate_graph_degree, graph_degree=graph_degree, + build_algo=build_algo, ) if array_type == "device": @@ -139,13 +141,17 @@ def run_cagra_build_search_test( @pytest.mark.parametrize("inplace", [True, False]) @pytest.mark.parametrize("dtype", [np.float32, np.int8, np.uint8]) @pytest.mark.parametrize("array_type", ["device", "host"]) -def test_cagra_dataset_dtype_host_device(dtype, array_type, inplace): +@pytest.mark.parametrize("build_algo", ["ivf_pq", "nn_descent"]) +def test_cagra_dataset_dtype_host_device( + dtype, array_type, inplace, build_algo +): # Note that inner_product tests use normalized input which we cannot # represent in int8, therefore we test only sqeuclidean metric here. run_cagra_build_search_test( dtype=dtype, inplace=inplace, array_type=array_type, + build_algo=build_algo, ) @@ -158,6 +164,7 @@ def test_cagra_dataset_dtype_host_device(dtype, array_type, inplace): "add_data_on_build": True, "k": 1, "metric": "euclidean", + "build_algo": "ivf_pq", }, { "intermediate_graph_degree": 32, @@ -165,6 +172,7 @@ def test_cagra_dataset_dtype_host_device(dtype, array_type, inplace): "add_data_on_build": False, "k": 5, "metric": "sqeuclidean", + "build_algo": "ivf_pq", }, { "intermediate_graph_degree": 128, @@ -172,6 +180,7 @@ def test_cagra_dataset_dtype_host_device(dtype, array_type, inplace): "add_data_on_build": True, "k": 10, "metric": "inner_product", + "build_algo": "nn_descent", }, ], ) @@ -184,6 +193,7 @@ def test_cagra_index_params(params): graph_degree=params["graph_degree"], intermediate_graph_degree=params["intermediate_graph_degree"], compare=False, + build_algo=params["build_algo"], ) From 4344666526526625796ccce453aa823bdbf988e1 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 21 Sep 2023 13:09:44 -0700 Subject: [PATCH 26/28] more review updates --- .../raft/neighbors/detail/nn_descent.cuh | 32 +++++++------------ 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index adcd734cc5..c122c257a2 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -442,6 +442,7 @@ __device__ __forceinline__ void load_vec(Data_t* vec_buffer, } } +/** Calculate L2 norm, and cast data to __half */ template __global__ void preprocess_data_kernel(const Data_t* input_data, __half* output_data, @@ -1215,30 +1216,21 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out cudaPointerAttributes data_ptr_attr; RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data)); - if (data_ptr_attr.type == cudaMemoryTypeUnregistered) { - size_t batch_size = 100000; - raft::spatial::knn::detail::utils::batch_load_iterator vec_batches{ - data, static_cast(nrow_), build_config_.dataset_dim, batch_size, stream}; - for (auto const& batch : vec_batches) { - preprocess_data_kernel<<(raft::warp_size())) * - raft::warp_size(), - stream>>>(batch.data(), - d_data_.data_handle(), - build_config_.dataset_dim, - l2_norms_.data_handle(), - batch.offset()); - } - } else { + size_t batch_size = (data_ptr_attr.devicePointer == nullptr) ? 100000 : nrow_; + + raft::spatial::knn::detail::utils::batch_load_iterator vec_batches{ + data, static_cast(nrow_), build_config_.dataset_dim, batch_size, stream}; + for (auto const& batch : vec_batches) { preprocess_data_kernel<<< - nrow_, + batch.size(), raft::warp_size(), sizeof(Data_t) * ceildiv(build_config_.dataset_dim, static_cast(raft::warp_size())) * raft::warp_size(), - stream>>>(data, d_data_.data_handle(), build_config_.dataset_dim, l2_norms_.data_handle()); + stream>>>(batch.data(), + d_data_.data_handle(), + build_config_.dataset_dim, + l2_norms_.data_handle(), + batch.offset()); } thrust::fill(thrust::device.on(stream), From 76b520ad1cfd1f9aa59b65fda2d7647934b79ce2 Mon Sep 17 00:00:00 2001 From: divyegala Date: Fri, 22 Sep 2023 14:35:42 -0700 Subject: [PATCH 27/28] address more review comments --- cpp/include/raft/neighbors/cagra.cuh | 43 ++++----- cpp/include/raft/neighbors/cagra_types.hpp | 7 +- .../neighbors/detail/cagra/cagra_build.cuh | 17 ++-- .../raft/neighbors/detail/nn_descent.cuh | 41 ++++++++- cpp/include/raft/neighbors/nn_descent.cuh | 88 +++++++++++++++++-- .../raft/neighbors/nn_descent_types.hpp | 31 ++++++- cpp/test/CMakeLists.txt | 1 - cpp/test/neighbors/ann_cagra.cuh | 18 ++-- cpp/test/neighbors/ann_nn_descent.cuh | 7 -- .../ann_nn_descent/test_float_int64_t.cu | 28 ------ 10 files changed, 189 insertions(+), 92 deletions(-) delete mode 100644 cpp/test/neighbors/ann_nn_descent/test_float_int64_t.cu diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index a9d73eb84b..794dd3e86a 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -69,7 +69,7 @@ namespace raft::neighbors::cagra { * @param[in] res raft resources * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] * @param[out] knn_graph a host matrix view to store the output knn graph [n_rows, graph_degree] - * @param[in] refine_rate refinement rate for ivf-pq search + * @param[in] refine_rate (optional) refinement rate for ivf-pq search * @param[in] build_params (optional) ivf_pq index building parameters for knn graph * @param[in] search_params (optional) ivf_pq search parameters */ @@ -114,8 +114,9 @@ void build_knn_graph(raft::resources const& res, * // use default index parameters * nn_descent::index_params build_params; * build_params.graph_degree = 128; + * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); * // create knn graph - * auto nn_descent_index = cagra::build_knn_graph(res, dataset, build_params); + * cagra::build_knn_graph(res, dataset, knn_graph.view(), build_params); * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64); * cagra::optimize(res, dataset, nn_descent_index.graph.view(), optimized_graph.view()); * // Construct an index from dataset and optimized knn_graph @@ -126,24 +127,23 @@ void build_knn_graph(raft::resources const& res, * @tparam DataT data element type * @tparam IdxT type of the dataset vector indices * @tparam accessor host or device accessor_type for the dataset - * @param res raft::resources is an object mangaging resources - * @param dataset input raft::host/device_matrix_view that can be located in + * @param[in] res raft::resources is an object mangaging resources + * @param[in] dataset input raft::host/device_matrix_view that can be located in * in host or device memory - * @param build_params an instance of experimental::nn_descent::index_params that are parameters + * @param[out] knn_graph a host matrix view to store the output knn graph [n_rows, graph_degree] + * @param[in] build_params an instance of experimental::nn_descent::index_params that are parameters * to run the nn-descent algorithm - * @return experimental::nn_descent::index index containing all-neighbors knn graph in host - * memory */ template , memory_type::device>> -experimental::nn_descent::index build_knn_graph( - raft::resources const& res, - mdspan, row_major, accessor> dataset, - std::optional build_params = std::nullopt) +void build_knn_graph(raft::resources const& res, + mdspan, row_major, accessor> dataset, + raft::host_matrix_view knn_graph, + experimental::nn_descent::index_params build_params) { - return detail::build_knn_graph(res, dataset, build_params); + detail::build_knn_graph(res, dataset, knn_graph, build_params); } /** @@ -307,24 +307,27 @@ index build(raft::resources const& res, graph_degree = intermediate_degree; } - auto cagra_graph = raft::make_host_matrix(dataset.extent(0), graph_degree); + std::optional> knn_graph( + raft::make_host_matrix(dataset.extent(0), intermediate_degree)); if (params.build_algo == graph_build_algo::IVF_PQ) { - auto knn_graph = raft::make_host_matrix(dataset.extent(0), intermediate_degree); + build_knn_graph(res, dataset, knn_graph->view()); - build_knn_graph(res, dataset, knn_graph.view()); - - optimize(res, knn_graph.view(), cagra_graph.view()); } else { // Use nn-descent to build CAGRA knn graph auto nn_descent_params = experimental::nn_descent::index_params(); nn_descent_params.graph_degree = intermediate_degree; nn_descent_params.intermediate_graph_degree = 1.5 * intermediate_degree; - auto nn_descent_index = build_knn_graph(res, dataset, nn_descent_params); - - optimize(res, nn_descent_index.graph(), cagra_graph.view()); + build_knn_graph(res, dataset, knn_graph->view(), nn_descent_params); } + auto cagra_graph = raft::make_host_matrix(dataset.extent(0), graph_degree); + + optimize(res, knn_graph->view(), cagra_graph.view()); + + // free intermediate graph before trying to create the index + knn_graph.reset(); + // Construct an index from dataset and optimized knn graph. return index(res, params.metric, dataset, raft::make_const_mdspan(cagra_graph.view())); } diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index fc47433877..2c6d2c1bb9 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -44,7 +44,12 @@ namespace raft::neighbors::cagra { * @brief ANN algorithm used by CAGRA to build knn graph * */ -enum class graph_build_algo { IVF_PQ, NN_DESCENT }; +enum class graph_build_algo { + /* Use IVF-PQ to build all-neighbors knn graph */ + IVF_PQ, + /* Experimental, use NN-Descent to build all-neighbors knn graph */ + NN_DESCENT +}; struct index_params : ann::index_params { /** Degree of input graph for pruning. */ diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index d572285178..ecc047743e 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -240,16 +240,13 @@ void build_knn_graph(raft::resources const& res, } template -experimental::nn_descent::index build_knn_graph( - raft::resources const& res, - mdspan, row_major, accessor> dataset, - std::optional build_params = std::nullopt) +void build_knn_graph(raft::resources const& res, + mdspan, row_major, accessor> dataset, + raft::host_matrix_view knn_graph, + experimental::nn_descent::index_params build_params) { - if (!build_params) { - build_params = std::make_optional(); - } - - auto nn_descent_idx = experimental::nn_descent::build(res, *build_params, dataset); + auto nn_descent_idx = experimental::nn_descent::index(res, knn_graph); + experimental::nn_descent::build(res, build_params, dataset, nn_descent_idx); using internal_IdxT = typename std::make_unsigned::type; using g_accessor = typename decltype(nn_descent_idx.graph())::accessor_type; @@ -263,8 +260,6 @@ experimental::nn_descent::index build_knn_graph( nn_descent_idx.graph().extent(1)); graph::sort_knn_graph(res, dataset, knn_graph_internal); - - return nn_descent_idx; } } // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index c122c257a2..3e4d0409bd 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -177,6 +177,7 @@ __device__ __forceinline__ int xor_swap(int x, int mask, int dir) return x < y == dir ? y : x; } +// TODO: Move to RAFT utils https://github.com/rapidsai/raft/issues/1827 __device__ __forceinline__ uint bfe(uint lane_id, uint pos) { uint res; @@ -323,6 +324,7 @@ struct GnndGraph { const size_t internal_node_degree, const size_t num_samples); void init_random_graph(); + // TODO: Create a generic bloom filter utility https://github.com/rapidsai/raft/issues/1827 // Use Bloom filter to sample "new" neighbors for local joining void sample_graph_new(InternalID_t* new_neighbors, const size_t width); void sample_graph(bool sample_new); @@ -369,6 +371,7 @@ class GNND { raft::device_matrix graph_buffer_; raft::device_matrix dists_buffer_; + // TODO: Investigate using RMM/RAFT types https://github.com/rapidsai/raft/issues/1827 thrust::host_vector> graph_host_buffer_; thrust::host_vector> dists_host_buffer_; @@ -442,6 +445,7 @@ __device__ __forceinline__ void load_vec(Data_t* vec_buffer, } } +// TODO: Replace with RAFT utilities https://github.com/rapidsai/raft/issues/1827 /** Calculate L2 norm, and cast data to __half */ template __global__ void preprocess_data_kernel(const Data_t* input_data, @@ -1363,15 +1367,17 @@ template , memory_type::host>> -index build(raft::resources const& res, - const index_params& params, - mdspan, row_major, Accessor> dataset) +void build(raft::resources const& res, + const index_params& params, + mdspan, row_major, Accessor> dataset, + index& idx) { RAFT_EXPECTS(dataset.extent(0) < std::numeric_limits::max() - 1, "The dataset size for GNND should be less than %d", std::numeric_limits::max() - 1); size_t intermediate_degree = params.intermediate_graph_degree; size_t graph_degree = params.graph_degree; + if (intermediate_degree >= static_cast(dataset.extent(0))) { RAFT_LOG_WARN( "Intermediate graph degree cannot be larger than dataset size, reducing it to %lu", @@ -1386,6 +1392,7 @@ index build(raft::resources const& res, intermediate_degree); graph_degree = intermediate_degree; } + // The elements in each knn-list are partitioned into different buckets, and we need more buckets // to mitigate bucket collisions. `intermediate_degree` is OK to larger than // extended_graph_degree. @@ -1406,7 +1413,7 @@ index build(raft::resources const& res, GNND nnd(res, build_config); nnd.build(dataset.data_handle(), dataset.extent(0), int_graph.data_handle()); - index idx{res, dataset.extent(0), static_cast(graph_degree)}; + #pragma omp parallel for for (size_t i = 0; i < static_cast(dataset.extent(0)); i++) { for (size_t j = 0; j < graph_degree; j++) { @@ -1414,6 +1421,32 @@ index build(raft::resources const& res, graph[i * graph_degree + j] = int_graph.data_handle()[i * extended_graph_degree + j]; } } +} + +template , memory_type::host>> +index build(raft::resources const& res, + const index_params& params, + mdspan, row_major, Accessor> dataset) +{ + size_t intermediate_degree = params.intermediate_graph_degree; + size_t graph_degree = params.graph_degree; + + if (intermediate_degree < graph_degree) { + RAFT_LOG_WARN( + "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " + "graph_degree.", + graph_degree, + intermediate_degree); + graph_degree = intermediate_degree; + } + + index idx{res, dataset.extent(0), static_cast(graph_degree)}; + + build(res, params, dataset, idx); + return idx; } diff --git a/cpp/include/raft/neighbors/nn_descent.cuh b/cpp/include/raft/neighbors/nn_descent.cuh index b20d407223..ceb5ae5643 100644 --- a/cpp/include/raft/neighbors/nn_descent.cuh +++ b/cpp/include/raft/neighbors/nn_descent.cuh @@ -48,10 +48,10 @@ namespace raft::neighbors::experimental::nn_descent { * * @tparam T data-type of the input dataset * @tparam IdxT data-type for the output index - * @param res raft::resources is an object mangaging resources - * @param params an instance of nn_descent::index_params that are parameters + * @param[in] res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters * to run the nn-descent algorithm - * @param dataset raft::device_matrix_view input dataset expected to be located + * @param[in] dataset raft::device_matrix_view input dataset expected to be located * in device memory * @return index index containing all-neighbors knn graph in host memory */ @@ -63,6 +63,45 @@ index build(raft::resources const& res, return detail::build(res, params, dataset); } +/** + * @brief Build nn-descent Index with dataset in device memory + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors::experimental; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::device_matrix_view dataset + * auto knn_graph = raft::make_host_matrix(N, D); + * auto index = nn_descent::index{res, knn_graph.view()}; + * cagra::build(res, index_params, dataset, index); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * @endcode + * + * @tparam T data-type of the input dataset + * @tparam IdxT data-type for the output index + * @param res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param[in] dataset raft::device_matrix_view input dataset expected to be located + * in device memory + * @param[out] idx raft::neighbors::experimental::nn_descentindex containing all-neighbors knn graph + * in host memory + */ +template +void build(raft::resources const& res, + index_params const& params, + raft::device_matrix_view dataset, + index& idx) +{ + detail::build(res, params, dataset, idx); +} + /** * @brief Build nn-descent Index with dataset in host memory * @@ -84,9 +123,9 @@ index build(raft::resources const& res, * @tparam T data-type of the input dataset * @tparam IdxT data-type for the output index * @param res raft::resources is an object mangaging resources - * @param params an instance of nn_descent::index_params that are parameters + * @param[in] params an instance of nn_descent::index_params that are parameters * to run the nn-descent algorithm - * @param dataset raft::host_matrix_view input dataset expected to be located + * @param[in] dataset raft::host_matrix_view input dataset expected to be located * in host memory * @return index index containing all-neighbors knn graph in host memory */ @@ -98,6 +137,45 @@ index build(raft::resources const& res, return detail::build(res, params, dataset); } +/** + * @brief Build nn-descent Index with dataset in host memory + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors::experimental; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::host_matrix_view dataset + * auto knn_graph = raft::make_host_matrix(N, D); + * auto index = nn_descent::index{res, knn_graph.view()}; + * cagra::build(res, index_params, dataset, index); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * @endcode + * + * @tparam T data-type of the input dataset + * @tparam IdxT data-type for the output index + * @param[in] res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param[in] dataset raft::host_matrix_view input dataset expected to be located + * in host memory + * @param[out] idx raft::neighbors::experimental::nn_descentindex containing all-neighbors knn graph + * in host memory + */ +template +void build(raft::resources const& res, + index_params const& params, + raft::host_matrix_view dataset, + index& idx) +{ + detail::build(res, params, dataset, idx); +} + /** @} */ // end group nn-descent } // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/include/raft/neighbors/nn_descent_types.hpp b/cpp/include/raft/neighbors/nn_descent_types.hpp index ef169faf1b..64e464c618 100644 --- a/cpp/include/raft/neighbors/nn_descent_types.hpp +++ b/cpp/include/raft/neighbors/nn_descent_types.hpp @@ -78,7 +78,28 @@ struct index : ann::index { : ann::index(), res_{res}, metric_{raft::distance::DistanceType::L2Expanded}, - graph_{raft::make_host_matrix(n_rows, n_cols)} + graph_{raft::make_host_matrix(n_rows, n_cols)}, + graph_view_{graph_.view()} + { + } + + /** + * @brief Construct a new index object + * + * This constructor creates an nn-descent index using a user allocated host memory knn-graph. + * The type of the knn-graph is a dense raft::host_matrix and dimensions are + * (n_rows, n_cols). + * + * @param res raft::resources is an object mangaging resources + * @param graph_view raft::host_matrix_view for storing knn-graph + */ + index(raft::resources const& res, + raft::host_matrix_view graph_view) + : ann::index(), + res_{res}, + metric_{raft::distance::DistanceType::L2Expanded}, + graph_{raft::make_host_matrix(0, 0)}, + graph_view_{graph_view} { } @@ -91,19 +112,19 @@ struct index : ann::index { // /** Total length of the index (number of vectors). */ [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { - return graph_.view().extent(0); + return graph_view_.extent(0); } /** Graph degree */ [[nodiscard]] constexpr inline auto graph_degree() const noexcept -> uint32_t { - return graph_.view().extent(1); + return graph_view_.extent(1); } /** neighborhood graph [size, graph-degree] */ [[nodiscard]] inline auto graph() noexcept -> host_matrix_view { - return graph_.view(); + return graph_view_; } // Don't allow copying the index for performance reasons (try avoiding copying data) @@ -117,6 +138,8 @@ struct index : ann::index { raft::resources const& res_; raft::distance::DistanceType metric_; raft::host_matrix graph_; // graph to return for non-int IdxT + raft::host_matrix_view + graph_view_; // view of graph for user provided matrix }; /** @} */ diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 69d950d10d..71de21e64a 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -386,7 +386,6 @@ if(BUILD_TESTS) test/neighbors/ann_nn_descent/test_float_uint32_t.cu test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu - test/neighbors/ann_nn_descent/test_float_int64_t.cu LIB EXPLICIT_INSTANTIATE_ONLY GPUS diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 734599057b..af8bae21f2 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -333,20 +333,16 @@ class AnnCagraSortTest : public ::testing::TestWithParam { cagra::build_knn_graph(handle_, database_view, knn_graph.view()); } } else { - auto nn_descent_idx_params = std::make_optional(); - nn_descent_idx_params->graph_degree = index_params.intermediate_graph_degree; - nn_descent_idx_params->intermediate_graph_degree = index_params.intermediate_graph_degree; + auto nn_descent_idx_params = nn_descent::index_params{}; + nn_descent_idx_params.graph_degree = index_params.intermediate_graph_degree; + nn_descent_idx_params.intermediate_graph_degree = index_params.intermediate_graph_degree; if (ps.host_dataset) { - auto nn_descent_idx = - cagra::build_knn_graph(handle_, database_host_view, nn_descent_idx_params); - std::memcpy( - knn_graph.data_handle(), nn_descent_idx.graph().data_handle(), knn_graph.size()); + cagra::build_knn_graph( + handle_, database_host_view, knn_graph.view(), nn_descent_idx_params); } else { - auto nn_descent_idx = - cagra::build_knn_graph(handle_, database_host_view, nn_descent_idx_params); - std::memcpy( - knn_graph.data_handle(), nn_descent_idx.graph().data_handle(), knn_graph.size()); + cagra::build_knn_graph( + handle_, database_host_view, knn_graph.view(), nn_descent_idx_params); } } diff --git a/cpp/test/neighbors/ann_nn_descent.cuh b/cpp/test/neighbors/ann_nn_descent.cuh index 0eeee65529..948323cf6e 100644 --- a/cpp/test/neighbors/ann_nn_descent.cuh +++ b/cpp/test/neighbors/ann_nn_descent.cuh @@ -112,13 +112,6 @@ class AnnNNDescentTest : public ::testing::TestWithParam { resource::sync_stream(handle_); } - // for (int i = 0; i < min(ps.n_queries, 10); i++) { - // // std::cout << "query " << i << std::end; - // print_vector("T", indices_naive.data() + i * ps.k, ps.k, std::cout); - // print_vector("C", indices_Cagra.data() + i * ps.k, ps.k, std::cout); - // print_vector("T", distances_naive.data() + i * ps.k, ps.k, std::cout); - // print_vector("C", distances_Cagra.data() + i * ps.k, ps.k, std::cout); - // } double min_recall = ps.min_recall; EXPECT_TRUE(eval_recall( indices_naive, indices_NNDescent, ps.n_rows, ps.graph_degree, 0.001, min_recall)); diff --git a/cpp/test/neighbors/ann_nn_descent/test_float_int64_t.cu b/cpp/test/neighbors/ann_nn_descent/test_float_int64_t.cu deleted file mode 100644 index 92a47a6a77..0000000000 --- a/cpp/test/neighbors/ann_nn_descent/test_float_int64_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "../ann_nn_descent.cuh" - -namespace raft::neighbors::experimental::nn_descent { - -typedef AnnNNDescentTest AnnNNDescentTestF_I64; -TEST_P(AnnNNDescentTestF_I64, AnnCagra) { this->testNNDescent(); } - -INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestF_I64, ::testing::ValuesIn(inputs)); - -} // namespace raft::neighbors::experimental::nn_descent From dffa67df03a0f467c1c0f3e00919e91da200f33a Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 25 Sep 2023 20:11:27 -0700 Subject: [PATCH 28/28] fix compiler error --- cpp/test/neighbors/ann_cagra.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index ce4da699ae..343afd04ec 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -350,7 +350,7 @@ class AnnCagraSortTest : public ::testing::TestWithParam { cagra::build_knn_graph(handle_, database_view, knn_graph.view()); } } else { - auto nn_descent_idx_params = nn_descent::index_params{}; + auto nn_descent_idx_params = experimental::nn_descent::index_params{}; nn_descent_idx_params.graph_degree = index_params.intermediate_graph_degree; nn_descent_idx_params.intermediate_graph_degree = index_params.intermediate_graph_degree;