Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

var matrix specializations for a-c unary functions in rev #2256

Merged
merged 30 commits into from
Dec 19, 2020
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0c52267
adds unary functions from a to c for varmat
SteveBronder Dec 10, 2020
17b853d
Merge remote-tracking branch 'origin/develop' into feature/varmat-a-t…
SteveBronder Dec 12, 2020
2f8da9f
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Dec 12, 2020
fe2354c
update docs
SteveBronder Dec 14, 2020
afda49a
Merge commit '7f3792baa9e91438e79a2cf6594a38d105c9b24e' into HEAD
yashikno Dec 14, 2020
91d1472
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Dec 14, 2020
da21330
Merge remote-tracking branch 'origin/develop' into feature/varmat-a-t…
SteveBronder Dec 15, 2020
86d6f8a
add tests for vectors of varmat
SteveBronder Dec 15, 2020
13974c1
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Dec 15, 2020
77f77e8
update abs so it can be vectorized
SteveBronder Dec 15, 2020
1768e1e
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Dec 15, 2020
198bec3
update abs so it can be vectorized
SteveBronder Dec 15, 2020
010ebaf
Merge branch 'feature/varmat-a-to-c-unary' of github.com:stan-dev/mat…
SteveBronder Dec 15, 2020
fac9b65
update with develop, write out sin and tan
SteveBronder Dec 17, 2020
11e3459
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Dec 17, 2020
3a02c30
updates to use cwiseproduct()
SteveBronder Dec 17, 2020
237f6a0
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Dec 17, 2020
0e610cf
update tan test
SteveBronder Dec 17, 2020
37d61ac
update tan requires
SteveBronder Dec 17, 2020
b0738f7
Merge commit '72552dcc6040018b35d2bb1303b461493b815d95' into HEAD
yashikno Dec 17, 2020
ba7dae7
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Dec 17, 2020
8129f72
update tan requires
SteveBronder Dec 17, 2020
88c35d9
update tan requires
SteveBronder Dec 17, 2020
4fa4816
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Dec 17, 2020
0db6223
Fixes tests and docs for var<matrix> trig functions and apply_scalar_…
SteveBronder Dec 18, 2020
a7deb36
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Dec 18, 2020
7e41553
Updated abs/fabs tests to be like the others (Issue #2101)
bbbales2 Dec 18, 2020
b26159f
remove .noalias() from trig functions in rev
SteveBronder Dec 18, 2020
812ae25
Merge commit '3d03131bdcbb2d33878f4117c857ac6404c7e10e' into HEAD
yashikno Dec 18, 2020
c5caa59
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Dec 18, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 42 additions & 6 deletions stan/math/prim/fun/abs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,57 @@

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/hypot.hpp>
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
#include <stan/math/prim/functor/apply_vector_unary.hpp>
#include <cmath>

namespace stan {
namespace math {

/**
* Return floating-point absolute value.
* Structure to wrap `abs()` so it can be vectorized.
*
* Delegates to <code>fabs(double)</code> rather than
* <code>std::abs(int)</code>.
* @tparam T type of variable
* @param x argument
* @return Arc cosine of variable in radians.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolute value

*/
struct abs_fun {
template <typename T>
static inline T fun(const T& x) {
using std::fabs;
return fabs(x);
}
};

/**
* Returns the elementwise `abs()` of the input,
* which may be a scalar or any Stan container of numeric scalars.
*
* @param x scalar
* @return absolute value of scalar
* @tparam Container type of container
* @param x argument
* @return Arc cosine of each variable in the container, in radians.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolute value

*/
inline double abs(double x) { return std::fabs(x); }
template <typename Container,
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
require_not_var_matrix_t<Container>* = nullptr>
inline auto abs(const Container& x) {
return apply_scalar_unary<abs_fun, Container>::apply(x);
}

/**
* Version of `abs()` that accepts std::vectors, Eigen Matrix/Array objects
* or expressions, and containers of these.
*
* @tparam Container Type of x
* @param x argument
* @return Arc cosine of each variable in the container, in radians.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolute value

*/
template <typename Container,
require_container_st<std::is_arithmetic, Container>* = nullptr>
inline auto abs(const Container& x) {
return apply_vector_unary<Container>::apply(
x, [](const auto& v) { return v.array().abs(); });
}

namespace internal {
/**
Expand Down
1 change: 1 addition & 0 deletions stan/math/prim/fun/acos.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ struct acos_fun {
*/
template <typename Container,
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
require_not_var_matrix_t<Container>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
Container>* = nullptr>
inline auto acos(const Container& x) {
Expand Down
5 changes: 3 additions & 2 deletions stan/math/prim/fun/acosh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ struct acosh_fun {
* @param x container
* @return Elementwise acosh of members of container.
*/
template <typename T, require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
T>* = nullptr>
template <
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not require_container_st?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is the other thing I'd expect:

template <typename Container,
          require_container_st<std::is_arithmetic, Container>* = nullptr>
inline auto acosh(const Container& x) {
  return apply_vector_unary<Container>::apply(
      x, [](const auto& v) { return v.array().acosh(); });
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this can just be template <typename T, require_container_t<T>* = nullptr>

Copy link
Member

@bbbales2 bbbales2 Dec 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were some functions here that didn't have apply_vector_unary/Container implementations. I'm assuming it's because the Eigen array functions didn't exist for them to be written (cbrt, this, and others). I don't think cleanup is what we're going for in this pull really, so it's fine to leave it.

Edit: didn't exist at the time these functions were written but do exist now

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(But if you want to clean things, feel free, happy to review)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ups actually ignore that, this needs to be able to take in scalar types so that's why we can't use require_constainer_st

typename T, require_not_var_matrix_t<T>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
inline auto acosh(const T& x) {
return apply_scalar_unary<acosh_fun, T>::apply(x);
}
Expand Down
1 change: 1 addition & 0 deletions stan/math/prim/fun/asin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ struct asin_fun {
*/
template <typename Container,
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
require_not_var_matrix_t<Container>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
Container>* = nullptr>
inline auto asin(const Container& x) {
Expand Down
5 changes: 3 additions & 2 deletions stan/math/prim/fun/asinh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ struct asinh_fun {
* @param x container
* @return Inverse hyperbolic sine of each value in the container.
*/
template <typename T, require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
T>* = nullptr>
template <
typename T, require_not_var_matrix_t<T>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
inline auto asinh(const T& x) {
return apply_scalar_unary<asinh_fun, T>::apply(x);
}
Expand Down
9 changes: 5 additions & 4 deletions stan/math/prim/fun/atan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ struct atan_fun {
* @param x container
* @return Arctan of each value in x, in radians.
*/
template <typename Container,
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
Container>* = nullptr>
template <
typename Container,
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
require_not_var_matrix_t<Container>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
inline auto atan(const Container& x) {
return apply_scalar_unary<atan_fun, Container>::apply(x);
}
Expand Down
5 changes: 3 additions & 2 deletions stan/math/prim/fun/atanh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ struct atanh_fun {
* @param x container
* @return Elementwise atanh of members of container.
*/
template <typename T, require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
T>* = nullptr>
template <
typename T, require_not_var_matrix_t<T>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
inline auto atanh(const T& x) {
return apply_scalar_unary<atanh_fun, T>::apply(x);
}
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/block.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace math {
* @param ncols Number of columns in block.
* @throw std::out_of_range if either index is out of range.
*/
template <typename T, require_eigen_t<T>* = nullptr>
template <typename T, require_matrix_t<T>* = nullptr>
inline auto block(const T& m, size_t i, size_t j, size_t nrows, size_t ncols) {
check_row_index("block", "i", m, i);
check_row_index("block", "i+nrows-1", m, i + nrows - 1);
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/cbrt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct cbrt_fun {
* @param x container
* @return Cube root of each value in x.
*/
template <typename T>
template <typename T, require_not_var_matrix_t<T>* = nullptr>
inline auto cbrt(const T& x) {
return apply_scalar_unary<cbrt_fun, T>::apply(x);
}
Expand Down
1 change: 1 addition & 0 deletions stan/math/prim/fun/cos.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct cos_fun {
*/
template <typename Container,
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
require_not_var_matrix_t<Container>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
Container>* = nullptr>
inline auto cos(const Container& x) {
Expand Down
1 change: 1 addition & 0 deletions stan/math/prim/fun/cosh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct cosh_fun {
*/
template <typename Container,
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
require_not_var_matrix_t<Container>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
Container>* = nullptr>
inline auto cosh(const Container& x) {
Expand Down
3 changes: 2 additions & 1 deletion stan/math/prim/fun/fabs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ struct fabs_fun {
* @return Absolute value of each value in x.
*/
template <typename Container,
require_not_container_st<std::is_arithmetic, Container>* = nullptr>
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
require_not_var_matrix_t<Container>* = nullptr>
inline auto fabs(const Container& x) {
return apply_scalar_unary<fabs_fun, Container>::apply(x);
}
Expand Down
1 change: 1 addition & 0 deletions stan/math/prim/fun/sin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct sin_fun {
*/
template <
typename T, require_not_container_st<std::is_arithmetic, T>* = nullptr,
require_not_var_matrix_t<T>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
inline auto sin(const T& x) {
return apply_scalar_unary<sin_fun, T>::apply(x);
Expand Down
1 change: 1 addition & 0 deletions stan/math/prim/fun/sinh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct sinh_fun {
*/
template <typename Container,
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
require_not_var_matrix_t<Container>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
Container>* = nullptr>
inline auto sinh(const Container& x) {
Expand Down
1 change: 1 addition & 0 deletions stan/math/prim/fun/tan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct tan_fun {
*/
template <typename Container,
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
require_not_var_matrix_t<Container>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
Container>* = nullptr>
inline auto tan(const Container& x) {
Expand Down
1 change: 1 addition & 0 deletions stan/math/prim/fun/tanh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct tanh_fun {
*/
template <typename Container,
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
require_not_var_matrix_t<Container>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
Container>* = nullptr>
inline auto tanh(const Container& x) {
Expand Down
9 changes: 8 additions & 1 deletion stan/math/prim/functor/apply_vector_unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ struct apply_vector_unary<T, require_std_vector_vt<is_stan_scalar, T>> {
}
};

namespace internal {
template <typename T>
using is_container_or_var_matrix
= disjunction<is_container<T>, is_var_matrix<T>>;
}
bbbales2 marked this conversation as resolved.
Show resolved Hide resolved

/**
* Specialisation for use with nested containers (std::vectors).
* For each of the member functions, an std::vector with the appropriate
Expand All @@ -170,7 +176,8 @@ struct apply_vector_unary<T, require_std_vector_vt<is_stan_scalar, T>> {
*
*/
template <typename T>
struct apply_vector_unary<T, require_std_vector_vt<is_container, T>> {
struct apply_vector_unary<
T, require_std_vector_vt<internal::is_container_or_var_matrix, T>> {
using T_vt = value_type_t<T>;

/**
Expand Down
3 changes: 1 addition & 2 deletions stan/math/prim/meta/is_container.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ namespace stan {
*/
template <typename Container>
using is_container = bool_constant<
math::disjunction<is_eigen<Container>, is_std_vector<Container>,
is_var_matrix<Container>>::value>;
math::disjunction<is_eigen<Container>, is_std_vector<Container>>::value>;

STAN_ADD_REQUIRE_UNARY(container, is_container, general_types);
STAN_ADD_REQUIRE_CONTAINER(container, is_container, general_types);
Expand Down
6 changes: 5 additions & 1 deletion stan/math/rev/fun/abs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,14 @@ namespace math {
\end{cases}
\f]
*
* @tparam T A floating point type or an Eigen type with floating point scalar.
* @param a Variable input.
* @return Absolute value of variable.
*/
inline var abs(const var& a) { return fabs(a); }
template <typename T>
inline auto abs(const var_value<T>& a) {
return fabs(a);
}

/**
* Return the absolute value of the complex argument.
Expand Down
32 changes: 21 additions & 11 deletions stan/math/rev/fun/acos.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,6 @@
namespace stan {
namespace math {

namespace internal {
class acos_vari : public op_v_vari {
public:
explicit acos_vari(vari* avi) : op_v_vari(std::acos(avi->val_), avi) {}
void chain() {
avi_->adj_ -= adj_ / std::sqrt(1.0 - (avi_->val_ * avi_->val_));
}
};
} // namespace internal

/**
* Return the principal value of the arc cosine of a variable,
* in radians (cmath).
Expand Down Expand Up @@ -62,7 +52,27 @@ class acos_vari : public op_v_vari {
* @param x argument
* @return Arc cosine of variable, in radians.
*/
inline var acos(const var& x) { return var(new internal::acos_vari(x.vi_)); }
inline var acos(const var& x) {
return make_callback_var(std::acos(x.val()), [x](const auto& vi) mutable {
x.adj() -= vi.adj_ / std::sqrt(1.0 - (x.val() * x.val()));
});
}

/**
* Return the principal value of the arc cosine of a variable,
* in radians (cmath).
*
* @param x a `var_value` with inner Eigen type
* @return Arc cosine of variable, in radians.
*/
template <typename VarMat, require_var_matrix_t<VarMat>* = nullptr>
inline auto acos(const VarMat& x) {
return make_callback_var(
x.val().array().acos().matrix(), [x](const auto& vi) mutable {
x.adj().array()
-= vi.adj_.array() / (1.0 - (x.val().array().square())).sqrt();
});
}

/**
* Return the arc cosine of the complex argument.
Expand Down
36 changes: 23 additions & 13 deletions stan/math/rev/fun/acosh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,6 @@
namespace stan {
namespace math {

namespace internal {
class acosh_vari : public op_v_vari {
public:
acosh_vari(double val, vari* avi) : op_v_vari(val, avi) {}
void chain() {
avi_->adj_ += adj_ / std::sqrt(avi_->val_ * avi_->val_ - 1.0);
}
};
} // namespace internal

/**
* The inverse hyperbolic cosine function for variables (C99).
*
Expand Down Expand Up @@ -67,11 +57,31 @@ class acosh_vari : public op_v_vari {
\frac{\partial \, \cosh^{-1}(x)}{\partial x} = \frac{1}{\sqrt{x^2-1}}
\f]
*
* @param a The variable.
* @param x The variable.
* @return Inverse hyperbolic cosine of the variable.
*/
inline var acosh(const var& x) {
return make_callback_var(acosh(x.val()), [x](const auto& vi) mutable {
x.adj() += vi.adj_ / std::sqrt(x.val() * x.val() - 1.0);
});
}
/**
* The inverse hyperbolic cosine function for variables (C99).
*
* For non-variable function, see ::acosh().
*
* @tparam Varmat a `var_value` with inner Eigen type
* @param x The variable
* @return Inverse hyperbolic cosine of the variable.
*/
inline var acosh(const var& a) {
return var(new internal::acosh_vari(acosh(a.val()), a.vi_));
template <typename VarMat, require_var_matrix_t<VarMat>* = nullptr>
inline auto acosh(const VarMat& x) {
return make_callback_var(
x.val().unaryExpr([](const auto x) { return acosh(x); }),
[x](const auto& vi) mutable {
x.adj().array()
+= vi.adj_.array() / (x.val().array().square() - 1.0).sqrt();
});
}

/**
Expand Down
Loading