Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Migrating cuml comms to raft::comms #7

Merged
merged 51 commits into from
Jun 3, 2020
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
fd38dbb
Initial commit of raft comms
cjnolet May 7, 2020
8a6e054
Removing cuml verbiage
cjnolet May 7, 2020
2bfc0dc
Renaming to comms_t, removing logs for now, adding to handle
cjnolet May 8, 2020
f129710
Looking like it's building!
cjnolet May 8, 2020
a0c13af
Adding NCCL and UCX build for tests
cjnolet May 8, 2020
7ad2053
Merge branch 'fea-ext-migrate-cumlHandle_impl-into-raft' into fea-ext…
cjnolet May 8, 2020
cb19482
Consolidating injection functions
cjnolet May 8, 2020
1d66637
Merge branch 'fea-ext-migrate-cumlHandle_impl-into-raft' into fea-ext…
cjnolet May 12, 2020
1548177
Making progress on Python. Need to refactor the comms interface a li…
cjnolet May 12, 2020
5d37e9e
Cython is building!
cjnolet May 13, 2020
bd1b876
The comms tests pass!!!
cjnolet May 13, 2020
98496d3
Fixing flake8 style
cjnolet May 13, 2020
fb52f79
Running clang format and removing getDataType
cjnolet May 13, 2020
d666314
Merge remote-tracking branch 'github/branch-0.14' into fea-ext-comms
cjnolet May 13, 2020
46a27eb
Cleaning up
cjnolet May 13, 2020
fc6ba2e
Fixing python style
cjnolet May 13, 2020
842f533
Fixing cpp style
cjnolet May 13, 2020
43f0d69
Adding copyright headers
cjnolet May 13, 2020
ac6e699
Adding init py for tests
cjnolet May 13, 2020
c442d87
Adding license headers and consistent namespacing
cjnolet May 13, 2020
32e63c1
More cleanup
cjnolet May 13, 2020
d377347
Cleaning up raft.dask.common.Comms
cjnolet May 13, 2020
e1b4ea7
Ignoring raft egg artifacts
cjnolet May 13, 2020
736696b
Cleaning up raft dask utils
cjnolet May 13, 2020
55d9dfd
More cleanup, copyright headers, and docs
cjnolet May 13, 2020
6eb0325
Removing the last of references to cuml
cjnolet May 13, 2020
b65e20d
Fixing python style
cjnolet May 13, 2020
7a845cb
Fixing c++ style
cjnolet May 13, 2020
c269300
Updating changelog
cjnolet May 13, 2020
d4aa5c5
Testing non-ucx cluster for pytests
cjnolet May 13, 2020
1ee1363
Implementing review feedback
cjnolet May 18, 2020
c165a36
More review feedback
cjnolet May 19, 2020
12f3db7
Fixing style
cjnolet May 19, 2020
aae4625
FIX Use relative imports
dantegd May 19, 2020
bc39321
Adding compile-time templates for comms_t to make interaction more st…
cjnolet May 19, 2020
e9c0995
Merge pull request #1 from dantegd/fea-ext-comms-pr1
cjnolet May 19, 2020
3150fbd
Using std::this_thread::yield instead of pthread_yield()
cjnolet May 19, 2020
5628ad2
Adding python tests for collective functions
cjnolet May 19, 2020
488d0d5
Running cpp style
cjnolet May 19, 2020
3d362d0
Updating tabbing for pytests
cjnolet May 19, 2020
417a4bd
Following clang tidy standards
cjnolet May 19, 2020
cb92349
Moving get_type out of comms_t
cjnolet May 19, 2020
6e9025d
More review feedback
cjnolet May 26, 2020
28f8101
Running cpp style check
cjnolet May 26, 2020
1a18553
Nccl red op
cjnolet May 27, 2020
a22dd62
Raising an exception to get around gcc issue
cjnolet May 27, 2020
fbd12aa
Using static for functions for now
cjnolet May 27, 2020
cc71ccb
Fixing style
cjnolet May 27, 2020
dc123b2
Fixing more relative imports
cjnolet May 28, 2020
60597f3
final updates based on feedback
cjnolet Jun 3, 2020
5bbb8be
Merge branch 'branch-0.15' into fea-ext-comms
afender Jun 3, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ log
.ipynb_checkpoints
.DS_Store
dask-worker-space/
*.egg-info/
## eclipse
.project
.cproject
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
## New Features
- Initial RAFT version
- PR #3: defining raft::handle_t, device_buffer, host_buffer, allocator classes
- PR #7: Migrating cuml comms -> raft comms_t

## Improvements

Expand Down
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ set(CMAKE_CUDA_FLAGS
# - dependencies -------------------------------------------------------------

include(cmake/Dependencies.cmake)
include(cmake/comms.cmake)

##############################################################################
# - include paths ------------------------------------------------------------
Expand Down
35 changes: 35 additions & 0 deletions cpp/cmake/comms.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#
# Copyright (c) 2019-2020, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
project(comms LANGUAGES CXX CUDA)

if(NOT NCCL_PATH)
find_package(NCCL REQUIRED)
else()
message("-- Manually set NCCL PATH to ${NCCL_PATH}")
set(NCCL_INCLUDE_DIRS ${NCCL_PATH}/include)
set(NCCL_LIBRARIES ${NCCL_PATH}/lib/libnccl.so)
endif(NOT NCCL_PATH)

set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

find_package(UCX)
include_directories(${UCX_INCLUDE_DIRS})

include_directories( ${NCCL_INCLUDE_DIRS} )
list(APPEND RAFT_LINK_LIBRARIES ${NCCL_LIBRARIES})
329 changes: 329 additions & 0 deletions cpp/include/raft/comms/comms.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,329 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/cudart_utils.h>
#include <memory>

namespace raft {
namespace comms {

typedef unsigned int request_t;
enum class datatype_t {
CHAR,
UINT8,
INT32,
UINT32,
INT64,
UINT64,
FLOAT32,
FLOAT64
};
enum class op_t { SUM, PROD, MIN, MAX };

/**
* The resulting status of distributed stream synchronization
*/
enum class status_t {
SUCCESS, // Synchronization successful
ERROR, // An error occured querying sync status
ABORT // A failure occurred in sync, queued operations aborted
};
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

template <typename value_t>
constexpr datatype_t get_type();

template <>
constexpr datatype_t get_type<char>() {
return datatype_t::CHAR;
}

template <>
constexpr datatype_t get_type<uint8_t>() {
return datatype_t::UINT8;
}

template <>
constexpr datatype_t get_type<int>() {
return datatype_t::INT32;
}

template <>
constexpr datatype_t get_type<uint32_t>() {
return datatype_t::UINT32;
}

template <>
constexpr datatype_t get_type<int64_t>() {
return datatype_t::INT64;
}

template <>
constexpr datatype_t get_type<uint64_t>() {
return datatype_t::UINT64;
}

template <>
constexpr datatype_t get_type<float>() {
return datatype_t::FLOAT32;
}

template <>
constexpr datatype_t get_type<double>() {
return datatype_t::FLOAT64;
}

class comms_iface {
public:
virtual ~comms_iface();

virtual int get_size() const = 0;
virtual int get_rank() const = 0;

virtual std::unique_ptr<comms_iface> comm_split(int color, int key) const = 0;
virtual void barrier() const = 0;

virtual status_t sync_stream(cudaStream_t stream) const = 0;

virtual void isend(const void* buf, size_t size, int dest, int tag,
request_t* request) const = 0;

virtual void irecv(void* buf, size_t size, int source, int tag,
request_t* request) const = 0;

virtual void waitall(int count, request_t array_of_requests[]) const = 0;

virtual void allreduce(const void* sendbuff, void* recvbuff, size_t count,
datatype_t datatype, op_t op,
cudaStream_t stream) const = 0;

virtual void bcast(void* buff, size_t count, datatype_t datatype, int root,
cudaStream_t stream) const = 0;

virtual void reduce(const void* sendbuff, void* recvbuff, size_t count,
datatype_t datatype, op_t op, int root,
cudaStream_t stream) const = 0;

virtual void allgather(const void* sendbuff, void* recvbuff, size_t sendcount,
datatype_t datatype, cudaStream_t stream) const = 0;

virtual void allgatherv(const void* sendbuf, void* recvbuf,
const size_t recvcounts[], const int displs[],
datatype_t datatype, cudaStream_t stream) const = 0;

virtual void reducescatter(const void* sendbuff, void* recvbuff,
size_t recvcount, datatype_t datatype, op_t op,
cudaStream_t stream) const = 0;
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
};

class comms_t {
public:
comms_t(std::unique_ptr<comms_iface> impl) : impl_(impl.release()) {
ASSERT(nullptr != impl_.get(), "ERROR: Invalid comms_iface used!");
}

/**
* Returns the size of the communicator clique
*/

int get_size() const { return impl_->get_size(); }

/**
* Returns the local rank
*/
int get_rank() const { return impl_->get_rank(); }

/**
* Splits the current communicator clique into sub-cliques matching
* the given color and key
*
* @param color ranks w/ the same color are placed in the same communicator
* @param key controls rank assignment
*/
std::unique_ptr<comms_iface> comm_split(int color, int key) const {
return impl_->comm_split(color, key);
}

/**
* Performs a collective barrier synchronization
*/
void barrier() const { impl_->barrier(); }

/**
* Some collective communications implementations (eg. NCCL) might use asynchronous
* collectives that are explicitly synchronized. It's important to always synchronize
* using this method to allow failures to propagate, rather than `cudaStreamSynchronize()`,
* to prevent the potential for deadlocks.
*
* @param stream the cuda stream to sync collective operations on
*/
status_t sync_stream(cudaStream_t stream) const {
return impl_->sync_stream(stream);
}

/**
* Performs an asynchronous point-to-point send
* @tparam value_t the type of data to send
* @param buf pointer to array of data to send
* @param size number of elements in buf
* @param dest destination rank
* @param tag a tag to use for the receiver to filter
* @param request pointer to hold returned request_t object.
* This will be used in `waitall()` to synchronize until the message is delivered (or fails).
*/
template <typename value_t>
void isend(const value_t* buf, size_t size, int dest, int tag,
request_t* request) const {
impl_->isend(static_cast<const void*>(buf), size * sizeof(value_t), dest,
tag, request);
}

/**
* Performs an asynchronous point-to-point receive
* @tparam value_t the type of data to be received
* @param buf pointer to (initialized) array that will hold received data
* @param size number of elements in buf
* @param source source rank
* @param tag a tag to use for message filtering
* @param request pointer to hold returned request_t object.
* This will be used in `waitall()` to synchronize until the message is delivered (or fails).
*/
template <typename value_t>
void irecv(value_t* buf, size_t size, int source, int tag,
request_t* request) const {
impl_->irecv(static_cast<void*>(buf), size * sizeof(value_t), source, tag,
request);
}

/**
* Synchronize on an array of request_t objects returned from isend/irecv
* @param count number of requests to synchronize on
* @param array_of_requests an array of request_t objects returned from isend/irecv
*/
void waitall(int count, request_t array_of_requests[]) const {
impl_->waitall(count, array_of_requests);
}

/**
* Perform an allreduce collective
* @tparam value_t datatype of underlying buffers
* @param sendbuff data to reduce
* @param recvbuff buffer to hold the reduced result
* @param count number of elements in sendbuff
* @param op reduction operation to perform
* @param stream CUDA stream to synchronize operation
*/
template <typename value_t>
void allreduce(const value_t* sendbuff, value_t* recvbuff, size_t count,
op_t op, cudaStream_t stream) const {
impl_->allreduce(static_cast<const void*>(sendbuff),
static_cast<void*>(recvbuff), count, get_type<value_t>(),
op, stream);
}

/**
* Broadcast data from one rank to the rest
* @tparam value_t datatype of underlying buffers
* @param buff buffer to send
* @param count number of elements if buff
* @param root the rank initiating the broadcast
* @param stream CUDA stream to synchronize operation
*/
template <typename value_t>
void bcast(value_t* buff, size_t count, int root, cudaStream_t stream) const {
impl_->bcast(static_cast<void*>(buff), count, get_type<value_t>(), root,
stream);
}

/**
* Reduce data from many ranks down to a single rank
* @tparam value_t datatype of underlying buffers
* @param sendbuff buffer containing data to reduce
* @param recvbuff buffer containing reduced data (only needs to be initialized on root)
* @param count number of elements in sendbuff
* @param op reduction operation to perform
* @param root rank to store the results
* @param stream CUDA stream to synchronize operation
*/
template <typename value_t>
void reduce(const value_t* sendbuff, value_t* recvbuff, size_t count, op_t op,
int root, cudaStream_t stream) const {
impl_->reduce(static_cast<const void*>(sendbuff),
static_cast<void*>(recvbuff), count, get_type<value_t>(), op,
root, stream);
}

/**
* Gathers data from each rank onto all ranks
* @tparam value_t datatype of underlying buffers
* @param sendbuff buffer containing data to gather
* @param recvbuff buffer containing gathered data from all ranks
* @param sendcount number of elements in send buffer
* @param stream CUDA stream to synchronize operation
*/
template <typename value_t>
void allgather(const value_t* sendbuff, value_t* recvbuff, size_t sendcount,
cudaStream_t stream) const {
impl_->allgather(static_cast<const void*>(sendbuff),
static_cast<void*>(recvbuff), sendcount,
get_type<value_t>(), stream);
}

/**
* Gathers data from all ranks and delivers to combined data to all ranks
* @param value_t datatype of underlying buffers
* @param sendbuff buffer containing data to send
* @param recvbuff buffer containing data to receive
* @param recvcounts array (of length num_ranks size) containing the number of elements
* that are to be received from each rank
* @param displs array (of length num_ranks size) to specify the displacement (relative to recvbuf)
* at which to place the incoming data from each rank
* @param stream CUDA stream to synchronize operation
*/
template <typename value_t>
void allgatherv(const value_t* sendbuf, value_t* recvbuf,
const size_t recvcounts[], const int displs[],
cudaStream_t stream) const {
impl_->allgatherv(static_cast<const void*>(sendbuf),
static_cast<void*>(recvbuf), recvcounts, displs,
get_type<value_t>(), stream);
}

/**
* Reduces data from all ranks then scatters the result across ranks
* @tparam value_t datatype of underlying buffers
* @param sendbuff buffer containing data to send (size recvcount * num_ranks)
* @param recvbuff buffer containing received data
* @param op reduction operation to perform
* @param stream CUDA stream to synchronize operation
*/
template <typename value_t>
void reducescatter(const value_t* sendbuff, value_t* recvbuff,
size_t recvcount, op_t op, cudaStream_t stream) const {
impl_->reducescatter(static_cast<const void*>(sendbuff),
static_cast<void*>(recvbuff), recvcount,
get_type<value_t>(), op, stream);
}

private:
std::unique_ptr<comms_iface> impl_;
};

comms_iface::~comms_iface() {}

} // namespace comms
} // namespace raft
Loading