Skip to content

Commit

Permalink
support multi-classes nms for multi-head, not checked
Browse files Browse the repository at this point in the history
  • Loading branch information
sshaoshuai committed Jul 27, 2020
1 parent 6901df6 commit a236428
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 26 deletions.
18 changes: 5 additions & 13 deletions pcdet/models/dense_heads/anchor_head_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,19 +229,11 @@ def forward(self, data_dict):
)

if isinstance(batch_cls_preds, list):
all_pred_labels = []
all_cls_preds = []
for idx, cls_pred in enumerate(batch_cls_preds):
pred_score, pred_head_label = torch.max(cls_pred, dim=-1)
pred_label = self.rpn_heads[idx].head_label_indices[pred_head_label]

all_pred_labels.append(pred_label)
all_cls_preds.append(pred_score[:, :, None])

batch_cls_preds = torch.cat(all_cls_preds, dim=1)
batch_pred_labels = torch.cat(all_pred_labels, dim=1)
data_dict['batch_pred_labels'] = batch_pred_labels
data_dict['has_class_labels'] = True
multihead_label_mapping = []
for idx in range(len(batch_cls_preds)):
multihead_label_mapping.append(self.rpn_heads[idx].head_label_indices)

data_dict['multihead_label_mapping'] = multihead_label_mapping

data_dict['batch_cls_preds'] = batch_cls_preds
data_dict['batch_box_preds'] = batch_box_preds
Expand Down
58 changes: 45 additions & 13 deletions pcdet/models/detectors/detector3d_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .. import backbones_3d, backbones_2d, dense_heads, roi_heads
from ..backbones_3d import vfe, pfe
from ..backbones_2d import map_to_bev
from ..model_utils.model_nms_utils import class_agnostic_nms
from ..model_utils import model_nms_utils
from ...ops.iou3d_nms import iou3d_nms_utils


Expand Down Expand Up @@ -169,6 +169,8 @@ def post_processing(self, batch_dict):
batch_dict:
batch_size:
batch_cls_preds: (B, num_boxes, num_classes | 1) or (N1+N2+..., num_classes | 1)
or [(B, num_boxes, num_class1), (B, num_boxes, num_class2) ...]
multihead_label_mapping: [(num_class1), (num_class2), ...]
batch_box_preds: (B, num_boxes, 7+C) or (N1+N2+..., 7+C)
cls_preds_normalized: indicate whether batch_cls_preds is normalized
batch_index: optional (N1+N2+...)
Expand All @@ -184,32 +186,62 @@ def post_processing(self, batch_dict):
pred_dicts = []
for index in range(batch_size):
if batch_dict.get('batch_index', None) is not None:
assert batch_dict['batch_cls_preds'].shape.__len__() == 2
assert batch_dict['batch_box_preds'].shape.__len__() == 2
batch_mask = (batch_dict['batch_index'] == index)
else:
assert batch_dict['batch_cls_preds'].shape.__len__() == 3
assert batch_dict['batch_box_preds'].shape.__len__() == 3
batch_mask = index

box_preds = batch_dict['batch_box_preds'][batch_mask]
cls_preds = batch_dict['batch_cls_preds'][batch_mask]

src_cls_preds = cls_preds
src_box_preds = box_preds
assert cls_preds.shape[1] in [1, self.num_class]

if not batch_dict['cls_preds_normalized']:
cls_preds = torch.sigmoid(cls_preds)
if not isinstance(batch_dict['batch_cls_preds'], list):
cls_preds = batch_dict['batch_cls_preds'][batch_mask]

src_cls_preds = cls_preds
assert cls_preds.shape[1] in [1, self.num_class]

if not batch_dict['cls_preds_normalized']:
cls_preds = torch.sigmoid(cls_preds)
else:
cls_preds = [x[batch_mask] for x in batch_dict['batch_cls_preds']]
src_cls_preds = cls_preds
if not batch_dict['cls_preds_normalized']:
cls_preds = [torch.sigmoid(x) for x in cls_preds]

if post_process_cfg.NMS_CONFIG.MULTI_CLASSES_NMS:
raise NotImplementedError
if not isinstance(cls_preds, list):
cls_preds = [cls_preds]
multihead_label_mapping = [torch.arange(1, self.num_class, device=cls_preds[0].device)]
else:
multihead_label_mapping = batch_dict['multihead_label_mapping']

cur_start_idx = 0
pred_scores, pred_labels, pred_boxes = [], [], []
for cur_cls_preds, cur_label_mapping in zip(cls_preds, multihead_label_mapping):
assert cur_cls_preds.shape[1] == len(cur_label_mapping)
cur_box_preds = box_preds[cur_start_idx: cur_start_idx + cur_cls_preds.shape[0]]
cur_pred_scores, cur_pred_labels, cur_pred_boxes = model_nms_utils.multi_classes_nms(
cls_scores=cur_cls_preds, box_preds=cur_box_preds,
nms_config=post_process_cfg.NMS_CONFIG,
score_thresh=post_process_cfg.SCORE_THRESH
)
cur_pred_labels = cur_label_mapping[cur_pred_labels]
pred_scores.append(cur_pred_scores)
pred_labels.append(cur_pred_labels)
pred_boxes.append(cur_pred_boxes)

final_scores = torch.cat(pred_scores, dim=0)
final_labels = torch.cat(pred_labels, dim=0)
final_boxes = torch.cat(pred_boxes, dim=0)
else:
cls_preds, label_preds = torch.max(cls_preds, dim=-1)
if batch_dict.get('has_class_labels', False):
label_key = 'roi_labels' if 'roi_labels' in batch_dict else 'batch_pred_labels'
label_preds = batch_dict[label_key][index]
else:
label_preds + 1

selected, selected_scores = class_agnostic_nms(
label_preds = label_preds + 1
selected, selected_scores = model_nms_utils.class_agnostic_nms(
box_scores=cls_preds, box_preds=box_preds,
nms_config=post_process_cfg.NMS_CONFIG,
score_thresh=post_process_cfg.SCORE_THRESH
Expand Down
43 changes: 43 additions & 0 deletions pcdet/models/model_utils/model_nms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,46 @@ def class_agnostic_nms(box_scores, box_preds, nms_config, score_thresh=None):
original_idxs = scores_mask.nonzero().view(-1)
selected = original_idxs[selected]
return selected, src_box_scores[selected]


def multi_classes_nms(cls_scores, box_preds, nms_config, score_thresh=None):
"""
Args:
cls_scores: (N, num_class)
box_preds: (N, 7 + C)
nms_config:
score_thresh:
Returns:
"""
pred_scores, pred_labels, pred_boxes = [], [], []
for k in range(cls_scores.shape[0]):
if score_thresh is not None:
scores_mask = (cls_scores[:, k] >= score_thresh)
box_scores = cls_scores[scores_mask, k]
box_preds = box_preds[scores_mask]
else:
box_scores = cls_scores[:, k]

selected = []
if box_scores.shape[0] > 0:
box_scores_nms, indices = torch.topk(box_scores, k=min(nms_config.NMS_PRE_MAXSIZE, box_scores.shape[0]))
boxes_for_nms = box_preds[indices]
keep_idx, selected_scores = getattr(iou3d_nms_utils, nms_config.NMS_TYPE)(
boxes_for_nms[:, 0:7], box_scores_nms, nms_config.NMS_THRESH, **nms_config
)
selected = indices[keep_idx[:nms_config.NMS_POST_MAXSIZE]]

if score_thresh is not None:
selected = scores_mask.nonzero().view(-1)

pred_scores.append(box_scores[selected])
pred_labels.append(box_scores.new_ones(selected.shape[0]) * k)
pred_boxes.append(box_preds[selected])

pred_scores = torch.cat(pred_scores, dim=0)
pred_labels = torch.cat(pred_labels, dim=0)
pred_boxes = torch.cat(pred_boxes, dim=0)

return pred_scores, pred_labels, pred_boxes

0 comments on commit a236428

Please sign in to comment.