From d471a69333808e1ad88e47774b9ce32e89bd7d9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cagent-sgs=E2=80=9D?= Date: Wed, 14 Sep 2022 14:33:28 +0800 Subject: [PATCH 1/3] pillarnet --- pcdet/models/backbones_2d/__init__.py | 5 +- .../models/backbones_2d/base_bev_backbone.py | 92 ++++++ pcdet/models/backbones_3d/__init__.py | 5 +- .../models/backbones_3d/spconv_backbone_2d.py | 300 ++++++++++++++++++ pcdet/models/backbones_3d/vfe/__init__.py | 3 +- .../backbones_3d/vfe/dynamic_pillar_vfe.py | 98 ++++++ 6 files changed, 499 insertions(+), 4 deletions(-) create mode 100644 pcdet/models/backbones_3d/spconv_backbone_2d.py diff --git a/pcdet/models/backbones_2d/__init__.py b/pcdet/models/backbones_2d/__init__.py index f5aa5cddf..b648212d9 100644 --- a/pcdet/models/backbones_2d/__init__.py +++ b/pcdet/models/backbones_2d/__init__.py @@ -1,5 +1,6 @@ -from .base_bev_backbone import BaseBEVBackbone +from .base_bev_backbone import BaseBEVBackbone, BaseBEVBackboneV1 __all__ = { - 'BaseBEVBackbone': BaseBEVBackbone + 'BaseBEVBackbone': BaseBEVBackbone, + 'BaseBEVBackboneV1': BaseBEVBackboneV1 } diff --git a/pcdet/models/backbones_2d/base_bev_backbone.py b/pcdet/models/backbones_2d/base_bev_backbone.py index 07fe70f08..55e675004 100644 --- a/pcdet/models/backbones_2d/base_bev_backbone.py +++ b/pcdet/models/backbones_2d/base_bev_backbone.py @@ -110,3 +110,95 @@ def forward(self, data_dict): data_dict['spatial_features_2d'] = x return data_dict + + +class BaseBEVBackboneV1(nn.Module): + def __init__(self, model_cfg, **kwargs): + super().__init__() + self.model_cfg = model_cfg + + layer_nums = self.model_cfg.LAYER_NUMS + num_filters = self.model_cfg.NUM_FILTERS + assert len(layer_nums) == len(num_filters) == 2 + + num_upsample_filters = self.model_cfg.NUM_UPSAMPLE_FILTERS + upsample_strides = self.model_cfg.UPSAMPLE_STRIDES + assert len(num_upsample_filters) == len(upsample_strides) + + num_levels = len(layer_nums) + self.blocks = nn.ModuleList() + self.deblocks = nn.ModuleList() + for idx in range(num_levels): + cur_layers = [ + nn.ZeroPad2d(1), + nn.Conv2d( + num_filters[idx], num_filters[idx], kernel_size=3, + stride=1, padding=0, bias=False + ), + nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01), + nn.ReLU() + ] + for k in range(layer_nums[idx]): + cur_layers.extend([ + nn.Conv2d(num_filters[idx], num_filters[idx], kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01), + nn.ReLU() + ]) + self.blocks.append(nn.Sequential(*cur_layers)) + if len(upsample_strides) > 0: + stride = upsample_strides[idx] + if stride >= 1: + self.deblocks.append(nn.Sequential( + nn.ConvTranspose2d( + num_filters[idx], num_upsample_filters[idx], + upsample_strides[idx], + stride=upsample_strides[idx], bias=False + ), + nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01), + nn.ReLU() + )) + else: + stride = np.round(1 / stride).astype(np.int) + self.deblocks.append(nn.Sequential( + nn.Conv2d( + num_filters[idx], num_upsample_filters[idx], + stride, + stride=stride, bias=False + ), + nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01), + nn.ReLU() + )) + + c_in = sum(num_upsample_filters) + if len(upsample_strides) > num_levels: + self.deblocks.append(nn.Sequential( + nn.ConvTranspose2d(c_in, c_in, upsample_strides[-1], stride=upsample_strides[-1], bias=False), + nn.BatchNorm2d(c_in, eps=1e-3, momentum=0.01), + nn.ReLU(), + )) + + self.num_bev_features = c_in + + def forward(self, data_dict): + """ + Args: + data_dict: + spatial_features + Returns: + """ + spatial_features = data_dict['multi_scale_2d_features'] + + x_conv4 = spatial_features['x_conv4'] + x_conv5 = spatial_features['x_conv5'] + + ups = [self.deblocks[0](x_conv4)] + + x = self.blocks[1](x_conv5) + ups.append(self.deblocks[1](x)) + + x = torch.cat(ups, dim=1) + x = self.blocks[0](x) + + data_dict['spatial_features_2d'] = x + + return data_dict diff --git a/pcdet/models/backbones_3d/__init__.py b/pcdet/models/backbones_3d/__init__.py index f58b4f9cc..61d33c3f2 100644 --- a/pcdet/models/backbones_3d/__init__.py +++ b/pcdet/models/backbones_3d/__init__.py @@ -1,5 +1,6 @@ from .pointnet2_backbone import PointNet2Backbone, PointNet2MSG from .spconv_backbone import VoxelBackBone8x, VoxelResBackBone8x +from .spconv_backbone_2d import PillarBackBone8x, PillarRes18BackBone8x from .spconv_backbone_focal import VoxelBackBone8xFocal from .spconv_unet import UNetV2 @@ -9,5 +10,7 @@ 'PointNet2Backbone': PointNet2Backbone, 'PointNet2MSG': PointNet2MSG, 'VoxelResBackBone8x': VoxelResBackBone8x, - 'VoxelBackBone8xFocal': VoxelBackBone8xFocal + 'VoxelBackBone8xFocal': VoxelBackBone8xFocal, + 'PillarBackBone8x': PillarBackBone8x, + 'PillarRes18BackBone8x': PillarRes18BackBone8x } diff --git a/pcdet/models/backbones_3d/spconv_backbone_2d.py b/pcdet/models/backbones_3d/spconv_backbone_2d.py new file mode 100644 index 000000000..3784ada1f --- /dev/null +++ b/pcdet/models/backbones_3d/spconv_backbone_2d.py @@ -0,0 +1,300 @@ +from functools import partial + +import torch.nn as nn + +from ...utils.spconv_utils import replace_feature, spconv + + +def post_act_block(in_channels, out_channels, kernel_size, indice_key=None, stride=1, padding=0, + conv_type='subm', norm_fn=None): + + if conv_type == 'subm': + conv = spconv.SubMConv2d(in_channels, out_channels, kernel_size, bias=False, indice_key=indice_key) + elif conv_type == 'spconv': + conv = spconv.SparseConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, + bias=False, indice_key=indice_key) + elif conv_type == 'inverseconv': + conv = spconv.SparseInverseConv2d(in_channels, out_channels, kernel_size, indice_key=indice_key, bias=False) + else: + raise NotImplementedError + + m = spconv.SparseSequential( + conv, + norm_fn(out_channels), + nn.ReLU(), + ) + + return m + + +def post_act_block_dense(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, norm_fn=None): + m = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, dilation=dilation, bias=False), + norm_fn(out_channels), + nn.ReLU(), + ) + + return m + + +class SparseBasicBlock(spconv.SparseModule): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, norm_fn=None, downsample=None, indice_key=None): + super(SparseBasicBlock, self).__init__() + + assert norm_fn is not None + bias = norm_fn is not None + self.conv1 = spconv.SubMConv2d( + inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=bias, indice_key=indice_key + ) + self.bn1 = norm_fn(planes) + self.relu = nn.ReLU() + self.conv2 = spconv.SubMConv2d( + planes, planes, kernel_size=3, stride=stride, padding=1, bias=bias, indice_key=indice_key + ) + self.bn2 = norm_fn(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = replace_feature(out, self.bn1(out.features)) + out = replace_feature(out, self.relu(out.features)) + + out = self.conv2(out) + out = replace_feature(out, self.bn2(out.features)) + + if self.downsample is not None: + identity = self.downsample(x) + + out = replace_feature(out, out.features + identity.features) + out = replace_feature(out, self.relu(out.features)) + + return out + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, norm_fn=None, downsample=None): + super(BasicBlock, self).__init__() + + assert norm_fn is not None + bias = norm_fn is not None + self.conv1 = nn.Conv2d(inplanes, planes, 3, stride=stride, padding=1, bias=bias) + self.bn1 = norm_fn(planes) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2d(planes, planes, 3, stride=stride, padding=1, bias=bias) + self.bn2 = norm_fn(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out = out + identity + out = self.relu(out) + + return out + + +class PillarBackBone8x(nn.Module): + def __init__(self, model_cfg, input_channels, grid_size, **kwargs): + super().__init__() + self.model_cfg = model_cfg + norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01) + self.sparse_shape = grid_size[[1, 0]] + + block = post_act_block + dense_block = post_act_block_dense + + self.conv1 = spconv.SparseSequential( + block(32, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm1'), + block(32, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm1'), + ) + + self.conv2 = spconv.SparseSequential( + # [1600, 1408] <- [800, 704] + block(32, 64, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv2', conv_type='spconv'), + block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm2'), + block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm2'), + ) + + self.conv3 = spconv.SparseSequential( + # [800, 704] <- [400, 352] + block(64, 128, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv3', conv_type='spconv'), + block(128, 128, 3, norm_fn=norm_fn, padding=1, indice_key='subm3'), + block(128, 128, 3, norm_fn=norm_fn, padding=1, indice_key='subm3'), + ) + + self.conv4 = spconv.SparseSequential( + # [400, 352] <- [200, 176] + block(128, 256, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv4', conv_type='spconv'), + block(256, 256, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'), + block(256, 256, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'), + ) + + norm_fn = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.01) + self.conv5 = nn.Sequential( + # [200, 176] <- [100, 88] + dense_block(256, 256, 3, norm_fn=norm_fn, stride=2, padding=1), + dense_block(256, 256, 3, norm_fn=norm_fn, padding=1), + dense_block(256, 256, 3, norm_fn=norm_fn, padding=1), + ) + + self.num_point_features = 256 + self.backbone_channels = { + 'x_conv1': 32, + 'x_conv2': 64, + 'x_conv3': 128, + 'x_conv4': 256, + 'x_conv5': 256 + } + + + def forward(self, batch_dict): + pillar_features, pillar_coords = batch_dict['pillar_features'], batch_dict['pillar_coords'] + batch_size = batch_dict['batch_size'] + input_sp_tensor = spconv.SparseConvTensor( + features=pillar_features, + indices=pillar_coords.int(), + spatial_shape=self.sparse_shape, + batch_size=batch_size + ) + + x_conv1 = self.conv1(input_sp_tensor) + x_conv2 = self.conv2(x_conv1) + x_conv3 = self.conv3(x_conv2) + x_conv4 = self.conv4(x_conv3) + x_conv4 = x_conv4.dense() + x_conv5 = self.conv5(x_conv4) + + batch_dict.update({ + 'multi_scale_2d_features': { + 'x_conv1': x_conv1, + 'x_conv2': x_conv2, + 'x_conv3': x_conv3, + 'x_conv4': x_conv4, + 'x_conv5': x_conv5, + } + }) + batch_dict.update({ + 'multi_scale_2d_strides': { + 'x_conv1': 1, + 'x_conv2': 2, + 'x_conv3': 4, + 'x_conv4': 8, + 'x_conv5': 16, + } + }) + + return batch_dict + + +class PillarRes18BackBone8x(nn.Module): + def __init__(self, model_cfg, input_channels, grid_size, **kwargs): + super().__init__() + self.model_cfg = model_cfg + norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01) + self.sparse_shape = grid_size[[1, 0]] + + block = post_act_block + dense_block = post_act_block_dense + + self.conv1 = spconv.SparseSequential( + SparseBasicBlock(32, 32, norm_fn=norm_fn, indice_key='res1'), + SparseBasicBlock(32, 32, norm_fn=norm_fn, indice_key='res1'), + ) + + self.conv2 = spconv.SparseSequential( + # [1600, 1408] <- [800, 704] + block(32, 64, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv2', conv_type='spconv'), + SparseBasicBlock(64, 64, norm_fn=norm_fn, indice_key='res2'), + SparseBasicBlock(64, 64, norm_fn=norm_fn, indice_key='res2'), + ) + + self.conv3 = spconv.SparseSequential( + # [800, 704] <- [400, 352] + block(64, 128, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv3', conv_type='spconv'), + SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res3'), + SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res3'), + ) + + self.conv4 = spconv.SparseSequential( + # [400, 352] <- [200, 176] + block(128, 256, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv4', conv_type='spconv'), + SparseBasicBlock(256, 256, norm_fn=norm_fn, indice_key='res4'), + SparseBasicBlock(256, 256, norm_fn=norm_fn, indice_key='res4'), + ) + + norm_fn = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.01) + self.conv5 = nn.Sequential( + # [200, 176] <- [100, 88] + dense_block(256, 256, 3, norm_fn=norm_fn, stride=2, padding=1), + BasicBlock(256, 256, norm_fn=norm_fn), + BasicBlock(256, 256, norm_fn=norm_fn), + ) + + self.num_point_features = 256 + self.backbone_channels = { + 'x_conv1': 32, + 'x_conv2': 64, + 'x_conv3': 128, + 'x_conv4': 256, + 'x_conv5': 256 + } + + def forward(self, batch_dict): + pillar_features, pillar_coords = batch_dict['pillar_features'], batch_dict['pillar_coords'] + batch_size = batch_dict['batch_size'] + input_sp_tensor = spconv.SparseConvTensor( + features=pillar_features, + indices=pillar_coords.int(), + spatial_shape=self.sparse_shape, + batch_size=batch_size + ) + + x_conv1 = self.conv1(input_sp_tensor) + x_conv2 = self.conv2(x_conv1) + x_conv3 = self.conv3(x_conv2) + x_conv4 = self.conv4(x_conv3) + x_conv4 = x_conv4.dense() + x_conv5 = self.conv5(x_conv4) + + # batch_dict.update({ + # 'encoded_spconv_tensor': out, + # 'encoded_spconv_tensor_stride': 8 + # }) + batch_dict.update({ + 'multi_scale_2d_features': { + 'x_conv1': x_conv1, + 'x_conv2': x_conv2, + 'x_conv3': x_conv3, + 'x_conv4': x_conv4, + 'x_conv5': x_conv5, + } + }) + batch_dict.update({ + 'multi_scale_2d_strides': { + 'x_conv1': 1, + 'x_conv2': 2, + 'x_conv3': 4, + 'x_conv4': 8, + 'x_conv5': 16, + } + }) + + return batch_dict diff --git a/pcdet/models/backbones_3d/vfe/__init__.py b/pcdet/models/backbones_3d/vfe/__init__.py index e544dfe0f..cf30a399b 100644 --- a/pcdet/models/backbones_3d/vfe/__init__.py +++ b/pcdet/models/backbones_3d/vfe/__init__.py @@ -1,7 +1,7 @@ from .mean_vfe import MeanVFE from .pillar_vfe import PillarVFE from .dynamic_mean_vfe import DynamicMeanVFE -from .dynamic_pillar_vfe import DynamicPillarVFE +from .dynamic_pillar_vfe import DynamicPillarVFE, DynamicPillarPFE from .image_vfe import ImageVFE from .vfe_template import VFETemplate @@ -12,4 +12,5 @@ 'ImageVFE': ImageVFE, 'DynMeanVFE': DynamicMeanVFE, 'DynPillarVFE': DynamicPillarVFE, + 'DynamicPillarPFE': DynamicPillarPFE } diff --git a/pcdet/models/backbones_3d/vfe/dynamic_pillar_vfe.py b/pcdet/models/backbones_3d/vfe/dynamic_pillar_vfe.py index 5e6e3ea27..3521ca972 100644 --- a/pcdet/models/backbones_3d/vfe/dynamic_pillar_vfe.py +++ b/pcdet/models/backbones_3d/vfe/dynamic_pillar_vfe.py @@ -140,3 +140,101 @@ def forward(self, batch_dict, **kwargs): batch_dict['pillar_features'] = features batch_dict['voxel_coords'] = voxel_coords return batch_dict + + +class DynamicPillarPFE(VFETemplate): + def __init__(self, model_cfg, num_point_features, voxel_size, grid_size, point_cloud_range, **kwargs): + super().__init__(model_cfg=model_cfg) + + self.use_norm = self.model_cfg.USE_NORM + self.with_distance = self.model_cfg.WITH_DISTANCE + self.use_absolute_xyz = self.model_cfg.USE_ABSLOTE_XYZ + self.use_cluster_xyz = self.model_cfg.get('USE_CLUSTER_XYZ', True) + if self.use_absolute_xyz: + num_point_features += 3 + if self.use_cluster_xyz: + num_point_features += 3 + if self.with_distance: + num_point_features += 1 + + self.num_filters = self.model_cfg.NUM_FILTERS + assert len(self.num_filters) > 0 + num_filters = [num_point_features] + list(self.num_filters) + + pfn_layers = [] + for i in range(len(num_filters) - 1): + in_filters = num_filters[i] + out_filters = num_filters[i + 1] + pfn_layers.append( + PFNLayerV2(in_filters, out_filters, self.use_norm, last_layer=(i >= len(num_filters) - 2)) + ) + self.pfn_layers = nn.ModuleList(pfn_layers) + + self.voxel_x = voxel_size[0] + self.voxel_y = voxel_size[1] + self.voxel_z = voxel_size[2] + self.x_offset = self.voxel_x / 2 + point_cloud_range[0] + self.y_offset = self.voxel_y / 2 + point_cloud_range[1] + self.z_offset = self.voxel_z / 2 + point_cloud_range[2] + + self.scale_xy = grid_size[0] * grid_size[1] + self.scale_y = grid_size[1] + + self.grid_size = torch.tensor(grid_size[:2]).cuda() + self.voxel_size = torch.tensor(voxel_size).cuda() + self.point_cloud_range = torch.tensor(point_cloud_range).cuda() + + def get_output_feature_dim(self): + return self.num_filters[-1] + + def forward(self, batch_dict, **kwargs): + points = batch_dict['points'] # (batch_idx, x, y, z, i, e) + + points_coords = torch.floor( + (points[:, [1, 2]] - self.point_cloud_range[[0, 1]]) / self.voxel_size[[0, 1]]).int() + mask = ((points_coords >= 0) & (points_coords < self.grid_size[[0, 1]])).all(dim=1) + points = points[mask] + points_coords = points_coords[mask] + points_xyz = points[:, [1, 2, 3]].contiguous() + + merge_coords = points[:, 0].int() * self.scale_xy + \ + points_coords[:, 0] * self.scale_y + \ + points_coords[:, 1] + + unq_coords, unq_inv, unq_cnt = torch.unique(merge_coords, return_inverse=True, return_counts=True, dim=0) + + f_center = torch.zeros_like(points_xyz) + f_center[:, 0] = points_xyz[:, 0] - (points_coords[:, 0].to(points_xyz.dtype) * self.voxel_x + self.x_offset) + f_center[:, 1] = points_xyz[:, 1] - (points_coords[:, 1].to(points_xyz.dtype) * self.voxel_y + self.y_offset) + f_center[:, 2] = points_xyz[:, 2] - self.z_offset + + features = [f_center] + if self.use_absolute_xyz: + features.append(points[:, 1:]) + else: + features.append(points[:, 4:]) + + if self.use_cluster_xyz: + points_mean = torch_scatter.scatter_mean(points_xyz, unq_inv, dim=0) + f_cluster = points_xyz - points_mean[unq_inv, :] + features.append(f_cluster) + + if self.with_distance: + points_dist = torch.norm(points[:, 1:4], 2, dim=1, keepdim=True) + features.append(points_dist) + features = torch.cat(features, dim=-1) + + for pfn in self.pfn_layers: + features = pfn(features, unq_inv) + + # generate voxel coordinates + unq_coords = unq_coords.int() + pillar_coords = torch.stack((unq_coords // self.scale_xy, + (unq_coords % self.scale_xy) // self.scale_y, + unq_coords % self.scale_y, + ), dim=1) + pillar_coords = pillar_coords[:, [0, 2, 1]] + + batch_dict['pillar_features'] = features + batch_dict['pillar_coords'] = pillar_coords + return batch_dict From 12919ddb52cfeda2bd63cc04341ec7e253d81b12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cagent-sgs=E2=80=9D?= Date: Mon, 19 Sep 2022 10:15:17 +0800 Subject: [PATCH 2/3] pfe --- .../models/backbones_3d/vfe/dynamic_pillar_vfe.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pcdet/models/backbones_3d/vfe/dynamic_pillar_vfe.py b/pcdet/models/backbones_3d/vfe/dynamic_pillar_vfe.py index 3521ca972..726783fc8 100644 --- a/pcdet/models/backbones_3d/vfe/dynamic_pillar_vfe.py +++ b/pcdet/models/backbones_3d/vfe/dynamic_pillar_vfe.py @@ -149,11 +149,11 @@ def __init__(self, model_cfg, num_point_features, voxel_size, grid_size, point_c self.use_norm = self.model_cfg.USE_NORM self.with_distance = self.model_cfg.WITH_DISTANCE self.use_absolute_xyz = self.model_cfg.USE_ABSLOTE_XYZ - self.use_cluster_xyz = self.model_cfg.get('USE_CLUSTER_XYZ', True) + # self.use_cluster_xyz = self.model_cfg.get('USE_CLUSTER_XYZ', True) if self.use_absolute_xyz: num_point_features += 3 - if self.use_cluster_xyz: - num_point_features += 3 + # if self.use_cluster_xyz: + # num_point_features += 3 if self.with_distance: num_point_features += 1 @@ -214,10 +214,10 @@ def forward(self, batch_dict, **kwargs): else: features.append(points[:, 4:]) - if self.use_cluster_xyz: - points_mean = torch_scatter.scatter_mean(points_xyz, unq_inv, dim=0) - f_cluster = points_xyz - points_mean[unq_inv, :] - features.append(f_cluster) + # if self.use_cluster_xyz: + # points_mean = torch_scatter.scatter_mean(points_xyz, unq_inv, dim=0) + # f_cluster = points_xyz - points_mean[unq_inv, :] + # features.append(f_cluster) if self.with_distance: points_dist = torch.norm(points[:, 1:4], 2, dim=1, keepdim=True) From 90944c5aa617ae97db143b73e3f763c4c18d8df9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cagent-sgs=E2=80=9D?= Date: Mon, 19 Sep 2022 10:18:32 +0800 Subject: [PATCH 3/3] cfg --- .gitignore | 1 - tools/cfgs/kitti_models/pillarnet.yaml | 123 +++++++++++++ .../cbgs_pillar0075_res2d_centerpoint.yaml | 161 ++++++++++++++++++ tools/cfgs/waymo_models/pillarnet.yaml | 97 +++++++++++ 4 files changed, 381 insertions(+), 1 deletion(-) create mode 100644 tools/cfgs/kitti_models/pillarnet.yaml create mode 100644 tools/cfgs/nuscenes_models/cbgs_pillar0075_res2d_centerpoint.yaml create mode 100644 tools/cfgs/waymo_models/pillarnet.yaml diff --git a/.gitignore b/.gitignore index 7d6eb6000..c4f4b92c0 100644 --- a/.gitignore +++ b/.gitignore @@ -7,7 +7,6 @@ data/ venv/ *.idea/ *.so -*.yaml *.sh *.pth *.pkl diff --git a/tools/cfgs/kitti_models/pillarnet.yaml b/tools/cfgs/kitti_models/pillarnet.yaml new file mode 100644 index 000000000..26dcf4586 --- /dev/null +++ b/tools/cfgs/kitti_models/pillarnet.yaml @@ -0,0 +1,123 @@ +CLASS_NAMES: ['Car', 'Pedestrian', 'Cyclist'] + +DATA_CONFIG: + _BASE_CONFIG_: cfgs/dataset_configs/kitti_dataset.yaml + + +MODEL: + NAME: PillarNet + + VFE: + NAME: DynamicPillarPFE + WITH_DISTANCE: False + USE_ABSLOTE_XYZ: True + USE_CLUSTER_XYZ: False + USE_NORM: True + NUM_FILTERS: [32] + + BACKBONE_3D: + NAME: PillarBackBone8x + + + BACKBONE_2D: + NAME: BaseBEVBackboneV1 + + LAYER_NUMS: [5, 5] + LAYER_STRIDES: [1, 2] + NUM_FILTERS: [256, 256] + UPSAMPLE_STRIDES: [1, 2] + NUM_UPSAMPLE_FILTERS: [128, 128] + + DENSE_HEAD: + NAME: AnchorHeadSingle + CLASS_AGNOSTIC: False + + USE_DIRECTION_CLASSIFIER: True + DIR_OFFSET: 0.78539 + DIR_LIMIT_OFFSET: 0.0 + NUM_DIR_BINS: 2 + + ANCHOR_GENERATOR_CONFIG: [ + { + 'class_name': 'Car', + 'anchor_sizes': [[3.9, 1.6, 1.56]], + 'anchor_rotations': [0, 1.57], + 'anchor_bottom_heights': [-1.78], + 'align_center': False, + 'feature_map_stride': 8, + 'matched_threshold': 0.6, + 'unmatched_threshold': 0.45 + }, + { + 'class_name': 'Pedestrian', + 'anchor_sizes': [[0.8, 0.6, 1.73]], + 'anchor_rotations': [0, 1.57], + 'anchor_bottom_heights': [-0.6], + 'align_center': False, + 'feature_map_stride': 8, + 'matched_threshold': 0.5, + 'unmatched_threshold': 0.35 + }, + { + 'class_name': 'Cyclist', + 'anchor_sizes': [[1.76, 0.6, 1.73]], + 'anchor_rotations': [0, 1.57], + 'anchor_bottom_heights': [-0.6], + 'align_center': False, + 'feature_map_stride': 8, + 'matched_threshold': 0.5, + 'unmatched_threshold': 0.35 + } + ] + + TARGET_ASSIGNER_CONFIG: + NAME: AxisAlignedTargetAssigner + POS_FRACTION: -1.0 + SAMPLE_SIZE: 512 + NORM_BY_NUM_EXAMPLES: False + MATCH_HEIGHT: False + BOX_CODER: ResidualCoder + + LOSS_CONFIG: + LOSS_WEIGHTS: { + 'cls_weight': 1.0, + 'loc_weight': 2.0, + 'dir_weight': 0.2, + 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + } + + POST_PROCESSING: + RECALL_THRESH_LIST: [0.3, 0.5, 0.7] + SCORE_THRESH: 0.1 + OUTPUT_RAW_SCORE: False + + EVAL_METRIC: kitti + + NMS_CONFIG: + MULTI_CLASSES_NMS: False + NMS_TYPE: nms_gpu + NMS_THRESH: 0.01 + NMS_PRE_MAXSIZE: 4096 + NMS_POST_MAXSIZE: 500 + + +OPTIMIZATION: + BATCH_SIZE_PER_GPU: 4 + NUM_EPOCHS: 80 + + OPTIMIZER: adam_onecycle + LR: 0.003 + WEIGHT_DECAY: 0.01 + MOMENTUM: 0.9 + + MOMS: [0.95, 0.85] + PCT_START: 0.4 + DIV_FACTOR: 10 + DECAY_STEP_LIST: [35, 45] + LR_DECAY: 0.1 + LR_CLIP: 0.0000001 + + LR_WARMUP: False + WARMUP_EPOCH: 1 + + GRAD_NORM_CLIP: 10 diff --git a/tools/cfgs/nuscenes_models/cbgs_pillar0075_res2d_centerpoint.yaml b/tools/cfgs/nuscenes_models/cbgs_pillar0075_res2d_centerpoint.yaml new file mode 100644 index 000000000..4f2820192 --- /dev/null +++ b/tools/cfgs/nuscenes_models/cbgs_pillar0075_res2d_centerpoint.yaml @@ -0,0 +1,161 @@ +CLASS_NAMES: ['car','truck', 'construction_vehicle', 'bus', 'trailer', + 'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'] + +DATA_CONFIG: + _BASE_CONFIG_: cfgs/dataset_configs/nuscenes_dataset.yaml + POINT_CLOUD_RANGE: [-54.0, -54.0, -5.0, 54.0, 54.0, 3.0] + + DATA_AUGMENTOR: + DISABLE_AUG_LIST: ['placeholder'] + AUG_CONFIG_LIST: + - NAME: gt_sampling + DB_INFO_PATH: + - nuscenes_dbinfos_10sweeps_withvelo.pkl + PREPARE: { + filter_by_min_points: [ + 'car:5','truck:5', 'construction_vehicle:5', 'bus:5', 'trailer:5', + 'barrier:5', 'motorcycle:5', 'bicycle:5', 'pedestrian:5', 'traffic_cone:5' + ], + } + + SAMPLE_GROUPS: [ + 'car:2','truck:3', 'construction_vehicle:7', 'bus:4', 'trailer:6', + 'barrier:2', 'motorcycle:6', 'bicycle:6', 'pedestrian:2', 'traffic_cone:2' + ] + + NUM_POINT_FEATURES: 5 + DATABASE_WITH_FAKELIDAR: False + REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0] + LIMIT_WHOLE_SCENE: True + + - NAME: random_world_flip + ALONG_AXIS_LIST: ['x', 'y'] + + - NAME: random_world_rotation + WORLD_ROT_ANGLE: [-0.78539816, 0.78539816] + + - NAME: random_world_scaling + WORLD_SCALE_RANGE: [0.9, 1.1] + + - NAME: random_world_translation + NOISE_TRANSLATE_STD: [0.5, 0.5, 0.5] + + + DATA_PROCESSOR: + - NAME: mask_points_and_boxes_outside_range + REMOVE_OUTSIDE_BOXES: True + + - NAME: shuffle_points + SHUFFLE_ENABLED: { + 'train': True, + 'test': True + } + + - NAME: transform_points_to_voxels + VOXEL_SIZE: [0.075, 0.075, 0.2] + MAX_POINTS_PER_VOXEL: 10 + MAX_NUMBER_OF_VOXELS: { + 'train': 120000, + 'test': 160000 + } + + +MODEL: + NAME: PillarNet + + VFE: + NAME: DynamicPillarPFE + WITH_DISTANCE: False + USE_ABSLOTE_XYZ: True + USE_CLUSTER_XYZ: False + USE_NORM: True + NUM_FILTERS: [ 32 ] + + BACKBONE_3D: + NAME: PillarRes18BackBone8x + + BACKBONE_2D: + NAME: BaseBEVBackboneV1 + + LAYER_NUMS: [ 5, 5 ] + LAYER_STRIDES: [ 1, 2 ] + NUM_FILTERS: [ 256, 256 ] + UPSAMPLE_STRIDES: [ 1, 2 ] + NUM_UPSAMPLE_FILTERS: [ 128, 128 ] + + DENSE_HEAD: + NAME: CenterHead + CLASS_AGNOSTIC: False + + CLASS_NAMES_EACH_HEAD: [ + ['car'], + ['truck', 'construction_vehicle'], + ['bus', 'trailer'], + ['barrier'], + ['motorcycle', 'bicycle'], + ['pedestrian', 'traffic_cone'], + ] + + SHARED_CONV_CHANNEL: 64 + USE_BIAS_BEFORE_NORM: True + NUM_HM_CONV: 2 + SEPARATE_HEAD_CFG: + HEAD_ORDER: ['center', 'center_z', 'dim', 'rot', 'vel'] + HEAD_DICT: { + 'center': {'out_channels': 2, 'num_conv': 2}, + 'center_z': {'out_channels': 1, 'num_conv': 2}, + 'dim': {'out_channels': 3, 'num_conv': 2}, + 'rot': {'out_channels': 2, 'num_conv': 2}, + 'vel': {'out_channels': 2, 'num_conv': 2}, + } + + TARGET_ASSIGNER_CONFIG: + FEATURE_MAP_STRIDE: 8 + NUM_MAX_OBJS: 500 + GAUSSIAN_OVERLAP: 0.1 + MIN_RADIUS: 2 + + LOSS_CONFIG: + LOSS_WEIGHTS: { + 'cls_weight': 1.0, + 'loc_weight': 0.25, + 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2, 1.0, 1.0] + } + + POST_PROCESSING: + SCORE_THRESH: 0.1 + POST_CENTER_LIMIT_RANGE: [-61.2, -61.2, -10.0, 61.2, 61.2, 10.0] + MAX_OBJ_PER_SAMPLE: 500 + NMS_CONFIG: + NMS_TYPE: nms_gpu + NMS_THRESH: 0.2 + NMS_PRE_MAXSIZE: 1000 + NMS_POST_MAXSIZE: 83 + + POST_PROCESSING: + RECALL_THRESH_LIST: [0.3, 0.5, 0.7] + + EVAL_METRIC: kitti + + + +OPTIMIZATION: + BATCH_SIZE_PER_GPU: 4 + NUM_EPOCHS: 20 + + OPTIMIZER: adam_onecycle + LR: 0.001 + WEIGHT_DECAY: 0.01 + MOMENTUM: 0.9 + + MOMS: [0.95, 0.85] + PCT_START: 0.4 + DIV_FACTOR: 10 + DECAY_STEP_LIST: [35, 45] + LR_DECAY: 0.1 + LR_CLIP: 0.0000001 + + LR_WARMUP: False + WARMUP_EPOCH: 1 + + GRAD_NORM_CLIP: 10 diff --git a/tools/cfgs/waymo_models/pillarnet.yaml b/tools/cfgs/waymo_models/pillarnet.yaml new file mode 100644 index 000000000..42c814811 --- /dev/null +++ b/tools/cfgs/waymo_models/pillarnet.yaml @@ -0,0 +1,97 @@ +CLASS_NAMES: ['Vehicle', 'Pedestrian', 'Cyclist'] + +DATA_CONFIG: + _BASE_CONFIG_: cfgs/dataset_configs/waymo_dataset.yaml + +MODEL: + NAME: PillarNet + + VFE: + NAME: DynamicPillarPFE + WITH_DISTANCE: False + USE_ABSLOTE_XYZ: True + USE_CLUSTER_XYZ: False + USE_NORM: True + NUM_FILTERS: [32] + + BACKBONE_3D: + NAME: PillarRes18BackBone8x + + BACKBONE_2D: + NAME: BaseBEVBackboneV1 + + LAYER_NUMS: [5, 5] + LAYER_STRIDES: [1, 2] + NUM_FILTERS: [256, 256] + UPSAMPLE_STRIDES: [1, 2] + NUM_UPSAMPLE_FILTERS: [128, 128] + + DENSE_HEAD: + NAME: CenterHead + CLASS_AGNOSTIC: False + + CLASS_NAMES_EACH_HEAD: [ + ['Vehicle', 'Pedestrian', 'Cyclist'] + ] + + SHARED_CONV_CHANNEL: 64 + USE_BIAS_BEFORE_NORM: True + NUM_HM_CONV: 2 + SEPARATE_HEAD_CFG: + HEAD_ORDER: ['center', 'center_z', 'dim', 'rot'] + HEAD_DICT: { + 'center': {'out_channels': 2, 'num_conv': 2}, + 'center_z': {'out_channels': 1, 'num_conv': 2}, + 'dim': {'out_channels': 3, 'num_conv': 2}, + 'rot': {'out_channels': 2, 'num_conv': 2}, + } + + TARGET_ASSIGNER_CONFIG: + FEATURE_MAP_STRIDE: 8 + NUM_MAX_OBJS: 500 + GAUSSIAN_OVERLAP: 0.1 + MIN_RADIUS: 2 + + LOSS_CONFIG: + LOSS_WEIGHTS: { + 'cls_weight': 1.0, + 'loc_weight': 2.0, + 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + } + + POST_PROCESSING: + SCORE_THRESH: 0.1 + POST_CENTER_LIMIT_RANGE: [-75.2, -75.2, -2, 75.2, 75.2, 4] + MAX_OBJ_PER_SAMPLE: 500 + NMS_CONFIG: + NMS_TYPE: nms_gpu + NMS_THRESH: 0.7 + NMS_PRE_MAXSIZE: 4096 + NMS_POST_MAXSIZE: 500 + + POST_PROCESSING: + RECALL_THRESH_LIST: [0.3, 0.5, 0.7] + + EVAL_METRIC: waymo + + +OPTIMIZATION: + BATCH_SIZE_PER_GPU: 4 + NUM_EPOCHS: 30 + + OPTIMIZER: adam_onecycle + LR: 0.003 + WEIGHT_DECAY: 0.01 + MOMENTUM: 0.9 + + MOMS: [0.95, 0.85] + PCT_START: 0.4 + DIV_FACTOR: 10 + DECAY_STEP_LIST: [35, 45] + LR_DECAY: 0.1 + LR_CLIP: 0.0000001 + + LR_WARMUP: False + WARMUP_EPOCH: 1 + + GRAD_NORM_CLIP: 10