diff --git a/.gitignore b/.gitignore index 6f34017d0a..fce749bd14 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ __pycache__/ _skbuild/ build/ cpp/src/legate_library.cc +cpp/src/legate_library.h cpp/src/legate_library.h.eggs/ dist/ legate.raft.egg-info/ diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt deleted file mode 100644 index 898b681443..0000000000 --- a/cpp/CMakeLists.txt +++ /dev/null @@ -1,49 +0,0 @@ -# ============================================================================= -# Copyright (c) 2023, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except -# in compliance with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under the License -# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing permissions and limitations under -# the License. - -cmake_minimum_required(VERSION 3.23.1 FATAL_ERROR) - -# ------------- configure rapids-cmake --------------# - -include(cmake/thirdparty/fetch_rapids.cmake) -include(rapids-cmake) -include(rapids-cpm) -include(rapids-cuda) -include(rapids-export) -include(rapids-find) - -# ------------- configure project --------------# - -rapids_cuda_init_architectures(test_raft) - -project(test_raft LANGUAGES C CXX CUDA) - -# ------------- configure raft -----------------# - -rapids_cpm_init() -include(cmake/thirdparty/get_raft.cmake) - - -# -------------- add legate --------------------# - -find_package(legate_core REQUIRED) -set(BUILD_SHARED_LIBS ON) - -legate_add_cpp_subdirectory(src TARGET legate_raft EXPORT legate_raft-export) -legate_add_cffi(${CMAKE_CURRENT_SOURCE_DIR}/src/legate_raft_cffi.h TARGET legate_raft) -#legate_python_library_template(hello) -#legate_default_python_install(hello EXPORT hello-export) - - -# -------------- compile tasks ----------------- # - diff --git a/cpp/cmake/thirdparty/fetch_rapids.cmake b/cpp/cmake/thirdparty/fetch_rapids.cmake deleted file mode 100644 index 40ba83be9e..0000000000 --- a/cpp/cmake/thirdparty/fetch_rapids.cmake +++ /dev/null @@ -1,21 +0,0 @@ -# ============================================================================= -# Copyright (c) 2023, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except -# in compliance with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under the License -# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing permissions and limitations under -# the License. - -# Use this variable to update RAPIDS and RAFT versions -set(RAPIDS_VERSION "23.04") - -if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) - file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/RAPIDS.cmake - ${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) -endif() -include(${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake deleted file mode 100644 index 5463942adf..0000000000 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ /dev/null @@ -1,62 +0,0 @@ -# ============================================================================= -# Copyright (c) 2023, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except -# in compliance with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under the License -# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing permissions and limitations under -# the License. - -# Use RAPIDS_VERSION from cmake/thirdparty/fetch_rapids.cmake -set(RAFT_VERSION "${RAPIDS_VERSION}") -set(RAFT_FORK "rapidsai") -set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}") - -function(find_and_configure_raft) - set(oneValueArgs VERSION FORK PINNED_TAG COMPILE_LIBRARY ENABLE_NVTX ENABLE_MNMG_DEPENDENCIES) - cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" - "${multiValueArgs}" ${ARGN} ) - - set(RAFT_COMPONENTS "") - if(PKG_COMPILE_LIBRARY) - string(APPEND RAFT_COMPONENTS " compiled") - endif() - - if(PKG_ENABLE_MNMG_DEPENDENCIES) - string(APPEND RAFT_COMPONENTS " distributed") - endif() - - #----------------------------------------------------- - # Invoke CPM find_package() - #----------------------------------------------------- - rapids_cpm_find(raft ${PKG_VERSION} - GLOBAL_TARGETS raft::raft - BUILD_EXPORT_SET raft-template-exports - INSTALL_EXPORT_SET raft-template-exports - COMPONENTS ${RAFT_COMPONENTS} - CPM_ARGS - GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git - GIT_TAG ${PKG_PINNED_TAG} - SOURCE_SUBDIR cpp - OPTIONS - "BUILD_TESTS OFF" - "BUILD_BENCH OFF" - "RAFT_NVTX ${ENABLE_NVTX}" - "RAFT_COMPILE_LIBRARY ${PKG_COMPILE_LIBRARY}" - ) -endfunction() - -# Change pinned tag here to test a commit in CI -# To use a different RAFT locally, set the CMake variable -# CPM_raft_SOURCE=/path/to/local/raft -find_and_configure_raft(VERSION ${RAFT_VERSION}.00 - FORK ${RAFT_FORK} - PINNED_TAG ${RAFT_PINNED_TAG} - COMPILE_LIBRARY ON - ENABLE_MNMG_DEPENDENCIES OFF - ENABLE_NVTX OFF -) diff --git a/cpp/src/CMakeLists.txt b/cpp/src/CMakeLists.txt index 8918f3ca1f..b522fdc18f 100644 --- a/cpp/src/CMakeLists.txt +++ b/cpp/src/CMakeLists.txt @@ -6,6 +6,8 @@ add_library( # RAFT API Wrappers raft/histogram.cu raft/distance.cu + raft/raft_knn.cu + raft/raft_knn_merge.cu # Legate tasks task/add.cc @@ -18,6 +20,8 @@ add_library( task/fill.cc task/find_max.cc task/histogram.cc + task/knn_task.cc + task/knn_merge_task.cc task/log.cc task/matmul.cc task/mul.cc diff --git a/cpp/src/legate_library.h b/cpp/src/legate_library.h deleted file mode 100644 index ee7cca43fa..0000000000 --- a/cpp/src/legate_library.h +++ /dev/null @@ -1,23 +0,0 @@ -#pragma once - -#include "legate.h" - -namespace legate_raft { - -struct Registry { - public: - template - static void record_variant(Args&&... args) - { - get_registrar().record_variant(std::forward(args)...); - } - static legate::TaskRegistrar& get_registrar(); -}; - -template -struct Task : public legate::LegateTask { - using Registrar = Registry; - static constexpr int TASK_ID = ID; -}; - -} diff --git a/cpp/src/legate_raft_cffi.h b/cpp/src/legate_raft_cffi.h index 5e84c268a7..06380223ac 100644 --- a/cpp/src/legate_raft_cffi.h +++ b/cpp/src/legate_raft_cffi.h @@ -13,5 +13,7 @@ enum LegateRaftOpCode { LOG, MATMUL, MUL, + RAFT_KNN, + RAFT_KNN_MERGE, SUM_OVER_AXIS, }; diff --git a/cpp/src/raft/raft_api.hpp b/cpp/src/raft/raft_api.hpp index e42fd46a54..6208db3405 100644 --- a/cpp/src/raft/raft_api.hpp +++ b/cpp/src/raft/raft_api.hpp @@ -16,5 +16,23 @@ #pragma once -void test_histogram(); -void test_distance(); \ No newline at end of file +template +void raft_knn(idx_t n_index_rows, + idx_t n_search_rows, + idx_t n_features, + idx_t k, + std::string& metric, + const value_t* index_ptr, + const value_t* search_ptr, + idx_t* indices_ptr, + float* distances_ptr); + + +template +void raft_knn_merge(size_t n_samples, + int n_parts, + int k, + const idx_t* in_ind, + const float* in_dist, + idx_t* out_ind, + float* out_dist); diff --git a/cpp/src/raft/raft_knn.cu b/cpp/src/raft/raft_knn.cu new file mode 100644 index 0000000000..d547a3d90c --- /dev/null +++ b/cpp/src/raft/raft_knn.cu @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + #include "raft_api.hpp" + + #include + #include + #include + #include + #include + #include + + #ifdef RAFT_COMPILED + #include + #endif + + +template +void raft_knn(idx_t n_index_rows, + idx_t n_search_rows, + idx_t n_features, + idx_t k, + std::string& metric, + const value_t* index_ptr, + const value_t* search_ptr, + idx_t* indices_ptr, + float* distances_ptr) +{ + static raft::device_resources handle; + + auto index_part = raft::make_device_matrix_view(index_ptr, n_index_rows, n_features); + auto search = raft::make_device_matrix_view(search_ptr, n_search_rows, n_features); + auto indices = raft::make_device_matrix_view(indices_ptr, n_search_rows, k); + auto distances = raft::make_device_matrix_view(distances_ptr, n_search_rows, k); + + std::vector> index; + index.push_back(index_part); + + raft::distance::DistanceType distance_type; + if (metric == "l2") { + distance_type = raft::distance::DistanceType::L2SqrtExpanded; + } else { + throw std::invalid_argument("invalid metric"); + } + raft::neighbors::brute_force::knn(handle, + index, + search, + indices, + distances, + distance_type); +} + + +template void raft_knn( + int64_t, + int64_t, + int64_t, + int64_t, + std::string&, + const float*, + const float*, + int64_t*, + float*); diff --git a/cpp/src/raft/raft_knn_merge.cu b/cpp/src/raft/raft_knn_merge.cu new file mode 100644 index 0000000000..605c413894 --- /dev/null +++ b/cpp/src/raft/raft_knn_merge.cu @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "raft_api.hpp" + +#include +#include +#include +#include +#include +#include + +#ifdef RAFT_COMPILED +#include +#endif + + +template +void raft_knn_merge(size_t n_samples, + int n_parts, + int k, + const idx_t* in_ind, + const float* in_dist, + idx_t* out_ind, + float* out_dist) +{ + static raft::device_resources handle; + + auto in_keys = raft::device_matrix_view(in_dist, n_samples * n_parts, k); + auto in_values = raft::device_matrix_view(in_ind, n_samples * n_parts, k); + auto out_keys = raft::device_matrix_view(out_dist, n_samples, k); + auto out_values = raft::device_matrix_view(out_ind, n_samples, k); + + raft::neighbors::brute_force::knn_merge_parts(handle, + in_keys, + in_values, + out_keys, + out_values, + n_samples); + +} + + +template void raft_knn_merge( + size_t, + int, + int, + const int64_t*, + const float*, + int64_t*, + float*); diff --git a/cpp/src/task/knn_merge_task.cc b/cpp/src/task/knn_merge_task.cc new file mode 100644 index 0000000000..0e06736f73 --- /dev/null +++ b/cpp/src/task/knn_merge_task.cc @@ -0,0 +1,47 @@ +#include + +#include "../raft/raft_api.hpp" +#include "../legate_raft.h" +#include "../legate_library.h" + +namespace legate_raft { + + class RAFT_KNN_MERGE_TASK : public Task { + public: + static void gpu_variant(legate::TaskContext& context) + { + size_t n_samples = context.scalars()[0].value(); + int n_parts = context.scalars()[1].value(); + int k = context.scalars()[2].value(); + + auto& in_ind = context.inputs()[0]; + auto& in_dist = context.inputs()[1]; + auto& out_ind = context.outputs()[0]; + auto& out_dist = context.outputs()[1]; + + auto in_ind_read = in_ind.read_accessor().ptr(Legion::DomainPoint(0)); + auto in_dist_read = in_dist.read_accessor().ptr(Legion::DomainPoint(0)); + auto out_ind_write = out_ind.write_accessor().ptr(Legion::DomainPoint(0)); + auto out_dist_write = out_dist.write_accessor().ptr(Legion::DomainPoint(0)); + + raft_knn_merge(n_samples, + n_parts, + k, + in_ind_read, + in_dist_read, + out_ind_write, + out_dist_write); + } + }; + +} // namespace legate_raft + +namespace // unnamed +{ + + static void __attribute__((constructor)) register_tasks(void) + { + legate_raft::RAFT_KNN_MERGE_TASK::register_variants(); + } + +} // namespace \ No newline at end of file diff --git a/cpp/src/task/knn_task.cc b/cpp/src/task/knn_task.cc new file mode 100644 index 0000000000..e67e19db1b --- /dev/null +++ b/cpp/src/task/knn_task.cc @@ -0,0 +1,58 @@ +#include + +#include "../raft/raft_api.hpp" +#include "../legate_raft.h" +#include "../legate_library.h" + +namespace legate_raft { + + class RAFT_KNN_TASK : public Task { + public: + static void gpu_variant(legate::TaskContext& context) + { + int64_t k = context.scalars()[0].value(); // number of nearest neighbors + std::string metric = context.scalars()[1].value(); + + auto& index = context.inputs()[0]; + auto& search = context.inputs()[1]; + auto& indices = context.outputs()[0]; + auto& distances = context.outputs()[1]; + + int64_t n_index_rows = index.shape<2>().hi[0] + 1 - index.shape<2>().lo[0]; + int64_t n_search_rows = search.shape<2>().hi[0] + 1 - search.shape<2>().lo[0]; + int64_t n_features = index.shape<2>().hi[1] + 1; + if(search.shape<2>().hi[1] + 1 != n_features) { + throw std::invalid_argument("index and search should have the same number of features"); + } + + // The offset of the current partition from the start of the store + // is used to obtain the pointer to the start of the partition. + uint64_t offset = search.shape<2>().lo[0]; + auto index_read = index.read_accessor().ptr(Legion::DomainPoint(0)); + auto search_read = search.read_accessor().ptr(Legion::DomainPoint(offset)); + auto indices_write = indices.write_accessor().ptr(Legion::DomainPoint(offset)); + auto distances_write = distances.write_accessor().ptr(Legion::DomainPoint(offset)); + + raft_knn(n_index_rows, + n_search_rows, + n_features, + k, + metric, + index_read, + search_read, + indices_write, + distances_write); + } + }; + +} // namespace legate_raft + +namespace // unnamed +{ + + static void __attribute__((constructor)) register_tasks(void) + { + legate_raft::RAFT_KNN_TASK::register_variants(); + } + +} // namespace \ No newline at end of file diff --git a/legate/__init__.py b/legate/__init__.py index 551c81b25d..dca0f13159 100644 --- a/legate/__init__.py +++ b/legate/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 NVIDIA Corporation +# Copyright 2023 NVIDIA Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,3 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -from pkgutil import extend_path - -__path__ = extend_path(__path__, __name__) diff --git a/legate/raft/__init__.py b/legate/raft/__init__.py index 28fb076aa3..e6494a0ab2 100644 --- a/legate/raft/__init__.py +++ b/legate/raft/__init__.py @@ -1,6 +1,22 @@ +# Copyright 2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + from .array_api import add, exp, fill, log, negative, subtract, sum_over_axis from .core import as_array, as_store, convert from .multiarray import bincount, categorize, matmul, multiply +from .knn import run_knn __all__ = [ "add", @@ -15,6 +31,7 @@ "matmul", "multiply", "negative", + "run_knn", "subtract", "sum_over_axis", ] diff --git a/legate/raft/knn.py b/legate/raft/knn.py new file mode 100644 index 0000000000..aaac35b1c9 --- /dev/null +++ b/legate/raft/knn.py @@ -0,0 +1,95 @@ +# Copyright 2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import legate.core.types as types +from legate.core import Rect +from .library import user_context as context +from .library import user_lib +from .core import array_to_store, store_to_array +import numpy as np + + +def run_knn(index: np.ndarray, + search: np.ndarray, + n_neighbors: int, + metric : str = 'l2'): + index_batch_size = 512 + query_batch_size = 8 + n_features = index.shape[1] + + # Setup index store + index_store = array_to_store(index) + index_store = index_store.partition_by_tiling((index_batch_size, n_features)) + + # Setup search store + search_store = array_to_store(search) + search_store = search_store.partition_by_tiling((query_batch_size, n_features)) + + # Setup buffer stores + n_parts = index_store.partition.color_shape[0] + buffer_size = n_parts * query_batch_size + indices_buffer_array = np.zeros((buffer_size, n_neighbors), dtype=np.int64) + distances_buffer_array = np.zeros((buffer_size, n_neighbors), dtype=np.float32) + indices_buffer_store = array_to_store(indices_buffer_array) + distances_buffer_store = array_to_store(distances_buffer_array) + indices_buffer_store = \ + indices_buffer_store.partition_by_tiling((query_batch_size, n_neighbors)) + distances_buffer_store = \ + distances_buffer_store.partition_by_tiling((query_batch_size, n_neighbors)) + + # Run KNN task + nn_task = context.create_manual_task(user_lib.cffi.RAFT_KNN, + launch_domain=Rect((n_parts, 1))) + nn_task.add_scalar_arg(n_neighbors, types.int64) + nn_task.add_scalar_arg(metric, types.string) + nn_task.add_input(index_store) + nn_task.add_input(search_store) + nn_task.add_output(indices_buffer_store) + nn_task.add_output(distances_buffer_store) + nn_task.execute() + + # Gather buffer store partitions + indices_buffer_array = store_to_array(indices_buffer_store.store) + indices_buffer_array = np.array(indices_buffer_array, copy=True) + indices_buffer_gathered = array_to_store(indices_buffer_array) + distances_buffer_array = store_to_array(distances_buffer_store.store) + distances_buffer_array = np.array(distances_buffer_array, copy=True) + distances_buffer_gathered = array_to_store(distances_buffer_array) + + # Setup output stores + n_search_rows = search.shape[0] + indices_output = np.zeros((n_search_rows, n_neighbors), dtype=np.int64) + distances_output = np.zeros((n_search_rows, n_neighbors), dtype=np.float32) + indices_store = array_to_store(indices_output) + distances_store = array_to_store(distances_output) + + # Run KNN merge task + merge_task = context.create_manual_task(user_lib.cffi.RAFT_KNN_MERGE, + launch_domain=Rect((1,))) + merge_task.add_scalar_arg(query_batch_size, types.int64) + merge_task.add_scalar_arg(n_parts, types.int64) + merge_task.add_scalar_arg(n_neighbors, types.int64) + + merge_task.add_input(indices_buffer_gathered) + merge_task.add_input(distances_buffer_gathered) + merge_task.add_output(indices_store) + merge_task.add_output(distances_store) + merge_task.execute() + + # Produce output array + indices = store_to_array(indices_store) + distances = store_to_array(distances_store) + + return distances, indices diff --git a/pytest/test_knn.py b/pytest/test_knn.py new file mode 100644 index 0000000000..46d65463fb --- /dev/null +++ b/pytest/test_knn.py @@ -0,0 +1,24 @@ +import numpy as np +from sklearn.datasets import make_blobs +from sklearn.neighbors import NearestNeighbors +from legate.raft import run_knn + + +def test_knn(): + k = 8 + metric = 'l2' + n_features = 20 + n_index_rows = 500 + n_search_rows = 16 + + X, _ = make_blobs(n_samples=n_index_rows + n_search_rows, + centers=5, n_features=n_features) + blob_index = X[:n_index_rows].astype(np.float32) + blob_search = X[n_index_rows:].astype(np.float32) + nn = NearestNeighbors(n_neighbors=k) + nn.fit(blob_index) + ref_distances, ref_indices = nn.kneighbors(blob_search, return_distance=True) + + distances, indices = run_knn(blob_index, blob_search, k, metric) + np.testing.assert_allclose(indices, ref_indices) + np.testing.assert_allclose(distances, ref_distances, rtol=0.001)