Skip to content

Commit

Permalink
transposed pool cpu/gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Dec 30, 2020
1 parent 547e229 commit d0614b6
Show file tree
Hide file tree
Showing 10 changed files with 641 additions and 412 deletions.
181 changes: 71 additions & 110 deletions MinkowskiEngine/MinkowskiPooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import torch
from torch.autograd import Function

from MinkowskiEngineBackend._C import CoordinateMapKey, RegionType, PoolingMode
from MinkowskiEngineBackend._C import CoordinateMapKey, PoolingMode
from MinkowskiSparseTensor import SparseTensor, _get_coordinate_map_key
from MinkowskiCoordinateManager import CoordinateManager
from MinkowskiKernelGenerator import KernelGenerator, save_ctx
Expand Down Expand Up @@ -434,85 +434,76 @@ def __init__(
)


class MinkowskiPoolingTransposeFunction(Function):
class MinkowskiLocalPoolingTransposeFunction(Function):
@staticmethod
def forward(
ctx,
input_features,
tensor_stride=1,
stride=1,
kernel_size=-1,
dilation=1,
region_type=-1,
region_offset=None,
average=False,
in_coords_key=None,
out_coords_key=None,
coords_manager=None,
input_features: torch.Tensor,
pooling_mode: PoolingMode,
kernel_generator: KernelGenerator,
in_coordinate_map_key: CoordinateMapKey,
out_coordinate_map_key: CoordinateMapKey = None,
coordinate_manager: CoordinateManager = None,
):
assert isinstance(region_type, RegionType)
if out_coords_key is None:
out_coords_key = CoordsKey(in_coords_key.D)
assert in_coords_key.D == out_coords_key.D
tensor_stride, stride, kernel_size, dilation, region_type = prep_args(
tensor_stride, stride, kernel_size, dilation, region_type, in_coords_key.D
)

if region_offset is None:
region_offset = torch.IntTensor()
if out_coordinate_map_key is None:
out_coordinate_map_key = CoordinateMapKey(
in_coordinate_map_key.get_coordinate_size()
)

ctx.in_feat = input_features
out_feat = input_features.new()
ctx.num_nonzero = input_features.new()
input_features = input_features.contiguous()
ctx.input_features = input_features
ctx = save_ctx(
ctx,
tensor_stride,
stride,
kernel_size,
dilation,
region_type,
in_coords_key,
out_coords_key,
coords_manager,
kernel_generator,
in_coordinate_map_key,
out_coordinate_map_key,
coordinate_manager,
)
D = in_coords_key.D
fw_fn = get_minkowski_function("PoolingTransposeForward", input_features)
fw_fn(
ctx.in_feat,
out_feat,
ctx.num_nonzero,
convert_to_int_list(ctx.tensor_stride, D),
convert_to_int_list(ctx.stride, D),
convert_to_int_list(ctx.kernel_size, D),
convert_to_int_list(ctx.dilation, D),
region_type,
region_offset,
ctx.in_coords_key.CPPCoordsKey,
ctx.out_coords_key.CPPCoordsKey,
ctx.coords_man.CPPCoordsManager,
ctx.pooling_mode = pooling_mode

fw_fn = get_minkowski_function("LocalPoolingTransposeForward", input_features)
out_feat, num_nonzero = fw_fn(
ctx.input_features,
kernel_generator.kernel_size,
kernel_generator.kernel_stride,
kernel_generator.kernel_dilation,
kernel_generator.region_type,
kernel_generator.region_offsets,
kernel_generator.expand_coordinates,
pooling_mode,
ctx.in_coordinate_map_key,
ctx.out_coordinate_map_key,
ctx.coordinate_manager._manager,
)
ctx.num_nonzero = num_nonzero
return out_feat

@staticmethod
def backward(ctx, grad_out_feat):
grad_in_feat = grad_out_feat.new()
D = ctx.in_coords_key.D
bw_fn = get_minkowski_function("PoolingTransposeBackward", grad_out_feat)
bw_fn(
ctx.in_feat,
grad_in_feat,
grad_out_feat = grad_out_feat.contiguous()
bw_fn = get_minkowski_function("LocalPoolingTransposeBackward", grad_out_feat)
grad_in_feat = bw_fn(
ctx.input_features,
grad_out_feat,
ctx.num_nonzero,
convert_to_int_list(ctx.tensor_stride, D),
convert_to_int_list(ctx.stride, D),
convert_to_int_list(ctx.kernel_size, D),
convert_to_int_list(ctx.dilation, D),
ctx.region_type,
ctx.in_coords_key.CPPCoordsKey,
ctx.out_coords_key.CPPCoordsKey,
ctx.coords_man.CPPCoordsManager,
ctx.kernel_generator.kernel_size,
ctx.kernel_generator.kernel_stride,
ctx.kernel_generator.kernel_dilation,
ctx.kernel_generator.region_type,
ctx.kernel_generator.region_offsets,
ctx.pooling_mode,
ctx.in_coordinate_map_key,
ctx.out_coordinate_map_key,
ctx.coordinate_manager._manager,
)
return (
grad_in_feat,
None,
None,
None,
None,
None,
)
return grad_in_feat, None, None, None, None, None, None, None, None, None, None


class MinkowskiPoolingTranspose(MinkowskiPoolingBase):
Expand All @@ -523,7 +514,7 @@ class MinkowskiPoolingTranspose(MinkowskiPoolingBase):
"""

def __init__(
self, kernel_size, stride, dilation=1, kernel_generator=None, dimension=None
self, kernel_size, stride, dilation=1, kernel_generator=None, expand_coordinates=False, dimension=None
):
r"""a high-dimensional unpooling layer for sparse tensors.
Expand All @@ -547,66 +538,36 @@ def __init__(
:attr:`kernel_generator` (:attr:`MinkowskiEngine.KernelGenerator`,
optional): define custom kernel shape.
:attr:`expand_coordinates` (bool, optional): Force generation of
new coordinates. When True, the output coordinates will be the
outer product of the kernel shape and the input coordinates.
`False` by default.
:attr:`dimension` (int): the spatial dimension of the space where
all the inputs and the network are defined. For example, images are
in a 2D space, meshes and 3D shapes are in a 3D space.
"""
is_transpose = True
if kernel_generator is None:
kernel_generator = KernelGenerator(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
expand_coordinates=expand_coordinates,
dimension=dimension,
)

MinkowskiPoolingBase.__init__(
self,
kernel_size,
stride,
dilation,
kernel_generator,
is_transpose,
average=False,
dimension=dimension,
)
self.pooling = MinkowskiPoolingTransposeFunction()

def forward(
self,
input: SparseTensor,
coords: Union[torch.IntTensor, CoordinateMapKey, SparseTensor] = None,
):
r"""
:attr:`input` (`MinkowskiEngine.SparseTensor`): Input sparse tensor to apply a
convolution on.
:attr:`coords` ((`torch.IntTensor`, `MinkowskiEngine.CoordsKey`,
`MinkowskiEngine.SparseTensor`), optional): If provided, generate
results on the provided coordinates. None by default.
"""
assert isinstance(input, SparseTensor)
assert input.D == self.dimension

# Create a region_offset
self.region_type_, self.region_offset_, _ = self.kernel_generator.get_kernel(
input.tensor_stride, self.is_transpose
)

# Get a new coords key or extract one from the coords
out_coords_key = _get_coords_key(input, coords)

output = self.pooling.apply(
input.F,
input.tensor_stride,
self.stride,
self.kernel_size,
self.dilation,
self.region_type_,
self.region_offset_,
self.average,
input.coords_key,
out_coords_key,
input.coords_man,
)

return SparseTensor(
output, coords_key=out_coords_key, coords_manager=input.coords_man
)
self.pooling = MinkowskiLocalPoolingTransposeFunction()


class MinkowskiGlobalPoolingFunction(Function):
Expand Down
4 changes: 2 additions & 2 deletions MinkowskiEngine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@
MinkowskiSumPooling,
MinkowskiAvgPooling,
MinkowskiMaxPooling,
# MinkowskiPoolingTransposeFunction,
# MinkowskiPoolingTranspose,
MinkowskiLocalPoolingTransposeFunction,
MinkowskiPoolingTranspose,
MinkowskiGlobalPoolingFunction,
MinkowskiGlobalPooling,
MinkowskiGlobalSumPooling,
Expand Down
4 changes: 4 additions & 0 deletions examples/multigpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def generate_input(file_name, voxel_size):
num_devices = torch.cuda.device_count()
num_devices = min(config.max_ngpu, num_devices)
devices = list(range(num_devices))
print("''''''''''''''''''''''''''''''''''''''''''''''''''''''''''")
print("' WARNING: This example is deprecated. '"
print("' Please use DistributedDataParallel or pytorch-lightning'")
print("''''''''''''''''''''''''''''''''''''''''''''''''''''''''''")
print(
f"Testing {num_devices} GPUs. Total batch size: {num_devices * config.batch_size}"
)
Expand Down
80 changes: 80 additions & 0 deletions pybind/extern.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,71 @@ at::Tensor LocalPoolingBackwardGPU(
gpu_manager_type<coordinate_type, TemplatedAllocator> *p_map_manager);
#endif

/*************************************
* Local Pooling Transpose
*************************************/
template <typename coordinate_type>
std::pair<at::Tensor, at::Tensor> LocalPoolingTransposeForwardCPU(
at::Tensor const &in_feat,
default_types::stride_type const &kernel_size, //
default_types::stride_type const &kernel_stride, //
default_types::stride_type const &kernel_dilation, //
RegionType::Type const region_type, //
at::Tensor const &offset, //
bool generate_new_coordinates, //
PoolingMode::Type pooling_mode, //
CoordinateMapKey *p_in_map_key, //
CoordinateMapKey *p_out_map_key, //
cpu_manager_type<coordinate_type> *p_map_manager);

template <typename coordinate_type>
at::Tensor LocalPoolingTransposeBackwardCPU(
at::Tensor const &in_feat, //
at::Tensor const &grad_out_feat, //
at::Tensor const &num_nonzero, //
default_types::stride_type const &kernel_size, //
default_types::stride_type const &kernel_stride, //
default_types::stride_type const &kernel_dilation, //
RegionType::Type const region_type, //
at::Tensor const &offset, //
PoolingMode::Type pooling_mode, //
CoordinateMapKey *p_in_map_key, //
CoordinateMapKey *p_out_map_key, //
cpu_manager_type<coordinate_type> *p_map_manager);

#ifndef CPU_ONLY
template <typename coordinate_type,
template <typename C> class TemplatedAllocator>
std::pair<at::Tensor, at::Tensor> LocalPoolingTransposeForwardGPU(
at::Tensor const &in_feat,
default_types::stride_type const &kernel_size, //
default_types::stride_type const &kernel_stride, //
default_types::stride_type const &kernel_dilation, //
RegionType::Type const region_type, //
at::Tensor const &offset, //
bool generate_new_coordinates, //
PoolingMode::Type pooling_mode, //
CoordinateMapKey *p_in_map_key, //
CoordinateMapKey *p_out_map_key, //
gpu_manager_type<coordinate_type, TemplatedAllocator> *p_map_manager);

template <typename coordinate_type,
template <typename C> class TemplatedAllocator>
at::Tensor LocalPoolingTransposeBackwardGPU(
at::Tensor const &in_feat, //
at::Tensor const &grad_out_feat, //
at::Tensor const &num_nonzero, //
default_types::stride_type const &kernel_size, //
default_types::stride_type const &kernel_stride, //
default_types::stride_type const &kernel_dilation, //
RegionType::Type const region_type, //
at::Tensor const &offset, //
PoolingMode::Type pooling_mode, //
CoordinateMapKey *p_in_map_key, //
CoordinateMapKey *p_out_map_key, //
gpu_manager_type<coordinate_type, TemplatedAllocator> *p_map_manager);
#endif

/*************************************
* Global Pooling
*************************************/
Expand Down Expand Up @@ -461,6 +526,13 @@ void instantiate_cpu_func(py::module &m, const std::string &dtypestr) {
&minkowski::LocalPoolingBackwardCPU<coordinate_type>,
py::call_guard<py::gil_scoped_release>());

m.def((std::string("LocalPoolingTransposeForwardCPU") + dtypestr).c_str(),
&minkowski::LocalPoolingTransposeForwardCPU<coordinate_type>,
py::call_guard<py::gil_scoped_release>());
m.def((std::string("LocalPoolingTransposeBackwardCPU") + dtypestr).c_str(),
&minkowski::LocalPoolingTransposeBackwardCPU<coordinate_type>,
py::call_guard<py::gil_scoped_release>());

m.def((std::string("GlobalPoolingForwardCPU") + dtypestr).c_str(),
&minkowski::GlobalPoolingForwardCPU<coordinate_type>,
py::call_guard<py::gil_scoped_release>());
Expand Down Expand Up @@ -521,6 +593,14 @@ void instantiate_gpu_func(py::module &m, const std::string &dtypestr) {
&minkowski::LocalPoolingBackwardGPU<coordinate_type, TemplatedAllocator>,
py::call_guard<py::gil_scoped_release>());

m.def((std::string("LocalPoolingTransposeForwardGPU") + dtypestr).c_str(),
&minkowski::LocalPoolingTransposeForwardGPU<coordinate_type, TemplatedAllocator>,
py::call_guard<py::gil_scoped_release>());
m.def(
(std::string("LocalPoolingTransposeBackwardGPU") + dtypestr).c_str(),
&minkowski::LocalPoolingTransposeBackwardGPU<coordinate_type, TemplatedAllocator>,
py::call_guard<py::gil_scoped_release>());

m.def(
(std::string("GlobalPoolingForwardGPU") + dtypestr).c_str(),
&minkowski::GlobalPoolingForwardGPU<coordinate_type, TemplatedAllocator>,
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def _argparse(pattern, argv, is_flag=True):
"convolution_cpu.cpp",
"convolution_transpose_cpu.cpp",
"local_pooling_cpu.cpp",
"local_pooling_transpose_cpu.cpp",
"global_pooling_cpu.cpp",
"broadcast_cpu.cpp",
"pruning_cpu.cpp",
Expand All @@ -233,6 +234,7 @@ def _argparse(pattern, argv, is_flag=True):
"pooling_avg_kernel.cu",
"pooling_max_kernel.cu",
"local_pooling_gpu.cu",
"local_pooling_transpose_gpu.cu",
"global_pooling_gpu.cu",
"broadcast_kernel.cu",
"broadcast_gpu.cu",
Expand Down
6 changes: 3 additions & 3 deletions src/local_pooling_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ std::pair<at::Tensor, at::Tensor> LocalPoolingForwardGPU(
gpu_manager_type<coordinate_type, TemplatedAllocator> *p_map_manager) {

ASSERT(in_feat.is_contiguous(), "in_feat must be contiguous");
ASSERT(in_feat.is_cuda(), "in_feat must be CPU");
ASSERT(in_feat.is_cuda(), "in_feat must be on CUDA");
ASSERT(in_feat.dim() == 2, "in_feat.dim():", in_feat.dim());

coordinate_map_key_type in_key = p_in_map_key->get_key();
Expand Down Expand Up @@ -154,8 +154,8 @@ at::Tensor LocalPoolingBackwardGPU(
ASSERT(in_feat.is_contiguous(), "in_feat must be contiguous");
ASSERT(grad_out_feat.is_contiguous(), "grad_out_feata must be contiguous");

ASSERT(in_feat.is_cuda(), "in_feat must be CPU");
ASSERT(grad_out_feat.is_cuda(), "in_feat must be CPU");
ASSERT(in_feat.is_cuda(), "in_feat must be on CUDA");
ASSERT(grad_out_feat.is_cuda(), "in_feat must be on CUDA");

ASSERT(in_feat.scalar_type() == grad_out_feat.scalar_type(), "type mismatch");

Expand Down
Loading

0 comments on commit d0614b6

Please sign in to comment.