Skip to content

Commit

Permalink
conv/convtr feature type dispatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Dec 15, 2020
1 parent 97d5417 commit f27bda1
Show file tree
Hide file tree
Showing 19 changed files with 493 additions and 677 deletions.
12 changes: 6 additions & 6 deletions MinkowskiEngine/MinkowskiCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@ def prep_args(

def get_postfix(tensor: torch.Tensor):
postfix = "GPU" if tensor.is_cuda else "CPU"
if isinstance(tensor, torch.DoubleTensor) or isinstance(
tensor, torch.cuda.DoubleTensor
):
postfix += "d"
else:
postfix += "f"
# if isinstance(tensor, torch.DoubleTensor) or isinstance(
# tensor, torch.cuda.DoubleTensor
# ):
# postfix += "d"
# else:
# postfix += "f"
return postfix


Expand Down
11 changes: 6 additions & 5 deletions MinkowskiEngine/MinkowskiConvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@


class MinkowskiConvolutionFunction(Function):

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -201,7 +202,7 @@ class MinkowskiConvolutionBase(MinkowskiModuleBase):
"kernel_generator",
"dimension",
"use_mm",
"weight",
"kernel",
"bias",
"conv",
)
Expand Down Expand Up @@ -248,7 +249,7 @@ def __init__(
Tensor = torch.FloatTensor
if (
self.kernel_generator.kernel_volume == 1
and self.kernel_generator.requires_strided_coordiantes
and self.kernel_generator.requires_strided_coordinates
):
kernel_shape = (self.in_channels, self.out_channels)
self.use_mm = True
Expand All @@ -259,7 +260,7 @@ def __init__(
self.out_channels,
)

self.weight = Parameter(Tensor(*kernel_shape))
self.kernel = Parameter(Tensor(*kernel_shape))
self.bias = Parameter(Tensor(1, out_channels)) if bias else None
self.conv = (
MinkowskiConvolutionTransposeFunction()
Expand Down Expand Up @@ -294,7 +295,7 @@ def forward(
out_coordinate_map_key = _get_coordinate_map_key(input, coordinates)
outfeat = self.conv.apply(
input.F,
self.weight,
self.kernel,
self.kernel_generator,
input.coordinate_map_key,
out_coordinate_map_key,
Expand All @@ -314,7 +315,7 @@ def reset_parameters(self, is_transpose=False):
self.out_channels if is_transpose else self.in_channels
) * self.kernel_generator.kernel_volume
stdv = 1.0 / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
self.kernel.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)

Expand Down
6 changes: 3 additions & 3 deletions MinkowskiEngine/MinkowskiNonlinearity.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch
from torch.nn import Module

from SparseTensor import SparseTensor
from MinkowskiSparseTensor import SparseTensor


class MinkowskiModuleBase(Module):
Expand All @@ -38,8 +38,8 @@ def forward(self, input):
output = self.module(input.F)
return SparseTensor(
output,
coords_key=input.coords_key,
coords_manager=input.coords_man)
coordinate_map_key=input.coordinate_map_key,
coordinate_manager=input.coordinate_manager)

def __repr__(self):
return self.__class__.__name__ + '()'
Expand Down
Loading

0 comments on commit f27bda1

Please sign in to comment.