Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds testing suite for nested binary var matrix functions #2502

Merged
merged 15 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove unneeded overloads of var<matrix> for apply_scalar_binary and …
…add is_container_or_var_matrix
  • Loading branch information
SteveBronder committed Jun 22, 2021
commit 3848b10ee4badaf3c74e2e5fc8ff58172f558d39
7 changes: 4 additions & 3 deletions stan/math/prim/functor/apply_scalar_binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
return result;
}


/**
* Specialisation for use with two nested containers (std::vectors).
* The returned scalar type is deduced to allow for cases where the input and
Expand All @@ -321,7 +322,7 @@ inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
* @return std::vector with result of applying functor to inputs.
*/
template <typename T1, typename T2, typename F,
require_all_std_vector_vt<is_container, T1, T2>* = nullptr>
require_all_std_vector_vt<is_container_or_var_matrix, T1, T2>* = nullptr>
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
check_matching_sizes("Binary function", "x", x, "y", y);
using T_return = plain_type_t<decltype(apply_scalar_binary(x[0], y[0], f))>;
Expand All @@ -348,7 +349,7 @@ inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
* @return std::vector with result of applying functor to inputs.
*/
template <typename T1, typename T2, typename F,
require_std_vector_vt<is_container, T1>* = nullptr,
require_std_vector_vt<is_container_or_var_matrix, T1>* = nullptr,
require_stan_scalar_t<T2>* = nullptr>
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
using T_return = plain_type_t<decltype(apply_scalar_binary(x[0], y, f))>;
Expand Down Expand Up @@ -376,7 +377,7 @@ inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
*/
template <typename T1, typename T2, typename F,
require_stan_scalar_t<T1>* = nullptr,
require_std_vector_vt<is_container, T2>* = nullptr>
require_std_vector_vt<is_container_or_var_matrix, T2>* = nullptr>
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
using T_return = plain_type_t<decltype(apply_scalar_binary(x, y[0], f))>;
size_t y_size = y.size();
Expand Down
8 changes: 1 addition & 7 deletions stan/math/prim/functor/apply_vector_unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,6 @@ struct apply_vector_unary<T, require_std_vector_vt<is_stan_scalar, T>> {
}
};

namespace internal {
template <typename T>
using is_container_or_var_matrix
= disjunction<is_container<T>, is_var_matrix<T>>;
}

/**
* Specialisation for use with nested containers (std::vectors).
* For each of the member functions, an std::vector with the appropriate
Expand All @@ -177,7 +171,7 @@ using is_container_or_var_matrix
*/
template <typename T>
struct apply_vector_unary<
T, require_std_vector_vt<internal::is_container_or_var_matrix, T>> {
T, require_std_vector_vt<is_container_or_var_matrix, T>> {
using T_vt = value_type_t<T>;

/**
Expand Down
1 change: 1 addition & 0 deletions stan/math/prim/meta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@
#include <stan/math/prim/meta/is_complex.hpp>
#include <stan/math/prim/meta/is_constant.hpp>
#include <stan/math/prim/meta/is_container.hpp>
#include <stan/math/prim/meta/is_container_or_var_matrix.hpp>
#include <stan/math/prim/meta/is_eigen.hpp>
#include <stan/math/prim/meta/is_eigen_dense_base.hpp>
#include <stan/math/prim/meta/is_eigen_dense_dynamic.hpp>
Expand Down
30 changes: 30 additions & 0 deletions stan/math/prim/meta/is_container_or_var_matrix.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef STAN_MATH_PRIM_META_IS_CONTAINER_OR_VAR_MATRIX_HPP
#define STAN_MATH_PRIM_META_IS_CONTAINER_OR_VAR_MATRIX_HPP

#include <stan/math/prim/meta/bool_constant.hpp>
#include <stan/math/prim/meta/disjunction.hpp>
#include <stan/math/prim/meta/is_eigen.hpp>
#include <stan/math/prim/meta/is_vector.hpp>
#include <stan/math/prim/meta/is_var_matrix.hpp>
#include <stan/math/prim/meta/scalar_type.hpp>
#include <stan/math/prim/meta/value_type.hpp>
#include <stan/math/prim/meta/require_helpers.hpp>

#include <type_traits>

namespace stan {

/**
* Deduces whether type is eigen matrix, standard vector, or var<Matrix>.
* @tparam Container type to check
*/
template <typename Container>
using is_container_or_var_matrix = bool_constant<
math::disjunction<is_container<Container>, is_var_matrix<Container>>::value>;

STAN_ADD_REQUIRE_UNARY(container_or_var_matrix, is_container_or_var_matrix, general_types);
STAN_ADD_REQUIRE_CONTAINER(container_or_var_matrix, is_container_or_var_matrix, general_types);

} // namespace stan

#endif
108 changes: 1 addition & 107 deletions stan/math/rev/functor/apply_scalar_binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,7 @@
namespace stan {
namespace math {

/**
* Specialisation for use with combinations of
* `Eigen::Matrix` and `var_value<Eigen::Matrix>` inputs.
* Eigen's binaryExpr framework is used for more efficient indexing of both row-
* and column-major inputs without separate loops.
*
* @tparam T1 Type of first argument to which functor is applied.
* @tparam T2 Type of second argument to which functor is applied.
* @tparam F Type of functor to apply.
* @param x First Matrix input to which operation is applied.
* @param y Second Matrix input to which operation is applied.
* @param f functor to apply to Matrix inputs.
* @return `var_value<Matrix>` with result of applying functor to inputs.
*/
template <typename T1, typename T2, typename F,
require_any_var_matrix_t<T1, T2>* = nullptr,
require_all_matrix_t<T1, T2>* = nullptr>
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
check_matching_dims("Binary function", "x", x, "y", y);
return f(x, y);
}


/**
* Specialisation for use with one `var_value<Eigen vector>` (row or column) and
Expand Down Expand Up @@ -126,8 +106,6 @@ inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
* @param f functor to apply to var matrix and scalar inputs.
* @return `var_value<Matrix> object with result of applying functor to inputs.
*
* Note: The return expresssion needs to be evaluated, otherwise the captured
* function and scalar fall out of scope.
*/
template <typename T1, typename T2, typename F,
require_var_matrix_t<T1>* = nullptr,
Expand All @@ -149,8 +127,6 @@ inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
* @return var value with inner Eigen type with result of applying functor to
* inputs.
*
* Note: The return expresssion needs to be evaluated, otherwise the captured
* function and scalar fall out of scope.
*/
template <typename T1, typename T2, typename F,
require_stan_scalar_t<T1>* = nullptr,
Expand All @@ -159,89 +135,7 @@ inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
return f(x, y);
}

/**
* Specialisation for use when the first input is a nested std::vector and the
* second is a scalar. The returned scalar type is deduced to allow for cases
* where the input and return scalar types differ (e.g., functions implicitly
* promoting integers).
*
* @tparam T1 Type of std::vector to which functor is applied.
* @tparam T2 Type of scalar to which functor is applied.
* @tparam F Type of functor to apply.
* @param x std::vector input to which operation is applied.
* @param y Scalar input to which operation is applied.
* @param f functor to apply to inputs.
* @return std::vector with result of applying functor to inputs.
*/
template <typename T1, typename T2, typename F,
require_std_vector_t<T1>* = nullptr,
require_var_matrix_t<value_type_t<T1>>* = nullptr,
require_stan_scalar_t<T2>* = nullptr>
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
using T_return = plain_type_t<decltype(apply_scalar_binary(x[0], y, f))>;
size_t x_size = x.size();
std::vector<T_return> result(x_size);
for (size_t i = 0; i < x_size; ++i) {
result[i] = apply_scalar_binary(x[i], y, f);
}
return result;
}

/**
* Specialisation for use when the first input is a scalar and the second is a
* nested std::vector. The returned scalar type is deduced to allow for cases
* where the input and return scalar types differ (e.g., functions implicitly
* promoting integers).
*
* @tparam T1 Type of scalar to which functor is applied.
* @tparam T2 Type of std::vector to which functor is applied.
* @tparam F Type of functor to apply.
* @param x Scalar input to which operation is applied.
* @param y std::vector input to which operation is applied.
* @param f functor to apply to inputs.
* @return std::vector with result of applying functor to inputs.
*/
template <typename T1, typename T2, typename F,
require_stan_scalar_t<T1>* = nullptr,
require_std_vector_t<T2>* = nullptr,
require_var_matrix_t<value_type_t<T2>>* = nullptr>
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
using T_return = plain_type_t<decltype(apply_scalar_binary(x, y[0], f))>;
size_t y_size = y.size();
std::vector<T_return> result(y_size);
for (size_t i = 0; i < y_size; ++i) {
result[i] = apply_scalar_binary(x, y[i], f);
}
return result;
}

/**
* Specialisation for use with two nested containers (std::vectors).
* The returned scalar type is deduced to allow for cases where the input and
* return scalar types differ (e.g., functions implicitly promoting
* integers).
*
* @tparam T1 Type of first std::vector to which functor is applied.
* @tparam T2 Type of second std::vector to which functor is applied.
* @tparam F Type of functor to apply.
* @param x First std::vector input to which operation is applied.
* @param y Second std::vector input to which operation is applied.
* @param f functor to apply to std::vector inputs.
* @return std::vector with result of applying functor to inputs.
*/
template <typename T1, typename T2, typename F,
require_any_var_matrix_t<T1, T2>* = nullptr>
inline auto apply_scalar_binary(const std::vector<T1>& x,
const std::vector<T2>& y, const F& f) {
check_matching_sizes("Binary function", "x", x, "y", y);
using T_return = plain_type_t<decltype(apply_scalar_binary(x[0], y[0], f))>;
size_t y_size = y.size();
std::vector<T_return> result(y_size);
for (size_t i = 0; i < y_size; ++i) {
result[i] = apply_scalar_binary(x[i], y[i], f);
}
return result;
}

} // namespace math
} // namespace stan
Expand Down
2 changes: 1 addition & 1 deletion test/unit/math/test_ad.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1315,7 +1315,7 @@ void expect_ad_vectorized_binary_impl(const ad_tolerances& tols, const F& f,
expect_ad(tols, f, nest_nest_x,
nest_nest_y); // nest<nest<mat>>, nest<nest<mat>>
expect_ad(tols, f, nest_nest_x, y[0]); // nest<nest<mat>, scal
expect_ad(tols, f, x[0], nest_nest_y); // scal, nest<nest<mat>
expect_ad(tols, f, x[0], nest_nest_y); // scal, nest<nest<mat>>
}

/**
Expand Down