Skip to content

Commit

Permalink
transpose kernel map
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Dec 15, 2020
1 parent 8633896 commit 0c332d7
Show file tree
Hide file tree
Showing 15 changed files with 452 additions and 192 deletions.
3 changes: 0 additions & 3 deletions MinkowskiEngine/MinkowskiSparseTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,6 @@ def __init__(
else:
# not (coordinate_map_key is None or coordinate_manager is None)
self.D = coordinate_manager.D
coordinate_map_key = CoordinateMapKey(
convert_to_int_list(tensor_stride, self.D), ""
)
self._manager = coordinate_manager

##########################
Expand Down
5 changes: 5 additions & 0 deletions pybind/extern.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,11 @@ void instantiate_manager(py::module &m, const std::string &dtypestr) {
.def(py::init<>())
.def(py::init<minkowski::CUDAKernelMapMode::Mode,
minkowski::default_types::size_type>())
.def("__repr__",
py::overload_cast<>(&manager_type::to_string, py::const_))
.def("print_coordinate_map",
py::overload_cast<minkowski::CoordinateMapKey const *>(
&manager_type::to_string, py::const_))
// TODO .def("insert", &manager_type::insert)
.def("insert_and_map", &manager_type::insert_and_map)
.def("stride",
Expand Down
10 changes: 10 additions & 0 deletions src/coordinate_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ stride_coordinate(const coordinate<Itype> &src, std::vector<Itype> &dst,
}
}

template <typename Itype, typename stride_type>
inline void stride_coordinate(const coordinate<Itype> &src,
std::vector<Itype> &dst,
const stride_type stride) noexcept {
dst[0] = src[0];
for (default_types::index_type i = 0; i < dst.size() - 1; ++i) {
dst[i + 1] = std::floor((float)src[i + 1] / stride[i]) * stride[i];
}
}

inline default_types::stride_type
stride_tensor_stride(const default_types::stride_type &tensor_stride,
const default_types::stride_type &stride) {
Expand Down
58 changes: 57 additions & 1 deletion src/coordinate_map_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,63 @@ class CoordinateMapCPU : public CoordinateMap<coordinate_type, TemplatedAllocato
return std::make_pair(in_maps, out_maps);
}

cpu_kernel_map stride_map(self_type const &out_coordinate_map,
stride_type const &out_tensor_stride) const {
// generate an in-out (kernel) map that maps all input points in the same
// voxel to strided output voxel.
size_type in_size = size();
LOG_DEBUG("Generate stride_map with in NNZ:", in_size,
"out NNZ:", out_coordinate_map.size(),
"out_tensor_stride:", out_tensor_stride);
ASSERT(in_size > out_coordinate_map.size(), "Invalid out_coordinate_map");
cpu_in_maps in_maps = initialize_maps<cpu_in_map>(1, in_size);
cpu_out_maps out_maps = initialize_maps<cpu_out_map>(1, in_size);

// compute the chunk size per thread.
// There's a trade-off between the thread initialization overhead and the
// job sizes. If some jobs finish earlier than others due to imbalance in
// hash distribution, these threads will be idle.
const size_t in_map_num_elements = m_map.capacity();
size_t N = 2 * omp_get_max_threads();
const size_t stride = (in_map_num_elements + N - 1) / N;
N = (in_map_num_elements + stride - 1) / stride;
LOG_DEBUG("kernel map with", N, "chunks.");

index_type num_used = 0;
#pragma omp parallel for
for (index_type n = 0; n < N; ++n) {
index_type curr_index_begin;
std::vector<coordinate_type> dst(m_coordinate_size);
for (auto iter_in = m_map.begin(stride * n);
iter_in.num_steps() <
std::min(stride, in_map_num_elements - n * stride);
++iter_in) {
detail::stride_coordinate<coordinate_type>(iter_in->first, dst,
out_tensor_stride);
const auto iter_out =
out_coordinate_map.find(coordinate<coordinate_type>(dst.data()));
ASSERT(iter_out != out_coordinate_map.m_map.cend(),
"Invalid out_coordinate_map");
#pragma omp atomic capture
{
curr_index_begin = num_used;
num_used += 1;
}

in_maps[0][curr_index_begin] = iter_in->second;
out_maps[0][curr_index_begin] = iter_out->second;
}
}

return std::make_pair(move(in_maps), move(out_maps));
}

inline size_type size() const noexcept { return m_map.size(); }
std::string to_string() const {
Formatter o;
o << "CoordinateMapCPU:" << size() << "x" << m_coordinate_size;
return o.str();
}

using base_type::capacity;
using base_type::coordinate_size;
Expand All @@ -368,7 +424,7 @@ class CoordinateMapCPU : public CoordinateMap<coordinate_type, TemplatedAllocato
// Put if outside the loop for speed
#pragma omp parallel for
for (index_type n = 0; n < N; ++n) {
for (auto it = m_map.begin(stride * n); //
for (auto it = m_map.begin(stride * n); //
it.num_steps() < std::min(stride, capacity - n * stride); //
++it) {
std::copy_n(it->first.data(), m_coordinate_size,
Expand Down
114 changes: 105 additions & 9 deletions src/coordinate_map_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ stride_copy(coordinate_type const *__restrict__ src_coordinates, //
dst_coordinates[dst_start] = src_coordinates[src_start];
for (index_type j = 1; j < coordinate_size; ++j) {
dst_coordinates[dst_start + j] =
(__float2int_rd(__fdiv_rd(src_coordinates[src_start + j],
sh_stride[j - 1]))) *
(__float2int_rd(
__fdiv_rd(src_coordinates[src_start + j], sh_stride[j - 1]))) *
sh_stride[j - 1];
// (__double2int_rd(
// __ddiv_rn(src_coordinates[src_start + j], sh_stride[j - 1]))) *
Expand Down Expand Up @@ -389,8 +389,8 @@ count_kernel(map_type const __restrict__ in_map, //
auto const bx = blockIdx.x;
auto const x = blockDim.x * bx + tx;

size_type coordinate_size = kernel.coordinate_size();
size_type volume = kernel.volume();
size_type const coordinate_size = kernel.coordinate_size();
size_type const volume = kernel.volume();

// clang-format off
size_type *sh_size = reinterpret_cast<size_type *>(sh_all);
Expand Down Expand Up @@ -455,8 +455,8 @@ __global__ void preallocated_kernel_map_iteration(
auto const bx = blockIdx.x;
auto const x = blockDim.x * bx + tx;

size_type coordinate_size = kernel.coordinate_size();
size_type volume = kernel.volume();
size_type const coordinate_size = kernel.coordinate_size();
size_type const volume = kernel.volume();

// clang-format off
size_type *sh_size = reinterpret_cast<size_type *>(sh_all);
Expand All @@ -481,9 +481,6 @@ __global__ void preallocated_kernel_map_iteration(

__syncthreads();

if (x >= num_threads)
return;

auto const unused_key = out_map.get_unused_key();
if (x < num_threads) {
// iterate over values
Expand Down Expand Up @@ -595,6 +592,105 @@ CoordinateMapGPU<coordinate_type, TemplatedAllocator>::kernel_map(

namespace detail {

template <typename coordinate_type, //
typename size_type, //
typename index_type, //
typename map_type>
__global__ void
stride_map_kernel(map_type const __restrict__ in_map, //
map_type const __restrict__ out_map, //
index_type const *const __restrict__ in_valid_map_index, //
size_type const num_threads, //
index_type const *const __restrict__ stride, //
index_type *__restrict__ p_in_maps, //
index_type *__restrict__ p_out_maps,
size_type const coordinate_size) {
extern __shared__ coordinate_type sh_all[];

auto const tx = threadIdx.x;
auto const bx = blockIdx.x;
auto const x = blockDim.x * bx + tx;

// clang-format off
size_type *sh_size = reinterpret_cast<size_type *>(sh_all);

size_type *sh_stride = sh_size;

coordinate_type *sh_coordinate = reinterpret_cast<coordinate_type *>(sh_size + coordinate_size);
coordinate_type *sh_tmp = sh_coordinate + tx * coordinate_size;
// clang-format on

for (index_type i = tx; i < coordinate_size - 1; i += blockDim.x) {
sh_stride[i] = stride[i];
}

__syncthreads();

if (x >= num_threads)
return;

typename map_type::value_type const &in_value =
in_map.data()[in_valid_map_index[x]];

sh_tmp[0] = in_value.first[0];
for (index_type j = 1; j < coordinate_size; ++j) {
sh_tmp[j] =
(__float2int_rd(__fdiv_rd(in_value.first[j], sh_stride[j - 1]))) *
sh_stride[j - 1];
}

auto out_iter = out_map.find(coordinate<coordinate_type>(sh_tmp));

p_in_maps[x] = in_value.second;
p_out_maps[x] = out_iter->second;
}

} // namespace detail

template <typename coordinate_type,
template <typename T> class TemplatedAllocator>
CoordinateMapGPU<coordinate_type, TemplatedAllocator>::kernel_map_type
CoordinateMapGPU<coordinate_type, TemplatedAllocator>::stride_map(
self_type const &out_map, stride_type const &out_tensor_stride,
uint32_t thread_dim) const {
// Over estimate the reserve size to be size();
size_type const in_size = size();
thrust::device_vector<size_type> d_out_tensor_stride(
out_tensor_stride.begin(), out_tensor_stride.end());

// (THREAD * D + D) * 4
uint32_t const shared_memory_size_in_bytes =
m_coordinate_size * sizeof(index_type) + // stride
thread_dim * m_coordinate_size * sizeof(coordinate_type); // tmp
size_type const num_threads = in_size;
auto const num_blocks = GET_BLOCKS(num_threads, thread_dim);

LOG_DEBUG("num block", num_blocks);
LOG_DEBUG("shared_memory size", shared_memory_size_in_bytes);
LOG_DEBUG("threads dim", thread_dim);
LOG_DEBUG("num threads", num_threads);

kernel_map_type kernel_map(in_size, base_type::m_byte_allocator,
false /* reserve_kernel_index */);
CUDA_CHECK(cudaStreamSynchronize(0));
LOG_DEBUG("Allocated kernel_map.");

detail::stride_map_kernel<coordinate_type, size_type, index_type, map_type>
<<<num_blocks, thread_dim, shared_memory_size_in_bytes>>>(
*m_map, //
*out_map.m_map, //
thrust::raw_pointer_cast(m_valid_map_index.data()), //
num_threads, //
thrust::raw_pointer_cast(d_out_tensor_stride.data()), //
kernel_map.in_maps.begin(), //
kernel_map.out_maps.begin(), //
m_coordinate_size);

return kernel_map;
}

namespace detail {

template <typename coordinate_type, //
typename size_type, //
typename index_type, //
Expand Down
8 changes: 8 additions & 0 deletions src/coordinate_map_gpu.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,18 @@ public:
gpu_kernel_region<coordinate_type> const &kernel,
CUDAKernelMapMode::Mode kernel_map_mode,
uint32_t thread_dim = CUDA_NUM_THREADS) const;
kernel_map_type stride_map(self_type const &out_coordinate_map,
stride_type const &out_tensor_stride,
uint32_t thread_dim = CUDA_NUM_THREADS) const;

// Returns the number of elements in the coordinate map
inline size_type size() const { return m_size; }
void copy_coordinates(coordinate_type *dst_coordinate) const;
std::string to_string() const {
Formatter o;
o << "CoordinateMapGPU:" << size() << "x" << m_coordinate_size;
return o.str();
}

// access the coordinate data pointer
using base_type::const_coordinate_data;
Expand Down
1 change: 1 addition & 0 deletions src/coordinate_map_key.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class CoordinateMapKey {
ASSERT(m_coordinate_size - 1 == key.first.size(),
"Invalid tensor_stride size:", key.first,
"coordinate_size:", m_coordinate_size);
LOG_DEBUG("Setting the key to ", key.first, ":", key.second);
m_key = key;
m_key_set = true;
}
Expand Down
Loading

0 comments on commit 0c332d7

Please sign in to comment.