Skip to content

Commit

Permalink
avg pooling for CUDA 10 (fix #246)
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Feb 4, 2021
1 parent ae13226 commit 7ccf01b
Show file tree
Hide file tree
Showing 14 changed files with 234 additions and 111 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
- TensorField to sparse with coordinate map key
- Sparse matrix multiplication
- force contiguous matrix
- Fix AveragePooling cudaErrorMisalignedAddress error for CUDA 10 (#246)

## [0.5.0] - 2020-12-24

Expand Down
2 changes: 1 addition & 1 deletion MinkowskiEngine/MinkowskiNormalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def forward(
in_coords_key: CoordinateMapKey,
glob_coords_key: CoordinateMapKey = None,
coords_manager: CoordinateManager = None,
gpooling_mode=PoolingMode.GLOBAL_AVG_POOLING_PYTORCH_INDEX,
gpooling_mode=PoolingMode.GLOBAL_AVG_POOLING_KERNEL,
):
if glob_coords_key is None:
glob_coords_key = CoordinateMapKey(in_coords_key.get_coordinate_size())
Expand Down
15 changes: 8 additions & 7 deletions MinkowskiEngine/MinkowskiSparseTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,13 +314,14 @@ def initialize_coordinates(self, coordinates, features, coordinate_map_key):
self.quantization_mode
== SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE
):
nums = spmm.apply(
self.inverse_mapping,
cols,
vals,
size,
vals.reshape(N, 1),
)
with torch.no_grad():
nums = spmm.apply(
self.inverse_mapping,
cols,
vals,
size,
vals.reshape(N, 1),
)
features /= nums
elif self.quantization_mode == SparseTensorQuantizationMode.RANDOM_SUBSAMPLE:
features = features[self.unique_index]
Expand Down
15 changes: 8 additions & 7 deletions MinkowskiEngine/MinkowskiTensorField.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,14 @@ def sparse(
self.quantization_mode
== SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE
):
nums = spmm.apply(
inverse_mapping,
cols,
vals,
size,
vals.reshape(N, 1),
)
with torch.no_grad():
nums = spmm.apply(
inverse_mapping,
cols,
vals,
size,
vals.reshape(N, 1),
)
features /= nums
elif self.quantization_mode == SparseTensorQuantizationMode.RANDOM_SUBSAMPLE:
features = self._F[unique_index]
Expand Down
48 changes: 38 additions & 10 deletions MinkowskiEngine/sparse_matrix_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def spmm(
vals: torch.Tensor,
size: torch.Size,
mat: torch.Tensor,
return_num_nonzero: bool = False,
cuda_spmm_alg: int = 1,
):

Expand All @@ -41,11 +42,13 @@ def spmm(
assert vals.dtype == mat.dtype, "dtype mismatch"
assert vals.device == mat.device, "device mismatch"
if mat.is_cuda:
assert rows.is_cuda and cols.is_cuda and vals.is_cuda
assert (
rows.is_cuda and cols.is_cuda and vals.is_cuda
), "All inputs must be on cuda"
rows = rows.int()
cols = cols.int()
return MEB.coo_spmm_int32(
rows, cols, vals, size[0], size[1], mat, cuda_spmm_alg
result, num_nonzero = MEB.coo_spmm_int32(
rows, cols, vals, size[0], size[1], mat, cuda_spmm_alg, return_num_nonzero
)

# WARNING: TODO: not sorting the vals. Should not be used for generic SPMM
Expand All @@ -54,7 +57,10 @@ def spmm(
# rows, cols, vals, size[0], size[1], mat, cuda_spmm_alg
# )
else:
COO = torch.stack((rows, cols), 0,).long()
COO = torch.stack(
(rows, cols),
0,
).long()
torchSparseTensor = None
if vals.dtype == torch.float64:
torchSparseTensor = torch.sparse.DoubleTensor
Expand All @@ -64,7 +70,14 @@ def spmm(
raise ValueError(f"Unsupported data type: {vals.dtype}")

sp = torchSparseTensor(COO, vals, size)
return sp.matmul(mat)
result = sp.matmul(mat)
if return_num_nonzero:
num_nonzero = sp.matmul(torch.ones((size[1], 1), dtype=vals.dtype))

if return_num_nonzero:
return result, num_nonzero
else:
return result


class MinkowskiSPMMFunction(Function):
Expand All @@ -78,19 +91,34 @@ def forward(
mat: torch.Tensor,
cuda_spmm_alg: int = 1,
):
ctx.save_for_backward(rows, cols, vals)
ctx.misc_args = size, cuda_spmm_alg
ctx.save_for_backward(rows, cols, vals)
mat = mat.contiguous()
out = spmm(rows, cols, vals, size, mat, cuda_spmm_alg)
return out
return spmm(
rows,
cols,
vals,
size,
mat,
return_num_nonzero=False,
cuda_spmm_alg=cuda_spmm_alg,
)

@staticmethod
def backward(ctx, grad: torch.Tensor):
rows, cols, vals = ctx.saved_tensors
size, cuda_spmm_alg = ctx.misc_args
rows, cols, vals = ctx.saved_tensors
new_size = torch.Size([size[1], size[0]])
grad = grad.contiguous()
grad = spmm(cols, rows, vals, new_size, grad, cuda_spmm_alg)
grad = spmm(
cols,
rows,
vals,
new_size,
grad,
return_num_nonzero=False,
cuda_spmm_alg=cuda_spmm_alg,
)
return (
None,
None,
Expand Down
5 changes: 4 additions & 1 deletion examples/multigpu_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def configure_optimizers(self):
print(f"Testing {num_devices} GPUs.")

# Training
pl_module = MinkowskiSegmentationModule(DummyNetwork(3, 20, D=3), lr=args.lr)
model = DummyNetwork(3, 20, D=3)
if args.ngpus > 1:
model = ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm(model)
pl_module = MinkowskiSegmentationModule(model, lr=args.lr)
trainer = Trainer(max_epochs=args.max_epochs, gpus=num_devices, accelerator="ddp")
trainer.fit(pl_module)
12 changes: 6 additions & 6 deletions pybind/extern.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,10 +491,11 @@ at::Tensor quantization_average_features(at::Tensor in_feat, at::Tensor in_map,

#ifndef CPU_ONLY
template <typename th_int_type>
torch::Tensor coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols,
torch::Tensor const &vals, int64_t const dim_i,
int64_t const dim_j, torch::Tensor const &mat2,
int64_t spmm_algorithm_id);
std::pair<torch::Tensor, torch::Tensor>
coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols,
torch::Tensor const &vals, int64_t const dim_i, int64_t const dim_j,
torch::Tensor const &mat2, int64_t const spmm_algorithm_id,
bool const return_num_nonzero);

std::pair<size_t, size_t> get_memory_info();
#endif
Expand Down Expand Up @@ -757,8 +758,7 @@ void instantiate_manager(py::module &m, const std::string &dtypestr) {
&manager_type::to_string, py::const_))
.def("insert_and_map", &manager_type::insert_and_map)
.def("insert_field", &manager_type::insert_field)
.def("field_to_sparse_map",
&manager_type::field_to_sparse_map)
.def("field_to_sparse_map", &manager_type::field_to_sparse_map)
.def("field_to_sparse_insert_and_map",
&manager_type::field_to_sparse_insert_and_map)
.def("exists_field_to_sparse",
Expand Down
3 changes: 0 additions & 3 deletions src/coordinate_map_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,6 @@ CoordinateMapManager<coordinate_type, coordinate_field_type, TemplatedAllocator,
#endif
}

LOG_DEBUG("initializing a field map with tensor stride:", map_key.first,
"string id:", map_key.second);

auto const map_inverse_map =
sparse_map.field_map(field_map.const_coordinate_data(), field_map.size());

Expand Down
12 changes: 7 additions & 5 deletions src/interpolation_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ std::vector<at::Tensor> InterpolationForwardGPU(
auto const &out_maps = map_weight[1];
auto const &weights = map_weight[2];

auto out_feat = coo_spmm<int>(out_maps, in_maps, weights, tfield.size(0),
in_feat.size(0), in_feat, 1);
auto out_feat_pair = coo_spmm<int>(out_maps, in_maps, weights, tfield.size(0),
in_feat.size(0), in_feat, 1, false);
// to out_feats
map_weight.insert(map_weight.begin(), out_feat);
map_weight.insert(map_weight.begin(), out_feat_pair.first);
return map_weight;
}

Expand All @@ -102,8 +102,10 @@ at::Tensor InterpolationBackwardGPU(
uint32_t const in_nrows = p_map_manager->size(in_key);

LOG_DEBUG("InterpolationBackwardKernelGPU");
return coo_spmm<int>(in_maps, out_maps, weights, in_nrows,
grad_out_feat.size(0), grad_out_feat, 1);
auto out_feat_pair =
coo_spmm<int>(in_maps, out_maps, weights, in_nrows, grad_out_feat.size(0),
grad_out_feat, 1, false);
return out_feat_pair.first;
}

// Forward
Expand Down
Loading

0 comments on commit 7ccf01b

Please sign in to comment.