Skip to content

Commit

Permalink
planning: add learning_based_task for online data training
Browse files Browse the repository at this point in the history
  • Loading branch information
jmtao authored and yifeijiang committed May 9, 2020
1 parent b3e876a commit 9698f57
Show file tree
Hide file tree
Showing 8 changed files with 223 additions and 94 deletions.
20 changes: 13 additions & 7 deletions modules/planning/conf/scenario/learning_model_sample_config.pb.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
scenario_type: LEARNING_MODEL_SAMPLE
learning_model_sample_config: {
model_file: "/apollo/modules/planning/tools/planning_demo_model.pt"
input_feature_num: 301084
}

stage_type: LEARNING_MODEL_RUN
stage_type: LEARNING_MODEL_RUN

stage_config: {
stage_type: LEARNING_MODEL_RUN
enabled: true
}
stage_config: {
stage_type: LEARNING_MODEL_RUN
enabled: true
task_type: LEARNING_BASED_TASK
task_config: {
task_type: LEARNING_BASED_TASK
learning_based_task_config {
model_file: "/apollo/modules/planning/tools/planning_demo_model.pt"
input_feature_num: 301084
}
}
}
9 changes: 7 additions & 2 deletions modules/planning/proto/planning_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ message TaskConfig {
PATH_REUSE_DECIDER = 30;
ST_BOUNDS_DECIDER = 31;
PIECEWISE_JERK_NONLINEAR_SPEED_OPTIMIZER = 32;
LEARNING_BASED_TASK = 33;
};
optional TaskType task_type = 1;
oneof task_config {
Expand All @@ -81,9 +82,15 @@ message TaskConfig {
STBoundsDeciderConfig st_bounds_decider_config = 28;
PiecewiseJerkNonlinearSpeedConfig piecewise_jerk_nonlinear_speed_config =
29;
LearningBasedTaskConfig learning_based_task_config = 30;
}
}

message LearningBasedTaskConfig {
optional string model_file = 1;
optional int32 input_feature_num = 2;
}

message ScenarioLaneFollowConfig {
}

Expand Down Expand Up @@ -194,8 +201,6 @@ message ScenarioYieldSignConfig {
}

message ScenarioLearningModelSampleConfig {
optional string model_file = 1;
optional int32 input_feature_num = 2;
}

// scenario configs
Expand Down
1 change: 0 additions & 1 deletion modules/planning/scenarios/learning_model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ cc_library(
"//modules/common/util:factory",
"//modules/planning/reference_line",
"//modules/planning/scenarios:scenario",
"//third_party:libtorch",
"@eigen",
],
)
Expand Down
74 changes: 4 additions & 70 deletions modules/planning/scenarios/learning_model/stage_run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,81 +36,15 @@ Stage::StageStatus LearningModelSampleStageRun::Process(

scenario_config_.CopyFrom(GetContext()->scenario_config);

const auto model_file = scenario_config_.model_file();
if (apollo::cyber::common::PathExists(model_file)) {
try {
model_ = torch::jit::load(model_file, device_);
}
catch (const c10::Error& e) {
AERROR << "error loading the model:" << model_file;
return StageStatus::ERROR;
}
bool plan_ok = ExecuteTaskOnReferenceLine(planning_init_point, frame);
if (!plan_ok) {
AERROR << "LearningModelSampleStageRun planning error";
return Stage::RUNNING;
}
input_feature_num_ = scenario_config_.input_feature_num();

std::vector<torch::jit::IValue> input_features;
ExtractFeatures(frame, &input_features);
InferenceModel(input_features, frame);

return FinishStage();
}

bool LearningModelSampleStageRun::ExtractFeatures(Frame* frame,
std::vector<torch::jit::IValue> *input_features) {
// TODO(all): generate learning features.
// TODO(all): adapt to new input feature shapes
std::vector<torch::jit::IValue> tuple;
tuple.push_back(torch::zeros({2, 3, 224, 224}));
tuple.push_back(torch::zeros({2, 14}));
// assumption: future learning model use one dimension input features.
input_features->push_back(torch::ivalue::Tuple::create(tuple));

return true;
}

bool LearningModelSampleStageRun::InferenceModel(
const std::vector<torch::jit::IValue> &input_features,
Frame* frame) {
auto reference_line_infos = frame->mutable_reference_line_infos();
if (reference_line_infos->empty()) {
AERROR << "no reference is found.";
return false;
}
// FIXME(all): current only pick up the first reference line to use
// learning model trajectory.
for (auto& reference_line_info : *reference_line_infos) {
reference_line_info.SetDrivable(false);
}
auto& picked_reference_line_info = reference_line_infos->front();
picked_reference_line_info.SetDrivable(true);
picked_reference_line_info.SetCost(0);

auto torch_output = model_.forward(input_features);
ADEBUG << torch_output;
auto torch_output_tensor = torch_output.toTensor();
auto output_shapes = torch_output_tensor.sizes();
if (output_shapes.empty() || output_shapes.size() < 3) {
AWARN << "invalid response from learning model.";
return false;
}

// TODO(all): only populate path data from learning model
// for initial version
std::vector<TrajectoryPoint> trajectory_points;
for (int i = 0; i < output_shapes[1]; ++i) {
TrajectoryPoint p;
p.mutable_path_point()->set_x(
torch_output_tensor.accessor<float, 3>()[0][i][0]);
p.mutable_path_point()->set_y(
torch_output_tensor.accessor<float, 3>()[0][i][1]);
trajectory_points.push_back(p);
}
picked_reference_line_info.SetTrajectory(
DiscretizedTrajectory(trajectory_points));

return true;
}

Stage::StageStatus LearningModelSampleStageRun::FinishStage() {
return FinishScenario();
}
Expand Down
16 changes: 2 additions & 14 deletions modules/planning/scenarios/learning_model/stage_run.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@

#pragma once

#include <vector>

#include "torch/script.h"
#include "torch/torch.h"

#include "modules/planning/scenarios/learning_model/learning_model_sample_scenario.h"
#include "modules/planning/scenarios/stage.h"

Expand All @@ -37,26 +32,19 @@ struct LearningModelSampleContext;
class LearningModelSampleStageRun : public Stage {
public:
explicit LearningModelSampleStageRun(
const ScenarioConfig::StageConfig& config)
: Stage(config), device_(torch::kCPU) {}
const ScenarioConfig::StageConfig& config) : Stage(config) {}

private:
Stage::StageStatus Process(const common::TrajectoryPoint& planning_init_point,
Frame* frame) override;
LearningModelSampleContext* GetContext() {
return GetContextAs<LearningModelSampleContext>();
}
bool ExtractFeatures(Frame* frame,
std::vector<torch::jit::IValue> *input_features);
bool InferenceModel(const std::vector<torch::jit::IValue> &input_features,
Frame* frame);

Stage::StageStatus FinishStage();

private:
ScenarioLearningModelSampleConfig scenario_config_;
torch::Device device_;
torch::jit::script::Module model_;
int input_feature_num_ = 0;
};

} // namespace scenario
Expand Down
23 changes: 23 additions & 0 deletions modules/planning/tasks/learning_model/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
load("//tools:cpplint.bzl", "cpplint")

package(default_visibility = ["//visibility:public"])

cc_library(
name = "learning_based_task",
srcs = ["learning_based_task.cc"],
hdrs = ["learning_based_task.h"],
copts = ["-DMODULE_NAME=\\\"planning\\\""],
deps = [
"//modules/common/configs:vehicle_config_helper",
"//modules/common/status",
"//modules/planning/common:frame",
"//modules/planning/common:planning_gflags",
"//modules/planning/common:reference_line_info",
"//modules/planning/proto:planning_config_proto",
"//modules/planning/proto:planning_proto",
"//modules/planning/tasks:task",
"//third_party:libtorch",
],
)

cpplint()
117 changes: 117 additions & 0 deletions modules/planning/tasks/learning_model/learning_based_task.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/******************************************************************************
* Copyright 2020 The Apollo Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*****************************************************************************/

/**
* @file
**/

#include "modules/planning/tasks/learning_model/learning_based_task.h"

#include "modules/planning/common/planning_gflags.h"

namespace apollo {
namespace planning {

using apollo::common::ErrorCode;
using apollo::common::Status;
using apollo::common::TrajectoryPoint;

Status LearningBasedTask::Execute(Frame *frame,
ReferenceLineInfo *reference_line_info) {
Task::Execute(frame, reference_line_info);
return Process(frame);
}

Status LearningBasedTask::Process(Frame *frame) {
const auto model_file = config_.model_file();
if (apollo::cyber::common::PathExists(model_file)) {
try {
model_ = torch::jit::load(model_file, device_);
}
catch (const c10::Error& e) {
AERROR << "error loading the model:" << model_file;
return Status(ErrorCode::PLANNING_ERROR,
"learning based task model file not exist");
}
}
input_feature_num_ = config_.input_feature_num();

std::vector<torch::jit::IValue> input_features;
ExtractFeatures(frame, &input_features);
InferenceModel(input_features, frame);

return Status::OK();
}

bool LearningBasedTask::ExtractFeatures(
Frame* frame,
std::vector<torch::jit::IValue> *input_features) {
// TODO(all): generate learning features.
// TODO(all): adapt to new input feature shapes
std::vector<torch::jit::IValue> tuple;
tuple.push_back(torch::zeros({2, 3, 224, 224}));
tuple.push_back(torch::zeros({2, 14}));
// assumption: future learning model use one dimension input features.
input_features->push_back(torch::ivalue::Tuple::create(tuple));

return true;
}

bool LearningBasedTask::InferenceModel(
const std::vector<torch::jit::IValue> &input_features,
Frame* frame) {
auto reference_line_infos = frame->mutable_reference_line_infos();
if (reference_line_infos->empty()) {
AERROR << "no reference is found.";
return false;
}
// FIXME(all): current only pick up the first reference line to use
// learning model trajectory.
for (auto& reference_line_info : *reference_line_infos) {
reference_line_info.SetDrivable(false);
}
auto& picked_reference_line_info = reference_line_infos->front();
picked_reference_line_info.SetDrivable(true);
picked_reference_line_info.SetCost(0);

auto torch_output = model_.forward(input_features);
ADEBUG << torch_output;
auto torch_output_tensor = torch_output.toTensor();
auto output_shapes = torch_output_tensor.sizes();
if (output_shapes.empty() || output_shapes.size() < 3) {
AWARN << "invalid response from learning model.";
return false;
}

// TODO(all): only populate path data from learning model
// for initial version
std::vector<TrajectoryPoint> trajectory_points;
for (int i = 0; i < output_shapes[1]; ++i) {
TrajectoryPoint p;
p.mutable_path_point()->set_x(
torch_output_tensor.accessor<float, 3>()[0][i][0]);
p.mutable_path_point()->set_y(
torch_output_tensor.accessor<float, 3>()[0][i][1]);
trajectory_points.push_back(p);
}
picked_reference_line_info.SetTrajectory(
DiscretizedTrajectory(trajectory_points));

return true;
}

} // namespace planning
} // namespace apollo
57 changes: 57 additions & 0 deletions modules/planning/tasks/learning_model/learning_based_task.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/******************************************************************************
* Copyright 2020 The Apollo Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*****************************************************************************/

/**
* @file
**/

#pragma once

#include <vector>

#include "torch/script.h"
#include "torch/torch.h"

#include "modules/planning/proto/planning_config.pb.h"
#include "modules/planning/tasks/task.h"

namespace apollo {
namespace planning {

class LearningBasedTask : public Task {
public:
explicit LearningBasedTask(const TaskConfig &config)
: Task(config), device_(torch::kCPU) {}

apollo::common::Status Execute(
Frame *frame, ReferenceLineInfo *reference_line_info) override;

private:
apollo::common::Status Process(Frame *frame);

bool ExtractFeatures(Frame* frame,
std::vector<torch::jit::IValue> *input_features);
bool InferenceModel(const std::vector<torch::jit::IValue> &input_features,
Frame* frame);
private:
LearningBasedTaskConfig config_;
torch::Device device_;
torch::jit::script::Module model_;
int input_feature_num_ = 0;
};

} // namespace planning
} // namespace apollo

0 comments on commit 9698f57

Please sign in to comment.