Skip to content

Commit

Permalink
Add Naive Bayes example prototype.
Browse files Browse the repository at this point in the history
  • Loading branch information
csadorf committed Apr 17, 2023
1 parent dd587cf commit afd215e
Show file tree
Hide file tree
Showing 41 changed files with 2,472 additions and 159 deletions.
14 changes: 12 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
build/
.eggs/
.idea

__pycache__/
_skbuild/
build/
cpp/src/legate_library.cc
cpp/src/legate_library.h.eggs/
dist/
legate.raft.egg-info/
legate/raft/__pycache__/
legate/raft/install_info.py
legate/raft/library.py
pytest/__pycache__
30 changes: 30 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
---
# Copyright (c) 2023, NVIDIA CORPORATION.

repos:
- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black
exclude: 'legate/naive_bayes/library\.py$'
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
exclude: 'legate/naive_bayes/library\.py$'
- repo: https://github.com/codespell-project/codespell
rev: v2.2.4
hooks:
- id: codespell
additional_dependencies: [tomli]
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
hooks:
- id: nbstripout

default_language_version:
python: python3
46 changes: 46 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# =============================================================================
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

cmake_minimum_required(VERSION 3.23.1 FATAL_ERROR)

# ------------- configure rapids-cmake --------------#

include(cmake/thirdparty/fetch_rapids.cmake)
include(rapids-cmake)
include(rapids-cpm)
include(rapids-cuda)
include(rapids-export)
include(rapids-find)

# ------------- configure project -------------- #

rapids_cuda_init_architectures(legate_raft)
project(legate_raft LANGUAGES C CXX CUDA)

# ------------- configure raft ----------------- #

rapids_cpm_init()
include(cmake/thirdparty/get_raft.cmake)

# -------------- add requirements -------------- #

find_package(legate_core REQUIRED)
set(BUILD_SHARED_LIBS ON)

# -------------- compile tasks ----------------- #

# C++ layer
legate_add_cpp_subdirectory(cpp/src TARGET legate_raft EXPORT legate_raft-export)

# Python layer
add_subdirectory(legate)
18 changes: 16 additions & 2 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# Abort script on first error
set -e

INSTALL_PREFIX=${INSTALL_PREFIX:=${PREFIX:=${CONDA_PREFIX}}}

PARALLEL_LEVEL=${PARALLEL_LEVEL:=`nproc`}

BUILD_TYPE=Release
Expand All @@ -24,6 +26,13 @@ fi

if [ "$1" == "clean" ]; then
rm -rf cpp/build
rm -rf dist legate.raft.egg-info
rm cpp/src/legate_library.cc
rm cpp/src/legate_library.h
python setup.py clean --all
rm legate/raft/install_info.py
rm legate/raft/library.py
rm -rf pytest/__pycache__
exit 0
fi

Expand All @@ -35,7 +44,12 @@ cmake \
-DRAFT_NVTX=OFF \
-DCMAKE_CUDA_ARCHITECTURES="NATIVE" \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} \
${EXTRA_CMAKE_ARGS} \
../
../../

cmake --build . -j${PARALLEL_LEVEL}
cmake --install . --prefix ${INSTALL_PREFIX}

cmake --build . -j${PARALLEL_LEVEL}
cd ../..
python setup.py install
21 changes: 21 additions & 0 deletions cmake/thirdparty/fetch_rapids.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# =============================================================================
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

# Use this variable to update RAPIDS and RAFT versions
set(RAPIDS_VERSION "23.04")

if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake)
file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/RAPIDS.cmake
${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake)
endif()
include(${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake)
62 changes: 62 additions & 0 deletions cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# =============================================================================
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

# Use RAPIDS_VERSION from cmake/thirdparty/fetch_rapids.cmake
set(RAFT_VERSION "${RAPIDS_VERSION}")
set(RAFT_FORK "rapidsai")
set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}")

function(find_and_configure_raft)
set(oneValueArgs VERSION FORK PINNED_TAG COMPILE_LIBRARY ENABLE_NVTX ENABLE_MNMG_DEPENDENCIES)
cmake_parse_arguments(PKG "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN} )

set(RAFT_COMPONENTS "")
if(PKG_COMPILE_LIBRARY)
string(APPEND RAFT_COMPONENTS " compiled")
endif()

if(PKG_ENABLE_MNMG_DEPENDENCIES)
string(APPEND RAFT_COMPONENTS " distributed")
endif()

#-----------------------------------------------------
# Invoke CPM find_package()
#-----------------------------------------------------
rapids_cpm_find(raft ${PKG_VERSION}
GLOBAL_TARGETS raft::raft
BUILD_EXPORT_SET raft-template-exports
INSTALL_EXPORT_SET raft-template-exports
COMPONENTS ${RAFT_COMPONENTS}
CPM_ARGS
GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git
GIT_TAG ${PKG_PINNED_TAG}
SOURCE_SUBDIR cpp
OPTIONS
"BUILD_TESTS OFF"
"BUILD_BENCH OFF"
"RAFT_NVTX ${ENABLE_NVTX}"
"RAFT_COMPILE_LIBRARY ${PKG_COMPILE_LIBRARY}"
)
endfunction()

# Change pinned tag here to test a commit in CI
# To use a different RAFT locally, set the CMake variable
# CPM_raft_SOURCE=/path/to/local/raft
find_and_configure_raft(VERSION ${RAFT_VERSION}.00
FORK ${RAFT_FORK}
PINNED_TAG ${RAFT_PINNED_TAG}
COMPILE_LIBRARY ON
ENABLE_MNMG_DEPENDENCIES OFF
ENABLE_NVTX OFF
)
42 changes: 0 additions & 42 deletions cpp/legate_raft/install_info.py

This file was deleted.

Loading

0 comments on commit afd215e

Please sign in to comment.