Skip to content

Commit

Permalink
fix transpose, MinkowskiOps
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Dec 15, 2020
1 parent ace1cbc commit 19a3b05
Show file tree
Hide file tree
Showing 20 changed files with 382 additions and 123 deletions.
4 changes: 2 additions & 2 deletions MinkowskiEngine/MinkowskiConvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,11 +279,11 @@ def forward(
assert isinstance(input, SparseTensor)
assert input.D == self.dimension

if self.use_mm and coords is None:
if self.use_mm and coordinates is None:
# If the kernel_size == 1, the convolution is simply a matrix
# multiplication
outfeat = input.F.mm(self.kernel)
out_coordinate_map_key = input.coords_key
out_coordinate_map_key = input.coordinate_map_key
else:
# Get a new coordinate_map_key or extract one from the coords
out_coordinate_map_key = _get_coordinate_map_key(input, coordinates)
Expand Down
42 changes: 28 additions & 14 deletions MinkowskiEngine/MinkowskiOps.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@
# of the code.
import torch
from torch.nn.modules import Module
from SparseTensor import SparseTensor, COORDS_MAN_DIFFERENT_ERROR, COORDS_KEY_DIFFERENT_ERROR
from MinkowskiSparseTensor import (
SparseTensor,
COORDINATE_MANAGER_DIFFERENT_ERROR,
COORDINATE_KEY_DIFFERENT_ERROR,
)


class MinkowskiLinear(Module):

def __init__(self, in_features, out_features, bias=True):
super(MinkowskiLinear, self).__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
Expand All @@ -36,13 +39,16 @@ def forward(self, input):
output = self.linear(input.F)
return SparseTensor(
output,
coords_key=input.coords_key,
coords_manager=input.coords_man)
coordinate_map_key=input.coordinate_map_key,
coordinate_manager=input.coordinate_manager,
)

def __repr__(self):
s = '(in_features={}, out_features={}, bias={})'.format(
self.linear.in_features, self.linear.out_features,
self.linear.bias is not None)
s = "(in_features={}, out_features={}, bias={})".format(
self.linear.in_features,
self.linear.out_features,
self.linear.bias is not None,
)
return self.__class__.__name__ + s


Expand All @@ -58,22 +64,30 @@ def cat(*sparse_tensors):
>>> import MinkowskiEngine as ME
>>> sin = ME.SparseTensor(feats, coords)
>>> sin2 = ME.SparseTensor(feats2, coords_key=sin.coords_key, coords_man=sin.coords_man)
>>> sin2 = ME.SparseTensor(feats2, coordinate_map_key=sin.coordinate_map_key, coords_man=sin.coordinate_manager)
>>> sout = UNet(sin) # Returns an output sparse tensor on the same coordinates
>>> sout2 = ME.cat(sin, sin2, sout) # Can concatenate multiple sparse tensors
"""
for s in sparse_tensors:
assert isinstance(s, SparseTensor), "Inputs must be sparse tensors."
coords_man = sparse_tensors[0].coords_man
coords_key = sparse_tensors[0].coords_key
coordinate_manager = sparse_tensors[0].coordinate_manager
coordinate_map_key = sparse_tensors[0].coordinate_map_key
for s in sparse_tensors:
assert coords_man == s.coords_man, COORDS_MAN_DIFFERENT_ERROR
assert coords_key == s.coords_key, COORDS_KEY_DIFFERENT_ERROR
assert (
coordinate_manager == s.coordinate_manager
), COORDINATE_MANAGER_DIFFERENT_ERROR
assert coordinate_map_key == s.coordinate_map_key, (
COORDINATE_KEY_DIFFERENT_ERROR
+ str(coordinate_map_key)
+ " != "
+ str(s.coordinate_map_key)
)
tens = []
for s in sparse_tensors:
tens.append(s.F)
return SparseTensor(
torch.cat(tens, dim=1),
coords_key=sparse_tensors[0].coords_key,
coords_manager=coords_man)
coordinate_map_key=coordinate_map_key,
coordinate_manager=coordinate_manager,
)
10 changes: 5 additions & 5 deletions MinkowskiEngine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@
# from MinkowskiUnion import MinkowskiUnion, MinkowskiUnionFunction
#
# from MinkowskiNetwork import MinkowskiNetwork
#
# import MinkowskiOps
#
# from MinkowskiOps import MinkowskiLinear, cat
#

import MinkowskiOps

from MinkowskiOps import MinkowskiLinear, cat

# import MinkowskiFunctional
#
import MinkowskiEngine.utils as utils
Expand Down
57 changes: 57 additions & 0 deletions MinkowskiEngine/sparse_matrix_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2020 NVIDIA CORPORATION.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
# Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
# of the code.
import torch

import MinkowskiEngineBackend._C as MEB


def spmm(
rows: torch.Tensor,
cols: torch.Tensor,
vals: torch.Tensor,
size: torch.Size,
mat: torch.Tensor,
cuda_spmm_alg: int = 1,
):
if mat.is_cuda:
assert rows.is_cuda and cols.is_cuda and vals.is_cuda
if MEB.cuda_version() < 11000:
rows = rows.int()
cols = cols.int()
return MEB.coo_spmm_int32(
rows, cols, vals, size[0], size[1], mat, cuda_spmm_alg
)
else:
if rows.dtype == torch.int32:
return MEB.coo_spmm_int32(
rows, cols, vals, size[0], size[1], mat, cuda_spmm_alg
)
else:
return MEB.coo_spmm_int64(
rows, cols, vals, size[0], size[1], mat, cuda_spmm_alg
)
else:
COO = torch.stack((rows, cols), 0,)
sp = torch.sparse.Tensor(COO, vals, size)
return sp.matmul(mat)
1 change: 0 additions & 1 deletion examples/indoor.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def load_file(file_name):
# Measure time
with torch.no_grad():
voxel_size = 0.02

# Feed-forward pass and get the prediction
sinput = ME.SparseTensor(
features=torch.from_numpy(colors).float().to(device),
Expand Down
5 changes: 3 additions & 2 deletions pybind/extern.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

#include <torch/extension.h>

#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

Expand Down Expand Up @@ -413,8 +414,8 @@ void initialize_non_templated_classes(py::module &m) {
.def("set_key", (void (minkowski::CoordinateMapKey::*)(
minkowski::default_types::stride_type, std::string)) &
minkowski::CoordinateMapKey::set_key)
.def("get_tensor_stride",
&minkowski::CoordinateMapKey::get_tensor_stride);
.def("get_tensor_stride", &minkowski::CoordinateMapKey::get_tensor_stride)
.def(py::self == py::self);
}

template <typename manager_type>
Expand Down
2 changes: 2 additions & 0 deletions src/convolution_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ at::Tensor ConvolutionForwardGPU(
LOG_DEBUG("Convolution on", out_nrows, "x", kernel.size(2));
AT_DISPATCH_FLOATING_TYPES(
in_feat.scalar_type(), "convolution_forward_gpu", [&] {
LOG_DEBUG("ConvolutionForwardKernelGPU with",
std::is_same<float, scalar_t>::value ? "float" : "double");
ConvolutionForwardKernelGPU<scalar_t, default_types::index_type,
TemplatedAllocator<char>>(
in_feat.template data_ptr<scalar_t>(), //
Expand Down
Loading

0 comments on commit 19a3b05

Please sign in to comment.