Skip to content

Commit

Permalink
channelwise conv
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Jan 1, 2021
1 parent 032bb51 commit d5c6fe8
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 113 deletions.
156 changes: 83 additions & 73 deletions MinkowskiEngine/MinkowskiChannelwiseConvolution.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) Chris Choy (chrischoy@ai.stanford.edu).
# Copyright (c) 2020 NVIDIA CORPORATION.
# Copyright (c) 2018-2020 Chris Choy (chrischoy@ai.stanford.edu).
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
Expand Down Expand Up @@ -27,13 +28,24 @@
import torch
from torch.nn import Parameter

from SparseTensor import SparseTensor
from Common import RegionType, MinkowskiModuleBase, KernelGenerator, \
prep_args, convert_to_int_list, convert_to_int_tensor
from MinkowskiCoords import CoordsKey
from MinkowskiSparseTensor import SparseTensor
from MinkowskiEngineBackend._C import CoordinateMapKey, RegionType
from MinkowskiCommon import MinkowskiModuleBase
from MinkowskiKernelGenerator import KernelGenerator


class MinkowskiChannelwiseConvolution(MinkowskiModuleBase):

__slots__ = (
"in_channels",
"out_channels",
"kernel_generator",
"dimension",
"kernel",
"bias",
"conv",
)

r"""Channelwise (Depthwise) Convolution layer for a sparse tensor.
Expand All @@ -57,14 +69,16 @@ class MinkowskiChannelwiseConvolution(MinkowskiModuleBase):
"""

def __init__(self,
in_channels,
kernel_size=-1,
stride=1,
dilation=1,
has_bias=False,
kernel_generator=None,
dimension=-1):
def __init__(
self,
in_channels,
kernel_size=-1,
stride=1,
dilation=1,
bias=False,
kernel_generator=None,
dimension=-1,
):
r"""convolution on a sparse tensor
Args:
Expand All @@ -87,7 +101,7 @@ def __init__(self,
convolution kernel. When a list is given, the length must be D and
each element is an axis specific dilation. All elements must be > 0.
:attr:`has_bias` (bool, optional): if True, the convolution layer
:attr:`bias` (bool, optional): if True, the convolution layer
has a bias.
:attr:`kernel_generator` (:attr:`MinkowskiEngine.KernelGenerator`,
Expand All @@ -107,97 +121,93 @@ def __init__(self,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
dimension=dimension)
else:
kernel_size = kernel_generator.kernel_size

stride = convert_to_int_tensor(stride, dimension)
kernel_size = convert_to_int_tensor(kernel_size, dimension)
dilation = convert_to_int_tensor(dilation, dimension)
dimension=dimension,
)

kernel_volume = kernel_generator.kernel_volume
self.kernel_generator = kernel_generator

self.in_channels = in_channels
self.kernel_size = kernel_size
self.kernel_volume = kernel_volume
self.stride = stride
self.dilation = dilation
self.kernel_generator = kernel_generator
self.dimension = dimension
self.use_mm = False # use matrix multiplication when kernel is 1

Tensor = torch.FloatTensor
self.kernel_shape = (self.kernel_volume, self.in_channels)
self.kernel_shape = (kernel_generator.kernel_volume, self.in_channels)

Tensor = torch.FloatTensor
self.kernel = Parameter(Tensor(*self.kernel_shape))
self.bias = Parameter(Tensor(1, in_channels)) if has_bias else None
self.has_bias = has_bias
self.bias = Parameter(Tensor(1, in_channels)) if bias else None

self.reset_parameters()

def forward(self,
input: SparseTensor,
coords: Union[torch.IntTensor, CoordsKey, SparseTensor] = None):
def forward(
self,
input: SparseTensor,
coords: Union[torch.IntTensor, CoordinateMapKey, SparseTensor] = None,
):
r"""
:attr:`input` (`MinkowskiEngine.SparseTensor`): Input sparse tensor to apply a
convolution on.
:attr:`coords` ((`torch.IntTensor`, `MinkowskiEngine.CoordsKey`,
:attr:`coords` ((`torch.IntTensor`, `MinkowskiEngine.CoordinateMapKey`,
`MinkowskiEngine.SparseTensor`), optional): If provided, generate
results on the provided coordinates. None by default.
"""
assert isinstance(input, SparseTensor)
assert input.D == self.dimension
assert (
self.in_channels == input.shape[1]
), f"Channel size mismatch {self.in_channels} != {input.shape[1]}"

# Create a region_offset
self.region_type_, self.region_offset_, _ = \
self.kernel_generator.get_kernel(input.tensor_stride, False)
region_type_, region_offset_, _ = self.kernel_generator.get_kernel(
input.tensor_stride, False
)

cm = input.coords_man
in_key = input.coords_key
on_gpu = input.device.type != 'cpu'
cm = input.coordinate_manager
in_key = input.coordinate_map_key

out_key = cm.stride(in_key, self.stride)
N_out = cm.get_coords_size_by_coords_key(out_key)
out_key = cm.stride(in_key, self.kernel_generator.kernel_stride)
N_out = cm.size(out_key)
out_F = input._F.new(N_out, self.in_channels).zero_()

in_maps, out_maps = cm.get_kernel_map(
kernel_map = cm.get_kernel_map(
in_key,
out_key,
self.stride,
self.kernel_size,
self.dilation,
self.region_type_,
self.region_offset_,
is_transpose=False,
is_pool=False,
on_gpu=on_gpu)

for k in range(self.kernel_volume):
out_F[out_maps[k]] += input.F[in_maps[k]] * self.kernel[k]

if self.has_bias:
self.kernel_generator.kernel_stride,
self.kernel_generator.kernel_size,
self.kernel_generator.kernel_dilation,
region_type=region_type_,
region_offset=region_offset_,
)

for k, in_out in kernel_map.items():
in_out = in_out.long().to(input.device)
out_F[in_out[1]] += input.F[in_out[0]] * self.kernel[k]

if self.bias is not None:
out_F += self.bias

return SparseTensor(out_F, coords_key=out_key, coords_manager=cm)
return SparseTensor(out_F, coordinate_map_key=out_key, coordinate_manager=cm)

def reset_parameters(self, is_transpose=False):
n = (self.out_channels
if is_transpose else self.in_channels) * self.kernel_volume
stdv = 1. / math.sqrt(n)
self.kernel.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
with torch.no_grad():
n = (
self.out_channels if is_transpose else self.in_channels
) * self.kernel_generator.kernel_volume
stdv = 1.0 / math.sqrt(n)
self.kernel.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)

def __repr__(self):
s = '(in={}, region_type={}, '.format(self.in_channels,
self.kernel_generator.region_type)
if self.kernel_generator.region_type in [
RegionType.HYBRID, RegionType.CUSTOM
]:
s += 'kernel_volume={}, '.format(self.kernel_volume)
s = "(in={}, region_type={}, ".format(
self.in_channels, self.kernel_generator.region_type
)
if self.kernel_generator.region_type in [RegionType.CUSTOM]:
s += "kernel_volume={}, ".format(self.kernel_generator.kernel_volume)
else:
s += 'kernel_size={}, '.format(self.kernel_size.tolist())
s += 'stride={}, dilation={})'.format(self.stride.tolist(),
self.dilation.tolist())
s += "kernel_size={}, ".format(self.kernel_generator.kernel_size)
s += "stride={}, dilation={})".format(
self.kernel_generator.kernel_stride,
self.kernel_generator.kernel_dilation,
)
return self.__class__.__name__ + s
8 changes: 4 additions & 4 deletions MinkowskiEngine/MinkowskiConvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
get_minkowski_function,
)
from MinkowskiCoordinateManager import CoordinateManager
from MinkowskiKernelGenerator import KernelGenerator, save_ctx
from MinkowskiKernelGenerator import KernelGenerator


class MinkowskiConvolutionFunction(Function):
Expand Down Expand Up @@ -413,7 +413,7 @@ def __init__(
convolution kernel. When a list is given, the length must be D and
each element is an axis specific dilation. All elements must be > 0.
:attr:`has_bias` (bool, optional): if True, the convolution layer
:attr:`bias` (bool, optional): if True, the convolution layer
has a bias.
:attr:`kernel_generator` (:attr:`MinkowskiEngine.KernelGenerator`,
Expand Down Expand Up @@ -487,7 +487,7 @@ def __init__(
convolution kernel. When a list is given, the length must be D and
each element is an axis specific dilation. All elements must be > 0.
:attr:`has_bias` (bool, optional): if True, the convolution layer
:attr:`bias` (bool, optional): if True, the convolution layer
has a bias.
:attr:`kernel_generator` (:attr:`MinkowskiEngine.KernelGenerator`,
Expand Down Expand Up @@ -582,7 +582,7 @@ def __init__(
convolution kernel. When a list is given, the length must be D and
each element is an axis specific dilation. All elements must be > 0.
:attr:`has_bias` (bool, optional): if True, the convolution layer
:attr:`bias` (bool, optional): if True, the convolution layer
has a bias.
:attr:`kernel_generator` (:attr:`MinkowskiEngine.KernelGenerator`,
Expand Down
18 changes: 10 additions & 8 deletions MinkowskiEngine/MinkowskiCoordinateManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,10 @@ def stride(
stride = convert_to_int_list(stride, self.D)
return self._manager.stride(coordinate_map_key, stride, string_id)

def origin(self):
def origin(self) -> CoordinateMapKey:
return self._manager.origin()

def size(self, coordinate_map_key: CoordinateMapKey):
def size(self, coordinate_map_key: CoordinateMapKey) -> int:
return self._manager.size(coordinate_map_key)

# def transposed_stride(
Expand Down Expand Up @@ -249,7 +249,7 @@ def size(self, coordinate_map_key: CoordinateMapKey):
# )
# return strided_key

def _get_coordinate_map_key(self, key_or_tensor_strides):
def _get_coordinate_map_key(self, key_or_tensor_strides) -> CoordinateMapKey:
r"""Helper function that retrieves the first coordinate map key for the given tensor stride."""
assert isinstance(key_or_tensor_strides, CoordinateMapKey) or isinstance(
key_or_tensor_strides, (Sequence, np.ndarray, torch.IntTensor, int)
Expand All @@ -263,18 +263,20 @@ def _get_coordinate_map_key(self, key_or_tensor_strides):
assert len(keys) > 0
return keys[0]

def get_coordinates(self, coords_key_or_tensor_strides):
def get_coordinates(self, coords_key_or_tensor_strides) -> torch.Tensor:
key = self._get_coordinate_map_key(coords_key_or_tensor_strides)
return self._manager.get_coordinates(key)

def get_coordinate_field(self, coords_key_or_tensor_strides):
def get_coordinate_field(self, coords_key_or_tensor_strides) -> torch.Tensor:
key = self._get_coordinate_map_key(coords_key_or_tensor_strides)
return self._manager.get_coordinate_field(key)

def number_of_unique_batch_indices(self):
def number_of_unique_batch_indices(self) -> int:
return self._manager.origin_map_size()

def get_unique_coordinate_map_key(self, tensor_stride: Union[int, list]):
def get_unique_coordinate_map_key(
self, tensor_stride: Union[int, list]
) -> CoordinateMapKey:
"""
Returns a unique coordinate_map_key for a given tensor stride.
Expand All @@ -292,7 +294,7 @@ def get_kernel_map(
region_offset=None,
is_transpose=False,
is_pool=False,
):
) -> dict:
r"""Get kernel in-out maps for the specified coords keys or tensor strides.
returns dict{kernel_index: in_out_tensor} where in_out_tensor[0] is the input row indices that correspond to in_out_tensor[1], which is the row indices for output.
Expand Down
Loading

0 comments on commit d5c6fe8

Please sign in to comment.