Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

operator+ for var matrices and matrix of vars #2115

Merged
merged 50 commits into from
Nov 16, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
4df2ba7
adds operator+ and add() overloads for var matrix types
SteveBronder Sep 30, 2020
18a985c
use arena_t
SteveBronder Sep 30, 2020
16c77bb
use arena_t
SteveBronder Sep 30, 2020
e54858d
update docs
SteveBronder Sep 30, 2020
789a94d
Merge commit '2cf4310702eec1b14a355d883e704af0b788abd8' into HEAD
yashikno Sep 30, 2020
b098c2c
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Sep 30, 2020
1a42196
return back matrix type for addition
SteveBronder Sep 30, 2020
00ca894
return matrix type for addition
SteveBronder Sep 30, 2020
c8c06e3
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Sep 30, 2020
a062f9c
remove is_equal
SteveBronder Oct 2, 2020
f67c251
Merge commit '15ccafff6cf5b9ee03bd4cee1a0ac00185522c97' into HEAD
yashikno Oct 2, 2020
efbf8b3
[Jenkins] auto-formatting by clang-format version 6.0.1-14 (tags/RELE…
stan-buildbot Oct 2, 2020
1f237d6
add coeffRef to adj_op plugin and use loops in operator_addition
SteveBronder Oct 2, 2020
ee4f556
include rev/meta header in operator_addition
SteveBronder Oct 2, 2020
a7260da
fixup headers in rev/meta.hpp
SteveBronder Oct 2, 2020
3a368dd
use var_value<double> instead of var in promote_var_matrix_t
SteveBronder Oct 2, 2020
759efe4
Merge commit 'a7901bca151f0e52f56c9a99cd2456a7fb1e278f' into HEAD
yashikno Oct 2, 2020
cd91e89
[Jenkins] auto-formatting by clang-format version 6.0.1-14 (tags/RELE…
stan-buildbot Oct 2, 2020
c2dcf5b
Merge remote-tracking branch 'origin/develop' into feature/varmat-ope…
SteveBronder Oct 3, 2020
728de5a
have operator+ use add() to avoid overload issue with matrix<var>
SteveBronder Oct 4, 2020
8e138f0
Merge commit 'b3297de86605da8fa97b40311e24bc60f6c8081b' into HEAD
yashikno Oct 4, 2020
bc0ca4e
[Jenkins] auto-formatting by clang-format version 6.0.1-14 (tags/RELE…
stan-buildbot Oct 4, 2020
0f1abb6
non-loop for var<matrix> addition
SteveBronder Oct 4, 2020
e9c32b4
add specialization for check_matching_dims for mixes of matrices and …
SteveBronder Oct 4, 2020
fc47c80
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Oct 4, 2020
90bd655
add add() specializations for (var, var) and combinations of var and …
SteveBronder Oct 4, 2020
0814f7b
Merge branch 'feature/varmat-operatorplus' of github.com:stan-dev/mat…
SteveBronder Oct 4, 2020
c5d5a88
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Oct 4, 2020
c3cb134
add expect_ad(scalar, scalar) test for add()
SteveBronder Oct 5, 2020
0a43d5c
Merge branch 'feature/varmat-operatorplus' of github.com:stan-dev/mat…
SteveBronder Oct 20, 2020
75b786c
merge to develop
SteveBronder Oct 20, 2020
c43241d
update is_nan and use loop that makes local copy of ret adj
SteveBronder Oct 21, 2020
232e608
use arena_a.val() instead of a.val() and for b in creating the return…
SteveBronder Oct 21, 2020
23d3d05
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Oct 21, 2020
a70b3aa
fix header includes and use loop for var mat and mat var
SteveBronder Oct 21, 2020
a4674aa
Merge branch 'feature/varmat-operatorplus' of github.com:stan-dev/mat…
SteveBronder Oct 21, 2020
a1ed9d5
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Oct 21, 2020
8a55108
remove changes to eigen_numtraits and move add() to operator_addition
SteveBronder Oct 21, 2020
ee27ace
Merge branch 'feature/varmat-operatorplus' of github.com:stan-dev/mat…
SteveBronder Oct 21, 2020
df959b6
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Oct 21, 2020
b1f2982
move Eigen_NumTraits, read_var, and typedefs headers to core
SteveBronder Oct 21, 2020
e3bc35e
Merge branch 'feature/varmat-operatorplus' of github.com:stan-dev/mat…
SteveBronder Oct 21, 2020
145a743
fix double include
SteveBronder Oct 21, 2020
3464880
fix includes
SteveBronder Oct 21, 2020
e793e6c
fix expression test failure
SteveBronder Oct 22, 2020
d0997e3
Merge remote-tracking branch 'origin/develop' into feature/varmat-ope…
SteveBronder Oct 28, 2020
6ea6704
fix headers
SteveBronder Oct 28, 2020
e4615df
update to develop
SteveBronder Nov 3, 2020
b0b1440
merge to develop and remove changes for eigen_plugins
SteveBronder Nov 12, 2020
1507a74
kickoff jenkins
SteveBronder Nov 13, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
…4.1 (tags/RELEASE_600/final)
  • Loading branch information
stan-buildbot committed Sep 30, 2020
commit b098c2cdf5833c5305c757b9f41c5c0118acc66d
17 changes: 10 additions & 7 deletions stan/math/prim/err/is_matching_dims.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ namespace math {
* @param y2 second matrix to test
* @return <code>true</code> if the dimensions of the matrices match
*/
template <typename EigMat1, typename EigMat2, require_all_matrix_t<EigMat1, EigMat2>* = nullptr>
inline bool is_matching_dims(const EigMat1& y1,
const EigMat2& y2) {
template <typename EigMat1, typename EigMat2,
require_all_matrix_t<EigMat1, EigMat2>* = nullptr>
andrjohns marked this conversation as resolved.
Show resolved Hide resolved
inline bool is_matching_dims(const EigMat1& y1, const EigMat2& y2) {
return is_size_match(y1.rows(), y2.rows())
&& is_size_match(y1.cols(), y2.cols());
}
Expand All @@ -46,10 +46,13 @@ inline bool is_matching_dims(const EigMat1& y1,
* @param y2 second matrix to test
* @return <code>true</code> if the dimensions of the matrices match
*/
template <bool check_compile, typename EigMat1, typename EigMat2, require_all_matrix_t<EigMat1, EigMat2>* = nullptr>
inline bool is_matching_dims(const EigMat1& y1,
const EigMat2& y2) {
return !(check_compile && (EigMat1::RowsAtCompileTime != EigMat2::RowsAtCompileTime || EigMat1::ColsAtCompileTime != EigMat2::ColsAtCompileTime)) && is_matching_dims(y1, y2);
template <bool check_compile, typename EigMat1, typename EigMat2,
require_all_matrix_t<EigMat1, EigMat2>* = nullptr>
andrjohns marked this conversation as resolved.
Show resolved Hide resolved
inline bool is_matching_dims(const EigMat1& y1, const EigMat2& y2) {
return !(check_compile
&& (EigMat1::RowsAtCompileTime != EigMat2::RowsAtCompileTime
|| EigMat1::ColsAtCompileTime != EigMat2::ColsAtCompileTime))
&& is_matching_dims(y1, y2);
}

} // namespace math
Expand Down
83 changes: 41 additions & 42 deletions stan/math/rev/core/operator_addition.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ inline var operator+(const var& a, const var& b) {
var ret(a.val() + b.val());
if (unlikely(is_any_nan(a.val(), b.val()))) {
reverse_pass_callback([a, b]() mutable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think special handling of nans benefits either correctness of performance. The same propagation of nans happens in general branch as well. Maybe benchmark it?

a.adj() = NOT_A_NUMBER;
b.adj() = NOT_A_NUMBER;
});
a.adj() = NOT_A_NUMBER;
b.adj() = NOT_A_NUMBER;
});
} else {
reverse_pass_callback([ret, a, b]() mutable {
a.adj() += ret.adj();
Expand Down Expand Up @@ -116,13 +116,9 @@ inline var operator+(const var& a, Arith b) {
} else {
var ret(a.val() + b);
if (unlikely(is_any_nan(a.val(), b))) {
reverse_pass_callback([a]() mutable {
a.adj() = NOT_A_NUMBER;
});
reverse_pass_callback([a]() mutable { a.adj() = NOT_A_NUMBER; });
} else {
reverse_pass_callback([ret, a]() mutable {
a.adj() += ret.adj();
});
reverse_pass_callback([ret, a]() mutable { a.adj() += ret.adj(); });
}
return ret;
}
Expand Down Expand Up @@ -154,17 +150,19 @@ inline var operator+(Arith a, const var& b) {
* @param b Second variable operand.
* @return Variable result of adding two variables.
*/
template <typename VarMat1, typename VarMat2, require_all_rev_matrix_t<VarMat1, VarMat2>* = nullptr>
template <typename VarMat1, typename VarMat2,
require_all_rev_matrix_t<VarMat1, VarMat2>* = nullptr>
inline auto operator+(const VarMat1& a, const VarMat2& b) {
check_matching_dims("operator+", "a", a, "b", b);
using ret_type = decltype(a.val() + b.val());
promote_var_matrix_t<ret_type, VarMat1, VarMat2> ret((a.val() + b.val()).eval());
promote_var_matrix_t<ret_type, VarMat1, VarMat2> ret(
(a.val() + b.val()).eval());
arena_t<VarMat1> arena_a = a;
arena_t<VarMat2> arena_b = b;
reverse_pass_callback([ret, arena_a, arena_b]() mutable {
arena_a.adj() += ret.adj_op();
arena_b.adj() += ret.adj_op();
});
arena_a.adj() += ret.adj_op();
arena_b.adj() += ret.adj_op();
});
return ret;
}

Expand All @@ -177,22 +175,23 @@ inline auto operator+(const VarMat1& a, const VarMat2& b) {
* @param b Second variable operand.
* @return Variable result of adding two variables.
*/
template <typename Arith, typename VarMat, require_st_arithmetic<Arith>* = nullptr,
require_rev_matrix_t<VarMat>* = nullptr>
template <typename Arith, typename VarMat,
require_st_arithmetic<Arith>* = nullptr,
require_rev_matrix_t<VarMat>* = nullptr>
inline auto operator+(const VarMat& a, const Arith& b) {
if (is_eigen<Arith>::value) {
check_matching_dims("operator+", "a", a, "b", b);
}
using ret_inner_type = plain_type_t<decltype((a.val().array() + as_array_or_scalar(b)).matrix())>;
using ret_inner_type = plain_type_t<decltype(
(a.val().array() + as_array_or_scalar(b)).matrix())>;
using ret_type = promote_var_matrix_t<ret_inner_type, VarMat>;
if (is_equal(b, 0.0)) {
return ret_type(a);
} else {
arena_t<VarMat> arena_a = a;
ret_type ret(a.val().array() + as_array_or_scalar(b));
reverse_pass_callback([ret, arena_a]() mutable {
arena_a.adj() += ret.adj_op();
});
reverse_pass_callback(
[ret, arena_a]() mutable { arena_a.adj() += ret.adj_op(); });
return ret;
}
}
Expand All @@ -206,8 +205,9 @@ inline auto operator+(const VarMat& a, const Arith& b) {
* @param b Second variable operand.
* @return Variable result of adding two variables.
*/
template <typename Arith, typename VarMat, require_st_arithmetic<Arith>* = nullptr,
require_rev_matrix_t<VarMat>* = nullptr>
template <typename Arith, typename VarMat,
require_st_arithmetic<Arith>* = nullptr,
require_rev_matrix_t<VarMat>* = nullptr>
inline auto operator+(const Arith& a, const VarMat& b) {
return b + a;
}
Expand All @@ -222,17 +222,14 @@ inline auto operator+(const Arith& a, const VarMat& b) {
* @return Variable result of adding two variables.
*/
template <typename Var, typename EigMat,
require_eigen_vt<std::is_arithmetic, EigMat>* = nullptr,
require_var_vt<std::is_arithmetic, Var>* = nullptr>
require_eigen_vt<std::is_arithmetic, EigMat>* = nullptr,
require_var_vt<std::is_arithmetic, Var>* = nullptr>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this pull but I didn't realise we had this require_var_vt, that will simplify some of the wacky templating I've been resorting to. These require generics have been such a good addition!

inline auto operator+(const Var& a, const EigMat& b) {
arena_t<promote_scalar_t<var, EigMat>> ret(a.val() + b.array());
reverse_pass_callback([ret, a]() mutable {
a.adj() += ret.adj().sum();
});
arena_t<promote_scalar_t<var, EigMat>> ret(a.val() + b.array());
reverse_pass_callback([ret, a]() mutable { a.adj() += ret.adj().sum(); });
return ret;
}


/**
* Addition operator for a variable and arithmetic matrix (C++).
*
Expand All @@ -242,24 +239,26 @@ inline auto operator+(const Var& a, const EigMat& b) {
* @param b Second variable operand.
* @return Variable result of adding two variables.
*/
template <typename EigMat, typename Var, require_var_vt<std::is_arithmetic, Var>* = nullptr,
require_eigen_vt<std::is_arithmetic, EigMat>* = nullptr>
template <typename EigMat, typename Var,
require_var_vt<std::is_arithmetic, Var>* = nullptr,
require_eigen_vt<std::is_arithmetic, EigMat>* = nullptr>
inline auto operator+(const EigMat& a, const Var& b) {
return b + a;
}

/**
* Addition operator for a variable and variable matrix (C++).
*
* @tparam VarMat An Eigen Matrix type with a variable Scalar type or a `var_value` with an underlying matrix type.
* @tparam VarMat An Eigen Matrix type with a variable Scalar type or a
* `var_value` with an underlying matrix type.
* @tparam Var A `var_value` with an underlying arithmetic type.
* @param a First variable operand.
* @param b Second variable operand.
* @return Variable result of adding two variables.
*/
template <typename Var, typename VarMat,
require_rev_matrix_t<VarMat>* = nullptr,
require_var_vt<std::is_arithmetic, Var>* = nullptr>
require_rev_matrix_t<VarMat>* = nullptr,
require_var_vt<std::is_arithmetic, Var>* = nullptr>
inline auto operator+(const Var& a, const VarMat& b) {
arena_t<VarMat> arena_b(b);
arena_t<VarMat> ret(a.val() + b.val().array());
Expand All @@ -270,22 +269,22 @@ inline auto operator+(const Var& a, const VarMat& b) {
return ret;
}


/**
* Addition operator for a variable matrix and variable (C++).
*
* @tparam VarMat An Eigen Matrix type with a variable Scalar type or a `var_value` with an underlying matrix type.
* @tparam VarMat An Eigen Matrix type with a variable Scalar type or a
* `var_value` with an underlying matrix type.
* @tparam Var A `var_value` with an underlying arithmetic type.
* @param a First variable operand.
* @param b Second variable operand.
* @return Variable result of adding two variables.
*/
template <typename Var, typename VarMat,
require_rev_matrix_t<VarMat>* = nullptr,
require_var_vt<std::is_arithmetic, Var>* = nullptr>
inline auto operator+(const VarMat& a, const Var& b) {
return b + a;
}
template <typename Var, typename VarMat,
require_rev_matrix_t<VarMat>* = nullptr,
require_var_vt<std::is_arithmetic, Var>* = nullptr>
inline auto operator+(const VarMat& a, const Var& b) {
return b + a;
}

} // namespace math
} // namespace stan
Expand Down
11 changes: 5 additions & 6 deletions stan/math/rev/core/operator_plus_equal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ inline var_value<T>& var_value<T>::operator+=(const var_value<T>& b) {
vi_ = new vari(this->val() + b.val(), false);
if (unlikely(is_any_nan(old_vi->val_, b.val()))) {
reverse_pass_callback([old_vi, b]() mutable {
old_vi->adj_ = NOT_A_NUMBER;
b.adj() = NOT_A_NUMBER;
});
old_vi->adj_ = NOT_A_NUMBER;
b.adj() = NOT_A_NUMBER;
});
} else {
reverse_pass_callback([new_vi = this->vi_, old_vi, b]() mutable {
old_vi->adj_ += new_vi->adj_;
Expand All @@ -34,9 +34,8 @@ inline var_value<T>& var_value<T>::operator+=(T b) {
auto* old_vi = this->vi_;
vi_ = new vari(this->val() + b, false);
if (unlikely(is_any_nan(old_vi->val_, b))) {
reverse_pass_callback([old_vi, b]() mutable {
old_vi->adj_ = NOT_A_NUMBER;
});
reverse_pass_callback(
[old_vi, b]() mutable { old_vi->adj_ = NOT_A_NUMBER; });
} else {
reverse_pass_callback([new_vi = this->vi_, old_vi, b]() mutable {
old_vi->adj_ += new_vi->adj_;
Expand Down
3 changes: 1 addition & 2 deletions stan/math/rev/fun/add.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ namespace math {
* @param b second scalar
* @return the sum of the scalars
*/
template <typename T1, typename T2,
require_any_st_var<T1, T2>* = nullptr>
template <typename T1, typename T2, require_any_st_var<T1, T2>* = nullptr>
inline auto add(const T1& a, const T2& b) {
return a + b;
}
Expand Down
11 changes: 6 additions & 5 deletions stan/math/rev/meta/promote_var_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ namespace stan {
* Else the type will be `Matrix<var>`
*/
template <typename ReturnType, typename... Types>
using promote_var_matrix_t = std::conditional_t<
is_any_var_matrix<Types...>::value,
stan::math::var_value<
stan::math::promote_scalar_t<double, plain_type_t<ReturnType>>>,
arena_t<stan::math::promote_scalar_t<stan::math::var, plain_type_t<ReturnType>>>>;
using promote_var_matrix_t
= std::conditional_t<is_any_var_matrix<Types...>::value,
stan::math::var_value<stan::math::promote_scalar_t<
double, plain_type_t<ReturnType>>>,
arena_t<stan::math::promote_scalar_t<
stan::math::var, plain_type_t<ReturnType>>>>;
} // namespace stan

#endif
14 changes: 8 additions & 6 deletions test/unit/math/mix/core/operator_addition_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ TEST(mathMixCore, operatorAddition) {

TEST(mathMixCore, operatorAdditionMatrixSmall) {
// This calls operator+ under the hood
auto f = [](const auto& x1, const auto& x2) { return stan::math::add(x1, x2); };
auto f
= [](const auto& x1, const auto& x2) { return stan::math::add(x1, x2); };
stan::test::ad_tolerances tols;
tols.hessian_hessian_ = 1e-1;
tols.hessian_fvar_hessian_ = 1e-1;
Expand Down Expand Up @@ -42,11 +43,11 @@ TEST(mathMixCore, operatorAdditionMatrixSmall) {
stan::test::expect_ad_matvar(tols, f, matrix_m11, vector_v1);
stan::test::expect_ad_matvar(tols, f, row_vector_rv1, matrix_m11);
stan::test::expect_ad_matvar(tols, f, matrix_m11, matrix_m11);

}

TEST(mathMixCore, operatorAdditionMatrixZeroSize) {
auto f = [](const auto& x1, const auto& x2) { return stan::math::add(x1, x2); };
auto f
= [](const auto& x1, const auto& x2) { return stan::math::add(x1, x2); };
stan::test::ad_tolerances tols;
tols.hessian_hessian_ = 1e-1;
tols.hessian_fvar_hessian_ = 1e-1;
Expand Down Expand Up @@ -76,7 +77,8 @@ TEST(mathMixCore, operatorAdditionMatrixZeroSize) {
}

TEST(mathMixCore, operatorAdditionMatrixNormal) {
auto f = [](const auto& x1, const auto& x2) { return stan::math::add(x1, x2); };
auto f
= [](const auto& x1, const auto& x2) { return stan::math::add(x1, x2); };
stan::test::ad_tolerances tols;
tols.hessian_hessian_ = 1e-1;
tols.hessian_fvar_hessian_ = 1e-1;
Expand Down Expand Up @@ -109,7 +111,8 @@ TEST(mathMixCore, operatorAdditionMatrixNormal) {
}

TEST(mathMixCore, operatorAdditionMatrixFailures) {
auto f = [](const auto& x1, const auto& x2) { return stan::math::add(x1, x2); };
auto f
= [](const auto& x1, const auto& x2) { return stan::math::add(x1, x2); };
stan::test::ad_tolerances tols;
tols.hessian_hessian_ = 1e-1;
tols.hessian_fvar_hessian_ = 1e-1;
Expand All @@ -132,5 +135,4 @@ TEST(mathMixCore, operatorAdditionMatrixFailures) {
stan::test::expect_ad_matvar(tols, f, u_tr, u);
stan::test::expect_ad_matvar(tols, f, u, vv);
stan::test::expect_ad_matvar(tols, f, rvv, u);

}