From 8633896e5531ede729ab5f920c43309a4d59a28f Mon Sep 17 00:00:00 2001 From: Chris Choy Date: Wed, 29 Jul 2020 13:29:17 -0700 Subject: [PATCH] MinkowskiConvolution CPU/GPU --- MinkowskiEngine/Common.py | 369 --------- MinkowskiEngine/MinkowskiCommon.py | 148 ++++ MinkowskiEngine/MinkowskiConvolution.py | 543 +++++++------ ...oords.py => MinkowskiCoordinateManager.py} | 115 ++- MinkowskiEngine/MinkowskiKernelGenerator.py | 360 +++++++++ MinkowskiEngine/MinkowskiSparseTensor.py | 730 ++++++++++-------- MinkowskiEngine/__init__.py | 40 +- MinkowskiEngine/utils/coords.py | 2 +- docs/migration_05.md | 94 +++ pybind/extern.hpp | 80 +- pybind/minkowski.cpp | 33 +- pybind/minkowski.cu | 79 +- setup.py | 3 + src/convolution_cpu.cpp | 1 + src/convolution_gpu.cu | 1 + src/coordinate_map_cpu.hpp | 18 +- src/coordinate_map_gpu.cu | 12 +- src/coordinate_map_key.hpp | 4 +- src/coordinate_map_manager.cpp | 32 +- src/coordinate_map_manager.cu | 46 +- src/math_functions.cu | 22 +- src/quantization.cpp | 2 + tests/cpp/convolution_cpu_test.py | 52 +- tests/cpp/convolution_gpu_test.py | 61 +- tests/cpp/kernel_region_cpu_test.cpp | 17 +- tests/cpp/kernel_region_cpu_test.py | 24 +- tests/python/{conv.py => convolution.py} | 197 +++-- tests/python/coordinate_manager.py | 127 +++ tests/python/sparse_tensor.py | 50 +- 29 files changed, 1978 insertions(+), 1284 deletions(-) delete mode 100644 MinkowskiEngine/Common.py create mode 100644 MinkowskiEngine/MinkowskiCommon.py rename MinkowskiEngine/{MinkowskiCoords.py => MinkowskiCoordinateManager.py} (84%) create mode 100644 MinkowskiEngine/MinkowskiKernelGenerator.py create mode 100644 docs/migration_05.md rename tests/python/{conv.py => convolution.py} (51%) create mode 100644 tests/python/coordinate_manager.py diff --git a/MinkowskiEngine/Common.py b/MinkowskiEngine/Common.py deleted file mode 100644 index c8e7ddf6..00000000 --- a/MinkowskiEngine/Common.py +++ /dev/null @@ -1,369 +0,0 @@ -# Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). -# -# Permission is hereby granted, free of charge, to any person obtaining a copy of -# this software and associated documentation files (the "Software"), to deal in -# the Software without restriction, including without limitation the rights to -# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies -# of the Software, and to permit persons to whom the Software is furnished to do -# so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# -# Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural -# Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part -# of the code. -import math -from collections import Sequence -import numpy as np -from enum import Enum -from itertools import product -from typing import Union - -import torch - -from torch.nn import Module - -import MinkowskiEngineBackend as MEB - - -class GlobalPoolingMode(Enum): - """ - Define the global pooling mode - """ - AUTO = 0, 'AUTO' - INDEX_SELECT = 1, 'INDEX_SELECT' - SPARSE = 2, 'SPARSE' - - def __new__(cls, value, name): - member = object.__new__(cls) - member._value_ = value - member.fullname = name - return member - - def __int__(self): - return self.value - - -class RegionType(Enum): - """ - Define the kernel region type - """ - HYPERCUBE = 0, 'HYPERCUBE' - HYPERCROSS = 1, 'HYPERCROSS' - CUSTOM = 2, 'CUSTOM' - HYBRID = 3, 'HYBRID' - - def __new__(cls, value, name): - member = object.__new__(cls) - member._value_ = value - member.fullname = name - return member - - def __int__(self): - return self.value - - -def convert_to_int_list(arg: Union[int, Sequence, np.ndarray, torch.Tensor], - dimension: int): - if isinstance(arg, list): - assert len(arg) == dimension - return arg - - if isinstance(arg, (Sequence, np.ndarray, torch.Tensor)): - tmp = [i for i in arg] - assert len(tmp) == dimension - elif np.isscalar(arg): # Assume that it is a scalar - tmp = [int(arg) for i in range(dimension)] - else: - raise ValueError('Input must be a scalar or a sequence') - - return tmp - - -def convert_to_int_tensor( - arg: Union[int, Sequence, np.ndarray, torch.IntTensor], dimension: int): - if isinstance(arg, torch.IntTensor): - assert arg.numel() == dimension - return arg - - if isinstance(arg, (Sequence, np.ndarray)): - tmp = torch.IntTensor([i for i in arg]) - assert tmp.numel() == dimension - elif np.isscalar(arg): # Assume that it is a scalar - tmp = torch.IntTensor([int(arg) for i in range(dimension)]) - else: - raise ValueError('Input must be a scalar or a sequence') - - return tmp - - -def prep_args(tensor_stride: Union[int, Sequence, np.ndarray, torch.IntTensor], - stride: Union[int, Sequence, np.ndarray, torch.IntTensor], - kernel_size: Union[int, Sequence, np.ndarray, torch.IntTensor], - dilation: Union[int, Sequence, np.ndarray, torch.IntTensor], - region_type: Union[int, RegionType], - D=-1): - assert torch.prod( - kernel_size > 0 - ), f"kernel_size must be a positive integer, provided {kernel_size}" - assert D > 0, f"dimension must be a positive integer, {D}" - tensor_stride = convert_to_int_tensor(tensor_stride, D) - stride = convert_to_int_tensor(stride, D) - kernel_size = convert_to_int_tensor(kernel_size, D) - dilation = convert_to_int_tensor(dilation, D) - region_type = int(region_type) - return tensor_stride, stride, kernel_size, dilation, region_type, - - -def get_postfix(tensor: torch.Tensor): - postfix = 'GPU' if tensor.is_cuda else 'CPU' - if isinstance(tensor, torch.DoubleTensor) or isinstance( - tensor, torch.cuda.DoubleTensor): - postfix += 'd' - else: - postfix += 'f' - return postfix - - -def get_kernel_volume(region_type, kernel_size, region_offset, axis_types, - dimension): - """ - when center is True, the custom region_offset will be centered at the - origin. Currently, for HYPERCUBE, HYPERCROSS with odd kernel sizes cannot - use center=False. - """ - if region_type == RegionType.HYPERCUBE: - assert region_offset is None, "Region offset must be None when region_type is given" - assert axis_types is None, "Axis types must be None when region_type is given" - # Typical convolution kernel - assert torch.prod(kernel_size > 0) == 1 - - # Convolution kernel with even numbered kernel size not defined. - kernel_volume = int(torch.prod(kernel_size)) - - elif region_type == RegionType.HYPERCROSS: - assert torch.prod(kernel_size > 0) == 1, "kernel_size must be positive" - assert ( - kernel_size % - 2).prod() == 1, "kernel_size must be odd for region_type HYPERCROSS" - # 0th: itself, (1, 2) for 0th dim neighbors, (3, 4) for 1th dim ... - kernel_volume = int(torch.sum(kernel_size - 1) + 1) - - elif region_type == RegionType.HYBRID: - assert region_offset is None, \ - "region_offset must be None when region_type is HYBRID" - kernel_size_list = kernel_size.tolist() - kernel_volume = 1 - # First HYPERCUBE - for axis_type, curr_kernel_size, d in \ - zip(axis_types, kernel_size_list, range(dimension)): - if axis_type == RegionType.HYPERCUBE: - kernel_volume *= curr_kernel_size - - # Second, HYPERCROSS - for axis_type, curr_kernel_size, d in \ - zip(axis_types, kernel_size_list, range(dimension)): - if axis_type == RegionType.HYPERCROSS: - kernel_volume += (curr_kernel_size - 1) - - elif region_type == RegionType.CUSTOM: - assert region_offset.numel( - ) > 0, "region_offset must be non empty when region_type is CUSTOM" - assert region_offset.size( - 1 - ) == dimension, "region_offset must have the same dimension as the network" - kernel_volume = int(region_offset.size(0)) - - else: - raise NotImplementedError() - - return kernel_volume - - -def convert_region_type( - region_type: RegionType, - tensor_stride: Union[Sequence, np.ndarray, torch.IntTensor], - kernel_size: Union[Sequence, np.ndarray, torch.IntTensor], - up_stride: Union[Sequence, np.ndarray, torch.IntTensor], - dilation: Union[Sequence, np.ndarray, torch.IntTensor], - region_offset: Union[Sequence, np.ndarray, torch.IntTensor], - axis_types: Union[Sequence, np.ndarray, torch.IntTensor], - dimension: int, - center: bool = True): - """ - when center is True, the custom region_offset will be centered at the - origin. Currently, for HYPERCUBE, HYPERCROSS with odd kernel sizes cannot - use center=False. - - up_stride: stride for conv_transpose, otherwise set it as 1 - """ - if region_type == RegionType.HYPERCUBE: - assert region_offset is None, "Region offset must be None when region_type is given" - assert axis_types is None, "Axis types must be None when region_type is given" - # Typical convolution kernel - assert torch.prod(kernel_size > 0) == 1 - # assert torch.unique(dilation).numel() == 1 - kernel_volume = int(torch.prod(kernel_size)) - - elif region_type == RegionType.HYPERCROSS: - assert torch.prod(kernel_size > 0) == 1, "kernel_size must be positive" - assert ( - kernel_size % - 2).prod() == 1, "kernel_size must be odd for region_type HYPERCROSS" - # 0th: itself, (1, 2) for 0th dim neighbors, (3, 4) for 1th dim ... - kernel_volume = int(torch.sum(kernel_size - 1) + 1) - - elif region_type == RegionType.HYBRID: - assert region_offset is None, \ - "region_offset must be None when region_type is HYBRID" - region_offset = [[ - 0, - ] * dimension] - kernel_size_list = kernel_size.tolist() - # First HYPERCUBE - for axis_type, curr_kernel_size, d in \ - zip(axis_types, kernel_size_list, range(dimension)): - new_offset = [] - if axis_type == RegionType.HYPERCUBE: - for offset in region_offset: - for curr_offset in range(curr_kernel_size): - off_center = int( - math.floor( - (curr_kernel_size - 1) / 2)) if center else 0 - offset = offset.copy() # Do not modify the original - # Exclude the coord (0, 0, ..., 0) - if curr_offset == off_center: - continue - offset[d] = (curr_offset - off_center) * \ - dilation[d] * (tensor_stride[d] / up_stride[d]) - new_offset.append(offset) - region_offset.extend(new_offset) - - # Second, HYPERCROSS - for axis_type, curr_kernel_size, d in \ - zip(axis_types, kernel_size_list, range(dimension)): - new_offset = [] - if axis_type == RegionType.HYPERCROSS: - for curr_offset in range(curr_kernel_size): - off_center = int(math.floor( - (curr_kernel_size - 1) / 2)) if center else 0 - offset = [ - 0, - ] * dimension - # Exclude the coord (0, 0, ..., 0) - if curr_offset == off_center: - continue - offset[d] = (curr_offset - off_center) * \ - dilation[d] * (tensor_stride[d] / up_stride[d]) - new_offset.append(offset) - region_offset.extend(new_offset) - - # Convert to CUSTOM type - region_type = RegionType.CUSTOM - region_offset = torch.IntTensor(region_offset) - kernel_volume = int(region_offset.size(0)) - - elif region_type == RegionType.CUSTOM: - assert region_offset.numel( - ) > 0, "region_offset must be non empty when region_type is CUSTOM" - assert region_offset.size( - 1 - ) == dimension, "region_offset must have the same dimension as the network" - kernel_volume = int(region_offset.size(0)) - assert isinstance( - region_offset.dtype, - torch.IntTensor), "region_offset must be a torch.IntTensor." - else: - raise NotImplementedError() - - if region_offset is None: - region_offset = torch.IntTensor() - - return region_type, region_offset, kernel_volume - - -class KernelGenerator: - - def __init__(self, - kernel_size=-1, - stride=1, - dilation=1, - is_transpose=False, - region_type=RegionType.HYPERCUBE, - region_offsets=None, - axis_types=None, - dimension=-1): - r""" - :attr:`region_type` (RegionType, optional): defines the kernel - shape. Please refer to MinkowskiEngine.Comon for details. - - :attr:`region_offset` (torch.IntTensor, optional): when the - :attr:`region_type` is :attr:`RegionType.CUSTOM`, the convolution - kernel uses the provided `region_offset` to define offsets. It - should be a matrix of size :math:`N \times D` where :math:`N` is - the number of offsets and :math:`D` is the dimension of the - space. - - :attr:`axis_types` (list of RegionType, optional): If given, it - uses different methods to create a kernel for each axis. e.g., when - it is `[RegionType.HYPERCUBE, RegionType.HYPERCUBE, - RegionType.HYPERCROSS]`, the kernel would be rectangular for the - first two dimensions and cross shaped for the thrid dimension. - """ - assert dimension > 0 - assert isinstance(region_type, RegionType) - - stride = convert_to_int_tensor(stride, dimension) - kernel_size = convert_to_int_tensor(kernel_size, dimension) - dilation = convert_to_int_tensor(dilation, dimension) - - self.cache = {} - self.kernel_size = kernel_size - self.stride = stride - self.dilation = dilation - self.region_type = region_type - self.region_offsets = region_offsets - self.axis_types = axis_types - self.dimension = dimension - self.kernel_volume = get_kernel_volume(region_type, kernel_size, - region_offsets, axis_types, - dimension) - - def get_kernel(self, tensor_stride, is_transpose): - assert len(tensor_stride) == self.dimension - if tuple(tensor_stride) not in self.cache: - up_stride = self.stride \ - if is_transpose else torch.Tensor([1, ] * self.dimension) - - self.cache[tuple(tensor_stride)] = convert_region_type( - self.region_type, tensor_stride, self.kernel_size, up_stride, - self.dilation, self.region_offsets, self.axis_types, - self.dimension) - - return self.cache[tuple(tensor_stride)] - - -class MinkowskiModuleBase(Module): - pass - - -def get_minkowski_function(name, variable): - fn_name = name + get_postfix(variable) - if hasattr(MEB, fn_name): - return getattr(MEB, fn_name) - else: - if variable.is_cuda: - raise ValueError( - f"Function {fn_name} not available. Please compile MinkowskiEngine where `torch.cuda.is_available()` is `True`." - ) - else: - raise ValueError(f"Function {fn_name} not available.") diff --git a/MinkowskiEngine/MinkowskiCommon.py b/MinkowskiEngine/MinkowskiCommon.py new file mode 100644 index 00000000..ce5f8fdf --- /dev/null +++ b/MinkowskiEngine/MinkowskiCommon.py @@ -0,0 +1,148 @@ +# Copyright (c) 2020 NVIDIA CORPORATION. +# Copyright (c) 2018-2020 Chris Choy (chrischoy@ai.stanford.edu). +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +# of the Software, and to permit persons to whom the Software is furnished to do +# so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +# Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural +# Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part +# of the code. +import math +from collections import Sequence +import numpy as np +from enum import Enum +from itertools import product +from typing import Union + +import torch + +from torch.nn import Module + +import MinkowskiEngineBackend._C as MEB + + +StrideType = Union[int, Sequence, np.ndarray, torch.IntTensor] + + +class GlobalPoolingMode(Enum): + """ + Define the global pooling mode + """ + + AUTO = 0, "AUTO" + INDEX_SELECT = 1, "INDEX_SELECT" + SPARSE = 2, "SPARSE" + + def __new__(cls, value, name): + member = object.__new__(cls) + member._value_ = value + member.fullname = name + return member + + def __int__(self): + return self.value + + +def convert_to_int_list( + arg: Union[int, Sequence, np.ndarray, torch.Tensor], dimension: int +): + if isinstance(arg, list): + assert len(arg) == dimension + return arg + + if isinstance(arg, (Sequence, np.ndarray, torch.Tensor)): + tmp = [i for i in arg] + assert len(tmp) == dimension + elif np.isscalar(arg): # Assume that it is a scalar + tmp = [int(arg) for i in range(dimension)] + else: + raise ValueError("Input must be a scalar or a sequence") + + return tmp + + +def convert_to_int_tensor( + arg: Union[int, Sequence, np.ndarray, torch.IntTensor], dimension: int +): + if isinstance(arg, torch.IntTensor): + assert arg.numel() == dimension + return arg + + if isinstance(arg, (Sequence, np.ndarray)): + tmp = torch.IntTensor([i for i in arg]) + assert tmp.numel() == dimension + elif np.isscalar(arg): # Assume that it is a scalar + tmp = torch.IntTensor([int(arg) for i in range(dimension)]) + else: + raise ValueError("Input must be a scalar or a sequence") + + return tmp + + +def prep_args( + tensor_stride: Union[int, Sequence, np.ndarray, torch.IntTensor], + stride: Union[int, Sequence, np.ndarray, torch.IntTensor], + kernel_size: Union[int, Sequence, np.ndarray, torch.IntTensor], + dilation: Union[int, Sequence, np.ndarray, torch.IntTensor], + region_type: Union[int, MEB.RegionType], + D=-1, +): + assert torch.prod( + kernel_size > 0 + ), f"kernel_size must be a positive integer, provided {kernel_size}" + assert D > 0, f"dimension must be a positive integer, {D}" + tensor_stride = convert_to_int_tensor(tensor_stride, D) + stride = convert_to_int_tensor(stride, D) + kernel_size = convert_to_int_tensor(kernel_size, D) + dilation = convert_to_int_tensor(dilation, D) + region_type = int(region_type) + return ( + tensor_stride, + stride, + kernel_size, + dilation, + region_type, + ) + + +def get_postfix(tensor: torch.Tensor): + postfix = "GPU" if tensor.is_cuda else "CPU" + if isinstance(tensor, torch.DoubleTensor) or isinstance( + tensor, torch.cuda.DoubleTensor + ): + postfix += "d" + else: + postfix += "f" + return postfix + + +class MinkowskiModuleBase(Module): + pass + + +def get_minkowski_function(name, variable): + fn_name = name + get_postfix(variable) + if hasattr(MEB, fn_name): + return getattr(MEB, fn_name) + else: + if variable.is_cuda: + raise ValueError( + f"Function {fn_name} not available. Please compile MinkowskiEngine with `torch.cuda.is_available()` is `True`." + ) + else: + raise ValueError(f"Function {fn_name} not available.") diff --git a/MinkowskiEngine/MinkowskiConvolution.py b/MinkowskiEngine/MinkowskiConvolution.py index a1b95474..fe3f3599 100644 --- a/MinkowskiEngine/MinkowskiConvolution.py +++ b/MinkowskiEngine/MinkowskiConvolution.py @@ -1,4 +1,5 @@ -# Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). +# Copyright (c) 2020 NVIDIA CORPORATION. +# Copyright (c) 2018-2020 Chris Choy (chrischoy@ai.stanford.edu). # # Permission is hereby granted, free of charge, to any person obtaining a copy of # this software and associated documentation files (the "Software"), to deal in @@ -28,229 +29,256 @@ from torch.autograd import Function from torch.nn import Parameter -from SparseTensor import SparseTensor, _get_coords_key -from Common import RegionType, MinkowskiModuleBase, KernelGenerator, \ - prep_args, convert_to_int_list, convert_to_int_tensor, \ - get_minkowski_function -from MinkowskiCoords import CoordsKey, save_ctx +from MinkowskiEngineBackend._C import CoordinateMapKey, RegionType +from MinkowskiSparseTensor import SparseTensor, _get_coordinate_map_key +from MinkowskiCommon import ( + MinkowskiModuleBase, + prep_args, + convert_to_int_list, + get_minkowski_function, +) +from MinkowskiCoordinateManager import CoordinateManager +from MinkowskiKernelGenerator import KernelRegion, KernelGenerator, save_ctx class MinkowskiConvolutionFunction(Function): - @staticmethod - def forward(ctx, - input_features, - kernel, - tensor_stride=1, - stride=1, - kernel_size=-1, - dilation=1, - region_type=0, - region_offset=None, - in_coords_key=None, - out_coords_key=None, - coords_manager=None): - """ - region_type=0 HyperCube - """ - # Prep arguments - # Kernel shape (n_spatial_kernels, in_nfeat, out_nfeat) - assert input_features.shape[1] == kernel.shape[1], \ - "The input shape " + str(list(input_features.shape)) + \ - " does not match the kernel shape " + str(list(kernel.shape)) - if out_coords_key is None: - out_coords_key = CoordsKey(in_coords_key.D) - assert in_coords_key.D == out_coords_key.D - assert input_features.type() == kernel.type(), \ - f"Type mismatch input: {input_features.type()} != kernel: {kernel.type()}" + def forward( + ctx, + input_features: torch.Tensor, + kernel_weights: torch.Tensor, + kernel_generator: KernelGenerator, + in_coordinate_map_key: CoordinateMapKey, + out_coordinate_map_key: CoordinateMapKey = None, + coordinate_manager: CoordinateManager = None, + ): + if out_coordinate_map_key is None: + out_coordinate_map_key = CoordinateMapKey( + in_coordinate_map_key.get_coordinate_size() + ) + assert ( + input_features.type() == kernel_weights.type() + ), f"Type mismatch input: {input_features.type()} != kernel: {kernel.type()}" if not input_features.is_contiguous(): input_features = input_features.contiguous() - tensor_stride, stride, kernel_size, dilation, region_type = prep_args( - tensor_stride, stride, kernel_size, dilation, region_type, - in_coords_key.D) - - if region_offset is None: - region_offset = torch.IntTensor() - - ctx.in_feat = input_features - ctx.kernel = kernel - ctx = save_ctx(ctx, tensor_stride, stride, kernel_size, dilation, - region_type, in_coords_key, out_coords_key, - coords_manager) - - D = in_coords_key.D - out_feat = input_features.new() - - fw_fn = get_minkowski_function('ConvolutionForward', input_features) - fw_fn(ctx.in_feat, out_feat, kernel, - convert_to_int_list(ctx.tensor_stride, D), - convert_to_int_list(ctx.stride, D), - convert_to_int_list(ctx.kernel_size, D), - convert_to_int_list(ctx.dilation, D), region_type, region_offset, - ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, - ctx.coords_man.CPPCoordsManager) - return out_feat + ctx.input_features = input_features + ctx.kernel_weights = kernel_weights + ctx = save_ctx( + ctx, + kernel_generator, + in_coordinate_map_key, + out_coordinate_map_key, + coordinate_manager, + ) + + D = in_coordinate_map_key.get_coordinate_size() - 1 + + fw_fn = get_minkowski_function("ConvolutionForward", input_features) + return fw_fn( + ctx.input_features, + kernel_weights, + kernel_generator.kernel_size, + kernel_generator.kernel_stride, + kernel_generator.kernel_dilation, + kernel_generator.region_type, + kernel_generator.region_offsets, + ctx.in_coordinate_map_key, + ctx.out_coordinate_map_key, + ctx.coordinate_manager._manager, + ) @staticmethod - def backward(ctx, grad_out_feat): - if not grad_out_feat.is_contiguous(): - grad_out_feat = grad_out_feat.contiguous() - - grad_in_feat = grad_out_feat.new() - grad_kernel = grad_out_feat.new() - D = ctx.in_coords_key.D - bw_fn = get_minkowski_function('ConvolutionBackward', grad_out_feat) - bw_fn(ctx.in_feat, grad_in_feat, grad_out_feat, ctx.kernel, grad_kernel, - convert_to_int_list(ctx.tensor_stride, D), - convert_to_int_list(ctx.stride, D), - convert_to_int_list(ctx.kernel_size, D), - convert_to_int_list(ctx.dilation, D), ctx.region_type, - ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, - ctx.coords_man.CPPCoordsManager) - return grad_in_feat, grad_kernel, None, None, None, None, None, None, None, None, None + def backward(ctx, grad_out_feat: torch.Tensor): + D = ctx.in_coordinate_map_key.get_coordinate_size() - 1 + + bw_fn = get_minkowski_function("ConvolutionBackward", grad_out_feat) + grad_in_feat, grad_kernel = bw_fn( + ctx.input_features, + grad_out_feat, + ctx.kernel_weights, + ctx.kernel_generator.kernel_size, + ctx.kernel_generator.kernel_stride, + ctx.kernel_generator.kernel_dilation, + ctx.kernel_generator.region_type, + ctx.kernel_generator.region_offsets, + ctx.in_coordinate_map_key, + ctx.out_coordinate_map_key, + ctx.coordinate_manager._manager, + ) + return ( + grad_in_feat, + grad_kernel, + None, + None, + None, + None, + ) class MinkowskiConvolutionTransposeFunction(Function): - @staticmethod - def forward(ctx, - input_features, - kernel, - tensor_stride=1, - stride=1, - kernel_size=-1, - dilation=1, - region_type=0, - region_offset=None, - generate_new_coords=False, - in_coords_key=None, - out_coords_key=None, - coords_manager=None): - """ - region_type=0 HyperCube - """ - # Prep arguments - # Kernel shape (n_spatial_kernels, in_nfeat, out_nfeat) - assert input_features.shape[1] == kernel.shape[1], \ - "The input shape " + str(list(input_features.shape)) + \ - " does not match the kernel shape " + str(list(kernel.shape)) - if out_coords_key is None: - out_coords_key = CoordsKey(in_coords_key.D) - assert in_coords_key.D == out_coords_key.D - assert input_features.type() == kernel.type(), \ - f"Type mismatch input: {input_features.type()} != kernel: {kernel.type()}" + def forward( + ctx, + input_features: torch.Tensor, + kernel_weights: torch.Tensor, + kernel_generator: KernelGenerator, + generate_new_coordinates: bool = False, + in_coordinate_map_key: CoordinateMapKey = None, + out_coordinate_map_key: CoordinateMapKey = None, + coordinate_manager: CoordinateManager = None, + ): + if out_coordinate_map_key is None: + out_coordinate_map_key = CoordinateMapKey( + in_coordinate_map_key.get_coordinate_size() + ) + assert ( + input_features.type() == kernel.type() + ), f"Type mismatch input: {input_features.type()} != kernel: {kernel.type()}" if not input_features.is_contiguous(): input_features = input_features.contiguous() - tensor_stride, stride, kernel_size, dilation, region_type = prep_args( - tensor_stride, stride, kernel_size, dilation, region_type, - in_coords_key.D) - - if region_offset is None: - region_offset = torch.IntTensor() - - ctx.in_feat = input_features - ctx.kernel = kernel - ctx = save_ctx(ctx, tensor_stride, stride, kernel_size, dilation, - region_type, in_coords_key, out_coords_key, - coords_manager) - - D = in_coords_key.D - out_feat = input_features.new() - - fw_fn = get_minkowski_function('ConvolutionTransposeForward', - input_features) - fw_fn(ctx.in_feat, out_feat, kernel, - convert_to_int_list(ctx.tensor_stride, D), - convert_to_int_list(ctx.stride, D), - convert_to_int_list(ctx.kernel_size, D), - convert_to_int_list(ctx.dilation, D), region_type, region_offset, - ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, - ctx.coords_man.CPPCoordsManager, generate_new_coords) - return out_feat + ctx.input_features = input_features + ctx.kernel_weights = kernel_weights + ctx = save_ctx( + ctx, + kernel_region, + in_coordinate_map_key, + out_coordinate_map_key, + coordinate_manager, + ) + + D = in_coordinate_map_key.get_coordinate_size() - 1 + + fw_fn = get_minkowski_function("ConvolutionTransposeForward", input_features) + return fw_fn( + ctx.input_features, + kernel_weights, + convert_to_int_list(kernel_region.kernel_size, D), + convert_to_int_list(kernel_region.kernel_stride, D), + convert_to_int_list(kernel_region.kernel_dilation, D), + kernel_region.region_type, + kernel_region.region_offset, + ctx.in_coordinate_map_key, + ctx.out_coordinate_map_key, + ctx.coordinate_manager, + generate_new_coordinates, + ) @staticmethod - def backward(ctx, grad_out_feat): - if not grad_out_feat.is_contiguous(): - grad_out_feat = grad_out_feat.contiguous() - - grad_in_feat = grad_out_feat.new() - grad_kernel = grad_out_feat.new() - D = ctx.in_coords_key.D - bw_fn = get_minkowski_function('ConvolutionTransposeBackward', - grad_out_feat) - bw_fn(ctx.in_feat, grad_in_feat, grad_out_feat, ctx.kernel, grad_kernel, - convert_to_int_list(ctx.tensor_stride, D), - convert_to_int_list(ctx.stride, D), - convert_to_int_list(ctx.kernel_size, D), - convert_to_int_list(ctx.dilation, D), ctx.region_type, - ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, - ctx.coords_man.CPPCoordsManager) - return grad_in_feat, grad_kernel, None, None, None, None, None, None, None, None, None, None + def backward(ctx, grad_out_feat: torch.Tensor): + D = ctx.in_coordinate_map_key.get_coordinate_size() - 1 + + bw_fn = get_minkowski_function("ConvolutionTransposeBackward", grad_out_feat) + grad_in_feat, grad_kernel = bw_fn( + ctx.input_features, + grad_out_feat, + ctx.kernel_weights, + convert_to_int_list(ctx.kernel_region.kernel_size, D), + convert_to_int_list(ctx.kernel_region.kernel_stride, D), + convert_to_int_list(ctx.kernel_region.kernel_dilation, D), + ctx.kernel_region.region_type, + ctx.kernel_region.region_offset, + ctx.in_coordinate_map_key, + ctx.out_coordinate_map_key, + ctx.coordinate_manager, + ) + return ( + grad_in_feat, + grad_kernel, + None, + None, + None, + None, + None, + ) class MinkowskiConvolutionBase(MinkowskiModuleBase): - def __init__(self, - in_channels, - out_channels, - kernel_size=-1, - stride=1, - dilation=1, - has_bias=False, - kernel_generator=None, - is_transpose=False, - dimension=-1): + __slots__ = ( + "in_channels", + "out_channels", + "is_transpose", + "kernel_generator", + "dimension", + "use_mm", + "weight", + "bias", + "conv", + ) + + def __init__( + self, + in_channels, + out_channels, + kernel_size=-1, + stride=1, + dilation=1, + bias=False, + kernel_generator=None, + is_transpose=False, # only the base class has this argument + dimension=-1, + ): + r""" + + .. note:: + + When the kernel generator is provided, all kernel related arguments + (kernel_size, stride, dilation) will be ignored. + + """ super(MinkowskiConvolutionBase, self).__init__() assert dimension > 0, f"dimension must be a positive integer, {dimension}" if kernel_generator is None: kernel_generator = KernelGenerator( kernel_size=kernel_size, - stride=stride, - dilation=dilation, - dimension=dimension) - else: - kernel_size = kernel_generator.kernel_size - - stride = convert_to_int_tensor(stride, dimension) - kernel_size = convert_to_int_tensor(kernel_size, dimension) - dilation = convert_to_int_tensor(dilation, dimension) - - kernel_volume = kernel_generator.kernel_volume + kernel_stride=stride, + kernel_dilation=dilation, + dimension=dimension, + ) self.is_transpose = is_transpose self.in_channels = in_channels self.out_channels = out_channels - self.kernel_size = kernel_size - self.kernel_volume = kernel_volume - self.stride = stride - self.dilation = dilation + self.kernel_generator = kernel_generator self.dimension = dimension - self.use_mm = False # use matrix multiplication when kernel is 1 + self.use_mm = False # use matrix multiplication when kernel_volume is 1 Tensor = torch.FloatTensor - if torch.prod(kernel_size) == 1 and torch.prod(stride) == 1: - self.kernel_shape = (self.in_channels, self.out_channels) + if ( + self.kernel_generator.kernel_volume == 1 + and self.kernel_generator.requires_strided_coordiantes + ): + kernel_shape = (self.in_channels, self.out_channels) self.use_mm = True else: - self.kernel_shape = (self.kernel_volume, self.in_channels, - self.out_channels) - - self.kernel = Parameter(Tensor(*self.kernel_shape)) - self.bias = Parameter(Tensor(1, out_channels)) if has_bias else None - self.has_bias = has_bias - - def forward(self, - input: SparseTensor, - coords: Union[torch.IntTensor, CoordsKey, SparseTensor] = None): + kernel_shape = ( + self.kernel_generator.kernel_volume, + self.in_channels, + self.out_channels, + ) + + self.weight = Parameter(Tensor(*kernel_shape)) + self.bias = Parameter(Tensor(1, out_channels)) if bias else None + self.conv = ( + MinkowskiConvolutionTransposeFunction() + if is_transpose + else MinkowskiConvolutionFunction() + ) + + def forward( + self, + input: SparseTensor, + coordinates: Union[torch.Tensor, CoordinateMapKey, SparseTensor] = None, + ): r""" :attr:`input` (`MinkowskiEngine.SparseTensor`): Input sparse tensor to apply a convolution on. - :attr:`coords` ((`torch.IntTensor`, `MinkowskiEngine.CoordsKey`, + :attr:`coordinates` ((`torch.IntTensor`, `MinkowskiEngine.CoordinateMapKey`, `MinkowskiEngine.SparseTensor`), optional): If provided, generate results on the provided coordinates. None by default. @@ -258,53 +286,51 @@ def forward(self, assert isinstance(input, SparseTensor) assert input.D == self.dimension - # Create a region_offset - self.region_type_, self.region_offset_, _ = \ - self.kernel_generator.get_kernel(input.tensor_stride, self.is_transpose) - if self.use_mm and coords is None: # If the kernel_size == 1, the convolution is simply a matrix # multiplication outfeat = input.F.mm(self.kernel) out_coords_key = input.coords_key else: - if self.is_transpose: - conv = MinkowskiConvolutionTransposeFunction() - else: - conv = MinkowskiConvolutionFunction() - # Get a new coords key or extract one from the coords - out_coords_key = _get_coords_key(input, coords) - outfeat = conv.apply(input.F, self.kernel, input.tensor_stride, - self.stride, self.kernel_size, self.dilation, - self.region_type_, self.region_offset_, - input.coords_key, out_coords_key, - input.coords_man) - if self.has_bias: + # Get a new coordinate_map_key or extract one from the coords + out_coordinate_map_key = _get_coordinate_map_key(input, coordinates) + outfeat = self.conv.apply( + input.F, + self.weight, + self.kernel_generator, + input.coordinate_map_key, + out_coordinate_map_key, + input._manager, + ) + if self.bias is not None: outfeat += self.bias return SparseTensor( - outfeat, coords_key=out_coords_key, coords_manager=input.coords_man) + outfeat, + coordinate_map_key=out_coordinate_map_key, + coordinate_manager=input._manager, + ) def reset_parameters(self, is_transpose=False): - n = (self.out_channels - if is_transpose else self.in_channels) * self.kernel_volume - stdv = 1. / math.sqrt(n) - self.kernel.data.uniform_(-stdv, stdv) + n = ( + self.out_channels if is_transpose else self.in_channels + ) * self.kernel_generator.kernel_volume + stdv = 1.0 / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.uniform_(-stdv, stdv) def __repr__(self): - s = '(in={}, out={}, region_type={}, '.format( - self.in_channels, self.out_channels, - self.kernel_generator.region_type) - if self.kernel_generator.region_type in [ - RegionType.HYBRID, RegionType.CUSTOM - ]: - s += 'kernel_volume={}, '.format(self.kernel_volume) + s = "(in={}, out={}, region_type={}, ".format( + self.in_channels, self.out_channels, self.kernel_generator.region_type + ) + if self.kernel_generator.region_type in [RegionType.CUSTOM]: + s += "kernel_volume={}, ".format(self.kernel_generator.kernel_volume) else: - s += 'kernel_size={}, '.format(self.kernel_size.tolist()) - s += 'stride={}, dilation={})'.format(self.stride.tolist(), - self.dilation.tolist()) + s += "kernel_size={}, ".format(self.kernel_generator.kernel_size) + s += "stride={}, dilation={})".format( + self.kernel_generator.kernel_stride, self.kernel_generator.kernel_dilation, + ) return self.__class__.__name__ + s @@ -331,15 +357,17 @@ class MinkowskiConvolution(MinkowskiConvolutionBase): """ - def __init__(self, - in_channels, - out_channels, - kernel_size=-1, - stride=1, - dilation=1, - has_bias=False, - kernel_generator=None, - dimension=None): + def __init__( + self, + in_channels, + out_channels, + kernel_size=-1, + stride=1, + dilation=1, + bias=False, + kernel_generator=None, + dimension=None, + ): r"""convolution on a sparse tensor Args: @@ -383,10 +411,11 @@ def __init__(self, kernel_size, stride, dilation, - has_bias, + bias, kernel_generator, is_transpose=False, - dimension=dimension) + dimension=dimension, + ) self.reset_parameters() @@ -394,16 +423,18 @@ class MinkowskiConvolutionTranspose(MinkowskiConvolutionBase): r"""A generalized sparse transposed convolution or deconvolution layer. """ - def __init__(self, - in_channels, - out_channels, - kernel_size=-1, - stride=1, - dilation=1, - has_bias=False, - kernel_generator=None, - generate_new_coords=False, - dimension=None): + def __init__( + self, + in_channels, + out_channels, + kernel_size=-1, + stride=1, + dilation=1, + bias=False, + kernel_generator=None, + generate_new_coordinates=False, + dimension=None, + ): r"""a generalized sparse transposed convolution layer. Args: @@ -455,21 +486,24 @@ def __init__(self, kernel_size, stride, dilation, - has_bias, + bias, kernel_generator, is_transpose=True, - dimension=dimension) + dimension=dimension, + ) self.reset_parameters(True) - self.generate_new_coords = generate_new_coords + self.generate_new_coordinates = generate_new_coordinates - def forward(self, - input: SparseTensor, - coords: Union[torch.IntTensor, CoordsKey, SparseTensor] = None): + def forward( + self, + input: SparseTensor, + coordinates: Union[torch.Tensor, CoordinateMapKey, SparseTensor] = None, + ): r""" :attr:`input` (`MinkowskiEngine.SparseTensor`): Input sparse tensor to apply a convolution on. - :attr:`coords` ((`torch.IntTensor`, `MinkowskiEngine.CoordsKey`, + :attr:`coordinates` ((`torch.IntTensor`, `MinkowskiEngine.CoordsKey`, `MinkowskiEngine.SparseTensor`), optional): If provided, generate results on the provided coordinates. None by default. @@ -478,24 +512,35 @@ def forward(self, assert input.D == self.dimension # Create a region_offset - self.region_type_, self.region_offset_, _ = \ - self.kernel_generator.get_kernel(input.tensor_stride, self.is_transpose) + self.region_type_, self.region_offset_, _ = self.kernel_generator.get_kernel( + input.tensor_stride, self.is_transpose + ) if self.use_mm and coords is None: # If the kernel_size == 1, the convolution is simply a matrix # multiplication outfeat = input.F.mm(self.kernel) - out_coords_key = input.coords_key + out_coordinate_map_key = input.coordinate_map_key else: # Get a new coords key or extract one from the coords out_coords_key = _get_coords_key(input, coords, tensor_stride=1) outfeat = MinkowskiConvolutionTransposeFunction().apply( - input.F, self.kernel, input.tensor_stride, self.stride, - self.kernel_size, self.dilation, self.region_type_, - self.region_offset_, self.generate_new_coords, input.coords_key, - out_coords_key, input.coords_man) + input.F, + self.kernel, + input.tensor_stride, + self.stride, + self.kernel_size, + self.dilation, + self.region_type_, + self.region_offset_, + self.generate_new_coords, + input.coords_key, + out_coords_key, + input.coords_man, + ) if self.has_bias: outfeat += self.bias return SparseTensor( - outfeat, coords_key=out_coords_key, coords_manager=input.coords_man) + outfeat, coords_key=out_coords_key, coords_manager=input.coords_man + ) diff --git a/MinkowskiEngine/MinkowskiCoords.py b/MinkowskiEngine/MinkowskiCoordinateManager.py similarity index 84% rename from MinkowskiEngine/MinkowskiCoords.py rename to MinkowskiEngine/MinkowskiCoordinateManager.py index ea168d52..7c71ff24 100644 --- a/MinkowskiEngine/MinkowskiCoords.py +++ b/MinkowskiEngine/MinkowskiCoordinateManager.py @@ -1,4 +1,5 @@ -# Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). +# Copyright (c) 2020 NVIDIA CORPORATION. +# Copyright (c) 2018-2020 Chris Choy (chrischoy@ai.stanford.edu). # # Permission is hereby granted, free of charge, to any person obtaining a copy of # this software and associated documentation files (the "Software"), to deal in @@ -28,21 +29,35 @@ import warnings import torch -from Common import convert_to_int_list, convert_to_int_tensor, prep_args +from MinkowskiCommon import convert_to_int_list, convert_to_int_tensor, prep_args import MinkowskiEngineBackend._C as _C from MinkowskiEngineBackend._C import ( CoordinateMapKey, GPUMemoryAllocatorType, CoordinateMapType, RegionType, + CUDAKernelMapMode, ) CPU_COUNT = os.cpu_count() if "OMP_NUM_THREADS" in os.environ: CPU_COUNT = int(os.environ["OMP_NUM_THREADS"]) -_allocator_type = GPUMemoryAllocator.PYTORCH -_map_type = CoordinateMapType.CUDA if _C.is_cuda_available() else CoordinateMapType.CPU +_allocator_type = GPUMemoryAllocatorType.PYTORCH +_coordinate_map_type = ( + CoordinateMapType.CUDA if _C.is_cuda_available() else CoordinateMapType.CPU +) +_kernel_map_mode = CUDAKernelMapMode.SPEED_OPTIMIZED + + +def set_coordinate_map_type(coordinate_map_type: CoordinateMapType): + r"""Set the default coordinate map type. + + The MinkowskiEngine automatically set the coordinate_map_type to CUDA if + a NVIDIA GPU is available. To control the + """ + global _coordinate_map_type + _coordinate_map_type = coordinate_map_type def set_gpu_allocator(backend: GPUMemoryAllocatorType): @@ -62,9 +77,9 @@ def set_gpu_allocator(backend: GPUMemoryAllocatorType): >>> import MinkowskiEngine as ME >>> # Set the GPU memory manager backend to raw CUDA calls - >>> ME.set_gpu_allocator(ME.GPUMemoryAllocator.CUDA) + >>> ME.set_gpu_allocator(ME.GPUMemoryAllocatorType.CUDA) >>> # Set the GPU memory manager backend to the pytorch c10 allocator - >>> ME.set_gpu_allocator(ME.GPUMemoryAllocator.PYTORCH) + >>> ME.set_gpu_allocator(ME.GPUMemoryAllocatorType.PYTORCH) """ assert isinstance( @@ -74,7 +89,7 @@ def set_gpu_allocator(backend: GPUMemoryAllocatorType): _allocator_type = backend -def set_memory_manager_backend(backend: GPUMemoryAllocator): +def set_memory_manager_backend(backend: GPUMemoryAllocatorType): r"""Alias for set_gpu_allocator. Deprecated and will be removed. """ warnings.warn( @@ -93,36 +108,39 @@ def __init__(*args, **kwargs): class CoordinateManager: def __init__( self, - D: int = -1, + D: int = 0, num_threads: int = -1, - map_type: CoordinateMapType = None, + coordinate_map_type: CoordinateMapType = None, allocator_type: GPUMemoryAllocatorType = None, - map_mode=None, # TODO + kernel_map_mode: CUDAKernelMapMode = None, ): r""" :attr:`D`: The order, or dimension of the coordinates. """ + global _coordinate_map_type, _allocator_type, _kernel_map_mode if D < 1: raise ValueError(f"Invalid rank D > 0, D = {D}.") if num_threads < 0: num_threads = min(CPU_COUNT, 20) - if map_type is None: - global _map_type - map_type = _map_type + if coordinate_map_type is None: + coordinate_map_type = _coordinate_map_type if allocator_type is None: - global _allocator_type allocator_type = _allocator_type + if kernel_map_mode is None: + kernel_map_mode = _kernel_map_mode - postfix = "CPU" if map_type == CoordinateMapType.CPU else "GPU" - if map_type == CoordinateMapType.GPU: - postfix += ( + postfix = "" + if coordinate_map_type == CoordinateMapType.CPU: + postfix = "CPU" + else: + postfix = "GPU" + ( "_default" if allocator_type == GPUMemoryAllocatorType.CUDA else "_c10" ) self.D = D self._CoordinateManagerClass = getattr(_C, "CoordinateMapManager" + postfix) - self._manager = self._CoordinateManagerClass() # TODO kernel_map_mode + self._manager = self._CoordinateManagerClass(kernel_map_mode, num_threads) # TODO: insert without remap, unique_map, inverse_mapa # @@ -136,7 +154,8 @@ def insert_and_map( ) -> Tuple[CoordinateMapKey, Tuple[torch.IntTensor, torch.IntTensor]]: r"""create a new coordinate map and returns - :attr:`coordinates`: `torch.IntTensor` (`CUDA` if map_type == `CoordinateMapType.GPU`) that defines the coordinates. + :attr:`coordinates`: `torch.IntTensor` (`CUDA` if coordinate_map_type + == `CoordinateMapType.GPU`) that defines the coordinates. Example:: @@ -204,7 +223,7 @@ def stride( # ) # return strided_key - def _get_coordinate_key(self, key_or_tensor_strides): + def _get_coordinate_map_key(self, key_or_tensor_strides): r"""Helper function that retrieves a coordinate map key from tensor stride. """ assert isinstance(key_or_tensor_strides, CoordinateMapKey) or isinstance( @@ -222,10 +241,8 @@ def _get_coordinate_key(self, key_or_tensor_strides): return coords_key def get_coordinates(self, coords_key_or_tensor_strides): - coords_key = self._get_coords_key(coords_key_or_tensor_strides) - coords = torch.IntTensor() - self.CPPCoordsManager.getCoords(coords, coords_key.CPPCoordsKey) - return coords + key = self._get_coordinate_map_key(coords_key_or_tensor_strides) + return self._manager.get_coordinates(key) # def get_batch_size(self): # return self.CPPCoordsManager.getBatchSize() @@ -302,8 +319,8 @@ def get_coordinates(self, coords_key_or_tensor_strides): # if region_offset is None: # region_offset = torch.IntTensor() - # in_coords_key = self._get_coords_key(in_key_or_tensor_strides) - # out_coords_key = self._get_coords_key(out_key_or_tensor_strides) + # in_coords_key = self._get_coordinate_map_key(in_key_or_tensor_strides) + # out_coords_key = self._get_coordinate_map_key(out_key_or_tensor_strides) # tensor_strides = convert_to_int_tensor(in_tensor_strides, self.D) # strides = convert_to_int_tensor(stride, self.D) @@ -351,8 +368,8 @@ def get_coordinates(self, coords_key_or_tensor_strides): # print(f"{i} -> {o}") # """ - # in_coords_key = self._get_coords_key(in_key_or_tensor_strides) - # out_coords_key = self._get_coords_key(out_key_or_tensor_strides) + # in_coords_key = self._get_coordinate_map_key(in_key_or_tensor_strides) + # out_coords_key = self._get_coordinate_map_key(out_key_or_tensor_strides) # return self.CPPCoordsManager.getCoordsMap( # in_coords_key.CPPCoordsKey, out_coords_key.CPPCoordsKey @@ -392,8 +409,8 @@ def get_coordinates(self, coords_key_or_tensor_strides): # return self.CPPCoordsManager.getCoordsSize(coords_key.CPPCoordsKey) # def get_mapping_by_tensor_strides(self, in_tensor_strides, out_tensor_strides): - # in_key = self._get_coords_key(in_tensor_strides) - # out_key = self._get_coords_key(out_tensor_strides) + # in_key = self._get_coordinate_map_key(in_tensor_strides) + # out_key = self._get_coordinate_map_key(out_tensor_strides) # return self.get_mapping_by_coords_key(in_key, out_key) # def permute_label( @@ -402,8 +419,8 @@ def get_coordinates(self, coords_key_or_tensor_strides): # if target_tensor_stride == label_tensor_stride: # return label - # label_coords_key = self._get_coords_key(label_tensor_stride) - # target_coords_key = self._get_coords_key(target_tensor_stride) + # label_coords_key = self._get_coordinate_map_key(label_tensor_stride) + # target_coords_key = self._get_coordinate_map_key(target_tensor_stride) # permutation = self.get_mapping_by_coords_key( # label_coords_key, target_coords_key @@ -417,31 +434,9 @@ def get_coordinates(self, coords_key_or_tensor_strides): # np.add.at(counter, (permutation, label), 1) # return torch.from_numpy(np.argmax(counter, 1)) - def print_diagnostics(self, coords_key: CoordsKey): - assert isinstance(coords_key, CoordsKey) - self.CPPCoordsManager.printDiagnostics(coords_key.CPPCoordsKey) - - def __repr__(self): - return str(self.CPPCoordsManager) - - -def save_ctx( - ctx, # function object context - tensor_stride: torch.IntTensor, - stride: torch.IntTensor, - kernel_size: torch.IntTensor, - dilation: torch.IntTensor, - region_type: int, - in_coords_key: CoordsKey, - out_coords_key: CoordsKey, - coords_man: CoordsManager, -): - ctx.tensor_stride = tensor_stride - ctx.stride = stride - ctx.kernel_size = kernel_size - ctx.dilation = dilation - ctx.region_type = region_type - ctx.in_coords_key = in_coords_key - ctx.out_coords_key = out_coords_key - ctx.coords_man = coords_man - return ctx + # def print_diagnostics(self, coords_key: CoordsKey): + # assert isinstance(coords_key, CoordsKey) + # self.CPPCoordsManager.printDiagnostics(coords_key.CPPCoordsKey) + + # def __repr__(self): + # return str(self._manager) diff --git a/MinkowskiEngine/MinkowskiKernelGenerator.py b/MinkowskiEngine/MinkowskiKernelGenerator.py new file mode 100644 index 00000000..564b3eb8 --- /dev/null +++ b/MinkowskiEngine/MinkowskiKernelGenerator.py @@ -0,0 +1,360 @@ +# Copyright (c) 2020 NVIDIA CORPORATION. +# Copyright (c) 2018-2020 Chris Choy (chrischoy@ai.stanford.edu). +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +# of the Software, and to permit persons to whom the Software is furnished to do +# so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +# Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural +# Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part +# of the code. +import math +from collections import Sequence, namedtuple +from functools import reduce +import numpy as np +from enum import Enum +from itertools import product +from typing import Union + +import torch +from MinkowskiCommon import convert_to_int_list +from MinkowskiEngineBackend._C import CoordinateMapKey, RegionType +from MinkowskiCoordinateManager import CoordinateManager + + +def get_kernel_volume(region_type, kernel_size, region_offset, axis_types, dimension): + """ + when center is True, the custom region_offset will be centered at the + origin. Currently, for HYPER_CUBE, HYPER_CROSS with odd kernel sizes cannot + use center=False. + """ + if region_type == RegionType.HYPER_CUBE: + assert reduce( + lambda k1, k2: k1 > 0 and k2 > 0, kernel_size + ), "kernel_size must be positive" + assert ( + region_offset is None + ), "Region offset must be None when region_type is given" + assert axis_types is None, "Axis types must be None when region_type is given" + # Typical convolution kernel + + # Convolution kernel with even numbered kernel size not defined. + kernel_volume = torch.prod(torch.IntTensor(kernel_size)).item() + + elif region_type == RegionType.HYPER_CROSS: + assert reduce( + lambda k1, k2: k1 > 0 and k2 > 0, kernel_size + ), "kernel_size must be positive" + assert ( + kernel_size % 2 + ).prod() == 1, "kernel_size must be odd for region_type HYPER_CROSS" + # 0th: itself, (1, 2) for 0th dim neighbors, (3, 4) for 1th dim ... + kernel_volume = int(torch.sum(kernel_size - 1) + 1) + + # elif region_type == RegionType.HYBRID: + # assert reduce( + # lambda k1, k2: k1 > 0 and k2 > 0, kernel_size + # ), "kernel_size must be positive" + # assert ( + # region_offset is None + # ), "region_offset must be None when region_type is HYBRID" + # kernel_size_list = kernel_size.tolist() + # kernel_volume = 1 + # # First HYPER_CUBE + # for axis_type, curr_kernel_size, d in zip( + # axis_types, kernel_size_list, range(dimension) + # ): + # if axis_type == RegionType.HYPER_CUBE: + # kernel_volume *= curr_kernel_size + + # # Second, HYPER_CROSS + # for axis_type, curr_kernel_size, d in zip( + # axis_types, kernel_size_list, range(dimension) + # ): + # if axis_type == RegionType.HYPER_CROSS: + # kernel_volume += curr_kernel_size - 1 + + elif region_type == RegionType.CUSTOM: + assert ( + region_offset.numel() > 0 + ), "region_offset must be non empty when region_type is CUSTOM" + assert ( + region_offset.size(1) == dimension + ), "region_offset must have the same dimension as the network" + kernel_volume = int(region_offset.size(0)) + + else: + raise NotImplementedError() + + return kernel_volume + + +def convert_region_type( + region_type: RegionType, + tensor_stride: Union[Sequence, np.ndarray, torch.IntTensor], + kernel_size: Union[Sequence, np.ndarray, torch.IntTensor], + up_stride: Union[Sequence, np.ndarray, torch.IntTensor], + dilation: Union[Sequence, np.ndarray, torch.IntTensor], + region_offset: Union[Sequence, np.ndarray, torch.IntTensor], + axis_types: Union[Sequence, np.ndarray, torch.IntTensor], + dimension: int, + center: bool = True, +): + """ + when center is True, the custom region_offset will be centered at the + origin. Currently, for HYPER_CUBE, HYPER_CROSS with odd kernel sizes cannot + use center=False. + + up_stride: stride for conv_transpose, otherwise set it as 1 + """ + if region_type == RegionType.HYPER_CUBE: + assert ( + region_offset is None + ), "Region offset must be None when region_type is given" + assert axis_types is None, "Axis types must be None when region_type is given" + # Typical convolution kernel + assert reduce( + lambda k1, k2: k1 > 0 and k2 > 0, kernel_size + ), "kernel_size must be positive" + # assert torch.unique(dilation).numel() == 1 + kernel_volume = reduce(lambda k1, k2: k1 * k2, kernel_size) + + elif region_type == RegionType.HYPER_CROSS: + assert reduce( + lambda k1, k2: k1 > 0 and k2 > 0, kernel_size + ), "kernel_size must be positive" + assert ( + kernel_size % 2 + ).prod() == 1, "kernel_size must be odd for region_type HYPER_CROSS" + # 0th: itself, (1, 2) for 0th dim neighbors, (3, 4) for 1th dim ... + kernel_volume = ( + reduce(lambda k1, k2: k1 + k2, map(lambda k: k - 1, kernel_size)) + 1 + ) + + elif region_type == RegionType.HYBRID: + assert reduce( + lambda k1, k2: k1 > 0 and k2 > 0, kernel_size + ), "kernel_size must be positive" + assert ( + region_offset is None + ), "region_offset must be None when region_type is HYBRID" + region_offset = [[0,] * dimension] + kernel_size_list = kernel_size.tolist() + # First HYPER_CUBE + for axis_type, curr_kernel_size, d in zip( + axis_types, kernel_size_list, range(dimension) + ): + new_offset = [] + if axis_type == RegionType.HYPER_CUBE: + for offset in region_offset: + for curr_offset in range(curr_kernel_size): + off_center = ( + int(math.floor((curr_kernel_size - 1) / 2)) if center else 0 + ) + offset = offset.copy() # Do not modify the original + # Exclude the coord (0, 0, ..., 0) + if curr_offset == off_center: + continue + offset[d] = ( + (curr_offset - off_center) + * dilation[d] + * (tensor_stride[d] / up_stride[d]) + ) + new_offset.append(offset) + region_offset.extend(new_offset) + + # Second, HYPER_CROSS + for axis_type, curr_kernel_size, d in zip( + axis_types, kernel_size_list, range(dimension) + ): + new_offset = [] + if axis_type == RegionType.HYPER_CROSS: + for curr_offset in range(curr_kernel_size): + off_center = ( + int(math.floor((curr_kernel_size - 1) / 2)) if center else 0 + ) + offset = [0,] * dimension + # Exclude the coord (0, 0, ..., 0) + if curr_offset == off_center: + continue + offset[d] = ( + (curr_offset - off_center) + * dilation[d] + * (tensor_stride[d] / up_stride[d]) + ) + new_offset.append(offset) + region_offset.extend(new_offset) + + # Convert to CUSTOM type + region_type = RegionType.CUSTOM + region_offset = torch.IntTensor(region_offset) + kernel_volume = int(region_offset.size(0)) + + elif region_type == RegionType.CUSTOM: + assert ( + region_offset.numel() > 0 + ), "region_offset must be non empty when region_type is CUSTOM" + assert ( + region_offset.size(1) == dimension + ), "region_offset must have the same dimension as the network" + kernel_volume = int(region_offset.size(0)) + assert isinstance( + region_offset.dtype, torch.IntTensor + ), "region_offset must be a torch.IntTensor." + else: + raise NotImplementedError() + + if region_offset is None: + region_offset = torch.IntTensor() + + return region_type, region_offset, kernel_volume + + +class KernelGenerator: + __slots__ = ( + "cache", + "kernel_size", + "kernel_stride", + "kernel_dilation", + "region_type", + "region_offsets", + "axis_types", + "dimension", + "kernel_volume", + "requires_strided_coordinates", + ) + + def __init__( + self, + kernel_size=-1, + kernel_stride=1, + kernel_dilation=1, + is_transpose: bool = False, + region_type: RegionType = RegionType.HYPER_CUBE, + region_offsets: torch.Tensor = None, + axis_types=None, + dimension=-1, + ): + r""" + :attr:`region_type` (RegionType, optional): defines the kernel + shape. Please refer to MinkowskiEngine.Comon for details. + + :attr:`region_offset` (torch.IntTensor, optional): when the + :attr:`region_type` is :attr:`RegionType.CUSTOM`, the convolution + kernel uses the provided `region_offset` to define offsets. It + should be a matrix of size :math:`N \times D` where :math:`N` is + the number of offsets and :math:`D` is the dimension of the + space. + + :attr:`axis_types` (list of RegionType, optional): If given, it + uses different methods to create a kernel for each axis. e.g., when + it is `[RegionType.HYPER_CUBE, RegionType.HYPER_CUBE, + RegionType.HYPER_CROSS]`, the kernel would be rectangular for the + first two dimensions and cross shaped for the thrid dimension. + """ + assert dimension > 0 + assert isinstance(region_type, RegionType) + + kernel_size = convert_to_int_list(kernel_size, dimension) + kernel_stride = convert_to_int_list(kernel_stride, dimension) + kernel_dilation = convert_to_int_list(kernel_dilation, dimension) + + self.cache = {} + self.kernel_size = kernel_size + self.kernel_stride = kernel_stride + self.kernel_dilation = kernel_dilation + self.region_type = region_type + self.region_offsets = region_offsets if region_offsets else torch.IntTensor() + self.axis_types = axis_types + self.dimension = dimension + self.kernel_volume = get_kernel_volume( + region_type, kernel_size, region_offsets, axis_types, dimension + ) + self.requires_strided_coordinates = reduce( + lambda s1, s2: s1 == 1 and s2 == 1, kernel_stride + ) + + def get_kernel(self, tensor_stride, is_transpose): + assert len(tensor_stride) == self.dimension + if tuple(tensor_stride) not in self.cache: + up_stride = ( + self.stride if is_transpose else torch.Tensor([1,] * self.dimension) + ) + + self.cache[tuple(tensor_stride)] = convert_region_type( + self.region_type, + tensor_stride, + self.kernel_size, + up_stride, + self.kernel_dilation, + self.region_offsets, + self.axis_types, + self.dimension, + ) + + return self.cache[tuple(tensor_stride)] + + +class KernelRegion( + namedtuple( + "KernelRegion", + ( + "kernel_size", + "kernel_stride", + "kernel_dilation", + "region_type", + "offset", + "D", + ), + ) +): + """adding functionality to a named tuple""" + + __slots__ = () + + def __init__( + self, + kernel_size, + kernel_stride, + kernel_dilation, + region_type, + offset, + dimension, + ): + kernel_size = convert_to_int_list(kernel_size, dimension) + kernel_stride = convert_to_int_list(kernel_stride, dimension) + kernel_dilation = convert_to_int_list(kernel_dilation, dimension) + super(KernelRegion, self).__init__( + kernel_size, kernel_stride, kernel_dilation, region_type, offset, dimension + ) + + def __str__(self): + return "kernel_size:{self.kernel_size}, kernel_stride:{self.kernel_stride}, region_type:{self.region_type}" + + +def save_ctx( + ctx, # function object context + kernel_generator: KernelGenerator, + in_coords_key: CoordinateMapKey, + out_coords_key: CoordinateMapKey, + coordinate_manager: CoordinateManager, +): + ctx.kernel_generator = kernel_generator + ctx.in_coordinate_map_key = in_coords_key + ctx.out_coordinate_map_key = out_coords_key + ctx.coordinate_manager = coordinate_manager + return ctx diff --git a/MinkowskiEngine/MinkowskiSparseTensor.py b/MinkowskiEngine/MinkowskiSparseTensor.py index f607f3b1..8cb72d58 100644 --- a/MinkowskiEngine/MinkowskiSparseTensor.py +++ b/MinkowskiEngine/MinkowskiSparseTensor.py @@ -1,4 +1,5 @@ -# Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). +# Copyright (c) 2020 NVIDIA CORPORATION. +# Copyright (c) 2018-2020 Chris Choy (chrischoy@ai.stanford.edu). # # Permission is hereby granted, free of charge, to any person obtaining a copy of # this software and associated documentation files (the "Software"), to deal in @@ -30,34 +31,50 @@ from collections import Sequence import numpy as np -from Common import convert_to_int_list -from MinkowskiCoords import CoordsKey, CoordsManager -import MinkowskiEngineBackend as MEB -from MinkowskiEngineBackend import MemoryManagerBackend +from MinkowskiCommon import convert_to_int_list, StrideType +from MinkowskiEngineBackend._C import ( + GPUMemoryAllocatorType, + CUDAKernelMapMode, + CoordinateMapType, + CoordinateMapKey, +) +from MinkowskiCoordinateManager import ( + CoordinateManager, + _allocator_type, + _coordinate_map_type, +) class SparseTensorOperationMode(Enum): + r"""Enum class for SparseTensor internal instantiation modes. + + :attr:`SEPARATE_COORDINATE_MANAGER`: always create a new coordinate manager. + + :attr:`SHARE_COORDINATE_MANAGER`: always use the globally defined coordinate + manager. Must clear the coordinate manager manually by + :attr:`MinkowskiEngine.SparseTensor.clear_global_coordinate_mananager`. + """ - `SEPARATE_COORDS_MANAGER`: always create a new coordinate manager. - `SHARE_COORDS_MANAGER`: always use the globally defined coordinate manager. Must clear the coordinate manager manually by :attr:`MinkowskiEngine.SparseTensor.clear_global_coords_man` - """ - SEPARATE_COORDS_MANAGER = 0 - SHARE_COORDS_MANAGER = 1 + SEPARATE_COORDINATE_MANAGER = 0 + SHARE_COORDINATE_MANAGER = 1 class SparseTensorQuantizationMode(Enum): - """ + r""" `RANDOM_SUBSAMPLE`: Subsample one coordinate per each quantization block randomly. `UNWEIGHTED_AVERAGE`: average all features within a quantization block equally. + `UNWEIGHTED_SUM`: sum all features within a quantization block equally. """ RANDOM_SUBSAMPLE = 0 UNWEIGHTED_AVERAGE = 1 + UNWEIGHTED_SUM = 2 + +_sparse_tensor_operation_mode = SparseTensorOperationMode.SEPARATE_COORDINATE_MANAGER +_global_coordinate_manager = None -_sparse_tensor_operation_mode = SparseTensorOperationMode.SEPARATE_COORDS_MANAGER -_global_coords_man = None -COORDS_MAN_DIFFERENT_ERROR = "SparseTensors must share the same coordinate manager for this operation. Please refer to the SparseTensor creation API (https://nvidia.github.io/MinkowskiEngine/sparse_tensor.html) to share the coordinate manager, or set the sparse tensor operation mode with `set_sparse_tensor_operation_mode` to share it by default." -COORDS_KEY_DIFFERENT_ERROR = "SparseTensors must have the same coords_key." +COORDINATE_MANAGER_DIFFERENT_ERROR = "SparseTensors must share the same coordinate manager for this operation. Please refer to the SparseTensor creation API (https://nvidia.github.io/MinkowskiEngine/sparse_tensor.html) to share the coordinate manager, or set the sparse tensor operation mode with `set_sparse_tensor_operation_mode` to share it by default." +COORDINATE_KEY_DIFFERENT_ERROR = "SparseTensors must have the same coordinate_map_key." def set_sparse_tensor_operation_mode(operation_mode: SparseTensorOperationMode): @@ -69,27 +86,28 @@ def set_sparse_tensor_operation_mode(operation_mode: SparseTensorOperationMode): :attr:`MinkowskiEngine.SparseTensorOperationMode.SHARE_COORDS_MANAGER`, you can share the coordinate manager globally with other sparse tensors. However, you must explicitly clear the coordinate manger after use. Please - refer to :attr:`MinkowskiEngine.clear_global_coords_man`. + refer to :attr:`MinkowskiEngine.clear_global_coordinate_mananager`. Args: :attr:`operation_mode` (:attr:`MinkowskiEngine.SparseTensorOperationMode`): The operation mode for the sparse tensor coordinate manager. By default - :attr:`MinkowskiEngine.SparseTensorOperationMode.SEPARATE_COORDS_MANAGER`. + :attr:`MinkowskiEngine.SparseTensorOperationMode.SEPARATE_COORDINATE_MANAGER`. Example: >>> import MinkowskiEngine as ME - >>> ME.set_sparse_tensor_operation_mode(ME.SparseTensorOperationMode.SHARE_COORDS_MANAGER) + >>> ME.set_sparse_tensor_operation_mode(ME.SparseTensorOperationMode.SHARE_COORDINATE_MANAGER) >>> ... - >>> a = ME.SparseTensor(coords=A_C, feats=A_F) - >>> b = ME.SparseTensor(coords=B_C, feats=B_C) # coords_man shared + >>> a = ME.SparseTensor(...) + >>> b = ME.SparseTensor(...) # coords_man shared >>> ... # one feed forward and backward - >>> ME.clear_global_coords_man() # Must use to clear the coordinates after one forward/backward + >>> ME.clear_global_coordinate_mananager() # Must use to clear the coordinates after one forward/backward """ - assert isinstance(operation_mode, SparseTensorOperationMode), \ - f"Input must be an instance of SparseTensorOperationMode not {operation_mode}" + assert isinstance( + operation_mode, SparseTensorOperationMode + ), f"Input must be an instance of SparseTensorOperationMode not {operation_mode}" global _sparse_tensor_operation_mode _sparse_tensor_operation_mode = operation_mode @@ -99,18 +117,18 @@ def sparse_tensor_operation_mode(): return copy.deepcopy(_sparse_tensor_operation_mode) -def clear_global_coords_man(): +def clear_global_coordinate_mananager(): r"""Clear the global coordinate manager cache. When you use the operation mode: - :attr:`MinkowskiEngine.SparseTensor.SparseTensorOperationMode.SHARE_COORDS_MANAGER`, + :attr:`MinkowskiEngine.SparseTensor.SparseTensorOperationMode.SHARE_COORDINATE_MANAGER`, you must explicitly clear the coordinate manager after each feed forward/backward. """ - global _global_coords_man - _global_coords_man = None + global _global_coordinate_manager + _global_coordinate_manager = None -class SparseTensor(): +class SparseTensor: r"""A sparse tensor class. Can be accessed via :attr:`MinkowskiEngine.SparseTensor`. @@ -141,6 +159,22 @@ class SparseTensor(): x_i^D)`, and the associated feature :math:`\mathbf{f}_i`. Internally, we handle the batch index as an additional spatial dimension. + Example:: + + >>> coords, feats = ME.utils.sparse_collate([coords_batch0, coords_batch1], [feats_batch0, feats_batch1]) + >>> A = ME.SparseTensor(features=feats, coordinates=coords) + >>> B = ME.SparseTensor(features=feats, coordinate_map_key=A.coordiante_map_key, coordinate_manager=A.coordinate_manager) + >>> C = ME.SparseTensor(features=feats, coordinates=coords, quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE) + >>> D = ME.SparseTensor(features=feats, coordinates=coords, tensor_stride=2) + + .. warning:: + + To use the GPU-backend for coordinate management, the + :attr:`coordinates` must be a torch tensor on GPU. Applying `to(device)` + after a :attr:`MinkowskiEngine.SparseTensor` initialization with a CPU + `coordinates` will waste time and computation for creating a CPU + CoordinateMap since GPU CoordinateMap will be created from scratch. + .. warning:: Before MinkowskiEngine version 0.4, we put the batch indices on the last @@ -167,61 +201,47 @@ class SparseTensor(): """ def __init__( - self, - feats, - coords=None, - coords_key=None, - coords_manager=None, - force_creation=False, - allow_duplicate_coords=False, - quantization_mode=SparseTensorQuantizationMode.RANDOM_SUBSAMPLE, - memory_manager_backend: MemoryManagerBackend = None, - tensor_stride=1): + self, + features: torch.Tensor, + coordinates: torch.Tensor = None, + # optional coordinate related arguments + tensor_stride: StrideType = 1, + coordinate_map_key: CoordinateMapKey = None, + coordinate_manager: CoordinateManager = None, + quantization_mode: SparseTensorQuantizationMode = SparseTensorQuantizationMode.RANDOM_SUBSAMPLE, + # optional manager related arguments + allocator_type: GPUMemoryAllocatorType = None, + kernel_map_mode: CUDAKernelMapMode = None, + ): r""" Args: - :attr:`feats` (:attr:`torch.FloatTensor`, + :attr:`features` (:attr:`torch.FloatTensor`, :attr:`torch.DoubleTensor`, :attr:`torch.cuda.FloatTensor`, or - :attr:`torch.cuda.DoubleTensor`): The features of the sparse + :attr:`torch.cuda.DoubleTensor`): The features of a sparse tensor. - :attr:`coords` (:attr:`torch.IntTensor`): The coordinates - associated to the features. If not provided, :attr:`coords_key` + :attr:`coordinates` (:attr:`torch.IntTensor`): The coordinates + associated to the features. If not provided, :attr:`coordinate_map_key` must be provided. - :attr:`coords_key` (:attr:`MinkowskiEngine.CoordsKey`): When the - coordinates are already cached in the MinkowskiEngine, we could - reuse the same coordinates by simply providing the coordinate hash - key. In most case, this process is done automatically. When you - provide a `coords_key`, all other arguments will be be ignored. - - :attr:`coords_manager` (:attr:`MinkowskiEngine.CoordsManager`): The - MinkowskiEngine creates a dynamic computation graph and all - coordinates inside the same computation graph are managed by a - CoordsManager object. If not provided, the MinkowskiEngine will - create a new computation graph. In most cases, this process is - handled automatically and you do not need to use this. When you use - it, make sure you understand what you are doing. - - :attr:`force_creation` (:attr:`bool`): Force creation of the - coordinates. This allows generating a new set of coordinates even - when there exists another set of coordinates with the same - tensor stride. This could happen when you manually feed the same - :attr:`coords_manager`. - - :attr:`allow_duplicate_coords` (:attr:`bool`): Allow duplicate - coordinates when creating the sparse tensor. Internally, it will - generate a new unique set of coordinates and use features of at the - corresponding unique coordinates. In general, setting - `allow_duplicate_coords=True` is not recommended as it could hide - obvious errors in your data loading and preprocessing steps. Please - refer to the quantization and data loading tutorial on `here - `_ - for more details. - - :attr:`quantizatino_mode` - (:attr:`MinkowskiEngine.SparseTensorQuantizationMode`): Defines the - quantization method and how to define features of a sparse tensor. + :attr:`coordinate_map_key` + (:attr:`MinkowskiEngine.CoordinateMapKey`): When the coordinates + are already cached in the MinkowskiEngine, we could reuse the same + coordinate map by simply providing the coordinate map key. In most + case, this process is done automatically. When you provide a + `coordinate_map_key`, `coordinates` will be be ignored. + + :attr:`coordinate_manager` + (:attr:`MinkowskiEngine.CoordinateManager`): The MinkowskiEngine + manages all coordinate maps using the `_C.CoordinateMapManager`. If + not provided, the MinkowskiEngine will create a new computation + graph. In most cases, this process is handled automatically and you + do not need to use this. + + :attr:`quantization_mode` + (:attr:`MinkowskiEngine.SparseTensorQuantizationMode`): Defines how + continuous coordinates will be quantized to define a sparse tensor. Please refer to :attr:`SparseTensorQuantizationMode` for details. :attr:`tensor_stride` (:attr:`int`, :attr:`list`, @@ -229,123 +249,139 @@ def __init__( of the current sparse tensor. By default, it is 1. """ - assert isinstance(feats, - torch.Tensor), "Features must be a torch.Tensor" - assert feats.ndim == 2, f"The feature should be a matrix, The input feature is an order-{feats.ndim} tensor." + # Type checks + assert isinstance(features, torch.Tensor), "Features must be a torch.Tensor" + assert ( + features.ndim == 2 + ), f"The feature should be a matrix, The input feature is an order-{features.ndim} tensor." assert isinstance(quantization_mode, SparseTensorQuantizationMode) self.quantization_mode = quantization_mode - if coords is None and coords_key is None: - raise ValueError('Either coords or coords_key must be provided') - - if coords_key is None: - assert coords_manager is not None or coords is not None - D = -1 - if coords_manager is None: - D = coords.size(1) - 1 - else: - D = coords_manager.D - coords_key = CoordsKey(D) - coords_key.setTensorStride(convert_to_int_list(tensor_stride, D)) - else: - assert isinstance(coords_key, CoordsKey) - - if coords is not None: - assert isinstance(coords, torch.Tensor), \ - "Coordinate must be of type torch.Tensor" - - if not isinstance(coords, (torch.IntTensor, torch.cuda.IntTensor)): + if coordinates is not None: + assert isinstance(coordinates, torch.Tensor) + if coordinate_map_key is not None: + assert isinstance(coordinate_map_key, CoordinateMapKey) + if coordinate_manager is not None: + assert isinstance(coordinate_manager, CoordinateManager) + + # Coordinate Management + self.D = 0 # coordinate size - 1 + if coordinates is None and ( + coordinate_map_key is None or coordinate_manager is None + ): + raise ValueError( + "Either coordinates or (coordinate_map_key, coordinate_manager) pair must be provided." + ) + elif coordinates is not None: + if not isinstance(coordinates, (torch.IntTensor, torch.cuda.IntTensor)): warnings.warn( - 'Coords implicitly converted to torch.IntTensor. ' + - 'To remove this warning, use `.int()` to convert the ' + - 'coords into an torch.IntTensor') - coords = torch.floor(coords).int() - - if coords.device.type != 'cpu': - warnings.warn( - 'Coords implicitly converted to CPU type. ' + - 'To remove this warning, use `.cpu()` to convert the ' + - 'coords into a CPU type') - coords = coords.cpu() - - assert feats.shape[0] == coords.shape[0], \ - "The number of rows in features and coordinates do not match." - - coords = coords.contiguous() + "coordinates implicitly converted to torch.IntTensor. " + + "To remove this warning, use `.int()` to convert the " + + "coords into an torch.IntTensor" + ) + coordinates = torch.floor(coordinates).int() + assert ( + features.shape[0] == coordinates.shape[0] + ), "The number of rows in features and coordinates must match." + self.D = coordinates.size(1) - 1 + + coordinate_map_key = CoordinateMapKey( + convert_to_int_list(tensor_stride, self.D), "" + ) + self._manager = coordinate_manager + else: + # not (coordinate_map_key is None or coordinate_manager is None) + self.D = coordinate_manager.D + coordinate_map_key = CoordinateMapKey( + convert_to_int_list(tensor_stride, self.D), "" + ) + self._manager = coordinate_manager ########################## # Setup CoordsManager ########################## - if coords_manager is None: + if coordinate_manager is None: # If set to share the coords man, use the global coords man - global _sparse_tensor_operation_mode, _global_coords_man - if _sparse_tensor_operation_mode == SparseTensorOperationMode.SHARE_COORDS_MANAGER: - if _global_coords_man is None: - _global_coords_man = CoordsManager( - memory_manager_backend=memory_manager_backend, - D=coords.size(1) - 1) - coords_manager = _global_coords_man + global _sparse_tensor_operation_mode, _global_coordinate_manager, _allocator_type + if ( + _sparse_tensor_operation_mode + == SparseTensorOperationMode.SHARE_COORDINATE_MANAGER + ): + if _global_coordinate_manager is None: + _global_coordinate_manager = CoordinateManager( + D=self.D, + coordinate_map_type=CoordinateMapType.CUDA + if coordinates.is_cuda + else CoordinateMapType.CPU, + allocator_type=allocator_type, + ) + coordinate_manager = _global_coordinate_manager else: - assert coords is not None, "Initial coordinates must be given" - coords_manager = CoordsManager(D=coords.size(1) - 1) - - else: - assert isinstance(coords_manager, CoordsManager) + coordinate_manager = CoordinateManager( + D=coordinates.size(1) - 1, + coordinate_map_type=CoordinateMapType.CUDA + if coordinates.is_cuda + else CoordinateMapType.CPU, + allocator_type=allocator_type, + kernel_map_mode=kernel_map_mode, + ) + self._manager = coordinate_manager ########################## # Initialize coords ########################## - if not coords_key.isKeySet() and coords is not None and len(coords) > 0: - if quantization_mode == SparseTensorQuantizationMode.RANDOM_SUBSAMPLE: - force_remap = True - return_inverse = False - elif quantization_mode == SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE: - force_remap = True - return_inverse = True - - self.unique_index, self.inverse_mapping = coords_manager.initialize( - coords, - coords_key, - force_creation=force_creation, - force_remap=force_remap, - allow_duplicate_coords=allow_duplicate_coords, - return_inverse=return_inverse) - - if quantization_mode == SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE: - self._CF = feats - self._CC = coords - feats = MEB.quantization_average_features( - feats, torch.arange(len(feats)), self.inverse_mapping, - len(self.unique_index), 0) - coords = coords[self.unique_index] - elif force_remap: - assert len(self.unique_index) > 0 - self._CC = coords - self._CF = feats - coords = coords[self.unique_index] - feats = feats[self.unique_index] - - elif coords is not None: # empty / invalid coords - assert isinstance(coords, torch.IntTensor) - assert coords.ndim == 2 - coords_manager.initialize( - coords, - coords_key, - force_creation=force_creation, - force_remap=False, - allow_duplicate_coords=False, - return_inverse=False) - elif coords_key is not None: - assert coords_key.isKeySet() - - self._F = feats.contiguous() - self._C = coords - self.coords_key = coords_key - self.coords_man = coords_manager + if coordinates is not None: + assert ( + features.is_cuda == coordinates.is_cuda + ), "Features and coordinates must have the same backend." + ( + self.coordinate_map_key, + (unique_index, self.inverse_mapping), + ) = self._manager.insert_and_map(coordinates, *coordinate_map_key.get_key()) + self.unique_index = unique_index.long() + coordinates = coordinates[self.unique_index] + + if self.quantization_mode in [ + SparseTensorQuantizationMode.UNWEIGHTED_SUM, + SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE, + ]: + N = len(features) + COO = torch.stack( + (self.inverse_mapping.long(), torch.arange(N)), 0 + ).long() + self.sp_mapping = torch.sparse.FloatTensor( + COO, + torch.ones(N), + torch.Size([len(self.unique_index), len(features)]), + ) + if ( + self.quantization_mode + == SparseTensorQuantizationMode.UNWEIGHTED_SUM + ): + features = self.sp_mapping.matmul(features) + else: + features = self.sp_mapping.matmul( + features + ) / self.sp_mapping.matmul(torch.ones(len(features), 1)) + else: + features = features[self.unique_index] + + elif coordinate_map_key is not None: + assert ( + coordinate_map_key.is_key_set() + ), "The coordinate key must be a valid key." + self.coordinate_map_key = coordinate_map_key + + self._F = features + self._C = coordinates + + @property + def coordinate_manager(self): + return self._manager @property def tensor_stride(self): - return self.coords_key.getTensorStride() + return self.coordinate_map_key.get_tensor_stride() @tensor_stride.setter def tensor_stride(self, p): @@ -353,19 +389,19 @@ def tensor_stride(self, p): This function is not recommended to be used directly. """ p = convert_to_int_list(p, self.D) - self.coords_key.setTensorStride(p) + self.coordinate_map_key.set_tensor_stride(p) - def _get_coords(self): - return self.coords_man.get_coords(self.coords_key) + def _get_coordinates(self): + return self._manager.get_coordinates(self.coordinate_map_key) @property def C(self): r"""The alias of :attr:`coords`. """ - return self.coords + return self.coordinates @property - def coords(self): + def coordinates(self): r""" The coordinates of the current sparse tensor. The coordinates are represented as a :math:`N \times (D + 1)` dimensional matrix where @@ -376,7 +412,7 @@ def coords(self): different instances in a batch. """ if self._C is None: - self._C = self._get_coords() + self._C = self._get_coordinates() return self._C @property @@ -388,8 +424,7 @@ def decomposed_coordinates(self): zero elements in the :math:`i`th batch index in :math:`D` dimensional space. """ - row_inds_list = self.coords_man.get_row_indices_per_batch( - self.coords_key) + row_inds_list = self._manager.get_row_indices_per_batch(self.coordinate_map_key) return [self.C[row_inds, 1:] for row_inds in row_inds_list] def coordinates_at(self, batch_index): @@ -400,8 +435,9 @@ def coordinates_at(self, batch_index): is the number of non zero elements in the :math:`i`th batch index in :math:`D` dimensional space. """ - row_inds = self.coords_man.get_row_indices_at(self.coords_key, - batch_index) + row_inds = self._manager.get_row_indices_at( + self.coordinate_map_key, batch_index + ) return self.C[row_inds, 1:] @property @@ -411,7 +447,7 @@ def F(self): return self._F @property - def feats(self): + def features(self): r""" The features of the current sparse tensor. The features are :math:`N \times D_F` where :math:`N` is the number of points in the space and @@ -429,8 +465,7 @@ def decomposed_features(self): zero elements in the :math:`i`th batch index in :math:`D` dimensional space. """ - row_inds_list = self.coords_man.get_row_indices_per_batch( - self.coords_key) + row_inds_list = self._manager.get_row_indices_per_batch(self.coordinate_map_key) return [self._F[row_inds] for row_inds in row_inds_list] def features_at(self, batch_index): @@ -441,8 +476,9 @@ def features_at(self, batch_index): zero elements in the specified batch index and :math:`N_F` is the number of channels. """ - row_inds = self.coords_man.get_row_indices_at(self.coords_key, - batch_index) + row_inds = self._manager.get_row_indices_at( + self.coordinate_map_key, batch_index + ) return self._F[row_inds] def coordinates_and_features_at(self, batch_index): @@ -456,8 +492,9 @@ def coordinates_and_features_at(self, batch_index): matrix :math:`N` is the number of non zero elements in the specified batch index and :math:`N_F` is the number of channels. """ - row_inds = self.coords_man.get_row_indices_at(self.coords_key, - batch_index) + row_inds = self._manager.get_row_indices_at( + self.coordinate_map_key, batch_index + ) return self.C[row_inds, 1:], self._F[row_inds] @property @@ -465,18 +502,11 @@ def decomposed_coordinates_and_features(self): r"""Returns a list of coordinates and a list of features per batch.abs """ - row_inds_list = self.coords_man.get_row_indices_per_batch( - self.coords_key) - return [self.C[row_inds, 1:] for row_inds in row_inds_list], \ - [self._F[row_inds] for row_inds in row_inds_list] - - @property - def D(self): - r""" - The spatial dimension of the sparse tensor. This is equal to the number - of columns of :attr:`C` minus 1. - """ - return self.coords_key.D + row_inds_list = self._manager.get_row_indices_per_batch(self.coordinate_map_key) + return ( + [self.C[row_inds, 1:] for row_inds in row_inds_list], + [self._F[row_inds] for row_inds in row_inds_list], + ) @property def dimension(self): @@ -499,16 +529,28 @@ def double(self): def set_tensor_stride(self, s): ss = convert_to_int_list(s, self.D) - self.coords_key.setTensorStride(ss) + self.coordinate_map_key.set_tensor_stride(ss) def __repr__(self): - return self.__class__.__name__ + '(' + os.linesep \ - + ' Coords=' + str(self.C) + os.linesep \ - + ' Feats=' + str(self.F) + os.linesep \ - + ' coords_key=' + str(self.coords_key) \ - + ' tensor_stride=' + str(self.coords_key.getTensorStride()) + os.linesep \ - + ' coords_man=' + str(self.coords_man) \ - + ' spatial dimension=' + str(self.D) + ')' + return ( + self.__class__.__name__ + + "(" + + os.linesep + + " coordinates=" + + str(self.C) + + os.linesep + + " features=" + + str(self.F) + + os.linesep + + " coordinate_map_key=" + + str(self.coordinate_map_key) + + os.linesep + + " coordinate_manager=" + + str(self._manager) + + " spatial dimension=" + + str(self.D) + + ")" + ) def __len__(self): return len(self._F) @@ -539,36 +581,31 @@ def dtype(self): def get_device(self): return self._F.get_device() - # Operation overloading - def __iadd__(self, other): + def _is_same_key(self, other): assert isinstance(other, SparseTensor) - assert self.coords_man == other.coords_man, COORDS_MAN_DIFFERENT_ERROR - assert self.coords_key == other.coords_key, COORDS_KEY_DIFFERENT_ERROR + assert self._manager == other._manager, COORDS_MAN_DIFFERENT_ERROR + assert ( + self.coordinate_map_key == other.coordinate_map_key + ), COORDS_KEY_DIFFERENT_ERROR + # Operation overloading + def __iadd__(self, other): + self._is_same_key(other) self._F += other.F return self def __isub__(self, other): - assert isinstance(other, SparseTensor) - assert self.coords_man == other.coords_man, COORDS_MAN_DIFFERENT_ERROR - assert self.coords_key == other.coords_key, COORDS_KEY_DIFFERENT_ERROR - + self._is_same_key(other) self._F -= other.F return self def __imul__(self, other): - assert isinstance(other, SparseTensor) - assert self.coords_man == other.coords_man, COORDS_MAN_DIFFERENT_ERROR - assert self.coords_key == other.coords_key, COORDS_KEY_DIFFERENT_ERROR - + self._is_same_key(other) self._F *= other.F return self def __idiv__(self, other): - assert isinstance(other, SparseTensor) - assert self.coords_man == other.coords_man, COORDS_MAN_DIFFERENT_ERROR - assert self.coords_key == other.coords_key, COORDS_KEY_DIFFERENT_ERROR - + self._is_same_key(other) self._F /= other.F return self @@ -582,31 +619,35 @@ def __add__(self, other): """ assert isinstance(other, (SparseTensor, torch.Tensor)) if isinstance(other, SparseTensor): - assert self.coords_man == other.coords_man, COORDS_MAN_DIFFERENT_ERROR + assert self._manager == other._manager, COORDS_MAN_DIFFERENT_ERROR - if self.coords_key == other.coords_key: + if self.coordinate_map_key == other.coordinate_map_key: return SparseTensor( self._F + other.F, - coords_key=self.coords_key, - coords_manager=self.coords_man) + coordinate_map_key=self.coordinate_map_key, + coords_manager=self._manager, + ) else: # Generate union maps - out_key = CoordsKey(self.coords_man.D) - ins, outs = self.coords_man.get_union_map( - (self.coords_key, other.coords_key), out_key) - N_out = self.coords_man.get_coords_size_by_coords_key(out_key) - out_F = torch.zeros((N_out, self._F.size(1)), - dtype=self.dtype, - device=self.device) + out_key = CoordsKey(self._manager.D) + ins, outs = self._manager.get_union_map( + (self.coordinate_map_key, other.coordinate_map_key), out_key + ) + N_out = self._manager.get_coords_size_by_coordinate_map_key(out_key) + out_F = torch.zeros( + (N_out, self._F.size(1)), dtype=self.dtype, device=self.device + ) out_F[outs[0]] = self._F[ins[0]] out_F[outs[1]] += other._F[ins[1]] return SparseTensor( - out_F, coords_key=out_key, coords_manager=self.coords_man) + out_F, coordinate_map_key=out_key, coords_manager=self._manager + ) else: # when it is a torch.Tensor return SparseTensor( self._F + other, - coords_key=self.coords_key, - coords_manager=self.coords_man) + coordinate_map_key=self.coordinate_map_key, + coords_manager=self._manager, + ) def __sub__(self, other): r""" @@ -617,32 +658,36 @@ def __sub__(self, other): """ assert isinstance(other, (SparseTensor, torch.Tensor)) if isinstance(other, SparseTensor): - assert self.coords_man == other.coords_man, COORDS_MAN_DIFFERENT_ERROR + assert self._manager == other._manager, COORDS_MAN_DIFFERENT_ERROR - if self.coords_key == other.coords_key: + if self.coordinate_map_key == other.coordinate_map_key: return SparseTensor( self._F - other.F, - coords_key=self.coords_key, - coords_manager=self.coords_man) + coordinate_map_key=self.coordinate_map_key, + coords_manager=self._manager, + ) else: # Generate union maps - out_key = CoordsKey(self.coords_man.D) - ins, outs = self.coords_man.get_union_map( - (self.coords_key, other.coords_key), out_key) - N_out = self.coords_man.get_coords_size_by_coords_key(out_key) - out_F = torch.zeros((N_out, self._F.size(1)), - dtype=self.dtype, - device=self.device) + out_key = CoordsKey(self._manager.D) + ins, outs = self._manager.get_union_map( + (self.coordinate_map_key, other.coordinate_map_key), out_key + ) + N_out = self._manager.get_coords_size_by_coordinate_map_key(out_key) + out_F = torch.zeros( + (N_out, self._F.size(1)), dtype=self.dtype, device=self.device + ) out_F[outs[0]] = self._F[ins[0]] out_F[outs[1]] -= other._F[ins[1]] return SparseTensor( - out_F, coords_key=out_key, coords_manager=self.coords_man) + out_F, coordinate_map_key=out_key, coords_manager=self._manager + ) else: # when it is a torch.Tensor return SparseTensor( self._F - other, - coords_key=self.coords_key, - coords_manager=self.coords_man) + coordinate_map_key=self.coordinate_map_key, + coords_manager=self._manager, + ) def __mul__(self, other): r""" @@ -654,31 +699,35 @@ def __mul__(self, other): """ assert isinstance(other, (SparseTensor, torch.Tensor)) if isinstance(other, SparseTensor): - assert self.coords_man == other.coords_man, COORDS_MAN_DIFFERENT_ERROR + assert self._manager == other._manager, COORDS_MAN_DIFFERENT_ERROR - if self.coords_key == other.coords_key: + if self.coordinate_map_key == other.coordinate_map_key: return SparseTensor( self._F * other.F, - coords_key=self.coords_key, - coords_manager=self.coords_man) + coordinate_map_key=self.coordinate_map_key, + coords_manager=self._manager, + ) else: # Generate union maps - out_key = CoordsKey(self.coords_man.D) - ins, outs = self.coords_man.get_union_map( - (self.coords_key, other.coords_key), out_key) - N_out = self.coords_man.get_coords_size_by_coords_key(out_key) - out_F = torch.zeros((N_out, self._F.size(1)), - dtype=self.dtype, - device=self.device) + out_key = CoordsKey(self._manager.D) + ins, outs = self._manager.get_union_map( + (self.coordinate_map_key, other.coordinate_map_key), out_key + ) + N_out = self._manager.get_coords_size_by_coordinate_map_key(out_key) + out_F = torch.zeros( + (N_out, self._F.size(1)), dtype=self.dtype, device=self.device + ) out_F[outs[0]] = self._F[ins[0]] out_F[outs[1]] *= other._F[ins[1]] return SparseTensor( - out_F, coords_key=out_key, coords_manager=self.coords_man) + out_F, coordinate_map_key=out_key, coords_manager=self._manager + ) else: # when it is a torch.Tensor return SparseTensor( self._F * other, - coords_key=self.coords_key, - coords_manager=self.coords_man) + coordinate_map_key=self.coordinate_map_key, + coords_manager=self._manager, + ) def __truediv__(self, other): r""" @@ -690,37 +739,42 @@ def __truediv__(self, other): """ assert isinstance(other, (SparseTensor, torch.Tensor)) if isinstance(other, SparseTensor): - assert self.coords_man == other.coords_man, COORDS_MAN_DIFFERENT_ERROR + assert self._manager == other._manager, COORDS_MAN_DIFFERENT_ERROR - if self.coords_key == other.coords_key: + if self.coordinate_map_key == other.coordinate_map_key: return SparseTensor( self._F / other.F, - coords_key=self.coords_key, - coords_manager=self.coords_man) + coordinate_map_key=self.coordinate_map_key, + coords_manager=self._manager, + ) else: # Generate union maps - out_key = CoordsKey(self.coords_man.D) - ins, outs = self.coords_man.get_union_map( - (self.coords_key, other.coords_key), out_key) - N_out = self.coords_man.get_coords_size_by_coords_key(out_key) - out_F = torch.zeros((N_out, self._F.size(1)), - dtype=self.dtype, - device=self.device) + out_key = CoordsKey(self._manager.D) + ins, outs = self._manager.get_union_map( + (self.coordinate_map_key, other.coordinate_map_key), out_key + ) + N_out = self._manager.get_coords_size_by_coordinate_map_key(out_key) + out_F = torch.zeros( + (N_out, self._F.size(1)), dtype=self.dtype, device=self.device + ) out_F[outs[0]] = self._F[ins[0]] out_F[outs[1]] /= other._F[ins[1]] return SparseTensor( - out_F, coords_key=out_key, coords_manager=self.coords_man) + out_F, coordinate_map_key=out_key, coords_manager=self._manager + ) else: # when it is a torch.Tensor return SparseTensor( self._F / other, - coords_key=self.coords_key, - coords_manager=self.coords_man) + coordinate_map_key=self.coordinate_map_key, + coords_manager=self._manager, + ) def __power__(self, power): return SparseTensor( - self._F**power, - coords_key=self.coords_key, - coords_manager=self.coords_man) + self._F ** power, + coordinate_map_key=self.coordinate_map_key, + coords_manager=self._manager, + ) # Conversion functions def sparse(self, min_coords=None, max_coords=None, contract_coords=True): @@ -770,14 +824,14 @@ def torch_sparse_Tensor(coords, feats, size=None): elif feats.dtype == torch.float32: return torch.sparse.FloatTensor(coords, feats) else: - raise ValueError('Feature type not supported.') + raise ValueError("Feature type not supported.") else: if feats.dtype == torch.float64: return torch.sparse.DoubleTensor(coords, feats, size) elif feats.dtype == torch.float32: return torch.sparse.FloatTensor(coords, feats, size) else: - raise ValueError('Feature type not supported.') + raise ValueError("Feature type not supported.") # Use int tensor for all operations tensor_stride = torch.IntTensor(self.tensor_stride) @@ -792,14 +846,18 @@ def torch_sparse_Tensor(coords, feats, size=None): elif min_coords.ndim == 1: min_coords = min_coords.unsqueeze(0) - assert (min_coords % tensor_stride).sum() == 0, \ - "The minimum coordinates must be divisible by the tensor stride." + assert ( + min_coords % tensor_stride + ).sum() == 0, "The minimum coordinates must be divisible by the tensor stride." if max_coords is not None: if max_coords.ndim == 1: max_coords = max_coords.unsqueeze(0) - assert (max_coords % tensor_stride).sum() == 0, \ + assert ( + max_coords % tensor_stride + ).sum() == 0, ( "The maximum coordinates must be divisible by the tensor stride." + ) coords -= min_coords @@ -823,11 +881,12 @@ def torch_sparse_Tensor(coords, feats, size=None): # Squeeze to make the size one-dimensional size = size.squeeze() - max_batch = max(self.coords_man.get_batch_indices()) + max_batch = max(self._manager.get_batch_indices()) size = torch.Size([max_batch + 1, *size, self.F.size(1)]) - sparse_tensor = torch_sparse_Tensor(new_coords.t().to(self.F.device), - self.F, size) + sparse_tensor = torch_sparse_Tensor( + new_coords.t().to(self.F.device), self.F, size + ) tensor_stride = torch.IntTensor(self.tensor_stride) return sparse_tensor, min_coords, tensor_stride @@ -884,14 +943,18 @@ def dense(self, min_coords=None, max_coords=None, contract_coords=True): elif min_coords.ndim == 1: min_coords = min_coords.unsqueeze(0) - assert (min_coords % tensor_stride).sum() == 0, \ - "The minimum coordinates must be divisible by the tensor stride." + assert ( + min_coords % tensor_stride + ).sum() == 0, "The minimum coordinates must be divisible by the tensor stride." if max_coords is not None: if max_coords.ndim == 1: max_coords = max_coords.unsqueeze(0) - assert (max_coords % tensor_stride).sum() == 0, \ + assert ( + max_coords % tensor_stride + ).sum() == 0, ( "The maximum coordinates must be divisible by the tensor stride." + ) coords -= min_coords @@ -907,7 +970,7 @@ def dense(self, min_coords=None, max_coords=None, contract_coords=True): size = None nchannels = self.F.size(1) - max_batch = max(self.coords_man.get_batch_indices()) + max_batch = max(self._manager.get_batch_indices()) if max_coords is not None: size = max_coords - min_coords + 1 # inclusive # Squeeze to make the size one-dimensional @@ -921,9 +984,11 @@ def dense(self, min_coords=None, max_coords=None, contract_coords=True): tcoords = coords.t().long() batch_indices = batch_indices.long() - exec("dense_F[batch_indices, :, " + - ", ".join([f"tcoords[{i}]" for i in range(len(tcoords))]) + - "] = self.F") + exec( + "dense_F[batch_indices, :, " + + ", ".join([f"tcoords[{i}]" for i in range(len(tcoords))]) + + "] = self.F" + ) tensor_stride = torch.IntTensor(self.tensor_stride) return dense_F, min_coords, tensor_stride @@ -955,7 +1020,11 @@ def slice(self, X, slicing_mode=0): >>> len(outputs) == len(coords) # recovers the original ordering and length """ # Currently only supports unweighted slice. - return self.feats[X.inverse_mapping] + assert X.quantization_mode in [ + SparseTensorQuantizationMode.RANDOM_SUBSAMPLE, + SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE, + ], "slice only available for sparse tensors with quantization RANDOM_SUBSAMPLE or UNWEIGHTED_AVERAGE" + return self.F[X.inverse_mapping] def features_at_coords(self, query_coords: torch.Tensor): r"""Extract features at the specified coordinate matrix. @@ -975,41 +1044,42 @@ def features_at_coords(self, query_coords: torch.Tensor): `query_feats` will be 0. """ - cm = self.coords_man + cm = self._manager - self_key = self.coords_key - query_key = cm.create_coords_key(query_coords) + self_key = self.coordinate_map_key + query_key = cm.create_coordinate_map_key(query_coords) self_indices, query_indices = cm.get_kernel_map( - self_key, query_key, kernel_size=1) - query_feats = torch.zeros((len(query_coords), self._F.size(1)), - dtype=self.dtype, - device=self.device) + self_key, query_key, kernel_size=1 + ) + query_feats = torch.zeros( + (len(query_coords), self._F.size(1)), dtype=self.dtype, device=self.device + ) if len(self_indices[0]) > 0: query_feats[query_indices[0]] = self._F[self_indices[0]] return query_feats, query_indices[0] -def _get_coords_key( - input: SparseTensor, - coords: Union[torch.IntTensor, CoordsKey, SparseTensor] = None, - tensor_stride: Union[Sequence, np.ndarray, torch.IntTensor] = 1): +def _get_coordinate_map_key( + input: SparseTensor, + coordinates: torch.Tensor = None, + tensor_stride: StrideType = 1, +): r"""Process coords according to its type. """ - if coords is not None: - assert isinstance(coords, (CoordsKey, torch.IntTensor, SparseTensor)) - if isinstance(coords, torch.IntTensor): - coords_key = input.coords_man.create_coords_key( - coords, - tensor_stride=tensor_stride, - force_creation=True, - force_remap=True, - allow_duplicate_coords=True) - elif isinstance(coords, SparseTensor): - coords_key = coords.coords_key - else: # CoordsKey type due to the previous assertion - coords_key = coords - else: - coords_key = CoordsKey(input.D) - return coords_key + if coordinates is not None: + assert isinstance(coords, (CoordinateMapKey, torch.Tensor, SparseTensor)) + if isinstance(coordinates, torch.Tensor): + coordinate_map_key = input._manager.create_coordinate_map_key( + coordinates, tensor_stride=tensor_stride + ) + elif isinstance(coordinates, SparseTensor): + coordinate_map_key = coordinates.coordinate_map_key + else: # CoordinateMapKey type due to the previous assertion + coordinate_map_key = coordinates + else: # coordinates is None + coordinate_map_key = CoordinateMapKey( + input.coordinate_map_key.get_coordinate_size() + ) + return coordinate_map_key diff --git a/MinkowskiEngine/__init__.py b/MinkowskiEngine/__init__.py index b57b5b11..26893c29 100644 --- a/MinkowskiEngine/__init__.py +++ b/MinkowskiEngine/__init__.py @@ -40,21 +40,42 @@ RegionType, ) -# from SparseTensor import SparseTensor, SparseTensorOperationMode, SparseTensorQuantizationMode, \ -# set_sparse_tensor_operation_mode, sparse_tensor_operation_mode, clear_global_coords_man +from MinkowskiKernelGenerator import ( + KernelRegion, + KernelGenerator, + convert_region_type, + get_kernel_volume, +) -# from Common import RegionType, convert_to_int_tensor, convert_region_type, \ -# MinkowskiModuleBase, KernelGenerator, GlobalPoolingMode -# -from MinkowskiCoords import ( +from MinkowskiSparseTensor import ( + SparseTensor, + SparseTensorOperationMode, + SparseTensorQuantizationMode, + set_sparse_tensor_operation_mode, + sparse_tensor_operation_mode, + clear_global_coordinate_mananager, +) + +from MinkowskiCommon import ( + convert_to_int_tensor, + MinkowskiModuleBase, + GlobalPoolingMode, +) + +from MinkowskiCoordinateManager import ( set_memory_manager_backend, set_gpu_allocator, CoordsManager, CoordinateManager, ) -# from MinkowskiConvolution import MinkowskiConvolutionFunction, MinkowskiConvolution, \ -# MinkowskiConvolutionTransposeFunction, MinkowskiConvolutionTranspose +from MinkowskiConvolution import ( + MinkowskiConvolutionFunction, + MinkowskiConvolution, + MinkowskiConvolutionTransposeFunction, + MinkowskiConvolutionTranspose, +) + # # from MinkowskiChannelwiseConvolution import MinkowskiChannelwiseConvolution # @@ -89,5 +110,6 @@ # # import MinkowskiFunctional # -# import MinkowskiEngine.utils as utils +import MinkowskiEngine.utils as utils + # import MinkowskiEngine.modules as modules diff --git a/MinkowskiEngine/utils/coords.py b/MinkowskiEngine/utils/coords.py index aae51884..6766e7e7 100644 --- a/MinkowskiEngine/utils/coords.py +++ b/MinkowskiEngine/utils/coords.py @@ -21,7 +21,7 @@ # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part # of the code. -from SparseTensor import SparseTensor +from MinkowskiSparseTensor import SparseTensor def get_coords_map(x, y): diff --git a/docs/migration_05.md b/docs/migration_05.md new file mode 100644 index 00000000..0fe7db9e --- /dev/null +++ b/docs/migration_05.md @@ -0,0 +1,94 @@ +# Migration Guide from v0.4.x to v0.5.0 + +## Summary + +```python +# 0.4 +ME.SparseTensor(feats=feats, coords=coords, D=3) +# 0.5 +ME.SparseTensor(feats=feats, coords=coords, D=3) +``` + + +``` +# 0.4 +ME.MinkowskiConvolution(..., has_bias=True) +# 0.5 +ME.MinkowskiConvolution(..., bias=True) +``` + + +``` +# 0.4 +RegionType.HYPERCUBE +# 0.5 +RegionType.HYPER_CUBE +``` + + +## Definitions + +### `CoordinateMap` + +A coordinate map refers to a map object that converts a D-dimensional +coordinate into a row index for a feature matrix where the corresponding +feature for the coordinate is located. This can be implemented using +`std::map`, `std::unordered_map` or a hash-table with the right hash function +and the equality function. + +### `CoordinateKey` + +A `CoordinateKey` or `CoordinateMapKey` refers to a unique identifier that can +be used to retrieve a `CoordinateMap`. + +### `tensor_stride` + +A tensor stride is a minimum distance between non-zero elements in a sparse +tensor. If we take a stride-2 convolution on a sparse tensor with tensor +stride 1, the resulting sparse tensor will have tensor stride 2. If we apply +two stride-2 convolutions on a sparse tensor with tensor stride 3, the +resulting sparse tensor will have the tensor stride 2 x 2 x 3 = 12. + +## From CoordsKey to CoordinateMapKey + +CoordsKey should not be called in most cases, but in rare cases where you used +it. Please review this section to update your code. + +One of the major difference is that we expose the pybind11 object directly to +the python side to remove the redundant abstraction layer. + +In v0.4, Minkowski Engine uses a `uint64_t` hash key to identify a +`CoordinateMap`, but from v0.5, we use a tensor stride + + +## From CoordsManager to CoordinateManager + +CoordinateManager should not be called in most cases, but if you do please re + + +### Initialization + +```python +# 0.4.x +manager = CoordsManager(D=3) +# 0.5.x +manager = CoordinateManager(D=3) +``` + +## Initializing a new CoordinateMap + +```python +# 0.4.x +manager = CoordsManager(D = 3) +manager.initialize(torch.IntTens + def initialize(self, + coords: torch.IntTensor, + coords_key: CoordsKey, + force_creation: bool = False, + force_remap: bool = False, + allow_duplicate_coords: bool = False, + return_inverse: bool = False) -> torch.LongTensor: +``` + + +## Consistent Layer Arguments diff --git a/pybind/extern.hpp b/pybind/extern.hpp index cffa18bc..96906b24 100644 --- a/pybind/extern.hpp +++ b/pybind/extern.hpp @@ -103,6 +103,29 @@ std::pair ConvolutionBackwardGPU( gpu_manager_type *p_map_manager); #endif +/************************************* + * Quantization + *************************************/ +/* +template +std::vector +quantize_np(py::array_t coords); + +vector quantize_label_np( + py::array_t coords, + py::array_t labels, + int invalid_label); + +template vector quantize_th(at::Tensor coords); + +vector quantize_label_th(at::Tensor coords, at::Tensor labels, + int invalid_label); + +at::Tensor quantization_average_features(at::Tensor in_feat, at::Tensor in_map, + at::Tensor out_map, int out_nrows, + int mode); +*/ + } // end namespace minkowski namespace py = pybind11; @@ -113,7 +136,7 @@ void instantiate_cpu_func(py::module &m, const std::string &dtypestr) { &minkowski::ConvolutionForwardCPU, py::call_guard()); - m.def((std::string("ConvolutionForwardCPU") + dtypestr).c_str(), + m.def((std::string("ConvolutionBackwardCPU") + dtypestr).c_str(), &minkowski::ConvolutionBackwardCPU, py::call_guard()); @@ -264,25 +287,62 @@ void instantiate_gpu_func(py::module &m, const std::string &dtypestr) { TemplatedAllocator>, py::call_guard()); - m.def((std::string("ConvolutionForwardGPU") + dtypestr).c_str(), + m.def((std::string("ConvolutionBackwardGPU") + dtypestr).c_str(), &minkowski::ConvolutionBackwardGPU, py::call_guard()); } #endif -template class TemplatedAllocator, - template class A> - class CoordinateMapType> -void instantiate_manager(py::module &m, const std::string &dtypestr) { - using manager_type = - minkowski::CoordinateMapManager; +void initialize_non_templated_classes(py::module &m) { + // Enums + py::enum_( + m, "GPUMemoryAllocatorType") + .value("PYTORCH", minkowski::GPUMemoryAllocatorBackend::Type::PYTORCH) + .value("CUDA", minkowski::GPUMemoryAllocatorBackend::Type::CUDA) + .export_values(); + + py::enum_(m, "CUDAKernelMapMode") + .value("MEMORY_EFFICIENT", + minkowski::CUDAKernelMapMode::Mode::MEMORY_EFFICIENT) + .value("SPEED_OPTIMIZED", + minkowski::CUDAKernelMapMode::Mode::SPEED_OPTIMIZED) + .export_values(); + + py::enum_(m, "CoordinateMapType") + .value("CPU", minkowski::CoordinateMapBackend::Type::CPU) + .value("CUDA", minkowski::CoordinateMapBackend::Type::CUDA) + .export_values(); + py::enum_(m, "RegionType") + .value("HYPER_CUBE", minkowski::RegionType::Type::HYPER_CUBE) + .value("HYPER_CROSS", minkowski::RegionType::Type::HYPER_CROSS) + .value("CUSTOM", minkowski::RegionType::Type::CUSTOM) + .export_values(); + + // Classes + py::class_(m, "CoordinateMapKey") + .def(py::init()) + .def(py::init()) + .def("__repr__", &minkowski::CoordinateMapKey::to_string) + .def("is_key_set", &minkowski::CoordinateMapKey::is_key_set) + .def("get_coordinate_size", + &minkowski::CoordinateMapKey::get_coordinate_size) + .def("get_key", &minkowski::CoordinateMapKey::get_key) + .def("set_key", (void (minkowski::CoordinateMapKey::*)( + minkowski::default_types::stride_type, std::string)) & + minkowski::CoordinateMapKey::set_key) + .def("get_tensor_stride", + &minkowski::CoordinateMapKey::get_tensor_stride); +} + +template +void instantiate_manager(py::module &m, const std::string &dtypestr) { py::class_( m, (std::string("CoordinateMapManager") + dtypestr).c_str()) .def(py::init<>()) + .def(py::init()) // TODO .def("insert", &manager_type::insert) .def("insert_and_map", &manager_type::insert_and_map) .def("stride", diff --git a/pybind/minkowski.cpp b/pybind/minkowski.cpp index 3011d2ee..7c4fee7c 100644 --- a/pybind/minkowski.cpp +++ b/pybind/minkowski.cpp @@ -37,38 +37,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Constant function m.def("is_cuda_available", &is_cuda_available); - // Enums - py::enum_( - m, "GPUMemoryAllocatorType") - .value("PYTORCH", minkowski::GPUMemoryAllocatorBackend::Type::PYTORCH) - .value("CUDA", minkowski::GPUMemoryAllocatorBackend::Type::CUDA) - .export_values(); - - py::enum_(m, "CoordinateMapType") - .value("CPU", minkowski::CoordinateMapBackend::Type::CPU) - .value("CUDA", minkowski::CoordinateMapBackend::Type::CUDA) - .export_values(); - - py::enum_(m, "RegionType") - .value("HYPER_CUBE", minkowski::RegionType::Type::HYPER_CUBE) - .value("HYPER_CROSS", minkowski::RegionType::Type::HYPER_CROSS) - .value("CUSTOM", minkowski::RegionType::Type::CUSTOM) - .export_values(); - - // Classes - py::class_(m, "CoordinateMapKey") - .def(py::init()) - .def(py::init()) - .def("__repr__", &minkowski::CoordinateMapKey::to_string) - .def("is_key_set", &minkowski::CoordinateMapKey::is_key_set) - .def("get_coordinate_size", - &minkowski::CoordinateMapKey::get_coordinate_size) - .def("get_key", &minkowski::CoordinateMapKey::get_key) - .def("set_key", (void (minkowski::CoordinateMapKey::*)( - minkowski::default_types::stride_type, std::string)) & - minkowski::CoordinateMapKey::set_key) - .def("get_tensor_stride", - &minkowski::CoordinateMapKey::get_tensor_stride); + initialize_non_templated_classes(m); // Manager instantiate_manager( diff --git a/pybind/minkowski.cu b/pybind/minkowski.cu index fc1c5b7d..7bc66296 100644 --- a/pybind/minkowski.cu +++ b/pybind/minkowski.cu @@ -34,62 +34,61 @@ #include PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - // Enums - py::enum_(m, "GPUMemoryAllocator") - .value("PYTORCH", minkowski::GPUMemoryAllocatorBackend::Type::PYTORCH) - .value("CUDA", minkowski::GPUMemoryAllocatorBackend::Type::CUDA) - .export_values(); + // Constant function + m.def("is_cuda_available", &is_cuda_available); - py::enum_(m, "CoordinateMap") - .value("CPU", minkowski::CoordinateMapBackend::Type::CPU) - .value("PYTORCH", minkowski::CoordinateMapBackend::Type::CUDA) - .export_values(); + initialize_non_templated_classes(m); - py::enum_(m, "RegionType") - .value("HYPER_CUBE", minkowski::RegionType::Type::HYPER_CUBE) - .value("HYPER_CROSS", minkowski::RegionType::Type::HYPER_CROSS) - .value("CUSTOM", minkowski::RegionType::Type::CUSTOM) - .export_values(); + /* + py::class_>(m, "CoordinateMapManagerCPU") + .def(py::init<>()) + .def(py::init()) + .def("insert_and_map", &minkowski::cpu_manager_type::insert_and_map) + .def("kernel_map", &minkowski::cpu_manager_type::kernel_map); - // Classes - py::class_(m, "CoordinateMapKey") - .def(py::init()) - .def(py::init()) - .def("__repr__", &minkowski::CoordinateMapKey::to_string) - .def("is_key_set", &minkowski::CoordinateMapKey::is_key_set) - .def("get_coordinate_size", - &minkowski::CoordinateMapKey::get_coordinate_size) - .def("get_key", &minkowski::CoordinateMapKey::get_key) - .def("set_key", (void (minkowski::CoordinateMapKey::*)( - minkowski::default_types::stride_type, std::string)) & - minkowski::CoordinateMapKey::set_key) - .def("get_tensor_stride", - &minkowski::CoordinateMapKey::get_tensor_stride); + py::class_>( + m, "CoordinateMapManagerGPU_c10") + .def(py::init<>()) + .def(py::init()) + .def("insert_and_map", + &minkowski::gpu_c10_manager_type::insert_and_map) + .def("stride", + (typename py::object // + (minkowski::gpu_c10_manager_type::*)( + minkowski::CoordinateMapKey const *in_map_key, + minkowski::default_types::stride_type const &kernel_stride)) & + minkowski::gpu_c10_manager_type::stride) + .def("size", (typename minkowski::default_types::size_type // + (minkowski::gpu_c10_manager_type::*)( + minkowski::CoordinateMapKey const *map_key) const) & + minkowski::gpu_c10_manager_type::size) + .def("kernel_map", &minkowski::gpu_c10_manager_type::kernel_map); + */ // Manager - instantiate_manager( - m, std::string("CPU")); + instantiate_manager>(m, + std::string("CPU")); #ifndef CPU_ONLY - instantiate_manager(m, - std::string("GPU_default")); - instantiate_manager(m, std::string("GPU_c10")); + instantiate_manager>( + m, std::string("GPU_default")); + instantiate_manager>( + m, std::string("GPU_c10")); #endif - // Functions instantiate_cpu_func(m, std::string("f")); instantiate_cpu_func(m, std::string("d")); #ifndef CPU_ONLY instantiate_gpu_func( - m, std::string("fd")); + m, std::string("f")); instantiate_gpu_func( - m, std::string("dd")); + m, std::string("d")); instantiate_gpu_func( - m, std::string("fc")); + m, std::string("f")); instantiate_gpu_func( - m, std::string("dc")); + m, std::string("d")); #endif } diff --git a/setup.py b/setup.py index 6820a2fe..1c5d66f7 100644 --- a/setup.py +++ b/setup.py @@ -128,6 +128,9 @@ def _argparse(pattern, argv, is_flag=True): print("--------------------------------") Extension = CppExtension else: + print("--------------------------------") + print("| CUDA compilation set |") + print("--------------------------------") # system python installation libraries.append("cusparse") diff --git a/src/convolution_cpu.cpp b/src/convolution_cpu.cpp index 9e3fb871..71753ebe 100644 --- a/src/convolution_cpu.cpp +++ b/src/convolution_cpu.cpp @@ -74,6 +74,7 @@ ConvolutionForwardCPU(at::Tensor const &in_feat, // : torch::kFloat64); torch::checkDim(c, arg_in_feat, 2); + torch::checkDim(c, arg_kernel, 3); ASSERT(in_feat.size(1) == kernel.size(1), "Input feature size and kernel size mismatch"); diff --git a/src/convolution_gpu.cu b/src/convolution_gpu.cu index e6079411..8b84035e 100644 --- a/src/convolution_gpu.cu +++ b/src/convolution_gpu.cu @@ -78,6 +78,7 @@ at::Tensor ConvolutionForwardGPU( : torch::kFloat64); torch::checkDim(c, arg_in_feat, 2); + torch::checkDim(c, arg_kernel, 3); ASSERT(in_feat.size(1) == kernel.size(1), "Input feature size and kernel size mismatch"); diff --git a/src/coordinate_map_cpu.hpp b/src/coordinate_map_cpu.hpp index 733d2d4c..fbfba221 100644 --- a/src/coordinate_map_cpu.hpp +++ b/src/coordinate_map_cpu.hpp @@ -118,13 +118,13 @@ class CoordinateMapCPU : public CoordinateMap @@ -357,17 +357,19 @@ class CoordinateMapCPU : public CoordinateMapfirst.data(), m_coordinate_size, dst_coordinate + m_coordinate_size * it->second); diff --git a/src/coordinate_map_gpu.cu b/src/coordinate_map_gpu.cu index 6203df5c..88f7724d 100644 --- a/src/coordinate_map_gpu.cu +++ b/src/coordinate_map_gpu.cu @@ -261,7 +261,7 @@ stride_copy(coordinate_type const *__restrict__ src_coordinates, // size_type const *__restrict__ stride, // coordinate_type *__restrict__ dst_coordinates, // size_type const num_threads, size_type const coordinate_size) { - extern __shared__ coordinate_type sh_stride[]; + extern __shared__ size_type sh_stride[]; auto const tx = threadIdx.x; auto const bx = blockIdx.x; @@ -276,9 +276,12 @@ stride_copy(coordinate_type const *__restrict__ src_coordinates, // dst_coordinates[dst_start] = src_coordinates[src_start]; for (index_type j = 1; j < coordinate_size; ++j) { dst_coordinates[dst_start + j] = - ((coordinate_type)floorf( - __fdiv_rd(src_coordinates[src_start + j], sh_stride[j - 1]))) * + (__float2int_rd(__fdiv_rd(src_coordinates[src_start + j], + sh_stride[j - 1]))) * sh_stride[j - 1]; + // (__double2int_rd( + // __ddiv_rn(src_coordinates[src_start + j], sh_stride[j - 1]))) * + // sh_stride[j - 1]; } } } @@ -311,8 +314,7 @@ CoordinateMapGPU::stride( auto const num_blocks = GET_BLOCKS(num_threads, CUDA_NUM_THREADS); detail::stride_copy - <<>>( + <<>>( const_coordinate_data(), thrust::raw_pointer_cast(m_valid_row_index.data()), thrust::raw_pointer_cast(stride_map.m_device_tensor_stride.data()), diff --git a/src/coordinate_map_key.hpp b/src/coordinate_map_key.hpp index a20c138a..6ecb4f78 100644 --- a/src/coordinate_map_key.hpp +++ b/src/coordinate_map_key.hpp @@ -133,7 +133,9 @@ class CoordinateMapKey { // misc functions std::string to_string() const { Formatter out; - out << "coordinate map key:" << m_key.first << ":" << m_key.second; + out << "coordinate map key:" << m_key.first; + if (m_key.second.length() > 0) + out << ":" << m_key.second; return out; } diff --git a/src/coordinate_map_manager.cpp b/src/coordinate_map_manager.cpp index 05c0d717..7783a2c2 100644 --- a/src/coordinate_map_manager.cpp +++ b/src/coordinate_map_manager.cpp @@ -183,15 +183,22 @@ struct insert_and_map_functor(); + for (default_types::index_type i = 0; i < mapping.size(); ++i) { + p_mapping[i] = mapping[i]; + } - std::copy_n(mapping.begin(), mapping.size(), - th_mapping.data_ptr()); - std::copy_n(inverse_mapping.begin(), inverse_mapping.size(), - th_inverse_mapping.data_ptr()); + int64_t *p_inverse_mapping = th_inverse_mapping.data_ptr(); + for (default_types::index_type i = 0; i < inverse_mapping.size(); ++i) { + p_inverse_mapping[i] = inverse_mapping[i]; + } return std::make_pair(std::move(th_mapping), std::move(th_inverse_mapping)); } @@ -1017,8 +1024,13 @@ CoordinateMapManager:: auto const ncols = map.coordinate_size(); // CPU torch.IntTensor - at::Tensor coordinates = torch::empty( - {(long)nrows, (long)ncols}, torch::TensorOptions().dtype(torch::kInt)); + auto options = torch::TensorOptions().dtype(torch::kInt).requires_grad(false); + if (!detail::is_cpu_coordinate_map::value) { + int device_id; + CUDA_CHECK(cudaGetDevice(&device_id)); + options = options.device(torch::kCUDA, device_id); + } + at::Tensor coordinates = torch::empty({(long)nrows, (long)ncols}, options); // copy to the out coords map.copy_coordinates(coordinates.template data_ptr()); diff --git a/src/coordinate_map_manager.cu b/src/coordinate_map_manager.cu index 4b1db942..92652f89 100644 --- a/src/coordinate_map_manager.cu +++ b/src/coordinate_map_manager.cu @@ -37,9 +37,9 @@ namespace minkowski { namespace detail { -template -__global__ void dtypeCopy(SrcType const *src, DstType *dst, size_t n) { - CUDA_KERNEL_LOOP(index, n) { dst[index] = src[index]; } +template +__global__ void cuda_copy_n(src_type const *src, uint32_t N, dst_type *dst) { + CUDA_KERNEL_LOOP(index, N) { dst[index] = src[index]; } } template (), - thrust::raw_pointer_cast(mapping.data()), - mapping.size() * sizeof(default_types::index_type), - cudaMemcpyDeviceToDevice)); - CUDA_CHECK( - cudaMemcpy(th_inverse_mapping.data_ptr(), - thrust::raw_pointer_cast(inverse_mapping.data()), - inverse_mapping.size() * sizeof(default_types::index_type), - cudaMemcpyDeviceToDevice)); + // TODO int64_t + at::Tensor th_mapping = torch::empty( + {(int64_t)mapping.size()}, + th_coordinate.options().requires_grad(false).dtype(torch::kInt64)); + at::Tensor th_inverse_mapping = torch::empty( + {(int64_t)inverse_mapping.size()}, + th_coordinate.options().requires_grad(false).dtype(torch::kInt64)); + + auto const num_blocks = + (mapping.size() + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; + + detail::cuda_copy_n + <<>>( + thrust::raw_pointer_cast(mapping.data()), mapping.size(), + th_mapping.data_ptr()); + + auto const num_inv_blocks = + (inverse_mapping.size() + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; + + detail::cuda_copy_n + <<>>( + thrust::raw_pointer_cast(inverse_mapping.data()), + inverse_mapping.size(), th_inverse_mapping.data_ptr()); return std::make_pair(std::move(th_mapping), std::move(th_inverse_mapping)); } diff --git a/src/math_functions.cu b/src/math_functions.cu index 5f96fe6a..a51b6a41 100644 --- a/src/math_functions.cu +++ b/src/math_functions.cu @@ -1,22 +1,22 @@ /* Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). * - * Permission is hereby granted, free of charge, to any person obtaining a copy of - * this software and associated documentation files (the "Software"), to deal in - * the Software without restriction, including without limitation the rights to - * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies - * of the Software, and to permit persons to whom the Software is furnished to do - * so, subject to the following conditions: + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. * * Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural * Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part diff --git a/src/quantization.cpp b/src/quantization.cpp index a1060bf7..c83abd32 100644 --- a/src/quantization.cpp +++ b/src/quantization.cpp @@ -39,6 +39,7 @@ namespace py = pybind11; namespace minkowski { +/* struct IndexLabel { int index; int label; @@ -230,6 +231,7 @@ template InOutMaps CopyToInOutMap(at::Tensor th_map) { return vec_map; } +*/ #ifndef CPU_ONLY template pInOutMaps CopyToInOutMapGPU(at::Tensor th_map) { diff --git a/tests/cpp/convolution_cpu_test.py b/tests/cpp/convolution_cpu_test.py index 1ddbf730..cccb456c 100644 --- a/tests/cpp/convolution_cpu_test.py +++ b/tests/cpp/convolution_cpu_test.py @@ -87,7 +87,7 @@ def test_pcd(self): IC, OC = 3, 16 coords, colors, pcd = load_file("1.ply") kernel_size = [3, 3, 3] - kernel_stride = [1, 1, 1] + kernel_stride = [2, 2, 2] kernel_dilation = [1, 1, 1] # size, in, out @@ -129,3 +129,53 @@ def test_pcd(self): min_time = min(time.time() - stime, min_time) print(f"{batch_size}\t{voxel_size}\t{manager.size(in_key)}\t{min_time}") + + def test_pcd2(self): + IC, OC = 3, 16 + coords, colors, pcd = load_file("1.ply") + kernel_size = [3, 3, 3] + kernel_stride = [2, 2, 2] + kernel_dilation = [1, 1, 1] + + for IC in [3, 8, 16, 32, 64, 128]: + for OC in [16, 32, 64, 128, 256]: + # size, in, out + kernel = torch.rand(np.prod(kernel_size), IC, OC) + for batch_size in [1]: + for voxel_size in [0.02]: + + min_time = 100000 + + dcoords = torch.from_numpy(np.floor(coords / voxel_size)).int() + bcoords = batched_coordinates( + [dcoords for i in range(batch_size)] + ) + + for i in range(10): + manager = _C.CoordinateMapManager() + + # batch insert + in_key, (unique_map, inverse_map) = manager.insert_and_map( + bcoords, [1, 1, 1], "" + ) + in_feats = torch.rand(manager.size(in_key), IC) + out_key = _C.CoordinateMapKey(4) + + stime = time.time() + out_features = _C.ConvolutionForwardCPUf( + in_feats, + kernel, + kernel_size, + kernel_stride, + kernel_dilation, + _C.RegionType.HYPER_CUBE, + torch.IntTensor(), + in_key, + out_key, + manager, + ) + min_time = min(time.time() - stime, min_time) + + print( + f"{batch_size}\t{manager.size(in_key)}\t{manager.size(out_key)}\t{IC}\t{OC}\t{min_time}" + ) diff --git a/tests/cpp/convolution_gpu_test.py b/tests/cpp/convolution_gpu_test.py index ea81464d..1bae19b1 100644 --- a/tests/cpp/convolution_gpu_test.py +++ b/tests/cpp/convolution_gpu_test.py @@ -102,9 +102,6 @@ def test_pcd(self): dcoords = torch.from_numpy(np.floor(coords / voxel_size)).int() bcoords = batched_coordinates([dcoords for i in range(batch_size)]) - tcolors = torch.from_numpy(colors).float() - bcolors = torch.cat([tcolors for i in range(batch_size)]).to(0) - for i in range(10): manager = _C.CoordinateMapManager() @@ -112,12 +109,12 @@ def test_pcd(self): in_key, (unique_map, inverse_map) = manager.insert_and_map( bcoords.to(0), [1, 1, 1], "" ) - ucolors = bcolors[unique_map.long()] + in_feats = torch.rand(manager.size(in_key), IC).to(0) out_key = _C.CoordinateMapKey(4) stime = time.time() out_features = _C.ConvolutionForwardGPUf( - ucolors, + in_feats, kernel, kernel_size, kernel_stride, @@ -130,4 +127,56 @@ def test_pcd(self): ) min_time = min(time.time() - stime, min_time) - print(f"{batch_size}\t{voxel_size}\t{manager.size(in_key)}\t{min_time}") + print( + f"{batch_size}\t{manager.size(in_key)}\t{manager.size(out_key)}\t{min_time}" + ) + + def test_pcd2(self): + IC, OC = 128, 128 + coords, colors, pcd = load_file("1.ply") + kernel_size = [3, 3, 3] + kernel_stride = [2, 2, 2] + kernel_dilation = [1, 1, 1] + + for IC in [3, 8, 16, 32, 64, 128]: + for OC in [16, 32, 64, 128, 256]: + # size, in, out + kernel = torch.rand(np.prod(kernel_size), IC, OC).to(0) + for batch_size in [1]: + for voxel_size in [0.02]: + + min_time = 100000 + + dcoords = torch.from_numpy(np.floor(coords / voxel_size)).int() + bcoords = batched_coordinates( + [dcoords for i in range(batch_size)] + ) + + for i in range(10): + manager = _C.CoordinateMapManager() + + # batch insert + in_key, (unique_map, inverse_map) = manager.insert_and_map( + bcoords.to(0), [1, 1, 1], "" + ) + in_feats = torch.rand(manager.size(in_key), IC).to(0) + out_key = _C.CoordinateMapKey(4) + + stime = time.time() + out_features = _C.ConvolutionForwardGPUf( + in_feats, + kernel, + kernel_size, + kernel_stride, + kernel_dilation, + _C.RegionType.HYPER_CUBE, + torch.IntTensor(), + in_key, + out_key, + manager, + ) + min_time = min(time.time() - stime, min_time) + + print( + f"{batch_size}\t{manager.size(in_key)}\t{manager.size(out_key)}\t{IC}\t{OC}\t{min_time}" + ) diff --git a/tests/cpp/kernel_region_cpu_test.cpp b/tests/cpp/kernel_region_cpu_test.cpp index 0e8557a4..65b03bcb 100644 --- a/tests/cpp/kernel_region_cpu_test.cpp +++ b/tests/cpp/kernel_region_cpu_test.cpp @@ -71,7 +71,7 @@ region_iterator_test(const torch::Tensor &coordinates, } auto region = cpu_kernel_region( - REGION_TYPE::HYPER_CUBE, D, tensor_stride.data(), s_kernel_size.data(), + RegionType::HYPER_CUBE, D, tensor_stride.data(), s_kernel_size.data(), dilation.data()); std::vector lb(D), ub(D); @@ -102,7 +102,7 @@ kernel_map_test(const torch::Tensor &in_coordinates, torch::TensorArg arg_out_coordinates(out_coordinates, "coordinates", 1); torch::TensorArg arg_kernel_size(kernel_size, "kernel_size", 2); - torch::CheckedFrom c = "region_iterator_test"; + torch::CheckedFrom c = "kernel_map_test"; torch::checkContiguous(c, arg_in_coordinates); torch::checkContiguous(c, arg_out_coordinates); torch::checkContiguous(c, arg_kernel_size); @@ -133,18 +133,13 @@ kernel_map_test(const torch::Tensor &in_coordinates, auto in_coordinate_range = coordinate_range(N_in, D, ptr); simple_range iter_in{N_in}; - in_map.insert(in_coordinate_range.begin(), // key begin - in_coordinate_range.end(), // key end - iter_in.begin(), // value begin - iter_in.end()); // value end + in_map.insert(ptr, + ptr + N_in * D); auto out_coordinate_range = coordinate_range(N_out, D, ptr_out); simple_range iter_out{N_out}; - out_map.insert(out_coordinate_range.begin(), // key begin - out_coordinate_range.end(), // key end - iter_out.begin(), // value begin - iter_out.end()); // value end + out_map.insert(ptr_out, ptr_out + N_out * D); LOG_DEBUG("coordinate initialization"); @@ -161,7 +156,7 @@ kernel_map_test(const torch::Tensor &in_coordinates, LOG_DEBUG("kernel_region initialization"); auto region = cpu_kernel_region( - REGION_TYPE::HYPER_CUBE, D, tensor_stride.data(), s_kernel_size.data(), + RegionType::HYPER_CUBE, D, tensor_stride.data(), s_kernel_size.data(), dilation.data()); timer t; diff --git a/tests/cpp/kernel_region_cpu_test.py b/tests/cpp/kernel_region_cpu_test.py index b91408d1..93292fe7 100644 --- a/tests/cpp/kernel_region_cpu_test.py +++ b/tests/cpp/kernel_region_cpu_test.py @@ -10,6 +10,16 @@ class KernelRegionTestCase(unittest.TestCase): def test(self): + coordinates = torch.IntTensor( + [[0, 1, -1], [0, 1, 0], [0, 1, 1], [0, 2, -1], [0, 2, 0], [0, 2, 1]] + ) + kernel_size = torch.IntTensor([3, 3]) + + (in_maps, out_maps), N, t = MinkowskiEngineTest._C.kernel_map_test( + coordinates, coordinates, kernel_size + ) + + def test2(self): coordinates = torch.IntTensor([[0, 1, -1], [0, 2, 1]]) kernel_size = torch.IntTensor([3, 3]) @@ -75,6 +85,18 @@ def test_even3(self): self.assertEqual(regions[10], [0, 1, 0, 4]) self.assertEqual(regions[11], [0, 2, 0, 4]) + def test_kernel_map1(self): + in_coordinates = torch.IntTensor([[0, 1, -1], [0, 2, 1]]) + out_coordinates = torch.IntTensor([[0, 1, -1], [0, 2, 1], [1, 2, 1]]) + kernel_size = torch.IntTensor([1, 1]) + + (in_maps, out_maps), num, t = MinkowskiEngineTest._C.kernel_map_test( + in_coordinates, out_coordinates, kernel_size + ) + + self.assertEqual(in_maps[0], [0, 1]) + self.assertEqual(out_maps[0], [0, 1]) + def test_kernel_map(self): in_coordinates = torch.IntTensor([[0, 1, -1], [0, 2, 1]]) out_coordinates = torch.IntTensor([[0, 1, 0], [0, 1, 2], [1, 2, 1]]) @@ -87,8 +109,6 @@ def test_kernel_map(self): in_maps = kernel_map[0] out_maps = kernel_map[1] self.assertEqual(len(in_maps), torch.prod(kernel_size).item()) - print(in_maps) - print(out_maps) self.assertEqual(in_maps[1], [0]) self.assertEqual(out_maps[1], [0]) diff --git a/tests/python/conv.py b/tests/python/convolution.py similarity index 51% rename from tests/python/conv.py rename to tests/python/convolution.py index 20cb81b0..102d9d83 100644 --- a/tests/python/conv.py +++ b/tests/python/convolution.py @@ -1,4 +1,5 @@ -# Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). +# Copyright (c) 2020 NVIDIA CORPORATION. +# Copyright (c) 2018-2020 Chris Choy (chrischoy@ai.stanford.edu). # # Permission is hereby granted, free of charge, to any person obtaining a copy of # this software and associated documentation files (the "Software"), to deal in @@ -24,15 +25,21 @@ import torch import unittest -from MinkowskiEngine import SparseTensor, MinkowskiConvolution, MinkowskiConvolutionFunction, \ - MinkowskiConvolutionTranspose, MinkowskiConvolutionTransposeFunction +import MinkowskiEngineBackend._C as _C -from tests.common import data_loader +from MinkowskiEngine import ( + SparseTensor, + MinkowskiConvolution, + MinkowskiConvolutionFunction, + MinkowskiConvolutionTranspose, + MinkowskiConvolutionTransposeFunction, +) + +from tests.python.common import data_loader from utils.gradcheck import gradcheck class TestConvolution(unittest.TestCase): - def test_gpu(self): print(f"{self.__class__.__name__}: test_gpu") if not torch.cuda.is_available(): @@ -41,26 +48,23 @@ def test_gpu(self): coords, feats, labels = data_loader(in_channels) feats = feats.double() feats.requires_grad_() - input = SparseTensor(feats, coords=coords) + # Initialize context conv = MinkowskiConvolution( - in_channels, - out_channels, - kernel_size=3, - stride=2, - has_bias=True, - dimension=D) + in_channels, out_channels, kernel_size=3, stride=2, bias=True, dimension=D + ) + print(conv) + input = SparseTensor(feats, coordinates=coords) conv = conv.double() output = conv(input) print(output) - device = torch.device('cuda') - input = input.to(device) + device = torch.device("cuda") + input = SparseTensor(feats.to(device), coordinates=coords.to(device)) conv = conv.to(device) output = conv(input) print(output) - print(output.F, output.coords) # Check backward fn = MinkowskiConvolutionFunction() @@ -70,10 +74,18 @@ def test_gpu(self): output.F.backward(grad) self.assertTrue( - gradcheck(fn, (input.F, conv.kernel, input.tensor_stride, - conv.stride, conv.kernel_size, conv.dilation, - conv.region_type_, conv.region_offset_, - input.coords_key, None, input.coords_man))) + gradcheck( + fn, + ( + input.F, + conv.weight, + conv.kernel_generator, + input.coordinate_map_key, + None, + input.coordinate_manager, + ), + ) + ) def test(self): print(f"{self.__class__.__name__}: test") @@ -81,61 +93,74 @@ def test(self): coords, feats, labels = data_loader(in_channels) feats = feats.double() feats.requires_grad_() - input = SparseTensor(feats, coords=coords) + input = SparseTensor(feats, coordinates=coords) # Initialize context conv = MinkowskiConvolution( - in_channels, - out_channels, - kernel_size=3, - stride=2, - has_bias=True, - dimension=D) + in_channels, out_channels, kernel_size=3, stride=2, bias=True, dimension=D + ) conv = conv.double() output = conv(input) print(output) - kernel_map = input.coords_man.get_kernel_map( - 1, 2, stride=2, kernel_size=3) - print(kernel_map) + # kernel_map = input.coords_man.get_kernel_map( + # 1, 2, stride=2, kernel_size=3) + # print(kernel_map) # Check backward fn = MinkowskiConvolutionFunction() self.assertTrue( - gradcheck(fn, (input.F, conv.kernel, input.tensor_stride, - conv.stride, conv.kernel_size, conv.dilation, - conv.region_type_, conv.region_offset_, - input.coords_key, None, input.coords_man))) + gradcheck( + fn, + ( + input.F, + conv.weight, + conv.kernel_generator, + input.coordinate_map_key, + None, + input.coordinate_manager, + ), + ) + ) class TestConvolutionTranspose(unittest.TestCase): - def test_gpu(self): print(f"{self.__class__.__name__}: test_gpu") if not torch.cuda.is_available(): return - device = torch.device('cuda') + device = torch.device("cuda") in_channels, out_channels, D = 2, 3, 2 coords, feats, labels = data_loader(in_channels) feats = feats.double() feats.requires_grad_() input = SparseTensor(feats, coords=coords).to(device) # Initialize context - conv = MinkowskiConvolution( - in_channels, - out_channels, - kernel_size=3, - stride=2, - has_bias=True, - dimension=D).double().to(device) - conv_tr = MinkowskiConvolutionTranspose( - out_channels, - in_channels, - kernel_size=3, - stride=2, - has_bias=True, - dimension=D).double().to(device) + conv = ( + MinkowskiConvolution( + in_channels, + out_channels, + kernel_size=3, + stride=2, + bias=True, + dimension=D, + ) + .double() + .to(device) + ) + conv_tr = ( + MinkowskiConvolutionTranspose( + out_channels, + in_channels, + kernel_size=3, + stride=2, + bias=True, + dimension=D, + ) + .double() + .to(device) + ) tr_input = conv(input) print(tr_input) output = conv_tr(tr_input) @@ -145,11 +170,24 @@ def test_gpu(self): fn = MinkowskiConvolutionTransposeFunction() self.assertTrue( - gradcheck(fn, - (tr_input.F, conv_tr.kernel, tr_input.tensor_stride, - conv_tr.stride, conv_tr.kernel_size, conv_tr.dilation, - conv_tr.region_type_, conv_tr.region_offset_, False, - tr_input.coords_key, None, tr_input.coords_man))) + gradcheck( + fn, + ( + tr_input.F, + conv_tr.kernel, + tr_input.tensor_stride, + conv_tr.stride, + conv_tr.kernel_size, + conv_tr.dilation, + conv_tr.region_type_, + conv_tr.region_offset_, + False, + tr_input.coords_key, + None, + tr_input.coords_man, + ), + ) + ) def test(self): print(f"{self.__class__.__name__}: test") @@ -161,37 +199,38 @@ def test(self): # Initialize context conv = MinkowskiConvolution( - in_channels, - out_channels, - kernel_size=3, - stride=2, - has_bias=True, - dimension=D).double() + in_channels, out_channels, kernel_size=3, stride=2, bias=True, dimension=D + ).double() conv_tr = MinkowskiConvolutionTranspose( - out_channels, - in_channels, - kernel_size=2, - stride=2, - has_bias=True, - dimension=D).double() - - print('Initial input: ', input) + out_channels, in_channels, kernel_size=2, stride=2, bias=True, dimension=D + ).double() + + print("Initial input: ", input) input = conv(input) - print('Conv output: ', input) + print("Conv output: ", input) output = conv_tr(input) - print('Conv tr output: ', output) + print("Conv tr output: ", output) # Check backward fn = MinkowskiConvolutionTransposeFunction() self.assertTrue( - gradcheck(fn, - (input.F, conv_tr.kernel, input.tensor_stride, - conv_tr.stride, conv_tr.kernel_size, conv_tr.dilation, - conv_tr.region_type_, conv_tr.region_offset_, False, - input.coords_key, None, input.coords_man))) - - -if __name__ == '__main__': - unittest.main() + gradcheck( + fn, + ( + input.F, + conv_tr.kernel, + input.tensor_stride, + conv_tr.stride, + conv_tr.kernel_size, + conv_tr.dilation, + conv_tr.region_type_, + conv_tr.region_offset_, + False, + input.coords_key, + None, + input.coords_man, + ), + ) + ) diff --git a/tests/python/coordinate_manager.py b/tests/python/coordinate_manager.py new file mode 100644 index 00000000..874ed0f7 --- /dev/null +++ b/tests/python/coordinate_manager.py @@ -0,0 +1,127 @@ +# Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +# of the Software, and to permit persons to whom the Software is furnished to do +# so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +# Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural +# Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part +# of the code. +import unittest + +import torch +import numpy as np + +import MinkowskiEngine as ME + + +class CoordinateManagerTestCase(unittest.TestCase): + def test_coordinate_manager(self): + + coordinates = torch.IntTensor( + [[0, 1], [0, 1], [0, 2], [0, 2], [1, 0], [1, 0], [1, 1]] + ) + + manager = ME.CoordinateManager( + D=1, coordinate_map_type=ME.CoordinateMapType.CPU + ) + key, (unique_map, inverse_map) = manager.insert_and_map(coordinates, [1]) + + # mapping and inverse mapping should recover the original coordinates + self.assertTrue( + torch.all(coordinates[unique_map.long()][inverse_map.long()] == coordinates) + ) + + # copied coordinates should retrieve the original coordinates + retrieved_coordinates = manager.get_coordinates(key) + self.assertTrue( + torch.all(coordinates == retrieved_coordinates[inverse_map.long()]) + ) + + # Create a strided map + stride_key = manager.stride(key, [4]) + strided_coords = manager.get_coordinates(stride_key) + self.assertTrue(len(strided_coords) == 2) + + # # Create a transposed stride map + # transposed_key = cm.transposed_stride(stride_key, [2], [3], [1]) + # print("Transposed Stride: ", cm.get_coords(transposed_key)) + # print(cm) + + # # Create a transposed stride map + # transposed_key = cm.transposed_stride( + # stride_key, [2], [3], [1], force_creation=True + # ) + # print("Forced Transposed Stride: ", cm.get_coords(transposed_key)) + # print(cm) + + # # Create a reduction map + # key = cm.reduce() + # print("Reduction: ", cm.get_coords(key)) + # print(cm) + + # print("Reduction mapping: ", cm.get_row_indices_per_batch(stride_key)) + # print(cm) + + def test_negative_coords(self): + coords = torch.IntTensor( + [[0, -3], [0, -2], [0, -1], [0, 0], [0, 1], [0, 2], [0, 3]] + ) + + # Initialize map + manager = ME.CoordinateManager( + D=1, coordinate_map_type=ME.CoordinateMapType.CPU + ) + key, (unique_map, inverse_map) = manager.insert_and_map(coords, [1]) + + # Create a strided map + stride_key = manager.stride(key, [2]) + strided_coords = manager.get_coordinates(stride_key).numpy().tolist() + self.assertTrue(len(strided_coords) == 4) + self.assertTrue([0, -4] in strided_coords) + self.assertTrue([0, -2] in strided_coords) + self.assertTrue([0, 2] in strided_coords) + + # def test_batch_size_initialize(self): + # cm = CoordsManager(D=1) + # coords = torch.IntTensor( + # [[0, -3], [0, -2], [0, -1], [0, 0], [1, 1], [1, 2], [1, 3]] + # ) + + # # key with batch_size 2 + # cm.create_coords_key(coords) + # self.assertTrue(cm.get_batch_size() == 2) + + # coords = torch.IntTensor( + # [[0, -3], [0, -2], [0, -1], [0, 0], [0, 1], [0, 2], [0, 3]] + # ) + # cm.create_coords_key(coords) + + # self.assertTrue(cm.get_batch_size() == 2) + + def test_gpu_allocator(self): + # Set the global GPU memory manager backend. By default PYTORCH. + ME.set_gpu_allocator(ME.GPUMemoryAllocatorType.PYTORCH) + ME.set_gpu_allocator(ME.GPUMemoryAllocatorType.CUDA) + + # Create a coords man with the specified GPU memory manager backend. + # No effect with CPU_ONLY build + manager = ME.CoordinateManager( + D=1, + coordinate_map_type=ME.CoordinateMapType.CPU, + allocator_type=ME.GPUMemoryAllocatorType.CUDA, + ) diff --git a/tests/python/sparse_tensor.py b/tests/python/sparse_tensor.py index 26cc0ed4..04175751 100644 --- a/tests/python/sparse_tensor.py +++ b/tests/python/sparse_tensor.py @@ -28,35 +28,32 @@ SparseTensorQuantizationMode, set_sparse_tensor_operation_mode) -from tests.common import data_loader +from tests.python.common import data_loader -class Test(unittest.TestCase): - +class SparseTensorTestCase(unittest.TestCase): def test(self): print(f"{self.__class__.__name__}: test SparseTensor") coords, feats, labels = data_loader(nchannel=2) - input = SparseTensor(feats, coords=coords) + input = SparseTensor(feats, coordinates=coords) print(input) def test_empty(self): print(f"{self.__class__.__name__}: test_empty SparseTensor") feats = torch.FloatTensor(0, 16) coords = torch.IntTensor(0, 4) - input = SparseTensor(feats, coords=coords) + input = SparseTensor(feats, coordinates=coords) print(input) def test_force_creation(self): print(f"{self.__class__.__name__}: test_force_creation") coords, feats, labels = data_loader(nchannel=2) - input1 = SparseTensor(feats, coords=coords, tensor_stride=1) + input1 = SparseTensor(feats, coordinates=coords) input2 = SparseTensor( feats, - coords=coords, - tensor_stride=1, - coords_manager=input1.coords_man, - force_creation=True) - print(input2) + coordinates=coords, + coordinate_manager=input1.coordinate_manager) + print(input1.coordinate_map_key, input2.coordinate_map_key) def test_duplicate_coords(self): print(f"{self.__class__.__name__}: test_duplicate_coords") @@ -64,13 +61,11 @@ def test_duplicate_coords(self): # create duplicate coords coords[0] = coords[1] coords[2] = coords[3] - input = SparseTensor(feats, coords=coords, allow_duplicate_coords=True) + input = SparseTensor(feats, coordinates=coords) self.assertTrue(len(input) == len(coords) - 2) - print(coords) - print(input) input = SparseTensor( feats, - coords=coords, + coordinates=coords, quantization_mode=SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE) self.assertTrue(len(coords) == 16) self.assertTrue(len(input) == 14) @@ -81,14 +76,14 @@ def test_duplicate_coords(self): feats = torch.FloatTensor([[0, 1, 2, 3, 5, 6, 7]]).T # 0.5, 2.5, 5.5, 7 sinput = SparseTensor( - coords=coords, - feats=feats, + coordinates=coords, + features=feats, quantization_mode=SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE) self.assertTrue(len(sinput) == 4) - self.assertTrue(0.5 in sinput.feats) - self.assertTrue(2.5 in sinput.feats) - self.assertTrue(5.5 in sinput.feats) - self.assertTrue(7 in sinput.feats) + self.assertTrue(0.5 in sinput.features) + self.assertTrue(2.5 in sinput.features) + self.assertTrue(5.5 in sinput.features) + self.assertTrue(7 in sinput.features) self.assertTrue(len(sinput.slice(sinput)) == len(coords)) def test_extraction(self): @@ -123,18 +118,17 @@ def test_extraction(self): def test_operation_mode(self): # Set to use the global sparse tensor coords manager by default set_sparse_tensor_operation_mode( - SparseTensorOperationMode.SHARE_COORDS_MANAGER) + SparseTensorOperationMode.SHARE_COORDINATE_MANAGER) coords, feats, labels = data_loader(nchannel=2) # Create a sparse tensor on two different coordinates. - A = SparseTensor(torch.rand(feats.shape), coords, force_creation=True) + A = SparseTensor(torch.rand(feats.shape), coordinates=coords) B = SparseTensor( torch.rand(4, 2), - torch.IntTensor([[0, 0, 0], [1, 1, 1], [0, 1, 0], [1, 0, 1]]), - force_creation=True) + coordinates=torch.IntTensor([[0, 0, 0], [1, 1, 1], [0, 1, 0], [1, 0, 1]])) - self.assertTrue(A.coords_man == B.coords_man) + self.assertTrue(A.coordinate_manager == B.coordinate_manager) A.requires_grad_(True) B.requires_grad_(True) @@ -156,7 +150,3 @@ def test_operation_mode(self): A -= D A *= D A /= D - - -if __name__ == '__main__': - unittest.main()