Skip to content

Commit

Permalink
spmm and minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Dec 15, 2020
1 parent f27bda1 commit ace1cbc
Show file tree
Hide file tree
Showing 14 changed files with 360 additions and 112 deletions.
6 changes: 0 additions & 6 deletions MinkowskiEngine/MinkowskiConvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,6 @@ def forward(
out_coordinate_map_key = CoordinateMapKey(
in_coordinate_map_key.get_coordinate_size()
)
assert (
input_features.type() == kernel_weights.type()
), f"Type mismatch input: {input_features.type()} != kernel: {kernel.type()}"
if not input_features.is_contiguous():
input_features = input_features.contiguous()

Expand Down Expand Up @@ -132,9 +129,6 @@ def forward(
out_coordinate_map_key = CoordinateMapKey(
in_coordinate_map_key.get_coordinate_size()
)
assert (
input_features.type() == kernel_weights.type()
), f"Type mismatch input: {input_features.type()} != kernel: {kernel.type()}"
if not input_features.is_contiguous():
input_features = input_features.contiguous()

Expand Down
45 changes: 20 additions & 25 deletions MinkowskiEngine/MinkowskiSparseTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
_allocator_type,
_coordinate_map_type,
)
from sparse_matrix_functions import spmm as _spmm


class SparseTensorOperationMode(Enum):
Expand Down Expand Up @@ -343,29 +344,23 @@ def __init__(
SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE,
]:
N = len(features)
import ipdb; ipdb.set_trace()
# int_inverse_mapping = self.inverse_mapping.int()
COO = torch.stack(
(
self.inverse_mapping,
torch.arange(N, dtype=int, device=self.unique_index.device),
),
0,
cols = torch.arange(
N,
dtype=self.inverse_mapping.dtype,
device=self.inverse_mapping.device,
)
self.sp_mapping = torch.sparse.FloatTensor(
COO,
torch.ones(N).to(self.unique_index),
torch.Size([len(self.unique_index), len(features)]),
).to(self.unique_index)
vals = torch.ones(N, dtype=features.dtype, device=features.device)
size = torch.Size([len(self.unique_index), len(self.inverse_mapping)])
features = _spmm(self.inverse_mapping, cols, vals, size, features)
# int_inverse_mapping = self.inverse_mapping.int()
if (
self.quantization_mode
== SparseTensorQuantizationMode.UNWEIGHTED_SUM
== SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE
):
features = self.sp_mapping.matmul(features)
else:
features = self.sp_mapping.matmul(
features
) / self.sp_mapping.matmul(torch.ones(len(features), 1))
nums = _spmm(
self.inverse_mapping, cols, vals, size, vals.reshape(N, 1),
)
features /= nums
else:
features = features[self.unique_index]

Expand Down Expand Up @@ -586,10 +581,10 @@ def get_device(self):

def _is_same_key(self, other):
assert isinstance(other, SparseTensor)
assert self._manager == other._manager, COORDS_MAN_DIFFERENT_ERROR
assert self._manager == other._manager, COORDINATE_MANAGER_DIFFERENT_ERROR
assert (
self.coordinate_map_key == other.coordinate_map_key
), COORDS_KEY_DIFFERENT_ERROR
), COORDINATE_KEY_DIFFERENT_ERROR

# Operation overloading
def __iadd__(self, other):
Expand Down Expand Up @@ -622,7 +617,7 @@ def __add__(self, other):
"""
assert isinstance(other, (SparseTensor, torch.Tensor))
if isinstance(other, SparseTensor):
assert self._manager == other._manager, COORDS_MAN_DIFFERENT_ERROR
assert self._manager == other._manager, COORDINATE_MANAGER_DIFFERENT_ERROR

if self.coordinate_map_key == other.coordinate_map_key:
return SparseTensor(
Expand Down Expand Up @@ -661,7 +656,7 @@ def __sub__(self, other):
"""
assert isinstance(other, (SparseTensor, torch.Tensor))
if isinstance(other, SparseTensor):
assert self._manager == other._manager, COORDS_MAN_DIFFERENT_ERROR
assert self._manager == other._manager, COORDINATE_MANAGER_DIFFERENT_ERROR

if self.coordinate_map_key == other.coordinate_map_key:
return SparseTensor(
Expand Down Expand Up @@ -702,7 +697,7 @@ def __mul__(self, other):
"""
assert isinstance(other, (SparseTensor, torch.Tensor))
if isinstance(other, SparseTensor):
assert self._manager == other._manager, COORDS_MAN_DIFFERENT_ERROR
assert self._manager == other._manager, COORDINATE_MANAGER_DIFFERENT_ERROR

if self.coordinate_map_key == other.coordinate_map_key:
return SparseTensor(
Expand Down Expand Up @@ -742,7 +737,7 @@ def __truediv__(self, other):
"""
assert isinstance(other, (SparseTensor, torch.Tensor))
if isinstance(other, SparseTensor):
assert self._manager == other._manager, COORDS_MAN_DIFFERENT_ERROR
assert self._manager == other._manager, COORDINATE_MANAGER_DIFFERENT_ERROR

if self.coordinate_map_key == other.coordinate_map_key:
return SparseTensor(
Expand Down
23 changes: 18 additions & 5 deletions pybind/extern.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/* Copyright (c) Chris Choy (chrischoy@ai.stanford.edu).
/*
* Copyright (c) 2020 NVIDIA Corporation.
* Copyright (c) 2018-2020 Chris Choy (chrischoy@ai.stanford.edu).
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -189,6 +191,14 @@ at::Tensor quantization_average_features(at::Tensor in_feat, at::Tensor in_map,
int mode);
*/

#ifndef CPU_ONLY
template <typename th_int_type>
torch::Tensor coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols,
torch::Tensor const &vals, int64_t const dim_i,
int64_t const dim_j, torch::Tensor const &mat2,
int64_t spmm_algorithm_id);
#endif

} // end namespace minkowski

namespace py = pybind11;
Expand Down Expand Up @@ -354,12 +364,15 @@ void instantiate_gpu_func(py::module &m, const std::string &dtypestr) {
&minkowski::ConvolutionTransposeBackwardGPU<coordinate_type,
TemplatedAllocator>,
py::call_guard<py::gil_scoped_release>());
}

// m.def("coo_spmm_int32", &coo_spmm<int32_t>,
// py::call_guard<py::gil_scoped_release>());
// m.def("coo_spmm_int64", &coo_spmm<int64_t>,
// py::call_guard<py::gil_scoped_release>());
void non_templated_gpu_func(py::module &m) {
m.def("coo_spmm_int32", &minkowski::coo_spmm<int32_t>,
py::call_guard<py::gil_scoped_release>());
m.def("coo_spmm_int64", &minkowski::coo_spmm<int64_t>,
py::call_guard<py::gil_scoped_release>());
}

#endif

void initialize_non_templated_classes(py::module &m) {
Expand Down
2 changes: 2 additions & 0 deletions pybind/minkowski.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

instantiate_gpu_func<int32_t, minkowski::detail::c10_allocator>(
m, std::string(""));

non_templated_gpu_func(m);
#endif
}
2 changes: 2 additions & 0 deletions pybind/minkowski.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

instantiate_gpu_func<int32_t, minkowski::detail::c10_allocator>(
m, std::string(""));

non_templated_gpu_func(m);
#endif
}
6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,15 @@ def _argparse(pattern, argv, is_flag=True):
"coordinate_map_gpu.cu",
"convolution_kernel.cu",
"convolution_transpose_gpu.cu",
"spmm.cu",
"gpu.cu",
],
["pybind/minkowski.cu"],
[],
],
}

no_debug, argv = _argparse("--nodebug", argv)
debug, argv = _argparse("--debug", argv)

USE_NINJA = os.getenv("USE_NINJA") == "0"
HERE = Path(os.path.dirname(__file__)).absolute()
Expand All @@ -220,7 +222,7 @@ def _argparse(pattern, argv, is_flag=True):

NVCC_FLAGS = [f"-ccbin={CXX}", "--extended-lambda"]

if not no_debug:
if debug:
CXX_FLAGS += ["-g", "-DDEBUG"]
NVCC_FLAGS += ["-g", "-DDEBUG"]
else:
Expand Down
1 change: 1 addition & 0 deletions src/convolution_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ at::Tensor ConvolutionForwardGPU(
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasSetStream(handle, at::cuda::getCurrentCUDAStream().stream());

LOG_DEBUG("Convolution on", out_nrows, "x", kernel.size(2));
AT_DISPATCH_FLOATING_TYPES(
in_feat.scalar_type(), "convolution_forward_gpu", [&] {
ConvolutionForwardKernelGPU<scalar_t, default_types::index_type,
Expand Down
2 changes: 1 addition & 1 deletion src/gpu.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <curand.h>
#include <cusparse_v2.h>
#include <cusparse.h>
#include <driver_types.h> // cuda driver types

#include <thrust/device_vector.h>
Expand Down
Loading

0 comments on commit ace1cbc

Please sign in to comment.