diff --git a/ci/build_wheel.sh b/ci/build_wheel.sh index af3a4c124b..976d7003dd 100755 --- a/ci/build_wheel.sh +++ b/ci/build_wheel.sh @@ -21,6 +21,7 @@ cd ${package_dir} case "${RAPIDS_CUDA_VERSION}" in 12.*) EXCLUDE_ARGS=( + --exclude "libcuvs.so" --exclude "libcublas.so.12" --exclude "libcublasLt.so.12" --exclude "libcufft.so.11" @@ -32,12 +33,14 @@ case "${RAPIDS_CUDA_VERSION}" in EXTRA_CMAKE_ARGS=";-DUSE_CUDA_MATH_WHEELS=ON" ;; 11.*) - EXCLUDE_ARGS=() + EXCLUDE_ARGS=( + --exclude "libcuvs.so" + ) EXTRA_CMAKE_ARGS=";-DUSE_CUDA_MATH_WHEELS=OFF" ;; esac -SKBUILD_CMAKE_ARGS="-DDETECT_CONDA_ENV=OFF;-DDISABLE_DEPRECATION_WARNINGS=ON;-DCPM_cumlprims_mg_SOURCE=${GITHUB_WORKSPACE}/cumlprims_mg/${EXTRA_CMAKE_ARGS}" \ +SKBUILD_CMAKE_ARGS="-DDETECT_CONDA_ENV=OFF;-DDISABLE_DEPRECATION_WARNINGS=ON;-DCPM_cumlprims_mg_SOURCE=${GITHUB_WORKSPACE}/cumlprims_mg/;-DUSE_CUVS_WHEEL=ON${EXTRA_CMAKE_ARGS}" \ python -m pip wheel . \ -w dist \ -vvv \ diff --git a/ci/release/update-version.sh b/ci/release/update-version.sh index e234e1401f..c37ba94238 100755 --- a/ci/release/update-version.sh +++ b/ci/release/update-version.sh @@ -40,11 +40,13 @@ echo "${NEXT_FULL_TAG}" > VERSION DEPENDENCIES=( cudf cuml + cuvs dask-cuda dask-cudf libcuml libcuml-tests libcumlprims + libcuvs libraft-headers libraft librmm diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 406789fffc..7bef5fcb11 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -14,6 +14,7 @@ dependencies: - cudatoolkit - cudf==24.10.*,>=0.0.0a0 - cupy>=12.0.0 +- cuvs==24.10.*,>=0.0.0a0 - cxx-compiler - cython>=3.0.0 - dask-cuda==24.10.*,>=0.0.0a0 @@ -39,8 +40,8 @@ dependencies: - libcusolver=11.4.1.48 - libcusparse-dev=11.7.5.86 - libcusparse=11.7.5.86 +- libcuvs==24.10.*,>=0.0.0a0 - libraft-headers==24.10.*,>=0.0.0a0 -- libraft==24.10.*,>=0.0.0a0 - librmm==24.10.*,>=0.0.0a0 - nbsphinx - ninja diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index 28c9197192..a016e3ccef 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -16,6 +16,7 @@ dependencies: - cuda-version=12.5 - cudf==24.10.*,>=0.0.0a0 - cupy>=12.0.0 +- cuvs==24.10.*,>=0.0.0a0 - cxx-compiler - cython>=3.0.0 - dask-cuda==24.10.*,>=0.0.0a0 @@ -36,8 +37,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev +- libcuvs==24.10.*,>=0.0.0a0 - libraft-headers==24.10.*,>=0.0.0a0 -- libraft==24.10.*,>=0.0.0a0 - librmm==24.10.*,>=0.0.0a0 - nbsphinx - ninja diff --git a/conda/environments/clang_tidy_cuda-118_arch-x86_64.yaml b/conda/environments/clang_tidy_cuda-118_arch-x86_64.yaml index f332c206d9..bec05e19f3 100644 --- a/conda/environments/clang_tidy_cuda-118_arch-x86_64.yaml +++ b/conda/environments/clang_tidy_cuda-118_arch-x86_64.yaml @@ -27,8 +27,8 @@ dependencies: - libcusolver=11.4.1.48 - libcusparse-dev=11.7.5.86 - libcusparse=11.7.5.86 +- libcuvs==24.10.*,>=0.0.0a0 - libraft-headers==24.10.*,>=0.0.0a0 -- libraft==24.10.*,>=0.0.0a0 - librmm==24.10.*,>=0.0.0a0 - ninja - nvcc_linux-64=11.8 diff --git a/conda/environments/cpp_all_cuda-118_arch-x86_64.yaml b/conda/environments/cpp_all_cuda-118_arch-x86_64.yaml index 66291a21ec..f70c53c16c 100644 --- a/conda/environments/cpp_all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/cpp_all_cuda-118_arch-x86_64.yaml @@ -25,8 +25,8 @@ dependencies: - libcusolver=11.4.1.48 - libcusparse-dev=11.7.5.86 - libcusparse=11.7.5.86 +- libcuvs==24.10.*,>=0.0.0a0 - libraft-headers==24.10.*,>=0.0.0a0 -- libraft==24.10.*,>=0.0.0a0 - librmm==24.10.*,>=0.0.0a0 - ninja - nvcc_linux-64=11.8 diff --git a/conda/environments/cpp_all_cuda-125_arch-x86_64.yaml b/conda/environments/cpp_all_cuda-125_arch-x86_64.yaml index 90bdefa75e..2210fe6e8b 100644 --- a/conda/environments/cpp_all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/cpp_all_cuda-125_arch-x86_64.yaml @@ -22,8 +22,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev +- libcuvs==24.10.*,>=0.0.0a0 - libraft-headers==24.10.*,>=0.0.0a0 -- libraft==24.10.*,>=0.0.0a0 - librmm==24.10.*,>=0.0.0a0 - ninja - spdlog>=1.14.1,<1.15 diff --git a/conda/recipes/libcuml/meta.yaml b/conda/recipes/libcuml/meta.yaml index ea1b935f01..f4a65c50f7 100644 --- a/conda/recipes/libcuml/meta.yaml +++ b/conda/recipes/libcuml/meta.yaml @@ -70,7 +70,7 @@ requirements: {% endif %} - fmt {{ fmt_version }} - libcumlprims ={{ minor_version }} - - libraft ={{ minor_version }} + - libcuvs ={{ minor_version }} - libraft-headers ={{ minor_version }} - librmm ={{ minor_version }} - spdlog {{ spdlog_version }} @@ -116,7 +116,7 @@ outputs: - libcusparse {% endif %} - libcumlprims ={{ minor_version }} - - libraft ={{ minor_version }} + - libcuvs ={{ minor_version }} - librmm ={{ minor_version }} - treelite {{ treelite_version }} about: diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 7451bde50b..f400b244c8 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -64,8 +64,7 @@ option(SINGLEGPU "Disable all mnmg components and comms libraries" OFF) option(USE_CCACHE "Cache build artifacts with ccache" OFF) option(CUDA_STATIC_RUNTIME "Statically link the CUDA runtime" OFF) option(CUDA_STATIC_MATH_LIBRARIES "Statically link the CUDA math libraries" OFF) -option(CUML_USE_RAFT_STATIC "Build and statically link the RAFT libraries" OFF) -option(CUML_RAFT_COMPILED "Use libraft shared library" ON) +option(CUML_USE_CUVS_STATIC "Build and statically link the CUVS library" OFF) option(CUML_USE_TREELITE_STATIC "Build and statically link the treelite library" OFF) option(CUML_EXPORT_TREELITE_LINKAGE "Whether to publicly or privately link treelite to libcuml++" OFF) option(CUML_USE_CUMLPRIMS_MG_STATIC "Build and statically link the cumlprims_mg library" OFF) @@ -78,6 +77,7 @@ option(CUML_EXCLUDE_RAFT_FROM_ALL "Exclude RAFT targets from cuML's 'all' target option(CUML_EXCLUDE_TREELITE_FROM_ALL "Exclude Treelite targets from cuML's 'all' target" OFF) option(CUML_EXCLUDE_CUMLPRIMS_MG_FROM_ALL "Exclude cumlprims_mg targets from cuML's 'all' target" OFF) option(CUML_RAFT_CLONE_ON_PIN "Explicitly clone RAFT branch when pinned to non-feature branch" ON) +option(CUML_CUVS_CLONE_ON_PIN "Explicitly clone CUVS branch when pinned to non-feature branch" ON) message(VERBOSE "CUML_CPP: Building libcuml_c shared library. Contains the cuML C API: ${BUILD_CUML_C_LIBRARY}") message(VERBOSE "CUML_CPP: Building libcuml shared library: ${BUILD_CUML_CPP_LIBRARY}") @@ -98,7 +98,7 @@ message(VERBOSE "CUML_CPP: Disabling all mnmg components and comms libraries: ${ message(VERBOSE "CUML_CPP: Cache build artifacts with ccache: ${USE_CCACHE}") message(VERBOSE "CUML_CPP: Statically link the CUDA runtime: ${CUDA_STATIC_RUNTIME}") message(VERBOSE "CUML_CPP: Statically link the CUDA math libraries: ${CUDA_STATIC_MATH_LIBRARIES}") -message(VERBOSE "CUML_CPP: Build and statically link RAFT libraries: ${CUML_USE_RAFT_STATIC}") +message(VERBOSE "CUML_CPP: Build and statically link CUVS libraries: ${CUML_USE_CUVS_STATIC}") message(VERBOSE "CUML_CPP: Build and statically link Treelite library: ${CUML_USE_TREELITE_STATIC}") set(CUML_ALGORITHMS "ALL" CACHE STRING "Experimental: Choose which algorithms are built into libcuml++.so. Can specify individual algorithms or groups in a semicolon-separated list.") @@ -228,6 +228,7 @@ endif() include(cmake/thirdparty/get_cccl.cmake) include(cmake/thirdparty/get_rmm.cmake) include(cmake/thirdparty/get_raft.cmake) +include(cmake/thirdparty/get_cuvs.cmake) if(LINK_TREELITE) include(cmake/thirdparty/get_treelite.cmake) @@ -442,18 +443,6 @@ if(BUILD_CUML_CPP_LIBRARY) src/metrics/kl_divergence.cu src/metrics/mutual_info_score.cu src/metrics/pairwise_distance.cu - src/metrics/pairwise_distance_canberra.cu - src/metrics/pairwise_distance_chebyshev.cu - src/metrics/pairwise_distance_correlation.cu - src/metrics/pairwise_distance_cosine.cu - src/metrics/pairwise_distance_euclidean.cu - src/metrics/pairwise_distance_hamming.cu - src/metrics/pairwise_distance_hellinger.cu - src/metrics/pairwise_distance_jensen_shannon.cu - src/metrics/pairwise_distance_kl_divergence.cu - src/metrics/pairwise_distance_l1.cu - src/metrics/pairwise_distance_minkowski.cu - src/metrics/pairwise_distance_russell_rao.cu src/metrics/r2_score.cu src/metrics/rand_index.cu src/metrics/silhouette_score.cu @@ -635,7 +624,7 @@ if(BUILD_CUML_CPP_LIBRARY) ) target_link_libraries(${CUML_CPP_TARGET} - PUBLIC rmm::rmm + PUBLIC rmm::rmm ${CUVS_LIB} ${_cuml_cpp_public_libs} PRIVATE ${_cuml_cpp_private_libs} ) diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 4f8c312717..237d8e0c6e 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -50,7 +50,6 @@ if(BUILD_CUML_BENCH) benchmark::benchmark ${TREELITE_LIBS} raft::raft - raft::compiled ) target_include_directories(${CUML_CPP_BENCH_TARGET} diff --git a/cpp/bench/sg/kmeans.cu b/cpp/bench/sg/kmeans.cu index 37eeed6346..7fb44a20fb 100644 --- a/cpp/bench/sg/kmeans.cu +++ b/cpp/bench/sg/kmeans.cu @@ -92,7 +92,7 @@ std::vector getInputs() p.kmeans.max_iter = 300; p.kmeans.tol = 1e-4; p.kmeans.verbosity = RAFT_LEVEL_INFO; - p.kmeans.metric = raft::distance::DistanceType::L2Expanded; + p.kmeans.metric = cuvs::distance::DistanceType::L2Expanded; p.kmeans.rng_state = raft::random::RngState(p.blobs.seed); p.kmeans.inertia_check = true; std::vector> rowcols = { diff --git a/cpp/cmake/thirdparty/get_cuvs.cmake b/cpp/cmake/thirdparty/get_cuvs.cmake new file mode 100644 index 0000000000..1d319206ba --- /dev/null +++ b/cpp/cmake/thirdparty/get_cuvs.cmake @@ -0,0 +1,77 @@ +#============================================================================= +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#============================================================================= + +set(CUML_MIN_VERSION_cuvs "${CUML_VERSION_MAJOR}.${CUML_VERSION_MINOR}.00") +set(CUML_BRANCH_VERSION_cuvs "${CUML_VERSION_MAJOR}.${CUML_VERSION_MINOR}") + +function(find_and_configure_cuvs) + set(oneValueArgs VERSION FORK PINNED_TAG EXCLUDE_FROM_ALL USE_CUVS_STATIC COMPILE_LIBRARY CLONE_ON_PIN) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN} ) + + if(PKG_CLONE_ON_PIN AND NOT PKG_PINNED_TAG STREQUAL "branch-${CUML_BRANCH_VERSION_cuvs}") + message(STATUS "CUML: CUVS pinned tag found: ${PKG_PINNED_TAG}. Cloning cuvs locally.") + set(CPM_DOWNLOAD_cuvs ON) + elseif(PKG_USE_CUVS_STATIC AND (NOT CPM_cuvs_SOURCE)) + message(STATUS "CUML: Cloning cuvs locally to build static libraries.") + set(CPM_DOWNLOAD_cuvs ON) + else() + message(STATUS "Not cloning cuvs locally") + endif() + + if(PKG_USE_CUVS_STATIC) + set(CUVS_LIB cuvs::cuvs_static PARENT_SCOPE) + else() + set(CUVS_LIB cuvs::cuvs PARENT_SCOPE) + endif() + + rapids_cpm_find(cuvs ${PKG_VERSION} + GLOBAL_TARGETS cuvs::cuvs + BUILD_EXPORT_SET cuml-exports + INSTALL_EXPORT_SET cuml-exports + CPM_ARGS + GIT_REPOSITORY https://github.com/${PKG_FORK}/cuvs.git + GIT_TAG ${PKG_PINNED_TAG} + SOURCE_SUBDIR cpp + EXCLUDE_FROM_ALL ${PKG_EXCLUDE_FROM_ALL} + OPTIONS + "BUILD_TESTS OFF" + "BUILD_BENCH OFF" + ) + + if(cuvs_ADDED) + message(VERBOSE "CUML: Using CUVS located in ${cuvs_SOURCE_DIR}") + else() + message(VERBOSE "CUML: Using CUVS located in ${cuvs_DIR}") + endif() + + +endfunction() + +# Change pinned tag here to test a commit in CI +# To use a different CUVS locally, set the CMake variable +# CPM_cuvs_SOURCE=/path/to/local/cuvs +find_and_configure_cuvs(VERSION ${CUML_MIN_VERSION_cuvs} + FORK rapidsai + PINNED_TAG branch-${CUML_BRANCH_VERSION_cuvs} + EXCLUDE_FROM_ALL ${CUML_EXCLUDE_CUVS_FROM_ALL} + # When PINNED_TAG above doesn't match cuml, + # force local cuvs clone in build directory + # even if it's already installed. + CLONE_ON_PIN ${CUML_CUVS_CLONE_ON_PIN} + COMPILE_LIBRARY ${CUML_CUVS_COMPILED} + USE_CUVS_STATIC ${CUML_USE_CUVS_STATIC} + ) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 7bc860eed8..4f260fcb93 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -36,16 +36,6 @@ function(find_and_configure_raft) string(APPEND RAFT_COMPONENTS " distributed") endif() - if(PKG_COMPILE_LIBRARY) - if(NOT PKG_USE_RAFT_STATIC) - string(APPEND RAFT_COMPONENTS " compiled") - set(RAFT_COMPILED_LIB raft::compiled PARENT_SCOPE) - else() - string(APPEND RAFT_COMPONENTS " compiled_static") - set(RAFT_COMPILED_LIB raft::compiled_static PARENT_SCOPE) - endif() - endif() - # We need to set this each time so that on subsequent calls to cmake # the raft-config.cmake re-evaluates the RAFT_NVTX value set(RAFT_NVTX ${PKG_NVTX}) @@ -66,7 +56,7 @@ function(find_and_configure_raft) "BUILD_TESTS OFF" "BUILD_BENCH OFF" "BUILD_CAGRA_HNSWLIB OFF" - "RAFT_COMPILE_LIBRARY ${PKG_COMPILE_LIBRARY}" + "RAFT_COMPILE_LIBRARY OFF" ) if(raft_ADDED) diff --git a/cpp/examples/kmeans/kmeans_example.cpp b/cpp/examples/kmeans/kmeans_example.cpp index 4e0a39e5bc..18433d8f16 100644 --- a/cpp/examples/kmeans/kmeans_example.cpp +++ b/cpp/examples/kmeans/kmeans_example.cpp @@ -112,7 +112,7 @@ int main(int argc, char* argv[]) params.max_iter = 300; params.tol = 0.05; } - params.metric = raft::distance::DistanceType::L2SqrtExpanded; + params.metric = cuvs::distance::DistanceType::L2SqrtExpanded; params.init = ML::kmeans::KMeansParams::InitMethod::Random; // Inputs copied from kmeans_test.cu diff --git a/cpp/include/cuml/cluster/kmeans.hpp b/cpp/include/cuml/cluster/kmeans.hpp index b62059945e..f075e49843 100644 --- a/cpp/include/cuml/cluster/kmeans.hpp +++ b/cpp/include/cuml/cluster/kmeans.hpp @@ -18,7 +18,7 @@ #include -#include +#include namespace raft { class handle_t; @@ -28,7 +28,7 @@ namespace ML { namespace kmeans { -using KMeansParams = raft::cluster::KMeansParams; +using KMeansParams = cuvs::cluster::kmeans::params; /** * @brief Compute k-means clustering and predicts cluster index for each sample diff --git a/cpp/include/cuml/cluster/kmeans_mg.hpp b/cpp/include/cuml/cluster/kmeans_mg.hpp index 722534e248..368d3c828c 100644 --- a/cpp/include/cuml/cluster/kmeans_mg.hpp +++ b/cpp/include/cuml/cluster/kmeans_mg.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,7 +48,7 @@ namespace opg { * @param[out] n_iter Number of iterations run. */ -void fit(const raft::handle_t& handle, +void fit(const raft::resources& handle, const KMeansParams& params, const float* X, int n_samples, @@ -58,7 +58,7 @@ void fit(const raft::handle_t& handle, float& inertia, int& n_iter); -void fit(const raft::handle_t& handle, +void fit(const raft::resources& handle, const KMeansParams& params, const double* X, int n_samples, @@ -68,7 +68,7 @@ void fit(const raft::handle_t& handle, double& inertia, int& n_iter); -void fit(const raft::handle_t& handle, +void fit(const raft::resources& handle, const KMeansParams& params, const float* X, int64_t n_samples, @@ -78,7 +78,7 @@ void fit(const raft::handle_t& handle, float& inertia, int64_t& n_iter); -void fit(const raft::handle_t& handle, +void fit(const raft::resources& handle, const KMeansParams& params, const double* X, int64_t n_samples, diff --git a/cpp/include/cuml/neighbors/knn.hpp b/cpp/include/cuml/neighbors/knn.hpp index 7927ec63a6..43150cf976 100644 --- a/cpp/include/cuml/neighbors/knn.hpp +++ b/cpp/include/cuml/neighbors/knn.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,8 +17,11 @@ #pragma once #include -#include #include +#include // MetricProcessor + +#include +#include namespace raft { class handle_t; @@ -46,6 +49,8 @@ namespace ML { * default * @param[in] metric_arg the value of `p` for Minkowski (l-p) distances. This * is ignored if the metric_type is not Minkowski. + * @param[in] translations translation ids for indices when index rows represent + * non-contiguous partitions */ void brute_force_knn(const raft::handle_t& handle, std::vector& input, @@ -59,7 +64,8 @@ void brute_force_knn(const raft::handle_t& handle, bool rowMajorIndex = false, bool rowMajorQuery = false, raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, - float metric_arg = 2.0f); + float metric_arg = 2.0f, + std::vector* translations = nullptr); void rbc_build_index(const raft::handle_t& handle, raft::spatial::knn::BallCoverIndex& index); @@ -71,6 +77,36 @@ void rbc_knn_query(const raft::handle_t& handle, uint32_t n_search_items, int64_t* out_inds, float* out_dists); + +struct knnIndex { + raft::distance::DistanceType metric; + float metricArg; + int nprobe; + std::unique_ptr> metric_processor; + + std::unique_ptr> ivf_flat; + std::unique_ptr> ivf_pq; + + int device; +}; + +struct knnIndexParam { + virtual ~knnIndexParam() {} +}; + +struct IVFParam : knnIndexParam { + int nlist; + int nprobe; +}; + +struct IVFFlatParam : IVFParam {}; + +struct IVFPQParam : IVFParam { + int M; + int n_bits; + bool usePrecomputedTables; +}; + /** * @brief Flat C++ API function to build an approximate nearest neighbors index * from an index array and a set of parameters. @@ -85,8 +121,8 @@ void rbc_knn_query(const raft::handle_t& handle, * @param[in] D the dimensionality of the index array */ void approx_knn_build_index(raft::handle_t& handle, - raft::spatial::knn::knnIndex* index, - raft::spatial::knn::knnIndexParam* params, + knnIndex* index, + knnIndexParam* params, raft::distance::DistanceType metric, float metricArg, float* index_array, @@ -109,7 +145,7 @@ void approx_knn_build_index(raft::handle_t& handle, void approx_knn_search(raft::handle_t& handle, float* distances, int64_t* indices, - raft::spatial::knn::knnIndex* index, + knnIndex* index, int k, float* query_array, int n); diff --git a/cpp/src/hdbscan/detail/soft_clustering.cuh b/cpp/src/hdbscan/detail/soft_clustering.cuh index 19f6e40754..5370ad2dfb 100644 --- a/cpp/src/hdbscan/detail/soft_clustering.cuh +++ b/cpp/src/hdbscan/detail/soft_clustering.cuh @@ -24,7 +24,6 @@ #include #include -#include #include #include #include @@ -43,6 +42,8 @@ #include #include +#include + #include #include #include @@ -88,45 +89,14 @@ void dist_membership_vector(const raft::handle_t& handle, value_idx samples_per_batch = min((value_idx)batch_size, (value_idx)n_queries - batch_offset); rmm::device_uvector dist(samples_per_batch * n_exemplars, stream); - // compute the distances using raft API - switch (metric) { - case raft::distance::DistanceType::L2SqrtExpanded: - raft::distance:: - distance( - handle, - query + batch_offset * n, - exemplars_dense.data(), - dist.data(), - samples_per_batch, - n_exemplars, - n, - true); - break; - case raft::distance::DistanceType::L1: - raft::distance::distance( - handle, - query + batch_offset * n, - exemplars_dense.data(), - dist.data(), - samples_per_batch, - n_exemplars, - n, - true); - break; - case raft::distance::DistanceType::CosineExpanded: - raft::distance:: - distance( - handle, - query + batch_offset * n, - exemplars_dense.data(), - dist.data(), - samples_per_batch, - n_exemplars, - n, - true); - break; - default: RAFT_EXPECTS(false, "Incorrect metric passed!"); - } + // compute the distances using the CUVS API + cuvs::distance::pairwise_distance( + handle, + raft::make_device_matrix_view( + query + batch_offset * n, samples_per_batch, n), + raft::make_device_matrix_view(exemplars_dense.data(), n_exemplars, n), + raft::make_device_matrix_view(dist.data(), samples_per_batch, n_exemplars), + static_cast(metric)); // compute the minimum distances to exemplars of each cluster value_idx n_elements = samples_per_batch * n_selected_clusters; diff --git a/cpp/src/hdbscan/runner.h b/cpp/src/hdbscan/runner.h index c79148eed2..2f3e554a20 100644 --- a/cpp/src/hdbscan/runner.h +++ b/cpp/src/hdbscan/runner.h @@ -24,8 +24,8 @@ #include #include -#include -#include +#include // build_dendrogram_host +#include // build_sorted_mst #include #include #include @@ -40,6 +40,8 @@ #include #include +#include + namespace ML { namespace HDBSCAN { @@ -174,16 +176,15 @@ void build_linkage(const raft::handle_t& handle, raft::sparse::COO mutual_reachability_coo(stream, (params.min_samples + 1) * m * 2); - detail::Reachability::mutual_reachability_graph(handle, - X, - (size_t)m, - (size_t)n, - metric, - params.min_samples + 1, - params.alpha, - mutual_reachability_indptr.data(), - core_dists, - mutual_reachability_coo); + cuvs::neighbors::reachability::mutual_reachability_graph( + handle, + raft::make_device_matrix_view(X, m, n), + params.min_samples + 1, + raft::make_device_vector_view(mutual_reachability_indptr.data(), m + 1), + raft::make_device_vector_view(core_dists, m), + mutual_reachability_coo, + static_cast(metric), + params.alpha); /** * Construct MST sorted by weights diff --git a/cpp/src/kmeans/kmeans_fit_predict.cu b/cpp/src/kmeans/kmeans_fit_predict.cu index f982e6b26a..8e37ea4534 100644 --- a/cpp/src/kmeans/kmeans_fit_predict.cu +++ b/cpp/src/kmeans/kmeans_fit_predict.cu @@ -14,17 +14,17 @@ * limitations under the License. */ -#include -#include #include +#include + namespace ML { namespace kmeans { // -------------------------- fit_predict --------------------------------// template void fit_predict_impl(const raft::handle_t& handle, - const raft::cluster::KMeansParams& params, + const cuvs::cluster::kmeans::params& params, const value_t* X, idx_t n_samples, idx_t n_features, @@ -45,12 +45,12 @@ void fit_predict_impl(const raft::handle_t& handle, auto inertia_view = raft::make_host_scalar_view(&inertia); auto n_iter_view = raft::make_host_scalar_view(&n_iter); - raft::cluster::kmeans_fit_predict( + cuvs::cluster::kmeans::fit_predict( handle, params, X_view, sw, centroids_opt, rLabels, inertia_view, n_iter_view); } void fit_predict(const raft::handle_t& handle, - const raft::cluster::KMeansParams& params, + const cuvs::cluster::kmeans::params& params, const float* X, int n_samples, int n_features, @@ -65,7 +65,7 @@ void fit_predict(const raft::handle_t& handle, } void fit_predict(const raft::handle_t& handle, - const raft::cluster::KMeansParams& params, + const cuvs::cluster::kmeans::params& params, const double* X, int n_samples, int n_features, @@ -80,7 +80,7 @@ void fit_predict(const raft::handle_t& handle, } void fit_predict(const raft::handle_t& handle, - const raft::cluster::KMeansParams& params, + const cuvs::cluster::kmeans::params& params, const float* X, int64_t n_samples, int64_t n_features, @@ -95,7 +95,7 @@ void fit_predict(const raft::handle_t& handle, } void fit_predict(const raft::handle_t& handle, - const raft::cluster::KMeansParams& params, + const cuvs::cluster::kmeans::params& params, const double* X, int64_t n_samples, int64_t n_features, diff --git a/cpp/src/kmeans/kmeans_mg.cu b/cpp/src/kmeans/kmeans_mg.cu index 0d135ebf99..c6804cbcaa 100644 --- a/cpp/src/kmeans/kmeans_mg.cu +++ b/cpp/src/kmeans/kmeans_mg.cu @@ -14,20 +14,16 @@ * limitations under the License. */ -#include "kmeans_mg_impl.cuh" - #include -#include - namespace ML { namespace kmeans { namespace opg { // ----------------------------- fit ---------------------------------// -void fit(const raft::handle_t& handle, - const raft::cluster::KMeansParams& params, +void fit(const raft::resources& handle, + const cuvs::cluster::kmeans::params& params, const float* X, int n_samples, int n_features, @@ -36,14 +32,23 @@ void fit(const raft::handle_t& handle, float& inertia, int& n_iter) { - const raft::handle_t& h = handle; + std::optional> sample_weight_view; + if (sample_weight != NULL) { + sample_weight_view = raft::make_device_vector_view(sample_weight, n_samples); + } - raft::stream_syncer _(h); - impl::fit(h, params, X, n_samples, n_features, sample_weight, centroids, inertia, n_iter); + cuvs::cluster::kmeans::fit( + handle, + params, + raft::make_device_matrix_view(X, n_samples, n_features), + sample_weight_view, + raft::make_device_matrix_view(centroids, params.n_clusters, n_features), + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter)); } -void fit(const raft::handle_t& handle, - const raft::cluster::KMeansParams& params, +void fit(const raft::resources& handle, + const cuvs::cluster::kmeans::params& params, const double* X, int n_samples, int n_features, @@ -52,13 +57,23 @@ void fit(const raft::handle_t& handle, double& inertia, int& n_iter) { - const raft::handle_t& h = handle; - raft::stream_syncer _(h); - impl::fit(h, params, X, n_samples, n_features, sample_weight, centroids, inertia, n_iter); + std::optional> sample_weight_view; + if (sample_weight != NULL) { + sample_weight_view = raft::make_device_vector_view(sample_weight, n_samples); + } + + cuvs::cluster::kmeans::fit( + handle, + params, + raft::make_device_matrix_view(X, n_samples, n_features), + sample_weight_view, + raft::make_device_matrix_view(centroids, params.n_clusters, n_features), + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter)); } -void fit(const raft::handle_t& handle, - const raft::cluster::KMeansParams& params, +void fit(const raft::resources& handle, + const cuvs::cluster::kmeans::params& params, const float* X, int64_t n_samples, int64_t n_features, @@ -67,14 +82,24 @@ void fit(const raft::handle_t& handle, float& inertia, int64_t& n_iter) { - const raft::handle_t& h = handle; + std::optional> sample_weight_view; + if (sample_weight != NULL) { + sample_weight_view = + raft::make_device_vector_view(sample_weight, n_samples); + } - raft::stream_syncer _(h); - impl::fit(h, params, X, n_samples, n_features, sample_weight, centroids, inertia, n_iter); + cuvs::cluster::kmeans::fit( + handle, + params, + raft::make_device_matrix_view(X, n_samples, n_features), + sample_weight_view, + raft::make_device_matrix_view(centroids, params.n_clusters, n_features), + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter)); } -void fit(const raft::handle_t& handle, - const raft::cluster::KMeansParams& params, +void fit(const raft::resources& handle, + const cuvs::cluster::kmeans::params& params, const double* X, int64_t n_samples, int64_t n_features, @@ -83,11 +108,21 @@ void fit(const raft::handle_t& handle, double& inertia, int64_t& n_iter) { - const raft::handle_t& h = handle; - raft::stream_syncer _(h); - impl::fit(h, params, X, n_samples, n_features, sample_weight, centroids, inertia, n_iter); -} + std::optional> sample_weight_view; + if (sample_weight != NULL) { + sample_weight_view = + raft::make_device_vector_view(sample_weight, n_samples); + } + cuvs::cluster::kmeans::fit( + handle, + params, + raft::make_device_matrix_view(X, n_samples, n_features), + sample_weight_view, + raft::make_device_matrix_view(centroids, params.n_clusters, n_features), + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter)); +} }; // end namespace opg }; // end namespace kmeans }; // end namespace ML diff --git a/cpp/src/kmeans/kmeans_mg_impl.cuh b/cpp/src/kmeans/kmeans_mg_impl.cuh deleted file mode 100644 index 6f08130632..0000000000 --- a/cpp/src/kmeans/kmeans_mg_impl.cuh +++ /dev/null @@ -1,818 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include - -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include -#include -#include -#include -#include -#include - -#include - -#include - -namespace ML { - -#define CUML_LOG_KMEANS(handle, fmt, ...) \ - do { \ - bool isRoot = true; \ - if (handle.comms_initialized()) { \ - const auto& comm = handle.get_comms(); \ - const int my_rank = comm.get_rank(); \ - isRoot = my_rank == 0; \ - } \ - if (isRoot) { CUML_LOG_DEBUG(fmt, ##__VA_ARGS__); } \ - } while (0) - -namespace kmeans { -namespace opg { -namespace impl { - -#define KMEANS_COMM_ROOT 0 - -static raft::cluster::kmeans::KMeansParams default_params; - -// Selects 'n_clusters' samples randomly from X -template -void initRandom(const raft::handle_t& handle, - const raft::cluster::kmeans::KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids) -{ - const auto& comm = handle.get_comms(); - cudaStream_t stream = handle.get_stream(); - auto n_local_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - - const int my_rank = comm.get_rank(); - const int n_ranks = comm.get_size(); - - std::vector nCentroidsSampledByRank(n_ranks, 0); - std::vector nCentroidsElementsToReceiveFromRank(n_ranks, 0); - - const int nranks_reqd = std::min(n_ranks, n_clusters); - ASSERT(KMEANS_COMM_ROOT < nranks_reqd, "KMEANS_COMM_ROOT must be in [0, %d)\n", nranks_reqd); - - for (int rank = 0; rank < nranks_reqd; ++rank) { - int nCentroidsSampledInRank = n_clusters / nranks_reqd; - if (rank == KMEANS_COMM_ROOT) { - nCentroidsSampledInRank += n_clusters - nCentroidsSampledInRank * nranks_reqd; - } - nCentroidsSampledByRank[rank] = nCentroidsSampledInRank; - nCentroidsElementsToReceiveFromRank[rank] = nCentroidsSampledInRank * n_features; - } - - auto nCentroidsSampledInRank = nCentroidsSampledByRank[my_rank]; - ASSERT((IndexT)nCentroidsSampledInRank <= (IndexT)n_local_samples, - "# random samples requested from rank-%d is larger than the available " - "samples at the rank (requested is %lu, available is %lu)", - my_rank, - (size_t)nCentroidsSampledInRank, - (size_t)n_local_samples); - - auto centroidsSampledInRank = - raft::make_device_matrix(handle, nCentroidsSampledInRank, n_features); - - raft::cluster::kmeans::shuffle_and_gather( - handle, X, centroidsSampledInRank.view(), nCentroidsSampledInRank, params.rng_state.seed); - - std::vector displs(n_ranks); - thrust::exclusive_scan(thrust::host, - nCentroidsElementsToReceiveFromRank.begin(), - nCentroidsElementsToReceiveFromRank.end(), - displs.begin()); - - // gather centroids from all ranks - comm.allgatherv(centroidsSampledInRank.data_handle(), // sendbuff - centroids.data_handle(), // recvbuff - nCentroidsElementsToReceiveFromRank.data(), // recvcount - displs.data(), - stream); -} - -/* - * @brief Selects 'n_clusters' samples from X using scalable kmeans++ algorithm - * Scalable kmeans++ pseudocode - * 1: C = sample a point uniformly at random from X - * 2: psi = phi_X (C) - * 3: for O( log(psi) ) times do - * 4: C' = sample each point x in X independently with probability - * p_x = l * ( d^2(x, C) / phi_X (C) ) - * 5: C = C U C' - * 6: end for - * 7: For x in C, set w_x to be the number of points in X closer to x than any - * other point in C - * 8: Recluster the weighted points in C into k clusters - */ -template -void initKMeansPlusPlus(const raft::handle_t& handle, - const raft::cluster::kmeans::KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroidsRawData, - rmm::device_uvector& workspace) -{ - const auto& comm = handle.get_comms(); - cudaStream_t stream = handle.get_stream(); - const int my_rank = comm.get_rank(); - const int n_rank = comm.get_size(); - - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - auto metric = params.metric; - - raft::random::RngState rng(params.rng_state.seed, raft::random::GeneratorType::GenPhilox); - - // <<<< Step-1 >>> : C <- sample a point uniformly at random from X - // 1.1 - Select a rank r' at random from the available n_rank ranks with a - // probability of 1/n_rank [Note - with same seed all rank selects - // the same r' which avoids a call to comm] - // 1.2 - Rank r' samples a point uniformly at random from the local dataset - // X which will be used as the initial centroid for kmeans++ - // 1.3 - Communicate the initial centroid chosen by rank-r' to all other - // ranks - std::mt19937 gen(params.rng_state.seed); - std::uniform_int_distribution<> dis(0, n_rank - 1); - int rp = dis(gen); - - // buffer to flag the sample that is chosen as initial centroids - std::vector h_isSampleCentroid(n_samples); - std::fill(h_isSampleCentroid.begin(), h_isSampleCentroid.end(), 0); - - auto initialCentroid = raft::make_device_matrix(handle, 1, n_features); - CUML_LOG_KMEANS( - handle, "@Rank-%d : KMeans|| : initial centroid is sampled at rank-%d\n", my_rank, rp); - - // 1.2 - Rank r' samples a point uniformly at random from the local dataset - // X which will be used as the initial centroid for kmeans++ - if (my_rank == rp) { - std::mt19937 gen(params.rng_state.seed); - std::uniform_int_distribution<> dis(0, n_samples - 1); - - int cIdx = dis(gen); - auto centroidsView = raft::make_device_matrix_view( - X.data_handle() + cIdx * n_features, 1, n_features); - - raft::copy( - initialCentroid.data_handle(), centroidsView.data_handle(), centroidsView.size(), stream); - - h_isSampleCentroid[cIdx] = 1; - } - - // 1.3 - Communicate the initial centroid chosen by rank-r' to all other ranks - comm.bcast(initialCentroid.data_handle(), initialCentroid.size(), rp, stream); - - // device buffer to flag the sample that is chosen as initial centroid - auto isSampleCentroid = raft::make_device_vector(handle, n_samples); - - raft::copy( - isSampleCentroid.data_handle(), h_isSampleCentroid.data(), isSampleCentroid.size(), stream); - - rmm::device_uvector centroidsBuf(0, stream); - - // reset buffer to store the chosen centroid - centroidsBuf.resize(initialCentroid.size(), stream); - raft::copy(centroidsBuf.begin(), initialCentroid.data_handle(), initialCentroid.size(), stream); - - auto potentialCentroids = raft::make_device_matrix_view( - centroidsBuf.data(), initialCentroid.extent(0), initialCentroid.extent(1)); - // <<< End of Step-1 >>> - - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // L2 norm of X: ||x||^2 - auto L2NormX = raft::make_device_vector(handle, n_samples); - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm(L2NormX.data_handle(), - X.data_handle(), - X.extent(1), - X.extent(0), - raft::linalg::L2Norm, - true, - stream); - } - - auto minClusterDistance = raft::make_device_vector(handle, n_samples); - auto uniformRands = raft::make_device_vector(handle, n_samples); - - // <<< Step-2 >>>: psi <- phi_X (C) - auto clusterCost = raft::make_device_scalar(handle, 0); - - raft::cluster::kmeans::min_cluster_distance(handle, - X, - potentialCentroids, - minClusterDistance.view(), - L2NormX.view(), - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - // compute partial cluster cost from the samples in rank - raft::cluster::kmeans::cluster_cost( - handle, - minClusterDistance.view(), - workspace, - clusterCost.view(), - cuda::proclaim_return_type( - [] __device__(const DataT& a, const DataT& b) { return a + b; })); - - // compute total cluster cost by accumulating the partial cost from all the - // ranks - comm.allreduce( - clusterCost.data_handle(), clusterCost.data_handle(), 1, raft::comms::op_t::SUM, stream); - - DataT psi = 0; - raft::copy(&psi, clusterCost.data_handle(), 1, stream); - - // <<< End of Step-2 >>> - - ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, - "An error occurred in the distributed operation. This can result from " - "a failed rank"); - - // Scalable kmeans++ paper claims 8 rounds is sufficient - int niter = std::min(8, (int)ceil(log(psi))); - CUML_LOG_KMEANS(handle, - "@Rank-%d:KMeans|| :phi - %f, max # of iterations for kmeans++ loop - " - "%d\n", - my_rank, - psi, - niter); - - // <<<< Step-3 >>> : for O( log(psi) ) times do - for (int iter = 0; iter < niter; ++iter) { - CUML_LOG_KMEANS(handle, - "@Rank-%d:KMeans|| - Iteration %d: # potential centroids sampled - " - "%d\n", - my_rank, - iter, - potentialCentroids.extent(0)); - - raft::cluster::kmeans::min_cluster_distance(handle, - X, - potentialCentroids, - minClusterDistance.view(), - L2NormX.view(), - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - raft::cluster::kmeans::cluster_cost( - handle, - minClusterDistance.view(), - workspace, - clusterCost.view(), - cuda::proclaim_return_type( - [] __device__(const DataT& a, const DataT& b) { return a + b; })); - comm.allreduce( - clusterCost.data_handle(), clusterCost.data_handle(), 1, raft::comms::op_t::SUM, stream); - raft::copy(&psi, clusterCost.data_handle(), 1, stream); - ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, - "An error occurred in the distributed operation. This can result " - "from a failed rank"); - - // <<<< Step-4 >>> : Sample each point x in X independently and identify new - // potentialCentroids - raft::random::uniform( - handle, rng, uniformRands.data_handle(), uniformRands.extent(0), (DataT)0, (DataT)1); - raft::cluster::kmeans::SamplingOp select_op(psi, - params.oversampling_factor, - n_clusters, - uniformRands.data_handle(), - isSampleCentroid.data_handle()); - - rmm::device_uvector inRankCp(0, stream); - raft::cluster::kmeans::sample_centroids(handle, - X, - minClusterDistance.view(), - isSampleCentroid.view(), - select_op, - inRankCp, - workspace); - /// <<<< End of Step-4 >>>> - - int* nPtsSampledByRank; - RAFT_CUDA_TRY(cudaMallocHost(&nPtsSampledByRank, n_rank * sizeof(int))); - - /// <<<< Step-5 >>> : C = C U C' - // append the data in Cp from all ranks to the buffer holding the - // potentialCentroids - // RAFT_CUDA_TRY(cudaMemsetAsync(nPtsSampledByRank, 0, n_rank * sizeof(int), stream)); - std::fill(nPtsSampledByRank, nPtsSampledByRank + n_rank, 0); - nPtsSampledByRank[my_rank] = inRankCp.size() / n_features; - comm.allgather(&(nPtsSampledByRank[my_rank]), nPtsSampledByRank, 1, stream); - ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, - "An error occurred in the distributed operation. This can result " - "from a failed rank"); - - auto nPtsSampled = - thrust::reduce(thrust::host, nPtsSampledByRank, nPtsSampledByRank + n_rank, 0); - - // gather centroids from all ranks - std::vector sizes(n_rank); - thrust::transform( - thrust::host, nPtsSampledByRank, nPtsSampledByRank + n_rank, sizes.begin(), [&](int val) { - return val * n_features; - }); - - RAFT_CUDA_TRY_NO_THROW(cudaFreeHost(nPtsSampledByRank)); - - std::vector displs(n_rank); - thrust::exclusive_scan(thrust::host, sizes.begin(), sizes.end(), displs.begin()); - - centroidsBuf.resize(centroidsBuf.size() + nPtsSampled * n_features, stream); - comm.allgatherv(inRankCp.data(), - centroidsBuf.end() - nPtsSampled * n_features, - sizes.data(), - displs.data(), - stream); - - auto tot_centroids = potentialCentroids.extent(0) + nPtsSampled; - potentialCentroids = - raft::make_device_matrix_view(centroidsBuf.data(), tot_centroids, n_features); - /// <<<< End of Step-5 >>> - } /// <<<< Step-6 >>> - - CUML_LOG_KMEANS(handle, - "@Rank-%d:KMeans||: # potential centroids sampled - %d\n", - my_rank, - potentialCentroids.extent(0)); - - if ((IndexT)potentialCentroids.extent(0) > (IndexT)n_clusters) { - // <<< Step-7 >>>: For x in C, set w_x to be the number of pts closest to X - // temporary buffer to store the sample count per cluster, destructor - // releases the resource - - auto weight = raft::make_device_vector(handle, potentialCentroids.extent(0)); - - raft::cluster::kmeans::count_samples_in_cluster( - handle, params, X, L2NormX.view(), potentialCentroids, workspace, weight.view()); - - // merge the local histogram from all ranks - comm.allreduce(weight.data_handle(), // sendbuff - weight.data_handle(), // recvbuff - weight.size(), // count - raft::comms::op_t::SUM, - stream); - - // <<< end of Step-7 >>> - - // Step-8: Recluster the weighted points in C into k clusters - // Note - reclustering step is duplicated across all ranks and with the same - // seed they should generate the same potentialCentroids - auto const_centroids = raft::make_device_matrix_view( - potentialCentroids.data_handle(), potentialCentroids.extent(0), potentialCentroids.extent(1)); - raft::cluster::kmeans::init_plus_plus( - handle, params, const_centroids, centroidsRawData, workspace); - - auto inertia = raft::make_host_scalar(0); - auto n_iter = raft::make_host_scalar(0); - auto weight_view = - raft::make_device_vector_view(weight.data_handle(), weight.extent(0)); - raft::cluster::kmeans::KMeansParams params_copy = params; - params_copy.rng_state = default_params.rng_state; - - raft::cluster::kmeans::fit_main(handle, - params_copy, - const_centroids, - weight_view, - centroidsRawData, - inertia.view(), - n_iter.view(), - workspace); - - } else if ((IndexT)potentialCentroids.extent(0) < (IndexT)n_clusters) { - // supplement with random - auto n_random_clusters = n_clusters - potentialCentroids.extent(0); - CUML_LOG_KMEANS(handle, - "[Warning!] KMeans||: found fewer than %d centroids during " - "initialization (found %d centroids, remaining %d centroids will be " - "chosen randomly from input samples)\n", - n_clusters, - potentialCentroids.extent(0), - n_random_clusters); - - // generate `n_random_clusters` centroids - raft::cluster::kmeans::KMeansParams rand_params = params; - rand_params.rng_state = default_params.rng_state; - rand_params.init = raft::cluster::kmeans::KMeansParams::InitMethod::Random; - rand_params.n_clusters = n_random_clusters; - initRandom(handle, rand_params, X, centroidsRawData); - - // copy centroids generated during kmeans|| iteration to the buffer - raft::copy(centroidsRawData.data_handle() + n_random_clusters * n_features, - potentialCentroids.data_handle(), - potentialCentroids.size(), - stream); - - } else { - // found the required n_clusters - raft::copy(centroidsRawData.data_handle(), - potentialCentroids.data_handle(), - potentialCentroids.size(), - stream); - } -} - -template -void checkWeights(const raft::handle_t& handle, - rmm::device_uvector& workspace, - raft::device_vector_view weight) -{ - cudaStream_t stream = handle.get_stream(); - rmm::device_scalar wt_aggr(stream); - - const auto& comm = handle.get_comms(); - - auto n_samples = weight.extent(0); - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceReduce::Sum( - nullptr, temp_storage_bytes, weight.data_handle(), wt_aggr.data(), n_samples, stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceReduce::Sum( - workspace.data(), temp_storage_bytes, weight.data_handle(), wt_aggr.data(), n_samples, stream)); - - comm.allreduce(wt_aggr.data(), // sendbuff - wt_aggr.data(), // recvbuff - 1, // count - raft::comms::op_t::SUM, - stream); - DataT wt_sum = wt_aggr.value(stream); - handle.sync_stream(stream); - - if (wt_sum != n_samples) { - CUML_LOG_KMEANS(handle, - "[Warning!] KMeans: normalizing the user provided sample weights to " - "sum up to %d samples", - n_samples); - - DataT scale = n_samples / wt_sum; - raft::linalg::unaryOp( - weight.data_handle(), - weight.data_handle(), - weight.size(), - cuda::proclaim_return_type([=] __device__(const DataT& wt) { return wt * scale; }), - stream); - } -} - -template -void fit(const raft::handle_t& handle, - const raft::cluster::kmeans::KMeansParams& params, - raft::device_matrix_view X, - raft::device_vector_view weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter, - rmm::device_uvector& workspace) -{ - const auto& comm = handle.get_comms(); - cudaStream_t stream = handle.get_stream(); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - auto metric = params.metric; - - // stores (key, value) pair corresponding to each sample where - // - key is the index of nearest cluster - // - value is the distance to the nearest cluster - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); - - // temporary buffer to store L2 norm of centroids or distance matrix, - // destructor releases the resource - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // temporary buffer to store intermediate centroids, destructor releases the - // resource - auto newCentroids = raft::make_device_matrix(handle, n_clusters, n_features); - - // temporary buffer to store the weights per cluster, destructor releases - // the resource - auto wtInCluster = raft::make_device_vector(handle, n_clusters); - - // L2 norm of X: ||x||^2 - auto L2NormX = raft::make_device_vector(handle, n_samples); - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm(L2NormX.data_handle(), - X.data_handle(), - X.extent(1), - X.extent(0), - raft::linalg::L2Norm, - true, - stream); - } - - DataT priorClusteringCost = 0; - for (n_iter[0] = 1; n_iter[0] <= params.max_iter; ++n_iter[0]) { - CUML_LOG_KMEANS(handle, - "KMeans.fit: Iteration-%d: fitting the model using the initialize " - "cluster centers\n", - n_iter[0]); - - auto const_centroids = raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)); - // computes minClusterAndDistance[0:n_samples) where - // minClusterAndDistance[i] is a pair where - // 'key' is index to an sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' - raft::cluster::kmeans::min_cluster_and_distance(handle, - X, - const_centroids, - minClusterAndDistance.view(), - L2NormX.view(), - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - // Using TransformInputIteratorT to dereference an array of - // cub::KeyValuePair and converting them to just return the Key to be used - // in reduce_rows_by_key prims - raft::cluster::kmeans::KeyValueIndexOp conversion_op; - cub::TransformInputIterator, - raft::KeyValuePair*> - itr(minClusterAndDistance.data_handle(), conversion_op); - - workspace.resize(n_samples, stream); - - // Calculates weighted sum of all the samples assigned to cluster-i and - // store the result in newCentroids[i] - raft::linalg::reduce_rows_by_key((DataT*)X.data_handle(), - X.extent(1), - itr, - weight.data_handle(), - workspace.data(), - X.extent(0), - X.extent(1), - static_cast(n_clusters), - newCentroids.data_handle(), - stream); - - // Reduce weights by key to compute weight in each cluster - raft::linalg::reduce_cols_by_key(weight.data_handle(), - itr, - wtInCluster.data_handle(), - (IndexT)1, - (IndexT)weight.extent(0), - (IndexT)n_clusters, - stream); - - // merge the local histogram from all ranks - comm.allreduce(wtInCluster.data_handle(), // sendbuff - wtInCluster.data_handle(), // recvbuff - wtInCluster.size(), // count - raft::comms::op_t::SUM, - stream); - - // reduces newCentroids from all ranks - comm.allreduce(newCentroids.data_handle(), // sendbuff - newCentroids.data_handle(), // recvbuff - newCentroids.size(), // count - raft::comms::op_t::SUM, - stream); - - // Computes newCentroids[i] = newCentroids[i]/wtInCluster[i] where - // newCentroids[n_clusters x n_features] - 2D array, newCentroids[i] has - // sum of all the samples assigned to cluster-i - // wtInCluster[n_clusters] - 1D array, wtInCluster[i] contains # of - // samples in cluster-i. - // Note - when wtInCluster[i] is 0, newCentroid[i] is reset to 0 - - raft::linalg::matrixVectorOp( - newCentroids.data_handle(), - newCentroids.data_handle(), - wtInCluster.data_handle(), - newCentroids.extent(1), - newCentroids.extent(0), - true, - false, - cuda::proclaim_return_type([=] __device__(DataT mat, DataT vec) { - if (vec == 0) - return DataT(0); - else - return mat / vec; - }), - stream); - - // copy the centroids[i] to newCentroids[i] when wtInCluster[i] is 0 - cub::ArgIndexInputIterator itr_wt(wtInCluster.data_handle()); - raft::matrix::gather_if( - centroids.data_handle(), - centroids.extent(1), - centroids.extent(0), - itr_wt, - itr_wt, - wtInCluster.extent(0), - newCentroids.data_handle(), - cuda::proclaim_return_type( - [=] __device__(raft::KeyValuePair map) { // predicate - // copy when the # of samples in the cluster is 0 - if (map.value == 0) - return true; - else - return false; - }), - cuda::proclaim_return_type( - [=] __device__(raft::KeyValuePair map) { // map - return map.key; - }), - stream); - - // compute the squared norm between the newCentroids and the original - // centroids, destructor releases the resource - auto sqrdNorm = raft::make_device_scalar(handle, 1); - raft::linalg::mapThenSumReduce( - sqrdNorm.data_handle(), - newCentroids.size(), - cuda::proclaim_return_type([=] __device__(const DataT a, const DataT b) { - DataT diff = a - b; - return diff * diff; - }), - stream, - centroids.data_handle(), - newCentroids.data_handle()); - - DataT sqrdNormError = 0; - raft::copy(&sqrdNormError, sqrdNorm.data_handle(), sqrdNorm.size(), stream); - - raft::copy(centroids.data_handle(), newCentroids.data_handle(), newCentroids.size(), stream); - - bool done = false; - if (params.inertia_check) { - rmm::device_scalar> clusterCostD(stream); - - // calculate cluster cost phi_x(C) - raft::cluster::kmeans::cluster_cost( - handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - cuda::proclaim_return_type>( - [] __device__(const raft::KeyValuePair& a, - const raft::KeyValuePair& b) { - raft::KeyValuePair res; - res.key = 0; - res.value = a.value + b.value; - return res; - })); - - // Cluster cost phi_x(C) from all ranks - comm.allreduce(&(clusterCostD.data()->value), - &(clusterCostD.data()->value), - 1, - raft::comms::op_t::SUM, - stream); - - DataT curClusteringCost = 0; - raft::copy(&curClusteringCost, &(clusterCostD.data()->value), 1, stream); - - ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, - "An error occurred in the distributed operation. This can result " - "from a failed rank"); - ASSERT(curClusteringCost != (DataT)0.0, - "Too few points and centroids being found is getting 0 cost from " - "centers\n"); - - if (n_iter[0] > 0) { - DataT delta = curClusteringCost / priorClusteringCost; - if (delta > 1 - params.tol) done = true; - } - priorClusteringCost = curClusteringCost; - } - - handle.sync_stream(stream); - if (sqrdNormError < params.tol) done = true; - - if (done) { - CUML_LOG_KMEANS( - handle, "Threshold triggered after %d iterations. Terminating early.\n", n_iter[0]); - break; - } - } -} - -template -void fit(const raft::handle_t& handle, - const raft::cluster::kmeans::KMeansParams& params, - const DataT* X, - const IndexT n_local_samples, - const IndexT n_features, - const DataT* sample_weight, - DataT* centroids, - DataT& inertia, - IndexT& n_iter) -{ - cudaStream_t stream = handle.get_stream(); - - ASSERT(n_local_samples > 0, "# of samples must be > 0"); - ASSERT(params.oversampling_factor > 0, - "oversampling factor must be > 0 (requested %d)", - (int)params.oversampling_factor); - ASSERT(is_device_or_managed_type(X), "input data must be device accessible"); - - auto n_clusters = params.n_clusters; - auto data = raft::make_device_matrix_view(X, n_local_samples, n_features); - auto weight = raft::make_device_vector(handle, n_local_samples); - if (sample_weight != nullptr) { - raft::copy(weight.data_handle(), sample_weight, n_local_samples, stream); - } else { - thrust::fill( - handle.get_thrust_policy(), weight.data_handle(), weight.data_handle() + weight.size(), 1); - } - - // underlying expandable storage that holds centroids data - auto centroidsRawData = raft::make_device_matrix(handle, n_clusters, n_features); - - // Device-accessible allocation of expandable storage used as temporary buffers - rmm::device_uvector workspace(0, stream); - - // check if weights sum up to n_samples - checkWeights(handle, workspace, weight.view()); - - if (params.init == raft::cluster::kmeans::KMeansParams::InitMethod::Random) { - // initializing with random samples from input dataset - CUML_LOG_KMEANS(handle, - "KMeans.fit: initialize cluster centers by randomly choosing from the " - "input data.\n"); - initRandom(handle, params, data, centroidsRawData.view()); - } else if (params.init == raft::cluster::kmeans::KMeansParams::InitMethod::KMeansPlusPlus) { - // default method to initialize is kmeans++ - CUML_LOG_KMEANS(handle, "KMeans.fit: initialize cluster centers using k-means++ algorithm.\n"); - initKMeansPlusPlus(handle, params, data, centroidsRawData.view(), workspace); - } else if (params.init == raft::cluster::kmeans::KMeansParams::InitMethod::Array) { - CUML_LOG_KMEANS(handle, - "KMeans.fit: initialize cluster centers from the ndarray array input " - "passed to init argument.\n"); - - ASSERT(centroids != nullptr, - "centroids array is null (require a valid array of centroids for " - "the requested initialization method)"); - - raft::copy(centroidsRawData.data_handle(), centroids, params.n_clusters * n_features, stream); - } else { - THROW("unknown initialization method to select initial centers"); - } - auto inertiaView = raft::make_host_scalar_view(&inertia); - auto n_iterView = raft::make_host_scalar_view(&n_iter); - - fit(handle, - params, - data, - weight.view(), - centroidsRawData.view(), - inertiaView, - n_iterView, - workspace); - - raft::copy(centroids, centroidsRawData.data_handle(), params.n_clusters * n_features, stream); - - CUML_LOG_KMEANS(handle, - "KMeans.fit: async call returned (fit could still be running on the " - "device)\n"); -} - -}; // end namespace impl -}; // end namespace opg -}; // end namespace kmeans -}; // end namespace ML diff --git a/cpp/src/kmeans/kmeans_predict.cu b/cpp/src/kmeans/kmeans_predict.cu index c0e79757f5..b65af8c21a 100644 --- a/cpp/src/kmeans/kmeans_predict.cu +++ b/cpp/src/kmeans/kmeans_predict.cu @@ -14,10 +14,10 @@ * limitations under the License. */ -#include -#include #include +#include + namespace ML { namespace kmeans { @@ -25,7 +25,7 @@ namespace kmeans { template void predict_impl(const raft::handle_t& handle, - const raft::cluster::KMeansParams& params, + const cuvs::cluster::kmeans::params& params, const value_t* centroids, const value_t* X, idx_t n_samples, @@ -45,12 +45,12 @@ void predict_impl(const raft::handle_t& handle, auto rLabels = raft::make_device_vector_view(labels, n_samples); auto inertia_view = raft::make_host_scalar_view(&inertia); - raft::cluster::kmeans_predict( + cuvs::cluster::kmeans::predict( handle, params, X_view, sw, centroids_view, rLabels, normalize_weights, inertia_view); } void predict(const raft::handle_t& handle, - const raft::cluster::KMeansParams& params, + const cuvs::cluster::kmeans::params& params, const float* centroids, const float* X, int n_samples, @@ -73,7 +73,7 @@ void predict(const raft::handle_t& handle, } void predict(const raft::handle_t& handle, - const raft::cluster::KMeansParams& params, + const cuvs::cluster::kmeans::params& params, const double* centroids, const double* X, int n_samples, @@ -96,7 +96,7 @@ void predict(const raft::handle_t& handle, } void predict(const raft::handle_t& handle, - const raft::cluster::KMeansParams& params, + const cuvs::cluster::kmeans::params& params, const float* centroids, const float* X, int64_t n_samples, @@ -119,7 +119,7 @@ void predict(const raft::handle_t& handle, } void predict(const raft::handle_t& handle, - const raft::cluster::KMeansParams& params, + const cuvs::cluster::kmeans::params& params, const double* centroids, const double* X, int64_t n_samples, diff --git a/cpp/src/kmeans/kmeans_transform.cu b/cpp/src/kmeans/kmeans_transform.cu index 2cef14f787..93a9851cf9 100644 --- a/cpp/src/kmeans/kmeans_transform.cu +++ b/cpp/src/kmeans/kmeans_transform.cu @@ -14,17 +14,17 @@ * limitations under the License. */ -#include -#include #include +#include + namespace ML { namespace kmeans { // ----------------------------- transform ---------------------------------// template void transform_impl(const raft::handle_t& handle, - const raft::cluster::KMeansParams& params, + const cuvs::cluster::kmeans::params& params, const value_t* centroids, const value_t* X, idx_t n_samples, @@ -36,11 +36,11 @@ void transform_impl(const raft::handle_t& handle, raft::make_device_matrix_view(centroids, params.n_clusters, n_features); auto rX_new = raft::make_device_matrix_view(X_new, n_samples, n_features); - raft::cluster::kmeans::transform(handle, params, X_view, centroids_view, rX_new); + cuvs::cluster::kmeans::transform(handle, params, X_view, centroids_view, rX_new); } void transform(const raft::handle_t& handle, - const raft::cluster::KMeansParams& params, + const cuvs::cluster::kmeans::params& params, const float* centroids, const float* X, int n_samples, @@ -51,7 +51,7 @@ void transform(const raft::handle_t& handle, } void transform(const raft::handle_t& handle, - const raft::cluster::KMeansParams& params, + const cuvs::cluster::kmeans::params& params, const double* centroids, const double* X, int n_samples, @@ -62,7 +62,7 @@ void transform(const raft::handle_t& handle, } void transform(const raft::handle_t& handle, - const raft::cluster::KMeansParams& params, + const cuvs::cluster::kmeans::params& params, const float* centroids, const float* X, int64_t n_samples, @@ -73,7 +73,7 @@ void transform(const raft::handle_t& handle, } void transform(const raft::handle_t& handle, - const raft::cluster::KMeansParams& params, + const cuvs::cluster::kmeans::params& params, const double* centroids, const double* X, int64_t n_samples, diff --git a/cpp/src/knn/knn.cu b/cpp/src/knn/knn.cu index 3d5329abc6..c08cd47de2 100644 --- a/cpp/src/knn/knn.cu +++ b/cpp/src/knn/knn.cu @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -28,6 +29,7 @@ #include +#include #include #include @@ -36,7 +38,6 @@ #include namespace ML { - void brute_force_knn(const raft::handle_t& handle, std::vector& input, std::vector& sizes, @@ -49,24 +50,108 @@ void brute_force_knn(const raft::handle_t& handle, bool rowMajorIndex, bool rowMajorQuery, raft::distance::DistanceType metric, - float metric_arg) + float metric_arg, + std::vector* translations) { ASSERT(input.size() == sizes.size(), "input and sizes vectors must be the same size"); - raft::spatial::knn::brute_force_knn(handle, - input, - sizes, - D, - search_items, - n, - res_I, - res_D, - k, - rowMajorIndex, - rowMajorQuery, - nullptr, - metric, - metric_arg); + // The cuvs api doesn't support having multiple input values to search against. + auto userStream = raft::resource::get_cuda_stream(handle); + + ASSERT(input.size() == sizes.size(), "input and sizes vectors should be the same size"); + + std::vector* id_ranges; + if (translations == nullptr) { + // If we don't have explicit translations + // for offsets of the indices, build them + // from the local partitions + id_ranges = new std::vector(); + int64_t total_n = 0; + for (size_t i = 0; i < input.size(); i++) { + id_ranges->push_back(total_n); + total_n += sizes[i]; + } + } else { + // otherwise, use the given translations + id_ranges = translations; + } + + rmm::device_uvector trans(id_ranges->size(), userStream); + raft::update_device(trans.data(), id_ranges->data(), id_ranges->size(), userStream); + + rmm::device_uvector all_D(0, userStream); + rmm::device_uvector all_I(0, userStream); + + float* out_D = res_D; + int64_t* out_I = res_I; + + if (input.size() > 1) { + all_D.resize(input.size() * k * n, userStream); + all_I.resize(input.size() * k * n, userStream); + + out_D = all_D.data(); + out_I = all_I.data(); + } + + // Make other streams from pool wait on main stream + raft::resource::wait_stream_pool_on_stream(handle); + + for (size_t i = 0; i < input.size(); i++) { + float* out_d_ptr = out_D + (i * k * n); + int64_t* out_i_ptr = out_I + (i * k * n); + + auto stream = raft::resource::get_next_usable_stream(handle, i); + auto current_handle = raft::device_resources(stream); + + // build the brute_force index (precalculates norms etc) + std::optional> idx; + if (rowMajorIndex) { + idx = cuvs::neighbors::brute_force::build( + current_handle, + raft::make_device_matrix_view(input[i], sizes[i], D), + static_cast(metric), + metric_arg); + + } else { + idx = cuvs::neighbors::brute_force::build( + current_handle, + raft::make_device_matrix_view(input[i], sizes[i], D), + static_cast(metric), + metric_arg); + } + + // query the index + if (rowMajorQuery) { + cuvs::neighbors::brute_force::search( + current_handle, + *idx, + raft::make_device_matrix_view(search_items, n, D), + raft::make_device_matrix_view(out_i_ptr, n, k), + raft::make_device_matrix_view(out_d_ptr, n, k)); + } else { + cuvs::neighbors::brute_force::search( + current_handle, + *idx, + raft::make_device_matrix_view(search_items, n, D), + raft::make_device_matrix_view(out_i_ptr, n, k), + raft::make_device_matrix_view(out_d_ptr, n, k)); + } + } + + // Sync internal streams if used. We don't need to + // sync the user stream because we'll already have + // fully serial execution. + raft::resource::sync_stream_pool(handle); + + if (input.size() > 1 || translations != nullptr) { + // This is necessary for proper index translations. If there are + // no translations or partitions to combine, it can be skipped. + // TODO: sort out where this knn_merge_parts should live + raft::spatial::knn::knn_merge_parts( + out_D, out_I, res_D, res_I, n, input.size(), k, userStream, trans.data()); + } + + if (translations == nullptr) delete id_ranges; } void rbc_build_index(const raft::handle_t& handle, @@ -83,32 +168,119 @@ void rbc_knn_query(const raft::handle_t& handle, int64_t* out_inds, float* out_dists) { + // TODO: we're using this from raft in header only mode, decide if we should split out to a + // separate instantiation here raft::spatial::knn::rbc_knn_query( handle, index, k, search_items, n_search_items, out_inds, out_dists); } void approx_knn_build_index(raft::handle_t& handle, - raft::spatial::knn::knnIndex* index, - raft::spatial::knn::knnIndexParam* params, + knnIndex* index, + knnIndexParam* params, raft::distance::DistanceType metric, float metricArg, float* index_array, int n, int D) { - raft::spatial::knn::approx_knn_build_index( - handle, index, params, metric, metricArg, index_array, n, D); + index->metric = metric; + index->metricArg = metricArg; + + auto ivf_ft_pams = dynamic_cast(params); + auto ivf_pq_pams = dynamic_cast(params); + + index->metric_processor = raft::spatial::knn::create_processor( + metric, n, D, 0, false, raft::resource::get_cuda_stream(handle)); + // For cosine/correlation distance, the metric processor translates distance + // to inner product via pre/post processing - pass the translated metric to + // ANN index + if (metric == raft::distance::DistanceType::CosineExpanded || + metric == raft::distance::DistanceType::CorrelationExpanded) { + metric = index->metric = raft::distance::DistanceType::InnerProduct; + } + index->metric_processor->preprocess(index_array); + auto index_view = raft::make_device_matrix_view(index_array, n, D); + + if (ivf_ft_pams) { + index->nprobe = ivf_ft_pams->nprobe; + cuvs::neighbors::ivf_flat::index_params params; + params.metric = static_cast(metric); + params.metric_arg = metricArg; + params.n_lists = ivf_ft_pams->nlist; + + index->ivf_flat = std::make_unique>( + cuvs::neighbors::ivf_flat::build(handle, params, index_view)); + } else if (ivf_pq_pams) { + index->nprobe = ivf_pq_pams->nprobe; + cuvs::neighbors::ivf_pq::index_params params; + params.metric = static_cast(metric); + params.metric_arg = metricArg; + params.n_lists = ivf_pq_pams->nlist; + params.pq_bits = ivf_pq_pams->n_bits; + params.pq_dim = ivf_pq_pams->M; + // TODO: handle ivf_pq_pams.usePrecomputedTables ? + + index->ivf_pq = std::make_unique>( + cuvs::neighbors::ivf_pq::build(handle, params, index_view)); + } else { + RAFT_FAIL("Unrecognized index type."); + } + + index->metric_processor->revert(index_array); } void approx_knn_search(raft::handle_t& handle, float* distances, int64_t* indices, - raft::spatial::knn::knnIndex* index, + knnIndex* index, int k, float* query_array, int n) { - raft::spatial::knn::approx_knn_search(handle, distances, indices, index, k, query_array, n); + index->metric_processor->preprocess(query_array); + index->metric_processor->set_num_queries(k); + + auto indices_view = raft::make_device_matrix_view(indices, n, k); + auto distances_view = raft::make_device_matrix_view(distances, n, k); + + if (index->ivf_flat) { + auto query_view = + raft::make_device_matrix_view(query_array, n, index->ivf_flat->dim()); + cuvs::neighbors::ivf_flat::search_params params; + params.n_probes = index->nprobe; + + cuvs::neighbors::ivf_flat::search( + handle, params, *index->ivf_flat, query_view, indices_view, distances_view); + } else if (index->ivf_pq) { + auto query_view = + raft::make_device_matrix_view(query_array, n, index->ivf_pq->dim()); + cuvs::neighbors::ivf_pq::search_params params; + params.n_probes = index->nprobe; + + cuvs::neighbors::ivf_pq::search( + handle, params, *index->ivf_pq, query_view, indices_view, distances_view); + } else { + RAFT_FAIL("The model is not trained"); + } + + index->metric_processor->revert(query_array); + + // perform post-processing to show the real distances + if (index->metric == raft::distance::DistanceType::L2SqrtExpanded || + index->metric == raft::distance::DistanceType::L2SqrtUnexpanded || + index->metric == raft::distance::DistanceType::LpUnexpanded) { + /** + * post-processing + */ + float p = 0.5; // standard l2 + if (index->metric == raft::distance::DistanceType::LpUnexpanded) p = 1.0 / index->metricArg; + raft::linalg::unaryOp(distances, + distances, + n * k, + raft::pow_const_op(p), + raft::resource::get_cuda_stream(handle)); + } + index->metric_processor->postprocess(distances); } void knn_classify(raft::handle_t& handle, diff --git a/cpp/src/knn/knn_opg_common.cuh b/cpp/src/knn/knn_opg_common.cuh index 188244d643..bcf3fe81ae 100644 --- a/cpp/src/knn/knn_opg_common.cuh +++ b/cpp/src/knn/knn_opg_common.cuh @@ -426,7 +426,7 @@ void perform_local_knn(opg_knn_param& params, size_t query_size) { std::vector ptrs(params.idx_data->size()); - std::vector sizes(params.idx_data->size()); + std::vector sizes(params.idx_data->size()); for (std::size_t cur_idx = 0; cur_idx < params.idx_data->size(); cur_idx++) { ptrs[cur_idx] = params.idx_data->at(cur_idx)->ptr; @@ -443,20 +443,20 @@ void perform_local_knn(opg_knn_param& params, // ID ranges need to be offset by each local partition's // starting indices. - raft::spatial::knn::brute_force_knn( - handle, - ptrs, - sizes, - params.idx_desc->N, - query, - query_size, - work.res_I.data(), - work.res_D.data(), - params.k, - params.rowMajorIndex, - params.rowMajorQuery, - &start_indices_long, - raft::distance::DistanceType::L2SqrtExpanded); + brute_force_knn(handle, + ptrs, + sizes, + params.idx_desc->N, + query, + query_size, + work.res_I.data(), + work.res_D.data(), + params.k, + params.rowMajorIndex, + params.rowMajorQuery, + raft::distance::DistanceType::L2SqrtExpanded, + 2.0f, + &start_indices_long); handle.sync_stream(handle.get_stream()); RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/src/metrics/pairwise_distance.cu b/cpp/src/metrics/pairwise_distance.cu index 4cb9fb60d1..1f56ae8bf9 100644 --- a/cpp/src/metrics/pairwise_distance.cu +++ b/cpp/src/metrics/pairwise_distance.cu @@ -15,25 +15,14 @@ * limitations under the License. */ -#include "pairwise_distance_canberra.cuh" -#include "pairwise_distance_chebyshev.cuh" -#include "pairwise_distance_correlation.cuh" -#include "pairwise_distance_cosine.cuh" -#include "pairwise_distance_euclidean.cuh" -#include "pairwise_distance_hamming.cuh" -#include "pairwise_distance_hellinger.cuh" -#include "pairwise_distance_jensen_shannon.cuh" -#include "pairwise_distance_kl_divergence.cuh" -#include "pairwise_distance_l1.cuh" -#include "pairwise_distance_minkowski.cuh" -#include "pairwise_distance_russell_rao.cuh" - #include #include #include #include +#include + namespace ML { namespace Metrics { @@ -48,48 +37,23 @@ void pairwise_distance(const raft::handle_t& handle, bool isRowMajor, double metric_arg) { - switch (metric) { - case raft::distance::DistanceType::L2Expanded: - case raft::distance::DistanceType::L2SqrtExpanded: - case raft::distance::DistanceType::L2Unexpanded: - case raft::distance::DistanceType::L2SqrtUnexpanded: - pairwise_distance_euclidean(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::CosineExpanded: - pairwise_distance_cosine(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::L1: - pairwise_distance_l1(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::Linf: - pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::HellingerExpanded: - pairwise_distance_hellinger(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::LpUnexpanded: - pairwise_distance_minkowski(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::Canberra: - pairwise_distance_canberra(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::CorrelationExpanded: - pairwise_distance_correlation(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::HammingUnexpanded: - pairwise_distance_hamming(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::JensenShannon: - pairwise_distance_jensen_shannon(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::KLDivergence: - pairwise_distance_kl_divergence(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::RusselRaoExpanded: - pairwise_distance_russell_rao(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - }; + if (isRowMajor) { + cuvs::distance::pairwise_distance( + handle, + raft::make_device_matrix_view(x, m, k), + raft::make_device_matrix_view(y, n, k), + raft::make_device_matrix_view(dist, m, n), + static_cast(metric), + metric_arg); + } else { + cuvs::distance::pairwise_distance( + handle, + raft::make_device_matrix_view(x, m, k), + raft::make_device_matrix_view(y, n, k), + raft::make_device_matrix_view(dist, m, n), + static_cast(metric), + metric_arg); + } } void pairwise_distance(const raft::handle_t& handle, @@ -103,48 +67,23 @@ void pairwise_distance(const raft::handle_t& handle, bool isRowMajor, float metric_arg) { - switch (metric) { - case raft::distance::DistanceType::L2Expanded: - case raft::distance::DistanceType::L2SqrtExpanded: - case raft::distance::DistanceType::L2Unexpanded: - case raft::distance::DistanceType::L2SqrtUnexpanded: - pairwise_distance_euclidean(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::CosineExpanded: - pairwise_distance_cosine(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::L1: - pairwise_distance_l1(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::Linf: - pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::HellingerExpanded: - pairwise_distance_hellinger(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::LpUnexpanded: - pairwise_distance_minkowski(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::Canberra: - pairwise_distance_canberra(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::CorrelationExpanded: - pairwise_distance_correlation(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::HammingUnexpanded: - pairwise_distance_hamming(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::JensenShannon: - pairwise_distance_jensen_shannon(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::KLDivergence: - pairwise_distance_kl_divergence(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::RusselRaoExpanded: - pairwise_distance_russell_rao(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - }; + if (isRowMajor) { + cuvs::distance::pairwise_distance( + handle, + raft::make_device_matrix_view(x, m, k), + raft::make_device_matrix_view(y, n, k), + raft::make_device_matrix_view(dist, m, n), + static_cast(metric), + metric_arg); + } else { + cuvs::distance::pairwise_distance( + handle, + raft::make_device_matrix_view(x, m, k), + raft::make_device_matrix_view(y, n, k), + raft::make_device_matrix_view(dist, m, n), + static_cast(metric), + metric_arg); + } } template diff --git a/cpp/src/metrics/pairwise_distance_canberra.cu b/cpp/src/metrics/pairwise_distance_canberra.cu deleted file mode 100644 index dab8f84d30..0000000000 --- a/cpp/src/metrics/pairwise_distance_canberra.cu +++ /dev/null @@ -1,59 +0,0 @@ - -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pairwise_distance_canberra.cuh" - -#include -#include - -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_canberra(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg) -{ - // Call the distance function - raft::distance::distance( - handle, x, y, dist, m, n, k, isRowMajor); -} - -void pairwise_distance_canberra(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg) -{ - // Call the distance function - raft::distance::distance( - handle, x, y, dist, m, n, k, isRowMajor); -} - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_canberra.cuh b/cpp/src/metrics/pairwise_distance_canberra.cuh deleted file mode 100644 index 76244fae8b..0000000000 --- a/cpp/src/metrics/pairwise_distance_canberra.cuh +++ /dev/null @@ -1,47 +0,0 @@ - -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_canberra(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg); - -void pairwise_distance_canberra(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg); - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_chebyshev.cu b/cpp/src/metrics/pairwise_distance_chebyshev.cu deleted file mode 100644 index b5e99de16c..0000000000 --- a/cpp/src/metrics/pairwise_distance_chebyshev.cu +++ /dev/null @@ -1,58 +0,0 @@ - -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pairwise_distance_chebyshev.cuh" - -#include -#include - -#include -namespace ML { - -namespace Metrics { -void pairwise_distance_chebyshev(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg) -{ - // Call the distance function - raft::distance::distance( - handle, x, y, dist, m, n, k, isRowMajor); -} - -void pairwise_distance_chebyshev(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg) -{ - // Call the distance function - raft::distance::distance( - handle, x, y, dist, m, n, k, isRowMajor); -} - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_chebyshev.cuh b/cpp/src/metrics/pairwise_distance_chebyshev.cuh deleted file mode 100644 index d2219a5f84..0000000000 --- a/cpp/src/metrics/pairwise_distance_chebyshev.cuh +++ /dev/null @@ -1,46 +0,0 @@ - -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_chebyshev(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg); - -void pairwise_distance_chebyshev(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg); - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_correlation.cu b/cpp/src/metrics/pairwise_distance_correlation.cu deleted file mode 100644 index 36fe78ec7e..0000000000 --- a/cpp/src/metrics/pairwise_distance_correlation.cu +++ /dev/null @@ -1,61 +0,0 @@ - -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pairwise_distance_correlation.cuh" - -#include -#include - -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_correlation(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg) -{ - // Call the distance function - raft::distance:: - distance( - handle, x, y, dist, m, n, k, isRowMajor); -} - -void pairwise_distance_correlation(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg) -{ - // Call the distance function - raft::distance:: - distance( - handle, x, y, dist, m, n, k, isRowMajor); -} - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_correlation.cuh b/cpp/src/metrics/pairwise_distance_correlation.cuh deleted file mode 100644 index e34c2224e7..0000000000 --- a/cpp/src/metrics/pairwise_distance_correlation.cuh +++ /dev/null @@ -1,47 +0,0 @@ - -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_correlation(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg); - -void pairwise_distance_correlation(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg); - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_cosine.cu b/cpp/src/metrics/pairwise_distance_cosine.cu deleted file mode 100644 index e207b8ad06..0000000000 --- a/cpp/src/metrics/pairwise_distance_cosine.cu +++ /dev/null @@ -1,60 +0,0 @@ - -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pairwise_distance_cosine.cuh" - -#include -#include - -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_cosine(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg) -{ - // Call the distance function - raft::distance:: - distance( - handle, x, y, dist, m, n, k, isRowMajor); -} - -void pairwise_distance_cosine(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg) -{ - // Call the distance function - raft::distance::distance( - handle, x, y, dist, m, n, k, isRowMajor); -} - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_cosine.cuh b/cpp/src/metrics/pairwise_distance_cosine.cuh deleted file mode 100644 index 0024ed52ac..0000000000 --- a/cpp/src/metrics/pairwise_distance_cosine.cuh +++ /dev/null @@ -1,46 +0,0 @@ - -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_cosine(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg); - -void pairwise_distance_cosine(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg); - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_euclidean.cu b/cpp/src/metrics/pairwise_distance_euclidean.cu deleted file mode 100644 index 56cb609067..0000000000 --- a/cpp/src/metrics/pairwise_distance_euclidean.cu +++ /dev/null @@ -1,102 +0,0 @@ - -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pairwise_distance_euclidean.cuh" - -#include -#include - -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_euclidean(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - double metric_arg) -{ - // Call the distance function - switch (metric) { - case raft::distance::DistanceType::L2Expanded: - raft::distance:: - distance( - handle, x, y, dist, m, n, k, isRowMajor); - break; - case raft::distance::DistanceType::L2SqrtExpanded: - raft::distance:: - distance( - handle, x, y, dist, m, n, k, isRowMajor); - break; - case raft::distance::DistanceType::L2Unexpanded: - raft::distance:: - distance( - handle, x, y, dist, m, n, k, isRowMajor); - break; - case raft::distance::DistanceType::L2SqrtUnexpanded: - raft::distance:: - distance( - handle, x, y, dist, m, n, k, isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } -} - -void pairwise_distance_euclidean(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - float metric_arg) -{ - // Call the distance function - switch (metric) { - case raft::distance::DistanceType::L2Expanded: - raft::distance::distance( - handle, x, y, dist, m, n, k, isRowMajor); - break; - case raft::distance::DistanceType::L2SqrtExpanded: - raft::distance:: - distance( - handle, x, y, dist, m, n, k, isRowMajor); - break; - case raft::distance::DistanceType::L2Unexpanded: - raft::distance:: - distance( - handle, x, y, dist, m, n, k, isRowMajor); - break; - case raft::distance::DistanceType::L2SqrtUnexpanded: - raft::distance:: - distance( - handle, x, y, dist, m, n, k, isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } -} - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_euclidean.cuh b/cpp/src/metrics/pairwise_distance_euclidean.cuh deleted file mode 100644 index 65a3f41683..0000000000 --- a/cpp/src/metrics/pairwise_distance_euclidean.cuh +++ /dev/null @@ -1,46 +0,0 @@ - -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_euclidean(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - double metric_arg); - -void pairwise_distance_euclidean(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - float metric_arg); -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_hamming.cu b/cpp/src/metrics/pairwise_distance_hamming.cu deleted file mode 100644 index 080faffe96..0000000000 --- a/cpp/src/metrics/pairwise_distance_hamming.cu +++ /dev/null @@ -1,61 +0,0 @@ - -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pairwise_distance_hamming.cuh" - -#include -#include - -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_hamming(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg) -{ - // Call the distance function - raft::distance:: - distance( - handle, x, y, dist, m, n, k, isRowMajor); -} - -void pairwise_distance_hamming(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg) -{ - // Call the distance function - raft::distance:: - distance( - handle, x, y, dist, m, n, k, isRowMajor); -} - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_hamming.cuh b/cpp/src/metrics/pairwise_distance_hamming.cuh deleted file mode 100644 index 7c89b6b58f..0000000000 --- a/cpp/src/metrics/pairwise_distance_hamming.cuh +++ /dev/null @@ -1,47 +0,0 @@ - -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_hamming(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg); - -void pairwise_distance_hamming(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg); - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_hellinger.cu b/cpp/src/metrics/pairwise_distance_hellinger.cu deleted file mode 100644 index 6cb09e0d93..0000000000 --- a/cpp/src/metrics/pairwise_distance_hellinger.cu +++ /dev/null @@ -1,60 +0,0 @@ - -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pairwise_distance_hellinger.cuh" - -#include -#include - -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_hellinger(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg) -{ - // Call the distance function - raft::distance:: - distance( - handle, x, y, dist, m, n, k, isRowMajor); -} - -void pairwise_distance_hellinger(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg) -{ - raft::distance:: - distance( - handle, x, y, dist, m, n, k, isRowMajor); -} - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_hellinger.cuh b/cpp/src/metrics/pairwise_distance_hellinger.cuh deleted file mode 100644 index 94d0089223..0000000000 --- a/cpp/src/metrics/pairwise_distance_hellinger.cuh +++ /dev/null @@ -1,45 +0,0 @@ - -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_hellinger(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg); - -void pairwise_distance_hellinger(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg); -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_jensen_shannon.cu b/cpp/src/metrics/pairwise_distance_jensen_shannon.cu deleted file mode 100644 index cf25a8cb5e..0000000000 --- a/cpp/src/metrics/pairwise_distance_jensen_shannon.cu +++ /dev/null @@ -1,58 +0,0 @@ - -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pairwise_distance_jensen_shannon.cuh" - -#include -#include - -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_jensen_shannon(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg) -{ - raft::distance:: - distance( - handle, x, y, dist, m, n, k, isRowMajor); -} - -void pairwise_distance_jensen_shannon(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg) -{ - raft::distance::distance( - handle, x, y, dist, m, n, k, isRowMajor); -} - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_jensen_shannon.cuh b/cpp/src/metrics/pairwise_distance_jensen_shannon.cuh deleted file mode 100644 index c4ba9a218e..0000000000 --- a/cpp/src/metrics/pairwise_distance_jensen_shannon.cuh +++ /dev/null @@ -1,47 +0,0 @@ - -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_jensen_shannon(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg); - -void pairwise_distance_jensen_shannon(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg); - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_kl_divergence.cu b/cpp/src/metrics/pairwise_distance_kl_divergence.cu deleted file mode 100644 index c47fb7adae..0000000000 --- a/cpp/src/metrics/pairwise_distance_kl_divergence.cu +++ /dev/null @@ -1,57 +0,0 @@ - -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pairwise_distance_kl_divergence.cuh" - -#include -#include - -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_kl_divergence(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg) -{ - raft::distance::distance( - handle, x, y, dist, m, n, k, isRowMajor); -} - -void pairwise_distance_kl_divergence(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg) -{ - raft::distance::distance( - handle, x, y, dist, m, n, k, isRowMajor); -} - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_kl_divergence.cuh b/cpp/src/metrics/pairwise_distance_kl_divergence.cuh deleted file mode 100644 index f3866de89d..0000000000 --- a/cpp/src/metrics/pairwise_distance_kl_divergence.cuh +++ /dev/null @@ -1,47 +0,0 @@ - -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_kl_divergence(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg); - -void pairwise_distance_kl_divergence(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg); - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_l1.cu b/cpp/src/metrics/pairwise_distance_l1.cu deleted file mode 100644 index 3d3ebf625b..0000000000 --- a/cpp/src/metrics/pairwise_distance_l1.cu +++ /dev/null @@ -1,57 +0,0 @@ - -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pairwise_distance_l1.cuh" - -#include -#include - -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_l1(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg) -{ - raft::distance::distance( - handle, x, y, dist, m, n, k, isRowMajor); -} - -void pairwise_distance_l1(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg) -{ - raft::distance::distance( - handle, x, y, dist, m, n, k, isRowMajor); -} - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_l1.cuh b/cpp/src/metrics/pairwise_distance_l1.cuh deleted file mode 100644 index 311cc186e7..0000000000 --- a/cpp/src/metrics/pairwise_distance_l1.cuh +++ /dev/null @@ -1,46 +0,0 @@ - -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_l1(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg); - -void pairwise_distance_l1(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg); - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_minkowski.cu b/cpp/src/metrics/pairwise_distance_minkowski.cu deleted file mode 100644 index 91146955e9..0000000000 --- a/cpp/src/metrics/pairwise_distance_minkowski.cu +++ /dev/null @@ -1,57 +0,0 @@ - -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pairwise_distance_minkowski.cuh" - -#include -#include - -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_minkowski(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg) -{ - raft::distance::distance( - handle, x, y, dist, m, n, k, isRowMajor, metric_arg); -} - -void pairwise_distance_minkowski(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg) -{ - raft::distance::distance( - handle, x, y, dist, m, n, k, isRowMajor, metric_arg); -} - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_minkowski.cuh b/cpp/src/metrics/pairwise_distance_minkowski.cuh deleted file mode 100644 index 0fe15de554..0000000000 --- a/cpp/src/metrics/pairwise_distance_minkowski.cuh +++ /dev/null @@ -1,47 +0,0 @@ - -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_minkowski(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg); - -void pairwise_distance_minkowski(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg); - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_russell_rao.cu b/cpp/src/metrics/pairwise_distance_russell_rao.cu deleted file mode 100644 index efeb66d13f..0000000000 --- a/cpp/src/metrics/pairwise_distance_russell_rao.cu +++ /dev/null @@ -1,59 +0,0 @@ - -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pairwise_distance_russell_rao.cuh" - -#include -#include - -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_russell_rao(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg) -{ - raft::distance:: - distance( - handle, x, y, dist, m, n, k, isRowMajor); -} - -void pairwise_distance_russell_rao(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg) -{ - raft::distance:: - distance( - handle, x, y, dist, m, n, k, isRowMajor); -} - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_russell_rao.cuh b/cpp/src/metrics/pairwise_distance_russell_rao.cuh deleted file mode 100644 index 574a5116df..0000000000 --- a/cpp/src/metrics/pairwise_distance_russell_rao.cuh +++ /dev/null @@ -1,47 +0,0 @@ - -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace ML { - -namespace Metrics { -void pairwise_distance_russell_rao(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - bool isRowMajor, - double metric_arg); - -void pairwise_distance_russell_rao(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - bool isRowMajor, - float metric_arg); - -} // namespace Metrics -} // namespace ML diff --git a/cpp/src/metrics/silhouette_score.cu b/cpp/src/metrics/silhouette_score.cu index 883a9eabb3..cf4fdeb6fd 100644 --- a/cpp/src/metrics/silhouette_score.cu +++ b/cpp/src/metrics/silhouette_score.cu @@ -18,8 +18,9 @@ #include #include -#include -#include + +#include +#include namespace ML { @@ -33,9 +34,18 @@ double silhouette_score(const raft::handle_t& handle, double* silScores, raft::distance::DistanceType metric) { - return raft::stats::silhouette_score( - handle, y, nRows, nCols, labels, nLabels, silScores, handle.get_stream(), metric); -} + std::optional> silhouette_score_per_sample; + if (silScores != NULL) { + silhouette_score_per_sample = raft::make_device_vector_view(silScores, nRows); + } + return cuvs::stats::silhouette_score( + handle, + raft::make_device_matrix_view(y, nRows, nCols), + raft::make_device_vector_view(labels, nRows), + silhouette_score_per_sample, + nLabels, + static_cast(metric)); +} } // namespace Metrics } // namespace ML diff --git a/cpp/src/metrics/silhouette_score_batched_double.cu b/cpp/src/metrics/silhouette_score_batched_double.cu index 961170a63f..3188ce0bc4 100644 --- a/cpp/src/metrics/silhouette_score_batched_double.cu +++ b/cpp/src/metrics/silhouette_score_batched_double.cu @@ -18,11 +18,11 @@ #include #include -#include -#include -namespace ML { +#include +#include +namespace ML { namespace Metrics { namespace Batched { @@ -36,11 +36,21 @@ double silhouette_score(const raft::handle_t& handle, int chunk, raft::distance::DistanceType metric) { - return raft::stats::silhouette_score_batched( - handle, X, n_rows, n_cols, y, n_labels, scores, chunk, metric); + std::optional> silhouette_score_per_sample; + if (scores != NULL) { + silhouette_score_per_sample = raft::make_device_vector_view(scores, n_rows); + } + + return cuvs::stats::silhouette_score_batched( + handle, + raft::make_device_matrix_view(X, n_rows, n_cols), + raft::make_device_vector_view(y, n_rows), + silhouette_score_per_sample, + n_labels, + chunk, + static_cast(metric)); } } // namespace Batched - } // namespace Metrics } // namespace ML diff --git a/cpp/src/metrics/silhouette_score_batched_float.cu b/cpp/src/metrics/silhouette_score_batched_float.cu index d24129f465..0245375657 100644 --- a/cpp/src/metrics/silhouette_score_batched_float.cu +++ b/cpp/src/metrics/silhouette_score_batched_float.cu @@ -19,12 +19,11 @@ #include #include -#include -namespace ML { +#include +namespace ML { namespace Metrics { - namespace Batched { float silhouette_score(const raft::handle_t& handle, @@ -37,10 +36,20 @@ float silhouette_score(const raft::handle_t& handle, int chunk, raft::distance::DistanceType metric) { - return raft::stats::silhouette_score_batched( - handle, X, n_rows, n_cols, y, n_labels, scores, chunk, metric); + std::optional> silhouette_score_per_sample; + if (scores != NULL) { + silhouette_score_per_sample = raft::make_device_vector_view(scores, n_rows); + } + + return cuvs::stats::silhouette_score_batched( + handle, + raft::make_device_matrix_view(X, n_rows, n_cols), + raft::make_device_vector_view(y, n_rows), + silhouette_score_per_sample, + n_labels, + chunk, + static_cast(metric)); } - } // namespace Batched } // namespace Metrics } // namespace ML diff --git a/cpp/src/metrics/trustworthiness.cu b/cpp/src/metrics/trustworthiness.cu index 7b5f80a010..724ef43ddd 100644 --- a/cpp/src/metrics/trustworthiness.cu +++ b/cpp/src/metrics/trustworthiness.cu @@ -17,8 +17,8 @@ #include #include -#include -#include + +#include namespace ML { namespace Metrics { @@ -47,8 +47,13 @@ double trustworthiness_score(const raft::handle_t& h, int n_neighbors, int batchSize) { - return raft::stats::trustworthiness_score( - h, X, X_embedded, n, m, d, n_neighbors, batchSize); + return cuvs::stats::trustworthiness_score( + h, + raft::make_device_matrix_view(X, n, m), + raft::make_device_matrix_view(X_embedded, n, d), + n_neighbors, + static_cast(distance_type), + batchSize); } template double trustworthiness_score( diff --git a/cpp/src/tsne/distances.cuh b/cpp/src/tsne/distances.cuh index 4963f42eee..31be50942c 100644 --- a/cpp/src/tsne/distances.cuh +++ b/cpp/src/tsne/distances.cuh @@ -37,13 +37,14 @@ #include #include +#include #include namespace ML { namespace TSNE { /** - * @brief Uses FAISS's KNN to find the top n_neighbors. This speeds up the attractive forces. + * @brief Uses CUVS's KNN to find the top n_neighbors. This speeds up the attractive forces. * @param[in] input: dense/sparse manifold input * @param[out] indices: The output indices from KNN. * @param[out] distances: The output sorted distances from KNN. @@ -70,32 +71,18 @@ void get_distances(const raft::handle_t& handle, { // TODO: for TSNE transform first fit some points then transform with 1/(1+d^2) // #861 - - std::vector input_vec = {input.X}; - std::vector sizes_vec = {input.n}; - - /** - * std::vector &input, std::vector &sizes, - IntType D, float *search_items, IntType n, int64_t *res_I, - float *res_D, IntType k, - std::shared_ptr allocator, - cudaStream_t userStream, - */ - - raft::spatial::knn::brute_force_knn(handle, - input_vec, - sizes_vec, - input.d, - input.X, - input.n, - k_graph.knn_indices, - k_graph.knn_dists, - k_graph.n_neighbors, - false, - false, - nullptr, - metric, - p); + auto k = k_graph.n_neighbors; + auto X_view = + raft::make_device_matrix_view(input.X, input.n, input.d); + auto idx = cuvs::neighbors::brute_force::build( + handle, X_view, static_cast(metric), p); + + cuvs::neighbors::brute_force::search( + handle, + idx, + X_view, + raft::make_device_matrix_view(k_graph.knn_indices, input.n, k), + raft::make_device_matrix_view(k_graph.knn_dists, input.n, k)); } // dense, int32 indices diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index 92c717afcd..fa55397659 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -36,6 +36,8 @@ #include #include +#include + #include namespace NNDescent = raft::neighbors::experimental::nn_descent; @@ -92,26 +94,20 @@ inline void launcher(const raft::handle_t& handle, cudaStream_t stream) { if (params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN) { - std::vector ptrs(1); - std::vector sizes(1); - ptrs[0] = inputsA.X; - sizes[0] = inputsA.n; - - raft::spatial::knn::brute_force_knn(handle, - ptrs, - sizes, - inputsA.d, - inputsB.X, - inputsB.n, - out.knn_indices, - out.knn_dists, - n_neighbors, - true, - true, - static_cast*>(nullptr), - params->metric, - params->p); + auto idx = cuvs::neighbors::brute_force::build( + handle, + raft::make_device_matrix_view(inputsA.X, inputsA.n, inputsA.d), + static_cast(params->metric), + params->p); + + cuvs::neighbors::brute_force::search( + handle, + idx, + raft::make_device_matrix_view(inputsB.X, inputsB.n, inputsB.d), + raft::make_device_matrix_view(out.knn_indices, inputsB.n, n_neighbors), + raft::make_device_matrix_view(out.knn_dists, inputsB.n, n_neighbors)); } else { // nn_descent + // TODO: use nndescent from cuvs RAFT_EXPECTS(static_cast(n_neighbors) <= params->nn_descent_params.graph_degree, "n_neighbors should be smaller than the graph degree computed by nn descent"); diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 2a04100cdf..46bd275fcc 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -57,7 +57,6 @@ function(ConfigureTest) $<$:CUDA::cufft${_ctk_static_suffix_cufft}> rmm::rmm raft::raft - $<$:raft::compiled> GTest::gtest GTest::gtest_main GTest::gmock diff --git a/cpp/test/prims/knn_classify.cu b/cpp/test/prims/knn_classify.cu index 56bc6a245d..75d59e6fda 100644 --- a/cpp/test/prims/knn_classify.cu +++ b/cpp/test/prims/knn_classify.cu @@ -24,6 +24,7 @@ #include +#include #include #include @@ -73,20 +74,17 @@ class KNNClassifyTest : public ::testing::TestWithParam { auto n_classes = raft::label::getUniquelabels(unique_labels, train_labels.data(), params.rows, stream); - std::vector ptrs(1); - std::vector sizes(1); - ptrs[0] = train_samples.data(); - sizes[0] = params.rows; - - raft::spatial::knn::brute_force_knn(handle, - ptrs, - sizes, - params.cols, - train_samples.data(), - params.rows, - knn_indices.data(), - knn_dists.data(), - params.k); + auto train_view = raft::make_device_matrix_view( + train_samples.data(), params.rows, params.cols); + auto idx = cuvs::neighbors::brute_force::build( + handle, train_view, cuvs::distance::DistanceType::L2Unexpanded); + + cuvs::neighbors::brute_force::search( + handle, + idx, + train_view, + raft::make_device_matrix_view(knn_indices.data(), params.rows, params.k), + raft::make_device_matrix_view(knn_dists.data(), params.rows, params.k)); std::vector y; y.push_back(train_labels.data()); diff --git a/cpp/test/prims/knn_regression.cu b/cpp/test/prims/knn_regression.cu index 7c29c8ea1e..07ae30dfd5 100644 --- a/cpp/test/prims/knn_regression.cu +++ b/cpp/test/prims/knn_regression.cu @@ -29,6 +29,7 @@ #include #include +#include #include #include @@ -99,20 +100,18 @@ class KNNRegressionTest : public ::testing::TestWithParam { { generate_data(train_samples.data(), train_labels.data(), params.rows, params.cols, stream); - std::vector ptrs(1); - std::vector sizes(1); - ptrs[0] = train_samples.data(); - sizes[0] = params.rows; - - raft::spatial::knn::brute_force_knn(handle, - ptrs, - sizes, - params.cols, - train_samples.data(), - params.rows, - knn_indices.data(), - knn_dists.data(), - params.k); + auto train_view = raft::make_device_matrix_view( + train_samples.data(), params.rows, params.cols); + + auto idx = cuvs::neighbors::brute_force::build( + handle, train_view, cuvs::distance::DistanceType::L2Unexpanded); + + cuvs::neighbors::brute_force::search( + handle, + idx, + train_view, + raft::make_device_matrix_view(knn_indices.data(), params.rows, params.k), + raft::make_device_matrix_view(knn_dists.data(), params.rows, params.k)); std::vector y; y.push_back(train_labels.data()); diff --git a/cpp/test/sg/hdbscan_test.cu b/cpp/test/sg/hdbscan_test.cu index 08705f5a8b..a7ce69b1bc 100644 --- a/cpp/test/sg/hdbscan_test.cu +++ b/cpp/test/sg/hdbscan_test.cu @@ -19,7 +19,7 @@ #include -#include +#include // build_dendrogram_host #include #include #include @@ -34,6 +34,7 @@ #include #include +#include #include #include #include diff --git a/cpp/test/sg/umap_parametrizable_test.cu b/cpp/test/sg/umap_parametrizable_test.cu index 821437fb0e..2980a79394 100644 --- a/cpp/test/sg/umap_parametrizable_test.cu +++ b/cpp/test/sg/umap_parametrizable_test.cu @@ -126,20 +126,17 @@ class UMAPParametrizableTest : public ::testing::Test { knn_indices = knn_indices_b->data(); knn_dists = knn_dists_b->data(); - std::vector ptrs(1); - std::vector sizes(1); - ptrs[0] = X; - sizes[0] = n_samples; - - raft::spatial::knn::brute_force_knn(handle, - ptrs, - sizes, - n_features, - X, - n_samples, - knn_indices, - knn_dists, - umap_params.n_neighbors); + auto X_view = raft::make_device_matrix_view(X, n_samples, n_features); + auto idx = cuvs::neighbors::brute_force::build( + handle, X_view, cuvs::distance::DistanceType::L2Unexpanded); + + cuvs::neighbors::brute_force::search(handle, + idx, + X_view, + raft::make_device_matrix_view( + knn_indices, n_samples, umap_params.n_neighbors), + raft::make_device_matrix_view( + knn_dists, n_samples, umap_params.n_neighbors)); handle.sync_stream(stream); } diff --git a/dependencies.yaml b/dependencies.yaml index 687c0bd9aa..0108588363 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -142,7 +142,7 @@ dependencies: - cxx-compiler - fmt>=11.0.2,<12 - libcumlprims==24.10.*,>=0.0.0a0 - - libraft==24.10.*,>=0.0.0a0 + - libcuvs==24.10.*,>=0.0.0a0 - libraft-headers==24.10.*,>=0.0.0a0 - librmm==24.10.*,>=0.0.0a0 - spdlog>=1.14.1,<1.15 @@ -183,6 +183,7 @@ dependencies: - &treelite treelite==4.3.0 - output_types: conda packages: + - &cuvs_unsuffixed cuvs==24.10.*,>=0.0.0a0 - &pylibraft_unsuffixed pylibraft==24.10.*,>=0.0.0a0 - &rmm_unsuffixed rmm==24.10.*,>=0.0.0a0 - output_types: requirements @@ -211,16 +212,19 @@ dependencies: cuda: "12.*" cuda_suffixed: "true" packages: + - cuvs-cu12==24.10.*,>=0.0.0a0 - pylibraft-cu12==24.10.*,>=0.0.0a0 - rmm-cu12==24.10.*,>=0.0.0a0 - matrix: cuda: "11.*" cuda_suffixed: "true" packages: + - cuvs-cu11==24.10.*,>=0.0.0a0 - pylibraft-cu11==24.10.*,>=0.0.0a0 - rmm-cu11==24.10.*,>=0.0.0a0 - matrix: packages: + - *cuvs_unsuffixed - *pylibraft_unsuffixed - *rmm_unsuffixed @@ -260,6 +264,7 @@ dependencies: packages: - cudf-cu12==24.10.*,>=0.0.0a0 - &cupy_pyproject_cu12 cupy-cuda12x>=12.0.0 + - cuvs-cu12==24.10.*,>=0.0.0a0 - dask-cudf-cu12==24.10.*,>=0.0.0a0 - pylibraft-cu12==24.10.*,>=0.0.0a0 - raft-dask-cu12==24.10.*,>=0.0.0a0 @@ -272,6 +277,7 @@ dependencies: # NOTE: cupy still has a "-cuda12x" suffix here, because it's suffixed # in DLFW builds - *cupy_pyproject_cu12 + - *cuvs_unsuffixed - *dask_cudf_unsuffixed - *pylibraft_unsuffixed - *raft_dask_unsuffixed @@ -282,6 +288,7 @@ dependencies: packages: &py_run_packages_cu11 - cudf-cu11==24.10.*,>=0.0.0a0 - &cupy_pyproject_cu11 cupy-cuda11x>=12.0.0 + - cuvs-cu11==24.10.*,>=0.0.0a0 - dask-cudf-cu11==24.10.*,>=0.0.0a0 - pylibraft-cu11==24.10.*,>=0.0.0a0 - raft-dask-cu11==24.10.*,>=0.0.0a0 @@ -294,6 +301,7 @@ dependencies: # NOTE: cupy still has a "-cuda11x" suffix here, because it's suffixed # in DLFW builds - *cupy_pyproject_cu11 + - *cuvs_unsuffixed - *dask_cudf_unsuffixed - *pylibraft_unsuffixed - *raft_dask_unsuffixed @@ -302,6 +310,7 @@ dependencies: packages: - *cudf_unsuffixed - *cupy_pyproject_cu11 + - *cuvs_unsuffixed - *dask_cudf_unsuffixed - *pylibraft_unsuffixed - *raft_dask_unsuffixed diff --git a/python/cuml/CMakeLists.txt b/python/cuml/CMakeLists.txt index 224525ee58..221b5ebf75 100644 --- a/python/cuml/CMakeLists.txt +++ b/python/cuml/CMakeLists.txt @@ -39,6 +39,7 @@ option(CUML_UNIVERSAL "Build all cuML Python components." ON) option(FIND_CUML_CPP "Search for existing CUML C++ installations before defaulting to local files" OFF) option(SINGLEGPU "Disable all mnmg components and comms libraries" OFF) option(USE_CUDA_MATH_WHEELS "Use the CUDA math wheels instead of the system libraries" OFF) +option(USE_CUVS_WHEEL "Use the cuVS wheel" OFF) set(CUML_RAFT_CLONE_ON_PIN OFF) @@ -88,7 +89,7 @@ if(NOT CUML_CPU) # Statically link dependencies if building wheels set(CUDA_STATIC_RUNTIME ON) - set(CUML_USE_RAFT_STATIC ON) + set(CUML_USE_CUVS_STATIC ON) set(CUML_USE_FAISS_STATIC ON) set(CUML_USE_TREELITE_STATIC ON) set(CUML_USE_CUMLPRIMS_MG_STATIC ON) @@ -110,7 +111,7 @@ if(NOT CUML_CPU) add_subdirectory(${CUML_CPP_SRC} cuml-cpp EXCLUDE_FROM_ALL) if(NOT CUDA_STATIC_MATH_LIBRARIES AND USE_CUDA_MATH_WHEELS) - set_property(TARGET ${CUML_CPP_TARGET} PROPERTY INSTALL_RPATH + set(rpaths "$ORIGIN/../nvidia/cublas/lib" "$ORIGIN/../nvidia/cufft/lib" "$ORIGIN/../nvidia/curand/lib" @@ -118,6 +119,12 @@ if(NOT CUML_CPU) "$ORIGIN/../nvidia/cusparse/lib" "$ORIGIN/../nvidia/nvjitlink/lib" ) + set_property(TARGET ${CUML_CPP_TARGET} PROPERTY INSTALL_RPATH ${rpaths} APPEND) + endif() + + if(USE_CUVS_WHEEL) + set(rpaths "$ORIGIN/../cuvs") + set_property(TARGET ${CUML_CPP_TARGET} PROPERTY INSTALL_RPATH ${rpaths} APPEND) endif() set(cython_lib_dir cuml) @@ -205,3 +212,7 @@ add_subdirectory(cuml/experimental/linear_model) if(DEFINED cython_lib_dir) rapids_cython_add_rpath_entries(TARGET cuml PATHS "${cython_lib_dir}") endif() + +if(USE_CUVS_WHEEL) + rapids_cython_add_rpath_entries(TARGET cuml PATHS cuvs) +endif() diff --git a/python/cuml/cuml/cluster/cpp/kmeans.pxd b/python/cuml/cuml/cluster/cpp/kmeans.pxd index 53b4c44d1d..fa3db02c6e 100644 --- a/python/cuml/cuml/cluster/cpp/kmeans.pxd +++ b/python/cuml/cuml/cluster/cpp/kmeans.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,28 +27,10 @@ from libcpp cimport bool from cuml.metrics.distance_type cimport DistanceType from cuml.common.rng_state cimport RngState -cdef extern from "cuml/cluster/kmeans.hpp" namespace \ - "ML::kmeans::KMeansParams": - enum InitMethod: - KMeansPlusPlus, Random, Array - - cdef struct KMeansParams: - int n_clusters, - InitMethod init - int max_iter, - double tol, - int verbosity, - RngState rng_state, - DistanceType metric, - int n_init, - double oversampling_factor, - int batch_samples, - int batch_centroids, - bool inertia_check +from cuml.cluster.kmeans_utils cimport params as KMeansParams cdef extern from "cuml/cluster/kmeans.hpp" namespace "ML::kmeans": - cdef void fit_predict(handle_t& handle, KMeansParams& params, const float *X, diff --git a/python/cuml/cuml/cluster/kmeans.pyx b/python/cuml/cuml/cluster/kmeans.pyx index 3d6be3abf2..fde067c14c 100644 --- a/python/cuml/cuml/cluster/kmeans.pyx +++ b/python/cuml/cuml/cluster/kmeans.pyx @@ -33,9 +33,10 @@ IF GPUBUILD == 1: from cuml.cluster.cpp.kmeans cimport fit_predict as cpp_fit_predict from cuml.cluster.cpp.kmeans cimport predict as cpp_predict from cuml.cluster.cpp.kmeans cimport transform as cpp_transform - from cuml.cluster.cpp.kmeans cimport KMeansParams from cuml.metrics.distance_type cimport DistanceType - from cuml.cluster.kmeans_utils cimport * + from cuml.cluster.kmeans_utils cimport params as KMeansParams + from cuml.cluster.kmeans_utils cimport KMeansPlusPlus, Random, Array + from cuml.cluster.kmeans_utils cimport DistanceType as CuvsDistanceType from cuml.internals.array import CumlArray from cuml.common.array_descriptor import CumlArrayDescriptor @@ -207,7 +208,7 @@ class KMeans(UniversalBase, params.tol = self.tol params.verbosity = self.verbose params.rng_state.seed = self.random_state - params.metric = DistanceType.L2Expanded # distance metric as squared L2: @todo - support other metrics # noqa: E501 + params.metric = CuvsDistanceType.L2Expanded # distance metric as squared L2: @todo - support other metrics # noqa: E501 params.batch_samples = self.max_samples_per_batch params.oversampling_factor = self.oversampling_factor params.n_init = self.n_init @@ -609,7 +610,8 @@ class KMeans(UniversalBase, # distance metric as L2-norm/euclidean distance: @todo - support other metrics # noqa: E501 cdef KMeansParams* params = \ self._get_kmeans_params() - params.metric = DistanceType.L2SqrtExpanded + + params.metric = CuvsDistanceType.L2Expanded int_dtype = np.int32 if self.labels_.dtype == np.int32 else np.int64 diff --git a/python/cuml/cuml/cluster/kmeans_mg.pyx b/python/cuml/cuml/cluster/kmeans_mg.pyx index 51d415036b..cf0a2967c4 100644 --- a/python/cuml/cuml/cluster/kmeans_mg.pyx +++ b/python/cuml/cuml/cluster/kmeans_mg.pyx @@ -31,7 +31,7 @@ from pylibraft.common.handle cimport handle_t from cuml.common import input_to_cuml_array from cuml.cluster import KMeans -from cuml.cluster.kmeans_utils cimport * +from cuml.cluster.kmeans_utils cimport params as KMeansParams cdef extern from "cuml/cluster/kmeans_mg.hpp" \ diff --git a/python/cuml/cuml/cluster/kmeans_utils.pxd b/python/cuml/cuml/cluster/kmeans_utils.pxd index efbe27dcd7..17d58a49be 100644 --- a/python/cuml/cuml/cluster/kmeans_utils.pxd +++ b/python/cuml/cuml/cluster/kmeans_utils.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2022, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,15 +17,39 @@ import ctypes from libcpp cimport bool -from cuml.metrics.distance_type cimport DistanceType from cuml.common.rng_state cimport RngState +cdef extern from "cuvs/distance/distance.hpp" namespace \ + "cuvs::distance": + ctypedef enum DistanceType: + L2Expanded "cuvs::distance::DistanceType::L2Expanded" + L2SqrtExpanded "cuvs::distance::DistanceType::L2SqrtExpanded" + CosineExpanded "cuvs::distance::DistanceType::CosineExpanded" + L1 "cuvs::distance::DistanceType::L1" + L2Unexpanded "cuvs::distance::DistanceType::L2Unexpanded" + L2SqrtUnexpanded "cuvs::distance::DistanceType::L2SqrtUnexpanded" + InnerProduct "cuvs::distance::DistanceType::InnerProduct" + Linf "cuvs::distance::DistanceType::Linf" + Canberra "cuvs::distance::DistanceType::Canberra" + LpUnexpanded "cuvs::distance::DistanceType::LpUnexpanded" + CorrelationExpanded "cuvs::distance::DistanceType::CorrelationExpanded" + JaccardExpanded "cuvs::distance::DistanceType::JaccardExpanded" + HellingerExpanded "cuvs::distance::DistanceType::HellingerExpanded" + Haversine "cuvs::distance::DistanceType::Haversine" + BrayCurtis "cuvs::distance::DistanceType::BrayCurtis" + JensenShannon "cuvs::distance::DistanceType::JensenShannon" + HammingUnexpanded "cuvs::distance::DistanceType::HammingUnexpanded" + KLDivergence "cuvs::distance::DistanceType::KLDivergence" + RusselRaoExpanded "cuvs::distance::DistanceType::RusselRaoExpanded" + DiceExpanded "cuvs::distance::DistanceType::DiceExpanded" + cdef extern from "cuml/cluster/kmeans.hpp" namespace \ - "ML::kmeans::KMeansParams": + "cuvs::cluster::kmeans::params": enum InitMethod: KMeansPlusPlus, Random, Array - - cdef struct KMeansParams: +cdef extern from "cuvs/cluster/kmeans.hpp" namespace \ + "cuvs::cluster::kmeans": + cdef struct params: int n_clusters, InitMethod init int max_iter, diff --git a/python/cuml/cuml/neighbors/ann.pxd b/python/cuml/cuml/neighbors/ann.pxd index 8819794b8f..f98e26b7ce 100644 --- a/python/cuml/cuml/neighbors/ann.pxd +++ b/python/cuml/cuml/neighbors/ann.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2021, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,8 +21,8 @@ from libc.stdint cimport uintptr_t from libcpp cimport bool -cdef extern from "raft/spatial/knn/ann_common.h" \ - namespace "raft::spatial::knn": +cdef extern from "cuml/neighbors/knn.hpp" \ + namespace "ML": cdef cppclass knnIndex: pass @@ -30,15 +30,6 @@ cdef extern from "raft/spatial/knn/ann_common.h" \ cdef cppclass knnIndexParam: pass - ctypedef enum QuantizerType: - QT_8bit, - QT_4bit, - QT_8bit_uniform, - QT_4bit_uniform, - QT_fp16, - QT_8bit_direct, - QT_6bit - cdef cppclass IVFParam(knnIndexParam): int nlist int nprobe diff --git a/python/cuml/cuml/tests/test_nearest_neighbors.py b/python/cuml/cuml/tests/test_nearest_neighbors.py index 9f5764a7e9..aa612b7763 100644 --- a/python/cuml/cuml/tests/test_nearest_neighbors.py +++ b/python/cuml/cuml/tests/test_nearest_neighbors.py @@ -573,12 +573,14 @@ def test_nearest_neighbors_rbc(distance_dims, n_neighbors, nrows): X[:query_rows, :], n_neighbors=n_neighbors ) - assert len(brute_d[brute_d != rbc_d]) == 0 + cp.testing.assert_allclose(brute_d, rbc_d, atol=1e-3, rtol=1e-3) # All the distances match so allow a couple mismatched indices # through from potential non-determinism in exact matching # distances - assert len(brute_i[brute_i != rbc_i]) <= 3 + assert ( + len(brute_i[brute_i != rbc_i]) <= 3 if distance != "haversine" else 10 + ) @pytest.mark.parametrize("metric", valid_metrics_sparse()) diff --git a/python/cuml/pyproject.toml b/python/cuml/pyproject.toml index 228cb92b5c..899baf535d 100644 --- a/python/cuml/pyproject.toml +++ b/python/cuml/pyproject.toml @@ -82,6 +82,7 @@ requires-python = ">=3.10" dependencies = [ "cudf==24.10.*,>=0.0.0a0", "cupy-cuda11x>=12.0.0", + "cuvs==24.10.*,>=0.0.0a0", "dask-cuda==24.10.*,>=0.0.0a0", "dask-cudf==24.10.*,>=0.0.0a0", "joblib>=0.11", @@ -164,6 +165,7 @@ matrix-entry = "cuda_suffixed=true;use_cuda_wheels=true" requires = [ "cmake>=3.26.4,!=3.30.0", "cuda-python", + "cuvs==24.10.*,>=0.0.0a0", "cython>=3.0.0", "ninja", "pylibraft==24.10.*,>=0.0.0a0",