Skip to content

Commit

Permalink
abstract load_data_to_gpu function in train/test (open-mmlab#116)
Browse files Browse the repository at this point in the history
* abstract load_data_to_gpu function in train/test
  • Loading branch information
sshaoshuai authored Jul 2, 2020
1 parent be0507c commit 4497eb5
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 29 deletions.
36 changes: 20 additions & 16 deletions pcdet/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,22 +144,26 @@ def collate_batch(batch_list, _unused=False):
ret = {}

for key, val in data_dict.items():
if key in ['voxels', 'voxel_num_points']:
ret[key] = np.concatenate(val, axis=0)
elif key in ['points', 'voxel_coords']:
coors = []
for i, coor in enumerate(val):
coor_pad = np.pad(coor, ((0, 0), (1, 0)), mode='constant', constant_values=i)
coors.append(coor_pad)
ret[key] = np.concatenate(coors, axis=0)
elif key in ['gt_boxes']:
max_gt = max([len(x) for x in val])
batch_gt_boxes3d = np.zeros((batch_size, max_gt, val[0].shape[-1]), dtype=np.float32)
for k in range(batch_size):
batch_gt_boxes3d[k, :val[k].__len__(), :] = val[k]
ret[key] = batch_gt_boxes3d
else:
ret[key] = np.stack(val, axis=0)
try:
if key in ['voxels', 'voxel_num_points']:
ret[key] = np.concatenate(val, axis=0)
elif key in ['points', 'voxel_coords']:
coors = []
for i, coor in enumerate(val):
coor_pad = np.pad(coor, ((0, 0), (1, 0)), mode='constant', constant_values=i)
coors.append(coor_pad)
ret[key] = np.concatenate(coors, axis=0)
elif key in ['gt_boxes']:
max_gt = max([len(x) for x in val])
batch_gt_boxes3d = np.zeros((batch_size, max_gt, val[0].shape[-1]), dtype=np.float32)
for k in range(batch_size):
batch_gt_boxes3d[k, :val[k].__len__(), :] = val[k]
ret[key] = batch_gt_boxes3d
else:
ret[key] = np.stack(val, axis=0)
except:
print('Error in collate_batch: key=%s' % key)
raise TypeError

ret['batch_size'] = batch_size
return ret
16 changes: 10 additions & 6 deletions pcdet/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,20 @@ def build_network(model_cfg, num_class, dataset):
return model


def load_data_to_gpu(batch_dict):
for key, val in batch_dict.items():
if not isinstance(val, np.ndarray):
continue
if key in ['frame_id', 'metadata', 'calib', 'image_shape']:
continue
batch_dict[key] = torch.from_numpy(val).float().cuda()


def model_fn_decorator():
ModelReturn = namedtuple('ModelReturn', ['loss', 'tb_dict', 'disp_dict'])

def model_func(model, batch_dict):
for key, val in batch_dict.items():
if not isinstance(val, np.ndarray):
continue
if key in ['frame_id']:
continue
batch_dict[key] = torch.from_numpy(val).float().cuda()
load_data_to_gpu(batch_dict)
ret_dict, tb_dict, disp_dict = model(batch_dict)

loss = ret_dict['loss'].mean()
Expand Down
9 changes: 2 additions & 7 deletions tools/eval_utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import torch
from pcdet.utils import common_utils
from pcdet.models import load_data_to_gpu


def statistics_info(cfg, ret_dict, metric, disp_dict):
Expand Down Expand Up @@ -51,13 +52,7 @@ def eval_one_epoch(cfg, model, dataloader, epoch_id, logger, dist_test=False, sa
progress_bar = tqdm.tqdm(total=len(dataloader), leave=True, desc='eval', dynamic_ncols=True)
start_time = time.time()
for i, batch_dict in enumerate(dataloader):
for key, val in batch_dict.items():
if not isinstance(val, np.ndarray):
continue
if key in ['frame_id', 'calib', 'image_shape']:
continue
batch_dict[key] = torch.from_numpy(val).float().cuda()

load_data_to_gpu(batch_dict)
with torch.no_grad():
pred_dicts, ret_dict = model(batch_dict)
disp_dict = {}
Expand Down

0 comments on commit 4497eb5

Please sign in to comment.