Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds testing suite for nested binary var matrix functions #2502

Merged
merged 15 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion stan/math/prim/fun/as_array_or_scalar.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,33 @@ inline auto as_array_or_scalar(T&& v) {
* @param v Specified vector.
* @return Matrix converted to an array.
*/
template <typename T, require_std_vector_t<T>* = nullptr>
template <typename T, require_std_vector_t<T>* = nullptr,
require_not_std_vector_t<value_type_t<T>>* = nullptr>
inline auto as_array_or_scalar(T&& v) {
using T_map
= Eigen::Map<const Eigen::Array<value_type_t<T>, Eigen::Dynamic, 1>>;
return make_holder([](auto& x) { return T_map(x.data(), x.size()); },
std::forward<T>(v));
}

/**
* Converts an std::vector<std::vector> to an Eigen Array.
* @tparam T A standard vector with inner container of a standard vector
* with an inner stan scalar.
* @param v specified vector of vectorised
* @return An Eigen Array with dynamic rows and columns.
*/
template <typename T, require_std_vector_vt<is_std_vector, T>* = nullptr,
require_std_vector_vt<is_stan_scalar, value_type_t<T>>* = nullptr>
inline auto as_array_or_scalar(T&& v) {
Eigen::Array<scalar_type_t<T>, -1, -1> ret(v.size(), v[0].size());
for (size_t i = 0; i < v.size(); ++i) {
ret.row(i) = Eigen::Map<const Eigen::Array<scalar_type_t<T>, 1, -1>>(
v[i].data(), v[i].size());
}
return ret;
}

} // namespace math
} // namespace stan

Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/beta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
require_all_not_var_matrix_t<T1, T2>* = nullptr>
inline auto beta(const T1& a, const T2& b) {
return apply_scalar_binary(
a, b, [&](const auto& c, const auto& d) { return beta(c, d); });
a, b, [](const auto& c, const auto& d) { return beta(c, d); });
}

} // namespace math
Expand Down
9 changes: 5 additions & 4 deletions stan/math/prim/functor/apply_scalar_binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,9 @@ inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
* @param f functor to apply to std::vector inputs.
* @return std::vector with result of applying functor to inputs.
*/
template <typename T1, typename T2, typename F,
require_all_std_vector_vt<is_container, T1, T2>* = nullptr>
template <
typename T1, typename T2, typename F,
require_all_std_vector_vt<is_container_or_var_matrix, T1, T2>* = nullptr>
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
check_matching_sizes("Binary function", "x", x, "y", y);
using T_return = plain_type_t<decltype(apply_scalar_binary(x[0], y[0], f))>;
Expand All @@ -348,7 +349,7 @@ inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
* @return std::vector with result of applying functor to inputs.
*/
template <typename T1, typename T2, typename F,
require_std_vector_vt<is_container, T1>* = nullptr,
require_std_vector_vt<is_container_or_var_matrix, T1>* = nullptr,
require_stan_scalar_t<T2>* = nullptr>
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
using T_return = plain_type_t<decltype(apply_scalar_binary(x[0], y, f))>;
Expand Down Expand Up @@ -376,7 +377,7 @@ inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
*/
template <typename T1, typename T2, typename F,
require_stan_scalar_t<T1>* = nullptr,
require_std_vector_vt<is_container, T2>* = nullptr>
require_std_vector_vt<is_container_or_var_matrix, T2>* = nullptr>
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
using T_return = plain_type_t<decltype(apply_scalar_binary(x, y[0], f))>;
size_t y_size = y.size();
Expand Down
8 changes: 1 addition & 7 deletions stan/math/prim/functor/apply_vector_unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,6 @@ struct apply_vector_unary<T, require_std_vector_vt<is_stan_scalar, T>> {
}
};

namespace internal {
template <typename T>
using is_container_or_var_matrix
= disjunction<is_container<T>, is_var_matrix<T>>;
}

/**
* Specialisation for use with nested containers (std::vectors).
* For each of the member functions, an std::vector with the appropriate
Expand All @@ -177,7 +171,7 @@ using is_container_or_var_matrix
*/
template <typename T>
struct apply_vector_unary<
T, require_std_vector_vt<internal::is_container_or_var_matrix, T>> {
T, require_std_vector_vt<is_container_or_var_matrix, T>> {
using T_vt = value_type_t<T>;

/**
Expand Down
1 change: 1 addition & 0 deletions stan/math/prim/meta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@
#include <stan/math/prim/meta/is_complex.hpp>
#include <stan/math/prim/meta/is_constant.hpp>
#include <stan/math/prim/meta/is_container.hpp>
#include <stan/math/prim/meta/is_container_or_var_matrix.hpp>
#include <stan/math/prim/meta/is_eigen.hpp>
#include <stan/math/prim/meta/is_eigen_dense_base.hpp>
#include <stan/math/prim/meta/is_eigen_dense_dynamic.hpp>
Expand Down
34 changes: 34 additions & 0 deletions stan/math/prim/meta/is_container_or_var_matrix.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef STAN_MATH_PRIM_META_IS_CONTAINER_OR_VAR_MATRIX_HPP
#define STAN_MATH_PRIM_META_IS_CONTAINER_OR_VAR_MATRIX_HPP

#include <stan/math/prim/meta/bool_constant.hpp>
#include <stan/math/prim/meta/disjunction.hpp>
#include <stan/math/prim/meta/is_eigen.hpp>
#include <stan/math/prim/meta/is_vector.hpp>
#include <stan/math/prim/meta/is_container.hpp>
#include <stan/math/prim/meta/is_var_matrix.hpp>
#include <stan/math/prim/meta/scalar_type.hpp>
#include <stan/math/prim/meta/value_type.hpp>
#include <stan/math/prim/meta/require_helpers.hpp>

#include <type_traits>

namespace stan {

/**
* Deduces whether type is eigen matrix, standard vector, or var<Matrix>.
* @tparam Container type to check
*/
template <typename Container>
using is_container_or_var_matrix
= bool_constant<math::disjunction<is_container<Container>,
is_var_matrix<Container>>::value>;

STAN_ADD_REQUIRE_UNARY(container_or_var_matrix, is_container_or_var_matrix,
general_types);
STAN_ADD_REQUIRE_CONTAINER(container_or_var_matrix, is_container_or_var_matrix,
general_types);

} // namespace stan

#endif
4 changes: 4 additions & 0 deletions stan/math/rev/fun/bessel_first_kind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ inline var bessel_first_kind(int v, const var& a) {
});
}

/**
* Overload with `var_value<Matrix>` for `int`, `std::vector<int>`, and
* `std::vector<std::vector<int>>`
*/
template <typename T1, typename T2, require_st_integral<T1>* = nullptr,
require_eigen_t<T2>* = nullptr>
inline auto bessel_first_kind(const T1& v, const var_value<T2>& a) {
Expand Down
4 changes: 4 additions & 0 deletions stan/math/rev/fun/bessel_second_kind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ inline var bessel_second_kind(int v, const var& a) {
});
}

/**
* Overload with `var_value<Matrix>` for `int`, `std::vector<int>`, and
* `std::vector<std::vector<int>>`
*/
template <typename T1, typename T2, require_st_integral<T1>* = nullptr,
require_eigen_t<T2>* = nullptr>
inline auto bessel_second_kind(const T1& v, const var_value<T2>& a) {
Expand Down
17 changes: 11 additions & 6 deletions stan/math/rev/fun/binary_log_loss.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ inline var binary_log_loss(int y, const var& y_hat) {
}
}

/**
* Overload with `int` and `var_value<Matrix>`
*/
template <typename Mat, require_eigen_t<Mat>* = nullptr>
inline auto binary_log_loss(int y, const var_value<Mat>& y_hat) {
if (y == 0) {
Expand All @@ -70,12 +73,14 @@ inline auto binary_log_loss(int y, const var_value<Mat>& y_hat) {
}
}

template <typename Mat, require_eigen_t<Mat>* = nullptr>
inline auto binary_log_loss(const std::vector<int>& y,
const var_value<Mat>& y_hat) {
arena_t<Eigen::Array<bool, -1, 1>> arena_y
= Eigen::Map<const Eigen::Array<int, -1, 1>>(y.data(), y.size())
.cast<bool>();
/**
* Overload with `var_value<Matrix>` for `std::vector<int>` and
* `std::vector<std::vector<int>>`
*/
template <typename StdVec, typename Mat, require_eigen_t<Mat>* = nullptr,
require_st_integral<StdVec>* = nullptr>
inline auto binary_log_loss(const StdVec& y, const var_value<Mat>& y_hat) {
auto arena_y = to_arena(as_array_or_scalar(y).template cast<bool>());
auto ret_val
= -(arena_y == 0)
.select((-y_hat.val().array()).log1p(), y_hat.val().array().log());
Expand Down
1 change: 1 addition & 0 deletions stan/math/rev/functor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <stan/math/rev/functor/algebra_solver_newton.hpp>
#include <stan/math/rev/functor/algebra_system.hpp>
#include <stan/math/rev/functor/apply_scalar_unary.hpp>
#include <stan/math/rev/functor/apply_scalar_binary.hpp>
#include <stan/math/rev/functor/apply_vector_unary.hpp>
#include <stan/math/rev/functor/coupled_ode_system.hpp>
#include <stan/math/rev/functor/cvodes_integrator.hpp>
Expand Down
102 changes: 102 additions & 0 deletions stan/math/rev/functor/apply_scalar_binary.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#ifndef STAN_MATH_REV_FUNCTOR_APPLY_SCALAR_BINARY_HPP
#define STAN_MATH_REV_FUNCTOR_APPLY_SCALAR_BINARY_HPP

#include <stan/math/prim/fun/as_column_vector_or_scalar.hpp>
#include <stan/math/rev/meta.hpp>
#include <stan/math/prim/err/check_matching_dims.hpp>
#include <stan/math/prim/err/check_matching_sizes.hpp>
#include <stan/math/prim/fun/num_elements.hpp>
#include <vector>

namespace stan {
namespace math {

/**
* Specialisation for use with combinations of
* `Eigen::Matrix` and `var_value<Eigen::Matrix>` inputs.
* Eigen's binaryExpr framework is used for more efficient indexing of both row-
* and column-major inputs without separate loops.
andrjohns marked this conversation as resolved.
Show resolved Hide resolved
*
* @tparam T1 Type of first argument to which functor is applied.
* @tparam T2 Type of second argument to which functor is applied.
* @tparam F Type of functor to apply.
* @param x First Matrix input to which operation is applied.
* @param y Second Matrix input to which operation is applied.
* @param f functor to apply to Matrix inputs.
* @return `var_value<Matrix>` with result of applying functor to inputs.
*/
template <typename T1, typename T2, typename F,
require_any_var_matrix_t<T1, T2>* = nullptr,
require_all_matrix_t<T1, T2>* = nullptr>
andrjohns marked this conversation as resolved.
Show resolved Hide resolved
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
check_matching_dims("Binary function", "x", x, "y", y);
return f(x, y);
}

/**
* Specialisation for use with one `var_value<Eigen vector>` (row or column) and
* a one-dimensional std::vector of integer types
*
* @tparam T1 Type of first argument to which functor is applied.
* @tparam T2 Type of second argument to which functor is applied.
* @tparam F Type of functor to apply.
* @param x Matrix input to which operation is applied.
* @param y Integer std::vector input to which operation is applied.
* @param f functor to apply to inputs.
* @return var_value<Eigen> object with result of applying functor to inputs.
*/
template <typename T1, typename T2, typename F,
require_any_var_matrix_t<T1, T2>* = nullptr,
require_any_std_vector_vt<std::is_integral, T1, T2>* = nullptr>
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
check_matching_sizes("Binary function", "x", x, "y", y);
return f(x, y);
}

/**
* Specialisation for use with a two-dimensional std::vector of integer types
* and one `var_value<Matrix>`.
*
* @tparam T1 Type of first argument to which functor is applied.
* @tparam T2 Type of second argument to which functor is applied.
* @tparam F Type of functor to apply.
* @param x Either a var matrix or nested integer std::vector input to which
* operation is applied.
* @param x Either a var matrix or nested integer std::vector input to which
* operation is applied.
* @param f functor to apply to inputs.
* @return Eigen object with result of applying functor to inputs.
*/
template <typename T1, typename T2, typename F,
require_any_std_vector_vt<is_std_vector, T1, T2>* = nullptr,
require_any_std_vector_st<std::is_integral, T1, T2>* = nullptr,
require_any_var_matrix_t<T1, T2>* = nullptr>
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
return f(x, y);
}

/**
* Specialisation for use when the one input is an `var_value<Eigen> type and
* the other is a scalar.
*
* @tparam T1 Type of either `var_value<Matrix>` or scalar object to which
* functor is applied.
* @tparam T2 Type of either `var_value<Matrix>` or scalar object to which
* functor is applied.
* @tparam F Type of functor to apply.
* @param x Matrix or Scalar input to which operation is applied.
* @param x Matrix or Scalar input to which operation is applied.
* @param f functor to apply to var matrix and scalar inputs.
* @return `var_value<Matrix> object with result of applying functor to inputs.
*
*/
template <typename T1, typename T2, typename F,
require_any_stan_scalar_t<T1, T2>* = nullptr,
require_any_var_matrix_t<T1, T2>* = nullptr>
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
return f(x, y);
}

} // namespace math
} // namespace stan
#endif
8 changes: 3 additions & 5 deletions test/unit/math/mix/fun/bessel_first_kind_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ TEST(mathMixScalFun, besselFirstKind_matvec) {
};

std::vector<int> std_in1{3, 1};
Eigen::VectorXd in2(2);
in2 << 0.5, 3.4;
stan::test::expect_ad_matvar(f, std_in1, in2);

stan::test::expect_ad_matvar(f, std_in1[0], in2);
Eigen::MatrixXd in2(2, 2);
in2 << 0.5, 3.4, 0.5, 3.4;
stan::test::expect_ad_vectorized_matvar(f, std_in1, in2);
}
7 changes: 3 additions & 4 deletions test/unit/math/mix/fun/bessel_second_kind_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ TEST(mathMixScalFun, besselSecondKind_matvec) {
};

std::vector<int> std_in1{3, 1};
Eigen::VectorXd in2(2);
in2 << 0.5, 3.4;
stan::test::expect_ad_matvar(f, std_in1, in2);
stan::test::expect_ad_matvar(f, std_in1[0], in2);
Eigen::MatrixXd in2(2, 2);
in2 << 0.5, 3.4, 0.5, 3.4;
stan::test::expect_ad_vectorized_matvar(f, std_in1, in2);
}
14 changes: 14 additions & 0 deletions test/unit/math/mix/fun/beta2_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include <test/unit/math/test_ad.hpp>
andrjohns marked this conversation as resolved.
Show resolved Hide resolved

TEST(mathMixScalFun, beta_varmat_vectorized) {
auto f = [](const auto& x1, const auto& x2) {
using stan::math::beta;
return beta(x1, x2);
};

Eigen::MatrixXd in1(2, 2);
in1 << 0.5, 3.4, 5.2, 0.5;
Eigen::MatrixXd in2(2, 2);
in2 << 3.3, 0.9, 6.7, 3.3;
stan::test::expect_ad_vectorized_matvar(f, in1, in2);
}
14 changes: 13 additions & 1 deletion test/unit/math/mix/fun/binary_log_loss_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ TEST(mathMixScalFun, binaryLogLossvec) {
stan::test::expect_ad_vectorized_binary(f, std_std_in1, mat_in2);
}

TEST(mathMixScalFun, binaryLogLossmatvar) {
TEST(mathMixScalFun, binaryLogLossMatVar) {
auto f = [](const auto& x1, const auto& x2) {
using stan::math::binary_log_loss;
return binary_log_loss(x1, x2);
Expand All @@ -41,3 +41,15 @@ TEST(mathMixScalFun, binaryLogLossmatvar) {
stan::test::expect_ad_matvar(f, std_in1, in2);
stan::test::expect_ad_matvar(f, std_in1[0], in2);
}

TEST(mathMixScalFun, binaryLogLossMatVarVec) {
auto f = [](const auto& x1, const auto& x2) {
using stan::math::binary_log_loss;
return binary_log_loss(x1, x2);
};

std::vector<int> std_in1{3, 1};
Eigen::MatrixXd in2(2, 2);
in2 << 0.5, 3.4, 0.5, 3.5;
stan::test::expect_ad_vectorized_matvar(f, std_in1, in2);
}
6 changes: 5 additions & 1 deletion test/unit/math/mix/fun/falling_factorial_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,13 @@ TEST(mathMixScalFun, fallingFactorial_matvar) {
return falling_factorial(x1, x2);
};

std::vector<int> std_in2{3, 1};
Eigen::VectorXd in1(2);
in1 << 0.5, 3.4;
std::vector<int> std_in2{3, 1};
Eigen::MatrixXd mat(2, 2);
mat << 0.5, 3.4, 0.5, 3.4;

stan::test::expect_ad_matvar(f, in1, std_in2);
stan::test::expect_ad_matvar(f, in1, std_in2[0]);
stan::test::expect_ad_vectorized_matvar(f, mat, std_in2);
}
Loading