Skip to content

Commit

Permalink
sparse tensor interpolate(field) and class example update
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Apr 30, 2021
1 parent a3a174a commit 432ce88
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 47 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
- Fix `TensorField.sparse()` for no duplicate coordinates
- Skip unnecessary spmm if `SparseTensor.initialize_coordinates()` has no duplicate coordinates
- Model summary utility function added
- TensorField.splat function for interpolated features
- TensorField.splat function for splat features to a sparse tensor
- SparseTensor.interpolate function for extracting interpolated features

## [0.5.3]

Expand Down
18 changes: 18 additions & 0 deletions MinkowskiEngine/MinkowskiSparseTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,24 @@ def dense(self, shape=None, min_coordinate=None, contract_stride=True):
tensor_stride = torch.IntTensor(self.tensor_stride)
return dense_F, min_coordinate, tensor_stride

def interpolate(self, X):
from MinkowskiTensorField import TensorField

assert isinstance(X, TensorField)
if self.coordinate_map_key in X._splat:
tensor_map, field_map, weights, size = X._splat[self.coordinate_map_key]
size = torch.Size([size[1], size[0]]) # transpose
features = MinkowskiSPMMFunction().apply(
field_map, tensor_map, weights, size, self._F
)
else:
features = self.features_at_coordinates(X.C)
return TensorField(
features=features,
coordinate_field_map_key=X.coordinate_field_map_key,
coordinate_manager=X.coordinate_manager,
)

def slice(self, X):
r"""
Expand Down
73 changes: 43 additions & 30 deletions MinkowskiEngine/MinkowskiTensorField.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def create_splat_coordinates(coordinates: torch.Tensor) -> torch.Tensor:
offset[d] = 1
new_offset.append(offset)
region_offset.extend(new_offset)
region_offset = torch.IntTensor(region_offset)
region_offset = torch.IntTensor(region_offset).to(coordinates.device)
coordinates = torch.floor(coordinates).int().unsqueeze(1) + region_offset.unsqueeze(
0
)
Expand Down Expand Up @@ -244,6 +244,7 @@ def __init__(
self.coordinate_field_map_key = coordinate_field_map_key
self._batch_rows = None
self._inverse_mapping = {}
self._splat = {}

@property
def C(self):
Expand Down Expand Up @@ -325,7 +326,6 @@ def sparse(

# Create features
if quantization_mode == SparseTensorQuantizationMode.UNWEIGHTED_SUM:
spmm = MinkowskiSPMMFunction()
N = len(self._F)
cols = torch.arange(
N,
Expand All @@ -334,17 +334,20 @@ def sparse(
)
vals = torch.ones(N, dtype=self._F.dtype, device=self._F.device)
size = torch.Size([N_rows, len(inverse_mapping)])
features = spmm.apply(inverse_mapping, cols, vals, size, self._F)
features = MinkowskiSPMMFunction().apply(
inverse_mapping, cols, vals, size, self._F
)
elif quantization_mode == SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE:
spmm_avg = MinkowskiSPMMAverageFunction()
N = len(self._F)
cols = torch.arange(
N,
dtype=inverse_mapping.dtype,
device=inverse_mapping.device,
)
size = torch.Size([N_rows, len(inverse_mapping)])
features = spmm_avg.apply(inverse_mapping, cols, size, self._F)
features = MinkowskiSPMMAverageFunction().apply(
inverse_mapping, cols, size, self._F
)
elif quantization_mode == SparseTensorQuantizationMode.RANDOM_SUBSAMPLE:
features = self._F[unique_index]
elif quantization_mode == SparseTensorQuantizationMode.MAX_POOL:
Expand All @@ -370,21 +373,26 @@ def sparse(
)

def splat(self):
splat_coordinates = create_splat_coordinates(self._C)
(coordinate_map_key, (unique_index, _)) = self._manager.insert_and_map(
splat_coordinates
)
r"""
For slice, use Y.slice(X) where X is the tensor field and Y is the
resulting sparse tensor.
"""
splat_coordinates = create_splat_coordinates(self.C)
(coordinate_map_key, _) = self._manager.insert_and_map(splat_coordinates)
N_rows = self._manager.size(coordinate_map_key)

tensor_map, field_map, weights = self._manager.interpolation_map_weight(
coordinate_map_key, self._C
)
# features
spmm = MinkowskiSPMMFunction()
N = len(self._F)
assert weights.dtype == self._F.dtype
size = torch.Size([N_rows, N])
features = spmm.apply(tensor_map, field_map, weights, size, self._F)
# Save the results for slice
self._splat[coordinate_map_key] = (tensor_map, field_map, weights, size)
features = MinkowskiSPMMFunction().apply(
tensor_map, field_map, weights, size, self._F
)
return SparseTensor(
features,
coordinate_map_key=coordinate_map_key,
Expand All @@ -400,27 +408,31 @@ def inverse_mapping(self, sparse_tensor_map_key: CoordinateMapKey):
self.coordinate_field_map_key
)
one_key = None
for key in sparse_keys:
if np.prod(key.get_tensor_stride()) == 1:
one_key = key

if one_key is not None:
if one_key not in self._inverse_mapping:
(
_,
self._inverse_mapping[one_key],
) = self._manager.get_field_to_sparse_map(
self.coordinate_field_map_key, one_key
)
_, stride_map = self.coordinate_manager.stride_map(
one_key, sparse_tensor_map_key
)
field_map = self._inverse_mapping[one_key]
self._inverse_mapping[sparse_tensor_map_key] = stride_map[field_map]
if len(sparse_keys) > 0:
for key in sparse_keys:
if np.prod(key.get_tensor_stride()) == 1:
one_key = key
else:
raise ValueError(
f"The field to sparse tensor mapping does not exists for the key: {sparse_tensor_map_key}. Please run TensorField.sparse() before you call slice."
one_key = CoordinateMapKey(
[
1,
]
* self.D,
"",
)

if one_key not in self._inverse_mapping:
(
_,
self._inverse_mapping[one_key],
) = self._manager.get_field_to_sparse_map(
self.coordinate_field_map_key, one_key
)
_, stride_map = self.coordinate_manager.stride_map(
one_key, sparse_tensor_map_key
)
field_map = self._inverse_mapping[one_key]
self._inverse_mapping[sparse_tensor_map_key] = stride_map[field_map]
else:
# Extract the mapping
(
Expand Down Expand Up @@ -484,4 +496,5 @@ def __repr__(self):
"quantization_mode",
"_inverse_mapping",
"_batch_rows",
"_splat",
)
128 changes: 113 additions & 15 deletions examples/classification_modelnet40.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
parser.add_argument(
"--network",
type=str,
choices=["pointnet", "minkpointnet", "minkfcnn"],
choices=["pointnet", "minkpointnet", "minkfcnn", "minksplatfcnn"],
default="minkfcnn",
)

Expand Down Expand Up @@ -107,22 +107,40 @@ def get_conv_block(self, in_channel, out_channel, kernel_size, stride):
)

def network_initialization(
self, in_channel, out_channel, channels, embedding_channel, kernel_size, D=3,
self,
in_channel,
out_channel,
channels,
embedding_channel,
kernel_size,
D=3,
):
self.mlp1 = self.get_mlp_block(in_channel, channels[0])
self.conv1 = self.get_conv_block(
channels[0], channels[1], kernel_size=kernel_size, stride=1,
channels[0],
channels[1],
kernel_size=kernel_size,
stride=1,
)
self.conv2 = self.get_conv_block(
channels[1], channels[2], kernel_size=kernel_size, stride=2,
channels[1],
channels[2],
kernel_size=kernel_size,
stride=2,
)

self.conv3 = self.get_conv_block(
channels[2], channels[3], kernel_size=kernel_size, stride=2,
channels[2],
channels[3],
kernel_size=kernel_size,
stride=2,
)

self.conv4 = self.get_conv_block(
channels[3], channels[4], kernel_size=kernel_size, stride=2,
channels[3],
channels[4],
kernel_size=kernel_size,
stride=2,
)
self.conv5 = nn.Sequential(
self.get_conv_block(
Expand All @@ -132,10 +150,16 @@ def network_initialization(
stride=2,
),
self.get_conv_block(
embedding_channel // 4, embedding_channel // 2, kernel_size=3, stride=2,
embedding_channel // 4,
embedding_channel // 2,
kernel_size=3,
stride=2,
),
self.get_conv_block(
embedding_channel // 2, embedding_channel, kernel_size=3, stride=2,
embedding_channel // 2,
embedding_channel,
kernel_size=3,
stride=2,
),
)

Expand All @@ -152,6 +176,7 @@ def network_initialization(
)

# No, Dropout, last 256 linear, AVG_POOLING 92%

def weight_initialization(self):
for m in self.modules():
if isinstance(m, ME.MinkowskiConvolution):
Expand Down Expand Up @@ -191,16 +216,76 @@ def forward(self, x: ME.TensorField):
return self.final(ME.cat(x1, x2)).F


class GlobalMaxAvgPool(torch.nn.Module):
def __init__(self):
torch.nn.Module.__init__(self)
self.global_max_pool = ME.MinkowskiGlobalMaxPooling()
self.global_avg_pool = ME.MinkowskiGlobalAvgPooling()

def forward(self, tensor):
x = self.global_max_pool(tensor)
y = self.global_avg_pool(tensor)
return ME.cat(x, y)


class MinkowskiSplatFCNN(MinkowskiFCNN):
def __init__(
self,
in_channel,
out_channel,
embedding_channel=1024,
channels=(32, 48, 64, 96, 128),
D=3,
):
MinkowskiFCNN.__init__(
self, in_channel, out_channel, embedding_channel, channels, D
)

def forward(self, x: ME.TensorField):
x = self.mlp1(x)
y = x.splat()

y = self.conv1(y)
y1 = self.pool(y)

y = self.conv2(y1)
y2 = self.pool(y)

y = self.conv3(y2)
y3 = self.pool(y)

y = self.conv4(y3)
y4 = self.pool(y)

x1 = y1.interpolate(x)
x2 = y2.interpolate(x)
x3 = y3.interpolate(x)
x4 = y4.interpolate(x)

x = ME.cat(x1, x2, x3, x4)
y = self.conv5(x.sparse())

x1 = self.global_max_pool(y)
x2 = self.global_avg_pool(y)

return self.final(ME.cat(x1, x2)).F


STR2NETWORK = dict(
pointnet=PointNet, minkpointnet=MinkowskiPointNet, minkfcnn=MinkowskiFCNN
pointnet=PointNet,
minkpointnet=MinkowskiPointNet,
minkfcnn=MinkowskiFCNN,
minksplatfcnn=MinkowskiSplatFCNN,
)


def create_input_batch(batch, is_minknet, device="cuda", quantization_size=0.05):
if is_minknet:
batch["coordinates"][:, 1:] = batch["coordinates"][:, 1:] / quantization_size
return ME.TensorField(
coordinates=batch["coordinates"], features=batch["features"], device=device,
coordinates=batch["coordinates"],
features=batch["features"],
device=device,
)
else:
return batch["coordinates"].permute(0, 2, 1).to(device)
Expand Down Expand Up @@ -237,14 +322,21 @@ def make_data_loader(phase, is_minknet, config):

def test(net, device, config, phase="val"):
is_minknet = isinstance(net, ME.MinkowskiNetwork)
data_loader = make_data_loader("test", is_minknet, config=config,)
data_loader = make_data_loader(
"test",
is_minknet,
config=config,
)

net.eval()
labels, preds = [], []
with torch.no_grad():
for batch in data_loader:
input = create_input_batch(
batch, is_minknet, device=device, quantization_size=config.voxel_size,
batch,
is_minknet,
device=device,
quantization_size=config.voxel_size,
)
logit = net(input)
pred = torch.argmax(logit, 1)
Expand All @@ -255,7 +347,7 @@ def test(net, device, config, phase="val"):


def criterion(pred, labels, smoothing=True):
""" Calculate cross entropy loss, apply label smoothing if needed. """
"""Calculate cross entropy loss, apply label smoothing if needed."""

labels = labels.contiguous().view(-1)
if smoothing:
Expand All @@ -276,9 +368,15 @@ def criterion(pred, labels, smoothing=True):
def train(net, device, config):
is_minknet = isinstance(net, ME.MinkowskiNetwork)
optimizer = optim.SGD(
net.parameters(), lr=config.lr, momentum=0.9, weight_decay=config.weight_decay,
net.parameters(),
lr=config.lr,
momentum=0.9,
weight_decay=config.weight_decay,
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=config.max_steps,
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.max_steps,)
print(optimizer)
print(scheduler)

Expand Down
Loading

0 comments on commit 432ce88

Please sign in to comment.