From 9698f5791f5b1484a1e2a7f16cb4cbbb344b9822 Mon Sep 17 00:00:00 2001 From: jmtao Date: Fri, 8 May 2020 14:39:32 -0700 Subject: [PATCH] planning: add learning_based_task for online data training --- .../learning_model_sample_config.pb.txt | 20 +-- modules/planning/proto/planning_config.proto | 9 +- .../planning/scenarios/learning_model/BUILD | 1 - .../scenarios/learning_model/stage_run.cc | 74 +---------- .../scenarios/learning_model/stage_run.h | 16 +-- modules/planning/tasks/learning_model/BUILD | 23 ++++ .../learning_model/learning_based_task.cc | 117 ++++++++++++++++++ .../learning_model/learning_based_task.h | 57 +++++++++ 8 files changed, 223 insertions(+), 94 deletions(-) create mode 100644 modules/planning/tasks/learning_model/BUILD create mode 100644 modules/planning/tasks/learning_model/learning_based_task.cc create mode 100644 modules/planning/tasks/learning_model/learning_based_task.h diff --git a/modules/planning/conf/scenario/learning_model_sample_config.pb.txt b/modules/planning/conf/scenario/learning_model_sample_config.pb.txt index c1356c97a4a..1ed3ad4a499 100644 --- a/modules/planning/conf/scenario/learning_model_sample_config.pb.txt +++ b/modules/planning/conf/scenario/learning_model_sample_config.pb.txt @@ -1,12 +1,18 @@ scenario_type: LEARNING_MODEL_SAMPLE learning_model_sample_config: { - model_file: "/apollo/modules/planning/tools/planning_demo_model.pt" - input_feature_num: 301084 } - stage_type: LEARNING_MODEL_RUN +stage_type: LEARNING_MODEL_RUN - stage_config: { - stage_type: LEARNING_MODEL_RUN - enabled: true - } +stage_config: { + stage_type: LEARNING_MODEL_RUN + enabled: true + task_type: LEARNING_BASED_TASK + task_config: { + task_type: LEARNING_BASED_TASK + learning_based_task_config { + model_file: "/apollo/modules/planning/tools/planning_demo_model.pt" + input_feature_num: 301084 + } + } +} diff --git a/modules/planning/proto/planning_config.proto b/modules/planning/proto/planning_config.proto index d20dfac6fd6..354d797e08d 100644 --- a/modules/planning/proto/planning_config.proto +++ b/modules/planning/proto/planning_config.proto @@ -56,6 +56,7 @@ message TaskConfig { PATH_REUSE_DECIDER = 30; ST_BOUNDS_DECIDER = 31; PIECEWISE_JERK_NONLINEAR_SPEED_OPTIMIZER = 32; + LEARNING_BASED_TASK = 33; }; optional TaskType task_type = 1; oneof task_config { @@ -81,9 +82,15 @@ message TaskConfig { STBoundsDeciderConfig st_bounds_decider_config = 28; PiecewiseJerkNonlinearSpeedConfig piecewise_jerk_nonlinear_speed_config = 29; + LearningBasedTaskConfig learning_based_task_config = 30; } } +message LearningBasedTaskConfig { + optional string model_file = 1; + optional int32 input_feature_num = 2; +} + message ScenarioLaneFollowConfig { } @@ -194,8 +201,6 @@ message ScenarioYieldSignConfig { } message ScenarioLearningModelSampleConfig { - optional string model_file = 1; - optional int32 input_feature_num = 2; } // scenario configs diff --git a/modules/planning/scenarios/learning_model/BUILD b/modules/planning/scenarios/learning_model/BUILD index 8d787178180..f0dcb391639 100644 --- a/modules/planning/scenarios/learning_model/BUILD +++ b/modules/planning/scenarios/learning_model/BUILD @@ -28,7 +28,6 @@ cc_library( "//modules/common/util:factory", "//modules/planning/reference_line", "//modules/planning/scenarios:scenario", - "//third_party:libtorch", "@eigen", ], ) diff --git a/modules/planning/scenarios/learning_model/stage_run.cc b/modules/planning/scenarios/learning_model/stage_run.cc index 89b3877feff..91eb92f1e44 100644 --- a/modules/planning/scenarios/learning_model/stage_run.cc +++ b/modules/planning/scenarios/learning_model/stage_run.cc @@ -36,81 +36,15 @@ Stage::StageStatus LearningModelSampleStageRun::Process( scenario_config_.CopyFrom(GetContext()->scenario_config); - const auto model_file = scenario_config_.model_file(); - if (apollo::cyber::common::PathExists(model_file)) { - try { - model_ = torch::jit::load(model_file, device_); - } - catch (const c10::Error& e) { - AERROR << "error loading the model:" << model_file; - return StageStatus::ERROR; - } + bool plan_ok = ExecuteTaskOnReferenceLine(planning_init_point, frame); + if (!plan_ok) { + AERROR << "LearningModelSampleStageRun planning error"; + return Stage::RUNNING; } - input_feature_num_ = scenario_config_.input_feature_num(); - - std::vector input_features; - ExtractFeatures(frame, &input_features); - InferenceModel(input_features, frame); return FinishStage(); } -bool LearningModelSampleStageRun::ExtractFeatures(Frame* frame, - std::vector *input_features) { - // TODO(all): generate learning features. - // TODO(all): adapt to new input feature shapes - std::vector tuple; - tuple.push_back(torch::zeros({2, 3, 224, 224})); - tuple.push_back(torch::zeros({2, 14})); - // assumption: future learning model use one dimension input features. - input_features->push_back(torch::ivalue::Tuple::create(tuple)); - - return true; -} - -bool LearningModelSampleStageRun::InferenceModel( - const std::vector &input_features, - Frame* frame) { - auto reference_line_infos = frame->mutable_reference_line_infos(); - if (reference_line_infos->empty()) { - AERROR << "no reference is found."; - return false; - } - // FIXME(all): current only pick up the first reference line to use - // learning model trajectory. - for (auto& reference_line_info : *reference_line_infos) { - reference_line_info.SetDrivable(false); - } - auto& picked_reference_line_info = reference_line_infos->front(); - picked_reference_line_info.SetDrivable(true); - picked_reference_line_info.SetCost(0); - - auto torch_output = model_.forward(input_features); - ADEBUG << torch_output; - auto torch_output_tensor = torch_output.toTensor(); - auto output_shapes = torch_output_tensor.sizes(); - if (output_shapes.empty() || output_shapes.size() < 3) { - AWARN << "invalid response from learning model."; - return false; - } - - // TODO(all): only populate path data from learning model - // for initial version - std::vector trajectory_points; - for (int i = 0; i < output_shapes[1]; ++i) { - TrajectoryPoint p; - p.mutable_path_point()->set_x( - torch_output_tensor.accessor()[0][i][0]); - p.mutable_path_point()->set_y( - torch_output_tensor.accessor()[0][i][1]); - trajectory_points.push_back(p); - } - picked_reference_line_info.SetTrajectory( - DiscretizedTrajectory(trajectory_points)); - - return true; -} - Stage::StageStatus LearningModelSampleStageRun::FinishStage() { return FinishScenario(); } diff --git a/modules/planning/scenarios/learning_model/stage_run.h b/modules/planning/scenarios/learning_model/stage_run.h index 5c187bffa96..2647e210ff7 100644 --- a/modules/planning/scenarios/learning_model/stage_run.h +++ b/modules/planning/scenarios/learning_model/stage_run.h @@ -20,11 +20,6 @@ #pragma once -#include - -#include "torch/script.h" -#include "torch/torch.h" - #include "modules/planning/scenarios/learning_model/learning_model_sample_scenario.h" #include "modules/planning/scenarios/stage.h" @@ -37,8 +32,7 @@ struct LearningModelSampleContext; class LearningModelSampleStageRun : public Stage { public: explicit LearningModelSampleStageRun( - const ScenarioConfig::StageConfig& config) - : Stage(config), device_(torch::kCPU) {} + const ScenarioConfig::StageConfig& config) : Stage(config) {} private: Stage::StageStatus Process(const common::TrajectoryPoint& planning_init_point, @@ -46,17 +40,11 @@ class LearningModelSampleStageRun : public Stage { LearningModelSampleContext* GetContext() { return GetContextAs(); } - bool ExtractFeatures(Frame* frame, - std::vector *input_features); - bool InferenceModel(const std::vector &input_features, - Frame* frame); + Stage::StageStatus FinishStage(); private: ScenarioLearningModelSampleConfig scenario_config_; - torch::Device device_; - torch::jit::script::Module model_; - int input_feature_num_ = 0; }; } // namespace scenario diff --git a/modules/planning/tasks/learning_model/BUILD b/modules/planning/tasks/learning_model/BUILD new file mode 100644 index 00000000000..682dbf6aa3d --- /dev/null +++ b/modules/planning/tasks/learning_model/BUILD @@ -0,0 +1,23 @@ +load("//tools:cpplint.bzl", "cpplint") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "learning_based_task", + srcs = ["learning_based_task.cc"], + hdrs = ["learning_based_task.h"], + copts = ["-DMODULE_NAME=\\\"planning\\\""], + deps = [ + "//modules/common/configs:vehicle_config_helper", + "//modules/common/status", + "//modules/planning/common:frame", + "//modules/planning/common:planning_gflags", + "//modules/planning/common:reference_line_info", + "//modules/planning/proto:planning_config_proto", + "//modules/planning/proto:planning_proto", + "//modules/planning/tasks:task", + "//third_party:libtorch", + ], +) + +cpplint() diff --git a/modules/planning/tasks/learning_model/learning_based_task.cc b/modules/planning/tasks/learning_model/learning_based_task.cc new file mode 100644 index 00000000000..b8e17d062d3 --- /dev/null +++ b/modules/planning/tasks/learning_model/learning_based_task.cc @@ -0,0 +1,117 @@ +/****************************************************************************** + * Copyright 2020 The Apollo Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *****************************************************************************/ + +/** + * @file + **/ + +#include "modules/planning/tasks/learning_model/learning_based_task.h" + +#include "modules/planning/common/planning_gflags.h" + +namespace apollo { +namespace planning { + +using apollo::common::ErrorCode; +using apollo::common::Status; +using apollo::common::TrajectoryPoint; + +Status LearningBasedTask::Execute(Frame *frame, + ReferenceLineInfo *reference_line_info) { + Task::Execute(frame, reference_line_info); + return Process(frame); +} + +Status LearningBasedTask::Process(Frame *frame) { + const auto model_file = config_.model_file(); + if (apollo::cyber::common::PathExists(model_file)) { + try { + model_ = torch::jit::load(model_file, device_); + } + catch (const c10::Error& e) { + AERROR << "error loading the model:" << model_file; + return Status(ErrorCode::PLANNING_ERROR, + "learning based task model file not exist"); + } + } + input_feature_num_ = config_.input_feature_num(); + + std::vector input_features; + ExtractFeatures(frame, &input_features); + InferenceModel(input_features, frame); + + return Status::OK(); +} + +bool LearningBasedTask::ExtractFeatures( + Frame* frame, + std::vector *input_features) { + // TODO(all): generate learning features. + // TODO(all): adapt to new input feature shapes + std::vector tuple; + tuple.push_back(torch::zeros({2, 3, 224, 224})); + tuple.push_back(torch::zeros({2, 14})); + // assumption: future learning model use one dimension input features. + input_features->push_back(torch::ivalue::Tuple::create(tuple)); + + return true; +} + +bool LearningBasedTask::InferenceModel( + const std::vector &input_features, + Frame* frame) { + auto reference_line_infos = frame->mutable_reference_line_infos(); + if (reference_line_infos->empty()) { + AERROR << "no reference is found."; + return false; + } + // FIXME(all): current only pick up the first reference line to use + // learning model trajectory. + for (auto& reference_line_info : *reference_line_infos) { + reference_line_info.SetDrivable(false); + } + auto& picked_reference_line_info = reference_line_infos->front(); + picked_reference_line_info.SetDrivable(true); + picked_reference_line_info.SetCost(0); + + auto torch_output = model_.forward(input_features); + ADEBUG << torch_output; + auto torch_output_tensor = torch_output.toTensor(); + auto output_shapes = torch_output_tensor.sizes(); + if (output_shapes.empty() || output_shapes.size() < 3) { + AWARN << "invalid response from learning model."; + return false; + } + + // TODO(all): only populate path data from learning model + // for initial version + std::vector trajectory_points; + for (int i = 0; i < output_shapes[1]; ++i) { + TrajectoryPoint p; + p.mutable_path_point()->set_x( + torch_output_tensor.accessor()[0][i][0]); + p.mutable_path_point()->set_y( + torch_output_tensor.accessor()[0][i][1]); + trajectory_points.push_back(p); + } + picked_reference_line_info.SetTrajectory( + DiscretizedTrajectory(trajectory_points)); + + return true; +} + +} // namespace planning +} // namespace apollo diff --git a/modules/planning/tasks/learning_model/learning_based_task.h b/modules/planning/tasks/learning_model/learning_based_task.h new file mode 100644 index 00000000000..07ee40bea59 --- /dev/null +++ b/modules/planning/tasks/learning_model/learning_based_task.h @@ -0,0 +1,57 @@ +/****************************************************************************** + * Copyright 2020 The Apollo Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *****************************************************************************/ + +/** + * @file + **/ + +#pragma once + +#include + +#include "torch/script.h" +#include "torch/torch.h" + +#include "modules/planning/proto/planning_config.pb.h" +#include "modules/planning/tasks/task.h" + +namespace apollo { +namespace planning { + +class LearningBasedTask : public Task { + public: + explicit LearningBasedTask(const TaskConfig &config) + : Task(config), device_(torch::kCPU) {} + + apollo::common::Status Execute( + Frame *frame, ReferenceLineInfo *reference_line_info) override; + + private: + apollo::common::Status Process(Frame *frame); + + bool ExtractFeatures(Frame* frame, + std::vector *input_features); + bool InferenceModel(const std::vector &input_features, + Frame* frame); + private: + LearningBasedTaskConfig config_; + torch::Device device_; + torch::jit::script::Module model_; + int input_feature_num_ = 0; +}; + +} // namespace planning +} // namespace apollo