From 13789796f70a033435a8289b3e1361000fc4694c Mon Sep 17 00:00:00 2001 From: Shaoshuai Shi Date: Sun, 26 Dec 2021 16:57:10 +0100 Subject: [PATCH] Support PV-RCNN++ frameworks, support VectorPool aggregation --- .../backbones_3d/pfe/voxel_set_abstraction.py | 35 +- pcdet/models/detectors/__init__.py | 4 +- pcdet/models/roi_heads/pvrcnn_head.py | 23 +- .../pointnet2_stack/pointnet2_modules.py | 333 ++++++++++++ .../pointnet2_stack/pointnet2_utils.py | 149 ++++++ .../pointnet2_stack/src/pointnet2_api.cpp | 7 + .../pointnet2_stack/src/vector_pool.cpp | 203 ++++++++ .../pointnet2_stack/src/vector_pool_gpu.cu | 486 ++++++++++++++++++ .../pointnet2_stack/src/vector_pool_gpu.h | 71 +++ setup.py | 2 + tools/cfgs/waymo_models/pv_rcnn_plusplus.yaml | 277 ++++++++++ 11 files changed, 1553 insertions(+), 37 deletions(-) create mode 100644 pcdet/ops/pointnet2/pointnet2_stack/src/vector_pool.cpp create mode 100644 pcdet/ops/pointnet2/pointnet2_stack/src/vector_pool_gpu.cu create mode 100644 pcdet/ops/pointnet2/pointnet2_stack/src/vector_pool_gpu.h create mode 100644 tools/cfgs/waymo_models/pv_rcnn_plusplus.yaml diff --git a/pcdet/models/backbones_3d/pfe/voxel_set_abstraction.py b/pcdet/models/backbones_3d/pfe/voxel_set_abstraction.py index ecf289769..0f3b8ae93 100644 --- a/pcdet/models/backbones_3d/pfe/voxel_set_abstraction.py +++ b/pcdet/models/backbones_3d/pfe/voxel_set_abstraction.py @@ -139,38 +139,31 @@ def __init__(self, model_cfg, voxel_size, point_cloud_range, num_bev_features=No if src_name in ['bev', 'raw_points']: continue self.downsample_times_map[src_name] = SA_cfg[src_name].DOWNSAMPLE_FACTOR - mlps = SA_cfg[src_name].MLPS - for k in range(len(mlps)): - mlps[k] = [mlps[k][0]] + mlps[k] - cur_layer = pointnet2_stack_modules.StackSAModuleMSG( - radii=SA_cfg[src_name].POOL_RADIUS, - nsamples=SA_cfg[src_name].NSAMPLE, - mlps=mlps, - use_xyz=True, - pool_method='max_pool', + + if SA_cfg[src_name].get('INPUT_CHANNELS', None) is None: + input_channels = SA_cfg[src_name].MLPS[0][0] \ + if isinstance(SA_cfg[src_name].MLPS[0], list) else SA_cfg[src_name].MLPS[0] + else: + input_channels = SA_cfg[src_name]['INPUT_CHANNELS'] + + cur_layer, cur_num_c_out = pointnet2_stack_modules.build_local_aggregation_module( + input_channels=input_channels, config=SA_cfg[src_name] ) self.SA_layers.append(cur_layer) self.SA_layer_names.append(src_name) - c_in += sum([x[-1] for x in mlps]) + c_in += cur_num_c_out if 'bev' in self.model_cfg.FEATURES_SOURCE: c_bev = num_bev_features c_in += c_bev if 'raw_points' in self.model_cfg.FEATURES_SOURCE: - mlps = SA_cfg['raw_points'].MLPS - for k in range(len(mlps)): - mlps[k] = [num_rawpoint_features - 3] + mlps[k] - - self.SA_rawpoints = pointnet2_stack_modules.StackSAModuleMSG( - radii=SA_cfg['raw_points'].POOL_RADIUS, - nsamples=SA_cfg['raw_points'].NSAMPLE, - mlps=mlps, - use_xyz=True, - pool_method='max_pool' + self.SA_rawpoints, cur_num_c_out = pointnet2_stack_modules.build_local_aggregation_module( + input_channels=num_rawpoint_features - 3, config=SA_cfg['raw_points'] ) - c_in += sum([x[-1] for x in mlps]) + + c_in += cur_num_c_out self.vsa_point_feature_fusion = nn.Sequential( nn.Linear(c_in, self.model_cfg.NUM_OUTPUT_FEATURES, bias=False), diff --git a/pcdet/models/detectors/__init__.py b/pcdet/models/detectors/__init__.py index 8f0f167f8..09b24f35a 100644 --- a/pcdet/models/detectors/__init__.py +++ b/pcdet/models/detectors/__init__.py @@ -8,6 +8,7 @@ from .caddn import CaDDN from .voxel_rcnn import VoxelRCNN from .centerpoint import CenterPoint +from .pv_rcnn_plusplus import PVRCNNPlusPlus __all__ = { 'Detector3DTemplate': Detector3DTemplate, @@ -19,7 +20,8 @@ 'SECONDNetIoU': SECONDNetIoU, 'CaDDN': CaDDN, 'VoxelRCNN': VoxelRCNN, - 'CenterPoint': CenterPoint + 'CenterPoint': CenterPoint, + 'PVRCNNPlusPlus': PVRCNNPlusPlus } diff --git a/pcdet/models/roi_heads/pvrcnn_head.py b/pcdet/models/roi_heads/pvrcnn_head.py index 01bcf15a3..6ec6b9806 100644 --- a/pcdet/models/roi_heads/pvrcnn_head.py +++ b/pcdet/models/roi_heads/pvrcnn_head.py @@ -10,21 +10,12 @@ def __init__(self, input_channels, model_cfg, num_class=1, **kwargs): super().__init__(num_class=num_class, model_cfg=model_cfg) self.model_cfg = model_cfg - mlps = self.model_cfg.ROI_GRID_POOL.MLPS - for k in range(len(mlps)): - mlps[k] = [input_channels] + mlps[k] - - self.roi_grid_pool_layer = pointnet2_stack_modules.StackSAModuleMSG( - radii=self.model_cfg.ROI_GRID_POOL.POOL_RADIUS, - nsamples=self.model_cfg.ROI_GRID_POOL.NSAMPLE, - mlps=mlps, - use_xyz=True, - pool_method=self.model_cfg.ROI_GRID_POOL.POOL_METHOD, + self.roi_grid_pool_layer, num_c_out = pointnet2_stack_modules.build_local_aggregation_module( + input_channels=input_channels, config=self.model_cfg.ROI_GRID_POOL ) GRID_SIZE = self.model_cfg.ROI_GRID_POOL.GRID_SIZE - c_out = sum([x[-1] for x in mlps]) - pre_channel = GRID_SIZE * GRID_SIZE * GRID_SIZE * c_out + pre_channel = GRID_SIZE * GRID_SIZE * GRID_SIZE * num_c_out shared_fc_list = [] for k in range(0, self.model_cfg.SHARED_FC.__len__()): @@ -150,9 +141,11 @@ def forward(self, batch_dict): batch_dict, nms_config=self.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST'] ) if self.training: - targets_dict = self.assign_targets(batch_dict) - batch_dict['rois'] = targets_dict['rois'] - batch_dict['roi_labels'] = targets_dict['roi_labels'] + targets_dict = batch_dict.get('roi_targets_dict', None) + if targets_dict is None: + targets_dict = self.assign_targets(batch_dict) + batch_dict['rois'] = targets_dict['rois'] + batch_dict['roi_labels'] = targets_dict['roi_labels'] # RoI aware pooling pooled_features = self.roi_grid_pool(batch_dict) # (BxN, 6x6x6, C) diff --git a/pcdet/ops/pointnet2/pointnet2_stack/pointnet2_modules.py b/pcdet/ops/pointnet2/pointnet2_stack/pointnet2_modules.py index 659af35d9..0210ab296 100644 --- a/pcdet/ops/pointnet2/pointnet2_stack/pointnet2_modules.py +++ b/pcdet/ops/pointnet2/pointnet2_stack/pointnet2_modules.py @@ -7,6 +7,26 @@ from . import pointnet2_utils +def build_local_aggregation_module(input_channels, config): + local_aggregation_name = config.get('NAME', 'StackSAModuleMSG') + + if local_aggregation_name == 'StackSAModuleMSG': + mlps = config.MLPS + for k in range(len(mlps)): + mlps[k] = [input_channels] + mlps[k] + cur_layer = StackSAModuleMSG( + radii=config.POOL_RADIUS, nsamples=config.NSAMPLE, mlps=mlps, use_xyz=True, pool_method='max_pool', + ) + num_c_out = sum([x[-1] for x in mlps]) + elif local_aggregation_name == 'VectorPoolAggregationModuleMSG': + cur_layer = VectorPoolAggregationModuleMSG(input_channels=input_channels, config=config) + num_c_out = config.MSG_POST_MLPS[-1] + else: + raise NotImplementedError + + return cur_layer, num_c_out + + class StackSAModuleMSG(nn.Module): def __init__(self, *, radii: List[float], nsamples: List[int], mlps: List[List[int]], @@ -135,3 +155,316 @@ def forward(self, unknown, unknown_batch_cnt, known, known_batch_cnt, unknown_fe new_features = new_features.squeeze(dim=0).squeeze(dim=-1).permute(1, 0) # (N1 + N2 ..., C) return new_features + + +class VectorPoolLocalInterpolateModule(nn.Module): + def __init__(self, mlp, num_voxels, max_neighbour_distance, nsample, neighbor_type, use_xyz=True, + neighbour_distance_multiplier=1.0, xyz_encoding_type='concat'): + """ + Args: + mlp: + num_voxels: + max_neighbour_distance: + neighbor_type: 1: ball, others: cube + nsample: find all (-1), find limited number(>0) + use_xyz: + neighbour_distance_multiplier: + xyz_encoding_type: + """ + super().__init__() + self.num_voxels = num_voxels # [num_grid_x, num_grid_y, num_grid_z]: number of grids in each local area centered at new_xyz + self.num_total_grids = self.num_voxels[0] * self.num_voxels[1] * self.num_voxels[2] + self.max_neighbour_distance = max_neighbour_distance + self.neighbor_distance_multiplier = neighbour_distance_multiplier + self.nsample = nsample + self.neighbor_type = neighbor_type + self.use_xyz = use_xyz + self.xyz_encoding_type = xyz_encoding_type + + if mlp is not None: + if self.use_xyz: + mlp[0] += 9 if self.xyz_encoding_type == 'concat' else 0 + shared_mlps = [] + for k in range(len(mlp) - 1): + shared_mlps.extend([ + nn.Conv2d(mlp[k], mlp[k + 1], kernel_size=1, bias=False), + nn.BatchNorm2d(mlp[k + 1]), + nn.ReLU() + ]) + self.mlp = nn.Sequential(*shared_mlps) + else: + self.mlp = None + + self.num_avg_length_of_neighbor_idxs = 1000 + + def forward(self, support_xyz, support_features, xyz_batch_cnt, new_xyz, new_xyz_grid_centers, new_xyz_batch_cnt): + """ + Args: + support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features + support_features: (N1 + N2 ..., C) point-wise features + xyz_batch_cnt: (batch_size), [N1, N2, ...] + new_xyz: (M1 + M2 ..., 3) centers of the ball query + new_xyz_grid_centers: (M1 + M2 ..., num_total_grids, 3) grids centers of each grid + new_xyz_batch_cnt: (batch_size), [M1, M2, ...] + Returns: + new_features: (N1 + N2 ..., C_out) + """ + with torch.no_grad(): + dist, idx, num_avg_length_of_neighbor_idxs = pointnet2_utils.three_nn_for_vector_pool_by_two_step( + support_xyz, xyz_batch_cnt, new_xyz, new_xyz_grid_centers, new_xyz_batch_cnt, + self.max_neighbour_distance, self.nsample, self.neighbor_type, + self.num_avg_length_of_neighbor_idxs, self.num_total_grids, self.neighbor_distance_multiplier + ) + self.num_avg_length_of_neighbor_idxs = max(self.num_avg_length_of_neighbor_idxs, num_avg_length_of_neighbor_idxs.item()) + + dist_recip = 1.0 / (dist + 1e-8) + norm = torch.sum(dist_recip, dim=-1, keepdim=True) + weight = dist_recip / torch.clamp_min(norm, min=1e-8) + + empty_mask = (idx.view(-1, 3)[:, 0] == -1) + idx.view(-1, 3)[empty_mask] = 0 + + interpolated_feats = pointnet2_utils.three_interpolate(support_features, idx.view(-1, 3), weight.view(-1, 3)) + interpolated_feats = interpolated_feats.view(idx.shape[0], idx.shape[1], -1) # (M1 + M2 ..., num_total_grids, C) + if self.use_xyz: + near_known_xyz = support_xyz[idx.view(-1, 3).long()].view(-1, 3, 3) # ( (M1 + M2 ...)*num_total_grids, 3) + local_xyz = (new_xyz_grid_centers.view(-1, 1, 3) - near_known_xyz).view(-1, idx.shape[1], 9) + if self.xyz_encoding_type == 'concat': + interpolated_feats = torch.cat((interpolated_feats, local_xyz), dim=-1) # ( M1 + M2 ..., num_total_grids, 9+C) + else: + raise NotImplementedError + + new_features = interpolated_feats.view(-1, interpolated_feats.shape[-1]) # ((M1 + M2 ...) * num_total_grids, C) + new_features[empty_mask, :] = 0 + if self.mlp is not None: + new_features = new_features.permute(1, 0)[None, :, :, None] # (1, C, N1 + N2 ..., 1) + new_features = self.mlp(new_features) + + new_features = new_features.squeeze(dim=0).squeeze(dim=-1).permute(1, 0) # (N1 + N2 ..., C) + return new_features + + +class VectorPoolAggregationModule(nn.Module): + def __init__( + self, input_channels, num_local_voxel=(3, 3, 3), local_aggregation_type='local_interpolation', + num_reduced_channels=30, num_channels_of_local_aggregation=32, post_mlps=(128,), + max_neighbor_distance=None, neighbor_nsample=-1, neighbor_type=0, neighbor_distance_multiplier=2.0): + super().__init__() + self.num_local_voxel = num_local_voxel + self.total_voxels = self.num_local_voxel[0] * self.num_local_voxel[1] * self.num_local_voxel[2] + self.local_aggregation_type = local_aggregation_type + assert self.local_aggregation_type in ['local_interpolation', 'voxel_avg_pool', 'voxel_random_choice'] + self.input_channels = input_channels + self.num_reduced_channels = input_channels if num_reduced_channels is None else num_reduced_channels + self.num_channels_of_local_aggregation = num_channels_of_local_aggregation + self.max_neighbour_distance = max_neighbor_distance + self.neighbor_nsample = neighbor_nsample + self.neighbor_type = neighbor_type # 1: ball, others: cube + + if self.local_aggregation_type == 'local_interpolation': + self.local_interpolate_module = VectorPoolLocalInterpolateModule( + mlp=None, num_voxels=self.num_local_voxel, + max_neighbour_distance=self.max_neighbour_distance, + nsample=self.neighbor_nsample, + neighbor_type=self.neighbor_type, + neighbour_distance_multiplier=neighbor_distance_multiplier, + ) + num_c_in = (self.num_reduced_channels + 9) * self.total_voxels + else: + self.local_interpolate_module = None + num_c_in = (self.num_reduced_channels + 3) * self.total_voxels + + num_c_out = self.total_voxels * self.num_channels_of_local_aggregation + + self.separate_local_aggregation_layer = nn.Sequential( + nn.Conv1d(num_c_in, num_c_out, kernel_size=1, groups=self.total_voxels, bias=False), + nn.BatchNorm1d(num_c_out), + nn.ReLU() + ) + + post_mlp_list = [] + c_in = num_c_out + for cur_num_c in post_mlps: + post_mlp_list.extend([ + nn.Conv1d(c_in, cur_num_c, kernel_size=1, bias=False), + nn.BatchNorm1d(cur_num_c), + nn.ReLU() + ]) + c_in = cur_num_c + self.post_mlps = nn.Sequential(*post_mlp_list) + + self.num_mean_points_per_grid = 20 + self.init_weights() + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0) + + def extra_repr(self) -> str: + ret = f'radius={self.max_neighbour_distance}, local_voxels=({self.num_local_voxel}, ' \ + f'local_aggregation_type={self.local_aggregation_type}, ' \ + f'num_c_reduction={self.input_channels}->{self.num_reduced_channels}, ' \ + f'num_c_local_aggregation={self.num_channels_of_local_aggregation}' + return ret + + def vector_pool_with_voxel_query(self, xyz, xyz_batch_cnt, features, new_xyz, new_xyz_batch_cnt): + use_xyz = 1 + pooling_type = 0 if self.local_aggregation_type == 'voxel_avg_pool' else 1 + + new_features, new_local_xyz, num_mean_points_per_grid, point_cnt_of_grid = pointnet2_utils.vector_pool_with_voxel_query_op( + xyz, xyz_batch_cnt, features, new_xyz, new_xyz_batch_cnt, + self.num_local_voxel[0], self.num_local_voxel[1], self.num_local_voxel[2], + self.max_neighbour_distance, self.num_reduced_channels, use_xyz, + self.num_mean_points_per_grid, self.neighbor_nsample, self.neighbor_type, + pooling_type + ) + self.num_mean_points_per_grid = max(self.num_mean_points_per_grid, num_mean_points_per_grid.item()) + + num_new_pts = new_features.shape[0] + new_local_xyz = new_local_xyz.view(num_new_pts, -1, 3) # (N, num_voxel, 3) + new_features = new_features.view(num_new_pts, -1, self.num_reduced_channels) # (N, num_voxel, C) + new_features = torch.cat((new_local_xyz, new_features), dim=-1).view(num_new_pts, -1) + + return new_features, point_cnt_of_grid + + @staticmethod + def get_dense_voxels_by_center(point_centers, max_neighbour_distance, num_voxels): + """ + Args: + point_centers: (N, 3) + max_neighbour_distance: float + num_voxels: [num_x, num_y, num_z] + + Returns: + voxel_centers: (N, total_voxels, 3) + """ + R = max_neighbour_distance + device = point_centers.device + x_grids = torch.arange(-R + R / num_voxels[0], R - R / num_voxels[0] + 1e-5, 2 * R / num_voxels[0], device=device) + y_grids = torch.arange(-R + R / num_voxels[1], R - R / num_voxels[1] + 1e-5, 2 * R / num_voxels[1], device=device) + z_grids = torch.arange(-R + R / num_voxels[2], R - R / num_voxels[2] + 1e-5, 2 * R / num_voxels[2], device=device) + x_offset, y_offset, z_offset = torch.meshgrid(x_grids, y_grids, z_grids) # shape: [num_x, num_y, num_z] + xyz_offset = torch.cat(( + x_offset.contiguous().view(-1, 1), + y_offset.contiguous().view(-1, 1), + z_offset.contiguous().view(-1, 1)), dim=-1 + ) + voxel_centers = point_centers[:, None, :] + xyz_offset[None, :, :] + return voxel_centers + + def vector_pool_with_local_interpolate(self, xyz, xyz_batch_cnt, features, new_xyz, new_xyz_batch_cnt): + """ + Args: + xyz: (N, 3) + xyz_batch_cnt: (batch_size) + features: (N, C) + new_xyz: (M, 3) + new_xyz_batch_cnt: (batch_size) + Returns: + new_features: (M, total_voxels * C) + """ + voxel_centers = self.get_dense_voxels_by_center( + point_centers=new_xyz, max_neighbour_distance=self.max_neighbour_distance, num_voxels=self.num_local_voxel + ) # (M1 + M2 + ..., total_voxels, 3) + voxel_features = self.local_interpolate_module.forward( + support_xyz=xyz, support_features=features, xyz_batch_cnt=xyz_batch_cnt, + new_xyz=new_xyz, new_xyz_grid_centers=voxel_centers, new_xyz_batch_cnt=new_xyz_batch_cnt + ) # ((M1 + M2 ...) * total_voxels, C) + + voxel_features = voxel_features.contiguous().view(-1, self.total_voxels * voxel_features.shape[-1]) + return voxel_features + + def forward(self, xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt, features, **kwargs): + """ + :param xyz: (N1 + N2 ..., 3) tensor of the xyz coordinates of the features + :param xyz_batch_cnt: (batch_size), [N1, N2, ...] + :param new_xyz: (M1 + M2 ..., 3) + :param new_xyz_batch_cnt: (batch_size), [M1, M2, ...] + :param features: (N1 + N2 ..., C) tensor of the descriptors of the the features + :return: + new_xyz: (M1 + M2 ..., 3) tensor of the new features' xyz + new_features: (M1 + M2 ..., \sum_k(mlps[k][-1])) tensor of the new_features descriptors + """ + N, C = features.shape + + assert C % self.num_reduced_channels == 0, \ + f'the input channels ({C}) should be an integral multiple of num_reduced_channels({self.num_reduced_channels})' + + features = features.view(N, -1, self.num_reduced_channels).sum(dim=1) + + if self.local_aggregation_type in ['voxel_avg_pool', 'voxel_random_choice']: + vector_features, point_cnt_of_grid = self.vector_pool_with_voxel_query( + xyz=xyz, xyz_batch_cnt=xyz_batch_cnt, features=features, + new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt + ) + elif self.local_aggregation_type == 'local_interpolation': + vector_features = self.vector_pool_with_local_interpolate( + xyz=xyz, xyz_batch_cnt=xyz_batch_cnt, features=features, + new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt + ) # (M1 + M2 + ..., total_voxels * C) + else: + raise NotImplementedError + + vector_features = vector_features.permute(1, 0)[None, :, :] # (1, num_voxels * C, M1 + M2 ...) + + new_features = self.separate_local_aggregation_layer(vector_features) + + new_features = self.post_mlps(new_features) + new_features = new_features.squeeze(dim=0).permute(1, 0) + return new_xyz, new_features + + +class VectorPoolAggregationModuleMSG(nn.Module): + def __init__(self, input_channels, config): + super().__init__() + self.model_cfg = config + self.num_groups = self.model_cfg.NUM_GROUPS + + self.layers = [] + c_in = 0 + for k in range(self.num_groups): + cur_config = self.model_cfg[f'GROUP_CFG_{k}'] + cur_vector_pool_module = VectorPoolAggregationModule( + input_channels=input_channels, num_local_voxel=cur_config.NUM_LOCAL_VOXEL, + post_mlps=cur_config.POST_MLPS, + max_neighbor_distance=cur_config.MAX_NEIGHBOR_DISTANCE, + neighbor_nsample=cur_config.NEIGHBOR_NSAMPLE, + local_aggregation_type=self.model_cfg.LOCAL_AGGREGATION_TYPE, + num_reduced_channels=self.model_cfg.get('NUM_REDUCED_CHANNELS', None), + num_channels_of_local_aggregation=self.model_cfg.NUM_CHANNELS_OF_LOCAL_AGGREGATION, + neighbor_distance_multiplier=2.0 + ) + self.__setattr__(f'layer_{k}', cur_vector_pool_module) + c_in += cur_config.POST_MLPS[-1] + + c_in += 3 # use_xyz + + shared_mlps = [] + for cur_num_c in self.model_cfg.MSG_POST_MLPS: + shared_mlps.extend([ + nn.Conv1d(c_in, cur_num_c, kernel_size=1, bias=False), + nn.BatchNorm1d(cur_num_c), + nn.ReLU() + ]) + c_in = cur_num_c + self.msg_post_mlps = nn.Sequential(*shared_mlps) + + def forward(self, **kwargs): + features_list = [] + for k in range(self.num_groups): + cur_xyz, cur_features = self.__getattr__(f'layer_{k}')(**kwargs) + features_list.append(cur_features) + + features = torch.cat(features_list, dim=-1) + features = torch.cat((cur_xyz, features), dim=-1) + features = features.permute(1, 0)[None, :, :] # (1, C, N) + new_features = self.msg_post_mlps(features) + new_features = new_features.squeeze(dim=0).permute(1, 0) # (N, C) + + return cur_xyz, new_features diff --git a/pcdet/ops/pointnet2/pointnet2_stack/pointnet2_utils.py b/pcdet/ops/pointnet2/pointnet2_stack/pointnet2_utils.py index 6aa364ffb..f6f77981d 100644 --- a/pcdet/ops/pointnet2/pointnet2_stack/pointnet2_utils.py +++ b/pcdet/ops/pointnet2/pointnet2_stack/pointnet2_utils.py @@ -299,5 +299,154 @@ def backward(ctx, grad_out: torch.Tensor): three_interpolate = ThreeInterpolate.apply +class ThreeNNForVectorPoolByTwoStep(Function): + @staticmethod + def forward(ctx, support_xyz, xyz_batch_cnt, new_xyz, new_xyz_grid_centers, new_xyz_batch_cnt, + max_neighbour_distance, nsample, neighbor_type, avg_length_of_neighbor_idxs, num_total_grids, + neighbor_distance_multiplier): + """ + Args: + ctx: + // support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features + // xyz_batch_cnt: (batch_size), [N1, N2, ...] + // new_xyz: (M1 + M2 ..., 3) centers of the ball query + // new_xyz_grid_centers: (M1 + M2 ..., num_total_grids, 3) grids centers of each grid + // new_xyz_batch_cnt: (batch_size), [M1, M2, ...] + // nsample: find all (-1), find limited number(>0) + // neighbor_type: 1: ball, others: cube + // neighbor_distance_multiplier: query_distance = neighbor_distance_multiplier * max_neighbour_distance + + Returns: + // new_xyz_grid_idxs: (M1 + M2 ..., num_total_grids, 3) three-nn + // new_xyz_grid_dist2: (M1 + M2 ..., num_total_grids, 3) square of dist of three-nn + """ + num_new_xyz = new_xyz.shape[0] + new_xyz_grid_dist2 = new_xyz_grid_centers.new_zeros(new_xyz_grid_centers.shape) + new_xyz_grid_idxs = new_xyz_grid_centers.new_zeros(new_xyz_grid_centers.shape).int().fill_(-1) + + while True: + num_max_sum_points = avg_length_of_neighbor_idxs * num_new_xyz + stack_neighbor_idxs = new_xyz_grid_idxs.new_zeros(num_max_sum_points) + start_len = new_xyz_grid_idxs.new_zeros(num_new_xyz, 2).int() + cumsum = new_xyz_grid_idxs.new_zeros(1) + + pointnet2.query_stacked_local_neighbor_idxs_wrapper_stack( + support_xyz.contiguous(), xyz_batch_cnt.contiguous(), + new_xyz.contiguous(), new_xyz_batch_cnt.contiguous(), + stack_neighbor_idxs.contiguous(), start_len.contiguous(), cumsum, + avg_length_of_neighbor_idxs, max_neighbour_distance * neighbor_distance_multiplier, + nsample, neighbor_type + ) + avg_length_of_neighbor_idxs = cumsum[0] // num_new_xyz + int(cumsum[0] % num_new_xyz > 0) + + if cumsum[0] <= num_max_sum_points: + break + + stack_neighbor_idxs = stack_neighbor_idxs[:cumsum[0]] + pointnet2.query_three_nn_by_stacked_local_idxs_wrapper_stack( + support_xyz, new_xyz, new_xyz_grid_centers, new_xyz_grid_idxs, new_xyz_grid_dist2, + stack_neighbor_idxs, start_len, num_new_xyz, num_total_grids + ) + + return torch.sqrt(new_xyz_grid_dist2), new_xyz_grid_idxs, avg_length_of_neighbor_idxs + + +three_nn_for_vector_pool_by_two_step = ThreeNNForVectorPoolByTwoStep.apply + + +class VectorPoolWithVoxelQuery(Function): + @staticmethod + def forward(ctx, support_xyz: torch.Tensor, xyz_batch_cnt: torch.Tensor, support_features: torch.Tensor, + new_xyz: torch.Tensor, new_xyz_batch_cnt: torch.Tensor, num_grid_x, num_grid_y, num_grid_z, + max_neighbour_distance, num_c_out_each_grid, use_xyz, + num_mean_points_per_grid=100, nsample=-1, neighbor_type=0, pooling_type=0): + """ + Args: + ctx: + support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features + xyz_batch_cnt: (batch_size), [N1, N2, ...] + support_features: (N1 + N2 ..., C) + new_xyz: (M1 + M2 ..., 3) centers of new positions + new_xyz_batch_cnt: (batch_size), [M1, M2, ...] + num_grid_x: number of grids in each local area centered at new_xyz + num_grid_y: + num_grid_z: + max_neighbour_distance: + num_c_out_each_grid: + use_xyz: + neighbor_type: 1: ball, others: cube: + pooling_type: 0: avg_pool, 1: random choice + Returns: + new_features: (M1 + M2 ..., num_c_out) + """ + assert support_xyz.is_contiguous() + assert support_features.is_contiguous() + assert xyz_batch_cnt.is_contiguous() + assert new_xyz.is_contiguous() + assert new_xyz_batch_cnt.is_contiguous() + num_total_grids = num_grid_x * num_grid_y * num_grid_z + num_c_out = num_c_out_each_grid * num_total_grids + N, num_c_in = support_features.shape + M = new_xyz.shape[0] + + assert num_c_in % num_c_out_each_grid == 0, \ + f'the input channels ({num_c_in}) should be an integral multiple of num_c_out_each_grid({num_c_out_each_grid})' + + while True: + new_features = support_features.new_zeros((M, num_c_out)) + new_local_xyz = support_features.new_zeros((M, 3 * num_total_grids)) + point_cnt_of_grid = xyz_batch_cnt.new_zeros((M, num_total_grids)) + + num_max_sum_points = num_mean_points_per_grid * M + grouped_idxs = xyz_batch_cnt.new_zeros((num_max_sum_points, 3)) + + num_cum_sum = pointnet2.vector_pool_wrapper( + support_xyz, xyz_batch_cnt, support_features, new_xyz, new_xyz_batch_cnt, + new_features, new_local_xyz, point_cnt_of_grid, grouped_idxs, + num_grid_x, num_grid_y, num_grid_z, max_neighbour_distance, use_xyz, + num_max_sum_points, nsample, neighbor_type, pooling_type + ) + num_mean_points_per_grid = num_cum_sum // M + int(num_cum_sum % M > 0) + if num_cum_sum <= num_max_sum_points: + break + + grouped_idxs = grouped_idxs[:num_cum_sum] + + normalizer = torch.clamp_min(point_cnt_of_grid[:, :, None].float(), min=1e-6) + new_features = (new_features.view(-1, num_total_grids, num_c_out_each_grid) / normalizer).view(-1, num_c_out) + + if use_xyz: + new_local_xyz = (new_local_xyz.view(-1, num_total_grids, 3) / normalizer).view(-1, num_total_grids * 3) + + num_mean_points_per_grid = torch.Tensor([num_mean_points_per_grid]).int() + nsample = torch.Tensor([nsample]).int() + ctx.vector_pool_for_backward = (point_cnt_of_grid, grouped_idxs, N, num_c_in) + ctx.mark_non_differentiable(new_local_xyz, num_mean_points_per_grid, nsample, point_cnt_of_grid) + return new_features, new_local_xyz, num_mean_points_per_grid, point_cnt_of_grid + + @staticmethod + def backward(ctx, grad_new_features: torch.Tensor, grad_local_xyz: torch.Tensor, grad_num_cum_sum, grad_point_cnt_of_grid): + """ + Args: + ctx: + grad_new_features: (M1 + M2 ..., num_c_out), num_c_out = num_c_out_each_grid * num_total_grids + + Returns: + grad_support_features: (N1 + N2 ..., C_in) + """ + point_cnt_of_grid, grouped_idxs, N, num_c_in = ctx.vector_pool_for_backward + grad_support_features = grad_new_features.new_zeros((N, num_c_in)) + + pointnet2.vector_pool_grad_wrapper( + grad_new_features.contiguous(), point_cnt_of_grid, grouped_idxs, + grad_support_features + ) + + return None, None, grad_support_features, None, None, None, None, None, None, None, None, None, None, None, None + + +vector_pool_with_voxel_query_op = VectorPoolWithVoxelQuery.apply + + if __name__ == '__main__': pass diff --git a/pcdet/ops/pointnet2/pointnet2_stack/src/pointnet2_api.cpp b/pcdet/ops/pointnet2/pointnet2_stack/src/pointnet2_api.cpp index abb94ca78..1b61e4158 100644 --- a/pcdet/ops/pointnet2/pointnet2_stack/src/pointnet2_api.cpp +++ b/pcdet/ops/pointnet2/pointnet2_stack/src/pointnet2_api.cpp @@ -6,6 +6,7 @@ #include "sampling_gpu.h" #include "interpolate_gpu.h" #include "voxel_query_gpu.h" +#include "vector_pool_gpu.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -21,4 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("three_nn_wrapper", &three_nn_wrapper_stack, "three_nn_wrapper_stack"); m.def("three_interpolate_wrapper", &three_interpolate_wrapper_stack, "three_interpolate_wrapper_stack"); m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_stack, "three_interpolate_grad_wrapper_stack"); + + m.def("query_stacked_local_neighbor_idxs_wrapper_stack", &query_stacked_local_neighbor_idxs_wrapper_stack, "query_stacked_local_neighbor_idxs_wrapper_stack"); + m.def("query_three_nn_by_stacked_local_idxs_wrapper_stack", &query_three_nn_by_stacked_local_idxs_wrapper_stack, "query_three_nn_by_stacked_local_idxs_wrapper_stack"); + + m.def("vector_pool_wrapper", &vector_pool_wrapper_stack, "vector_pool_grad_wrapper_stack"); + m.def("vector_pool_grad_wrapper", &vector_pool_grad_wrapper_stack, "vector_pool_grad_wrapper_stack"); } diff --git a/pcdet/ops/pointnet2/pointnet2_stack/src/vector_pool.cpp b/pcdet/ops/pointnet2/pointnet2_stack/src/vector_pool.cpp new file mode 100644 index 000000000..308beea79 --- /dev/null +++ b/pcdet/ops/pointnet2/pointnet2_stack/src/vector_pool.cpp @@ -0,0 +1,203 @@ +/* +Vector-pool aggregation based local feature aggregation for point cloud. +PV-RCNN++: Point-Voxel Feature Set Abstraction With Local Vector Representation for 3D Object Detection +https://arxiv.org/abs/2102.00463 + +Written by Shaoshuai Shi +All Rights Reserved 2020. +*/ + + +#include +#include +#include +#include +#include +#include "vector_pool_gpu.h" + +extern THCState *state; + +#define CHECK_CUDA(x) do { \ + if (!x.type().is_cuda()) { \ + fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \ + exit(-1); \ + } \ +} while (0) +#define CHECK_CONTIGUOUS(x) do { \ + if (!x.is_contiguous()) { \ + fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \ + exit(-1); \ + } \ +} while (0) +#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) + + +int query_stacked_local_neighbor_idxs_wrapper_stack(at::Tensor support_xyz_tensor, at::Tensor xyz_batch_cnt_tensor, + at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor, + at::Tensor stack_neighbor_idxs_tensor, at::Tensor start_len_tensor, at::Tensor cumsum_tensor, + int avg_length_of_neighbor_idxs, float max_neighbour_distance, int nsample, int neighbor_type){ + // support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features + // xyz_batch_cnt: (batch_size), [N1, N2, ...] + // new_xyz: (M1 + M2 ..., 3) centers of the ball query + // new_xyz_grid_centers: (M1 + M2 ..., num_total_grids, 3) grids centers of each grid + // new_xyz_batch_cnt: (batch_size), [M1, M2, ...] + // new_xyz_grid_idxs: (M1 + M2 ..., num_total_grids, 3) three-nn + // new_xyz_grid_dist2: (M1 + M2 ..., num_total_grids, 3) square of dist of three-nn + // num_grid_x, num_grid_y, num_grid_z: number of grids in each local area centered at new_xyz + // nsample: find all (-1), find limited number(>0) + // neighbor_type: 1: ball, others: cube + + CHECK_INPUT(support_xyz_tensor); + CHECK_INPUT(xyz_batch_cnt_tensor); + CHECK_INPUT(new_xyz_tensor); + CHECK_INPUT(new_xyz_batch_cnt_tensor); + CHECK_INPUT(stack_neighbor_idxs_tensor); + CHECK_INPUT(start_len_tensor); + CHECK_INPUT(cumsum_tensor); + + const float *support_xyz = support_xyz_tensor.data(); + const int *xyz_batch_cnt = xyz_batch_cnt_tensor.data(); + const float *new_xyz = new_xyz_tensor.data(); + const int *new_xyz_batch_cnt = new_xyz_batch_cnt_tensor.data(); + int *stack_neighbor_idxs = stack_neighbor_idxs_tensor.data(); + int *start_len = start_len_tensor.data(); + int *cumsum = cumsum_tensor.data(); + + int batch_size = xyz_batch_cnt_tensor.size(0); + int M = new_xyz_tensor.size(0); + + query_stacked_local_neighbor_idxs_kernel_launcher_stack( + support_xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt, + stack_neighbor_idxs, start_len, cumsum, avg_length_of_neighbor_idxs, + max_neighbour_distance, batch_size, M, nsample, neighbor_type + ); + return 0; +} + + +int query_three_nn_by_stacked_local_idxs_wrapper_stack(at::Tensor support_xyz_tensor, + at::Tensor new_xyz_tensor, at::Tensor new_xyz_grid_centers_tensor, + at::Tensor new_xyz_grid_idxs_tensor, at::Tensor new_xyz_grid_dist2_tensor, + at::Tensor stack_neighbor_idxs_tensor, at::Tensor start_len_tensor, + int M, int num_total_grids){ + // support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features + // new_xyz: (M1 + M2 ..., 3) centers of the ball query + // new_xyz_grid_centers: (M1 + M2 ..., num_total_grids, 3) grids centers of each grid + // new_xyz_grid_idxs: (M1 + M2 ..., num_total_grids, 3) three-nn + // new_xyz_grid_dist2: (M1 + M2 ..., num_total_grids, 3) square of dist of three-nn + // stack_neighbor_idxs: (max_length_of_neighbor_idxs) + // start_len: (M1 + M2, 2) [start_offset, neighbor_length] + + CHECK_INPUT(support_xyz_tensor); + CHECK_INPUT(new_xyz_tensor); + CHECK_INPUT(new_xyz_grid_centers_tensor); + CHECK_INPUT(new_xyz_grid_idxs_tensor); + CHECK_INPUT(new_xyz_grid_dist2_tensor); + CHECK_INPUT(stack_neighbor_idxs_tensor); + CHECK_INPUT(start_len_tensor); + + const float *support_xyz = support_xyz_tensor.data(); + const float *new_xyz = new_xyz_tensor.data(); + const float *new_xyz_grid_centers = new_xyz_grid_centers_tensor.data(); + int *new_xyz_grid_idxs = new_xyz_grid_idxs_tensor.data(); + float *new_xyz_grid_dist2 = new_xyz_grid_dist2_tensor.data(); + int *stack_neighbor_idxs = stack_neighbor_idxs_tensor.data(); + int *start_len = start_len_tensor.data(); + + query_three_nn_by_stacked_local_idxs_kernel_launcher_stack( + support_xyz, new_xyz, new_xyz_grid_centers, + new_xyz_grid_idxs, new_xyz_grid_dist2, stack_neighbor_idxs, start_len, + M, num_total_grids + ); + return 0; +} + + +int vector_pool_wrapper_stack(at::Tensor support_xyz_tensor, at::Tensor xyz_batch_cnt_tensor, + at::Tensor support_features_tensor, at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor, + at::Tensor new_features_tensor, at::Tensor new_local_xyz_tensor, + at::Tensor point_cnt_of_grid_tensor, at::Tensor grouped_idxs_tensor, + int num_grid_x, int num_grid_y, int num_grid_z, float max_neighbour_distance, int use_xyz, + int num_max_sum_points, int nsample, int neighbor_type, int pooling_type){ + // support_xyz_tensor: (N1 + N2 ..., 3) xyz coordinates of the features + // support_features_tensor: (N1 + N2 ..., C) + // xyz_batch_cnt: (batch_size), [N1, N2, ...] + // new_xyz_tensor: (M1 + M2 ..., 3) centers of new positions + // new_features_tensor: (M1 + M2 ..., C) + // new_xyz_batch_cnt: (batch_size), [M1, M2, ...] + // point_cnt_of_grid: (M1 + M2 ..., num_total_grids) + // grouped_idxs_tensor: (num_max_sum_points, 3) + // num_grid_x, num_grid_y, num_grid_z: number of grids in each local area centered at new_xyz + // use_xyz: whether to calculate new_local_xyz + // neighbor_type: 1: ball, others: cube + // pooling_type: 0: avg_pool, 1: random choice + + CHECK_INPUT(support_xyz_tensor); + CHECK_INPUT(support_features_tensor); + CHECK_INPUT(xyz_batch_cnt_tensor); + CHECK_INPUT(new_xyz_tensor); + CHECK_INPUT(new_xyz_batch_cnt_tensor); + CHECK_INPUT(new_features_tensor); + CHECK_INPUT(new_local_xyz_tensor); + CHECK_INPUT(point_cnt_of_grid_tensor); + CHECK_INPUT(grouped_idxs_tensor); + + const float *support_xyz = support_xyz_tensor.data(); + const float *support_features = support_features_tensor.data(); + const int *xyz_batch_cnt = xyz_batch_cnt_tensor.data(); + const float *new_xyz = new_xyz_tensor.data(); + const int *new_xyz_batch_cnt = new_xyz_batch_cnt_tensor.data(); + float *new_features = new_features_tensor.data(); + float *new_local_xyz = new_local_xyz_tensor.data(); + int *point_cnt_of_grid = point_cnt_of_grid_tensor.data(); + int *grouped_idxs = grouped_idxs_tensor.data(); + + int N = support_xyz_tensor.size(0); + int batch_size = xyz_batch_cnt_tensor.size(0); + int M = new_xyz_tensor.size(0); + int num_c_out = new_features_tensor.size(1); + int num_c_in = support_features_tensor.size(1); + int num_total_grids = point_cnt_of_grid_tensor.size(1); + + int cum_sum = vector_pool_kernel_launcher_stack( + support_xyz, support_features, xyz_batch_cnt, + new_xyz, new_features, new_local_xyz, new_xyz_batch_cnt, + point_cnt_of_grid, grouped_idxs, + num_grid_x, num_grid_y, num_grid_z, max_neighbour_distance, + batch_size, N, M, num_c_in, num_c_out, num_total_grids, use_xyz, num_max_sum_points, nsample, neighbor_type, pooling_type + ); + return cum_sum; +} + + +int vector_pool_grad_wrapper_stack(at::Tensor grad_new_features_tensor, + at::Tensor point_cnt_of_grid_tensor, at::Tensor grouped_idxs_tensor, + at::Tensor grad_support_features_tensor) { + // grad_new_features_tensor: (M1 + M2 ..., C_out) + // point_cnt_of_grid_tensor: (M1 + M2 ..., num_total_grids) + // grouped_idxs_tensor: (num_max_sum_points, 3) [idx of support_xyz, idx of new_xyz, idx of grid_idx in new_xyz] + // grad_support_features_tensor: (N1 + N2 ..., C_in) + + CHECK_INPUT(grad_new_features_tensor); + CHECK_INPUT(point_cnt_of_grid_tensor); + CHECK_INPUT(grouped_idxs_tensor); + CHECK_INPUT(grad_support_features_tensor); + + int M = grad_new_features_tensor.size(0); + int num_c_out = grad_new_features_tensor.size(1); + int N = grad_support_features_tensor.size(0); + int num_c_in = grad_support_features_tensor.size(1); + int num_total_grids = point_cnt_of_grid_tensor.size(1); + int num_max_sum_points = grouped_idxs_tensor.size(0); + + const float *grad_new_features = grad_new_features_tensor.data(); + const int *point_cnt_of_grid = point_cnt_of_grid_tensor.data(); + const int *grouped_idxs = grouped_idxs_tensor.data(); + float *grad_support_features = grad_support_features_tensor.data(); + + vector_pool_grad_kernel_launcher_stack( + grad_new_features, point_cnt_of_grid, grouped_idxs, grad_support_features, + N, M, num_c_out, num_c_in, num_total_grids, num_max_sum_points + ); + return 1; +} diff --git a/pcdet/ops/pointnet2/pointnet2_stack/src/vector_pool_gpu.cu b/pcdet/ops/pointnet2/pointnet2_stack/src/vector_pool_gpu.cu new file mode 100644 index 000000000..8f05e266c --- /dev/null +++ b/pcdet/ops/pointnet2/pointnet2_stack/src/vector_pool_gpu.cu @@ -0,0 +1,486 @@ +/* +Vector-pool aggregation based local feature aggregation for point cloud. +PV-RCNN++: Point-Voxel Feature Set Abstraction With Local Vector Representation for 3D Object Detection +https://arxiv.org/abs/2102.00463 + +Written by Shaoshuai Shi +All Rights Reserved 2020. +*/ + + +#include +#include +#include + +#include "vector_pool_gpu.h" +#include "cuda_utils.h" + + +__global__ void query_three_nn_by_stacked_local_idxs_kernel( + const float *support_xyz, const float *new_xyz, const float *new_xyz_grid_centers, + int *new_xyz_grid_idxs, float *new_xyz_grid_dist2, + const int *stack_neighbor_idxs, const int *start_len, + int M, int num_total_grids){ + // support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features + // new_xyz: (M1 + M2 ..., 3) centers of the ball query + // new_xyz_grid_centers: (M1 + M2 ..., num_total_grids, 3) grids centers of each grid + // new_xyz_grid_idxs: (M1 + M2 ..., num_total_grids, 3) three-nn + // new_xyz_grid_dist2: (M1 + M2 ..., num_total_grids, 3) square of dist of three-nn + // stack_neighbor_idxs: (max_length_of_neighbor_idxs) + // start_len: (M1 + M2, 2) [start_offset, neighbor_length] + + int grid_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (pt_idx >= M || grid_idx >= num_total_grids) return; + + new_xyz += pt_idx * 3; + new_xyz_grid_centers += pt_idx * num_total_grids * 3 + grid_idx * 3; + new_xyz_grid_idxs += pt_idx * num_total_grids * 3 + grid_idx * 3; + new_xyz_grid_dist2 += pt_idx * num_total_grids * 3 + grid_idx * 3; + + start_len += pt_idx * 2; + stack_neighbor_idxs += start_len[0]; + int neighbor_length = start_len[1]; + + float center_x = new_xyz_grid_centers[0]; + float center_y = new_xyz_grid_centers[1]; + float center_z = new_xyz_grid_centers[2]; + + double best1 = 1e40, best2 = 1e40, best3 = 1e40; + int besti1 = -1, besti2 = -1, besti3 = -1; + for (int k = 0; k < neighbor_length; k++){ + int cur_neighbor_idx = stack_neighbor_idxs[k]; + + float x = support_xyz[cur_neighbor_idx * 3 + 0]; + float y = support_xyz[cur_neighbor_idx * 3 + 1]; + float z = support_xyz[cur_neighbor_idx * 3 + 2]; + + float d = (center_x - x) * (center_x - x) + (center_y - y) * (center_y - y) + (center_z - z) * (center_z - z); + + if (d < best1) { + best3 = best2; besti3 = besti2; + best2 = best1; besti2 = besti1; + best1 = d; besti1 = cur_neighbor_idx; + } + else if (d < best2) { + best3 = best2; besti3 = besti2; + best2 = d; besti2 = cur_neighbor_idx; + } + else if (d < best3) { + best3 = d; besti3 = cur_neighbor_idx; + } + } + if (besti2 == -1){ + besti2 = besti1; best2 = best1; + } + if (besti3 == -1){ + besti3 = besti1; best3 = best1; + } + new_xyz_grid_dist2[0] = best1; + new_xyz_grid_dist2[1] = best2; + new_xyz_grid_dist2[2] = best3; + new_xyz_grid_idxs[0] = besti1; + new_xyz_grid_idxs[1] = besti2; + new_xyz_grid_idxs[2] = besti3; +} + + +int query_three_nn_by_stacked_local_idxs_kernel_launcher_stack( + const float *support_xyz, const float *new_xyz, const float *new_xyz_grid_centers, + int *new_xyz_grid_idxs, float *new_xyz_grid_dist2, + const int *stack_neighbor_idxs, const int *start_len, + int M, int num_total_grids){ + // support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features + // new_xyz: (M1 + M2 ..., 3) centers of the ball query + // new_xyz_grid_centers: (M1 + M2 ..., num_total_grids, 3) grids centers of each grid + // new_xyz_grid_idxs: (M1 + M2 ..., num_total_grids, 3) three-nn + // new_xyz_grid_dist2: (M1 + M2 ..., num_total_grids, 3) square of dist of three-nn + // stack_neighbor_idxs: (max_length_of_neighbor_idxs) + // start_len: (M1 + M2, 2) [start_offset, neighbor_length] + + cudaError_t err; + dim3 blocks(DIVUP(M, THREADS_PER_BLOCK), num_total_grids); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + query_three_nn_by_stacked_local_idxs_kernel<<>>( + support_xyz, new_xyz, new_xyz_grid_centers, + new_xyz_grid_idxs, new_xyz_grid_dist2, stack_neighbor_idxs, start_len, + M, num_total_grids + ); + + // cudaDeviceSynchronize(); // for using printf in kernel function + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } + return 0; +} + + +__global__ void query_stacked_local_neighbor_idxs_kernel( + const float *support_xyz, const int *xyz_batch_cnt, const float *new_xyz, const int *new_xyz_batch_cnt, + int *stack_neighbor_idxs, int *start_len, int *cumsum, int avg_length_of_neighbor_idxs, + float max_neighbour_distance, int batch_size, int M, int nsample, int neighbor_type){ + // support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features + // xyz_batch_cnt: (batch_size), [N1, N2, ...] + // new_xyz: (M1 + M2 ..., 3) centers of the ball query + // new_xyz_batch_cnt: (batch_size), [M1, M2, ...] + // stack_neighbor_idxs: (max_length_of_neighbor_idxs) + // start_len: (M1 + M2, 2) [start_offset, neighbor_length] + // cumsum: (1), max offset of current data in stack_neighbor_idxs + // max_neighbour_distance: float + // nsample: find all (-1), find limited number(>0) + // neighbor_type: 1: ball, others: cube + + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (pt_idx >= M) return; + + int bs_idx = 0, pt_cnt = new_xyz_batch_cnt[0]; + for (int k = 1; k < batch_size; k++){ + if (pt_idx < pt_cnt) break; + pt_cnt += new_xyz_batch_cnt[k]; + bs_idx = k; + } + + int xyz_batch_start_idx = 0; + for (int k = 0; k < bs_idx; k++) xyz_batch_start_idx += xyz_batch_cnt[k]; + + support_xyz += xyz_batch_start_idx * 3; + new_xyz += pt_idx * 3; + start_len += pt_idx * 2; + + float new_x = new_xyz[0]; + float new_y = new_xyz[1]; + float new_z = new_xyz[2]; + int n = xyz_batch_cnt[bs_idx]; + + float local_x, local_y, local_z; + float radius2 = max_neighbour_distance * max_neighbour_distance; + + int temp_idxs[1000]; + + int sample_cnt = 0; + for (int k = 0; k < n; ++k) { + local_x = support_xyz[k * 3 + 0] - new_x; + local_y = support_xyz[k * 3 + 1] - new_y; + local_z = support_xyz[k * 3 + 2] - new_z; + + if (neighbor_type == 1){ + // ball + if (local_x * local_x + local_y * local_y + local_z * local_z > radius2){ + continue; + } + } + else{ + // voxel + if ((fabs(local_x) > max_neighbour_distance) | + (fabs(local_y) > max_neighbour_distance) | + (fabs(local_z) > max_neighbour_distance)){ + continue; + } + } + if (sample_cnt < 1000){ + temp_idxs[sample_cnt] = k; + } + else{ + break; + } + sample_cnt++; + if (nsample > 0 && sample_cnt >= nsample) break; + } + start_len[0] = atomicAdd(cumsum, sample_cnt); + start_len[1] = sample_cnt; + + int max_thresh = avg_length_of_neighbor_idxs * M; + if (start_len[0] >= max_thresh) return; + + stack_neighbor_idxs += start_len[0]; + if (start_len[0] + sample_cnt >= max_thresh) sample_cnt = max_thresh - start_len[0]; + + for (int k = 0; k < sample_cnt; k++){ + stack_neighbor_idxs[k] = temp_idxs[k] + xyz_batch_start_idx; + } +} + + +int query_stacked_local_neighbor_idxs_kernel_launcher_stack( + const float *support_xyz, const int *xyz_batch_cnt, const float *new_xyz, const int *new_xyz_batch_cnt, + int *stack_neighbor_idxs, int *start_len, int *cumsum, int avg_length_of_neighbor_idxs, + float max_neighbour_distance, int batch_size, int M, int nsample, int neighbor_type){ + // support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features + // xyz_batch_cnt: (batch_size), [N1, N2, ...] + // new_xyz: (M1 + M2 ..., 3) centers of the ball query + // new_xyz_batch_cnt: (batch_size), [M1, M2, ...] + // stack_neighbor_idxs: (max_length_of_neighbor_idxs) + // start_len: (M1 + M2, 2) [start_offset, neighbor_length] + // cumsum: (1), max offset of current data in stack_neighbor_idxs + // max_neighbour_distance: float + // nsample: find all (-1), find limited number(>0) + // neighbor_type: 1: ball, others: cube + + cudaError_t err; + dim3 blocks(DIVUP(M, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + query_stacked_local_neighbor_idxs_kernel<<>>( + support_xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt, + stack_neighbor_idxs, start_len, cumsum, avg_length_of_neighbor_idxs, + max_neighbour_distance, batch_size, M, nsample, neighbor_type + ); + + // cudaDeviceSynchronize(); // for using printf in kernel function + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } + return 0; +} + + +__global__ void vector_pool_kernel_stack( + const float *support_xyz, const float *support_features, const int *xyz_batch_cnt, + const float *new_xyz, float *new_features, float *new_local_xyz, const int *new_xyz_batch_cnt, + int num_grid_x, int num_grid_y, int num_grid_z, float max_neighbour_distance, + int batch_size, int M, int num_c_in, int num_c_out, + int num_c_each_grid, int num_total_grids, int *point_cnt_of_grid, int *grouped_idxs, + int use_xyz, float grid_size_x, float grid_size_y, + float grid_size_z, int *cum_sum, int num_max_sum_points, int nsample, int neighbor_type, int pooling_type){ + // support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features + // support_features: (N1 + N2 ..., C) + // xyz_batch_cnt: (batch_size), [N1, N2, ...] + // new_xyz: (M1 + M2 ..., 3) centers of the ball query + // new_features: (M1 + M2 ..., C), C = num_total_grids * num_c_each_grid + // new_local_xyz: (M1 + M2 ..., 3 * num_total_grids) + // new_xyz_batch_cnt: (batch_size), [M1, M2, ...] + // num_grid_x, num_grid_y, num_grid_z: number of grids in each local area centered at new_xyz + // point_cnt_of_grid: (M1 + M2 ..., num_total_grids) + // grouped_idxs: (num_max_sum_points, 3)[idx of support_xyz, idx of new_xyz, idx of grid_idx in new_xyz] + // use_xyz: whether to calculate new_local_xyz + // neighbor_type: 1: ball, others: cube + // pooling_type: 0: avg_pool, 1: random choice + + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (pt_idx >= M) return; + + int bs_idx = 0, pt_cnt = new_xyz_batch_cnt[0]; + for (int k = 1; k < batch_size; k++){ + if (pt_idx < pt_cnt) break; + pt_cnt += new_xyz_batch_cnt[k]; + bs_idx = k; + } + + int xyz_batch_start_idx = 0; + for (int k = 0; k < bs_idx; k++) xyz_batch_start_idx += xyz_batch_cnt[k]; + + support_xyz += xyz_batch_start_idx * 3; + support_features += xyz_batch_start_idx * num_c_in; + + new_xyz += pt_idx * 3; + new_features += pt_idx * num_c_out; + point_cnt_of_grid += pt_idx * num_total_grids; + new_local_xyz += pt_idx * 3 * num_total_grids; + + float new_x = new_xyz[0]; + float new_y = new_xyz[1]; + float new_z = new_xyz[2]; + int n = xyz_batch_cnt[bs_idx], grid_idx_x, grid_idx_y, grid_idx_z, grid_idx; + float local_x, local_y, local_z; + float radius2 = max_neighbour_distance * max_neighbour_distance; + + int sample_cnt = 0; + for (int k = 0; k < n; ++k) { + local_x = support_xyz[k * 3 + 0] - new_x; + local_y = support_xyz[k * 3 + 1] - new_y; + local_z = support_xyz[k * 3 + 2] - new_z; + + if (neighbor_type == 1){ + // ball + if (local_x * local_x + local_y * local_y + local_z * local_z > radius2){ + continue; + } + } + else{ + // voxel + if ((fabs(local_x) > max_neighbour_distance) | + (fabs(local_y) > max_neighbour_distance) | + (fabs(local_z) > max_neighbour_distance)){ + continue; + } + } + + grid_idx_x = floorf((local_x + max_neighbour_distance) / grid_size_x); + grid_idx_y = floorf((local_y + max_neighbour_distance) / grid_size_y); + grid_idx_z = floorf((local_z + max_neighbour_distance) / grid_size_z); + grid_idx = grid_idx_x * num_grid_y * num_grid_z + grid_idx_y * num_grid_z + grid_idx_z; + grid_idx = min(max(grid_idx, 0), num_total_grids - 1); + + if (pooling_type == 0){ + // avg pooling + point_cnt_of_grid[grid_idx] ++; + + for (int i = 0; i < num_c_in; i++){ + new_features[grid_idx * num_c_each_grid + i % num_c_each_grid] += support_features[k * num_c_in + i]; + } + if (use_xyz){ + new_local_xyz[grid_idx * 3 + 0] += local_x; + new_local_xyz[grid_idx * 3 + 1] += local_y; + new_local_xyz[grid_idx * 3 + 2] += local_z; + } + + int cnt = atomicAdd(cum_sum, 1); + if (cnt >= num_max_sum_points) continue; // continue to statistics the max number of points + + grouped_idxs[cnt * 3 + 0] = xyz_batch_start_idx + k; + grouped_idxs[cnt * 3 + 1] = pt_idx; + grouped_idxs[cnt * 3 + 2] = grid_idx; + + sample_cnt++; + if(nsample > 0 && sample_cnt >= nsample) break; + } + else if (pooling_type == 1){ + // random choose one within sub-voxel + // printf("new_xyz=(%.2f, %.2f, %.2f, ), find neighbor k=%d: support_xyz=(%.2f, %.2f, %.2f), local_xyz=(%.2f, %.2f, %.2f), neighbor=%.2f, grid_idx=%d, point_cnt_of_grid_idx=%d\n", + // new_x, new_y, new_z, k, support_xyz[k * 3 + 0], support_xyz[k * 3 + 1], support_xyz[k * 3 + 2], local_x, local_y, local_z, max_neighbour_distance, grid_idx, point_cnt_of_grid[grid_idx]); + + if (point_cnt_of_grid[grid_idx] == 0){ + point_cnt_of_grid[grid_idx] ++; + for (int i = 0; i < num_c_in; i++){ + new_features[grid_idx * num_c_each_grid + i % num_c_each_grid] = support_features[k * num_c_in + i]; + } + if (use_xyz){ + new_local_xyz[grid_idx * 3 + 0] = local_x; + new_local_xyz[grid_idx * 3 + 1] = local_y; + new_local_xyz[grid_idx * 3 + 2] = local_z; + } + + int cnt = atomicAdd(cum_sum, 1); + if (cnt >= num_max_sum_points) continue; // continue to statistics the max number of points + + grouped_idxs[cnt * 3 + 0] = xyz_batch_start_idx + k; + grouped_idxs[cnt * 3 + 1] = pt_idx; + grouped_idxs[cnt * 3 + 2] = grid_idx; + + sample_cnt++; + if(nsample > 0 && sample_cnt >= nsample || sample_cnt >= num_total_grids) break; + } + + } + + } +} + + +int vector_pool_kernel_launcher_stack( + const float *support_xyz, const float *support_features, const int *xyz_batch_cnt, + const float *new_xyz, float *new_features, float *new_local_xyz, const int *new_xyz_batch_cnt, + int *point_cnt_of_grid, int *grouped_idxs, + int num_grid_x, int num_grid_y, int num_grid_z, float max_neighbour_distance, + int batch_size, int N, int M, int num_c_in, int num_c_out, int num_total_grids, + int use_xyz, int num_max_sum_points, int nsample, int neighbor_type, int pooling_type){ + // support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features + // support_features: (N1 + N2 ..., C) + // xyz_batch_cnt: (batch_size), [N1, N2, ...] + // new_xyz: (M1 + M2 ..., 3) centers of the ball query + // new_features: (M1 + M2 ..., C) + // new_local_xyz: (M1 + M2 ..., 3) + // new_xyz_batch_cnt: (batch_size), [M1, M2, ...] + // num_grid_x, num_grid_y, num_grid_z: number of grids in each local area centered at new_xyz + // use_xyz: whether to calculate new_local_xyz + // grouped_idxs: (num_max_sum_points, 3)[idx of support_xyz, idx of new_xyz, idx of grid_idx in new_xyz] + // neighbor_type: 1: ball, others: cube + // pooling_type: 0: avg_pool, 1: random choice + + + cudaError_t err; + int num_c_each_grid = num_c_out / num_total_grids; + float grid_size_x = max_neighbour_distance * 2 / num_grid_x; + float grid_size_y = max_neighbour_distance * 2 / num_grid_y; + float grid_size_z = max_neighbour_distance * 2 / num_grid_z; + + dim3 blocks(DIVUP(M, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + int cum_sum = 0; + int *p_cum_sum; + cudaMalloc((void**)&p_cum_sum, sizeof(int)); + cudaMemcpy(p_cum_sum, &cum_sum, sizeof(int), cudaMemcpyHostToDevice); + + vector_pool_kernel_stack<<>>( + support_xyz, support_features, xyz_batch_cnt, + new_xyz, new_features, new_local_xyz, new_xyz_batch_cnt, + num_grid_x, num_grid_y, num_grid_z, max_neighbour_distance, + batch_size, M, num_c_in, num_c_out, + num_c_each_grid, num_total_grids, point_cnt_of_grid, grouped_idxs, + use_xyz, grid_size_x, grid_size_y, grid_size_z, p_cum_sum, num_max_sum_points, + nsample, neighbor_type, pooling_type + ); + + cudaMemcpy(&cum_sum, p_cum_sum, sizeof(int), cudaMemcpyDeviceToHost); + + // cudaDeviceSynchronize(); // for using printf in kernel function + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } + return cum_sum; +} + + +__global__ void vector_pool_grad_kernel_stack(const float *grad_new_features, + const int *point_cnt_of_grid, const int *grouped_idxs, + float *grad_support_features, int N, int M, int num_c_out, int num_c_in, + int num_c_each_grid, int num_total_grids, int num_max_sum_points){ + // grad_new_features: (M1 + M2 ..., C_out) + // point_cnt_of_grid: (M1 + M2 ..., num_total_grids) + // grouped_idxs: (num_max_sum_points, 3) [idx of support_xyz, idx of new_xyz, idx of grid_idx in new_xyz] + // grad_support_features: (N1 + N2 ..., C_in) + + int channel_idx = blockIdx.y; + int index = blockIdx.x * blockDim.x + threadIdx.x; + + if (index >= num_max_sum_points || channel_idx >= num_c_in) return; + + int idx_of_support_xyz = grouped_idxs[index * 3 + 0]; + int idx_of_new_xyz = grouped_idxs[index * 3 + 1]; + int idx_of_grid_idx = grouped_idxs[index * 3 + 2]; + + int num_total_pts = point_cnt_of_grid[idx_of_new_xyz * num_total_grids + idx_of_grid_idx]; + grad_support_features += idx_of_support_xyz * num_c_in + channel_idx; + + grad_new_features += idx_of_new_xyz * num_c_out + idx_of_grid_idx * num_c_each_grid; + int channel_idx_of_cin = channel_idx % num_c_each_grid; + float cur_grad = 1 / fmaxf(float(num_total_pts), 1.0); + atomicAdd(grad_support_features, grad_new_features[channel_idx_of_cin] * cur_grad); +} + + +void vector_pool_grad_kernel_launcher_stack( + const float *grad_new_features, const int *point_cnt_of_grid, const int *grouped_idxs, + float *grad_support_features, int N, int M, int num_c_out, int num_c_in, int num_total_grids, + int num_max_sum_points){ + // grad_new_features: (M1 + M2 ..., C_out) + // point_cnt_of_grid: (M1 + M2 ..., num_total_grids) + // grouped_idxs: (num_max_sum_points, 3) [idx of support_xyz, idx of new_xyz, idx of grid_idx in new_xyz] + // grad_support_features: (N1 + N2 ..., C_in) + int num_c_each_grid = num_c_out / num_total_grids; + + cudaError_t err; + + dim3 blocks(DIVUP(num_max_sum_points, THREADS_PER_BLOCK), num_c_in); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + vector_pool_grad_kernel_stack<<>>( + grad_new_features, point_cnt_of_grid, grouped_idxs, grad_support_features, + N, M, num_c_out, num_c_in, num_c_each_grid, num_total_grids, num_max_sum_points + ); + // cudaDeviceSynchronize(); // for using printf in kernel function + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} \ No newline at end of file diff --git a/pcdet/ops/pointnet2/pointnet2_stack/src/vector_pool_gpu.h b/pcdet/ops/pointnet2/pointnet2_stack/src/vector_pool_gpu.h new file mode 100644 index 000000000..febfb8553 --- /dev/null +++ b/pcdet/ops/pointnet2/pointnet2_stack/src/vector_pool_gpu.h @@ -0,0 +1,71 @@ +/* +Vector-pool aggregation based local feature aggregation for point cloud. +PV-RCNN++: Point-Voxel Feature Set Abstraction With Local Vector Representation for 3D Object Detection +https://arxiv.org/abs/2102.00463 + +Written by Shaoshuai Shi +All Rights Reserved 2020. +*/ + + +#ifndef _STACK_VECTOR_POOL_GPU_H +#define _STACK_VECTOR_POOL_GPU_H + +#include +#include +#include +#include + + +int query_stacked_local_neighbor_idxs_kernel_launcher_stack( + const float *support_xyz, const int *xyz_batch_cnt, const float *new_xyz, const int *new_xyz_batch_cnt, + int *stack_neighbor_idxs, int *start_len, int *cumsum, int avg_length_of_neighbor_idxs, + float max_neighbour_distance, int batch_size, int M, int nsample, int neighbor_type); + +int query_stacked_local_neighbor_idxs_wrapper_stack(at::Tensor support_xyz_tensor, at::Tensor xyz_batch_cnt_tensor, + at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor, + at::Tensor stack_neighbor_idxs_tensor, at::Tensor start_len_tensor, at::Tensor cumsum_tensor, + int avg_length_of_neighbor_idxs, float max_neighbour_distance, int nsample, int neighbor_type); + + +int query_three_nn_by_stacked_local_idxs_kernel_launcher_stack( + const float *support_xyz, const float *new_xyz, const float *new_xyz_grid_centers, + int *new_xyz_grid_idxs, float *new_xyz_grid_dist2, + const int *stack_neighbor_idxs, const int *start_len, + int M, int num_total_grids); + +int query_three_nn_by_stacked_local_idxs_wrapper_stack(at::Tensor support_xyz_tensor, + at::Tensor new_xyz_tensor, at::Tensor new_xyz_grid_centers_tensor, + at::Tensor new_xyz_grid_idxs_tensor, at::Tensor new_xyz_grid_dist2_tensor, + at::Tensor stack_neighbor_idxs_tensor, at::Tensor start_len_tensor, + int M, int num_total_grids); + + +int vector_pool_wrapper_stack(at::Tensor support_xyz_tensor, at::Tensor xyz_batch_cnt_tensor, + at::Tensor support_features_tensor, at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor, + at::Tensor new_features_tensor, at::Tensor new_local_xyz, + at::Tensor point_cnt_of_grid_tensor, at::Tensor grouped_idxs_tensor, + int num_grid_x, int num_grid_y, int num_grid_z, float max_neighbour_distance, int use_xyz, + int num_max_sum_points, int nsample, int neighbor_type, int pooling_type); + + +int vector_pool_kernel_launcher_stack( + const float *support_xyz, const float *support_features, const int *xyz_batch_cnt, + const float *new_xyz, float *new_features, float * new_local_xyz, const int *new_xyz_batch_cnt, + int *point_cnt_of_grid, int *grouped_idxs, + int num_grid_x, int num_grid_y, int num_grid_z, float max_neighbour_distance, + int batch_size, int N, int M, int num_c_in, int num_c_out, int num_total_grids, int use_xyz, + int num_max_sum_points, int nsample, int neighbor_type, int pooling_type); + + +int vector_pool_grad_wrapper_stack(at::Tensor grad_new_features_tensor, + at::Tensor point_cnt_of_grid_tensor, at::Tensor grouped_idxs_tensor, + at::Tensor grad_support_features_tensor); + + +void vector_pool_grad_kernel_launcher_stack( + const float *grad_new_features, const int *point_cnt_of_grid, const int *grouped_idxs, + float *grad_support_features, int N, int M, int num_c_out, int num_c_in, int num_total_grids, + int num_max_sum_points); + +#endif diff --git a/setup.py b/setup.py index 935c65eea..418d3e39d 100644 --- a/setup.py +++ b/setup.py @@ -97,6 +97,8 @@ def write_version_to_file(version, target_file): 'src/interpolate_gpu.cu', 'src/voxel_query.cpp', 'src/voxel_query_gpu.cu', + 'src/vector_pool.cpp', + 'src/vector_pool_gpu.cu' ], ), make_cuda_ext( diff --git a/tools/cfgs/waymo_models/pv_rcnn_plusplus.yaml b/tools/cfgs/waymo_models/pv_rcnn_plusplus.yaml new file mode 100644 index 000000000..b16b86346 --- /dev/null +++ b/tools/cfgs/waymo_models/pv_rcnn_plusplus.yaml @@ -0,0 +1,277 @@ +CLASS_NAMES: ['Vehicle', 'Pedestrian', 'Cyclist'] + +DATA_CONFIG: + _BASE_CONFIG_: cfgs/dataset_configs/waymo_dataset.yaml + + +MODEL: + NAME: PVRCNNPlusPlus + + VFE: + NAME: MeanVFE + + BACKBONE_3D: + NAME: VoxelBackBone8x + + MAP_TO_BEV: + NAME: HeightCompression + NUM_BEV_FEATURES: 256 + + BACKBONE_2D: + NAME: BaseBEVBackbone + + LAYER_NUMS: [5, 5] + LAYER_STRIDES: [1, 2] + NUM_FILTERS: [128, 256] + UPSAMPLE_STRIDES: [1, 2] + NUM_UPSAMPLE_FILTERS: [256, 256] + + DENSE_HEAD: + NAME: CenterHead + CLASS_AGNOSTIC: False + + CLASS_NAMES_EACH_HEAD: [ + [ 'Vehicle', 'Pedestrian', 'Cyclist' ] + ] + + SHARED_CONV_CHANNEL: 64 + USE_BIAS_BEFORE_NORM: True + NUM_HM_CONV: 2 + SEPARATE_HEAD_CFG: + HEAD_ORDER: [ 'center', 'center_z', 'dim', 'rot' ] + HEAD_DICT: { + 'center': { 'out_channels': 2, 'num_conv': 2 }, + 'center_z': { 'out_channels': 1, 'num_conv': 2 }, + 'dim': { 'out_channels': 3, 'num_conv': 2 }, + 'rot': { 'out_channels': 2, 'num_conv': 2 }, + } + + TARGET_ASSIGNER_CONFIG: + FEATURE_MAP_STRIDE: 8 + NUM_MAX_OBJS: 500 + GAUSSIAN_OVERLAP: 0.1 + MIN_RADIUS: 2 + + LOSS_CONFIG: + LOSS_WEIGHTS: { + 'cls_weight': 1.0, + 'loc_weight': 2.0, + 'code_weights': [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ] + } + + POST_PROCESSING: + SCORE_THRESH: 0.1 + POST_CENTER_LIMIT_RANGE: [ -75.2, -75.2, -2, 75.2, 75.2, 4 ] + MAX_OBJ_PER_SAMPLE: 500 + NMS_CONFIG: + NMS_TYPE: nms_gpu + NMS_THRESH: 0.7 + NMS_PRE_MAXSIZE: 4096 + NMS_POST_MAXSIZE: 500 + + PFE: + NAME: VoxelSetAbstraction + POINT_SOURCE: raw_points + NUM_KEYPOINTS: 4096 + NUM_OUTPUT_FEATURES: 90 + SAMPLE_METHOD: SPC + SPC_SAMPLING: + NUM_SECTORS: 6 + SAMPLE_RADIUS_WITH_ROI: 1.6 + + FEATURES_SOURCE: ['bev', 'x_conv3', 'x_conv4', 'raw_points'] + SA_LAYER: + raw_points: + NAME: VectorPoolAggregationModuleMSG + NUM_GROUPS: 2 + LOCAL_AGGREGATION_TYPE: local_interpolation + NUM_REDUCED_CHANNELS: 2 + NUM_CHANNELS_OF_LOCAL_AGGREGATION: 32 + MSG_POST_MLPS: [ 32 ] + FILTER_NEIGHBOR_WITH_ROI: True + RADIUS_OF_NEIGHBOR_WITH_ROI: 2.4 + + GROUP_CFG_0: + NUM_LOCAL_VOXEL: [ 2, 2, 2 ] + MAX_NEIGHBOR_DISTANCE: 0.2 + NEIGHBOR_NSAMPLE: -1 + POST_MLPS: [ 32, 32 ] + GROUP_CFG_1: + NUM_LOCAL_VOXEL: [ 3, 3, 3 ] + MAX_NEIGHBOR_DISTANCE: 0.4 + NEIGHBOR_NSAMPLE: -1 + POST_MLPS: [ 32, 32 ] + + x_conv3: + DOWNSAMPLE_FACTOR: 4 + INPUT_CHANNELS: 64 + + NAME: VectorPoolAggregationModuleMSG + NUM_GROUPS: 2 + LOCAL_AGGREGATION_TYPE: local_interpolation + NUM_REDUCED_CHANNELS: 32 + NUM_CHANNELS_OF_LOCAL_AGGREGATION: 32 + MSG_POST_MLPS: [128] + FILTER_NEIGHBOR_WITH_ROI: True + RADIUS_OF_NEIGHBOR_WITH_ROI: 4.0 + + GROUP_CFG_0: + NUM_LOCAL_VOXEL: [3, 3, 3] + MAX_NEIGHBOR_DISTANCE: 1.2 + NEIGHBOR_NSAMPLE: -1 + POST_MLPS: [64, 64] + GROUP_CFG_1: + NUM_LOCAL_VOXEL: [ 3, 3, 3 ] + MAX_NEIGHBOR_DISTANCE: 2.4 + NEIGHBOR_NSAMPLE: -1 + POST_MLPS: [ 64, 64 ] + + x_conv4: + DOWNSAMPLE_FACTOR: 8 + INPUT_CHANNELS: 64 + + NAME: VectorPoolAggregationModuleMSG + NUM_GROUPS: 2 + LOCAL_AGGREGATION_TYPE: local_interpolation + NUM_REDUCED_CHANNELS: 32 + NUM_CHANNELS_OF_LOCAL_AGGREGATION: 32 + MSG_POST_MLPS: [ 128 ] + FILTER_NEIGHBOR_WITH_ROI: True + RADIUS_OF_NEIGHBOR_WITH_ROI: 6.4 + + GROUP_CFG_0: + NUM_LOCAL_VOXEL: [ 3, 3, 3 ] + MAX_NEIGHBOR_DISTANCE: 2.4 + NEIGHBOR_NSAMPLE: -1 + POST_MLPS: [ 64, 64 ] + GROUP_CFG_1: + NUM_LOCAL_VOXEL: [ 3, 3, 3 ] + MAX_NEIGHBOR_DISTANCE: 4.8 + NEIGHBOR_NSAMPLE: -1 + POST_MLPS: [ 64, 64 ] + + + POINT_HEAD: + NAME: PointHeadSimple + CLS_FC: [256, 256] + CLASS_AGNOSTIC: True + USE_POINT_FEATURES_BEFORE_FUSION: True + TARGET_CONFIG: + GT_EXTRA_WIDTH: [0.2, 0.2, 0.2] + LOSS_CONFIG: + LOSS_REG: smooth-l1 + LOSS_WEIGHTS: { + 'point_cls_weight': 1.0, + } + + ROI_HEAD: + NAME: PVRCNNHead + CLASS_AGNOSTIC: True + + SHARED_FC: [256, 256] + CLS_FC: [256, 256] + REG_FC: [256, 256] + DP_RATIO: 0.3 + + NMS_CONFIG: + TRAIN: + NMS_TYPE: nms_gpu + MULTI_CLASSES_NMS: False + NMS_PRE_MAXSIZE: 9000 + NMS_POST_MAXSIZE: 512 + NMS_THRESH: 0.8 + TEST: + NMS_TYPE: nms_gpu + MULTI_CLASSES_NMS: False + NMS_PRE_MAXSIZE: 1024 + NMS_POST_MAXSIZE: 100 + NMS_THRESH: 0.7 + SCORE_THRESH: 0.1 + +# NMS_PRE_MAXSIZE: 4096 +# NMS_POST_MAXSIZE: 500 +# NMS_THRESH: 0.85 + + + ROI_GRID_POOL: + GRID_SIZE: 6 + + NAME: VectorPoolAggregationModuleMSG + NUM_GROUPS: 2 + LOCAL_AGGREGATION_TYPE: voxel_random_choice + NUM_REDUCED_CHANNELS: 30 + NUM_CHANNELS_OF_LOCAL_AGGREGATION: 32 + MSG_POST_MLPS: [ 128 ] + + GROUP_CFG_0: + NUM_LOCAL_VOXEL: [ 3, 3, 3 ] + MAX_NEIGHBOR_DISTANCE: 0.8 + NEIGHBOR_NSAMPLE: 32 + POST_MLPS: [ 64, 64 ] + GROUP_CFG_1: + NUM_LOCAL_VOXEL: [ 3, 3, 3 ] + MAX_NEIGHBOR_DISTANCE: 1.6 + NEIGHBOR_NSAMPLE: 32 + POST_MLPS: [ 64, 64 ] + + TARGET_CONFIG: + BOX_CODER: ResidualCoder + ROI_PER_IMAGE: 128 + FG_RATIO: 0.5 + + SAMPLE_ROI_BY_EACH_CLASS: True + CLS_SCORE_TYPE: roi_iou + + CLS_FG_THRESH: 0.75 + CLS_BG_THRESH: 0.25 + CLS_BG_THRESH_LO: 0.1 + HARD_BG_RATIO: 0.8 + + REG_FG_THRESH: 0.55 + + LOSS_CONFIG: + CLS_LOSS: BinaryCrossEntropy + REG_LOSS: smooth-l1 + CORNER_LOSS_REGULARIZATION: True + LOSS_WEIGHTS: { + 'rcnn_cls_weight': 1.0, + 'rcnn_reg_weight': 1.0, + 'rcnn_corner_weight': 1.0, + 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + } + + POST_PROCESSING: + RECALL_THRESH_LIST: [0.3, 0.5, 0.7] + SCORE_THRESH: 0.1 + OUTPUT_RAW_SCORE: False + + EVAL_METRIC: waymo + + NMS_CONFIG: + MULTI_CLASSES_NMS: False + NMS_TYPE: nms_gpu + NMS_THRESH: 0.7 + NMS_PRE_MAXSIZE: 4096 + NMS_POST_MAXSIZE: 500 + + +OPTIMIZATION: + BATCH_SIZE_PER_GPU: 2 + NUM_EPOCHS: 30 + + OPTIMIZER: adam_onecycle + LR: 0.01 + WEIGHT_DECAY: 0.001 + MOMENTUM: 0.9 + + MOMS: [0.95, 0.85] + PCT_START: 0.4 + DIV_FACTOR: 10 + DECAY_STEP_LIST: [35, 45] + LR_DECAY: 0.1 + LR_CLIP: 0.0000001 + + LR_WARMUP: False + WARMUP_EPOCH: 1 + + GRAD_NORM_CLIP: 10 \ No newline at end of file