diff --git a/CHANGELOG.md b/CHANGELOG.md index 49f6e373..0bd5d7b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,8 @@ - Fix `TensorField.sparse()` for no duplicate coordinates - Skip unnecessary spmm if `SparseTensor.initialize_coordinates()` has no duplicate coordinates - Model summary utility function added -- TensorField.splat function for interpolated features +- TensorField.splat function for splat features to a sparse tensor +- SparseTensor.interpolate function for extracting interpolated features ## [0.5.3] diff --git a/MinkowskiEngine/MinkowskiSparseTensor.py b/MinkowskiEngine/MinkowskiSparseTensor.py index e7e91263..f80fb222 100644 --- a/MinkowskiEngine/MinkowskiSparseTensor.py +++ b/MinkowskiEngine/MinkowskiSparseTensor.py @@ -536,6 +536,24 @@ def dense(self, shape=None, min_coordinate=None, contract_stride=True): tensor_stride = torch.IntTensor(self.tensor_stride) return dense_F, min_coordinate, tensor_stride + def interpolate(self, X): + from MinkowskiTensorField import TensorField + + assert isinstance(X, TensorField) + if self.coordinate_map_key in X._splat: + tensor_map, field_map, weights, size = X._splat[self.coordinate_map_key] + size = torch.Size([size[1], size[0]]) # transpose + features = MinkowskiSPMMFunction().apply( + field_map, tensor_map, weights, size, self._F + ) + else: + features = self.features_at_coordinates(X.C) + return TensorField( + features=features, + coordinate_field_map_key=X.coordinate_field_map_key, + coordinate_manager=X.coordinate_manager, + ) + def slice(self, X): r""" diff --git a/MinkowskiEngine/MinkowskiTensorField.py b/MinkowskiEngine/MinkowskiTensorField.py index 73ddb068..4ef6886e 100644 --- a/MinkowskiEngine/MinkowskiTensorField.py +++ b/MinkowskiEngine/MinkowskiTensorField.py @@ -66,7 +66,7 @@ def create_splat_coordinates(coordinates: torch.Tensor) -> torch.Tensor: offset[d] = 1 new_offset.append(offset) region_offset.extend(new_offset) - region_offset = torch.IntTensor(region_offset) + region_offset = torch.IntTensor(region_offset).to(coordinates.device) coordinates = torch.floor(coordinates).int().unsqueeze(1) + region_offset.unsqueeze( 0 ) @@ -244,6 +244,7 @@ def __init__( self.coordinate_field_map_key = coordinate_field_map_key self._batch_rows = None self._inverse_mapping = {} + self._splat = {} @property def C(self): @@ -325,7 +326,6 @@ def sparse( # Create features if quantization_mode == SparseTensorQuantizationMode.UNWEIGHTED_SUM: - spmm = MinkowskiSPMMFunction() N = len(self._F) cols = torch.arange( N, @@ -334,9 +334,10 @@ def sparse( ) vals = torch.ones(N, dtype=self._F.dtype, device=self._F.device) size = torch.Size([N_rows, len(inverse_mapping)]) - features = spmm.apply(inverse_mapping, cols, vals, size, self._F) + features = MinkowskiSPMMFunction().apply( + inverse_mapping, cols, vals, size, self._F + ) elif quantization_mode == SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE: - spmm_avg = MinkowskiSPMMAverageFunction() N = len(self._F) cols = torch.arange( N, @@ -344,7 +345,9 @@ def sparse( device=inverse_mapping.device, ) size = torch.Size([N_rows, len(inverse_mapping)]) - features = spmm_avg.apply(inverse_mapping, cols, size, self._F) + features = MinkowskiSPMMAverageFunction().apply( + inverse_mapping, cols, size, self._F + ) elif quantization_mode == SparseTensorQuantizationMode.RANDOM_SUBSAMPLE: features = self._F[unique_index] elif quantization_mode == SparseTensorQuantizationMode.MAX_POOL: @@ -370,21 +373,26 @@ def sparse( ) def splat(self): - splat_coordinates = create_splat_coordinates(self._C) - (coordinate_map_key, (unique_index, _)) = self._manager.insert_and_map( - splat_coordinates - ) + r""" + For slice, use Y.slice(X) where X is the tensor field and Y is the + resulting sparse tensor. + """ + splat_coordinates = create_splat_coordinates(self.C) + (coordinate_map_key, _) = self._manager.insert_and_map(splat_coordinates) N_rows = self._manager.size(coordinate_map_key) tensor_map, field_map, weights = self._manager.interpolation_map_weight( coordinate_map_key, self._C ) # features - spmm = MinkowskiSPMMFunction() N = len(self._F) assert weights.dtype == self._F.dtype size = torch.Size([N_rows, N]) - features = spmm.apply(tensor_map, field_map, weights, size, self._F) + # Save the results for slice + self._splat[coordinate_map_key] = (tensor_map, field_map, weights, size) + features = MinkowskiSPMMFunction().apply( + tensor_map, field_map, weights, size, self._F + ) return SparseTensor( features, coordinate_map_key=coordinate_map_key, @@ -400,27 +408,31 @@ def inverse_mapping(self, sparse_tensor_map_key: CoordinateMapKey): self.coordinate_field_map_key ) one_key = None - for key in sparse_keys: - if np.prod(key.get_tensor_stride()) == 1: - one_key = key - - if one_key is not None: - if one_key not in self._inverse_mapping: - ( - _, - self._inverse_mapping[one_key], - ) = self._manager.get_field_to_sparse_map( - self.coordinate_field_map_key, one_key - ) - _, stride_map = self.coordinate_manager.stride_map( - one_key, sparse_tensor_map_key - ) - field_map = self._inverse_mapping[one_key] - self._inverse_mapping[sparse_tensor_map_key] = stride_map[field_map] + if len(sparse_keys) > 0: + for key in sparse_keys: + if np.prod(key.get_tensor_stride()) == 1: + one_key = key else: - raise ValueError( - f"The field to sparse tensor mapping does not exists for the key: {sparse_tensor_map_key}. Please run TensorField.sparse() before you call slice." + one_key = CoordinateMapKey( + [ + 1, + ] + * self.D, + "", ) + + if one_key not in self._inverse_mapping: + ( + _, + self._inverse_mapping[one_key], + ) = self._manager.get_field_to_sparse_map( + self.coordinate_field_map_key, one_key + ) + _, stride_map = self.coordinate_manager.stride_map( + one_key, sparse_tensor_map_key + ) + field_map = self._inverse_mapping[one_key] + self._inverse_mapping[sparse_tensor_map_key] = stride_map[field_map] else: # Extract the mapping ( @@ -484,4 +496,5 @@ def __repr__(self): "quantization_mode", "_inverse_mapping", "_batch_rows", + "_splat", ) diff --git a/examples/classification_modelnet40.py b/examples/classification_modelnet40.py index 78c298e1..f7013a92 100644 --- a/examples/classification_modelnet40.py +++ b/examples/classification_modelnet40.py @@ -60,7 +60,7 @@ parser.add_argument( "--network", type=str, - choices=["pointnet", "minkpointnet", "minkfcnn"], + choices=["pointnet", "minkpointnet", "minkfcnn", "minksplatfcnn"], default="minkfcnn", ) @@ -107,22 +107,40 @@ def get_conv_block(self, in_channel, out_channel, kernel_size, stride): ) def network_initialization( - self, in_channel, out_channel, channels, embedding_channel, kernel_size, D=3, + self, + in_channel, + out_channel, + channels, + embedding_channel, + kernel_size, + D=3, ): self.mlp1 = self.get_mlp_block(in_channel, channels[0]) self.conv1 = self.get_conv_block( - channels[0], channels[1], kernel_size=kernel_size, stride=1, + channels[0], + channels[1], + kernel_size=kernel_size, + stride=1, ) self.conv2 = self.get_conv_block( - channels[1], channels[2], kernel_size=kernel_size, stride=2, + channels[1], + channels[2], + kernel_size=kernel_size, + stride=2, ) self.conv3 = self.get_conv_block( - channels[2], channels[3], kernel_size=kernel_size, stride=2, + channels[2], + channels[3], + kernel_size=kernel_size, + stride=2, ) self.conv4 = self.get_conv_block( - channels[3], channels[4], kernel_size=kernel_size, stride=2, + channels[3], + channels[4], + kernel_size=kernel_size, + stride=2, ) self.conv5 = nn.Sequential( self.get_conv_block( @@ -132,10 +150,16 @@ def network_initialization( stride=2, ), self.get_conv_block( - embedding_channel // 4, embedding_channel // 2, kernel_size=3, stride=2, + embedding_channel // 4, + embedding_channel // 2, + kernel_size=3, + stride=2, ), self.get_conv_block( - embedding_channel // 2, embedding_channel, kernel_size=3, stride=2, + embedding_channel // 2, + embedding_channel, + kernel_size=3, + stride=2, ), ) @@ -152,6 +176,7 @@ def network_initialization( ) # No, Dropout, last 256 linear, AVG_POOLING 92% + def weight_initialization(self): for m in self.modules(): if isinstance(m, ME.MinkowskiConvolution): @@ -191,8 +216,66 @@ def forward(self, x: ME.TensorField): return self.final(ME.cat(x1, x2)).F +class GlobalMaxAvgPool(torch.nn.Module): + def __init__(self): + torch.nn.Module.__init__(self) + self.global_max_pool = ME.MinkowskiGlobalMaxPooling() + self.global_avg_pool = ME.MinkowskiGlobalAvgPooling() + + def forward(self, tensor): + x = self.global_max_pool(tensor) + y = self.global_avg_pool(tensor) + return ME.cat(x, y) + + +class MinkowskiSplatFCNN(MinkowskiFCNN): + def __init__( + self, + in_channel, + out_channel, + embedding_channel=1024, + channels=(32, 48, 64, 96, 128), + D=3, + ): + MinkowskiFCNN.__init__( + self, in_channel, out_channel, embedding_channel, channels, D + ) + + def forward(self, x: ME.TensorField): + x = self.mlp1(x) + y = x.splat() + + y = self.conv1(y) + y1 = self.pool(y) + + y = self.conv2(y1) + y2 = self.pool(y) + + y = self.conv3(y2) + y3 = self.pool(y) + + y = self.conv4(y3) + y4 = self.pool(y) + + x1 = y1.interpolate(x) + x2 = y2.interpolate(x) + x3 = y3.interpolate(x) + x4 = y4.interpolate(x) + + x = ME.cat(x1, x2, x3, x4) + y = self.conv5(x.sparse()) + + x1 = self.global_max_pool(y) + x2 = self.global_avg_pool(y) + + return self.final(ME.cat(x1, x2)).F + + STR2NETWORK = dict( - pointnet=PointNet, minkpointnet=MinkowskiPointNet, minkfcnn=MinkowskiFCNN + pointnet=PointNet, + minkpointnet=MinkowskiPointNet, + minkfcnn=MinkowskiFCNN, + minksplatfcnn=MinkowskiSplatFCNN, ) @@ -200,7 +283,9 @@ def create_input_batch(batch, is_minknet, device="cuda", quantization_size=0.05) if is_minknet: batch["coordinates"][:, 1:] = batch["coordinates"][:, 1:] / quantization_size return ME.TensorField( - coordinates=batch["coordinates"], features=batch["features"], device=device, + coordinates=batch["coordinates"], + features=batch["features"], + device=device, ) else: return batch["coordinates"].permute(0, 2, 1).to(device) @@ -237,14 +322,21 @@ def make_data_loader(phase, is_minknet, config): def test(net, device, config, phase="val"): is_minknet = isinstance(net, ME.MinkowskiNetwork) - data_loader = make_data_loader("test", is_minknet, config=config,) + data_loader = make_data_loader( + "test", + is_minknet, + config=config, + ) net.eval() labels, preds = [], [] with torch.no_grad(): for batch in data_loader: input = create_input_batch( - batch, is_minknet, device=device, quantization_size=config.voxel_size, + batch, + is_minknet, + device=device, + quantization_size=config.voxel_size, ) logit = net(input) pred = torch.argmax(logit, 1) @@ -255,7 +347,7 @@ def test(net, device, config, phase="val"): def criterion(pred, labels, smoothing=True): - """ Calculate cross entropy loss, apply label smoothing if needed. """ + """Calculate cross entropy loss, apply label smoothing if needed.""" labels = labels.contiguous().view(-1) if smoothing: @@ -276,9 +368,15 @@ def criterion(pred, labels, smoothing=True): def train(net, device, config): is_minknet = isinstance(net, ME.MinkowskiNetwork) optimizer = optim.SGD( - net.parameters(), lr=config.lr, momentum=0.9, weight_decay=config.weight_decay, + net.parameters(), + lr=config.lr, + momentum=0.9, + weight_decay=config.weight_decay, + ) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=config.max_steps, ) - scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.max_steps,) print(optimizer) print(scheduler) diff --git a/tests/python/tensor_field.py b/tests/python/tensor_field.py index 5ed9c0cb..26daa873 100644 --- a/tests/python/tensor_field.py +++ b/tests/python/tensor_field.py @@ -259,6 +259,7 @@ def test_small(self): tensor = tfield.splat() print(tfield) print(tensor) + print(tensor.interpolate(tfield)) def test_small2(self): coords = torch.FloatTensor([[0, 0.1, 0.1], [0, 1.1, 1.1]]) @@ -266,4 +267,5 @@ def test_small2(self): tfield = TensorField(coordinates=coords, features=feats) tensor = tfield.splat() print(tfield) - print(tensor) \ No newline at end of file + print(tensor) + print(tensor.interpolate(tfield)) \ No newline at end of file