Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using local_nested_autodiff for all instances of nested autodiff #1706

Merged
Prev Previous commit
Next Next commit
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
…gs/RELEASE_500/final)
  • Loading branch information
stan-buildbot committed Feb 12, 2020
commit 9af29af2d3eafe2461c72c7e94e4527b2909eb14
24 changes: 8 additions & 16 deletions stan/math/rev/core/local_nested_autodiff.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,24 @@ namespace math {
/**
* A class following the RAII idiom to start and recover nested autodiff scopes.
* This is the preferred way to use nested autodiff. Example:
*
*
* var a; // allocated normally
* {
* local_nested_autodiff nested; // Starts nested autodiff
*
*
* var nested_var; //allocated on the nested stack
* // Do stuff on the nested stack
*
*
* // Nested stack is automatically recovered at the end of scope where
* // nested was declared, including exceptions, returns, etc.
* // nested was declared, including exceptions, returns, etc.
* }
* var b;
*/
class local_nested_autodiff {
public:
local_nested_autodiff()
{
start_nested();
}
public:
local_nested_autodiff() { start_nested(); }

~local_nested_autodiff()
{
recover_memory_nested();
}
~local_nested_autodiff() { recover_memory_nested(); }

// Prevent undesirable operations
local_nested_autodiff(const local_nested_autodiff&) = delete;
Expand All @@ -45,9 +39,7 @@ class local_nested_autodiff {
* Reset all adjoint values in this nested stack
* to zero.
**/
void set_zero_all_adjoints() {
set_zero_all_adjoints_nested();
}
void set_zero_all_adjoints() { set_zero_all_adjoints_nested(); }
};

} // namespace math
Expand Down
4 changes: 2 additions & 2 deletions stan/math/rev/core/recover_memory_nested.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ namespace math {
* is nothing on the nested stack, then a
* <code>std::logic_error</code> exception is thrown.
*
* It is preferred to use the <code>local_nested_autodiff</code> class for
* It is preferred to use the <code>local_nested_autodiff</code> class for
* nested autodiff as it handles recovery of memory automatically.
*
*
* @throw std::logic_error if <code>empty_nested()</code> returns
* <code>true</code>
*/
Expand Down
4 changes: 2 additions & 2 deletions stan/math/rev/core/set_zero_all_adjoints_nested.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ namespace math {
/**
* Reset all adjoint values in the top nested portion of the stack
* to zero.
*
* It is preferred to use the <code>local_nested_autodiff</code> class for
*
* It is preferred to use the <code>local_nested_autodiff</code> class for
* nested autodiff class as it handles recovery of memory automatically.
*/
static void set_zero_all_adjoints_nested() {
Expand Down
4 changes: 2 additions & 2 deletions stan/math/rev/core/start_nested.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ namespace math {
/**
* Record the current position so that <code>recover_memory_nested()</code>
* can find it.
*
* It is preferred to use the <code>local_nested_autodiff</code> class for
*
* It is preferred to use the <code>local_nested_autodiff</code> class for
* nested autodiff as it handles recovery of memory automatically.
*/
static inline void start_nested() {
Expand Down
12 changes: 6 additions & 6 deletions stan/math/rev/functor/coupled_ode_system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ struct coupled_ode_system<F, double, var> {

vector<var> dy_dt_vars = f_(t, y_vars, theta_nochain_, x_, x_int_, msgs_);

check_size_match("coupled_ode_system", "dz_dt", dy_dt_vars.size(),
"states", N_);
check_size_match("coupled_ode_system", "dz_dt", dy_dt_vars.size(), "states",
N_);

for (size_t i = 0; i < N_; i++) {
dz_dt[i] = dy_dt_vars[i].val();
Expand Down Expand Up @@ -281,8 +281,8 @@ struct coupled_ode_system<F, var, double> {

vector<var> dy_dt_vars = f_(t, y_vars, theta_dbl_, x_, x_int_, msgs_);

check_size_match("coupled_ode_system", "dz_dt", dy_dt_vars.size(),
"states", N_);
check_size_match("coupled_ode_system", "dz_dt", dy_dt_vars.size(), "states",
N_);

for (size_t i = 0; i < N_; i++) {
dz_dt[i] = dy_dt_vars[i].val();
Expand Down Expand Up @@ -457,8 +457,8 @@ struct coupled_ode_system<F, var, var> {

vector<var> dy_dt_vars = f_(t, y_vars, theta_nochain_, x_, x_int_, msgs_);

check_size_match("coupled_ode_system", "dz_dt", dy_dt_vars.size(),
"states", N_);
check_size_match("coupled_ode_system", "dz_dt", dy_dt_vars.size(), "states",
N_);

for (size_t i = 0; i < N_; i++) {
dz_dt[i] = dy_dt_vars[i].val();
Expand Down
6 changes: 2 additions & 4 deletions stan/math/rev/functor/idas_forward_system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ class idas_forward_system : public idas_system<F, Tyy, Typ, Tpar> {
MatrixXd J, r;
VectorXd f_val;

auto fyy
= [&t, &vyp, &vtheta, &N, &dae](const matrix_v& x) -> vector_v {
auto fyy = [&t, &vyp, &vtheta, &N, &dae](const matrix_v& x) -> vector_v {
std::vector<var> yy(x.data(), x.data() + N);
auto eval
= dae->f_(t, yy, vyp, vtheta, dae->x_r_, dae->x_i_, dae->msgs_);
Expand All @@ -138,8 +137,7 @@ class idas_forward_system : public idas_system<F, Tyy, Typ, Tpar> {
stan::math::jacobian(fyy, vec_yy, f_val, J);
r = J * yys_mat;

auto fyp
= [&t, &vyy, &vtheta, &N, &dae](const matrix_v& x) -> vector_v {
auto fyp = [&t, &vyy, &vtheta, &N, &dae](const matrix_v& x) -> vector_v {
std::vector<var> yp(x.data(), x.data() + N);
auto eval
= dae->f_(t, vyy, yp, vtheta, dae->x_r_, dae->x_i_, dae->msgs_);
Expand Down
2 changes: 1 addition & 1 deletion stan/math/rev/functor/integrate_1d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ inline double gradient_of_f(const F &f, const double &x, const double &xc,
gradient = 0;
} else {
throw_domain_error("gradient_of_f", "The gradient of f", n,
"is nan for parameter ", "");
"is nan for parameter ", "");
}
}

Expand Down
36 changes: 17 additions & 19 deletions stan/math/rev/functor/map_rect_reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ struct map_rect_reduce<F, var, var> {
vector_v shared_params_v = to_var(shared_params);
vector_v job_specific_params_v = to_var(job_specific_params);

vector_v fx_v
= F()(shared_params_v, job_specific_params_v, x_r, x_i, msgs);
vector_v fx_v = F()(shared_params_v, job_specific_params_v, x_r, x_i, msgs);

const size_type size_f = fx_v.rows();

Expand All @@ -45,8 +44,7 @@ struct map_rect_reduce<F, var, var> {
out(1 + j, i) = shared_params_v(j).vi_->adj_;
}
for (size_type j = 0; j < num_job_specific_params; ++j) {
out(1 + num_shared_params + j, i)
= job_specific_params_v(j).vi_->adj_;
out(1 + num_shared_params + j, i) = job_specific_params_v(j).vi_->adj_;
}
}
return out;
Expand Down Expand Up @@ -96,28 +94,28 @@ struct map_rect_reduce<F, var, double> {
const size_type num_shared_params = shared_params.rows();
matrix_d out(1 + num_shared_params, 0);

// Run nested autodiff in this scope
local_nested_autodiff nested;
// Run nested autodiff in this scope
local_nested_autodiff nested;

vector_v shared_params_v = to_var(shared_params);
vector_v shared_params_v = to_var(shared_params);

vector_v fx_v = F()(shared_params_v, job_specific_params, x_r, x_i, msgs);
vector_v fx_v = F()(shared_params_v, job_specific_params, x_r, x_i, msgs);

const size_type size_f = fx_v.rows();
const size_type size_f = fx_v.rows();

out.resize(Eigen::NoChange, size_f);
out.resize(Eigen::NoChange, size_f);

for (size_type i = 0; i < size_f; ++i) {
out(0, i) = fx_v(i).val();
nested.set_zero_all_adjoints();
fx_v(i).grad();
for (size_type j = 0; j < num_shared_params; ++j) {
out(1 + j, i) = shared_params_v(j).vi_->adj_;
for (size_type i = 0; i < size_f; ++i) {
out(0, i) = fx_v(i).val();
nested.set_zero_all_adjoints();
fx_v(i).grad();
for (size_type j = 0; j < num_shared_params; ++j) {
out(1 + j, i) = shared_params_v(j).vi_->adj_;
}
}

return out;
}

return out;
}
};

} // namespace internal
Expand Down
4 changes: 1 addition & 3 deletions test/unit/math/rev/core/gradable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ struct gradable {
EXPECT_FLOAT_EQ(g_expected_(i), g[i]);
}

double adj() {
return f_.adj();
}
double adj() { return f_.adj(); }
};

gradable setup_quad_form() {
Expand Down
9 changes: 4 additions & 5 deletions test/unit/math/rev/core/local_nested_autodiff_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@ struct AgradLocalNested : public testing::Test {
}
};


TEST_F(AgradLocalNested, local_nested_autodiff_base) {
{
stan::math::local_nested_autodiff nested;
EXPECT_THROW(stan::math::recover_memory(), std::logic_error);
}
stan::math::recover_memory(); // Should not throw
stan::math::recover_memory(); // Should not throw

gradable g_out = setup_quad_form();
for (int i = 0; i < 100; ++i) {
Expand All @@ -25,7 +24,7 @@ TEST_F(AgradLocalNested, local_nested_autodiff_base) {
g.test();
nested.set_zero_all_adjoints();
EXPECT_EQ(g.adj(), 0);
}
}
g_out.test();
}

Expand All @@ -50,7 +49,7 @@ TEST_F(AgradLocalNested, local_nested_autodiff_Gradient1) {
stan::math::recover_memory();
}

TEST_F(AgradLocalNested,local_nested_autodiff_Gradient2) {
TEST_F(AgradLocalNested, local_nested_autodiff_Gradient2) {
using stan::math::local_nested_autodiff;

gradable g0 = setup_quad_form();
Expand Down Expand Up @@ -85,7 +84,7 @@ TEST_F(AgradLocalNested, local_nested_autodiff_Gradient3) {
gradable g3 = setup_quad_form();
{
local_nested_autodiff nested4;
gradable g4 = setup_simple();
gradable g4 = setup_simple();
g4.test();
}
g3.test();
Expand Down
3 changes: 1 addition & 2 deletions test/unit/math/rev/functor/coupled_ode_system_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ TEST_F(StanAgradRevOde, coupled_ode_system_dv) {
EXPECT_FLOAT_EQ(-1.075, dz_dt[1]);
EXPECT_FLOAT_EQ(2, dz_dt[2]);
EXPECT_FLOAT_EQ(-1.8, dz_dt[3]);

}
TEST_F(StanAgradRevOde, initial_state_dv) {
using stan::math::coupled_ode_system;
Expand Down Expand Up @@ -321,7 +320,7 @@ TEST_F(StanAgradRevOde, coupled_ode_system_vv) {

// Run nested autodiff in this scope
stan::math::local_nested_autodiff nested;

const size_t N = 2;
const size_t M = 1;
const size_t z_size = N + N * N + N * M;
Expand Down