From d0614b6b450937710aab2c82a2dae21ce95dd57f Mon Sep 17 00:00:00 2001 From: chrischoy Date: Wed, 30 Dec 2020 13:45:02 +0900 Subject: [PATCH] transposed pool cpu/gpu --- MinkowskiEngine/MinkowskiPooling.py | 181 ++++++++----------- MinkowskiEngine/__init__.py | 4 +- examples/multigpu.py | 4 + pybind/extern.hpp | 80 +++++++++ setup.py | 2 + src/local_pooling_gpu.cu | 6 +- src/local_pooling_transpose_cpu.cpp | 198 +++++++++++++++++++++ src/local_pooling_transpose_gpu.cu | 256 +++++++++++++++++++++++++++ src/pooling_transpose.cpp | 261 ---------------------------- tests/python/pool.py | 61 +++---- 10 files changed, 641 insertions(+), 412 deletions(-) create mode 100644 src/local_pooling_transpose_cpu.cpp create mode 100644 src/local_pooling_transpose_gpu.cu delete mode 100644 src/pooling_transpose.cpp diff --git a/MinkowskiEngine/MinkowskiPooling.py b/MinkowskiEngine/MinkowskiPooling.py index b1b3228a..c9386ca2 100644 --- a/MinkowskiEngine/MinkowskiPooling.py +++ b/MinkowskiEngine/MinkowskiPooling.py @@ -27,7 +27,7 @@ import torch from torch.autograd import Function -from MinkowskiEngineBackend._C import CoordinateMapKey, RegionType, PoolingMode +from MinkowskiEngineBackend._C import CoordinateMapKey, PoolingMode from MinkowskiSparseTensor import SparseTensor, _get_coordinate_map_key from MinkowskiCoordinateManager import CoordinateManager from MinkowskiKernelGenerator import KernelGenerator, save_ctx @@ -434,85 +434,76 @@ def __init__( ) -class MinkowskiPoolingTransposeFunction(Function): +class MinkowskiLocalPoolingTransposeFunction(Function): @staticmethod def forward( ctx, - input_features, - tensor_stride=1, - stride=1, - kernel_size=-1, - dilation=1, - region_type=-1, - region_offset=None, - average=False, - in_coords_key=None, - out_coords_key=None, - coords_manager=None, + input_features: torch.Tensor, + pooling_mode: PoolingMode, + kernel_generator: KernelGenerator, + in_coordinate_map_key: CoordinateMapKey, + out_coordinate_map_key: CoordinateMapKey = None, + coordinate_manager: CoordinateManager = None, ): - assert isinstance(region_type, RegionType) - if out_coords_key is None: - out_coords_key = CoordsKey(in_coords_key.D) - assert in_coords_key.D == out_coords_key.D - tensor_stride, stride, kernel_size, dilation, region_type = prep_args( - tensor_stride, stride, kernel_size, dilation, region_type, in_coords_key.D - ) - - if region_offset is None: - region_offset = torch.IntTensor() + if out_coordinate_map_key is None: + out_coordinate_map_key = CoordinateMapKey( + in_coordinate_map_key.get_coordinate_size() + ) - ctx.in_feat = input_features - out_feat = input_features.new() - ctx.num_nonzero = input_features.new() + input_features = input_features.contiguous() + ctx.input_features = input_features ctx = save_ctx( ctx, - tensor_stride, - stride, - kernel_size, - dilation, - region_type, - in_coords_key, - out_coords_key, - coords_manager, + kernel_generator, + in_coordinate_map_key, + out_coordinate_map_key, + coordinate_manager, ) - D = in_coords_key.D - fw_fn = get_minkowski_function("PoolingTransposeForward", input_features) - fw_fn( - ctx.in_feat, - out_feat, - ctx.num_nonzero, - convert_to_int_list(ctx.tensor_stride, D), - convert_to_int_list(ctx.stride, D), - convert_to_int_list(ctx.kernel_size, D), - convert_to_int_list(ctx.dilation, D), - region_type, - region_offset, - ctx.in_coords_key.CPPCoordsKey, - ctx.out_coords_key.CPPCoordsKey, - ctx.coords_man.CPPCoordsManager, + ctx.pooling_mode = pooling_mode + + fw_fn = get_minkowski_function("LocalPoolingTransposeForward", input_features) + out_feat, num_nonzero = fw_fn( + ctx.input_features, + kernel_generator.kernel_size, + kernel_generator.kernel_stride, + kernel_generator.kernel_dilation, + kernel_generator.region_type, + kernel_generator.region_offsets, + kernel_generator.expand_coordinates, + pooling_mode, + ctx.in_coordinate_map_key, + ctx.out_coordinate_map_key, + ctx.coordinate_manager._manager, ) + ctx.num_nonzero = num_nonzero return out_feat @staticmethod def backward(ctx, grad_out_feat): - grad_in_feat = grad_out_feat.new() - D = ctx.in_coords_key.D - bw_fn = get_minkowski_function("PoolingTransposeBackward", grad_out_feat) - bw_fn( - ctx.in_feat, - grad_in_feat, + grad_out_feat = grad_out_feat.contiguous() + bw_fn = get_minkowski_function("LocalPoolingTransposeBackward", grad_out_feat) + grad_in_feat = bw_fn( + ctx.input_features, grad_out_feat, ctx.num_nonzero, - convert_to_int_list(ctx.tensor_stride, D), - convert_to_int_list(ctx.stride, D), - convert_to_int_list(ctx.kernel_size, D), - convert_to_int_list(ctx.dilation, D), - ctx.region_type, - ctx.in_coords_key.CPPCoordsKey, - ctx.out_coords_key.CPPCoordsKey, - ctx.coords_man.CPPCoordsManager, + ctx.kernel_generator.kernel_size, + ctx.kernel_generator.kernel_stride, + ctx.kernel_generator.kernel_dilation, + ctx.kernel_generator.region_type, + ctx.kernel_generator.region_offsets, + ctx.pooling_mode, + ctx.in_coordinate_map_key, + ctx.out_coordinate_map_key, + ctx.coordinate_manager._manager, + ) + return ( + grad_in_feat, + None, + None, + None, + None, + None, ) - return grad_in_feat, None, None, None, None, None, None, None, None, None, None class MinkowskiPoolingTranspose(MinkowskiPoolingBase): @@ -523,7 +514,7 @@ class MinkowskiPoolingTranspose(MinkowskiPoolingBase): """ def __init__( - self, kernel_size, stride, dilation=1, kernel_generator=None, dimension=None + self, kernel_size, stride, dilation=1, kernel_generator=None, expand_coordinates=False, dimension=None ): r"""a high-dimensional unpooling layer for sparse tensors. @@ -547,12 +538,26 @@ def __init__( :attr:`kernel_generator` (:attr:`MinkowskiEngine.KernelGenerator`, optional): define custom kernel shape. + :attr:`expand_coordinates` (bool, optional): Force generation of + new coordinates. When True, the output coordinates will be the + outer product of the kernel shape and the input coordinates. + `False` by default. + :attr:`dimension` (int): the spatial dimension of the space where all the inputs and the network are defined. For example, images are in a 2D space, meshes and 3D shapes are in a 3D space. """ is_transpose = True + if kernel_generator is None: + kernel_generator = KernelGenerator( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + expand_coordinates=expand_coordinates, + dimension=dimension, + ) + MinkowskiPoolingBase.__init__( self, kernel_size, @@ -560,53 +565,9 @@ def __init__( dilation, kernel_generator, is_transpose, - average=False, dimension=dimension, ) - self.pooling = MinkowskiPoolingTransposeFunction() - - def forward( - self, - input: SparseTensor, - coords: Union[torch.IntTensor, CoordinateMapKey, SparseTensor] = None, - ): - r""" - :attr:`input` (`MinkowskiEngine.SparseTensor`): Input sparse tensor to apply a - convolution on. - - :attr:`coords` ((`torch.IntTensor`, `MinkowskiEngine.CoordsKey`, - `MinkowskiEngine.SparseTensor`), optional): If provided, generate - results on the provided coordinates. None by default. - - """ - assert isinstance(input, SparseTensor) - assert input.D == self.dimension - - # Create a region_offset - self.region_type_, self.region_offset_, _ = self.kernel_generator.get_kernel( - input.tensor_stride, self.is_transpose - ) - - # Get a new coords key or extract one from the coords - out_coords_key = _get_coords_key(input, coords) - - output = self.pooling.apply( - input.F, - input.tensor_stride, - self.stride, - self.kernel_size, - self.dilation, - self.region_type_, - self.region_offset_, - self.average, - input.coords_key, - out_coords_key, - input.coords_man, - ) - - return SparseTensor( - output, coords_key=out_coords_key, coords_manager=input.coords_man - ) + self.pooling = MinkowskiLocalPoolingTransposeFunction() class MinkowskiGlobalPoolingFunction(Function): diff --git a/MinkowskiEngine/__init__.py b/MinkowskiEngine/__init__.py index d97125db..027954f7 100644 --- a/MinkowskiEngine/__init__.py +++ b/MinkowskiEngine/__init__.py @@ -108,8 +108,8 @@ MinkowskiSumPooling, MinkowskiAvgPooling, MinkowskiMaxPooling, - # MinkowskiPoolingTransposeFunction, - # MinkowskiPoolingTranspose, + MinkowskiLocalPoolingTransposeFunction, + MinkowskiPoolingTranspose, MinkowskiGlobalPoolingFunction, MinkowskiGlobalPooling, MinkowskiGlobalSumPooling, diff --git a/examples/multigpu.py b/examples/multigpu.py index daed7b3c..bc501d6b 100644 --- a/examples/multigpu.py +++ b/examples/multigpu.py @@ -86,6 +86,10 @@ def generate_input(file_name, voxel_size): num_devices = torch.cuda.device_count() num_devices = min(config.max_ngpu, num_devices) devices = list(range(num_devices)) + print("''''''''''''''''''''''''''''''''''''''''''''''''''''''''''") + print("' WARNING: This example is deprecated. '" + print("' Please use DistributedDataParallel or pytorch-lightning'") + print("''''''''''''''''''''''''''''''''''''''''''''''''''''''''''") print( f"Testing {num_devices} GPUs. Total batch size: {num_devices * config.batch_size}" ) diff --git a/pybind/extern.hpp b/pybind/extern.hpp index 87837e0b..af58b9ff 100644 --- a/pybind/extern.hpp +++ b/pybind/extern.hpp @@ -243,6 +243,71 @@ at::Tensor LocalPoolingBackwardGPU( gpu_manager_type *p_map_manager); #endif +/************************************* + * Local Pooling Transpose + *************************************/ +template +std::pair LocalPoolingTransposeForwardCPU( + at::Tensor const &in_feat, + default_types::stride_type const &kernel_size, // + default_types::stride_type const &kernel_stride, // + default_types::stride_type const &kernel_dilation, // + RegionType::Type const region_type, // + at::Tensor const &offset, // + bool generate_new_coordinates, // + PoolingMode::Type pooling_mode, // + CoordinateMapKey *p_in_map_key, // + CoordinateMapKey *p_out_map_key, // + cpu_manager_type *p_map_manager); + +template +at::Tensor LocalPoolingTransposeBackwardCPU( + at::Tensor const &in_feat, // + at::Tensor const &grad_out_feat, // + at::Tensor const &num_nonzero, // + default_types::stride_type const &kernel_size, // + default_types::stride_type const &kernel_stride, // + default_types::stride_type const &kernel_dilation, // + RegionType::Type const region_type, // + at::Tensor const &offset, // + PoolingMode::Type pooling_mode, // + CoordinateMapKey *p_in_map_key, // + CoordinateMapKey *p_out_map_key, // + cpu_manager_type *p_map_manager); + +#ifndef CPU_ONLY +template class TemplatedAllocator> +std::pair LocalPoolingTransposeForwardGPU( + at::Tensor const &in_feat, + default_types::stride_type const &kernel_size, // + default_types::stride_type const &kernel_stride, // + default_types::stride_type const &kernel_dilation, // + RegionType::Type const region_type, // + at::Tensor const &offset, // + bool generate_new_coordinates, // + PoolingMode::Type pooling_mode, // + CoordinateMapKey *p_in_map_key, // + CoordinateMapKey *p_out_map_key, // + gpu_manager_type *p_map_manager); + +template class TemplatedAllocator> +at::Tensor LocalPoolingTransposeBackwardGPU( + at::Tensor const &in_feat, // + at::Tensor const &grad_out_feat, // + at::Tensor const &num_nonzero, // + default_types::stride_type const &kernel_size, // + default_types::stride_type const &kernel_stride, // + default_types::stride_type const &kernel_dilation, // + RegionType::Type const region_type, // + at::Tensor const &offset, // + PoolingMode::Type pooling_mode, // + CoordinateMapKey *p_in_map_key, // + CoordinateMapKey *p_out_map_key, // + gpu_manager_type *p_map_manager); +#endif + /************************************* * Global Pooling *************************************/ @@ -461,6 +526,13 @@ void instantiate_cpu_func(py::module &m, const std::string &dtypestr) { &minkowski::LocalPoolingBackwardCPU, py::call_guard()); + m.def((std::string("LocalPoolingTransposeForwardCPU") + dtypestr).c_str(), + &minkowski::LocalPoolingTransposeForwardCPU, + py::call_guard()); + m.def((std::string("LocalPoolingTransposeBackwardCPU") + dtypestr).c_str(), + &minkowski::LocalPoolingTransposeBackwardCPU, + py::call_guard()); + m.def((std::string("GlobalPoolingForwardCPU") + dtypestr).c_str(), &minkowski::GlobalPoolingForwardCPU, py::call_guard()); @@ -521,6 +593,14 @@ void instantiate_gpu_func(py::module &m, const std::string &dtypestr) { &minkowski::LocalPoolingBackwardGPU, py::call_guard()); + m.def((std::string("LocalPoolingTransposeForwardGPU") + dtypestr).c_str(), + &minkowski::LocalPoolingTransposeForwardGPU, + py::call_guard()); + m.def( + (std::string("LocalPoolingTransposeBackwardGPU") + dtypestr).c_str(), + &minkowski::LocalPoolingTransposeBackwardGPU, + py::call_guard()); + m.def( (std::string("GlobalPoolingForwardGPU") + dtypestr).c_str(), &minkowski::GlobalPoolingForwardGPU, diff --git a/setup.py b/setup.py index 8836a4e7..690a2ec4 100644 --- a/setup.py +++ b/setup.py @@ -211,6 +211,7 @@ def _argparse(pattern, argv, is_flag=True): "convolution_cpu.cpp", "convolution_transpose_cpu.cpp", "local_pooling_cpu.cpp", + "local_pooling_transpose_cpu.cpp", "global_pooling_cpu.cpp", "broadcast_cpu.cpp", "pruning_cpu.cpp", @@ -233,6 +234,7 @@ def _argparse(pattern, argv, is_flag=True): "pooling_avg_kernel.cu", "pooling_max_kernel.cu", "local_pooling_gpu.cu", + "local_pooling_transpose_gpu.cu", "global_pooling_gpu.cu", "broadcast_kernel.cu", "broadcast_gpu.cu", diff --git a/src/local_pooling_gpu.cu b/src/local_pooling_gpu.cu index da1483be..b2fa3486 100644 --- a/src/local_pooling_gpu.cu +++ b/src/local_pooling_gpu.cu @@ -58,7 +58,7 @@ std::pair LocalPoolingForwardGPU( gpu_manager_type *p_map_manager) { ASSERT(in_feat.is_contiguous(), "in_feat must be contiguous"); - ASSERT(in_feat.is_cuda(), "in_feat must be CPU"); + ASSERT(in_feat.is_cuda(), "in_feat must be on CUDA"); ASSERT(in_feat.dim() == 2, "in_feat.dim():", in_feat.dim()); coordinate_map_key_type in_key = p_in_map_key->get_key(); @@ -154,8 +154,8 @@ at::Tensor LocalPoolingBackwardGPU( ASSERT(in_feat.is_contiguous(), "in_feat must be contiguous"); ASSERT(grad_out_feat.is_contiguous(), "grad_out_feata must be contiguous"); - ASSERT(in_feat.is_cuda(), "in_feat must be CPU"); - ASSERT(grad_out_feat.is_cuda(), "in_feat must be CPU"); + ASSERT(in_feat.is_cuda(), "in_feat must be on CUDA"); + ASSERT(grad_out_feat.is_cuda(), "in_feat must be on CUDA"); ASSERT(in_feat.scalar_type() == grad_out_feat.scalar_type(), "type mismatch"); diff --git a/src/local_pooling_transpose_cpu.cpp b/src/local_pooling_transpose_cpu.cpp new file mode 100644 index 00000000..20826f07 --- /dev/null +++ b/src/local_pooling_transpose_cpu.cpp @@ -0,0 +1,198 @@ +/* + * Copyright (c) 2020 NVIDIA Corporation. + * Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + * + * Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural + * Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part + * of the code. + */ +#include "coordinate_map.hpp" +#include "coordinate_map_cpu.hpp" +#include "coordinate_map_key.hpp" +#include "coordinate_map_manager.hpp" +#include "errors.hpp" +#include "types.hpp" +#include "utils.hpp" + +#include "pooling_avg_kernel.hpp" + +#include +#include + +namespace minkowski { + +template +std::pair LocalPoolingTransposeForwardCPU( + at::Tensor const &in_feat, + default_types::stride_type const &kernel_size, // + default_types::stride_type const &kernel_stride, // + default_types::stride_type const &kernel_dilation, // + RegionType::Type const region_type, // + at::Tensor const &offset, // + bool generate_new_coordinates, // + PoolingMode::Type pooling_mode, // + CoordinateMapKey *p_in_map_key, // + CoordinateMapKey *p_out_map_key, // + cpu_manager_type *p_map_manager) { + + ASSERT(in_feat.is_contiguous(), "in_feat must be contiguous"); + ASSERT(!in_feat.is_cuda(), "in_feat must be CPU"); + ASSERT(in_feat.dim() == 2, "in_feat.dim():", in_feat.dim()); + + coordinate_map_key_type in_key = p_in_map_key->get_key(); + ASSERT(p_map_manager->exists(in_key), ERROR_MAP_NOT_FOUND); + + ASSERT(in_feat.size(0) == p_map_manager->size(in_key), "Invalid in_feat size", + in_feat.size(0), "!=", p_map_manager->size(in_key)); + + // create an output coordinate map + if (!p_out_map_key->is_key_set()) { + auto map_it = p_map_manager->find(p_in_map_key->get_key()); + ASSERT(map_it != p_map_manager->map_end(), ERROR_MAP_NOT_FOUND); + auto const &in_map = (*map_it).second; + + auto out_tensor_stride = detail::stride_tensor_stride( + in_map.get_tensor_stride(), kernel_stride, true /* is_transpose */); + auto kernel_region = cpu_kernel_region( + region_type, // + in_map.coordinate_size(), // + out_tensor_stride.data(), // + kernel_size.data(), // + kernel_dilation.data(), // + 0, // volume. Will be initialized automatically + offset.data_ptr(), offset.size(0), + true // is_transpose + ); + + coordinate_map_key_type out_key = std::get<0>(p_map_manager->stride_region( + in_key, kernel_region, out_tensor_stride, generate_new_coordinates)); + p_out_map_key->set_key(out_key); + } + + cpu_kernel_map const &in_out = p_map_manager->kernel_map( + p_in_map_key, // + p_out_map_key, // + kernel_size, // + kernel_stride, // + kernel_dilation, // + region_type, // + offset, true /* is_transpose */, true /* is_pool */); + + auto const out_nrows = p_map_manager->size(p_out_map_key->get_key()); + at::Tensor out_feat = + torch::zeros({out_nrows, in_feat.size(1)}, in_feat.options()); + LOG_DEBUG("Allocated", out_nrows, "x", in_feat.size(1), "features."); + + at::Tensor num_nonzero = + torch::empty({0}, in_feat.options().requires_grad(false)); + AT_DISPATCH_FLOATING_TYPES( + in_feat.scalar_type(), "local_pooling_forward_cpu", [&] { + NonzeroAvgPoolingForwardKernelCPU( + in_feat.template data_ptr(), + out_feat.template data_ptr(), + num_nonzero.template data_ptr(), in_feat.size(1), + in_out.first, in_out.second, out_nrows, false); + }); + return std::make_pair(out_feat, num_nonzero); +} + +template +at::Tensor LocalPoolingTransposeBackwardCPU( + at::Tensor const &in_feat, // + at::Tensor const &grad_out_feat, // + at::Tensor const &num_nonzero, // + default_types::stride_type const &kernel_size, // + default_types::stride_type const &kernel_stride, // + default_types::stride_type const &kernel_dilation, // + RegionType::Type const region_type, // + at::Tensor const &offset, // + PoolingMode::Type pooling_mode, // + CoordinateMapKey *p_in_map_key, // + CoordinateMapKey *p_out_map_key, // + cpu_manager_type *p_map_manager) { + ASSERT(in_feat.is_contiguous(), "in_feat must be contiguous"); + ASSERT(grad_out_feat.is_contiguous(), "grad_out_feata must be contiguous"); + + ASSERT(!in_feat.is_cuda(), "in_feat must be CPU"); + ASSERT(!grad_out_feat.is_cuda(), "in_feat must be CPU"); + + ASSERT(in_feat.scalar_type() == grad_out_feat.scalar_type(), "type mismatch"); + + ASSERT(in_feat.dim() == 2, "in_feat.dim():", in_feat.dim()); + ASSERT(grad_out_feat.dim() == 2, "grad_out_feat.dim():", grad_out_feat.dim()); + + coordinate_map_key_type in_key = p_in_map_key->get_key(); + ASSERT(p_map_manager->exists(in_key), ERROR_MAP_NOT_FOUND); + coordinate_map_key_type out_key = p_out_map_key->get_key(); + ASSERT(p_map_manager->exists(out_key), ERROR_MAP_NOT_FOUND); + + cpu_kernel_map const &in_out = p_map_manager->kernel_map( + p_in_map_key, // + p_out_map_key, // + kernel_size, // + kernel_stride, // + kernel_dilation, // + region_type, // + offset, true /* is_transpose */, true /* is_pool */); + + at::Tensor grad_in_feat = + torch::zeros({in_feat.size(0), in_feat.size(1)}, in_feat.options()); + + AT_DISPATCH_FLOATING_TYPES( + in_feat.scalar_type(), "local_pooling_backward_cpu", [&] { + NonzeroAvgPoolingBackwardKernelCPU( + grad_in_feat.template data_ptr(), in_feat.size(0), + grad_out_feat.template data_ptr(), + num_nonzero.template data_ptr(), in_feat.size(1), + in_out.first, in_out.second, false /* avg */); + }); + return grad_in_feat; +} + +template std::pair +LocalPoolingTransposeForwardCPU( + at::Tensor const &in_feat, + default_types::stride_type const &kernel_size, // + default_types::stride_type const &kernel_stride, // + default_types::stride_type const &kernel_dilation, // + RegionType::Type const region_type, // + at::Tensor const &offset, // + bool generate_new_coordinates, // + PoolingMode::Type pooling_mode, // + CoordinateMapKey *p_in_map_key, // + CoordinateMapKey *p_out_map_key, // + cpu_manager_type *p_map_manager); + +template at::Tensor LocalPoolingTransposeBackwardCPU( + at::Tensor const &in_feat, // + at::Tensor const &grad_out_feat, // + at::Tensor const &num_nonzero, // + default_types::stride_type const &kernel_size, // + default_types::stride_type const &kernel_stride, // + default_types::stride_type const &kernel_dilation, // + RegionType::Type const region_type, // + at::Tensor const &offset, // + PoolingMode::Type pooling_mode, // + CoordinateMapKey *p_in_map_key, // + CoordinateMapKey *p_out_map_key, // + cpu_manager_type *p_map_manager); + +} // end namespace minkowski diff --git a/src/local_pooling_transpose_gpu.cu b/src/local_pooling_transpose_gpu.cu new file mode 100644 index 00000000..53f9a01f --- /dev/null +++ b/src/local_pooling_transpose_gpu.cu @@ -0,0 +1,256 @@ +/* + * Copyright (c) 2020 NVIDIA Corporation. + * Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + * + * Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural + * Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part + * of the code. + */ +#include "coordinate_map.hpp" +#include "coordinate_map_cpu.hpp" +#include "coordinate_map_key.hpp" +#include "coordinate_map_manager.hpp" +#include "errors.hpp" +#include "types.hpp" +#include "utils.hpp" + +#include "pooling_avg_kernel.cuh" + +// Ninja +#include "local_pooling_transpose_cpu.cpp" + +#include +#include + +namespace minkowski { + +template class TemplatedAllocator> +std::pair LocalPoolingTransposeForwardGPU( + at::Tensor const &in_feat, + default_types::stride_type const &kernel_size, // + default_types::stride_type const &kernel_stride, // + default_types::stride_type const &kernel_dilation, // + RegionType::Type const region_type, // + at::Tensor const &offset, // + bool generate_new_coordinates, // + PoolingMode::Type pooling_mode, // + CoordinateMapKey *p_in_map_key, // + CoordinateMapKey *p_out_map_key, // + gpu_manager_type *p_map_manager) { + + ASSERT(in_feat.is_contiguous(), "in_feat must be contiguous"); + ASSERT(in_feat.is_cuda(), "in_feat must be CUDA"); + ASSERT(in_feat.dim() == 2, "in_feat.dim():", in_feat.dim()); + + coordinate_map_key_type in_key = p_in_map_key->get_key(); + ASSERT(p_map_manager->exists(in_key), ERROR_MAP_NOT_FOUND); + + ASSERT(in_feat.size(0) == p_map_manager->size(in_key), "Invalid in_feat size", + in_feat.size(0), "!=", p_map_manager->size(in_key)); + + // create an output coordinate map + if (!p_out_map_key->is_key_set()) { + auto map_it = p_map_manager->find(p_in_map_key->get_key()); + ASSERT(map_it != p_map_manager->map_end(), ERROR_MAP_NOT_FOUND); + auto const &in_map = (*map_it).second; + + auto out_tensor_stride = detail::stride_tensor_stride( + in_map.get_tensor_stride(), kernel_stride, true /* is_transpose */); + auto kernel_region = cpu_kernel_region( + region_type, // + in_map.coordinate_size(), // + out_tensor_stride.data(), // + kernel_size.data(), // + kernel_dilation.data(), // + 0, // volume + offset.data_ptr(), offset.size(0), + true // is_transpose + ); + + coordinate_map_key_type out_key = std::get<0>(p_map_manager->stride_region( + in_key, kernel_region, out_tensor_stride, generate_new_coordinates)); + LOG_DEBUG("PoolingTranspose out key:", out_key); + p_out_map_key->set_key(out_key); + } + + auto const &in_out = p_map_manager->kernel_map( + p_in_map_key, // + p_out_map_key, // + kernel_size, // + kernel_stride, // + kernel_dilation, // + region_type, // + offset, true /* is_transpose */, true /* is_pool */); + + auto const out_nrows = p_map_manager->size(p_out_map_key->get_key()); + at::Tensor out_feat = + torch::empty({out_nrows, in_feat.size(1)}, in_feat.options()); + LOG_DEBUG("Allocated", out_nrows, "x", in_feat.size(1), "features."); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + at::Tensor num_nonzero = + torch::empty({0}, in_feat.options().requires_grad(false)); + + cusparseHandle_t handle = at::cuda::getCurrentCUDASparseHandle(); + cusparseSetStream(handle, stream); + + AT_DISPATCH_FLOATING_TYPES( + in_feat.scalar_type(), "local_pooling_forward_gpu", [&] { + TemplatedAllocator byte_allocator; + NonzeroAvgPoolingForwardKernelGPU>( + in_feat.template data_ptr(), in_feat.size(0), + out_feat.template data_ptr(), out_nrows, + num_nonzero.template data_ptr(), in_feat.size(1), in_out, + false /* avg */, byte_allocator, handle, stream); + }); + + return std::make_pair(out_feat, num_nonzero); +} + +template class TemplatedAllocator> +at::Tensor LocalPoolingTransposeBackwardGPU( + at::Tensor const &in_feat, // + at::Tensor const &grad_out_feat, // + at::Tensor const &num_nonzero, // + default_types::stride_type const &kernel_size, // + default_types::stride_type const &kernel_stride, // + default_types::stride_type const &kernel_dilation, // + RegionType::Type const region_type, // + at::Tensor const &offset, // + PoolingMode::Type pooling_mode, // + CoordinateMapKey *p_in_map_key, // + CoordinateMapKey *p_out_map_key, // + gpu_manager_type *p_map_manager) { + ASSERT(in_feat.is_contiguous(), "in_feat must be contiguous"); + ASSERT(grad_out_feat.is_contiguous(), "grad_out_feata must be contiguous"); + + ASSERT(in_feat.is_cuda(), "in_feat must be CUDA"); + ASSERT(grad_out_feat.is_cuda(), "in_feat must be CUDA"); + + ASSERT(in_feat.scalar_type() == grad_out_feat.scalar_type(), "type mismatch"); + + ASSERT(in_feat.dim() == 2, "in_feat.dim():", in_feat.dim()); + ASSERT(grad_out_feat.dim() == 2, "grad_out_feat.dim():", grad_out_feat.dim()); + + coordinate_map_key_type in_key = p_in_map_key->get_key(); + ASSERT(p_map_manager->exists(in_key), ERROR_MAP_NOT_FOUND); + coordinate_map_key_type out_key = p_out_map_key->get_key(); + ASSERT(p_map_manager->exists(out_key), ERROR_MAP_NOT_FOUND); + + auto const &in_out = p_map_manager->kernel_map( + p_in_map_key, // + p_out_map_key, // + kernel_size, // + kernel_stride, // + kernel_dilation, // + region_type, // + offset, true /* is_transpose */, true /* is_pool */); + + at::Tensor grad_in_feat = + torch::zeros({in_feat.size(0), in_feat.size(1)}, in_feat.options()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + AT_DISPATCH_FLOATING_TYPES( + in_feat.scalar_type(), "local_pooling_backward_gpu", [&] { + NonzeroAvgPoolingBackwardKernelGPU>( + grad_in_feat.template data_ptr(), in_feat.size(0), + grad_out_feat.template data_ptr(), grad_out_feat.size(0), + num_nonzero.template data_ptr(), in_feat.size(1), in_out, + false /* avg */, stream); + }); + return grad_in_feat; +} + +// Forward +template std::pair +LocalPoolingTransposeForwardGPU( + at::Tensor const &in_feat, + default_types::stride_type const &kernel_size, // + default_types::stride_type const &kernel_stride, // + default_types::stride_type const &kernel_dilation, // + RegionType::Type const region_type, // + at::Tensor const &offset, // + bool generate_new_coordinates, // + PoolingMode::Type pooling_mode, // + CoordinateMapKey *p_in_map_key, // + CoordinateMapKey *p_out_map_key, // + gpu_manager_type + *p_map_manager); + +template std::pair +LocalPoolingTransposeForwardGPU( + at::Tensor const &in_feat, + default_types::stride_type const &kernel_size, // + default_types::stride_type const &kernel_stride, // + default_types::stride_type const &kernel_dilation, // + RegionType::Type const region_type, // + at::Tensor const &offset, // + bool generate_new_coordinates, // + PoolingMode::Type pooling_mode, // + CoordinateMapKey *p_in_map_key, // + CoordinateMapKey *p_out_map_key, // + gpu_manager_type + *p_map_manager); + +// Backward +template at::Tensor +LocalPoolingTransposeBackwardGPU( + at::Tensor const &in_feat, // + at::Tensor const &grad_out_feat, // + at::Tensor const &num_nonzero, // + default_types::stride_type const &kernel_size, // + default_types::stride_type const &kernel_stride, // + default_types::stride_type const &kernel_dilation, // + RegionType::Type const region_type, // + at::Tensor const &offset, // + PoolingMode::Type pooling_mode, // + CoordinateMapKey *p_in_map_key, // + CoordinateMapKey *p_out_map_key, // + gpu_manager_type + *p_map_manager); + +template at::Tensor +LocalPoolingTransposeBackwardGPU( + at::Tensor const &in_feat, // + at::Tensor const &grad_out_feat, // + at::Tensor const &num_nonzero, // + default_types::stride_type const &kernel_size, // + default_types::stride_type const &kernel_stride, // + default_types::stride_type const &kernel_dilation, // + RegionType::Type const region_type, // + at::Tensor const &offset, // + PoolingMode::Type pooling_mode, // + CoordinateMapKey *p_in_map_key, // + CoordinateMapKey *p_out_map_key, // + gpu_manager_type + *p_map_manager); + +} // end namespace minkowski diff --git a/src/pooling_transpose.cpp b/src/pooling_transpose.cpp deleted file mode 100644 index a32c6aa7..00000000 --- a/src/pooling_transpose.cpp +++ /dev/null @@ -1,261 +0,0 @@ -/* Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS - * IN THE SOFTWARE. - * - * Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural - * Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part - * of the code. - */ -#include "common.hpp" - -#include "pooling_avg.hpp" -#ifndef CPU_ONLY -#include "pooling_avg.cuh" -#endif - -#include - -namespace minkowski { - -template -void PoolingTransposeForwardCPU(at::Tensor in_feat, at::Tensor out_feat, - at::Tensor num_nonzero, - vector tensor_strides, vector strides, - vector kernel_sizes, vector dilations, - int region_type, at::Tensor offsets, - py::object py_in_coords_key, - py::object py_out_coords_key, - py::object py_coords_manager) { - CoordsManager *p_coords_manager = - py_coords_manager.cast *>(); - const auto &in_out = p_coords_manager->getInOutMaps( - tensor_strides, strides, kernel_sizes, dilations, region_type, offsets, - py_in_coords_key, py_out_coords_key, true, true); - - const int out_nrows = p_coords_manager->getCoordsSize(py_out_coords_key); - out_feat.resize_({out_nrows, in_feat.size(1)}); - out_feat.zero_(); - num_nonzero.resize_({out_nrows}); - num_nonzero.zero_(); - - NonzeroAvgPoolingForwardKernelCPU( - in_feat.template data(), out_feat.template data(), - num_nonzero.template data(), in_feat.size(1), in_out.first, - in_out.second, out_nrows, false); -} - -template -void PoolingTransposeBackwardCPU( - at::Tensor in_feat, at::Tensor grad_in_feat, at::Tensor grad_out_feat, - at::Tensor num_nonzero, vector tensor_strides, vector strides, - vector kernel_sizes, vector dilations, int region_type, - py::object py_in_coords_key, py::object py_out_coords_key, - py::object py_coords_manager) { - CoordsManager *p_coords_manager = - py_coords_manager.cast *>(); - bool reverse_map = false; - const InOutMapKey rev_map_key = p_coords_manager->getMapHashKey( - tensor_strides, strides, kernel_sizes, dilations, region_type, - py_out_coords_key, py_in_coords_key, false, true); - const InOutMapKey map_key = p_coords_manager->getMapHashKey( - tensor_strides, strides, kernel_sizes, dilations, region_type, - py_in_coords_key, py_out_coords_key, true, true); - - // Check if the reverse map exists first - if (p_coords_manager->in_maps.find(rev_map_key) != - p_coords_manager->in_maps.end()) - reverse_map = true; - - grad_in_feat.resize_as_(in_feat); - grad_in_feat.zero_(); - - if (!reverse_map) { - ASSERT( - p_coords_manager->in_maps.find(map_key) != - p_coords_manager->in_maps.end(), - "The in-out map doesn't exist for backward. Did you run forward pass?"); - - NonzeroAvgPoolingBackwardKernelCPU( - grad_in_feat.template data(), in_feat.size(0), - grad_out_feat.template data(), - num_nonzero.template data(), in_feat.size(1), - p_coords_manager->in_maps[map_key], p_coords_manager->out_maps[map_key], - false); - } else { - ASSERT( - p_coords_manager->in_maps.find(rev_map_key) != - p_coords_manager->in_maps.end(), - "The in-out map doesn't exist for backward. Did you run forward pass?"); - - NonzeroAvgPoolingBackwardKernelCPU( - grad_in_feat.template data(), in_feat.size(0), - grad_out_feat.template data(), - num_nonzero.template data(), in_feat.size(1), - p_coords_manager->out_maps[rev_map_key], - p_coords_manager->in_maps[rev_map_key], false); - } -} - -#ifndef CPU_ONLY -template -void PoolingTransposeForwardGPU(at::Tensor in_feat, at::Tensor out_feat, - at::Tensor num_nonzero, - vector tensor_strides, vector strides, - vector kernel_sizes, vector dilations, - int region_type, at::Tensor offsets, - py::object py_in_coords_key, - py::object py_out_coords_key, - py::object py_coords_manager) { - CoordsManager *p_coords_manager = - py_coords_manager.cast *>(); - const auto &in_out = p_coords_manager->getInOutMapsGPU( - tensor_strides, strides, kernel_sizes, dilations, region_type, offsets, - py_in_coords_key, py_out_coords_key, true, true); - - const int out_nrows = p_coords_manager->getCoordsSize(py_out_coords_key); - out_feat.resize_({out_nrows, in_feat.size(1)}); - out_feat.zero_(); - num_nonzero.resize_({out_nrows}); - num_nonzero.zero_(); - - cusparseHandle_t handle = at::cuda::getCurrentCUDASparseHandle(); - cusparseSetStream(handle, at::cuda::getCurrentCUDAStream()); - - NonzeroAvgPoolingForwardKernelGPU( - in_feat.template data(), in_feat.size(0), - out_feat.template data(), out_nrows, - num_nonzero.template data(), in_feat.size(1), get<0>(in_out), - get<1>(in_out), false, handle, at::cuda::getCurrentCUDAStream()); -} - -template -void PoolingTransposeBackwardGPU( - at::Tensor in_feat, at::Tensor grad_in_feat, at::Tensor grad_out_feat, - at::Tensor num_nonzero, vector tensor_strides, vector strides, - vector kernel_sizes, vector dilations, int region_type, - py::object py_in_coords_key, py::object py_out_coords_key, - py::object py_coords_manager) { - CoordsManager *p_coords_manager = - py_coords_manager.cast *>(); - bool reverse_map = false; - const InOutMapKey rev_map_key = p_coords_manager->getMapHashKey( - tensor_strides, strides, kernel_sizes, dilations, region_type, - py_out_coords_key, py_in_coords_key, false, true); - const InOutMapKey map_key = p_coords_manager->getMapHashKey( - tensor_strides, strides, kernel_sizes, dilations, region_type, - py_in_coords_key, py_out_coords_key, true, true); - - // Check if the reverse map exists first - if (p_coords_manager->in_maps.find(rev_map_key) != - p_coords_manager->in_maps.end()) - reverse_map = true; - - grad_in_feat.resize_as_(in_feat); - grad_in_feat.zero_(); - - if (!reverse_map) { - ASSERT( - p_coords_manager->d_in_maps.find(map_key) != - p_coords_manager->d_in_maps.end(), - "The in-out map doesn't exist for backward. Did you run forward pass?"); - - NonzeroAvgPoolingBackwardKernelGPU( - grad_in_feat.template data(), in_feat.size(0), - grad_out_feat.template data(), grad_out_feat.size(0), - num_nonzero.template data(), in_feat.size(1), - p_coords_manager->d_in_maps[map_key], - p_coords_manager->d_out_maps[map_key], false, - at::cuda::getCurrentCUDAStream()); - } else { - ASSERT( - p_coords_manager->d_in_maps.find(rev_map_key) != - p_coords_manager->d_in_maps.end(), - "The in-out map doesn't exist for backward. Did you run forward pass?"); - - NonzeroAvgPoolingBackwardKernelGPU( - grad_in_feat.template data(), in_feat.size(0), - grad_out_feat.template data(), grad_out_feat.size(0), - num_nonzero.template data(), in_feat.size(1), - p_coords_manager->d_out_maps[rev_map_key], - p_coords_manager->d_in_maps[rev_map_key], false, - at::cuda::getCurrentCUDAStream()); - } -} -#endif - -template void PoolingTransposeForwardCPU( - at::Tensor in_feat, at::Tensor out_feat, at::Tensor num_nonzero, - vector tensor_strides, vector strides, vector kernel_sizes, - vector dilations, int region_type, at::Tensor offsets, - py::object py_in_coords_key, py::object py_out_coords_key, - py::object py_coords_manager); - -template void PoolingTransposeForwardCPU( - at::Tensor in_feat, at::Tensor out_feat, at::Tensor num_nonzero, - vector tensor_strides, vector strides, vector kernel_sizes, - vector dilations, int region_type, at::Tensor offsets, - py::object py_in_coords_key, py::object py_out_coords_key, - py::object py_coords_manager); - -template void PoolingTransposeBackwardCPU( - at::Tensor in_feat, at::Tensor grad_in_feat, at::Tensor grad_out_feat, - at::Tensor num_nonzero, vector tensor_strides, vector strides, - vector kernel_sizes, vector dilations, int region_type, - py::object py_in_coords_key, py::object py_out_coords_key, - py::object py_coords_manager); - -template void PoolingTransposeBackwardCPU( - at::Tensor in_feat, at::Tensor grad_in_feat, at::Tensor grad_out_feat, - at::Tensor num_nonzero, vector tensor_strides, vector strides, - vector kernel_sizes, vector dilations, int region_type, - py::object py_in_coords_key, py::object py_out_coords_key, - py::object py_coords_manager); - -#ifndef CPU_ONLY - -template void PoolingTransposeForwardGPU( - at::Tensor in_feat, at::Tensor out_feat, at::Tensor num_nonzero, - vector tensor_strides, vector strides, vector kernel_sizes, - vector dilations, int region_type, at::Tensor offsets, - py::object py_in_coords_key, py::object py_out_coords_key, - py::object py_coords_manager); - -template void PoolingTransposeForwardGPU( - at::Tensor in_feat, at::Tensor out_feat, at::Tensor num_nonzero, - vector tensor_strides, vector strides, vector kernel_sizes, - vector dilations, int region_type, at::Tensor offsets, - py::object py_in_coords_key, py::object py_out_coords_key, - py::object py_coords_manager); - -template void PoolingTransposeBackwardGPU( - at::Tensor in_feat, at::Tensor grad_in_feat, at::Tensor grad_out_feat, - at::Tensor num_nonzero, vector tensor_strides, vector strides, - vector kernel_sizes, vector dilations, int region_type, - py::object py_in_coords_key, py::object py_out_coords_key, - py::object py_coords_manager); - -template void PoolingTransposeBackwardGPU( - at::Tensor in_feat, at::Tensor grad_in_feat, at::Tensor grad_out_feat, - at::Tensor num_nonzero, vector tensor_strides, vector strides, - vector kernel_sizes, vector dilations, int region_type, - py::object py_in_coords_key, py::object py_out_coords_key, - py::object py_coords_manager); -#endif // CPU_ONLY - -} // end namespace minkowski diff --git a/tests/python/pool.py b/tests/python/pool.py index 32438676..4209e893 100644 --- a/tests/python/pool.py +++ b/tests/python/pool.py @@ -32,8 +32,8 @@ MinkowskiSumPooling, MinkowskiAvgPooling, MinkowskiMaxPooling, - # MinkowskiPoolingTransposeFunction, - # MinkowskiPoolingTranspose, + MinkowskiLocalPoolingTransposeFunction, + MinkowskiPoolingTranspose, MinkowskiGlobalPoolingFunction, MinkowskiGlobalPooling, MinkowskiGlobalSumPooling, @@ -233,35 +233,30 @@ def test_unpool(self): print(output) # Check backward - fn = MinkowskiPoolingTransposeFunction() + fn = MinkowskiLocalPoolingTransposeFunction() self.assertTrue( gradcheck( fn, ( input.F, - input.tensor_stride, - unpool.stride, - unpool.kernel_size, - unpool.dilation, - unpool.region_type_, - unpool.region_offset_, - False, - input.coords_key, + unpool.pooling_mode, + unpool.kernel_generator, + input.coordinate_map_key, None, - input.coords_man, + input.coordinate_manager, ), ) ) - def test_unpooling_gpu(self): + def test_unpool_gpu(self): if not torch.cuda.is_available(): return in_channels, out_channels, D = 2, 3, 2 coords, feats, labels = data_loader(in_channels) feats = feats.double() - input = SparseTensor(feats, coords=coords) + input = SparseTensor(feats, coords) conv = MinkowskiConvolution( in_channels, out_channels, kernel_size=3, stride=2, dimension=D ) @@ -271,30 +266,27 @@ def test_unpooling_gpu(self): output = unpool(input) print(output) # Check backward - fn = MinkowskiPoolingTransposeFunction() + fn = MinkowskiLocalPoolingTransposeFunction() self.assertTrue( gradcheck( fn, ( input.F, - input.tensor_stride, - unpool.stride, - unpool.kernel_size, - unpool.dilation, - unpool.region_type_, - unpool.region_offset_, - False, - input.coords_key, + unpool.pooling_mode, + unpool.kernel_generator, + input.coordinate_map_key, None, - input.coords_man, + input.coordinate_manager, ), ) ) - device = torch.device("cuda") with torch.cuda.device(0): - input = input.to(device) + conv = conv.to("cuda") + input = SparseTensor(feats, coords, device="cuda") + input = conv(input) + input.requires_grad_() output = unpool(input) print(output) @@ -304,26 +296,22 @@ def test_unpooling_gpu(self): fn, ( input.F, - input.tensor_stride, - unpool.stride, - unpool.kernel_size, - unpool.dilation, - unpool.region_type_, - unpool.region_offset_, - True, - input.coords_key, + unpool.pooling_mode, + unpool.kernel_generator, + input.coordinate_map_key, None, - input.coords_man, + input.coordinate_manager, ), ) ) + class TestGlobalAvgPooling(unittest.TestCase): def test_gpu(self): if not torch.cuda.is_available(): return - in_channels, D = 2, 2 + in_channels = 2 coords, feats, labels = data_loader(in_channels) feats = feats.double() feats.requires_grad_() @@ -379,6 +367,7 @@ def test(self): ) ) + class TestGlobalMaxPooling(unittest.TestCase): def test_gpu(self): if not torch.cuda.is_available():