Skip to content

Commit

Permalink
Removed a couple rep_matrix specializations and tests (Issue #1805)
Browse files Browse the repository at this point in the history
  • Loading branch information
bbbales2 committed Feb 26, 2021
1 parent 374866a commit 63a1ad6
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 200 deletions.
28 changes: 6 additions & 22 deletions stan/math/prim/fun/rep_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace stan {
namespace math {

/**
* Impl of rep_matrix returning an Eigen matrix with scalar
* Implementation of rep_matrix returning an Eigen matrix with scalar
* type equal to the input scalar type.
* @tparam Ret An Eigen type.
* @tparam T A Scalar type.
Expand All @@ -29,7 +29,7 @@ inline auto rep_matrix(const T& x, int m, int n) {
}

/**
* Default Implimentation of rep_matrix returning an Eigen matrix with scalar
* Default Implementation of rep_matrix returning an Eigen matrix with scalar
* type equal to the input scalar type.
* @tparam T A Scalar type.
* @param x A Scalar whose values are propogated to all values in the return
Expand All @@ -43,15 +43,14 @@ inline auto rep_matrix(const T& x, int m, int n) {
}

/**
* Impl of rep_matrix returning an Eigen matrix from an Eigen vector.
* @tparam Ret An Eigen type.
* @tparam Vec An Eigen vector with Arithmetic scalar type.
* Implementation of rep_matrix returning an Eigen matrix from an Eigen
* vector.
* @tparam Vec An Eigen vector.
* @param x An Eigen vector. For Row vectors the values are replacated rowwise.
* and for column vectors the values are repliacated colwise.
* @param n Number of rows or columns.
*/
template <typename Ret, typename Vec,
require_eigen_vt<is_stan_scalar, Ret>* = nullptr,
template <typename Vec,
require_eigen_vector_t<Vec>* = nullptr>
inline auto rep_matrix(const Vec& x, int n) {
if (is_eigen_row_vector<Vec>::value) {
Expand All @@ -63,21 +62,6 @@ inline auto rep_matrix(const Vec& x, int n) {
}
}

/**
* Default Implimentation of rep_matrix returning an Eigen matrix from an Eigen
* vector.
* @tparam Vec An Eigen vector with Arithmetic scalar type.
* @param x An Eigen vector. For Row vectors the values are replacated rowwise
* and for column vectors the values are repliacated colwise.
* @param n Number of rows or columns.
*/
template <typename Vec, require_vector_t<Vec>* = nullptr>
inline auto rep_matrix(const Vec& x, int n) {
using scalar_t = value_type_t<Vec>;
using ret_t = Eigen::Matrix<scalar_t, Eigen::Dynamic, Eigen::Dynamic>;
return rep_matrix<ret_t>(x, n);
}

} // namespace math
} // namespace stan

Expand Down
46 changes: 3 additions & 43 deletions stan/math/rev/fun/rep_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,6 @@
namespace stan {
namespace math {

/**
* Impl of rep_matrix returning an `var_value<Eigen::Matrix>` with a double
* scalar type.
* @tparam Ret A `var_value` with inner Eigen type.
* @tparam T A Scalar type.
* @param x A Scalar whose values are propogated to all values in the return
* matrix.
* @param m Number or rows.
* @param n Number of columns.
*/
template <typename Ret, typename T, require_var_matrix_t<Ret>* = nullptr,
require_arithmetic_t<T>* = nullptr>
inline auto rep_matrix(const T& x, int m, int n) {
check_nonnegative("rep_matrix", "rows", m);
check_nonnegative("rep_matrix", "cols", n);
return Ret(value_type_t<Ret>::Constant(m, n, x));
}

/**
* Impl of rep_matrix returning an `var_value<Eigen::Matrix>` with a var scalar
* type.
Expand All @@ -48,27 +30,6 @@ inline auto rep_matrix(const T& x, int m, int n) {
[x](auto& rep) mutable { x.adj() += rep.adj().sum(); });
}

/**
* Impl of rep_matrix returning an `var_value<Eigen::Matrix>` from an Eigen
* matrix with Arithmetic scalar.
* @tparam Ret A `var_value` with inner Eigen dynamic matrix type.
* @tparam Vec An Eigen vector with Arithmetic scalar type.
* @param x An Eigen vector. For Row vectors the values are replacated rowwise
* and for column vectors the values are repliacated colwise.
* @param n Number of rows or columns.
*/
template <typename Ret, typename Vec, require_var_matrix_t<Ret>* = nullptr,
require_eigen_vector_vt<std::is_arithmetic, Vec>* = nullptr>
inline auto rep_matrix(const Vec& x, int n) {
if (is_eigen_row_vector<Vec>::value) {
check_nonnegative("rep_matrix", "rows", n);
return Ret(x.replicate(n, 1));
} else {
check_nonnegative("rep_matrix", "cols", n);
return Ret(x.replicate(1, n));
}
}

/**
* Impl of rep_matrix returning a `var_value<Eigen::Matrix>` from a `var_value`
* with an inner Eigen vector type.
Expand All @@ -79,18 +40,17 @@ inline auto rep_matrix(const Vec& x, int n) {
* repliacated colwise.
* @param n Number of rows or columns.
*/
template <typename Ret, typename Vec, require_var_matrix_t<Ret>* = nullptr,
require_vector_st<is_var, Vec>* = nullptr>
template <typename Vec, require_var_matrix_t<Vec>* = nullptr>
inline auto rep_matrix(const Vec& x, int n) {
if (is_row_vector<Vec>::value) {
check_nonnegative("rep_matrix", "rows", n);
return make_callback_var(x.val().replicate(n, 1), [x](auto& rep) mutable {
x.adj() += rep.adj().rowwise().sum();
x.adj() += rep.adj().colwise().sum();
});
} else {
check_nonnegative("rep_matrix", "cols", n);
return make_callback_var(x.val().replicate(1, n), [x](auto& rep) mutable {
x.adj() += rep.adj().colwise().sum();
x.adj() += rep.adj().rowwise().sum();
});
}
}
Expand Down
137 changes: 2 additions & 135 deletions test/unit/math/mix/fun/rep_matrix_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,143 +22,10 @@ TEST(MathMixMatFun, repMatrix) {
Eigen::VectorXd a(3);
a << 3, 3, 3;
stan::test::expect_ad(g(2), a);
stan::test::expect_ad_matvar(g(2), a);

Eigen::RowVectorXd b(2);
b << 2, 2;
stan::test::expect_ad(g(3), b);
}

TEST(MathMixMatFun, repVarMatrix) {
using stan::math::rep_matrix;
using stan::math::sum;
using stan::math::var;
using stan::math::var_value;
auto x_var = var(1.0);
auto x = rep_matrix<var_value<Eigen::Matrix<double, -1, -1>>>(x_var, 5, 5);
auto x_sum = sum(x);
x_sum.grad();

EXPECT_EQ(x_sum.val(), 25.0);
EXPECT_EQ(x_sum.adj(), 1.0);
EXPECT_MATRIX_EQ(x.val(), Eigen::MatrixXd::Ones(5, 5));
EXPECT_MATRIX_EQ(x.adj(), Eigen::MatrixXd::Ones(5, 5));
EXPECT_EQ(x_var.val(), 1.0);
EXPECT_EQ(x_var.adj(), 25.0);
}
TEST(MathMixMatFun, repVarMatrixArithmetic) {
using stan::math::rep_matrix;
using stan::math::sum;
using stan::math::var;
using stan::math::var_value;
double x_dbl = 1.0;
auto x = rep_matrix<var_value<Eigen::Matrix<double, -1, -1>>>(x_dbl, 5, 5);
auto x_sum = sum(x);
x_sum.grad();

EXPECT_EQ(x_sum.val(), 25.0);
EXPECT_EQ(x_sum.adj(), 1.0);
EXPECT_MATRIX_EQ(x.val(), Eigen::MatrixXd::Ones(5, 5));
EXPECT_MATRIX_EQ(x.adj(), Eigen::MatrixXd::Ones(5, 5));
}

TEST(MathMixMatFun, repVarMatrixVec) {
using stan::math::rep_matrix;
using stan::math::sum;
using stan::math::var;
using stan::math::var_value;
var_value<Eigen::VectorXd> x_var(Eigen::VectorXd::Ones(5));
auto x = rep_matrix<var_value<Eigen::Matrix<double, -1, -1>>>(x_var, 5);
for (int j = 0; j < x.cols(); ++j) {
for (int i = 0; i < x.rows(); ++i) {
x.adj()(i, j) = i;
}
}
auto x_sum = sum(x);
x_sum.grad();
EXPECT_EQ(x_sum.val(), 25.0);
EXPECT_EQ(x_sum.adj(), 1.0);
Eigen::VectorXd expected_cols(5);
expected_cols << 1, 2, 3, 4, 5;
auto expected_adjs = expected_cols.replicate(1, 5).eval();
EXPECT_MATRIX_EQ(x.val(), Eigen::MatrixXd::Ones(5, 5));
EXPECT_MATRIX_EQ(x.adj(), expected_adjs);
EXPECT_MATRIX_EQ(x_var.val(), Eigen::VectorXd::Ones(5));
Eigen::VectorXd expected_x_var_adjs(5);
expected_x_var_adjs << 15, 15, 15, 15, 15;
EXPECT_MATRIX_EQ(x_var.adj(), expected_x_var_adjs);
}

TEST(MathMixMatFun, repVarMatrixVecArithmetic) {
using stan::math::rep_matrix;
using stan::math::sum;
using stan::math::var;
using stan::math::var_value;
auto x_dbl = Eigen::VectorXd::Ones(5).eval();
auto x = rep_matrix<var_value<Eigen::Matrix<double, -1, -1>>>(x_dbl, 5);
for (int j = 0; j < x.cols(); ++j) {
for (int i = 0; i < x.rows(); ++i) {
x.adj()(i, j) = i;
}
}
auto x_sum = sum(x);
x_sum.grad();
EXPECT_EQ(x_sum.val(), 25.0);
EXPECT_EQ(x_sum.adj(), 1.0);
Eigen::VectorXd expected_cols(5);
expected_cols << 1, 2, 3, 4, 5;
auto expected_adjs = expected_cols.replicate(1, 5).eval();
EXPECT_MATRIX_EQ(x.val(), Eigen::MatrixXd::Ones(5, 5));
EXPECT_MATRIX_EQ(x.adj(), expected_adjs);
}

TEST(MathMixMatFun, repVarMatrixRowVec) {
using stan::math::rep_matrix;
using stan::math::sum;
using stan::math::var;
using stan::math::var_value;
var_value<Eigen::RowVectorXd> x_var(Eigen::RowVectorXd::Ones(5));
auto x = rep_matrix<var_value<Eigen::Matrix<double, -1, -1>>>(x_var, 5);
for (int j = 0; j < x.cols(); ++j) {
for (int i = 0; i < x.rows(); ++i) {
x.adj()(i, j) = i;
}
}
auto x_sum = sum(x);
x_sum.grad();

EXPECT_EQ(x_sum.val(), 25.0);
EXPECT_EQ(x_sum.adj(), 1.0);
Eigen::VectorXd expected_cols(5);
expected_cols << 1, 2, 3, 4, 5;
auto expected_adjs = expected_cols.replicate(1, 5).eval();
EXPECT_MATRIX_EQ(x.val(), Eigen::MatrixXd::Ones(5, 5));
EXPECT_MATRIX_EQ(x.adj(), expected_adjs);
EXPECT_MATRIX_EQ(x_var.val(), Eigen::RowVectorXd::Ones(5));
Eigen::RowVectorXd expected_x_var_adjs(5);
expected_x_var_adjs << 5, 10, 15, 20, 25;
EXPECT_MATRIX_EQ(x_var.adj(), expected_x_var_adjs);
}

TEST(MathMixMatFun, repVarMatrixRowVecArithmetic) {
using stan::math::rep_matrix;
using stan::math::sum;
using stan::math::var;
using stan::math::var_value;
auto x_dbl = Eigen::RowVectorXd::Ones(5).eval();
auto x = rep_matrix<var_value<Eigen::Matrix<double, -1, -1>>>(x_dbl, 5);
for (int j = 0; j < x.cols(); ++j) {
for (int i = 0; i < x.rows(); ++i) {
x.adj()(i, j) = i;
}
}
auto x_sum = sum(x);
x_sum.grad();

EXPECT_EQ(x_sum.val(), 25.0);
EXPECT_EQ(x_sum.adj(), 1.0);
Eigen::VectorXd expected_cols(5);
expected_cols << 1, 2, 3, 4, 5;
auto expected_adjs = expected_cols.replicate(1, 5).eval();
EXPECT_MATRIX_EQ(x.val(), Eigen::MatrixXd::Ones(5, 5));
EXPECT_MATRIX_EQ(x.adj(), expected_adjs);
//stan::test::expect_ad_matvar(g(3), b);
}

0 comments on commit 63a1ad6

Please sign in to comment.