diff --git a/CHANGELOG.md b/CHANGELOG.md index 540efb59..1cd0a4fe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ - TensorField to sparse with coordinate map key - Sparse matrix multiplication - force contiguous matrix +- Fix AveragePooling cudaErrorMisalignedAddress error for CUDA 10 (#246) ## [0.5.0] - 2020-12-24 diff --git a/MinkowskiEngine/MinkowskiNormalization.py b/MinkowskiEngine/MinkowskiNormalization.py index dc0035cc..1f57a15a 100644 --- a/MinkowskiEngine/MinkowskiNormalization.py +++ b/MinkowskiEngine/MinkowskiNormalization.py @@ -199,7 +199,7 @@ def forward( in_coords_key: CoordinateMapKey, glob_coords_key: CoordinateMapKey = None, coords_manager: CoordinateManager = None, - gpooling_mode=PoolingMode.GLOBAL_AVG_POOLING_PYTORCH_INDEX, + gpooling_mode=PoolingMode.GLOBAL_AVG_POOLING_KERNEL, ): if glob_coords_key is None: glob_coords_key = CoordinateMapKey(in_coords_key.get_coordinate_size()) diff --git a/MinkowskiEngine/MinkowskiSparseTensor.py b/MinkowskiEngine/MinkowskiSparseTensor.py index ff4d47e1..745b0495 100644 --- a/MinkowskiEngine/MinkowskiSparseTensor.py +++ b/MinkowskiEngine/MinkowskiSparseTensor.py @@ -314,13 +314,14 @@ def initialize_coordinates(self, coordinates, features, coordinate_map_key): self.quantization_mode == SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE ): - nums = spmm.apply( - self.inverse_mapping, - cols, - vals, - size, - vals.reshape(N, 1), - ) + with torch.no_grad(): + nums = spmm.apply( + self.inverse_mapping, + cols, + vals, + size, + vals.reshape(N, 1), + ) features /= nums elif self.quantization_mode == SparseTensorQuantizationMode.RANDOM_SUBSAMPLE: features = features[self.unique_index] diff --git a/MinkowskiEngine/MinkowskiTensorField.py b/MinkowskiEngine/MinkowskiTensorField.py index eae06cdf..4eb73467 100644 --- a/MinkowskiEngine/MinkowskiTensorField.py +++ b/MinkowskiEngine/MinkowskiTensorField.py @@ -298,13 +298,14 @@ def sparse( self.quantization_mode == SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE ): - nums = spmm.apply( - inverse_mapping, - cols, - vals, - size, - vals.reshape(N, 1), - ) + with torch.no_grad(): + nums = spmm.apply( + inverse_mapping, + cols, + vals, + size, + vals.reshape(N, 1), + ) features /= nums elif self.quantization_mode == SparseTensorQuantizationMode.RANDOM_SUBSAMPLE: features = self._F[unique_index] diff --git a/MinkowskiEngine/sparse_matrix_functions.py b/MinkowskiEngine/sparse_matrix_functions.py index a098c491..9fd72203 100644 --- a/MinkowskiEngine/sparse_matrix_functions.py +++ b/MinkowskiEngine/sparse_matrix_functions.py @@ -33,6 +33,7 @@ def spmm( vals: torch.Tensor, size: torch.Size, mat: torch.Tensor, + return_num_nonzero: bool = False, cuda_spmm_alg: int = 1, ): @@ -41,11 +42,13 @@ def spmm( assert vals.dtype == mat.dtype, "dtype mismatch" assert vals.device == mat.device, "device mismatch" if mat.is_cuda: - assert rows.is_cuda and cols.is_cuda and vals.is_cuda + assert ( + rows.is_cuda and cols.is_cuda and vals.is_cuda + ), "All inputs must be on cuda" rows = rows.int() cols = cols.int() - return MEB.coo_spmm_int32( - rows, cols, vals, size[0], size[1], mat, cuda_spmm_alg + result, num_nonzero = MEB.coo_spmm_int32( + rows, cols, vals, size[0], size[1], mat, cuda_spmm_alg, return_num_nonzero ) # WARNING: TODO: not sorting the vals. Should not be used for generic SPMM @@ -54,7 +57,10 @@ def spmm( # rows, cols, vals, size[0], size[1], mat, cuda_spmm_alg # ) else: - COO = torch.stack((rows, cols), 0,).long() + COO = torch.stack( + (rows, cols), + 0, + ).long() torchSparseTensor = None if vals.dtype == torch.float64: torchSparseTensor = torch.sparse.DoubleTensor @@ -64,7 +70,14 @@ def spmm( raise ValueError(f"Unsupported data type: {vals.dtype}") sp = torchSparseTensor(COO, vals, size) - return sp.matmul(mat) + result = sp.matmul(mat) + if return_num_nonzero: + num_nonzero = sp.matmul(torch.ones((size[1], 1), dtype=vals.dtype)) + + if return_num_nonzero: + return result, num_nonzero + else: + return result class MinkowskiSPMMFunction(Function): @@ -78,19 +91,34 @@ def forward( mat: torch.Tensor, cuda_spmm_alg: int = 1, ): - ctx.save_for_backward(rows, cols, vals) ctx.misc_args = size, cuda_spmm_alg + ctx.save_for_backward(rows, cols, vals) mat = mat.contiguous() - out = spmm(rows, cols, vals, size, mat, cuda_spmm_alg) - return out + return spmm( + rows, + cols, + vals, + size, + mat, + return_num_nonzero=False, + cuda_spmm_alg=cuda_spmm_alg, + ) @staticmethod def backward(ctx, grad: torch.Tensor): - rows, cols, vals = ctx.saved_tensors size, cuda_spmm_alg = ctx.misc_args + rows, cols, vals = ctx.saved_tensors new_size = torch.Size([size[1], size[0]]) grad = grad.contiguous() - grad = spmm(cols, rows, vals, new_size, grad, cuda_spmm_alg) + grad = spmm( + cols, + rows, + vals, + new_size, + grad, + return_num_nonzero=False, + cuda_spmm_alg=cuda_spmm_alg, + ) return ( None, None, diff --git a/examples/multigpu_lightning.py b/examples/multigpu_lightning.py index e2cf0ada..547adcef 100644 --- a/examples/multigpu_lightning.py +++ b/examples/multigpu_lightning.py @@ -196,6 +196,9 @@ def configure_optimizers(self): print(f"Testing {num_devices} GPUs.") # Training - pl_module = MinkowskiSegmentationModule(DummyNetwork(3, 20, D=3), lr=args.lr) + model = DummyNetwork(3, 20, D=3) + if args.ngpus > 1: + model = ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm(model) + pl_module = MinkowskiSegmentationModule(model, lr=args.lr) trainer = Trainer(max_epochs=args.max_epochs, gpus=num_devices, accelerator="ddp") trainer.fit(pl_module) diff --git a/pybind/extern.hpp b/pybind/extern.hpp index 01d891fa..456e0624 100644 --- a/pybind/extern.hpp +++ b/pybind/extern.hpp @@ -491,10 +491,11 @@ at::Tensor quantization_average_features(at::Tensor in_feat, at::Tensor in_map, #ifndef CPU_ONLY template -torch::Tensor coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols, - torch::Tensor const &vals, int64_t const dim_i, - int64_t const dim_j, torch::Tensor const &mat2, - int64_t spmm_algorithm_id); +std::pair +coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols, + torch::Tensor const &vals, int64_t const dim_i, int64_t const dim_j, + torch::Tensor const &mat2, int64_t const spmm_algorithm_id, + bool const return_num_nonzero); std::pair get_memory_info(); #endif @@ -757,8 +758,7 @@ void instantiate_manager(py::module &m, const std::string &dtypestr) { &manager_type::to_string, py::const_)) .def("insert_and_map", &manager_type::insert_and_map) .def("insert_field", &manager_type::insert_field) - .def("field_to_sparse_map", - &manager_type::field_to_sparse_map) + .def("field_to_sparse_map", &manager_type::field_to_sparse_map) .def("field_to_sparse_insert_and_map", &manager_type::field_to_sparse_insert_and_map) .def("exists_field_to_sparse", diff --git a/src/coordinate_map_manager.cpp b/src/coordinate_map_manager.cpp index 895aab8e..12d6c40b 100644 --- a/src/coordinate_map_manager.cpp +++ b/src/coordinate_map_manager.cpp @@ -325,9 +325,6 @@ CoordinateMapManager InterpolationForwardGPU( auto const &out_maps = map_weight[1]; auto const &weights = map_weight[2]; - auto out_feat = coo_spmm(out_maps, in_maps, weights, tfield.size(0), - in_feat.size(0), in_feat, 1); + auto out_feat_pair = coo_spmm(out_maps, in_maps, weights, tfield.size(0), + in_feat.size(0), in_feat, 1, false); // to out_feats - map_weight.insert(map_weight.begin(), out_feat); + map_weight.insert(map_weight.begin(), out_feat_pair.first); return map_weight; } @@ -102,8 +102,10 @@ at::Tensor InterpolationBackwardGPU( uint32_t const in_nrows = p_map_manager->size(in_key); LOG_DEBUG("InterpolationBackwardKernelGPU"); - return coo_spmm(in_maps, out_maps, weights, in_nrows, - grad_out_feat.size(0), grad_out_feat, 1); + auto out_feat_pair = + coo_spmm(in_maps, out_maps, weights, in_nrows, grad_out_feat.size(0), + grad_out_feat, 1, false); + return out_feat_pair.first; } // Forward diff --git a/src/pooling_avg_kernel.cu b/src/pooling_avg_kernel.cu index 36142f26..78ed1200 100644 --- a/src/pooling_avg_kernel.cu +++ b/src/pooling_avg_kernel.cu @@ -71,7 +71,7 @@ __global__ void col2row_major_with_div(const int n, const int nrows, CUDA_KERNEL_LOOP(index, n) { i = index % nrows; j = index / nrows; - if (num_nonzero[i]) { + if (num_nonzero[i] >= 1) { rowA[i * ncols + j] = colA[index] / num_nonzero[i]; } else { rowA[i * ncols + j] = colA[index]; @@ -79,6 +79,16 @@ __global__ void col2row_major_with_div(const int n, const int nrows, } } +template +__global__ void +unique_row2num_nonzero(const int n, Dtype *__restrict__ d_num_nonzero, + const Itype *__restrict__ unique_row_ptr, + const Dtype *__restrict__ reduced_val_ptr) { + CUDA_KERNEL_LOOP(index, n) { + d_num_nonzero[unique_row_ptr[index]] = reduced_val_ptr[index]; + } +} + template __global__ void set_gradient(const int n, const Dtype *d_grad_out, Dtype *d_grad_in, const Itype *out_index, @@ -109,7 +119,7 @@ set_gradient_nonzero_avg(const int n, const Dtype *d_grad_out, Dtype *d_grad_in, int nrow = index / nchannel; int ch = index % nchannel; int curr_num_nonzero = d_num_nonzero[out_map[nrow]]; - if (curr_num_nonzero > 0) + if (curr_num_nonzero >= 1) atomicAdd(&d_grad_in[in_map[nrow] * nchannel + ch], d_grad_out[out_map[nrow] * nchannel + ch] / curr_num_nonzero); } @@ -153,31 +163,21 @@ void NonzeroAvgPoolingForwardKernelGPU( /* sparse mm prep */ size_t const sparse_nnzs = kernel_map.in_maps.end() - kernel_map.in_maps.begin(); - size_t one_vector_size = ((use_avg ? in_nrows : 0) // in_nrows vector - + sparse_nnzs // coo vals - + nchannel * out_nrows // out tmp - ) * - sizeof(Dtype); static_assert(is_int32, "sort_coo supports int32"); sort_coo_gpu(cushandle, out_nrows, in_nrows, sparse_nnzs, (int *)kernel_map.out_maps.begin(), (int *)kernel_map.in_maps.begin(), allocator); - // one vector. - d_ones = (Dtype *)allocator.allocate(one_vector_size); - + // feature output + d_tmp_out_feat = + (Dtype *)allocator.allocate(nchannel * out_nrows * sizeof(Dtype)); + d_coo_val = (Dtype *)allocator.allocate(sparse_nnzs * sizeof(Dtype)); + fill<<>>(sparse_nnzs, d_coo_val, (Dtype)1.); if (use_avg) { - d_ones = d_ones; // in_nrows; - d_coo_val = d_ones + in_nrows; // sparse_nnzs - d_tmp_out_feat = d_coo_val + sparse_nnzs; // nchannel * out_nrows - fill<<>>(in_nrows + sparse_nnzs, d_ones, - (Dtype)1.); - } else { - d_coo_val = d_ones; // sparse_nnzs - d_tmp_out_feat = d_coo_val + sparse_nnzs; // nchannel * out_nrows + d_ones = (Dtype *)allocator.allocate(sparse_nnzs * sizeof(Dtype)); fill<<>>(sparse_nnzs, d_coo_val, (Dtype)1.); + 0, stream>>>(sparse_nnzs, d_ones, (Dtype)1.); } #ifdef DEBUG @@ -206,6 +206,20 @@ void NonzeroAvgPoolingForwardKernelGPU( std::cout << "done printing\n"; #endif + Itype *sorted_row_ptr = + (Itype *)allocator.allocate(2 * (sparse_nnzs + 1) * sizeof(Itype)); + Itype *sorted_col_ptr = sorted_row_ptr + sparse_nnzs + 1; + + CUDA_CHECK(cudaMemcpy(sorted_row_ptr, kernel_map.out_maps.begin(), + sparse_nnzs * sizeof(Itype), cudaMemcpyDeviceToDevice)); + CUDA_CHECK(cudaMemcpy(sorted_col_ptr, kernel_map.in_maps.begin(), + sparse_nnzs * sizeof(Itype), cudaMemcpyDeviceToDevice)); + + thrust::sort_by_key(thrust::device, // + sorted_row_ptr, // key begin + sorted_row_ptr + sparse_nnzs, // key end + sorted_col_ptr); + // +---------+ +---+ // | spm | | i | // +---------+ | n | @@ -219,11 +233,11 @@ void NonzeroAvgPoolingForwardKernelGPU( cusparseDnMatDescr_t dense_descr; cusparseDnMatDescr_t result_descr; CUSPARSE_CHECK( - cusparseCreateCoo(&sparse_descr, // - dim_i, dim_j, sparse_nnzs, // - kernel_map.out_maps.begin(), // rows - kernel_map.in_maps.begin(), // cols - d_coo_val, // coo vals + cusparseCreateCoo(&sparse_descr, // + dim_i, dim_j, sparse_nnzs, // + sorted_row_ptr, // rows + sorted_col_ptr, // cols + d_coo_val, // coo vals is_int32 ? CUSPARSE_INDEX_32I : CUSPARSE_INDEX_64I, CUSPARSE_INDEX_BASE_ZERO, cuda_data_type)); @@ -237,6 +251,12 @@ void NonzeroAvgPoolingForwardKernelGPU( (void *)d_tmp_out_feat, // cuda_data_type, CUSPARSE_ORDER_COL)); + size_t buffer_size = 0; + CUSPARSE_CHECK(cusparseSpMM_bufferSize( + cushandle, CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_TRANSPOSE, + (void *)&alpha, sparse_descr, dense_descr, (void *)&beta, result_descr, + cuda_data_type, mm_alg, &buffer_size)); + // buffer size 0 for CUSPARSE_SPMM_COO_ALG1, CUSPARSE_SPMM_COO_ALG3, // CUSPARSE_SPMM_COO_ALG4, and CUSPARSE_SPMM_CSR_ALG1 @@ -248,31 +268,51 @@ void NonzeroAvgPoolingForwardKernelGPU( (void *)&alpha, // sparse_descr, dense_descr, // (void *)&beta, result_descr, // - cuda_data_type, mm_alg, 0)); + cuda_data_type, mm_alg, &buffer_size)); #ifdef DEBUG CUDA_CHECK(cudaStreamSynchronize(0)); #endif LOG_DEBUG("SPMM"); if (use_avg) { - cusparseDnVecDescr_t vecX, vecY; - // Create dense vector X - CUSPARSE_CHECK( - cusparseCreateDnVec(&vecX, in_nrows, d_ones, cuda_data_type)); - // Create dense vector y - CUSPARSE_CHECK( - cusparseCreateDnVec(&vecY, out_nrows, d_num_nonzero, cuda_data_type)); - - CUSPARSE_CHECK(cusparseSpMV(cushandle, // - CUSPARSE_OPERATION_NON_TRANSPOSE, // - (void *)&alpha, // - sparse_descr, vecX, // - (void *)&beta, vecY, // - cuda_data_type, CUSPARSE_COOMV_ALG, nullptr)); + Itype *unique_row_ptr = + (Itype *)allocator.allocate(sparse_nnzs * sizeof(Itype)); + Dtype *reduced_val_ptr = + (Dtype *)allocator.allocate(sparse_nnzs * sizeof(Dtype)); + + // reduce by key + auto end = thrust::reduce_by_key(thrust::device, // policy + sorted_row_ptr, // key begin + sorted_row_ptr + sparse_nnzs, // key end + d_ones, // value begin + unique_row_ptr, // key out begin + reduced_val_ptr // value out begin + ); + + int num_unique_keys = end.first - unique_row_ptr; + LOG_DEBUG("Num unique keys:", num_unique_keys); + #ifdef DEBUG - CUDA_CHECK(cudaStreamSynchronize(0)); + Itype *p_unique_row = (Itype *)std::malloc(num_unique_keys * sizeof(Itype)); + CUDA_CHECK(cudaMemcpy(p_unique_row, unique_row_ptr, + num_unique_keys * sizeof(Itype), + cudaMemcpyDeviceToHost)); + std::cout << "[" << PtrToString(p_unique_row, num_unique_keys) << "]\n"; + std::free(p_unique_row); + + Dtype *p_reduced_val = + (Dtype *)std::malloc(num_unique_keys * sizeof(Dtype)); + CUDA_CHECK(cudaMemcpy(p_reduced_val, reduced_val_ptr, + num_unique_keys * sizeof(Dtype), + cudaMemcpyDeviceToHost)); + std::cout << "[" << PtrToString(p_reduced_val, num_unique_keys) << "]\n"; + std::free(p_reduced_val); #endif - LOG_DEBUG("SPMV"); + // Copy the results to the correct output + unique_row2num_nonzero + <<>>(num_unique_keys, d_num_nonzero, unique_row_ptr, + reduced_val_ptr); col2row_major_with_div <<<<>>( @@ -295,7 +336,11 @@ void NonzeroAvgPoolingForwardKernelGPU( CUSPARSE_CHECK(cusparseDestroyDnMat(dense_descr)); CUSPARSE_CHECK(cusparseDestroyDnMat(result_descr)); - allocator.deallocate((char *)d_ones, one_vector_size); + allocator.deallocate((char *)d_coo_val, sparse_nnzs * sizeof(Dtype)); + allocator.deallocate((char *)d_tmp_out_feat, + nchannel * out_nrows * sizeof(Dtype)); + if (use_avg) + allocator.deallocate((char *)d_ones, in_nrows * sizeof(Dtype)); CUDA_CHECK(cudaStreamSynchronize(0)); } diff --git a/src/spmm.cu b/src/spmm.cu index 5ca9c519..40977309 100644 --- a/src/spmm.cu +++ b/src/spmm.cu @@ -37,6 +37,16 @@ namespace minkowski { +template +__global__ void +unique_row2num_nonzero(const int n, Dtype *__restrict__ d_num_nonzero, + const Itype *__restrict__ unique_row_ptr, + const Dtype *__restrict__ reduced_val_ptr) { + CUDA_KERNEL_LOOP(index, n) { + d_num_nonzero[unique_row_ptr[index]] = reduced_val_ptr[index]; + } +} + cudaDataType getTensorCudaDataType(torch::Tensor const &self) { cudaDataType cuda_data_type; switch (self.scalar_type()) { @@ -54,10 +64,11 @@ cudaDataType getTensorCudaDataType(torch::Tensor const &self) { } template -torch::Tensor coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols, - torch::Tensor const &vals, int64_t const dim_i, - int64_t const dim_j, torch::Tensor const &mat2, - int64_t spmm_algorithm_id) { +std::pair +coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols, + torch::Tensor const &vals, int64_t const dim_i, int64_t const dim_j, + torch::Tensor const &mat2, int64_t const spmm_algorithm_id, + bool const return_num_nonzero) { #if defined __HIP_PLATFORM_HCC__ TORCH_CHECK(false, "spmm sparse-dense is not supported on HIP"); #elif defined(_WIN32) || defined(_WIN64) @@ -145,9 +156,13 @@ torch::Tensor coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols, int64_t dim_k = mat2.size(1); torch::Tensor result = at::zeros({dim_k, dim_i}, mat2.options()); + torch::Tensor num_nonzero = at::zeros({0}, mat2.options()); - if ((dim_j == 0) || (dim_k == 0)) { - return result; + // Create tensors to view just the current set of matrices + int64_t const nnz = rows.numel(); + + if ((dim_j == 0) || (dim_k == 0) || (nnz == 0)) { + return std::make_pair(result, num_nonzero); } // Dense matrices have to be contiguous for cusparseSpMM to work @@ -157,15 +172,6 @@ torch::Tensor coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols, torch::Scalar beta = 0; torch::Scalar alpha = 1; - // Create tensors to view just the current set of matrices - int64_t const nnz = rows.numel(); - - if (nnz == 0) { - result.transpose_(0, 1); - result.zero_(); - return result; - } - cudaDataType cuda_data_type = getTensorCudaDataType(mat2_contig); th_int_type *row_indices_ptr = reinterpret_cast(rows.data_ptr()); @@ -267,6 +273,43 @@ torch::Tensor coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols, CUSPARSE_CHECK(cusparseDestroyDnMat(dense_descr)); CUSPARSE_CHECK(cusparseDestroyDnMat(result_descr)); + // Num nonzer + if (return_num_nonzero) { + th_int_type *unique_row_ptr = + (th_int_type *)c10::cuda::CUDACachingAllocator::raw_alloc( + nnz * sizeof(th_int_type)); + scalar_t *reduced_val_ptr = + (scalar_t *)c10::cuda::CUDACachingAllocator::raw_alloc( + nnz * sizeof(scalar_t)); + torch::Tensor ones = at::zeros({nnz}, mat2.options()); + + num_nonzero.resize_({dim_i, 1}); + num_nonzero.zero_(); + + // reduce by key + auto end = thrust::reduce_by_key( + thrust::device, // policy + sorted_row_ptr, // key begin + sorted_row_ptr + nnz, // key end + reinterpret_cast(ones.data_ptr()), // value begin + unique_row_ptr, // key out begin + reduced_val_ptr // value out begin + ); + + int num_unique_keys = end.first - unique_row_ptr; + LOG_DEBUG("Num unique keys:", num_unique_keys); + + // Copy the results to the correct output + unique_row2num_nonzero + <<>>( + num_unique_keys, + reinterpret_cast(num_nonzero.data_ptr()), + unique_row_ptr, reduced_val_ptr); + + c10::cuda::CUDACachingAllocator::raw_delete((void *)unique_row_ptr); + c10::cuda::CUDACachingAllocator::raw_delete((void *)reduced_val_ptr); + } + LOG_DEBUG("Dealloc"); c10::cuda::CUDACachingAllocator::raw_delete((void *)sorted_row_ptr); c10::cuda::CUDACachingAllocator::raw_delete((void *)sorted_val_ptr); @@ -282,14 +325,15 @@ torch::Tensor coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols, CUDA_CHECK(cudaGetLastError()); - return result; + return std::make_pair(result, num_nonzero); } -template torch::Tensor +template std::pair coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols, torch::Tensor const &vals, int64_t const dim_i, int64_t const dim_j, torch::Tensor const &mat2, - int64_t spmm_algorithm_id); + int64_t const spmm_algorithm_id, + bool const return_num_nonzero); // template torch::Tensor // coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols, diff --git a/src/spmm.cuh b/src/spmm.cuh index f9447e8d..2b5078aa 100644 --- a/src/spmm.cuh +++ b/src/spmm.cuh @@ -38,10 +38,11 @@ namespace minkowski { template -torch::Tensor coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols, - torch::Tensor const &vals, int64_t const dim_i, - int64_t const dim_j, torch::Tensor const &mat2, - int64_t spmm_algorithm_id); +std::pair +coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols, + torch::Tensor const &vals, int64_t const dim_i, int64_t const dim_j, + torch::Tensor const &mat2, int64_t const spmm_algorithm_id, + bool const return_num_nonzero); } #endif diff --git a/tests/python/spmm.py b/tests/python/spmm.py index 312c02b0..56f1f2d9 100644 --- a/tests/python/spmm.py +++ b/tests/python/spmm.py @@ -30,7 +30,6 @@ class TestSPMM(unittest.TestCase): - def test(self): rows = torch.Tensor([0, 0, 1, 1]).int() cols = torch.Tensor([0, 1, 2, 3]).int() @@ -75,9 +74,9 @@ def test_dtype(self): if not torch.cuda.is_available(): return - rows = torch.Tensor([0, 0, 1, 1]).float().to(0) - cols = torch.Tensor([0, 1, 2, 3]).double().to(0) - vals = torch.ones(4).float().to(0) + rows = torch.cuda.IntTensor([0, 0, 1, 1]) + cols = torch.cuda.IntTensor([0, 1, 2, 3]) + vals = torch.ones(4).double().to(0) size = [2, 4] mat = mat.to(0) mat.requires_grad_() diff --git a/tests/python/tensor_field.py b/tests/python/tensor_field.py index 66c87c17..429ae758 100644 --- a/tests/python/tensor_field.py +++ b/tests/python/tensor_field.py @@ -179,5 +179,6 @@ def field_to_sparse(self): ) otensor = network(tfield) + otensor.F.sum().backward() field_to_sparse = tfield.sparse(coordinate_map_key=otensor.coordinate_map_key) self.assertTrue(len(field_to_sparse.F) == len(otensor))