Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds simple vectorized functions for var<matrix> #2461

Merged
merged 9 commits into from
May 5, 2021
3 changes: 2 additions & 1 deletion stan/math/prim/fun/bessel_first_kind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ inline T2 bessel_first_kind(int v, const T2 z) {
* @param b Second input
* @return Bessel first kind function applied to the two inputs.
*/
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
require_not_var_matrix_t<T2>* = nullptr>
inline auto bessel_first_kind(const T1& a, const T2& b) {
return apply_scalar_binary(a, b, [&](const auto& c, const auto& d) {
return bessel_first_kind(c, d);
Expand Down
3 changes: 2 additions & 1 deletion stan/math/prim/fun/beta.hpp
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ inline return_type_t<T1, T2> beta(const T1 a, const T2 b) {
* @param b Second input
* @return Beta function applied to the two inputs.
*/
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
require_all_not_var_matrix_t<T1, T2>* = nullptr>
andrjohns marked this conversation as resolved.
Show resolved Hide resolved
inline auto beta(const T1& a, const T2& b) {
return apply_scalar_binary(
a, b, [&](const auto& c, const auto& d) { return beta(c, d); });
Expand Down
3 changes: 2 additions & 1 deletion stan/math/prim/fun/binary_log_loss.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ inline T binary_log_loss(int y, const T& y_hat) {
* @param b Second input
* @return Binary log loss function applied to the two inputs.
*/
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
require_not_var_matrix_t<T2>* = nullptr>
inline auto binary_log_loss(const T1& a, const T2& b) {
return apply_scalar_binary(a, b, [&](const auto& c, const auto& d) {
return binary_log_loss(c, d);
Expand Down
3 changes: 2 additions & 1 deletion stan/math/prim/fun/ceil.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ inline auto ceil(const Container& x) {
* @return Least integer >= each value in x.
*/
template <typename Container,
require_container_st<std::is_arithmetic, Container>* = nullptr>
require_container_st<std::is_arithmetic, Container>* = nullptr,
require_not_var_matrix_t<Container>* = nullptr>
inline auto ceil(const Container& x) {
return apply_vector_unary<Container>::apply(
x, [](const auto& v) { return v.array().ceil(); });
Expand Down
6 changes: 4 additions & 2 deletions stan/math/prim/fun/erf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ struct erf_fun {
* @param x container
* @return Error function applied to each value in x.
*/
template <typename T, require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
T>* = nullptr>
template <
typename T,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr,
require_not_var_matrix_t<T>* = nullptr>
inline auto erf(const T& x) {
return apply_scalar_unary<erf_fun, T>::apply(x);
}
Expand Down
6 changes: 4 additions & 2 deletions stan/math/prim/fun/erfc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ struct erfc_fun {
* @param x container
* @return Complementary error function applied to each value in x.
*/
template <typename T, require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
T>* = nullptr>
template <
typename T,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr,
require_not_var_matrix_t<T>* = nullptr>
inline auto erfc(const T& x) {
return apply_scalar_unary<erfc_fun, T>::apply(x);
}
Expand Down
6 changes: 4 additions & 2 deletions stan/math/prim/fun/exp2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ struct exp2_fun {
* @param x container
* @return Elementwise exp2 of members of container.
*/
template <typename T, require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
T>* = nullptr>
template <
typename T,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr,
require_not_var_matrix_t<T>* = nullptr>
inline auto exp2(const T& x) {
return apply_scalar_unary<exp2_fun, T>::apply(x);
}
Expand Down
6 changes: 4 additions & 2 deletions stan/math/prim/fun/expm1.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ struct expm1_fun {
* @param x container
* @return Natural exponential of each value in x minus one.
*/
template <typename T, require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
T>* = nullptr>
template <
typename T,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr,
require_not_var_matrix_t<T>* = nullptr>
inline auto expm1(const T& x) {
return apply_scalar_unary<expm1_fun, T>::apply(x);
}
Expand Down
5 changes: 3 additions & 2 deletions stan/math/prim/fun/falling_factorial.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ namespace math {
*/
template <typename T, require_arithmetic_t<T>* = nullptr>
inline return_type_t<T> falling_factorial(const T& x, int n) {
static const char* function = "falling_factorial";
constexpr const char* function = "falling_factorial";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably save this for a separate PR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I can, though it's a teeny change that shouldn't effect much if you don't mind leaving it. In general we should sweep through the math library looking for static const char* function = and make them constexpr

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think probably best to pull it out for now, just so that all functions are consistent. But I'm not super tied to it, so feel free to ignore

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave it for now and I'll open up another PR this week just making a bunch of these constexpr

check_not_nan(function, "first argument", x);
check_nonnegative(function, "second argument", n);
return boost::math::falling_factorial(x, n, boost_policy_t<>());
Expand All @@ -78,7 +78,8 @@ inline return_type_t<T> falling_factorial(const T& x, int n) {
* @param b Second input
* @return Falling factorial function applied to the two inputs.
*/
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
require_all_not_var_matrix_t<T1, T2>* = nullptr>
inline auto falling_factorial(const T1& a, const T2& b) {
return apply_scalar_binary(a, b, [&](const auto& c, const auto& d) {
return falling_factorial(c, d);
Expand Down
6 changes: 4 additions & 2 deletions stan/math/prim/fun/floor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ struct floor_fun {
template <typename Container,
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
Container>* = nullptr>
Container>* = nullptr,
require_not_var_matrix_t<Container>* = nullptr>
inline auto floor(const Container& x) {
return apply_scalar_unary<floor_fun, Container>::apply(x);
}
Expand All @@ -50,7 +51,8 @@ inline auto floor(const Container& x) {
* @return Greatest integer <= each value in x.
*/
template <typename Container,
require_container_st<std::is_arithmetic, Container>* = nullptr>
require_container_st<std::is_arithmetic, Container>* = nullptr,
require_not_var_matrix_t<Container>* = nullptr>
inline auto floor(const Container& x) {
return apply_vector_unary<Container>::apply(
x, [](const auto& v) { return v.array().floor(); });
Expand Down
16 changes: 15 additions & 1 deletion stan/math/rev/fun/bessel_first_kind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace stan {
namespace math {

inline var bessel_first_kind(int v, const var& a) {
double ret_val = bessel_first_kind(v, a.val());
auto ret_val = bessel_first_kind(v, a.val());
auto precomp_bessel
= v * ret_val / a.val() - bessel_first_kind(v + 1, a.val());
return make_callback_var(ret_val,
Expand All @@ -18,6 +18,20 @@ inline var bessel_first_kind(int v, const var& a) {
});
}

template <typename T1, typename T2, require_st_integral<T1>* = nullptr,
require_eigen_t<T2>* = nullptr>
inline auto bessel_first_kind(const T1& v, const var_value<T2>& a) {
auto ret_val = bessel_first_kind(v, a.val()).array().eval();
auto v_map = as_array_or_scalar(v);
auto precomp_bessel
= to_arena(v_map * ret_val / a.val().array()
- bessel_first_kind(v_map + 1, a.val().array()));
return make_callback_var(
ret_val.matrix(), [precomp_bessel, a](const auto& vi) mutable {
a.adj().array() += vi.adj().array() * precomp_bessel;
});
}

andrjohns marked this conversation as resolved.
Show resolved Hide resolved
} // namespace math
} // namespace stan
#endif
14 changes: 14 additions & 0 deletions stan/math/rev/fun/bessel_second_kind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@ inline var bessel_second_kind(int v, const var& a) {
});
}

template <typename T1, typename T2, require_st_integral<T1>* = nullptr,
require_eigen_t<T2>* = nullptr>
inline auto bessel_second_kind(const T1& v, const var_value<T2>& a) {
auto ret_val = bessel_second_kind(v, a.val()).array().eval();
auto v_map = as_array_or_scalar(v);
auto precomp_bessel
= to_arena(v_map * ret_val / a.val().array()
- bessel_second_kind(v_map + 1, a.val().array()));
return make_callback_var(
ret_val.matrix(), [precomp_bessel, a](const auto& vi) mutable {
a.adj().array() += vi.adj().array() * precomp_bessel;
});
}
andrjohns marked this conversation as resolved.
Show resolved Hide resolved

} // namespace math
} // namespace stan
#endif
132 changes: 132 additions & 0 deletions stan/math/rev/fun/beta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,138 @@ inline var beta(double a, const var& b) {
});
}

template <typename Mat1, typename Mat2,
require_any_var_matrix_t<Mat1, Mat2>* = nullptr,
require_all_matrix_t<Mat1, Mat2>* = nullptr>
inline auto beta(const Mat1& a, const Mat2& b) {
if (!is_constant<Mat1>::value && !is_constant<Mat2>::value) {
arena_t<promote_scalar_t<var, Mat1>> arena_a = a;
arena_t<promote_scalar_t<var, Mat2>> arena_b = b;
auto beta_val = beta(arena_a.val(), arena_b.val());
auto digamma_ab
= to_arena(digamma(arena_a.val().array() + arena_b.val().array()));
return make_callback_var(
beta(arena_a.val(), arena_b.val()),
[arena_a, arena_b, digamma_ab](auto& vi) mutable {
const auto adj_val = (vi.adj().array() * vi.val().array()).eval();
arena_a.adj().array()
+= adj_val * (digamma(arena_a.val().array()) - digamma_ab);
arena_b.adj().array()
+= adj_val * (digamma(arena_b.val().array()) - digamma_ab);
});
} else if (!is_constant<Mat1>::value) {
arena_t<promote_scalar_t<var, Mat1>> arena_a = a;
arena_t<promote_scalar_t<double, Mat2>> arena_b = value_of(b);
auto digamma_ab
= to_arena(digamma(arena_a.val()).array()
- digamma(arena_a.val().array() + arena_b.array()));
return make_callback_var(beta(arena_a.val(), arena_b),
[arena_a, arena_b, digamma_ab](auto& vi) mutable {
arena_a.adj().array() += vi.adj().array()
* digamma_ab
* vi.val().array();
});
} else if (!is_constant<Mat2>::value) {
arena_t<promote_scalar_t<double, Mat1>> arena_a = value_of(a);
arena_t<promote_scalar_t<var, Mat2>> arena_b = b;
auto beta_val = beta(arena_a, arena_b.val());
auto digamma_ab
= to_arena((digamma(arena_b.val()).array()
- digamma(arena_a.array() + arena_b.val().array()))
* beta_val.array());
return make_callback_var(
beta_val, [arena_a, arena_b, digamma_ab](auto& vi) mutable {
arena_b.adj().array() += vi.adj().array() * digamma_ab.array();
});
}
}

template <typename Scalar, typename VarMat,
require_var_matrix_t<VarMat>* = nullptr,
require_stan_scalar_t<Scalar>* = nullptr>
inline auto beta(const Scalar& a, const VarMat& b) {
if (!is_constant<Scalar>::value && !is_constant<VarMat>::value) {
var arena_a = a;
arena_t<promote_scalar_t<var, VarMat>> arena_b = b;
auto beta_val = beta(arena_a.val(), arena_b.val());
auto digamma_ab = to_arena(digamma(arena_a.val() + arena_b.val().array()));
return make_callback_var(
beta(arena_a.val(), arena_b.val()),
[arena_a, arena_b, digamma_ab](auto& vi) mutable {
const auto adj_val = (vi.adj().array() * vi.val().array()).eval();
arena_a.adj()
+= (adj_val * (digamma(arena_a.val()) - digamma_ab)).sum();
arena_b.adj().array()
+= adj_val * (digamma(arena_b.val().array()) - digamma_ab);
});
} else if (!is_constant<Scalar>::value) {
var arena_a = a;
arena_t<promote_scalar_t<double, VarMat>> arena_b = value_of(b);
auto digamma_ab = to_arena(digamma(arena_a.val())
- digamma(arena_a.val() + arena_b.array()));
return make_callback_var(
beta(arena_a.val(), arena_b),
[arena_a, arena_b, digamma_ab](auto& vi) mutable {
arena_a.adj()
+= (vi.adj().array() * digamma_ab * vi.val().array()).sum();
});
} else if (!is_constant<VarMat>::value) {
double arena_a = value_of(a);
arena_t<promote_scalar_t<var, VarMat>> arena_b = b;
auto beta_val = beta(arena_a, arena_b.val());
auto digamma_ab = to_arena((digamma(arena_b.val()).array()
- digamma(arena_a + arena_b.val().array()))
* beta_val.array());
return make_callback_var(
beta_val, [arena_a, arena_b, digamma_ab](auto& vi) mutable {
arena_b.adj().array() += vi.adj().array() * digamma_ab.array();
});
}
}

template <typename VarMat, typename Scalar,
require_var_matrix_t<VarMat>* = nullptr,
require_stan_scalar_t<Scalar>* = nullptr>
inline auto beta(const VarMat& a, const Scalar& b) {
if (!is_constant<VarMat>::value && !is_constant<Scalar>::value) {
arena_t<promote_scalar_t<var, VarMat>> arena_a = a;
var arena_b = b;
auto beta_val = beta(arena_a.val(), arena_b.val());
auto digamma_ab = to_arena(digamma(arena_a.val().array() + arena_b.val()));
return make_callback_var(
beta(arena_a.val(), arena_b.val()),
[arena_a, arena_b, digamma_ab](auto& vi) mutable {
const auto adj_val = (vi.adj().array() * vi.val().array()).eval();
arena_a.adj().array()
+= adj_val * (digamma(arena_a.val().array()) - digamma_ab);
arena_b.adj()
+= (adj_val * (digamma(arena_b.val()) - digamma_ab)).sum();
});
} else if (!is_constant<VarMat>::value) {
arena_t<promote_scalar_t<var, VarMat>> arena_a = a;
double arena_b = value_of(b);
auto digamma_ab = to_arena(digamma(arena_a.val()).array()
- digamma(arena_a.val().array() + arena_b));
return make_callback_var(beta(arena_a.val(), arena_b),
[arena_a, arena_b, digamma_ab](auto& vi) mutable {
arena_a.adj().array() += vi.adj().array()
* digamma_ab
* vi.val().array();
});
} else if (!is_constant<Scalar>::value) {
arena_t<promote_scalar_t<double, VarMat>> arena_a = value_of(a);
var arena_b = b;
auto beta_val = beta(arena_a, arena_b.val());
auto digamma_ab = to_arena(
(digamma(arena_b.val()) - digamma(arena_a.array() + arena_b.val()))
* beta_val.array());
return make_callback_var(
beta_val, [arena_a, arena_b, digamma_ab](auto& vi) mutable {
arena_b.adj() += (vi.adj().array() * digamma_ab.array()).sum();
});
}
}

} // namespace math
} // namespace stan
#endif
32 changes: 32 additions & 0 deletions stan/math/rev/fun/binary_log_loss.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,38 @@ inline var binary_log_loss(int y, const var& y_hat) {
}
}

template <typename Mat, require_eigen_t<Mat>* = nullptr>
inline auto binary_log_loss(int y, const var_value<Mat>& y_hat) {
if (y == 0) {
return make_callback_var(
-(-y_hat.val().array()).log1p(), [y_hat](auto& vi) mutable {
y_hat.adj().array() += vi.adj().array() / (1.0 - y_hat.val().array());
});
} else {
return make_callback_var(
-y_hat.val().array().log(), [y_hat](auto& vi) mutable {
y_hat.adj().array() -= vi.adj().array() / y_hat.val().array();
});
}
}

template <typename Mat, require_eigen_t<Mat>* = nullptr>
inline auto binary_log_loss(const std::vector<int>& y,
const var_value<Mat>& y_hat) {
arena_t<Eigen::Array<bool, -1, 1>> arena_y
= Eigen::Map<const Eigen::Array<int, -1, 1>>(y.data(), y.size())
.cast<bool>();
auto ret_val
= -(arena_y == 0)
.select((-y_hat.val().array()).log1p(), y_hat.val().array().log());
return make_callback_var(ret_val, [y_hat, arena_y](auto& vi) mutable {
y_hat.adj().array()
+= vi.adj().array()
/ (arena_y == 0)
.select((1.0 - y_hat.val().array()), -y_hat.val().array());
});
}

} // namespace math
} // namespace stan
#endif
5 changes: 5 additions & 0 deletions stan/math/rev/fun/ceil.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ namespace math {
*/
inline var ceil(const var& a) { return var(std::ceil(a.val())); }

template <typename T, require_matrix_t<T>* = nullptr>
inline auto ceil(const var_value<T>& a) {
return var_value<T>(a.val().array().ceil());
}

} // namespace math
} // namespace stan
#endif
9 changes: 9 additions & 0 deletions stan/math/rev/fun/erf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ inline var erf(const var& a) {
});
}

template <typename T, require_matrix_t<T>* = nullptr>
inline auto erf(const var_value<T>& a) {
auto precomp_erf
= to_arena(TWO_OVER_SQRT_PI * (-a.val().array().square()).exp());
return make_callback_var(erf(a.val()), [a, precomp_erf](auto& vi) mutable {
a.adj().array() += vi.adj().array() * precomp_erf;
});
}

} // namespace math
} // namespace stan
#endif
Loading