Skip to content

Commit

Permalink
SignalLogger: Deprecate/replace threading hazard
Browse files Browse the repository at this point in the history
Relevant to: RobotLocomotion#10228

Add storage modes to SignalLogger; allow the internal SignalLog member to
become a cache entry. Limit the data access methods on SignalLogger to the
legacy unsafe mode, and alternatively provide ergonomic lookup methods to get
the SignalLog from a context in the per-context mode.

Make SignalLog copyable and movable for compatibility with the cache.

Provide complete python bindings for SignalLogger, SignalLog, and the
(temporary) LogStorageMode enum.

Update numerous tests and examples.
  • Loading branch information
rpoyner-tri committed Jul 21, 2021
1 parent cff0f37 commit 70b5120
Show file tree
Hide file tree
Showing 21 changed files with 446 additions and 148 deletions.
109 changes: 94 additions & 15 deletions bindings/pydrake/systems/primitives_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ PYBIND11_MODULE(primitives, m) {
m.doc() = "Bindings for the primitives portion of the Systems framework.";
constexpr auto& doc = pydrake_doc.drake.systems;

py::enum_<LogStorageMode>(m, "LogStorageMode", doc.LogStorageMode.doc)
.value("kDeprecatedLogPerSystem", kDeprecatedLogPerSystem,
doc.LogStorageMode.kDeprecatedLogPerSystem.doc)
.value("kLogPerContext", kLogPerContext,
doc.LogStorageMode.kLogPerContext.doc)
.export_values();

py::module::import("pydrake.systems.framework");
// N.B. Capturing `&doc` should not be required; workaround per #9600.
auto bind_common_scalar_types = [m, &doc](auto dummy) {
Expand Down Expand Up @@ -229,38 +236,110 @@ PYBIND11_MODULE(primitives, m) {
py::arg("min_value"), py::arg("max_value"),
doc.Saturation.ctor.doc_2args);

DefineTemplateClassWithDefault<SignalLogger<T>, LeafSystem<T>>(
m, "SignalLogger", GetPyParam<T>(), doc.SignalLogger.doc)
DefineTemplateClassWithDefault<SignalLog<T>>(
m, "SignalLog", GetPyParam<T>(), doc.SignalLog.doc)
.def(py::init<int, int>(), py::arg("input_size"),
py::arg("batch_allocation_size") = 1000, doc.SignalLogger.ctor.doc)
.def("set_publish_period", &SignalLogger<T>::set_publish_period,
py::arg("period"), doc.SignalLogger.set_publish_period.doc)
.def("set_forced_publish_only",
&SignalLogger<T>::set_forced_publish_only,
doc.SignalLogger.set_forced_publish_only.doc)
py::arg("batch_allocation_size") = 1000, doc.SignalLog.ctor.doc)
.def("num_samples", &SignalLog<T>::num_samples,
doc.SignalLog.num_samples.doc)
.def(
"sample_times",
[](const SignalLog<T>* self) {
// Reference
return CopyIfNotPodType(self->sample_times());
},
return_value_policy_for_scalar_type<T>(),
doc.SignalLog.sample_times.doc)
.def(
"data",
[](const SignalLog<T>* self) {
// Reference.
return CopyIfNotPodType(self->data());
},
return_value_policy_for_scalar_type<T>(), doc.SignalLog.data.doc)
.def("reset", &SignalLog<T>::reset, doc.SignalLog.reset.doc)
.def("AddData", &SignalLog<T>::AddData, py::arg("time"),
py::arg("sample"), doc.SignalLog.AddData.doc)
.def("get_input_size", &SignalLog<T>::get_input_size,
doc.SignalLog.get_input_size.doc);

auto cls =
DefineTemplateClassWithDefault<SignalLogger<T>, LeafSystem<T>>(
m, "SignalLogger", GetPyParam<T>(), doc.SignalLogger.doc)
.def(py::init<int, int, LogStorageMode>(), py::arg("input_size"),
py::arg("batch_allocation_size") = 1000,
py::arg("storage_mode") = kDeprecatedLogPerSystem,
doc.SignalLogger.ctor.doc_3args)
.def(py::init<int, LogStorageMode>(), py::arg("input_size"),
py::arg("storage_mode"), doc.SignalLogger.ctor.doc_2args)
.def("set_publish_period", &SignalLogger<T>::set_publish_period,
py::arg("period"), doc.SignalLogger.set_publish_period.doc)
.def("set_forced_publish_only",
&SignalLogger<T>::set_forced_publish_only,
doc.SignalLogger.set_forced_publish_only.doc)
.def(
"GetLog",
[](const SignalLogger<T>* self, const Context<T>& context)
-> const SignalLog<T>& { return self->GetLog(context); },
py::arg("context"), py_rvp::reference,
doc.SignalLogger.GetLog.doc_1args)
.def(
"GetLog",
[](const SignalLogger<T>* self, const System<T>& outer_system,
const Context<T>& outer_context) -> const SignalLog<T>& {
return self->GetLog(outer_system, outer_context);
},
py::arg("outer_system"), py::arg("outer_context"),
py_rvp::reference, doc.SignalLogger.GetLog.doc_2args)
.def(
"GetMutableLog",
[](const SignalLogger<T>* self, const Context<T>& context)
-> SignalLog<T>& { return self->GetMutableLog(context); },
py::arg("context"), py_rvp::reference,
doc.SignalLogger.GetMutableLog.doc_1args)
.def(
"GetMutableLog",
[](const SignalLogger<T>* self, const System<T>& outer_system,
const Context<T>& outer_context) -> SignalLog<T>& {
return self->GetMutableLog(outer_system, outer_context);
},
py::arg("outer_system"), py::arg("outer_context"),
py_rvp::reference, doc.SignalLogger.GetMutableLog.doc_2args);

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
cls.def("num_samples", &SignalLogger<T>::num_samples,
doc.SignalLogger.num_samples.doc_deprecated)
.def(
"sample_times",
[](const SignalLogger<T>* self) {
// Reference
return CopyIfNotPodType(self->sample_times());
},
return_value_policy_for_scalar_type<T>(),
doc.SignalLogger.sample_times.doc)
doc.SignalLogger.sample_times.doc_deprecated)
.def(
"data",
[](const SignalLogger<T>* self) {
// Reference.
return CopyIfNotPodType(self->data());
},
return_value_policy_for_scalar_type<T>(), doc.SignalLogger.data.doc)
.def("reset", &SignalLogger<T>::reset, doc.SignalLogger.reset.doc);
return_value_policy_for_scalar_type<T>(),
doc.SignalLogger.data.doc_deprecated)
.def("reset", &SignalLogger<T>::reset,
doc.SignalLogger.reset.doc_deprecated);
#pragma GCC diagnostic pop

AddTemplateFunction(m, "LogOutput", &LogOutput<T>, GetPyParam<T>(),
py::arg("src"), py::arg("builder"),
AddTemplateFunction(m, "LogOutput",
py::overload_cast<const OutputPort<T>&, DiagramBuilder<T>*, int,
LogStorageMode>(&LogOutput<T>),
GetPyParam<T>(), py::arg("src"), py::arg("builder"),
py::arg("batch_allocation_size") = 1000,
py::arg("storage_mode") = kDeprecatedLogPerSystem,
// Keep alive, ownership: `return` keeps `builder` alive.
py::keep_alive<0, 2>(),
// See #11531 for why `py_rvp::reference` is needed.
py_rvp::reference, doc.LogOutput.doc);
py_rvp::reference, doc.LogOutput.doc_4args);

DefineTemplateClassWithDefault<StateInterpolatorWithDiscreteDerivative<T>,
Diagram<T>>(m, "StateInterpolatorWithDiscreteDerivative",
Expand Down Expand Up @@ -527,7 +606,7 @@ PYBIND11_MODULE(primitives, m) {
py::arg("threshold") = std::nullopt, doc.IsObservable.doc);

// TODO(eric.cousineau): Add more systems as needed.
}
} // NOLINT(readability/fn_size)

} // namespace pydrake
} // namespace drake
6 changes: 3 additions & 3 deletions bindings/pydrake/systems/pyplot_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from warnings import warn

from pydrake.systems.framework import LeafSystem, PublishEvent, TriggerType
from pydrake.systems.primitives import SignalLogger
from pydrake.systems.primitives import SignalLog
from pydrake.trajectories import Trajectory


Expand Down Expand Up @@ -108,7 +108,7 @@ def get_recording_as_animation(self, **kwargs):
def animate(self, log, resample=True, **kwargs):
"""
Args:
log: A reference to a pydrake.systems.primitives.SignalLogger
log: A reference to a pydrake.systems.primitives.SignalLog
or a pydrake.trajectories.Trajectory that contains the plant
state after running a simulation.
resample: Whether we should do a resampling operation to make the
Expand All @@ -117,7 +117,7 @@ def animate(self, log, resample=True, **kwargs):
matches the sample timestep of the log.
Additional kwargs are passed through to FuncAnimation.
"""
if isinstance(log, SignalLogger):
if isinstance(log, SignalLog):
t = log.sample_times()
x = log.data()

Expand Down
29 changes: 19 additions & 10 deletions bindings/pydrake/systems/test/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
Integrator, Integrator_,
IsControllable,
IsObservable,
kLogPerContext,
Linearize,
LinearSystem, LinearSystem_,
LinearTransformDensity, LinearTransformDensity_,
Expand Down Expand Up @@ -116,7 +117,8 @@ def test_signal_logger(self, T):
source = builder.AddSystem(ConstantVectorSource_[T]([kValue]))
kSize = 1
integrator = builder.AddSystem(Integrator_[T](kSize))
logger_per_step = builder.AddSystem(SignalLogger_[T](kSize))
logger_per_step = builder.AddSystem(
SignalLogger_[T](kSize, storage_mode=kLogPerContext))
builder.Connect(source.get_output_port(0),
integrator.get_input_port(0))
builder.Connect(integrator.get_output_port(0),
Expand All @@ -125,43 +127,50 @@ def test_signal_logger(self, T):
# Add a redundant logger via the helper method.
if T == float:
logger_per_step_2 = LogOutput(
integrator.get_output_port(0), builder
integrator.get_output_port(0), builder,
storage_mode=kLogPerContext
)
else:
logger_per_step_2 = LogOutput[T](
integrator.get_output_port(0), builder
integrator.get_output_port(0), builder,
storage_mode=kLogPerContext
)

# Add a periodic logger
logger_periodic = builder.AddSystem(SignalLogger_[T](kSize))
logger_periodic = builder.AddSystem(SignalLogger_[T](
kSize, storage_mode=kLogPerContext))
kPeriod = 0.1
logger_periodic.set_publish_period(kPeriod)
builder.Connect(integrator.get_output_port(0),
logger_periodic.get_input_port(0))

diagram = builder.Build()
simulator = Simulator_[T](diagram)
context = simulator.get_context()
kTime = 1.
simulator.AdvanceTo(kTime)

# Verify outputs of the every-step logger
t = logger_per_step.sample_times()
x = logger_per_step.data()
log_per_step = logger_per_step.GetMutableLog(diagram, context)
t = log_per_step.sample_times()
x = log_per_step.data()

self.assertTrue(t.shape[0] > 2)
self.assertTrue(t.shape[0] == x.shape[1])
numpy_compare.assert_allclose(
t[-1]*kValue, x[0, -1], atol=1e-15, rtol=0
)
numpy_compare.assert_equal(x, logger_per_step_2.data())
log_per_step_2 = logger_per_step_2.GetLog(diagram, context)
numpy_compare.assert_equal(x, log_per_step_2.data())

# Verify outputs of the periodic logger
t = logger_periodic.sample_times()
x = logger_periodic.data()
log_periodic = logger_periodic.GetLog(diagram, context)
t = log_periodic.sample_times()
x = log_periodic.data()
# Should log exactly once every kPeriod, up to and including kTime.
self.assertTrue(t.shape[0] == np.floor(kTime / kPeriod) + 1.)

logger_per_step.reset()
log_per_step.reset()

# Verify that t and x retain their values after systems are deleted.
t_copy = t.copy()
Expand Down
8 changes: 5 additions & 3 deletions bindings/pydrake/systems/test/pyplot_visualizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydrake.systems.analysis import Simulator
from pydrake.systems.framework import (
Context, DiagramBuilder, PortDataType, VectorSystem, kUseDefaultName)
from pydrake.systems.primitives import SignalLogger
from pydrake.systems.primitives import SignalLogger, kLogPerContext
from pydrake.systems.pyplot_visualizer import PyPlotVisualizer
from pydrake.trajectories import PiecewisePolynomial

Expand Down Expand Up @@ -73,7 +73,8 @@ class TestPyplotVisualizer(unittest.TestCase):
def test_simple_visualizer(self):
builder = DiagramBuilder()
system = builder.AddSystem(SimpleContinuousTimeSystem())
logger = builder.AddSystem(SignalLogger(1))
logger = builder.AddSystem(SignalLogger(
1, storage_mode=kLogPerContext))
builder.Connect(system.get_output_port(0), logger.get_input_port(0))
visualizer = builder.AddSystem(TestVisualizer(1))
builder.Connect(system.get_output_port(0),
Expand All @@ -84,9 +85,10 @@ def test_simple_visualizer(self):
context.SetContinuousState([0.9])

simulator = Simulator(diagram, context)
log = logger.GetLog(diagram, context)
simulator.AdvanceTo(.1)

ani = visualizer.animate(logger, repeat=True)
ani = visualizer.animate(log, repeat=True)
self.assertIsInstance(ani, animation.FuncAnimation)

def test_trajectory(self):
Expand Down
21 changes: 13 additions & 8 deletions examples/acrobot/run_lqr_w_estimator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,19 @@ int do_main() {
builder.Connect(controller->get_output_port(), observer->get_input_port(1));

// Log the true state and the estimated state.
auto x_logger = LogOutput(acrobot_w_encoder->get_output_port(1), &builder);
auto x_logger = LogOutput(acrobot_w_encoder->get_output_port(1), &builder,
systems::kLogPerContext);
x_logger->set_name("x_logger");
auto xhat_logger = LogOutput(observer->get_output_port(0), &builder);
auto xhat_logger = LogOutput(observer->get_output_port(0), &builder,
systems::kLogPerContext);
xhat_logger->set_name("xhat_logger");

// Build the system/simulator.
auto diagram = builder.Build();
systems::Simulator<double> simulator(*diagram);
const auto& context = simulator.get_context();
const auto& x_log = x_logger->GetLog(*diagram, context);
const auto& xhat_log = xhat_logger->GetLog(*diagram, context);

// Set an initial condition near the upright fixed point.
AcrobotState<double>& x0 = acrobot_w_encoder->get_mutable_acrobot_state(
Expand Down Expand Up @@ -135,18 +140,18 @@ int do_main() {
using common::ToPythonTuple;
CallPython("figure", 1);
CallPython("clf");
CallPython("plot", x_logger->sample_times(),
(x_logger->data().row(0).array() - M_PI)
CallPython("plot", x_log.sample_times(),
(x_log.data().row(0).array() - M_PI)
.matrix().transpose());
CallPython("plot", x_logger->sample_times(),
x_logger->data().row(1).transpose());
CallPython("plot", x_log.sample_times(),
x_log.data().row(1).transpose());
CallPython("legend", ToPythonTuple("theta1 - PI", "theta2"));
CallPython("axis", "tight");

CallPython("figure", 2);
CallPython("clf");
CallPython("plot", x_logger->sample_times(),
(x_logger->data().array() - xhat_logger->data().array())
CallPython("plot", x_log.sample_times(),
(x_log.data().array() - xhat_log.data().array())
.matrix().transpose());
CallPython("ylabel", "error");
CallPython("legend", ToPythonTuple("theta1", "theta2", "theta1dot",
Expand Down
5 changes: 3 additions & 2 deletions examples/acrobot/spong_sim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ std::string Simulate(const YAML::Node& scenario_node) {

builder.Connect(plant->get_output_port(0), controller->get_input_port(0));
builder.Connect(controller->get_output_port(0), plant->get_input_port(0));
auto state_logger = LogOutput(plant->get_output_port(0), &builder);
auto state_logger = LogOutput(plant->get_output_port(0), &builder,
systems::kLogPerContext);
state_logger->set_publish_period(scenario.tape_period);

auto diagram = builder.Build();
Expand All @@ -120,7 +121,7 @@ std::string Simulate(const YAML::Node& scenario_node) {
simulator.AdvanceTo(scenario.t_final);

Output output;
output.x_tape = state_logger->data();
output.x_tape = state_logger->GetLog(*diagram, context).data();
drake::yaml::YamlWriteArchive writer;
writer.Accept(output);
// The EmitString call below saves a document like so:
Expand Down
8 changes: 5 additions & 3 deletions examples/acrobot/spong_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from pydrake.systems.analysis import Simulator
from pydrake.systems.framework import DiagramBuilder
from pydrake.systems.primitives import LogOutput
from pydrake.systems.primitives import LogOutput, kLogPerContext

from drake.examples.acrobot.acrobot_io import load_scenario, save_output

Expand All @@ -28,7 +28,8 @@ def simulate(*, initial_state, controller_params, t_final, tape_period):

builder.Connect(plant.get_output_port(0), controller.get_input_port(0))
builder.Connect(controller.get_output_port(0), plant.get_input_port(0))
state_logger = LogOutput(plant.get_output_port(0), builder)
state_logger = LogOutput(plant.get_output_port(0), builder,
storage_mode=kLogPerContext)
state_logger.set_publish_period(tape_period)

diagram = builder.Build()
Expand All @@ -37,14 +38,15 @@ def simulate(*, initial_state, controller_params, t_final, tape_period):
plant_context = diagram.GetMutableSubsystemContext(plant, context)
controller_context = diagram.GetMutableSubsystemContext(
controller, context)
log = state_logger.GetLog(diagram, context)

plant_context.SetContinuousState(initial_state)
controller_context.get_mutable_numeric_parameter(0).SetFromVector(
controller_params)

simulator.AdvanceTo(t_final)

x_tape = state_logger.data()
x_tape = log.data()
return x_tape


Expand Down
Loading

0 comments on commit 70b5120

Please sign in to comment.