Skip to content

Commit

Permalink
Added dimension checks and tests (Issue #1805)
Browse files Browse the repository at this point in the history
  • Loading branch information
bbbales2 committed Mar 3, 2021
1 parent 6f77316 commit af9f338
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 125 deletions.
8 changes: 8 additions & 0 deletions stan/math/prim/fun/fma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ template <typename T1, typename T2, typename T3,
require_any_matrix_t<T1, T2, T3>* = nullptr,
require_not_var_t<return_type_t<T1, T2, T3>>* = nullptr>
inline auto fma(T1&& x, T2&& y, T3&& z) {
if(is_matrix<T1>::value && is_matrix<T2>::value) {
check_matching_dims("fma", "x", x, "y", y);
}
if(is_matrix<T1>::value && is_matrix<T3>::value) {
check_matching_dims("fma", "x", x, "z", z);
} else if(is_matrix<T2>::value && is_matrix<T3>::value) {
check_matching_dims("fma", "y", y, "z", z);
}
return make_holder(
[](auto&& x, auto&& y, auto&& z) {
return ((as_array_or_scalar(x) * as_array_or_scalar(y))
Expand Down
22 changes: 18 additions & 4 deletions stan/math/prim/fun/offset_multiplier_constrain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,15 @@ inline auto offset_multiplier_constrain(const T& x, const M& mu,
const char* function = "offset_multiplier_constrain";
const auto& mu_ref = to_ref(mu);
const auto& sigma_ref = to_ref(sigma);
check_consistent_sizes(function, "offset", mu, "multiplier", sigma,
"parameter", x);
if(is_matrix<T>::value && is_matrix<M>::value) {
check_matching_dims("function", "x", x, "mu", mu);
}
if(is_matrix<T>::value && is_matrix<S>::value) {
check_matching_dims("function", "x", x, "sigma", sigma);
} else if(is_matrix<M>::value && is_matrix<S>::value) {
check_matching_dims("function", "mu", mu, "sigma", sigma);
}

check_finite(function, "offset", value_of_rec(mu_ref));
check_positive_finite(function, "multiplier", value_of_rec(sigma_ref));
return fma(sigma_ref, x, mu_ref);
Expand Down Expand Up @@ -82,8 +89,15 @@ inline auto offset_multiplier_constrain(const T& x, const M& mu, const S& sigma,
const char* function = "offset_multiplier_constrain";
const auto& mu_ref = to_ref(mu);
const auto& sigma_ref = to_ref(sigma);
check_consistent_sizes(function, "offset", mu, "multiplier", sigma,
"parameter", x);
if(is_matrix<T>::value && is_matrix<M>::value) {
check_matching_dims("function", "x", x, "mu", mu);
}
if(is_matrix<T>::value && is_matrix<S>::value) {
check_matching_dims("function", "x", x, "sigma", sigma);
} else if(is_matrix<M>::value && is_matrix<S>::value) {
check_matching_dims("function", "mu", mu, "sigma", sigma);
}

check_finite(function, "offset", value_of_rec(mu_ref));
check_positive_finite(function, "multiplier", value_of_rec(sigma_ref));
if (size(sigma_ref) == 1 && size(x) > 1) {
Expand Down
15 changes: 11 additions & 4 deletions stan/math/prim/fun/offset_multiplier_free.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,20 @@ namespace math {
* @throw std::domain_error if mu is not finite
* @throw std::invalid_argument if non-scalar arguments don't match in size
*/
template <typename T, typename L, typename S>
inline auto offset_multiplier_free(const T& y, const L& mu, const S& sigma) {
template <typename T, typename M, typename S>
inline auto offset_multiplier_free(const T& y, const M& mu, const S& sigma) {
const char* function = "offset_multiplier_free";
auto&& mu_ref = to_ref(mu);
auto&& sigma_ref = to_ref(sigma);
check_consistent_sizes(function, "offset", mu, "multiplier", sigma,
"parameter", y);
if(is_matrix<T>::value && is_matrix<M>::value) {
check_matching_dims("function", "y", y, "mu", mu);
}
if(is_matrix<T>::value && is_matrix<S>::value) {
check_matching_dims("function", "y", y, "sigma", sigma);
} else if(is_matrix<M>::value && is_matrix<S>::value) {
check_matching_dims("function", "mu", mu, "sigma", sigma);
}

check_finite(function, "offset", value_of(mu_ref));
check_positive_finite(function, "multiplier", value_of(sigma_ref));
return divide(subtract(y, mu_ref), sigma_ref);
Expand Down
Loading

0 comments on commit af9f338

Please sign in to comment.