Skip to content

Commit

Permalink
Updated the input and output types of the sqrt_spd adj_jac_apply (Issue
Browse files Browse the repository at this point in the history
#1144)

Rearranged some includes in stan/math/rev/mat.hpp to get rid of some type traits errors. Not sure why these were moved around. Might need moved back.
  • Loading branch information
bbbales2 committed Mar 8, 2019
1 parent 73bf8b3 commit f85bae9
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
6 changes: 3 additions & 3 deletions stan/math/rev/mat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@

#include <stan/math/rev/mat/fun/Eigen_NumTraits.hpp>

#include <stan/math/rev/mat/vectorize/apply_scalar_unary.hpp>
#include <stan/math/prim/mat.hpp>
#include <stan/math/rev/arr.hpp>
#include <stan/math/rev/mat/vectorize/apply_scalar_unary.hpp>

#include <stan/math/rev/mat/fun/LDLT_alloc.hpp>
#include <stan/math/rev/mat/fun/LDLT_factor.hpp>
#include <stan/math/rev/mat/fun/cholesky_decompose.hpp>
#include <stan/math/rev/mat/fun/columns_dot_product.hpp>
#include <stan/math/rev/mat/fun/columns_dot_self.hpp>
Expand All @@ -26,6 +24,8 @@
#include <stan/math/rev/mat/fun/gp_periodic_cov.hpp>
#include <stan/math/rev/mat/fun/grad.hpp>
#include <stan/math/rev/mat/fun/initialize_variable.hpp>
#include <stan/math/rev/mat/fun/LDLT_alloc.hpp>
#include <stan/math/rev/mat/fun/LDLT_factor.hpp>
#include <stan/math/rev/mat/fun/log_determinant.hpp>
#include <stan/math/rev/mat/fun/log_determinant_ldlt.hpp>
#include <stan/math/rev/mat/fun/log_determinant_spd.hpp>
Expand Down
22 changes: 13 additions & 9 deletions stan/math/rev/mat/fun/sqrt_spd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,23 @@ class sqrt_spd_op {
// derivative-or-differential-of-symmetric-square-root-of-a-matrix

template <std::size_t size>
std::tuple<Eigen::VectorXd>
std::tuple<Eigen::MatrixXd>
multiply_adjoint_jacobian(const std::array<bool, size> &needs_adj,
const Eigen::VectorXd &adj) const {
const Eigen::MatrixXd &adj) const {
using Eigen::kroneckerProduct;
using Eigen::MatrixXd;
Eigen::MatrixXd output(K_, K_);
Eigen::Map<const Eigen::VectorXd> map_input(adj.data(), adj.size());
Eigen::Map<Eigen::VectorXd> map_output(output.data(), output.size());
Eigen::Map<MatrixXd> sqrt_m(y_, K_, K_);
return std::make_tuple(
(kroneckerProduct(sqrt_m, MatrixXd::Identity(K_, K_)) +
kroneckerProduct(MatrixXd::Identity(K_, K_), sqrt_m))
// the above is symmetric so can skip the transpose
.ldlt()
.solve(adj)
.eval());

map_output = (kroneckerProduct(sqrt_m, MatrixXd::Identity(K_, K_)) +
kroneckerProduct(MatrixXd::Identity(K_, K_), sqrt_m))
// the above is symmetric so can skip the transpose
.ldlt()
.solve(map_input);

return std::make_tuple(output);
}
};
} // namespace
Expand Down
4 changes: 2 additions & 2 deletions test/unit/math/rev/mat/fun/sqrt_spd_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ TEST(AgradRevMatrix, check_varis_on_stack) {
test::check_varis_on_stack(sqrt_spd(a));
}

/*

struct make_zero {
template <typename T>
Eigen::Matrix<T, Eigen::Dynamic, 1> operator()(
Expand Down Expand Up @@ -57,4 +57,4 @@ TEST(AgradRevMatrix, sqrt_spd) {
if (i != j) EXPECT_NEAR(J(i, j), 0, TOL);
EXPECT_TRUE(J.array().any());
}
*/

0 comments on commit f85bae9

Please sign in to comment.