Skip to content

Commit

Permalink
merge main branch
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Apr 18, 2023
2 parents db803cb + 98c5ece commit 8c59346
Show file tree
Hide file tree
Showing 40 changed files with 2,032 additions and 82 deletions.
11 changes: 7 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
build/
.eggs/
.idea
__pycache__/
_skbuild/
build/
cpp/src/legate_library.cc
cpp/src/legate_library.h
cpp/src/legate_library.h.eggs/
dist/
legate.raft.egg-info/
legate/raft/__pycache__/
legate/raft/library.py
legate/raft/install_info.py
legate/raft/library.py
pytest/__pycache__
cpp/src/legate_library.cc
cpp/src/legate_library.h
30 changes: 30 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
---
# Copyright (c) 2023, NVIDIA CORPORATION.

repos:
- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black
exclude: 'legate/naive_bayes/library\.py$'
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
exclude: 'legate/naive_bayes/library\.py$'
- repo: https://github.com/codespell-project/codespell
rev: v2.2.4
hooks:
- id: codespell
additional_dependencies: [tomli]
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
hooks:
- id: nbstripout

default_language_version:
python: python3
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mamba install -c conda-forge -c rapidsai-nightly -c nvidia libraft==23.04*

- Can potentially copy in `bincount` from legate.core or the `raft::stats::histogram` or other RAFT primitives.
- maybe we could try to use pylibraft as a legate task if they make progress on exposing legate tasks through Python UDFs?
- Interested in an end-to-end prototype that uses the underlying indices array from the legate.sparse csr w/ bincount/histogram primitive.

### K-NN

Expand Down
22 changes: 19 additions & 3 deletions cpp/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,29 @@ legate_cpp_library_template(legate_raft TEMPLATE_SOURCES)
add_library(
legate_raft

# RAFT API wrappers
# RAFT API Wrappers
raft/histogram.cu
raft/distance.cu
raft/raft_knn.cu
raft/raft_knn_merge.cu

# Legate tasks
legate_tasks/knn_task.cc
legate_tasks/knn_merge_task.cc
task/add.cc
# task/compute_1nn.cc
task/add_constant.cc
task/bincount.cc
task/categorize.cc
task/convert.cc
task/exp.cc
task/fill.cc
task/find_max.cc
task/histogram.cc
task/knn_task.cc
task/knn_merge_task.cc
task/log.cc
task/matmul.cc
task/mul.cc
task/sum_over_axis.cc

# Library templates
${TEMPLATE_SOURCES}
Expand Down
20 changes: 17 additions & 3 deletions cpp/src/legate_raft_cffi.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
enum LegateRaftOpCode {
_OP_CODE_BASE = 0,
RAFT_KNN = 1,
RAFT_KNN_MERGE = 2,
};
ADD,
ADD_CONSTANT,
BINCOUNT,
CATEGORIZE,
CONVERT,
EXP,
FILL,
FIND_MAX,
FUSED_1NN,
HISTOGRAM,
LOG,
MATMUL,
MUL,
RAFT_KNN,
RAFT_KNN_MERGE,
SUM_OVER_AXIS,
};
27 changes: 0 additions & 27 deletions cpp/src/legate_tasks/histogram.cc

This file was deleted.

72 changes: 72 additions & 0 deletions cpp/src/task/add.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/* Copyright 2023 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

#include <cmath>

#include "legate_library.h"
#include "legate_raft_cffi.h"

#include "core/utilities/dispatch.h"

namespace legate_raft {

namespace {

struct add_fn {
template <legate::LegateTypeCode CODE, int32_t DIM>
void operator()(legate::Store& output, legate::Store& x1, legate::Store& x2)
{
using VAL = legate::legate_type_of<CODE>;

auto shape = x1.shape<DIM>();

if (shape.empty()) return;

auto x1_acc = x1.read_accessor<VAL, DIM>();
auto x2_acc = x2.read_accessor<VAL, DIM>();
auto output_acc = output.write_accessor<VAL, DIM>();

for (legate::PointInRectIterator<DIM> it(shape, false /*fortran order*/); it.valid(); ++it) {
auto p = *it;
output_acc[p] = x1_acc[p] + x2_acc[p];
}
}
};

} // namespace

class AddTask : public Task<AddTask, ADD> {
public:
static void cpu_variant(legate::TaskContext& context)
{
auto& input1 = context.inputs()[0];
auto& input2 = context.inputs()[1];
auto& output = context.outputs()[0];

legate::double_dispatch(input1.dim(), input1.code(), add_fn{}, output, input1, input2);
}
};

} // namespace legate_raft

namespace {

static void __attribute__((constructor)) register_tasks()
{
legate_raft::AddTask::register_variants();
}

} // namespace
69 changes: 69 additions & 0 deletions cpp/src/task/add_constant.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/* Copyright 2023 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

#include "legate_library.h"
#include "legate_raft_cffi.h"

#include "core/utilities/dispatch.h"

namespace legate_raft {

namespace {

struct add_constant_fn {
template <legate::LegateTypeCode CODE, int32_t DIM>
void operator()(legate::Store& output, legate::Store& input, legate::Scalar& value)
{
using VAL = legate::legate_type_of<CODE>;

auto shape = input.shape<DIM>();

if (shape.empty()) return;

auto input_acc = input.read_accessor<VAL, DIM>();
auto output_acc = output.write_accessor<VAL, DIM>();

for (legate::PointInRectIterator<DIM> it(shape, false /*fortran order*/); it.valid(); ++it) {
auto p = *it;
output_acc[p] = input_acc[p] + value.value<VAL>();
}
}
};

} // namespace

class AddConstantTask : public Task<AddConstantTask, ADD_CONSTANT> {
public:
static void cpu_variant(legate::TaskContext& context)
{
auto& input = context.inputs()[0];
auto& value = context.scalars()[0];
auto& output = context.outputs()[0];

legate::double_dispatch(input.dim(), input.code(), add_constant_fn{}, output, input, value);
}
};

} // namespace legate_raft

namespace {

static void __attribute__((constructor)) register_tasks()
{
legate_raft::AddConstantTask::register_variants();
}

} // namespace
55 changes: 55 additions & 0 deletions cpp/src/task/bincount.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/* Copyright 2023 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

#include "legate_library.h"
#include "legate_raft_cffi.h"

#include "core/utilities/dispatch.h"

namespace legate_raft {

class BincountTask : public Task<BincountTask, BINCOUNT> {
public:
static void cpu_variant(legate::TaskContext& context)
{
auto& input = context.inputs()[0];
auto& output = context.reductions()[0];

auto in_shape = input.shape<1>();
auto out_shape = output.shape<1>();

auto in_acc = input.read_accessor<uint64_t, 1>();
auto out_acc = output.reduce_accessor<legate::SumReduction<uint64_t>, true, 1>();

for (legate::PointInRectIterator<1> it(in_shape); it.valid(); ++it) {
auto& value = in_acc[*it];
legate::Point<1> pos_reduce(static_cast<int64_t>(value));

if (out_shape.contains(pos_reduce)) out_acc.reduce(pos_reduce, 1);
}
}
};

} // namespace legate_raft

namespace {

static void __attribute__((constructor)) register_tasks()
{
legate_raft::BincountTask::register_variants();
}

} // namespace
Loading

0 comments on commit 8c59346

Please sign in to comment.