Skip to content

Commit

Permalink
Merge pull request #1 from cjnolet/knn-work
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet authored Apr 18, 2023
2 parents 98c5ece + 8c59346 commit 80d836d
Show file tree
Hide file tree
Showing 16 changed files with 409 additions and 162 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ __pycache__/
_skbuild/
build/
cpp/src/legate_library.cc
cpp/src/legate_library.h
cpp/src/legate_library.h.eggs/
dist/
legate.raft.egg-info/
Expand Down
49 changes: 0 additions & 49 deletions cpp/CMakeLists.txt

This file was deleted.

21 changes: 0 additions & 21 deletions cpp/cmake/thirdparty/fetch_rapids.cmake

This file was deleted.

62 changes: 0 additions & 62 deletions cpp/cmake/thirdparty/get_raft.cmake

This file was deleted.

4 changes: 4 additions & 0 deletions cpp/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ add_library(
# RAFT API Wrappers
raft/histogram.cu
raft/distance.cu
raft/raft_knn.cu
raft/raft_knn_merge.cu

# Legate tasks
task/add.cc
Expand All @@ -18,6 +20,8 @@ add_library(
task/fill.cc
task/find_max.cc
task/histogram.cc
task/knn_task.cc
task/knn_merge_task.cc
task/log.cc
task/matmul.cc
task/mul.cc
Expand Down
23 changes: 0 additions & 23 deletions cpp/src/legate_library.h

This file was deleted.

2 changes: 2 additions & 0 deletions cpp/src/legate_raft_cffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@ enum LegateRaftOpCode {
LOG,
MATMUL,
MUL,
RAFT_KNN,
RAFT_KNN_MERGE,
SUM_OVER_AXIS,
};
22 changes: 20 additions & 2 deletions cpp/src/raft/raft_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,23 @@

#pragma once

void test_histogram();
void test_distance();
template<typename idx_t, typename value_t>
void raft_knn(idx_t n_index_rows,
idx_t n_search_rows,
idx_t n_features,
idx_t k,
std::string& metric,
const value_t* index_ptr,
const value_t* search_ptr,
idx_t* indices_ptr,
float* distances_ptr);


template<typename idx_t>
void raft_knn_merge(size_t n_samples,
int n_parts,
int k,
const idx_t* in_ind,
const float* in_dist,
idx_t* out_ind,
float* out_dist);
76 changes: 76 additions & 0 deletions cpp/src/raft/raft_knn.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "raft_api.hpp"

#include <cstdint>
#include <raft/core/device_resources.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/brute_force.cuh>

#ifdef RAFT_COMPILED
#include <raft/distance/specializations.cuh>
#endif


template<typename idx_t, typename value_t>
void raft_knn(idx_t n_index_rows,
idx_t n_search_rows,
idx_t n_features,
idx_t k,
std::string& metric,
const value_t* index_ptr,
const value_t* search_ptr,
idx_t* indices_ptr,
float* distances_ptr)
{
static raft::device_resources handle;

auto index_part = raft::make_device_matrix_view<const value_t, idx_t, raft::row_major>(index_ptr, n_index_rows, n_features);
auto search = raft::make_device_matrix_view<const value_t, idx_t, raft::row_major>(search_ptr, n_search_rows, n_features);
auto indices = raft::make_device_matrix_view<idx_t, idx_t, raft::row_major>(indices_ptr, n_search_rows, k);
auto distances = raft::make_device_matrix_view<value_t, idx_t, raft::row_major>(distances_ptr, n_search_rows, k);

std::vector<raft::device_matrix_view<const value_t, idx_t, raft::row_major>> index;
index.push_back(index_part);

raft::distance::DistanceType distance_type;
if (metric == "l2") {
distance_type = raft::distance::DistanceType::L2SqrtExpanded;
} else {
throw std::invalid_argument("invalid metric");
}
raft::neighbors::brute_force::knn(handle,
index,
search,
indices,
distances,
distance_type);
}


template void raft_knn(
int64_t,
int64_t,
int64_t,
int64_t,
std::string&,
const float*,
const float*,
int64_t*,
float*);
64 changes: 64 additions & 0 deletions cpp/src/raft/raft_knn_merge.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "raft_api.hpp"

#include <cstdint>
#include <raft/core/device_resources.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/brute_force.cuh>

#ifdef RAFT_COMPILED
#include <raft/distance/specializations.cuh>
#endif


template<typename idx_t>
void raft_knn_merge(size_t n_samples,
int n_parts,
int k,
const idx_t* in_ind,
const float* in_dist,
idx_t* out_ind,
float* out_dist)
{
static raft::device_resources handle;

auto in_keys = raft::device_matrix_view<const float, idx_t, raft::row_major>(in_dist, n_samples * n_parts, k);
auto in_values = raft::device_matrix_view<const idx_t, idx_t, raft::row_major>(in_ind, n_samples * n_parts, k);
auto out_keys = raft::device_matrix_view<float, idx_t, raft::row_major>(out_dist, n_samples, k);
auto out_values = raft::device_matrix_view<idx_t, idx_t, raft::row_major>(out_ind, n_samples, k);

raft::neighbors::brute_force::knn_merge_parts(handle,
in_keys,
in_values,
out_keys,
out_values,
n_samples);

}


template void raft_knn_merge(
size_t,
int,
int,
const int64_t*,
const float*,
int64_t*,
float*);
Loading

0 comments on commit 80d836d

Please sign in to comment.