Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port NN-descent algorithm to use in cagra::build() #1748

Merged
merged 40 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
ef245e8
successful compilation
divyegala Aug 17, 2023
0c1a6fe
nn-descent tests stuck indefinitely
divyegala Aug 18, 2023
7ad4c8a
Fix the bug of unexpected hang
RayWang96 Aug 21, 2023
558f849
Merge pull request #2 from RayWang96/port-nn-descent
divyegala Aug 21, 2023
4ccf3a7
Fix bugs that cause unit tests to fail
RayWang96 Aug 22, 2023
40e1cf0
Fix duplicate nodes issue
RayWang96 Aug 22, 2023
1f1f32d
Merge pull request #3 from RayWang96/port-nn-descent
divyegala Aug 22, 2023
33f5ebc
passing tests
divyegala Aug 23, 2023
56d3b93
merging upstream
divyegala Aug 24, 2023
6ac9186
Fix duplicate nodes issue
RayWang96 Aug 30, 2023
508050f
Fix IMA in sort_knn_graph
RayWang96 Aug 30, 2023
cfba7ab
Merge pull request #4 from RayWang96/port-nn-descent
divyegala Aug 30, 2023
0496bd9
temp benchmark
divyegala Aug 30, 2023
0e96d40
Revert "temp benchmark"
divyegala Aug 31, 2023
94682d8
remove explicit sort from nn-descent
divyegala Aug 31, 2023
7bf3ad6
Use RAFT types
divyegala Sep 1, 2023
60d7805
using RAFT types
divyegala Sep 1, 2023
5eb5690
remove explicit cuda copies and stream syncs
divyegala Sep 1, 2023
21ac440
experimental namespace, docs update+code examples
divyegala Sep 1, 2023
7a5bd71
merging upstream
divyegala Sep 1, 2023
28135a8
add graph_build_algo to bench-ann
divyegala Sep 1, 2023
b0344c7
add tests
divyegala Sep 6, 2023
3f3d965
add arch guards for using wmma
divyegala Sep 6, 2023
832d056
Revert "add arch guards for using wmma"
divyegala Sep 6, 2023
f60db9d
correctly add arch guards using raft::util::arch
divyegala Sep 6, 2023
5038f5a
Merge remote-tracking branch 'upstream/branch-23.10' into port-nn-des…
divyegala Sep 6, 2023
86f18bb
fix launch bounds for arches 750,860
divyegala Sep 6, 2023
69b7ba7
add comment explaining launch bounds changes for archs
divyegala Sep 6, 2023
a44e4a4
first batch of review addressing
divyegala Sep 13, 2023
aa4f6cb
merging upstream
divyegala Sep 13, 2023
4f0e425
use batch load iterator
divyegala Sep 14, 2023
c55ae4e
add nn-descent to python cagra
divyegala Sep 14, 2023
93d7f6e
Merge branch 'branch-23.10' into port-nn-descent
cjnolet Sep 20, 2023
4344666
more review updates
divyegala Sep 21, 2023
e769764
Merge remote-tracking branch 'upstream/branch-23.10' into port-nn-des…
divyegala Sep 21, 2023
76b520a
address more review comments
divyegala Sep 22, 2023
f53810d
Merge remote-tracking branch 'upstream/branch-23.10' into port-nn-des…
divyegala Sep 22, 2023
78284a1
Merge branch 'branch-23.10' into port-nn-descent
cjnolet Sep 25, 2023
dffa67d
fix compiler error
divyegala Sep 26, 2023
a0df0c6
Merge remote-tracking branch 'upstream/branch-23.10' into port-nn-des…
divyegala Sep 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cpp/bench/ann/src/raft/raft_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ void parse_build_param(const nlohmann::json& conf,
if (conf.contains("intermediate_graph_degree")) {
param.intermediate_graph_degree = conf.at("intermediate_graph_degree");
}
if (conf.contains("graph_build_algo")) {
if (conf.at("graph_build_algo") == "IVF_PQ") {
param.build_algo = raft::neighbors::cagra::graph_build_algo::IVF_PQ;
} else if (conf.at("graph_build_algo") == "NN_DESCENT") {
param.build_algo = raft::neighbors::cagra::graph_build_algo::NN_DESCENT;
}
}
}

template <typename T, typename IdxT>
Expand Down
73 changes: 65 additions & 8 deletions cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@ namespace raft::neighbors::cagra {
*/

/**
* @brief Build a kNN graph.
* @brief Build a kNN graph using IVF-PQ.
*
* The kNN graph is the first building block for CAGRA index.
* This function uses the IVF-PQ method to build a kNN graph.
*
* The output is a dense matrix that stores the neighbor indices for each pont in the dataset.
* Each point has the same number of neighbors.
Expand All @@ -52,8 +51,8 @@ namespace raft::neighbors::cagra {
* @code{.cpp}
* using namespace raft::neighbors;
* // use default index parameters
* cagra::index_params build_params;
* cagra::search_params search_params
* ivf_pq::index_params build_params;
* ivf_pq::search_params search_params
* auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 128);
* // create knn graph
* cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params);
Expand Down Expand Up @@ -95,6 +94,55 @@ void build_knn_graph(raft::resources const& res,
res, dataset_internal, knn_graph_internal, refine_rate, build_params, search_params);
}

/**
* @brief Build a kNN graph using NN-descent.
*
* The kNN graph is the first building block for CAGRA index.
*
* The output is a dense matrix that stores the neighbor indices for each pont in the dataset.
divyegala marked this conversation as resolved.
Show resolved Hide resolved
* Each point has the same number of neighbors.
*
* See [cagra::build](#cagra::build) for an alternative method.
*
* The following distance metrics are supported:
* - L2Expanded
*
* Usage example:
* @code{.cpp}
* using namespace raft::neighbors;
* using namespace raft::neighbors::experimental;
* // use default index parameters
* nn_descent::index_params build_params;
* build_params.graph_degree = 128;
* // create knn graph
* auto nn_descent_index = cagra::build_knn_graph(res, dataset, build_params);
* auto optimized_gaph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 64);
divyegala marked this conversation as resolved.
Show resolved Hide resolved
* cagra::optimize(res, dataset, nn_descent_index.graph.view(), optimized_graph.view());
* // Construct an index from dataset and optimized knn_graph
* auto index = cagra::index<T, IdxT>(res, build_params.metric(), dataset,
* optimized_graph.view());
* @endcode
*
* @tparam DataT data element type
* @tparam IdxT type of the dataset vector indices
* @tparam accessor host or device accessor_type for the dataset
* @param res raft::resources
divyegala marked this conversation as resolved.
Show resolved Hide resolved
* @param dataset raft::host/device_matrix_view
divyegala marked this conversation as resolved.
Show resolved Hide resolved
* @param build_params raft::neighbors::nn_descent::index_params
* @return raft::neighbors::nn_descent::index<IdxT>
*/
template <typename DataT,
typename IdxT = uint32_t,
typename accessor =
host_device_accessor<std::experimental::default_accessor<DataT>, memory_type::device>>
experimental::nn_descent::index<IdxT> build_knn_graph(
divyegala marked this conversation as resolved.
Show resolved Hide resolved
raft::resources const& res,
mdspan<const DataT, matrix_extent<int64_t>, row_major, accessor> dataset,
std::optional<experimental::nn_descent::index_params> build_params = std::nullopt)
{
return detail::build_knn_graph<DataT, IdxT>(res, dataset, build_params);
}

/**
* @brief Sort a KNN graph index.
* Preprocessing step for `cagra::optimize`: If a KNN graph is not built using
Expand Down Expand Up @@ -256,13 +304,22 @@ index<T, IdxT> build(raft::resources const& res,
graph_degree = intermediate_degree;
}

auto knn_graph = raft::make_host_matrix<IdxT, int64_t>(dataset.extent(0), intermediate_degree);
auto cagra_graph = raft::make_host_matrix<IdxT, int64_t>(dataset.extent(0), graph_degree);

build_knn_graph(res, dataset, knn_graph.view());
if (params.build_algo == graph_build_algo::IVF_PQ) {
auto knn_graph = raft::make_host_matrix<IdxT, int64_t>(dataset.extent(0), intermediate_degree);

auto cagra_graph = raft::make_host_matrix<IdxT, int64_t>(dataset.extent(0), graph_degree);
build_knn_graph(res, dataset, knn_graph.view());

optimize<IdxT>(res, knn_graph.view(), cagra_graph.view());
optimize<IdxT>(res, knn_graph.view(), cagra_graph.view());
} else {
auto nn_descent_params = std::make_optional<experimental::nn_descent::index_params>();
nn_descent_params->graph_degree = intermediate_degree;
nn_descent_params->intermediate_graph_degree = 1.5 * intermediate_degree;
auto nn_descent_index = build_knn_graph<T, IdxT>(res, dataset, nn_descent_params);

optimize<IdxT>(res, nn_descent_index.graph(), cagra_graph.view());
}

// Construct an index from dataset and optimized knn graph.
return index<T, IdxT>(res, params.metric, dataset, raft::make_const_mdspan(cagra_graph.view()));
Expand Down
5 changes: 5 additions & 0 deletions cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,15 @@ namespace raft::neighbors::cagra {
* @{
*/

enum class graph_build_algo { IVF_PQ, NN_DESCENT };
divyegala marked this conversation as resolved.
Show resolved Hide resolved

struct index_params : ann::index_params {
/** Degree of input graph for pruning. */
size_t intermediate_graph_degree = 128;
/** Degree of output graph. */
size_t graph_degree = 64;
/** ANN algorithm to build knn graph. */
graph_build_algo build_algo = graph_build_algo::IVF_PQ;
};

enum class search_algo {
Expand Down Expand Up @@ -351,6 +355,7 @@ struct index : ann::index {

// TODO: Remove deprecated experimental namespace in 23.12 release
namespace raft::neighbors::experimental::cagra {
using raft::neighbors::cagra::graph_build_algo;
using raft::neighbors::cagra::hash_mode;
using raft::neighbors::cagra::index;
using raft::neighbors::cagra::index_params;
Expand Down
29 changes: 29 additions & 0 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <raft/neighbors/detail/refine.cuh>
#include <raft/neighbors/ivf_pq.cuh>
#include <raft/neighbors/ivf_pq_types.hpp>
#include <raft/neighbors/nn_descent.cuh>
#include <raft/neighbors/refine.cuh>

namespace raft::neighbors::cagra::detail {
Expand Down Expand Up @@ -238,4 +239,32 @@ void build_knn_graph(raft::resources const& res,
if (!first) RAFT_LOG_DEBUG("# Finished building kNN graph");
}

template <typename DataT, typename IdxT, typename accessor>
experimental::nn_descent::index<IdxT> build_knn_graph(
raft::resources const& res,
mdspan<const DataT, matrix_extent<int64_t>, row_major, accessor> dataset,
std::optional<experimental::nn_descent::index_params> build_params = std::nullopt)
{
if (!build_params) {
build_params = std::make_optional<experimental::nn_descent::index_params>();
divyegala marked this conversation as resolved.
Show resolved Hide resolved
}

auto nn_descent_idx = experimental::nn_descent::build<DataT, IdxT>(res, *build_params, dataset);

using internal_IdxT = typename std::make_unsigned<IdxT>::type;
using g_accessor = typename decltype(nn_descent_idx.graph())::accessor_type;
using g_accessor_internal =
host_device_accessor<std::experimental::default_accessor<internal_IdxT>, g_accessor::mem_type>;

auto knn_graph_internal =
mdspan<internal_IdxT, matrix_extent<int64_t>, row_major, g_accessor_internal>(
reinterpret_cast<internal_IdxT*>(nn_descent_idx.graph().data_handle()),
nn_descent_idx.graph().extent(0),
nn_descent_idx.graph().extent(1));

graph::sort_knn_graph(res, dataset, knn_graph_internal);
divyegala marked this conversation as resolved.
Show resolved Hide resolved

return nn_descent_idx;
}

} // namespace raft::neighbors::cagra::detail
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/detail/cagra/graph_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ void sort_knn_graph(raft::resources const& res,
const uint32_t input_graph_degree = knn_graph.extent(1);
IdxT* const input_graph_ptr = knn_graph.data_handle();

auto d_input_graph = raft::make_device_matrix<IdxT, IdxT>(res, graph_size, input_graph_degree);
auto d_input_graph = raft::make_device_matrix<IdxT, int64_t>(res, graph_size, input_graph_degree);

//
// Sorting kNN graph
Expand Down
Loading