Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
sshaoshuai committed Jul 13, 2020
2 parents b7553c8 + 7ac5625 commit e5ba23f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 23 deletions.
12 changes: 8 additions & 4 deletions pcdet/models/backbones_3d/vfe/pillar_vfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn.functional as F
from .vfe_template import VFETemplate


class PFNLayer(nn.Module):
def __init__(self,
in_channels,
Expand All @@ -28,12 +29,14 @@ def forward(self, inputs):
if inputs.shape[0] > self.part:
# nn.Linear performs randomly when batch size is too large
num_parts = inputs.shape[0] // self.part
part_linear_out = [self.linear(inputs[num_part*self.part:(num_part+1)*self.part]) for num_part in range(num_parts+1)]
part_linear_out = [self.linear(inputs[num_part*self.part:(num_part+1)*self.part])
for num_part in range(num_parts+1)]
x = torch.cat(part_linear_out, dim=0)
else:
x = self.linear(inputs)
total_points, voxel_points, channels = x.shape
x = self.norm(x.view(-1, channels)).view(total_points, voxel_points, channels) if self.use_norm else x
x = self.linear(inputs)
torch.backends.cudnn.enabled = False
x = self.norm(x.permute(0, 2, 1)).permute(0, 2, 1) if self.use_norm else x
torch.backends.cudnn.enabled = True
x = F.relu(x)
x_max = torch.max(x, dim=1, keepdim=True)[0]

Expand All @@ -44,6 +47,7 @@ def forward(self, inputs):
x_concatenated = torch.cat([x, x_repeat], dim=2)
return x_concatenated


class PillarVFE(VFETemplate):
def __init__(self, model_cfg, num_point_features, voxel_size, point_cloud_range):
super().__init__(model_cfg=model_cfg)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import numpy as np
from ....utils import box_utils
from ....ops.iou3d_nms import iou3d_nms_utils

Expand All @@ -8,7 +9,7 @@ def __init__(self, anchor_target_cfg, anchor_generator_cfg, class_names, box_cod
super().__init__()
self.box_coder = box_coder
self.match_height = match_height
self.class_names = class_names
self.class_names = np.array(class_names)
self.anchor_class_names = [config['class_name'] for config in anchor_generator_cfg]
self.pos_fraction = anchor_target_cfg.POS_FRACTION if anchor_target_cfg.POS_FRACTION >= 0 else None
self.sample_size = anchor_target_cfg.SAMPLE_SIZE
Expand Down Expand Up @@ -46,7 +47,11 @@ def assign_targets(self, all_anchors, gt_boxes_with_classes, use_multihead=False

target_list = []
for anchor_class_name, anchors in zip(self.anchor_class_names, all_anchors):
mask = torch.tensor([self.class_names[c-1] == anchor_class_name for c in cur_gt_classes], dtype=torch.bool)
if cur_gt_classes.shape[0] > 1:
mask = torch.from_numpy(self.class_names[cur_gt_classes.cpu() - 1] == anchor_class_name)
else:
mask = torch.tensor([self.class_names[c - 1] == anchor_class_name
for c in cur_gt_classes], dtype=torch.bool)

if use_multihead:
anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1])
Expand Down Expand Up @@ -75,15 +80,16 @@ def assign_targets(self, all_anchors, gt_boxes_with_classes, use_multihead=False
else:
target_dict = {
'box_cls_labels': [t['box_cls_labels'].view(*feature_map_size, -1) for t in target_list],
'box_reg_targets': [t['box_reg_targets'].view(*feature_map_size, -1, self.box_coder.code_size) for t in target_list],
'box_reg_targets': [t['box_reg_targets'].view(*feature_map_size, -1, self.box_coder.code_size)
for t in target_list],
'reg_weights': [t['reg_weights'].view(*feature_map_size, -1) for t in target_list]
}

target_dict['box_reg_targets'] = torch.cat(target_dict['box_reg_targets'], dim=-2).view(-1, self.box_coder.code_size)
target_dict['box_reg_targets'] = torch.cat(target_dict['box_reg_targets'],
dim=-2).view(-1, self.box_coder.code_size)
target_dict['box_cls_labels'] = torch.cat(target_dict['box_cls_labels'], dim=-1).view(-1)
target_dict['reg_weights'] = torch.cat(target_dict['reg_weights'], dim=-1).view(-1)


bbox_targets.append(target_dict['box_reg_targets'])
cls_labels.append(target_dict['box_cls_labels'])
reg_weights.append(target_dict['reg_weights'])
Expand All @@ -109,25 +115,25 @@ def assign_targets_single(self, anchors,

num_anchors = anchors.shape[0]
num_gt = gt_boxes.shape[0]
box_ndim = anchors.shape[1]
# box_ndim = anchors.shape[1]

labels = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1
gt_ids = torch.ones((num_anchors,), dtype=torch.int32, device=anchors.device) * -1

if len(gt_boxes) > 0 and anchors.shape[0] > 0:
anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors, gt_boxes) if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors, gt_boxes)
anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(dim=1)
anchor_to_gt_max = anchor_by_gt_overlap[torch.arange(num_anchors),
anchor_to_gt_argmax]

gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(dim=0)
gt_to_anchor_max = anchor_by_gt_overlap[
gt_to_anchor_argmax,
torch.arange(num_gt)]
anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors[:, 0:7], gt_boxes[:, 0:7]) \
if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors[:, 0:7], gt_boxes[:, 0:7])

anchor_to_gt_argmax = torch.from_numpy(anchor_by_gt_overlap.cpu().numpy().argmax(axis=1)).cuda()
anchor_to_gt_max = anchor_by_gt_overlap[
torch.arange(num_anchors, device=anchors.device), anchor_to_gt_argmax
]

gt_to_anchor_argmax = torch.from_numpy(anchor_by_gt_overlap.cpu().numpy().argmax(axis=0)).cuda()
gt_to_anchor_max = anchor_by_gt_overlap[gt_to_anchor_argmax, torch.arange(num_gt, device=anchors.device)]
empty_gt_mask = gt_to_anchor_max == 0
gt_to_anchor_max[empty_gt_mask] = -1
anchors_with_max_overlap = torch.nonzero(
anchor_by_gt_overlap == gt_to_anchor_max)[:, 0]
anchors_with_max_overlap = torch.nonzero(anchor_by_gt_overlap == gt_to_anchor_max)[:, 0]

gt_inds_force = anchor_to_gt_argmax[anchors_with_max_overlap]
labels[anchors_with_max_overlap] = gt_classes[gt_inds_force]
Expand All @@ -139,7 +145,7 @@ def assign_targets_single(self, anchors,
gt_ids[pos_inds] = gt_inds_over_thresh.int()
bg_inds = torch.nonzero(anchor_to_gt_max < unmatched_threshold)[:, 0]
else:
bg_inds = torch.arange(num_anchors)
bg_inds = torch.arange(num_anchors, device=anchors.device)

fg_inds = torch.nonzero(labels > 0)[:, 0]

Expand All @@ -155,7 +161,7 @@ def assign_targets_single(self, anchors,
if len(bg_inds) > num_bg:
enable_inds = bg_inds[torch.randint(0, len(bg_inds), size=(num_bg,))]
labels[enable_inds] = 0
bg_inds = torch.nonzero(labels == 0)[:, 0]
# bg_inds = torch.nonzero(labels == 0)[:, 0]
else:
if len(gt_boxes) == 0 or anchors.shape[0] == 0:
labels[:] = 0
Expand Down

0 comments on commit e5ba23f

Please sign in to comment.