Skip to content

Commit

Permalink
Merge branch 'pillarnet' of https://github.com/agent-sgs/OpenPCDet-raw
Browse files Browse the repository at this point in the history
…into agent-sgs-pillarnet
  • Loading branch information
sshaoshuai committed Sep 19, 2022
2 parents 3e3712e + 90944c5 commit 4e0962e
Show file tree
Hide file tree
Showing 10 changed files with 880 additions and 5 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ data/
venv/
*.idea/
*.so
*.yaml
*.sh
*.pth
*.pkl
Expand Down
5 changes: 3 additions & 2 deletions pcdet/models/backbones_2d/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base_bev_backbone import BaseBEVBackbone
from .base_bev_backbone import BaseBEVBackbone, BaseBEVBackboneV1

__all__ = {
'BaseBEVBackbone': BaseBEVBackbone
'BaseBEVBackbone': BaseBEVBackbone,
'BaseBEVBackboneV1': BaseBEVBackboneV1
}
92 changes: 92 additions & 0 deletions pcdet/models/backbones_2d/base_bev_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,95 @@ def forward(self, data_dict):
data_dict['spatial_features_2d'] = x

return data_dict


class BaseBEVBackboneV1(nn.Module):
def __init__(self, model_cfg, **kwargs):
super().__init__()
self.model_cfg = model_cfg

layer_nums = self.model_cfg.LAYER_NUMS
num_filters = self.model_cfg.NUM_FILTERS
assert len(layer_nums) == len(num_filters) == 2

num_upsample_filters = self.model_cfg.NUM_UPSAMPLE_FILTERS
upsample_strides = self.model_cfg.UPSAMPLE_STRIDES
assert len(num_upsample_filters) == len(upsample_strides)

num_levels = len(layer_nums)
self.blocks = nn.ModuleList()
self.deblocks = nn.ModuleList()
for idx in range(num_levels):
cur_layers = [
nn.ZeroPad2d(1),
nn.Conv2d(
num_filters[idx], num_filters[idx], kernel_size=3,
stride=1, padding=0, bias=False
),
nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
]
for k in range(layer_nums[idx]):
cur_layers.extend([
nn.Conv2d(num_filters[idx], num_filters[idx], kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
])
self.blocks.append(nn.Sequential(*cur_layers))
if len(upsample_strides) > 0:
stride = upsample_strides[idx]
if stride >= 1:
self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d(
num_filters[idx], num_upsample_filters[idx],
upsample_strides[idx],
stride=upsample_strides[idx], bias=False
),
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
))
else:
stride = np.round(1 / stride).astype(np.int)
self.deblocks.append(nn.Sequential(
nn.Conv2d(
num_filters[idx], num_upsample_filters[idx],
stride,
stride=stride, bias=False
),
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
))

c_in = sum(num_upsample_filters)
if len(upsample_strides) > num_levels:
self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d(c_in, c_in, upsample_strides[-1], stride=upsample_strides[-1], bias=False),
nn.BatchNorm2d(c_in, eps=1e-3, momentum=0.01),
nn.ReLU(),
))

self.num_bev_features = c_in

def forward(self, data_dict):
"""
Args:
data_dict:
spatial_features
Returns:
"""
spatial_features = data_dict['multi_scale_2d_features']

x_conv4 = spatial_features['x_conv4']
x_conv5 = spatial_features['x_conv5']

ups = [self.deblocks[0](x_conv4)]

x = self.blocks[1](x_conv5)
ups.append(self.deblocks[1](x))

x = torch.cat(ups, dim=1)
x = self.blocks[0](x)

data_dict['spatial_features_2d'] = x

return data_dict
5 changes: 4 additions & 1 deletion pcdet/models/backbones_3d/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .pointnet2_backbone import PointNet2Backbone, PointNet2MSG
from .spconv_backbone import VoxelBackBone8x, VoxelResBackBone8x
from .spconv_backbone_2d import PillarBackBone8x, PillarRes18BackBone8x
from .spconv_backbone_focal import VoxelBackBone8xFocal
from .spconv_unet import UNetV2

Expand All @@ -9,5 +10,7 @@
'PointNet2Backbone': PointNet2Backbone,
'PointNet2MSG': PointNet2MSG,
'VoxelResBackBone8x': VoxelResBackBone8x,
'VoxelBackBone8xFocal': VoxelBackBone8xFocal
'VoxelBackBone8xFocal': VoxelBackBone8xFocal,
'PillarBackBone8x': PillarBackBone8x,
'PillarRes18BackBone8x': PillarRes18BackBone8x
}
Loading

0 comments on commit 4e0962e

Please sign in to comment.