Skip to content

Commit

Permalink
distributed index + answered part of the reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Apr 18, 2023
1 parent 646452c commit db803cb
Show file tree
Hide file tree
Showing 13 changed files with 245 additions and 39 deletions.
2 changes: 2 additions & 0 deletions cpp/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ add_library(

# RAFT API wrappers
raft/raft_knn.cu
raft/raft_knn_merge.cu

# Legate tasks
legate_tasks/knn_task.cc
legate_tasks/knn_merge_task.cc

# Library templates
${TEMPLATE_SOURCES}
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/legate_raft_cffi.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
enum LegateRaftOpCode {
_OP_CODE_BASE = 0,
RAFT_KNN_OP = 1,
RAFT_KNN = 1,
RAFT_KNN_MERGE = 2,
};
2 changes: 1 addition & 1 deletion cpp/src/legate_tasks/compute_1nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ namespace // unnamed
legate_raft::Compute1NNTask::register_variants();
}

} // namespace
} // namespace
2 changes: 1 addition & 1 deletion cpp/src/legate_tasks/histogram.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ namespace // unnamed
legate_raft::HistogramTask::register_variants();
}

} // namespace
} // namespace
47 changes: 47 additions & 0 deletions cpp/src/legate_tasks/knn_merge_task.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include <stdexcept>

#include "../raft/raft_api.hpp"
#include "../legate_raft.h"
#include "../legate_library.h"

namespace legate_raft {

class RAFT_KNN_MERGE_TASK : public Task<RAFT_KNN_MERGE_TASK, RAFT_KNN_MERGE> {
public:
static void gpu_variant(legate::TaskContext& context)
{
size_t n_samples = context.scalars()[0].value<int64_t>();
int n_parts = context.scalars()[1].value<int64_t>();
int k = context.scalars()[2].value<int64_t>();

auto& in_ind = context.inputs()[0];
auto& in_dist = context.inputs()[1];
auto& out_ind = context.outputs()[0];
auto& out_dist = context.outputs()[1];

auto in_ind_read = in_ind.read_accessor<int64_t, 2>().ptr(Legion::DomainPoint(0));
auto in_dist_read = in_dist.read_accessor<float, 2>().ptr(Legion::DomainPoint(0));
auto out_ind_write = out_ind.write_accessor<int64_t, 2>().ptr(Legion::DomainPoint(0));
auto out_dist_write = out_dist.write_accessor<float, 2>().ptr(Legion::DomainPoint(0));

raft_knn_merge(n_samples,
n_parts,
k,
in_ind_read,
in_dist_read,
out_ind_write,
out_dist_write);
}
};

} // namespace legate_raft

namespace // unnamed
{

static void __attribute__((constructor)) register_tasks(void)
{
legate_raft::RAFT_KNN_MERGE_TASK::register_variants();
}

} // namespace
6 changes: 4 additions & 2 deletions cpp/src/legate_tasks/knn_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

namespace legate_raft {

class RAFT_KNN_TASK : public Task<RAFT_KNN_TASK, RAFT_KNN_OP> {
class RAFT_KNN_TASK : public Task<RAFT_KNN_TASK, RAFT_KNN> {
public:
static void gpu_variant(legate::TaskContext& context)
{
int64_t k = context.scalars()[0].value<int64_t>();
int64_t k = context.scalars()[0].value<int64_t>(); // number of nearest neighbors
std::string metric = context.scalars()[1].value<std::string>();

auto& index = context.inputs()[0];
Expand All @@ -25,6 +25,8 @@ namespace legate_raft {
throw std::invalid_argument("index and search should have the same number of features");
}

// The offset of the current partition from the start of the store
// is used to obtain the pointer to the start of the partition.
uint64_t offset = search.shape<2>().lo[0];
auto index_read = index.read_accessor<float, 2>().ptr(Legion::DomainPoint(0));
auto search_read = search.read_accessor<float, 2>().ptr(Legion::DomainPoint(offset));
Expand Down
10 changes: 10 additions & 0 deletions cpp/src/raft/raft_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,13 @@ void raft_knn(idx_t n_index_rows,
const value_t* search_ptr,
idx_t* indices_ptr,
float* distances_ptr);


template<typename idx_t>
void raft_knn_merge(size_t n_samples,
int n_parts,
int k,
const idx_t* in_ind,
const float* in_dist,
idx_t* out_ind,
float* out_dist);
12 changes: 6 additions & 6 deletions cpp/src/raft/raft_knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void raft_knn(idx_t n_index_rows,
idx_t* indices_ptr,
float* distances_ptr)
{
raft::device_resources handle;
static raft::device_resources handle;

auto index_part = raft::make_device_matrix_view<const value_t, idx_t, raft::row_major>(index_ptr, n_index_rows, n_features);
auto search = raft::make_device_matrix_view<const value_t, idx_t, raft::row_major>(search_ptr, n_search_rows, n_features);
Expand All @@ -56,11 +56,11 @@ void raft_knn(idx_t n_index_rows,
throw std::invalid_argument("invalid metric");
}
raft::neighbors::brute_force::knn(handle,
index,
search,
indices,
distances,
distance_type);
index,
search,
indices,
distances,
distance_type);
}


Expand Down
64 changes: 64 additions & 0 deletions cpp/src/raft/raft_knn_merge.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "raft_api.hpp"

#include <cstdint>
#include <raft/core/device_resources.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/brute_force.cuh>

#ifdef RAFT_COMPILED
#include <raft/distance/specializations.cuh>
#endif


template<typename idx_t>
void raft_knn_merge(size_t n_samples,
int n_parts,
int k,
const idx_t* in_ind,
const float* in_dist,
idx_t* out_ind,
float* out_dist)
{
static raft::device_resources handle;

auto in_keys = raft::device_matrix_view<const float, idx_t, raft::row_major>(in_dist, n_samples * n_parts, k);
auto in_values = raft::device_matrix_view<const idx_t, idx_t, raft::row_major>(in_ind, n_samples * n_parts, k);
auto out_keys = raft::device_matrix_view<float, idx_t, raft::row_major>(out_dist, n_samples, k);
auto out_values = raft::device_matrix_view<idx_t, idx_t, raft::row_major>(out_ind, n_samples, k);

raft::neighbors::brute_force::knn_merge_parts(handle,
in_keys,
in_values,
out_keys,
out_values,
n_samples);

}


template void raft_knn_merge(
size_t,
int,
int,
const int64_t*,
const float*,
int64_t*,
float*);
6 changes: 1 addition & 5 deletions legate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 NVIDIA Corporation
# Copyright 2023 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,7 +12,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pkgutil import extend_path

__path__ = extend_path(__path__, __name__)
19 changes: 18 additions & 1 deletion legate/raft/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,18 @@
from .knn import run_knn
# Copyright 2023 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from .knn import run_knn

__all__ = ["run_knn"]
15 changes: 15 additions & 0 deletions legate/raft/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2023 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from dataclasses import dataclass
import pyarrow as pa
from legate.core import Store
Expand Down
96 changes: 74 additions & 22 deletions legate/raft/knn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2023 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import legate.core.types as types
from legate.core import Rect
from .library import user_context as context
Expand All @@ -6,36 +21,73 @@
import numpy as np


def run_knn(index, search, n_neighbors, metric='l2'):
batch_size = 8
def run_knn(index: np.ndarray,
search: np.ndarray,
n_neighbors: int,
metric : str ='l2'):
index_batch_size = 512
query_batch_size = 8
n_features = index.shape[1]

# Setup input partitions
# Setup index store
index_store = array_to_store(index)
index_store = index_store.partition_by_tiling((index_batch_size, n_features))

# Setup search store
search_store = array_to_store(search)
n_features = search.shape[1]
search_store = search_store.partition_by_tiling((batch_size, n_features))
search_store = search_store.partition_by_tiling((query_batch_size, n_features))

# Setup buffer stores
n_parts = index_store.partition.color_shape[0]
buffer_size = n_parts * query_batch_size
indices_buffer_array = np.zeros((buffer_size, n_neighbors), dtype=np.int64)
distances_buffer_array = np.zeros((buffer_size, n_neighbors), dtype=np.float32)
indices_buffer_store = array_to_store(indices_buffer_array)
distances_buffer_store = array_to_store(distances_buffer_array)
indices_buffer_store = indices_buffer_store.partition_by_tiling((query_batch_size, n_neighbors))
distances_buffer_store = distances_buffer_store.partition_by_tiling((query_batch_size, n_neighbors))

# Run KNN task
nn_task = context.create_manual_task(user_lib.cffi.RAFT_KNN,
launch_domain=Rect((n_parts, 1)))
nn_task.add_scalar_arg(n_neighbors, types.int64)
nn_task.add_scalar_arg(metric, types.string)
nn_task.add_input(index_store)
nn_task.add_input(search_store)
nn_task.add_output(indices_buffer_store)
nn_task.add_output(distances_buffer_store)
nn_task.execute()

# Setup output partitions
# Gather buffer store partitions
indices_buffer_array = store_to_array(indices_buffer_store.store)
indices_buffer_array = np.array(indices_buffer_array, copy=True)
indices_buffer_gathered = array_to_store(indices_buffer_array)
distances_buffer_array = store_to_array(distances_buffer_store.store)
distances_buffer_array = np.array(distances_buffer_array, copy=True)
distances_buffer_gathered = array_to_store(distances_buffer_array)

# Setup output stores
n_search_rows = search.shape[0]
indices_output = np.zeros((n_search_rows, n_neighbors), dtype=np.int64)
distances_output = np.zeros((n_search_rows, n_neighbors), dtype=np.float32)
indices_store = array_to_store(indices_output)
distances_store = array_to_store(distances_output)
indices_store = indices_store.partition_by_tiling((batch_size, n_neighbors))
distances_store = distances_store.partition_by_tiling((batch_size, n_neighbors))

launch_shape = search_store.partition.color_shape
task = context.create_manual_task(user_lib.cffi.RAFT_KNN_OP,
launch_domain=Rect(launch_shape))
task.add_scalar_arg(n_neighbors, types.int64)
task.add_scalar_arg(metric, types.string)
task.add_input(index_store)
task.add_input(search_store)
task.add_output(indices_store)
task.add_output(distances_store)
task.execute()

indices = store_to_array(indices_store.store)
distances = store_to_array(distances_store.store)

# Run KNN merge task
merge_task = context.create_manual_task(user_lib.cffi.RAFT_KNN_MERGE,
launch_domain=Rect((1,)))
merge_task.add_scalar_arg(query_batch_size, types.int64)
merge_task.add_scalar_arg(n_parts, types.int64)
merge_task.add_scalar_arg(n_neighbors, types.int64)

merge_task.add_input(indices_buffer_gathered)
merge_task.add_input(distances_buffer_gathered)
merge_task.add_output(indices_store)
merge_task.add_output(distances_store)
merge_task.execute()

# Produce output array
indices = store_to_array(indices_store)
distances = store_to_array(distances_store)

return distances, indices

0 comments on commit db803cb

Please sign in to comment.