Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
 
Tools: update data extraction. (1)Enable small channel storage procedure, as in bin file; (2) Enable different extraction rates for different channels; (3) Enable Gps Extraction;(4) Add more saint checks; (ApolloAuto#7678)
  • Loading branch information
gchen-apollo committed Apr 5, 2019
1 parent 9fa5e30 commit 8471914
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env python

###############################################################################
# Copyright 2019 The Apollo Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################

"""
This is a bunch of classes to manage cyber record channel FileIO.
"""
import os
import re
import sys
import struct

class FileObject:
""" Wrapper for file object """

# Initalizing file obj
def __init__(self, file_path, operation='write', filetype='binary'):
if operation != 'write' and operation != 'read':
raise ValueError("Unsupported file operation: %s " % operation)

if filetype != 'binary' and filetype != 'txt':
raise ValueError("Unsupported file type: %s " % filetype)

operator = 'w' if operation == 'write' else 'r'
operator += 'b' if filetype == 'binary' else ''

try:
self._file_object = open(file_path, operator)
except IOError:
raise ValueError("cannot open file: {}".format(file_path))

def file_object(self):
return self._file_object

def save_to_file(self, data):
raise NotImplementedError

# Safely close file
def __del__(self):
self._file_object.close()

class OdometryFileObject(FileObject):
"""class to handle gnss/odometry topic"""

def save_to_file(self, data):
"""
odometry data: total 9 elements
[
double timestamp
int32 postion_type
double qw, qx, qy, qz
double x, y, z
]
"""
if type(data) != list:
raise ValueError("Odometry data must be in a list")
data_size = len(data)
self._file_object.write(struct.pack('i', data_size))
# TODO: gchen-Apollo, Check why pack_d still have 72 byte,
# not 68 = 8 + 4 + 7*8
s = struct.Struct('di7d')
for d in data:
pack_d = s.pack(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8])
self._file_object.write(pack_d)
print("pack_d length is %d " % len(pack_d))
print(d)
158 changes: 136 additions & 22 deletions modules/tools/vehicle_calibration/sensor_calibration/extract_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,57 @@

from modules.drivers.proto import sensor_image_pb2
from modules.drivers.proto import pointcloud_pb2
from modules.localization.proto import gps_pb2
from data_file_object import *
#from scripts.record_bag import SMALL_TOPICS


SMALL_TOPICS = [
'/apollo/canbus/chassis',
'/apollo/canbus/chassis_detail',
'/apollo/control',
'/apollo/control/pad',
'/apollo/drive_event',
'/apollo/guardian',
'/apollo/hmi/status',
'/apollo/localization/pose',
'/apollo/localization/msf_gnss',
'/apollo/localization/msf_lidar',
'/apollo/localization/msf_status',
'/apollo/monitor',
'/apollo/monitor/system_status',
'/apollo/navigation',
'/apollo/perception/obstacles',
'/apollo/perception/traffic_light',
'/apollo/planning',
'/apollo/prediction',
'/apollo/relative_map',
'/apollo/routing_request',
'/apollo/routing_response',
'/apollo/routing_response_history',
'/apollo/sensor/conti_radar',
'/apollo/sensor/delphi_esr',
'/apollo/sensor/gnss/best_pose',
'/apollo/sensor/gnss/corrected_imu',
'/apollo/sensor/gnss/gnss_status',
'/apollo/sensor/gnss/imu',
'/apollo/sensor/gnss/ins_stat',
'/apollo/sensor/gnss/odometry',
'/apollo/sensor/gnss/raw_data',
'/apollo/sensor/gnss/rtk_eph',
'/apollo/sensor/gnss/rtk_obs',
'/apollo/sensor/gnss/heading',
'/apollo/sensor/mobileye',
'/tf',
'/tf_static',
]

CYBER_PATH = os.environ['CYBER_PATH']

CYBER_RECORD_HEADER_LENGTH = 2048

IMAGE_OBJ = sensor_image_pb2.Image()
POINTCLOUD_OBJ = pointcloud_pb2.PointCloud()
GPS_OBJ = gps_pb2.Gps()

def process_dir(path, operation):
"""Create or remove directory."""
Expand All @@ -65,7 +110,7 @@ def process_dir(path, operation):
return True

def extract_camera_data(dest_dir, msg):
"""Extract camera file from message according to ratio."""
"""Extract camera file from message according to rate."""
#TODO: change saving logic
cur_time_second = msg.timestamp
image = IMAGE_OBJ
Expand Down Expand Up @@ -154,16 +199,43 @@ def extract_pcd_data(dest_dir, msg):
pc_meta = make_xyzit_point_cloud(pointcloud.point)
pcd_file = os.path.join(dest_dir, '{}.pcd'.format(cur_time_second))
pypcd.save_point_cloud_bin_compressed(pc_meta, pcd_file)
#TODO: (@gchen-Apollo) add saint check
return True

def extract_gps_data(dest_dir, msg, out_msg_list):
"""
Save gps information to bin file, to be fed into following tools
"""
if type(out_msg_list) != list:
raise ValueError("Gps/Odometry msg should be saved as a list,\
not %s " % type(out_msg_list))

gps = GPS_OBJ
gps.ParseFromString(msg.message)

# all double, except point_type is int32
ts = gps.header.timestamp_sec
point_type = 0
qw = gps.localization.orientation.qw
qx = gps.localization.orientation.qx
qy = gps.localization.orientation.qy
qz = gps.localization.orientation.qz
x = gps.localization.position.x
y = gps.localization.position.y
z = gps.localization.position.z
# save 9 values as a tuple, for eaisier struct packing during storage
out_msg_list.append((ts, point_type, qw, qx, qy, qz, x, y, z))

print(gps)

#TODO: gchen-Apollo, build class, to encapsulate inner data structure
return out_msg_list

def get_sensor_channel_list(record_file):
"""Get the channel list of sensors for calibration."""
record_reader = RecordReader(record_file)
return set(channel_name for channel_name in record_reader.get_channellist()
if 'sensor' in channel_name)
# return [channel_name for channel_name in record_reader.get_channellist()
# if 'sensor' in channel_name]


def validate_record(record_file):
"""Validate the record file."""
Expand Down Expand Up @@ -195,7 +267,7 @@ def validate_record(record_file):
(record_file, header.message_number, header.channel_number))
return False

# There should be at least has one sensor channel
# There should be at least one sensor channel
sensor_channels = get_sensor_channel_list(record_file)
if len(sensor_channels) < 1:
print('Record file: %s. cannot find sensor channels.' % record_file)
Expand All @@ -204,20 +276,21 @@ def validate_record(record_file):
return True


def extract_channel_data(output_path, msg):
def extract_channel_data(output_path, msg, channel_msgs_dict={}):
"""Process channel messages."""
# timestamp = msg.timestamp / float(1e9)
# if abs(begin_time - timestamp) > 2:
channel_desc = msg.data_type
if channel_desc == 'apollo.drivers.Image':
extract_camera_data(output_path, msg)
elif channel_desc == 'apollo.drivers.PointCloud':
extract_pcd_data(output_path, msg)
elif channel_desc == 'apollo.localization.Gps':
channel_msgs_dict[msg.topic] = extract_gps_data(output_path, msg,
channel_msgs_dict[msg.topic])
else:
# (TODO) (LiuJie/gchen-Apollo) Handle binary data extraction.
print('Not implemented!')

return True
return True, channel_msgs_dict

def validate_channel_list(channels, dictionary):
ret = True
Expand All @@ -233,16 +306,12 @@ def in_range(v, s, e):
return True if v >= s and v <= e else False

def extract_data(record_file, output_path, channel_list,
start_timestamp, end_timestamp, extraction_ratio):
start_timestamp, end_timestamp, extraction_rate_dict):
"""
Extract the desired channel messages if channel_list is specified.
Otherwise extract all sensor calibration messages according to
extraction ratio, 10% by default.
extraction rate, 10% by default.
"""
# Validate extration_ratio, and set it as an integer.
if extraction_ratio < 1.0:
raise ValueError("Extraction rate must be a number greater than 1.")
extraction_ratio = np.floor(extraction_ratio)

sensor_channels = get_sensor_channel_list(record_file)
if len(channel_list) > 0 and validate_channel_list(channel_list,
Expand All @@ -263,13 +332,17 @@ def extract_data(record_file, output_path, channel_list,
channel_success_dict = {}
channel_occur_time = {}
channel_output_path = {}
channel_msgs_dict = {}
for channel in channel_list:
channel_success_dict[channel] = True
channel_occur_time[channel] = -1
topic_name = channel.replace('/', '_')
channel_output_path[channel] = os.path.join(output_path, topic_name)
process_dir(channel_output_path[channel], operation='create')

if channel in SMALL_TOPICS:
channel_msgs_dict[channel] = list()

record_reader = RecordReader(record_file)
for msg in record_reader.read_messages():
if msg.topic in channel_list:
Expand All @@ -279,11 +352,12 @@ def extract_data(record_file, output_path, channel_list,
continue

channel_occur_time[msg.topic] += 1
# Extract the topic according to extraction_ratio
if channel_occur_time[msg.topic] % extraction_ratio != 0:
# Extract the topic according to extraction_rate
if channel_occur_time[msg.topic] % extraction_rate_dict[msg.topic] != 0:
continue

ret = extract_channel_data(channel_output_path[msg.topic], msg)
ret, channel_msgs_dict = extract_channel_data(channel_output_path[msg.topic], msg,
channel_msgs_dict)
# Calculate parsing statistics
if ret is False:
process_msg_failure_num += 1
Expand All @@ -293,6 +367,11 @@ def extract_data(record_file, output_path, channel_list,
process_channel_success_num -= 1
print('Failed to extract data from channel: %s' % msg.topic)

# traverse the dict, if any channel topic stored as a list
# then save the list as a summary file, mostly binary file
for channel, msg_list in channel_msgs_dict.items():
save_msg_list_to_file(channel_output_path[channel], channel, msg_list)

# Logging statics about channel extraction
print('Extracted sensor channel number [%d] in record file: %s' %
(len(channel_list), record_file))
Expand All @@ -303,6 +382,16 @@ def extract_data(record_file, output_path, channel_list,

return True

def save_msg_list_to_file(out_path, channel, msg_list):
if 'odometry' in channel:
# generate file objects for small topics I/O
file_path = os.path.join(out_path, 'messages.bin')
odometry_file_obj = OdometryFileObject(file_path)
print(len(msg_list))
odometry_file_obj.save_to_file(msg_list)
else:
raise ValueError("saving func for {} not implemented".format(channel))

def generate_compressed_file(extracted_data_path,
compressed_file='sensor_data'):
"""
Expand All @@ -313,6 +402,27 @@ def generate_compressed_file(extracted_data_path,
with tarfile.open(compressed_file + '.tar.gz', 'w:gz') as g_out:
g_out.add(extracted_data_path, arcname=compressed_file)

def generate_extraction_rate_dict(channel_list, large_topic_extraction_rate,
small_topic_extraction_rate=1):
"""
Default extraction rate for small topics is 1, which means no sampling
"""

# Validate extration_rate, and set it as an integer.
if large_topic_extraction_rate < 1.0 or small_topic_extraction_rate < 1.0:
raise ValueError("Extraction rate must be a number no less than 1.")

large_topic_extraction_rate = np.floor(large_topic_extraction_rate)
small_topic_extraction_rate = np.floor(small_topic_extraction_rate)

rate_dict = {}
for channel in channel_list:
if channel in SMALL_TOPICS:
rate_dict[channel] = small_topic_extraction_rate
else:
rate_dict[channel] = large_topic_extraction_rate

return rate_dict

def main():
"""
Expand Down Expand Up @@ -343,8 +453,8 @@ def main():
parser.add_argument("-e", "--end_timestamp", action="store", type=float,
default=np.finfo(np.float32).max,
help="Specify the ending timestamp to extract data information.")
parser.add_argument("-r", "--extraction_ratio", action="store", type=int,
default=10, help="The output compressed file.")
parser.add_argument("-r", "--extraction_rate", action="store", type=int,
default=10, help="extraction rate for channel with large storage cost.")

args = parser.parse_args()

Expand All @@ -364,8 +474,12 @@ def main():
sys.exit(1)

channel_list = set(channel_name for channel_name in args.channel_list)
extraction_rate_dict = generate_extraction_rate_dict(channel_list,
large_topic_extraction_rate=args.extraction_rate,
small_topic_extraction_rate=1)

ret = extract_data(args.record_path, output_path, channel_list,
args.start_timestamp, args.end_timestamp, args.extraction_ratio)
args.start_timestamp, args.end_timestamp, extraction_rate_dict)
if ret is False:
print('Failed to extract data!')

Expand Down

0 comments on commit 8471914

Please sign in to comment.