Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using local_nested_autodiff for all instances of nested autodiff #1706

Merged
Next Next commit
Using local_nested_autodiff for all instances of nested autodiff
  • Loading branch information
martinmodrak committed Feb 12, 2020
commit baf3a011a6a407f9dd84e21f44d05ca393b7c90e
38 changes: 17 additions & 21 deletions stan/math/mix/functor/grad_hessian.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,29 +50,25 @@ void grad_hessian(
int d = x.size();
H.resize(d, d);
grad_H.resize(d, Matrix<double, Dynamic, Dynamic>(d, d));
try {
for (int i = 0; i < d; ++i) {
for (int j = i; j < d; ++j) {
start_nested();
Matrix<fvar<fvar<var> >, Dynamic, 1> x_ffvar(d);
for (int k = 0; k < d; ++k) {
x_ffvar(k)
= fvar<fvar<var> >(fvar<var>(x(k), i == k), fvar<var>(j == k, 0));
}
fvar<fvar<var> > fx_ffvar = f(x_ffvar);
H(i, j) = fx_ffvar.d_.d_.val();
H(j, i) = H(i, j);
grad(fx_ffvar.d_.d_.vi_);
for (int k = 0; k < d; ++k) {
grad_H[i](j, k) = x_ffvar(k).val_.val_.adj();
grad_H[j](i, k) = grad_H[i](j, k);
}
recover_memory_nested();
for (int i = 0; i < d; ++i) {
for (int j = i; j < d; ++j) {
// Run nested autodiff in this scope
local_nested_autodiff nested;

Matrix<fvar<fvar<var> >, Dynamic, 1> x_ffvar(d);
for (int k = 0; k < d; ++k) {
x_ffvar(k)
= fvar<fvar<var> >(fvar<var>(x(k), i == k), fvar<var>(j == k, 0));
}
fvar<fvar<var> > fx_ffvar = f(x_ffvar);
H(i, j) = fx_ffvar.d_.d_.val();
H(j, i) = H(i, j);
grad(fx_ffvar.d_.d_.vi_);
for (int k = 0; k < d; ++k) {
grad_H[i](j, k) = x_ffvar(k).val_.val_.adj();
grad_H[j](i, k) = grad_H[i](j, k);
}
}
} catch (const std::exception& e) {
recover_memory_nested();
throw;
}
}

Expand Down
57 changes: 27 additions & 30 deletions stan/math/mix/functor/grad_tr_mat_times_hessian.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,41 +18,38 @@ void grad_tr_mat_times_hessian(
Eigen::Matrix<double, Eigen::Dynamic, 1>& grad_tr_MH) {
using Eigen::Dynamic;
using Eigen::Matrix;
start_nested();
try {
grad_tr_MH.resize(x.size());

Matrix<var, Dynamic, 1> x_var(x.size());
for (int i = 0; i < x.size(); ++i) {
x_var(i) = x(i);
}
// Run nested autodiff in this scope
local_nested_autodiff nested;

Matrix<fvar<var>, Dynamic, 1> x_fvar(x.size());

var sum(0.0);
Matrix<double, Dynamic, 1> M_n(x.size());
for (int n = 0; n < x.size(); ++n) {
for (int k = 0; k < x.size(); ++k) {
M_n(k) = M(n, k);
}
for (int k = 0; k < x.size(); ++k) {
x_fvar(k) = fvar<var>(x_var(k), k == n);
}
fvar<var> fx;
fvar<var> grad_fx_dot_v;
gradient_dot_vector<fvar<var>, double>(f, x_fvar, M_n, fx, grad_fx_dot_v);
sum += grad_fx_dot_v.d_;
}
grad_tr_MH.resize(x.size());

Matrix<var, Dynamic, 1> x_var(x.size());
for (int i = 0; i < x.size(); ++i) {
x_var(i) = x(i);
}

grad(sum.vi_);
for (int i = 0; i < x.size(); ++i) {
grad_tr_MH(i) = x_var(i).adj();
Matrix<fvar<var>, Dynamic, 1> x_fvar(x.size());

var sum(0.0);
Matrix<double, Dynamic, 1> M_n(x.size());
for (int n = 0; n < x.size(); ++n) {
for (int k = 0; k < x.size(); ++k) {
M_n(k) = M(n, k);
}
for (int k = 0; k < x.size(); ++k) {
x_fvar(k) = fvar<var>(x_var(k), k == n);
}
} catch (const std::exception& e) {
recover_memory_nested();
throw;
fvar<var> fx;
fvar<var> grad_fx_dot_v;
gradient_dot_vector<fvar<var>, double>(f, x_fvar, M_n, fx, grad_fx_dot_v);
sum += grad_fx_dot_v.d_;
}

grad(sum.vi_);
for (int i = 0; i < x.size(); ++i) {
grad_tr_MH(i) = x_var(i).adj();
}
recover_memory_nested();
}

} // namespace math
Expand Down
36 changes: 16 additions & 20 deletions stan/math/mix/functor/hessian.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,23 @@ void hessian(const F& f, const Eigen::Matrix<double, Eigen::Dynamic, 1>& x,
fx = f(x);
return;
}
try {
for (int i = 0; i < x.size(); ++i) {
start_nested();
Eigen::Matrix<fvar<var>, Eigen::Dynamic, 1> x_fvar(x.size());
for (int j = 0; j < x.size(); ++j) {
x_fvar(j) = fvar<var>(x(j), i == j);
}
fvar<var> fx_fvar = f(x_fvar);
grad(i) = fx_fvar.d_.val();
if (i == 0) {
fx = fx_fvar.val_.val();
}
stan::math::grad(fx_fvar.d_.vi_);
for (int j = 0; j < x.size(); ++j) {
H(i, j) = x_fvar(j).val_.adj();
}
recover_memory_nested();
for (int i = 0; i < x.size(); ++i) {
// Run nested autodiff in this scope
local_nested_autodiff nested;

Eigen::Matrix<fvar<var>, Eigen::Dynamic, 1> x_fvar(x.size());
for (int j = 0; j < x.size(); ++j) {
x_fvar(j) = fvar<var>(x(j), i == j);
}
fvar<var> fx_fvar = f(x_fvar);
grad(i) = fx_fvar.d_.val();
if (i == 0) {
fx = fx_fvar.val_.val();
}
stan::math::grad(fx_fvar.d_.vi_);
for (int j = 0; j < x.size(); ++j) {
H(i, j) = x_fvar(j).val_.adj();
}
} catch (const std::exception& e) {
recover_memory_nested();
throw;
}
}

Expand Down
35 changes: 16 additions & 19 deletions stan/math/mix/functor/hessian_times_vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,23 @@ void hessian_times_vector(const F& f,
double& fx,
Eigen::Matrix<double, Eigen::Dynamic, 1>& Hv) {
using Eigen::Matrix;
start_nested();
try {
Matrix<var, Eigen::Dynamic, 1> x_var(x.size());
for (int i = 0; i < x_var.size(); ++i) {
x_var(i) = x(i);
}
var fx_var;
var grad_fx_var_dot_v;
gradient_dot_vector(f, x_var, v, fx_var, grad_fx_var_dot_v);
fx = fx_var.val();
grad(grad_fx_var_dot_v.vi_);
Hv.resize(x.size());
for (int i = 0; i < x.size(); ++i) {
Hv(i) = x_var(i).adj();
}
} catch (const std::exception& e) {
recover_memory_nested();
throw;

// Run nested autodiff in this scope
local_nested_autodiff nested;

Matrix<var, Eigen::Dynamic, 1> x_var(x.size());
for (int i = 0; i < x_var.size(); ++i) {
x_var(i) = x(i);
}
var fx_var;
var grad_fx_var_dot_v;
gradient_dot_vector(f, x_var, v, fx_var, grad_fx_var_dot_v);
fx = fx_var.val();
grad(grad_fx_var_dot_v.vi_);
Hv.resize(x.size());
for (int i = 0; i < x.size(); ++i) {
Hv(i) = x_var(i).adj();
}
recover_memory_nested();
}
template <typename T, typename F>
void hessian_times_vector(const F& f,
Expand Down
1 change: 1 addition & 0 deletions stan/math/rev/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <stan/math/rev/core/empty_nested.hpp>
#include <stan/math/rev/core/gevv_vvv_vari.hpp>
#include <stan/math/rev/core/grad.hpp>
#include <stan/math/rev/core/local_nested_autodiff.hpp>
#include <stan/math/rev/core/matrix_vari.hpp>
#include <stan/math/rev/core/nested_size.hpp>
#include <stan/math/rev/core/operator_addition.hpp>
Expand Down
55 changes: 55 additions & 0 deletions stan/math/rev/core/local_nested_autodiff.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#ifndef STAN_MATH_REV_CORE_LOCAL_NESTED_AUTODIFF_HPP
#define STAN_MATH_REV_CORE_LOCAL_NESTED_AUTODIFF_HPP

#include <stan/math/rev/core/recover_memory_nested.hpp>
#include <stan/math/rev/core/set_zero_all_adjoints_nested.hpp>
#include <stan/math/rev/core/start_nested.hpp>

namespace stan {
namespace math {

/**
* A class following the RAII idiom to start and recover nested autodiff scopes.
* This is the preferred way to use nested autodiff. Example:
*
* var a; // allocated normally
* {
* local_nested_autodiff nested; // Starts nested autodiff
*
* var nested_var; //allocated on the nested stack
* // Do stuff on the nested stack
*
* // Nested stack is automatically recovered at the end of scope where
* // nested was declared, including exceptions, returns, etc.
* }
* var b;
*/
class local_nested_autodiff {
public:
local_nested_autodiff()
{
start_nested();
}

~local_nested_autodiff()
{
recover_memory_nested();
}

// Prevent undesirable operations
local_nested_autodiff(const local_nested_autodiff&) = delete;
local_nested_autodiff& operator=(const local_nested_autodiff&) = delete;
void* operator new(std::size_t) = delete;

/**
* Reset all adjoint values in this nested stack
* to zero.
**/
void set_zero_all_adjoints() {
set_zero_all_adjoints_nested();
}
};

} // namespace math
} // namespace stan
#endif
3 changes: 3 additions & 0 deletions stan/math/rev/core/recover_memory_nested.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ namespace math {
* is nothing on the nested stack, then a
* <code>std::logic_error</code> exception is thrown.
*
* It is preferred to use the <code>local_nested_autodiff</code> class for
* nested autodiff as it handles recovery of memory automatically.
*
* @throw std::logic_error if <code>empty_nested()</code> returns
* <code>true</code>
*/
Expand Down
3 changes: 3 additions & 0 deletions stan/math/rev/core/set_zero_all_adjoints_nested.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ namespace math {
/**
* Reset all adjoint values in the top nested portion of the stack
* to zero.
*
* It is preferred to use the <code>local_nested_autodiff</code> class for
* nested autodiff class as it handles recovery of memory automatically.
*/
static void set_zero_all_adjoints_nested() {
if (empty_nested()) {
Expand Down
3 changes: 3 additions & 0 deletions stan/math/rev/core/start_nested.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ namespace math {
/**
* Record the current position so that <code>recover_memory_nested()</code>
* can find it.
*
* It is preferred to use the <code>local_nested_autodiff</code> class for
* nested autodiff as it handles recovery of memory automatically.
*/
static inline void start_nested() {
ChainableStack::instance_->nested_var_stack_sizes_.push_back(
Expand Down
Loading