Skip to content

Commit

Permalink
Add DynamicPillarVFE (open-mmlab#754)
Browse files Browse the repository at this point in the history
* add dynamic pillar vfe

* make the upperbound unaccessible

* add place holder for voxel generation

* add DynPillarVFE

* add PFNLayerV2

* add try except for torch_scatter package

* add dynamic pillar in readme

* add the cfg file of centerpoint with dynamic pillar
  • Loading branch information
djiajunustc authored Jan 15, 2022
1 parent 7f977ea commit 0f4d3f1
Show file tree
Hide file tree
Showing 6 changed files with 276 additions and 5 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/18


## Changelog
[2022-01-14] Added support for dynamic pillar voxelization, following the implementation proposed in `H^23D R-CNN` with unique operation and [`torch_scatter`](https://github.com/rusty1s/pytorch_scatter) package.

[2022-01-05] **NEW:** Update `OpenPCDet` to v0.5.2:
* The code of [PV-RCNN++](https://arxiv.org/abs/2102.00463) has been released to this repo, with higher performance, faster training/inference speed and less memory consumption than PV-RCNN.
* Add performance of several models trained with full training set of [Waymo Open Dataset](#waymo-open-dataset-baselines).
Expand All @@ -39,9 +41,9 @@ It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/18
* Support config [`USE_SHARED_MEMORY`](tools/cfgs/dataset_configs/waymo_dataset.yaml) to use shared memory to potentially speed up the training process in case you suffer from an IO problem.
* Support better and faster [visualization script](tools/visual_utils/open3d_vis_utils.py), and you need to install [Open3D](https://github.com/isl-org/Open3D) firstly.

[2021-06-08] Added support for the voxel-based 3D object detection model [`Voxel R-CNN`](#KITTI-3D-Object-Detection-Baselines)
[2021-06-08] Added support for the voxel-based 3D object detection model [`Voxel R-CNN`](#KITTI-3D-Object-Detection-Baselines).

[2021-05-14] Added support for the monocular 3D object detection model [`CaDDN`](#KITTI-3D-Object-Detection-Baselines)
[2021-05-14] Added support for the monocular 3D object detection model [`CaDDN`](#KITTI-3D-Object-Detection-Baselines).

[2020-11-27] Bugfixed: Please re-prepare the validation infos of Waymo dataset (version 1.2) if you would like to
use our provided Waymo evaluation tool (see [PR](https://github.com/open-mmlab/OpenPCDet/pull/383)).
Expand Down Expand Up @@ -144,6 +146,7 @@ By default, all models are trained with **a single frame** of **20% data (~32k f
| [SECOND](tools/cfgs/waymo_models/second.yaml) | 70.96/70.34|62.58/62.02|65.23/54.24 |57.22/47.49| 57.13/55.62 | 54.97/53.53 |
| [PointPillar](tools/cfgs/waymo_models/pointpillar_1x.yaml) | 70.43/69.83 | 62.18/61.64 | 66.21/46.32|58.18/40.64|55.26/51.75|53.18/49.80 |
[CenterPoint-Pillar](tools/cfgs/waymo_models/centerpoint_pillar_1x.yaml)| 70.50/69.96|62.18/61.69|73.11/61.97|65.06/55.00|65.44/63.85|62.98/61.46|
[CenterPoint-Dynamic-Pillar](tools/cfgs/waymo_models/centerpoint_dyn_pillar_1x.yaml)| 70.46/69.93|62.06/61.58|73.92/63.35|65.91/56.33|66.24/64.69|63.73/62.24|
[CenterPoint](tools/cfgs/waymo_models/centerpoint_without_resnet.yaml)| 71.33/70.76|63.16/62.65| 72.09/65.49 |64.27/58.23| 68.68/67.39 |66.11/64.87|
| [CenterPoint (ResNet)](tools/cfgs/waymo_models/centerpoint.yaml)|72.76/72.23|64.91/64.42 |74.19/67.96 |66.03/60.34| 71.04/69.79 |68.49/67.28 |
| [Part-A2-Anchor](tools/cfgs/waymo_models/PartA2.yaml) | 74.66/74.12 |65.82/65.32 |71.71/62.24 |62.46/54.06 |66.53/65.18 |64.05/62.75 |
Expand Down
10 changes: 10 additions & 0 deletions pcdet/datasets/processor/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ def shuffle_points(self, data_dict=None, config=None):

return data_dict

def transform_points_to_voxels_placeholder(self, data_dict=None, config=None):
# just calculate grid size
if data_dict is None:
grid_size = (self.point_cloud_range[3:6] - self.point_cloud_range[0:3]) / np.array(config.VOXEL_SIZE)
self.grid_size = np.round(grid_size).astype(np.int64)
self.voxel_size = config.VOXEL_SIZE
return partial(self.transform_points_to_voxels_placeholder, config=config)

return data_dict

def transform_points_to_voxels(self, data_dict=None, config=None):
if data_dict is None:
grid_size = (self.point_cloud_range[3:6] - self.point_cloud_range[0:3]) / np.array(config.VOXEL_SIZE)
Expand Down
4 changes: 3 additions & 1 deletion pcdet/models/backbones_3d/vfe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .mean_vfe import MeanVFE
from .pillar_vfe import PillarVFE
from .dynamic_pillar_vfe import DynamicPillarVFE
from .image_vfe import ImageVFE
from .vfe_template import VFETemplate

__all__ = {
'VFETemplate': VFETemplate,
'MeanVFE': MeanVFE,
'PillarVFE': PillarVFE,
'ImageVFE': ImageVFE
'ImageVFE': ImageVFE,
'DynPillarVFE': DynamicPillarVFE
}
144 changes: 144 additions & 0 deletions pcdet/models/backbones_3d/vfe/dynamic_pillar_vfe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

try:
import torch_scatter
except Exception as e:
# Incase someone doesn't want to use dynamic pillar vfe and hasn't installed torch_scatter
pass

from .vfe_template import VFETemplate


class PFNLayerV2(nn.Module):
def __init__(self,
in_channels,
out_channels,
use_norm=True,
last_layer=False):
super().__init__()

self.last_vfe = last_layer
self.use_norm = use_norm
if not self.last_vfe:
out_channels = out_channels // 2

if self.use_norm:
self.linear = nn.Linear(in_channels, out_channels, bias=False)
self.norm = nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01)
else:
self.linear = nn.Linear(in_channels, out_channels, bias=True)

self.relu = nn.ReLU()

def forward(self, inputs, unq_inv):

x = self.linear(inputs)
x = self.norm(x) if self.use_norm else x
x = self.relu(x)
x_max = torch_scatter.scatter_max(x, unq_inv, dim=0)[0]

if self.last_vfe:
return x_max
else:
x_concatenated = torch.cat([x, x_max[unq_inv, :]], dim=1)
return x_concatenated


class DynamicPillarVFE(VFETemplate):
def __init__(self, model_cfg, num_point_features, voxel_size, grid_size, point_cloud_range, **kwargs):
super().__init__(model_cfg=model_cfg)

self.use_norm = self.model_cfg.USE_NORM
self.with_distance = self.model_cfg.WITH_DISTANCE
self.use_absolute_xyz = self.model_cfg.USE_ABSLOTE_XYZ
num_point_features += 6 if self.use_absolute_xyz else 3
if self.with_distance:
num_point_features += 1

self.num_filters = self.model_cfg.NUM_FILTERS
assert len(self.num_filters) > 0
num_filters = [num_point_features] + list(self.num_filters)

pfn_layers = []
for i in range(len(num_filters) - 1):
in_filters = num_filters[i]
out_filters = num_filters[i + 1]
pfn_layers.append(
PFNLayerV2(in_filters, out_filters, self.use_norm, last_layer=(i >= len(num_filters) - 2))
)
self.pfn_layers = nn.ModuleList(pfn_layers)

self.voxel_x = voxel_size[0]
self.voxel_y = voxel_size[1]
self.voxel_z = voxel_size[2]
self.x_offset = self.voxel_x / 2 + point_cloud_range[0]
self.y_offset = self.voxel_y / 2 + point_cloud_range[1]
self.z_offset = self.voxel_z / 2 + point_cloud_range[2]

self.point_cloud_range = point_cloud_range
self.scale_xy = grid_size[0] * grid_size[1]
self.scale_y = grid_size[1]


def get_output_feature_dim(self):
return self.num_filters[-1]

def forward(self, batch_dict, **kwargs):
batch_size = batch_dict['batch_size']
points = batch_dict['points'] # (batch_idx, x, y, z, i, e)

points_xyz = points[:, [1,2,3]].contiguous()
# points_coords = (points_xyz[:, [0,1]] - self.point_cloud_range[[0,1]]) / self.voxel_size[[0,1]]
points_coords_x = (points_xyz[:, 0] - self.point_cloud_range[0]) / self.voxel_x
points_coords_y = (points_xyz[:, 1] - self.point_cloud_range[1]) / self.voxel_y

points_coords_x = points_coords_x.floor()
points_coords_y = points_coords_y.floor()
points_coords = torch.stack([points_coords_x, points_coords_y], dim=-1)

merge_coords = points[:, 0].int() * self.scale_xy + \
points_coords[:, 0] * self.scale_y + \
points_coords[:, 1]

unq_coords, unq_inv, unq_cnt = torch.unique(merge_coords, return_inverse=True, return_counts=True, dim=0)

points_mean = torch_scatter.scatter_mean(points_xyz, unq_inv, dim=0)
f_cluster = points_xyz - points_mean[unq_inv, :]

f_center = torch.zeros_like(points_xyz)
f_center[:, 0] = points_xyz[:, 0] - (points_coords[:, 0].to(points_xyz.dtype) * self.voxel_x + self.x_offset)
f_center[:, 1] = points_xyz[:, 1] - (points_coords[:, 1].to(points_xyz.dtype) * self.voxel_y + self.y_offset)
f_center[:, 2] = points_xyz[:, 2] - self.z_offset

if self.use_absolute_xyz:
features = [points[:, 1:], f_cluster, f_center]
else:
features = [points[:, 4:], f_cluster, f_center]

if self.with_distance:
points_dist = torch.norm(points[:, 1:4], 2, dim=1, keepdim=True)
features.append(points_dist)
features = torch.cat(features, dim=-1)

for pfn in self.pfn_layers:
features = pfn(features, unq_inv)
# features = self.linear1(features)
# features_max = torch_scatter.scatter_max(features, unq_inv, dim=0)[0]
# features = torch.cat([features, features_max[unq_inv, :]], dim=1)
# features = self.linear2(features)
# features = torch_scatter.scatter_max(features, unq_inv, dim=0)[0]

# generate voxel coordinates
unq_coords = unq_coords.int()
voxel_coords = torch.stack((unq_coords // self.scale_xy,
(unq_coords % self.scale_xy) // self.scale_y,
unq_coords % self.scale_y,
torch.zeros(unq_coords.shape[0]).to(unq_coords.device).int()
), dim=1)
voxel_coords = voxel_coords[:, [0, 3, 2, 1]]

batch_dict['pillar_features'] = features
batch_dict['voxel_coords'] = voxel_coords
return batch_dict
6 changes: 4 additions & 2 deletions pcdet/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ def rotate_points_along_z(points, angle):


def mask_points_by_range(points, limit_range):
mask = (points[:, 0] >= limit_range[0]) & (points[:, 0] <= limit_range[3]) \
& (points[:, 1] >= limit_range[1]) & (points[:, 1] <= limit_range[4])
# mask = (points[:, 0] >= limit_range[0]) & (points[:, 0] <= limit_range[3]) \
# & (points[:, 1] >= limit_range[1]) & (points[:, 1] <= limit_range[4])
mask = (points[:, 0] >= limit_range[0]) & (points[:, 0] < limit_range[3]) \
& (points[:, 1] >= limit_range[1]) & (points[:, 1] < limit_range[4])
return mask


Expand Down
110 changes: 110 additions & 0 deletions tools/cfgs/waymo_models/centerpoint_dyn_pillar_1x.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
CLASS_NAMES: ['Vehicle', 'Pedestrian', 'Cyclist']

DATA_CONFIG:
_BASE_CONFIG_: cfgs/dataset_configs/waymo_dataset.yaml

POINT_CLOUD_RANGE: [-74.88, -74.88, -2, 74.88, 74.88, 4.0]
DATA_PROCESSOR:
- NAME: mask_points_and_boxes_outside_range
REMOVE_OUTSIDE_BOXES: True

- NAME: shuffle_points
SHUFFLE_ENABLED: {
'train': True,
'test': True
}

- NAME: transform_points_to_voxels_placeholder
VOXEL_SIZE: [ 0.32, 0.32, 6.0 ]

MODEL:
NAME: CenterPoint

VFE:
NAME: DynPillarVFE
WITH_DISTANCE: False
USE_ABSLOTE_XYZ: True
USE_NORM: True
NUM_FILTERS: [ 64, 64 ]

MAP_TO_BEV:
NAME: PointPillarScatter
NUM_BEV_FEATURES: 64

BACKBONE_2D:
NAME: BaseBEVBackbone
LAYER_NUMS: [ 3, 5, 5 ]
LAYER_STRIDES: [ 1, 2, 2 ]
NUM_FILTERS: [ 64, 128, 256 ]
UPSAMPLE_STRIDES: [ 1, 2, 4 ]
NUM_UPSAMPLE_FILTERS: [ 128, 128, 128 ]

DENSE_HEAD:
NAME: CenterHead
CLASS_AGNOSTIC: False

CLASS_NAMES_EACH_HEAD: [
['Vehicle', 'Pedestrian', 'Cyclist']
]

SHARED_CONV_CHANNEL: 64
USE_BIAS_BEFORE_NORM: True
NUM_HM_CONV: 2
SEPARATE_HEAD_CFG:
HEAD_ORDER: ['center', 'center_z', 'dim', 'rot']
HEAD_DICT: {
'center': {'out_channels': 2, 'num_conv': 2},
'center_z': {'out_channels': 1, 'num_conv': 2},
'dim': {'out_channels': 3, 'num_conv': 2},
'rot': {'out_channels': 2, 'num_conv': 2},
}

TARGET_ASSIGNER_CONFIG:
FEATURE_MAP_STRIDE: 1
NUM_MAX_OBJS: 500
GAUSSIAN_OVERLAP: 0.1
MIN_RADIUS: 2

LOSS_CONFIG:
LOSS_WEIGHTS: {
'cls_weight': 1.0,
'loc_weight': 2.0,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
}

POST_PROCESSING:
SCORE_THRESH: 0.1
POST_CENTER_LIMIT_RANGE: [-80, -80, -10.0, 80, 80, 10.0]
MAX_OBJ_PER_SAMPLE: 500
NMS_CONFIG:
NMS_TYPE: nms_gpu
NMS_THRESH: 0.7
NMS_PRE_MAXSIZE: 4096
NMS_POST_MAXSIZE: 500

POST_PROCESSING:
RECALL_THRESH_LIST: [0.3, 0.5, 0.7]

EVAL_METRIC: waymo


OPTIMIZATION:
BATCH_SIZE_PER_GPU: 2
NUM_EPOCHS: 30

OPTIMIZER: adam_onecycle
LR: 0.003
WEIGHT_DECAY: 0.01
MOMENTUM: 0.9

MOMS: [0.95, 0.85]
PCT_START: 0.4
DIV_FACTOR: 10
DECAY_STEP_LIST: [35, 45]
LR_DECAY: 0.1
LR_CLIP: 0.0000001

LR_WARMUP: False
WARMUP_EPOCH: 1

GRAD_NORM_CLIP: 10

0 comments on commit 0f4d3f1

Please sign in to comment.