Skip to content

Commit

Permalink
add contiguous() after transpose() at pointrcnn_head (open-mmlab#373)
Browse files Browse the repository at this point in the history
* bugfixed: ignore empty boxes in visualization

* add contiguous() after transpose() at pointrcnn_head.py
  • Loading branch information
sshaoshuai authored Nov 25, 2020
1 parent 1145f83 commit 53b2b93
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pcdet/models/dense_heads/anchor_head_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def forward(self, spatial_features_2d):

class AnchorHeadMulti(AnchorHeadTemplate):
def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range,
predict_boxes_when_training=True):
predict_boxes_when_training=True, **kwargs):
super().__init__(
model_cfg=model_cfg, num_class=num_class, class_names=class_names, grid_size=grid_size,
point_cloud_range=point_cloud_range, predict_boxes_when_training=predict_boxes_when_training
Expand Down
2 changes: 1 addition & 1 deletion pcdet/models/dense_heads/anchor_head_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class AnchorHeadSingle(AnchorHeadTemplate):
def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range,
predict_boxes_when_training=True):
predict_boxes_when_training=True, **kwargs):
super().__init__(
model_cfg=model_cfg, num_class=num_class, class_names=class_names, grid_size=grid_size, point_cloud_range=point_cloud_range,
predict_boxes_when_training=predict_boxes_when_training
Expand Down
2 changes: 1 addition & 1 deletion pcdet/models/roi_heads/pointrcnn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def forward(self, batch_dict):

pooled_features = self.roipool3d_gpu(batch_dict) # (total_rois, num_sampled_points, 3 + C)

xyz_input = pooled_features[..., 0:self.num_prefix_channels].transpose(1, 2).unsqueeze(dim=3)
xyz_input = pooled_features[..., 0:self.num_prefix_channels].transpose(1, 2).unsqueeze(dim=3).contiguous()
xyz_features = self.xyz_up_layer(xyz_input)
point_features = pooled_features[..., self.num_prefix_channels:].transpose(1, 2).unsqueeze(dim=3)
merged_features = torch.cat((xyz_features, point_features), dim=1)
Expand Down

0 comments on commit 53b2b93

Please sign in to comment.