Skip to content

Commit

Permalink
fix broadcast modules, fix #103
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Mar 14, 2020
1 parent a1fab7a commit f1a2a95
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 99 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
- Fix an error in examples.convolution
- Changed `features_at`, `coordinates_at` to take a batch index not the index of the unique batch indices. (Issue #100)
- Fix an error torch.range --> torch.arange in `sparse_quantize` (Issue #101)
- Fix BLAS installation link error (Issue #94)
- Fix `MinkowskiBroadcast` and `MinkowskiBroadcastConcatenation` to use arbitrary channel sizes


## [0.4.1] - 2020-01-28
Expand Down
4 changes: 2 additions & 2 deletions MinkowskiEngine/MinkowskiBroadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def forward(self, input, input_glob):
assert isinstance(input, SparseTensor)
assert isinstance(input_glob, SparseTensor)

broadcast_feat = torch.empty_like(input.F)
broadcast_feat = input.F.new(len(input), input_glob.size()[1])
row_inds = input.coords_man.get_row_indices_per_batch(input.coords_key)
for b, row_ind in enumerate(row_inds):
broadcast_feat[row_ind] = input_glob.F[b]
Expand Down Expand Up @@ -226,7 +226,7 @@ def forward(self, input, input_glob):
assert isinstance(input, SparseTensor)
assert isinstance(input_glob, SparseTensor)

broadcast_feat = torch.empty_like(input.F)
broadcast_feat = input.F.new(len(input), input_glob.size()[1])
row_inds = input.coords_man.get_row_indices_per_batch(input.coords_key)
for b, row_ind in enumerate(row_inds):
broadcast_feat[row_ind] = input_glob.F[b]
Expand Down
206 changes: 109 additions & 97 deletions examples/pointnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
try:
import open3d as o3d
except ImportError:
raise ImportError(
'Please install open3d with `pip install open3d`.')
raise ImportError('Please install open3d with `pip install open3d`.')

import torch
import torch.nn as nn
Expand All @@ -51,54 +50,62 @@ def __init__(self, D=3):
s = self.STRIDES
c = self.CONV_CHANNELS

self.conv1 = ME.MinkowskiConvolution(
3, c[0], kernel_size=k[0], stride=s[0], has_bias=False, dimension=3)
self.conv2 = ME.MinkowskiConvolution(
c[0],
c[1],
kernel_size=k[1],
stride=s[1],
has_bias=False,
dimension=3)
self.conv3 = ME.MinkowskiConvolution(
c[1],
c[2],
kernel_size=k[2],
stride=s[2],
has_bias=False,
dimension=3)
self.block1 = nn.Sequential(
ME.MinkowskiConvolution(
3,
c[0],
kernel_size=k[0],
stride=s[0],
has_bias=False,
dimension=3), ME.MinkowskiInstanceNorm(c[0]),
ME.MinkowskiReLU())
self.block2 = nn.Sequential(
ME.MinkowskiConvolution(
c[0],
c[1],
kernel_size=k[1],
stride=s[1],
has_bias=False,
dimension=3), ME.MinkowskiInstanceNorm(c[1]),
ME.MinkowskiReLU())
self.block3 = nn.Sequential(
ME.MinkowskiConvolution(
c[1],
c[2],
kernel_size=k[2],
stride=s[2],
has_bias=False,
dimension=3), ME.MinkowskiInstanceNorm(c[2]),
ME.MinkowskiReLU())

# Use the kernelsize 1 convolution for linear layers. If kernel size ==
# 1, minkowski engine internally uses a linear function.
self.fc4 = ME.MinkowskiConvolution(
c[2], c[3], kernel_size=1, has_bias=False, dimension=3)
self.fc5 = ME.MinkowskiConvolution(
c[3], c[4], kernel_size=1, has_bias=False, dimension=3)
self.block4 = nn.Sequential(
ME.MinkowskiConvolution(
c[2], c[3], kernel_size=1, has_bias=False, dimension=3),
ME.MinkowskiInstanceNorm(c[3]), ME.MinkowskiReLU())
self.block5 = nn.Sequential(
ME.MinkowskiConvolution(
c[3], c[4], kernel_size=1, has_bias=False, dimension=3),
ME.MinkowskiInstanceNorm(c[4]), ME.MinkowskiReLU())
self.fc6 = ME.MinkowskiConvolution(
c[4], 9, kernel_size=1, has_bias=True, dimension=3)

self.relu = ME.MinkowskiReLU(inplace=True)
self.avgpool = ME.MinkowskiGlobalPooling()
self.broadcast = ME.MinkowskiBroadcast()

self.bn1 = ME.MinkowskiInstanceNorm(c[0], dimension=3)
self.bn2 = ME.MinkowskiInstanceNorm(c[1], dimension=3)
self.bn3 = ME.MinkowskiInstanceNorm(c[2], dimension=3)
self.bn4 = ME.MinkowskiInstanceNorm(c[3], dimension=3)
self.bn5 = ME.MinkowskiInstanceNorm(c[4], dimension=3)

def forward(self, in_x):
x = self.relu(self.bn1(self.conv1(in_x)))
x = self.relu(self.bn2(self.conv2(x)))
x = self.relu(self.bn3(self.conv3(x)))
x = self.block1(in_x)
x = self.block2(x)
x = self.block3(x)

# batch size x channel
x = self.avgpool(x)

x = self.relu(self.bn4(self.fc4(x)))
x = self.relu(self.bn5(self.fc5(x)))
x = self.block4(x)
x = self.block5(x)

# get the features only
# get the features batch-wise
x = self.fc6(x)

# Add identity transformation
Expand Down Expand Up @@ -130,32 +137,34 @@ def __init__(self):
c = self.CONV_CHANNELS

self.stn = STN3d(D=3)
self.conv1 = ME.MinkowskiConvolution(
6,
c[0],
kernel_size=k[0],
stride=s[0],
has_bias=False,
dimension=3)
self.conv2 = ME.MinkowskiConvolution(
c[0],
c[1],
kernel_size=k[1],
stride=s[1],
has_bias=False,
dimension=3)
self.conv3 = ME.MinkowskiConvolution(
c[1],
c[2],
kernel_size=k[2],
stride=s[2],
has_bias=False,
dimension=3)
self.bn1 = ME.MinkowskiInstanceNorm(c[0], dimension=3)
self.bn2 = ME.MinkowskiInstanceNorm(c[1], dimension=3)
self.bn3 = ME.MinkowskiInstanceNorm(c[2], dimension=3)

self.relu = ME.MinkowskiReLU(inplace=True)
self.block1 = nn.Sequential(
ME.MinkowskiConvolution(
6,
c[0],
kernel_size=k[0],
stride=s[0],
has_bias=False,
dimension=3), ME.MinkowskiInstanceNorm(c[0]),
ME.MinkowskiReLU())
self.block2 = nn.Sequential(
ME.MinkowskiConvolution(
c[0],
c[1],
kernel_size=k[1],
stride=s[1],
has_bias=False,
dimension=3), ME.MinkowskiInstanceNorm(c[1]),
ME.MinkowskiReLU())
self.block3 = nn.Sequential(
ME.MinkowskiConvolution(
c[1],
c[2],
kernel_size=k[2],
stride=s[2],
has_bias=False,
dimension=3), ME.MinkowskiInstanceNorm(c[2]),
ME.MinkowskiReLU())

self.avgpool = ME.MinkowskiGlobalPooling()
self.concat = ME.MinkowskiBroadcastConcatenation()

Expand All @@ -170,15 +179,16 @@ def forward(self, x):
T = self.stn(x)

# Apply the transformation
coords_feat_stn = torch.squeeze(torch.bmm(x.F.view(-1, 1, 3), T.F.view(-1, 3, 3)))
coords_feat_stn = torch.squeeze(
torch.bmm(x.F.view(-1, 1, 3), T.F.view(-1, 3, 3)))
x = ME.SparseTensor(
torch.cat((coords_feat_stn, x.F), 1),
coords_key=x.coords_key,
coords_manager=x.coords_man)

point_feat = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(point_feat)))
x = self.bn3(self.conv3(x))
point_feat = self.block1(x)
x = self.block2(point_feat)
x = self.block3(x)
glob_feat = self.avgpool(x)
return self.concat(point_feat, glob_feat)

Expand All @@ -203,45 +213,47 @@ def __init__(self, out_channels, D=3):
c = self.CONV_CHANNELS

self.feat = PointNetFeature()
self.conv1 = ME.MinkowskiConvolution(
1280,
c[0],
kernel_size=k[0],
stride=s[0],
has_bias=False,
dimension=3)
self.conv2 = ME.MinkowskiConvolution(
c[0],
c[1],
kernel_size=k[1],
stride=s[1],
has_bias=False,
dimension=3)
self.conv3 = ME.MinkowskiConvolution(
c[1],
c[2],
kernel_size=k[2],
stride=s[2],
has_bias=False,
dimension=3)
self.block1 = nn.Sequential(
ME.MinkowskiConvolution(
1280,
c[0],
kernel_size=k[0],
stride=s[0],
has_bias=False,
dimension=3), ME.MinkowskiInstanceNorm(c[0]),
ME.MinkowskiReLU())
self.block2 = nn.Sequential(
ME.MinkowskiConvolution(
c[0],
c[1],
kernel_size=k[1],
stride=s[1],
has_bias=False,
dimension=3), ME.MinkowskiInstanceNorm(c[1]),
ME.MinkowskiReLU())
self.block3 = nn.Sequential(
ME.MinkowskiConvolution(
c[1],
c[2],
kernel_size=k[2],
stride=s[2],
has_bias=False,
dimension=3), ME.MinkowskiInstanceNorm(c[2]),
ME.MinkowskiReLU())

# Last FC layer. Note that kernel_size 1 == linear layer
self.conv4 = ME.MinkowskiConvolution(
c[2], out_channels, kernel_size=1, has_bias=True, dimension=3)

self.bn1 = ME.MinkowskiInstanceNorm(c[0], dimension=3)
self.bn2 = ME.MinkowskiInstanceNorm(c[1], dimension=3)
self.bn3 = ME.MinkowskiInstanceNorm(c[2], dimension=3)
self.relu = ME.MinkowskiReLU(inplace=True)

def forward(self, x):
"""
Assume that x.F (features) are normalized coordinates or centered coordinates
"""
assert isinstance(x, ME.SparseTensor)
x = self.feat(x)
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x)))
x = self.relu(self.bn3(self.conv3(x)))
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
return self.conv4(x)


Expand All @@ -253,7 +265,7 @@ def forward(self, x):

if __name__ == '__main__':
voxel_size = 2e-3 # High resolution grid works better just like high-res image is better for 2D classification
pointnet = PointNet(20)
pointnet = PointNet(20).float()

pcd = o3d.io.read_point_cloud(bunny_file)

Expand All @@ -267,6 +279,6 @@ def forward(self, x):
inds = ME.utils.sparse_quantize(quantized_coords, return_index=True)
quantized_coords, feats = ME.utils.sparse_collate([quantized_coords[inds]],
[feats[inds]])
sinput = ME.SparseTensor(feats, quantized_coords)
sinput = ME.SparseTensor(feats.float(), quantized_coords)

pointnet(sinput)
print(pointnet(sinput))

0 comments on commit f1a2a95

Please sign in to comment.