Skip to content

Commit

Permalink
First version of dispatcher (pytorch#8713)
Browse files Browse the repository at this point in the history
  • Loading branch information
smessmer committed Jun 25, 2018
1 parent 2b926aa commit cca2476
Show file tree
Hide file tree
Showing 19 changed files with 1,412 additions and 3 deletions.
4 changes: 4 additions & 0 deletions caffe2/core/dispatch/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
set(LIB_SOURCES
DeviceId.cpp
Dispatcher.cpp
DispatchKey.cpp
DispatchTable.cpp
KernelRegistration.cpp
LayoutId.cpp
OpSchema.cpp
OpSchemaRegistration.cpp
TensorTypeId.cpp
TensorTypeIdRegistration.cpp
)
Expand Down
2 changes: 2 additions & 0 deletions caffe2/core/dispatch/DeviceId.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <functional>
#include <iostream>
#include "caffe2/utils/C++17.h"

namespace c10 {

Expand All @@ -19,6 +20,7 @@ inline std::ostream& operator<<(std::ostream& stream, DeviceTypeId device_type_i
case DeviceTypeId::CUDA: return stream << "DeviceTypeId(CUDA)";
case DeviceTypeId::UNDEFINED: return stream << "DeviceTypeId(UNDEFINED)";
}
throw std::logic_error("Unknown DeviceTypeId: " + guts::to_string(static_cast<int>(device_type_id)));
}

}
Expand Down
16 changes: 16 additions & 0 deletions caffe2/core/dispatch/DispatchKey.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <vector>
#include <functional>
#include <sstream>
#include "caffe2/utils/Array.h"

namespace c10 {
Expand All @@ -21,6 +22,9 @@ struct TensorParameterDispatchKey final {
inline constexpr bool operator==(const TensorParameterDispatchKey& lhs, const TensorParameterDispatchKey& rhs) {
return lhs.deviceTypeId == rhs.deviceTypeId && lhs.layoutId == rhs.layoutId && lhs.dataType == rhs.dataType;
}
inline std::ostream& operator<<(std::ostream& stream, const TensorParameterDispatchKey& key) {
return stream << "TensorKey(" << key.deviceTypeId << ", " << key.layoutId.value() << ", " << key.dataType << ")";
}
} // namespace details
} // namespace c10

Expand Down Expand Up @@ -58,6 +62,18 @@ inline constexpr bool operator==(const DispatchKey<num_dispatch_args> &lhs, cons
// TODO: Use AVX instructions to perform this equality test more quickly
return lhs.argTypes == rhs.argTypes;
}
template<size_t num_dispatch_args>
inline std::ostream& operator<<(std::ostream& stream, const DispatchKey<num_dispatch_args>& key) {
stream << "DispatchKey(";
if (num_dispatch_args > 0) {
stream << "DispatchKey(" << key.argTypes[0];
for (size_t i = 1; i < num_dispatch_args; ++i) {
stream << ", " << key.argTypes[i];
}
stream << ")";
}
return stream << ")";
}

} // namespace c10

Expand Down
1 change: 1 addition & 0 deletions caffe2/core/dispatch/DispatchTable.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#include "caffe2/core/dispatch/DispatchTable.h"
154 changes: 154 additions & 0 deletions caffe2/core/dispatch/DispatchTable.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
#pragma once

#include "caffe2/utils/flat_hash_map/flat_hash_map.h"
#include "caffe2/utils/Metaprogramming.h"
#include "caffe2/core/dispatch/OpSchema.h"

#include <type_traits>
#include <array>
#include <unordered_map>
#include <iostream>
#include <mutex>

namespace c10 {

namespace details {

/// Kernel implementations in a thread-safe hash table.
template<class Key>
class ThreadsafeOperatorTable_ final {
public:
// TODO The current implementation below does not have the correct correctness characteristics
// which we need. It's worth spelling out exactly what we need:
//
// - We need LOCK FREE read access to the table (as per the performance benchmark
// at https://fb.quip.com/hvz3AGnx8MQ8
//
// - We need to support writes which are possibly concurrent with reads, occurring when
// a dynamic library is loaded or unloaded.
//
// - We probably can require that dynamic library loads/unloads be synchronized (so
// there are never two concurrent loads.)

template<class Key_>
void emplace(Key_&& key, void* value) {
using std::to_string;
// TODO Locking
//std::unique_lock<std::shared_timed_mutex> lock(mutex_);

auto result = map_.emplace(std::forward<Key>(key), value);
if (!result.second) {
std::ostringstream msg;
msg << "Tried to register conflicting kernels to the dispatcher: " << key;
throw std::logic_error(msg.str());
}
}

void erase(const Key& key) {
// TODO Locking
//std::unique_lock<std::shared_timed_mutex> lock(mutex_);

size_t num_removed = map_.erase(key);
assert(num_removed <= 1); //This is not a multi-map
if (num_removed == 0) {
throw std::logic_error("Tried to deregister a kernel that isn't registered.");
}
}

void* lookup(const Key& key) const {
// TODO (lock needed but slow perf. Find better way)
// std::shared_lock<std::shared_timed_mutex> lock(mutex_);
auto found = map_.find(key);
if (found == map_.end()) {
return nullptr;
} else {
return found->second;
}
}

private:
ska::flat_hash_map<Key, void*> map_;
// TODO Figure out how to get fast locking in C++11 (use boost::shared_timed_mutex? folly::SharedMutex? LR pattern?)
//mutable std::shared_timed_mutex mutex_;
};
} // namespace details

/**
* Per-operator dispatch table.
*
* Given an operator specified by 'OpSchemaDef', this class records a dispatch table for
* various kernels provided for this operator. For example, if we consider the operator
* add(Tensor, Tensor), the dispatch table for this operator may contain implementations
* for various dynamic tensor types, such as (CPUFloatTensor, CPUFloatTensor),
* (CUDAFloatTensor, CUDAFloatTensor), etc.
*
* @tparam OpSchemaDef The operator signature this dispatch table encodes.
*/
// TODO: Support dispatch for meta-operators (which apply to all dynamic types)
template<class OpSchemaDef>
class DispatchTable final {
private:
using Schema = OpSchema<OpSchemaDef>;

public:
DispatchTable(): kernels_() {}

/**
* Register a kernel in the table at some dispatch key.
* @param func Concrete kernel function implementation to register
* @param dispatch_key Dispatch key to define when this kernel is selected
*/
void registerKernel(typename Schema::signature::func_type* func, typename Schema::dispatch::dispatch_key_type dispatch_key) {
kernels_.emplace(std::move(dispatch_key), reinterpret_cast<void*>(func));
}

/**
* Deregister the kernel for some dispatch key.
*
* @param dispatch_key Dispatch key to unregister.
*/
// TODO: This isn't going to work so well when we get more complicated override patterns!
// In this case, an operator will show up in multiple slots, and erasing them one-by-one
// is probably not such a good idea.
void deregisterKernel(const typename Schema::dispatch::dispatch_key_type& dispatch_key) {
kernels_.erase(dispatch_key);
}

/**
* Perform a dynamic dispatch on this table.
*
* @tparam Args Perfect forwarding template arguments to the dispatch
* @param args Arguments to invoke the function with
* @return Returned value of the operator
*/
template<class... Args>
typename Schema::signature::return_type call(Args&&... args) const {
// TODO Better error message, but need to take care that reference arguments match non-reference arguments and so on.
// static_assert(std::is_same<typename Schema::return_type (Args...), typename Schema::func_type>::value, "Argument types don't match operator signature");
auto kernel_func = lookupKernelFunc_(args...);
return kernel_func(std::forward<Args>(args)...);
}

private:
template<class... Args>
typename Schema::signature::func_type* lookupKernelFunc_(const Args&... args) const {
auto dispatch_key = Schema::dispatch::dispatch_key(args...);
void* found = kernels_.lookup(dispatch_key);
if (found == nullptr) {
// TODO Better error message - include op name and dispatch key (i.e. argument types)
throw std::logic_error(std::string() + "Didn't find kernel to dispatch to for operator '" + Schema::metadata::name() + "'");
}
return reinterpret_cast<typename Schema::signature::func_type*>(found);
}

details::ThreadsafeOperatorTable_<typename Schema::dispatch::dispatch_key_type> kernels_;
};

} // namespace c10

/*
* Use this to access the dispatch table singleton for a given op schema.
* It has an implementation for each op schema def in a cpp file, because
* we can't rely on the one-definition-rule.
*/
template<class OpSchemaDef> c10::DispatchTable<OpSchemaDef>& c10_dispatch_table();
1 change: 1 addition & 0 deletions caffe2/core/dispatch/Dispatcher.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#include "caffe2/core/dispatch/Dispatcher.h"
60 changes: 60 additions & 0 deletions caffe2/core/dispatch/Dispatcher.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#pragma once

#include "caffe2/core/dispatch/DispatchTable.h"

namespace c10 {

/**
* Top-level dispatch interface for dispatching via the dynamic dispatcher.
*/
template<class OpSchemaDef>
class Dispatcher final {
public:
// Implementation note: this class abstracts over the fact that we have per-operator
// dispatch tables. This could be easily adjusted to have a single global hash
// table.

/**
* Register an operator to the dispatch table for some operator schema.
*
* @tparam OpSchemaDef Operator schema to register this operator to (mandatory)
* @tparam Args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::registerOp (inferred)
* @param args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::registerOp
* @return void
*/
template<class... Args>
static void registerKernel(Args&&... args) {
auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>();
return dispatch_table_for_this_op.registerKernel(std::forward<Args>(args)...);
}

/**
* Remove an operator from the dispatch table for some operator schema.
*
* @tparam OpSchemaDef Operator schema to deregister from (mandatory)
* @tparam Args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::deregisterOp (inferred)
* @param args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::deregisterOp
* @return void
*/
template<class... Args>
static void deregisterKernel(Args&&... args) {
auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>();
return dispatch_table_for_this_op.deregisterKernel(std::forward<Args>(args)...);
}

/**
* Perform a dynamic dispatch to some operator
*
* @tparam OpSchemaDef Operator schema to dispatch with (mandatory)
* @tparam Args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::call (inferred)
* @param args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::call
* @return Return type of this operator
*/
template<class... Args>
static typename OpSchema<OpSchemaDef>::signature::return_type call(Args&&... args) {
auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>();
return dispatch_table_for_this_op.call(std::forward<Args>(args)...);
}
};

} // namespace c10
1 change: 1 addition & 0 deletions caffe2/core/dispatch/KernelRegistration.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#include "caffe2/core/dispatch/KernelRegistration.h"
Loading

0 comments on commit cca2476

Please sign in to comment.