diff --git a/MinkowskiEngine/MinkowskiConvolution.py b/MinkowskiEngine/MinkowskiConvolution.py index ec712fc7..9eb2fca6 100644 --- a/MinkowskiEngine/MinkowskiConvolution.py +++ b/MinkowskiEngine/MinkowskiConvolution.py @@ -57,9 +57,6 @@ def forward( out_coordinate_map_key = CoordinateMapKey( in_coordinate_map_key.get_coordinate_size() ) - assert ( - input_features.type() == kernel_weights.type() - ), f"Type mismatch input: {input_features.type()} != kernel: {kernel.type()}" if not input_features.is_contiguous(): input_features = input_features.contiguous() @@ -132,9 +129,6 @@ def forward( out_coordinate_map_key = CoordinateMapKey( in_coordinate_map_key.get_coordinate_size() ) - assert ( - input_features.type() == kernel_weights.type() - ), f"Type mismatch input: {input_features.type()} != kernel: {kernel.type()}" if not input_features.is_contiguous(): input_features = input_features.contiguous() diff --git a/MinkowskiEngine/MinkowskiSparseTensor.py b/MinkowskiEngine/MinkowskiSparseTensor.py index 39ab03fd..fb973092 100644 --- a/MinkowskiEngine/MinkowskiSparseTensor.py +++ b/MinkowskiEngine/MinkowskiSparseTensor.py @@ -43,6 +43,7 @@ _allocator_type, _coordinate_map_type, ) +from sparse_matrix_functions import spmm as _spmm class SparseTensorOperationMode(Enum): @@ -343,29 +344,23 @@ def __init__( SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE, ]: N = len(features) - import ipdb; ipdb.set_trace() - # int_inverse_mapping = self.inverse_mapping.int() - COO = torch.stack( - ( - self.inverse_mapping, - torch.arange(N, dtype=int, device=self.unique_index.device), - ), - 0, + cols = torch.arange( + N, + dtype=self.inverse_mapping.dtype, + device=self.inverse_mapping.device, ) - self.sp_mapping = torch.sparse.FloatTensor( - COO, - torch.ones(N).to(self.unique_index), - torch.Size([len(self.unique_index), len(features)]), - ).to(self.unique_index) + vals = torch.ones(N, dtype=features.dtype, device=features.device) + size = torch.Size([len(self.unique_index), len(self.inverse_mapping)]) + features = _spmm(self.inverse_mapping, cols, vals, size, features) + # int_inverse_mapping = self.inverse_mapping.int() if ( self.quantization_mode - == SparseTensorQuantizationMode.UNWEIGHTED_SUM + == SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE ): - features = self.sp_mapping.matmul(features) - else: - features = self.sp_mapping.matmul( - features - ) / self.sp_mapping.matmul(torch.ones(len(features), 1)) + nums = _spmm( + self.inverse_mapping, cols, vals, size, vals.reshape(N, 1), + ) + features /= nums else: features = features[self.unique_index] @@ -586,10 +581,10 @@ def get_device(self): def _is_same_key(self, other): assert isinstance(other, SparseTensor) - assert self._manager == other._manager, COORDS_MAN_DIFFERENT_ERROR + assert self._manager == other._manager, COORDINATE_MANAGER_DIFFERENT_ERROR assert ( self.coordinate_map_key == other.coordinate_map_key - ), COORDS_KEY_DIFFERENT_ERROR + ), COORDINATE_KEY_DIFFERENT_ERROR # Operation overloading def __iadd__(self, other): @@ -622,7 +617,7 @@ def __add__(self, other): """ assert isinstance(other, (SparseTensor, torch.Tensor)) if isinstance(other, SparseTensor): - assert self._manager == other._manager, COORDS_MAN_DIFFERENT_ERROR + assert self._manager == other._manager, COORDINATE_MANAGER_DIFFERENT_ERROR if self.coordinate_map_key == other.coordinate_map_key: return SparseTensor( @@ -661,7 +656,7 @@ def __sub__(self, other): """ assert isinstance(other, (SparseTensor, torch.Tensor)) if isinstance(other, SparseTensor): - assert self._manager == other._manager, COORDS_MAN_DIFFERENT_ERROR + assert self._manager == other._manager, COORDINATE_MANAGER_DIFFERENT_ERROR if self.coordinate_map_key == other.coordinate_map_key: return SparseTensor( @@ -702,7 +697,7 @@ def __mul__(self, other): """ assert isinstance(other, (SparseTensor, torch.Tensor)) if isinstance(other, SparseTensor): - assert self._manager == other._manager, COORDS_MAN_DIFFERENT_ERROR + assert self._manager == other._manager, COORDINATE_MANAGER_DIFFERENT_ERROR if self.coordinate_map_key == other.coordinate_map_key: return SparseTensor( @@ -742,7 +737,7 @@ def __truediv__(self, other): """ assert isinstance(other, (SparseTensor, torch.Tensor)) if isinstance(other, SparseTensor): - assert self._manager == other._manager, COORDS_MAN_DIFFERENT_ERROR + assert self._manager == other._manager, COORDINATE_MANAGER_DIFFERENT_ERROR if self.coordinate_map_key == other.coordinate_map_key: return SparseTensor( diff --git a/pybind/extern.hpp b/pybind/extern.hpp index 4084d9cc..7c2e96dd 100644 --- a/pybind/extern.hpp +++ b/pybind/extern.hpp @@ -1,4 +1,6 @@ -/* Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). +/* + * Copyright (c) 2020 NVIDIA Corporation. + * Copyright (c) 2018-2020 Chris Choy (chrischoy@ai.stanford.edu). * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -189,6 +191,14 @@ at::Tensor quantization_average_features(at::Tensor in_feat, at::Tensor in_map, int mode); */ +#ifndef CPU_ONLY +template +torch::Tensor coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols, + torch::Tensor const &vals, int64_t const dim_i, + int64_t const dim_j, torch::Tensor const &mat2, + int64_t spmm_algorithm_id); +#endif + } // end namespace minkowski namespace py = pybind11; @@ -354,12 +364,15 @@ void instantiate_gpu_func(py::module &m, const std::string &dtypestr) { &minkowski::ConvolutionTransposeBackwardGPU, py::call_guard()); +} - // m.def("coo_spmm_int32", &coo_spmm, - // py::call_guard()); - // m.def("coo_spmm_int64", &coo_spmm, - // py::call_guard()); +void non_templated_gpu_func(py::module &m) { + m.def("coo_spmm_int32", &minkowski::coo_spmm, + py::call_guard()); + m.def("coo_spmm_int64", &minkowski::coo_spmm, + py::call_guard()); } + #endif void initialize_non_templated_classes(py::module &m) { diff --git a/pybind/minkowski.cpp b/pybind/minkowski.cpp index 30d2b1ad..755403e5 100644 --- a/pybind/minkowski.cpp +++ b/pybind/minkowski.cpp @@ -59,5 +59,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { instantiate_gpu_func( m, std::string("")); + + non_templated_gpu_func(m); #endif } diff --git a/pybind/minkowski.cu b/pybind/minkowski.cu index b1e4ec06..3b8fcd02 100644 --- a/pybind/minkowski.cu +++ b/pybind/minkowski.cu @@ -59,5 +59,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { instantiate_gpu_func( m, std::string("")); + + non_templated_gpu_func(m); #endif } diff --git a/setup.py b/setup.py index 0c1bb988..d4cda3bb 100644 --- a/setup.py +++ b/setup.py @@ -193,13 +193,15 @@ def _argparse(pattern, argv, is_flag=True): "coordinate_map_gpu.cu", "convolution_kernel.cu", "convolution_transpose_gpu.cu", + "spmm.cu", + "gpu.cu", ], ["pybind/minkowski.cu"], [], ], } -no_debug, argv = _argparse("--nodebug", argv) +debug, argv = _argparse("--debug", argv) USE_NINJA = os.getenv("USE_NINJA") == "0" HERE = Path(os.path.dirname(__file__)).absolute() @@ -220,7 +222,7 @@ def _argparse(pattern, argv, is_flag=True): NVCC_FLAGS = [f"-ccbin={CXX}", "--extended-lambda"] -if not no_debug: +if debug: CXX_FLAGS += ["-g", "-DDEBUG"] NVCC_FLAGS += ["-g", "-DDEBUG"] else: diff --git a/src/convolution_gpu.cu b/src/convolution_gpu.cu index 031443b3..f6e49f50 100644 --- a/src/convolution_gpu.cu +++ b/src/convolution_gpu.cu @@ -105,6 +105,7 @@ at::Tensor ConvolutionForwardGPU( cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasSetStream(handle, at::cuda::getCurrentCUDAStream().stream()); + LOG_DEBUG("Convolution on", out_nrows, "x", kernel.size(2)); AT_DISPATCH_FLOATING_TYPES( in_feat.scalar_type(), "convolution_forward_gpu", [&] { ConvolutionForwardKernelGPU #include #include -#include +#include #include // cuda driver types #include diff --git a/src/spmm.cu b/src/spmm.cu new file mode 100644 index 00000000..358e7e3f --- /dev/null +++ b/src/spmm.cu @@ -0,0 +1,245 @@ +/* + * Copyright (c) 2020 NVIDIA Corporation. + * Copyright (c) 2018-2020 Chris Choy (chrischoy@ai.stanford.edu). + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural + * Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part + * of the code. + */ +#include "gpu.cuh" + +#include + +#include +#include +#include + +namespace minkowski { + +cudaDataType getTensorCudaDataType(torch::Tensor const &self) { + cudaDataType cuda_data_type; + switch (self.scalar_type()) { + case torch::ScalarType::Float: + cuda_data_type = CUDA_R_32F; + break; + case torch::ScalarType::Double: + cuda_data_type = CUDA_R_64F; + break; + default: + TORCH_CHECK(false, "Tensor types must be either float32 or float64"); + break; + } + return cuda_data_type; +} + +template +torch::Tensor coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols, + torch::Tensor const &vals, int64_t const dim_i, + int64_t const dim_j, torch::Tensor const &mat2, + int64_t spmm_algorithm_id) { +#if defined __HIP_PLATFORM_HCC__ + TORCH_CHECK(false, "spmm sparse-dense is not supported on HIP"); +#elif defined(_WIN32) || defined(_WIN64) + TORCH_CHECK(false, "spmm sparse-dense CUDA is not supported on Windows"); +#elif !defined(CUDART_VERSION) + TORCH_CHECK(false, "CUDART_VERSION not defined"); +#endif + + constexpr bool is_int32 = std::is_same::value; + + cusparseSpMMAlg_t mm_alg; +#if defined(CUDART_VERSION) && (CUDART_VERSION < 10010) + TORCH_CHECK(false, "spmm sparse-dense requires CUDA 10.1 or greater"); +#elif defined(CUDART_VERSION) && (CUDART_VERSION >= 10010) && \ + (CUDART_VERSION < 11000) + switch (spmm_algorithm_id) { + case 1: + mm_alg = CUSPARSE_COOMM_ALG1; + break; + case 2: + mm_alg = CUSPARSE_COOMM_ALG2; + break; + case 3: + mm_alg = CUSPARSE_COOMM_ALG3; + break; + default: + TORCH_CHECK(false, "Invalid algorithm id.", spmm_algorithm_id); + mm_alg = CUSPARSE_MM_ALG_DEFAULT; + } + TORCH_CHECK(!is_int32, "int64 cusparseSpMM requires CUDA 11.0 or greater"); +#elif defined(CUDART_VERSION) && (CUDART_VERSION >= 11000) + switch (spmm_algorithm_id) { + case 1: + mm_alg = CUSPARSE_SPMM_COO_ALG1; + break; + case 2: + mm_alg = CUSPARSE_SPMM_COO_ALG2; + break; + case 3: + mm_alg = CUSPARSE_SPMM_COO_ALG3; + break; + case 3: + mm_alg = CUSPARSE_SPMM_COO_ALG4; + break; + default: + TORCH_CHECK(false, "Invalid algorithm id.", spmm_algorithm_id); + mm_alg = CUSPARSE_SPMM_ALG_DEFAULT; + } + TORCH_CHECK(std::is_same::value || + (std::is_same::value && + (mm_alg == CUSPARSE_SPMM_COO_ALG4))); +#endif + + at::ScalarType int_scalar_type = std::is_same::value + ? at::ScalarType::Int + : at::ScalarType::Long; + + TORCH_CHECK(rows.scalar_type() == int_scalar_type, "int type mismatch."); + + TORCH_CHECK(rows.scalar_type() == cols.scalar_type(), + "rows and cols must have the same scalar type."); + TORCH_CHECK(rows.scalar_type() == cols.scalar_type(), + "rows and cols must have the same scalar type."); + TORCH_CHECK(vals.scalar_type() == mat2.scalar_type(), + "vals and mat2 must have the same scalar type."); + + TORCH_CHECK(rows.is_cuda(), "rows must be CUDA, but got CPU"); + TORCH_CHECK(cols.is_cuda(), "cols must be CUDA, but got CPU"); + TORCH_CHECK(vals.is_cuda(), "vals must be CUDA, but got CPU"); + TORCH_CHECK(mat2.is_cuda(), "mat2 must be CUDA, but got CPU"); + TORCH_CHECK(at::cuda::check_device({rows, cols, vals, mat2})); + + TORCH_CHECK(mat2.dim() == 2, "Tensor 'mat2' must have 2 dims, but has ", + mat2.dim()); + + // int64_t dim_i = self.size(0); + // int64_t dim_j = self.size(1); + int64_t dim_k = mat2.size(1); + + torch::Tensor result = at::empty({dim_k, dim_i}, mat2.options()); + + if ((dim_j == 0) || (dim_k == 0)) { + return result; + } + + // Dense matrices have to be contiguous for cusparseSpMM to work + torch::Tensor const mat2_contig = mat2.contiguous(); + auto cusparse_handle = at::cuda::getCurrentCUDASparseHandle(); + + torch::Scalar beta = 0; + torch::Scalar alpha = 1; + + size_t workspace_buffer_size = 0; + void *workspace_buffer = nullptr; + + // Iterate through each set of 2D matrices within the 3D + // tensor inputs, performing a matrix multiply with each + AT_DISPATCH_FLOATING_TYPES(vals.scalar_type(), "coo_spmm", [&] { + scalar_t alpha_val = alpha.to(); + scalar_t beta_val = beta.to(); + + // Create tensors to view just the current set of matrices + int64_t sparse_nnz = rows.numel(); + + cudaDataType cuda_data_type = getTensorCudaDataType(mat2_contig); + th_int_type *row_indices_ptr = + reinterpret_cast(rows.data_ptr()); + th_int_type *col_indices_ptr = + reinterpret_cast(cols.data_ptr()); + scalar_t *values_ptr = reinterpret_cast(vals.data_ptr()); + scalar_t *mat2_ptr = reinterpret_cast(mat2_contig.data_ptr()); + scalar_t *result_ptr = reinterpret_cast(result.data_ptr()); + + cusparseSpMatDescr_t sparse_descr; + CUSPARSE_CHECK(cusparseCreateCoo(&sparse_descr, // + dim_i, dim_j, sparse_nnz, // + reinterpret_cast(row_indices_ptr), + reinterpret_cast(col_indices_ptr), + reinterpret_cast(values_ptr), // + std::is_same::value + ? CUSPARSE_INDEX_32I + : CUSPARSE_INDEX_64I, + CUSPARSE_INDEX_BASE_ZERO, cuda_data_type)); + + cusparseDnMatDescr_t dense_descr; + CUSPARSE_CHECK(cusparseCreateDnMat(&dense_descr, // + dim_k, dim_j, dim_k, // + reinterpret_cast(mat2_ptr), // + cuda_data_type, CUSPARSE_ORDER_COL)); + + cusparseDnMatDescr_t result_descr; + CUSPARSE_CHECK(cusparseCreateDnMat(&result_descr, // + dim_i, dim_k, dim_i, // + reinterpret_cast(result_ptr), // + cuda_data_type, CUSPARSE_ORDER_COL)); + + size_t required_workspace_buffer_size = 0; + CUSPARSE_CHECK(cusparseSpMM_bufferSize( + cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_TRANSPOSE, (void *)&alpha_val, sparse_descr, + dense_descr, (void *)&beta_val, result_descr, cuda_data_type, mm_alg, + &required_workspace_buffer_size)); + + if (required_workspace_buffer_size > workspace_buffer_size) { + if (workspace_buffer != nullptr) { + cudaFree(workspace_buffer); + } + workspace_buffer_size = required_workspace_buffer_size; + cudaMallocManaged(&workspace_buffer, workspace_buffer_size); + } + CUSPARSE_CHECK(cusparseSpMM(cusparse_handle, // + CUSPARSE_OPERATION_NON_TRANSPOSE, // + CUSPARSE_OPERATION_TRANSPOSE, // + (void *)&alpha_val, // + sparse_descr, dense_descr, // + (void *)&beta_val, result_descr, // + cuda_data_type, mm_alg, workspace_buffer)); + CUSPARSE_CHECK(cusparseDestroySpMat(sparse_descr)); + CUSPARSE_CHECK(cusparseDestroyDnMat(dense_descr)); + CUSPARSE_CHECK(cusparseDestroyDnMat(result_descr)); + }); + + // Need to transpose the result matrices since cusparse stores + // them in column-major order in memory + result.transpose_(0, 1); + + if (workspace_buffer != nullptr) { + cudaFree(workspace_buffer); + } + + CUDA_CHECK(cudaGetLastError()); + + return result; +} + +template torch::Tensor +coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols, + torch::Tensor const &vals, int64_t const dim_i, + int64_t const dim_j, torch::Tensor const &mat2, + int64_t spmm_algorithm_id); + +template torch::Tensor +coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols, + torch::Tensor const &vals, int64_t const dim_i, + int64_t const dim_j, torch::Tensor const &mat2, + int64_t spmm_algorithm_id); + +} // namespace minkowski diff --git a/tests/cpp/convolution_cpu_test.py b/tests/cpp/convolution_cpu_test.py index cccb456c..6ca1a66b 100644 --- a/tests/cpp/convolution_cpu_test.py +++ b/tests/cpp/convolution_cpu_test.py @@ -25,7 +25,7 @@ def test(self): # size, in, out kernel = torch.rand(9, IC, OC) - out_features = _C.ConvolutionForwardCPUf( + out_features = _C.ConvolutionForwardCPU( in_features, kernel, kernel_size, @@ -55,7 +55,7 @@ def test_backward(self): # size, in, out kernel = torch.rand(9, IC, OC) - out_features = _C.ConvolutionForwardCPUf( + out_features = _C.ConvolutionForwardCPU( in_features, kernel, kernel_size, @@ -69,7 +69,7 @@ def test_backward(self): ) out_feat_grad = torch.rand_like(out_features) - in_feat_grad, kernel_grad = _C.ConvolutionBackwardCPUf( + in_feat_grad, kernel_grad = _C.ConvolutionBackwardCPU( in_features, out_feat_grad, kernel, @@ -114,7 +114,7 @@ def test_pcd(self): out_key = in_key stime = time.time() - out_features = _C.ConvolutionForwardCPUf( + out_features = _C.ConvolutionForwardCPU( ucolors, kernel, kernel_size, @@ -162,7 +162,7 @@ def test_pcd2(self): out_key = _C.CoordinateMapKey(4) stime = time.time() - out_features = _C.ConvolutionForwardCPUf( + out_features = _C.ConvolutionForwardCPU( in_feats, kernel, kernel_size, diff --git a/tests/cpp/convolution_gpu_test.py b/tests/cpp/convolution_gpu_test.py index 1bae19b1..656b151f 100644 --- a/tests/cpp/convolution_gpu_test.py +++ b/tests/cpp/convolution_gpu_test.py @@ -25,7 +25,7 @@ def test(self): # size, in, out kernel = torch.rand(3, IC, OC).to(0) - out_features = _C.ConvolutionForwardGPUf( + out_features = _C.ConvolutionForwardGPU( in_features, kernel, kernel_size, @@ -55,7 +55,7 @@ def test_backward(self): # size, in, out kernel = torch.rand(9, IC, OC).to(0) - out_features = _C.ConvolutionForwardGPUf( + out_features = _C.ConvolutionForwardGPU( in_features, kernel, kernel_size, @@ -69,7 +69,7 @@ def test_backward(self): ) out_feat_grad = torch.rand_like(out_features) - in_feat_grad, kernel_grad = _C.ConvolutionBackwardGPUf( + in_feat_grad, kernel_grad = _C.ConvolutionBackwardGPU( in_features, out_feat_grad, kernel, @@ -113,7 +113,7 @@ def test_pcd(self): out_key = _C.CoordinateMapKey(4) stime = time.time() - out_features = _C.ConvolutionForwardGPUf( + out_features = _C.ConvolutionForwardGPU( in_feats, kernel, kernel_size, @@ -163,7 +163,7 @@ def test_pcd2(self): out_key = _C.CoordinateMapKey(4) stime = time.time() - out_features = _C.ConvolutionForwardGPUf( + out_features = _C.ConvolutionForwardGPU( in_feats, kernel, kernel_size, diff --git a/tests/cpp/convolution_test.cpp b/tests/cpp/convolution_test.cpp index a65268ff..7a1c7a81 100644 --- a/tests/cpp/convolution_test.cpp +++ b/tests/cpp/convolution_test.cpp @@ -38,7 +38,7 @@ namespace minkowski { -template +template at::Tensor ConvolutionForwardCPU(at::Tensor const &in_feat, // at::Tensor const &kernel, // @@ -51,7 +51,7 @@ ConvolutionForwardCPU(at::Tensor const &in_feat, // CoordinateMapKey *p_out_map_key, // cpu_manager_type *p_map_manager); -template +template std::pair ConvolutionBackwardCPU(at::Tensor const &in_feat, // at::Tensor const &grad_out_feat, // @@ -114,23 +114,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { minkowski::cpu_manager_type::size) .def("kernel_map", &minkowski::cpu_manager_type::kernel_map); - m.def("ConvolutionForwardCPUf", + m.def("ConvolutionForwardCPU", &minkowski::ConvolutionForwardCPU< - minkowski::default_types::dcoordinate_type, float>, + minkowski::default_types::dcoordinate_type>, py::call_guard()); - m.def("ConvolutionForwardCPUd", - &minkowski::ConvolutionForwardCPU< - minkowski::default_types::dcoordinate_type, double>, - py::call_guard()); - - m.def("ConvolutionBackwardCPUf", - &minkowski::ConvolutionBackwardCPU< - minkowski::default_types::dcoordinate_type, float>, - py::call_guard()); - - m.def("ConvolutionBackwardCPUd", + m.def("ConvolutionBackwardCPU", &minkowski::ConvolutionBackwardCPU< - minkowski::default_types::dcoordinate_type, double>, + minkowski::default_types::dcoordinate_type>, py::call_guard()); } diff --git a/tests/cpp/convolution_test.cu b/tests/cpp/convolution_test.cu index d04f4982..9bb9ae3e 100644 --- a/tests/cpp/convolution_test.cu +++ b/tests/cpp/convolution_test.cu @@ -40,7 +40,7 @@ namespace minkowski { -template class TemplatedAllocator> at::Tensor ConvolutionForwardGPU( at::Tensor const &in_feat, // @@ -54,7 +54,7 @@ at::Tensor ConvolutionForwardGPU( CoordinateMapKey *p_out_map_key, // gpu_manager_type *p_map_manager); -template class TemplatedAllocator> std::pair ConvolutionBackwardGPU( at::Tensor const &in_feat, // @@ -117,27 +117,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { minkowski::gpu_c10_manager_type::size) .def("kernel_map", &minkowski::gpu_c10_manager_type::kernel_map); - m.def("ConvolutionForwardGPUf", + m.def("ConvolutionForwardGPU", &minkowski::ConvolutionForwardGPU< - minkowski::default_types::dcoordinate_type, float, + minkowski::default_types::dcoordinate_type, minkowski::detail::c10_allocator>, py::call_guard()); - m.def("ConvolutionForwardCPUd", - &minkowski::ConvolutionForwardGPU< - minkowski::default_types::dcoordinate_type, double, - minkowski::detail::c10_allocator>, - py::call_guard()); - - m.def("ConvolutionBackwardGPUf", - &minkowski::ConvolutionBackwardGPU< - minkowski::default_types::dcoordinate_type, float, - minkowski::detail::c10_allocator>, - py::call_guard()); - - m.def("ConvolutionBackwardGPUd", + m.def("ConvolutionBackwardGPU", &minkowski::ConvolutionBackwardGPU< - minkowski::default_types::dcoordinate_type, double, + minkowski::default_types::dcoordinate_type, minkowski::detail::c10_allocator>, py::call_guard()); } diff --git a/tests/python/strided_conv.py b/tests/python/strided_conv.py index 4f53fefa..516cbee7 100644 --- a/tests/python/strided_conv.py +++ b/tests/python/strided_conv.py @@ -25,43 +25,55 @@ import argparse import numpy as np from urllib.request import urlretrieve + try: import open3d as o3d except ImportError: - raise ImportError('Please install open3d with `pip install open3d`.') + raise ImportError("Please install open3d with `pip install open3d`.") import torch import MinkowskiEngine as ME +from MinkowskiCommon import convert_to_int_list from examples.common import Timer # Check if the weights and file exist and download -if not os.path.isfile('1.ply'): - print('Downloading a room ply file...') - urlretrieve("http://cvgl.stanford.edu/data2/minkowskiengine/1.ply", '1.ply') +if not os.path.isfile("1.ply"): + print("Downloading a room ply file...") + urlretrieve("http://cvgl.stanford.edu/data2/minkowskiengine/1.ply", "1.ply") parser = argparse.ArgumentParser() -parser.add_argument('--file_name', type=str, default='1.ply') -parser.add_argument('--voxel_size', type=float, default=0.02) -parser.add_argument('--batch_size', type=int, default=1) -parser.add_argument('--max_kernel_size', type=int, default=7) +parser.add_argument("--file_name", type=str, default="1.ply") +parser.add_argument("--voxel_size", type=float, default=0.02) +parser.add_argument("--batch_size", type=int, default=1) +parser.add_argument("--max_kernel_size", type=int, default=7) + + +def quantize(coordinates): + D = coordinates.size(1) - 1 + coordinate_manager = ME.CoordinateManager( + D=D, coordinate_map_type=ME.CoordinateMapType.CPU + ) + coordinate_map_key = ME.CoordinateMapKey(convert_to_int_list(1, D), "") + key, (unique_map, inverse_map) = coordinate_manager.insert_and_map( + coordinates, *coordinate_map_key.get_key() + ) + return unique_map, inverse_map def load_file(file_name, voxel_size): pcd = o3d.io.read_point_cloud(file_name) - coords = np.array(pcd.points) - feats = np.array(pcd.colors) + coords = torch.from_numpy(np.array(pcd.points)) + feats = torch.from_numpy(np.array(pcd.colors)).float() - quantized_coords = np.floor(coords / voxel_size) - inds = ME.utils.sparse_quantize(quantized_coords, return_index=True) + quantized_coords = torch.floor(coords / voxel_size).int() + inds, inverse_inds = quantize(quantized_coords) return quantized_coords[inds], feats[inds], pcd def generate_input_sparse_tensor(file_name, voxel_size=0.05, batch_size=1): # Create a batch, this process is done in a data loader during training in parallel. - batch = [ - load_file(file_name, voxel_size), - ] * batch_size + batch = [load_file(file_name, voxel_size),] * batch_size coordinates_, featrues_, pcds = list(zip(*batch)) coordinates, features = ME.utils.sparse_collate(coordinates_, featrues_) @@ -69,9 +81,9 @@ def generate_input_sparse_tensor(file_name, voxel_size=0.05, batch_size=1): return features, coordinates -if __name__ == '__main__': +if __name__ == "__main__": config = parser.parse_args() - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define a model and load the weights all_convs = {} @@ -83,40 +95,44 @@ def generate_input_sparse_tensor(file_name, voxel_size=0.05, batch_size=1): out_channels=out_ch, kernel_size=k, stride=2, - dimension=3).to(device) + dimension=3, + ).to(device) # Measure time - print('Initialization time') + print("Initialization time") features, coordinates = generate_input_sparse_tensor( - config.file_name, - voxel_size=config.voxel_size, - batch_size=config.batch_size) + config.file_name, voxel_size=config.voxel_size, batch_size=config.batch_size + ) timer = Timer() for i in range(20): timer.tic() - sinput = ME.SparseTensor(features, coords=coordinates).to(device) + sinput = ME.SparseTensor( + features.to(device), coordinates=coordinates.to(device) + ) timer.toc() - print(f'{timer.min_time:.12f} for initialization of {len(sinput)} voxels') + print(f"{timer.min_time:.12f} for initialization of {len(sinput)} voxels") - print('Forward') + print("Forward") for k, conv in all_convs.items(): timer = Timer() features = torch.rand(len(coordinates), k[1]).to(device) # Feed-forward pass and get the prediction for i in range(20): - sinput = ME.SparseTensor(features, coords=coordinates).to(device) + sinput = ME.SparseTensor( + features.to(device), coordinates=coordinates.to(device) + ) timer.tic() soutput = conv(sinput) timer.toc() print( - f'{timer.min_time:.12f} for {k} strided convolution with {len(sinput)} voxel' + f"{timer.min_time:.12f} for {k} strided convolution with {len(sinput)} voxel" ) - print('Backward') + print("Backward") for k, conv in all_convs.items(): timer = Timer() sinput._F = torch.rand(len(sinput), k[1]).to(device) @@ -129,5 +145,5 @@ def generate_input_sparse_tensor(file_name, voxel_size=0.05, batch_size=1): loss.backward() timer.toc() print( - f'{timer.min_time:.12f} for {k} strided convolution with {len(sinput)} voxel' + f"{timer.min_time:.12f} for {k} strided convolution with {len(sinput)} voxel" )