Skip to content

Commit

Permalink
Removed MinkowskiConvolutionFunction from MinkowskiConvolution to all…
Browse files Browse the repository at this point in the history
…ow pickling (#139)
  • Loading branch information
edraizen committed May 15, 2020
1 parent c351d10 commit 6618442
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions MinkowskiEngine/MinkowskiConvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,13 @@ def forward(self,
outfeat = input.F.mm(self.kernel)
out_coords_key = input.coords_key
else:
if self.is_transpose:
conv = MinkowskiConvolutionTransposeFunction()
else:
conv = MinkowskiConvolutionFunction()
# Get a new coords key or extract one from the coords
out_coords_key = _get_coords_key(input, coords)
outfeat = self.conv.apply(input.F, self.kernel, input.tensor_stride,
outfeat = conv.apply(input.F, self.kernel, input.tensor_stride,
self.stride, self.kernel_size,
self.dilation, self.region_type_,
self.region_offset_, input.coords_key,
Expand Down Expand Up @@ -384,7 +388,6 @@ def __init__(self,
is_transpose=False,
dimension=dimension)
self.reset_parameters()
self.conv = MinkowskiConvolutionFunction()


class MinkowskiConvolutionTranspose(MinkowskiConvolutionBase):
Expand Down Expand Up @@ -458,7 +461,6 @@ def __init__(self,
dimension=dimension)
self.reset_parameters(True)
self.generate_new_coords = generate_new_coords
self.conv = MinkowskiConvolutionTransposeFunction()

def forward(self,
input: SparseTensor,
Expand Down Expand Up @@ -487,7 +489,7 @@ def forward(self,
else:
# Get a new coords key or extract one from the coords
out_coords_key = _get_coords_key(input, coords, tensor_stride=1)
outfeat = self.conv.apply(
outfeat = MinkowskiConvolutionTransposeFunction().apply(
input.F, self.kernel, input.tensor_stride, self.stride,
self.kernel_size, self.dilation, self.region_type_,
self.region_offset_, self.generate_new_coords, input.coords_key,
Expand All @@ -496,4 +498,4 @@ def forward(self,
outfeat += self.bias

return SparseTensor(
outfeat, coords_key=out_coords_key, coords_manager=input.coords_man)
outfeat, coords_key=out_coords_key, coords_manager=input.coords_man)

0 comments on commit 6618442

Please sign in to comment.