forked from open-mmlab/OpenPCDet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add WaymoDataset dataloader and evaluation tools, add related configs
- Loading branch information
1 parent
7bc7e55
commit 8f3bf93
Showing
6 changed files
with
1,120 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,364 @@ | ||
# OpenPCDet PyTorch Dataloader and Evaluation Tools for Waymo Open Dataset | ||
# Reference https://github.com/open-mmlab/OpenPCDet | ||
# Written by Shaoshuai Shi, Chaoxu Guo | ||
# All Rights Reserved 2019-2020. | ||
|
||
import os | ||
import pickle | ||
import copy | ||
import numpy as np | ||
import torch | ||
from path import Path | ||
from ...ops.roiaware_pool3d import roiaware_pool3d_utils | ||
from ...utils import box_utils, common_utils | ||
from ..dataset import DatasetTemplate | ||
|
||
|
||
class WaymoDataset(DatasetTemplate): | ||
def __init__(self, dataset_cfg, class_names, training=True, root_path=None, logger=None): | ||
super().__init__( | ||
dataset_cfg=dataset_cfg, class_names=class_names, training=training, root_path=root_path, logger=logger | ||
) | ||
self.data_path = self.root_path / self.dataset_cfg.PROCESSED_DATA_TAG | ||
self.split = self.dataset_cfg.DATA_SPLIT[self.mode] | ||
split_dir = self.root_path / 'ImageSets' / (self.split + '.txt') | ||
self.sample_sequence_list = [x.strip() for x in open(split_dir).readlines()] | ||
|
||
self.infos = [] | ||
self.include_waymo_data(self.mode) | ||
|
||
def set_split(self, split): | ||
super().__init__( | ||
dataset_cfg=self.dataset_cfg, class_names=self.class_names, training=self.training, | ||
root_path=self.root_path, logger=self.logger | ||
) | ||
self.split = split | ||
split_dir = self.root_path / 'ImageSets' / (self.split + '.txt') | ||
self.sample_sequence_list = [x.strip() for x in open(split_dir).readlines()] | ||
self.infos = [] | ||
self.include_waymo_data(self.mode) | ||
|
||
def include_waymo_data(self, mode): | ||
self.logger.info('Loading Waymo dataset') | ||
waymo_infos = [] | ||
|
||
num_skipped_infos = 0 | ||
for k in range(len(self.sample_sequence_list)): | ||
sequence_name = os.path.splitext(self.sample_sequence_list[k])[0] | ||
info_path = self.data_path / sequence_name / ('%s.pkl' % sequence_name) | ||
info_path = self.check_sequence_name_with_all_version(info_path) | ||
if not info_path.exists(): | ||
num_skipped_infos += 1 | ||
continue | ||
with open(info_path, 'rb') as f: | ||
infos = pickle.load(f) | ||
waymo_infos.extend(infos) | ||
|
||
self.infos.extend(waymo_infos[:]) | ||
self.logger.info('Total skipped info %s' % num_skipped_infos) | ||
self.logger.info('Total samples for Waymo dataset: %d' % (len(waymo_infos))) | ||
|
||
if self.dataset_cfg.SAMPLED_INTERVAL[mode] > 1: | ||
sampled_waymo_infos = [] | ||
for k in range(0, len(self.infos), self.dataset_cfg.SAMPLED_INTERVAL[mode]): | ||
sampled_waymo_infos.append(self.infos[k]) | ||
self.infos = sampled_waymo_infos | ||
self.logger.info('Total sampled samples for Waymo dataset: %d' % len(self.infos)) | ||
|
||
@staticmethod | ||
def check_sequence_name_with_all_version(sequence_file): | ||
if '_with_camera_labels' not in str(sequence_file) and not sequence_file.exists(): | ||
sequence_file = Path(str(sequence_file[:-9]) + '_with_camera_labels.tfrecord') | ||
if '_with_camera_labels' in str(sequence_file) and not sequence_file.exists(): | ||
sequence_file = Path(str(sequence_file).replace('_with_camera_labels', '')) | ||
|
||
return sequence_file | ||
|
||
def get_infos(self, raw_data_path, save_path, num_workers=4, has_label=True, sampled_interval=1): | ||
import concurrent.futures as futures | ||
from functools import partial | ||
from . import waymo_utils | ||
# tf.enable_eager_execution() | ||
print('The waymo sample interval is %d, total sequecnes is %d' % (sampled_interval, len(self.sample_sequence_list))) | ||
|
||
process_single_sequence = partial( | ||
waymo_utils.process_single_sequence, | ||
save_path=save_path, sampled_interval=sampled_interval, has_label=has_label | ||
) | ||
sample_sequence_file_list = [ | ||
self.check_sequence_name_with_all_version(raw_data_path / sequence_file) | ||
for sequence_file in self.sample_sequence_list | ||
] | ||
|
||
# process_single_sequence(sample_sequence_file_list[0]) | ||
with futures.ThreadPoolExecutor(num_workers) as executor: | ||
sequence_infos = executor.map(process_single_sequence, sample_sequence_file_list) | ||
sequence_infos = list(sequence_infos) | ||
all_sequences_infos = [item for infos in sequence_infos for item in infos] | ||
return all_sequences_infos | ||
|
||
def get_lidar(self, sequence_name, sample_idx): | ||
lidar_file = self.data_path / sequence_name / ('%04d.npy' % sample_idx) | ||
point_features = np.load(lidar_file) # (N, 7): [x, y, z, intensity, elongation, NLZ_flag] | ||
|
||
points_all, NLZ_flag = point_features[:, 0:5], point_features[:, 5] | ||
points_all = points_all[NLZ_flag == -1] | ||
points_all[:, 3] = np.tanh(points_all[:, 3]) | ||
return points_all | ||
|
||
def __len__(self): | ||
if self._merge_all_iters_to_one_epoch: | ||
return len(self.infos) * self.total_epochs | ||
|
||
return len(self.infos) | ||
|
||
def __getitem__(self, index): | ||
if self._merge_all_iters_to_one_epoch: | ||
index = index % len(self.infos) | ||
|
||
info = copy.deepcopy(self.infos[index]) | ||
pc_info = info['point_cloud'] | ||
sequence_name = pc_info['lidar_sequence'] | ||
sample_idx = pc_info['sample_idx'] | ||
points = self.get_lidar(sequence_name, sample_idx) | ||
|
||
input_dict = { | ||
'points': points, | ||
'frame_id': info['frame_id'], | ||
} | ||
|
||
if 'annos' in info: | ||
annos = info['annos'] | ||
annos = common_utils.drop_info_with_name(annos, name='unknown') | ||
|
||
if self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False): | ||
gt_boxes_lidar = box_utils.boxes3d_kitti_fakelidar_to_lidar(annos['gt_boxes_lidar']) | ||
else: | ||
gt_boxes_lidar = annos['gt_boxes_lidar'] | ||
|
||
input_dict.update({ | ||
'gt_names': annos['name'], | ||
'gt_boxes': gt_boxes_lidar, | ||
'num_points_in_gt': annos.get('num_points_in_gt', None) | ||
}) | ||
|
||
data_dict = self.prepare_data(data_dict=input_dict) | ||
data_dict['metadata'] = info.get('metadata', info['frame_id']) | ||
data_dict.pop('num_points_in_gt', None) | ||
return data_dict | ||
|
||
@staticmethod | ||
def generate_prediction_dicts(batch_dict, pred_dicts, class_names, output_path=None): | ||
""" | ||
Args: | ||
batch_dict: | ||
frame_id: | ||
pred_dicts: list of pred_dicts | ||
pred_boxes: (N, 7), Tensor | ||
pred_scores: (N), Tensor | ||
pred_labels: (N), Tensor | ||
class_names: | ||
output_path: | ||
Returns: | ||
""" | ||
|
||
def get_template_prediction(num_samples): | ||
ret_dict = { | ||
'name': np.zeros(num_samples), 'score': np.zeros(num_samples), | ||
'boxes_lidar': np.zeros([num_samples, 7]) | ||
} | ||
return ret_dict | ||
|
||
def generate_single_sample_dict(box_dict): | ||
pred_scores = box_dict['pred_scores'].cpu().numpy() | ||
pred_boxes = box_dict['pred_boxes'].cpu().numpy() | ||
pred_labels = box_dict['pred_labels'].cpu().numpy() | ||
pred_dict = get_template_prediction(pred_scores.shape[0]) | ||
if pred_scores.shape[0] == 0: | ||
return pred_dict | ||
|
||
pred_dict['name'] = np.array(class_names)[pred_labels - 1] | ||
pred_dict['score'] = pred_scores | ||
pred_dict['boxes_lidar'] = pred_boxes | ||
|
||
return pred_dict | ||
|
||
annos = [] | ||
for index, box_dict in enumerate(pred_dicts): | ||
single_pred_dict = generate_single_sample_dict(box_dict) | ||
single_pred_dict['frame_id'] = batch_dict['frame_id'][index] | ||
single_pred_dict['metadata'] = batch_dict['metadata'][index] | ||
annos.append(single_pred_dict) | ||
|
||
return annos | ||
|
||
def evaluation(self, det_annos, class_names, **kwargs): | ||
if 'annos' not in self.infos[0].keys(): | ||
return 'No ground-truth boxes for evaluation', {} | ||
|
||
def kitti_eval(eval_det_annos, eval_gt_annos): | ||
from ..kitti.kitti_object_eval_python import eval as kitti_eval | ||
from ..kitti import kitti_utils | ||
|
||
map_name_to_kitti = { | ||
'Vehicle': 'Car', | ||
'Pedestrian': 'Pedestrian', | ||
'Cyclist': 'Cyclist', | ||
'Sign': 'Sign', | ||
'Car': 'Car' | ||
} | ||
kitti_utils.transform_annotations_to_kitti_format(eval_det_annos, map_name_to_kitti=map_name_to_kitti) | ||
kitti_utils.transform_annotations_to_kitti_format( | ||
eval_gt_annos, map_name_to_kitti=map_name_to_kitti, | ||
info_with_fakelidar=self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False) | ||
) | ||
kitti_class_names = [map_name_to_kitti[x] for x in class_names] | ||
ap_result_str, ap_dict = kitti_eval.get_official_eval_result( | ||
gt_annos=eval_gt_annos, dt_annos=eval_det_annos, current_classes=kitti_class_names | ||
) | ||
return ap_result_str, ap_dict | ||
|
||
def waymo_eval(eval_det_annos, eval_gt_annos): | ||
from .waymo_eval import OpenPCDetWaymoDetectionMetricsEstimator | ||
eval = OpenPCDetWaymoDetectionMetricsEstimator() | ||
|
||
ap_dict = eval.waymo_evaluation( | ||
eval_det_annos, eval_gt_annos, class_name=class_names, | ||
distance_thresh=1000, fake_gt_infos=self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False) | ||
) | ||
ap_result_str = '\n' | ||
for key in ap_dict: | ||
ap_dict[key] = ap_dict[key][0] | ||
ap_result_str += '%s: %.4f \n' % (key, ap_dict[key]) | ||
|
||
return ap_result_str, ap_dict | ||
|
||
eval_det_annos = copy.deepcopy(det_annos) | ||
eval_gt_annos = [copy.deepcopy(info['annos']) for info in self.infos] | ||
|
||
if kwargs['eval_metric'] == 'kitti': | ||
ap_result_str, ap_dict = kitti_eval(eval_det_annos, eval_gt_annos) | ||
elif kwargs['eval_metric'] == 'waymo': | ||
ap_result_str, ap_dict = waymo_eval(eval_det_annos, eval_gt_annos) | ||
else: | ||
raise NotImplementedError | ||
|
||
return ap_result_str, ap_dict | ||
|
||
def create_groundtruth_database(self, info_path, save_path, used_classes=None, split='train', sampled_interval=10, | ||
processed_data_tag=None): | ||
database_save_path = save_path / ('pcdet_gt_database_%s_sampled_%d' % (split, sampled_interval)) | ||
db_info_save_path = save_path / ('pcdet_waymo_dbinfos_%s_sampled_%d.pkl' % (split, sampled_interval)) | ||
|
||
database_save_path.mkdir(parents=True, exist_ok=True) | ||
all_db_infos = {} | ||
|
||
with open(info_path, 'rb') as f: | ||
infos = pickle.load(f) | ||
|
||
for k in range(0, len(infos), sampled_interval): | ||
print('gt_database sample: %d/%d' % (k + 1, len(infos))) | ||
info = infos[k] | ||
|
||
pc_info = info['point_cloud'] | ||
sequence_name = pc_info['lidar_sequence'] | ||
sample_idx = pc_info['sample_idx'] | ||
points = self.get_lidar(sequence_name, sample_idx) | ||
|
||
annos = info['annos'] | ||
names = annos['name'] | ||
difficulty = annos['difficulty'] | ||
gt_boxes = annos['gt_boxes_lidar'] | ||
|
||
num_obj = gt_boxes.shape[0] | ||
|
||
box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu( | ||
torch.from_numpy(points[:, 0:3]).unsqueeze(dim=0).float().cuda(), | ||
torch.from_numpy(gt_boxes[:, 0:7]).unsqueeze(dim=0).float().cuda() | ||
).long().squeeze(dim=0).cpu().numpy() | ||
|
||
for i in range(num_obj): | ||
filename = '%s_%04d_%s_%d.bin' % (sequence_name, sample_idx, names[i], i) | ||
filepath = database_save_path / filename | ||
gt_points = points[box_idxs_of_pts == i] | ||
gt_points[:, :3] -= gt_boxes[i, :3] | ||
|
||
if (used_classes is None) or names[i] in used_classes: | ||
with open(filepath, 'w') as f: | ||
gt_points.tofile(f) | ||
|
||
db_path = str(filepath.relative_to(self.root_path)) # gt_database/xxxxx.bin | ||
db_info = {'name': names[i], 'path': db_path, 'sequence_name': sequence_name, | ||
'sample_idx': sample_idx, 'gt_idx': i, 'box3d_lidar': gt_boxes[i], | ||
'num_points_in_gt': gt_points.shape[0], 'difficulty': difficulty[i]} | ||
if names[i] in all_db_infos: | ||
all_db_infos[names[i]].append(db_info) | ||
else: | ||
all_db_infos[names[i]] = [db_info] | ||
for k, v in all_db_infos.items(): | ||
print('Database %s: %d' % (k, len(v))) | ||
|
||
with open(db_info_save_path, 'wb') as f: | ||
pickle.dump(all_db_infos, f) | ||
|
||
|
||
def create_waymo_infos(dataset_cfg, class_names, data_path, save_path, | ||
raw_data_tag='raw_data', processed_data_tag='waymo_processed_data', workers=4): | ||
dataset = WaymoDataset( | ||
dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path, | ||
training=False, logger=common_utils.create_logger() | ||
) | ||
train_split, val_split = 'train', 'val' | ||
|
||
train_filename = save_path / ('waymo_infos_%s.pkl' % train_split) | ||
val_filename = save_path / ('waymo_infos_%s.pkl' % val_split) | ||
|
||
print('---------------Start to generate data infos---------------') | ||
|
||
dataset.set_split(train_split) | ||
waymo_infos_train = dataset.get_infos( | ||
raw_data_path=data_path / raw_data_tag, | ||
save_path=save_path / processed_data_tag, num_workers=workers, has_label=True, | ||
sampled_interval=1 | ||
) | ||
with open(train_filename, 'wb') as f: | ||
pickle.dump(waymo_infos_train, f) | ||
print('Waymo info train file is saved to %s' % train_filename) | ||
|
||
dataset.set_split(val_split) | ||
waymo_infos_val = dataset.get_infos( | ||
raw_data_path=data_path / raw_data_tag, | ||
save_path=save_path / processed_data_tag, num_workers=workers, has_label=True, | ||
sampled_interval=1 | ||
) | ||
with open(val_filename, 'wb') as f: | ||
pickle.dump(waymo_infos_val, f) | ||
print('Waymo info val file is saved to %s' % val_filename) | ||
|
||
print('---------------Start create groundtruth database for data augmentation---------------') | ||
dataset.set_split(train_split) | ||
dataset.create_groundtruth_database( | ||
info_path=train_filename, save_path=save_path, split='train', sampled_interval=10, | ||
used_classes=['Vehicle', 'Pedestrian', 'Cyclist'] | ||
) | ||
print('---------------Data preparation Done---------------') | ||
|
||
|
||
if __name__ == '__main__': | ||
import sys | ||
if sys.argv.__len__() > 1 and sys.argv[1] == 'create_waymo_infos': | ||
import yaml | ||
from pathlib import Path | ||
from easydict import EasyDict | ||
dataset_cfg = EasyDict(yaml.load(open(sys.argv[2]))) | ||
ROOT_DIR = (Path(__file__).resolve().parent / '../../../').resolve() | ||
create_waymo_infos( | ||
dataset_cfg=dataset_cfg, | ||
class_names=['Vehicle', 'Pedestrian', 'Cyclist'], | ||
data_path=ROOT_DIR / 'data' / 'waymo', | ||
save_path=ROOT_DIR / 'data' / 'waymo', | ||
raw_data_tag='raw_data_v1_2', | ||
processed_data_tag=dataset_cfg.PROCESSED_DATA_TAG | ||
) |
Oops, something went wrong.